diff --git a/src/ast/euf/euf_egraph.cpp b/src/ast/euf/euf_egraph.cpp index 8a3a4e002..652663b70 100644 --- a/src/ast/euf/euf_egraph.cpp +++ b/src/ast/euf/euf_egraph.cpp @@ -293,11 +293,12 @@ namespace euf { VERIFY(n->num_args() == 0 || !n->merge_enabled() || m_table.contains(n)); } - void egraph::set_value(enode* n, lbool value) { + void egraph::set_value(enode* n, lbool value, justification j) { if (n->value() == l_undef) { force_push(); TRACE("euf", tout << bpp(n) << " := " << value << "\n";); n->set_value(value); + n->set_justification(j); m_updates.push_back(update_record(n, update_record::value_assignment())); } } @@ -657,6 +658,7 @@ namespace euf { push_lca(n1->get_arg(1), n2->get_arg(0)); return; } + TRACE("euf_verbose", tout << bpp(n1) << " " << bpp(n2) << "\n"); for (unsigned i = 0; i < n1->num_args(); ++i) push_lca(n1->get_arg(i), n2->get_arg(i)); @@ -713,6 +715,15 @@ namespace euf { explain_todo(justifications); } + template + void egraph::explain_eq(ptr_vector& justifications, enode* a, enode* b, justification const& j) { + if (j.is_external()) + justifications.push_back(j.ext()); + else if (j.is_congruence()) + push_congruence(a, b, j.is_commutative()); + } + + template void egraph::explain_eq(ptr_vector& justifications, enode* a, enode* b) { SASSERT(a->get_root() == b->get_root()); @@ -746,11 +757,21 @@ namespace euf { void egraph::explain_todo(ptr_vector& justifications) { for (unsigned i = 0; i < m_todo.size(); ++i) { enode* n = m_todo[i]; - if (n->m_target && !n->is_marked1()) { + if (n->is_marked1()) + continue; + if (n->m_target) { n->mark1(); CTRACE("euf_verbose", m_display_justification, n->m_justification.display(tout << n->get_expr_id() << " = " << n->m_target->get_expr_id() << " ", m_display_justification) << "\n";); explain_eq(justifications, n, n->m_target, n->m_justification); } + else if (!n->is_marked1() && n->value() != l_undef) { + n->mark1(); + if (m.is_true(n->get_expr()) || m.is_false(n->get_expr())) + continue; + justification j = n->m_justification; + SASSERT(j.is_external()); + justifications.push_back(j.ext()); + } } } diff --git a/src/ast/euf/euf_egraph.h b/src/ast/euf/euf_egraph.h index a91dbf4a4..55f94f0f2 100644 --- a/src/ast/euf/euf_egraph.h +++ b/src/ast/euf/euf_egraph.h @@ -226,12 +226,8 @@ namespace euf { void erase_from_table(enode* p); template - void explain_eq(ptr_vector& justifications, enode* a, enode* b, justification const& j) { - if (j.is_external()) - justifications.push_back(j.ext()); - else if (j.is_congruence()) - push_congruence(a, b, j.is_commutative()); - } + void explain_eq(ptr_vector& justifications, enode* a, enode* b, justification const& j); + template void explain_todo(ptr_vector& justifications); @@ -295,7 +291,7 @@ namespace euf { void add_th_var(enode* n, theory_var v, theory_id id); void set_th_propagates_diseqs(theory_id id); void set_merge_enabled(enode* n, bool enable_merge); - void set_value(enode* n, lbool value); + void set_value(enode* n, lbool value, justification j); void set_bool_var(enode* n, unsigned v) { n->set_bool_var(v); } void set_relevant(enode* n); void set_default_relevant(bool b) { m_default_relevant = b; } diff --git a/src/ast/euf/euf_enode.h b/src/ast/euf/euf_enode.h index 850e183e8..dc98d95c8 100644 --- a/src/ast/euf/euf_enode.h +++ b/src/ast/euf/euf_enode.h @@ -133,6 +133,7 @@ namespace euf { void del_th_var(theory_id id) { m_th_vars.del_var(id); } void set_merge_enabled(bool m) { m_merge_enabled = m; } void set_value(lbool v) { m_value = v; } + void set_justification(justification j) { m_justification = j; } void set_is_equality() { m_is_equality = true; } void set_bool_var(sat::bool_var v) { m_bool_var = v; } diff --git a/src/sat/smt/euf_internalize.cpp b/src/sat/smt/euf_internalize.cpp index 6cc72eba5..116621dcc 100644 --- a/src/sat/smt/euf_internalize.cpp +++ b/src/sat/smt/euf_internalize.cpp @@ -190,8 +190,9 @@ namespace euf { m_egraph.set_bool_var(n, v); if (m.is_eq(e) || m.is_or(e) || m.is_and(e) || m.is_not(e)) m_egraph.set_merge_enabled(n, false); - if (s().value(lit) != l_undef) - m_egraph.set_value(n, s().value(lit)); + lbool val = s().value(lit); + if (val != l_undef) + m_egraph.set_value(n, val, justification::external(to_ptr(val == l_true ? lit : ~lit))); return lit; } diff --git a/src/sat/smt/euf_solver.cpp b/src/sat/smt/euf_solver.cpp index 3c7342136..2ef77ac6a 100644 --- a/src/sat/smt/euf_solver.cpp +++ b/src/sat/smt/euf_solver.cpp @@ -309,7 +309,7 @@ namespace euf { if (!n) return; bool sign = l.sign(); - m_egraph.set_value(n, sign ? l_false : l_true); + m_egraph.set_value(n, sign ? l_false : l_true, justification::external(to_ptr(l))); for (auto const& th : enode_th_vars(n)) m_id2solver[th.get_id()]->asserted(l);