diff --git a/src/ast/euf/euf_egraph.h b/src/ast/euf/euf_egraph.h index 2feb9e79f..dda68767c 100644 --- a/src/ast/euf/euf_egraph.h +++ b/src/ast/euf/euf_egraph.h @@ -304,7 +304,7 @@ namespace euf { egraph const& g; enode* n; b_pp(egraph const& g, enode* n) : g(g), n(n) {} - std::ostream& display(std::ostream& out) const { return out << n->get_expr_id() << ": " << mk_bounded_pp(n->get_expr(), g.m); } + std::ostream& display(std::ostream& out) const { return n ? (out << n->get_expr_id() << ": " << mk_bounded_pp(n->get_expr(), g.m)) : out << "null"; } }; b_pp bpp(enode* n) const { return b_pp(*this, n); } std::ostream& display(std::ostream& out) const; diff --git a/src/sat/smt/q_ematch.cpp b/src/sat/smt/q_ematch.cpp index 0986fcfd1..c006bc9eb 100644 --- a/src/sat/smt/q_ematch.cpp +++ b/src/sat/smt/q_ematch.cpp @@ -142,58 +142,91 @@ namespace q { } } + struct ematch::remove_binding : public trail { + clause& c; + binding* b; + remove_binding(clause& c, binding* b): c(c), b(b) {} + void undo(euf::solver& ctx) override { + binding::remove_from(c.m_bindings, b); + } + }; + + struct ematch::insert_binding : public trail { + clause& c; + binding* b; + insert_binding(clause& c, binding* b): c(c), b(b) {} + void undo(euf::solver& ctx) override { + binding::push_to_front(c.m_bindings, b); + } + }; + ematch::binding* ematch::alloc_binding(unsigned n) { unsigned sz = sizeof(binding) + sizeof(euf::enode* const*)*n; void* mem = ctx.get_region().allocate(sz); return new (mem) binding(); } - void ematch::on_binding(quantifier* q, app* pat, euf::enode* const* _binding) { - clause& c = *m_clauses[m_q2clauses[q]]; - if (propagate(_binding, c)) - return; - unsigned n = q->get_num_decls(); - binding* b = alloc_binding(n); - b->m_propagated = false; + void ematch::clause::add_binding(ematch& em, euf::enode* const* _binding) { + unsigned n = m_q->get_num_decls(); + binding* b = em.alloc_binding(n); for (unsigned i = 0; i < n; ++i) b->m_nodes[i] = _binding[i]; - c.m_bindings.push_back(b); - ctx.push(push_back_vector>(c.m_bindings)); + + binding::push_to_front(m_bindings, b); + em.ctx.push(remove_binding(*this, b)); } - std::ostream& ematch::clause::display(std::ostream& out) const { + void ematch::on_binding(quantifier* q, app* pat, euf::enode* const* _binding) { + clause& c = *m_clauses[m_q2clauses[q]]; + if (!propagate(_binding, c)) + c.add_binding(*this, _binding); + } + + std::ostream& ematch::clause::display(euf::solver& ctx, std::ostream& out) const { out << "clause:\n"; for (auto const& lit : m_lits) - out << lit.lhs << (lit.sign ? " != " : " == ") << lit.rhs << "\n"; + out << mk_bounded_pp(lit.lhs, lit.lhs.m(), 2) + << (lit.sign ? " != " : " == ") + << mk_bounded_pp(lit.rhs, lit.rhs.m(), 2) << "\n"; + unsigned num_decls = m_q->get_num_decls(); + binding* b = m_bindings; + if (b) { + do { + for (unsigned i = 0; i < num_decls; ++i) + out << ctx.bpp(b->nodes()[i]) << " "; + out << "\n"; + b = b->next(); + } + while (b != m_bindings); + } return out; } bool ematch::propagate(euf::enode* const* binding, clause& c) { - TRACE("q", c.display(tout) << "\n";); + TRACE("q", c.display(ctx, tout) << "\n";); unsigned clause_idx = m_q2clauses[c.m_q]; scoped_mark_reset _sr(*this); unsigned idx = UINT_MAX; unsigned sz = c.m_lits.size(); + unsigned n = c.m_q->get_num_decls(); for (unsigned i = 0; i < sz; ++i) { lit l = c.m_lits[i]; m_indirect_nodes.reset(); - lbool cmp = compare(binding, l.lhs, l.rhs); + lbool cmp = compare(n, binding, l.lhs, l.rhs); switch (cmp) { case l_false: - if (l.sign) { - if (i > 0) - std::swap(c.m_lits[0], c.m_lits[i]); - return true; - } - break; + if (!l.sign) + break; + if (i > 0) + std::swap(c.m_lits[0], c.m_lits[i]); + return true; case l_true: - if (!l.sign) { - if (i > 0) - std::swap(c.m_lits[0], c.m_lits[i]); - return true; - } - break; + if (l.sign) + break; + if (i > 0) + std::swap(c.m_lits[0], c.m_lits[i]); + return true; case l_undef: TRACE("q", tout << l.lhs << " ~~ " << l.rhs << " is undef\n";); if (idx == 0) { @@ -207,19 +240,13 @@ namespace q { std::swap(c.m_lits[1], c.m_lits[i]); return false; } - else if (i > 0) { + else if (i > 0) std::swap(c.m_lits[0], c.m_lits[i]); - idx = 0; - } + idx = 0; break; } } - if (idx == UINT_MAX) { - std::cout << "clause is false\n"; - } - else { - std::cout << "unit propagate\n"; - } + TRACE("q", tout << "instantiate " << (idx == UINT_MAX ? "clause is false":"unit propagate") << "\n";); instantiate(binding, c); return true; } @@ -238,10 +265,12 @@ namespace q { m_qs.add_clause(ctx.mk_literal(q), ~ctx.mk_literal(result)); } - lbool ematch::compare(euf::enode* const* binding, expr* s, expr* t) { - TRACE("q", tout << mk_pp(s, m) << " ~~ " << mk_pp(t, m) << "\n";); - euf::enode* sn = eval(binding, s); - euf::enode* tn = eval(binding, t); + lbool ematch::compare(unsigned n, euf::enode* const* binding, expr* s, expr* t) { + euf::enode* sn = eval(n, binding, s); + euf::enode* tn = eval(n, binding, t); + TRACE("q", tout << mk_pp(s, m) << " ~~ " << mk_pp(t, m) << "\n"; + tout << ctx.bpp(sn) << " " << ctx.bpp(tn) << "\n";); + lbool c; if (sn && sn == tn) return l_true; @@ -250,21 +279,21 @@ namespace q { if (sn && tn) return l_undef; if (!sn && !tn) - return compare_rec(binding, s, t); + return compare_rec(n, binding, s, t); if (!sn && tn) for (euf::enode* t1 : euf::enode_class(tn)) - if (c = compare_rec(binding, s, t1->get_expr()), c != l_undef) + if (c = compare_rec(n, binding, s, t1->get_expr()), c != l_undef) return c; if (sn && !tn) for (euf::enode* s1 : euf::enode_class(sn)) - if (c = compare_rec(binding, t, s1->get_expr()), c != l_undef) + if (c = compare_rec(n, binding, t, s1->get_expr()), c != l_undef) return c; return l_undef; } // f(p1) = f(p2) if p1 = p2 // f(p1) != f(p2) if p1 != p2 and f is injective - lbool ematch::compare_rec(euf::enode* const* binding, expr* s, expr* t) { + lbool ematch::compare_rec(unsigned n, euf::enode* const* binding, expr* s, expr* t) { if (m.are_equal(s, t)) return l_true; if (m.are_distinct(s, t)) @@ -278,7 +307,7 @@ namespace q { bool is_injective = to_app(s)->get_decl()->is_injective(); bool has_undef = false; for (unsigned i = to_app(s)->get_num_args(); i-- > 0; ) { - switch (compare(binding, to_app(s)->get_arg(i), to_app(t)->get_arg(i))) { + switch (compare(n, binding, to_app(s)->get_arg(i), to_app(t)->get_arg(i))) { case l_true: break; case l_false: @@ -295,8 +324,7 @@ namespace q { return has_undef ? l_undef : l_true; } - euf::enode* ematch::eval(euf::enode* const* binding, expr* e) { - TRACE("q", tout << mk_pp(e, m) << "\n";); + euf::enode* ematch::eval(unsigned n, euf::enode* const* binding, expr* e) { if (is_ground(e)) return ctx.get_egraph().find(e)->get_root(); if (m_mark.is_marked(e)) @@ -319,7 +347,7 @@ namespace q { } if (is_var(t)) { m_mark.mark(t); - m_eval.setx(t->get_id(), binding[to_var(t)->get_idx()], nullptr); + m_eval.setx(t->get_id(), binding[n - 1 - to_var(t)->get_idx()], nullptr); todo.pop_back(); continue; } @@ -368,13 +396,18 @@ namespace q { for (; m_qhead < m_queue.size(); ++m_qhead) { unsigned idx = m_queue[m_qhead]; clause& c = *m_clauses[idx]; - for (auto& b : c.bindings()) { - if (!b->propagated() && propagate(b->m_nodes, c)) { - ctx.push(value_trail(b->m_propagated)); - b->set_propagated(true); - propagated = true; + binding* b = c.m_bindings; + if (!b) + continue; + do { + binding* next = b->next(); + if (propagate(b->m_nodes, c)) { + binding::remove_from(c.m_bindings, b); + ctx.push(insert_binding(c, b)); } + b = next; } + while (b != c.m_bindings); } m_clause_in_queue.reset(); m_node_in_queue.reset(); @@ -471,15 +504,42 @@ namespace q { } bool ematch::operator()() { - if (m_lazy_mam) - m_lazy_mam->propagate(); if (propagate()) return true; + if (m_lazy_mam) { + m_lazy_mam->propagate(); + if (propagate()) + return true; + } + // - // TODO: loop over pending bindings and instantiate them + // loop over pending bindings and instantiate them // - // NOT_IMPLEMENTED_YET(); - return true; + bool instantiated = false; + for (auto* c : m_clauses) { + binding* b = c->m_bindings; + if (!b) + continue; + instantiated = true; + do { + instantiate(b->m_nodes, *c); + b = b->next(); + } + while (b != c->m_bindings); + + while (b = c->m_bindings) { + binding::remove_from(c->m_bindings, b); + ctx.push(insert_binding(*c, b)); + } + } + TRACE("q", ctx.display(tout << "instantiated: " << instantiated << "\n");); + return instantiated; + } + + std::ostream& ematch::display(std::ostream& out) const { + for (auto const& c : m_clauses) + c->display(ctx, out); + return out; } } diff --git a/src/sat/smt/q_ematch.h b/src/sat/smt/q_ematch.h index a5af4e5b5..d1c22be0c 100644 --- a/src/sat/smt/q_ematch.h +++ b/src/sat/smt/q_ematch.h @@ -17,6 +17,7 @@ Author: #pragma once #include "util/nat_set.h" +#include "util/dlist.h" #include "solver/solver.h" #include "sat/smt/sat_th.h" #include "sat/smt/q_mam.h" @@ -50,15 +51,16 @@ namespace q { }; - struct binding { - bool m_propagated { false }; + struct remove_binding; + struct insert_binding; + + struct binding : public dll_base { euf::enode* m_nodes[0]; binding() {} - bool propagated() const { return m_propagated; } - void set_propagated(bool b) { m_propagated = b; } euf::enode* const* nodes() { return m_nodes; } + }; binding* alloc_binding(unsigned n); @@ -66,10 +68,10 @@ namespace q { struct clause { vector m_lits; quantifier* m_q; - ptr_vector m_bindings; + binding* m_bindings { nullptr }; - ptr_vector const& bindings() { return m_bindings; } - std::ostream& display(std::ostream& out) const; + void add_binding(ematch& em, euf::enode* const* b); + std::ostream& display(euf::solver& ctx, std::ostream& out) const; }; @@ -98,10 +100,10 @@ namespace q { void ensure_ground_enodes(clause const& c); // compare s, t modulo sign under binding - lbool compare(euf::enode* const* binding, expr* s, expr* t); - lbool compare_rec(euf::enode* const* binding, expr* s, expr* t); + lbool compare(unsigned n, euf::enode* const* binding, expr* s, expr* t); + lbool compare_rec(unsigned n, euf::enode* const* binding, expr* s, expr* t); euf::enode_vector m_eval, m_indirect_nodes; - euf::enode* eval(euf::enode* const* binding, expr* e); + euf::enode* eval(unsigned n, euf::enode* const* binding, expr* e); bool propagate(euf::enode* const* binding, clause& c); void instantiate(euf::enode* const* binding, clause& c); @@ -138,6 +140,8 @@ namespace q { // callback from mam void on_binding(quantifier* q, app* pat, euf::enode* const* binding); + std::ostream& display(std::ostream& out) const; + }; } diff --git a/src/sat/smt/q_solver.cpp b/src/sat/smt/q_solver.cpp index 2351c1aa3..0e8e9dbb9 100644 --- a/src/sat/smt/q_solver.cpp +++ b/src/sat/smt/q_solver.cpp @@ -58,9 +58,8 @@ namespace q { } sat::check_result solver::check() { - if (ctx.get_config().m_ematching) - if (!m_ematch()) - return sat::check_result::CR_CONTINUE; + if (ctx.get_config().m_ematching && m_ematch()) + return sat::check_result::CR_CONTINUE; if (ctx.get_config().m_mbqi) { switch (m_mbqi()) { @@ -73,6 +72,7 @@ namespace q { } std::ostream& solver::display(std::ostream& out) const { + m_ematch.display(out); return out; } diff --git a/src/smt/smt_setup.cpp b/src/smt/smt_setup.cpp index 0d683b1ff..e51048229 100644 --- a/src/smt/smt_setup.cpp +++ b/src/smt/smt_setup.cpp @@ -926,6 +926,7 @@ namespace smt { void setup::setup_str() { setup_arith(); m_context.register_plugin(alloc(theory_str, m_context, m_manager, m_params)); + setup_char(); } void setup::setup_seq() {