From c72780d9b92bea096b98c83a6abc031305637665 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Mon, 11 Dec 2023 20:22:23 -0800 Subject: [PATCH] b-and, stats, reinsert variable to heap, debugging --- src/math/lp/lp_api.h | 2 ++ src/sat/smt/arith_axioms.cpp | 34 ++++++++++-------- src/sat/smt/arith_internalize.cpp | 1 - src/sat/smt/arith_solver.cpp | 5 +-- src/sat/smt/intblast_solver.cpp | 55 +++++++++++++++++++++++++++-- src/sat/smt/intblast_solver.h | 8 ++++- src/sat/smt/polysat/core.cpp | 18 ++++++++-- src/sat/smt/polysat/core.h | 3 ++ src/sat/smt/polysat_internalize.cpp | 9 ++++- src/sat/smt/polysat_model.cpp | 35 +++++++----------- src/sat/smt/polysat_solver.cpp | 6 ++++ src/sat/smt/polysat_solver.h | 3 +- 12 files changed, 132 insertions(+), 47 deletions(-) diff --git a/src/math/lp/lp_api.h b/src/math/lp/lp_api.h index 2a4e5058d..0eb8b6b37 100644 --- a/src/math/lp/lp_api.h +++ b/src/math/lp/lp_api.h @@ -108,6 +108,7 @@ namespace lp_api { unsigned m_gomory_cuts; unsigned m_assume_eqs; unsigned m_branch; + unsigned m_band_axioms; stats() { reset(); } void reset() { memset(this, 0, sizeof(*this)); @@ -128,6 +129,7 @@ namespace lp_api { st.update("arith-gomory-cuts", m_gomory_cuts); st.update("arith-assume-eqs", m_assume_eqs); st.update("arith-branch", m_branch); + st.update("arith-band-axioms", m_band_axioms); } }; diff --git a/src/sat/smt/arith_axioms.cpp b/src/sat/smt/arith_axioms.cpp index b8bffa5f2..046470000 100644 --- a/src/sat/smt/arith_axioms.cpp +++ b/src/sat/smt/arith_axioms.cpp @@ -215,10 +215,15 @@ namespace arith { } theory_var vx = expr2enode(x)->get_th_var(get_id()); theory_var vy = expr2enode(y)->get_th_var(get_id()); - theory_var xn = expr2enode(n)->get_th_var(get_id()); - rational valx = get_value(vx); - rational valy = get_value(vy); - rational valn = get_value(xn); + theory_var vn = expr2enode(n)->get_th_var(get_id()); + rational N = rational::power_of_two(sz); + SASSERT(get_value(vx).is_int()); + SASSERT(get_value(vy).is_int()); + SASSERT(get_value(vn).is_int()); + rational valx = mod(get_value(vx), N); + rational valy = mod(get_value(vy), N); + rational valn = get_value(vn); + SASSERT(0 <= valn && valn < N); // x mod 2^{i + 1} >= 2^i means the i'th bit is 1. auto bitof = [&](expr* x, unsigned i) { @@ -230,26 +235,25 @@ namespace arith { bool xb = valx.get_bit(i); bool yb = valy.get_bit(i); bool nb = valn.get_bit(i); - if (xb && yb && !nb) { + if (xb && yb && !nb) add_clause(~bitof(x, i), ~bitof(y, i), bitof(n, i)); - return false; - } - if (nb && !xb) { + else if (nb && !xb) add_clause(~bitof(n, i), bitof(x, i)); - return false; - } - if (nb && !yb) { + else if (nb && !yb) add_clause(~bitof(n, i), bitof(y, i)); - return false; - } + else + continue; + return false; } return true; } bool solver::check_band_terms() { for (app* n : m_band_terms) { - if (!check_band_term(n)) - return false; + if (!check_band_term(n)) { + ++m_stats.m_band_axioms; + return false; + } } return true; } diff --git a/src/sat/smt/arith_internalize.cpp b/src/sat/smt/arith_internalize.cpp index 4d0943d65..decd49019 100644 --- a/src/sat/smt/arith_internalize.cpp +++ b/src/sat/smt/arith_internalize.cpp @@ -253,7 +253,6 @@ namespace arith { st.to_ensure_var().push_back(n2); } else if (a.is_band(n)) { - // unsupported for now. m_band_terms.push_back(to_app(n)); mk_band_axiom(to_app(n)); ctx.push(push_back_vector(m_band_terms)); diff --git a/src/sat/smt/arith_solver.cpp b/src/sat/smt/arith_solver.cpp index 9e03bbee4..306a6cce0 100644 --- a/src/sat/smt/arith_solver.cpp +++ b/src/sat/smt/arith_solver.cpp @@ -1042,6 +1042,9 @@ namespace arith { if (!check_delayed_eqs()) return sat::check_result::CR_CONTINUE; + if (!check_band_terms()) + return sat::check_result::CR_CONTINUE; + if (ctx.get_config().m_arith_ignore_int && int_undef) return sat::check_result::CR_GIVEUP; if (m_not_handled != nullptr) { @@ -1197,8 +1200,6 @@ namespace arith { default: UNREACHABLE(); } - if (lia_check == l_true && !check_band_terms()) - lia_check = l_false; return lia_check; } diff --git a/src/sat/smt/intblast_solver.cpp b/src/sat/smt/intblast_solver.cpp index 65dc56e00..db5798236 100644 --- a/src/sat/smt/intblast_solver.cpp +++ b/src/sat/smt/intblast_solver.cpp @@ -13,6 +13,7 @@ Author: #include "ast/ast_util.h" #include "ast/for_each_expr.h" +#include "ast/rewriter/bv_rewriter.h" #include "params/bv_rewriter_params.hpp" #include "sat/smt/intblast_solver.h" #include "sat/smt/euf_solver.h" @@ -104,6 +105,8 @@ namespace intblast { lbool r = m_solver->check_sat(es); + m_solver->collect_statistics(m_stats); + IF_VERBOSE(2, verbose_stream() << "(sat.intblast :result " << r << ")\n"); if (r == l_false) { @@ -472,9 +475,32 @@ namespace intblast { else m_trail.push_back(a.mk_mod(x, y)); break; - } + } + // + // ashr(x, y) + // if y = k & x >= 0 -> x / 2^k + // if y = k & x < 0 -> - (x / 2^k) + // + + case OP_BASHR: { + expr* x = args.get(0), * y = args.get(1); + rational N = rational::power_of_two(bv.get_bv_size(e)); + bv_expr = ap; + x = mk_mod(x); + y = mk_mod(y); + expr* signbit = a.mk_ge(x, a.mk_int(N/2)); + expr* r = m.mk_ite(signbit, a.mk_int(N - 1), a.mk_int(0)); + for (unsigned i = 0; i < bv.get_bv_size(e); ++i) { + expr* d = a.mk_idiv(x, a.mk_int(rational::power_of_two(i))); + r = m.mk_ite(a.mk_eq(y, a.mk_int(i)), + m.mk_ite(signbit, a.mk_uminus(d), d), + r); + } + m_trail.push_back(r); + break; + } case OP_BCOMP: - case OP_BASHR: + case OP_ROTATE_LEFT: case OP_ROTATE_RIGHT: case OP_EXT_ROTATE_LEFT: @@ -524,6 +550,27 @@ namespace intblast { return val; } + void solver::add_value(euf::enode* n, model& mdl, expr_ref_vector& values) { + expr_ref value(m); + if (n->interpreted()) + value = n->get_expr(); + else if (to_app(n->get_expr())->get_family_id() == bv.get_family_id()) { + bv_rewriter rw(m); + expr_ref_vector args(m); + for (auto arg : euf::enode_args(n)) + args.push_back(values.get(arg->get_root_id())); + rw.mk_app(n->get_decl(), args.size(), args.data(), value); + VERIFY(value); + } + else { + rational r = get_value(n->get_expr()); + verbose_stream() << ctx.bpp(n) << " := " << r << "\n"; + value = bv.mk_numeral(r, bv.get_bv_size(n->get_expr())); + } + values.set(n->get_root_id(), value); + TRACE("model", tout << "add_value " << ctx.bpp(n) << " := " << value << "\n"); + } + sat::literal_vector const& solver::unsat_core() { return m_core; } @@ -534,4 +581,8 @@ namespace intblast { return out; } + void solver::collect_statistics(statistics& st) const { + st.copy(m_stats); + } + } diff --git a/src/sat/smt/intblast_solver.h b/src/sat/smt/intblast_solver.h index a093713bb..b87724cc8 100644 --- a/src/sat/smt/intblast_solver.h +++ b/src/sat/smt/intblast_solver.h @@ -25,6 +25,7 @@ Author: #include "ast/bv_decl_plugin.h" #include "solver/solver.h" #include "sat/smt/sat_th.h" +#include "util/statistics.h" namespace euf { class solver; @@ -49,11 +50,14 @@ namespace intblast { expr_ref_vector m_trail; ast_ref_vector m_pinned; sat::literal_vector m_core; + statistics m_stats; bool is_bv(sat::literal lit); void translate(expr_ref_vector& es); void sorted_subterms(expr_ref_vector& es, ptr_vector& sorted); + rational get_value(expr* e) const; + public: solver(euf::solver& ctx); @@ -61,9 +65,11 @@ namespace intblast { sat::literal_vector const& unsat_core(); - rational get_value(expr* e) const; + void add_value(euf::enode* n, model& mdl, expr_ref_vector& values); std::ostream& display(std::ostream& out) const; + + void collect_statistics(statistics& st) const; }; } diff --git a/src/sat/smt/polysat/core.cpp b/src/sat/smt/polysat/core.cpp index 8e779923d..a552bb9ab 100644 --- a/src/sat/smt/polysat/core.cpp +++ b/src/sat/smt/polysat/core.cpp @@ -119,7 +119,7 @@ namespace polysat { m_activity.pop_back(); m_justification.pop_back(); m_watch.pop_back(); - m_values.pop_back(); + m_values.pop_back(); m_var_queue.del_var_eh(v); } @@ -160,6 +160,7 @@ namespace polysat { s.add_eq_literal(m_var, m_value); return sat::check_result::CR_CONTINUE; case find_t::resource_out: + m_var_queue.unassign_var_eh(m_var); return sat::check_result::CR_GIVEUP; } UNREACHABLE(); @@ -342,8 +343,21 @@ namespace polysat { for (auto const& [sc, d, value] : m_constraint_index) out << sc << " " << d << " := " << value << "\n"; for (unsigned i = 0; i < m_vars.size(); ++i) - out << "p" << m_vars[i] << " := " << m_values[i] << " " << m_justification[i] << "\n"; + out << m_vars[i] << " := " << m_values[i] << " " << m_justification[i] << "\n"; + m_var_queue.display(out << "vars ") << "\n"; return out; } + bool core::try_eval(pdd const& p, rational& r) { + auto q = subst(p); + if (!q.is_val()) + return false; + r = q.val(); + return true; + } + + void core::collect_statistics(statistics& st) const { + + } + } diff --git a/src/sat/smt/polysat/core.h b/src/sat/smt/polysat/core.h index c6442a290..c3dddfece 100644 --- a/src/sat/smt/polysat/core.h +++ b/src/sat/smt/polysat/core.h @@ -104,6 +104,9 @@ namespace polysat { pdd value(rational const& v, unsigned sz); pdd subst(pdd const&); + bool try_eval(pdd const& p, rational& r); + + void collect_statistics(statistics& st) const; signed_constraint eq(pdd const& p) { return m_constraints.eq(p); } signed_constraint eq(pdd const& p, pdd const& q) { return m_constraints.eq(p - q); } diff --git a/src/sat/smt/polysat_internalize.cpp b/src/sat/smt/polysat_internalize.cpp index 4496dc759..5e5647bd3 100644 --- a/src/sat/smt/polysat_internalize.cpp +++ b/src/sat/smt/polysat_internalize.cpp @@ -86,6 +86,7 @@ namespace polysat { case OP_BSUB: internalize_binary(a, [&](pdd const& p, pdd const& q) { return p - q; }); break; case OP_BLSHR: internalize_lshr(a); break; case OP_BSHL: internalize_shl(a); break; + case OP_BASHR: internalize_ashr(a); break; case OP_BAND: internalize_band(a); break; case OP_BOR: internalize_bor(a); break; case OP_BXOR: internalize_bxor(a); break; @@ -148,7 +149,7 @@ namespace polysat { case OP_BSDIV_I: case OP_BSREM_I: case OP_BSMOD_I: - case OP_BASHR: + IF_VERBOSE(0, verbose_stream() << mk_pp(a, m) << "\n"); NOT_IMPLEMENTED_YET(); return; @@ -254,6 +255,12 @@ namespace polysat { auto sc = m_core.lshr(expr2pdd(x), expr2pdd(y), expr2pdd(n)); } + void solver::internalize_ashr(app* n) { + expr* x, * y; + VERIFY(bv.is_bv_ashr(n, x, y)); + auto sc = m_core.ashr(expr2pdd(x), expr2pdd(y), expr2pdd(n)); + } + void solver::internalize_shl(app* n) { expr* x, * y; VERIFY(bv.is_bv_shl(n, x, y)); diff --git a/src/sat/smt/polysat_model.cpp b/src/sat/smt/polysat_model.cpp index 9a44e0abf..5bd8d4dc9 100644 --- a/src/sat/smt/polysat_model.cpp +++ b/src/sat/smt/polysat_model.cpp @@ -26,32 +26,18 @@ namespace polysat { void solver::add_value(euf::enode* n, model& mdl, expr_ref_vector& values) { if (m_use_intblast_model) { - expr_ref value(m); - if (n->interpreted()) - value = n->get_expr(); - else if (to_app(n->get_expr())->get_family_id() == bv.get_family_id()) { - bv_rewriter rw(m); - expr_ref_vector args(m); - for (auto arg : euf::enode_args(n)) - args.push_back(values.get(arg->get_root_id())); - rw.mk_app(n->get_decl(), args.size(), args.data(), value); - VERIFY(value); - } - else { - rational r = m_intblast.get_value(n->get_expr()); - verbose_stream() << ctx.bpp(n) << " := " << r << "\n"; - value = bv.mk_numeral(r, get_bv_size(n)); - } - values.set(n->get_root_id(), value); - TRACE("model", tout << "add_value " << ctx.bpp(n) << " := " << value << "\n"); + m_intblast.add_value(n, mdl, values); return; } -#if 0 auto p = expr2pdd(n->get_expr()); rational val; - VERIFY(m_polysat.try_eval(p, val)); - values[n->get_root_id()] = bv.mk_numeral(val, get_bv_size(n)); -#endif + if (!m_core.try_eval(p, val)) { + ctx.s().display(verbose_stream()); + verbose_stream() << ctx.bpp(n) << " := " << p << "\n"; + UNREACHABLE(); + } + VERIFY(m_core.try_eval(p, val)); + values.set(n->get_root_id(), bv.mk_numeral(val, get_bv_size(n))); } bool solver::add_dep(euf::enode* n, top_sort& dep) { @@ -78,6 +64,11 @@ namespace polysat { } + void solver::collect_statistics(statistics& st) const { + m_intblast.collect_statistics(st); + m_core.collect_statistics(st); + } + std::ostream& solver::display_justification(std::ostream& out, sat::ext_justification_idx idx) const { return out; } diff --git a/src/sat/smt/polysat_solver.cpp b/src/sat/smt/polysat_solver.cpp index 43f156c7d..9f185b22d 100644 --- a/src/sat/smt/polysat_solver.cpp +++ b/src/sat/smt/polysat_solver.cpp @@ -184,6 +184,9 @@ namespace polysat { void solver::new_eq_eh(euf::th_eq const& eq) { auto v1 = eq.v1(), v2 = eq.v2(); + euf::enode* n = var2enode(v1); + if (!bv.is_bv(n->get_expr())) + return; pdd p = var2pdd(v1); pdd q = var2pdd(v2); auto sc = m_core.eq(p, q); @@ -197,6 +200,9 @@ namespace polysat { void solver::new_diseq_eh(euf::th_eq const& ne) { euf::theory_var v1 = ne.v1(), v2 = ne.v2(); + euf::enode* n = var2enode(v1); + if (!bv.is_bv(n->get_expr())) + return; pdd p = var2pdd(v1); pdd q = var2pdd(v2); auto sc = ~m_core.eq(p, q); diff --git a/src/sat/smt/polysat_solver.h b/src/sat/smt/polysat_solver.h index 7cf176b0c..f54bafb1c 100644 --- a/src/sat/smt/polysat_solver.h +++ b/src/sat/smt/polysat_solver.h @@ -121,6 +121,7 @@ namespace polysat { void internalize_bxnor(app* n); void internalize_band(app* n); void internalize_lshr(app* n); + void internalize_ashr(app* n); void internalize_shl(app* n); template void internalize_le(app* n); @@ -174,7 +175,7 @@ namespace polysat { std::ostream& display(std::ostream& out) const override; std::ostream& display_justification(std::ostream& out, sat::ext_justification_idx idx) const override; std::ostream& display_constraint(std::ostream& out, sat::ext_constraint_idx idx) const override; - void collect_statistics(statistics& st) const override {} + void collect_statistics(statistics& st) const override; euf::th_solver* clone(euf::solver& ctx) override { return alloc(solver, ctx, get_id()); } extension* copy(sat::solver* s) override { throw default_exception("nyi"); } void find_mutexes(literal_vector& lits, vector & mutexes) override {}