From 08a87b102c12acec85c19c462a4c3827ac569e58 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Thu, 1 Oct 2020 17:47:50 -0700 Subject: [PATCH] more fpa Signed-off-by: Nikolaj Bjorner --- src/ast/fpa/fpa2bv_converter.cpp | 92 ++++++++++++++++++ src/ast/fpa/fpa2bv_converter.h | 3 + src/model/fpa_factory.h | 6 +- src/model/model.cpp | 3 + src/sat/sat_extension.h | 2 +- src/sat/smt/euf_model.cpp | 2 + src/sat/smt/euf_solver.cpp | 7 ++ src/sat/smt/fpa_solver.cpp | 162 +++++++------------------------ src/sat/smt/fpa_solver.h | 19 ++-- src/sat/smt/sat_th.h | 5 + 10 files changed, 163 insertions(+), 138 deletions(-) diff --git a/src/ast/fpa/fpa2bv_converter.cpp b/src/ast/fpa/fpa2bv_converter.cpp index 89023c8f5..6749773c3 100644 --- a/src/ast/fpa/fpa2bv_converter.cpp +++ b/src/ast/fpa/fpa2bv_converter.cpp @@ -4358,3 +4358,95 @@ void fpa2bv_converter_wrapped::mk_rm_const(func_decl* f, expr_ref& result) { } } + +expr* fpa2bv_converter_wrapped::bv2rm_value(expr* b) { + app* result = nullptr; + unsigned bv_sz; + rational val(0); + VERIFY(m_bv_util.is_numeral(b, val, bv_sz)); + SASSERT(bv_sz == 3); + + switch (val.get_uint64()) { + case BV_RM_TIES_TO_AWAY: result = m_util.mk_round_nearest_ties_to_away(); break; + case BV_RM_TIES_TO_EVEN: result = m_util.mk_round_nearest_ties_to_even(); break; + case BV_RM_TO_NEGATIVE: result = m_util.mk_round_toward_negative(); break; + case BV_RM_TO_POSITIVE: result = m_util.mk_round_toward_positive(); break; + case BV_RM_TO_ZERO: + default: result = m_util.mk_round_toward_zero(); + } + + TRACE("t_fpa", tout << "result: " << mk_ismt2_pp(result, m) << std::endl;); + return result; +} + +expr* fpa2bv_converter_wrapped::bv2fpa_value(sort* s, expr* a, expr* b, expr* c) { + mpf_manager& mpfm = m_util.fm(); + unsynch_mpz_manager& mpzm = mpfm.mpz_manager(); + app* result; + unsigned ebits = m_util.get_ebits(s); + unsigned sbits = m_util.get_sbits(s); + + scoped_mpz bias(mpzm); + mpzm.power(mpz(2), ebits - 1, bias); + mpzm.dec(bias); + + scoped_mpz sgn_z(mpzm), sig_z(mpzm), exp_z(mpzm); + unsigned bv_sz; + + if (b == nullptr) { + SASSERT(m_bv_util.is_bv(a)); + SASSERT(m_bv_util.get_bv_size(a) == (ebits + sbits)); + + rational all_r(0); + scoped_mpz all_z(mpzm); + + VERIFY(m_bv_util.is_numeral(a, all_r, bv_sz)); + SASSERT(bv_sz == (ebits + sbits)); + SASSERT(all_r.is_int()); + mpzm.set(all_z, all_r.to_mpq().numerator()); + + mpzm.machine_div2k(all_z, ebits + sbits - 1, sgn_z); + mpzm.mod(all_z, mpfm.m_powers2(ebits + sbits - 1), all_z); + + mpzm.machine_div2k(all_z, sbits - 1, exp_z); + mpzm.mod(all_z, mpfm.m_powers2(sbits - 1), all_z); + + mpzm.set(sig_z, all_z); + } + else { + SASSERT(b); + SASSERT(c); + rational sgn_r(0), exp_r(0), sig_r(0); + + bool r = m_bv_util.is_numeral(a, sgn_r, bv_sz); + SASSERT(r && bv_sz == 1); + r = m_bv_util.is_numeral(b, exp_r, bv_sz); + SASSERT(r && bv_sz == ebits); + r = m_bv_util.is_numeral(c, sig_r, bv_sz); + SASSERT(r && bv_sz == sbits - 1); + (void)r; + + SASSERT(mpzm.is_one(sgn_r.to_mpq().denominator())); + SASSERT(mpzm.is_one(exp_r.to_mpq().denominator())); + SASSERT(mpzm.is_one(sig_r.to_mpq().denominator())); + + mpzm.set(sgn_z, sgn_r.to_mpq().numerator()); + mpzm.set(exp_z, exp_r.to_mpq().numerator()); + mpzm.set(sig_z, sig_r.to_mpq().numerator()); + } + + scoped_mpz exp_u = exp_z - bias; + SASSERT(mpzm.is_int64(exp_u)); + + scoped_mpf f(mpfm); + mpfm.set(f, ebits, sbits, mpzm.is_one(sgn_z), mpzm.get_int64(exp_u), sig_z); + result = m_util.mk_value(f); + + TRACE("t_fpa", tout << "result: [" << + mpzm.to_string(sgn_z) << "," << + mpzm.to_string(exp_z) << "," << + mpzm.to_string(sig_z) << "] --> " << + mk_ismt2_pp(result, m) << std::endl;); + + return result; +} diff --git a/src/ast/fpa/fpa2bv_converter.h b/src/ast/fpa/fpa2bv_converter.h index 67e1834a8..31a5f7f40 100644 --- a/src/ast/fpa/fpa2bv_converter.h +++ b/src/ast/fpa/fpa2bv_converter.h @@ -238,6 +238,9 @@ class fpa2bv_converter_wrapped : public fpa2bv_converter { void mk_rm_const(func_decl * f, expr_ref & result) override; app_ref wrap(expr * e); app_ref unwrap(expr * e, sort * s); + + expr* bv2rm_value(expr* b); + expr* bv2fpa_value(sort* s, expr* a, expr* b = nullptr, expr* c = nullptr); }; diff --git a/src/model/fpa_factory.h b/src/model/fpa_factory.h index 5930cb575..ff2f18e35 100644 --- a/src/model/fpa_factory.h +++ b/src/model/fpa_factory.h @@ -30,9 +30,9 @@ class fpa_value_factory : public value_factory { } public: - fpa_value_factory(ast_manager & m, family_id fid) : - value_factory(m, fid), - m_util(m) {} + fpa_value_factory(ast_manager & m, family_id fid) : + value_factory(m, fid), + m_util(m) {} ~fpa_value_factory() override {} diff --git a/src/model/model.cpp b/src/model/model.cpp index 4542cd244..647dc444f 100644 --- a/src/model/model.cpp +++ b/src/model/model.cpp @@ -36,6 +36,7 @@ Revision History: #include "model/seq_factory.h" #include "model/datatype_factory.h" #include "model/numeral_factory.h" +#include "model/fpa_factory.h" model::model(ast_manager & m): @@ -99,11 +100,13 @@ bool model::eval_expr(expr * e, expr_ref & result, bool model_completion) { value_factory* model::get_factory(sort* s) { if (m_factories.plugins().empty()) { seq_util su(m); + fpa_util fu(m); m_factories.register_plugin(alloc(array_factory, m, *this)); m_factories.register_plugin(alloc(datatype_factory, m, *this)); m_factories.register_plugin(alloc(bv_factory, m)); m_factories.register_plugin(alloc(arith_factory, m)); m_factories.register_plugin(alloc(seq_factory, m, su.get_family_id(), *this)); + m_factories.register_plugin(alloc(fpa_value_factory, m, fu.get_family_id())); } family_id fid = s->get_family_id(); return m_factories.get_plugin(fid); diff --git a/src/sat/sat_extension.h b/src/sat/sat_extension.h index 8b5ef3fe3..fd341059e 100644 --- a/src/sat/sat_extension.h +++ b/src/sat/sat_extension.h @@ -96,7 +96,7 @@ namespace sat { virtual std::ostream& display(std::ostream& out) const = 0; virtual std::ostream& display_justification(std::ostream& out, ext_justification_idx idx) const = 0; virtual std::ostream& display_constraint(std::ostream& out, ext_constraint_idx idx) const = 0; - virtual void collect_statistics(statistics& st) const = 0; + virtual void collect_statistics(statistics& st) const {} virtual extension* copy(solver* s) { UNREACHABLE(); return nullptr; } virtual void find_mutexes(literal_vector& lits, vector & mutexes) {} virtual void gc() {} diff --git a/src/sat/smt/euf_model.cpp b/src/sat/smt/euf_model.cpp index 83d02fada..411087d1e 100644 --- a/src/sat/smt/euf_model.cpp +++ b/src/sat/smt/euf_model.cpp @@ -29,6 +29,8 @@ namespace euf { deps.topological_sort(); dependencies2values(deps, mdl); values2model(deps, mdl); + for (auto* mb : m_solvers) + mb->finalize_model(*mdl); } bool solver::include_func_interp(func_decl* f) { diff --git a/src/sat/smt/euf_solver.cpp b/src/sat/smt/euf_solver.cpp index bb5d9d4a1..85bbd72e2 100644 --- a/src/sat/smt/euf_solver.cpp +++ b/src/sat/smt/euf_solver.cpp @@ -24,6 +24,7 @@ Author: #include "sat/smt/euf_solver.h" #include "sat/smt/array_solver.h" #include "sat/smt/q_solver.h" +#include "sat/smt/fpa_solver.h" namespace euf { @@ -97,6 +98,7 @@ namespace euf { pb_util pb(m); bv_util bvu(m); array_util au(m); + fpa_util fpa(m); if (pb.get_family_id() == fid) { ext = alloc(sat::ba_solver, *this, fid); if (use_drat()) @@ -112,6 +114,11 @@ namespace euf { if (use_drat()) s().get_drat().add_theory(fid, symbol("array")); } + else if (fpa.get_family_id() == fid) { + ext = alloc(fpa::solver, *this); + if (use_drat()) + s().get_drat().add_theory(fid, symbol("fpa")); + } if (ext) { ext->set_solver(m_solver); ext->push_scopes(s().num_scopes()); diff --git a/src/sat/smt/fpa_solver.cpp b/src/sat/smt/fpa_solver.cpp index bfaab1316..122437fe4 100644 --- a/src/sat/smt/fpa_solver.cpp +++ b/src/sat/smt/fpa_solver.cpp @@ -26,19 +26,19 @@ namespace fpa { euf::th_euf_solver(ctx, ctx.get_manager().mk_family_id("fpa")), m_th_rw(ctx.get_manager()), m_converter(ctx.get_manager(), m_th_rw), - m_rw(ctx.get_manager(), m_converter, params_ref()), + m_rw(ctx.get_manager(), m_converter, params_ref()), m_fpa_util(m_converter.fu()), m_bv_util(m_converter.bu()), m_arith_util(m_converter.au()) { params_ref p; p.set_bool("arith_lhs", true); - m_th_rw.updt_params(p); + m_th_rw.updt_params(p); } solver::~solver() { - dec_ref_map_key_values(m, m_conversions); - SASSERT(m_conversions.empty()); + dec_ref_map_key_values(m, m_conversions); + SASSERT(m_conversions.empty()); } @@ -82,7 +82,7 @@ namespace fpa { return conds; } - void solver::attach_new_th_var(enode * n) { + void solver::attach_new_th_var(enode* n) { theory_var v = mk_var(n); ctx.attach_th_var(n, this, v); TRACE("t_fpa", tout << "new theory var: " << mk_ismt2_pp(n->get_expr(), m) << " := " << v << "\n";); @@ -125,7 +125,7 @@ namespace fpa { sat::literal atom(ctx.get_si().add_bool_var(e), false); atom = ctx.attach_lit(atom, e); sat::literal bv_atom = b_internalize(m_rw.convert_atom(m_th_rw, e)); - sat::literal_vector conds = mk_side_conditions(); + sat::literal_vector conds = mk_side_conditions(); conds.push_back(bv_atom); add_equiv_and(atom, conds); if (root) { @@ -133,48 +133,47 @@ namespace fpa { atom.neg(); add_unit(atom); } - return true; } - - switch (a->get_decl_kind()) { + else { + switch (a->get_decl_kind()) { case OP_FPA_TO_FP: case OP_FPA_TO_UBV: case OP_FPA_TO_SBV: case OP_FPA_TO_REAL: case OP_FPA_TO_IEEE_BV: { expr_ref conv = convert(e); - expr_ref eq = ctx.mk_eq(e, conv); + expr_ref eq = ctx.mk_eq(e, conv); add_unit(b_internalize(eq)); add_units(mk_side_conditions()); break; } - default: /* ignore */; + default: /* ignore */ + break; + } } return true; } - void solver::apply_sort_cnstr(enode * n, sort * s) { + void solver::apply_sort_cnstr(enode* n, sort* s) { TRACE("t_fpa", tout << "apply sort cnstr for: " << mk_ismt2_pp(n->get_expr(), m) << "\n";); SASSERT(s->get_family_id() == get_id()); SASSERT(m_fpa_util.is_float(s) || m_fpa_util.is_rm(s)); SASSERT(m_fpa_util.is_float(n->get_expr()) || m_fpa_util.is_rm(n->get_expr())); SASSERT(n->get_decl()->get_range() == s); - expr * owner = n->get_expr(); + expr* owner = n->get_expr(); if (is_attached_to_var(n)) return; attach_new_th_var(n); - if (m_fpa_util.is_rm(s)) { + if (m_fpa_util.is_rm(s) && !m_fpa_util.is_bv2rm(owner)) { // For every RM term, we need to make sure that it's // associated bit-vector is within the valid range. - if (!m_fpa_util.is_bv2rm(owner)) { - expr_ref valid(m), limit(m); - limit = m_bv_util.mk_numeral(4, 3); - valid = m_bv_util.mk_ule(m_converter.wrap(owner), limit); - add_unit(b_internalize(valid)); - } + expr_ref valid(m), limit(m); + limit = m_bv_util.mk_numeral(4, 3); + valid = m_bv_util.mk_ule(m_converter.wrap(owner), limit); + add_unit(b_internalize(valid)); } activate(owner); } @@ -307,19 +306,22 @@ namespace fpa { expr* a = values.get(n->get_arg(0)->get_root_id()); expr* b = values.get(n->get_arg(1)->get_root_id()); expr* c = values.get(n->get_arg(2)->get_root_id()); - value = bvs2fpa_value(m.get_sort(e), a, b, c); + value = m_converter.bv2fpa_value(m.get_sort(e), a, b, c); } else if (m_fpa_util.is_bv2rm(e)) { SASSERT(n->num_args() == 1); - value = bv2rm_value(values.get(n->get_arg(0)->get_root_id())); + value = m_converter.bv2rm_value(values.get(n->get_arg(0)->get_root_id())); } - else if (m_fpa_util.is_rm(e) && is_wrapped()) - value = bv2rm_value(values.get(expr2enode(wrapped)->get_root_id())); + else if (m_fpa_util.is_rm(e) && is_wrapped()) + value = m_converter.bv2rm_value(values.get(expr2enode(wrapped)->get_root_id())); + else if (m_fpa_util.is_rm(e)) + value = m_fpa_util.mk_round_toward_zero(); else if (m_fpa_util.is_float(e) && is_wrapped()) { expr* a = values.get(expr2enode(wrapped)->get_root_id()); - value = bvs2fpa_value(m.get_sort(e), a, nullptr, nullptr); + value = m_converter.bv2fpa_value(m.get_sort(e), a); } else { + SASSERT(m_fpa_util.is_float(e)); unsigned ebits = m_fpa_util.get_ebits(m.get_sort(e)); unsigned sbits = m_fpa_util.get_sbits(m.get_sort(e)); value = m_fpa_util.mk_pzero(ebits, sbits); @@ -327,99 +329,6 @@ namespace fpa { values.set(n->get_root_id(), value); } - expr* solver::bv2rm_value(expr* b) { - app* result = nullptr; - unsigned bv_sz; - rational val(0); - VERIFY(m_bv_util.is_numeral(b, val, bv_sz)); - SASSERT(bv_sz == 3); - - switch (val.get_uint64()) { - case BV_RM_TIES_TO_AWAY: result = m_fpa_util.mk_round_nearest_ties_to_away(); break; - case BV_RM_TIES_TO_EVEN: result = m_fpa_util.mk_round_nearest_ties_to_even(); break; - case BV_RM_TO_NEGATIVE: result = m_fpa_util.mk_round_toward_negative(); break; - case BV_RM_TO_POSITIVE: result = m_fpa_util.mk_round_toward_positive(); break; - case BV_RM_TO_ZERO: - default: result = m_fpa_util.mk_round_toward_zero(); - } - - TRACE("t_fpa", tout << "result: " << mk_ismt2_pp(result, m) << std::endl;); - return result; - } - - - expr* solver::bvs2fpa_value(sort* s, expr* a, expr* b, expr* c) { - mpf_manager& mpfm = m_fpa_util.fm(); - unsynch_mpz_manager& mpzm = mpfm.mpz_manager(); - app* result; - unsigned ebits = m_fpa_util.get_ebits(s); - unsigned sbits = m_fpa_util.get_sbits(s); - - scoped_mpz bias(mpzm); - mpzm.power(mpz(2), ebits - 1, bias); - mpzm.dec(bias); - - scoped_mpz sgn_z(mpzm), sig_z(mpzm), exp_z(mpzm); - unsigned bv_sz; - - if (b == nullptr) { - SASSERT(m_bv_util.is_bv(a)); - SASSERT(m_bv_util.get_bv_size(a) == (ebits + sbits)); - - rational all_r(0); - scoped_mpz all_z(mpzm); - - VERIFY(m_bv_util.is_numeral(a, all_r, bv_sz)); - SASSERT(bv_sz == (ebits + sbits)); - SASSERT(all_r.is_int()); - mpzm.set(all_z, all_r.to_mpq().numerator()); - - mpzm.machine_div2k(all_z, ebits + sbits - 1, sgn_z); - mpzm.mod(all_z, mpfm.m_powers2(ebits + sbits - 1), all_z); - - mpzm.machine_div2k(all_z, sbits - 1, exp_z); - mpzm.mod(all_z, mpfm.m_powers2(sbits - 1), all_z); - - mpzm.set(sig_z, all_z); - } - else { - SASSERT(b); - SASSERT(c); - rational sgn_r(0), exp_r(0), sig_r(0); - - bool r = m_bv_util.is_numeral(a, sgn_r, bv_sz); - SASSERT(r && bv_sz == 1); - r = m_bv_util.is_numeral(b, exp_r, bv_sz); - SASSERT(r && bv_sz == ebits); - r = m_bv_util.is_numeral(c, sig_r, bv_sz); - SASSERT(r && bv_sz == sbits - 1); - (void)r; - - SASSERT(mpzm.is_one(sgn_r.to_mpq().denominator())); - SASSERT(mpzm.is_one(exp_r.to_mpq().denominator())); - SASSERT(mpzm.is_one(sig_r.to_mpq().denominator())); - - mpzm.set(sgn_z, sgn_r.to_mpq().numerator()); - mpzm.set(exp_z, exp_r.to_mpq().numerator()); - mpzm.set(sig_z, sig_r.to_mpq().numerator()); - } - - scoped_mpz exp_u = exp_z - bias; - SASSERT(mpzm.is_int64(exp_u)); - - scoped_mpf f(mpfm); - mpfm.set(f, ebits, sbits, mpzm.is_one(sgn_z), mpzm.get_int64(exp_u), sig_z); - result = m_fpa_util.mk_value(f); - - TRACE("t_fpa", tout << "result: [" << - mpzm.to_string(sgn_z) << "," << - mpzm.to_string(exp_z) << "," << - mpzm.to_string(sig_z) << "] --> " << - mk_ismt2_pp(result, m) << std::endl;); - - return result; - } - void solver::add_dep(euf::enode* n, top_sort& dep) { expr* e = n->get_expr(); if (m_fpa_util.is_fp(e)) { @@ -432,13 +341,13 @@ namespace fpa { dep.add(n, n->get_arg(0)); } else if (m_fpa_util.is_rm(e) || m_fpa_util.is_float(e)) { - app_ref wrapped = m_converter.wrap(e); - if (expr2enode(wrapped)) - dep.add(n, expr2enode(wrapped)); + euf::enode* wrapped = expr2enode(m_converter.wrap(e)); + if (wrapped) + dep.add(n, wrapped); } } - std::ostream& solver::display(std::ostream & out) const { + std::ostream& solver::display(std::ostream& out) const { bool first = true; for (enode* n : ctx.get_egraph().nodes()) { theory_var v = n->get_th_var(m_fpa_util.get_family_id()); @@ -450,11 +359,11 @@ namespace fpa { } } // if there are no fpa theory variables, was fp ever used? - if (first) + if (first) return out; out << "bv theory variables:" << std::endl; - for (enode * n : ctx.get_egraph().nodes()) { + for (enode* n : ctx.get_egraph().nodes()) { theory_var v = n->get_th_var(m_bv_util.get_family_id()); if (v != -1) out << v << " -> " << mk_ismt2_pp(n->get_expr(), m) << std::endl; @@ -468,14 +377,13 @@ namespace fpa { } out << "equivalence classes:\n"; - for (enode * n : ctx.get_egraph().nodes()) { - expr * e = n->get_expr(); + for (enode* n : ctx.get_egraph().nodes()) { + expr* e = n->get_expr(); out << n->get_root_id() << " --> " << mk_ismt2_pp(e, m) << std::endl; } return out; } - void solver::finalize_model(model& mdl) { model new_model(m); diff --git a/src/sat/smt/fpa_solver.h b/src/sat/smt/fpa_solver.h index 68ceb20c7..f09a1dd8c 100644 --- a/src/sat/smt/fpa_solver.h +++ b/src/sat/smt/fpa_solver.h @@ -45,12 +45,7 @@ namespace fpa { sat::literal_vector mk_side_conditions(); void attach_new_th_var(enode* n); void activate(expr* e); - void ensure_equality_relation(theory_var x, theory_var y); - expr* bv2rm_value(expr* b); - expr* bvs2fpa_value(sort* s, expr* a, expr* b, expr* c); - - void finalize_model(model& mdl); - + void ensure_equality_relation(theory_var x, theory_var y); public: solver(euf::solver& ctx); @@ -66,10 +61,20 @@ namespace fpa { void apply_sort_cnstr(euf::enode* n, sort* s) override; std::ostream& display(std::ostream& out) const override; - + std::ostream& display_justification(std::ostream& out, sat::ext_justification_idx idx) const override { UNREACHABLE(); return out; } + std::ostream& display_constraint(std::ostream& out, sat::ext_constraint_idx idx) const override { UNREACHABLE(); return out; } void add_value(euf::enode* n, model& mdl, expr_ref_vector& values) override; void add_dep(euf::enode* n, top_sort& dep) override; + void finalize_model(model& mdl) override; + + bool unit_propagate() override { return false; } + void get_antecedents(sat::literal l, sat::ext_justification_idx idx, sat::literal_vector& r, bool probing) override { UNREACHABLE(); } + sat::check_result check() override { return sat::check_result::CR_DONE; } + + euf::th_solver* clone(sat::solver*, euf::solver& ctx) { return alloc(solver, ctx); } + }; } + diff --git a/src/sat/smt/sat_th.h b/src/sat/smt/sat_th.h index f18d99dd4..033ca85ac 100644 --- a/src/sat/smt/sat_th.h +++ b/src/sat/smt/sat_th.h @@ -86,6 +86,11 @@ namespace euf { \brief should function be included in model. */ virtual bool include_func_interp(func_decl* f) const { return false; } + + /** + \brief conclude model building + */ + virtual void finalize_model(model& mdl) {} }; class th_solver : public sat::extension, public th_model_builder, public th_decompile, public th_internalizer {