diff --git a/src/ast/euf/euf_egraph.cpp b/src/ast/euf/euf_egraph.cpp index b3a1d8516..6d22083dd 100644 --- a/src/ast/euf/euf_egraph.cpp +++ b/src/ast/euf/euf_egraph.cpp @@ -35,7 +35,7 @@ namespace euf { void undo(egraph & ctx) override { theory_var v = m_enode->get_th_var(m_th_id); - SASSERT(v != null_var); + SASSERT(v != null_theory_var); m_enode->del_th_var(m_th_id); enode * root = m_enode->get_root(); if (root != m_enode && root->get_th_var(m_th_id) == v) @@ -43,6 +43,26 @@ namespace euf { } }; + /** + \brief Trail for replace_th_var + */ + class replace_th_var_trail : public trail { + enode * m_enode; + unsigned m_th_id:8; + unsigned m_old_th_var:24; + public: + replace_th_var_trail(enode * n, theory_id th_id, theory_var old_var): + m_enode(n), + m_th_id(th_id), + m_old_th_var(old_var) { + } + + void undo(egraph & ctx) override { + SASSERT(m_enode->get_th_var(m_th_id) != null_theory_var); + m_enode->replace_th_var(m_old_th_var, m_th_id); + } + }; + void egraph::undo_eq(enode* r1, enode* n1, unsigned r2_num_parents) { enode* r2 = r1->get_root(); r2->dec_class_size(r1->class_size()); @@ -84,7 +104,7 @@ namespace euf { void egraph::reinsert_equality(enode* p) { SASSERT(is_equality(p)); if (p->get_arg(0)->get_root() == p->get_arg(1)->get_root()) { - m_new_eqs.push_back(p); + m_new_lits.push_back(enode_bool_pair(p, true)); ++m_stats.m_num_eqs; } } @@ -93,25 +113,17 @@ namespace euf { return m.is_eq(p->get_owner()); } - void egraph::dedup_equalities() { - unsigned j = 0; - for (enode* p : m_new_eqs) { - if (!p->is_marked1()) - m_new_eqs[j++] = p; - p->mark1(); - } - for (enode* p : m_new_eqs) - p->unmark1(); - m_new_eqs.shrink(j); - } - void egraph::force_push() { for (; m_num_scopes > 0; --m_num_scopes) { scope s; s.m_inconsistent = m_inconsistent; s.m_num_eqs = m_eqs.size(); s.m_num_nodes = m_nodes.size(); - s.m_trail_sz = m_trail.size(); + s.m_trail_sz = m_trail.size(); + s.m_new_lits_sz = m_new_lits.size(); + s.m_new_th_eqs_sz = m_new_th_eqs.size(); + s.m_new_lits_qhead = m_new_lits_qhead; + s.m_new_th_eqs_qhead = m_new_th_eqs_qhead; m_scopes.push_back(s); m_region.push_scope(); } @@ -160,10 +172,27 @@ namespace euf { void egraph::add_th_var(enode* n, theory_var v, theory_id id) { force_push(); - SASSERT(null_var == n->get_th_var(id)); - SASSERT(n->class_size() == 1); - n->add_th_var(v, id, m_region); - m_trail.push_back(new (m_region) add_th_var_trail(n, id)); + theory_var w = n->get_th_var(id); + enode* r = n->get_root(); + + if (w == null_theory_var) { + n->add_th_var(v, id, m_region); + m_trail.push_back(new (m_region) add_th_var_trail(n, id)); + if (r != n) { + theory_var u = r->get_th_var(id); + if (u == null_theory_var) + r->add_th_var(v, id, m_region); + else + m_new_th_eqs.push_back(th_eq(id, v, u, n, r)); + } + } + else { + theory_var u = r->get_th_var(id); + SASSERT(u != v && u != null_theory_var); + n->replace_th_var(v, id); + m_trail.push_back(new (m_region) replace_th_var_trail(n, id, u)); + m_new_th_eqs.push_back(th_eq(id, v, u, n, r)); + } } void egraph::pop(unsigned num_scopes) { @@ -187,9 +216,13 @@ namespace euf { } undo_trail_stack(*this, m_trail, s.m_trail_sz); m_inconsistent = s.m_inconsistent; + m_new_lits_qhead = s.m_new_lits_qhead; + m_new_th_eqs_qhead = s.m_new_th_eqs_qhead; m_eqs.shrink(s.m_num_eqs); m_nodes.shrink(s.m_num_nodes); m_exprs.shrink(s.m_num_nodes); + m_new_lits.shrink(s.m_new_lits_sz); + m_new_th_eqs.shrink(s.m_new_th_eqs_sz); m_scopes.shrink(old_lim); m_region.pop_scope(num_scopes); } @@ -212,7 +245,7 @@ namespace euf { std::swap(n1, n2); } if ((m.is_true(r2->get_owner()) || m.is_false(r2->get_owner())) && j.is_congruence()) { - m_new_lits.push_back(n1); + m_new_lits.push_back(enode_pair(n1, false)); ++m_stats.m_num_lits; } for (enode* p : enode_parents(n1)) @@ -235,7 +268,7 @@ namespace euf { for (auto iv : enode_th_vars(n)) { theory_id id = iv.get_id(); theory_var v = root->get_th_var(id); - if (v == null_var) { + if (v == null_theory_var) { root->add_th_var(iv.get_var(), id, m_region); m_trail.push_back(new (m_region) add_th_var_trail(root, id)); } @@ -246,10 +279,7 @@ namespace euf { } } - void egraph::propagate() { - m_new_eqs.reset(); - m_new_lits.reset(); - m_new_th_eqs.reset(); + bool egraph::propagate() { SASSERT(m_num_scopes == 0 || m_worklist.empty()); unsigned head = 0, tail = m_worklist.size(); while (head < tail && m.limit().inc() && !inconsistent()) { @@ -267,7 +297,10 @@ namespace euf { tail = m_worklist.size(); } m_worklist.reset(); - dedup_equalities(); + return + (m_new_lits_qhead < m_new_lits.size()) || + (m_new_th_eqs_qhead < m_new_th_eqs.size()) || + inconsistent(); } void egraph::set_conflict(enode* n1, enode* n2, justification j) { diff --git a/src/ast/euf/euf_egraph.h b/src/ast/euf/euf_egraph.h index 19bee4a06..0fd5c4189 100644 --- a/src/ast/euf/euf_egraph.h +++ b/src/ast/euf/euf_egraph.h @@ -63,6 +63,10 @@ namespace euf { unsigned m_num_eqs; unsigned m_num_nodes; unsigned m_trail_sz; + unsigned m_new_lits_sz; + unsigned m_new_th_eqs_sz; + unsigned m_new_lits_qhead; + unsigned m_new_th_eqs_qhead; }; struct stats { unsigned m_num_merge; @@ -87,8 +91,9 @@ namespace euf { enode *m_n1 { nullptr }; enode *m_n2 { nullptr }; justification m_justification; - enode_vector m_new_eqs; - enode_vector m_new_lits; + unsigned m_new_lits_qhead { 0 }; + unsigned m_new_th_eqs_qhead { 0 }; + svector m_new_lits; svector m_new_th_eqs; enode_vector m_todo; stats m_stats; @@ -107,7 +112,6 @@ namespace euf { void merge_th_eq(enode* n, enode* root); void merge_justification(enode* n1, enode* n2, justification j); void unmerge_justification(enode* n1); - void dedup_equalities(); void reinsert_equality(enode* p); void update_children(enode* n); void push_lca(enode* a, enode* b); @@ -151,11 +155,22 @@ namespace euf { equated nodes are merged. Use then new_eqs() to extract the vector of new equalities. */ - void propagate(); + bool propagate(); bool inconsistent() const { return m_inconsistent; } - enode_vector const& new_eqs() const { return m_new_eqs; } - enode_vector const& new_lits() const { return m_new_lits; } - svector const& new_th_eqs() const { return m_new_th_eqs; } + + /** + \brief Maintain and update cursor into propagated consequences. + The result of get_literal() is a pair (n, is_eq) + where \c n is an enode and \c is_eq indicates whether the enode + is an equality consequence. + */ + bool has_literal() const { return m_new_lits_qhead < m_new_lits.size(); } + bool has_th_eq() const { return m_new_th_eqs_qhead < m_new_th_eqs.size(); } + enode_bool_pair get_literal() const { return m_new_lits[m_new_lits_qhead]; } + th_eq get_th_eq() const { return m_new_th_eqs[m_new_th_eqs_qhead]; } + void next_literal() { SASSERT(m_new_lits_qhead < m_new_lits.size()); m_new_lits_qhead++; } + void next_th_eq() { SASSERT(m_new_th_eqs_qhead < m_new_th_eqs.size()); m_new_th_eqs_qhead++; } + void add_th_var(enode* n, theory_var v, theory_id id); diff --git a/src/ast/euf/euf_enode.h b/src/ast/euf/euf_enode.h index ad330f0b5..064576119 100644 --- a/src/ast/euf/euf_enode.h +++ b/src/ast/euf/euf_enode.h @@ -32,8 +32,8 @@ namespace euf { typedef id_var_list<> th_var_list; typedef int theory_var; typedef int theory_id; - const theory_var null_var = -1; - const theory_id null_id = -1; + const theory_var null_theory_var = -1; + const theory_id null_theory_id = -1; class enode { expr* m_owner; @@ -83,6 +83,7 @@ namespace euf { void set_update_children() { m_update_children = true; } friend class add_th_var_trail; + friend class replace_th_var_trail; void add_th_var(theory_var v, theory_id id, region & r) { m_th_vars.add_var(v, id, r); } void replace_th_var(theory_var v, theory_id id) { m_th_vars.replace(v, id); } void del_th_var(theory_id id) { m_th_vars.del_var(id); } diff --git a/src/sat/sat_binspr.cpp b/src/sat/sat_binspr.cpp index 89826cd38..460b6a4c0 100644 --- a/src/sat/sat_binspr.cpp +++ b/src/sat/sat_binspr.cpp @@ -343,7 +343,7 @@ namespace sat { void binspr::block_binary(literal lit1, literal lit2, bool learned) { IF_VERBOSE(2, verbose_stream() << "SPR: " << learned << " " << ~lit1 << " " << ~lit2 << "\n"); TRACE("sat", tout << "SPR: " << learned << " " << ~lit1 << " " << ~lit2 << "\n";); - s->mk_clause(~lit1, ~lit2, learned); + s->mk_clause(~lit1, ~lit2, learned ? sat::status::redundant() : sat::status::asserted()); ++m_bin_clauses; } diff --git a/src/sat/sat_cut_simplifier.cpp b/src/sat/sat_cut_simplifier.cpp index 68255afb0..bc6d9871f 100644 --- a/src/sat/sat_cut_simplifier.cpp +++ b/src/sat/sat_cut_simplifier.cpp @@ -468,7 +468,7 @@ namespace sat { if (w.is_binary_clause() && v == w.get_literal()) return; certify_implies(u, v, c); - s.mk_clause(~u, v, true); + s.mk_clause(~u, v, sat::status::redundant()); // m_bins owns reference to ~u or v created by certify_implies m_bins.insert(p); ++m_stats.m_num_learned_implies; @@ -524,7 +524,7 @@ namespace sat { void cut_simplifier::track_binary(literal u, literal v) { if (s.m_config.m_drat) { - s.m_drat.add(u, v, true); + s.m_drat.add(u, v, sat::status::redundant()); } } diff --git a/src/sat/sat_drat.cpp b/src/sat/sat_drat.cpp index 869364b76..a3e5fb0ee 100644 --- a/src/sat/sat_drat.cpp +++ b/src/sat/sat_drat.cpp @@ -68,19 +68,23 @@ namespace sat { m_activity = s.get_config().m_drat_activity; } - std::ostream& operator<<(std::ostream& out, drat::status st) { - switch (st) { - case drat::status::learned: return out << "l"; - case drat::status::deleted: return out << "d"; - case drat::status::asserted: return out << "c a"; - case drat::status::ba: return out << "c ba"; - case drat::status::euf: return out << "c euf"; + std::ostream& operator<<(std::ostream& out, status st) { + if (st.is_redundant()) + out << "l"; + else if (st.is_deleted()) + out << "d"; + else if (st.is_asserted()) + out << "a"; + + switch (st.orig) { + case status::orig::ba: return out << " ba"; + case status::orig::euf: return out << " euf"; default: return out; } } void drat::dump(unsigned n, literal const* c, status st) { - if (st == status::asserted && !s.m_ext) + if (st.is_asserted() && !s.m_ext) return; if (m_activity && ((m_num_add % 1000) == 0)) dump_activity(); @@ -90,35 +94,29 @@ namespace sat { char* lastd = digits + sizeof(digits); unsigned len = 0; - switch (st) { - case status::asserted: - buffer[0] = 'c'; - buffer[1] = ' '; - buffer[2] = 'a'; - buffer[3] = ' '; - len = 4; + switch (st.st) { + case status::st::asserted: + buffer[len++] = 'a'; + buffer[len++] = ' '; break; - case status::deleted: - buffer[0] = 'd'; - buffer[1] = ' '; - len = 2; + case status::st::deleted: + buffer[len++] = 'd'; + buffer[len++] = ' '; break; - case status::euf: - buffer[0] = 'c'; - buffer[1] = ' '; - buffer[2] = 'e'; - buffer[3] = 'u'; - buffer[4] = 'f'; - buffer[5] = ' '; - len = 6; + default: break; - case status::ba: - buffer[0] = 'c'; - buffer[1] = ' '; - buffer[2] = 'b'; - buffer[3] = 'a'; - buffer[4] = ' '; - len = 5; + } + switch (st.orig) { + case status::orig::euf: + buffer[len++] = 'e'; + buffer[len++] = 'u'; + buffer[len++] = 'f'; + buffer[len++] = ' '; + break; + case status::orig::ba: + buffer[len++] = 'b'; + buffer[len++] = 'a'; + buffer[len++] = ' '; break; default: break; @@ -147,10 +145,13 @@ namespace sat { buffer[len++] = '0'; buffer[len++] = '\n'; m_out->write(buffer, len); + + m_out->flush(); + } void drat::dump_activity() { - (*m_out) << "c a "; + (*m_out) << "c activity "; for (unsigned v = 0; v < s.num_vars(); ++v) { (*m_out) << s.m_activity[v] << " "; } @@ -159,12 +160,10 @@ namespace sat { void drat::bdump(unsigned n, literal const* c, status st) { unsigned char ch = 0; - switch (st) { - case status::asserted: return; - case status::ba: return; - case status::euf: return; - case status::learned: ch = 'a'; break; - case status::deleted: ch = 'd'; break; + switch (st.st) { + case status::st::asserted: return; + case status::st::redundant: ch = 'a'; break; + case status::st::deleted: ch = 'd'; break; default: UNREACHABLE(); break; } char buffer[10000]; @@ -217,10 +216,10 @@ namespace sat { declare(l); IF_VERBOSE(20, trace(verbose_stream(), 1, &l, st);); - if (st == status::learned) { + if (st.is_redundant()) { verify(1, &l); } - if (st == status::deleted) { + if (st.is_deleted()) { return; } if (m_check_unsat) { @@ -237,15 +236,15 @@ namespace sat { literal lits[2] = { l1, l2 }; IF_VERBOSE(20, trace(verbose_stream(), 2, lits, st);); - if (st == status::deleted) { + if (st.is_deleted()) { // noop // don't record binary as deleted. } else { - if (st == status::learned) { + if (st.is_redundant()) { verify(2, lits); } - clause* c = m_alloc.mk_clause(2, lits, st == status::learned); + clause* c = m_alloc.mk_clause(2, lits, st.is_redundant()); m_proof.push_back(c); m_status.push_back(st); if (!m_check_unsat) return; @@ -268,12 +267,12 @@ namespace sat { void drat::bool_def(bool_var v, unsigned n) { if (m_out) - (*m_out) << "c b " << v << " := " << n << " 0\n"; + (*m_out) << "b " << v << " " << n << " 0\n"; } void drat::def_begin(unsigned n, symbol const& name) { if (m_out) - (*m_out) << "c n " << n << " := " << name; + (*m_out) << "n " << n << " " << name; } void drat::def_add_arg(unsigned arg) { @@ -319,13 +318,13 @@ namespace sat { unsigned n = c.size(); IF_VERBOSE(20, trace(verbose_stream(), n, c.begin(), st);); - if (st == status::learned) { + if (st.is_redundant()) { verify(c); } m_status.push_back(st); m_proof.push_back(&c); - if (st == status::deleted) { + if (st.is_deleted()) { if (n > 0) del_watch(c, c[0]); if (n > 1) del_watch(c, c[1]); return; @@ -471,7 +470,7 @@ namespace sat { void drat::validate_propagation() const { for (unsigned i = 0; i < m_proof.size(); ++i) { status st = m_status[i]; - if (m_proof[i] && m_proof[i]->size() > 1 && st != status::deleted) { + if (m_proof[i] && m_proof[i]->size() > 1 && !st.is_deleted()) { clause& c = *m_proof[i]; unsigned num_undef = 0, num_true = 0; for (unsigned j = 0; j < c.size(); ++j) { @@ -494,14 +493,15 @@ namespace sat { SASSERT(lits.size() == n); for (unsigned i = 0; i < m_proof.size(); ++i) { status st = m_status[i]; - if (m_proof[i] && m_proof[i]->size() > 1 && st == status::asserted) { + if (m_proof[i] && m_proof[i]->size() > 1 && st.is_asserted()) { clause& c = *m_proof[i]; unsigned j = 0; for (; j < c.size() && c[j] != ~l; ++j) {} - if (j != c.size()) { + if (st.orig == status::orig::sat && j != c.size()) { lits.append(j, c.begin()); lits.append(c.size() - j - 1, c.begin() + j + 1); - if (!is_drup(lits.size(), lits.c_ptr())) return false; + if (!is_drup(lits.size(), lits.c_ptr())) + return false; lits.resize(n); } } @@ -562,7 +562,7 @@ namespace sat { clause& c = *m_proof[i]; status st = m_status[i]; if (match(n, lits, c)) { - if (st == status::deleted) { + if (st.is_deleted()) { num_del++; } else { @@ -601,7 +601,7 @@ namespace sat { } for (unsigned i = 0; i < m_proof.size(); ++i) { clause* c = m_proof[i]; - if (m_status[i] != status::deleted && c) { + if (!m_status[i].is_deleted() && c) { unsigned num_true = 0; unsigned num_undef = 0; for (unsigned j = 0; j < c->size(); ++j) { @@ -716,14 +716,16 @@ namespace sat { clauses.set_end(it2); } - drat::status drat::get_status(bool learned) const { - return learned || s.m_searching ? status::learned : status::asserted; + status drat::get_status(bool learned) const { + if (learned || s.m_searching) + return status::redundant(); + return status::asserted(); } void drat::add() { ++m_num_add; if (m_out) (*m_out) << "0\n"; - if (m_bout) bdump(0, nullptr, status::learned); + if (m_bout) bdump(0, nullptr, status::redundant()); if (m_check_unsat) { SASSERT(m_inconsistent); } @@ -735,53 +737,60 @@ namespace sat { if (m_bout) bdump(1, &l, st); if (m_check) append(l, st); } - void drat::add(literal l1, literal l2, bool learned) { - ++m_num_add; + void drat::add(literal l1, literal l2, status st) { + if (st.is_deleted()) + ++m_num_del; + else + ++m_num_add; literal ls[2] = {l1, l2}; - status st = get_status(learned); if (m_out) dump(2, ls, st); if (m_bout) bdump(2, ls, st); if (m_check) append(l1, l2, st); } - void drat::add(clause& c, bool learned) { - ++m_num_add; - status st = get_status(learned); + void drat::add(clause& c, status st) { + if (st.is_deleted()) + ++m_num_del; + else + ++m_num_add; if (m_out) dump(c.size(), c.begin(), st); if (m_bout) bdump(c.size(), c.begin(), st); if (m_check) { - clause* cl = m_alloc.mk_clause(c.size(), c.begin(), learned); - append(*cl, get_status(learned)); + clause* cl = m_alloc.mk_clause(c.size(), c.begin(), st.is_redundant()); + append(*cl, st); } } - void drat::add(literal_vector const& lits, status th) { - ++m_num_add; + void drat::add(literal_vector const& lits, status st) { + if (st.is_deleted()) + ++m_num_del; + else + ++m_num_add; if (m_check) { switch (lits.size()) { case 0: add(); break; - case 1: append(lits[0], th); break; + case 1: append(lits[0], st); break; default: { - clause* c = m_alloc.mk_clause(lits.size(), lits.c_ptr(), true); - append(*c, th); + clause* c = m_alloc.mk_clause(lits.size(), lits.c_ptr(), st.is_redundant()); + append(*c, st); break; } } } if (m_out) - dump(lits.size(), lits.c_ptr(), th); + dump(lits.size(), lits.c_ptr(), st); } void drat::add(literal_vector const& c) { ++m_num_add; - if (m_out) dump(c.size(), c.begin(), status::learned); - if (m_bout) bdump(c.size(), c.begin(), status::learned); + if (m_out) dump(c.size(), c.begin(), status::redundant()); + if (m_bout) bdump(c.size(), c.begin(), status::redundant()); if (m_check) { for (literal lit : c) declare(lit); switch (c.size()) { case 0: add(); break; - case 1: append(c[0], status::learned); break; + case 1: append(c[0], status::redundant()); break; default: { verify(c.size(), c.begin()); clause* cl = m_alloc.mk_clause(c.size(), c.c_ptr(), true); - append(*cl, status::ba); + append(*cl, status::redundant()); break; } } @@ -790,17 +799,17 @@ namespace sat { void drat::del(literal l) { ++m_num_del; - if (m_out) dump(1, &l, status::deleted); - if (m_bout) bdump(1, &l, status::deleted); - if (m_check_unsat) append(l, status::deleted); + if (m_out) dump(1, &l, status::deleted()); + if (m_bout) bdump(1, &l, status::deleted()); + if (m_check_unsat) append(l, status::deleted()); } void drat::del(literal l1, literal l2) { ++m_num_del; literal ls[2] = {l1, l2}; - if (m_out) dump(2, ls, status::deleted); - if (m_bout) bdump(2, ls, status::deleted); - if (m_check) append(l1, l2, status::deleted); + if (m_out) dump(2, ls, status::deleted()); + if (m_bout) bdump(2, ls, status::deleted()); + if (m_check) append(l1, l2, status::deleted()); } void drat::del(clause& c) { @@ -816,21 +825,21 @@ namespace sat { } #endif ++m_num_del; - if (m_out) dump(c.size(), c.begin(), status::deleted); - if (m_bout) bdump(c.size(), c.begin(), status::deleted); + if (m_out) dump(c.size(), c.begin(), status::deleted()); + if (m_bout) bdump(c.size(), c.begin(), status::deleted()); if (m_check) { clause* c1 = m_alloc.mk_clause(c.size(), c.begin(), c.is_learned()); - append(*c1, status::deleted); + append(*c1, status::deleted()); } } void drat::del(literal_vector const& c) { ++m_num_del; - if (m_out) dump(c.size(), c.begin(), status::deleted); - if (m_bout) bdump(c.size(), c.begin(), status::deleted); + if (m_out) dump(c.size(), c.begin(), status::deleted()); + if (m_bout) bdump(c.size(), c.begin(), status::deleted()); if (m_check) { clause* c1 = m_alloc.mk_clause(c.size(), c.begin(), true); - append(*c1, status::deleted); + append(*c1, status::deleted()); } } diff --git a/src/sat/sat_drat.h b/src/sat/sat_drat.h index e178ca15e..5ca47d010 100644 --- a/src/sat/sat_drat.h +++ b/src/sat/sat_drat.h @@ -19,17 +19,18 @@ Notes: For SMT extensions are as follows: - Input assertion (trusted modulo internalizer): - c a * 0 + Assertion (trusted modulo internalizer): + a [] * 0 + The optional theory id indicates the assertion is irredundant Bridge from ast-node to boolean variable: - c b := 0 + b 0 Definition of an ast node: - c n := * 0 + n * 0 Theory lemma - c * 0 + * 0 Available theories are: - euf The theory lemma should be a consequence of congruence closure. @@ -41,10 +42,10 @@ Notes: --*/ #pragma once +#include "sat_types.h" + namespace sat { class drat { - public: - enum status { asserted, learned, deleted, ba, euf }; private: struct watched_clause { clause* m_clause; @@ -100,8 +101,8 @@ namespace sat { void updt_config(); void add(); void add(literal l, bool learned); - void add(literal l1, literal l2, bool learned); - void add(clause& c, bool learned); + void add(literal l1, literal l2, status st); + void add(clause& c, status st); void add(literal_vector const& c, status st); void add(literal_vector const& c); // add learned clause diff --git a/src/sat/sat_elim_eqs.cpp b/src/sat/sat_elim_eqs.cpp index d666f19ad..baca88229 100644 --- a/src/sat/sat_elim_eqs.cpp +++ b/src/sat/sat_elim_eqs.cpp @@ -207,7 +207,7 @@ namespace sat { c.update_approx(); } if (m_solver.m_config.m_drat) { - m_solver.m_drat.add(c, true); + m_solver.m_drat.add(c, status::redundant()); drat_delete_clause(); } @@ -237,8 +237,8 @@ namespace sat { // cannot really eliminate v, since we have to notify extension of future assignments if (m_solver.m_config.m_drat && m_solver.m_config.m_drat_file.is_null()) { std::cout << "DRAT\n"; - m_solver.m_drat.add(~l, r, true); - m_solver.m_drat.add(l, ~r, true); + m_solver.m_drat.add(~l, r, sat::status::redundant()); + m_solver.m_drat.add(l, ~r, sat::status::redundant()); } m_solver.mk_bin_clause(~l, r, false); m_solver.mk_bin_clause(l, ~r, false); diff --git a/src/sat/sat_lookahead.cpp b/src/sat/sat_lookahead.cpp index a550102a9..4884f0797 100644 --- a/src/sat/sat_lookahead.cpp +++ b/src/sat/sat_lookahead.cpp @@ -2499,7 +2499,7 @@ namespace sat { uf.merge((~u).index(), (~v).index()); VERIFY(!m_s.was_eliminated(u.var())); VERIFY(!m_s.was_eliminated(v.var())); - m_s.mk_clause(~u, v, true); + m_s.mk_clause(~u, v, sat::status::redundant()); } else { candidates[k] = candidates[j]; diff --git a/src/sat/sat_parallel.cpp b/src/sat/sat_parallel.cpp index 802c9ce4a..03fff2b58 100644 --- a/src/sat/sat_parallel.cpp +++ b/src/sat/sat_parallel.cpp @@ -194,7 +194,7 @@ namespace sat { IF_VERBOSE(3, verbose_stream() << s.m_par_id << ": retrieve " << m_lits << "\n";); SASSERT(n >= 2); if (usable_clause) { - s.mk_clause_core(m_lits.size(), m_lits.c_ptr(), true); + s.mk_clause_core(m_lits.size(), m_lits.c_ptr(), sat::status::redundant()); } } } diff --git a/src/sat/sat_probing.cpp b/src/sat/sat_probing.cpp index ed790a856..c56fee1e0 100644 --- a/src/sat/sat_probing.cpp +++ b/src/sat/sat_probing.cpp @@ -55,7 +55,7 @@ namespace sat { for (unsigned i = old_tr_sz; i < tr_sz; i++) { entry.m_lits.push_back(s.m_trail[i]); if (s.m_config.m_drat) { - s.m_drat.add(~l, s.m_trail[i], true); + s.m_drat.add(~l, s.m_trail[i], status::redundant()); } } } @@ -71,8 +71,8 @@ namespace sat { for (literal lit : *implied_lits) { if (m_assigned.contains(lit)) { if (s.m_config.m_drat) { - s.m_drat.add(l, lit, true); - s.m_drat.add(~l, lit, true); + s.m_drat.add(l, lit, status::redundant()); + s.m_drat.add(~l, lit, status::redundant()); } s.assign_scoped(lit); m_num_assigned++; @@ -106,8 +106,8 @@ namespace sat { for (literal lit : m_to_assert) { if (s.m_config.m_drat) { - s.m_drat.add(l, lit, true); - s.m_drat.add(~l, lit, true); + s.m_drat.add(l, lit, status::redundant()); + s.m_drat.add(~l, lit, status::redundant()); } s.assign_scoped(lit); m_num_assigned++; diff --git a/src/sat/sat_simplifier.cpp b/src/sat/sat_simplifier.cpp index 6bf3a92f9..09dff101c 100644 --- a/src/sat/sat_simplifier.cpp +++ b/src/sat/sat_simplifier.cpp @@ -681,7 +681,7 @@ namespace sat { if (s.m_config.m_drat && c.contains(l)) { unsigned sz = c.size(); c.elim(l); - s.m_drat.add(c, true); + s.m_drat.add(c, status::redundant()); c.restore(sz); s.m_drat.del(c); c.shrink(sz-1); @@ -2005,7 +2005,7 @@ namespace sat { s.m_stats.m_mk_clause++; clause * new_c = s.alloc_clause(m_new_cls.size(), m_new_cls.c_ptr(), false); - if (s.m_config.m_drat) s.m_drat.add(*new_c, true); + if (s.m_config.m_drat) s.m_drat.add(*new_c, status::redundant()); s.m_clauses.push_back(new_c); m_use_list.insert(*new_c); diff --git a/src/sat/sat_solver.cpp b/src/sat/sat_solver.cpp index 6f54585ce..101a44968 100644 --- a/src/sat/sat_solver.cpp +++ b/src/sat/sat_solver.cpp @@ -213,7 +213,7 @@ namespace sat { if (c->glue() <= 2 || (c->size() <= 40 && c->glue() <= 8) || copy_learned) { buffer.reset(); for (literal l : *c) buffer.push_back(l); - clause* c1 = mk_clause_core(buffer.size(), buffer.c_ptr(), true); + clause* c1 = mk_clause_core(buffer.size(), buffer.c_ptr(), sat::status::redundant()); if (c1) { ++num_learned; c1->set_glue(c->glue()); @@ -305,7 +305,7 @@ namespace sat { } - clause* solver::mk_clause(unsigned num_lits, literal * lits, bool learned) { + clause* solver::mk_clause(unsigned num_lits, literal * lits, sat::status st) { m_model_is_current = false; DEBUG_CODE({ for (unsigned i = 0; i < num_lits; i++) { @@ -315,24 +315,24 @@ namespace sat { }); if (m_user_scope_literals.empty()) { - return mk_clause_core(num_lits, lits, learned); + return mk_clause_core(num_lits, lits, st); } else { m_aux_literals.reset(); m_aux_literals.append(num_lits, lits); m_aux_literals.append(m_user_scope_literals); - return mk_clause_core(m_aux_literals.size(), m_aux_literals.c_ptr(), learned); + return mk_clause_core(m_aux_literals.size(), m_aux_literals.c_ptr(), st); } } - clause* solver::mk_clause(literal l1, literal l2, bool learned) { + clause* solver::mk_clause(literal l1, literal l2, sat::status st) { literal ls[2] = { l1, l2 }; - return mk_clause(2, ls, learned); + return mk_clause(2, ls, st); } - clause* solver::mk_clause(literal l1, literal l2, literal l3, bool learned) { + clause* solver::mk_clause(literal l1, literal l2, literal l3, sat::status st) { literal ls[3] = { l1, l2, l3 }; - return mk_clause(3, ls, learned); + return mk_clause(3, ls, st); } void solver::del_clause(clause& c) { @@ -350,9 +350,10 @@ namespace sat { m_stats.m_del_clause++; } - clause * solver::mk_clause_core(unsigned num_lits, literal * lits, bool learned) { - TRACE("sat", tout << "mk_clause: " << mk_lits_pp(num_lits, lits) << (learned?" learned":" aux") << "\n";); - if (!learned) { + clause * solver::mk_clause_core(unsigned num_lits, literal * lits, sat::status st) { + bool redundant = st.is_redundant(); + TRACE("sat", tout << "mk_clause: " << mk_lits_pp(num_lits, lits) << (redundant?" learned":" aux") << "\n";); + if (!redundant) { unsigned old_sz = num_lits; bool keep = simplify_clause(num_lits, lits); TRACE("sat_mk_clause", tout << "mk_clause (after simp), keep: " << keep << "\n" << mk_lits_pp(num_lits, lits) << "\n";); @@ -360,17 +361,17 @@ namespace sat { return nullptr; // clause is equivalent to true. } // if an input clause is simplified, then log the simplified version as learned - if (!learned && old_sz > num_lits && m_config.m_drat) { + if (old_sz > num_lits && m_config.m_drat) { m_lemma.reset(); m_lemma.append(num_lits, lits); - m_drat.add(m_lemma); + m_drat.add(m_lemma, st); } ++m_stats.m_non_learned_generation; if (!m_searching) { m_mc.add_clause(num_lits, lits); } - } - + } + switch (num_lits) { case 0: set_conflict(); @@ -379,55 +380,56 @@ namespace sat { assign_unit(lits[0]); return nullptr; case 2: - mk_bin_clause(lits[0], lits[1], learned); - if (learned && m_par) m_par->share_clause(*this, lits[0], lits[1]); + mk_bin_clause(lits[0], lits[1], st); + if (redundant && m_par) m_par->share_clause(*this, lits[0], lits[1]); return nullptr; case 3: if (ENABLE_TERNARY) { - return mk_ter_clause(lits, learned); + return mk_ter_clause(lits, st); } default: - return mk_nary_clause(num_lits, lits, learned); + return mk_nary_clause(num_lits, lits, st); } } - void solver::mk_bin_clause(literal l1, literal l2, bool learned) { + void solver::mk_bin_clause(literal l1, literal l2, sat::status st) { + bool redundant = st.is_redundant(); m_touched[l1.var()] = m_touch_index; m_touched[l2.var()] = m_touch_index; - if (learned && find_binary_watch(get_wlist(~l1), ~l2) && value(l1) == l_undef) { + if (redundant && find_binary_watch(get_wlist(~l1), ~l2) && value(l1) == l_undef) { assign_unit(l1); return; } - if (learned && find_binary_watch(get_wlist(~l2), ~l1) && value(l2) == l_undef) { + if (redundant && find_binary_watch(get_wlist(~l2), ~l1) && value(l2) == l_undef) { assign_unit(l2); return; } - watched* w0 = learned ? find_binary_watch(get_wlist(~l1), l2) : nullptr; + watched* w0 = redundant ? find_binary_watch(get_wlist(~l1), l2) : nullptr; if (w0) { TRACE("sat", tout << "found binary " << l1 << " " << l2 << "\n";); - if (w0->is_learned() && !learned) { + if (w0->is_learned() && !redundant) { w0->set_learned(false); w0 = find_binary_watch(get_wlist(~l2), l1); VERIFY(w0); w0->set_learned(false); } - if (propagate_bin_clause(l1, l2) && !learned && !at_base_lvl() && !at_search_lvl()) { + if (propagate_bin_clause(l1, l2) && !redundant && !at_base_lvl() && !at_search_lvl()) { m_clauses_to_reinit.push_back(clause_wrapper(l1, l2)); } return; } if (m_config.m_drat) - m_drat.add(l1, l2, learned); + m_drat.add(l1, l2, st); if (propagate_bin_clause(l1, l2)) { if (at_base_lvl()) return; - if (!learned && !at_search_lvl()) + if (!redundant && !at_search_lvl()) m_clauses_to_reinit.push_back(clause_wrapper(l1, l2)); } m_stats.m_mk_bin_clause++; - get_wlist(~l1).push_back(watched(l2, learned)); - get_wlist(~l2).push_back(watched(l1, learned)); + get_wlist(~l1).push_back(watched(l2, redundant)); + get_wlist(~l2).push_back(watched(l1, redundant)); } bool solver::propagate_bin_clause(literal l1, literal l2) { @@ -451,13 +453,13 @@ namespace sat { } - clause * solver::mk_ter_clause(literal * lits, bool learned) { + clause * solver::mk_ter_clause(literal * lits, sat::status st) { VERIFY(ENABLE_TERNARY); m_stats.m_mk_ter_clause++; - clause * r = alloc_clause(3, lits, learned); + clause * r = alloc_clause(3, lits, st.is_redundant()); bool reinit = attach_ter_clause(*r); - if (reinit && !learned) push_reinit_stack(*r); - if (learned) + if (reinit && !st.is_redundant()) push_reinit_stack(*r); + if (st.is_redundant()) m_learned.push_back(r); else m_clauses.push_back(r); @@ -470,7 +472,7 @@ namespace sat { bool solver::attach_ter_clause(clause & c) { VERIFY(ENABLE_TERNARY); bool reinit = false; - if (m_config.m_drat) m_drat.add(c, c.is_learned()); + if (m_config.m_drat) m_drat.add(c, c.is_learned() ? status::redundant() : status::asserted()); TRACE("sat_verbose", tout << c << "\n";); SASSERT(!c.was_removed()); m_watches[(~c[0]).index()].push_back(watched(c[1], c[2])); @@ -496,20 +498,20 @@ namespace sat { return reinit; } - clause * solver::mk_nary_clause(unsigned num_lits, literal * lits, bool learned) { + clause * solver::mk_nary_clause(unsigned num_lits, literal * lits, sat::status st) { m_stats.m_mk_clause++; - clause * r = alloc_clause(num_lits, lits, learned); - SASSERT(!learned || r->is_learned()); + clause * r = alloc_clause(num_lits, lits, st.is_redundant()); + SASSERT(!st.is_learned() || r->is_learned()); bool reinit = attach_nary_clause(*r); - if (reinit && !learned) push_reinit_stack(*r); - if (learned) { + if (reinit && !st.is_redundant()) push_reinit_stack(*r); + if (st.is_redundant()) { m_learned.push_back(r); } else { m_clauses.push_back(r); } if (m_config.m_drat) { - m_drat.add(*r, learned); + m_drat.add(*r, st); } for (literal l : *r) { m_touched[l.var()] = m_touch_index; @@ -571,15 +573,15 @@ namespace sat { reinit = attach_nary_clause(c); } - void solver::set_learned(clause& c, bool learned) { - if (c.is_learned() != learned) - c.set_learned(learned); + void solver::set_learned(clause& c, bool redundant) { + if (c.is_learned() != redundant) + c.set_learned(redundant); } - void solver::set_learned1(literal l1, literal l2, bool learned) { + void solver::set_learned1(literal l1, literal l2, bool redundant) { for (watched& w : get_wlist(~l1)) { if (w.is_binary_clause() && l2 == w.get_literal() && !w.is_learned()) { - w.set_learned(learned); + w.set_learned(redundant); break; } } @@ -594,7 +596,7 @@ namespace sat { m_touched[l.var()] = m_touch_index; } if (m_config.m_drat) { - m_drat.add(c, true); + m_drat.add(c, status::redundant()); c.restore(old_sz); m_drat.del(c); c.shrink(new_sz); @@ -687,9 +689,9 @@ namespace sat { } - void solver::set_learned(literal l1, literal l2, bool learned) { - set_learned1(l1, l2, learned); - set_learned1(l2, l1, learned); + void solver::set_learned(literal l1, literal l2, bool redundant) { + set_learned1(l1, l2, redundant); + set_learned1(l2, l1, redundant); } /** @@ -809,9 +811,9 @@ namespace sat { return simplify_clause_core(num_lits, lits); } - void solver::detach_bin_clause(literal l1, literal l2, bool learned) { - get_wlist(~l1).erase(watched(l2, learned)); - get_wlist(~l2).erase(watched(l1, learned)); + void solver::detach_bin_clause(literal l1, literal l2, bool redundant) { + get_wlist(~l1).erase(watched(l2, redundant)); + get_wlist(~l2).erase(watched(l1, redundant)); if (m_config.m_drat) m_drat.del(l1, l2); } @@ -2727,7 +2729,7 @@ namespace sat { if (m_lemma.empty()) { pop_reinit(m_scope_lvl); - mk_clause_core(0, nullptr, true); + mk_clause_core(0, nullptr, sat::status::redundant()); return; } @@ -2777,7 +2779,7 @@ namespace sat { ++m_stats.m_backtracks; pop_reinit(m_scope_lvl - backtrack_lvl + 1); } - clause * lemma = mk_clause_core(m_lemma.size(), m_lemma.c_ptr(), true); + clause * lemma = mk_clause_core(m_lemma.size(), m_lemma.c_ptr(), sat::status::redundant()); if (lemma) { lemma->set_glue(glue); } @@ -3795,9 +3797,9 @@ namespace sat { } } - bool_var solver::max_var(bool learned, bool_var v) { + bool_var solver::max_var(bool redundant, bool_var v) { m_user_bin_clauses.reset(); - collect_bin_clauses(m_user_bin_clauses, learned, false); + collect_bin_clauses(m_user_bin_clauses, redundant, false); for (unsigned i = 0; i < m_user_bin_clauses.size(); ++i) { literal l1 = m_user_bin_clauses[i].first; literal l2 = m_user_bin_clauses[i].second; @@ -3977,8 +3979,8 @@ namespace sat { // Iterators // // ----------------------- - void solver::collect_bin_clauses(svector & r, bool learned, bool learned_only) const { - SASSERT(learned || !learned_only); + void solver::collect_bin_clauses(svector & r, bool redundant, bool learned_only) const { + SASSERT(redundant || !learned_only); unsigned sz = m_watches.size(); for (unsigned l_idx = 0; l_idx < sz; l_idx++) { literal l = to_literal(l_idx); @@ -3986,9 +3988,9 @@ namespace sat { for (watched const& w : m_watches[l_idx]) { if (!w.is_binary_clause()) continue; - if (!learned && w.is_learned()) + if (!redundant && w.is_learned()) continue; - else if (learned && learned_only && !w.is_learned()) + else if (redundant && learned_only && !w.is_learned()) continue; literal l2 = w.get_literal(); if (l.index() > l2.index()) @@ -4121,14 +4123,14 @@ namespace sat { return num_cls + m_clauses.size() + m_learned.size(); } - void solver::num_binary(unsigned& given, unsigned& learned) const { - given = learned = 0; + void solver::num_binary(unsigned& given, unsigned& redundant) const { + given = redundant = 0; unsigned l_idx = 0; for (auto const& wl : m_watches) { literal l = ~to_literal(l_idx++); for (auto const& w : wl) { if (w.is_binary_clause() && l.index() < w.get_literal().index()) { - if (w.is_learned()) ++learned; else ++given; + if (w.is_learned()) ++redundant; else ++given; } } } @@ -4288,10 +4290,10 @@ namespace sat { return false; } - void solver::simplify(bool learned) { + void solver::simplify(bool redundant) { if (!at_base_lvl() || inconsistent()) return; - m_simplifier(learned); + m_simplifier(redundant); m_simplifier.finalize(); if (m_ext) m_ext->clauses_modifed(); @@ -4921,10 +4923,10 @@ namespace sat { } void mk_stat::display(std::ostream & out) const { - unsigned given, learned; - m_solver.num_binary(given, learned); + unsigned given, redundant; + m_solver.num_binary(given, redundant); out << " " << std::setw(5) << m_solver.m_clauses.size() + given << "/" << given; - out << " " << std::setw(5) << (m_solver.m_learned.size() + learned - m_solver.m_num_frozen) << "/" << learned; + out << " " << std::setw(5) << (m_solver.m_learned.size() + redundant - m_solver.m_num_frozen) << "/" << redundant; out << " " << std::setw(3) << m_solver.init_trail_size(); out << " " << std::setw(7) << m_solver.m_stats.m_gc_clause << " "; out << " " << std::setw(7) << mem_stat(); diff --git a/src/sat/sat_solver.h b/src/sat/sat_solver.h index 07871d43c..2fd9cbbc1 100644 --- a/src/sat/sat_solver.h +++ b/src/sat/sat_solver.h @@ -249,15 +249,15 @@ namespace sat { // Variable & Clause creation // // ----------------------- - void add_clause(unsigned num_lits, literal * lits, bool learned) override { mk_clause(num_lits, lits, learned); } + void add_clause(unsigned num_lits, literal * lits, sat::status st) override { mk_clause(num_lits, lits, st); } bool_var add_var(bool ext) override { return mk_var(ext, true); } bool_var mk_var(bool ext = false, bool dvar = true); - clause* mk_clause(literal_vector const& lits, bool learned = false) { return mk_clause(lits.size(), lits.c_ptr(), learned); } - clause* mk_clause(unsigned num_lits, literal * lits, bool learned = false); - clause* mk_clause(literal l1, literal l2, bool learned = false); - clause* mk_clause(literal l1, literal l2, literal l3, bool learned = false); + clause* mk_clause(literal_vector const& lits, sat::status st = sat::status::asserted()) { return mk_clause(lits.size(), lits.c_ptr(), st); } + clause* mk_clause(unsigned num_lits, literal * lits, sat::status st = sat::status::asserted()); + clause* mk_clause(literal l1, literal l2, sat::status st = sat::status::asserted()); + clause* mk_clause(literal l1, literal l2, literal l3, sat::status st = sat::status::asserted()); random_gen& rand() { return m_rand; } @@ -271,15 +271,16 @@ namespace sat { bool should_defrag(); bool memory_pressure(); void del_clause(clause & c); - clause * mk_clause_core(unsigned num_lits, literal * lits, bool learned); + clause * mk_clause_core(unsigned num_lits, literal * lits, sat::status st); clause * mk_clause_core(literal_vector const& lits) { return mk_clause_core(lits.size(), lits.c_ptr()); } - clause * mk_clause_core(unsigned num_lits, literal * lits) { return mk_clause_core(num_lits, lits, false); } + clause * mk_clause_core(unsigned num_lits, literal * lits) { return mk_clause_core(num_lits, lits, sat::status::asserted()); } void mk_clause_core(literal l1, literal l2) { literal lits[2] = { l1, l2 }; mk_clause_core(2, lits); } - void mk_bin_clause(literal l1, literal l2, bool learned); + void mk_bin_clause(literal l1, literal l2, sat::status st); + void mk_bin_clause(literal l1, literal l2, bool learned) { mk_bin_clause(l1, l2, learned ? sat::status::redundant() : sat::status::asserted()); } bool propagate_bin_clause(literal l1, literal l2); - clause * mk_ter_clause(literal * lits, bool learned); + clause * mk_ter_clause(literal * lits, status st); bool attach_ter_clause(clause & c); - clause * mk_nary_clause(unsigned num_lits, literal * lits, bool learned); + clause * mk_nary_clause(unsigned num_lits, literal * lits, status st); bool attach_nary_clause(clause & c); void attach_clause(clause & c, bool & reinit); void attach_clause(clause & c) { bool reinit; attach_clause(c, reinit); } diff --git a/src/sat/sat_solver_core.h b/src/sat/sat_solver_core.h index 89164f857..372b1308e 100644 --- a/src/sat/sat_solver_core.h +++ b/src/sat/sat_solver_core.h @@ -55,14 +55,14 @@ namespace sat { virtual char const* get_reason_unknown() const { return "reason unavailable"; } // add clauses - virtual void add_clause(unsigned n, literal* lits, bool is_redundant) = 0; - void add_clause(literal l1, literal l2, bool is_redundant) { + virtual void add_clause(unsigned n, literal* lits, status st) = 0; + void add_clause(literal l1, literal l2, status st) { literal lits[2] = {l1, l2}; - add_clause(2, lits, is_redundant); + add_clause(2, lits, st); } - void add_clause(literal l1, literal l2, literal l3, bool is_redundant) { + void add_clause(literal l1, literal l2, literal l3, status st) { literal lits[3] = {l1, l2, l3}; - add_clause(3, lits, is_redundant); + add_clause(3, lits, st); } // create boolean variable, tagged as external (= true) or internal (can be eliminated). virtual bool_var add_var(bool ext) = 0; diff --git a/src/sat/sat_types.h b/src/sat/sat_types.h index 94ca8a85f..258875ea9 100644 --- a/src/sat/sat_types.h +++ b/src/sat/sat_types.h @@ -251,5 +251,31 @@ namespace sat { virtual double get_priority(bool_var v) const { return 0; } }; + + class status { + public: + enum class st { asserted, redundant, deleted } st; + enum class orig { sat, ba, euf } orig; + status(enum st s, enum orig o) : st(s), orig(o) {}; + status(status const& s) : st(s.st), orig(s.orig) {} + status(status&& s) noexcept { st = st::asserted; orig = orig::sat; std::swap(st, s.st); std::swap(orig, s.orig); } + static status redundant() { return status(status::st::redundant, status::orig::sat); } + static status asserted() { return status(status::st::asserted, status::orig::sat); } + static status deleted() { return status(status::st::deleted, status::orig::sat); } + + static status euf_learned() { return status(status::st::redundant, status::orig::euf); } + static status euf_asserted() { return status(status::st::asserted, status::orig::euf); } + + static status ba(bool redundant) { return redundant ? ba_redundant() : ba_asserted(); } + static status ba_redundant() { return status(status::st::redundant, status::orig::ba); } + static status ba_asserted() { return status(status::st::asserted, status::orig::ba); } + + bool is_redundant() const { return st::redundant == st; } + bool is_asserted() const { return st::asserted == st; } + bool is_deleted() const { return st::deleted == st; } + bool operator==(status const& s) const { return s.orig == orig && s.st == st; } + }; + + }; diff --git a/src/sat/smt/CMakeLists.txt b/src/sat/smt/CMakeLists.txt index e695de0b8..e9468b415 100644 --- a/src/sat/smt/CMakeLists.txt +++ b/src/sat/smt/CMakeLists.txt @@ -9,6 +9,7 @@ z3_add_component(sat_smt euf_model.cpp euf_proof.cpp euf_solver.cpp + sat_th.cpp COMPONENT_DEPENDENCIES sat ast diff --git a/src/sat/smt/ba_solver.cpp b/src/sat/smt/ba_solver.cpp index a92989aee..da6f9ae0e 100644 --- a/src/sat/smt/ba_solver.cpp +++ b/src/sat/smt/ba_solver.cpp @@ -19,6 +19,7 @@ Author: #include "util/mpz.h" #include "sat/sat_types.h" #include "sat/smt/ba_solver.h" +#include "sat/smt/euf_solver.h" #include "sat/sat_simplifier_params.hpp" #include "sat/sat_xor_finder.h" @@ -359,7 +360,7 @@ namespace sat { } if (p.k() == 1 && p.lit() == null_literal) { literal_vector lits(p.literals()); - s().mk_clause(lits.size(), lits.c_ptr(), p.learned()); + s().mk_clause(lits.size(), lits.c_ptr(), status::ba(p.learned())); IF_VERBOSE(100, display(verbose_stream() << "add clause: " << lits << "\n", p, true);); remove_constraint(p, "implies clause"); } @@ -411,7 +412,7 @@ namespace sat { if (k == 1 && p.lit() == null_literal) { literal_vector lits(sz, p.literals().c_ptr()); - s().mk_clause(sz, lits.c_ptr(), p.learned()); + s().mk_clause(sz, lits.c_ptr(), status::ba(p.learned())); remove_constraint(p, "is clause"); return; } @@ -794,7 +795,7 @@ namespace sat { else if (k == 1 && p.lit() == null_literal) { literal_vector lits(sz, p.literals().c_ptr()); - s().mk_clause(sz, lits.c_ptr(), p.learned()); + s().mk_clause(sz, lits.c_ptr(), status::ba(p.learned())); remove_constraint(p, "recompiled to clause"); return; } @@ -1597,7 +1598,7 @@ namespace sat { TRACE("ba", tout << m_lemma << "\n";); if (get_config().m_drat && m_solver) { - s().m_drat.add(m_lemma, sat::drat::status::ba); + s().m_drat.add(m_lemma, sat::status::ba_redundant()); } s().m_lemma.reset(); @@ -1719,8 +1720,12 @@ namespace sat { return p; } - ba_solver::ba_solver(ast_manager& m, sat_internalizer& si) - : m(m), si(si), m_pb(m), + ba_solver::ba_solver(euf::solver& ctx, euf::theory_id id) : + ba_solver(ctx.get_manager(), ctx.get_si(), id) {} + + ba_solver::ba_solver(ast_manager& m, sat::sat_internalizer& si, euf::theory_id id) + : euf::th_solver(m, id), + si(si), m_pb(m), m_solver(nullptr), m_lookahead(nullptr), m_constraint_id(0), m_ba(*this), m_sort(m_ba) { TRACE("ba", tout << this << "\n";); @@ -1745,7 +1750,7 @@ namespace sat { ba_solver::constraint* ba_solver::add_at_least(literal lit, literal_vector const& lits, unsigned k, bool learned) { if (k == 1 && lit == null_literal) { literal_vector _lits(lits); - s().mk_clause(_lits.size(), _lits.c_ptr(), learned); + s().mk_clause(_lits.size(), _lits.c_ptr(), status::ba(learned)); return nullptr; } if (!learned && clausify(lit, lits.size(), lits.c_ptr(), k)) { @@ -2135,7 +2140,7 @@ namespace sat { for (literal lit : r) lits.push_back(~lit); lits.push_back(l); - s().m_drat.add(lits, sat::drat::status::ba); + s().m_drat.add(lits, sat::status::ba_redundant()); } } @@ -2894,7 +2899,7 @@ namespace sat { if (k == 1 && c.lit() == null_literal) { literal_vector lits(sz, c.literals().c_ptr()); - s().mk_clause(sz, lits.c_ptr(), c.learned()); + s().mk_clause(sz, lits.c_ptr(), sat::status::ba(c.learned())); remove_constraint(c, "recompiled to clause"); return; } @@ -2902,27 +2907,27 @@ namespace sat { if (sz == 0) { if (c.lit() == null_literal) { if (k > 0) { - s().mk_clause(0, nullptr, true); + s().mk_clause(0, nullptr, status::ba_asserted()); } } else if (k > 0) { literal lit = ~c.lit(); - s().mk_clause(1, &lit, c.learned()); + s().mk_clause(1, &lit, status::ba(c.learned())); } else { literal lit = c.lit(); - s().mk_clause(1, &lit, c.learned()); + s().mk_clause(1, &lit, status::ba(c.learned())); } remove_constraint(c, "recompiled to clause"); return; } if (all_units && sz < k) { if (c.lit() == null_literal) { - s().mk_clause(0, nullptr, true); + s().mk_clause(0, nullptr, status::ba_redundant()); } else { literal lit = ~c.lit(); - s().mk_clause(1, &lit, c.learned()); + s().mk_clause(1, &lit, status::ba(c.learned())); } remove_constraint(c, "recompiled to clause"); return; @@ -3719,12 +3724,16 @@ namespace sat { } extension* ba_solver::copy(solver* s) { - return fresh(s, m, si); + return fresh(s, m, si, m_id); } - th_solver* ba_solver::fresh(solver* s, ast_manager& m, sat_internalizer& si) { - ba_solver* result = alloc(ba_solver, m, si); - result->set_solver(s); + euf::th_solver* ba_solver::fresh(solver* new_s, euf::solver& new_ctx) { + return fresh(new_s, new_ctx.get_manager(), new_ctx.get_si(), get_id()); + } + + euf::th_solver* ba_solver::fresh(solver* new_s, ast_manager& m, sat::sat_internalizer& si, euf::theory_id id) { + ba_solver* result = alloc(ba_solver, m, si, id); + result->set_solver(new_s); copy_constraints(result, m_constraints); return result; } diff --git a/src/sat/smt/ba_solver.h b/src/sat/smt/ba_solver.h index ecc06d3c2..8ff891cb5 100644 --- a/src/sat/smt/ba_solver.h +++ b/src/sat/smt/ba_solver.h @@ -35,7 +35,7 @@ namespace sat { class xor_finder; - class ba_solver : public th_solver { + class ba_solver : public euf::th_solver { friend class local_search; @@ -232,7 +232,6 @@ namespace sat { bool contains(literal l) const { for (auto wl : m_wlits) if (wl.second == l) return true; return false; } }; - ast_manager& m; sat_internalizer& si; pb_util m_pb; @@ -305,6 +304,9 @@ namespace sat { bool_vector m_root_vars; unsigned_vector m_weights; svector m_wlits; + + euf::th_solver* ba_solver::fresh(sat::solver* new_s, ast_manager& m, sat::sat_internalizer& si, euf::theory_id id); + bool subsumes(card& c1, card& c2, literal_vector& comp); bool subsumes(card& c1, clause& c2, bool& self); bool subsumed(card& c1, literal l1, literal l2); @@ -563,7 +565,8 @@ namespace sat { expr_ref get_xor(std::function& l2e, ba_solver::xr const& x); public: - ba_solver(ast_manager& m, sat_internalizer& si); + ba_solver(euf::solver& ctx, euf::theory_id id); + ba_solver(ast_manager& m, sat::sat_internalizer& si, euf::theory_id id); ~ba_solver() override; void set_solver(solver* s) override { m_solver = s; } void set_lookahead(lookahead* l) override { m_lookahead = l; } @@ -602,7 +605,7 @@ namespace sat { literal internalize(expr* e, bool sign, bool root, bool redundant) override; bool to_formulas(std::function& l2e, expr_ref_vector& fmls) override; - th_solver* fresh(solver* s, ast_manager& m, sat_internalizer& si) override; + euf::th_solver* fresh(solver* s, euf::solver& ctx) override; ptr_vector const & constraints() const { return m_constraints; } std::ostream& display(std::ostream& out, constraint const& c, bool values) const; diff --git a/src/sat/smt/euf_ackerman.cpp b/src/sat/smt/euf_ackerman.cpp index 66e58c123..6c8266459 100644 --- a/src/sat/smt/euf_ackerman.cpp +++ b/src/sat/smt/euf_ackerman.cpp @@ -194,7 +194,7 @@ namespace euf { } expr_ref eq(m.mk_eq(a, b), m); lits.push_back(s.internalize(eq, false, false, true)); - s.s().mk_clause(lits, true); + s.s().mk_clause(lits, sat::status::euf_learned()); } void ackerman::add_eq(expr* a, expr* b, expr* c) { @@ -205,6 +205,6 @@ namespace euf { lits[0] = s.internalize(eq1, true, false, true); lits[1] = s.internalize(eq2, true, false, true); lits[2] = s.internalize(eq3, false, false, true); - s.s().mk_clause(3, lits, true); + s.s().mk_clause(3, lits, sat::status::euf_learned()); } } diff --git a/src/sat/smt/euf_internalize.cpp b/src/sat/smt/euf_internalize.cpp index 727d3a806..5d615590a 100644 --- a/src/sat/smt/euf_internalize.cpp +++ b/src/sat/smt/euf_internalize.cpp @@ -22,11 +22,11 @@ Author: namespace euf { - sat::literal solver::internalize(expr* e, bool sign, bool root, bool learned) { - flet _is_learned(m_is_redundant, learned); + sat::literal solver::internalize(expr* e, bool sign, bool root, bool redundant) { + flet _is_learned(m_is_redundant, redundant); auto* ext = get_solver(e); if (ext) - return ext->internalize(e, sign, root, learned); + return ext->internalize(e, sign, root, redundant); IF_VERBOSE(110, verbose_stream() << "internalize: " << mk_pp(e, m) << "\n"); SASSERT(!si.is_bool_op(e)); sat::scoped_stack _sc(m_stack); @@ -103,8 +103,8 @@ namespace euf { if (lit.sign()) { sat::bool_var v = si.add_bool_var(n->get_owner()); sat::literal lit2 = literal(v, false); - s().mk_clause(~lit, lit2, false); - s().mk_clause(lit, ~lit2, false); + s().mk_clause(~lit, lit2, sat::status::euf_asserted()); + s().mk_clause(lit, ~lit2, sat::status::euf_asserted()); lit = lit2; } sat::bool_var v = lit.var(); @@ -132,6 +132,7 @@ namespace euf { if (sz <= 1) return; + sat::status st = m_is_redundant ? sat::status::euf_learned() : sat::status::euf_asserted(); static const unsigned distinct_max_args = 32; if (sz <= distinct_max_args) { sat::literal_vector lits; @@ -142,7 +143,7 @@ namespace euf { lits.push_back(lit); } } - s().mk_clause(lits, false); + s().mk_clause(lits, st); } else { // g(f(x_i)) = x_i @@ -160,13 +161,13 @@ namespace euf { expr_ref gapp(m.mk_app(g, fapp.get()), m); expr_ref eq(m.mk_eq(gapp, arg), m); sat::literal lit = internalize(eq, false, false, m_is_redundant); - s().add_clause(1, &lit, m_is_redundant); + s().add_clause(1, &lit, st); eqs.push_back(m.mk_eq(fapp, a)); } pb_util pb(m); expr_ref at_least2(pb.mk_at_least_k(eqs.size(), eqs.c_ptr(), 2), m); sat::literal lit = si.internalize(at_least2, m_is_redundant); - s().mk_clause(1, &lit, m_is_redundant); + s().mk_clause(1, &lit, st); } } @@ -174,8 +175,9 @@ namespace euf { SASSERT(m.is_distinct(e)); static const unsigned distinct_max_args = 32; unsigned sz = e->get_num_args(); + sat::status st = m_is_redundant ? sat::status::euf_learned() : sat::status::euf_asserted(); if (sz <= 1) { - s().mk_clause(0, nullptr, m_is_redundant); + s().mk_clause(0, nullptr, st); return; } if (sz <= distinct_max_args) { @@ -183,7 +185,7 @@ namespace euf { for (unsigned j = i + 1; j < sz; ++j) { expr_ref eq(m.mk_eq(args[i]->get_owner(), args[j]->get_owner()), m); sat::literal lit = internalize(eq, true, false, m_is_redundant); - s().add_clause(1, &lit, m_is_redundant); + s().add_clause(1, &lit, st); } } } @@ -200,13 +202,14 @@ namespace euf { n->mark_interpreted(); expr_ref eq(m.mk_eq(fapp, fresh), m); sat::literal lit = internalize(eq, false, false, m_is_redundant); - s().add_clause(1, &lit, m_is_redundant); + s().add_clause(1, &lit, st); } } } void solver::axiomatize_basic(enode* n) { expr* e = n->get_owner(); + sat::status st = m_is_redundant ? sat::status::euf_learned() : sat::status::euf_asserted(); if (m.is_ite(e)) { app* a = to_app(e); expr* c = a->get_arg(0); @@ -221,8 +224,8 @@ namespace euf { sat::literal lit_el = internalize(eq_el, false, false, m_is_redundant); literal lits1[2] = { literal(v, true), lit_th }; literal lits2[2] = { literal(v, false), lit_el }; - s().add_clause(2, lits1, m_is_redundant); - s().add_clause(2, lits2, m_is_redundant); + s().add_clause(2, lits1, st); + s().add_clause(2, lits2, st); } else if (m.is_distinct(e)) { expr_ref_vector eqs(m); @@ -238,8 +241,8 @@ namespace euf { sat::literal some_eq = si.internalize(fml, m_is_redundant); sat::literal lits1[2] = { ~dist, ~some_eq }; sat::literal lits2[2] = { dist, some_eq }; - s().add_clause(2, lits1, m_is_redundant); - s().add_clause(2, lits2, m_is_redundant); + s().add_clause(2, lits1, st); + s().add_clause(2, lits2, st); } } diff --git a/src/sat/smt/euf_model.cpp b/src/sat/smt/euf_model.cpp index 95ae940c5..a56d8f4cb 100644 --- a/src/sat/smt/euf_model.cpp +++ b/src/sat/smt/euf_model.cpp @@ -37,7 +37,7 @@ namespace euf { return true; if (f->get_family_id() == m.get_basic_family_id()) return false; - sat::th_model_builder* mb = get_solver(f); + euf::th_model_builder* mb = get_solver(f); return mb && mb->include_func_interp(f); } diff --git a/src/sat/smt/euf_proof.cpp b/src/sat/smt/euf_proof.cpp index 0ed4a965f..d50314eba 100644 --- a/src/sat/smt/euf_proof.cpp +++ b/src/sat/smt/euf_proof.cpp @@ -49,7 +49,7 @@ namespace euf { for (literal lit : r) lits.push_back(~lit); if (l != sat::null_literal) lits.push_back(l); - s().get_drat().add(lits, sat::drat::status::euf); + s().get_drat().add(lits, sat::status::euf_learned()); } } diff --git a/src/sat/smt/euf_solver.cpp b/src/sat/smt/euf_solver.cpp index 0680ea612..8b0003ece 100644 --- a/src/sat/smt/euf_solver.cpp +++ b/src/sat/smt/euf_solver.cpp @@ -32,7 +32,7 @@ namespace euf { /** * retrieve extension that is associated with Boolean variable. */ - sat::th_solver* solver::get_solver(sat::bool_var v) { + th_solver* solver::get_solver(sat::bool_var v) { if (v >= m_var2node.size()) return nullptr; euf::enode* n = m_var2node[v]; @@ -41,13 +41,13 @@ namespace euf { return get_solver(n->get_owner()); } - sat::th_solver* solver::get_solver(expr* e) { + th_solver* solver::get_solver(expr* e) { if (is_app(e)) return get_solver(to_app(e)->get_decl()); return nullptr; } - sat::th_solver* solver::get_solver(func_decl* f) { + th_solver* solver::get_solver(func_decl* f) { family_id fid = f->get_family_id(); if (fid == null_family_id) return nullptr; @@ -58,7 +58,7 @@ namespace euf { return nullptr; pb_util pb(m); if (pb.get_family_id() == fid) { - ext = alloc(sat::ba_solver, m, si); + ext = alloc(sat::ba_solver, *this, fid); } if (ext) { ext->set_solver(m_solver); @@ -71,7 +71,7 @@ namespace euf { return ext; } - void solver::add_solver(family_id fid, sat::th_solver* th) { + void solver::add_solver(family_id fid, th_solver* th) { m_solvers.push_back(th); m_id2solver.setx(fid, th, nullptr); } @@ -173,43 +173,52 @@ namespace euf { propagate(); } - void solver::propagate() { - m_egraph.propagate(); - unsigned lvl = s().scope_lvl(); + void solver::propagate() { + while (m_egraph.propagate() && !s().inconsistent()) { + if (m_egraph.inconsistent()) { + unsigned lvl = s().scope_lvl(); + s().set_conflict(sat::justification::mk_ext_justification(lvl, conflict_constraint().to_index())); + return; + } + propagate_literals(); + propagate_th_eqs(); + } + } - if (m_egraph.inconsistent()) { - s().set_conflict(sat::justification::mk_ext_justification(lvl, conflict_constraint().to_index())); - return; - } - for (euf::enode* eq : m_egraph.new_eqs()) { - bool_var v = m_expr2var.to_bool_var(eq->get_owner()); + void solver::propagate_literals() { + for (; m_egraph.has_literal() && !s().inconsistent() && !m_egraph.inconsistent(); m_egraph.next_literal()) { + euf::enode_bool_pair p = m_egraph.get_literal(); + euf::enode* n = p.first; + bool is_eq = p.second; + expr* e = n->get_owner(); expr* a = nullptr, *b = nullptr; - if (s().value(v) == l_false && m_ackerman && m.is_eq(eq->get_owner(), a, b)) - m_ackerman->cg_conflict_eh(a, b); - literal lit(v, false); - if (s().value(lit) == l_true) - continue; - s().assign(literal(v, false), sat::justification::mk_ext_justification(lvl, eq_constraint().to_index())); - if (s().inconsistent()) - return; - } - for (euf::enode* p : m_egraph.new_lits()) { - expr* e = p->get_owner(); - bool sign = m.is_false(p->get_root()->get_owner()); - SASSERT(m.is_bool(e)); - SASSERT(m.is_true(p->get_root()->get_owner()) || sign); bool_var v = m_expr2var.to_bool_var(e); - literal lit(v, sign); - if (s().value(lit) == l_true) - continue; + SASSERT(m.is_bool(e)); + size_t cnstr; + literal lit; + if (is_eq) { + VERIFY(m.is_eq(e, a, b)); + cnstr = eq_constraint().to_index(); + lit = literal(v, false); + } + else { + a = e, b = n->get_root()->get_owner(); + SASSERT(m.is_true(a) || m.is_false(b)); + cnstr = lit_constraint().to_index(); + lit = literal(v, m.is_false(b)); + } if (s().value(lit) == l_false && m_ackerman) - m_ackerman->cg_conflict_eh(p->get_owner(), p->get_root()->get_owner()); - s().assign(lit, sat::justification::mk_ext_justification(lvl, lit_constraint().to_index())); - if (s().inconsistent()) - return; + m_ackerman->cg_conflict_eh(a, b); + unsigned lvl = s().scope_lvl(); + if (s().value(lit) != l_true) + s().assign(lit, sat::justification::mk_ext_justification(lvl, cnstr)); } - for (euf::th_eq const& eq : m_egraph.new_th_eqs()) { - // m_id2solver[eq.m_id]->new_eq_eh(eq); + } + + void solver::propagate_th_eqs() { + for (; m_egraph.has_th_eq() && !s().inconsistent() && !m_egraph.inconsistent(); m_egraph.next_th_eq()) { + th_eq eq = m_egraph.get_th_eq(); + m_id2solver[eq.m_id]->new_eq_eh(eq); } } @@ -347,7 +356,7 @@ namespace euf { for (unsigned i = 0; i < m_id2solver.size(); ++i) { auto* e = m_id2solver[i]; if (e) - r->add_solver(i, e->fresh(s, *m_to_m, *m_to_si)); + r->add_solver(i, e->fresh(s, *r)); } return r; } diff --git a/src/sat/smt/euf_solver.h b/src/sat/smt/euf_solver.h index 27bc0c413..7dfb70e27 100644 --- a/src/sat/smt/euf_solver.h +++ b/src/sat/smt/euf_solver.h @@ -48,9 +48,10 @@ namespace euf { size_t to_index() const { return sat::constraint_base::mem2base(this); } }; - class solver : public sat::extension, public sat::th_internalizer, public sat::th_decompile { + class solver : public sat::extension, public th_internalizer, public th_decompile { typedef top_sort deps_t; friend class ackerman; + // friend class sat::ba_solver; struct stats { unsigned m_num_dynack; stats() { reset(); } @@ -87,8 +88,8 @@ namespace euf { unsigned m_num_scopes { 0 }; unsigned_vector m_var_trail; svector m_scopes; - scoped_ptr_vector m_solvers; - ptr_vector m_id2solver; + scoped_ptr_vector m_solvers; + ptr_vector m_id2solver; constraint* m_conflict { nullptr }; constraint* m_eq { nullptr }; @@ -110,10 +111,10 @@ namespace euf { euf::enode* mk_false(); // extensions - sat::th_solver* get_solver(func_decl* f); - sat::th_solver* get_solver(expr* e); - sat::th_solver* get_solver(sat::bool_var v); - void add_solver(family_id fid, sat::th_solver* th); + th_solver* get_solver(func_decl* f); + th_solver* get_solver(expr* e); + th_solver* get_solver(sat::bool_var v); + void add_solver(family_id fid, th_solver* th); void unhandled_function(func_decl* f); void init_ackerman(); @@ -126,6 +127,8 @@ namespace euf { // solving void propagate(); + void propagate_literals(); + void propagate_th_eqs(); void get_antecedents(literal l, constraint& j, literal_vector& r); void force_push(); void log_antecedents(std::ostream& out, literal l, literal_vector const& r); @@ -175,6 +178,10 @@ namespace euf { } }; + sat::sat_internalizer& get_si() { return si; } + ast_manager& get_manager() { return m; } + enode* get_enode(expr* e) { return m_egraph.find(e); } + void updt_params(params_ref const& p); void set_solver(sat::solver* s) override { m_solver = s; m_drat = s->get_config().m_drat; } void set_lookahead(sat::lookahead* s) override { m_lookahead = s; } @@ -214,7 +221,6 @@ namespace euf { sat::literal internalize(expr* e, bool sign, bool root, bool learned) override; void update_model(model_ref& mdl); - func_decl_ref_vector const& unhandled_functions() { return m_unhandled_functions; } - + func_decl_ref_vector const& unhandled_functions() { return m_unhandled_functions; } }; }; diff --git a/src/sat/smt/sat_th.h b/src/sat/smt/sat_th.h index 2f3bd56f2..4d1f93c8e 100644 --- a/src/sat/smt/sat_th.h +++ b/src/sat/smt/sat_th.h @@ -20,13 +20,15 @@ Author: #include "sat/smt/sat_smt.h" #include "ast/euf/euf_egraph.h" -namespace sat { +namespace euf { + + class solver; class th_internalizer { public: virtual ~th_internalizer() {} - virtual literal internalize(expr* e, bool sign, bool root, bool redundant) = 0; + virtual sat::literal internalize(expr* e, bool sign, bool root, bool redundant) = 0; }; class th_decompile { @@ -58,14 +60,49 @@ namespace sat { virtual bool include_func_interp(func_decl* f) const { return false; } }; - class th_solver : public extension, public th_model_builder, public th_decompile, public th_internalizer { + class th_solver : public sat::extension, public th_model_builder, public th_decompile, public th_internalizer { + protected: + ast_manager & m; + euf::theory_id m_id; public: - virtual ~th_solver() {} + th_solver(ast_manager& m, euf::theory_id id): m(m), m_id(id) {} - virtual th_solver* fresh(solver* s, ast_manager& m, sat_internalizer& si) = 0; + virtual th_solver* fresh(sat::solver* s, euf::solver& ctx) = 0; virtual void new_eq_eh(euf::th_eq const& eq) {} }; + class th_euf_solver : public th_solver { + protected: + solver & ctx; + euf::enode_vector m_var2enode; + unsigned_vector m_var2enode_lim; + public: + virtual ~th_euf_solver() {} + + th_euf_solver(euf::solver& ctx, euf::theory_id id); + + virtual euf::theory_var mk_var(enode * n) { + SASSERT(!is_attached_to_var(n)); + euf::theory_var v = m_var2enode.size(); + m_var2enode.push_back(n); + return v; + } + + enode* get_enode(theory_var v) const { return m_var2enode[v]; } + + euf::theory_var get_th_var(expr* e) const; + + euf::theory_var get_th_var(euf::enode* n) const { + return n->get_th_var(get_id()); + } + + bool is_attached_to_var(enode* n) const { + theory_var v = n->get_th_var(get_id()); + return v != null_theory_var && get_enode(v) == n; + } + + }; + } diff --git a/src/sat/tactic/goal2sat.cpp b/src/sat/tactic/goal2sat.cpp index 9d5d6a1e8..954933379 100644 --- a/src/sat/tactic/goal2sat.cpp +++ b/src/sat/tactic/goal2sat.cpp @@ -41,6 +41,7 @@ Notes: #include "sat/tactic/goal2sat.h" #include "sat/smt/ba_solver.h" #include "sat/smt/euf_solver.h" +#include "sat/smt/sat_th.h" #include "sat/sat_params.hpp" #include @@ -106,22 +107,22 @@ struct goal2sat::imp : public sat::sat_internalizer { void mk_clause(sat::literal l) { TRACE("goal2sat", tout << "mk_clause: " << l << "\n";); - m_solver.add_clause(1, &l, m_is_redundant); + m_solver.add_clause(1, &l, m_is_redundant ? sat::status::redundant() : sat::status::asserted()); } void mk_clause(sat::literal l1, sat::literal l2) { TRACE("goal2sat", tout << "mk_clause: " << l1 << " " << l2 << "\n";); - m_solver.add_clause(l1, l2, m_is_redundant); + m_solver.add_clause(l1, l2, m_is_redundant ? sat::status::redundant() : sat::status::asserted()); } void mk_clause(sat::literal l1, sat::literal l2, sat::literal l3) { TRACE("goal2sat", tout << "mk_clause: " << l1 << " " << l2 << " " << l3 << "\n";); - m_solver.add_clause(l1, l2, l3, m_is_redundant); + m_solver.add_clause(l1, l2, l3, m_is_redundant ? sat::status::redundant() : sat::status::asserted()); } void mk_clause(unsigned num, sat::literal * lits) { TRACE("goal2sat", tout << "mk_clause: "; for (unsigned i = 0; i < num; i++) tout << lits[i] << " "; tout << "\n";); - m_solver.add_clause(num, lits, m_is_redundant); + m_solver.add_clause(num, lits, m_is_redundant ? sat::status::redundant() : sat::status::asserted()); } sat::literal mk_true() { @@ -509,18 +510,17 @@ struct goal2sat::imp : public sat::sat_internalizer { void convert_ba(app* t, bool root, bool sign) { SASSERT(!m_euf); sat::extension* ext = m_solver.get_extension(); - sat::ba_solver* ba = nullptr; + euf::th_solver* th = nullptr; if (!ext) { - ba = alloc(sat::ba_solver, m, *this); - m_solver.set_extension(ba); - ba->push_scopes(m_solver.num_scopes()); + th = alloc(sat::ba_solver, m, *this, pb.get_family_id()); + m_solver.set_extension(th); + th->push_scopes(m_solver.num_scopes()); } else { - ba = dynamic_cast(ext); + th = dynamic_cast(ext); + SASSERT(th); } - if (!ba) - throw default_exception("cannot convert to pb"); - sat::literal lit = ba->internalize(t, sign, root, m_is_redundant); + auto lit = th->internalize(t, sign, root, m_is_redundant); if (root) m_result_stack.reset(); else