diff --git a/src/ast/euf/euf_egraph.cpp b/src/ast/euf/euf_egraph.cpp index 020e50440..b05df0cff 100644 --- a/src/ast/euf/euf_egraph.cpp +++ b/src/ast/euf/euf_egraph.cpp @@ -16,6 +16,7 @@ Author: --*/ #include "ast/euf/euf_egraph.h" +#include "ast/ast_pp.h" namespace euf { @@ -59,6 +60,7 @@ namespace euf { } void egraph::reinsert_equality(enode* p) { + SASSERT(is_equality(p)); if (p->get_arg(0)->get_root() == p->get_arg(1)->get_root()) m_new_eqs.push_back(p); } @@ -162,10 +164,10 @@ namespace euf { if (r1 == r2) return; if (r1->interpreted() && r2->interpreted()) { - set_conflict(r1, r2, j); + set_conflict(n1, n2, j); return; } - if ((r1->class_size() > r2->class_size() && !r1->interpreted()) || r2->interpreted()) { + if ((r1->class_size() > r2->class_size() && !r2->interpreted()) || r1->interpreted()) { std::swap(r1, r2); std::swap(n1, n2); } @@ -237,41 +239,6 @@ namespace euf { SASSERT(n1->get_root()->m_target == nullptr); } - template - void egraph::explain(ptr_vector& justifications) { - SASSERT(m_inconsistent); - SASSERT(m_todo.empty()); - m_todo.push_back(m_n1); - m_todo.push_back(m_n2); - auto push_congruence = [&](enode* p, enode* q) { - SASSERT(p->get_decl() == q->get_decl()); - for (enode* arg : enode_args(p)) - m_todo.push_back(arg); - for (enode* arg : enode_args(q)) - m_todo.push_back(arg); - }; - auto explain_node = [&](enode* n) { - if (!n->m_target) - return; - if (n->is_marked1()) - return; - n->mark1(); - if (n->m_justification.is_external()) - justifications.push_back(n->m_justification.ext()); - else if (n->m_justification.is_congruence()) - push_congruence(n, n->m_target); - n = n->m_target; - if (!n->is_marked1()) - m_todo.push_back(n); - }; - if (m_justification.is_external()) - justifications.push_back(m_justification.ext()); - for (unsigned i = 0; i < m_todo.size(); ++i) - explain_node(m_todo[i]); - for (enode* n : m_todo) - n->unmark1(); - } - void egraph::invariant() { for (enode* n : m_nodes) n->invariant(); diff --git a/src/ast/euf/euf_egraph.h b/src/ast/euf/euf_egraph.h index 327023850..2f4bfed89 100644 --- a/src/ast/euf/euf_egraph.h +++ b/src/ast/euf/euf_egraph.h @@ -99,8 +99,44 @@ namespace euf { bool inconsistent() const { return m_inconsistent; } enode_vector const& new_eqs() const { return m_new_eqs; } template - void explain(ptr_vector& justifications); - + void explain(ptr_vector& justifications) { + SASSERT(m_inconsistent); + SASSERT(m_todo.empty()); + auto push_congruence = [&](enode* p, enode* q) { + SASSERT(p->get_decl() == q->get_decl()); + for (enode* arg : enode_args(p)) + m_todo.push_back(arg); + for (enode* arg : enode_args(q)) + m_todo.push_back(arg); + }; + auto explain_node = [&](enode* n) { + if (!n->m_target) + return; + if (n->is_marked1()) + return; + n->mark1(); + if (n->m_justification.is_external()) + justifications.push_back(n->m_justification.ext()); + else if (n->m_justification.is_congruence()) + push_congruence(n, n->m_target); + n = n->m_target; + if (!n->is_marked1()) + m_todo.push_back(n); + }; + m_todo.push_back(m_n1); + m_todo.push_back(m_n2); + if (m_justification.is_external()) + justifications.push_back(m_justification.ext()); + else if (m_justification.is_congruence()) + push_congruence(m_n1, m_n2); + for (unsigned i = 0; i < m_todo.size(); ++i) + explain_node(m_todo[i]); + for (enode* n : m_todo) + n->unmark1(); + m_todo.reset(); + } + + void invariant(); std::ostream& display(std::ostream& out) const; }; diff --git a/src/ast/euf/euf_enode.h b/src/ast/euf/euf_enode.h index 008c5953b..39034bd0e 100644 --- a/src/ast/euf/euf_enode.h +++ b/src/ast/euf/euf_enode.h @@ -63,6 +63,7 @@ namespace euf { n->m_root = n; n->m_commutative = num_args == 2 && is_app(f) && to_app(f)->get_decl()->is_commutative(); for (unsigned i = 0; i < num_args; ++i) { + SASSERT(to_app(f)->get_arg(i) == args[i]->get_owner()); n->m_args[i] = args[i]; } return n; diff --git a/src/smt/theory_seq.cpp b/src/smt/theory_seq.cpp index c1e182aa2..921e5b45f 100644 --- a/src/smt/theory_seq.cpp +++ b/src/smt/theory_seq.cpp @@ -1502,6 +1502,9 @@ bool theory_seq::internalize_term(app* term) { if (m.is_bool(term) && (m_util.str.is_in_re(term) || m_sk.is_skolem(term))) { + if (m_util.str.is_in_re(term)) { + mk_var(ensure_enode(term->get_arg(0))); + } bool_var bv = ctx.mk_bool_var(term); ctx.set_var_theory(bv, get_id()); ctx.mark_as_relevant(bv); diff --git a/src/test/egraph.cpp b/src/test/egraph.cpp index e0ad218ae..f311b5038 100644 --- a/src/test/egraph.cpp +++ b/src/test/egraph.cpp @@ -8,6 +8,7 @@ Copyright (c) 2020 Microsoft Corporation #include "ast/euf/euf_egraph.h" #include "ast/reg_decl_plugins.h" #include "ast/ast_pp.h" +#include "ast/arith_decl_plugin.h" static expr_ref mk_const(ast_manager& m, char const* name, sort* s) { return expr_ref(m.mk_const(symbol(name), s), m); @@ -85,13 +86,47 @@ static void test2() { static void test3() { ast_manager m; reg_decl_plugins(m); + arith_util a(m); euf::egraph g(m); - sort_ref S(m.mk_uninterpreted_sort(symbol("S")), m); - + sort_ref I(a.mk_int(), m); + expr_ref zero(a.mk_int(0), m); + expr_ref one(a.mk_int(1), m); + expr_ref x = mk_const(m, "x", I); + expr_ref y = mk_const(m, "y", I); + expr_ref z = mk_const(m, "z", I); + expr_ref u = mk_const(m, "u", I); + expr_ref fx = mk_app("f", x, I); + expr_ref fy = mk_app("f", y, I); + euf::enode* nx = g.mk(x, nullptr); + euf::enode* ny = g.mk(y, nullptr); + euf::enode* nz = g.mk(z, nullptr); + euf::enode* nu = g.mk(u, nullptr); + euf::enode* n0 = g.mk(zero, nullptr); + euf::enode* n1 = g.mk(one, nullptr); + euf::enode* nfx = g.mk(fx, &nx); + euf::enode* nfy = g.mk(fy, &ny); + int justifications[5] = { 1, 2, 3, 4, 5 }; + g.merge(nfx, n0, justifications + 0); + g.merge(nfy, n1, justifications + 1); + g.merge(nx, nz, justifications + 2); + g.merge(nx, nu, justifications + 3); + g.propagate(); + SASSERT(!g.inconsistent()); + g.merge(nx, ny, justifications + 4); + std::cout << g << "\n"; + g.propagate(); + std::cout << g << "\n"; + SASSERT(g.inconsistent()); + ptr_vector js; + g.explain(js); + for (int* j : js) { + std::cout << "conflict: " << *j << "\n"; + } } void tst_egraph() { enable_trace("euf"); + test3(); test1(); test2(); }