From 75894a10c116019b839063c534475176688ef76f Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Thu, 10 Aug 2023 18:32:47 -0700 Subject: [PATCH] adding conditions and smallest depth expressions Signed-off-by: Nikolaj Bjorner --- src/sat/smt/synth_solver.cpp | 134 ++++++++++++++++++++--------------- src/sat/smt/synth_solver.h | 18 ++--- 2 files changed, 84 insertions(+), 68 deletions(-) diff --git a/src/sat/smt/synth_solver.cpp b/src/sat/smt/synth_solver.cpp index 3224343af..7d931fa4b 100644 --- a/src/sat/smt/synth_solver.cpp +++ b/src/sat/smt/synth_solver.cpp @@ -15,6 +15,8 @@ Author: #include "util/heap.h" #include "ast/for_each_expr.h" #include "ast/synth_decl_plugin.h" +#include "ast/rewriter/expr_safe_replace.h" +#include "ast/rewriter/th_rewriter.h" #include "sat/smt/synth_solver.h" #include "sat/smt/euf_solver.h" @@ -39,35 +41,6 @@ namespace synth { return any_of(subterms::all(expr_ref(e, m)), [&](expr* a) { return (is_app(a) && m_uncomputable.contains(to_app(a)->get_decl())) || is_output(a); }); } - sat::literal solver::synthesize(app* e) { - - if (e->get_num_args() == 0) - return sat::null_literal; - - expr_ref sol = compute_solution(e); - if (!sol) - return sat::null_literal; - - IF_VERBOSE(0, verbose_stream() << sol << "\n"); - return eq_internalize(synth_output(e), sol); - } - - // block current model using realizer by E-graph (and arithmetic) - // - sat::check_result solver::check() { - sat::literal_vector clause; - for (app* e : m_synth) { - auto lit = synthesize(e); - if (lit == sat::null_literal) - return sat::check_result::CR_GIVEUP; - clause.push_back(~lit); - } - if (clause.empty()) - return sat::check_result::CR_DONE; - add_clause(clause); - return sat::check_result::CR_CONTINUE; - } - void solver::add_uncomputable(app* e) { for (expr* arg : *e) { if (is_app(arg)) { @@ -91,35 +64,44 @@ namespace synth { sat::literal lit = ctx.mk_literal(arg); sat::bool_var bv = ctx.get_si().add_bool_var(e); sat::literal lit_e(bv, false); - ctx.attach_lit(lit_e, e); + ctx.attach_lit(lit_e, e); add_clause(~lit_e, lit); - ctx.push_vec(m_spec, arg); + ctx.push_vec(m_spec, arg); } // recognize synthesis objectives here. sat::literal solver::internalize(expr* e, bool sign, bool root) { internalize(e); - sat::literal lit = ctx.expr2literal(e); - if (sign) - lit.neg(); + sat::literal lit = ctx.expr2literal(e); + if (sign) + lit.neg(); return lit; } // recognize synthesis objectives here and above void solver::internalize(expr* e) { SASSERT(is_app(e)); - sat::bool_var bv = ctx.get_si().add_bool_var(e); + sat::bool_var bv = ctx.get_si().add_bool_var(e); sat::literal lit(bv, false); - ctx.attach_lit(lit, e); + ctx.attach_lit(lit, e); synth::util util(m); - app* a = to_app(e); + app* a = to_app(e); expr* arg = nullptr; if (util.is_synthesiz3(e)) add_synth_objective(a); if (util.is_grammar(e)) - add_uncomputable(a); - if (util.is_specification(e, arg)) - add_specification(a, arg); + add_uncomputable(a); + if (util.is_specification(e, arg)) + add_specification(a, arg); + } + + sat::check_result solver::check() { + // TODO: need to know if there are quantifiers to instantiate + if (m_solved.size() < m_synth.size()) + return sat::check_result::CR_DONE; + if (!compute_solutions()) + return sat::check_result::CR_GIVEUP; + return sat::check_result::CR_CONTINUE; } // display current state (eg. current set of realizers) @@ -189,27 +171,16 @@ namespace synth { IF_VERBOSE(2, verbose_stream() << "propagate\n"); ctx.push(value_trail(m_is_solved)); m_is_solved = true; - - sat::literal_vector clause; - for (app* e : m_synth) { - auto lit = synthesize(e); - if (lit == sat::null_literal) - return false; - clause.push_back(~lit); - } - add_clause(clause); - return true; + return compute_solutions(); } - expr_ref solver::compute_solution(app* e) { - auto* n = expr2enode(synth_output(e)); + expr_ref_vector solver::compute_rep() { expr_ref_vector repr(m); auto get_rep = [&](euf::enode* n) { return repr.get(n->get_root_id(), nullptr); }; auto has_rep = [&](euf::enode* n) { return !!get_rep(n); }; auto set_rep = [&](euf::enode* n, expr* e) { repr.setx(n->get_root_id(), e); }; auto is_uncomputable = [&](func_decl* f) { return m_uncomputable.contains(f); }; - struct rep_lt { expr_ref_vector const& repr; rep_lt(expr_ref_vector& repr) : repr(repr) {} @@ -227,12 +198,14 @@ namespace synth { nodes[id] = n->get_root(); heap.reserve(id + 1); heap.insert(id); - }; + }; - for (unsigned i = 1; i < e->get_num_args(); ++i) { - expr* arg = e->get_arg(i); - auto* narg = expr2enode(arg); - insert_repr(narg, arg); + for (auto* e : m_synth) { + for (unsigned i = 1; i < e->get_num_args(); ++i) { + expr* arg = e->get_arg(i); + auto* narg = expr2enode(arg); + insert_repr(narg, arg); + } } // make sure we only insert non-input symbols. for (auto* n : ctx.get_egraph().nodes()) { @@ -262,7 +235,50 @@ namespace synth { insert_repr(p, papp); } } - return expr_ref(get_rep(n), m); + return repr; + } + + expr_ref solver::compute_solution(expr_ref_vector const& repr, app* e) { + auto* n = expr2enode(synth_output(e)); + return expr_ref(repr.get(n->get_root_id(), nullptr), m); + } + + expr_ref solver::compute_condition(expr_ref_vector const& repr) { + expr_ref result(m.mk_and(m_spec), m); + expr_safe_replace replace(m); + for (auto* e : m_synth) + replace.insert(synth_output(e), compute_solution(repr, e)); + replace(result); + th_rewriter rw(m); + rw(result); + return result; + } + + sat::literal solver::synthesize(expr_ref_vector const& repr, app* e) { + if (e->get_num_args() == 0) + return sat::null_literal; + expr_ref sol = compute_solution(repr, e); + if (!sol) + return sat::null_literal; + + IF_VERBOSE(0, verbose_stream() << sol << "\n"); + return eq_internalize(synth_output(e), sol); + } + + bool solver::compute_solutions() { + sat::literal_vector clause; + auto repr = compute_rep(); + + for (app* e : m_synth) { + auto lit = synthesize(repr, e); + if (lit == sat::null_literal) + return false; + clause.push_back(~lit); + } + add_clause(clause); + expr_ref cond = compute_condition(repr); + IF_VERBOSE(0, verbose_stream() << "if " << cond << "\n"); + return true; } } diff --git a/src/sat/smt/synth_solver.h b/src/sat/smt/synth_solver.h index 6c72b599d..e881d70d0 100644 --- a/src/sat/smt/synth_solver.h +++ b/src/sat/smt/synth_solver.h @@ -39,24 +39,24 @@ namespace synth { euf::th_solver* clone(euf::solver& ctx) override; private: - sat::literal synthesize(app* e); - void add_uncomputable(app* e); + sat::literal synthesize(expr_ref_vector const& repr, app* e); + void add_uncomputable(app* e); void add_synth_objective(app* e); - void add_specification(app* e, expr* arg); - bool contains_uncomputable(expr* e); - + void add_specification(app* e, expr* arg); + bool contains_uncomputable(expr* e); void on_merge_eh(euf::enode* root, euf::enode* other); - - expr_ref compute_solution(app* synth_objective); - + expr_ref compute_solution(expr_ref_vector const& repr, app* synth_objective); expr* synth_output(expr* e) const { return to_app(e)->get_arg(0); } + expr_ref compute_condition(expr_ref_vector const& repr); + bool compute_solutions(); + expr_ref_vector compute_rep(); bool_vector m_is_computable; bool m_is_solved = false; ptr_vector m_solved; ptr_vector m_synth; - obj_hashtable m_uncomputable; + obj_hashtable m_uncomputable; ptr_vector m_spec; };