From c21a2fcf9f98f5b7d1c1a5db034714880b7bfa34 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Wed, 26 Aug 2020 09:40:31 -0700 Subject: [PATCH] sat solver setup Signed-off-by: Nikolaj Bjorner --- scripts/mk_project.py | 8 +- src/CMakeLists.txt | 4 +- src/ast/euf/euf_egraph.cpp | 63 ++- src/ast/euf/euf_egraph.h | 13 + src/ast/euf/euf_justification.h | 14 + src/ast/pb_decl_plugin.h | 1 + src/sat/CMakeLists.txt | 1 - src/sat/ba/CMakeLists.txt | 8 + src/sat/ba/ba_internalize.cpp | 281 ++++++++++++++ src/sat/ba/ba_internalize.h | 52 +++ src/sat/{ => ba}/ba_solver.cpp | 20 +- src/sat/{ => ba}/ba_solver.h | 15 +- src/sat/euf/CMakeLists.txt | 2 +- src/sat/euf/euf_solver.cpp | 318 +++++++++++++-- src/sat/euf/euf_solver.h | 70 +++- src/sat/sat_extension.h | 4 +- src/sat/sat_local_search.cpp | 7 +- src/sat/sat_params.pyg | 1 + src/sat/sat_solver.h | 1 + src/sat/sat_solver_core.h | 2 + src/sat/smt/CMakeLists.txt | 8 + src/sat/{tactic => smt}/atom2bool_var.cpp | 2 +- src/sat/{tactic => smt}/atom2bool_var.h | 0 src/sat/smt/sat_smt.h | 62 +++ src/sat/tactic/CMakeLists.txt | 5 +- src/sat/tactic/goal2sat.cpp | 449 ++++++---------------- src/sat/tactic/goal2sat.h | 12 +- src/shell/dimacs_frontend.cpp | 2 +- 28 files changed, 984 insertions(+), 441 deletions(-) create mode 100644 src/sat/ba/CMakeLists.txt create mode 100644 src/sat/ba/ba_internalize.cpp create mode 100644 src/sat/ba/ba_internalize.h rename src/sat/{ => ba}/ba_solver.cpp (99%) rename src/sat/{ => ba}/ba_solver.h (97%) create mode 100644 src/sat/smt/CMakeLists.txt rename src/sat/{tactic => smt}/atom2bool_var.cpp (99%) rename src/sat/{tactic => smt}/atom2bool_var.h (100%) create mode 100644 src/sat/smt/sat_smt.h diff --git a/scripts/mk_project.py b/scripts/mk_project.py index 59faf0f06..979566614 100644 --- a/scripts/mk_project.py +++ b/scripts/mk_project.py @@ -37,9 +37,11 @@ def init_project_def(): add_lib('parser_util', ['ast'], 'parsers/util') add_lib('proofs', ['rewriter', 'util'], 'ast/proofs') add_lib('solver', ['model', 'tactic', 'proofs']) - add_lib('cmd_context', ['solver', 'rewriter']) - add_lib('sat_tactic', ['tactic', 'sat', 'solver'], 'sat/tactic') - add_lib('sat_euf', ['sat_tactic', 'sat', 'euf'], 'sat/euf') + add_lib('cmd_context', ['solver', 'rewriter']) + add_lib('sat_smt', ['sat', 'tactic'], 'sat/smt') + add_lib('sat_ba', ['sat', 'sat_smt'], 'sat/ba') + add_lib('sat_euf', ['sat', 'euf', 'sat_ba'], 'sat/euf') + add_lib('sat_tactic', ['tactic', 'sat', 'solver', 'sat_euf'], 'sat/tactic') add_lib('smt2parser', ['cmd_context', 'parser_util'], 'parsers/smt2') add_lib('pattern', ['normal_forms', 'smt2parser', 'rewriter'], 'ast/pattern') add_lib('core_tactics', ['tactic', 'macros', 'normal_forms', 'rewriter', 'pattern'], 'tactic/core') diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index e18e38f69..df9d6acaf 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -59,8 +59,10 @@ add_subdirectory(tactic/core) add_subdirectory(math/subpaving/tactic) add_subdirectory(tactic/aig) add_subdirectory(solver) -add_subdirectory(sat/tactic) +add_subdirectory(sat/smt) +add_subdirectory(sat/ba) add_subdirectory(sat/euf) +add_subdirectory(sat/tactic) add_subdirectory(tactic/arith) add_subdirectory(nlsat/tactic) add_subdirectory(ackermannization) diff --git a/src/ast/euf/euf_egraph.cpp b/src/ast/euf/euf_egraph.cpp index c68f743a1..5f073a877 100644 --- a/src/ast/euf/euf_egraph.cpp +++ b/src/ast/euf/euf_egraph.cpp @@ -17,6 +17,7 @@ Author: #include "ast/euf/euf_egraph.h" #include "ast/ast_pp.h" +#include "ast/ast_translation.h" namespace euf { @@ -60,8 +61,10 @@ namespace euf { void egraph::reinsert_equality(enode* p) { SASSERT(is_equality(p)); - if (p->get_arg(0)->get_root() == p->get_arg(1)->get_root()) + if (p->get_arg(0)->get_root() == p->get_arg(1)->get_root()) { m_new_eqs.push_back(p); + ++m_stats.m_num_eqs; + } } bool egraph::is_equality(enode* p) const { @@ -162,6 +165,7 @@ namespace euf { enode* r2 = n2->get_root(); if (r1 == r2) return; + ++m_stats.m_num_merge; if (r1->interpreted() && r2->interpreted()) { set_conflict(n1, n2, j); return; @@ -170,8 +174,10 @@ namespace euf { std::swap(r1, r2); std::swap(n1, n2); } - if ((m.is_true(r2->get_owner()) || m.is_false(r2->get_owner())) && j.is_congruence()) + if ((m.is_true(r2->get_owner()) || m.is_false(r2->get_owner())) && j.is_congruence()) { m_new_lits.push_back(n1); + ++m_stats.m_num_lits; + } for (enode* p : enode_parents(n1)) m_table.erase(p); for (enode* p : enode_parents(n2)) @@ -211,6 +217,7 @@ namespace euf { } void egraph::set_conflict(enode* n1, enode* n2, justification j) { + ++m_stats.m_num_conflicts; if (m_inconsistent) return; m_inconsistent = true; @@ -305,7 +312,8 @@ namespace euf { template void egraph::explain_eq(ptr_vector& justifications, enode* a, enode* b, bool comm) { SASSERT(m_todo.empty()); - push_congruence(a, b, comm); + SASSERT(a->get_root() == b->get_root()); + push_lca(a, b); explain_todo(justifications); } @@ -330,6 +338,10 @@ namespace euf { std::ostream& egraph::display(std::ostream& out) const { m_table.display(out); + unsigned max_args = 0; + for (enode* n : m_nodes) + max_args = std::max(max_args, n->num_args()); + for (enode* n : m_nodes) { out << std::setw(5) << n->get_owner_id() << " := "; @@ -342,14 +354,55 @@ namespace euf { else out << "v "; for (enode* arg : enode_args(n)) - out << arg->get_owner_id() << " "; - out << std::setw(20) << " parents: "; + out << arg->get_owner_id() << " "; + for (unsigned i = n->num_args(); i < max_args; ++i) + out << " "; + out << "\t"; for (enode* p : enode_parents(n)) out << p->get_owner_id() << " "; out << "\n"; } return out; } + + void egraph::collect_statistics(statistics& st) const { + st.update("euf merge", m_stats.m_num_merge); + st.update("euf conflicts", m_stats.m_num_conflicts); + st.update("euf eq prop", m_stats.m_num_eqs); + st.update("euf lit prop", m_stats.m_num_lits); + } + + void egraph::copy_from(egraph const& src, std::function& copy_justification) { + SASSERT(m_scopes.empty()); + SASSERT(src.m_scopes.empty()); + SASSERT(m_nodes.empty()); + ptr_vector old_expr2new_enode, args; + ast_translation tr(src.m, m); + for (unsigned i = 0; i < src.m_nodes.size(); ++i) { + enode* n1 = src.m_nodes[i]; + expr* e1 = src.m_exprs[i]; + args.reset(); + for (unsigned j = 0; j < n1->num_args(); ++j) { + args.push_back(old_expr2new_enode[n1->get_arg(j)->get_owner_id()]); + } + expr* e2 = tr(e1); + enode* n2 = mk(e2, args.size(), args.c_ptr()); + m_exprs.push_back(e2); + m_nodes.push_back(n2); + old_expr2new_enode.setx(e1->get_id(), n2, nullptr); + } + for (unsigned i = 0; i < src.m_nodes.size(); ++i) { + enode* n1 = src.m_nodes[i]; + enode* n1t = n1->m_target; + enode* n2 = m_nodes[i]; + enode* n2t = n1t ? old_expr2new_enode[n1t->get_owner_id()] : nullptr; + SASSERT(!n1t || n2t); + if (n1t && n2->get_root() != n2t->get_root()) { + merge(n2, n2t, n1->m_justification.copy(copy_justification)); + } + } + propagate(); + } } template void euf::egraph::explain(ptr_vector& justifications); diff --git a/src/ast/euf/euf_egraph.h b/src/ast/euf/euf_egraph.h index 6844a2dc5..e52f44da5 100644 --- a/src/ast/euf/euf_egraph.h +++ b/src/ast/euf/euf_egraph.h @@ -24,6 +24,7 @@ Notes: --*/ #pragma once +#include "util/statistics.h" #include "ast/euf/euf_enode.h" #include "ast/euf/euf_etable.h" @@ -43,6 +44,14 @@ namespace euf { unsigned m_num_eqs; unsigned m_num_nodes; }; + struct stats { + unsigned m_num_merge; + unsigned m_num_lits; + unsigned m_num_eqs; + unsigned m_num_conflicts; + stats() { reset(); } + void reset() { memset(this, 0, sizeof(*this)); } + }; ast_manager& m; region m_region; enode_vector m_worklist; @@ -60,6 +69,8 @@ namespace euf { enode_vector m_new_eqs; enode_vector m_new_lits; enode_vector m_todo; + stats m_stats; + void push_eq(enode* r1, enode* n1, unsigned r2_num_parents) { m_eqs.push_back(add_eq_record(r1, n1, r2_num_parents)); @@ -120,7 +131,9 @@ namespace euf { template void explain_eq(ptr_vector& justifications, enode* a, enode* b, bool comm); void invariant(); + void copy_from(egraph const& src, std::function& copy_justification); std::ostream& display(std::ostream& out) const; + void collect_statistics(statistics& st) const; }; inline std::ostream& operator<<(std::ostream& out, egraph const& g) { return g.display(out); } diff --git a/src/ast/euf/euf_justification.h b/src/ast/euf/euf_justification.h index 68d7d214f..20e733de0 100644 --- a/src/ast/euf/euf_justification.h +++ b/src/ast/euf/euf_justification.h @@ -56,5 +56,19 @@ namespace euf { bool is_commutative() const { return m_comm; } template T* ext() const { SASSERT(is_external()); return static_cast(m_external); } + + justification copy(std::function& copy_justification) const { + switch (m_kind) { + case external_t: + return external(copy_justification(m_external)); + case axiom_t: + return axiom(); + case congruence_t: + return congruence(m_comm); + default: + UNREACHABLE(); + return axiom(); + } + } }; } diff --git a/src/ast/pb_decl_plugin.h b/src/ast/pb_decl_plugin.h index e804cbb6b..dafce43f2 100644 --- a/src/ast/pb_decl_plugin.h +++ b/src/ast/pb_decl_plugin.h @@ -94,6 +94,7 @@ public: app * mk_ge(unsigned num_args, rational const * coeffs, expr * const * args, rational const& k); app * mk_eq(unsigned num_args, rational const * coeffs, expr * const * args, rational const& k); app * mk_lt(unsigned num_args, rational const * coeffs, expr * const * args, rational const& k); + bool is_pb(expr* t) const { return is_app(t) && to_app(t)->get_family_id() == get_family_id(); } bool is_at_most_k(func_decl *a) const; bool is_at_most_k(expr *a) const { return is_app(a) && is_at_most_k(to_app(a)->get_decl()); } bool is_at_most_k(expr *a, rational& k) const; diff --git a/src/sat/CMakeLists.txt b/src/sat/CMakeLists.txt index 6e79bbab8..f7dc7485a 100644 --- a/src/sat/CMakeLists.txt +++ b/src/sat/CMakeLists.txt @@ -1,6 +1,5 @@ z3_add_component(sat SOURCES - ba_solver.cpp dimacs.cpp sat_aig_cuts.cpp sat_aig_finder.cpp diff --git a/src/sat/ba/CMakeLists.txt b/src/sat/ba/CMakeLists.txt new file mode 100644 index 000000000..fc94b42f8 --- /dev/null +++ b/src/sat/ba/CMakeLists.txt @@ -0,0 +1,8 @@ +z3_add_component(sat_ba + SOURCES + ba_solver.cpp + ba_internalize.cpp + COMPONENT_DEPENDENCIES + sat +) + diff --git a/src/sat/ba/ba_internalize.cpp b/src/sat/ba/ba_internalize.cpp new file mode 100644 index 000000000..27606ff55 --- /dev/null +++ b/src/sat/ba/ba_internalize.cpp @@ -0,0 +1,281 @@ +/*++ +Copyright (c) 2020 Microsoft Corporation + +Module Name: + + ba_internalize.h + +Abstract: + + INternalize methods for Boolean algebra operators. + +Author: + + Nikolaj Bjorner (nbjorner) 2020-08-25 + +--*/ + + +#include "sat/ba/ba_internalize.h" + +namespace sat { + + literal ba_internalize::internalize(sat_internalizer& si, expr* e, bool sign, bool root) { + m_si = &si; + if (pb.is_pb(e)) + return internalize_pb(e, sign, root); + if (m.is_xor(e)) + return internalize_xor(e, sign, root); + UNREACHABLE(); + return null_literal; + } + + literal ba_internalize::internalize_xor(expr* e, bool sign, bool root) { + sat::literal_vector lits; + sat::bool_var v = m_solver.add_var(true); + lits.push_back(literal(v, true)); + auto add_expr = [&](expr* a) { + literal lit = m_si->internalize(a); + m_solver.set_external(lit.var()); + lits.push_back(lit); + }; + expr* e1 = nullptr; + while (m.is_iff(e, e1, e)) + add_expr(e1); + add_expr(e); + // ensure that = is converted to xor + for (unsigned i = 1; i + 1 < lits.size(); ++i) { + lits[i].neg(); + } + ba.add_xr(lits); + auto* aig = m_solver.get_cut_simplifier(); + if (aig) aig->add_xor(~lits.back(), lits.size() - 1, lits.c_ptr() + 1); + sat::literal lit(v, sign); + return literal(v, sign); + } + + literal ba_internalize::internalize_pb(expr* e, bool sign, bool root) { + SASSERT(pb.is_pb(e)); + app* t = to_app(e); + rational k = pb.get_k(t); + switch (t->get_decl_kind()) { + case OP_AT_MOST_K: + return convert_at_most_k(t, k, root, sign); + case OP_AT_LEAST_K: + return convert_at_least_k(t, k, root, sign); + case OP_PB_LE: + if (pb.has_unit_coefficients(t)) + return convert_at_most_k(t, k, root, sign); + else + return convert_pb_le(t, root, sign); + case OP_PB_GE: + if (pb.has_unit_coefficients(t)) + return convert_at_least_k(t, k, root, sign); + else + return convert_pb_ge(t, root, sign); + case OP_PB_EQ: + if (pb.has_unit_coefficients(t)) + return convert_eq_k(t, k, root, sign); + else + return convert_pb_eq(t, root, sign); + default: + UNREACHABLE(); + } + return null_literal; + } + + void ba_internalize::check_unsigned(rational const& c) { + if (!c.is_unsigned()) { + throw default_exception("unsigned coefficient expected"); + } + } + + void ba_internalize::convert_to_wlits(app* t, sat::literal_vector const& lits, svector& wlits) { + for (unsigned i = 0; i < lits.size(); ++i) { + rational c = pb.get_coeff(t, i); + check_unsigned(c); + wlits.push_back(std::make_pair(c.get_unsigned(), lits[i])); + } + } + + void ba_internalize::convert_pb_args(app* t, literal_vector& lits) { + for (expr* arg : *t) { + lits.push_back(m_si->internalize(arg)); + m_solver.set_external(lits.back().var()); + } + } + + void ba_internalize::convert_pb_args(app* t, svector& wlits) { + sat::literal_vector lits; + convert_pb_args(t, lits); + convert_to_wlits(t, lits, wlits); + } + + literal ba_internalize::convert_pb_le(app* t, bool root, bool sign) { + rational k = pb.get_k(t); + k.neg(); + svector wlits; + convert_pb_args(t, wlits); + for (wliteral& wl : wlits) { + wl.second.neg(); + k += rational(wl.first); + } + check_unsigned(k); + if (root && m_solver.num_user_scopes() == 0) { + unsigned k1 = k.get_unsigned(); + if (sign) { + k1 = 1 - k1; + for (wliteral& wl : wlits) { + wl.second.neg(); + k1 += wl.first; + } + } + ba.add_pb_ge(null_bool_var, wlits, k1); + return null_literal; + } + else { + bool_var v = m_solver.add_var(true); + literal lit(v, sign); + ba.add_pb_ge(v, wlits, k.get_unsigned()); + TRACE("ba", tout << "root: " << root << " lit: " << lit << "\n";); + return lit; + } + } + + + literal ba_internalize::convert_pb_ge(app* t, bool root, bool sign) { + rational k = pb.get_k(t); + check_unsigned(k); + svector wlits; + convert_pb_args(t, wlits); + if (root && m_solver.num_user_scopes() == 0) { + unsigned k1 = k.get_unsigned(); + if (sign) { + k1 = 1 - k1; + for (wliteral& wl : wlits) { + wl.second.neg(); + k1 += wl.first; + } + } + ba.add_pb_ge(sat::null_bool_var, wlits, k1); + return null_literal; + } + else { + sat::bool_var v = m_solver.add_var(true); + sat::literal lit(v, sign); + ba.add_pb_ge(v, wlits, k.get_unsigned()); + TRACE("goal2sat", tout << "root: " << root << " lit: " << lit << "\n";); + return lit; + } + } + + literal ba_internalize::convert_pb_eq(app* t, bool root, bool sign) { + rational k = pb.get_k(t); + SASSERT(k.is_unsigned()); + svector wlits; + convert_pb_args(t, wlits); + bool base_assert = (root && !sign && m_solver.num_user_scopes() == 0); + bool_var v1 = base_assert ? null_bool_var : m_solver.add_var(true); + bool_var v2 = base_assert ? null_bool_var : m_solver.add_var(true); + ba.add_pb_ge(v1, wlits, k.get_unsigned()); + k.neg(); + for (wliteral& wl : wlits) { + wl.second.neg(); + k += rational(wl.first); + } + check_unsigned(k); + ba.add_pb_ge(v2, wlits, k.get_unsigned()); + if (base_assert) { + return null_literal; + } + else { + literal l1(v1, false), l2(v2, false); + bool_var v = m_solver.add_var(false); + literal l(v, false); + m_si->mk_clause(~l, l1); + m_si->mk_clause(~l, l2); + m_si->mk_clause(~l1, ~l2, l); + m_si->cache(t, l); + if (sign) l.neg(); + return l; + } + } + + literal ba_internalize::convert_at_least_k(app* t, rational const& k, bool root, bool sign) { + SASSERT(k.is_unsigned()); + literal_vector lits; + convert_pb_args(t, lits); + unsigned k2 = k.get_unsigned(); + if (root && m_solver.num_user_scopes() == 0) { + if (sign) { + for (literal& l : lits) l.neg(); + k2 = lits.size() + 1 - k2; + } + ba.add_at_least(null_bool_var, lits, k2); + return null_literal; + } + else { + bool_var v = m_solver.add_var(true); + literal lit(v, false); + ba.add_at_least(v, lits, k.get_unsigned()); + m_si->cache(t, lit); + if (sign) lit.neg(); + TRACE("ba", tout << "root: " << root << " lit: " << lit << "\n";); + return lit; + } + } + + literal ba_internalize::convert_at_most_k(app* t, rational const& k, bool root, bool sign) { + SASSERT(k.is_unsigned()); + literal_vector lits; + convert_pb_args(t, lits); + for (literal& l : lits) { + l.neg(); + } + unsigned k2 = lits.size() - k.get_unsigned(); + if (root && m_solver.num_user_scopes() == 0) { + if (sign) { + for (literal& l : lits) l.neg(); + k2 = lits.size() + 1 - k2; + } + ba.add_at_least(null_bool_var, lits, k2); + return null_literal; + } + else { + bool_var v = m_solver.add_var(true); + literal lit(v, false); + ba.add_at_least(v, lits, k2); + m_si->cache(t, lit); + if (sign) lit.neg(); + return lit; + } + } + + literal ba_internalize::convert_eq_k(app* t, rational const& k, bool root, bool sign) { + SASSERT(k.is_unsigned()); + literal_vector lits; + convert_pb_args(t, lits); + bool_var v1 = (root && !sign) ? null_bool_var : m_solver.add_var(true); + bool_var v2 = (root && !sign) ? null_bool_var : m_solver.add_var(true); + ba.add_at_least(v1, lits, k.get_unsigned()); + for (literal& l : lits) { + l.neg(); + } + ba.add_at_least(v2, lits, lits.size() - k.get_unsigned()); + + if (!root || sign) { + literal l1(v1, false), l2(v2, false); + bool_var v = m_solver.add_var(false); + literal l(v, false); + m_si->mk_clause(~l, l1); + m_si->mk_clause(~l, l2); + m_si->mk_clause(~l1, ~l2, l); + m_si->cache(t, l); + if (sign) l.neg(); + return l; + } + else { + return null_literal; + } + } +} diff --git a/src/sat/ba/ba_internalize.h b/src/sat/ba/ba_internalize.h new file mode 100644 index 000000000..eb81486b2 --- /dev/null +++ b/src/sat/ba/ba_internalize.h @@ -0,0 +1,52 @@ +/*++ +Copyright (c) 2020 Microsoft Corporation + +Module Name: + + ba_internalize.h + +Abstract: + + INternalize methods for Boolean algebra operators. + +Author: + + Nikolaj Bjorner (nbjorner) 2020-08-25 + +--*/ + + +#pragma once + +#include "sat/smt/sat_smt.h" +#include "sat/ba/ba_solver.h" +#include "ast/pb_decl_plugin.h" + + +namespace sat { + + class ba_internalize : public th_internalizer { + typedef std::pair wliteral; + ast_manager& m; + pb_util pb; + ba_solver& ba; + solver_core& m_solver; + sat_internalizer* m_si; + literal convert_eq_k(app* t, rational const& k, bool root, bool sign); + literal convert_at_most_k(app* t, rational const& k, bool root, bool sign); + literal convert_at_least_k(app* t, rational const& k, bool root, bool sign); + literal convert_pb_eq(app* t, bool root, bool sign); + literal convert_pb_le(app* t, bool root, bool sign); + literal convert_pb_ge(app* t, bool root, bool sign); + void check_unsigned(rational const& c); + void convert_to_wlits(app* t, sat::literal_vector const& lits, svector& wlits); + void convert_pb_args(app* t, svector& wlits); + void convert_pb_args(app* t, literal_vector& lits); + literal internalize_pb(expr* e, bool sign, bool root); + literal internalize_xor(expr* e, bool sign, bool root); + public: + ba_internalize(ba_solver& ba, solver_core& s, ast_manager& m) : m(m), pb(m), ba(ba), m_solver(s) {} + literal internalize(sat_internalizer& si, expr* e, bool sign, bool root) override; + + }; +} diff --git a/src/sat/ba_solver.cpp b/src/sat/ba/ba_solver.cpp similarity index 99% rename from src/sat/ba_solver.cpp rename to src/sat/ba/ba_solver.cpp index acc7dc777..239ed723d 100644 --- a/src/sat/ba_solver.cpp +++ b/src/sat/ba/ba_solver.cpp @@ -18,7 +18,7 @@ Revision History: --*/ #include -#include "sat/ba_solver.h" +#include "sat/ba/ba_solver.h" #include "sat/sat_types.h" #include "util/mpz.h" #include "sat/sat_simplifier_params.hpp" @@ -128,8 +128,8 @@ namespace sat { // ---------------------- // card - ba_solver::card::card(unsigned id, literal lit, literal_vector const& lits, unsigned k): - pb_base(card_t, id, lit, lits.size(), get_obj_size(lits.size()), k) { + ba_solver::card::card(extension* e, unsigned id, literal lit, literal_vector const& lits, unsigned k): + pb_base(e, card_t, id, lit, lits.size(), get_obj_size(lits.size()), k) { for (unsigned i = 0; i < size(); ++i) { m_lits[i] = lits[i]; } @@ -155,8 +155,8 @@ namespace sat { // ----------------------------------- // pb - ba_solver::pb::pb(unsigned id, literal lit, svector const& wlits, unsigned k): - pb_base(pb_t, id, lit, wlits.size(), get_obj_size(wlits.size()), k), + ba_solver::pb::pb(extension* e, unsigned id, literal lit, svector const& wlits, unsigned k): + pb_base(e, pb_t, id, lit, wlits.size(), get_obj_size(wlits.size()), k), m_slack(0), m_num_watch(0), m_max_sum(0) { @@ -208,8 +208,8 @@ namespace sat { // ----------------------------------- // xr - ba_solver::xr::xr(unsigned id, literal_vector const& lits): - constraint(xr_t, id, null_literal, lits.size(), get_obj_size(lits.size())) { + ba_solver::xr::xr(extension* e, unsigned id, literal_vector const& lits): + constraint(e, xr_t, id, null_literal, lits.size(), get_obj_size(lits.size())) { for (unsigned i = 0; i < size(); ++i) { m_lits[i] = lits[i]; } @@ -1897,7 +1897,7 @@ namespace sat { return nullptr; } void * mem = m_allocator.allocate(card::get_obj_size(lits.size())); - card* c = new (mem) card(next_id(), lit, lits, k); + card* c = new (mem) card(this, next_id(), lit, lits, k); c->set_learned(learned); add_constraint(c); return c; @@ -1966,7 +1966,7 @@ namespace sat { return add_at_least(lit, lits, k, learned); } void * mem = m_allocator.allocate(pb::get_obj_size(wlits.size())); - pb* p = new (mem) pb(next_id(), lit, wlits, k); + pb* p = new (mem) pb(this, next_id(), lit, wlits, k); p->set_learned(learned); add_constraint(p); return p; @@ -2082,7 +2082,7 @@ namespace sat { break; } void * mem = m_allocator.allocate(xr::get_obj_size(lits.size())); - xr* x = new (mem) xr(next_id(), lits); + xr* x = new (mem) xr(this, next_id(), lits); x->set_learned(learned); add_constraint(x); return x; diff --git a/src/sat/ba_solver.h b/src/sat/ba/ba_solver.h similarity index 97% rename from src/sat/ba_solver.h rename to src/sat/ba/ba_solver.h index 8c67e9560..101f2dc52 100644 --- a/src/sat/ba_solver.h +++ b/src/sat/ba/ba_solver.h @@ -24,6 +24,7 @@ Revision History: #include "sat/sat_solver.h" #include "sat/sat_lookahead.h" #include "sat/sat_big.h" +#include "sat/smt/sat_smt.h" #include "util/small_object_allocator.h" #include "util/scoped_ptr_vector.h" #include "util/sorting_network.h" @@ -64,7 +65,7 @@ namespace sat { class xr; class pb_base; - class constraint { + class constraint : public index_base { protected: tag_t m_tag; bool m_removed; @@ -78,7 +79,8 @@ namespace sat { unsigned m_id; bool m_pure; // is the constraint pure (only positive occurrences) public: - constraint(tag_t t, unsigned id, literal l, unsigned sz, size_t osz): + constraint(extension* e, tag_t t, unsigned id, literal l, unsigned sz, size_t osz): + index_base(e), m_tag(t), m_removed(false), m_lit(l), m_watch(null_literal), m_glue(0), m_psm(0), m_size(sz), m_obj_size(osz), m_learned(false), m_id(id), m_pure(false) {} ext_constraint_idx index() const { return reinterpret_cast(this); } unsigned id() const { return m_id; } @@ -132,7 +134,8 @@ namespace sat { protected: unsigned m_k; public: - pb_base(tag_t t, unsigned id, literal l, unsigned sz, size_t osz, unsigned k): constraint(t, id, l, sz, osz), m_k(k) { VERIFY(k < 4000000000); } + pb_base(extension* e, tag_t t, unsigned id, literal l, unsigned sz, size_t osz, unsigned k): + constraint(e, t, id, l, sz, osz), m_k(k) { VERIFY(k < 4000000000); } virtual void set_k(unsigned k) { VERIFY(k < 4000000000); m_k = k; } virtual unsigned get_coeff(unsigned i) const { UNREACHABLE(); return 0; } unsigned k() const { return m_k; } @@ -143,7 +146,7 @@ namespace sat { literal m_lits[0]; public: static size_t get_obj_size(unsigned num_lits) { return sizeof(card) + num_lits * sizeof(literal); } - card(unsigned id, literal lit, literal_vector const& lits, unsigned k); + card(extension* e, unsigned id, literal lit, literal_vector const& lits, unsigned k); literal operator[](unsigned i) const { return m_lits[i]; } literal& operator[](unsigned i) { return m_lits[i]; } literal const* begin() const { return m_lits; } @@ -167,7 +170,7 @@ namespace sat { wliteral m_wlits[0]; public: static size_t get_obj_size(unsigned num_lits) { return sizeof(pb) + num_lits * sizeof(wliteral); } - pb(unsigned id, literal lit, svector const& wlits, unsigned k); + pb(extension* e, unsigned id, literal lit, svector const& wlits, unsigned k); literal lit() const { return m_lit; } wliteral operator[](unsigned i) const { return m_wlits[i]; } wliteral& operator[](unsigned i) { return m_wlits[i]; } @@ -195,7 +198,7 @@ namespace sat { literal m_lits[0]; public: static size_t get_obj_size(unsigned num_lits) { return sizeof(xr) + num_lits * sizeof(literal); } - xr(unsigned id, literal_vector const& lits); + xr(extension* e, unsigned id, literal_vector const& lits); literal operator[](unsigned i) const { return m_lits[i]; } literal const* begin() const { return m_lits; } literal const* end() const { return begin() + m_size; } diff --git a/src/sat/euf/CMakeLists.txt b/src/sat/euf/CMakeLists.txt index 0be16aa8b..4ee444de4 100644 --- a/src/sat/euf/CMakeLists.txt +++ b/src/sat/euf/CMakeLists.txt @@ -3,6 +3,6 @@ z3_add_component(sat_euf euf_solver.cpp COMPONENT_DEPENDENCIES sat - sat_tactic + sat_smt euf ) diff --git a/src/sat/euf/euf_solver.cpp b/src/sat/euf/euf_solver.cpp index ee018a012..1da64970a 100644 --- a/src/sat/euf/euf_solver.cpp +++ b/src/sat/euf/euf_solver.cpp @@ -14,28 +14,83 @@ Author: Nikolaj Bjorner (nbjorner) 2020-08-25 --*/ + +#include "ast/pb_decl_plugin.h" +#include "sat/smt/sat_smt.h" +#include "sat/ba/ba_solver.h" +#include "sat/ba/ba_internalize.h" #include "sat/euf/euf_solver.h" #include "sat/sat_solver.h" #include "tactic/tactic_exception.h" -namespace euf_sat { +namespace euf { - bool solver::propagate(literal l, ext_constraint_idx idx) { - UNREACHABLE(); - return true; + /** + * retrieve extension that is associated with Boolean variable. + */ + sat::extension* solver::get_extension(sat::bool_var v) { + if (v >= m_var2node.size()) + return nullptr; + euf::enode* n = m_var2node[v].first; + if (!n) + return nullptr; + return get_extension(n->get_owner()); } - void solver::get_antecedents(literal l, ext_justification_idx idx, literal_vector & r) { + void solver::add_extension(family_id fid, sat::extension* e) { + m_extensions.push_back(e); + m_id2extension.setx(fid, e, nullptr); + } + + sat::extension* solver::get_extension(expr* e) { + if (is_app(e)) { + auto fid = to_app(e)->get_family_id(); + if (fid == null_family_id) + return nullptr; + auto* ext = m_id2extension.get(fid, nullptr); + if (ext) + return ext; + pb_util pb(m); + if (pb.is_pb(e)) { + auto* ba = alloc(sat::ba_solver); + ba->set_solver(m_solver); + add_extension(pb.get_family_id(), ba); + auto* bai = alloc(sat::ba_internalize, *ba, s(), m); + m_id2internalize.setx(pb.get_family_id(), bai, nullptr); + m_internalizers.push_back(bai); + ba->push_scopes(s().num_scopes()); + return ba; + } + } + return nullptr; + } + + bool solver::propagate(literal l, ext_constraint_idx idx) { + auto* ext = sat::index_base::to_extension(idx); + SASSERT(ext != this); + return ext->propagate(l, idx); + } + + void solver::get_antecedents(literal l, ext_justification_idx idx, literal_vector& r) { + auto* ext = sat::index_base::to_extension(idx); + if (ext == this) + get_antecedents(l, *euf_base::from_idx(idx), r); + else + ext->get_antecedents(l, idx, r); + } + + void solver::get_antecedents(literal l, euf_base& j, literal_vector& r) { m_explain.reset(); euf::enode* n = nullptr; bool sign = false; - if (idx != 0) { + if (j.id() != 0) { auto p = m_var2node[l.var()]; n = p.first; + SASSERT(n); sign = l.sign() != p.second; } - switch (idx) { + switch (j.id()) { case 0: SASSERT(m_egraph.inconsistent()); m_egraph.explain(m_explain); @@ -56,12 +111,20 @@ namespace euf_sat { } void solver::asserted(literal l) { - auto p = m_var2node[l.var()]; + auto* ext = get_extension(l.var()); + if (ext) { + ext->asserted(l); + return; + } + + auto p = m_var2node.get(l.var(), enode_bool_pair(nullptr, false)); + if (!p.first) + return; bool sign = p.second != l.sign(); euf::enode* n = p.first; - expr* e = n->get_owner(); - if (m.is_eq(e) && !sign) { - euf::enode* na = n->get_arg(0); + expr* e = n->get_owner(); + if (m.is_eq(e) && !sign) { + euf::enode* na = n->get_arg(0); euf::enode* nb = n->get_arg(1); m_egraph.merge(na, nb, base_ptr() + l.index()); } @@ -69,16 +132,19 @@ namespace euf_sat { euf::enode* nb = sign ? m_false : m_true; m_egraph.merge(n, nb, base_ptr() + l.index()); } - // TBD: delay propagation? + propagate(); + } + + void solver::propagate() { m_egraph.propagate(); if (m_egraph.inconsistent()) { - s().set_conflict(sat::justification::mk_ext_justification(s().scope_lvl(), 0)); + s().set_conflict(sat::justification::mk_ext_justification(s().scope_lvl(), m_conflict_idx.to_index())); return; } for (euf::enode* eq : m_egraph.new_eqs()) { bool_var v = m_expr2var.to_bool_var(eq->get_owner()); - s().assign(literal(v, false), sat::justification::mk_ext_justification(s().scope_lvl(), 1)); + s().assign(literal(v, false), sat::justification::mk_ext_justification(s().scope_lvl(), m_eq_idx.to_index())); } for (euf::enode* p : m_egraph.new_lits()) { expr* e = p->get_owner(); @@ -86,73 +152,231 @@ namespace euf_sat { SASSERT(m.is_bool(e)); SASSERT(m.is_true(p->get_root()->get_owner()) || sign); bool_var v = m_expr2var.to_bool_var(e); - s().assign(literal(v, sign), sat::justification::mk_ext_justification(s().scope_lvl(), 2)); + s().assign(literal(v, sign), sat::justification::mk_ext_justification(s().scope_lvl(), m_lit_idx.to_index())); } } sat::check_result solver::check() { + bool give_up = false; + bool cont = false; + for (auto* e : m_extensions) + switch (e->check()) { + case sat::CR_CONTINUE: cont = true; break; + case sat::CR_GIVEUP: give_up = true; break; + default: break; + } + if (cont) + return sat::CR_CONTINUE; + if (give_up) + return sat::CR_GIVEUP; return sat::CR_DONE; } + void solver::push() { + for (auto* e : m_extensions) + e->push(); m_egraph.push(); + ++m_num_scopes; } + void solver::pop(unsigned n) { m_egraph.pop(n); + for (auto* e : m_extensions) + e->pop(n); + if (n <= m_num_scopes) { + m_num_scopes -= n; + return; + } + n -= m_num_scopes; + unsigned old_lim = m_bool_var_lim.size() - n; + unsigned old_sz = m_bool_var_lim[old_lim]; + for (unsigned i = m_bool_var_trail.size(); i-- > old_sz; ) + m_var2node[m_bool_var_trail[i]] = enode_bool_pair(nullptr, false); + m_bool_var_trail.shrink(old_sz); + m_bool_var_lim.shrink(old_lim); } - void solver::pre_simplify() {} - void solver::simplify() {} - // have a way to replace l by r in all constraints - void solver::clauses_modifed() {} - lbool solver::get_phase(bool_var v) { return l_undef; } + + void solver::pre_simplify() { + for (auto* e : m_extensions) + e->pre_simplify(); + } + + void solver::simplify() { + for (auto* e : m_extensions) + e->simplify(); + } + + void solver::clauses_modifed() { + for (auto* e : m_extensions) + e->clauses_modifed(); + } + + lbool solver::get_phase(bool_var v) { + auto* ext = get_extension(v); + if (ext) + return ext->get_phase(v); + return l_undef; + } + std::ostream& solver::display(std::ostream& out) const { m_egraph.display(out); + for (auto* e : m_extensions) + e->display(out); return out; } - std::ostream& solver::display_justification(std::ostream& out, ext_justification_idx idx) const { return out; } - std::ostream& solver::display_constraint(std::ostream& out, ext_constraint_idx idx) const { return out; } - void solver::collect_statistics(statistics& st) const {} - sat::extension* solver::copy(sat::solver* s) { return nullptr; } - sat::extension* solver::copy(sat::lookahead* s, bool learned) { return nullptr; } - void solver::find_mutexes(literal_vector& lits, vector & mutexes) {} - void solver::gc() {} - void solver::pop_reinit() {} - bool solver::validate() { return true; } - void solver::init_use_list(sat::ext_use_list& ul) {} - bool solver::is_blocked(literal l, ext_constraint_idx) { return false; } - bool solver::check_model(sat::model const& m) const { return true;} - unsigned solver::max_var(unsigned w) const { return w; } + + std::ostream& solver::display_justification(std::ostream& out, ext_justification_idx idx) const { + auto* ext = sat::index_base::to_extension(idx); + if (ext != this) + return ext->display_justification(out, idx); + return out; + } + + std::ostream& solver::display_constraint(std::ostream& out, ext_constraint_idx idx) const { + auto* ext = sat::index_base::to_extension(idx); + if (ext != this) + return ext->display_constraint(out, idx); + return out; + } + + void solver::collect_statistics(statistics& st) const { + m_egraph.collect_statistics(st); + for (auto* e : m_extensions) + e->collect_statistics(st); + } + + solver* solver::copy_core() { + ast_manager& to = m_translate ? m_translate->to() : m; + atom2bool_var& a2b = m_translate_expr2var ? *m_translate_expr2var : m_expr2var; + auto* r = alloc(solver, to, a2b); + std::function copy_justification = [&](void* x) { return (void*)(r->base_ptr() + ((unsigned*)x - base_ptr())); }; + r->m_egraph.copy_from(m_egraph, copy_justification); + return r; + } + + sat::extension* solver::copy(sat::solver* s) { + auto* r = copy_core(); + r->set_solver(s); + for (auto* e : m_extensions) + r->m_extensions.push_back(e->copy(s)); + return r; + } + + sat::extension* solver::copy(sat::lookahead* s, bool learned) { + (void) learned; + auto* r = copy_core(); + r->set_lookahead(s); + for (auto* e : m_extensions) + r->m_extensions.push_back(e->copy(s, learned)); + return r; + } + + void solver::find_mutexes(literal_vector& lits, vector & mutexes) { + for (auto* e : m_extensions) + e->find_mutexes(lits, mutexes); + } + + void solver::gc() { + for (auto* e : m_extensions) + e->gc(); + } + + void solver::pop_reinit() { + for (auto* e : m_extensions) + e->pop_reinit(); + } + + bool solver::validate() { + for (auto* e : m_extensions) + if (!e->validate()) + return false; + return true; + } + + void solver::init_use_list(sat::ext_use_list& ul) { + for (auto* e : m_extensions) + e->init_use_list(ul); + } + + bool solver::is_blocked(literal l, ext_constraint_idx idx) { + auto* ext = sat::index_base::to_extension(idx); + if (ext != this) + return is_blocked(l, idx); + return false; + } + + bool solver::check_model(sat::model const& m) const { + for (auto* e : m_extensions) + if (!e->check_model(m)) + return false; + return true; + } + + unsigned solver::max_var(unsigned w) const { + for (auto* e : m_extensions) + w = e->max_var(w); + for (unsigned sz = m_var2node.size(); sz-- > 0; ) { + euf::enode* n = m_var2node[sz].first; + if (n && m.is_bool(n->get_owner())) { + w = std::max(w, sz); + break; + } + } + return w; + } + + sat::th_internalizer* solver::get_internalizer(expr* e) { + if (is_app(e)) + return m_id2internalize.get(to_app(e)->get_family_id(), nullptr); + if (m.is_iff(e)) { + pb_util pb(m); + return m_id2internalize.get(pb.get_family_id(), nullptr); + } + return nullptr; + } - void solver::internalize(sat_internalizer& si, expr* e) { + sat::literal solver::internalize(sat::sat_internalizer& si, expr* e, bool sign, bool root) { + auto* ext = get_internalizer(e); + if (ext) + return ext->internalize(si, e, sign, root); + if (!m_true) { + m_true = visit(si, m.mk_true()); + m_false = visit(si, m.mk_false()); + } SASSERT(!si.is_bool_op(e)); + sat::scoped_stack _sc(m_stack); unsigned sz = m_stack.size(); euf::enode* n = visit(si, e); while (m_stack.size() > sz) { loop: if (!m.inc()) throw tactic_exception(m.limit().get_cancel_msg()); - frame & fr = m_stack.back(); + sat::frame & fr = m_stack.back(); expr* e = fr.m_e; if (m_egraph.find(e)) { m_stack.pop_back(); continue; } unsigned num = is_app(e) ? to_app(e)->get_num_args() : 0; - m_args.reset(); + while (fr.m_idx < num) { expr* arg = to_app(e)->get_arg(fr.m_idx); fr.m_idx++; n = visit(si, arg); if (!n) goto loop; - m_args.push_back(n); } + m_args.reset(); + for (unsigned i = 0; i < num; ++i) + m_args.push_back(m_egraph.find(to_app(e)->get_arg(i))); n = m_egraph.mk(e, num, m_args.c_ptr()); attach_bool_var(si, n); } SASSERT(m_egraph.find(e)); + return literal(m_expr2var.to_bool_var(e), sign); } - euf::enode* solver::visit(sat_internalizer& si, expr* e) { + euf::enode* solver::visit(sat::sat_internalizer& si, expr* e) { euf::enode* n = m_egraph.find(e); if (n) return n; @@ -160,11 +384,12 @@ namespace euf_sat { sat::literal lit = si.internalize(e); n = m_egraph.mk(e, 0, nullptr); attach_bool_var(lit.var(), lit.sign(), n); - s().set_external(lit.var()); + if (!m.is_true(e) && !m.is_false(e)) + s().set_external(lit.var()); return n; } if (is_app(e) && to_app(e)->get_num_args() > 0) { - m_stack.push_back(frame(e)); + m_stack.push_back(sat::frame(e)); return nullptr; } n = m_egraph.mk(e, 0, nullptr); @@ -172,7 +397,7 @@ namespace euf_sat { return n; } - void solver::attach_bool_var(sat_internalizer& si, euf::enode* n) { + void solver::attach_bool_var(sat::sat_internalizer& si, euf::enode* n) { expr* e = n->get_owner(); if (m.is_bool(e)) { sat::bool_var v = si.add_bool_var(e); @@ -181,8 +406,17 @@ namespace euf_sat { } void solver::attach_bool_var(sat::bool_var v, bool sign, euf::enode* n) { - m_var2node.reserve(v + 1); + m_var2node.reserve(v + 1, enode_bool_pair(nullptr, false)); + for (; m_num_scopes > 0; --m_num_scopes) + m_bool_var_lim.push_back(m_bool_var_trail.size()); + SASSERT(m_var2node[v].first == nullptr); m_var2node[v] = euf::enode_bool_pair(n, sign); + m_bool_var_trail.push_back(v); + } + + model_converter* solver::get_model() { + NOT_IMPLEMENTED_YET(); + return nullptr; } } diff --git a/src/sat/euf/euf_solver.h b/src/sat/euf/euf_solver.h index 6a98ea93e..3dd36fcac 100644 --- a/src/sat/euf/euf_solver.h +++ b/src/sat/euf/euf_solver.h @@ -16,53 +16,91 @@ Author: --*/ #pragma once +#include "util/scoped_ptr_vector.h" #include "sat/sat_extension.h" #include "ast/euf/euf_egraph.h" -#include "sat/tactic/atom2bool_var.h" -#include "sat/tactic/goal2sat.h" +#include "ast/ast_translation.h" +#include "sat/smt/sat_smt.h" +#include "sat/smt/atom2bool_var.h" +#include "tactic/model_converter.h" -namespace euf_sat { +namespace euf { typedef sat::literal literal; typedef sat::ext_constraint_idx ext_constraint_idx; typedef sat::ext_justification_idx ext_justification_idx; typedef sat::literal_vector literal_vector; typedef sat::bool_var bool_var; - struct frame { - expr* m_e; - unsigned m_idx; - frame(expr* e) : m_e(e), m_idx(0) {} + class euf_base : public sat::index_base { + unsigned m_id; + public: + euf_base(sat::extension* e, unsigned id) : + index_base(e), m_id(id) + {} + unsigned id() const { return m_id; } + static euf_base* from_idx(size_t z) { return reinterpret_cast(z); } }; - class solver : public sat::extension { + class solver : public sat::extension, public sat::th_internalizer { ast_manager& m; atom2bool_var& m_expr2var; euf::egraph m_egraph; sat::solver* m_solver; + sat::lookahead* m_lookahead; + ast_translation* m_translate; + atom2bool_var* m_translate_expr2var; euf::enode* m_true; euf::enode* m_false; svector m_var2node; ptr_vector m_explain; euf::enode_vector m_args; - svector m_stack; + svector m_stack; + unsigned m_num_scopes { 0 }; + unsigned_vector m_bool_var_trail; + unsigned_vector m_bool_var_lim; + scoped_ptr_vector m_extensions; + ptr_vector m_id2extension; + ptr_vector m_id2internalize; + scoped_ptr_vector m_internalizers; + euf_base m_conflict_idx, m_eq_idx, m_lit_idx; sat::solver& s() { return *m_solver; } unsigned * base_ptr() { return reinterpret_cast(this); } - euf::enode* visit(sat_internalizer& si, expr* e); - void attach_bool_var(sat_internalizer& si, euf::enode* n); + euf::enode* visit(sat::sat_internalizer& si, expr* e); + void attach_bool_var(sat::sat_internalizer& si, euf::enode* n); void attach_bool_var(sat::bool_var v, bool sign, euf::enode* n); + solver* copy_core(); + sat::extension* get_extension(sat::bool_var v); + void add_extension(family_id fid, sat::extension* e); + sat::th_internalizer* get_internalizer(expr* e); + + void propagate(); + void get_antecedents(literal l, euf_base& j, literal_vector& r); public: solver(ast_manager& m, atom2bool_var& expr2var): m(m), m_expr2var(expr2var), m_egraph(m), - m_solver(nullptr) + m_solver(nullptr), + m_lookahead(nullptr), + m_translate(nullptr), + m_translate_expr2var(nullptr), + m_true(nullptr), + m_false(nullptr), + m_conflict_idx(this, 0), + m_eq_idx(this, 1), + m_lit_idx(this, 2) {} void set_solver(sat::solver* s) override { m_solver = s; } - void set_lookahead(sat::lookahead* s) override { } + void set_lookahead(sat::lookahead* s) override { m_lookahead = s; } + struct scoped_set_translate { + solver& s; + scoped_set_translate(solver& s, ast_translation& t, atom2bool_var& a2b):s(s) { s.m_translate = &t; s.m_translate_expr2var = &a2b; } + ~scoped_set_translate() { s.m_translate = nullptr; s. m_translate_expr2var = nullptr; } + }; double get_reward(literal l, ext_constraint_idx idx, sat::literal_occs_fun& occs) const override { return 0; } bool is_extended_binary(ext_justification_idx idx, literal_vector & r) override { return false; } @@ -92,7 +130,11 @@ namespace euf_sat { bool check_model(sat::model const& m) const override; unsigned max_var(unsigned w) const override; - void internalize(sat_internalizer& si, expr* e); + sat::literal internalize(sat::sat_internalizer& si, expr* e, bool sign, bool root) override; + model_converter* get_model(); + + + sat::extension* get_extension(expr* e); }; }; diff --git a/src/sat/sat_extension.h b/src/sat/sat_extension.h index c251d769e..f0f15d307 100644 --- a/src/sat/sat_extension.h +++ b/src/sat/sat_extension.h @@ -47,8 +47,9 @@ namespace sat { }; class extension { - public: + public: virtual ~extension() {} + virtual unsigned get_id() const { return 0; } virtual void set_solver(solver* s) = 0; virtual void set_lookahead(lookahead* s) = 0; virtual bool propagate(literal l, ext_constraint_idx idx) = 0; @@ -59,6 +60,7 @@ namespace sat { virtual check_result check() = 0; virtual lbool resolve_conflict() { return l_undef; } // stores result in sat::solver::m_lemma virtual void push() = 0; + void push_scopes(unsigned n) { for (unsigned i = 0; i < n; ++i) push(); } virtual void pop(unsigned n) = 0; virtual void pre_simplify() = 0; virtual void simplify() = 0; diff --git a/src/sat/sat_local_search.cpp b/src/sat/sat_local_search.cpp index cd53e5a1a..a951b728b 100644 --- a/src/sat/sat_local_search.cpp +++ b/src/sat/sat_local_search.cpp @@ -19,7 +19,6 @@ Notes: #include "sat/sat_local_search.h" #include "sat/sat_solver.h" -#include "sat/ba_solver.h" #include "sat/sat_params.hpp" #include "util/timer.h" @@ -415,6 +414,10 @@ namespace sat { } m_num_non_binary_clauses = s.m_clauses.size(); + if (s.get_extension()) + throw default_exception("local search is incompatible with extensions"); + +#if 0 // copy cardinality clauses ba_solver* ext = dynamic_cast(s.get_extension()); if (ext) { @@ -502,6 +505,8 @@ namespace sat { } } } +#endif + if (_init) { init(); } diff --git a/src/sat/sat_params.pyg b/src/sat/sat_params.pyg index 74e5f77b4..24ab6687b 100644 --- a/src/sat/sat_params.pyg +++ b/src/sat/sat_params.pyg @@ -57,6 +57,7 @@ def_module_params('sat', ('cardinality.encoding', SYMBOL, 'grouped', 'encoding used for at-most-k constraints: grouped, bimander, ordered, unate, circuit'), ('pb.resolve', SYMBOL, 'cardinality', 'resolution strategy for boolean algebra solver: cardinality, rounding'), ('pb.lemma_format', SYMBOL, 'cardinality', 'generate either cardinality or pb lemmas'), + ('euf', BOOL, False, 'enable euf solver'), ('ddfw_search', BOOL, False, 'use ddfw local search instead of CDCL'), ('ddfw.init_clause_weight', UINT, 8, 'initial clause weight for DDFW local search'), ('ddfw.use_reward_pct', UINT, 15, 'percentage to pick highest reward variable when it has reward 0'), diff --git a/src/sat/sat_solver.h b/src/sat/sat_solver.h index 0a98bd73b..c52bf4d7c 100644 --- a/src/sat/sat_solver.h +++ b/src/sat/sat_solver.h @@ -635,6 +635,7 @@ namespace sat { void user_pop(unsigned num_scopes) override; void pop_to_base_level() override; unsigned num_user_scopes() const override { return m_user_scope_literals.size(); } + unsigned num_scopes() const override { return m_scopes.size(); } reslimit& rlimit() { return m_rlimit; } params_ref const& params() { return m_params; } // ----------------------- diff --git a/src/sat/sat_solver_core.h b/src/sat/sat_solver_core.h index cbb9c1985..89164f857 100644 --- a/src/sat/sat_solver_core.h +++ b/src/sat/sat_solver_core.h @@ -86,6 +86,8 @@ namespace sat { virtual void user_push() { throw default_exception("optional API not supported"); } virtual void user_pop(unsigned num_scopes) {}; virtual unsigned num_user_scopes() const { return 0;} + virtual unsigned num_scopes() const { return 0; } + // hooks for extension solver. really just ba_solver atm. virtual extension* get_extension() const { return nullptr; } diff --git a/src/sat/smt/CMakeLists.txt b/src/sat/smt/CMakeLists.txt new file mode 100644 index 000000000..495917065 --- /dev/null +++ b/src/sat/smt/CMakeLists.txt @@ -0,0 +1,8 @@ +z3_add_component(sat_smt + SOURCES + atom2bool_var.cpp + COMPONENT_DEPENDENCIES + sat + ast +) + diff --git a/src/sat/tactic/atom2bool_var.cpp b/src/sat/smt/atom2bool_var.cpp similarity index 99% rename from src/sat/tactic/atom2bool_var.cpp rename to src/sat/smt/atom2bool_var.cpp index 08e1258d9..1fa3688a8 100644 --- a/src/sat/tactic/atom2bool_var.cpp +++ b/src/sat/smt/atom2bool_var.cpp @@ -20,7 +20,7 @@ Notes: #include "util/ref_util.h" #include "ast/ast_smt2_pp.h" #include "tactic/goal.h" -#include "sat/tactic/atom2bool_var.h" +#include "sat/smt/atom2bool_var.h" void atom2bool_var::mk_inv(expr_ref_vector & lit2expr) const { for (auto const& kv : m_mapping) { diff --git a/src/sat/tactic/atom2bool_var.h b/src/sat/smt/atom2bool_var.h similarity index 100% rename from src/sat/tactic/atom2bool_var.h rename to src/sat/smt/atom2bool_var.h diff --git a/src/sat/smt/sat_smt.h b/src/sat/smt/sat_smt.h new file mode 100644 index 000000000..9a2f9d1d0 --- /dev/null +++ b/src/sat/smt/sat_smt.h @@ -0,0 +1,62 @@ +/*++ +Copyright (c) 2020 Microsoft Corporation + +Module Name: + + sat_smt.h + +Abstract: + + Header for SMT theories over SAT solver + +Author: + + Nikolaj Bjorner (nbjorner) 2020-08-25 + +--*/ +#pragma once + + +#pragma once +#include "sat/sat_solver.h" +#include "ast/ast.h" + +namespace sat { + + struct frame { + expr* m_e; + unsigned m_idx; + frame(expr* e) : m_e(e), m_idx(0) {} + }; + + struct scoped_stack { + svector& s; + unsigned sz; + scoped_stack(svector& s):s(s), sz(s.size()) {} + ~scoped_stack() { s.shrink(sz); } + }; + + class sat_internalizer { + public: + virtual bool is_bool_op(expr* e) const = 0; + virtual sat::literal internalize(expr* e) = 0; + virtual sat::bool_var add_bool_var(expr* e) = 0; + virtual void mk_clause(literal a, literal b) = 0; + virtual void mk_clause(literal l1, literal l2, literal l3, bool is_lemma = false) = 0; + virtual void cache(app* t, literal l) = 0; + }; + + class th_internalizer { + public: + virtual literal internalize(sat_internalizer& si, expr* e, bool sign, bool root) = 0; + }; + + class index_base { + extension* ex; + public: + index_base(extension* e): ex(e) {} + static extension* to_extension(size_t s) { return from_index(s)->ex; } + static index_base* from_index(size_t s) { return reinterpret_cast(s); } + size_t to_index() const { return reinterpret_cast(this); } + }; +} diff --git a/src/sat/tactic/CMakeLists.txt b/src/sat/tactic/CMakeLists.txt index 848bf814b..434fe854c 100644 --- a/src/sat/tactic/CMakeLists.txt +++ b/src/sat/tactic/CMakeLists.txt @@ -1,13 +1,14 @@ z3_add_component(sat_tactic SOURCES - atom2bool_var.cpp goal2sat.cpp sat_tactic.cpp COMPONENT_DEPENDENCIES sat tactic solver - euf + sat_smt + sat_ba + sat_euf TACTIC_HEADERS sat_tactic.h ) diff --git a/src/sat/tactic/goal2sat.cpp b/src/sat/tactic/goal2sat.cpp index e0ede3523..1f6723240 100644 --- a/src/sat/tactic/goal2sat.cpp +++ b/src/sat/tactic/goal2sat.cpp @@ -34,15 +34,17 @@ Notes: #include "ast/ast_util.h" #include "ast/for_each_expr.h" #include "sat/tactic/goal2sat.h" -#include "sat/ba_solver.h" #include "sat/sat_cut_simplifier.h" +#include "sat/ba/ba_internalize.h" +#include "sat/ba/ba_solver.h" +#include "sat/euf/euf_solver.h" #include "model/model_evaluator.h" #include "model/model_v2_pp.h" #include "tactic/tactic.h" #include "tactic/generic_model_converter.h" #include -struct goal2sat::imp : public sat_internalizer { +struct goal2sat::imp : public sat::sat_internalizer { struct frame { app * m_t; unsigned m_root:1; @@ -53,7 +55,6 @@ struct goal2sat::imp : public sat_internalizer { }; ast_manager & m; pb_util pb; - sat::ba_solver* m_ext; sat::cut_simplifier* m_aig; svector m_frame_stack; svector m_result_stack; @@ -69,12 +70,13 @@ struct goal2sat::imp : public sat_internalizer { expr_ref_vector m_interpreted_atoms; bool m_default_external; bool m_xor_solver; + bool m_euf; bool m_is_lemma; + imp(ast_manager & _m, params_ref const & p, sat::solver_core & s, atom2bool_var & map, dep2asm_map& dep2asm, bool default_external): m(_m), pb(m), - m_ext(nullptr), m_aig(nullptr), m_solver(s), m_map(map), @@ -92,7 +94,7 @@ struct goal2sat::imp : public sat_internalizer { m_ite_extra = p.get_bool("ite_extra", true); m_max_memory = megabytes_to_bytes(p.get_uint("max_memory", UINT_MAX)); m_xor_solver = p.get_bool("xor_solver", false); - if (m_xor_solver) ensure_extension(); + m_euf = false; } void throw_op_not_handled(std::string const& s) { @@ -107,12 +109,12 @@ struct goal2sat::imp : public sat_internalizer { void set_lemma_mode(bool f) { m_is_lemma = f; } - void mk_clause(sat::literal l1, sat::literal l2) { + void mk_clause(sat::literal l1, sat::literal l2) override { TRACE("goal2sat", tout << "mk_clause: " << l1 << " " << l2 << "\n";); m_solver.add_clause(l1, l2, m_is_lemma); } - void mk_clause(sat::literal l1, sat::literal l2, sat::literal l3, bool is_lemma = false) { + void mk_clause(sat::literal l1, sat::literal l2, sat::literal l3, bool is_lemma = false) override { TRACE("goal2sat", tout << "mk_clause: " << l1 << " " << l2 << " " << l3 << "\n";); m_solver.add_clause(l1, l2, l3, m_is_lemma || is_lemma); } @@ -143,6 +145,10 @@ struct goal2sat::imp : public sat_internalizer { return v; } + void cache(app* t, sat::literal l) override { + m_cache.insert(t, l); + } + void convert_atom(expr * t, bool root, bool sign) { SASSERT(m.is_bool(t)); sat::literal l; @@ -166,7 +172,12 @@ struct goal2sat::imp : public sat_internalizer { l = sat::literal(v, sign); TRACE("sat", tout << "new_var: " << v << ": " << mk_bounded_pp(t, m, 2) << " " << is_uninterp_const(t) << "\n";); if (!is_uninterp_const(t)) { - m_interpreted_atoms.push_back(t); + if (m_euf) { + convert_euf(t, root, sign); + return; + } + else + m_interpreted_atoms.push_back(t); } } } @@ -183,8 +194,7 @@ struct goal2sat::imp : public sat_internalizer { } bool convert_app(app* t, bool root, bool sign) { - if (t->get_family_id() == pb.get_family_id()) { - ensure_extension(); + if (pb.is_pb(t)) { m_frame_stack.push_back(frame(to_app(t), root, sign, 0)); return false; } @@ -208,7 +218,6 @@ struct goal2sat::imp : public sat_internalizer { return false; } - bool visit(expr * t, bool root, bool sign) { if (!is_app(t)) { convert_atom(t, root, sign); @@ -216,9 +225,8 @@ struct goal2sat::imp : public sat_internalizer { } if (process_cached(to_app(t), root, sign)) return true; - if (to_app(t)->get_family_id() != m.get_basic_family_id()) { - return convert_app(to_app(t), root, sign); - } + if (to_app(t)->get_family_id() != m.get_basic_family_id()) + return convert_app(to_app(t), root, sign); switch (to_app(t)->get_decl_kind()) { case OP_NOT: case OP_OR: @@ -341,11 +349,11 @@ struct goal2sat::imp : public sat_internalizer { mk_clause(num+1, lits); if (m_aig) { m_aig->add_and(l, num, aig_lits.c_ptr()); - } - unsigned old_sz = m_result_stack.size() - num - 1; - m_result_stack.shrink(old_sz); + } if (sign) l.neg(); + unsigned old_sz = m_result_stack.size() - num - 1; + m_result_stack.shrink(old_sz); m_result_stack.push_back(l); TRACE("goal2sat", tout << m_result_stack << "\n";); } @@ -382,9 +390,9 @@ struct goal2sat::imp : public sat_internalizer { mk_clause(t, e, ~l, false); } if (m_aig) m_aig->add_ite(l, c, t, e); - m_result_stack.shrink(sz-3); if (sign) l.neg(); + m_result_stack.shrink(sz-3); m_result_stack.push_back(l); } } @@ -415,268 +423,70 @@ struct goal2sat::imp : public sat_internalizer { mk_clause(~l, ~l1, l2); mk_clause(l, l1, l2); mk_clause(l, ~l1, ~l2); - if (m_aig) m_aig->add_iff(l, l1, l2); - m_result_stack.shrink(sz-2); + if (m_aig) m_aig->add_iff(l, l1, l2); if (sign) l.neg(); + m_result_stack.shrink(sz - 2); m_result_stack.push_back(l); } } void convert_iff(app * t, bool root, bool sign) { TRACE("goal2sat", tout << "convert_iff " << root << " " << sign << "\n" << mk_bounded_pp(t, m, 2) << "\n";); - unsigned sz = m_result_stack.size(); - unsigned num = get_num_args(t); - SASSERT(sz >= num && num >= 2); - if (num == 2) { + if (is_xor(t)) + convert_ba(t, root, sign); + else convert_iff2(t, root, sign); + } + + void convert_euf(expr* e, bool root, bool sign) { + sat::extension* ext = m_solver.get_extension(); + euf::solver* euf = nullptr; + if (!ext) { + euf = alloc(euf::solver, m, m_map); + m_solver.set_extension(euf); + for (unsigned i = m_solver.num_scopes(); i-- > 0; ) + euf->push(); + } + else { + euf = dynamic_cast(ext); + } + if (!euf) + throw default_exception("cannot convert to euf"); + sat::literal lit = euf->internalize(*this, e, sign, root); + if (root) + m_result_stack.reset(); + if (lit == sat::null_literal) return; - } - sat::literal_vector lits; - sat::bool_var v = m_solver.add_var(true); - lits.push_back(sat::literal(v, true)); - convert_pb_args(num, lits); - // ensure that = is converted to xor - for (unsigned i = 1; i + 1 < lits.size(); ++i) { - lits[i].neg(); - } - ensure_extension(); - m_ext->add_xr(lits); - if (m_aig) m_aig->add_xor(~lits.back(), lits.size() - 1, lits.c_ptr() + 1); - sat::literal lit(v, sign); - if (root) { - m_result_stack.reset(); + if (root) mk_clause(lit); - } - else { - m_result_stack.shrink(sz - num); + else m_result_stack.push_back(lit); - } } - void convert_pb_args(unsigned num_args, sat::literal_vector& lits) { - unsigned sz = m_result_stack.size(); - for (unsigned i = 0; i < num_args; ++i) { - sat::literal lit(m_result_stack[sz - num_args + i]); - if (!m_solver.is_external(lit.var())) { - m_solver.set_external(lit.var()); - } - lits.push_back(lit); - } - } - - typedef std::pair wliteral; - - void check_unsigned(rational const& c) { - if (!c.is_unsigned()) { - throw default_exception("unsigned coefficient expected"); - } - } - - void convert_to_wlits(app* t, sat::literal_vector const& lits, svector& wlits) { - for (unsigned i = 0; i < lits.size(); ++i) { - rational c = pb.get_coeff(t, i); - check_unsigned(c); - wlits.push_back(std::make_pair(c.get_unsigned(), lits[i])); - } - } - - void convert_pb_args(app* t, svector& wlits) { - sat::literal_vector lits; - convert_pb_args(t->get_num_args(), lits); - convert_to_wlits(t, lits, wlits); - } - - void push_result(bool root, sat::literal lit, unsigned num_args) { - if (root) { - m_result_stack.reset(); - mk_clause(lit); + void convert_ba(app* t, bool root, bool sign) { + sat::extension* ext = m_solver.get_extension(); + sat::ba_solver* ba = nullptr; + if (!ext) { + ba = alloc(sat::ba_solver); + m_solver.set_extension(ba); + ba->push_scopes(m_solver.num_scopes()); } else { - m_result_stack.shrink(m_result_stack.size() - num_args); - m_result_stack.push_back(lit); + ba = dynamic_cast(ext); } - } - - void convert_pb_ge(app* t, bool root, bool sign) { - rational k = pb.get_k(t); - check_unsigned(k); - svector wlits; - convert_pb_args(t, wlits); - if (root && m_solver.num_user_scopes() == 0) { + if (!ba) + throw default_exception("cannot convert to pb"); + sat::ba_internalize internalize(*ba, m_solver, m); + sat::literal lit = internalize.internalize(*this, t, sign, root); + if (root) m_result_stack.reset(); - unsigned k1 = k.get_unsigned(); - if (sign) { - k1 = 1 - k1; - for (wliteral& wl : wlits) { - wl.second.neg(); - k1 += wl.first; - } - } - m_ext->add_pb_ge(sat::null_bool_var, wlits, k1); - } - else { - sat::bool_var v = m_solver.add_var(true); - sat::literal lit(v, sign); - m_ext->add_pb_ge(v, wlits, k.get_unsigned()); - TRACE("goal2sat", tout << "root: " << root << " lit: " << lit << "\n";); - push_result(root, lit, t->get_num_args()); - } - } - - void convert_pb_le(app* t, bool root, bool sign) { - rational k = pb.get_k(t); - k.neg(); - svector wlits; - convert_pb_args(t, wlits); - for (wliteral& wl : wlits) { - wl.second.neg(); - k += rational(wl.first); - } - check_unsigned(k); - if (root && m_solver.num_user_scopes() == 0) { - m_result_stack.reset(); - unsigned k1 = k.get_unsigned(); - if (sign) { - k1 = 1 - k1; - for (wliteral& wl : wlits) { - wl.second.neg(); - k1 += wl.first; - } - } - m_ext->add_pb_ge(sat::null_bool_var, wlits, k1); - } - else { - sat::bool_var v = m_solver.add_var(true); - sat::literal lit(v, sign); - m_ext->add_pb_ge(v, wlits, k.get_unsigned()); - TRACE("goal2sat", tout << "root: " << root << " lit: " << lit << "\n";); - push_result(root, lit, t->get_num_args()); - } - } - - void convert_pb_eq(app* t, bool root, bool sign) { - rational k = pb.get_k(t); - SASSERT(k.is_unsigned()); - svector wlits; - convert_pb_args(t, wlits); - bool base_assert = (root && !sign && m_solver.num_user_scopes() == 0); - sat::bool_var v1 = base_assert ? sat::null_bool_var : m_solver.add_var(true); - sat::bool_var v2 = base_assert ? sat::null_bool_var : m_solver.add_var(true); - m_ext->add_pb_ge(v1, wlits, k.get_unsigned()); - k.neg(); - for (wliteral& wl : wlits) { - wl.second.neg(); - k += rational(wl.first); - } - check_unsigned(k); - m_ext->add_pb_ge(v2, wlits, k.get_unsigned()); - if (base_assert) { - m_result_stack.reset(); - } - else { - sat::literal l1(v1, false), l2(v2, false); - sat::bool_var v = m_solver.add_var(false); - sat::literal l(v, false); - mk_clause(~l, l1); - mk_clause(~l, l2); - mk_clause(~l1, ~l2, l); - m_cache.insert(t, l); - if (sign) l.neg(); - push_result(root, l, t->get_num_args()); - } - } - - void convert_at_least_k(app* t, rational const& k, bool root, bool sign) { - SASSERT(k.is_unsigned()); - sat::literal_vector lits; - convert_pb_args(t->get_num_args(), lits); - unsigned k2 = k.get_unsigned(); - if (root && m_solver.num_user_scopes() == 0) { - m_result_stack.reset(); - if (sign) { - for (sat::literal& l : lits) l.neg(); - k2 = lits.size() + 1 - k2; - } - m_ext->add_at_least(sat::null_bool_var, lits, k2); - } - else { - sat::bool_var v = m_solver.add_var(true); - sat::literal lit(v, false); - m_ext->add_at_least(v, lits, k.get_unsigned()); - m_cache.insert(t, lit); - if (sign) lit.neg(); - TRACE("goal2sat", tout << "root: " << root << " lit: " << lit << "\n";); - push_result(root, lit, t->get_num_args()); - } - } - - void convert_at_most_k(app* t, rational const& k, bool root, bool sign) { - SASSERT(k.is_unsigned()); - sat::literal_vector lits; - convert_pb_args(t->get_num_args(), lits); - for (sat::literal& l : lits) { - l.neg(); - } - unsigned k2 = lits.size() - k.get_unsigned(); - if (root && m_solver.num_user_scopes() == 0) { - m_result_stack.reset(); - if (sign) { - for (sat::literal& l : lits) l.neg(); - k2 = lits.size() + 1 - k2; - } - m_ext->add_at_least(sat::null_bool_var, lits, k2); - } - else { - sat::bool_var v = m_solver.add_var(true); - sat::literal lit(v, false); - m_ext->add_at_least(v, lits, k2); - m_cache.insert(t, lit); - if (sign) lit.neg(); - push_result(root, lit, t->get_num_args()); - } - } - - void convert_eq_k(app* t, rational const& k, bool root, bool sign) { - SASSERT(k.is_unsigned()); - sat::literal_vector lits; - convert_pb_args(t->get_num_args(), lits); - sat::bool_var v1 = (root && !sign) ? sat::null_bool_var : m_solver.add_var(true); - sat::bool_var v2 = (root && !sign) ? sat::null_bool_var : m_solver.add_var(true); - m_ext->add_at_least(v1, lits, k.get_unsigned()); - for (sat::literal& l : lits) { - l.neg(); - } - m_ext->add_at_least(v2, lits, lits.size() - k.get_unsigned()); - - if (root && !sign) { - m_result_stack.reset(); - } - else { - sat::literal l1(v1, false), l2(v2, false); - sat::bool_var v = m_solver.add_var(false); - sat::literal l(v, false); - mk_clause(~l, l1); - mk_clause(~l, l2); - mk_clause(~l1, ~l2, l); - m_cache.insert(t, l); - if (sign) l.neg(); - push_result(root, l, t->get_num_args()); - } - } - - void ensure_extension() { - if (!m_ext) { - sat::extension* ext = m_solver.get_extension(); - if (ext) { - m_ext = dynamic_cast(ext); - SASSERT(m_ext); - } - if (!m_ext) { - m_ext = alloc(sat::ba_solver); - m_solver.set_extension(m_ext); - } - } + if (lit == sat::null_literal) + return; + if (root) + mk_clause(lit); + else + m_result_stack.push_back(lit); } void convert(app * t, bool root, bool sign) { @@ -698,95 +508,45 @@ struct goal2sat::imp : public sat_internalizer { UNREACHABLE(); } } - else if (t->get_family_id() == pb.get_family_id()) { - ensure_extension(); - rational k; - switch (t->get_decl_kind()) { - case OP_AT_MOST_K: - k = pb.get_k(t); - convert_at_most_k(t, k, root, sign); - break; - case OP_AT_LEAST_K: - k = pb.get_k(t); - convert_at_least_k(t, k, root, sign); - break; - case OP_PB_LE: - if (pb.has_unit_coefficients(t)) { - k = pb.get_k(t); - convert_at_most_k(t, k, root, sign); - } - else { - convert_pb_le(t, root, sign); - } - break; - case OP_PB_GE: - if (pb.has_unit_coefficients(t)) { - k = pb.get_k(t); - convert_at_least_k(t, k, root, sign); - } - else { - convert_pb_ge(t, root, sign); - } - break; - case OP_PB_EQ: - if (pb.has_unit_coefficients(t)) { - k = pb.get_k(t); - convert_eq_k(t, k, root, sign); - } - else { - convert_pb_eq(t, root, sign); - } - break; - default: - UNREACHABLE(); - } + else if (pb.is_pb(t)) { + convert_ba(t, root, sign); } else { UNREACHABLE(); } } - - unsigned get_num_args(app* t) { - - if (m.is_iff(t) && m_xor_solver) { - unsigned n = 2; - while (m.is_iff(t->get_arg(1))) { - ++n; - t = to_app(t->get_arg(1)); - } - return n; - } - else { - return t->get_num_args(); - } + bool is_xor(app* t) const { + return m_xor_solver && m.is_iff(t) && m.is_iff(t->get_arg(1)); } - expr* get_arg(app* t, unsigned idx) { - if (m.is_iff(t) && m_xor_solver) { - while (idx >= 1) { - SASSERT(m.is_iff(t)); - t = to_app(t->get_arg(1)); - --idx; - } - if (m.is_iff(t)) { - return t->get_arg(idx); - } - else { - return t; + struct scoped_stack { + sat::literal_vector& r; + unsigned rsz; + svector& frames; + unsigned fsz; + bool is_root; + scoped_stack(imp& x, bool is_root) : + r(x.m_result_stack), rsz(r.size()), frames(x.m_frame_stack), fsz(frames.size()), is_root(is_root) + {} + ~scoped_stack() { + if (frames.size() > fsz) { + frames.shrink(fsz); + r.shrink(rsz); + return; } + SASSERT(frames.size() == fsz); + SASSERT(!is_root || rsz == r.size()); + SASSERT(is_root || rsz + 1 == r.size()); } - else { - return t->get_arg(idx); - } - } + }; void process(expr* n, bool is_root) { + scoped_stack _sc(*this, is_root); unsigned sz = m_frame_stack.size(); - if (visit(n, is_root, false)) { - SASSERT(m_result_stack.empty()); + if (visit(n, is_root, false)) return; - } + while (m_frame_stack.size() > sz) { loop: if (!m.inc()) @@ -809,9 +569,14 @@ struct goal2sat::imp : public sat_internalizer { visit(t->get_arg(0), root, !sign); continue; } - unsigned num = get_num_args(t); + if (is_xor(t)) { + convert_ba(t, root, sign); + m_frame_stack.pop_back(); + continue; + } + unsigned num = t->get_num_args(); while (fr.m_idx < num) { - expr * arg = get_arg(t, fr.m_idx); + expr * arg = t->get_arg(fr.m_idx); fr.m_idx++; if (!visit(arg, false, false)) goto loop; @@ -825,11 +590,11 @@ struct goal2sat::imp : public sat_internalizer { } sat::literal internalize(expr* n) override { - SASSERT(m_result_stack.empty()); + unsigned sz = m_result_stack.size(); process(n, false); - SASSERT(m_result_stack.size() == 1); + SASSERT(m_result_stack.size() == sz + 1); sat::literal result = m_result_stack.back(); - m_result_stack.reset(); + m_result_stack.pop_back(); return result; } diff --git a/src/sat/tactic/goal2sat.h b/src/sat/tactic/goal2sat.h index 47a12e93f..a4c9b90b0 100644 --- a/src/sat/tactic/goal2sat.h +++ b/src/sat/tactic/goal2sat.h @@ -32,15 +32,8 @@ Notes: #include "sat/sat_solver.h" #include "tactic/model_converter.h" #include "tactic/generic_model_converter.h" -#include "sat/tactic/atom2bool_var.h" - - -class sat_internalizer { -public: - virtual bool is_bool_op(expr* e) const = 0; - virtual sat::literal internalize(expr* e) = 0; - virtual sat::bool_var add_bool_var(expr* e) = 0; -}; +#include "sat/smt/atom2bool_var.h" +#include "sat/smt/sat_smt.h" class goal2sat { struct imp; @@ -58,7 +51,6 @@ public: static bool has_unsupported_bool(goal const & s); - /** \brief "Compile" the goal into the given sat solver. Store a mapping from atoms to boolean variables into m. diff --git a/src/shell/dimacs_frontend.cpp b/src/shell/dimacs_frontend.cpp index 25e8aa0bc..3b82fbd52 100644 --- a/src/shell/dimacs_frontend.cpp +++ b/src/shell/dimacs_frontend.cpp @@ -25,7 +25,7 @@ Revision History: #include "sat/dimacs.h" #include "sat/sat_params.hpp" #include "sat/sat_solver.h" -#include "sat/ba_solver.h" +#include "sat/ba/ba_solver.h" #include "sat/tactic/goal2sat.h" #include "ast/reg_decl_plugins.h" #include "tactic/tactic.h"