From 586343ce643e7d9fb5865cc5f5b1f2829cce63af Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sun, 14 Jul 2024 15:38:56 -0700 Subject: [PATCH] na Signed-off-by: Nikolaj Bjorner --- src/ast/bv_decl_plugin.cpp | 4 +- src/ast/bv_decl_plugin.h | 4 +- src/ast/sls/CMakeLists.txt | 3 +- src/ast/sls/bv_sls_eval.cpp | 72 ++-- src/ast/sls/bv_sls_eval.h | 13 +- src/ast/sls/bv_sls_fixed.h | 2 +- src/ast/sls/bv_sls_terms.cpp | 2 +- src/ast/sls/bv_sls_terms.h | 2 +- src/ast/sls/sat_ddfw.h | 4 + src/ast/sls/sls_arith_base.cpp | 328 ++++++++++--------- src/ast/sls/sls_arith_base.h | 30 +- src/ast/sls/sls_arith_plugin.cpp | 94 +++--- src/ast/sls/sls_arith_plugin.h | 13 +- src/ast/sls/sls_basic_plugin.cpp | 185 +++++------ src/ast/sls/sls_basic_plugin.h | 24 +- src/ast/sls/sls_bv_plugin.cpp | 129 +++----- src/ast/sls/sls_bv_plugin.h | 25 +- src/ast/sls/sls_cc.cpp | 4 +- src/ast/sls/sls_cc.h | 10 +- src/ast/sls/{sls_smt.cpp => sls_context.cpp} | 127 +++++-- src/ast/sls/{sls_smt.h => sls_context.h} | 44 ++- src/sat/smt/sls_solver.cpp | 2 +- src/tactic/sls/sls_tactic.cpp | 98 +++++- src/tactic/sls/sls_tactic.h | 2 + 24 files changed, 708 insertions(+), 513 deletions(-) rename src/ast/sls/{sls_smt.cpp => sls_context.cpp} (68%) rename src/ast/sls/{sls_smt.h => sls_context.h} (79%) diff --git a/src/ast/bv_decl_plugin.cpp b/src/ast/bv_decl_plugin.cpp index 5dd9f6080..aee03ed62 100644 --- a/src/ast/bv_decl_plugin.cpp +++ b/src/ast/bv_decl_plugin.cpp @@ -932,13 +932,13 @@ unsigned bv_util::get_int2bv_size(parameter const& p) { return static_cast(sz); } -app * bv_util::mk_bv2int(expr* e) { +app * bv_util::mk_bv2int(expr* e) const { sort* s = m_manager.mk_sort(m_manager.mk_family_id("arith"), INT_SORT); parameter p(s); return m_manager.mk_app(get_fid(), OP_BV2INT, 1, &p, 1, &e); } -app* bv_util::mk_int2bv(unsigned sz, expr* e) { +app* bv_util::mk_int2bv(unsigned sz, expr* e) const { parameter p(sz); return m_manager.mk_app(get_fid(), OP_INT2BV, 1, &p, 1, &e); } diff --git a/src/ast/bv_decl_plugin.h b/src/ast/bv_decl_plugin.h index 58445afda..b8dde9361 100644 --- a/src/ast/bv_decl_plugin.h +++ b/src/ast/bv_decl_plugin.h @@ -549,8 +549,8 @@ public: app * mk_bv_ashr(expr* arg1, expr* arg2) { return m_manager.mk_app(get_fid(), OP_BASHR, arg1, arg2); } app * mk_bv_lshr(expr* arg1, expr* arg2) { return m_manager.mk_app(get_fid(), OP_BLSHR, arg1, arg2); } - app * mk_bv2int(expr* e); - app * mk_int2bv(unsigned sz, expr* e); + app * mk_bv2int(expr* e) const; + app * mk_int2bv(unsigned sz, expr* e) const; app* mk_bv_rotate_left(expr* arg1, expr* arg2) { return m_manager.mk_app(get_fid(), OP_EXT_ROTATE_LEFT, arg1, arg2); } app* mk_bv_rotate_right(expr* arg1, expr* arg2) { return m_manager.mk_app(get_fid(), OP_EXT_ROTATE_RIGHT, arg1, arg2); } diff --git a/src/ast/sls/CMakeLists.txt b/src/ast/sls/CMakeLists.txt index ae1533085..3306ea216 100644 --- a/src/ast/sls/CMakeLists.txt +++ b/src/ast/sls/CMakeLists.txt @@ -10,8 +10,9 @@ z3_add_component(ast_sls sls_basic_plugin.cpp sls_bv_plugin.cpp sls_cc.cpp + sls_context.cpp sls_engine.cpp - sls_smt.cpp + sls_smt_solver.cpp sls_valuation.cpp COMPONENT_DEPENDENCIES ast diff --git a/src/ast/sls/bv_sls_eval.cpp b/src/ast/sls/bv_sls_eval.cpp index 7c7afbeea..85c0fd5c0 100644 --- a/src/ast/sls/bv_sls_eval.cpp +++ b/src/ast/sls/bv_sls_eval.cpp @@ -14,6 +14,7 @@ Author: #include "ast/ast_pp.h" #include "ast/ast_ll_pp.h" #include "ast/sls/bv_sls.h" +#include "ast/rewriter/th_rewriter.h" namespace bv { @@ -25,29 +26,28 @@ namespace bv { m_fix(*this, terms, ctx) {} - void sls_eval::init_eval(std::function const& eval) { - for (expr* e : ctx.subterms()) { - if (!is_app(e)) - continue; - app* a = to_app(e); - if (!bv.is_bv(e)) - continue; - add_bit_vector(a); - if (a->get_family_id() == bv.get_family_id()) - init_eval_bv(a); - else if (is_uninterp(e)) { - auto& v = wval(e); - for (unsigned i = 0; i < v.bw; ++i) - m_tmp.set(i, eval(e, i)); - v.set_repair(random_bool(), m_tmp); - } + + void sls_eval::register_term(expr* e) { + if (!is_app(e)) + return; + app* a = to_app(e); + add_bit_vector(a); + if (a->get_family_id() == bv.get_family_id()) + init_eval_bv(a); + else if (bv.is_bv(e)) { + auto& v = wval(e); + for (unsigned i = 0; i < v.bw; ++i) + m_tmp.set(i, false); + v.set_repair(random_bool(), m_tmp); } } - bool sls_eval::add_bit_vector(app* e) { + void sls_eval::add_bit_vector(app* e) { + if (!bv.is_bv(e)) + return; m_values.reserve(e->get_id() + 1); if (m_values.get(e->get_id())) - return false; + return; auto v = alloc_valuation(e); m_values.set(e->get_id(), v); expr* x, * y; @@ -57,7 +57,7 @@ namespace bv { else if (bv.is_bv_ashr(e, x, y) && bv.is_numeral(y, val) && val.is_unsigned() && val.get_unsigned() <= bv.get_bv_size(e)) v->set_signed(val.get_unsigned()); - return true; + return; } sls_valuation* sls_eval::alloc_valuation(app* e) { @@ -575,11 +575,19 @@ namespace bv { val.set(val.eval, 0); break; } + case OP_INT2BV: { + expr_ref v = ctx.get_value(e->get_arg(0)); + th_rewriter rw(m); + v = bv.mk_int2bv(bv.get_bv_size(e), v); + rw(v); + rational r; + VERIFY(bv.is_numeral(v, r)); + val.set_value(m_tmp, r); + break; + } case OP_BREDAND: case OP_BREDOR: case OP_BXNOR: - case OP_INT2BV: - verbose_stream() << mk_bounded_pp(e, m) << "\n"; NOT_IMPLEMENTED_YET(); break; @@ -672,7 +680,7 @@ namespace bv { case OP_BV2INT: return false; case OP_INT2BV: - return false; + return try_repair_int2bv(eval_value(e), e->get_arg(0)); case OP_ULEQ: if (i == 0) return try_repair_ule(bval0(e), wval(e, i), wval(e, 1 - i)); @@ -1804,6 +1812,16 @@ namespace bv { return a.set_random(m_rand); } + bool sls_eval::try_repair_int2bv(bvect const& e, expr* arg) { + expr_ref intval(m); + intval = bv.mk_bv2int(bv.mk_numeral(e.get_value(e.nw), e.bw)); + th_rewriter rw(m); + rw(intval); + verbose_stream() << "repair " << mk_pp(arg, m) << " " << intval << "\n"; + ctx.set_value(arg, intval); + return true; + } + void sls_eval::set_div(bvect const& a, bvect const& b, unsigned bw, bvect& quot, bvect& rem) const { unsigned nw = (bw + 8 * sizeof(digit_t) - 1) / (8 * sizeof(digit_t)); @@ -1853,8 +1871,10 @@ namespace bv { } void sls_eval::commit_eval(app* e) { - if (bv.is_bv(e)) - VERIFY(wval(e).commit_eval()); + if (!bv.is_bv(e)) + return; + VERIFY(wval(e).commit_eval()); + // todo: if e is shared, then ctx.set_value(). } void sls_eval::set_random(app* e) { @@ -1899,7 +1919,7 @@ namespace bv { return expr_ref(m); } - std::ostream& sls_eval::display(std::ostream& out) { + std::ostream& sls_eval::display(std::ostream& out) const { auto& terms = ctx.subterms(); for (expr* e : terms) { if (!bv.is_bv(e)) @@ -1912,7 +1932,7 @@ namespace bv { return out; } - std::ostream& sls_eval::display_value(std::ostream& out, expr* e) { + std::ostream& sls_eval::display_value(std::ostream& out, expr* e) const { if (bv.is_bv(e)) return out << wval(e); return out << "?"; diff --git a/src/ast/sls/bv_sls_eval.h b/src/ast/sls/bv_sls_eval.h index 943d731a6..995bd4aef 100644 --- a/src/ast/sls/bv_sls_eval.h +++ b/src/ast/sls/bv_sls_eval.h @@ -19,7 +19,7 @@ Author: #include "ast/ast.h" #include "ast/sls/sls_valuation.h" #include "ast/sls/bv_sls_fixed.h" -#include "ast/sls/sls_smt.h" +#include "ast/sls/sls_context.h" #include "ast/bv_decl_plugin.h" namespace bv { @@ -58,7 +58,7 @@ namespace bv { * Register e as a bit-vector. * Return true if not already registered, false if already registered. */ - bool add_bit_vector(app* e); + void add_bit_vector(app* e); sls_valuation* alloc_valuation(app* e); //bool bval1_basic(app* e) const; @@ -109,6 +109,7 @@ namespace bv { bool try_repair_extract(bvect const& e, bvval& a, unsigned lo); bool try_repair_comp(bvect const& e, bvval& a, bvval& b, unsigned i); bool try_repair_eq(bool is_true, bvval& a, bvval const& b); + bool try_repair_int2bv(bvect const& e, expr* arg); void add_p2_1(bvval const& a, bvect& t) const; bool add_overflow_on_fixed(bvval const& a, bvect const& t); @@ -136,10 +137,12 @@ namespace bv { public: sls_eval(sls_terms& terms, sls::context& ctx); - void init_eval(std::function const& eval); +// void init_eval(std::function const& eval); void tighten_range() { m_fix.init(); } + void register_term(expr* e); + /** * Retrieve evaluation based on cache. * bval - Boolean values @@ -178,8 +181,8 @@ namespace bv { bool repair_up(expr* e); - std::ostream& display(std::ostream& out); + std::ostream& display(std::ostream& out) const; - std::ostream& display_value(std::ostream& out, expr* e); + std::ostream& display_value(std::ostream& out, expr* e) const; }; } diff --git a/src/ast/sls/bv_sls_fixed.h b/src/ast/sls/bv_sls_fixed.h index 01fccecf9..f0dfcd43e 100644 --- a/src/ast/sls/bv_sls_fixed.h +++ b/src/ast/sls/bv_sls_fixed.h @@ -18,7 +18,7 @@ Author: #include "ast/ast.h" #include "ast/sls/sls_valuation.h" -#include "ast/sls/sls_smt.h" +#include "ast/sls/sls_context.h" #include "ast/bv_decl_plugin.h" namespace bv { diff --git a/src/ast/sls/bv_sls_terms.cpp b/src/ast/sls/bv_sls_terms.cpp index 2df076a20..7c489960c 100644 --- a/src/ast/sls/bv_sls_terms.cpp +++ b/src/ast/sls/bv_sls_terms.cpp @@ -26,7 +26,7 @@ namespace bv { ctx(ctx), m(ctx.get_manager()), bv(m), - m_axioms(m) {} + m_axioms(m) {} void sls_terms::register_term(expr* e) { auto r = ensure_binary(e); diff --git a/src/ast/sls/bv_sls_terms.h b/src/ast/sls/bv_sls_terms.h index 93b703e37..8f1e477f0 100644 --- a/src/ast/sls/bv_sls_terms.h +++ b/src/ast/sls/bv_sls_terms.h @@ -24,7 +24,7 @@ Author: #include "ast/sls/sls_stats.h" #include "ast/sls/sls_powers.h" #include "ast/sls/sls_valuation.h" -#include "ast/sls/sls_smt.h" +#include "ast/sls/sls_context.h" namespace bv { diff --git a/src/ast/sls/sat_ddfw.h b/src/ast/sls/sat_ddfw.h index d74da3d54..24af8b207 100644 --- a/src/ast/sls/sat_ddfw.h +++ b/src/ast/sls/sat_ddfw.h @@ -243,10 +243,14 @@ namespace sat { // access clause information and state of Boolean search indexed_uint_set& unsat_set() { return m_unsat; } + indexed_uint_set const& unsat_set() const { return m_unsat; } + vector const& clauses() const { return m_clauses; } clause_info& get_clause_info(unsigned idx) { return m_clauses[idx]; } + clause_info const& get_clause_info(unsigned idx) const { return m_clauses[idx]; } + void remove_assumptions(); void flip(bool_var v); diff --git a/src/ast/sls/sls_arith_base.cpp b/src/ast/sls/sls_arith_base.cpp index 1188e6511..219d0bc31 100644 --- a/src/ast/sls/sls_arith_base.cpp +++ b/src/ast/sls/sls_arith_base.cpp @@ -317,32 +317,16 @@ namespace sls { SASSERT(dtt(sign(bv), ineq) == 0); } vi.m_value = new_value; - if (vi.m_shared) { - sort* s = vi.m_sort == var_sort::INT ? a.mk_int() : a.mk_real(); - expr_ref num = from_num(s, new_value); - ctx.set_value(vi.m_expr, num); - } for (auto idx : vi.m_muls) { - auto const& [v, monomial] = m_muls[idx]; - num_t prod(1); - for (auto w : monomial) - prod *= value(w); - if (value(v) != prod) - m_vars_to_update.push_back({ v, prod }); + auto const& [w, coeff, monomial] = m_muls[idx]; + ctx.new_value_eh(m_vars[w].m_expr); } - for (auto const& idx : vi.m_adds) { + for (auto idx : vi.m_adds) { auto const& ad = m_adds[idx]; - auto const& args = ad.m_args; - auto v = ad.m_var; - num_t sum(ad.m_coeff); - for (auto [c, w] : args) - sum += c * value(w); - if (value(v) != sum) - m_vars_to_update.push_back({ v, sum }); + ctx.new_value_eh(m_vars[ad.m_var].m_expr); } - if (vi.m_def_idx != UINT_MAX) - // add repair actions for additions and multiplications - m_defs_to_update.push_back(v); + expr* e = vi.m_expr; + ctx.new_value_eh(e); } template @@ -375,10 +359,10 @@ namespace sls { template bool arith_base::is_num(expr* e, num_t& i) { + UNREACHABLE(); return false; } - expr_ref arith_base::from_num(sort* s, rational const& n) { return expr_ref(a.mk_numeral(n, s), m); } @@ -389,6 +373,7 @@ namespace sls { template expr_ref arith_base::from_num(sort* s, num_t const& n) { + UNREACHABLE(); return expr_ref(m); } @@ -427,14 +412,14 @@ namespace sls { default: { v = mk_var(e); unsigned idx = m_muls.size(); - m_muls.push_back({ v, m }); - num_t prod(1); + m_muls.push_back({ v, c, m }); + num_t prod(c); for (auto w : m) m_vars[w].m_muls.push_back(idx), prod *= value(w); m_vars[v].m_def_idx = idx; m_vars[v].m_op = arith_op_kind::OP_MUL; m_vars[v].m_value = prod; - add_arg(term, c, v); + add_arg(term, num_t(1), v); break; } } @@ -473,22 +458,21 @@ namespace sls { num_t val; switch (k) { case arith_op_kind::OP_MOD: - if (value(v) != 0) - val = mod(value(w), value(v)); + val = value(v) == 0 ? num_t(0) : mod(value(w), value(v)); break; case arith_op_kind::OP_REM: - if (value(v) != 0) { + if (value(v) == 0) + val = 0; + else { val = value(w); val %= value(v); } break; case arith_op_kind::OP_IDIV: - if (value(v) != 0) - val = div(value(w), value(v)); + val = value(v) == 0 ? num_t(0): div(value(w), value(v)); break; case arith_op_kind::OP_DIV: - if (value(v) != 0) - val = value(w) / value(v); + val = value(v) == 0? num_t(0) : value(w) / value(v); break; case arith_op_kind::OP_ABS: val = abs(value(w)); @@ -511,7 +495,7 @@ namespace sls { return v; linear_term t; add_args(t, e, num_t(1)); - if (t.m_coeff == 1 && t.m_args.size() == 1 && t.m_args[0].first == 1) + if (t.m_coeff == 0 && t.m_args.size() == 1 && t.m_args[0].first == 1) return t.m_args[0].second; v = mk_var(e); auto idx = m_adds.size(); @@ -531,7 +515,7 @@ namespace sls { if (v == UINT_MAX) { v = m_vars.size(); m_expr2var.setx(e->get_id(), v, UINT_MAX); - m_vars.push_back(var_info(e, a.is_int(e) ? var_sort::INT : var_sort::REAL)); + m_vars.push_back(var_info(e, a.is_int(e) ? var_sort::INT : var_sort::REAL)); } return v; } @@ -541,7 +525,7 @@ namespace sls { if (m_bool_vars.get(bv, nullptr)) return; expr* e = ctx.atom(bv); - // verbose_stream() << "bool var " << bv << " " << mk_bounded_pp(e, m) << "\n"; + verbose_stream() << "bool var " << bv << " " << mk_bounded_pp(e, m) << "\n"; if (!e) return; expr* x, * y; @@ -570,6 +554,9 @@ namespace sls { add_args(ineq, y, num_t(-1)); init_ineq(bv, ineq); } + else if (m.is_distinct(e) && a.is_int_real(e->get_arg(0))) { + NOT_IMPLEMENTED_YET(); + } else if (a.is_is_int(e, x)) { NOT_IMPLEMENTED_YET(); @@ -601,77 +588,127 @@ namespace sls { } template - void arith_base::repair(sat::literal lit) { + void arith_base::propagate_literal(sat::literal lit) { + TRACE("sls", tout << "repair is-true: " << ctx.is_true(lit) << " lit: " << lit << "\n"); if (!ctx.is_true(lit)) return; auto const* ineq = atom(lit.var()); if (!ineq) return; + TRACE("sls", tout << "repair lit: " << lit << " ineq-is-true: " << ineq->is_true() << "\n"); if (ineq->is_true() != lit.sign()) - return; - TRACE("sls", tout << "repair " << lit << "\n"); + return; repair(lit, *ineq); } template - void arith_base::repair_defs_and_updates() { - while (!m_defs_to_update.empty() || !m_vars_to_update.empty()) { - repair_updates(); - repair_defs(); + bool arith_base::propagate() { + return false; + } + + template + void arith_base::repair_up(app* e) { + auto v = m_expr2var.get(e->get_id(), UINT_MAX); + if (v == UINT_MAX) + return; + auto const& vi = m_vars[v]; + if (vi.m_def_idx == UINT_MAX) + return; + m_ops.reserve(vi.m_def_idx + 1); + auto const& od = m_ops[vi.m_def_idx]; + num_t v1, v2; + switch (vi.m_op) { + case LAST_ARITH_OP: + break; + case OP_ADD: { + auto const& ad = m_adds[vi.m_def_idx]; + auto const& args = ad.m_args; + num_t sum(ad.m_coeff); + for (auto [c, w] : args) + sum += c * value(w); + update(v, sum); + break; + } + case OP_MUL: { + auto const& [w, coeff, monomial] = m_muls[vi.m_def_idx]; + num_t prod(coeff); + for (auto w : monomial) + prod *= value(w); + update(v, prod); + break; + } + case OP_MOD: + v1 = value(od.m_arg1); + v2 = value(od.m_arg2); + update(v, v2 == 0 ? num_t(0) : mod(v1, v2)); + break; + case OP_DIV: + v1 = value(od.m_arg1); + v2 = value(od.m_arg2); + update(v, v2 == 0 ? num_t(0) : v1 / v2); + break; + case OP_IDIV: + v1 = value(od.m_arg1); + v2 = value(od.m_arg2); + update(v, v2 == 0 ? num_t(0) : div(v1, v2)); + break; + case OP_REM: + v1 = value(od.m_arg1); + v2 = value(od.m_arg2); + update(v, v2 == 0 ? num_t(0) : v1 %= v2); + break; + case OP_ABS: + update(v, abs(value(od.m_arg1))); + break; + default: + NOT_IMPLEMENTED_YET(); } } template - void arith_base::repair_updates() { - while (!m_vars_to_update.empty()) { - auto [w, new_value1] = m_vars_to_update.back(); - m_vars_to_update.pop_back(); - update(w, new_value1); - } - } - - template - void arith_base::repair_defs() { - while (!m_defs_to_update.empty()) { - auto v = m_defs_to_update.back(); - m_defs_to_update.pop_back(); - auto const& vi = m_vars[v]; - switch (vi.m_op) { - case arith_op_kind::LAST_ARITH_OP: - break; - case arith_op_kind::OP_ADD: - repair_add(m_adds[vi.m_def_idx]); - break; - case arith_op_kind::OP_MUL: - repair_mul(m_muls[vi.m_def_idx]); - break; - case arith_op_kind::OP_MOD: - repair_mod(m_ops[vi.m_def_idx]); - break; - case arith_op_kind::OP_REM: - repair_rem(m_ops[vi.m_def_idx]); - break; - case arith_op_kind::OP_POWER: - repair_power(m_ops[vi.m_def_idx]); - break; - case arith_op_kind::OP_IDIV: - repair_idiv(m_ops[vi.m_def_idx]); - break; - case arith_op_kind::OP_DIV: - repair_div(m_ops[vi.m_def_idx]); - break; - case arith_op_kind::OP_ABS: - repair_abs(m_ops[vi.m_def_idx]); - break; - case arith_op_kind::OP_TO_INT: - repair_to_int(m_ops[vi.m_def_idx]); - break; - case arith_op_kind::OP_TO_REAL: - repair_to_real(m_ops[vi.m_def_idx]); - break; - default: - NOT_IMPLEMENTED_YET(); - } + void arith_base::repair_down(app* e) { + auto v = m_expr2var.get(e->get_id(), UINT_MAX); + if (v == UINT_MAX) + return; + auto const& vi = m_vars[v]; + if (vi.m_def_idx == UINT_MAX) + return; + TRACE("sls", tout << "repair def " << mk_bounded_pp(vi.m_expr, m) << "\n"); + switch (vi.m_op) { + case arith_op_kind::LAST_ARITH_OP: + break; + case arith_op_kind::OP_ADD: + repair_add(m_adds[vi.m_def_idx]); + break; + case arith_op_kind::OP_MUL: + repair_mul(m_muls[vi.m_def_idx]); + break; + case arith_op_kind::OP_MOD: + repair_mod(m_ops[vi.m_def_idx]); + break; + case arith_op_kind::OP_REM: + repair_rem(m_ops[vi.m_def_idx]); + break; + case arith_op_kind::OP_POWER: + repair_power(m_ops[vi.m_def_idx]); + break; + case arith_op_kind::OP_IDIV: + repair_idiv(m_ops[vi.m_def_idx]); + break; + case arith_op_kind::OP_DIV: + repair_div(m_ops[vi.m_def_idx]); + break; + case arith_op_kind::OP_ABS: + repair_abs(m_ops[vi.m_def_idx]); + break; + case arith_op_kind::OP_TO_INT: + repair_to_int(m_ops[vi.m_def_idx]); + break; + case arith_op_kind::OP_TO_REAL: + repair_to_real(m_ops[vi.m_def_idx]); + break; + default: + NOT_IMPLEMENTED_YET(); } } @@ -699,23 +736,24 @@ namespace sls { template void arith_base::repair_mul(mul_def const& md) { - num_t product(1); - num_t val = value(md.m_var); - for (auto v : md.m_monomial) + auto const& [v, coeff, monomial] = md; + num_t product(coeff); + num_t val = value(v); + for (auto v : monomial) product *= value(v); if (product == val) return; if (rand() % 20 == 0) { - update(md.m_var, product); + update(v, product); } else if (val == 0) { - auto v = md.m_monomial[rand() % md.m_monomial.size()]; + auto v = monomial[ctx.rand(monomial.size())]; num_t zero(0); update(v, zero); } else if (val == 1 || val == -1) { - product = 1; - for (auto v : md.m_monomial) { + product = coeff; + for (auto v : monomial) { num_t new_value(1); if (rand() % 2 == 0) new_value = -1; @@ -723,14 +761,14 @@ namespace sls { update(v, new_value); } if (product != val) { - auto last = md.m_monomial.back(); + auto last = monomial.back(); update(last, -value(last)); } } else if (rand() % 2 == 0 && product != 0) { // value1(v) * product / value(v) = val // value1(v) = value(v) * val / product - auto w = md.m_monomial[rand() % md.m_monomial.size()]; + auto w = monomial[ctx.rand(monomial.size())]; auto old_value = value(w); num_t new_value; if (m_vars[w].m_sort == var_sort::REAL) @@ -740,15 +778,15 @@ namespace sls { update(w, new_value); } else { - product = 1; - for (auto v : md.m_monomial) { + product = coeff; + for (auto v : monomial) { num_t new_value{ 1 }; if (rand() % 2 == 0) new_value = -1; product *= new_value; update(v, new_value); } - auto v = md.m_monomial[rand() % md.m_monomial.size()]; + auto v = monomial[ctx.rand(monomial.size())]; if ((product < 0 && 0 < val) || (val < 0 && 0 < product)) update(v, -val * value(v)); else @@ -761,8 +799,10 @@ namespace sls { auto val = value(od.m_var); auto v1 = value(od.m_arg1); auto v2 = value(od.m_arg2); - if (v2 == 0) + if (v2 == 0) { + update(od.m_var, num_t(0)); return; + } IF_VERBOSE(0, verbose_stream() << "todo repair rem"); // bail @@ -784,7 +824,11 @@ namespace sls { template void arith_base::repair_to_int(op_def const& od) { - NOT_IMPLEMENTED_YET(); + auto val = value(od.m_var); + auto v1 = value(od.m_arg1); + if (val - 1 < v1 && v1 <= val) + return; + update(od.m_arg1, val); } template @@ -800,8 +844,10 @@ namespace sls { auto val = value(od.m_var); auto v1 = value(od.m_arg1); auto v2 = value(od.m_arg2); - if (v1 == 0 && v2 == 0) + if (v1 == 0 && v2 == 0) { + update(od.m_var, num_t(0)); return; + } IF_VERBOSE(0, verbose_stream() << "todo repair ^"); NOT_IMPLEMENTED_YET(); } @@ -832,10 +878,7 @@ namespace sls { update(od.m_arg1, v1); return; } - if (v2 == 0) - return; - // bail - update(od.m_var, mod(v1, v2)); + update(od.m_var, v2 == 0 ? num_t(0) : mod(v1, v2)); } template @@ -843,11 +886,9 @@ namespace sls { auto val = value(od.m_var); auto v1 = value(od.m_arg1); auto v2 = value(od.m_arg2); - if (v2 == 0) - return; IF_VERBOSE(0, verbose_stream() << "todo repair div"); // bail - update(od.m_var, div(v1, v2)); + update(od.m_var, v2 == 0 ? num_t(0) : div(v1, v2)); } template @@ -855,11 +896,9 @@ namespace sls { auto val = value(od.m_var); auto v1 = value(od.m_arg1); auto v2 = value(od.m_arg2); - if (v2 == 0) - return; IF_VERBOSE(0, verbose_stream() << "todo repair /"); // bail - update(od.m_var, v1 / v2); + update(od.m_var, v2 == 0 ? num_t(0) : v1 / v2); } template @@ -956,27 +995,31 @@ namespace sls { } template - void arith_base::register_term(expr* e) { - } - - template - void arith_base::set_shared(expr* e) { - if (!a.is_int_real(e)) + void arith_base::register_term(expr* _e) { + if (!is_app(_e)) return; - var_t v = m_expr2var.get(e->get_id(), UINT_MAX); - if (v == UINT_MAX) - v = mk_term(e); - m_vars[v].m_shared = true; + app* e = to_app(_e); + auto v = ctx.atom2bool_var(e); + if (v != sat::null_bool_var) + init_bool_var(v); + if (!a.is_arith_expr(e) && !m.is_eq(e) && !m.is_distinct(e)) + for (auto arg : *e) + if (a.is_int_real(arg)) + mk_term(arg); } template void arith_base::set_value(expr* e, expr* v) { - auto w = m_expr2var.get(e->get_id(), UINT_MAX); - if (w == UINT_MAX) + if (!a.is_int_real(e)) return; + var_t w = m_expr2var.get(e->get_id(), UINT_MAX); + if (w == UINT_MAX) + w = mk_term(e); + num_t n; if (!is_num(v, n)) return; + verbose_stream() << "set value " << w << " " << mk_bounded_pp(e, m) << " " << n << " " << value(w) << "\n"; if (n == value(w)) return; update(w, n); @@ -988,21 +1031,6 @@ namespace sls { return expr_ref(a.mk_numeral(rational(m_vars[v].m_value.get_int64(), rational::i64()), a.is_int(e)), m); } - template - lbool arith_base::check() { - // repair each root literal - for (sat::literal lit : ctx.root_literals()) - repair(lit); - - repair_defs_and_updates(); - - // update literal assignment based on current model - for (unsigned v = 0; v < ctx.num_bool_vars(); ++v) - init_bool_var_assignment(v); - - return ctx.unsat().empty() ? l_true : l_undef; - } - template bool arith_base::is_sat() { for (auto const& clause : ctx.clauses()) { @@ -1035,8 +1063,8 @@ namespace sls { } for (unsigned v = 0; v < m_vars.size(); ++v) { auto const& vi = m_vars[v]; - out << "v" << v << " := " << vi.m_value << " " << vi.m_best_value << " "; - out << mk_bounded_pp(vi.m_expr, m) << " - "; + out << "v" << v << " := " << vi.m_value << " (best " << vi.m_best_value << ") "; + out << mk_bounded_pp(vi.m_expr, m) << " : "; for (auto [c, bv] : vi.m_bool_vars) out << c << "@" << bv << " "; out << "\n"; @@ -1049,9 +1077,11 @@ namespace sls { } for (auto ad : m_adds) { out << "v" << ad.m_var << " := "; + bool first = true; for (auto [c, w] : ad.m_args) - out << c << "* v" << w << " + "; - out << ad.m_coeff; + out << (first?"":" + ") << c << "* v" << w; + if (ad.m_coeff != 0) + out << " + " << ad.m_coeff; out << "\n"; } for (auto od : m_ops) { diff --git a/src/ast/sls/sls_arith_base.h b/src/ast/sls/sls_arith_base.h index 21806f574..197cfda26 100644 --- a/src/ast/sls/sls_arith_base.h +++ b/src/ast/sls/sls_arith_base.h @@ -20,7 +20,7 @@ Author: #include "util/checked_int64.h" #include "ast/ast_trail.h" #include "ast/arith_decl_plugin.h" -#include "ast/sls/sls_smt.h" +#include "ast/sls/sls_context.h" namespace sls { @@ -75,11 +75,11 @@ namespace sls { out << " + " << m_coeff; switch (m_op) { case ineq_kind::LE: - return out << " <= " << 0 << "(" << m_args_value << ")"; + return out << " <= " << 0 << "(" << m_args_value + m_coeff << ")"; case ineq_kind::EQ: - return out << " == " << 0 << "(" << m_args_value << ")"; + return out << " == " << 0 << "(" << m_args_value + m_coeff << ")"; default: - return out << " < " << 0 << "(" << m_args_value << ")"; + return out << " < " << 0 << "(" << m_args_value + m_coeff << ")"; } } }; @@ -101,6 +101,7 @@ namespace sls { struct mul_def { unsigned m_var; + num_t m_coeff; unsigned_vector m_monomial; }; @@ -109,8 +110,8 @@ namespace sls { }; struct op_def { - unsigned m_var; - arith_op_kind m_op; + unsigned m_var = UINT_MAX; + arith_op_kind m_op = LAST_ARITH_OP; unsigned m_arg1, m_arg2; }; @@ -124,8 +125,6 @@ namespace sls { unsigned_vector m_expr2var; bool m_dscore_mode = false; arith_util a; - unsigned_vector m_defs_to_update; - vector> m_vars_to_update; unsigned get_num_vars() const { return m_vars.size(); } @@ -139,10 +138,6 @@ namespace sls { void repair_abs(op_def const& od); void repair_to_int(op_def const& od); void repair_to_real(op_def const& od); - void repair_defs_and_updates(); - void repair_defs(); - void repair_updates(); - void repair(sat::literal lit); void repair(sat::literal lit, ineq const& ineq); double reward(sat::literal lit); @@ -179,15 +174,18 @@ namespace sls { bool is_num(expr* e, num_t& i); expr_ref from_num(sort* s, num_t const& n); void check_ineqs(); + void init_bool_var(sat::bool_var v); public: arith_base(context& ctx); - ~arith_base() override {} - void init_bool_var(sat::bool_var v) override; + ~arith_base() override {} void register_term(expr* e) override; - void set_shared(expr* e) override; void set_value(expr* e, expr* v) override; expr_ref get_value(expr* e) override; - lbool check() override; + void initialize() override {} + void propagate_literal(sat::literal lit) override; + bool propagate() override; + void repair_up(app* e) override; + void repair_down(app* e) override; bool is_sat() override; void on_rescale() override; void on_restart() override; diff --git a/src/ast/sls/sls_arith_plugin.cpp b/src/ast/sls/sls_arith_plugin.cpp index 31303818d..a95162093 100644 --- a/src/ast/sls/sls_arith_plugin.cpp +++ b/src/ast/sls/sls_arith_plugin.cpp @@ -21,21 +21,9 @@ Author: namespace sls { - void arith_plugin::init_bool_var(sat::bool_var v) { - if (!m_arith) { - try { - m_arith64->init_bool_var(v); - return; - } - catch (overflow_exception&) { - m_arith = alloc(arith_base, ctx); - for (auto e : m_shared) - m_arith->set_shared(e); - return; // initialization happens on check-sat calls - } - } - m_arith->init_bool_var(v); - + void arith_plugin::init_backup() { + m_arith = alloc(arith_base, ctx); + m_arith->initialize(); } void arith_plugin::register_term(expr* e) { @@ -45,9 +33,7 @@ namespace sls { return; } catch (overflow_exception&) { - m_arith = alloc(arith_base, ctx); - for (auto e : m_shared) - m_arith->set_shared(e); + init_backup(); } } m_arith->register_term(e); @@ -59,32 +45,49 @@ namespace sls { return m_arith64->get_value(e); } catch (overflow_exception&) { - m_arith = alloc(arith_base, ctx); - for (auto e : m_shared) - m_arith->set_shared(e); + init_backup(); } } return m_arith->get_value(e); } - lbool arith_plugin::check() { + void arith_plugin::initialize() { + if (m_arith) + m_arith->initialize(); + else + m_arith64->initialize(); + } + + void arith_plugin::propagate_literal(sat::literal lit) { if (!m_arith) { try { - return m_arith64->check(); + m_arith64->propagate_literal(lit); + return; } catch (overflow_exception&) { - m_arith = alloc(arith_base, ctx); - for (auto e : m_shared) - m_arith->set_shared(e); + init_backup(); } - } - return m_arith->check(); + } + m_arith->propagate_literal(lit); + } + + bool arith_plugin::propagate() { + if (!m_arith) { + try { + return m_arith64->propagate(); + } + catch (overflow_exception&) { + init_backup(); + } + } + return m_arith->propagate(); } bool arith_plugin::is_sat() { - if (!m_arith) + if (m_arith) + return m_arith->is_sat(); + else return m_arith64->is_sat(); - return m_arith->is_sat(); } void arith_plugin::on_rescale() { @@ -115,19 +118,30 @@ namespace sls { m_arith64->mk_model(mdl); } - void arith_plugin::set_shared(expr* e) { + void arith_plugin::repair_down(app* e) { + if (m_arith) + m_arith->repair_down(e); + else + m_arith64->repair_down(e); + } + + void arith_plugin::repair_up(app* e) { if (m_arith) - m_arith->set_shared(e); - else { - m_arith64->set_shared(e); - m_shared.push_back(e); - } + m_arith->repair_up(e); + else + m_arith64->repair_up(e); } void arith_plugin::set_value(expr* e, expr* v) { - if (m_arith) - m_arith->set_value(e, v); - else - m_arith->set_value(e, v); + if (!m_arith) { + try { + m_arith64->set_value(e, v); + return; + } + catch (overflow_exception&) { + init_backup(); + } + } + m_arith->set_value(e, v); } } diff --git a/src/ast/sls/sls_arith_plugin.h b/src/ast/sls/sls_arith_plugin.h index 4a1d71deb..53c1a7ba5 100644 --- a/src/ast/sls/sls_arith_plugin.h +++ b/src/ast/sls/sls_arith_plugin.h @@ -16,7 +16,7 @@ Author: --*/ #pragma once -#include "ast/sls/sls_smt.h" +#include "ast/sls/sls_context.h" #include "ast/sls/sls_arith_base.h" namespace sls { @@ -25,23 +25,28 @@ namespace sls { scoped_ptr>> m_arith64; scoped_ptr> m_arith; expr_ref_vector m_shared; + + void init_backup(); public: arith_plugin(context& ctx) : plugin(ctx), m_shared(ctx.get_manager()) { m_arith64 = alloc(arith_base>,ctx); + m_fid = m_arith64->fid(); } ~arith_plugin() override {} - void init_bool_var(sat::bool_var v) override; void register_term(expr* e) override; expr_ref get_value(expr* e) override; - lbool check() override; + void initialize() override; + void propagate_literal(sat::literal lit) override; + bool propagate() override; + void repair_down(app* e) override; + void repair_up(app* e) override; bool is_sat() override; void on_rescale() override; void on_restart() override; std::ostream& display(std::ostream& out) const override; void mk_model(model& mdl) override; - void set_shared(expr* e) override; void set_value(expr* e, expr* v) override; }; diff --git a/src/ast/sls/sls_basic_plugin.cpp b/src/ast/sls/sls_basic_plugin.cpp index 18c5599bd..278fa55b9 100644 --- a/src/ast/sls/sls_basic_plugin.cpp +++ b/src/ast/sls/sls_basic_plugin.cpp @@ -24,24 +24,21 @@ namespace sls { return expr_ref(m.mk_bool_val(bval0(e)), m); } - lbool basic_plugin::check() { - init(); - for (sat::literal lit : ctx.root_literals()) - repair_literal(lit); - repair_defs_and_updates(); - return ctx.unsat().empty() ? l_true : l_undef; + void basic_plugin::propagate_literal(sat::literal lit) { + auto a = ctx.atom(lit.var()); + if (!a || !is_app(a)) + return; + SASSERT(to_app(a)->get_family_id() != basic_family_id); + if (bval1(to_app(a)) != bval0(to_app(a))) + ctx.new_value_eh(a); } - void basic_plugin::init() { - m_repair_down = UINT_MAX; - m_repair_roots.reset(); - m_repair_up.reset(); - if (m_initialized) - return; - m_initialized = true; - for (auto t : ctx.subterms()) - if (is_app(t) && m.is_bool(t) && to_app(t)->get_family_id() == basic_family_id) - m_values.setx(t->get_id(), bval1(to_app(t)), false); + void basic_plugin::register_term(expr* e) { + if (is_app(e) && m.is_bool(e) && to_app(e)->get_family_id() == basic_family_id) + m_values.setx(e->get_id(), bval1(to_app(e)), false); + } + + void basic_plugin::initialize() { } bool basic_plugin::is_sat() { @@ -70,7 +67,6 @@ namespace sls { if (bval0(e) != m.is_true(v)) return; set_value(e, m.is_true(v)); - m_repair_roots.insert(e->get_id()); } bool basic_plugin::bval1(app* e) const { @@ -133,7 +129,7 @@ namespace sls { if (v == sat::null_bool_var) return m_values.get(e->get_id(), false); else - return ctx.is_true(sat::literal(v, false)); + return ctx.is_true(v); } bool basic_plugin::try_repair(app* e, unsigned i) { @@ -157,8 +153,7 @@ namespace sls { case OP_ITE: return try_repair_ite(e, i); case OP_DISTINCT: - NOT_IMPLEMENTED_YET(); - return false; + return try_repair_distinct(e, i); default: UNREACHABLE(); return false; @@ -167,17 +162,21 @@ namespace sls { bool basic_plugin::try_repair_and_or(app* e, unsigned i) { auto b = bval0(e); + if ((b && m.is_and(e)) || (!b && m.is_or(e))) { + for (auto arg : *e) + if (!set_value(arg, b)) + return false; + return true; + } auto child = e->get_arg(i); if (b == bval0(child)) return false; - set_value(child, b); - return true; + return set_value(child, b); } bool basic_plugin::try_repair_not(app* e) { auto child = e->get_arg(0); - set_value(child, !bval0(e)); - return true; + return set_value(child, !bval0(e)); } bool basic_plugin::try_repair_eq(app* e, unsigned i) { @@ -185,129 +184,111 @@ namespace sls { auto sibling = e->get_arg(1 - i); if (!m.is_bool(child)) return false; - set_value(child, bval0(e) == bval0(sibling)); - return true; + return set_value(child, bval0(e) == bval0(sibling)); } bool basic_plugin::try_repair_xor(app* e, unsigned i) { - bool ev = bval0(e); - bool bv = bval0(e->get_arg(1 - i)); auto child = e->get_arg(i); - set_value(child, ev != bv); - return true; + bool bv = false; + for (unsigned j = 0; j < e->get_num_args(); ++j) + if (j != i) + bv ^= bval0(e->get_arg(j)); + bool ev = bval0(e); + return set_value(child, ev != bv); } bool basic_plugin::try_repair_ite(app* e, unsigned i) { auto child = e->get_arg(i); bool c = bval0(e->get_arg(0)); - if (i == 0) { - set_value(child, !c); - return true; - } + if (i == 0) + return set_value(child, !c); + if (c != (i == 1)) return false; - if (m.is_bool(e)) { - set_value(child, bval0(e)); - return true; - } + if (m.is_bool(e)) + return set_value(child, bval0(e)); return false; } bool basic_plugin::try_repair_implies(app* e, unsigned i) { auto child = e->get_arg(i); + auto sibling = e->get_arg(1 - i); bool ev = bval0(e); bool av = bval0(child); - bool bv = bval0(e->get_arg(1 - i)); - if (i == 0) { - if (ev == (!av || bv)) - return false; - } - else if (ev != (!bv || av)) + bool bv = bval0(sibling); + if (ev) { + + if (i == 0 && (!av || bv)) + return true; + if (i == 1 && (!bv || av)) + return true; + if (i == 0) { + return set_value(child, false); + } + if (i == 1) { + return set_value(child, true); + } return false; - set_value(child, ev); - return true; + } + if (i == 0 && av && !bv) + return true; + if (i == 1 && bv && !av) + return true; + if (i == 0) { + return set_value(child, true) && set_value(sibling, false); + } + if (i == 1) { + return set_value(child, false) && set_value(sibling, true); + } + return false; } - bool basic_plugin::repair_up(expr* e) { - if (!m.is_bool(e)) - return false; - auto b = bval1(to_app(e)); + void basic_plugin::repair_up(app* e) { + if (!m.is_bool(e) || e->get_family_id() != basic_family_id) + return; + auto b = bval1(e); + if (bval0(e) == b) + return; set_value(e, b); - return true; } void basic_plugin::repair_down(app* e) { SASSERT(m.is_bool(e)); unsigned n = e->get_num_args(); - if (n == 0 || e->get_family_id() != m.get_basic_family_id()) { - for (auto p : ctx.parents(e)) - m_repair_up.insert(p->get_id()); - ctx.set_value(e, m.mk_bool_val(bval0(e))); + if (n == 0 || e->get_family_id() != m.get_basic_family_id()) return; - } + if (bval0(e) == bval1(e)) return; unsigned s = ctx.rand(n); for (unsigned i = 0; i < n; ++i) { auto j = (i + s) % n; - if (try_repair(e, j)) { - m_repair_down = e->get_arg(j)->get_id(); - return; - } + if (try_repair(e, j)) + return; } - m_repair_up.insert(e->get_id()); + set_value(e, bval1(e)); } - - void basic_plugin::repair_defs_and_updates() { - if (!m_repair_roots.empty() || - !m_repair_up.empty() || - m_repair_down != UINT_MAX) { - - while (m_repair_down != UINT_MAX) { - auto e = ctx.term(m_repair_down); - repair_down(to_app(e)); - } - - while (!m_repair_up.empty()) { - auto id = m_repair_up.elem_at(rand() % m_repair_up.size()); - auto e = ctx.term(id); - m_repair_up.remove(id); - repair_up(to_app(e)); - } - - if (!m_repair_roots.empty()) { - auto id = m_repair_roots.elem_at(rand() % m_repair_roots.size()); - m_repair_roots.remove(id); - m_repair_down = id; - } - } + bool basic_plugin::try_repair_distinct(app* e, unsigned i) { + return false; } - void basic_plugin::set_value(expr* e, bool b) { + bool basic_plugin::set_value(expr* e, bool b) { + if (m.is_true(e) && !b) + return false; + if (m.is_false(e) && b) + return false; sat::bool_var v = ctx.atom2bool_var(e); if (v == sat::null_bool_var) { if (m_values.get(e->get_id(), b) != b) { m_values.set(e->get_id(), b); - ctx.set_value(e, m.mk_bool_val(b)); + ctx.new_value_eh(e); } } - else if (ctx.is_true(sat::literal(v, false)) != b) { + else if (ctx.is_true(v) != b) { ctx.flip(v); - ctx.set_value(e, m.mk_bool_val(b)); + ctx.new_value_eh(e); } + return true; } - - void basic_plugin::repair_literal(sat::literal lit) { - if (!ctx.is_true(lit)) - return; - auto a = ctx.atom(lit.var()); - if (!a || !is_app(a)) - return; - if (to_app(a)->get_family_id() != basic_family_id) - return; - if (bval1(to_app(a)) != bval0(to_app(a))) - m_repair_roots.insert(a->get_id()); - } - } diff --git a/src/ast/sls/sls_basic_plugin.h b/src/ast/sls/sls_basic_plugin.h index 568ae2877..9aa3aee10 100644 --- a/src/ast/sls/sls_basic_plugin.h +++ b/src/ast/sls/sls_basic_plugin.h @@ -12,20 +12,16 @@ Author: --*/ #pragma once -#include "ast/sls/sls_smt.h" +#include "ast/sls/sls_context.h" namespace sls { class basic_plugin : public plugin { bool_vector m_values; - indexed_uint_set m_repair_up, m_repair_roots; - unsigned m_repair_down = UINT_MAX; bool m_initialized = false; - void init(); bool bval1(app* e) const; bool bval0(expr* e) const; - bool repair_up(expr* e); bool try_repair(app* e, unsigned i); bool try_repair_and_or(app* e, unsigned i); bool try_repair_not(app* e); @@ -33,28 +29,28 @@ namespace sls { bool try_repair_xor(app* e, unsigned i); bool try_repair_ite(app* e, unsigned i); bool try_repair_implies(app* e, unsigned i); - void set_value(expr* e, bool b); - - void repair_down(app* e); - void repair_defs_and_updates(); - void repair_literal(sat::literal lit); + bool try_repair_distinct(app* e, unsigned i); + bool set_value(expr* e, bool b); public: basic_plugin(context& ctx) : plugin(ctx) { + m_fid = basic_family_id; } ~basic_plugin() override {} - void init_bool_var(sat::bool_var v) override {} - void register_term(expr* e) override {} + void register_term(expr* e) override; expr_ref get_value(expr* e) override; - lbool check() override; + void initialize() override; + void propagate_literal(sat::literal lit) override; + bool propagate() override { return false; } + void repair_down(app* e) override; + void repair_up(app* e) override; bool is_sat() override; void on_rescale() override {} void on_restart() override {} std::ostream& display(std::ostream& out) const override; void mk_model(model& mdl) override {} - void set_shared(expr* e) override {} void set_value(expr* e, expr* v) override; }; diff --git a/src/ast/sls/sls_bv_plugin.cpp b/src/ast/sls/sls_bv_plugin.cpp index a97d4f736..0819caf2d 100644 --- a/src/ast/sls/sls_bv_plugin.cpp +++ b/src/ast/sls/sls_bv_plugin.cpp @@ -23,80 +23,44 @@ namespace sls { plugin(ctx), bv(m), m_terms(ctx), - m_eval(m_terms, ctx) - {} + m_eval(m_terms, ctx) { + m_fid = bv.get_family_id(); + } void bv_plugin::register_term(expr* e) { m_terms.register_term(e); } expr_ref bv_plugin::get_value(expr* e) { - return expr_ref(m); + SASSERT(bv.is_bv(e)); + auto const & val = m_eval.wval(e); + return expr_ref(bv.mk_numeral(val.get_value(), e->get_sort()), m); } - - lbool bv_plugin::check() { - if (!m_initialized) { - auto eval = [&](expr* e, unsigned idx) { return false; }; - m_eval.init_eval(eval); - m_initialized = true; - } + void bv_plugin::propagate_literal(sat::literal lit) { + SASSERT(ctx.is_true(lit)); + auto a = ctx.atom(lit.var()); + if (!a || !is_app(a)) + return; + if (!m_eval.eval_is_correct(to_app(a))) + ctx.new_value_eh(a); + } + bool bv_plugin::propagate() { auto& axioms = m_terms.axioms(); if (!axioms.empty()) { for (auto* e : axioms) ctx.add_constraint(e); axioms.reset(); - return l_undef; + return true; } - - // repair each root literal - for (sat::literal lit : ctx.root_literals()) - repair_literal(lit); - - repair_defs_and_updates(); - - // update literal assignment based on current model - for (unsigned v = 0; v < ctx.num_bool_vars(); ++v) - init_bool_var_assignment(v); - - return ctx.unsat().empty() ? l_true : l_undef; + return false; } - void bv_plugin::repair_literal(sat::literal lit) { - if (!ctx.is_true(lit)) - return; - auto a = ctx.atom(lit.var()); - if (!a || !is_app(a)) - return; - if (to_app(a)->get_family_id() != bv.get_family_id()) - return; - if (!m_eval.eval_is_correct(to_app(a))) - m_repair_roots.insert(a->get_id()); - } - - void bv_plugin::repair_defs_and_updates() { - if (!m_repair_roots.empty() || - !m_repair_up.empty() || - m_repair_down != UINT_MAX) { - - while (m_repair_down != UINT_MAX) { - auto e = ctx.term(m_repair_down); - try_repair_down(to_app(e)); - } - - while (!m_repair_up.empty()) { - auto id = m_repair_up.elem_at(rand() % m_repair_up.size()); - auto e = ctx.term(id); - m_repair_up.remove(id); - try_repair_up(to_app(e)); - } - - if (!m_repair_roots.empty()) { - auto id = m_repair_roots.elem_at(rand() % m_repair_roots.size()); - m_repair_roots.remove(id); - m_repair_down = id; - } + void bv_plugin::initialize() { + if (!m_initialized) { + // compute fixed ranges + m_initialized = true; } } @@ -108,38 +72,35 @@ namespace sls { return; bool is_true = m_eval.bval1(to_app(a)); - if (is_true != ctx.is_true(sat::literal(v, false))) + if (is_true != ctx.is_true(v)) ctx.flip(v); } bool bv_plugin::is_sat() { - return false; + for (auto t : ctx.subterms()) + if (is_app(t) && bv.is_bv(t) && !m_eval.eval_is_correct(to_app(t))) + return false; + return true; } std::ostream& bv_plugin::display(std::ostream& out) const { - // m_eval.display(out); - return out; - } - - void bv_plugin::set_shared(expr* e) { + return m_eval.display(out); } void bv_plugin::set_value(expr* e, expr* v) { + if (!bv.is_bv(e)) + return; + rational val; + VERIFY(bv.is_numeral(v, val)); + NOT_IMPLEMENTED_YET(); + // set value of e to val, } - void bv_plugin::try_repair_down(app* e) { + void bv_plugin::repair_down(app* e) { unsigned n = e->get_num_args(); if (n == 0 || m_eval.eval_is_correct(e)) { m_eval.commit_eval(e); - if (!m.is_bool(e)) - for (auto p : ctx.parents(e)) - m_repair_up.insert(p->get_id()); - return; - } - - if (m.is_bool(e)) { - NOT_IMPLEMENTED_YET(); return; } @@ -148,15 +109,15 @@ namespace sls { auto d2 = get_depth(e->get_arg(1)); unsigned s = ctx.rand(d1 + d2 + 2); if (s <= d1 && m_eval.try_repair(e, 0)) { - set_repair_down(e->get_arg(0)); + ctx.new_value_eh(e->get_arg(0)); return; } if (m_eval.try_repair(e, 1)) { - set_repair_down(e->get_arg(1)); + ctx.new_value_eh(e->get_arg(1)); return; } if (m_eval.try_repair(e, 0)) { - set_repair_down(e->get_arg(0)); + ctx.new_value_eh(e->get_arg(0)); return; } } @@ -165,18 +126,16 @@ namespace sls { for (unsigned i = 0; i < n; ++i) { auto j = (i + s) % n; if (m_eval.try_repair(e, j)) { - set_repair_down(e->get_arg(j)); + ctx.new_value_eh(e->get_arg(j)); return; } } } IF_VERBOSE(3, verbose_stream() << "init-repair " << mk_bounded_pp(e, m) << "\n"); - // repair was not successful, so reset the state to find a different way to repair - m_repair_down = UINT_MAX; } - void bv_plugin::try_repair_up(app* e) { - if (m.is_bool(e)) + void bv_plugin::repair_up(app* e) { + if (!bv.is_bv(e)) ; else if (m_eval.repair_up(e)) { if (!m_eval.eval_is_correct(e)) { @@ -184,12 +143,12 @@ namespace sls { } SASSERT(m_eval.eval_is_correct(e)); for (auto p : ctx.parents(e)) - m_repair_up.insert(p->get_id()); + ctx.new_value_eh(p); } else if (ctx.rand(10) != 0) { IF_VERBOSE(2, verbose_stream() << "repair-up "; trace_repair(true, e)); m_eval.set_random(e); - m_repair_roots.insert(e->get_id()); + ctx.new_value_eh(e); } } @@ -202,9 +161,7 @@ namespace sls { void bv_plugin::trace() { IF_VERBOSE(2, verbose_stream() - << "(bvsls :restarts " << m_stats.m_restarts - << " :repair-up " << m_repair_up.size() - << " :repair-roots " << m_repair_roots.size() << ")\n"); + << "(bvsls :restarts " << m_stats.m_restarts << ")\n"); } } diff --git a/src/ast/sls/sls_bv_plugin.h b/src/ast/sls/sls_bv_plugin.h index 99cf4cf12..9657b2e98 100644 --- a/src/ast/sls/sls_bv_plugin.h +++ b/src/ast/sls/sls_bv_plugin.h @@ -16,7 +16,7 @@ Author: --*/ #pragma once -#include "ast/sls/sls_smt.h" +#include "ast/sls/sls_context.h" #include "ast/bv_decl_plugin.h" #include "ast/sls/bv_sls_terms.h" #include "ast/sls/bv_sls_eval.h" @@ -28,38 +28,29 @@ namespace sls { bv::sls_terms m_terms; bv::sls_eval m_eval; bv::sls_stats m_stats; - - indexed_uint_set m_repair_up, m_repair_roots; - unsigned m_repair_down = UINT_MAX; bool m_initialized = false; - void repair_literal(sat::literal lit); - - void repair_defs_and_updates(); - void init_bool_var_assignment(sat::bool_var v); - - void try_repair_down(app* e); - void set_repair_down(expr* e) { m_repair_down = e->get_id(); } - void try_repair_up(app* e); - - std::ostream& bv_plugin::trace_repair(bool down, expr* e); + std::ostream& trace_repair(bool down, expr* e); void trace(); + bool can_propagate(); public: bv_plugin(context& ctx); ~bv_plugin() override {} - void init_bool_var(sat::bool_var v) override {} void register_term(expr* e) override; expr_ref get_value(expr* e) override; - lbool check() override; + void initialize() override; + void propagate_literal(sat::literal lit) override; + bool propagate() override; + void repair_down(app* e) override; + void repair_up(app* e) override; bool is_sat() override; void on_rescale() override {} void on_restart() override {} std::ostream& display(std::ostream& out) const override; void mk_model(model& mdl) override {} - void set_shared(expr* e) override; void set_value(expr* e, expr* v) override; }; diff --git a/src/ast/sls/sls_cc.cpp b/src/ast/sls/sls_cc.cpp index a9c2b3002..609e0bb29 100644 --- a/src/ast/sls/sls_cc.cpp +++ b/src/ast/sls/sls_cc.cpp @@ -84,7 +84,7 @@ namespace sls { return true; } - lbool cc_plugin::check() { + bool cc_plugin::propagate() { bool new_constraint = false; for (auto & [f, ts] : m_app) { if (ts.size() <= 1) @@ -108,7 +108,7 @@ namespace sls { m_values.insert(t); } } - return new_constraint ? l_undef : l_true; + return new_constraint; } std::ostream& cc_plugin::display(std::ostream& out) const { diff --git a/src/ast/sls/sls_cc.h b/src/ast/sls/sls_cc.h index 381652a39..7bcd141d8 100644 --- a/src/ast/sls/sls_cc.h +++ b/src/ast/sls/sls_cc.h @@ -17,7 +17,7 @@ Author: #pragma once #include "util/hashtable.h" -#include "ast/sls/sls_smt.h" +#include "ast/sls/sls_context.h" namespace sls { @@ -39,14 +39,16 @@ namespace sls { ~cc_plugin() override; family_id fid() { return m_fid; } expr_ref get_value(expr* e) override; - lbool check() override; + void initialize() override {} + void propagate_literal(sat::literal lit) override {} + bool propagate() override; bool is_sat() override; void register_term(expr* e) override; - void init_bool_var(sat::bool_var v) override {} std::ostream& display(std::ostream& out) const override; void mk_model(model& mdl) override; void set_value(expr* e, expr* v) override {} - void set_shared(expr* e) override {} + void repair_up(app* e) override {} + void repair_down(app* e) override {} }; } diff --git a/src/ast/sls/sls_smt.cpp b/src/ast/sls/sls_context.cpp similarity index 68% rename from src/ast/sls/sls_smt.cpp rename to src/ast/sls/sls_context.cpp index bb7b047bb..075bc1dfc 100644 --- a/src/ast/sls/sls_smt.cpp +++ b/src/ast/sls/sls_context.cpp @@ -16,11 +16,12 @@ Author: --*/ #pragma once -#include "ast/sls/sls_smt.h" +#include "ast/sls/sls_context.h" #include "ast/sls/sls_cc.h" #include "ast/sls/sls_arith_plugin.h" #include "ast/sls/sls_bv_plugin.h" #include "ast/sls/sls_basic_plugin.h" +#include "ast/ast_ll_pp.h" namespace sls { @@ -30,7 +31,11 @@ namespace sls { } context::context(ast_manager& m, sat_solver_context& s) : - m(m), s(s), m_atoms(m), m_allterms(m) { + m(m), s(s), m_atoms(m), m_allterms(m), + m_gd(*this), + m_ld(*this), + m_repair_down(m.get_num_asts(), m_gd), + m_repair_up(m.get_num_asts(), m_ld) { register_plugin(alloc(cc_plugin, *this)); register_plugin(alloc(arith_plugin, *this)); register_plugin(alloc(bv_plugin, *this)); @@ -56,14 +61,12 @@ namespace sls { // init(); while (unsat().empty()) { - reinit_relevant(); - for (auto p : m_plugins) { - lbool r; - if (p && (r = p->check()) != l_true) - return r; - } - if (m_new_constraint) + + propagate_boolean_assignment(); + + if (m_new_constraint || !unsat().empty()) return l_undef; + if (all_of(m_plugins, [&](auto* p) { return !p || p->is_sat(); })) { model_ref mdl = alloc(model, m); for (expr* e : subterms()) @@ -74,17 +77,72 @@ namespace sls { p->mk_model(*mdl); s.on_model(mdl); verbose_stream() << *mdl << "\n"; + TRACE("sls", display(tout)); return l_true; } } return l_undef; } + void context::propagate_boolean_assignment() { + reinit_relevant(); + + for (sat::literal lit : root_literals()) { + if (m_new_constraint) + break; + propagate_literal(lit); + } + + while (!m_new_constraint && (!m_repair_up.empty() || !m_repair_down.empty())) { + while (!m_repair_down.empty() && !m_new_constraint) { + auto id = m_repair_down.erase_min(); + expr* e = term(id); + if (is_app(e)) { + auto p = m_plugins.get(to_app(e)->get_family_id(), nullptr); + if (p) + p->repair_down(to_app(e)); + } + } + while (!m_repair_up.empty() && !m_new_constraint) { + auto id = m_repair_up.erase_min(); + expr* e = term(id); + if (is_app(e)) { + auto p = m_plugins.get(to_app(e)->get_family_id(), nullptr); + if (p) + p->repair_up(to_app(e)); + } + } + } + // propagate "final checks" + bool propagated = true; + while (propagated && !m_new_constraint) { + propagated = false; + for (auto p : m_plugins) + propagated |= p && !m_new_constraint && p->propagate(); + } + } + + void context::propagate_literal(sat::literal lit) { + if (!is_true(lit)) + return; + auto a = atom(lit.var()); + if (!a || !is_app(a)) + return; + family_id fid = to_app(a)->get_family_id(); + if (m.is_eq(a) || m.is_distinct(a)) + fid = to_app(a)->get_arg(0)->get_sort()->get_family_id(); + auto p = m_plugins.get(fid, nullptr); + if (p) + p->propagate_literal(lit); + } + bool context::is_true(expr* e) { SASSERT(m.is_bool(e)); auto v = m_atom2bool_var.get(e->get_id(), sat::null_bool_var); - SASSERT(v != sat::null_bool_var); - return is_true(sat::literal(v, false)); + if (v != sat::null_bool_var) + return m.is_true(m_plugins[basic_family_id]->get_value(e)); + else + return is_true(v); } bool context::is_fixed(expr* e) { @@ -101,11 +159,11 @@ namespace sls { UNREACHABLE(); return expr_ref(e, m); } - - void context::set_value(expr* e, expr* v) { + + void context::set_value(expr * e, expr * v) { for (auto p : m_plugins) if (p) - p->set_value(e, v); + p->set_value(e, v); } bool context::is_relevant(expr* e) { @@ -148,31 +206,21 @@ namespace sls { v = s.add_var(); register_terms(e); register_atom(v, e); - init_bool_var(v); } return v; } - void context::init_bool_var(sat::bool_var v) { - for (auto p : m_plugins) - if (p) - p->init_bool_var(v); - } - void context::init() { m_new_constraint = false; if (m_initialized) return; m_initialized = true; - register_terms(); - for (sat::bool_var v = 0; v < num_bool_vars(); ++v) - init_bool_var(v); - } - - void context::register_terms() { for (auto a : m_atoms) if (a) register_terms(a); + for (auto p : m_plugins) + if (p) + p->initialize(); } void context::register_terms(expr* e) { @@ -213,6 +261,23 @@ namespace sls { } } + void context::new_value_eh(expr* e) { + DEBUG_CODE( + if (m.is_bool(e)) { + auto v = m_atom2bool_var.get(e->get_id(), sat::null_bool_var); + if (v != sat::null_bool_var) { + SASSERT(m.is_true(get_value(e)) == is_true(v)); + } + } + ); + m_repair_down.reserve(e->get_id() + 1); + m_repair_down.insert(e->get_id()); + for (auto p : parents(e)) { + m_repair_up.reserve(p->get_id() + 1); + m_repair_up.insert(p->get_id()); + } + } + void context::register_term(expr* e) { for (auto p : m_plugins) if (p) @@ -261,10 +326,14 @@ namespace sls { } std::ostream& context::display(std::ostream& out) const { - for (auto p : m_plugins) { + for (auto id : m_repair_down) + out << "d " << mk_bounded_pp(term(id), m) << "\n"; + for (auto id : m_repair_up) + out << "u " << mk_bounded_pp(term(id), m) << "\n"; + for (auto p : m_plugins) if (p) p->display(out); - } + return out; } } diff --git a/src/ast/sls/sls_smt.h b/src/ast/sls/sls_context.h similarity index 79% rename from src/ast/sls/sls_smt.h rename to src/ast/sls/sls_context.h index 1b8e389b7..75a12b4f8 100644 --- a/src/ast/sls/sls_smt.h +++ b/src/ast/sls/sls_context.h @@ -3,7 +3,7 @@ Copyright (c) 2024 Microsoft Corporation Module Name: - smt_sls.h + sls_context.h Abstract: @@ -22,6 +22,7 @@ Author: #include "model/model.h" #include "util/scoped_ptr_vector.h" #include "util/obj_hashtable.h" +#include "util/heap.h" namespace sls { @@ -38,14 +39,16 @@ namespace sls { virtual family_id fid() { return m_fid; } virtual void register_term(expr* e) = 0; virtual expr_ref get_value(expr* e) = 0; - virtual void init_bool_var(sat::bool_var v) = 0; - virtual lbool check() = 0; + virtual void initialize() = 0; + virtual bool propagate() = 0; + virtual void propagate_literal(sat::literal lit) = 0; + virtual void repair_down(app* e) = 0; + virtual void repair_up(app* e) = 0; virtual bool is_sat() = 0; virtual void on_rescale() {}; virtual void on_restart() {}; virtual std::ostream& display(std::ostream& out) const = 0; virtual void mk_model(model& mdl) = 0; - virtual void set_shared(expr* e) = 0; virtual void set_value(expr* e, expr* v) = 0; }; @@ -68,6 +71,22 @@ namespace sls { }; class context { + struct greater_depth { + context& c; + greater_depth(context& c) : c(c) {} + bool operator()(unsigned x, unsigned y) const { + return get_depth(c.term(x)) > get_depth(c.term(y)); + } + }; + + struct less_depth { + context& c; + less_depth(context& c) : c(c) {} + bool operator()(unsigned x, unsigned y) const { + return get_depth(c.term(x)) < get_depth(c.term(y)); + } + }; + ast_manager& m; sat_solver_context& s; scoped_ptr_vector m_plugins; @@ -81,23 +100,27 @@ namespace sls { bool m_new_constraint = false; expr_ref_vector m_allterms; ptr_vector m_subterms; + greater_depth m_gd; + less_depth m_ld; + heap m_repair_down; + heap m_repair_up; void register_plugin(plugin* p); void init(); - void init_bool_var(sat::bool_var v); - void register_terms(); ptr_vector m_todo; void register_terms(expr* e); void register_term(expr* e); sat::bool_var mk_atom(expr* e); + + void propagate_boolean_assignment(); + void propagate_literal(sat::literal lit); public: context(ast_manager& m, sat_solver_context& s); // Between SAT/SMT solver and context. void register_atom(sat::bool_var v, expr* e); - // void reset(); lbool check(); // expose sat_solver to plugins @@ -107,6 +130,7 @@ namespace sls { double get_weight(unsigned clause_idx) { return s.get_weigth(clause_idx); } unsigned num_bool_vars() const { return s.num_vars(); } bool is_true(sat::literal lit) { return s.is_true(lit); } + bool is_true(sat::bool_var v) { return s.is_true(sat::literal(v, false)); } expr* atom(sat::bool_var v) { return m_atoms.get(v, nullptr); } expr* term(unsigned id) const { return m_allterms.get(id); } sat::bool_var atom2bool_var(expr* e) const { return m_atom2bool_var.get(e->get_id(), sat::null_bool_var); } @@ -126,13 +150,15 @@ namespace sls { // Between plugin solvers expr_ref get_value(expr* e); - bool is_true(expr* e); - bool is_fixed(expr* e); void set_value(expr* e, expr* v); + void new_value_eh(expr* e); + bool is_true(expr* e); + bool is_fixed(expr* e); bool is_relevant(expr* e); void add_constraint(expr* e); ptr_vector const& subterms(); ast_manager& get_manager() { return m; } std::ostream& display(std::ostream& out) const; + }; } diff --git a/src/sat/smt/sls_solver.cpp b/src/sat/smt/sls_solver.cpp index 035385b65..3c3d4566c 100644 --- a/src/sat/smt/sls_solver.cpp +++ b/src/sat/smt/sls_solver.cpp @@ -17,7 +17,7 @@ Author: #include "sat/smt/sls_solver.h" #include "sat/smt/euf_solver.h" -#include "ast/sls/sls_smt.h" +#include "ast/sls/sls_context.h" namespace sls { diff --git a/src/tactic/sls/sls_tactic.cpp b/src/tactic/sls/sls_tactic.cpp index f438d33a0..cfe03b6da 100644 --- a/src/tactic/sls/sls_tactic.cpp +++ b/src/tactic/sls/sls_tactic.cpp @@ -29,7 +29,103 @@ Notes: #include "tactic/sls/sls_tactic.h" #include "params/sls_params.hpp" #include "ast/sls/sls_engine.h" -#include "ast/sls/bv_sls.h" +#include "ast/sls/sls_smt_solver.h" + +class sls_smt_tactic : public tactic { + ast_manager& m; + params_ref m_params; + sls::smt_solver* m_sls; + statistics m_st; + +public: + sls_smt_tactic(ast_manager& _m, params_ref const& p) : + m(_m), + m_params(p) { + m_sls = alloc(sls::smt_solver, m, p); + } + + tactic* translate(ast_manager& m) override { + return alloc(sls_smt_tactic, m, m_params); + } + + ~sls_smt_tactic() override { + dealloc(m_sls); + } + + char const* name() const override { return "sls-smt"; } + + void updt_params(params_ref const& p) override { + m_params.append(p); + m_sls->updt_params(m_params); + } + + void collect_param_descrs(param_descrs& r) override { + sls_params::collect_param_descrs(r); + } + + void run(goal_ref const& g, model_converter_ref& mc) { + if (g->inconsistent()) { + mc = nullptr; + return; + } + + for (unsigned i = 0; i < g->size(); i++) + m_sls->assert_expr(g->form(i)); + + + lbool res = m_sls->check(); + m_st.reset(); + m_sls->collect_statistics(m_st); +// report_tactic_progress("Number of flips:", m_sls->get_num_moves()); + IF_VERBOSE(10, verbose_stream() << res << "\n"); + IF_VERBOSE(10, m_sls->display(verbose_stream())); + + if (res == l_true) { + if (g->models_enabled()) { + model_ref mdl = m_sls->get_model(); + mc = model2model_converter(mdl.get()); + TRACE("sls_model", mc->display(tout);); + } + g->reset(); + } + else + mc = nullptr; + + } + + void operator()(goal_ref const& g, + goal_ref_buffer& result) override { + result.reset(); + + TRACE("sls", g->display(tout);); + tactic_report report("sls", *g); + + model_converter_ref mc; + run(g, mc); + g->add(mc.get()); + g->inc_depth(); + result.push_back(g.get()); + } + + void cleanup() override { + auto* d = alloc(sls::smt_solver, m, m_params); + std::swap(d, m_sls); + dealloc(d); + } + + void collect_statistics(statistics& st) const override { + st.copy(m_st); + } + + void reset_statistics() override { + m_sls->reset_statistics(); + m_st.reset(); + } +}; + +tactic* mk_sls_smt_tactic(ast_manager& m, params_ref const& p) { + return alloc(sls_smt_tactic, m, p); +} class sls_tactic : public tactic { ast_manager & m; diff --git a/src/tactic/sls/sls_tactic.h b/src/tactic/sls/sls_tactic.h index 867474319..b6a42a783 100644 --- a/src/tactic/sls/sls_tactic.h +++ b/src/tactic/sls/sls_tactic.h @@ -23,9 +23,11 @@ class ast_manager; class tactic; tactic * mk_qfbv_sls_tactic(ast_manager & m, params_ref const & p = params_ref()); +tactic * mk_sls_smt_tactic(ast_manager & m, params_ref const & p = params_ref()); /* ADD_TACTIC("qfbv-sls", "(try to) solve using stochastic local search for QF_BV.", "mk_qfbv_sls_tactic(m, p)") + ADD_TACTIC("sls-smt", "(try to) solve SMT formulas using local search.", "mk_sls_smt_tactic(m, p)") */