diff --git a/src/ast/sls/CMakeLists.txt b/src/ast/sls/CMakeLists.txt index e1f35e583..ae1533085 100644 --- a/src/ast/sls/CMakeLists.txt +++ b/src/ast/sls/CMakeLists.txt @@ -7,7 +7,8 @@ z3_add_component(ast_sls sat_ddfw.cpp sls_arith_base.cpp sls_arith_plugin.cpp - sls_bv.cpp + sls_basic_plugin.cpp + sls_bv_plugin.cpp sls_cc.cpp sls_engine.cpp sls_smt.cpp diff --git a/src/ast/sls/bv_sls_eval.cpp b/src/ast/sls/bv_sls_eval.cpp index 27a4acf4e..7c7afbeea 100644 --- a/src/ast/sls/bv_sls_eval.cpp +++ b/src/ast/sls/bv_sls_eval.cpp @@ -26,7 +26,7 @@ namespace bv { {} void sls_eval::init_eval(std::function const& eval) { - for (expr* e : terms.subterms()) { + for (expr* e : ctx.subterms()) { if (!is_app(e)) continue; app* a = to_app(e); @@ -68,6 +68,7 @@ namespace bv { m_tmp2.push_back(0); m_tmp3.push_back(0); m_tmp4.push_back(0); + m_mul_tmp.push_back(0); m_zero.push_back(0); m_one.push_back(0); m_a.push_back(0); @@ -272,31 +273,46 @@ namespace bv { break; } case OP_BAND: { - SASSERT(e->get_num_args() == 2); + SASSERT(e->get_num_args() >= 2); auto const& a = wval(e->get_arg(0)); auto const& b = wval(e->get_arg(1)); for (unsigned i = 0; i < a.nw; ++i) val.eval[i] = a.bits()[i] & b.bits()[i]; + for (unsigned j = 2; j < e->get_num_args(); ++j) { + auto const& c = wval(e->get_arg(j)); + for (unsigned i = 0; i < a.nw; ++i) + val.eval[i] &= c.bits()[i]; + } break; } case OP_BOR: { - SASSERT(e->get_num_args() == 2); + SASSERT(e->get_num_args() >= 2); auto const& a = wval(e->get_arg(0)); auto const& b = wval(e->get_arg(1)); for (unsigned i = 0; i < a.nw; ++i) val.eval[i] = a.bits()[i] | b.bits()[i]; + for (unsigned j = 2; j < e->get_num_args(); ++j) { + auto const& c = wval(e->get_arg(j)); + for (unsigned i = 0; i < a.nw; ++i) + val.eval[i] |= c.bits()[i]; + } break; } case OP_BXOR: { - SASSERT(e->get_num_args() == 2); + SASSERT(e->get_num_args() >= 2); auto const& a = wval(e->get_arg(0)); auto const& b = wval(e->get_arg(1)); for (unsigned i = 0; i < a.nw; ++i) val.eval[i] = a.bits()[i] ^ b.bits()[i]; + for (unsigned j = 2; j < e->get_num_args(); ++j) { + auto const& c = wval(e->get_arg(j)); + for (unsigned i = 0; i < a.nw; ++i) + val.eval[i] ^= c.bits()[i]; + } break; } case OP_BNAND: { - SASSERT(e->get_num_args() == 2); + VERIFY(e->get_num_args() == 2); auto const& a = wval(e->get_arg(0)); auto const& b = wval(e->get_arg(1)); for (unsigned i = 0; i < a.nw; ++i) @@ -304,10 +320,15 @@ namespace bv { break; } case OP_BADD: { - SASSERT(e->get_num_args() == 2); + SASSERT(e->get_num_args() >= 2); auto const& a = wval(e->get_arg(0)); auto const& b = wval(e->get_arg(1)); - val.set_add(val.eval, a.bits(), b.bits()); + for (unsigned i = 0; i < a.nw; ++i) + val.set_add(val.eval, a.bits(), b.bits()); + for (unsigned j = 2; j < e->get_num_args(); ++j) { + auto const& c = wval(e->get_arg(j)); + val.set_add(val.eval, val.eval, c.bits()); + } break; } case OP_BSUB: { @@ -318,11 +339,14 @@ namespace bv { break; } case OP_BMUL: { - SASSERT(e->get_num_args() == 2); auto const& a = wval(e->get_arg(0)); auto const& b = wval(e->get_arg(1)); - val.set_mul(m_tmp2, a.bits(), b.bits()); - val.set(m_tmp2); + for (unsigned i = 0; i < a.nw; ++i) + val.set_mul(val.eval, a.bits(), b.bits()); + for (unsigned j = 2; j < e->get_num_args(); ++j) { + auto const& c = wval(e->get_arg(j)); + val.set_mul(val.eval, val.eval, c.bits()); + } break; } case OP_CONCAT: { @@ -600,17 +624,43 @@ namespace bv { bool sls_eval::try_repair_bv(app* e, unsigned i) { switch (e->get_decl_kind()) { case OP_BAND: - return try_repair_band(eval_value(e), wval(e, i), wval(e, 1 - i)); + SASSERT(e->get_num_args() >= 2); + if (e->get_num_args() == 2) + return try_repair_band(eval_value(e), wval(e, i), wval(e, 1 - i)); + else + return try_repair_band(e, i); case OP_BOR: - return try_repair_bor(eval_value(e), wval(e, i), wval(e, 1 - i)); + SASSERT(e->get_num_args() >= 2); + if (e->get_num_args() == 2) + return try_repair_bor(eval_value(e), wval(e, i), wval(e, 1 - i)); + else + return try_repair_bor(e, i); case OP_BXOR: - return try_repair_bxor(eval_value(e), wval(e, i), wval(e, 1 - i)); + SASSERT(e->get_num_args() >= 2); + if (e->get_num_args() == 2) + return try_repair_bxor(eval_value(e), wval(e, i), wval(e, 1 - i)); + else + return try_repair_bxor(e, i); case OP_BADD: - return try_repair_add(eval_value(e), wval(e, i), wval(e, 1 - i)); + SASSERT(e->get_num_args() >= 2); + if (e->get_num_args() == 2) + return try_repair_add(eval_value(e), wval(e, i), wval(e, 1 - i)); + else + return try_repair_add(e, i); case OP_BSUB: return try_repair_sub(eval_value(e), wval(e, 0), wval(e, 1), i); case OP_BMUL: - return try_repair_mul(eval_value(e), wval(e, i), wval(e, 1 - i)); + SASSERT(e->get_num_args() >= 2); + if (e->get_num_args() == 2) + return try_repair_mul(eval_value(e), wval(e, i), eval_value(to_app(e->get_arg(1 - i)))); + else { + auto const& a = wval(e, 0); + auto f = [&](bvect& out, bvval const& c) { + a.set_mul(out, out, c.bits()); + }; + fold_oper(m_mul_tmp, e, i, f); + return try_repair_mul(eval_value(e), wval(e, i), m_mul_tmp); + } case OP_BNOT: return try_repair_bnot(eval_value(e), wval(e, i)); case OP_BNEG: @@ -734,8 +784,9 @@ namespace bv { case OP_BSDIV_I: case OP_BSDIV0: // these are currently compiled to udiv and urem. - UNREACHABLE(); - return false; + // there is an equation that enforces equality between the semantics + // of these operators. + return true; default: return false; } @@ -787,6 +838,19 @@ namespace bv { } } + void sls_eval::fold_oper(bvect& out, app* t, unsigned i, std::function const& f) { + auto i2 = i == 0 ? 1 : 0; + auto const& c = wval(t->get_arg(i2)); + for (unsigned j = 0; j < c.nw; ++j) + out[j] = c.bits()[j]; + for (unsigned k = 1; k < t->get_num_args(); ++k) { + if (k == i || k == i2) + continue; + bvval const& c = wval(t->get_arg(k)); + f(out, c); + } + } + // // e = a & b // e[i] = 1 -> a[i] = 1 @@ -800,6 +864,21 @@ namespace bv { return a.set_repair(random_bool(), m_tmp); } + bool sls_eval::try_repair_band(app* t, unsigned i) { + bvect const& e = eval_value(t); + auto f = [&](bvect& out, bvval const& c) { + for (unsigned j = 0; j < c.nw; ++j) + out[j] &= c.bits()[j]; + }; + fold_oper(m_tmp2, t, i, f); + + bvval& a = wval(t, i); + for (unsigned j = 0; j < a.nw; ++j) + m_tmp[j] = ~a.fixed[j] & (e[j] | (~m_tmp2[j] & random_bits())); + + return a.set_repair(random_bool(), m_tmp); + } + // // e = a | b // set a[i] to 1 where b[i] = 0, e[i] = 1 @@ -811,6 +890,20 @@ namespace bv { return a.set_repair(random_bool(), m_tmp); } + bool sls_eval::try_repair_bor(app* t, unsigned i) { + bvect const& e = eval_value(t); + auto f = [&](bvect& out, bvval const& c) { + for (unsigned j = 0; j < c.nw; ++j) + out[j] |= c.bits()[j]; + }; + fold_oper(m_tmp2, t, i, f); + bvval& a = wval(t, i); + for (unsigned j = 0; j < a.nw; ++j) + m_tmp[j] = e[i] & (~m_tmp2[i] | random_bits()); + + return a.set_repair(random_bool(), m_tmp); + } + bool sls_eval::try_repair_bxor(bvect const& e, bvval& a, bvval const& b) { for (unsigned i = 0; i < a.nw; ++i) m_tmp[i] = e[i] ^ b.bits()[i]; @@ -818,6 +911,23 @@ namespace bv { } + + bool sls_eval::try_repair_bxor(app* t, unsigned i) { + bvect const& e = eval_value(t); + auto f = [&](bvect& out, bvval const& c) { + for (unsigned j = 0; j < c.nw; ++j) + out[j] ^= c.bits()[j]; + }; + fold_oper(m_tmp2, t, i, f); + + bvval& a = wval(t, i); + for (unsigned j = 0; j < a.nw; ++j) + m_tmp[j] = e[i] ^ m_tmp2[i]; + + return a.set_repair(random_bool(), m_tmp); + } + + // // first try to set a := e - b // If this fails, set a to a random value @@ -831,6 +941,22 @@ namespace bv { return a.set_random(m_rand); } + bool sls_eval::try_repair_add(app* t, unsigned i) { + bvval& a = wval(t, i); + bvect const& e = eval_value(t); + if (m_rand(20) != 0) { + auto f = [&](bvect& out, bvval const& c) { + a.set_add(m_tmp2, m_tmp2, c.bits()); + }; + fold_oper(m_tmp2, t, i, f); + a.set_sub(m_tmp, e, m_tmp2); + if (a.try_set(m_tmp)) + return true; + } + return a.set_random(m_rand); + + } + bool sls_eval::try_repair_sub(bvect const& e, bvval& a, bvval & b, unsigned i) { if (m_rand(20) != 0) { if (i == 0) @@ -850,11 +976,11 @@ namespace bv { * e = a*b, then a = e * b^-1 * 8*e = a*(2b), then a = 4e*b^-1 */ - bool sls_eval::try_repair_mul(bvect const& e, bvval& a, bvval const& b) { - unsigned parity_e = b.parity(e); - unsigned parity_b = b.parity(b.bits()); + bool sls_eval::try_repair_mul(bvect const& e, bvval& a, bvect const& b) { + unsigned parity_e = a.parity(e); + unsigned parity_b = a.parity(b); - if (b.is_zero(e)) { + if (a.is_zero(e)) { a.get_variant(m_tmp, m_rand); if (m_rand(10) != 0) for (unsigned i = 0; i < b.bw - parity_b; ++i) @@ -862,7 +988,7 @@ namespace bv { return a.set_repair(random_bool(), m_tmp); } - if (b.is_zero() || m_rand(20) == 0) { + if (m_rand(20) == 0) { a.get_variant(m_tmp, m_rand); return a.set_repair(random_bool(), m_tmp); } @@ -890,9 +1016,9 @@ namespace bv { // x*ta + y*tb = x - b.get(y); + b.copy_to(a.nw, y); if (parity_b > 0) { - b.shift_right(y, parity_b); + a.shift_right(y, parity_b); #if 0 for (unsigned i = parity_b; i < b.bw; ++i) y.set(i, m_rand(2) == 0); @@ -937,15 +1063,15 @@ namespace bv { tb.set_bw(0); #if Z3DEBUG - b.get(y); + b.copy_to(a.nw, y); if (parity_b > 0) - b.shift_right(y, parity_b); + a.shift_right(y, parity_b); a.set_mul(m_tmp, tb, y); SASSERT(a.is_one(m_tmp)); #endif e.copy_to(b.nw, m_tmp2); if (parity_e > 0 && parity_b > 0) - b.shift_right(m_tmp2, std::min(parity_b, parity_e)); + a.shift_right(m_tmp2, std::min(parity_b, parity_e)); a.set_mul(m_tmp, tb, m_tmp2); if (a.set_repair(random_bool(), m_tmp)) return true; @@ -1773,17 +1899,16 @@ namespace bv { return expr_ref(m); } - std::ostream& sls_eval::display(std::ostream& out, expr_ref_vector const& es) { -#if 0 - auto& terms = sort_assertions(es); + std::ostream& sls_eval::display(std::ostream& out) { + auto& terms = ctx.subterms(); for (expr* e : terms) { + if (!bv.is_bv(e)) + continue; out << e->get_id() << ": " << mk_bounded_pp(e, m, 1) << " "; if (is_fixed0(e)) out << "f "; display_value(out, e) << "\n"; } - terms.reset(); -#endif return out; } diff --git a/src/ast/sls/bv_sls_eval.h b/src/ast/sls/bv_sls_eval.h index 9f087c4f4..943d731a6 100644 --- a/src/ast/sls/bv_sls_eval.h +++ b/src/ast/sls/bv_sls_eval.h @@ -47,7 +47,7 @@ namespace bv { scoped_ptr_vector m_values; // expr-id -> bv valuation - mutable bvect m_tmp, m_tmp2, m_tmp3, m_tmp4, m_zero, m_one, m_minus_one; + mutable bvect m_tmp, m_tmp2, m_tmp3, m_tmp4, m_mul_tmp, m_zero, m_one, m_minus_one; bvect m_a, m_b, m_nextb, m_nexta, m_aux; using bvval = sls_valuation; @@ -64,16 +64,21 @@ namespace bv { //bool bval1_basic(app* e) const; bool bval1_bv(app* e) const; + void fold_oper(bvect& out, app* e, unsigned i, std::function const& f); /** * Repair operations */ bool try_repair_bv(app * e, unsigned i); bool try_repair_band(bvect const& e, bvval& a, bvval const& b); + bool try_repair_band(app* t, unsigned i); bool try_repair_bor(bvect const& e, bvval& a, bvval const& b); + bool try_repair_bor(app* t, unsigned i); bool try_repair_add(bvect const& e, bvval& a, bvval const& b); + bool try_repair_add(app* t, unsigned i); bool try_repair_sub(bvect const& e, bvval& a, bvval& b, unsigned i); - bool try_repair_mul(bvect const& e, bvval& a, bvval const& b); + bool try_repair_mul(bvect const& e, bvval& a, bvect const& b); bool try_repair_bxor(bvect const& e, bvval& a, bvval const& b); + bool try_repair_bxor(app* t, unsigned i); bool try_repair_bnot(bvect const& e, bvval& a); bool try_repair_bneg(bvect const& e, bvval& a); bool try_repair_ule(bool e, bvval& a, bvval const& b); @@ -125,7 +130,7 @@ namespace bv { /** * Retrieve evaluation based on immediate children. */ - bool bval1(app* e) const; + bool can_eval1(app* e) const; public: @@ -158,6 +163,8 @@ namespace bv { bool re_eval_is_correct(app* e); expr_ref get_value(app* e); + + bool bval1(app* e) const; /* * Try to invert value of child to repair value assignment of parent. @@ -171,7 +178,7 @@ namespace bv { bool repair_up(expr* e); - std::ostream& display(std::ostream& out, expr_ref_vector const& es); + std::ostream& display(std::ostream& out); std::ostream& display_value(std::ostream& out, expr* e); }; diff --git a/src/ast/sls/bv_sls_fixed.cpp b/src/ast/sls/bv_sls_fixed.cpp index 7613797dc..530fc4b43 100644 --- a/src/ast/sls/bv_sls_fixed.cpp +++ b/src/ast/sls/bv_sls_fixed.cpp @@ -28,7 +28,7 @@ namespace bv { {} void sls_fixed::init() { - for (auto e : terms.subterms()) + for (auto e : ctx.subterms()) set_fixed(e); for (auto const& c : ctx.clauses()) { @@ -37,13 +37,12 @@ namespace bv { auto a = ctx.atom(lit.var()); if (!a) continue; - a = terms.translated(a); if (is_app(a)) init_range(to_app(a), lit.sign()); ev.m_fixed.setx(a->get_id(), true, false); } } - for (auto e : terms.subterms()) + for (auto e : ctx.subterms()) propagate_range_up(e); } diff --git a/src/ast/sls/bv_sls_terms.cpp b/src/ast/sls/bv_sls_terms.cpp index 084fc11b8..2df076a20 100644 --- a/src/ast/sls/bv_sls_terms.cpp +++ b/src/ast/sls/bv_sls_terms.cpp @@ -3,14 +3,11 @@ Copyright (c) 2024 Microsoft Corporation Module Name: - bv_sls.cpp + bv_sls_terms.cpp Abstract: - A Stochastic Local Search (SLS) engine - Uses invertibility conditions, - interval annotations - don't care annotations + normalize bit-vector expressions to use only binary operators. Author: @@ -19,7 +16,7 @@ Author: --*/ #include "ast/ast_ll_pp.h" -#include "ast/sls/bv_sls.h" +#include "ast/sls/bv_sls_terms.h" #include "ast/rewriter/bool_rewriter.h" #include "ast/rewriter/bv_rewriter.h" @@ -29,38 +26,16 @@ namespace bv { ctx(ctx), m(ctx.get_manager()), bv(m), - m_translated(m) {} + m_axioms(m) {} - void sls_terms::init() { - for (auto t : ctx.subterms()) - ensure_binary(t); - - m_subterms.reset(); - expr_fast_mark1 visited; - for (auto t : ctx.subterms()) - m_subterms.push_back(translated(t)); - for (auto t : m_subterms) - visited.mark(t, true); - for (unsigned i = 0; i < m_subterms.size(); ++i) { - auto t = m_subterms[i]; - if (!is_app(t)) - continue; - app* a = to_app(t); - for (expr* arg : *a) { - if (visited.is_marked(arg)) - continue; - visited.mark(arg, true); - m_subterms.push_back(arg); - } - } - std::stable_sort(m_subterms.begin(), m_subterms.end(), - [](expr* a, expr* b) { return a->get_id() < b->get_id(); }); + void sls_terms::register_term(expr* e) { + auto r = ensure_binary(e); + if (r != e) + m_axioms.push_back(m.mk_eq(e, r)); } - void sls_terms::ensure_binary(expr* e) { - if (m_translated.get(e->get_id(), nullptr)) - return; - + expr_ref sls_terms::ensure_binary(expr* e) { + app* a = to_app(e); auto arg = [&](unsigned i) { return a->get_arg(i); @@ -72,22 +47,7 @@ namespace bv { for (unsigned i = 1; i < num_args; ++i)\ r = oper(r, arg(i)); \ - if (bv.is_bv_and(e)) { - FOLD_OP(bv.mk_bv_and); - } - else if (bv.is_bv_or(e)) { - FOLD_OP(bv.mk_bv_or); - } - else if (bv.is_bv_xor(e)) { - FOLD_OP(bv.mk_bv_xor); - } - else if (bv.is_bv_add(e)) { - FOLD_OP(bv.mk_bv_add); - } - else if (bv.is_bv_mul(e)) { - FOLD_OP(bv.mk_bv_mul); - } - else if (bv.is_concat(e)) { + if (bv.is_concat(e)) { FOLD_OP(bv.mk_concat); } else if (bv.is_bv_sdiv(e) || bv.is_bv_sdiv0(e) || bv.is_bv_sdivi(e)) { @@ -101,7 +61,7 @@ namespace bv { } else r = e; - m_translated.setx(e->get_id(), r); + return r; } expr_ref sls_terms::mk_sdiv(expr* x, expr* y) { @@ -118,14 +78,16 @@ namespace bv { unsigned sz = bv.get_bv_size(x); rational N = rational::power_of_two(sz); expr_ref z(bv.mk_zero(sz), m); - expr* signx = bvr.mk_ule(bv.mk_numeral(N / 2, sz), x); - expr* signy = bvr.mk_ule(bv.mk_numeral(N / 2, sz), y); - expr* absx = br.mk_ite(signx, bvr.mk_bv_neg(x), x); - expr* absy = br.mk_ite(signy, bvr.mk_bv_neg(y), y); - expr* d = bv.mk_bv_udiv(absx, absy); - expr_ref r(br.mk_ite(br.mk_eq(signx, signy), d, bvr.mk_bv_neg(d)), m); + expr_ref o(bv.mk_one(sz), m); + expr_ref n1(bv.mk_numeral(N - 1, sz), m); + expr_ref signx = bvr.mk_ule(bv.mk_numeral(N / 2, sz), x); + expr_ref signy = bvr.mk_ule(bv.mk_numeral(N / 2, sz), y); + expr_ref absx = br.mk_ite(signx, bvr.mk_bv_neg(x), x); + expr_ref absy = br.mk_ite(signy, bvr.mk_bv_neg(y), y); + expr_ref d = expr_ref(bv.mk_bv_udiv(absx, absy), m); + expr_ref r = br.mk_ite(br.mk_eq(signx, signy), d, bvr.mk_bv_neg(d)); r = br.mk_ite(br.mk_eq(z, y), - br.mk_ite(signx, bv.mk_one(sz), bv.mk_numeral(N - 1, sz)), + br.mk_ite(signx, o, n1), br.mk_ite(br.mk_eq(x, z), z, r)); return r; } @@ -142,9 +104,9 @@ namespace bv { bv_rewriter bvr(m); unsigned sz = bv.get_bv_size(x); expr_ref z(bv.mk_zero(sz), m); - expr_ref abs_x(br.mk_ite(bvr.mk_sle(z, x), x, bvr.mk_bv_neg(x)), m); - expr_ref abs_y(br.mk_ite(bvr.mk_sle(z, y), y, bvr.mk_bv_neg(y)), m); - expr_ref u(bvr.mk_bv_urem(abs_x, abs_y), m); + expr_ref abs_x = br.mk_ite(bvr.mk_sle(z, x), x, bvr.mk_bv_neg(x)); + expr_ref abs_y = br.mk_ite(bvr.mk_sle(z, y), y, bvr.mk_bv_neg(y)); + expr_ref u = bvr.mk_bv_urem(abs_x, abs_y); expr_ref r(m); r = br.mk_ite(br.mk_eq(u, z), z, br.mk_ite(br.mk_eq(y, z), x, diff --git a/src/ast/sls/bv_sls_terms.h b/src/ast/sls/bv_sls_terms.h index a35ea5025..93b703e37 100644 --- a/src/ast/sls/bv_sls_terms.h +++ b/src/ast/sls/bv_sls_terms.h @@ -32,10 +32,9 @@ namespace bv { sls::context& ctx; ast_manager& m; bv_util bv; - expr_ref_vector m_translated; - ptr_vector m_subterms; + expr_ref_vector m_axioms; - void ensure_binary(expr* e); + expr_ref ensure_binary(expr* e); expr_ref mk_sdiv(expr* x, expr* y); expr_ref mk_smod(expr* x, expr* y); @@ -44,14 +43,8 @@ namespace bv { public: sls_terms(sls::context& ctx); - /** - * Initialize structures: assertions, parents, terms - */ - void init(); - - expr* translated(expr* e) const { return m_translated.get(e->get_id(), nullptr); } - - ptr_vector const& subterms() const { return m_subterms; } + void register_term(expr* e); + expr_ref_vector& axioms() { return m_axioms; } }; } diff --git a/src/ast/sls/sls_arith_base.cpp b/src/ast/sls/sls_arith_base.cpp index 1c6e7e404..1188e6511 100644 --- a/src/ast/sls/sls_arith_base.cpp +++ b/src/ast/sls/sls_arith_base.cpp @@ -27,13 +27,6 @@ namespace sls { m_fid = a.get_family_id(); } - template - void arith_base::reset() { - m_bool_vars.reset(); - m_vars.reset(); - m_expr2var.reset(); - } - template void arith_base::save_best_values() { for (auto& v : m_vars) @@ -1070,11 +1063,6 @@ namespace sls { template void arith_base::mk_model(model& mdl) { - for (auto const& v : m_vars) { - expr* e = v.m_expr; - if (is_uninterp_const(e)) - mdl.register_decl(to_app(e)->get_decl(), get_value(e)); - } } } diff --git a/src/ast/sls/sls_arith_base.h b/src/ast/sls/sls_arith_base.h index 135e8e2e7..21806f574 100644 --- a/src/ast/sls/sls_arith_base.h +++ b/src/ast/sls/sls_arith_base.h @@ -189,7 +189,6 @@ namespace sls { expr_ref get_value(expr* e) override; lbool check() override; bool is_sat() override; - void reset() override; void on_rescale() override; void on_restart() override; std::ostream& display(std::ostream& out) const override; diff --git a/src/ast/sls/sls_arith_plugin.cpp b/src/ast/sls/sls_arith_plugin.cpp index e8d237fb0..31303818d 100644 --- a/src/ast/sls/sls_arith_plugin.cpp +++ b/src/ast/sls/sls_arith_plugin.cpp @@ -86,13 +86,6 @@ namespace sls { return m_arith64->is_sat(); return m_arith->is_sat(); } - void arith_plugin::reset() { - if (m_arith) - m_arith->reset(); - else - m_arith64->reset(); - m_shared.reset(); - } void arith_plugin::on_rescale() { if (m_arith) diff --git a/src/ast/sls/sls_arith_plugin.h b/src/ast/sls/sls_arith_plugin.h index 1686cf3b2..4a1d71deb 100644 --- a/src/ast/sls/sls_arith_plugin.h +++ b/src/ast/sls/sls_arith_plugin.h @@ -36,7 +36,6 @@ namespace sls { expr_ref get_value(expr* e) override; lbool check() override; bool is_sat() override; - void reset() override; void on_rescale() override; void on_restart() override; diff --git a/src/ast/sls/sls_basic_plugin.cpp b/src/ast/sls/sls_basic_plugin.cpp new file mode 100644 index 000000000..18c5599bd --- /dev/null +++ b/src/ast/sls/sls_basic_plugin.cpp @@ -0,0 +1,313 @@ +/*++ +Copyright (c) 2024 Microsoft Corporation + +Module Name: + + sls_basic_plugin.cpp + +Abstract: + + Local search dispatch for Boolean connectives + +Author: + + Nikolaj Bjorner (nbjorner) 2024-07-07 + +--*/ + +#include "ast/sls/sls_basic_plugin.h" +#include "ast/ast_ll_pp.h" + +namespace sls { + + expr_ref basic_plugin::get_value(expr* e) { + return expr_ref(m.mk_bool_val(bval0(e)), m); + } + + lbool basic_plugin::check() { + init(); + for (sat::literal lit : ctx.root_literals()) + repair_literal(lit); + repair_defs_and_updates(); + return ctx.unsat().empty() ? l_true : l_undef; + } + + void basic_plugin::init() { + m_repair_down = UINT_MAX; + m_repair_roots.reset(); + m_repair_up.reset(); + if (m_initialized) + return; + m_initialized = true; + for (auto t : ctx.subterms()) + if (is_app(t) && m.is_bool(t) && to_app(t)->get_family_id() == basic_family_id) + m_values.setx(t->get_id(), bval1(to_app(t)), false); + } + + bool basic_plugin::is_sat() { + for (auto t : ctx.subterms()) + if (is_app(t) && + m.is_bool(t) && + to_app(t)->get_family_id() == basic_family_id && + bval0(t) != bval1(to_app(t))) + return false; + return true; + } + + + std::ostream& basic_plugin::display(std::ostream& out) const { + for (auto t : ctx.subterms()) + if (is_app(t) && m.is_bool(t) && to_app(t)->get_family_id() == basic_family_id) + out << mk_bounded_pp(t, m) << " " << bval0(t) << " ~ " << bval1(to_app(t)) << "\n"; + return out; + } + + void basic_plugin::set_value(expr* e, expr* v) { + if (!m.is_bool(e)) + return; + SASSERT(m.is_bool(v)); + SASSERT(m.is_true(v) || m.is_false(v)); + if (bval0(e) != m.is_true(v)) + return; + set_value(e, m.is_true(v)); + m_repair_roots.insert(e->get_id()); + } + + bool basic_plugin::bval1(app* e) const { + SASSERT(m.is_bool(e)); + SASSERT(e->get_family_id() == basic_family_id); + + auto id = e->get_id(); + switch (e->get_decl_kind()) { + case OP_TRUE: + return true; + case OP_FALSE: + return false; + case OP_AND: + return all_of(*to_app(e), [&](expr* arg) { return bval0(arg); }); + case OP_OR: + return any_of(*to_app(e), [&](expr* arg) { return bval0(arg); }); + case OP_NOT: + return !bval0(e->get_arg(0)); + case OP_XOR: { + bool r = false; + for (auto* arg : *to_app(e)) + r ^= bval0(arg); + return r; + } + case OP_IMPLIES: { + auto a = e->get_arg(0); + auto b = e->get_arg(1); + return !bval0(a) || bval0(b); + } + case OP_ITE: { + auto c = bval0(e->get_arg(0)); + return bval0(c ? e->get_arg(1) : e->get_arg(2)); + } + case OP_EQ: { + auto a = e->get_arg(0); + auto b = e->get_arg(1); + if (m.is_bool(a)) + return bval0(a) == bval0(b); + return ctx.get_value(a) == ctx.get_value(b); + } + case OP_DISTINCT: { + for (unsigned i = 0; i < e->get_num_args(); ++i) + for (unsigned j = i + 1; j < e->get_num_args(); ++j) + if (ctx.get_value(e->get_arg(i)) == ctx.get_value(e->get_arg(j))) + return false; + return true; + } + default: + verbose_stream() << mk_bounded_pp(e, m) << "\n"; + UNREACHABLE(); + break; + } + UNREACHABLE(); + return false; + } + + bool basic_plugin::bval0(expr* e) const { + SASSERT(m.is_bool(e)); + sat::bool_var v = ctx.atom2bool_var(e); + if (v == sat::null_bool_var) + return m_values.get(e->get_id(), false); + else + return ctx.is_true(sat::literal(v, false)); + } + + bool basic_plugin::try_repair(app* e, unsigned i) { + switch (e->get_decl_kind()) { + case OP_AND: + return try_repair_and_or(e, i); + case OP_OR: + return try_repair_and_or(e, i); + case OP_NOT: + return try_repair_not(e); + case OP_FALSE: + return false; + case OP_TRUE: + return false; + case OP_EQ: + return try_repair_eq(e, i); + case OP_IMPLIES: + return try_repair_implies(e, i); + case OP_XOR: + return try_repair_xor(e, i); + case OP_ITE: + return try_repair_ite(e, i); + case OP_DISTINCT: + NOT_IMPLEMENTED_YET(); + return false; + default: + UNREACHABLE(); + return false; + } + } + + bool basic_plugin::try_repair_and_or(app* e, unsigned i) { + auto b = bval0(e); + auto child = e->get_arg(i); + if (b == bval0(child)) + return false; + set_value(child, b); + return true; + } + + bool basic_plugin::try_repair_not(app* e) { + auto child = e->get_arg(0); + set_value(child, !bval0(e)); + return true; + } + + bool basic_plugin::try_repair_eq(app* e, unsigned i) { + auto child = e->get_arg(i); + auto sibling = e->get_arg(1 - i); + if (!m.is_bool(child)) + return false; + set_value(child, bval0(e) == bval0(sibling)); + return true; + } + + bool basic_plugin::try_repair_xor(app* e, unsigned i) { + bool ev = bval0(e); + bool bv = bval0(e->get_arg(1 - i)); + auto child = e->get_arg(i); + set_value(child, ev != bv); + return true; + } + + bool basic_plugin::try_repair_ite(app* e, unsigned i) { + auto child = e->get_arg(i); + bool c = bval0(e->get_arg(0)); + if (i == 0) { + set_value(child, !c); + return true; + } + if (c != (i == 1)) + return false; + if (m.is_bool(e)) { + set_value(child, bval0(e)); + return true; + } + return false; + } + + bool basic_plugin::try_repair_implies(app* e, unsigned i) { + auto child = e->get_arg(i); + bool ev = bval0(e); + bool av = bval0(child); + bool bv = bval0(e->get_arg(1 - i)); + if (i == 0) { + if (ev == (!av || bv)) + return false; + } + else if (ev != (!bv || av)) + return false; + set_value(child, ev); + return true; + } + + bool basic_plugin::repair_up(expr* e) { + if (!m.is_bool(e)) + return false; + auto b = bval1(to_app(e)); + set_value(e, b); + return true; + } + + void basic_plugin::repair_down(app* e) { + SASSERT(m.is_bool(e)); + unsigned n = e->get_num_args(); + if (n == 0 || e->get_family_id() != m.get_basic_family_id()) { + for (auto p : ctx.parents(e)) + m_repair_up.insert(p->get_id()); + ctx.set_value(e, m.mk_bool_val(bval0(e))); + return; + } + if (bval0(e) == bval1(e)) + return; + unsigned s = ctx.rand(n); + for (unsigned i = 0; i < n; ++i) { + auto j = (i + s) % n; + if (try_repair(e, j)) { + m_repair_down = e->get_arg(j)->get_id(); + return; + } + } + m_repair_up.insert(e->get_id()); + } + + + void basic_plugin::repair_defs_and_updates() { + if (!m_repair_roots.empty() || + !m_repair_up.empty() || + m_repair_down != UINT_MAX) { + + while (m_repair_down != UINT_MAX) { + auto e = ctx.term(m_repair_down); + repair_down(to_app(e)); + } + + while (!m_repair_up.empty()) { + auto id = m_repair_up.elem_at(rand() % m_repair_up.size()); + auto e = ctx.term(id); + m_repair_up.remove(id); + repair_up(to_app(e)); + } + + if (!m_repair_roots.empty()) { + auto id = m_repair_roots.elem_at(rand() % m_repair_roots.size()); + m_repair_roots.remove(id); + m_repair_down = id; + } + } + } + + void basic_plugin::set_value(expr* e, bool b) { + sat::bool_var v = ctx.atom2bool_var(e); + if (v == sat::null_bool_var) { + if (m_values.get(e->get_id(), b) != b) { + m_values.set(e->get_id(), b); + ctx.set_value(e, m.mk_bool_val(b)); + } + } + else if (ctx.is_true(sat::literal(v, false)) != b) { + ctx.flip(v); + ctx.set_value(e, m.mk_bool_val(b)); + } + } + + void basic_plugin::repair_literal(sat::literal lit) { + if (!ctx.is_true(lit)) + return; + auto a = ctx.atom(lit.var()); + if (!a || !is_app(a)) + return; + if (to_app(a)->get_family_id() != basic_family_id) + return; + if (bval1(to_app(a)) != bval0(to_app(a))) + m_repair_roots.insert(a->get_id()); + } + +} diff --git a/src/ast/sls/sls_basic_plugin.h b/src/ast/sls/sls_basic_plugin.h new file mode 100644 index 000000000..568ae2877 --- /dev/null +++ b/src/ast/sls/sls_basic_plugin.h @@ -0,0 +1,61 @@ +/*++ +Copyright (c) 2020 Microsoft Corporation + +Module Name: + + sls_basic_plugin.h + +Author: + + Nikolaj Bjorner (nbjorner) 2024-07-05 + +--*/ +#pragma once + +#include "ast/sls/sls_smt.h" + +namespace sls { + + class basic_plugin : public plugin { + bool_vector m_values; + indexed_uint_set m_repair_up, m_repair_roots; + unsigned m_repair_down = UINT_MAX; + bool m_initialized = false; + + void init(); + bool bval1(app* e) const; + bool bval0(expr* e) const; + bool repair_up(expr* e); + bool try_repair(app* e, unsigned i); + bool try_repair_and_or(app* e, unsigned i); + bool try_repair_not(app* e); + bool try_repair_eq(app* e, unsigned i); + bool try_repair_xor(app* e, unsigned i); + bool try_repair_ite(app* e, unsigned i); + bool try_repair_implies(app* e, unsigned i); + void set_value(expr* e, bool b); + + void repair_down(app* e); + void repair_defs_and_updates(); + void repair_literal(sat::literal lit); + + public: + basic_plugin(context& ctx) : + plugin(ctx) { + } + ~basic_plugin() override {} + void init_bool_var(sat::bool_var v) override {} + void register_term(expr* e) override {} + expr_ref get_value(expr* e) override; + lbool check() override; + bool is_sat() override; + + void on_rescale() override {} + void on_restart() override {} + std::ostream& display(std::ostream& out) const override; + void mk_model(model& mdl) override {} + void set_shared(expr* e) override {} + void set_value(expr* e, expr* v) override; + }; + +} diff --git a/src/ast/sls/sls_bv.cpp b/src/ast/sls/sls_bv.cpp deleted file mode 100644 index b910ee1c5..000000000 --- a/src/ast/sls/sls_bv.cpp +++ /dev/null @@ -1,63 +0,0 @@ - -#include "ast/sls/sls_bv.h" - -namespace sls { - - bv_plugin::bv_plugin(context& ctx): - plugin(ctx), - bv(m), - m_terms(ctx), - m_eval(m_terms, ctx) - {} - - void bv_plugin::init_bool_var(sat::bool_var v) { - } - - void bv_plugin::register_term(expr* e) { - } - - expr_ref bv_plugin::get_value(expr* e) { - return expr_ref(m); - } - - lbool bv_plugin::check() { - return l_undef; - } - - bool bv_plugin::is_sat() { - return false; - } - - void bv_plugin::reset() { - } - - void bv_plugin::on_rescale() { - - } - - void bv_plugin::on_restart() { - } - - std::ostream& bv_plugin::display(std::ostream& out) const { - return out; - } - - void bv_plugin::mk_model(model& mdl) { - - } - - void bv_plugin::set_shared(expr* e) { - - } - - void bv_plugin::set_value(expr* e, expr* v) { - - } - - std::pair bv_plugin::next_to_repair() { - - - return { false, nullptr }; - } - -} diff --git a/src/ast/sls/sls_bv_plugin.cpp b/src/ast/sls/sls_bv_plugin.cpp new file mode 100644 index 000000000..a97d4f736 --- /dev/null +++ b/src/ast/sls/sls_bv_plugin.cpp @@ -0,0 +1,210 @@ +/*++ +Copyright (c) 2024 Microsoft Corporation + +Module Name: + + sls_bv_plugin.cpp + +Abstract: + + Theory plugin for bit-vector local search + +Author: + + Nikolaj Bjorner (nbjorner) 2024-07-06 + +--*/ +#include "ast/sls/sls_bv_plugin.h" +#include "ast/ast_ll_pp.h" + +namespace sls { + + bv_plugin::bv_plugin(context& ctx): + plugin(ctx), + bv(m), + m_terms(ctx), + m_eval(m_terms, ctx) + {} + + void bv_plugin::register_term(expr* e) { + m_terms.register_term(e); + } + + expr_ref bv_plugin::get_value(expr* e) { + return expr_ref(m); + } + + lbool bv_plugin::check() { + + if (!m_initialized) { + auto eval = [&](expr* e, unsigned idx) { return false; }; + m_eval.init_eval(eval); + m_initialized = true; + } + + auto& axioms = m_terms.axioms(); + if (!axioms.empty()) { + for (auto* e : axioms) + ctx.add_constraint(e); + axioms.reset(); + return l_undef; + } + + // repair each root literal + for (sat::literal lit : ctx.root_literals()) + repair_literal(lit); + + repair_defs_and_updates(); + + // update literal assignment based on current model + for (unsigned v = 0; v < ctx.num_bool_vars(); ++v) + init_bool_var_assignment(v); + + return ctx.unsat().empty() ? l_true : l_undef; + } + + void bv_plugin::repair_literal(sat::literal lit) { + if (!ctx.is_true(lit)) + return; + auto a = ctx.atom(lit.var()); + if (!a || !is_app(a)) + return; + if (to_app(a)->get_family_id() != bv.get_family_id()) + return; + if (!m_eval.eval_is_correct(to_app(a))) + m_repair_roots.insert(a->get_id()); + } + + void bv_plugin::repair_defs_and_updates() { + if (!m_repair_roots.empty() || + !m_repair_up.empty() || + m_repair_down != UINT_MAX) { + + while (m_repair_down != UINT_MAX) { + auto e = ctx.term(m_repair_down); + try_repair_down(to_app(e)); + } + + while (!m_repair_up.empty()) { + auto id = m_repair_up.elem_at(rand() % m_repair_up.size()); + auto e = ctx.term(id); + m_repair_up.remove(id); + try_repair_up(to_app(e)); + } + + if (!m_repair_roots.empty()) { + auto id = m_repair_roots.elem_at(rand() % m_repair_roots.size()); + m_repair_roots.remove(id); + m_repair_down = id; + } + } + } + + void bv_plugin::init_bool_var_assignment(sat::bool_var v) { + auto a = ctx.atom(v); + if (!a || !is_app(a)) + return; + if (to_app(a)->get_family_id() != bv.get_family_id()) + return; + bool is_true = m_eval.bval1(to_app(a)); + + if (is_true != ctx.is_true(sat::literal(v, false))) + ctx.flip(v); + } + + bool bv_plugin::is_sat() { + return false; + } + + std::ostream& bv_plugin::display(std::ostream& out) const { + // m_eval.display(out); + return out; + } + + void bv_plugin::set_shared(expr* e) { + } + + void bv_plugin::set_value(expr* e, expr* v) { + } + + void bv_plugin::try_repair_down(app* e) { + + unsigned n = e->get_num_args(); + if (n == 0 || m_eval.eval_is_correct(e)) { + m_eval.commit_eval(e); + if (!m.is_bool(e)) + for (auto p : ctx.parents(e)) + m_repair_up.insert(p->get_id()); + return; + } + + if (m.is_bool(e)) { + NOT_IMPLEMENTED_YET(); + return; + } + + if (n == 2) { + auto d1 = get_depth(e->get_arg(0)); + auto d2 = get_depth(e->get_arg(1)); + unsigned s = ctx.rand(d1 + d2 + 2); + if (s <= d1 && m_eval.try_repair(e, 0)) { + set_repair_down(e->get_arg(0)); + return; + } + if (m_eval.try_repair(e, 1)) { + set_repair_down(e->get_arg(1)); + return; + } + if (m_eval.try_repair(e, 0)) { + set_repair_down(e->get_arg(0)); + return; + } + } + else { + unsigned s = ctx.rand(n); + for (unsigned i = 0; i < n; ++i) { + auto j = (i + s) % n; + if (m_eval.try_repair(e, j)) { + set_repair_down(e->get_arg(j)); + return; + } + } + } + IF_VERBOSE(3, verbose_stream() << "init-repair " << mk_bounded_pp(e, m) << "\n"); + // repair was not successful, so reset the state to find a different way to repair + m_repair_down = UINT_MAX; + } + + void bv_plugin::try_repair_up(app* e) { + if (m.is_bool(e)) + ; + else if (m_eval.repair_up(e)) { + if (!m_eval.eval_is_correct(e)) { + verbose_stream() << "incorrect eval #" << e->get_id() << " " << mk_bounded_pp(e, m) << "\n"; + } + SASSERT(m_eval.eval_is_correct(e)); + for (auto p : ctx.parents(e)) + m_repair_up.insert(p->get_id()); + } + else if (ctx.rand(10) != 0) { + IF_VERBOSE(2, verbose_stream() << "repair-up "; trace_repair(true, e)); + m_eval.set_random(e); + m_repair_roots.insert(e->get_id()); + } + } + + std::ostream& bv_plugin::trace_repair(bool down, expr* e) { + verbose_stream() << (down ? "d #" : "u #") + << e->get_id() << ": " + << mk_bounded_pp(e, m, 1) << " "; + return m_eval.display_value(verbose_stream(), e) << "\n"; + } + + void bv_plugin::trace() { + IF_VERBOSE(2, verbose_stream() + << "(bvsls :restarts " << m_stats.m_restarts + << " :repair-up " << m_repair_up.size() + << " :repair-roots " << m_repair_roots.size() << ")\n"); + } + +} diff --git a/src/ast/sls/sls_bv.h b/src/ast/sls/sls_bv_plugin.h similarity index 62% rename from src/ast/sls/sls_bv.h rename to src/ast/sls/sls_bv_plugin.h index c591c4a7d..99cf4cf12 100644 --- a/src/ast/sls/sls_bv.h +++ b/src/ast/sls/sls_bv_plugin.h @@ -3,7 +3,7 @@ Copyright (c) 2020 Microsoft Corporation Module Name: - sls_bv.h + sls_bv_plugin.h Abstract: @@ -31,23 +31,34 @@ namespace sls { indexed_uint_set m_repair_up, m_repair_roots; unsigned m_repair_down = UINT_MAX; + bool m_initialized = false; - std::pair next_to_repair(); + void repair_literal(sat::literal lit); + + void repair_defs_and_updates(); + + void init_bool_var_assignment(sat::bool_var v); + + void try_repair_down(app* e); + void set_repair_down(expr* e) { m_repair_down = e->get_id(); } + void try_repair_up(app* e); + + std::ostream& bv_plugin::trace_repair(bool down, expr* e); + void trace(); public: bv_plugin(context& ctx); ~bv_plugin() override {} - void init_bool_var(sat::bool_var v) override; + void init_bool_var(sat::bool_var v) override {} void register_term(expr* e) override; expr_ref get_value(expr* e) override; lbool check() override; bool is_sat() override; - void reset() override; - void on_rescale() override; - void on_restart() override; + void on_rescale() override {} + void on_restart() override {} std::ostream& display(std::ostream& out) const override; - void mk_model(model& mdl) override; + void mk_model(model& mdl) override {} void set_shared(expr* e) override; void set_value(expr* e, expr* v) override; }; diff --git a/src/ast/sls/sls_cc.cpp b/src/ast/sls/sls_cc.cpp index 0d5ebf4c7..a9c2b3002 100644 --- a/src/ast/sls/sls_cc.cpp +++ b/src/ast/sls/sls_cc.cpp @@ -34,10 +34,6 @@ namespace sls { UNREACHABLE(); return expr_ref(m); } - - void cc_plugin::reset() { - m_app.reset(); - } void cc_plugin::register_term(expr* e) { if (!is_app(e)) diff --git a/src/ast/sls/sls_cc.h b/src/ast/sls/sls_cc.h index 06204bb28..381652a39 100644 --- a/src/ast/sls/sls_cc.h +++ b/src/ast/sls/sls_cc.h @@ -41,7 +41,6 @@ namespace sls { expr_ref get_value(expr* e) override; lbool check() override; bool is_sat() override; - void reset() override; void register_term(expr* e) override; void init_bool_var(sat::bool_var v) override {} std::ostream& display(std::ostream& out) const override; diff --git a/src/ast/sls/sls_smt.cpp b/src/ast/sls/sls_smt.cpp index 20138874a..bb7b047bb 100644 --- a/src/ast/sls/sls_smt.cpp +++ b/src/ast/sls/sls_smt.cpp @@ -19,18 +19,22 @@ Author: #include "ast/sls/sls_smt.h" #include "ast/sls/sls_cc.h" #include "ast/sls/sls_arith_plugin.h" +#include "ast/sls/sls_bv_plugin.h" +#include "ast/sls/sls_basic_plugin.h" namespace sls { plugin::plugin(context& c): ctx(c), m(c.get_manager()) { - reset(); } context::context(ast_manager& m, sat_solver_context& s) : m(m), s(s), m_atoms(m), m_allterms(m) { - reset(); + register_plugin(alloc(cc_plugin, *this)); + register_plugin(alloc(arith_plugin, *this)); + register_plugin(alloc(bv_plugin, *this)); + register_plugin(alloc(basic_plugin, *this)); } void context::register_plugin(plugin* p) { @@ -43,19 +47,6 @@ namespace sls { m_atom2bool_var.setx(e->get_id(), v, sat::null_bool_var); } - void context::reset() { - m_plugins.reset(); - m_atoms.reset(); - m_atom2bool_var.reset(); - m_initialized = false; - m_parents.reset(); - m_relevant.reset(); - m_visited.reset(); - m_allterms.reset(); - register_plugin(alloc(cc_plugin, *this)); - register_plugin(alloc(arith_plugin, *this)); - } - lbool context::check() { // // initialize data-structures if not done before. @@ -75,6 +66,9 @@ namespace sls { return l_undef; if (all_of(m_plugins, [&](auto* p) { return !p || p->is_sat(); })) { model_ref mdl = alloc(model, m); + for (expr* e : subterms()) + if (is_uninterp_const(e)) + mdl->register_decl(to_app(e)->get_decl(), get_value(e)); for (auto p : m_plugins) if (p) p->mk_model(*mdl); @@ -99,10 +93,6 @@ namespace sls { } expr_ref context::get_value(expr* e) { - if (m.is_bool(e)) { - auto v = m_atom2bool_var[e->get_id()]; - return expr_ref(is_true(sat::literal(v, false)) ? m.mk_true() : m.mk_false(), m); - } sort* s = e->get_sort(); auto fid = s->get_family_id(); auto p = m_plugins.get(fid, nullptr); diff --git a/src/ast/sls/sls_smt.h b/src/ast/sls/sls_smt.h index 3b0e16c7a..1b8e389b7 100644 --- a/src/ast/sls/sls_smt.h +++ b/src/ast/sls/sls_smt.h @@ -41,7 +41,6 @@ namespace sls { virtual void init_bool_var(sat::bool_var v) = 0; virtual lbool check() = 0; virtual bool is_sat() = 0; - virtual void reset() {}; virtual void on_rescale() {}; virtual void on_restart() {}; virtual std::ostream& display(std::ostream& out) const = 0; @@ -98,7 +97,7 @@ namespace sls { // Between SAT/SMT solver and context. void register_atom(sat::bool_var v, expr* e); - void reset(); + // void reset(); lbool check(); // expose sat_solver to plugins @@ -109,6 +108,8 @@ namespace sls { unsigned num_bool_vars() const { return s.num_vars(); } bool is_true(sat::literal lit) { return s.is_true(lit); } expr* atom(sat::bool_var v) { return m_atoms.get(v, nullptr); } + expr* term(unsigned id) const { return m_allterms.get(id); } + sat::bool_var atom2bool_var(expr* e) const { return m_atom2bool_var.get(e->get_id(), sat::null_bool_var); } void flip(sat::bool_var v) { s.flip(v); } double reward(sat::bool_var v) { return s.reward(v); } indexed_uint_set const& unsat() const { return s.unsat(); } @@ -118,6 +119,11 @@ namespace sls { void reinit_relevant(); + ptr_vector const& parents(expr* e) { + m_parents.reserve(e->get_id() + 1); + return m_parents[e->get_id()]; + } + // Between plugin solvers expr_ref get_value(expr* e); bool is_true(expr* e);