From 5dfe86fc2d73aaf0fd48048114cbb097972962df Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Wed, 13 Dec 2023 14:13:16 -0800 Subject: [PATCH] bugfixes in intblast solver Signed-off-by: Nikolaj Bjorner --- src/math/lp/int_solver.cpp | 6 +- src/sat/smt/CMakeLists.txt | 1 + src/sat/smt/arith_axioms.cpp | 1 + src/sat/smt/arith_solver.cpp | 35 +++++-- src/sat/smt/arith_solver.h | 2 + src/sat/smt/dt_solver.cpp | 2 +- src/sat/smt/intblast_solver.cpp | 171 ++++++++++++++++++++++---------- src/sat/smt/intblast_solver.h | 13 ++- src/smt/theory_datatype.cpp | 2 +- src/util/trail.h | 6 +- 10 files changed, 163 insertions(+), 76 deletions(-) diff --git a/src/math/lp/int_solver.cpp b/src/math/lp/int_solver.cpp index c324af5b6..9cbc765d4 100644 --- a/src/math/lp/int_solver.cpp +++ b/src/math/lp/int_solver.cpp @@ -207,8 +207,10 @@ namespace lp { #endif m_cut_vars.reset(); - if (r == lia_move::undef) r = int_branch(*this)(); - if (settings().get_cancel_flag()) r = lia_move::undef; + if (settings().get_cancel_flag()) + return lia_move::undef; + if (r == lia_move::undef) + r = int_branch(*this)(); return r; } diff --git a/src/sat/smt/CMakeLists.txt b/src/sat/smt/CMakeLists.txt index 2302a6c39..1ed9c05ca 100644 --- a/src/sat/smt/CMakeLists.txt +++ b/src/sat/smt/CMakeLists.txt @@ -5,6 +5,7 @@ z3_add_component(sat_smt arith_internalize.cpp arith_sls.cpp arith_solver.cpp + arith_value.cpp array_axioms.cpp array_diagnostics.cpp array_internalize.cpp diff --git a/src/sat/smt/arith_axioms.cpp b/src/sat/smt/arith_axioms.cpp index 09db74f75..0150824b2 100644 --- a/src/sat/smt/arith_axioms.cpp +++ b/src/sat/smt/arith_axioms.cpp @@ -250,6 +250,7 @@ namespace arith { add_clause(~bitof(n, i), bitof(y, i)); else continue; + verbose_stream() << "added b-and clause\n"; return false; } return true; diff --git a/src/sat/smt/arith_solver.cpp b/src/sat/smt/arith_solver.cpp index 306a6cce0..37aef2bf8 100644 --- a/src/sat/smt/arith_solver.cpp +++ b/src/sat/smt/arith_solver.cpp @@ -619,17 +619,20 @@ namespace arith { } } - void solver::add_value(euf::enode* n, model& mdl, expr_ref_vector& values) { + bool solver::get_value(euf::enode* n, expr_ref& value) { theory_var v = n->get_th_var(get_id()); expr* o = n->get_expr(); - expr_ref value(m); + if (m.is_value(n->get_root()->get_expr())) { value = n->get_root()->get_expr(); } else if (use_nra_model() && lp().external_to_local(v) != lp::null_lpvar) { anum const& an = nl_value(v, m_nla->tmp1()); + + + if (a.is_int(o) && !m_nla->am().is_int(an)) - value = a.mk_numeral(rational::zero(), a.is_int(o)); + value = a.mk_numeral(rational::zero(), a.is_int(o)); else value = a.mk_numeral(m_nla->am(), nl_value(v, m_nla->tmp1()), a.is_int(o)); } @@ -637,24 +640,35 @@ namespace arith { rational r = get_value(v); TRACE("arith", tout << mk_pp(o, m) << " v" << v << " := " << r << "\n";); SASSERT("integer variables should have integer values: " && (ctx.get_config().m_arith_ignore_int || !a.is_int(o) || r.is_int() || m_not_handled != nullptr || m.limit().is_canceled())); - if (a.is_int(o) && !r.is_int()) + if (a.is_int(o) && !r.is_int()) r = floor(r); value = a.mk_numeral(r, o->get_sort()); } + else + return false; + + return true; + } + + + void solver::add_value(euf::enode* n, model& mdl, expr_ref_vector& values) { + expr_ref value(m); + expr* o = n->get_expr(); + if (get_value(n, value)) + ; else if (a.is_arith_expr(o) && reflect(o)) { expr_ref_vector args(m); for (auto* arg : *to_app(o)) { if (m.is_value(arg)) args.push_back(arg); - else + else args.push_back(values.get(ctx.get_enode(arg)->get_root_id())); } value = m.mk_app(to_app(o)->get_decl(), args.size(), args.data()); ctx.get_rewriter()(value); } - else { - value = mdl.get_fresh_value(o->get_sort()); - } + else + value = mdl.get_fresh_value(n->get_sort()); mdl.register_value(value); values.set(n->get_root_id(), value); } @@ -1042,7 +1056,7 @@ namespace arith { if (!check_delayed_eqs()) return sat::check_result::CR_CONTINUE; - if (!check_band_terms()) + if (!int_undef && !check_band_terms()) return sat::check_result::CR_CONTINUE; if (ctx.get_config().m_arith_ignore_int && int_undef) @@ -1195,7 +1209,8 @@ namespace arith { lia_check = l_undef; break; case lp::lia_move::continue_with_check: - lia_check = l_undef; + TRACE("arith", tout << "continue-with-check\n"); + lia_check = l_false; break; default: UNREACHABLE(); diff --git a/src/sat/smt/arith_solver.h b/src/sat/smt/arith_solver.h index 50cdc63ef..022dbeaea 100644 --- a/src/sat/smt/arith_solver.h +++ b/src/sat/smt/arith_solver.h @@ -526,6 +526,8 @@ namespace arith { bool add_eq(lpvar u, lpvar v, lp::explanation const& e, bool is_fixed); void consume(rational const& v, lp::constraint_index j); bool bound_is_interesting(unsigned vi, lp::lconstraint_kind kind, const rational& bval) const; + + bool get_value(euf::enode* n, expr_ref& val); }; diff --git a/src/sat/smt/dt_solver.cpp b/src/sat/smt/dt_solver.cpp index daecb7325..52c4ed953 100644 --- a/src/sat/smt/dt_solver.cpp +++ b/src/sat/smt/dt_solver.cpp @@ -400,7 +400,7 @@ namespace dt { return; } SASSERT(val == l_undef || (val == l_false && !d->m_constructor)); - ctx.push(set_vector_idx_trail(d->m_recognizers, c_idx)); + ctx.push(set_vector_idx_trail(d->m_recognizers, c_idx)); d->m_recognizers[c_idx] = recognizer; if (val == l_false) propagate_recognizer(v, recognizer); diff --git a/src/sat/smt/intblast_solver.cpp b/src/sat/smt/intblast_solver.cpp index 32bf52f79..9960197fb 100644 --- a/src/sat/smt/intblast_solver.cpp +++ b/src/sat/smt/intblast_solver.cpp @@ -17,6 +17,7 @@ Author: #include "params/bv_rewriter_params.hpp" #include "sat/smt/intblast_solver.h" #include "sat/smt/euf_solver.h" +#include "sat/smt/arith_value.h" namespace intblast { @@ -29,7 +30,8 @@ namespace intblast { bv(m), a(m), m_translate(m), - m_args(m) + m_args(m), + m_pinned(m) {} euf::theory_var solver::mk_var(euf::enode* n) { @@ -89,40 +91,70 @@ namespace intblast { expr* x, * y; VERIFY(m.is_eq(n->get_expr(), x, y)); SASSERT(bv.is_bv(x)); - ensure_translated(x); - ensure_translated(y); - m_args.reset(); - m_args.push_back(a.mk_sub(translated(x), translated(y))); - expr_ref lhs(umod(x, 0), m); - ctx.get_rewriter()(lhs); - add_equiv(expr2literal(e), eq_internalize(lhs, a.mk_int(0))); + if (!is_translated(e)) { + ensure_translated(x); + ensure_translated(y); + m_args.reset(); + m_args.push_back(a.mk_sub(translated(x), translated(y))); + set_translated(e, m.mk_eq(umod(x, 0), a.mk_int(0))); + } + m_preds.push_back(e); + ctx.push(push_back_vector(m_preds)); + } + + void solver::set_translated(expr* e, expr* r) { + SASSERT(r); + SASSERT(!is_translated(e)); + m_translate.setx(e->get_id(), r); + ctx.push(set_vector_idx_trail(m_translate, e->get_id())); } void solver::internalize_bv(app* e) { ensure_translated(e); - - // possibly wait until propagation? if (m.is_bool(e)) { - expr_ref r(translated(e), m); - ctx.get_rewriter()(r); - add_equiv(expr2literal(e), mk_literal(r)); + m_preds.push_back(e); + ctx.push(push_back_vector(m_preds)); } - add_bound_axioms(); } - void solver::add_bound_axioms() { + bool solver::add_bound_axioms() { if (m_vars_qhead == m_vars.size()) - return; + return false; ctx.push(value_trail(m_vars_qhead)); for (; m_vars_qhead < m_vars.size(); ++m_vars_qhead) { auto v = m_vars[m_vars_qhead]; auto w = translated(v); auto sz = rational::power_of_two(bv.get_bv_size(v->get_sort())); - add_unit(ctx.mk_literal(a.mk_ge(w, a.mk_int(0)))); - add_unit(ctx.mk_literal(a.mk_le(w, a.mk_int(sz - 1)))); + auto lo = ctx.mk_literal(a.mk_ge(w, a.mk_int(0))); + auto hi = ctx.mk_literal(a.mk_le(w, a.mk_int(sz - 1))); + ctx.mark_relevant(lo); + ctx.mark_relevant(hi); + add_unit(lo); + add_unit(hi); } + return true; } + bool solver::add_predicate_axioms() { + if (m_preds_qhead == m_preds.size()) + return false; + ctx.push(value_trail(m_preds_qhead)); + for (; m_preds_qhead < m_preds.size(); ++m_preds_qhead) { + expr* e = m_preds[m_preds_qhead]; + expr_ref r(translated(e), m); + ctx.get_rewriter()(r); + auto a = expr2literal(e); + auto b = mk_literal(r); + ctx.mark_relevant(b); + add_equiv(a, b); + } + return true; + } + + bool solver::unit_propagate() { + return add_bound_axioms() || add_predicate_axioms(); + } + void solver::ensure_translated(expr* e) { if (m_translate.get(e->get_id(), nullptr)) return; @@ -200,7 +232,6 @@ namespace intblast { } m_core.reset(); - m_translate.reset(); m_is_plugin = false; m_solver = mk_smt2_solver(m, s.params(), symbol::null); @@ -256,6 +287,8 @@ namespace intblast { void solver::sorted_subterms(expr_ref_vector& es, ptr_vector& sorted) { expr_fast_mark1 visited; for (expr* e : es) { + if (is_translated(e)) + continue; sorted.push_back(e); visited.mark(e); } @@ -264,7 +297,7 @@ namespace intblast { if (is_app(e)) { app* a = to_app(e); for (expr* arg : *a) { - if (!visited.is_marked(arg)) { + if (!visited.is_marked(arg) && !is_translated(arg)) { visited.mark(arg); sorted.push_back(arg); } @@ -287,7 +320,7 @@ namespace intblast { expr* r = n->get_root()->get_expr(); es.push_back(m.mk_eq(e, r)); r = es.back(); - if (!visited.is_marked(r)) { + if (!visited.is_marked(r) && !is_translated(r)) { visited.mark(r); sorted.push_back(r); } @@ -295,7 +328,7 @@ namespace intblast { else if (is_quantifier(e)) { quantifier* q = to_quantifier(e); expr* b = q->get_expr(); - if (!visited.is_marked(b)) { + if (!visited.is_marked(b) && !is_translated(b)) { visited.mark(b); sorted.push_back(b); } @@ -333,7 +366,11 @@ namespace intblast { continue; if (sib->get_arg(0)->get_root() == r1) continue; - add_clause(~eq_internalize(n, sib), eq_internalize(sib->get_arg(0), n->get_arg(0)), nullptr); + auto a = eq_internalize(n, sib); + auto b = eq_internalize(sib->get_arg(0), n->get_arg(0)); + ctx.mark_relevant(a); + ctx.mark_relevant(b); + add_clause(~a, b, nullptr); return sat::check_result::CR_CONTINUE; } } @@ -350,7 +387,9 @@ namespace intblast { auto nBv2int = ctx.get_enode(bv2int); auto nxModN = ctx.get_enode(xModN); if (nBv2int->get_root() != nxModN->get_root()) { - add_unit(eq_internalize(nBv2int, nxModN)); + auto a = eq_internalize(nBv2int, nxModN); + ctx.mark_relevant(a); + add_unit(a); return sat::check_result::CR_CONTINUE; } } @@ -366,7 +405,7 @@ namespace intblast { return x; return a.mk_int(mod(r, N)); } - if (any_of(m_vars, [&](expr* v) { return translated(v) == x; })) + if (any_of(m_vars, [&](expr* v) { return translated(v) == x && bv.get_bv_size(v) == bv.get_bv_size(bv_expr); })) return x; return a.mk_mod(x, a.mk_int(N)); } @@ -481,6 +520,7 @@ namespace intblast { m_new_funs.insert(f, g); } f = g; + m_pinned.push_back(f); } set_translated(e, m.mk_app(f, m_args)); } @@ -578,14 +618,14 @@ namespace intblast { } case OP_BUREM: case OP_BUREM_I: { - expr* x = arg(0), * y = arg(1); + expr* x = arg(0), * y = umod(e, 1); r = m.mk_ite(m.mk_eq(y, a.mk_int(0)), x, a.mk_mod(x, y)); break; } case OP_BUDIV: case OP_BUDIV_I: { - expr* x = arg(0), * y = arg(1); - r = m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(1), a.mk_idiv(x, umod(bv_expr, 1))); + expr* x = arg(0), * y = umod(e, 1); + r = m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(1), a.mk_idiv(x, y)); break; } case OP_BUMUL_NO_OVFL: { @@ -594,24 +634,24 @@ namespace intblast { break; } case OP_BSHL: { - expr* x = arg(0), * y = arg(1); + expr* x = arg(0), * y = umod(e, 1); r = a.mk_int(0); for (unsigned i = 0; i < bv.get_bv_size(e); ++i) - r = m.mk_ite(m.mk_eq(y, a.mk_int(i)), a.mk_mul(x, a.mk_int(rational::power_of_two(i))), r); + r = m.mk_ite(m.mk_eq(y, a.mk_int(i)), a.mk_mul(x, a.mk_int(rational::power_of_two(i))), r); break; } case OP_BNOT: r = bnot(arg(0)); break; case OP_BLSHR: { - expr* x = arg(0), * y = arg(1); + expr* x = arg(0), * y = umod(e, 1); r = a.mk_int(0); for (unsigned i = 0; i < bv.get_bv_size(e); ++i) r = m.mk_ite(m.mk_eq(y, a.mk_int(i)), a.mk_idiv(x, a.mk_int(rational::power_of_two(i))), r); break; - } - // Or use (p + q) - band(p, q)? + } case OP_BOR: { + // p | q := (p + q) - band(p, q) r = arg(0); for (unsigned i = 1; i < args.size(); ++i) r = a.mk_sub(a.mk_add(r, arg(i)), a.mk_band(bv.get_bv_size(e), r, arg(i))); @@ -623,11 +663,9 @@ namespace intblast { case OP_BAND: r = band(args); break; - // From "Hacker's Delight", section 2-2. Addition Combined with Logical Operations; - // found via Int-Blasting paper; see https://doi.org/10.1007/978-3-030-94583-1_24 - // (p + q) - 2*band(p, q); case OP_BXNOR: case OP_BXOR: { + // p ^ q := (p + q) - 2*band(p, q); unsigned sz = bv.get_bv_size(e); r = arg(0); for (unsigned i = 1; i < args.size(); ++i) { @@ -691,7 +729,7 @@ namespace intblast { case OP_BSMOD_I: case OP_BSMOD: { bv_expr = e; - expr* x = umod(bv_expr, 0), *y = umod(bv_expr, 0); + expr* x = umod(bv_expr, 0), *y = umod(bv_expr, 1); rational N = rational::power_of_two(bv.get_bv_size(bv_expr)); expr* signx = a.mk_ge(x, a.mk_int(N/2)); expr* signy = a.mk_ge(y, a.mk_int(N/2)); @@ -721,7 +759,7 @@ namespace intblast { // x > 0, y > 0 -> d // x < 0, y < 0 -> d bv_expr = e; - expr* x = umod(bv_expr, 0), * y = umod(bv_expr, 0); + expr* x = umod(bv_expr, 0), * y = umod(bv_expr, 1); rational N = rational::power_of_two(bv.get_bv_size(bv_expr)); expr* signx = a.mk_ge(x, a.mk_int(N / 2)); expr* signy = a.mk_ge(y, a.mk_int(N / 2)); @@ -735,7 +773,7 @@ namespace intblast { // y = 0 -> x // else x - sdiv(x, y) * y bv_expr = e; - expr* x = umod(bv_expr, 0), * y = umod(bv_expr, 0); + expr* x = umod(bv_expr, 0), * y = umod(bv_expr, 1); rational N = rational::power_of_two(bv.get_bv_size(bv_expr)); expr* signx = a.mk_ge(x, a.mk_int(N / 2)); expr* signy = a.mk_ge(y, a.mk_int(N / 2)); @@ -751,8 +789,7 @@ namespace intblast { case OP_EXT_ROTATE_RIGHT: case OP_REPEAT: case OP_BREDOR: - case OP_BREDAND: - + case OP_BREDAND: verbose_stream() << mk_pp(e, m) << "\n"; NOT_IMPLEMENTED_YET(); break; @@ -804,26 +841,46 @@ namespace intblast { } bool solver::add_dep(euf::enode* n, top_sort& dep) { - // bv2int - auto e = ctx.get_enode(translated(n->get_expr())); - if (!e) + if (!is_app(n->get_expr())) return false; - dep.add(n, e); + app* e = to_app(n->get_expr()); + if (n->num_args() == 0) { + dep.insert(n, nullptr); + return true; + } + if (e->get_family_id() != bv.get_family_id()) + return false; + for (euf::enode* arg : euf::enode_args(n)) + dep.add(n, arg->get_root()); return true; } // TODO: handle dependencies properly by using arithmetical model to retrieve values of translated // bit-vectors directly. - void solver::add_value_plugin(euf::enode* n, model& mdl, expr_ref_vector& values) { - SASSERT(bv.is_bv(n->get_expr())); - rational N = rational::power_of_two(bv.get_bv_size(n->get_expr())); - auto e = ctx.get_enode(translated(n->get_expr())); + void solver::add_value_solver(euf::enode* n, model& mdl, expr_ref_vector& values) { + expr* e = n->get_expr(); + SASSERT(bv.is_bv(e)); + + if (bv.is_numeral(e)) { + values.setx(n->get_root_id(), e); + return; + } + + rational r, N = rational::power_of_two(bv.get_bv_size(e)); + expr* te = translated(e); + model_ref mdlr; + m_solver->get_model(mdlr); expr_ref value(m); - value = values.get(e->get_root_id()); - values.setx(n->get_root_id(), value); + if (mdlr->eval_expr(te, value, true) && a.is_numeral(value, r)) { + values.setx(n->get_root_id(), bv.mk_numeral(mod(r, N), bv.get_bv_size(e))); + return; + } + ctx.s().display(verbose_stream()); + verbose_stream() << "failed to evaluate " << mk_pp(te, m) << " " << value << "\n"; + UNREACHABLE(); } - void solver::add_value_solver(euf::enode* n, model& mdl, expr_ref_vector& values) { + void solver::add_value_plugin(euf::enode* n, model& mdl, expr_ref_vector& values) { expr_ref value(m); if (n->interpreted()) value = n->get_expr(); @@ -833,10 +890,16 @@ namespace intblast { for (auto arg : euf::enode_args(n)) args.push_back(values.get(arg->get_root_id())); rw.mk_app(n->get_decl(), args.size(), args.data(), value); - VERIFY(value); } else { - rational r = get_value(n->get_expr()); + expr_ref bv2int(bv.mk_bv2int(n->get_expr()), m); + euf::enode* b2i = ctx.get_enode(bv2int); + if (!b2i) verbose_stream() << bv2int << "\n"; + SASSERT(b2i); + VERIFY(b2i); + arith::arith_value av(ctx); + rational r; + VERIFY(av.get_value(b2i->get_expr(), r)); verbose_stream() << ctx.bpp(n) << " := " << r << "\n"; value = bv.mk_numeral(r, bv.get_bv_size(n->get_expr())); } diff --git a/src/sat/smt/intblast_solver.h b/src/sat/smt/intblast_solver.h index 7dd37d5a7..493b1f3c5 100644 --- a/src/sat/smt/intblast_solver.h +++ b/src/sat/smt/intblast_solver.h @@ -54,6 +54,7 @@ namespace intblast { scoped_ptr<::solver> m_solver; obj_map m_new_funs; expr_ref_vector m_translate, m_args; + ast_ref_vector m_pinned; sat::literal_vector m_core; ptr_vector m_bv2int, m_int2bv; statistics m_stats; @@ -65,8 +66,9 @@ namespace intblast { rational get_value(expr* e) const; + bool is_translated(expr* e) const { return !!m_translate.get(e->get_id(), nullptr); } expr* translated(expr* e) const { expr* r = m_translate.get(e->get_id(), nullptr); SASSERT(r); return r; } - void set_translated(expr* e, expr* r) { SASSERT(r); m_translate.setx(e->get_id(), r); } + void set_translated(expr* e, expr* r); expr* arg(unsigned i) { return m_args.get(i); } expr* umod(expr* bv_expr, unsigned i); @@ -83,9 +85,10 @@ namespace intblast { void ensure_translated(expr* e); void internalize_bv(app* e); - unsigned m_vars_qhead = 0; - ptr_vector m_vars; - void add_bound_axioms(); + unsigned m_vars_qhead = 0, m_preds_qhead = 0; + ptr_vector m_vars, m_preds; + bool add_bound_axioms(); + bool add_predicate_axioms(); euf::theory_var mk_var(euf::enode* n) override; @@ -109,7 +112,7 @@ namespace intblast { void collect_statistics(statistics& st) const override; - bool unit_propagate() override { return false; } + bool unit_propagate() override; void get_antecedents(sat::literal l, sat::ext_justification_idx idx, sat::literal_vector& r, bool probing) override {} diff --git a/src/smt/theory_datatype.cpp b/src/smt/theory_datatype.cpp index cfc1f06f2..b794a44b5 100644 --- a/src/smt/theory_datatype.cpp +++ b/src/smt/theory_datatype.cpp @@ -915,7 +915,7 @@ namespace smt { } SASSERT(val == l_undef || (val == l_false && d->m_constructor == nullptr)); d->m_recognizers[c_idx] = recognizer; - m_trail_stack.push(set_vector_idx_trail(d->m_recognizers, c_idx)); + m_trail_stack.push(set_vector_idx_trail(d->m_recognizers, c_idx)); if (val == l_false) { propagate_recognizer(v, recognizer); } diff --git a/src/util/trail.h b/src/util/trail.h index 1aa7e4441..43e698234 100644 --- a/src/util/trail.h +++ b/src/util/trail.h @@ -219,12 +219,12 @@ public: } }; -template +template class set_vector_idx_trail : public trail { - ptr_vector & m_vector; + V & m_vector; unsigned m_idx; public: - set_vector_idx_trail(ptr_vector & v, unsigned idx): + set_vector_idx_trail(V & v, unsigned idx): m_vector(v), m_idx(idx) { }