diff --git a/src/sat/sat_extension.h b/src/sat/sat_extension.h index 9ab1fcd0c..884a49986 100644 --- a/src/sat/sat_extension.h +++ b/src/sat/sat_extension.h @@ -75,7 +75,6 @@ namespace sat { virtual std::ostream& display_constraint(std::ostream& out, ext_constraint_idx idx) const = 0; virtual void collect_statistics(statistics& st) const = 0; virtual extension* copy(solver* s) = 0; - virtual extension* copy(lookahead* s, bool learned) = 0; virtual void find_mutexes(literal_vector& lits, vector & mutexes) = 0; virtual void gc() = 0; virtual void pop_reinit() = 0; diff --git a/src/sat/sat_lookahead.cpp b/src/sat/sat_lookahead.cpp index 8e3f5bddb..a550102a9 100644 --- a/src/sat/sat_lookahead.cpp +++ b/src/sat/sat_lookahead.cpp @@ -22,7 +22,6 @@ Notes: #include #include "sat/sat_solver.h" -#include "sat/sat_extension.h" #include "sat/sat_lookahead.h" #include "sat/sat_scc.h" #include "util/union_find.h" @@ -1037,9 +1036,6 @@ namespace sat { } } - if (m_s.m_ext) { - // m_ext = m_s.m_ext->copy(this, learned); - } propagate(); m_qhead = m_trail.size(); m_init_freevars = m_freevars.size(); diff --git a/src/sat/sat_solver/inc_sat_solver.cpp b/src/sat/sat_solver/inc_sat_solver.cpp index adf8df443..ca6fc98c7 100644 --- a/src/sat/sat_solver/inc_sat_solver.cpp +++ b/src/sat/sat_solver/inc_sat_solver.cpp @@ -126,7 +126,7 @@ public: auto* ext = dynamic_cast(m_solver.get_extension()); if (ext) { auto& si = result->m_goal2sat.si(dst_m, m_params, result->m_solver, result->m_map, result->m_dep2asm, is_incremental()); - euf::solver::scoped_set_translate st(*ext, tr, result->m_map, si); + euf::solver::scoped_set_translate st(*ext, dst_m, result->m_map, si); result->m_solver.copy(m_solver); } else { diff --git a/src/sat/smt/ba_internalize.cpp b/src/sat/smt/ba_internalize.cpp index 00ebd0d19..74e9c76db 100644 --- a/src/sat/smt/ba_internalize.cpp +++ b/src/sat/smt/ba_internalize.cpp @@ -3,7 +3,7 @@ Copyright (c) 2020 Microsoft Corporation Module Name: - ba_internalize.h + ba_core.h Abstract: @@ -16,12 +16,13 @@ Author: --*/ -#include "sat/smt/ba_internalize.h" +#include "sat/smt/ba_solver.h" +#include "ast/pb_decl_plugin.h" namespace sat { - literal ba_internalize::internalize(expr* e, bool sign, bool root) { - if (pb.is_pb(e)) + literal ba_solver::internalize(expr* e, bool sign, bool root) { + if (m_pb.is_pb(e)) return internalize_pb(e, sign, root); if (m.is_xor(e)) return internalize_xor(e, sign, root); @@ -29,13 +30,13 @@ namespace sat { return null_literal; } - literal ba_internalize::internalize_xor(expr* e, bool sign, bool root) { + literal ba_solver::internalize_xor(expr* e, bool sign, bool root) { sat::literal_vector lits; - sat::bool_var v = m_solver.add_var(true); + sat::bool_var v = s().add_var(true); lits.push_back(literal(v, true)); auto add_expr = [&](expr* a) { literal lit = si.internalize(a); - m_solver.set_external(lit.var()); + s().set_external(lit.var()); lits.push_back(lit); }; expr* e1 = nullptr; @@ -46,34 +47,34 @@ namespace sat { for (unsigned i = 1; i + 1 < lits.size(); ++i) { lits[i].neg(); } - ba.add_xr(lits); - auto* aig = m_solver.get_cut_simplifier(); + add_xr(lits); + auto* aig = s().get_cut_simplifier(); if (aig) aig->add_xor(~lits.back(), lits.size() - 1, lits.c_ptr() + 1); sat::literal lit(v, sign); return literal(v, sign); } - literal ba_internalize::internalize_pb(expr* e, bool sign, bool root) { - SASSERT(pb.is_pb(e)); + literal ba_solver::internalize_pb(expr* e, bool sign, bool root) { + SASSERT(m_pb.is_pb(e)); app* t = to_app(e); - rational k = pb.get_k(t); + rational k = m_pb.get_k(t); switch (t->get_decl_kind()) { case OP_AT_MOST_K: return convert_at_most_k(t, k, root, sign); case OP_AT_LEAST_K: return convert_at_least_k(t, k, root, sign); case OP_PB_LE: - if (pb.has_unit_coefficients(t)) + if (m_pb.has_unit_coefficients(t)) return convert_at_most_k(t, k, root, sign); else return convert_pb_le(t, root, sign); case OP_PB_GE: - if (pb.has_unit_coefficients(t)) + if (m_pb.has_unit_coefficients(t)) return convert_at_least_k(t, k, root, sign); else return convert_pb_ge(t, root, sign); case OP_PB_EQ: - if (pb.has_unit_coefficients(t)) + if (m_pb.has_unit_coefficients(t)) return convert_eq_k(t, k, root, sign); else return convert_pb_eq(t, root, sign); @@ -83,35 +84,35 @@ namespace sat { return null_literal; } - void ba_internalize::check_unsigned(rational const& c) { + void ba_solver::check_unsigned(rational const& c) { if (!c.is_unsigned()) { throw default_exception("unsigned coefficient expected"); } } - void ba_internalize::convert_to_wlits(app* t, sat::literal_vector const& lits, svector& wlits) { + void ba_solver::convert_to_wlits(app* t, sat::literal_vector const& lits, svector& wlits) { for (unsigned i = 0; i < lits.size(); ++i) { - rational c = pb.get_coeff(t, i); + rational c = m_pb.get_coeff(t, i); check_unsigned(c); wlits.push_back(std::make_pair(c.get_unsigned(), lits[i])); } } - void ba_internalize::convert_pb_args(app* t, literal_vector& lits) { + void ba_solver::convert_pb_args(app* t, literal_vector& lits) { for (expr* arg : *t) { lits.push_back(si.internalize(arg)); - m_solver.set_external(lits.back().var()); + s().set_external(lits.back().var()); } } - void ba_internalize::convert_pb_args(app* t, svector& wlits) { + void ba_solver::convert_pb_args(app* t, svector& wlits) { sat::literal_vector lits; convert_pb_args(t, lits); convert_to_wlits(t, lits, wlits); } - literal ba_internalize::convert_pb_le(app* t, bool root, bool sign) { - rational k = pb.get_k(t); + literal ba_solver::convert_pb_le(app* t, bool root, bool sign) { + rational k = m_pb.get_k(t); k.neg(); svector wlits; convert_pb_args(t, wlits); @@ -120,7 +121,7 @@ namespace sat { k += rational(wl.first); } check_unsigned(k); - if (root && m_solver.num_user_scopes() == 0) { + if (root && s().num_user_scopes() == 0) { unsigned k1 = k.get_unsigned(); if (sign) { k1 = 1 - k1; @@ -129,25 +130,25 @@ namespace sat { k1 += wl.first; } } - ba.add_pb_ge(null_bool_var, wlits, k1); + add_pb_ge(null_bool_var, wlits, k1); return null_literal; } else { - bool_var v = m_solver.add_var(true); + bool_var v = s().add_var(true); literal lit(v, sign); - ba.add_pb_ge(v, wlits, k.get_unsigned()); + add_pb_ge(v, wlits, k.get_unsigned()); TRACE("ba", tout << "root: " << root << " lit: " << lit << "\n";); return lit; } } - literal ba_internalize::convert_pb_ge(app* t, bool root, bool sign) { - rational k = pb.get_k(t); + literal ba_solver::convert_pb_ge(app* t, bool root, bool sign) { + rational k = m_pb.get_k(t); check_unsigned(k); svector wlits; convert_pb_args(t, wlits); - if (root && m_solver.num_user_scopes() == 0) { + if (root && s().num_user_scopes() == 0) { unsigned k1 = k.get_unsigned(); if (sign) { k1 = 1 - k1; @@ -156,40 +157,40 @@ namespace sat { k1 += wl.first; } } - ba.add_pb_ge(sat::null_bool_var, wlits, k1); + add_pb_ge(sat::null_bool_var, wlits, k1); return null_literal; } else { - sat::bool_var v = m_solver.add_var(true); + sat::bool_var v = s().add_var(true); sat::literal lit(v, sign); - ba.add_pb_ge(v, wlits, k.get_unsigned()); + add_pb_ge(v, wlits, k.get_unsigned()); TRACE("goal2sat", tout << "root: " << root << " lit: " << lit << "\n";); return lit; } } - literal ba_internalize::convert_pb_eq(app* t, bool root, bool sign) { - rational k = pb.get_k(t); + literal ba_solver::convert_pb_eq(app* t, bool root, bool sign) { + rational k = m_pb.get_k(t); SASSERT(k.is_unsigned()); svector wlits; convert_pb_args(t, wlits); - bool base_assert = (root && !sign && m_solver.num_user_scopes() == 0); - bool_var v1 = base_assert ? null_bool_var : m_solver.add_var(true); - bool_var v2 = base_assert ? null_bool_var : m_solver.add_var(true); - ba.add_pb_ge(v1, wlits, k.get_unsigned()); + bool base_assert = (root && !sign && s().num_user_scopes() == 0); + bool_var v1 = base_assert ? null_bool_var : s().add_var(true); + bool_var v2 = base_assert ? null_bool_var : s().add_var(true); + add_pb_ge(v1, wlits, k.get_unsigned()); k.neg(); for (wliteral& wl : wlits) { wl.second.neg(); k += rational(wl.first); } check_unsigned(k); - ba.add_pb_ge(v2, wlits, k.get_unsigned()); + add_pb_ge(v2, wlits, k.get_unsigned()); if (base_assert) { return null_literal; } else { literal l1(v1, false), l2(v2, false); - bool_var v = m_solver.add_var(false); + bool_var v = s().add_var(false); literal l(v, false); si.mk_clause(~l, l1); si.mk_clause(~l, l2); @@ -200,23 +201,23 @@ namespace sat { } } - literal ba_internalize::convert_at_least_k(app* t, rational const& k, bool root, bool sign) { + literal ba_solver::convert_at_least_k(app* t, rational const& k, bool root, bool sign) { SASSERT(k.is_unsigned()); literal_vector lits; convert_pb_args(t, lits); unsigned k2 = k.get_unsigned(); - if (root && m_solver.num_user_scopes() == 0) { + if (root && s().num_user_scopes() == 0) { if (sign) { for (literal& l : lits) l.neg(); k2 = lits.size() + 1 - k2; } - ba.add_at_least(null_bool_var, lits, k2); + add_at_least(null_bool_var, lits, k2); return null_literal; } else { - bool_var v = m_solver.add_var(true); + bool_var v = s().add_var(true); literal lit(v, false); - ba.add_at_least(v, lits, k.get_unsigned()); + add_at_least(v, lits, k.get_unsigned()); si.cache(t, lit); if (sign) lit.neg(); TRACE("ba", tout << "root: " << root << " lit: " << lit << "\n";); @@ -224,7 +225,7 @@ namespace sat { } } - literal ba_internalize::convert_at_most_k(app* t, rational const& k, bool root, bool sign) { + literal ba_solver::convert_at_most_k(app* t, rational const& k, bool root, bool sign) { SASSERT(k.is_unsigned()); literal_vector lits; convert_pb_args(t, lits); @@ -232,39 +233,39 @@ namespace sat { l.neg(); } unsigned k2 = lits.size() - k.get_unsigned(); - if (root && m_solver.num_user_scopes() == 0) { + if (root && s().num_user_scopes() == 0) { if (sign) { for (literal& l : lits) l.neg(); k2 = lits.size() + 1 - k2; } - ba.add_at_least(null_bool_var, lits, k2); + add_at_least(null_bool_var, lits, k2); return null_literal; } else { - bool_var v = m_solver.add_var(true); + bool_var v = s().add_var(true); literal lit(v, false); - ba.add_at_least(v, lits, k2); + add_at_least(v, lits, k2); si.cache(t, lit); if (sign) lit.neg(); return lit; } } - literal ba_internalize::convert_eq_k(app* t, rational const& k, bool root, bool sign) { + literal ba_solver::convert_eq_k(app* t, rational const& k, bool root, bool sign) { SASSERT(k.is_unsigned()); literal_vector lits; convert_pb_args(t, lits); - bool_var v1 = (root && !sign) ? null_bool_var : m_solver.add_var(true); - bool_var v2 = (root && !sign) ? null_bool_var : m_solver.add_var(true); - ba.add_at_least(v1, lits, k.get_unsigned()); + bool_var v1 = (root && !sign) ? null_bool_var : s().add_var(true); + bool_var v2 = (root && !sign) ? null_bool_var : s().add_var(true); + add_at_least(v1, lits, k.get_unsigned()); for (literal& l : lits) { l.neg(); } - ba.add_at_least(v2, lits, lits.size() - k.get_unsigned()); + add_at_least(v2, lits, lits.size() - k.get_unsigned()); if (!root || sign) { literal l1(v1, false), l2(v2, false); - bool_var v = m_solver.add_var(false); + bool_var v = s().add_var(false); literal l(v, false); si.mk_clause(~l, l1); si.mk_clause(~l, l2); @@ -278,12 +279,12 @@ namespace sat { } } - expr_ref ba_decompile::get_card(std::function& lit2expr, ba_solver::card const& c) { + expr_ref ba_solver::get_card(std::function& lit2expr, ba_solver::card const& c) { ptr_buffer lits; for (sat::literal l : c) { lits.push_back(lit2expr(l)); } - expr_ref fml(pb.mk_at_least_k(c.size(), lits.c_ptr(), c.k()), m); + expr_ref fml(m_pb.mk_at_least_k(c.size(), lits.c_ptr(), c.k()), m); if (c.lit() != sat::null_literal) { fml = m.mk_eq(lit2expr(c.lit()), fml); @@ -291,7 +292,7 @@ namespace sat { return fml; } - expr_ref ba_decompile::get_pb(std::function& lit2expr, ba_solver::pb const& p) { + expr_ref ba_solver::get_pb(std::function& lit2expr, ba_solver::pb const& p) { ptr_buffer lits; vector coeffs; for (auto const& wl : p) { @@ -299,7 +300,7 @@ namespace sat { coeffs.push_back(rational(wl.first)); } rational k(p.k()); - expr_ref fml(pb.mk_ge(p.size(), coeffs.c_ptr(), lits.c_ptr(), k), m); + expr_ref fml(m_pb.mk_ge(p.size(), coeffs.c_ptr(), lits.c_ptr(), k), m); if (p.lit() != sat::null_literal) { fml = m.mk_eq(lit2expr(p.lit()), fml); @@ -307,7 +308,7 @@ namespace sat { return fml; } - expr_ref ba_decompile::get_xor(std::function& lit2expr, ba_solver::xr const& x) { + expr_ref ba_solver::get_xor(std::function& lit2expr, ba_solver::xr const& x) { ptr_buffer lits; for (sat::literal l : x) { lits.push_back(lit2expr(l)); @@ -320,16 +321,16 @@ namespace sat { return fml; } - bool ba_decompile::to_formulas(std::function& l2e, expr_ref_vector& fmls) { - for (auto* c : ba.constraints()) { + bool ba_solver::to_formulas(std::function& l2e, expr_ref_vector& fmls) { + for (auto* c : constraints()) { switch (c->tag()) { case ba_solver::card_t: fmls.push_back(get_card(l2e, c->to_card())); break; - case sat::ba_solver::pb_t: + case ba_solver::pb_t: fmls.push_back(get_pb(l2e, c->to_pb())); break; - case sat::ba_solver::xr_t: + case ba_solver::xr_t: fmls.push_back(get_xor(l2e, c->to_xr())); break; } diff --git a/src/sat/smt/ba_internalize.h b/src/sat/smt/ba_internalize.h deleted file mode 100644 index 39864acae..000000000 --- a/src/sat/smt/ba_internalize.h +++ /dev/null @@ -1,73 +0,0 @@ -/*++ -Copyright (c) 2020 Microsoft Corporation - -Module Name: - - ba_internalize.h - -Abstract: - - INternalize methods for Boolean algebra operators. - -Author: - - Nikolaj Bjorner (nbjorner) 2020-08-25 - ---*/ - - -#pragma once - -#include "sat/smt/sat_th.h" -#include "sat/smt/ba_solver.h" -#include "ast/pb_decl_plugin.h" - - -namespace sat { - - class ba_internalize : public th_internalizer { - typedef std::pair wliteral; - ast_manager& m; - pb_util pb; - ba_solver& ba; - solver_core& m_solver; - sat_internalizer& si; - literal convert_eq_k(app* t, rational const& k, bool root, bool sign); - literal convert_at_most_k(app* t, rational const& k, bool root, bool sign); - literal convert_at_least_k(app* t, rational const& k, bool root, bool sign); - literal convert_pb_eq(app* t, bool root, bool sign); - literal convert_pb_le(app* t, bool root, bool sign); - literal convert_pb_ge(app* t, bool root, bool sign); - void check_unsigned(rational const& c); - void convert_to_wlits(app* t, sat::literal_vector const& lits, svector& wlits); - void convert_pb_args(app* t, svector& wlits); - void convert_pb_args(app* t, literal_vector& lits); - literal internalize_pb(expr* e, bool sign, bool root); - literal internalize_xor(expr* e, bool sign, bool root); - - public: - ba_internalize(ba_solver& ba, solver_core& s, sat_internalizer& si, ast_manager& m) : - m(m), pb(m), ba(ba), m_solver(s), si(si) {} - ~ba_internalize() override {} - literal internalize(expr* e, bool sign, bool root) override; - - }; - - class ba_decompile : public sat::th_decompile { - ast_manager& m; - ba_solver& ba; - solver_core& m_solver; - pb_util pb; - - expr_ref get_card(std::function& l2e, ba_solver::card const& c); - expr_ref get_pb(std::function& l2e, ba_solver::pb const& p); - expr_ref get_xor(std::function& l2e, ba_solver::xr const& x); - public: - ba_decompile(ba_solver& ba, solver_core& s, ast_manager& m) : - m(m), ba(ba), m_solver(s), pb(m) {} - - ~ba_decompile() override {} - - bool to_formulas(std::function& l2e, expr_ref_vector& fmls) override; - }; -} diff --git a/src/sat/smt/ba_solver.cpp b/src/sat/smt/ba_solver.cpp index b4c6ee3d6..bcc674d37 100644 --- a/src/sat/smt/ba_solver.cpp +++ b/src/sat/smt/ba_solver.cpp @@ -3,7 +3,7 @@ Copyright (c) 2017 Microsoft Corporation Module Name: - ba_solver.cpp + ba_core.cpp Abstract: @@ -119,8 +119,8 @@ namespace sat { // ---------------------- // card - ba_solver::card::card(extension* e, unsigned id, literal lit, literal_vector const& lits, unsigned k): - pb_base(e, card_t, id, lit, lits.size(), get_obj_size(lits.size()), k) { + ba_solver::card::card(unsigned id, literal lit, literal_vector const& lits, unsigned k): + pb_base(card_t, id, lit, lits.size(), get_obj_size(lits.size()), k) { for (unsigned i = 0; i < size(); ++i) { m_lits[i] = lits[i]; } @@ -146,8 +146,8 @@ namespace sat { // ----------------------------------- // pb - ba_solver::pb::pb(extension* e, unsigned id, literal lit, svector const& wlits, unsigned k): - pb_base(e, pb_t, id, lit, wlits.size(), get_obj_size(wlits.size()), k), + ba_solver::pb::pb(unsigned id, literal lit, svector const& wlits, unsigned k): + pb_base(pb_t, id, lit, wlits.size(), get_obj_size(wlits.size()), k), m_slack(0), m_num_watch(0), m_max_sum(0) { @@ -302,7 +302,7 @@ namespace sat { SASSERT(validate_conflict(c)); if (c.is_xr() && value(lit) == l_true) lit.neg(); SASSERT(value(lit) == l_false); - set_conflict(justification::mk_ext_justification(s().scope_lvl(), c.index()), ~lit); + set_conflict(justification::mk_ext_justification(s().scope_lvl(), c.cindex()), ~lit); SASSERT(inconsistent()); } @@ -327,7 +327,7 @@ namespace sat { ps.push_back(drat::premise(drat::s_ext(), c.lit())); // null_literal case. drat_add(lits, ps); } - assign(lit, justification::mk_ext_justification(s().scope_lvl(), c.index())); + assign(lit, justification::mk_ext_justification(s().scope_lvl(), c.cindex())); break; } } @@ -1730,21 +1730,21 @@ namespace sat { return p; } - ba_solver::ba_solver() - : m_solver(nullptr), m_lookahead(nullptr), + ba_solver::ba_solver(ast_manager& m, sat_internalizer& si) + : m(m), si(si), m_pb(m), + m_solver(nullptr), m_lookahead(nullptr), m_constraint_id(0), m_ba(*this), m_sort(m_ba) { TRACE("ba", tout << this << "\n";); - std::cout << "mk " << this << "\n"; m_num_propagations_since_pop = 0; } ba_solver::~ba_solver() { m_stats.reset(); for (constraint* c : m_constraints) { - m_allocator.deallocate(c->obj_size(), c); + c->deallocate(m_allocator); } for (constraint* c : m_learned) { - m_allocator.deallocate(c->obj_size(), c); + c->deallocate(m_allocator); } } @@ -1763,7 +1763,8 @@ namespace sat { return nullptr; } void * mem = m_allocator.allocate(card::get_obj_size(lits.size())); - card* c = new (mem) card(this, next_id(), lit, lits, k); + constraint_base::initialize(mem, this); + card* c = new (constraint_base::ptr2mem(mem)) card(next_id(), lit, lits, k); c->set_learned(learned); add_constraint(c); return c; @@ -1832,7 +1833,8 @@ namespace sat { return add_at_least(lit, lits, k, learned); } void * mem = m_allocator.allocate(pb::get_obj_size(wlits.size())); - pb* p = new (mem) pb(this, next_id(), lit, wlits, k); + constraint_base::initialize(mem, this); + pb* p = new (constraint_base::ptr2mem(mem)) pb(next_id(), lit, wlits, k); p->set_learned(learned); add_constraint(p); return p; @@ -2108,11 +2110,11 @@ namespace sat { } bool ba_solver::is_watched(literal lit, constraint const& c) const { - return get_wlist(~lit).contains(watched(c.index())); + return get_wlist(~lit).contains(watched(c.cindex())); } void ba_solver::unwatch_literal(literal lit, constraint& c) { - watched w(c.index()); + watched w(c.cindex()); get_wlist(~lit).erase(w); SASSERT(!is_watched(lit, c)); } @@ -2120,7 +2122,7 @@ namespace sat { void ba_solver::watch_literal(literal lit, constraint& c) { if (c.is_pure() && lit == ~c.lit()) return; SASSERT(!is_watched(lit, c)); - watched w(c.index()); + watched w(c.cindex()); get_wlist(~lit).push_back(w); } @@ -2425,7 +2427,7 @@ namespace sat { constraint* c = m_learned[i]; if (!m_constraint_to_reinit.contains(c)) { remove_constraint(*c, "gc"); - m_allocator.deallocate(c->obj_size(), c); + c->deallocate(m_allocator); ++removed; } else { @@ -2639,7 +2641,7 @@ namespace sat { get_wlist(lit).size() == 1 && m_clause_use_list.get(~lit).empty()) { cp->set_pure(); - get_wlist(~lit).erase(watched(cp->index())); // just ignore assignments to false + get_wlist(~lit).erase(watched(cp->cindex())); // just ignore assignments to false } } } @@ -3403,7 +3405,7 @@ namespace sat { if (c.was_removed()) { clear_watch(c); nullify_tracking_literal(c); - m_allocator.deallocate(c.obj_size(), &c); + c.deallocate(m_allocator); } else if (learned && !c.learned()) { m_constraints.push_back(&c); @@ -3537,10 +3539,10 @@ namespace sat { } card& c2 = c->to_card(); - SASSERT(c1.index() != c2.index()); + SASSERT(&c1 != &c2); if (subsumes(c1, c2, slit)) { if (slit.empty()) { - TRACE("ba", tout << "subsume cardinality\n" << c1 << "\n" << c2.index() << ":" << c2 << "\n";); + TRACE("ba", tout << "subsume cardinality\n" << c1 << "\n" << c2 << "\n";); remove_constraint(c2, "subsumed"); ++m_stats.m_num_pb_subsumes; set_non_learned(c1); @@ -3713,22 +3715,14 @@ namespace sat { } extension* ba_solver::copy(solver* s) { - ba_solver* result = alloc(ba_solver); + return fresh(s, m, si); + } + + th_solver* ba_solver::fresh(solver* s, ast_manager& m, sat_internalizer& si) { + ba_solver* result = alloc(ba_solver, m, si); result->set_solver(s); - copy_core(result, false); - return result; - } - - extension* ba_solver::copy(lookahead* s, bool learned) { - ba_solver* result = alloc(ba_solver); - result->set_lookahead(s); - copy_core(result, learned); - return result; - } - - void ba_solver::copy_core(ba_solver* result, bool learned) { copy_constraints(result, m_constraints); - if (learned) copy_constraints(result, m_learned); + return result; } void ba_solver::copy_constraints(ba_solver* result, ptr_vector const& constraints) { @@ -3768,7 +3762,7 @@ namespace sat { void ba_solver::init_use_list(ext_use_list& ul) { ul.init(s().num_vars()); for (constraint const* cp : m_constraints) { - ext_constraint_idx idx = cp->index(); + ext_constraint_idx idx = cp->cindex(); if (cp->lit() != null_literal) { ul.insert(cp->lit(), idx); ul.insert(~cp->lit(), idx); diff --git a/src/sat/smt/ba_solver.h b/src/sat/smt/ba_solver.h index 50d4fda47..a3a028831 100644 --- a/src/sat/smt/ba_solver.h +++ b/src/sat/smt/ba_solver.h @@ -25,15 +25,17 @@ Revision History: #include "sat/sat_lookahead.h" #include "sat/sat_big.h" #include "sat/smt/sat_smt.h" +#include "sat/smt/sat_th.h" #include "util/small_object_allocator.h" #include "util/scoped_ptr_vector.h" #include "util/sorting_network.h" +#include "ast/pb_decl_plugin.h" namespace sat { class xor_finder; - class ba_solver : public extension { + class ba_solver : public th_solver { friend class local_search; @@ -65,7 +67,7 @@ namespace sat { class xr; class pb_base; - class constraint : public index_base { + class constraint { protected: tag_t m_tag; bool m_removed; @@ -79,19 +81,11 @@ namespace sat { unsigned m_id; bool m_pure; // is the constraint pure (only positive occurrences) public: - constraint(extension* e, tag_t t, unsigned id, literal l, unsigned sz, size_t osz): - index_base(e), + constraint(tag_t t, unsigned id, literal l, unsigned sz, size_t osz): m_tag(t), m_removed(false), m_lit(l), m_watch(null_literal), m_glue(0), m_psm(0), m_size(sz), m_obj_size(osz), m_learned(false), m_id(id), m_pure(false) { - std::cout << "constraint ext: " << t << " " << e << "\n"; - size_t idx = reinterpret_cast(this); - std::cout << "index " << idx << "\n"; - std::cout << this << " " << index_base::from_index(idx) << "\n"; - std::cout << e << " " << index_base::to_extension(idx) << "\n"; - std::cout.flush(); - } - ext_constraint_idx index() const { - return reinterpret_cast(this); } + ext_constraint_idx cindex() const { return constraint_base::mem2base(this); } + void deallocate(small_object_allocator& a) { a.deallocate(obj_size(), constraint_base::mem2base_ptr(this)); } unsigned id() const { return m_id; } tag_t tag() const { return m_tag; } literal lit() const { return m_lit; } @@ -143,8 +137,8 @@ namespace sat { protected: unsigned m_k; public: - pb_base(extension* e, tag_t t, unsigned id, literal l, unsigned sz, size_t osz, unsigned k): - constraint(e, t, id, l, sz, osz), m_k(k) { VERIFY(k < 4000000000); } + pb_base(tag_t t, unsigned id, literal l, unsigned sz, size_t osz, unsigned k): + constraint(t, id, l, sz, osz), m_k(k) { VERIFY(k < 4000000000); } virtual void set_k(unsigned k) { VERIFY(k < 4000000000); m_k = k; } virtual unsigned get_coeff(unsigned i) const { UNREACHABLE(); return 0; } unsigned k() const { return m_k; } @@ -154,8 +148,8 @@ namespace sat { class card : public pb_base { literal m_lits[0]; public: - static size_t get_obj_size(unsigned num_lits) { return sizeof(card) + num_lits * sizeof(literal); } - card(extension* e, unsigned id, literal lit, literal_vector const& lits, unsigned k); + static size_t get_obj_size(unsigned num_lits) { return constraint_base::obj_size(sizeof(card) + num_lits * sizeof(literal)); } + card(unsigned id, literal lit, literal_vector const& lits, unsigned k); literal operator[](unsigned i) const { return m_lits[i]; } literal& operator[](unsigned i) { return m_lits[i]; } literal const* begin() const { return m_lits; } @@ -178,8 +172,8 @@ namespace sat { unsigned m_max_sum; wliteral m_wlits[0]; public: - static size_t get_obj_size(unsigned num_lits) { return sizeof(pb) + num_lits * sizeof(wliteral); } - pb(extension* e, unsigned id, literal lit, svector const& wlits, unsigned k); + static size_t get_obj_size(unsigned num_lits) { return constraint_base::obj_size(sizeof(pb) + num_lits * sizeof(wliteral)); } + pb(unsigned id, literal lit, svector const& wlits, unsigned k); literal lit() const { return m_lit; } wliteral operator[](unsigned i) const { return m_wlits[i]; } wliteral& operator[](unsigned i) { return m_wlits[i]; } @@ -206,8 +200,8 @@ namespace sat { class xr : public constraint { literal m_lits[0]; public: - static size_t get_obj_size(unsigned num_lits) { return sizeof(xr) + num_lits * sizeof(literal); } - xr(extension* e, unsigned id, literal_vector const& lits); + static size_t get_obj_size(unsigned num_lits) { return constraint_base::obj_size(sizeof(xr) + num_lits * sizeof(literal)); } + xr(unsigned id, literal_vector const& lits); literal operator[](unsigned i) const { return m_lits[i]; } literal const* begin() const { return m_lits; } literal const* end() const { return begin() + m_size; } @@ -238,6 +232,10 @@ namespace sat { bool contains(literal l) const { for (auto wl : m_wlits) if (wl.second == l) return true; return false; } }; + ast_manager& m; + sat_internalizer& si; + pb_util m_pb; + solver* m_solver; lookahead* m_lookahead; stats m_stats; @@ -343,7 +341,7 @@ namespace sat { void remove_constraint(constraint& c, char const* reason); // constraints - constraint& index2constraint(size_t idx) const { return *reinterpret_cast(idx); } + constraint& index2constraint(size_t idx) const { return *reinterpret_cast(constraint_base::from_index(idx)->mem()); } void pop_constraint(); void unwatch_literal(literal w, constraint& c); void watch_literal(literal w, constraint& c); @@ -545,8 +543,27 @@ namespace sat { void copy_core(ba_solver* result, bool learned); void copy_constraints(ba_solver* result, ptr_vector const& constraints); + // Internalize + literal convert_eq_k(app* t, rational const& k, bool root, bool sign); + literal convert_at_most_k(app* t, rational const& k, bool root, bool sign); + literal convert_at_least_k(app* t, rational const& k, bool root, bool sign); + literal convert_pb_eq(app* t, bool root, bool sign); + literal convert_pb_le(app* t, bool root, bool sign); + literal convert_pb_ge(app* t, bool root, bool sign); + void check_unsigned(rational const& c); + void convert_to_wlits(app* t, sat::literal_vector const& lits, svector& wlits); + void convert_pb_args(app* t, svector& wlits); + void convert_pb_args(app* t, literal_vector& lits); + literal internalize_pb(expr* e, bool sign, bool root); + literal internalize_xor(expr* e, bool sign, bool root); + + // Decompile + expr_ref get_card(std::function& l2e, ba_solver::card const& c); + expr_ref get_pb(std::function& l2e, ba_solver::pb const& p); + expr_ref get_xor(std::function& l2e, ba_solver::xr const& x); + public: - ba_solver(); + ba_solver(ast_manager& m, sat_internalizer& si); ~ba_solver() override; void set_solver(solver* s) override { m_solver = s; } void set_lookahead(lookahead* l) override { m_lookahead = l; } @@ -572,7 +589,6 @@ namespace sat { std::ostream& display_constraint(std::ostream& out, ext_constraint_idx idx) const override; void collect_statistics(statistics& st) const override; extension* copy(solver* s) override; - extension* copy(lookahead* s, bool learned) override; void find_mutexes(literal_vector& lits, vector & mutexes) override; void pop_reinit() override; void gc() override; @@ -583,6 +599,10 @@ namespace sat { bool is_blocked(literal l, ext_constraint_idx idx) override; bool check_model(model const& m) const override; + literal internalize(expr* e, bool sign, bool root) override; + bool to_formulas(std::function& l2e, expr_ref_vector& fmls) override; + th_solver* fresh(solver* s, ast_manager& m, sat_internalizer& si) override; + ptr_vector const & constraints() const { return m_constraints; } std::ostream& display(std::ostream& out, constraint const& c, bool values) const; diff --git a/src/sat/smt/euf_model.cpp b/src/sat/smt/euf_model.cpp index 048030636..ecf7f36ca 100644 --- a/src/sat/smt/euf_model.cpp +++ b/src/sat/smt/euf_model.cpp @@ -30,20 +30,10 @@ namespace euf { values2model(deps, values, mdl); } - sat::th_model_builder* solver::get_model_builder(expr* e) const { - if (is_app(e)) - return get_model_builder(to_app(e)->get_decl()); - return nullptr; - } - - sat::th_model_builder* solver::get_model_builder(func_decl* f) const { - return m_id2model_builder.get(f->get_family_id(), nullptr); - } - - bool solver::include_func_interp(func_decl* f) const { + bool solver::include_func_interp(func_decl* f) { if (f->get_family_id() == null_family_id) return true; - sat::th_model_builder* mb = get_model_builder(f); + sat::th_model_builder* mb = get_solver(f); return mb && mb->include_func_interp(f); } @@ -53,7 +43,7 @@ namespace euf { deps.insert(n, nullptr); continue; } - auto* mb = get_model_builder(n->get_owner()); + auto* mb = get_solver(n->get_owner()); if (mb) mb->add_dep(n, deps); else @@ -87,7 +77,7 @@ namespace euf { } continue; } - auto* mb = get_model_builder(e); + auto* mb = get_solver(e); if (mb) mb->add_value(n, values); else if (m.is_uninterp(m.get_sort(e))) { diff --git a/src/sat/smt/euf_solver.cpp b/src/sat/smt/euf_solver.cpp index c666e5728..e555990ad 100644 --- a/src/sat/smt/euf_solver.cpp +++ b/src/sat/smt/euf_solver.cpp @@ -20,7 +20,6 @@ Author: #include "sat/sat_solver.h" #include "sat/smt/sat_smt.h" #include "sat/smt/ba_solver.h" -#include "sat/smt/ba_internalize.h" #include "sat/smt/euf_solver.h" namespace euf { @@ -32,54 +31,52 @@ namespace euf { /** * retrieve extension that is associated with Boolean variable. */ - sat::extension* solver::get_extension(sat::bool_var v) { + sat::th_solver* solver::get_solver(sat::bool_var v) { if (v >= m_var2node.size()) return nullptr; euf::enode* n = m_var2node[v].first; if (!n) return nullptr; - return get_extension(n->get_owner()); + return get_solver(n->get_owner()); } - void solver::add_extension(family_id fid, sat::extension* e) { - m_extensions.push_back(e); - m_id2extension.setx(fid, e, nullptr); - } - - sat::extension* solver::get_extension(expr* e) { - if (is_app(e)) { - auto fid = to_app(e)->get_family_id(); - if (fid == null_family_id) - return nullptr; - auto* ext = m_id2extension.get(fid, nullptr); - if (ext) - return ext; - pb_util pb(m); - if (pb.is_pb(e)) { - auto* ba = alloc(sat::ba_solver); - ba->set_solver(m_solver); - add_extension(pb.get_family_id(), ba); - auto* bai = alloc(sat::ba_internalize, *ba, s(), si, m); - m_id2internalize.setx(pb.get_family_id(), bai, nullptr); - m_internalizers.push_back(bai); - m_decompilers.push_back(alloc(sat::ba_decompile, *ba, s(), m)); - ba->push_scopes(s().num_scopes()); - std::cout << "extension ba " << ba << "\n"; - return ba; - } - } + sat::th_solver* solver::get_solver(expr* e) { + if (is_app(e)) + return fid2solver(to_app(e)->get_family_id()); return nullptr; } + sat::th_solver* solver::fid2solver(family_id fid) { + if (fid == null_family_id) + return nullptr; + auto* ext = m_id2solver.get(fid, nullptr); + if (ext) + return ext; + pb_util pb(m); + if (pb.get_family_id() == fid) { + sat::ba_solver* ba = alloc(sat::ba_solver, m, si); + ba->set_solver(m_solver); + add_solver(pb.get_family_id(), ba); + ba->push_scopes(s().num_scopes()); + return ba; + } + + return nullptr; + } + + void solver::add_solver(family_id fid, sat::th_solver* th) { + m_solvers.push_back(th); + m_id2solver.setx(fid, th, nullptr); + } + bool solver::propagate(literal l, ext_constraint_idx idx) { - auto* ext = sat::index_base::to_extension(idx); - std::cout << "extension " << ext << " " << idx << "\n"; + auto* ext = sat::constraint_base::to_extension(idx); SASSERT(ext != this); return ext->propagate(l, idx); } void solver::get_antecedents(literal l, ext_justification_idx idx, literal_vector& r) { - auto* ext = sat::index_base::to_extension(idx); + auto* ext = sat::constraint_base::to_extension(idx); if (ext == this) get_antecedents(l, *constraint::from_idx(idx), r); else @@ -110,7 +107,7 @@ namespace euf { break; case 2: SASSERT(m.is_bool(n->get_owner())); - m_egraph.explain_eq(m_explain, n, (sign ? m_false : m_true), false); + m_egraph.explain_eq(m_explain, n, (sign ? mk_false() : mk_true()), false); break; default: UNREACHABLE(); @@ -120,7 +117,7 @@ namespace euf { } void solver::asserted(literal l) { - auto* ext = get_extension(l.var()); + auto* ext = get_solver(l.var()); if (ext) { ext->asserted(l); return; @@ -138,7 +135,7 @@ namespace euf { m_egraph.merge(na, nb, base_ptr() + l.index()); } else { - euf::enode* nb = sign ? m_false : m_true; + euf::enode* nb = sign ? mk_false() : mk_true(); m_egraph.merge(n, nb, base_ptr() + l.index()); } // TBD: delay propagation? @@ -148,7 +145,7 @@ namespace euf { void solver::propagate() { m_egraph.propagate(); if (m_egraph.inconsistent()) { - s().set_conflict(sat::justification::mk_ext_justification(s().scope_lvl(), m_conflict_idx.to_index())); + s().set_conflict(sat::justification::mk_ext_justification(s().scope_lvl(), conflict_constraint().to_index())); return; } for (euf::enode* eq : m_egraph.new_eqs()) { @@ -156,7 +153,7 @@ namespace euf { expr* a = nullptr, *b = nullptr; if (s().value(v) == l_false && m_ackerman && m.is_eq(eq->get_owner(), a, b)) m_ackerman->cg_conflict_eh(a, b); - s().assign(literal(v, false), sat::justification::mk_ext_justification(s().scope_lvl(), m_eq_idx.to_index())); + s().assign(literal(v, false), sat::justification::mk_ext_justification(s().scope_lvl(), eq_constraint().to_index())); } for (euf::enode* p : m_egraph.new_lits()) { expr* e = p->get_owner(); @@ -167,14 +164,31 @@ namespace euf { literal lit(v, sign); if (s().value(lit) == l_false && m_ackerman) m_ackerman->cg_conflict_eh(p->get_owner(), p->get_root()->get_owner()); - s().assign(lit, sat::justification::mk_ext_justification(s().scope_lvl(), m_lit_idx.to_index())); + s().assign(lit, sat::justification::mk_ext_justification(s().scope_lvl(), lit_constraint().to_index())); } } + constraint& solver::mk_constraint(constraint*& c, unsigned id) { + if (!c) { + void* mem = memory::allocate(sat::constraint_base::obj_size(sizeof(constraint))); + c = new (sat::constraint_base::ptr2mem(mem)) constraint(id); + sat::constraint_base::initialize(mem, this); + } + return *c; + } + + enode* solver::mk_true() { + return visit(m.mk_true()); + } + + enode* solver::mk_false() { + return visit(m.mk_false()); + } + sat::check_result solver::check() { bool give_up = false; bool cont = false; - for (auto* e : m_extensions) + for (auto* e : m_solvers) switch (e->check()) { case sat::CR_CONTINUE: cont = true; break; case sat::CR_GIVEUP: give_up = true; break; @@ -188,7 +202,7 @@ namespace euf { } void solver::push() { - for (auto* e : m_extensions) + for (auto* e : m_solvers) e->push(); m_egraph.push(); ++m_num_scopes; @@ -196,7 +210,7 @@ namespace euf { void solver::pop(unsigned n) { m_egraph.pop(n); - for (auto* e : m_extensions) + for (auto* e : m_solvers) e->pop(n); if (n <= m_num_scopes) { m_num_scopes -= n; @@ -212,24 +226,24 @@ namespace euf { } void solver::pre_simplify() { - for (auto* e : m_extensions) + for (auto* e : m_solvers) e->pre_simplify(); } void solver::simplify() { - for (auto* e : m_extensions) + for (auto* e : m_solvers) e->simplify(); if (m_ackerman) m_ackerman->propagate(); } void solver::clauses_modifed() { - for (auto* e : m_extensions) + for (auto* e : m_solvers) e->clauses_modifed(); } lbool solver::get_phase(bool_var v) { - auto* ext = get_extension(v); + auto* ext = get_solver(v); if (ext) return ext->get_phase(v); return l_undef; @@ -237,20 +251,20 @@ namespace euf { std::ostream& solver::display(std::ostream& out) const { m_egraph.display(out); - for (auto* e : m_extensions) + for (auto* e : m_solvers) e->display(out); return out; } std::ostream& solver::display_justification(std::ostream& out, ext_justification_idx idx) const { - auto* ext = sat::index_base::to_extension(idx); + auto* ext = sat::constraint_base::to_extension(idx); if (ext != this) return ext->display_justification(out, idx); return out; } std::ostream& solver::display_constraint(std::ostream& out, ext_constraint_idx idx) const { - auto* ext = sat::index_base::to_extension(idx); + auto* ext = sat::constraint_base::to_extension(idx); if (ext != this) return ext->display_constraint(out, idx); return out; @@ -258,89 +272,68 @@ namespace euf { void solver::collect_statistics(statistics& st) const { m_egraph.collect_statistics(st); - for (auto* e : m_extensions) + for (auto* e : m_solvers) e->collect_statistics(st); st.update("euf dynack", m_stats.m_num_dynack); } - solver* solver::copy_core() { - ast_manager& to = m_translate ? m_translate->to() : m; - atom2bool_var& a2b = m_translate_expr2var ? *m_translate_expr2var : m_expr2var; - sat::sat_internalizer& to_si = m_translate_si ? *m_translate_si : si; - auto* r = alloc(solver, to, a2b, to_si); - r->m_config = m_config; - std::function copy_justification = [&](void* x) { return (void*)(r->base_ptr() + ((unsigned*)x - base_ptr())); }; - r->m_egraph.copy_from(m_egraph, copy_justification); - return r; - } - sat::extension* solver::copy(sat::solver* s) { - auto* r = copy_core(); - r->set_solver(s); - for (unsigned i = 0; i < m_id2extension.size(); ++i) { - auto* e = m_id2extension[i]; - if (e) - r->add_extension(i, e->copy(s)); - } - - return r; - } - - sat::extension* solver::copy(sat::lookahead* s, bool learned) { - (void) learned; - auto* r = copy_core(); - r->set_lookahead(s); - for (unsigned i = 0; i < m_id2extension.size(); ++i) { - auto* e = m_id2extension[i]; - if (e) - r->add_extension(i, e->copy(s, learned)); - } + auto* r = alloc(solver, *m_to_m, *m_to_expr2var, *m_to_si); + r->m_config = m_config; + std::function copy_justification = [&](void* x) { return (void*)(r->base_ptr() + ((unsigned*)x - base_ptr())); }; + r->m_egraph.copy_from(m_egraph, copy_justification); + r->set_solver(s); + for (unsigned i = 0; i < m_id2solver.size(); ++i) { + auto* e = m_id2solver[i]; + if (e) + r->add_solver(i, e->fresh(s, *m_to_m, *m_to_si)); + } return r; } void solver::find_mutexes(literal_vector& lits, vector & mutexes) { - for (auto* e : m_extensions) + for (auto* e : m_solvers) e->find_mutexes(lits, mutexes); } void solver::gc() { - for (auto* e : m_extensions) + for (auto* e : m_solvers) e->gc(); } void solver::pop_reinit() { - for (auto* e : m_extensions) + for (auto* e : m_solvers) e->pop_reinit(); } bool solver::validate() { - for (auto* e : m_extensions) + for (auto* e : m_solvers) if (!e->validate()) return false; return true; } void solver::init_use_list(sat::ext_use_list& ul) { - for (auto* e : m_extensions) + for (auto* e : m_solvers) e->init_use_list(ul); } bool solver::is_blocked(literal l, ext_constraint_idx idx) { - auto* ext = sat::index_base::to_extension(idx); + auto* ext = sat::constraint_base::to_extension(idx); if (ext != this) return is_blocked(l, idx); return false; } bool solver::check_model(sat::model const& m) const { - for (auto* e : m_extensions) + for (auto* e : m_solvers) if (!e->check_model(m)) return false; return true; } unsigned solver::max_var(unsigned w) const { - for (auto* e : m_extensions) + for (auto* e : m_solvers) w = e->max_var(w); for (unsigned sz = m_var2node.size(); sz-- > 0; ) { euf::enode* n = m_var2node[sz].first; @@ -368,25 +361,10 @@ namespace euf { m_egraph.set_used_cc(used_cc); } - - sat::th_internalizer* solver::get_internalizer(expr* e) { - if (is_app(e)) - return m_id2internalize.get(to_app(e)->get_family_id(), nullptr); - if (m.is_iff(e)) { - pb_util pb(m); - return m_id2internalize.get(pb.get_family_id(), nullptr); - } - return nullptr; - } - sat::literal solver::internalize(expr* e, bool sign, bool root) { - auto* ext = get_internalizer(e); + auto* ext = get_solver(e); if (ext) return ext->internalize(e, sign, root); - if (!m_true) { - m_true = visit(m.mk_true()); - m_false = visit(m.mk_false()); - } std::cout << mk_pp(e, m) << "\n"; SASSERT(!si.is_bool_op(e)); sat::scoped_stack _sc(m_stack); @@ -466,7 +444,7 @@ namespace euf { } bool solver::to_formulas(std::function& l2e, expr_ref_vector& fmls) { - for (auto* th : m_decompilers) { + for (auto* th : m_solvers) { if (!th->to_formulas(l2e, fmls)) return false; } @@ -479,9 +457,7 @@ namespace euf { bool solver::extract_pb(std::function& card, std::function& pb) { - if (m_true) - return false; - for (auto* e : m_extensions) + for (auto* e : m_solvers) if (!e->extract_pb(card, pb)) return false; return true; diff --git a/src/sat/smt/euf_solver.h b/src/sat/smt/euf_solver.h index 6c18b87e2..26c41abeb 100644 --- a/src/sat/smt/euf_solver.h +++ b/src/sat/smt/euf_solver.h @@ -33,14 +33,15 @@ namespace euf { typedef sat::literal_vector literal_vector; typedef sat::bool_var bool_var; - class constraint : public sat::index_base { + class constraint { unsigned m_id; public: - constraint(sat::extension* e, unsigned id) : - index_base(e), m_id(id) + constraint(unsigned id) : + m_id(id) {} unsigned id() const { return m_id; } static constraint* from_idx(size_t z) { return reinterpret_cast(z); } + size_t to_index() const { return sat::constraint_base::mem2base(this); } }; class solver : public sat::extension, public sat::th_internalizer, public sat::th_decompile { @@ -60,13 +61,11 @@ namespace euf { stats m_stats; sat::solver* m_solver { nullptr }; sat::lookahead* m_lookahead { nullptr }; - ast_translation* m_translate { nullptr }; - atom2bool_var* m_translate_expr2var { nullptr }; - sat::sat_internalizer* m_translate_si{ nullptr }; - scoped_ptr m_ackerman; + ast_manager* m_to_m { nullptr }; + atom2bool_var* m_to_expr2var { nullptr }; + sat::sat_internalizer* m_to_si{ nullptr }; + scoped_ptr m_ackerman; - euf::enode* m_true { nullptr }; - euf::enode* m_false { nullptr }; svector m_var2node; ptr_vector m_explain; euf::enode_vector m_args; @@ -74,35 +73,33 @@ namespace euf { unsigned m_num_scopes { 0 }; unsigned_vector m_bool_var_trail; unsigned_vector m_bool_var_lim; - scoped_ptr_vector m_extensions; - ptr_vector m_id2extension; - ptr_vector m_id2internalize; - scoped_ptr_vector m_internalizers; - scoped_ptr_vector m_model_builders; - ptr_vector m_id2model_builder; - scoped_ptr_vector m_decompilers; - constraint m_conflict_idx, m_eq_idx, m_lit_idx; + scoped_ptr_vector m_solvers; + ptr_vector m_id2solver; + + constraint* m_conflict { nullptr }; + constraint* m_eq { nullptr }; + constraint* m_lit { nullptr }; sat::solver& s() { return *m_solver; } unsigned * base_ptr() { return reinterpret_cast(this); } // internalization - sat::th_internalizer* get_internalizer(expr* e); euf::enode* visit(expr* e); void attach_bool_var(euf::enode* n); void attach_bool_var(sat::bool_var v, bool sign, euf::enode* n); - solver* copy_core(); + euf::enode* mk_true(); + euf::enode* mk_false(); // extensions - sat::extension* get_extension(sat::bool_var v); - sat::extension* get_extension(expr* e); - void add_extension(family_id fid, sat::extension* e); + sat::th_solver* get_solver(func_decl* f) { return fid2solver(f->get_family_id()); } + sat::th_solver* get_solver(expr* e); + sat::th_solver* get_solver(sat::bool_var v); + sat::th_solver* fid2solver(family_id fid); + void add_solver(family_id fid, sat::th_solver* th); void init_ackerman(); // model building - bool include_func_interp(func_decl* f) const; - sat::th_model_builder* get_model_builder(expr* e) const; - sat::th_model_builder* get_model_builder(func_decl* f) const; + bool include_func_interp(func_decl* f); void register_macros(model& mdl); void dependencies2values(deps_t& deps, expr_ref_vector& values, model_ref const& mdl); void collect_dependencies(deps_t& deps); @@ -112,6 +109,11 @@ namespace euf { void propagate(); void get_antecedents(literal l, constraint& j, literal_vector& r); + constraint& mk_constraint(constraint*& c, unsigned id); + constraint& conflict_constraint() { return mk_constraint(m_conflict, 0); } + constraint& eq_constraint() { return mk_constraint(m_eq, 1); } + constraint& lit_constraint() { return mk_constraint(m_lit, 2); } + public: solver(ast_manager& m, atom2bool_var& expr2var, sat::sat_internalizer& si, params_ref const& p = params_ref()): m(m), @@ -120,31 +122,31 @@ namespace euf { m_egraph(m), m_solver(nullptr), m_lookahead(nullptr), - m_translate(nullptr), - m_translate_expr2var(nullptr), - m_true(nullptr), - m_false(nullptr), - m_conflict_idx(this, 0), - m_eq_idx(this, 1), - m_lit_idx(this, 2) + m_to_m(&m), + m_to_expr2var(&expr2var), + m_to_si(&si) { updt_params(p); } - ~solver() override {} + ~solver() override { + if (m_conflict) dealloc(sat::constraint_base::mem2base_ptr(m_conflict)); + if (m_eq) dealloc(sat::constraint_base::mem2base_ptr(m_eq)); + if (m_lit) dealloc(sat::constraint_base::mem2base_ptr(m_lit)); + } void updt_params(params_ref const& p); void set_solver(sat::solver* s) override { m_solver = s; } void set_lookahead(sat::lookahead* s) override { m_lookahead = s; } struct scoped_set_translate { solver& s; - scoped_set_translate(solver& s, ast_translation& t, atom2bool_var& a2b, sat::sat_internalizer& si) : + scoped_set_translate(solver& s, ast_manager& m, atom2bool_var& a2b, sat::sat_internalizer& si) : s(s) { - s.m_translate = &t; - s.m_translate_expr2var = &a2b; - s.m_translate_si = &si; + s.m_to_m = &m; + s.m_to_expr2var = &a2b; + s.m_to_si = &si; } - ~scoped_set_translate() { s.m_translate = nullptr; s.m_translate_expr2var = nullptr; s.m_translate_si = nullptr; } + ~scoped_set_translate() { s.m_to_m = &s.m; s.m_to_expr2var = &s.m_expr2var; s.m_to_si = &s.si; } }; double get_reward(literal l, ext_constraint_idx idx, sat::literal_occs_fun& occs) const override { return 0; } bool is_extended_binary(ext_justification_idx idx, literal_vector & r) override { return false; } @@ -165,7 +167,6 @@ namespace euf { std::ostream& display_constraint(std::ostream& out, ext_constraint_idx idx) const override; void collect_statistics(statistics& st) const override; extension* copy(sat::solver* s) override; - extension* copy(sat::lookahead* s, bool learned) override; void find_mutexes(literal_vector& lits, vector & mutexes) override; void gc() override; void pop_reinit() override; diff --git a/src/sat/smt/sat_smt.h b/src/sat/smt/sat_smt.h index 4c5bdd246..ec718bf93 100644 --- a/src/sat/smt/sat_smt.h +++ b/src/sat/smt/sat_smt.h @@ -14,9 +14,6 @@ Author: Nikolaj Bjorner (nbjorner) 2020-08-25 --*/ -#pragma once - - #pragma once #include "ast/ast.h" #include "ast/ast_pp.h" @@ -41,19 +38,59 @@ namespace sat { public: virtual ~sat_internalizer() {} virtual bool is_bool_op(expr* e) const = 0; - virtual sat::literal internalize(expr* e) = 0; - virtual sat::bool_var add_bool_var(expr* e) = 0; + virtual literal internalize(expr* e) = 0; + virtual bool_var add_bool_var(expr* e) = 0; virtual void mk_clause(literal a, literal b) = 0; virtual void mk_clause(literal l1, literal l2, literal l3, bool is_lemma = false) = 0; virtual void cache(app* t, literal l) = 0; }; - class index_base { - extension* ex; + class constraint_base { + extension* m_ex; + unsigned m_mem[0]; + static size_t ext_size() { + return sizeof(((constraint_base*)nullptr)->m_ex); + } + public: - index_base(extension* e) : ex(e) { to_index(); } - static extension* to_extension(size_t s) { std::cout << "to_extension: " << from_index(s) << " " << from_index(s)->ex << " " << s << "\n"; return from_index(s)->ex; } - static index_base* from_index(size_t s) { return reinterpret_cast(s); } - size_t to_index() const { std::cout << "to_index " << this << " " << ex << " " << reinterpret_cast(this) << "\n"; return reinterpret_cast(this); } + constraint_base(): m_ex(nullptr) {} + void* mem() { return m_mem; } + + static size_t obj_size(size_t sz) { + return ext_size() + sz; + } + + static extension* to_extension(size_t s) { + return from_index(s)->m_ex; + } + + static constraint_base* from_index(size_t s) { + return reinterpret_cast(s); + } + + size_t to_index() const { + return reinterpret_cast(this); + } + + static constraint_base const* mem2base_ptr(void const* mem) { + return reinterpret_cast((unsigned char const*)(mem) - ext_size()); + } + + static constraint_base* mem2base_ptr(void* mem) { + return reinterpret_cast((unsigned char*)(mem) - ext_size()); + } + + static size_t mem2base(void const* mem) { + return reinterpret_cast(mem2base_ptr(mem)); + } + + static void initialize(void* ptr, extension* ext) { + reinterpret_cast(ptr)->m_ex = ext; + } + + static void* ptr2mem(void* ptr) { + return reinterpret_cast(((unsigned char*) ptr) + ext_size()); + } + }; } diff --git a/src/sat/smt/sat_th.h b/src/sat/smt/sat_th.h index 659e6aeca..220111db5 100644 --- a/src/sat/smt/sat_th.h +++ b/src/sat/smt/sat_th.h @@ -47,12 +47,12 @@ namespace sat { \brief compute the value for enode \c n and store the value in \c values for the root of the class of \c n. */ - virtual void add_value(euf::enode* n, expr_ref_vector& values) = 0; + virtual void add_value(euf::enode* n, expr_ref_vector& values) {} /** \brief compute dependencies for node n */ - virtual void add_dep(euf::enode* n, top_sort& dep) = 0; + virtual void add_dep(euf::enode* n, top_sort& dep) {} /** \brief should function be included in model. @@ -61,24 +61,10 @@ namespace sat { }; class th_solver : public extension, public th_model_builder, public th_decompile, public th_internalizer { - + public: virtual ~th_solver() {} - /** - \brief compute the value for enode \c n and store the value in \c values - for the root of the class of \c n. - */ - virtual void add_value(euf::enode* n, expr_ref_vector& values) = 0; - - /** - \brief compute dependencies for node n - */ - virtual void add_dep(euf::enode* n, top_sort& dep) = 0; - - /** - \brief should function be included in model. - */ - virtual bool include_func_interp(func_decl* f) const { return false; } + virtual th_solver* fresh(solver* s, ast_manager& m, sat_internalizer& si) = 0; }; diff --git a/src/sat/smt/xor_solver.cpp b/src/sat/smt/xor_solver.cpp index 19ac000bc..51b784ff4 100644 --- a/src/sat/smt/xor_solver.cpp +++ b/src/sat/smt/xor_solver.cpp @@ -33,8 +33,8 @@ namespace sat { return static_cast(*this); } - ba_solver::xr::xr(extension* e, unsigned id, literal_vector const& lits): - constraint(e, xr_t, id, null_literal, lits.size(), get_obj_size(lits.size())) { + ba_solver::xr::xr(unsigned id, literal_vector const& lits): + constraint(xr_t, id, null_literal, lits.size(), get_obj_size(lits.size())) { for (unsigned i = 0; i < size(); ++i) { m_lits[i] = lits[i]; } @@ -264,7 +264,8 @@ namespace sat { break; } void * mem = m_allocator.allocate(xr::get_obj_size(lits.size())); - xr* x = new (mem) xr(this, next_id(), lits); + constraint_base::initialize(mem, this); + xr* x = new (constraint_base::ptr2mem(mem)) xr(next_id(), lits); x->set_learned(learned); add_constraint(x); return x; diff --git a/src/sat/tactic/goal2sat.cpp b/src/sat/tactic/goal2sat.cpp index 2d4bbbead..b9effe74e 100644 --- a/src/sat/tactic/goal2sat.cpp +++ b/src/sat/tactic/goal2sat.cpp @@ -35,7 +35,6 @@ Notes: #include "ast/for_each_expr.h" #include "sat/tactic/goal2sat.h" #include "sat/sat_cut_simplifier.h" -#include "sat/smt/ba_internalize.h" #include "sat/smt/ba_solver.h" #include "sat/smt/euf_solver.h" #include "model/model_evaluator.h" @@ -95,7 +94,6 @@ struct goal2sat::imp : public sat::sat_internalizer { m_max_memory = megabytes_to_bytes(p.get_uint("max_memory", UINT_MAX)); m_xor_solver = p.get_bool("xor_solver", false); m_euf = false; - m_euf = true; } void throw_op_not_handled(std::string const& s) { @@ -220,6 +218,7 @@ struct goal2sat::imp : public sat::sat_internalizer { } bool visit(expr * t, bool root, bool sign) { + SASSERT(m.is_bool(t)); if (!is_app(t)) { convert_atom(t, root, sign); return true; @@ -232,10 +231,12 @@ struct goal2sat::imp : public sat::sat_internalizer { case OP_NOT: case OP_OR: case OP_AND: + case OP_ITE: + case OP_XOR: + case OP_IMPLIES: m_frame_stack.push_back(frame(to_app(t), root, sign, 0)); return false; - case OP_ITE: - case OP_EQ: + case OP_EQ: if (m.is_bool(to_app(t)->get_arg(1))) { m_frame_stack.push_back(frame(to_app(t), root, sign, 0)); return false; @@ -244,8 +245,6 @@ struct goal2sat::imp : public sat::sat_internalizer { convert_atom(t, root, sign); return true; } - case OP_XOR: - case OP_IMPLIES: case OP_DISTINCT: { TRACE("goal2sat_not_handled", tout << mk_ismt2_pp(t, m) << "\n";); std::ostringstream strm; @@ -397,7 +396,40 @@ struct goal2sat::imp : public sat::sat_internalizer { } } + void convert_implies(app* t, bool root, bool sign) { + SASSERT(t->get_num_args() == 2); + unsigned sz = m_result_stack.size(); + SASSERT(sz >= 2); + sat::literal l1 = m_result_stack[sz - 1]; + sat::literal l2 = m_result_stack[sz - 2]; + if (root) { + SASSERT(sz == 2); + if (sign) { + mk_clause(l1); + mk_clause(~l2); + } + else { + mk_clause(~l1, l2); + } + m_result_stack.reset(); + } + else { + sat::bool_var k = m_solver.add_var(false); + sat::literal l(k, false); + m_cache.insert(t, l); + // l <=> (l1 => l2) + mk_clause(~l, ~l1, l2); + mk_clause(l1, l); + mk_clause(~l2, l); + if (sign) + l.neg(); + m_result_stack.shrink(sz - 2); + m_result_stack.push_back(l); + } + } + void convert_iff2(app * t, bool root, bool sign) { + SASSERT(t->get_num_args() == 2); TRACE("goal2sat", tout << "convert_iff " << root << " " << sign << "\n" << mk_bounded_pp(t, m, 2) << "\n";); unsigned sz = m_result_stack.size(); SASSERT(sz >= 2); @@ -467,11 +499,10 @@ struct goal2sat::imp : public sat::sat_internalizer { void convert_ba(app* t, bool root, bool sign) { SASSERT(!m_euf); - std::cout << "convert ba\n"; sat::extension* ext = m_solver.get_extension(); sat::ba_solver* ba = nullptr; if (!ext) { - ba = alloc(sat::ba_solver); + ba = alloc(sat::ba_solver, m, *this); m_solver.set_extension(ba); ba->push_scopes(m_solver.num_scopes()); } @@ -480,8 +511,7 @@ struct goal2sat::imp : public sat::sat_internalizer { } if (!ba) throw default_exception("cannot convert to pb"); - sat::ba_internalize internalize(*ba, m_solver, *this, m); - sat::literal lit = internalize.internalize(t, sign, root); + sat::literal lit = ba->internalize(t, sign, root); if (root) m_result_stack.reset(); else @@ -509,6 +539,12 @@ struct goal2sat::imp : public sat::sat_internalizer { case OP_EQ: convert_iff(t, root, sign); break; + case OP_XOR: + convert_iff(t, root, !sign); + break; + case OP_IMPLIES: + convert_implies(t, root, sign); + break; default: UNREACHABLE(); } @@ -614,6 +650,8 @@ struct goal2sat::imp : public sat::sat_internalizer { case OP_TRUE: case OP_FALSE: case OP_NOT: + case OP_IMPLIES: + case OP_XOR: return true; case OP_ITE: case OP_EQ: @@ -657,6 +695,15 @@ struct goal2sat::imp : public sat::sat_internalizer { } void operator()(goal const & g) { + struct scoped_reset { + imp& i; + scoped_reset(imp& i) :i(i) {} + ~scoped_reset() { + i.m_interface_vars.reset(); + i.m_cache.reset(); + } + }; + scoped_reset _reset(*this); collect_boolean_interface(g, m_interface_vars); unsigned size = g.size(); expr_ref f(m), d_new(m); @@ -696,16 +743,6 @@ struct goal2sat::imp : public sat::sat_internalizer { } } -#if 0 - void operator()(unsigned sz, expr * const * fs) { - m_interface_vars.reset(); - collect_boolean_interface(m, sz, fs, m_interface_vars); - - for (unsigned i = 0; i < sz; i++) - process(fs[i]); - } -#endif - }; struct unsupported_bool_proc { @@ -717,8 +754,6 @@ struct unsupported_bool_proc { void operator()(app * n) { if (n->get_family_id() == m.get_basic_family_id()) { switch (n->get_decl_kind()) { - case OP_XOR: - case OP_IMPLIES: case OP_DISTINCT: throw found(); default: @@ -758,19 +793,8 @@ void goal2sat::operator()(goal const & g, params_ref const & p, sat::solver_core if (!m_imp) m_imp = alloc(imp, g.m(), p, t, m, dep2asm, default_external); - struct scoped_reset { - goal2sat& g; - scoped_reset(goal2sat& g):g(g) {} - ~scoped_reset() { - g.m_imp->m_interface_vars.reset(); - g.m_imp->m_cache.reset(); - } - }; - { - scoped_reset _reset(*this); - (*m_imp)(g); - } - + (*m_imp)(g); + m_interpreted_atoms = alloc(expr_ref_vector, g.m()); m_interpreted_atoms->append(m_imp->m_interpreted_atoms); if (!t.get_extension()) { @@ -1021,8 +1045,7 @@ struct sat2goal::imp { expr_ref_vector fmls(m); sat::ba_solver* ba = dynamic_cast(ext); if (ba) { - sat::ba_decompile decompile(*ba, s, m); - decompile.to_formulas(l2e, fmls); + ba->to_formulas(l2e, fmls); } else dynamic_cast(ext)->to_formulas(l2e, fmls); diff --git a/src/smt/seq_regex.cpp b/src/smt/seq_regex.cpp index 4852347bc..784a7efd3 100644 --- a/src/smt/seq_regex.cpp +++ b/src/smt/seq_regex.cpp @@ -490,6 +490,7 @@ namespace smt { expr_ref is_nullable = is_nullable_wrapper(r); if (m.is_true(is_nullable)) return; + literal null_lit = th.mk_literal(is_nullable); expr_ref hd = mk_first(r, n); expr_ref d(m);