diff --git a/src/ast/euf/euf_egraph.cpp b/src/ast/euf/euf_egraph.cpp index 666a7fad6..1448ee8d7 100644 --- a/src/ast/euf/euf_egraph.cpp +++ b/src/ast/euf/euf_egraph.cpp @@ -36,8 +36,10 @@ namespace euf { } m_expr2enode.setx(f->get_id(), n, nullptr); push_node(n); - for (unsigned i = 0; i < num_args; ++i) - set_merge_enabled(args[i], true); + for (unsigned i = 0; i < num_args; ++i) { + set_cgc_enabled(args[i], true); + set_merge_tf_enabled(args[i], true); + } return n; } @@ -78,9 +80,8 @@ namespace euf { void egraph::reinsert_equality(enode* p) { SASSERT(p->is_equality()); - if (p->value() != l_true && p->get_arg(0)->get_root() == p->get_arg(1)->get_root()) { - add_literal(p, true); - } + if (p->value() != l_true && p->get_arg(0)->get_root() == p->get_arg(1)->get_root()) + add_literal(p, nullptr); } void egraph::force_push() { @@ -116,18 +117,16 @@ namespace euf { m_on_make(n); if (num_args == 0) return n; - if (m.is_eq(f)) { + if (m.is_eq(f) && !m.is_iff(f)) { n->set_is_equality(); - update_children(n); reinsert_equality(n); } - else { - auto [n2, comm] = insert_table(n); - if (n2 == n) - update_children(n); - else - merge(n, n2, justification::congruence(comm, m_congruence_timestamp++)); - } + auto [n2, comm] = insert_table(n); + if (n2 == n) + update_children(n); + else + merge(n, n2, justification::congruence(comm, m_congruence_timestamp++)); + return n; } @@ -158,11 +157,11 @@ namespace euf { ++m_stats.m_num_th_diseqs; } - void egraph::add_literal(enode* n, bool is_eq) { + void egraph::add_literal(enode* n, enode* ante) { TRACE("euf_verbose", tout << "lit: " << n->get_expr_id() << "\n";); - m_new_lits.push_back(enode_bool_pair(n, is_eq)); + m_new_lits.push_back(enode_pair(n, ante)); m_updates.push_back(update_record(update_record::new_lit())); - if (is_eq) ++m_stats.m_num_eqs; else ++m_stats.m_num_lits; + if (!ante) ++m_stats.m_num_eqs; else ++m_stats.m_num_lits; } void egraph::new_diseq(enode* n) { @@ -173,7 +172,7 @@ namespace euf { enode* r2 = arg2->get_root(); TRACE("euf", tout << "new-diseq: " << bpp(r1) << " " << bpp(r2) << ": " << r1->has_th_vars() << " " << r2->has_th_vars() << "\n";); if (r1 == r2) { - add_literal(n, true); + add_literal(n, nullptr); return; } if (!r1->has_th_vars()) @@ -264,10 +263,26 @@ namespace euf { root->del_th_var(tid); } - void egraph::set_merge_enabled(enode* n, bool enable_merge) { - if (enable_merge != n->merge_enabled()) { - toggle_merge_enabled(n, false); - m_updates.push_back(update_record(n, update_record::toggle_merge())); + void egraph::set_merge_tf_enabled(enode* n, bool enable_merge_tf) { + if (!m.is_bool(n->get_sort())) + return; + if (enable_merge_tf != n->merge_tf()) { + n->set_merge_tf(enable_merge_tf); + m_updates.push_back(update_record(n, update_record::toggle_merge_tf())); + if (enable_merge_tf && n->value() != l_undef && !m.is_value(n->get_root()->get_expr())) { + expr* b = n->value() == l_true ? m.mk_true() : m.mk_false(); + enode* tf = find(b); + if (!tf) + tf = mk(b, 0, 0, nullptr); + add_literal(n, tf); + } + } + } + + void egraph::set_cgc_enabled(enode* n, bool enable_merge) { + if (enable_merge != n->cgc_enabled()) { + toggle_cgc_enabled(n, false); + m_updates.push_back(update_record(n, update_record::toggle_cgc())); } } @@ -278,9 +293,9 @@ namespace euf { m_updates.push_back(update_record(n, update_record::set_relevant())); } - void egraph::toggle_merge_enabled(enode* n, bool backtracking) { - bool enable_merge = !n->merge_enabled(); - n->set_merge_enabled(enable_merge); + void egraph::toggle_cgc_enabled(enode* n, bool backtracking) { + bool enable_merge = !n->cgc_enabled(); + n->set_cgc_enabled(enable_merge); if (n->num_args() > 0) { if (enable_merge) { auto [n2, comm] = insert_table(n); @@ -290,7 +305,7 @@ namespace euf { else if (n->is_cgr()) erase_from_table(n); } - VERIFY(n->num_args() == 0 || !n->merge_enabled() || m_table.contains(n)); + VERIFY(n->num_args() == 0 || !n->cgc_enabled() || m_table.contains(n)); } void egraph::set_value(enode* n, lbool value, justification j) { @@ -300,6 +315,8 @@ namespace euf { n->set_value(value); n->m_lit_justification = j; m_updates.push_back(update_record(n, update_record::value_assignment())); + if (n->is_equality() && n->value() == l_false) + new_diseq(n); } } @@ -352,8 +369,11 @@ namespace euf { case update_record::tag_t::is_add_node: undo_node(); break; - case update_record::tag_t::is_toggle_merge: - toggle_merge_enabled(p.r1, true); + case update_record::tag_t::is_toggle_cgc: + toggle_cgc_enabled(p.r1, true); + break; + case update_record::tag_t::is_toggle_merge_tf: + p.r1->set_merge_tf(!p.r1->merge_tf()); break; case update_record::tag_t::is_set_parent: undo_eq(p.r1, p.n1, p.r2_num_parents); @@ -419,7 +439,7 @@ namespace euf { void egraph::merge(enode* n1, enode* n2, justification j) { - if (!n1->merge_enabled() && !n2->merge_enabled()) + if (!n1->cgc_enabled() && !n2->cgc_enabled()) return; SASSERT(n1->get_sort() == n2->get_sort()); enode* r1 = n1->get_root(); @@ -436,6 +456,7 @@ namespace euf { set_conflict(n1, n2, j); return; } + if (r1->value() != r2->value() && r1->value() != l_undef && r2->value() != l_undef) { SASSERT(m.is_bool(r1->get_expr())); set_conflict(n1, n2, j); @@ -448,9 +469,11 @@ namespace euf { } if (j.is_congruence() && (m.is_false(r2->get_expr()) || m.is_true(r2->get_expr()))) - add_literal(n1, false); - if (n1->is_equality() && n1->value() == l_false) - new_diseq(n1); + add_literal(n1, r2); + if (r2->value() != l_undef && n1->value() == l_undef) + add_literal(n1, r2); + else if (r1->value() != l_undef && n2->value() == l_undef) + add_literal(n2, r1); remove_parents(r1); push_eq(r1, n1, r2->num_parents()); merge_justification(n1, n2, j); @@ -468,7 +491,7 @@ namespace euf { for (enode* p : enode_parents(r)) { if (p->is_marked1()) continue; - if (p->merge_enabled()) { + if (p->cgc_enabled()) { if (!p->is_cgr()) continue; SASSERT(m_table.contains_ptr(p)); @@ -486,8 +509,8 @@ namespace euf { if (!p->is_marked1()) continue; p->unmark1(); - TRACE("euf", tout << "reinsert " << bpp(r1) << " " << bpp(r2) << " " << bpp(p) << " " << p->merge_enabled() << "\n";); - if (p->merge_enabled()) { + TRACE("euf", tout << "reinsert " << bpp(r1) << " " << bpp(r2) << " " << bpp(p) << " " << p->cgc_enabled() << "\n";); + if (p->cgc_enabled()) { auto [p_other, comm] = insert_table(p); SASSERT(m_table.contains_ptr(p) == (p_other == p)); TRACE("euf", tout << "other " << bpp(p_other) << "\n";); @@ -531,9 +554,9 @@ namespace euf { for (auto it = begin; it != end; ++it) { enode* p = *it; TRACE("euf", tout << "erase " << bpp(p) << "\n";); - SASSERT(!p->merge_enabled() || m_table.contains_ptr(p)); - SASSERT(!p->merge_enabled() || p->is_cgr()); - if (p->merge_enabled()) + SASSERT(!p->cgc_enabled() || m_table.contains_ptr(p)); + SASSERT(!p->cgc_enabled() || p->is_cgr()); + if (p->cgc_enabled()) erase_from_table(p); } @@ -541,7 +564,7 @@ namespace euf { c->m_root = r1; for (enode* p : enode_parents(r1)) - if (p->merge_enabled() && (p->is_cgr() || !p->congruent(p->m_cg))) + if (p->cgc_enabled() && (p->is_cgr() || !p->congruent(p->m_cg))) insert_table(p); r2->m_parents.shrink(r2_num_parents); unmerge_justification(n1); @@ -783,7 +806,7 @@ namespace euf { for (enode* n : m_nodes) n->invariant(*this); for (enode* n : m_nodes) - if (n->merge_enabled() && n->num_args() > 0 && (!m_table.find(n) || n->get_root() != m_table.find(n)->get_root())) { + if (n->cgc_enabled() && n->num_args() > 0 && (!m_table.find(n) || n->get_root() != m_table.find(n)->get_root())) { CTRACE("euf", !m_table.find(n), tout << "node is not in table\n";); CTRACE("euf", m_table.find(n), tout << "root " << bpp(n->get_root()) << " table root " << bpp(m_table.find(n)->get_root()) << "\n";); TRACE("euf", display(tout << bpp(n) << " is not closed under congruence\n");); @@ -818,7 +841,7 @@ namespace euf { } }; if (n->bool_var() != sat::null_bool_var) - out << "[b" << n->bool_var() << " := " << value_of() << (n->merge_tf() ? "" : " no merge") << "] "; + out << "[b" << n->bool_var() << " := " << value_of() << (n->cgc_enabled() ? "" : " no-cgc") << (n->merge_tf()? " merge-tf" : "") << "] "; if (n->has_th_vars()) { out << "[t"; for (auto const& v : enode_th_vars(n)) @@ -873,7 +896,8 @@ namespace euf { n2->set_value(n1->value()); n2->m_bool_var = n1->m_bool_var; n2->m_commutative = n1->m_commutative; - n2->m_merge_enabled = n1->m_merge_enabled; + n2->m_cgc_enabled = n1->m_cgc_enabled; + n2->m_merge_tf_enabled = n1->m_merge_tf_enabled; n2->m_is_equality = n1->m_is_equality; } for (unsigned i = 0; i < src.m_nodes.size(); ++i) { diff --git a/src/ast/euf/euf_egraph.h b/src/ast/euf/euf_egraph.h index d6bcb9cd3..1686be384 100644 --- a/src/ast/euf/euf_egraph.h +++ b/src/ast/euf/euf_egraph.h @@ -101,7 +101,8 @@ namespace euf { void reset() { memset(this, 0, sizeof(*this)); } }; struct update_record { - struct toggle_merge {}; + struct toggle_cgc {}; + struct toggle_merge_tf {}; struct add_th_var {}; struct replace_th_var {}; struct new_lit {}; @@ -114,7 +115,7 @@ namespace euf { struct lbl_set {}; struct update_children {}; struct set_relevant {}; - enum class tag_t { is_set_parent, is_add_node, is_toggle_merge, is_update_children, + enum class tag_t { is_set_parent, is_add_node, is_toggle_cgc, is_toggle_merge_tf, is_update_children, is_add_th_var, is_replace_th_var, is_new_lit, is_new_th_eq, is_lbl_hash, is_new_th_eq_qhead, is_new_lits_qhead, is_inconsistent, is_value_assignment, is_lbl_set, is_set_relevant }; @@ -136,8 +137,10 @@ namespace euf { tag(tag_t::is_set_parent), r1(r1), n1(n1), r2_num_parents(r2_num_parents) {} update_record(enode* n) : tag(tag_t::is_add_node), r1(n), n1(nullptr), r2_num_parents(UINT_MAX) {} - update_record(enode* n, toggle_merge) : - tag(tag_t::is_toggle_merge), r1(n), n1(nullptr), r2_num_parents(UINT_MAX) {} + update_record(enode* n, toggle_cgc) : + tag(tag_t::is_toggle_cgc), r1(n), n1(nullptr), r2_num_parents(UINT_MAX) {} + update_record(enode* n, toggle_merge_tf) : + tag(tag_t::is_toggle_merge_tf), r1(n), n1(nullptr), r2_num_parents(UINT_MAX) {} update_record(enode* n, unsigned id, add_th_var) : tag(tag_t::is_add_th_var), r1(n), n1(nullptr), r2_num_parents(id) {} update_record(enode* n, theory_id id, theory_var v, replace_th_var) : @@ -186,7 +189,7 @@ namespace euf { justification m_justification; unsigned m_new_lits_qhead = 0; unsigned m_new_th_eqs_qhead = 0; - svector m_new_lits; + svector m_new_lits; svector m_new_th_eqs; bool_vector m_th_propagates_diseqs; enode_vector m_todo; @@ -210,7 +213,7 @@ namespace euf { void add_th_diseqs(theory_id id, theory_var v1, enode* r); bool th_propagates_diseqs(theory_id id) const; - void add_literal(enode* n, bool is_eq); + void add_literal(enode* n, enode* ante); void undo_eq(enode* r1, enode* n1, unsigned r2_num_parents); void undo_add_th_var(enode* n, theory_id id); enode* mk_enode(expr* f, unsigned generation, unsigned num_args, enode * const* args); @@ -229,7 +232,7 @@ namespace euf { void push_to_lca(enode* a, enode* lca); void push_congruence(enode* n1, enode* n2, bool commutative); void push_todo(enode* n); - void toggle_merge_enabled(enode* n, bool backtracking); + void toggle_cgc_enabled(enode* n, bool backtracking); enode_bool_pair insert_table(enode* p); void erase_from_table(enode* p); @@ -289,7 +292,7 @@ namespace euf { void add_th_diseq(theory_id id, theory_var v1, theory_var v2, expr* eq); bool has_literal() const { return m_new_lits_qhead < m_new_lits.size(); } bool has_th_eq() const { return m_new_th_eqs_qhead < m_new_th_eqs.size(); } - enode_bool_pair get_literal() const { return m_new_lits[m_new_lits_qhead]; } + enode_pair get_literal() const { return m_new_lits[m_new_lits_qhead]; } th_eq get_th_eq() const { return m_new_th_eqs[m_new_th_eqs_qhead]; } void next_literal() { force_push(); SASSERT(m_new_lits_qhead < m_new_lits.size()); m_new_lits_qhead++; } void next_th_eq() { force_push(); SASSERT(m_new_th_eqs_qhead < m_new_th_eqs.size()); m_new_th_eqs_qhead++; } @@ -299,7 +302,9 @@ namespace euf { void add_th_var(enode* n, theory_var v, theory_id id); void set_th_propagates_diseqs(theory_id id); - void set_merge_enabled(enode* n, bool enable_merge); + void set_cgc_enabled(enode* n, bool enable_cgc); + void set_merge_tf_enabled(enode* n, bool enable_merge_tf); + void set_value(enode* n, lbool value, justification j); void set_bool_var(enode* n, unsigned v) { n->set_bool_var(v); } void set_relevant(enode* n); diff --git a/src/ast/euf/euf_enode.cpp b/src/ast/euf/euf_enode.cpp index 038325790..08df9f493 100644 --- a/src/ast/euf/euf_enode.cpp +++ b/src/ast/euf/euf_enode.cpp @@ -36,7 +36,7 @@ namespace euf { if (is_root()) { VERIFY(!m_target); for (enode* p : enode_parents(this)) { - if (!p->merge_enabled()) + if (!p->cgc_enabled()) continue; bool found = false; for (enode* arg : enode_args(p)) { @@ -49,7 +49,7 @@ namespace euf { if (c == this) continue; for (enode* p : enode_parents(c)) { - if (!p->merge_enabled()) + if (!p->cgc_enabled()) continue; bool found = false; for (enode* q : enode_parents(this)) { diff --git a/src/ast/euf/euf_enode.h b/src/ast/euf/euf_enode.h index 18a1a86af..d9ae45074 100644 --- a/src/ast/euf/euf_enode.h +++ b/src/ast/euf/euf_enode.h @@ -48,7 +48,8 @@ namespace euf { bool m_mark3 = false; bool m_commutative = false; bool m_interpreted = false; - bool m_merge_enabled = true; + bool m_cgc_enabled = true; + bool m_merge_tf_enabled = false; bool m_is_equality = false; // Does the expression represent an equality bool m_is_relevant = false; lbool m_value = l_undef; // Assignment by SAT solver for Boolean node @@ -91,7 +92,7 @@ namespace euf { n->m_generation = generation, n->m_commutative = num_args == 2 && is_app(f) && to_app(f)->get_decl()->is_commutative(); n->m_num_args = num_args; - n->m_merge_enabled = true; + n->m_cgc_enabled = true; for (unsigned i = 0; i < num_args; ++i) { SASSERT(to_app(f)->get_arg(i) == args[i]->get_expr()); n->m_args[i] = args[i]; @@ -107,7 +108,7 @@ namespace euf { n->m_root = n; n->m_commutative = true; n->m_num_args = 2; - n->m_merge_enabled = true; + n->m_cgc_enabled = true; for (unsigned i = 0; i < num_args; ++i) n->m_args[i] = nullptr; return n; @@ -121,7 +122,7 @@ namespace euf { n->m_root = n; n->m_commutative = true; n->m_num_args = 2; - n->m_merge_enabled = true; + n->m_cgc_enabled = true; for (unsigned i = 0; i < num_args; ++i) n->m_args[i] = nullptr; return n; @@ -132,7 +133,8 @@ namespace euf { void add_th_var(theory_var v, theory_id id, region & r) { m_th_vars.add_var(v, id, r); } void replace_th_var(theory_var v, theory_id id) { m_th_vars.replace(v, id); } void del_th_var(theory_id id) { m_th_vars.del_var(id); } - void set_merge_enabled(bool m) { m_merge_enabled = m; } + void set_cgc_enabled(bool m) { m_cgc_enabled = m; } + void set_merge_tf(bool m) { m_merge_tf_enabled = m; } void set_value(lbool v) { m_value = v; } void set_justification(justification j) { m_justification = j; } void set_is_equality() { m_is_equality = true; } @@ -152,14 +154,13 @@ namespace euf { bool is_relevant() const { return m_is_relevant; } void set_relevant(bool b) { m_is_relevant = b; } lbool value() const { return m_value; } - bool value_conflict() const { return value() != l_undef && get_root()->value() != l_undef && value() != get_root()->value(); } sat::bool_var bool_var() const { return m_bool_var; } bool is_cgr() const { return this == m_cg; } enode* get_cg() const { return m_cg; } bool commutative() const { return m_commutative; } void mark_interpreted() { SASSERT(num_args() == 0); m_interpreted = true; } - bool merge_enabled() const { return m_merge_enabled; } - bool merge_tf() const { return merge_enabled() && (class_size() > 1 || num_parents() > 0 || num_args() > 0); } + bool cgc_enabled() const { return m_cgc_enabled; } + bool merge_tf() const { return m_merge_tf_enabled && (class_size() > 1 || num_parents() > 0 || num_args() > 0); } enode* get_arg(unsigned i) const { SASSERT(i < num_args()); return m_args[i]; } unsigned hash() const { return m_expr->get_id(); } diff --git a/src/ast/rewriter/th_rewriter.cpp b/src/ast/rewriter/th_rewriter.cpp index 4c7d4dc49..9bc8c22cf 100644 --- a/src/ast/rewriter/th_rewriter.cpp +++ b/src/ast/rewriter/th_rewriter.cpp @@ -707,6 +707,8 @@ struct th_rewriter_cfg : public default_rewriter_cfg { expr_ref mk_eq(expr* a, expr* b) { expr_ref result(m()); + if (a->get_id() > b->get_id()) + std::swap(a, b); if (BR_FAILED == reduce_eq(a, b, result)) result = m().mk_eq(a, b); return result; diff --git a/src/ast/simplifiers/euf_completion.cpp b/src/ast/simplifiers/euf_completion.cpp index ac1ecc47a..1ca980d07 100644 --- a/src/ast/simplifiers/euf_completion.cpp +++ b/src/ast/simplifiers/euf_completion.cpp @@ -211,8 +211,8 @@ namespace euf { if (x == y) return expr_ref(m.mk_true(), m); - if (x == x1 && y == y1) - return expr_ref(f, m); + if (x == x1 && y == y1) + return m_rewriter.mk_eq(x, y); if (is_nullary(x) && is_nullary(y)) return mk_and(m_rewriter.mk_eq(x, x1), m_rewriter.mk_eq(y, x1)); @@ -268,7 +268,8 @@ namespace euf { m_eargs.push_back(get_canonical(arg, d)); change |= arg != m_eargs.back(); } - + if (m.is_eq(f)) + return m_rewriter.mk_eq(m_eargs.get(0), m_eargs.get(1)); if (!change) return expr_ref(f, m); else diff --git a/src/sat/smt/arith_internalize.cpp b/src/sat/smt/arith_internalize.cpp index 04a3ae4ef..d35f79954 100644 --- a/src/sat/smt/arith_internalize.cpp +++ b/src/sat/smt/arith_internalize.cpp @@ -372,7 +372,7 @@ namespace arith { enode* n = ctx.get_enode(atom); theory_var w = mk_var(n); ctx.attach_th_var(n, this, w); - ctx.get_egraph().set_merge_enabled(n, false); + ctx.get_egraph().set_cgc_enabled(n, false); if (is_int(v) && !r.is_int()) r = (k == lp_api::upper_t) ? floor(r) : ceil(r); api_bound* b = mk_var_bound(lit, v, k, r); diff --git a/src/sat/smt/euf_internalize.cpp b/src/sat/smt/euf_internalize.cpp index 9f747090a..3aaafa36c 100644 --- a/src/sat/smt/euf_internalize.cpp +++ b/src/sat/smt/euf_internalize.cpp @@ -74,10 +74,8 @@ namespace euf { } if (auto* ext = expr2solver(e)) return ext->internalize(e, sign, root); - if (!visit_rec(m, e, sign, root)) { - TRACE("euf", tout << "visit-rec\n";); + if (!visit_rec(m, e, sign, root)) return sat::null_literal; - } SASSERT(get_enode(e)); if (m.is_bool(e)) return literal(si.to_bool_var(e), sign); @@ -119,7 +117,7 @@ namespace euf { SASSERT(!get_enode(e)); if (auto* s = expr2solver(e)) s->internalize(e); - else + else attach_node(mk_enode(e, num, m_args.data())); return true; } @@ -188,6 +186,7 @@ namespace euf { return lit; } + set_bool_var2expr(v, e); enode* n = m_egraph.find(e); if (!n) @@ -195,8 +194,8 @@ namespace euf { CTRACE("euf", n->bool_var() != sat::null_bool_var && n->bool_var() != v, display(tout << bpp(n) << " " << n->bool_var() << " vs " << v << "\n")); SASSERT(n->bool_var() == sat::null_bool_var || n->bool_var() == v); m_egraph.set_bool_var(n, v); - if (m.is_eq(e) || m.is_or(e) || m.is_and(e) || m.is_not(e)) - m_egraph.set_merge_enabled(n, false); + if (si.is_bool_op(e)) + m_egraph.set_cgc_enabled(n, false); lbool val = s().value(lit); if (val != l_undef) m_egraph.set_value(n, val, justification::external(to_ptr(val == l_true ? lit : ~lit))); @@ -349,15 +348,6 @@ namespace euf { else if (m.is_eq(e, th, el) && !m.is_iff(e)) { sat::literal lit1 = expr2literal(e); s().set_phase(lit1); - expr_ref e2(m.mk_eq(el, th), m); - enode* n2 = m_egraph.find(e2); - if (n2) { - sat::literal lit2 = expr2literal(e2); - add_root(~lit1, lit2); - add_root(lit1, ~lit2); - s().add_clause(~lit1, lit2, mk_distinct_status(~lit1, lit2)); - s().add_clause(lit1, ~lit2, mk_distinct_status(lit1, ~lit2)); - } } } @@ -476,26 +466,15 @@ namespace euf { return n; } - euf::enode* solver::mk_enode(expr* e, unsigned n, enode* const* args) { - euf::enode* r = m_egraph.mk(e, m_generation, n, args); - for (unsigned i = 0; i < n; ++i) - ensure_merged_tf(args[i]); - return r; - } + euf::enode* solver::mk_enode(expr* e, unsigned num, enode* const* args) { - void solver::ensure_merged_tf(euf::enode* n) { - switch (n->value()) { - case l_undef: - break; - case l_true: - if (n->get_root() != mk_true()) - m_egraph.merge(n, mk_true(), to_ptr(sat::literal(n->bool_var()))); - break; - case l_false: - if (n->get_root() != mk_false()) - m_egraph.merge(n, mk_false(), to_ptr(~sat::literal(n->bool_var()))); - break; - } + if (si.is_bool_op(e)) + num = 0; + + enode* n = m_egraph.mk(e, m_generation, num, args); + if (si.is_bool_op(e)) + m_egraph.set_cgc_enabled(n, false); + return n; } } diff --git a/src/sat/smt/euf_solver.cpp b/src/sat/smt/euf_solver.cpp index 7b02509ee..9b346543f 100644 --- a/src/sat/smt/euf_solver.cpp +++ b/src/sat/smt/euf_solver.cpp @@ -316,13 +316,23 @@ namespace euf { SASSERT(!l.sign()); m_egraph.explain_eq(m_explain, cc, n->get_arg(0), n->get_arg(1)); break; - case constraint::kind_t::lit: + case constraint::kind_t::lit: { e = m_bool_var2expr[l.var()]; n = m_egraph.find(e); + enode* ante = j.node(); SASSERT(n); SASSERT(m.is_bool(n->get_expr())); - m_egraph.explain_eq(m_explain, cc, n, (l.sign() ? mk_false() : mk_true())); + SASSERT(ante->get_root() == n->get_root()); + m_egraph.explain_eq(m_explain, cc, n, ante); + if (!m.is_true(ante->get_expr()) && !m.is_false(ante->get_expr())) { + bool_var v = ante->bool_var(); + lbool val = ante->value(); + SASSERT(val != l_undef); + literal ante(v, val == l_false); + m_explain.push_back(to_ptr(ante)); + } break; + } default: IF_VERBOSE(0, verbose_stream() << (unsigned)j.kind() << "\n"); UNREACHABLE(); @@ -345,24 +355,30 @@ namespace euf { euf::enode* n = m_egraph.find(e); if (!n) return; - bool sign = l.sign(); - m_egraph.set_value(n, sign ? l_false : l_true, justification::external(to_ptr(l))); + bool sign = l.sign(); + lbool old_value = n->value(); + lbool new_value = sign ? l_false : l_true; + m_egraph.set_value(n, new_value, justification::external(to_ptr(l))); + if (old_value == l_undef && n->cgc_enabled()) { + for (enode* k : enode_class(n)) { + if (k->bool_var() == sat::null_bool_var) + continue; + if (k->value() == new_value) + continue; + auto& c = lit_constraint(n); + propagate(literal(k->bool_var(), sign), c.to_index()); + if (k->value() == l_undef) + m_egraph.set_value(k, new_value, justification::external(to_ptr(l))); + else + return; + } + } for (auto const& th : enode_th_vars(n)) m_id2solver[th.get_id()]->asserted(l); size_t* c = to_ptr(l); SASSERT(is_literal(c)); SASSERT(l == get_literal(c)); - if (n->value_conflict()) { - euf::enode* nb = sign ? mk_false() : mk_true(); - euf::enode* r = n->get_root(); - euf::enode* rb = sign ? mk_true() : mk_false(); - sat::literal rl(r->bool_var(), r->value() == l_false); - m_egraph.merge(n, nb, c); - m_egraph.merge(r, rb, to_ptr(rl)); - SASSERT(m_egraph.inconsistent()); - return; - } if (n->merge_tf()) { euf::enode* nb = sign ? mk_false() : mk_true(); m_egraph.merge(n, nb, c); @@ -374,9 +390,17 @@ namespace euf { m_egraph.new_diseq(n); else m_egraph.merge(n->get_arg(0), n->get_arg(1), c); - } + } } + constraint& solver::lit_constraint(enode* n) { + void* mem = get_region().allocate(sat::constraint_base::obj_size(sizeof(constraint))); + auto* c = new (sat::constraint_base::ptr2mem(mem)) constraint(n); + sat::constraint_base::initialize(mem, this); + return *c; + } + + bool solver::unit_propagate() { bool propagated = false; @@ -412,37 +436,44 @@ namespace euf { void solver::propagate_literals() { for (; m_egraph.has_literal() && !s().inconsistent() && !m_egraph.inconsistent(); m_egraph.next_literal()) { - auto [n, is_eq] = m_egraph.get_literal(); + auto [n, ante] = m_egraph.get_literal(); expr* e = n->get_expr(); expr* a = nullptr, *b = nullptr; bool_var v = n->bool_var(); SASSERT(m.is_bool(e)); size_t cnstr; - literal lit; - if (is_eq) { + literal lit; + if (!ante) { VERIFY(m.is_eq(e, a, b)); cnstr = eq_constraint().to_index(); lit = literal(v, false); } else { - lbool val = n->get_root()->value(); - if (val == l_undef && m.is_false(n->get_root()->get_expr())) - val = l_false; - if (val == l_undef && m.is_true(n->get_root()->get_expr())) - val = l_true; - a = e; - b = (val == l_true) ? m.mk_true() : m.mk_false(); - SASSERT(val != l_undef); - cnstr = lit_constraint().to_index(); + // + // There are the following three cases for propagation of literals + // + // 1. n == ante is true from equallity, ante = true/false + // 2. n == ante is true from equality, value(ante) != l_undef + // 3. value(n) != l_undef, ante = true/false, merge_tf is set on n + // + lbool val = ante->value(); + if (val == l_undef) { + SASSERT(m.is_value(ante->get_expr())); + val = m.is_true(ante->get_expr()) ? l_true : l_false; + } + auto& c = lit_constraint(ante); + cnstr = c.to_index(); lit = literal(v, val == l_false); } unsigned lvl = s().scope_lvl(); - CTRACE("euf", s().value(lit) != l_true, tout << lit << " " << s().value(lit) << "@" << lvl << " " << is_eq << " " << mk_bounded_pp(a, m) << " = " << mk_bounded_pp(b, m) << "\n";); - if (s().value(lit) == l_false && m_ackerman) + CTRACE("euf", s().value(lit) != l_true, tout << lit << " " << s().value(lit) << "@" << lvl << " " << mk_bounded_pp(a, m) << " = " << mk_bounded_pp(b, m) << "\n";); + if (s().value(lit) == l_false && m_ackerman && a && b) m_ackerman->cg_conflict_eh(a, b); switch (s().value(lit)) { case l_true: + if (n->merge_tf() && !m.is_value(n->get_root()->get_expr())) + m_egraph.merge(n, ante, to_ptr(lit)); break; case l_undef: case l_false: @@ -889,7 +920,7 @@ namespace euf { if (m.is_eq(e) && !m.is_iff(e)) ok = false; euf::enode* n = get_enode(e); - if (n && n->merge_enabled()) + if (n && n->cgc_enabled()) ok = false; (void)ok; @@ -938,7 +969,7 @@ namespace euf { case constraint::kind_t::eq: return out << "euf equality propagation"; case constraint::kind_t::lit: - return out << "euf literal propagation"; + return out << "euf literal propagation " << m_egraph.bpp(c.node()) ; default: UNREACHABLE(); return out; diff --git a/src/sat/smt/euf_solver.h b/src/sat/smt/euf_solver.h index f935ad9ff..5b09e6a46 100644 --- a/src/sat/smt/euf_solver.h +++ b/src/sat/smt/euf_solver.h @@ -45,9 +45,12 @@ namespace euf { enum class kind_t { conflict, eq, lit }; private: kind_t m_kind; + enode* m_node = nullptr; public: constraint(kind_t k) : m_kind(k) {} + constraint(enode* n): m_kind(kind_t::lit), m_node(n) {} kind_t kind() const { return m_kind; } + enode* node() const { SASSERT(kind() == kind_t::lit); return m_node; } static constraint& from_idx(size_t z) { return *reinterpret_cast(sat::constraint_base::idx2mem(z)); } @@ -171,7 +174,6 @@ namespace euf { void add_not_distinct_axiom(app* e, euf::enode* const* args); void axiomatize_basic(enode* n); bool internalize_root(app* e, bool sign, ptr_vector const& args); - void ensure_merged_tf(euf::enode* n); euf::enode* mk_true(); euf::enode* mk_false(); @@ -250,7 +252,7 @@ namespace euf { constraint& mk_constraint(constraint*& c, constraint::kind_t k); constraint& conflict_constraint() { return mk_constraint(m_conflict, constraint::kind_t::conflict); } constraint& eq_constraint() { return mk_constraint(m_eq, constraint::kind_t::eq); } - constraint& lit_constraint() { return mk_constraint(m_lit, constraint::kind_t::lit); } + constraint& lit_constraint(enode* n); // user propagator void check_for_user_propagator() {