diff --git a/src/ast/simplifiers/euf_completion.cpp b/src/ast/simplifiers/euf_completion.cpp index 7d5e484aa..67886f3bd 100644 --- a/src/ast/simplifiers/euf_completion.cpp +++ b/src/ast/simplifiers/euf_completion.cpp @@ -28,7 +28,7 @@ Algorithm for extracting canonical form from an E-graph: * Each f(t) = g(s) in E: * add f(canon(t)) = canon(f(t)), g(canon(s)) = canon(g(s)) where canon(f(t)) = canon(g(s)) by construction. - + * Each other g(t) in E: * add g(canon(t)) to E. * Note that canon(g(t)) = true because g(t) = true is added to congruence closure of E. @@ -37,16 +37,11 @@ Algorithm for extracting canonical form from an E-graph: Conditional saturation: - forall X . Body => Head -- propagate when (all assertions in) Body is merged with True -- Possible efficient approaches: - - use on_merge? - - or bit set in nodes with Body? - - register Boolean reduction rules to EUF? - - register function "body_of" and monitor merges based on function? +- propagate when (all assertions in) Body is merged with True +- insert expressions from Body into a watch list. + When elements of the watch list are merged by true/false + trigger rep-propagation with respect to body. -Delayed solver invocation -- So far default code for checking rules - - EUF check should be on demand, see note on conditional saturation Mam optimization? match(p, t, S) = suppose all variables in p are bound in S, check equality using canonization of p[S], otherwise prune instances from S. @@ -59,10 +54,11 @@ Mam optimization? #include "ast/rewriter/var_subst.h" #include "ast/simplifiers/euf_completion.h" #include "ast/shared_occs.h" +#include "params/tactic_params.hpp" namespace euf { - completion::completion(ast_manager& m, dependent_expr_state& fmls): + completion::completion(ast_manager& m, dependent_expr_state& fmls) : dependent_expr_simplifier(m, fmls), m_egraph(m), m_mam(mam::mk(*this, *this)), @@ -75,16 +71,17 @@ namespace euf { m_rewriter.set_order_eq(true); m_rewriter.set_flat_and_or(false); - std::function _on_merge = - [&](euf::enode* root, euf::enode* other) { - m_mam->on_merge(root, other); - }; - - std::function _on_make = + std::function _on_merge = + [&](euf::enode* root, euf::enode* other) { + m_mam->on_merge(root, other); + watch_rule(root, other); + }; + + std::function _on_make = [&](euf::enode* n) { m_mam->add_node(n, false); - }; - + }; + m_egraph.set_on_merge(_on_merge); m_egraph.set_on_make(_on_make); } @@ -92,9 +89,76 @@ namespace euf { completion::~completion() { } + bool completion::should_stop() { + return + !m.inc() || + m_egraph.inconsistent() || + m_fmls.inconsistent() || + resource_limits_exceeded(); + } + + void completion::updt_params(params_ref const& p) { + tactic_params tp(p); + m_max_instantiations = tp.completion_max_instantiations(); + } + + struct completion::push_watch_rule : public trail { + vector>& m_rules; + unsigned idx; + push_watch_rule(vector>& r, unsigned i) : m_rules(r), idx(i) {} + void undo() override { + m_rules[idx].pop_back(); + } + }; + + void completion::push() { + if (m_side_condition_solver) + m_side_condition_solver->push(); + m_egraph.push(); + dependent_expr_simplifier::push(); + } + + void completion::pop(unsigned n) { + clear_propagation_queue(); + dependent_expr_simplifier::pop(n); + m_egraph.pop(n); + if (m_side_condition_solver) + m_side_condition_solver->pop(n); + } + + void completion::clear_propagation_queue() { + for (auto r : m_propagation_queue) + r->m_in_queue = false; + m_propagation_queue.reset(); + } + + void completion::watch_rule(enode* root, enode* other) { + auto oid = other->get_id(); + if (oid >= m_rule_watch.size()) + return; + if (m_rule_watch[oid].empty()) + return; + auto is_true_or_false = m.is_true(root->get_expr()) || m.is_false(root->get_expr()); + if (is_true_or_false) { + for (auto r : m_rule_watch[oid]) + if (!r->m_in_queue) + r->m_in_queue = true, + m_propagation_queue.push_back(r); + } + else { + // root is not true or false, use root to watch rules + auto rid = root->get_id(); + m_rule_watch.reserve(rid + 1); + for (auto r : m_rule_watch[oid]) { + m_rule_watch[rid].push_back(r); + get_trail().push(push_watch_rule(m_rule_watch, rid)); + } + } + } + void completion::reduce() { m_has_new_eq = true; - for (unsigned rounds = 0; m_has_new_eq && rounds <= 3 && !m_fmls.inconsistent(); ++rounds) { + for (unsigned rounds = 0; m_has_new_eq && rounds <= 3 && !should_stop(); ++rounds) { ++m_epoch; m_has_new_eq = false; add_egraph(); @@ -113,23 +177,24 @@ namespace euf { add_constraint(f, d); } m_should_propagate = true; - while (m_should_propagate && m.inc() && !m_egraph.inconsistent()) { + while (m_should_propagate && !should_stop()) { m_should_propagate = false; m_egraph.propagate(); m_mam->propagate(); + propagate_rules(); IF_VERBOSE(11, verbose_stream() << "propagate " << m_stats.m_num_instances << "\n"); if (!m_should_propagate) - check_rules(); + propagate_all_rules(); } } void completion::add_constraint(expr* f, expr_dependency* d) { if (m_egraph.inconsistent()) return; - auto add_children = [&](enode* n) { + auto add_children = [&](enode* n) { for (auto* ch : enode_args(n)) m_nodes_to_canonize.push_back(ch); - }; + }; expr* x, * y; if (m.is_eq(f, x, y)) { enode* a = mk_enode(x); @@ -160,7 +225,7 @@ namespace euf { if (!get_dependency(q)) { m_q2dep.insert(q, d); get_trail().push(insert_obj_map(m_q2dep, q)); - } + } } add_rule(f, d); } @@ -174,9 +239,9 @@ namespace euf { d = m.mk_join(d, explain_eq(n, n->get_root())); return l_true; } - if (m.is_false(n->get_root()->get_expr())) + if (m.is_false(n->get_root()->get_expr())) return l_false; - + expr* g = nullptr; if (m.is_not(f, g)) { n = mk_enode(g); @@ -184,6 +249,8 @@ namespace euf { d = m.mk_join(d, explain_eq(n, n->get_root())); return l_true; } + if (m.is_true(n->get_root()->get_expr())) + return l_false; } if (m_side_condition_solver) { expr_dependency* sd = nullptr; @@ -203,7 +270,7 @@ namespace euf { expr_ref_vector body(m); expr_ref head(y, m); body.push_back(x); - flatten_and(body); + flatten_and(body); unsigned j = 0; for (auto f : body) { switch (eval_cond(f, d)) { @@ -217,51 +284,66 @@ namespace euf { } } body.shrink(j); - if (body.empty()) - add_constraint(head, d); + if (body.empty()) + add_constraint(head, d); else { - m_rules.push_back(alloc(ground_rule, body, head, d)); + // create a new rule. + // add all (one is actually enough) parts of the body to watch list. + auto r = alloc(conditional_rule, body, head, d); + m_rules.push_back(r); + get_trail().push(new_obj_trail(r)); get_trail().push(push_back_vector(m_rules)); - } - } - - void completion::check_rules() { - for (auto& r : m_rules) { - if (!r->m_active) - continue; - switch (check_rule(*r)) { - case l_true: - get_trail().push(value_trail(r->m_active)); - r->m_active = false; - break; // remove rule, it is activated - case l_false: - get_trail().push(value_trail(r->m_active)); - r->m_active = false; - break; // remove rule, premise is false - case l_undef: - break; + for (auto f : body) { + auto n = m_egraph.find(f)->get_root(); + if (m.is_not(n->get_expr())) + n = n->get_arg(0)->get_root(); + m_rule_watch.reserve(n->get_id() + 1); + m_rule_watch[n->get_id()].push_back(r); + get_trail().push(push_watch_rule(m_rule_watch, n->get_id())); } } } - lbool completion::check_rule(ground_rule& r) { + void completion::propagate_all_rules() { + for (auto* r : m_rules) + if (!r->m_in_queue) + r->m_in_queue = true, + m_propagation_queue.push_back(r); + propagate_rules(); + } + + void completion::propagate_rules() { + for (unsigned i = 0; i < m_propagation_queue.size() && !should_stop(); ++i) { + auto r = m_propagation_queue[i]; + r->m_in_queue = false; + propagate_rule(*r); + } + clear_propagation_queue(); + } + + void completion::propagate_rule(conditional_rule& r) { + if (!r.m_active) + return; for (auto* f : r.m_body) { switch (eval_cond(f, r.m_dep)) { case l_true: break; case l_false: - return l_false; + get_trail().push(value_trail(r.m_active)); + r.m_active = false; + return; default: break; - } + } } if (r.m_body.empty()) { add_constraint(r.m_head, r.m_dep); - return l_true; + get_trail().push(value_trail(r.m_active)); + r.m_active = false; } - return l_undef; } + // callback when mam finds a binding void completion::on_binding(quantifier* q, app* pat, enode* const* binding, unsigned mg, unsigned ming, unsigned mx) { if (m_egraph.inconsistent()) return; @@ -272,6 +354,7 @@ namespace euf { expr_ref r = subst(q->get_expr(), _binding); IF_VERBOSE(12, verbose_stream() << "add " << r << "\n"); add_constraint(r, get_dependency(q)); + propagate_rules(); m_should_propagate = true; ++m_stats.m_num_instances; } @@ -285,7 +368,7 @@ namespace euf { } unsigned sz = qtail(); for (unsigned i = qhead(); i < sz; ++i) { - auto [f, p, d] = m_fmls[i](); + auto [f, p, d] = m_fmls[i](); expr_dependency_ref dep(d, m); expr_ref g = canonize_fml(f, dep); if (g != f) { @@ -346,7 +429,7 @@ namespace euf { n = m_egraph.find(arg); if (n) m_args.push_back(n); - else + else m_todo.push_back(arg); } if (sz == m_todo.size()) { @@ -361,7 +444,7 @@ namespace euf { auto is_nullary = [&](expr* e) { return is_app(e) && to_app(e)->get_num_args() == 0; - }; + }; expr* x, * y; if (m.is_eq(f, x, y)) { expr_ref x1 = canonize(x, d); @@ -379,10 +462,10 @@ namespace euf { if (x == y) return expr_ref(m.mk_true(), m); - if (x == x1 && y == y1) + if (x == x1 && y == y1) return m_rewriter.mk_eq(x, y); - if (is_nullary(x) && is_nullary(y)) + if (is_nullary(x) && is_nullary(y)) return mk_and(m_rewriter.mk_eq(x, x1), m_rewriter.mk_eq(y, x1)); if (x == x1 && is_nullary(x)) @@ -390,13 +473,13 @@ namespace euf { if (y == y1 && is_nullary(y)) return m_rewriter.mk_eq(x1, y1); - + if (is_nullary(x)) return mk_and(m_rewriter.mk_eq(x, x1), m_rewriter.mk_eq(y1, x1)); - + if (is_nullary(y)) return mk_and(m_rewriter.mk_eq(y, y1), m_rewriter.mk_eq(x1, y1)); - + if (x1 == y1) return expr_ref(m.mk_true(), m); else { @@ -438,8 +521,8 @@ namespace euf { } if (m.is_eq(f)) return m_rewriter.mk_eq(m_eargs.get(0), m_eargs.get(1)); - if (!change) - return expr_ref(f, m); + if (!change) + return expr_ref(f, m); else return expr_ref(m_rewriter.mk_app(to_app(f)->get_decl(), m_eargs.size(), m_eargs.data()), m); } @@ -448,11 +531,11 @@ namespace euf { enode* n = m_egraph.find(f); enode* r = n->get_root(); d = m.mk_join(d, explain_eq(n, r)); - d = m.mk_join(d, m_deps.get(r->get_id(), nullptr)); + d = m.mk_join(d, m_deps.get(r->get_id(), nullptr)); SASSERT(m_canonical.get(r->get_id())); return m_canonical.get(r->get_id()); } - + expr* completion::get_canonical(enode* n) { if (m_epochs.get(n->get_id(), 0) == m_epoch) return m_canonical.get(n->get_id()); @@ -466,10 +549,10 @@ namespace euf { unsigned idx; expr_ref old_value; public: - vtrail(expr_ref_vector& c, unsigned idx) : + vtrail(expr_ref_vector& c, unsigned idx) : c(c), idx(idx), old_value(c.get(idx), c.m()) { } - + void undo() override { c[idx] = old_value; old_value = nullptr; @@ -507,24 +590,24 @@ namespace euf { } void completion::collect_statistics(statistics& st) const { - st.update("euf-completion-rewrites", m_stats.m_num_rewrites); - st.update("euf-completion-instances", m_stats.m_num_instances); + st.update("euf-completion-rewrites", m_stats.m_num_rewrites); + st.update("euf-completion-instances", m_stats.m_num_instances); } bool completion::is_gt(expr* lhs, expr* rhs) const { - if (lhs == rhs) + if (lhs == rhs) return false; // values are always less in ordering than non-values. bool v1 = m.is_value(lhs); bool v2 = m.is_value(rhs); - if (!v1 && v2) + if (!v1 && v2) return true; - if (v1 && !v2) + if (v1 && !v2) return false; - - if (get_depth(lhs) > get_depth(rhs)) + + if (get_depth(lhs) > get_depth(rhs)) return true; - if (get_depth(lhs) < get_depth(rhs)) + if (get_depth(lhs) < get_depth(rhs)) return false; // slow path @@ -534,16 +617,16 @@ namespace euf { return true; if (n1 < n2) return false; - + if (is_app(lhs) && is_app(rhs)) { app* l = to_app(lhs); app* r = to_app(rhs); - if (l->get_decl()->get_id() != r->get_decl()->get_id()) + if (l->get_decl()->get_id() != r->get_decl()->get_id()) return l->get_decl()->get_id() > r->get_decl()->get_id(); - if (l->get_num_args() != r->get_num_args()) + if (l->get_num_args() != r->get_num_args()) return l->get_num_args() > r->get_num_args(); - for (unsigned i = 0; i < l->get_num_args(); ++i) - if (l->get_arg(i) != r->get_arg(i)) + for (unsigned i = 0; i < l->get_num_args(); ++i) + if (l->get_arg(i) != r->get_arg(i)) return is_gt(l->get_arg(i), r->get_arg(i)); UNREACHABLE(); } @@ -569,14 +652,14 @@ namespace euf { n->mark1(); roots.push_back(n); enode* rep = nullptr; - for (enode* k : enode_class(n)) + for (enode* k : enode_class(n)) if (!rep || m.is_value(k->get_expr()) || is_gt(rep->get_expr(), k->get_expr())) - rep = k; + rep = k; // IF_VERBOSE(0, verbose_stream() << m_egraph.bpp(n) << " ->\n" << m_egraph.bpp(rep) << "\n";); m_reps.setx(n->get_id(), rep, nullptr); TRACE(euf_completion, tout << "rep " << m_egraph.bpp(n) << " -> " << m_egraph.bpp(rep) << "\n"; - for (enode* k : enode_class(n)) tout << m_egraph.bpp(k) << "\n";); + for (enode* k : enode_class(n)) tout << m_egraph.bpp(k) << "\n";); m_todo.push_back(n->get_expr()); for (enode* arg : enode_args(n)) { arg = arg->get_root(); @@ -602,7 +685,7 @@ namespace euf { enode* n = m_egraph.find(e); SASSERT(n->is_root()); enode* rep = m_reps[n->get_id()]; - if (get_canonical(n)) + if (get_canonical(n)) m_todo.pop_back(); else if (get_depth(rep->get_expr()) == 0 || !is_app(rep->get_expr())) { set_canonical(n, rep->get_expr()); @@ -626,15 +709,14 @@ namespace euf { } if (sz == m_todo.size()) { m_todo.pop_back(); - if (new_arg) - new_expr = m_rewriter.mk_app(to_app(rep->get_expr())->get_decl(), m_eargs.size(), m_eargs.data()); + if (new_arg) + new_expr = m_rewriter.mk_app(to_app(rep->get_expr())->get_decl(), m_eargs.size(), m_eargs.data()); else - new_expr = rep->get_expr(); + new_expr = rep->get_expr(); set_canonical(n, new_expr); m_deps.setx(n->get_id(), d); } } } - } - + } } \ No newline at end of file diff --git a/src/ast/simplifiers/euf_completion.h b/src/ast/simplifiers/euf_completion.h index fa96ca5f3..fb0b866e5 100644 --- a/src/ast/simplifiers/euf_completion.h +++ b/src/ast/simplifiers/euf_completion.h @@ -7,7 +7,10 @@ Module Name: Abstract: - Ground completion for equalities + Completion for (conditional) equalities. + This transforms expressions into a normal form by perorming equality saturation modulo + ground equations and E-matching on quantified axioms. + It supports conditional equations in terms of implications. Author: @@ -27,11 +30,17 @@ namespace euf { class side_condition_solver { public: + struct solution { + expr* var; + expr_ref term; + expr_ref guard; + }; virtual ~side_condition_solver() = default; virtual void add_constraint(expr* f, expr_dependency* d) = 0; virtual bool is_true(expr* f, expr_dependency*& d) = 0; virtual void push() = 0; virtual void pop(unsigned n) = 0; + virtual void solve_for(vector& sol) = 0; }; class completion : public dependent_expr_simplifier, public on_binding_callback, public mam_solver { @@ -42,12 +51,13 @@ namespace euf { void reset() { memset(this, 0, sizeof(*this)); } }; - struct ground_rule { + struct conditional_rule { expr_ref_vector m_body; expr_ref m_head; expr_dependency* m_dep; bool m_active = true; - ground_rule(expr_ref_vector& b, expr_ref& h, expr_dependency* d) : + bool m_in_queue = false; + conditional_rule(expr_ref_vector& b, expr_ref& h, expr_dependency* d) : m_body(b), m_head(h), m_dep(d) {} }; @@ -64,9 +74,11 @@ namespace euf { th_rewriter m_rewriter; stats m_stats; scoped_ptr m_side_condition_solver; - ptr_vector m_rules; + ptr_vector m_rules; bool m_has_new_eq = false; bool m_should_propagate = false; + unsigned m_max_instantiations = std::numeric_limits::max(); + vector> m_rule_watch; enode* mk_enode(expr* e); bool is_new_eq(expr* a, expr* b); @@ -87,32 +99,38 @@ namespace euf { lbool eval_cond(expr* f, expr_dependency*& d); - lbool check_rule(ground_rule& rule); - void check_rules(); + + bool should_stop(); + void add_rule(expr* f, expr_dependency* d); + void watch_rule(enode* root, enode* other); + void propagate_rule(conditional_rule& r); + void propagate_rules(); + void propagate_all_rules(); + void clear_propagation_queue(); + ptr_vector m_propagation_queue; + struct push_watch_rule; bool is_gt(expr* a, expr* b) const; public: completion(ast_manager& m, dependent_expr_state& fmls); ~completion() override; char const* name() const override { return "euf-completion"; } - void push() override { if (m_side_condition_solver) m_side_condition_solver->push(); m_egraph.push(); dependent_expr_simplifier::push(); } - void pop(unsigned n) override { dependent_expr_simplifier::pop(n); m_egraph.pop(n); if (m_side_condition_solver) m_side_condition_solver->pop(1); - } + void push() override; + void pop(unsigned n) override; void reduce() override; void collect_statistics(statistics& st) const override; void reset_statistics() override { m_stats.reset(); } + void updt_params(params_ref const& p) override; trail_stack& get_trail() override { return m_trail;} region& get_region() override { return m_trail.get_region(); } egraph& get_egraph() override { return m_egraph; } bool is_relevant(enode* n) const override { return true; } - bool resource_limits_exceeded() const override { return false; } + bool resource_limits_exceeded() const override { return m_stats.m_num_instances > m_max_instantiations; } ast_manager& get_manager() override { return m; } void on_binding(quantifier* q, app* pat, enode* const* binding, unsigned mg, unsigned ming, unsigned mx) override; - void set_solver(side_condition_solver* s) { m_side_condition_solver = s; } - }; } diff --git a/src/tactic/portfolio/euf_completion_tactic.cpp b/src/tactic/portfolio/euf_completion_tactic.cpp index aafccd69f..8fda726ac 100644 --- a/src/tactic/portfolio/euf_completion_tactic.cpp +++ b/src/tactic/portfolio/euf_completion_tactic.cpp @@ -18,6 +18,7 @@ Author: #include "tactic/tactic.h" #include "tactic/portfolio/euf_completion_tactic.h" #include "solver/solver.h" +#include "smt/smt_solver.h" class euf_side_condition_solver : public euf::side_condition_solver { ast_manager& m; @@ -25,55 +26,81 @@ class euf_side_condition_solver : public euf::side_condition_solver { scoped_ptr m_solver; expr_ref_vector m_deps; obj_map m_e2d; + expr_ref_vector m_fmls; + obj_hashtable m_seen; + trail_stack m_trail; + void init_solver() { if (m_solver.get()) return; m_params.set_uint("smt.max_conflicts", 100); - scoped_ptr f = mk_smt_strategic_solver_factory(); + scoped_ptr f = mk_smt_solver_factory(); m_solver = (*f)(m, m_params, false, false, true, symbol::null); } + public: - euf_side_condition_solver(ast_manager& m, params_ref const& p) : m(m), m_params(p), m_deps(m) {} + + euf_side_condition_solver(ast_manager& m, params_ref const& p) : + m(m), m_params(p), m_deps(m), m_fmls(m) {} void push() override { init_solver(); m_solver->push(); + m_trail.pop_scope(1); } void pop(unsigned n) override { + m_trail.push_scope(); SASSERT(m_solver.get()); m_solver->pop(n); } void add_constraint(expr* f, expr_dependency* d) override { + if (m_seen.contains(f)) + return; + m_seen.insert(f); + m_trail.push(insert_obj_trail(m_seen, f)); if (!is_ground(f)) return; + if (m.is_implies(f)) + return; init_solver(); - expr* e_dep = nullptr; if (d) { - e_dep = m.mk_fresh_const("dep", m.mk_bool_sort()); + expr* e_dep = m.mk_fresh_const("dep", m.mk_bool_sort()); m_deps.push_back(e_dep); m_e2d.insert(e_dep, d); + m_trail.push(insert_obj_map(m_e2d, e_dep)); + m_solver->assert_expr(f, e_dep); } - m_solver->assert_expr(f, e_dep); + else + m_solver->assert_expr(f); } bool is_true(expr* f, expr_dependency*& d) override { d = nullptr; - m_solver->push(); - expr_ref_vector fmls(m); - fmls.push_back(m.mk_not(f)); + solver::scoped_push _sp(*m_solver); + m_fmls.reset(); + m_fmls.push_back(m.mk_not(f)); expr_ref nf(m.mk_not(f), m); - lbool r = m_solver->check_sat(fmls); + lbool r = m_solver->check_sat(m_fmls); if (r == l_false) { expr_ref_vector core(m); m_solver->get_unsat_core(core); for (auto c : core) d = m.mk_join(d, m_e2d[c]); } - m_solver->pop(1); return r == l_false; } + + void solve_for(vector& sol) override { + vector ss; + for (auto [v, t, g] : sol) + ss.push_back({ v, t, g }); + sol.reset(); + m_solver->solve_for(ss); + for (auto [v, t, g] : ss) + sol.push_back({ v, t, g }); + } }; dependent_expr_simplifier* mk_euf_completion_simplifier(ast_manager& m, dependent_expr_state& s, params_ref const& p) {