From faa3a7ab4f64ade7bc1d4813512df361cf1c5b2f Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Fri, 15 Dec 2023 13:50:26 -0800 Subject: [PATCH] updates to poly --- src/ast/arith_decl_plugin.cpp | 17 ++++++--- src/ast/arith_decl_plugin.h | 51 ++++++++++++++------------- src/sat/smt/arith_axioms.cpp | 28 +++++++-------- src/sat/smt/arith_solver.cpp | 3 -- src/sat/smt/intblast_solver.h | 4 ++- src/sat/smt/polysat/core.cpp | 5 ++- src/sat/smt/polysat/core.h | 9 ++--- src/sat/smt/polysat/op_constraint.cpp | 16 ++++++++- src/sat/smt/polysat_model.cpp | 10 ++---- src/sat/smt/polysat_solver.cpp | 47 ++++++++++++++---------- src/sat/smt/polysat_solver.h | 3 +- 11 files changed, 111 insertions(+), 82 deletions(-) diff --git a/src/ast/arith_decl_plugin.cpp b/src/ast/arith_decl_plugin.cpp index 2d830d510..8317b37c3 100644 --- a/src/ast/arith_decl_plugin.cpp +++ b/src/ast/arith_decl_plugin.cpp @@ -707,7 +707,16 @@ expr * arith_decl_plugin::get_some_value(sort * s) { return mk_numeral(rational(0), s == m_int_decl); } -bool arith_recognizers::is_numeral(expr const * n, rational & val, bool & is_int) const { +bool arith_util::is_numeral(expr const * n, rational & val, bool & is_int) const { + if (is_irrational_algebraic_numeral(n)) { + scoped_anum an(am()); + is_irrational_algebraic_numeral2(n, an); + if (am().is_rational(an)) { + am().to_rational(an, val); + is_int = val.is_int(); + return true; + } + } if (!is_app_of(n, arith_family_id, OP_NUM)) return false; func_decl * decl = to_app(n)->get_decl(); @@ -738,7 +747,7 @@ bool arith_recognizers::is_int_expr(expr const *e) const { if (is_to_real(e)) { // pass } - else if (is_numeral(e, r) && r.is_int()) { + else if (is_numeral(e) && is_int(e)) { // pass } else if (is_add(e) || is_mul(e)) { @@ -761,14 +770,14 @@ void arith_util::init_plugin() { m_plugin = static_cast(m_manager.get_plugin(arith_family_id)); } -bool arith_util::is_irrational_algebraic_numeral2(expr const * n, algebraic_numbers::anum & val) { +bool arith_util::is_irrational_algebraic_numeral2(expr const * n, algebraic_numbers::anum & val) const { if (!is_app_of(n, arith_family_id, OP_IRRATIONAL_ALGEBRAIC_NUM)) return false; am().set(val, to_irrational_algebraic_numeral(n)); return true; } -algebraic_numbers::anum const & arith_util::to_irrational_algebraic_numeral(expr const * n) { +algebraic_numbers::anum const & arith_util::to_irrational_algebraic_numeral(expr const * n) const { SASSERT(is_irrational_algebraic_numeral(n)); return plugin().aw().to_anum(to_app(n)->get_decl()); } diff --git a/src/ast/arith_decl_plugin.h b/src/ast/arith_decl_plugin.h index b073e205e..25c4977e9 100644 --- a/src/ast/arith_decl_plugin.h +++ b/src/ast/arith_decl_plugin.h @@ -237,26 +237,10 @@ public: family_id get_family_id() const { return arith_family_id; } bool is_arith_expr(expr const * n) const { return is_app(n) && to_app(n)->get_family_id() == arith_family_id; } - bool is_irrational_algebraic_numeral(expr const * n) const; - bool is_unsigned(expr const * n, unsigned& u) const { - rational val; - bool is_int = true; - return is_numeral(n, val, is_int) && is_int && val.is_unsigned() && (u = val.get_unsigned(), true); - } - bool is_numeral(expr const * n, rational & val, bool & is_int) const; - bool is_numeral(expr const * n, rational & val) const { bool is_int; return is_numeral(n, val, is_int); } - bool is_numeral(expr const * n) const { return is_app_of(n, arith_family_id, OP_NUM); } - bool is_zero(expr const * n) const { rational val; return is_numeral(n, val) && val.is_zero(); } - bool is_minus_one(expr * n) const { rational tmp; return is_numeral(n, tmp) && tmp.is_minus_one(); } - // return true if \c n is a term of the form (* -1 r) - bool is_times_minus_one(expr * n, expr * & r) const { - if (is_mul(n) && to_app(n)->get_num_args() == 2 && is_minus_one(to_app(n)->get_arg(0))) { - r = to_app(n)->get_arg(1); - return true; - } - return false; - } + bool is_irrational_algebraic_numeral(expr const* n) const; + + bool is_numeral(expr const* n) const { return is_app_of(n, arith_family_id, OP_NUM); } bool is_int_expr(expr const * e) const; bool is_le(expr const * n) const { return is_app_of(n, arith_family_id, OP_LE); } @@ -399,13 +383,32 @@ public: return *m_plugin; } - algebraic_numbers::manager & am() { + algebraic_numbers::manager & am() const { return plugin().am(); } + // return true if \c n is a term of the form (* -1 r) + bool is_zero(expr const* n) const { rational val; return is_numeral(n, val) && val.is_zero(); } + bool is_minus_one(expr* n) const { rational tmp; return is_numeral(n, tmp) && tmp.is_minus_one(); } + bool is_times_minus_one(expr* n, expr*& r) const { + if (is_mul(n) && to_app(n)->get_num_args() == 2 && is_minus_one(to_app(n)->get_arg(0))) { + r = to_app(n)->get_arg(1); + return true; + } + return false; + } + bool is_unsigned(expr const* n, unsigned& u) const { + rational val; + bool is_int = true; + return is_numeral(n, val, is_int) && is_int && val.is_unsigned() && (u = val.get_unsigned(), true); + } + bool is_numeral(expr const* n) const { return arith_recognizers::is_numeral(n); } + bool is_numeral(expr const* n, rational& val, bool& is_int) const; + bool is_numeral(expr const* n, rational& val) const { bool is_int; return is_numeral(n, val, is_int); } + bool convert_int_numerals_to_real() const { return plugin().convert_int_numerals_to_real(); } - bool is_irrational_algebraic_numeral2(expr const * n, algebraic_numbers::anum & val); - algebraic_numbers::anum const & to_irrational_algebraic_numeral(expr const * n); + bool is_irrational_algebraic_numeral2(expr const * n, algebraic_numbers::anum & val) const; + algebraic_numbers::anum const & to_irrational_algebraic_numeral(expr const * n) const; sort * mk_int() { return m_manager.mk_sort(arith_family_id, INT_SORT); } sort * mk_real() { return m_manager.mk_sort(arith_family_id, REAL_SORT); } @@ -512,11 +515,11 @@ public: if none of them are numerals, then the left-hand-side has a smaller id than the right hand side. */ app * mk_eq(expr * lhs, expr * rhs) { - if (is_numeral(lhs) || (!is_numeral(rhs) && lhs->get_id() > rhs->get_id())) + if (arith_recognizers::is_numeral(lhs) || (!arith_recognizers::is_numeral(rhs) && lhs->get_id() > rhs->get_id())) std::swap(lhs, rhs); if (lhs == rhs) return m_manager.mk_true(); - if (is_numeral(lhs) && is_numeral(rhs)) { + if (arith_recognizers::is_numeral(lhs) && arith_recognizers::is_numeral(rhs)) { SASSERT(lhs != rhs); return m_manager.mk_false(); } diff --git a/src/sat/smt/arith_axioms.cpp b/src/sat/smt/arith_axioms.cpp index 09db74f75..f004422a6 100644 --- a/src/sat/smt/arith_axioms.cpp +++ b/src/sat/smt/arith_axioms.cpp @@ -211,25 +211,23 @@ namespace arith { if (!ctx.is_relevant(expr2enode(n))) return true; VERIFY(a.is_band(n, sz, x, y)); - if (use_nra_model()) { + expr_ref vx(m), vy(m),vn(m); + if (!get_value(expr2enode(x), vx) || !get_value(expr2enode(y), vy) || !get_value(expr2enode(n), vn)) { + IF_VERBOSE(2, verbose_stream() << "could not get value of " << mk_pp(n, m) << "\n"); found_unsupported(n); return true; } - theory_var vx = expr2enode(x)->get_th_var(get_id()); - theory_var vy = expr2enode(y)->get_th_var(get_id()); - theory_var vn = expr2enode(n)->get_th_var(get_id()); - rational N = rational::power_of_two(sz); - if (!get_value(vx).is_int() || !get_value(vy).is_int()) { - - s().display(verbose_stream()); - verbose_stream() << vx << " " << vy << " " << mk_pp(n, m) << "\n"; + rational valn, valx, valy; + bool is_int; + if (!a.is_numeral(vn, valn, is_int) || !is_int || !a.is_numeral(vx, valx, is_int) || !is_int || !a.is_numeral(vy, valy, is_int) || !is_int) { + IF_VERBOSE(2, verbose_stream() << "could not get value of " << mk_pp(n, m) << "\n"); + found_unsupported(n); + return true; } - SASSERT(get_value(vx).is_int()); - SASSERT(get_value(vy).is_int()); - SASSERT(get_value(vn).is_int()); - rational valx = mod(get_value(vx), N); - rational valy = mod(get_value(vy), N); - rational valn = get_value(vn); + // verbose_stream() << "band: " << mk_pp(n, m) << " " << valn << " := " << valx << "&" << valy << "\n"; + rational N = rational::power_of_two(sz); + valx = mod(valx, N); + valy = mod(valy, N); SASSERT(0 <= valn && valn < N); // x mod 2^{i + 1} >= 2^i means the i'th bit is 1. diff --git a/src/sat/smt/arith_solver.cpp b/src/sat/smt/arith_solver.cpp index 37aef2bf8..eff25bc4a 100644 --- a/src/sat/smt/arith_solver.cpp +++ b/src/sat/smt/arith_solver.cpp @@ -628,9 +628,6 @@ namespace arith { } else if (use_nra_model() && lp().external_to_local(v) != lp::null_lpvar) { anum const& an = nl_value(v, m_nla->tmp1()); - - - if (a.is_int(o) && !m_nla->am().is_int(an)) value = a.mk_numeral(rational::zero(), a.is_int(o)); else diff --git a/src/sat/smt/intblast_solver.h b/src/sat/smt/intblast_solver.h index 493b1f3c5..d59dac935 100644 --- a/src/sat/smt/intblast_solver.h +++ b/src/sat/smt/intblast_solver.h @@ -64,7 +64,7 @@ namespace intblast { void translate(expr_ref_vector& es); void sorted_subterms(expr_ref_vector& es, ptr_vector& sorted); - rational get_value(expr* e) const; + bool is_translated(expr* e) const { return !!m_translate.get(e->get_id(), nullptr); } expr* translated(expr* e) const { expr* r = m_translate.get(e->get_id(), nullptr); SASSERT(r); return r; } @@ -136,6 +136,8 @@ namespace intblast { void eq_internalized(euf::enode* n) override; + rational get_value(expr* e) const; + }; } diff --git a/src/sat/smt/polysat/core.cpp b/src/sat/smt/polysat/core.cpp index 0607f530d..c9deb5726 100644 --- a/src/sat/smt/polysat/core.cpp +++ b/src/sat/smt/polysat/core.cpp @@ -177,10 +177,9 @@ namespace polysat { s.set_lemma(m_viable.get_core(), m_viable.explain()); // propagate_unsat_core(); return sat::check_result::CR_CONTINUE; - case find_t::singleton: { + case find_t::singleton: s.propagate(m_constraints.eq(var2pdd(m_var), m_value), m_viable.explain()); - return sat::check_result::CR_CONTINUE; - } + return sat::check_result::CR_CONTINUE; case find_t::multiple: s.add_eq_literal(m_var, m_value); return sat::check_result::CR_CONTINUE; diff --git a/src/sat/smt/polysat/core.h b/src/sat/smt/polysat/core.h index fb0875ec8..46661dc84 100644 --- a/src/sat/smt/polysat/core.h +++ b/src/sat/smt/polysat/core.h @@ -60,10 +60,10 @@ namespace polysat { // attributes associated with variables vector m_vars; // for each variable a pdd vector m_values; // current value of assigned variable - svector m_justification; // justification for assignment - activity m_activity; // activity of variables - var_queue m_var_queue; // priority queue of variables to assign - vector m_watch; // watch lists for variables for constraints on m_prop_queue where they occur + svector m_justification; // justification for assignment + activity m_activity; // activity of variables + var_queue m_var_queue; // priority queue of variables to assign + vector m_watch; // watch lists for variables for constraints on m_prop_queue where they occur // values to split on rational m_value; @@ -101,6 +101,7 @@ namespace polysat { constraint_id register_constraint(signed_constraint& sc, dependency d); bool propagate(); void assign_eh(constraint_id idx, bool sign, unsigned level); + pvar next_var() { return m_var_queue.next_var(); } pdd value(rational const& v, unsigned sz); pdd subst(pdd const&); diff --git a/src/sat/smt/polysat/op_constraint.cpp b/src/sat/smt/polysat/op_constraint.cpp index b7d312d55..175d3a145 100644 --- a/src/sat/smt/polysat/op_constraint.cpp +++ b/src/sat/smt/polysat/op_constraint.cpp @@ -96,7 +96,21 @@ namespace polysat { } lbool op_constraint::eval_ashr(pdd const& p, pdd const& q, pdd const& r) { - NOT_IMPLEMENTED_YET(); + auto& m = p.manager(); + if (r.is_val() && p.is_val() && q.is_val()) { + auto M = m.max_value(); + auto N = M + 1; + if (p.val() >= N/2) { + if (q.val() >= m.power_of_2()) + return to_lbool(r.val() == M); + unsigned k = q.val().get_unsigned(); + return to_lbool(r.val() == p.val() - rational::power_of_two(k)); + } + else + return eval_lshr(p, q, r); + } + if (q.is_val() && q.is_zero() && p == r) + return l_true; return l_undef; } diff --git a/src/sat/smt/polysat_model.cpp b/src/sat/smt/polysat_model.cpp index 5bd8d4dc9..028aeed6b 100644 --- a/src/sat/smt/polysat_model.cpp +++ b/src/sat/smt/polysat_model.cpp @@ -23,12 +23,7 @@ Author: namespace polysat { - void solver::add_value(euf::enode* n, model& mdl, expr_ref_vector& values) { - - if (m_use_intblast_model) { - m_intblast.add_value(n, mdl, values); - return; - } + void solver::add_value(euf::enode* n, model& mdl, expr_ref_vector& values) { auto p = expr2pdd(n->get_expr()); rational val; if (!m_core.try_eval(p, val)) { @@ -82,8 +77,7 @@ namespace polysat { for (unsigned v = 0; v < get_num_vars(); ++v) if (m_var2pdd_valid.get(v, false)) out << ctx.bpp(var2enode(v)) << " := " << m_var2pdd[v] << "\n"; - if (m_use_intblast_model) - m_intblast.display(out); + m_intblast.display(out); return out; } } diff --git a/src/sat/smt/polysat_solver.cpp b/src/sat/smt/polysat_solver.cpp index 75dd09075..219b9017a 100644 --- a/src/sat/smt/polysat_solver.cpp +++ b/src/sat/smt/polysat_solver.cpp @@ -61,25 +61,36 @@ namespace polysat { return sat::check_result::CR_DONE; case sat::check_result::CR_CONTINUE: return sat::check_result::CR_CONTINUE; - case sat::check_result::CR_GIVEUP: { - if (!m.inc()) - return sat::check_result::CR_GIVEUP; - switch (m_intblast.check_solver_state()) { - case l_true: - trail().push(value_trail(m_use_intblast_model)); - m_use_intblast_model = true; - return sat::check_result::CR_DONE; - case l_false: { - auto core = m_intblast.unsat_core(); - for (auto& lit : core) - lit.neg(); - s().add_clause(core.size(), core.data(), sat::status::th(true, get_id(), nullptr)); - return sat::check_result::CR_CONTINUE; - } - case l_undef: - return sat::check_result::CR_GIVEUP; - } + case sat::check_result::CR_GIVEUP: + return intblast(); } + UNREACHABLE(); + return sat::check_result::CR_GIVEUP; + } + + sat::check_result solver::intblast() { + if (!m.inc()) + return sat::check_result::CR_GIVEUP; + switch (m_intblast.check_solver_state()) { + case l_true: { + pvar pv = m_core.next_var(); + auto v = m_pddvar2var[pv]; + auto n = var2expr(v); + auto val = m_intblast.get_value(n); + sat::literal lit = eq_internalize(n, bv.mk_numeral(val, get_bv_size(v))); + s().set_phase(lit); + return sat::check_result::CR_CONTINUE; + } + case l_false: { + IF_VERBOSE(2, verbose_stream() << "unsat core: " << m_intblast.unsat_core() << "\n"); + auto core = m_intblast.unsat_core(); + for (auto& lit : core) + lit.neg(); + s().add_clause(core.size(), core.data(), sat::status::th(true, get_id(), nullptr)); + return sat::check_result::CR_CONTINUE; + } + case l_undef: + return sat::check_result::CR_GIVEUP; } UNREACHABLE(); return sat::check_result::CR_GIVEUP; diff --git a/src/sat/smt/polysat_solver.h b/src/sat/smt/polysat_solver.h index 0ecc8941b..60535207b 100644 --- a/src/sat/smt/polysat_solver.h +++ b/src/sat/smt/polysat_solver.h @@ -59,7 +59,6 @@ namespace polysat { stats m_stats; core m_core; intblast::solver m_intblast; - bool m_use_intblast_model = false; vector m_var2pdd; // theory_var 2 pdd bool_vector m_var2pdd_valid; // valid flag @@ -73,6 +72,8 @@ namespace polysat { unsigned m_lemma_level = 0; expr_ref_vector m_lemma; + sat::check_result intblast(); + // internalize bool visit(expr* e) override; bool visited(expr* e) override;