3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-22 16:45:31 +00:00

use abstract datatype for synth objectives

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2023-08-11 13:52:41 -07:00
parent 75894a10c1
commit e2e377cfd7
2 changed files with 39 additions and 32 deletions

View file

@ -36,7 +36,7 @@ namespace synth {
bool solver::contains_uncomputable(expr* e) {
auto is_output = [&](expr* e) {
return any_of(m_synth, [&](app* a) { return synth_output(a) == e; });
return any_of(m_synth, [&](synth_objective const& a) { return a.output() == e; });
};
return any_of(subterms::all(expr_ref(e, m)), [&](expr* a) { return (is_app(a) && m_uncomputable.contains(to_app(a)->get_decl())) || is_output(a); });
}
@ -51,11 +51,11 @@ namespace synth {
}
}
void solver::add_synth_objective(app* e) {
void solver::add_synth_objective(synth_objective const& e) {
ctx.push_vec(m_synth, e);
for (unsigned i = 1; i < e->get_num_args(); ++i) {
m_is_computable.reserve(e->get_arg(i)->get_id() + 1);
ctx.push(set_bitvector_trail(m_is_computable, e->get_arg(i)->get_id())); // TODO use enode roots instead and test if they are already set.
for (auto* arg : e) {
m_is_computable.reserve(arg->get_id() + 1);
ctx.push(set_bitvector_trail(m_is_computable, arg->get_id())); // TODO use enode roots instead and test if they are already set.
}
}
@ -88,7 +88,7 @@ namespace synth {
app* a = to_app(e);
expr* arg = nullptr;
if (util.is_synthesiz3(e))
add_synth_objective(a);
add_synth_objective(synth_objective(a));
if (util.is_grammar(e))
add_uncomputable(a);
if (util.is_specification(e, arg))
@ -97,8 +97,10 @@ namespace synth {
sat::check_result solver::check() {
// TODO: need to know if there are quantifiers to instantiate
if (m_solved.size() < m_synth.size())
if (m_solved.size() < m_synth.size()) {
IF_VERBOSE(2, ctx.display(verbose_stream()));
return sat::check_result::CR_DONE;
}
if (!compute_solutions())
return sat::check_result::CR_GIVEUP;
return sat::check_result::CR_CONTINUE;
@ -106,8 +108,8 @@ namespace synth {
// display current state (eg. current set of realizers)
std::ostream& solver::display(std::ostream& out) const {
for (auto * e : m_synth)
out << "synth objective " << mk_pp(e, m) << "\n";
for (auto const& e : m_synth)
out << "synth objective " << mk_pp(e.output(), m) << "\n";
return out;
}
@ -156,8 +158,8 @@ namespace synth {
if (m_is_solved)
return;
for (app* e : m_synth) {
euf::enode* n = expr2enode(synth_output(e));
for (auto const& e : m_synth) {
euf::enode* n = expr2enode(e.output());
if (is_computable(n) && !m_solved.contains(e))
ctx.push_vec(m_solved, e);
}
@ -200,9 +202,8 @@ namespace synth {
heap.insert(id);
};
for (auto* e : m_synth) {
for (unsigned i = 1; i < e->get_num_args(); ++i) {
expr* arg = e->get_arg(i);
for (auto const& e : m_synth) {
for (expr* arg : e) {
auto* narg = expr2enode(arg);
insert_repr(narg, arg);
}
@ -215,8 +216,6 @@ namespace synth {
while (!heap.empty()) {
auto* nn = nodes[heap.erase_min()];
for (auto* p : euf::enode_parents(nn)) {
if (has_rep(p))
continue;
if (is_uncomputable(p->get_decl()))
continue;
if (!all_of(euf::enode_args(p), [&](auto* ch) { return has_rep(ch); }))
@ -238,38 +237,36 @@ namespace synth {
return repr;
}
expr_ref solver::compute_solution(expr_ref_vector const& repr, app* e) {
auto* n = expr2enode(synth_output(e));
expr_ref solver::compute_solution(expr_ref_vector const& repr, synth_objective const& e) {
auto* n = expr2enode(e.output());
return expr_ref(repr.get(n->get_root_id(), nullptr), m);
}
expr_ref solver::compute_condition(expr_ref_vector const& repr) {
expr_ref result(m.mk_and(m_spec), m);
expr_safe_replace replace(m);
for (auto* e : m_synth)
replace.insert(synth_output(e), compute_solution(repr, e));
for (auto const& e : m_synth)
replace.insert(e.output(), compute_solution(repr, e));
replace(result);
th_rewriter rw(m);
rw(result);
return result;
}
sat::literal solver::synthesize(expr_ref_vector const& repr, app* e) {
if (e->get_num_args() == 0)
return sat::null_literal;
expr_ref sol = compute_solution(repr, e);
sat::literal solver::synthesize(expr_ref_vector const& repr, synth_objective const& synth_objective) {
expr_ref sol = compute_solution(repr, synth_objective);
if (!sol)
return sat::null_literal;
IF_VERBOSE(0, verbose_stream() << sol << "\n");
return eq_internalize(synth_output(e), sol);
return eq_internalize(synth_objective.output(), sol);
}
bool solver::compute_solutions() {
sat::literal_vector clause;
auto repr = compute_rep();
for (app* e : m_synth) {
for (synth_objective const& e : m_synth) {
auto lit = synthesize(repr, e);
if (lit == sat::null_literal)
return false;
@ -277,6 +274,7 @@ namespace synth {
}
add_clause(clause);
expr_ref cond = compute_condition(repr);
add_unit(~mk_literal(cond));
IF_VERBOSE(0, verbose_stream() << "if " << cond << "\n");
return true;
}

View file

@ -39,23 +39,32 @@ namespace synth {
euf::th_solver* clone(euf::solver& ctx) override;
private:
sat::literal synthesize(expr_ref_vector const& repr, app* e);
class synth_objective {
app* obj;
public:
synth_objective(app* obj): obj(obj) { VERIFY(obj->get_num_args() > 0); }
expr* output() const { return obj->get_arg(0); }
expr* const* begin() const { return obj->get_args() + 1; }
expr* const* end() const { return obj->get_args() + obj->get_num_args(); }
bool operator==(synth_objective const& o) const { return o.obj == obj; }
};
sat::literal synthesize(expr_ref_vector const& repr, synth_objective const& synth_objective);
void add_uncomputable(app* e);
void add_synth_objective(app* e);
void add_synth_objective(synth_objective const& e);
void add_specification(app* e, expr* arg);
bool contains_uncomputable(expr* e);
void on_merge_eh(euf::enode* root, euf::enode* other);
expr_ref compute_solution(expr_ref_vector const& repr, app* synth_objective);
expr* synth_output(expr* e) const { return to_app(e)->get_arg(0); }
expr_ref compute_solution(expr_ref_vector const& repr, synth_objective const& synth_objective);
expr_ref compute_condition(expr_ref_vector const& repr);
bool compute_solutions();
expr_ref_vector compute_rep();
bool_vector m_is_computable;
bool m_is_solved = false;
ptr_vector<app> m_solved;
svector<synth_objective> m_solved;
ptr_vector<app> m_synth;
svector<synth_objective> m_synth;
obj_hashtable<func_decl> m_uncomputable;
ptr_vector<expr> m_spec;