diff --git a/src/ast/simplifiers/euf_completion.cpp b/src/ast/simplifiers/euf_completion.cpp index f3c40aa6d..1a32a1bd5 100644 --- a/src/ast/simplifiers/euf_completion.cpp +++ b/src/ast/simplifiers/euf_completion.cpp @@ -253,10 +253,14 @@ namespace euf { auto n = m_egraph.find(t); if (!n) return; - ptr_vector args; + expr_ref_vector args(m); + expr_mark visited; for (auto s : enode_class(n)) { expr_ref r(s->get_expr(), m); m_rewriter(r); + if (visited.is_marked(r)) + continue; + visited.mark(r); args.push_back(r); } expr_ref cong(m); @@ -288,8 +292,10 @@ namespace euf { propagate_rules(); propagate_closures(); IF_VERBOSE(11, verbose_stream() << "propagate " << m_stats.m_num_instances << "\n"); + if (!should_stop()) + propagate_arithmetic(); if (!m_should_propagate && !should_stop()) - propagate_all_rules(); + propagate_all_rules(); } TRACE(euf, m_egraph.display(tout)); } @@ -310,16 +316,14 @@ namespace euf { for (auto* ch : enode_args(n)) m_nodes_to_canonize.push_back(ch); }; - expr* x = nullptr, * y = nullptr; + expr* x = nullptr, * y = nullptr, * nf = nullptr; if (m.is_eq(f, x, y)) { - expr_ref x1(x, m); expr_ref y1(y, m); - m_rewriter(x1); m_rewriter(y1); - add_quantifiers(x1); + add_quantifiers(x); add_quantifiers(y1); - enode* a = mk_enode(x1); + enode* a = mk_enode(x); enode* b = mk_enode(y1); if (a->get_root() == b->get_root()) @@ -331,42 +335,28 @@ namespace euf { m_egraph.merge(a, b, to_ptr(push_pr_dep(pr, d))); m_egraph.propagate(); m_should_propagate = true; - -#if 0 - auto a1 = mk_enode(x); - auto b1 = mk_enode(y); - - if (a->get_root() != a1->get_root()) { - add_children(a1);; - m_egraph.merge(a, a1, nullptr); - m_egraph.propagate(); - } - - if (b->get_root() != b1->get_root()) { - add_children(b1); - m_egraph.merge(b, b1, nullptr); - m_egraph.propagate(); - } -#endif if (m_side_condition_solver && a->get_root() != b->get_root()) m_side_condition_solver->add_constraint(f, pr, d); IF_VERBOSE(1, verbose_stream() << "eq: " << a->get_root_id() << " " << b->get_root_id() << " " - << x1 << " == " << y1 << "\n"); + << mk_pp(x, m) << " == " << y1 << "\n"); } - else if (m.is_not(f, f)) { - enode* n = mk_enode(f); + else if (m.is_not(f, nf)) { + expr_ref f1(nf, m); + m_rewriter(f1); + enode* n = mk_enode(f1); if (m.is_false(n->get_root()->get_expr())) return; - add_quantifiers(f); + add_quantifiers(f1); + auto n_false = mk_enode(m.mk_false()); auto j = to_ptr(push_pr_dep(pr, d)); - m_egraph.new_diseq(n, j); + m_egraph.merge(n, n_false, j); m_egraph.propagate(); add_children(n); m_should_propagate = true; if (m_side_condition_solver) m_side_condition_solver->add_constraint(f, pr, d); - IF_VERBOSE(1, verbose_stream() << "not: " << mk_pp(f, m) << "\n"); + IF_VERBOSE(1, verbose_stream() << "not: " << nf << "\n"); } else { enode* n = mk_enode(f); @@ -631,6 +621,88 @@ namespace euf { } } + // + // extract shared arithmetic terms T + // extract shared variables V + // add t = rewriter(t) to E-graph + // solve for V by solver producing theta + // add theta to E-graph + // add theta to canonize (?) + // + void completion::propagate_arithmetic() { + ptr_vector shared_terms, shared_vars; + expr_mark visited; + arith_util a(m); + bool merged = false; + for (auto n : m_egraph.nodes()) { + expr* e = n->get_expr(); + if (!is_app(e)) + continue; + app* t = to_app(e); + bool is_arith = a.is_arith_expr(t); + for (auto arg : *t) { + bool is_arith_arg = a.is_arith_expr(arg); + if (is_arith_arg == is_arith) + continue; + if (visited.is_marked(arg)) + continue; + visited.mark(arg); + if (is_arith_arg) + shared_terms.push_back(arg); + else + shared_vars.push_back(arg); + } + } + for (auto t : shared_terms) { + auto tn = m_egraph.find(t); + + if (!tn) + continue; + expr_ref r(t, m); + m_rewriter(r); + if (r == t) + continue; + auto n = m_egraph.find(t); + auto t_root = tn->get_root(); + if (n && n->get_root() == t_root) + continue; + + if (!n) + n = mk_enode(r); + TRACE(euf_completion, tout << "propagate-arith: " << mk_pp(t, m) << " -> " << r << "\n"); + + m_egraph.merge(tn, n, nullptr); + merged = true; + } + visited.reset(); + for (auto v : shared_vars) { + if (visited.is_marked(v)) + continue; + visited.mark(v); + vector sol; + expr_ref term(m), guard(m); + sol.push_back({ v, term, guard }); + m_side_condition_solver->solve_for(sol); + for (auto [v, t, g] : sol) { + if (!t) + continue; + visited.mark(v); + auto a = mk_enode(v); + auto b = mk_enode(t); + if (a->get_root() == b->get_root()) + continue; + TRACE(euf_completion, tout << "propagate-arith: " << m_egraph.bpp(a) << " -> " << m_egraph.bpp(b) << "\n"); + IF_VERBOSE(1, verbose_stream() << "propagate-arith: " << m_egraph.bpp(a) << " -> " << m_egraph.bpp(b) << "\n"); + m_egraph.merge(a, b, nullptr); // TODO guard justifies reason. + merged = true; + } + } + if (merged) { + m_egraph.propagate(); + m_should_propagate = true; + } + } + void completion::propagate_closures() { for (auto [q, clos] : m_closures) { expr* body = clos.second; diff --git a/src/ast/simplifiers/euf_completion.h b/src/ast/simplifiers/euf_completion.h index 8d1a936c7..02366ee7d 100644 --- a/src/ast/simplifiers/euf_completion.h +++ b/src/ast/simplifiers/euf_completion.h @@ -187,6 +187,7 @@ namespace euf { expr_ref get_canonical(quantifier* q, proof_ref& pr, expr_dependency_ref& d); obj_map, expr*>> m_closures; + void propagate_arithmetic(); expr_dependency* explain_eq(enode* a, enode* b); proof_ref prove_eq(enode* a, enode* b); proof_ref prove_conflict();