From 69a9701b5c88b5ec18e30838432523cf4ec381c9 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Fri, 11 Aug 2023 15:47:05 -0700 Subject: [PATCH] compute normalize sketch Signed-off-by: Nikolaj Bjorner --- src/sat/smt/synth_solver.cpp | 91 ++++++++++++++++++++++++++---------- src/sat/smt/synth_solver.h | 17 +++++-- 2 files changed, 80 insertions(+), 28 deletions(-) diff --git a/src/sat/smt/synth_solver.cpp b/src/sat/smt/synth_solver.cpp index a7ee03ba7..3760bdd17 100644 --- a/src/sat/smt/synth_solver.cpp +++ b/src/sat/smt/synth_solver.cpp @@ -23,7 +23,7 @@ Author: namespace synth { solver::solver(euf::solver& ctx): - th_euf_solver(ctx, symbol("synth"), ctx.get_manager().mk_family_id("synth")) { + th_euf_solver(ctx, symbol("synth"), ctx.get_manager().mk_family_id("synth")), m_rep(m) { std::function _on_merge = [&](euf::enode* root, euf::enode* other) { on_merge_eh(root, other); @@ -176,24 +176,21 @@ namespace synth { return compute_solutions(); } - expr_ref_vector solver::compute_rep() { - expr_ref_vector repr(m); - auto get_rep = [&](euf::enode* n) { return repr.get(n->get_root_id(), nullptr); }; - auto has_rep = [&](euf::enode* n) { return !!get_rep(n); }; - auto set_rep = [&](euf::enode* n, expr* e) { repr.setx(n->get_root_id(), e); }; + void solver::compute_rep() { + m_rep.reset(); auto is_uncomputable = [&](func_decl* f) { return m_uncomputable.contains(f); }; struct rep_lt { - expr_ref_vector const& repr; - rep_lt(expr_ref_vector& repr) : repr(repr) {} + expr_ref_vector const& m_rep; + rep_lt(expr_ref_vector& m_rep) : m_rep(m_rep) {} bool operator()(int v1, int v2) const { - return get_depth(repr.get(v1)) < get_depth(repr.get(v2)); + return get_depth(m_rep.get(v1)) < get_depth(m_rep.get(v2)); }; }; - rep_lt lt(repr); + rep_lt lt(m_rep); heap heap(1000, lt); euf::enode_vector nodes; - auto insert_repr = [&](euf::enode* n, expr* r) { + auto insert_m_rep = [&](euf::enode* n, expr* r) { unsigned id = n->get_root_id(); set_rep(n, r); nodes.reserve(id + 1); @@ -205,13 +202,13 @@ namespace synth { for (auto const& e : m_synth) { for (expr* arg : e) { auto* narg = expr2enode(arg); - insert_repr(narg, arg); + insert_m_rep(narg, arg); } } // make sure we only insert non-input symbols. for (auto* n : ctx.get_egraph().nodes()) { if (n->num_args() == 0 && !contains_uncomputable(n->get_expr()) && !has_rep(n)) - insert_repr(n, n->get_expr()); + insert_m_rep(n, n->get_expr()); } while (!heap.empty()) { auto* nn = nodes[heap.erase_min()]; @@ -231,30 +228,76 @@ namespace synth { for (auto* ch : euf::enode_args(p)) args.push_back(get_rep(ch)); expr_ref papp(m.mk_app(p->get_decl(), args), m); - insert_repr(p, papp); + insert_m_rep(p, papp); } } - return repr; } - expr_ref solver::compute_solution(expr_ref_vector const& repr, synth_objective const& e) { + expr_ref solver::compute_solution(synth_objective const& e) { auto* n = expr2enode(e.output()); - return expr_ref(repr.get(n->get_root_id(), nullptr), m); + return expr_ref(m_rep.get(n->get_root_id(), nullptr), m); } - expr_ref solver::compute_condition(expr_ref_vector const& repr) { + expr_ref solver::compute_condition() { expr_ref result(m.mk_and(m_spec), m); expr_safe_replace replace(m); for (auto const& e : m_synth) - replace.insert(e.output(), compute_solution(repr, e)); + replace.insert(e.output(), compute_solution(e)); replace(result); th_rewriter rw(m); rw(result); return result; } +#if 0 - sat::literal solver::synthesize(expr_ref_vector const& repr, synth_objective const& synth_objective) { - expr_ref sol = compute_solution(repr, synth_objective); + expr_ref solver::simplify_condition(expr* e) { + ptr_vector todo; + todo.push_back(e); + while (!todo.empty()) { + expr* a = todo.back(); + if (m_rep.get(a->get_id(), nullptr)) { + todo.pop_back(); + continue; + } + euf::enode* n = expr2enode(a); + if (n && has_rep(n)) { + todo.pop_back(); + continue; + } + if (!is_app(a)) { + m_rep.setx(n->get_id(), a); + todo.pop_back(); + continue; + } + + unsigned orig_size = todo.size(); + for (expr* arg : *to_app(a)) { + if (has_rep(arg)) + args.push_back(get_rep(arg)); + else + todo.push_back(arg); + } + if (todo.size() == orig_size) { + todo.pop_back(); + expr_ref new_a(m.mk_app(to_app(a)->get_decl(), args), m); + n = expr2enode(new_a); + if (has_rep(n)) + m_rep.setx(a->get_id(), get_rep(n)); + else + m_rep.setx(a->get_id(), new_a); + } + } + euf::enode* n = expr2enode(e); + if (n && has_rep(n)) + return expr_ref(get_rep(n), m); + SASSERT(m_rep.get(e->get_id(), nullptr)); + return expr_ref(m_rep.get(e->get_id()), m); + } +#endif + + + sat::literal solver::synthesize(synth_objective const& synth_objective) { + expr_ref sol = compute_solution(synth_objective); if (!sol) return sat::null_literal; @@ -264,16 +307,16 @@ namespace synth { bool solver::compute_solutions() { sat::literal_vector clause; - auto repr = compute_rep(); + compute_rep(); for (synth_objective const& e : m_synth) { - auto lit = synthesize(repr, e); + auto lit = synthesize(e); if (lit == sat::null_literal) return false; clause.push_back(~lit); } add_clause(clause); - expr_ref cond = compute_condition(repr); + expr_ref cond = compute_condition(); add_unit(~mk_literal(cond)); IF_VERBOSE(0, verbose_stream() << "if " << cond << "\n"); return true; diff --git a/src/sat/smt/synth_solver.h b/src/sat/smt/synth_solver.h index a7d8dbf98..0b0c99521 100644 --- a/src/sat/smt/synth_solver.h +++ b/src/sat/smt/synth_solver.h @@ -49,20 +49,29 @@ namespace synth { bool operator==(synth_objective const& o) const { return o.obj == obj; } }; - sat::literal synthesize(expr_ref_vector const& repr, synth_objective const& synth_objective); + + + sat::literal synthesize(synth_objective const& synth_objective); void add_uncomputable(app* e); void add_synth_objective(synth_objective const& e); void add_specification(app* e, expr* arg); bool contains_uncomputable(expr* e); void on_merge_eh(euf::enode* root, euf::enode* other); - expr_ref compute_solution(expr_ref_vector const& repr, synth_objective const& synth_objective); - expr_ref compute_condition(expr_ref_vector const& repr); + expr_ref compute_solution(synth_objective const& synth_objective); + expr_ref compute_condition(); bool compute_solutions(); - expr_ref_vector compute_rep(); + void compute_rep(); + + expr* get_rep(euf::enode* n) { return m_rep.get(n->get_root_id(), nullptr); }; + bool has_rep(euf::enode* n) { return !!get_rep(n); }; + void set_rep(euf::enode* n, expr* e) { m_rep.setx(n->get_root_id(), e); }; + + expr_ref simplify_condition(expr* e); bool_vector m_is_computable; bool m_is_solved = false; svector m_solved; + expr_ref_vector m_rep; svector m_synth; obj_hashtable m_uncomputable;