From 6e45057a649ad954ce6165431ee1beff44b479f4 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Mon, 14 Aug 2023 15:36:58 -0700 Subject: [PATCH] synth uninterpreted, take 1 Signed-off-by: Nikolaj Bjorner --- src/sat/smt/synth_solver.cpp | 73 +++++++++++++++++++++++++++++------- src/sat/smt/synth_solver.h | 23 +++++++++--- 2 files changed, 78 insertions(+), 18 deletions(-) diff --git a/src/sat/smt/synth_solver.cpp b/src/sat/smt/synth_solver.cpp index 7d078476f..599ba3b28 100644 --- a/src/sat/smt/synth_solver.cpp +++ b/src/sat/smt/synth_solver.cpp @@ -51,8 +51,9 @@ namespace synth { } } - void solver::add_synth_objective(synth_objective const& e) { - ctx.push_vec(m_synth, e); + void solver::add_synth_objective(synth_objective& e) { + m_synth.push_back(e); + ctx.push(push_back_vector(m_synth)); for (auto* arg : e) { m_is_computable.reserve(arg->get_id() + 1); ctx.push(set_bitvector_trail(m_is_computable, arg->get_id())); // TODO use enode roots instead and test if they are already set. @@ -88,7 +89,7 @@ namespace synth { app* a = to_app(e); expr* arg = nullptr; if (util.is_synthesiz3(e)) - add_synth_objective(synth_objective(a)); + add_synth_objective(synth_objective(m, a)); if (util.is_grammar(e)) add_uncomputable(a); if (util.is_specification(e, arg)) @@ -99,9 +100,20 @@ namespace synth { if (m_synth.empty()) return sat::check_result::CR_DONE; - SASSERT(m_solved.size() < m_synth.size()); + for (auto& s : m_synth) { + if (s.is_solved()) + continue; + if (m.is_uninterp(s.output()->get_sort()) && + synthesize_uninterpreted_sort(s)) + continue; + IF_VERBOSE(2, ctx.display(verbose_stream())); + return sat::check_result::CR_GIVEUP; + } + if (compute_solutions()) + return sat::check_result::CR_DONE; + IF_VERBOSE(2, ctx.display(verbose_stream())); - return sat::check_result::CR_GIVEUP; + return sat::check_result::CR_GIVEUP; } // display current state (eg. current set of realizers) @@ -163,17 +175,19 @@ namespace synth { if (m_is_solved) return; - for (auto const& e : m_synth) { + for (auto& e : m_synth) { euf::enode* n = expr2enode(e.output()); - if (is_computable(n) && !m_solved.contains(e)) - ctx.push_vec(m_solved, e); - } + if (is_computable(n) && !e.is_solved()) { + e.set_solution(nullptr); + ctx.push(synth_objective::unset_solution(e)); + } + } } bool solver::unit_propagate() { if (m_is_solved) return false; - if (m_solved.size() < m_synth.size()) + if (!all_of(m_synth, [&](synth_objective const& s) { return s.is_solved(); })) return false; IF_VERBOSE(2, verbose_stream() << "propagate\n"); ctx.push(value_trail(m_is_solved)); @@ -239,6 +253,8 @@ namespace synth { } expr_ref solver::compute_solution(synth_objective const& e) { + if (e.solution()) + return expr_ref(e.solution(), m); auto* n = expr2enode(e.output()); return expr_ref(m_rep.get(n->get_root_id(), nullptr), m); } @@ -252,9 +268,10 @@ namespace synth { th_rewriter rw(m); rw(result); IF_VERBOSE(2, ctx.display(verbose_stream())); - IF_VERBOSE(0, verbose_stream() << "simplifying: " << result << "\n"); + IF_VERBOSE(3, verbose_stream() << "simplifying: " << result << "\n"); result = simplify_condition(result.get()); - IF_VERBOSE(0, verbose_stream() << result << "\n"); + IF_VERBOSE(3, verbose_stream() << result << "\n"); + rw(result); return result; } @@ -278,6 +295,8 @@ namespace synth { continue; } + func_decl* f = to_app(a)->get_decl(); + ptr_buffer args; for (expr* arg : *to_app(a)) { n = expr2enode(arg); @@ -291,10 +310,16 @@ namespace synth { } if (args.size() == to_app(a)->get_num_args()) { todo.pop_back(); - expr_ref new_a(m.mk_app(to_app(a)->get_decl(), args), m); + expr_ref new_a(m.mk_app(f, args), m); n = expr2enode(new_a); if (n && has_rep(n)) m_rep.setx(a->get_id(), get_rep(n)); + else if (m_uncomputable.contains(f) && n->get_root()->value() != l_undef) { + if (n->get_root()->value() == l_true) + m_rep.setx(a->get_id(), m.mk_true()); + else + m_rep.setx(a->get_id(), m.mk_false()); + } else m_rep.setx(a->get_id(), new_a); } @@ -332,4 +357,26 @@ namespace synth { return true; } + bool solver::synthesize_uninterpreted_sort(synth_objective& obj) { + sort* srt = obj.output()->get_sort(); + euf::enode* r = expr2enode(obj.output()); + VERIFY(r); + if (!r) + return false; + for (auto* n : ctx.get_egraph().nodes()) { + if (n->get_sort() != srt || contains_uncomputable(n->get_expr())) + continue; + expr_ref eq(m.mk_eq(r->get_root()->get_expr(), n->get_root()->get_expr()), m); + euf::enode* eq_n = expr2enode(eq); + if (eq_n && eq_n->bool_var() != sat::null_bool_var && + s().value(eq_n->bool_var()) == l_false) + continue; + obj.set_solution(n->get_expr()); + ctx.push(synth_objective::unset_solution(obj)); + return true; + } + verbose_stream() << "synth-uninterp failed\n"; + return false; + } + } diff --git a/src/sat/smt/synth_solver.h b/src/sat/smt/synth_solver.h index ab2cf3208..53a834b84 100644 --- a/src/sat/smt/synth_solver.h +++ b/src/sat/smt/synth_solver.h @@ -41,19 +41,31 @@ namespace synth { private: class synth_objective { app* obj; + expr_ref m_solution; + bool m_is_solved = false; public: - synth_objective(app* obj): obj(obj) { VERIFY(obj->get_num_args() > 0); } + synth_objective(ast_manager& m, app* obj): obj(obj), m_solution(m) { VERIFY(obj->get_num_args() > 0); } expr* output() const { return obj->get_arg(0); } expr* const* begin() const { return obj->get_args() + 1; } expr* const* end() const { return obj->get_args() + obj->get_num_args(); } bool operator==(synth_objective const& o) const { return o.obj == obj; } + expr* solution() const { return m_solution; } + void set_solution(expr* s) { m_solution = s; m_is_solved = true; } + void clear_solution() { m_solution = nullptr; m_is_solved = false;} + bool is_solved() const { return m_is_solved; } + + struct unset_solution : public trail { + synth_objective& e; + unset_solution(synth_objective& e): e(e) {} + void undo() override { e.clear_solution(); } + }; }; sat::literal synthesize(synth_objective const& synth_objective); void add_uncomputable(app* e); - void add_synth_objective(synth_objective const& e); + void add_synth_objective(synth_objective & e); void add_specification(app* e, expr* arg); bool contains_uncomputable(expr* e); void on_merge_eh(euf::enode* root, euf::enode* other); @@ -62,6 +74,8 @@ namespace synth { bool compute_solutions(); void compute_rep(); + bool synthesize_uninterpreted_sort(synth_objective& obj); + expr* get_rep(euf::enode* n) { return m_rep.get(n->get_root_id(), nullptr); }; bool has_rep(euf::enode* n) { return !!get_rep(n); }; void set_rep(euf::enode* n, expr* e) { m_rep.setx(n->get_root_id(), e); }; @@ -70,10 +84,9 @@ namespace synth { bool_vector m_is_computable; bool m_is_solved = false; - svector m_solved; expr_ref_vector m_rep; - - svector m_synth; + + vector m_synth; obj_hashtable m_uncomputable; ptr_vector m_spec;