From e2e377cfd728cde963a19d4605e2f61da65e5855 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Fri, 11 Aug 2023 13:52:41 -0700 Subject: [PATCH] use abstract datatype for synth objectives Signed-off-by: Nikolaj Bjorner --- src/sat/smt/synth_solver.cpp | 50 +++++++++++++++++------------------- src/sat/smt/synth_solver.h | 21 ++++++++++----- 2 files changed, 39 insertions(+), 32 deletions(-) diff --git a/src/sat/smt/synth_solver.cpp b/src/sat/smt/synth_solver.cpp index 7d931fa4b..a7ee03ba7 100644 --- a/src/sat/smt/synth_solver.cpp +++ b/src/sat/smt/synth_solver.cpp @@ -36,7 +36,7 @@ namespace synth { bool solver::contains_uncomputable(expr* e) { auto is_output = [&](expr* e) { - return any_of(m_synth, [&](app* a) { return synth_output(a) == e; }); + return any_of(m_synth, [&](synth_objective const& a) { return a.output() == e; }); }; return any_of(subterms::all(expr_ref(e, m)), [&](expr* a) { return (is_app(a) && m_uncomputable.contains(to_app(a)->get_decl())) || is_output(a); }); } @@ -51,11 +51,11 @@ namespace synth { } } - void solver::add_synth_objective(app* e) { + void solver::add_synth_objective(synth_objective const& e) { ctx.push_vec(m_synth, e); - for (unsigned i = 1; i < e->get_num_args(); ++i) { - m_is_computable.reserve(e->get_arg(i)->get_id() + 1); - ctx.push(set_bitvector_trail(m_is_computable, e->get_arg(i)->get_id())); // TODO use enode roots instead and test if they are already set. + for (auto* arg : e) { + m_is_computable.reserve(arg->get_id() + 1); + ctx.push(set_bitvector_trail(m_is_computable, arg->get_id())); // TODO use enode roots instead and test if they are already set. } } @@ -88,7 +88,7 @@ namespace synth { app* a = to_app(e); expr* arg = nullptr; if (util.is_synthesiz3(e)) - add_synth_objective(a); + add_synth_objective(synth_objective(a)); if (util.is_grammar(e)) add_uncomputable(a); if (util.is_specification(e, arg)) @@ -97,8 +97,10 @@ namespace synth { sat::check_result solver::check() { // TODO: need to know if there are quantifiers to instantiate - if (m_solved.size() < m_synth.size()) + if (m_solved.size() < m_synth.size()) { + IF_VERBOSE(2, ctx.display(verbose_stream())); return sat::check_result::CR_DONE; + } if (!compute_solutions()) return sat::check_result::CR_GIVEUP; return sat::check_result::CR_CONTINUE; @@ -106,8 +108,8 @@ namespace synth { // display current state (eg. current set of realizers) std::ostream& solver::display(std::ostream& out) const { - for (auto * e : m_synth) - out << "synth objective " << mk_pp(e, m) << "\n"; + for (auto const& e : m_synth) + out << "synth objective " << mk_pp(e.output(), m) << "\n"; return out; } @@ -156,8 +158,8 @@ namespace synth { if (m_is_solved) return; - for (app* e : m_synth) { - euf::enode* n = expr2enode(synth_output(e)); + for (auto const& e : m_synth) { + euf::enode* n = expr2enode(e.output()); if (is_computable(n) && !m_solved.contains(e)) ctx.push_vec(m_solved, e); } @@ -200,9 +202,8 @@ namespace synth { heap.insert(id); }; - for (auto* e : m_synth) { - for (unsigned i = 1; i < e->get_num_args(); ++i) { - expr* arg = e->get_arg(i); + for (auto const& e : m_synth) { + for (expr* arg : e) { auto* narg = expr2enode(arg); insert_repr(narg, arg); } @@ -215,8 +216,6 @@ namespace synth { while (!heap.empty()) { auto* nn = nodes[heap.erase_min()]; for (auto* p : euf::enode_parents(nn)) { - if (has_rep(p)) - continue; if (is_uncomputable(p->get_decl())) continue; if (!all_of(euf::enode_args(p), [&](auto* ch) { return has_rep(ch); })) @@ -238,38 +237,36 @@ namespace synth { return repr; } - expr_ref solver::compute_solution(expr_ref_vector const& repr, app* e) { - auto* n = expr2enode(synth_output(e)); + expr_ref solver::compute_solution(expr_ref_vector const& repr, synth_objective const& e) { + auto* n = expr2enode(e.output()); return expr_ref(repr.get(n->get_root_id(), nullptr), m); } expr_ref solver::compute_condition(expr_ref_vector const& repr) { expr_ref result(m.mk_and(m_spec), m); expr_safe_replace replace(m); - for (auto* e : m_synth) - replace.insert(synth_output(e), compute_solution(repr, e)); + for (auto const& e : m_synth) + replace.insert(e.output(), compute_solution(repr, e)); replace(result); th_rewriter rw(m); rw(result); return result; } - sat::literal solver::synthesize(expr_ref_vector const& repr, app* e) { - if (e->get_num_args() == 0) - return sat::null_literal; - expr_ref sol = compute_solution(repr, e); + sat::literal solver::synthesize(expr_ref_vector const& repr, synth_objective const& synth_objective) { + expr_ref sol = compute_solution(repr, synth_objective); if (!sol) return sat::null_literal; IF_VERBOSE(0, verbose_stream() << sol << "\n"); - return eq_internalize(synth_output(e), sol); + return eq_internalize(synth_objective.output(), sol); } bool solver::compute_solutions() { sat::literal_vector clause; auto repr = compute_rep(); - for (app* e : m_synth) { + for (synth_objective const& e : m_synth) { auto lit = synthesize(repr, e); if (lit == sat::null_literal) return false; @@ -277,6 +274,7 @@ namespace synth { } add_clause(clause); expr_ref cond = compute_condition(repr); + add_unit(~mk_literal(cond)); IF_VERBOSE(0, verbose_stream() << "if " << cond << "\n"); return true; } diff --git a/src/sat/smt/synth_solver.h b/src/sat/smt/synth_solver.h index e881d70d0..a7d8dbf98 100644 --- a/src/sat/smt/synth_solver.h +++ b/src/sat/smt/synth_solver.h @@ -39,23 +39,32 @@ namespace synth { euf::th_solver* clone(euf::solver& ctx) override; private: - sat::literal synthesize(expr_ref_vector const& repr, app* e); + class synth_objective { + app* obj; + public: + synth_objective(app* obj): obj(obj) { VERIFY(obj->get_num_args() > 0); } + expr* output() const { return obj->get_arg(0); } + expr* const* begin() const { return obj->get_args() + 1; } + expr* const* end() const { return obj->get_args() + obj->get_num_args(); } + bool operator==(synth_objective const& o) const { return o.obj == obj; } + }; + + sat::literal synthesize(expr_ref_vector const& repr, synth_objective const& synth_objective); void add_uncomputable(app* e); - void add_synth_objective(app* e); + void add_synth_objective(synth_objective const& e); void add_specification(app* e, expr* arg); bool contains_uncomputable(expr* e); void on_merge_eh(euf::enode* root, euf::enode* other); - expr_ref compute_solution(expr_ref_vector const& repr, app* synth_objective); - expr* synth_output(expr* e) const { return to_app(e)->get_arg(0); } + expr_ref compute_solution(expr_ref_vector const& repr, synth_objective const& synth_objective); expr_ref compute_condition(expr_ref_vector const& repr); bool compute_solutions(); expr_ref_vector compute_rep(); bool_vector m_is_computable; bool m_is_solved = false; - ptr_vector m_solved; + svector m_solved; - ptr_vector m_synth; + svector m_synth; obj_hashtable m_uncomputable; ptr_vector m_spec;