diff --git a/src/ast/sls/sls_arith_base.cpp b/src/ast/sls/sls_arith_base.cpp index 79f60d763..5061cd46b 100644 --- a/src/ast/sls/sls_arith_base.cpp +++ b/src/ast/sls/sls_arith_base.cpp @@ -604,10 +604,9 @@ namespace sls { bool arith_base::update(var_t v, num_t const& new_value) { auto& vi = m_vars[v]; expr* e = vi.m_expr; - SASSERT(!m.is_value(e)); auto old_value = vi.m_value; if (old_value == new_value) - return true; + return true; if (!vi.in_range(new_value)) return false; if (!in_bounds(v, new_value) && in_bounds(v, old_value)) diff --git a/src/ast/sls/sls_basic_plugin.cpp b/src/ast/sls/sls_basic_plugin.cpp index c55785c71..712ce742a 100644 --- a/src/ast/sls/sls_basic_plugin.cpp +++ b/src/ast/sls/sls_basic_plugin.cpp @@ -29,246 +29,85 @@ namespace sls { bool basic_plugin::is_basic(expr* e) const { if (!e || !is_app(e)) return false; - if (to_app(e)->get_family_id() != basic_family_id) - return false; - if (!m.is_bool(e)) - return false; - expr* x, * y; - if (m.is_eq(e, x, y) && !m.is_bool(x)) - return false; - if (m.is_distinct(e) && !m.is_bool(to_app(e)->get_arg(0))) - return false; - if (m.is_not(e, x)) - return is_basic(x); - return true; + if (m.is_ite(e) && !m.is_bool(e)) + return true; + if (m.is_xor(e) && to_app(e)->get_num_args() != 2) + return true; + if (m.is_distinct(e)) + return true; + return false; } void basic_plugin::propagate_literal(sat::literal lit) { - auto a = ctx.atom(lit.var()); - if (!is_basic(a)) - return; - if (bval1(to_app(a)) != bval0(to_app(a))) - ctx.new_value_eh(a); } void basic_plugin::register_term(expr* e) { - if (is_basic(e) && m.is_bool(e)) - m_values.setx(e->get_id(), bval1(to_app(e)), false); } void basic_plugin::initialize() { } bool basic_plugin::propagate() { - for (auto t : ctx.subterms()) - if (is_basic(t) && !m.is_not(t) && - bval0(t) != bval1(to_app(t))) { - add_clause(to_app(t)); - return true; - } - return false; } bool basic_plugin::is_sat() { - for (auto t : ctx.subterms()) - if (is_basic(t) && !m.is_not(t) && - bval0(t) != bval1(to_app(t))) { - verbose_stream() << mk_bounded_pp(t, m) << " := " << (bval0(t) ? "T" : "F") << " eval: " << (bval1(to_app(t)) ? "T" : "F") << "\n"; - return false; - } return true; } - std::ostream& basic_plugin::display(std::ostream& out) const { - for (auto t : ctx.subterms()) - if (is_basic(t)) - out << mk_bounded_pp(t, m) << " := " << (bval0(t)?"T":"F") << " eval: " << (bval1(to_app(t))?"T":"F") << "\n"; return out; } bool basic_plugin::set_value(expr* e, expr* v) { - if (!is_basic(e)) + if (!m.is_bool(e)) return false; SASSERT(m.is_true(v) || m.is_false(v)); return set_value(e, m.is_true(v)); } - bool basic_plugin::bval1(app* e) const { - verbose_stream() << mk_bounded_pp(e, m) << "\n"; - if (m.is_not(e)) - return bval1(to_app(e->get_arg(0))); - SASSERT(m.is_bool(e)); - SASSERT(e->get_family_id() == basic_family_id); + expr_ref basic_plugin::eval_ite(app* e) { + expr* c, * th, * el; + VERIFY(m.is_ite(e, c, th, el)); + if (bval0(c)) + return ctx.get_value(th); + else + return ctx.get_value(el); + } - auto id = e->get_id(); - switch (e->get_decl_kind()) { - case OP_TRUE: - return true; - case OP_FALSE: - return false; - case OP_AND: - return all_of(*to_app(e), [&](expr* arg) { return bval0(arg); }); - case OP_OR: - return any_of(*to_app(e), [&](expr* arg) { return bval0(arg); }); - case OP_NOT: - return !bval0(e->get_arg(0)); - case OP_XOR: { - bool r = false; - for (auto* arg : *to_app(e)) - r ^= bval0(arg); - return r; + expr_ref basic_plugin::eval_distinct(app* e) { + for (unsigned i = 0; i < e->get_num_args(); ++i) { + for (unsigned j = i + 1; j < e->get_num_args(); ++j) { + if (bval0(e->get_arg(i)) == bval0(e->get_arg(j))) + return expr_ref(m.mk_false(), m); + } } - case OP_IMPLIES: { - auto a = e->get_arg(0); - auto b = e->get_arg(1); - return !bval0(a) || bval0(b); - } - case OP_ITE: { - auto c = bval0(e->get_arg(0)); - return bval0(c ? e->get_arg(1) : e->get_arg(2)); - } - case OP_EQ: { - auto a = e->get_arg(0); - auto b = e->get_arg(1); - if (m.is_bool(a)) - return bval0(a) == bval0(b); - verbose_stream() << mk_bounded_pp(e, m) << " " << ctx.get_value(a) << " " << ctx.get_value(b) << "\n"; - return ctx.get_value(a) == ctx.get_value(b); - } - case OP_DISTINCT: { - for (unsigned i = 0; i < e->get_num_args(); ++i) - for (unsigned j = i + 1; j < e->get_num_args(); ++j) - if (ctx.get_value(e->get_arg(i)) == ctx.get_value(e->get_arg(j))) - return false; - return true; - } - default: - verbose_stream() << mk_bounded_pp(e, m) << "\n"; - UNREACHABLE(); - break; - } - UNREACHABLE(); - return false; + return expr_ref(m.mk_true(), m); + } + + expr_ref basic_plugin::eval_xor(app* e) { + bool b = false; + for (expr* arg : *e) + b ^= bval0(arg); + return expr_ref(m.mk_bool_val(b), m); } bool basic_plugin::bval0(expr* e) const { - SASSERT(m.is_bool(e)); - bool b = true; - while (m.is_not(e, e)) - b = !b; - sat::bool_var v = ctx.atom2bool_var(e); - if (v == sat::null_bool_var) - return b == m_values.get(e->get_id(), false); - else - return b == ctx.is_true(v); + SASSERT(m.is_bool(e)); + return ctx.is_true(ctx.mk_literal(e)); } bool basic_plugin::try_repair(app* e, unsigned i) { switch (e->get_decl_kind()) { - case OP_AND: - return try_repair_and_or(e, i); - case OP_OR: - return try_repair_and_or(e, i); - case OP_NOT: - return try_repair_not(e); - case OP_FALSE: - return false; - case OP_TRUE: - return false; - case OP_EQ: - return try_repair_eq(e, i); - case OP_IMPLIES: - return try_repair_implies(e, i); case OP_XOR: return try_repair_xor(e, i); case OP_ITE: return try_repair_ite(e, i); case OP_DISTINCT: return try_repair_distinct(e, i); - default: - UNREACHABLE(); - return false; - } - } - - void basic_plugin::add_clause(app* e) { - expr_ref_vector es(m); - expr_ref fml(m); - expr* x, *y; - switch (e->get_decl_kind()) { - case OP_AND: - for (expr* arg : *e) { - ctx.add_constraint(m.mk_or(m.mk_not(e), arg)); - es.push_back(mk_not(m, arg)); - } - es.push_back(e); - ctx.add_constraint(m.mk_or(es)); - break; - case OP_OR: - for (expr* arg : *e) { - ctx.add_constraint(m.mk_or(mk_not(m, arg), e)); - es.push_back(arg); - } - es.push_back(m.mk_not(e)); - ctx.add_constraint(m.mk_or(es)); - break; - case OP_NOT: - break; - case OP_FALSE: - break; - case OP_TRUE: - break; - case OP_EQ: - VERIFY(m.is_eq(e, x, y)); - ctx.add_constraint(m.mk_or(m.mk_not(e), mk_not(m, x), y)); - ctx.add_constraint(m.mk_or(m.mk_not(e), mk_not(m, y), x)); - ctx.add_constraint(m.mk_or(e, y, x)); - ctx.add_constraint(m.mk_or(e, mk_not(m, x), mk_not(m, y))); - break; - case OP_IMPLIES: - NOT_IMPLEMENTED_YET(); - case OP_XOR: - NOT_IMPLEMENTED_YET(); - case OP_ITE: - - NOT_IMPLEMENTED_YET(); - case OP_DISTINCT: - NOT_IMPLEMENTED_YET(); - default: - UNREACHABLE(); - break; - } - - } - - - bool basic_plugin::try_repair_and_or(app* e, unsigned i) { - auto b = bval0(e); - if ((b && m.is_and(e)) || (!b && m.is_or(e))) { - for (auto arg : *e) - if (!set_value(arg, b)) - return false; + default: return true; - } - auto child = e->get_arg(i); - if (b == bval0(child)) - return false; - return set_value(child, b); - } - - bool basic_plugin::try_repair_not(app* e) { - auto child = e->get_arg(0); - return set_value(child, !bval0(e)); - } - - bool basic_plugin::try_repair_eq(app* e, unsigned i) { - auto child = e->get_arg(i); - auto sibling = e->get_arg(1 - i); - if (!m.is_bool(child)) - return false; - return set_value(child, bval0(e) == bval0(sibling)); + } } bool basic_plugin::try_repair_xor(app* e, unsigned i) { @@ -283,12 +122,7 @@ namespace sls { bool basic_plugin::try_repair_ite(app* e, unsigned i) { if (m.is_bool(e)) - return try_repair_ite_bool(e, i); - else - return try_repair_ite_nonbool(e, i); - } - - bool basic_plugin::try_repair_ite_nonbool(app* e, unsigned i) { + return true; auto child = e->get_arg(i); auto cond = e->get_arg(0); bool c = bval0(cond); @@ -307,85 +141,45 @@ namespace sls { } if (c != (i == 1)) return false; - return ctx.set_value(child, ctx.get_value(e)); - } - - bool basic_plugin::try_repair_ite_bool(app* e, unsigned i) { - auto child = e->get_arg(i); - auto cond = e->get_arg(0); - bool c = bval0(cond); - if (i == 0) { - if (ctx.rand(2) == 0) - return set_value(cond, true) && set_value(e->get_arg(1), bval0(e)); - else - return set_value(cond, false) && set_value(e->get_arg(2), bval0(e)); - } - - if (!set_value(child, bval0(e))) + if (m.is_value(child)) return false; - return (c == (i == 1)) || set_value(cond, !c); - } - - bool basic_plugin::try_repair_implies(app* e, unsigned i) { - auto child = e->get_arg(i); - auto sibling = e->get_arg(1 - i); - bool ev = bval0(e); - bool av = bval0(child); - bool bv = bval0(sibling); - if (ev) { - - if (i == 0 && (!av || bv)) - return true; - if (i == 1 && (!bv || av)) - return true; - if (i == 0) { - return set_value(child, false); - } - if (i == 1) { - return set_value(child, true); - } - return false; - } - if (i == 0 && av && !bv) - return true; - if (i == 1 && bv && !av) - return true; - if (i == 0) - return set_value(child, true) && set_value(sibling, false); - if (i == 1) - return set_value(child, false) && set_value(sibling, true); - return false; + bool r = ctx.set_value(child, ctx.get_value(e)); + verbose_stream() << "repair-ite-down " << mk_bounded_pp(e, m) << " @ " << mk_bounded_pp(child, m) << " := " << ctx.get_value(e) << " success " << r << "\n"; + return r; } void basic_plugin::repair_up(app* e) { + expr* c, * th, * el; + expr_ref val(m); if (!is_basic(e)) return; - auto b = bval1(e); - if (bval0(e) == b) + if (m.is_ite(e, c, th, el) && !m.is_bool(e)) + val = eval_ite(e); + else if (m.is_xor(e)) + val = eval_xor(e); + else if (m.is_distinct(e)) + val = eval_distinct(e); + else return; - set_value(e, b); + verbose_stream() << "repair-up " << mk_bounded_pp(e, m) << " " << val << "\n"; + if (!ctx.set_value(e, val)) + ctx.new_value_eh(e); } void basic_plugin::repair_literal(sat::literal lit) { - auto a = ctx.atom(lit.var()); - if (!is_basic(a)) - return; - if (bval1(to_app(a)) != bval0(to_app(a))) - ctx.flip(lit.var()); } - bool basic_plugin::repair_down(app* e) { - SASSERT(m.is_bool(e)); - - unsigned n = e->get_num_args(); + bool basic_plugin::repair_down(app* e) { if (!is_basic(e)) - return false; - if (n == 0) + return true; + if (m.is_xor(e) && eval_xor(e) == ctx.get_value(e)) return true; - - if (bval0(e) == bval1(e)) + if (m.is_ite(e) && eval_ite(e) == ctx.get_value(e)) + return true; + if (m.is_distinct(e) && eval_distinct(e) == ctx.get_value(e)) return true; verbose_stream() << "basic repair down " << mk_bounded_pp(e, m) << "\n"; + unsigned n = e->get_num_args(); unsigned s = ctx.rand(n); for (unsigned i = 0; i < n; ++i) { auto j = (i + s) % n; @@ -396,23 +190,14 @@ namespace sls { } bool basic_plugin::try_repair_distinct(app* e, unsigned i) { + NOT_IMPLEMENTED_YET(); return false; } bool basic_plugin::set_value(expr* e, bool b) { - if (m.is_true(e) && !b) - return false; - if (m.is_false(e) && b) - return false; - sat::bool_var v = ctx.atom2bool_var(e); - if (v == sat::null_bool_var) { - if (m_values.get(e->get_id(), b) != b) { - m_values.set(e->get_id(), b); - ctx.new_value_eh(e); - } - } - else if (ctx.is_true(v) != b) { - ctx.flip(v); + auto lit = ctx.mk_literal(e); + if (ctx.is_true(lit) != b) { + ctx.flip(lit.var()); ctx.new_value_eh(e); } return true; diff --git a/src/ast/sls/sls_basic_plugin.h b/src/ast/sls/sls_basic_plugin.h index fc36ad629..d640415f4 100644 --- a/src/ast/sls/sls_basic_plugin.h +++ b/src/ast/sls/sls_basic_plugin.h @@ -17,25 +17,20 @@ Author: namespace sls { class basic_plugin : public plugin { - bool_vector m_values; bool m_initialized = false; + expr_mark m_axiomatized; bool is_basic(expr* e) const; - bool bval1(app* e) const; bool bval0(expr* e) const; bool try_repair(app* e, unsigned i); - bool try_repair_and_or(app* e, unsigned i); - bool try_repair_not(app* e); - bool try_repair_eq(app* e, unsigned i); bool try_repair_xor(app* e, unsigned i); bool try_repair_ite(app* e, unsigned i); - bool try_repair_ite_nonbool(app* e, unsigned i); - bool try_repair_ite_bool(app* e, unsigned i); - bool try_repair_implies(app* e, unsigned i); bool try_repair_distinct(app* e, unsigned i); bool set_value(expr* e, bool b); - void add_clause(app* e); + expr_ref eval_ite(app* e); + expr_ref eval_distinct(app* e); + expr_ref eval_xor(app* e); public: basic_plugin(context& ctx) : diff --git a/src/ast/sls/sls_context.cpp b/src/ast/sls/sls_context.cpp index ec26737e3..4d3704db0 100644 --- a/src/ast/sls/sls_context.cpp +++ b/src/ast/sls/sls_context.cpp @@ -21,6 +21,7 @@ Author: #include "ast/sls/sls_bv_plugin.h" #include "ast/sls/sls_basic_plugin.h" #include "ast/ast_ll_pp.h" +#include "ast/ast_pp.h" namespace sls { @@ -96,9 +97,9 @@ namespace sls { reinit_relevant(); for (sat::literal lit : root_literals()) { - if (m_new_constraint) - break; propagate_literal(lit); + if (m_new_constraint) + return; } while (!m_new_constraint && m.inc() && (!m_repair_up.empty() || !m_repair_down.empty())) { @@ -125,15 +126,20 @@ namespace sls { } } } + + // propagate "final checks" bool propagated = true; while (propagated && !m_new_constraint) { propagated = false; for (auto p : m_plugins) propagated |= p && !m_new_constraint && p->propagate(); - } + } - for (sat::bool_var v = 0; v < s.num_vars(); ++v) { + if (m_new_constraint) + return; + + for (sat::bool_var v = 0; v < s.num_vars() && !m_new_constraint; ++v) { auto a = atom(v); if (!a) continue; @@ -148,10 +154,8 @@ namespace sls { if (!is_app(e)) return null_family_id; family_id fid = to_app(e)->get_family_id(); - if (m.is_eq(e) || m.is_distinct(e)) + if (m.is_eq(e)) fid = to_app(e)->get_arg(0)->get_sort()->get_family_id(); - else if (m.is_ite(e)) - fid = to_app(e)->get_arg(1)->get_sort()->get_family_id(); return fid; } @@ -191,7 +195,6 @@ namespace sls { return expr_ref(e, m); } - bool context::set_value(expr * e, expr * v) { for (auto p : m_plugins) if (p && p->set_value(e, v)) @@ -215,34 +218,151 @@ namespace sls { return false; } - void context::add_constraint(expr* e) { - expr_ref _e(e, m); - sat::literal_vector lits; - auto add_literal = [&](expr* e) { - bool is_neg = m.is_not(e, e); - auto v = mk_atom(e); - lits.push_back(sat::literal(v, is_neg)); - }; - if (m.is_or(e)) - for (auto arg : *to_app(e)) - add_literal(arg); - else - add_literal(e); - TRACE("sls", tout << "new clause " << lits << "\n"); - s.add_clause(lits.size(), lits.data()); + void context::add_constraint(expr* e) { + add_clause(e); m_new_constraint = true; } - sat::bool_var context::mk_atom(expr* e) { - auto v = m_atom2bool_var.get(e->get_id(), sat::null_bool_var); - if (v == sat::null_bool_var) { - v = s.add_var(); - register_atom(v, e); - register_terms(e); + void context::add_clause(expr* f) { + expr_ref _e(f, m); + verbose_stream() << "add constraint " << _e << "\n"; + expr* g, * h, * k; + sat::literal_vector clause; + if (m.is_not(f, g) && m.is_not(g, g)) { + add_clause(g); + return; + } + bool sign = m.is_not(f, f); + if (!sign && m.is_or(f)) { + clause.reset(); + for (auto arg : *to_app(f)) + clause.push_back(mk_literal(arg)); + s.add_clause(clause.size(), clause.data()); + } + else if (!sign && m.is_and(f)) { + for (auto arg : *to_app(f)) + add_clause(arg); + } + else if (sign && m.is_or(f)) { + for (auto arg : *to_app(f)) { + expr_ref fml(m.mk_not(arg), m);; + add_clause(fml); + } + } + else if (sign && m.is_and(f)) { + clause.reset(); + for (auto arg : *to_app(f)) + clause.push_back(~mk_literal(arg)); + s.add_clause(clause.size(), clause.data()); + } + else if (m.is_iff(f, g, h)) { + auto lit1 = mk_literal(g); + auto lit2 = mk_literal(h); + sat::literal cls1[2] = { sign ? lit1 : ~lit1, lit2 }; + sat::literal cls2[2] = { sign ? ~lit1 : lit1, ~lit2 }; + s.add_clause(2, cls1); + s.add_clause(2, cls2); + } + else if (m.is_ite(f, g, h, k)) { + auto lit1 = mk_literal(g); + auto lit2 = mk_literal(h); + auto lit3 = mk_literal(k); + // (g -> h) & (~g -> k) + // (g & h) | (~g & k) + // negated: (g -> ~h) & (g -> ~k) + sat::literal cls1[2] = { ~lit1, sign ? ~lit2 : lit2 }; + sat::literal cls2[2] = { lit1, sign ? ~lit3 : lit3 }; + s.add_clause(2, cls1); + s.add_clause(2, cls2); + } + else { + sat::literal lit = mk_literal(f); + if (sign) + lit.neg(); + s.add_clause(1, &lit); } - return v; } + sat::literal context::mk_literal() { + sat::bool_var v = s.add_var(); + return sat::literal(v, false); + } + + sat::literal context::mk_literal(expr* e) { + sat::literal lit; + bool neg = false; + expr* a, * b, * c; + while (m.is_not(e, e)) + neg = !neg; + auto v = m_atom2bool_var.get(e->get_id(), sat::null_bool_var); + if (v != sat::null_bool_var) + return sat::literal(v, neg); + sat::literal_vector clause; + lit = mk_literal(); + if (m.is_true(e)) { + clause.push_back(lit); + s.add_clause(clause.size(), clause.data()); + } + else if (m.is_false(e)) { + clause.push_back(~lit); + s.add_clause(clause.size(), clause.data()); + } + else if (m.is_and(e)) { + for (expr* arg : *to_app(e)) { + auto lit2 = mk_literal(arg); + clause.push_back(~lit2); + sat::literal lits[2] = { ~lit, lit2 }; + s.add_clause(2, lits); + } + clause.push_back(lit); + s.add_clause(clause.size(), clause.data()); + } + else if (m.is_or(e)) { + for (expr* arg : *to_app(e)) { + auto lit2 = mk_literal(arg); + clause.push_back(lit2); + sat::literal lits[2] = { lit, ~lit2 }; + s.add_clause(2, lits); + } + clause.push_back(~lit); + s.add_clause(clause.size(), clause.data()); + } + else if (m.is_iff(e, a, b) || m.is_xor(e, a, b)) { + auto lit1 = mk_literal(a); + auto lit2 = mk_literal(b); + if (m.is_xor(e)) + lit2.neg(); + sat::literal cls1[3] = { ~lit, ~lit1, lit2 }; + sat::literal cls2[3] = { ~lit, lit1, ~lit2 }; + sat::literal cls3[3] = { lit, lit1, lit2 }; + sat::literal cls4[3] = { lit, ~lit1, ~lit2 }; + s.add_clause(3, cls1); + s.add_clause(3, cls2); + s.add_clause(3, cls3); + s.add_clause(3, cls4); + } + else if (m.is_ite(e, a, b, c)) { + auto lit1 = mk_literal(a); + auto lit2 = mk_literal(b); + auto lit3 = mk_literal(c); + sat::literal cls1[3] = { ~lit, ~lit1, lit2 }; + sat::literal cls2[3] = { ~lit, lit1, lit3 }; + sat::literal cls3[3] = { lit, ~lit1, ~lit2 }; + sat::literal cls4[3] = { lit, lit1, ~lit3 }; + s.add_clause(3, cls1); + s.add_clause(3, cls2); + s.add_clause(3, cls3); + s.add_clause(3, cls4); + } + else + register_terms(e); + + register_atom(lit.var(), e); + + return neg ? ~lit : lit; + } + + void context::init() { m_new_constraint = false; if (m_initialized) diff --git a/src/ast/sls/sls_context.h b/src/ast/sls/sls_context.h index 6a3bebacc..52fc8133c 100644 --- a/src/ast/sls/sls_context.h +++ b/src/ast/sls/sls_context.h @@ -104,12 +104,13 @@ namespace sls { random_gen m_rand; bool m_initialized = false; bool m_new_constraint = false; + bool m_dirty = false; expr_ref_vector m_allterms; ptr_vector m_subterms; greater_depth m_gd; less_depth m_ld; heap m_repair_down; - heap m_repair_up; + heap m_repair_up; void register_plugin(plugin* p); @@ -117,12 +118,14 @@ namespace sls { ptr_vector m_todo; void register_terms(expr* e); void register_term(expr* e); - sat::bool_var mk_atom(expr* e); void propagate_boolean_assignment(); void propagate_literal(sat::literal lit); family_id get_fid(expr* e) const; + + + sat::literal mk_literal(); public: context(ast_manager& m, sat_solver_context& s); @@ -142,6 +145,8 @@ namespace sls { expr* atom(sat::bool_var v) { return m_atoms.get(v, nullptr); } expr* term(unsigned id) const { return m_allterms.get(id); } sat::bool_var atom2bool_var(expr* e) const { return m_atom2bool_var.get(e->get_id(), sat::null_bool_var); } + sat::literal mk_literal(expr* e); + void add_clause(expr* f); void flip(sat::bool_var v) { s.flip(v); } double reward(sat::bool_var v) { return s.reward(v); } indexed_uint_set const& unsat() const { return s.unsat(); } diff --git a/src/ast/sls/sls_smt_solver.cpp b/src/ast/sls/sls_smt_solver.cpp index a72e95e72..34dd06e63 100644 --- a/src/ast/sls/sls_smt_solver.cpp +++ b/src/ast/sls/sls_smt_solver.cpp @@ -27,8 +27,9 @@ namespace sls { ast_manager& m; sat::ddfw& m_ddfw; context m_context; - bool m_new_clause_added = false; + bool m_dirty = false; model_ref m_model; + obj_map m_expr2lit; public: solver_ctx(ast_manager& m, sat::ddfw& d) : m(m), m_ddfw(d), m_context(m, *this) { @@ -54,11 +55,11 @@ namespace sls { TRACE("sls", display(tout)); while (unsat().empty()) { m_context.check(); - if (!m_new_clause_added) + if (!m_dirty) break; TRACE("sls", display(tout)); m_ddfw.reinit(); - m_new_clause_added = false; + m_dirty = false; } } @@ -80,17 +81,29 @@ namespace sls { vector const& clauses() const override { return m_ddfw.clauses(); } sat::clause_info const& get_clause(unsigned idx) const override { return m_ddfw.get_clause_info(idx); } ptr_iterator get_use_list(sat::literal lit) override { return m_ddfw.use_list(lit); } - void flip(sat::bool_var v) override { m_ddfw.flip(v); } + void flip(sat::bool_var v) override { if (m_dirty) m_ddfw.reinit(), m_dirty = false; m_ddfw.flip(v); } double reward(sat::bool_var v) override { return m_ddfw.get_reward(v); } double get_weigth(unsigned clause_idx) override { return m_ddfw.get_clause_info(clause_idx).m_weight; } bool is_true(sat::literal lit) override { return m_ddfw.get_value(lit.var()) != lit.sign(); } unsigned num_vars() const override { return m_ddfw.num_vars(); } indexed_uint_set const& unsat() const override { return m_ddfw.unsat_set(); } - sat::bool_var add_var() override { return m_ddfw.add_var(); } - void add_clause(unsigned n, sat::literal const* lits) override { - m_ddfw.add(n, lits); - m_new_clause_added = true; + sat::bool_var add_var() override { m_dirty = true; return m_ddfw.add_var(); } + + + void add_clause(expr* f) { + m_context.add_clause(f); } + + void add_clause(unsigned n, sat::literal const* lits) override { + m_ddfw.add(n, lits); + m_dirty = true; + } + + sat::literal mk_literal() { + sat::bool_var v = add_var(); + return sat::literal(v, false); + } + model_ref get_model() { return m_model; } void collect_statistics(statistics& st) { @@ -129,143 +142,13 @@ namespace sls { // send expression mapping to m_solver_ctx for (auto f : m_assertions) - add_clause(f); + m_solver_ctx->add_clause(f); IF_VERBOSE(10, m_solver_ctx->display(verbose_stream())); auto r = m_ddfw.check(0, nullptr); return r; } - - void smt_solver::add_clause(expr* f) { - expr* g, * h, * k; - sat::literal_vector clause; - if (m.is_not(f, g) && m.is_not(g, g)) { - add_clause(g); - return; - } - bool sign = m.is_not(f, f); - if (!sign && m.is_or(f)) { - clause.reset(); - for (auto arg : *to_app(f)) - clause.push_back(mk_literal(arg)); - m_solver_ctx->add_clause(clause.size(), clause.data()); - } - else if (!sign && m.is_and(f)) { - for (auto arg : *to_app(f)) - add_clause(arg); - } - else if (sign && m.is_or(f)) { - for (auto arg : *to_app(f)) { - expr_ref fml(m.mk_not(arg), m);; - add_clause(fml); - } - } - else if (sign && m.is_and(f)) { - clause.reset(); - for (auto arg : *to_app(f)) - clause.push_back(~mk_literal(arg)); - m_solver_ctx->add_clause(clause.size(), clause.data()); - } - else if (m.is_iff(f, g, h)) { - auto lit1 = mk_literal(g); - auto lit2 = mk_literal(h); - sat::literal cls1[2] = { sign ? lit1 :~lit1, lit2 }; - sat::literal cls2[2] = { sign ? ~lit1 : lit1, ~lit2 }; - m_solver_ctx->add_clause(2, cls1); - m_solver_ctx->add_clause(2, cls2); - } - else if (m.is_ite(f, g, h, k)) { - auto lit1 = mk_literal(g); - auto lit2 = mk_literal(h); - auto lit3 = mk_literal(k); - // (g -> h) & (~g -> k) - // (g & h) | (~g & k) - // negated: (g -> ~h) & (g -> ~k) - sat::literal cls1[2] = { ~lit1, sign ? ~lit2 : lit2 }; - sat::literal cls2[2] = { lit1, sign ? ~lit3 : lit3 }; - m_solver_ctx->add_clause(2, cls1); - m_solver_ctx->add_clause(2, cls2); - } - else { - sat::literal lit = mk_literal(f); - if (sign) - lit.neg(); - m_solver_ctx->add_clause(1, &lit); - } - } - - sat::literal smt_solver::mk_literal(expr* e) { - sat::literal lit; - bool neg = false; - expr* a, * b, * c; - while (m.is_not(e,e)) - neg = !neg; - if (m_expr2lit.find(e, lit)) - return neg ? ~lit : lit; - sat::literal_vector clause; - if (m.is_and(e)) { - lit = mk_literal(); - for (expr* arg : *to_app(e)) { - auto lit2 = mk_literal(arg); - clause.push_back(~lit2); - sat::literal lits[2] = { ~lit, lit2 }; - m_solver_ctx->add_clause(2, lits); - } - clause.push_back(lit); - m_solver_ctx->add_clause(clause.size(), clause.data()); - } - else if (m.is_or(e)) { - lit = mk_literal(); - for (expr* arg : *to_app(e)) { - auto lit2 = mk_literal(arg); - clause.push_back(lit2); - sat::literal lits[2] = { lit, ~lit2 }; - m_solver_ctx->add_clause(2, lits); - } - clause.push_back(~lit); - m_solver_ctx->add_clause(clause.size(), clause.data()); - } - else if (m.is_iff(e, a, b)) { - lit = mk_literal(); - auto lit1 = mk_literal(a); - auto lit2 = mk_literal(b); - sat::literal cls1[3] = { ~lit, ~lit1, lit2 }; - sat::literal cls2[3] = { ~lit, lit1, ~lit2 }; - sat::literal cls3[3] = { lit, lit1, lit2 }; - sat::literal cls4[3] = { lit, ~lit1, ~lit2 }; - m_solver_ctx->add_clause(3, cls1); - m_solver_ctx->add_clause(3, cls2); - m_solver_ctx->add_clause(3, cls3); - m_solver_ctx->add_clause(3, cls4); - } - else if (m.is_ite(e, a, b, c)) { - lit = mk_literal(); - auto lit1 = mk_literal(a); - auto lit2 = mk_literal(b); - auto lit3 = mk_literal(c); - sat::literal cls1[3] = { ~lit, ~lit1, lit2 }; - sat::literal cls2[3] = { ~lit, lit1, lit3 }; - sat::literal cls3[3] = { lit, ~lit1, ~lit2 }; - sat::literal cls4[3] = { lit, lit1, ~lit3 }; - m_solver_ctx->add_clause(3, cls1); - m_solver_ctx->add_clause(3, cls2); - m_solver_ctx->add_clause(3, cls3); - m_solver_ctx->add_clause(3, cls4); - } - else { - sat::bool_var v = m_num_vars++; - lit = sat::literal(v, false); - m_solver_ctx->register_atom(lit.var(), e); - } - m_expr2lit.insert(e, lit); - return neg ? ~lit : lit; - } - - sat::literal smt_solver::mk_literal() { - sat::bool_var v = m_num_vars++; - return sat::literal(v, false); - } model_ref smt_solver::get_model() { return m_solver_ctx->get_model(); diff --git a/src/ast/sls/sls_smt_solver.h b/src/ast/sls/sls_smt_solver.h index 5b7e0d62a..914397fc1 100644 --- a/src/ast/sls/sls_smt_solver.h +++ b/src/ast/sls/sls_smt_solver.h @@ -29,12 +29,7 @@ namespace sls { solver_ctx* m_solver_ctx = nullptr; expr_ref_vector m_assertions; statistics m_st; - obj_map m_expr2lit; - unsigned m_num_vars = 0; - - sat::literal mk_literal(expr* e); - sat::literal mk_literal(); - void add_clause(expr* f); + public: smt_solver(ast_manager& m, params_ref const& p); ~smt_solver(); diff --git a/src/tactic/sls/sls_tactic.cpp b/src/tactic/sls/sls_tactic.cpp index d78fb30f8..a0f5f3b76 100644 --- a/src/tactic/sls/sls_tactic.cpp +++ b/src/tactic/sls/sls_tactic.cpp @@ -78,7 +78,7 @@ public: try { res = m_sls->check(); } - catch (z3_exception& ex) { + catch (z3_exception&) { m_sls->collect_statistics(m_st); throw; }