diff --git a/src/sat/smt/euf_internalize.cpp b/src/sat/smt/euf_internalize.cpp index 020d29baa..f60c3597d 100644 --- a/src/sat/smt/euf_internalize.cpp +++ b/src/sat/smt/euf_internalize.cpp @@ -342,7 +342,9 @@ namespace euf { break; } } - + if (m.is_bool(n->get_expr()) && th_id != m.get_basic_family_id()) + return true; + for (enode* parent : euf::enode_parents(n)) { app* p = to_app(parent->get_expr()); family_id fid = p->get_family_id(); diff --git a/src/sat/smt/euf_solver.cpp b/src/sat/smt/euf_solver.cpp index 5d9b96958..1f0523dde 100644 --- a/src/sat/smt/euf_solver.cpp +++ b/src/sat/smt/euf_solver.cpp @@ -300,7 +300,7 @@ namespace euf { size_t* c = to_ptr(l); SASSERT(is_literal(c)); SASSERT(l == get_literal(c)); - if (n->value_conflict()) { + if (n->value_conflict()) { euf::enode* nb = sign ? mk_false() : mk_true(); euf::enode* r = n->get_root(); euf::enode* rb = sign ? mk_true() : mk_false(); @@ -458,6 +458,8 @@ namespace euf { give_up = true; unsigned num_nodes = m_egraph.num_nodes(); + if (merge_shared_bools()) + cont = true; for (auto* e : m_solvers) { if (!m.inc()) return sat::check_result::CR_GIVEUP; @@ -485,6 +487,23 @@ namespace euf { return sat::check_result::CR_DONE; } + bool solver::merge_shared_bools() { + bool merged = false; + for (euf::enode* n : m_egraph.nodes()) { + if (!is_shared(n) || !m.is_bool(n->get_expr())) + continue; + if (n->value() == l_true && !m.is_true(n->get_root()->get_expr())) { + m_egraph.merge(n, mk_true(), to_ptr(sat::literal(n->bool_var()))); + merged = true; + } + if (n->value() == l_false && !m.is_false(n->get_root()->get_expr())) { + m_egraph.merge(n, mk_false(), to_ptr(~sat::literal(n->bool_var()))); + merged = true; + } + } + return merged; + } + void solver::push() { si.push(); scope s; diff --git a/src/sat/smt/euf_solver.h b/src/sat/smt/euf_solver.h index d88a832ec..db928259c 100644 --- a/src/sat/smt/euf_solver.h +++ b/src/sat/smt/euf_solver.h @@ -165,6 +165,7 @@ namespace euf { bool is_self_propagated(th_eq const& e); void get_antecedents(literal l, constraint& j, literal_vector& r, bool probing); void new_diseq(enode* a, enode* b, literal lit); + bool merge_shared_bools(); // proofs void log_antecedents(std::ostream& out, literal l, literal_vector const& r); @@ -286,6 +287,7 @@ namespace euf { void propagate(literal lit, th_explain* p) { propagate(lit, p->to_index()); } bool propagate(enode* a, enode* b, th_explain* p) { return propagate(a, b, p->to_index()); } + size_t* to_justification(sat::literal l) { return to_ptr(l); } void set_conflict(th_explain* p) { set_conflict(p->to_index()); } bool set_root(literal l, literal r) override;