diff --git a/src/sat/smt/synth_solver.cpp b/src/sat/smt/synth_solver.cpp index dcc7a35b6..fcfcb18f6 100644 --- a/src/sat/smt/synth_solver.cpp +++ b/src/sat/smt/synth_solver.cpp @@ -33,7 +33,7 @@ namespace synth { bool solver::contains_uncomputable(expr* e) { auto is_output = [&](expr* e) { - return any_of(m_synth, [&](app* a) { return a->get_arg(0) == e; }); + return any_of(m_synth, [&](app* a) { return synth_output(a) == e; }); }; return any_of(subterms::all(expr_ref(e, m)), [&](expr* a) { return (is_app(a) && m_uncomputable.contains(to_app(a)->get_decl())) || is_output(a); }); } @@ -48,7 +48,7 @@ namespace synth { return sat::null_literal; IF_VERBOSE(0, verbose_stream() << sol << "\n"); - return eq_internalize(e->get_arg(0), sol); + return eq_internalize(synth_output(e), sol); } // block current model using realizer by E-graph (and arithmetic) @@ -89,10 +89,13 @@ namespace synth { // This assumes that each (assert (constraint (...)) is asserting exactly one app SASSERT((e->get_num_args() == 1) && (is_app(e->get_arg(0)))); app* arg = to_app(e->get_arg(0)); - internalize(arg); + sat::literal lit = ctx.mk_literal(arg); + sat::bool_var bv = ctx.get_si().add_bool_var(e); + sat::literal lit_e(bv, false); + ctx.attach_lit(lit_e, e); + add_clause(~lit_e, lit); m_spec.insert(arg); ctx.push(insert_obj_trail(m_spec, arg)); - // TODO: assert arg <=> e } // recognize synthesis objectives here. @@ -169,33 +172,39 @@ namespace synth { } } - if (m_blockers.size() > m_blockers_qhead) + if (m_is_solved) return; for (app* e : m_synth) { - // TODO: actually wait to call compute_solution until unit propagation! - euf::enode* n = expr2enode(e->get_arg(0)); - if (is_computable(n)) { - expr_ref sol = compute_solution(e); - IF_VERBOSE(0, verbose_stream() << "solution " << sol << "\n"); - ctx.push_vec(m_blockers, ~eq_internalize(sol, n->get_expr())); - } + euf::enode* n = expr2enode(synth_output(e)); + if (is_computable(n) && !m_solved.contains(e)) + ctx.push_vec(m_solved, e); } } bool solver::unit_propagate() { - if (m_blockers_qhead >= m_blockers.size()) + if (m_is_solved) return false; - IF_VERBOSE(2, verbose_stream() << "propagate " << m_blockers_qhead << " " << m_blockers << "\n"); - ctx.push(value_trail(m_blockers_qhead)); - while (m_blockers_qhead++ < m_blockers.size()) - add_unit(m_blockers[m_blockers_qhead-1]); + if (m_solved.size() < m_synth.size()) + return false; + IF_VERBOSE(2, verbose_stream() << "propagate\n"); + ctx.push(value_trail(m_is_solved)); + m_is_solved = true; + + sat::literal_vector clause; + for (app* e : m_synth) { + auto lit = synthesize(e); + if (lit == sat::null_literal) + return false; + clause.push_back(~lit); + } + add_clause(clause); return true; } expr_ref solver::compute_solution(app* e) { - auto * n = expr2enode(e->get_arg(0)); - expr_ref_vector repr(m); + auto * n = expr2enode(synth_output(e)); + expr_ref_vector repr(m); auto get_rep = [&](euf::enode* n) { return repr.get(n->get_root_id(), nullptr); }; auto has_rep = [&](euf::enode* n) { return !!get_rep(n); }; auto set_rep = [&](euf::enode* n, expr* e) { repr.setx(n->get_root_id(), e); }; diff --git a/src/sat/smt/synth_solver.h b/src/sat/smt/synth_solver.h index 79404d8b4..aeca4ae62 100644 --- a/src/sat/smt/synth_solver.h +++ b/src/sat/smt/synth_solver.h @@ -48,9 +48,12 @@ namespace synth { void on_merge_eh(euf::enode* root, euf::enode* other); expr_ref compute_solution(app* synth_objective); + + expr* synth_output(expr* e) const { return to_app(e)->get_arg(0); } + bool_vector m_is_computable; - unsigned m_blockers_qhead = 0; - sat::literal_vector m_blockers; + bool m_is_solved = false; + ptr_vector m_solved; ptr_vector m_synth; typedef obj_hashtable func_decl_set;