diff --git a/src/sat/smt/synth_solver.cpp b/src/sat/smt/synth_solver.cpp index 20de66fb3..37677092c 100644 --- a/src/sat/smt/synth_solver.cpp +++ b/src/sat/smt/synth_solver.cpp @@ -12,6 +12,7 @@ Author: --*/ +#include "util/heap.h" #include "ast/for_each_expr.h" #include "ast/synth_decl_plugin.h" #include "sat/smt/synth_solver.h" @@ -208,16 +209,39 @@ namespace synth { auto set_rep = [&](euf::enode* n, expr* e) { repr.setx(n->get_root_id(), e); }; auto is_uncomputable = [&](func_decl* f) { return m_uncomputable.contains(f); }; - euf::enode_vector todo; + + struct rep_lt { + expr_ref_vector const& repr; + rep_lt(expr_ref_vector& repr): repr(repr) {} + bool operator()(int v1, int v2) const { + return get_depth(repr.get(v1)) < get_depth(repr.get(v2)); + }; + }; + rep_lt lt(repr); + heap heap(1000, lt); + euf::enode_vector nodes; + auto insert_repr = [&](euf::enode* n, expr* r) { + unsigned id = n->get_root_id(); + set_rep(n, r); + nodes.reserve(id + 1); + nodes[id] = n->get_root(); + heap.reserve(id + 1); + heap.insert(id); + }; + for (unsigned i = 1; i < e->get_num_args(); ++i) { expr * arg = e->get_arg(i); auto * narg = expr2enode(arg); - todo.push_back(narg); - set_rep(narg, arg); + insert_repr(narg, arg); } - for (unsigned i = 0; i < todo.size() && !has_rep(n); ++i) { - auto * nn = todo[i]; + // make sure we only insert non-input symbols. + for (auto* n : ctx.get_egraph().nodes()) { + if (n->num_args() == 0 && !contains_uncomputable(n->get_expr()) && !has_rep(n)) + insert_repr(n, n->get_expr()); + } + while (!heap.empty()) { + auto * nn = nodes[heap.erase_min()]; for (auto * p : euf::enode_parents(nn)) { if (has_rep(p)) continue; @@ -225,12 +249,18 @@ namespace synth { continue; if (!all_of(euf::enode_args(p), [&](auto * ch) { return has_rep(ch); })) continue; + expr* r = get_rep(p); + if (r) { + unsigned depth = get_depth(r); + if (any_of(euf::enode_args(p), [&](auto* ch) { return get_depth(get_rep(ch)) >= depth; })) + continue; + heap.erase(p->get_root_id()); + } ptr_buffer args; for (auto * ch : euf::enode_args(p)) args.push_back(get_rep(ch)); - app * papp = m.mk_app(p->get_decl(), args); - set_rep(p, papp); - todo.push_back(p); + expr_ref papp(m.mk_app(p->get_decl(), args), m); + insert_repr(p, papp); } } return expr_ref(get_rep(n), m);