3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-22 16:45:31 +00:00

initial stab at mbi

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2023-08-15 11:12:55 -07:00
parent 4d48dba1e3
commit 1bd73d4635
4 changed files with 292 additions and 227 deletions

View file

@ -164,80 +164,25 @@ namespace qe {
}
// -------------------------------
// uflia_mbi
// uflia_project
struct uflia_mbi::is_atom_proc {
ast_manager& m;
expr_ref_vector& m_atoms;
obj_hashtable<expr>& m_atom_set;
/**
* \brief Order arithmetical variables:
* sort arithmetical terms, such that deepest terms are first.
*/
void uflia_project::order_avars(app_ref_vector& avars) {
is_atom_proc(expr_ref_vector& atoms, obj_hashtable<expr>& atom_set):
m(atoms.m()), m_atoms(atoms), m_atom_set(atom_set) {}
void operator()(app* a) {
if (m_atom_set.contains(a)) {
// continue
}
else if (m.is_eq(a) && !m.is_iff(a)) {
m_atoms.push_back(a);
m_atom_set.insert(a);
}
else if (m.is_bool(a) && a->get_family_id() != m.get_basic_family_id()) {
m_atoms.push_back(a);
m_atom_set.insert(a);
}
}
void operator()(expr*) {}
};
uflia_mbi::uflia_mbi(solver* s, solver* sNot):
mbi_plugin(s->get_manager()),
m_atoms(m),
m_fmls(m),
m_solver(s),
m_dual_solver(sNot) {
params_ref p;
p.set_bool("core.minimize", true);
m_solver->updt_params(p);
m_dual_solver->updt_params(p);
m_solver->get_assertions(m_fmls);
collect_atoms(m_fmls);
// sort avars based on depth
std::function<bool(app*, app*)> compare_depth =
[](app* x, app* y) {
return
(x->get_depth() > y->get_depth()) ||
(x->get_depth() == y->get_depth() && x->get_id() > y->get_id());
};
std::sort(avars.data(), avars.data() + avars.size(), compare_depth);
TRACE("qe", tout << "avars:" << avars << "\n";);
}
void uflia_mbi::collect_atoms(expr_ref_vector const& fmls) {
expr_fast_mark1 marks;
is_atom_proc proc(m_atoms, m_atom_set);
for (expr* e : fmls) {
quick_for_each_expr(proc, marks, e);
}
}
bool uflia_mbi::get_literals(model_ref& mdl, expr_ref_vector& lits) {
lits.reset();
IF_VERBOSE(10, verbose_stream() << "atoms: " << m_atoms << "\n");
for (expr* e : m_atoms) {
if (mdl->is_true(e))
lits.push_back(e);
else if (mdl->is_false(e))
lits.push_back(m.mk_not(e));
}
TRACE("qe", tout << "atoms from model: " << lits << "\n";);
solver_ref dual = m_dual_solver->translate(m, m_dual_solver->get_params());
dual->assert_expr(mk_not(mk_and(m_fmls)));
lbool r = dual->check_sat(lits);
TRACE("qe", dual->display(tout << "dual result " << r << "\n"););
if (l_false == r) {
// use the dual solver to find a 'small' implicant
lits.reset();
dual->get_unsat_core(lits);
return true;
}
else {
return false;
}
}
/**
* \brief A subterm is an arithmetic variable if:
* 1. it is not shared.
@ -246,7 +191,7 @@ namespace qe {
*
* The result is ordered using deepest term first.
*/
app_ref_vector uflia_mbi::get_arith_vars(expr_ref_vector const& lits) {
app_ref_vector uflia_project::get_arith_vars(expr_ref_vector const& lits) {
app_ref_vector avars(m);
bool_vector seen;
arith_util a(m);
@ -273,7 +218,7 @@ namespace qe {
For these cases we apply model refinement to the literals: non-shared
sub-expressions are replaced by model values.
*/
void uflia_mbi::fix_non_shared(model& mdl, expr_ref_vector& lits) {
void uflia_project::fix_non_shared(model& mdl, expr_ref_vector& lits) {
th_rewriter rewrite(m);
expr_ref_vector trail(m);
obj_map<expr, expr*> cache;
@ -325,7 +270,7 @@ namespace qe {
lits[i] = cache[lits.get(i)];
}
vector<mbp::def> uflia_mbi::arith_project(model_ref& mdl, app_ref_vector& avars, expr_ref_vector& lits) {
vector<mbp::def> uflia_project::arith_project(model_ref& mdl, app_ref_vector& avars, expr_ref_vector& lits) {
mbp::arith_project_plugin ap(m);
ap.set_check_purified(false);
vector<mbp::def> defs;
@ -336,6 +281,196 @@ namespace qe {
return defs;
}
void uflia_project::split_arith(expr_ref_vector const& lits,
expr_ref_vector& alits,
expr_ref_vector& uflits) {
arith_util a(m);
for (expr* lit : lits) {
expr* atom = lit, *x = nullptr, *y = nullptr;
m.is_not(lit, atom);
if (m.is_eq(atom, x, y)) {
if (a.is_int_real(x)) {
alits.push_back(lit);
}
uflits.push_back(lit);
}
else if (a.is_arith_expr(atom)) {
alits.push_back(lit);
}
else {
uflits.push_back(lit);
}
}
TRACE("qe",
tout << "alits: " << alits << "\n";
tout << "uflits: " << uflits << "\n";);
}
/**
\brief add difference certificates to formula.
*/
void uflia_project::add_dcert(model_ref& mdl, expr_ref_vector& lits) {
mbp::term_graph tg(m);
add_arith_dcert(*mdl.get(), lits);
func_decl_ref_vector shared(m_shared_trail);
tg.set_vars(shared, false);
lits.append(tg.dcert(*mdl.get(), lits));
TRACE("qe", tout << "project: " << lits << "\n";);
}
/**
Add disequalities between functions that appear in arithmetic context.
*/
void uflia_project::add_arith_dcert(model& mdl, expr_ref_vector& lits) {
obj_map<func_decl, ptr_vector<app>> apps;
arith_util a(m);
for (expr* e : subterms::ground(lits)) {
if (a.is_int_real(e) && is_uninterp(e) && to_app(e)->get_num_args() > 0) {
func_decl* f = to_app(e)->get_decl();
apps.insert_if_not_there(f, ptr_vector<app>()).push_back(to_app(e));
}
}
for (auto const& kv : apps) {
ptr_vector<app> const& es = kv.m_value;
expr_ref_vector values(m);
for (expr* e : kv.m_value) values.push_back(mdl(e));
for (unsigned i = 0; i < es.size(); ++i) {
expr* v1 = values.get(i);
for (unsigned j = i + 1; j < es.size(); ++j) {
expr* v2 = values.get(j);
if (v1 != v2) {
add_arith_dcert(mdl, lits, es[i], es[j]);
}
}
}
}
}
void uflia_project::add_arith_dcert(model& mdl, expr_ref_vector& lits, app* a, app* b) {
arith_util arith(m);
SASSERT(a->get_decl() == b->get_decl());
for (unsigned i = a->get_num_args(); i-- > 0; ) {
expr* arg1 = a->get_arg(i), *arg2 = b->get_arg(i);
if (arith.is_int_real(arg1) && mdl(arg1) != mdl(arg2)) {
lits.push_back(m.mk_not(m.mk_eq(arg1, arg2)));
return;
}
}
}
/**
* \brief project private symbols.
*/
void uflia_project::project_euf(model_ref& mdl, expr_ref_vector& lits) {
mbp::term_graph tg(m);
func_decl_ref_vector shared(m_shared_trail);
tg.set_vars(shared, false);
tg.add_lits(lits);
lits.reset();
lits.append(tg.project(*mdl.get()));
TRACE("qe", tout << "project: " << lits << "\n";);
}
vector<mbp::def> uflia_project::project_solve(model_ref& mdl, expr_ref_vector& lits) {
TRACE("qe", tout << "project literals: " << lits << "\n" << *mdl << "\n");
add_dcert(mdl, lits);
expr_ref_vector alits(m), uflits(m);
split_arith(lits, alits, uflits);
auto avars = get_arith_vars(lits);
vector<mbp::def> defs = arith_project(mdl, avars, alits);
for (auto const& d : defs) uflits.push_back(m.mk_eq(d.var, d.term));
TRACE("qe", tout << "uflits: " << uflits << "\n";);
project_euf(mdl, uflits);
lits.reset();
lits.append(alits);
lits.append(uflits);
IF_VERBOSE(10, verbose_stream() << "projection : " << lits << "\n");
TRACE("qe",
tout << "projection: " << lits << "\n";
tout << "avars: " << avars << "\n";
tout << "alits: " << lits << "\n";
tout << "uflits: " << uflits << "\n";);
return defs;
}
// -------------------------------
// uflia_mbi
struct uflia_mbi::is_atom_proc {
ast_manager& m;
expr_ref_vector& m_atoms;
obj_hashtable<expr>& m_atom_set;
is_atom_proc(expr_ref_vector& atoms, obj_hashtable<expr>& atom_set):
m(atoms.m()), m_atoms(atoms), m_atom_set(atom_set) {}
void operator()(app* a) {
if (m_atom_set.contains(a)) {
// continue
}
else if (m.is_eq(a) && !m.is_iff(a)) {
m_atoms.push_back(a);
m_atom_set.insert(a);
}
else if (m.is_bool(a) && a->get_family_id() != m.get_basic_family_id()) {
m_atoms.push_back(a);
m_atom_set.insert(a);
}
}
void operator()(expr*) {}
};
uflia_mbi::uflia_mbi(solver* s, solver* sNot):
uflia_project(s->get_manager()),
m_atoms(m),
m_fmls(m),
m_solver(s),
m_dual_solver(sNot) {
params_ref p;
p.set_bool("core.minimize", true);
m_solver->updt_params(p);
m_dual_solver->updt_params(p);
m_solver->get_assertions(m_fmls);
collect_atoms(m_fmls);
}
void uflia_mbi::collect_atoms(expr_ref_vector const& fmls) {
expr_fast_mark1 marks;
is_atom_proc proc(m_atoms, m_atom_set);
for (expr* e : fmls) {
quick_for_each_expr(proc, marks, e);
}
}
bool uflia_mbi::get_literals(model_ref& mdl, expr_ref_vector& lits) {
lits.reset();
IF_VERBOSE(10, verbose_stream() << "atoms: " << m_atoms << "\n");
for (expr* e : m_atoms) {
if (mdl->is_true(e))
lits.push_back(e);
else if (mdl->is_false(e))
lits.push_back(m.mk_not(e));
}
TRACE("qe", tout << "atoms from model: " << lits << "\n";);
solver_ref dual = m_dual_solver->translate(m, m_dual_solver->get_params());
dual->assert_expr(mk_not(mk_and(m_fmls)));
lbool r = dual->check_sat(lits);
TRACE("qe", dual->display(tout << "dual result " << r << "\n"););
if (l_false == r) {
// use the dual solver to find a 'small' implicant
lits.reset();
dual->get_unsat_core(lits);
return true;
}
else {
return false;
}
}
mbi_result uflia_mbi::operator()(expr_ref_vector& lits, model_ref& mdl) {
lbool r = m_solver->check_sat(lits);
@ -367,137 +502,13 @@ namespace qe {
\brief main projection routine
*/
void uflia_mbi::project(model_ref& mdl, expr_ref_vector& lits) {
TRACE("qe",
tout << "project literals: " << lits << "\n" << *mdl << "\n";
tout << m_solver->get_assertions() << "\n";);
add_dcert(mdl, lits);
expr_ref_vector alits(m), uflits(m);
split_arith(lits, alits, uflits);
auto avars = get_arith_vars(lits);
vector<mbp::def> defs = arith_project(mdl, avars, alits);
for (auto const& d : defs) uflits.push_back(m.mk_eq(d.var, d.term));
TRACE("qe", tout << "uflits: " << uflits << "\n";);
project_euf(mdl, uflits);
lits.reset();
lits.append(alits);
lits.append(uflits);
IF_VERBOSE(10, verbose_stream() << "projection : " << lits << "\n");
TRACE("qe",
tout << "projection: " << lits << "\n";
tout << "avars: " << avars << "\n";
tout << "alits: " << lits << "\n";
tout << "uflits: " << uflits << "\n";);
}
void uflia_mbi::split_arith(expr_ref_vector const& lits,
expr_ref_vector& alits,
expr_ref_vector& uflits) {
arith_util a(m);
for (expr* lit : lits) {
expr* atom = lit, *x = nullptr, *y = nullptr;
m.is_not(lit, atom);
if (m.is_eq(atom, x, y)) {
if (a.is_int_real(x)) {
alits.push_back(lit);
}
uflits.push_back(lit);
}
else if (a.is_arith_expr(atom)) {
alits.push_back(lit);
}
else {
uflits.push_back(lit);
}
}
TRACE("qe",
tout << "alits: " << alits << "\n";
tout << "uflits: " << uflits << "\n";);
project_solve(mdl, lits);
}
/**
\brief add difference certificates to formula.
*/
void uflia_mbi::add_dcert(model_ref& mdl, expr_ref_vector& lits) {
mbp::term_graph tg(m);
add_arith_dcert(*mdl.get(), lits);
func_decl_ref_vector shared(m_shared_trail);
tg.set_vars(shared, false);
lits.append(tg.dcert(*mdl.get(), lits));
TRACE("qe", tout << "project: " << lits << "\n";);
}
/**
Add disequalities between functions that appear in arithmetic context.
*/
void uflia_mbi::add_arith_dcert(model& mdl, expr_ref_vector& lits) {
obj_map<func_decl, ptr_vector<app>> apps;
arith_util a(m);
for (expr* e : subterms::ground(lits)) {
if (a.is_int_real(e) && is_uninterp(e) && to_app(e)->get_num_args() > 0) {
func_decl* f = to_app(e)->get_decl();
apps.insert_if_not_there(f, ptr_vector<app>()).push_back(to_app(e));
}
}
for (auto const& kv : apps) {
ptr_vector<app> const& es = kv.m_value;
expr_ref_vector values(m);
for (expr* e : kv.m_value) values.push_back(mdl(e));
for (unsigned i = 0; i < es.size(); ++i) {
expr* v1 = values.get(i);
for (unsigned j = i + 1; j < es.size(); ++j) {
expr* v2 = values.get(j);
if (v1 != v2) {
add_arith_dcert(mdl, lits, es[i], es[j]);
}
}
}
}
}
void uflia_mbi::add_arith_dcert(model& mdl, expr_ref_vector& lits, app* a, app* b) {
arith_util arith(m);
SASSERT(a->get_decl() == b->get_decl());
for (unsigned i = a->get_num_args(); i-- > 0; ) {
expr* arg1 = a->get_arg(i), *arg2 = b->get_arg(i);
if (arith.is_int_real(arg1) && mdl(arg1) != mdl(arg2)) {
lits.push_back(m.mk_not(m.mk_eq(arg1, arg2)));
return;
}
}
}
/**
* \brief project private symbols.
*/
void uflia_mbi::project_euf(model_ref& mdl, expr_ref_vector& lits) {
mbp::term_graph tg(m);
func_decl_ref_vector shared(m_shared_trail);
tg.set_vars(shared, false);
tg.add_lits(lits);
lits.reset();
lits.append(tg.project(*mdl.get()));
TRACE("qe", tout << "project: " << lits << "\n";);
}
/**
* \brief Order arithmetical variables:
* sort arithmetical terms, such that deepest terms are first.
*/
void uflia_mbi::order_avars(app_ref_vector& avars) {
// sort avars based on depth
std::function<bool(app*, app*)> compare_depth =
[](app* x, app* y) {
return
(x->get_depth() > y->get_depth()) ||
(x->get_depth() == y->get_depth() && x->get_id() > y->get_id());
};
std::sort(avars.data(), avars.data() + avars.size(), compare_depth);
TRACE("qe", tout << "avars:" << avars << "\n";);
}
void uflia_mbi::block(expr_ref_vector const& lits) {
expr_ref clause(mk_not(mk_and(lits)), m);

View file

@ -115,7 +115,29 @@ namespace qe {
void block(expr_ref_vector const& lits) override;
};
class uflia_mbi : public mbi_plugin {
class uflia_project : public mbi_plugin {
protected:
void order_avars(app_ref_vector& avars);
app_ref_vector get_arith_vars(expr_ref_vector const& lits);
void fix_non_shared(model& mdl, expr_ref_vector& lits);
vector<::mbp::def> arith_project(model_ref& mdl, app_ref_vector& avars, expr_ref_vector& lits);
void add_dcert(model_ref& mdl, expr_ref_vector& lits);
void add_arith_dcert(model& mdl, expr_ref_vector& lits);
void add_arith_dcert(model& mdl, expr_ref_vector& lits, app* a, app* b);
void project_euf(model_ref& mdl, expr_ref_vector& lits);
void split_arith(expr_ref_vector const& lits,
expr_ref_vector& alits,
expr_ref_vector& uflits);
public:
uflia_project(ast_manager& m): mbi_plugin(m) {}
vector<::mbp::def> project_solve(model_ref& mdl, expr_ref_vector& lits);
void block(expr_ref_vector const& lits) override {}
mbi_result operator()(expr_ref_vector& lits, model_ref& mdl) override { return mbi_result::mbi_undef; }
};
class uflia_mbi : public uflia_project {
expr_ref_vector m_atoms;
obj_hashtable<expr> m_atom_set;
expr_ref_vector m_fmls;
@ -125,18 +147,8 @@ namespace qe {
bool get_literals(model_ref& mdl, expr_ref_vector& lits);
void collect_atoms(expr_ref_vector const& fmls);
void order_avars(app_ref_vector& avars);
void add_dcert(model_ref& mdl, expr_ref_vector& lits);
void add_arith_dcert(model& mdl, expr_ref_vector& lits);
void add_arith_dcert(model& mdl, expr_ref_vector& lits, app* a, app* b);
app_ref_vector get_arith_vars(expr_ref_vector const& lits);
vector<::mbp::def> arith_project(model_ref& mdl, app_ref_vector& avars, expr_ref_vector& lits);
void project_euf(model_ref& mdl, expr_ref_vector& lits);
void split_arith(expr_ref_vector const& lits,
expr_ref_vector& alits,
expr_ref_vector& uflits);
void fix_non_shared(model& mdl, expr_ref_vector& lits);
public:
uflia_mbi(solver* s, solver* emptySolver);
mbi_result operator()(expr_ref_vector& lits, model_ref& mdl) override;

View file

@ -19,9 +19,7 @@ Author:
#include "ast/rewriter/th_rewriter.h"
#include "sat/smt/synth_solver.h"
#include "sat/smt/euf_solver.h"
#include "qe/mbp/mbp_term_graph.h"
#include "qe/mbp/mbp_arith.h"
#include "qe/mbp/mbp_arrays.h"
#include "qe/qe_mbi.h"
namespace synth {
@ -36,11 +34,12 @@ namespace synth {
solver::~solver() {}
bool solver::is_output(expr * e) const {
return any_of(m_synth, [&](synth_objective const& a) { return a.output() == e; });
}
bool solver::contains_uncomputable(expr* e) {
auto is_output = [&](expr* e) {
return any_of(m_synth, [&](synth_objective const& a) { return a.output() == e; });
};
return any_of(subterms::all(expr_ref(e, m)), [&](expr* a) { return (is_app(a) && m_uncomputable.contains(to_app(a)->get_decl())) || is_output(a); });
}
@ -353,12 +352,11 @@ namespace synth {
compute_rep();
for (synth_objective const& e : m_synth) {
auto lit = synthesize(e);
if (lit == sat::null_literal)
expr_ref sol = compute_solution(e);
if (!sol)
return false;
clause.push_back(~lit);
IF_VERBOSE(0, verbose_stream() << sol << "\n");
}
add_clause(clause);
expr_ref cond = compute_condition();
add_unit(~mk_literal(cond));
IF_VERBOSE(0, verbose_stream() << "if " << cond << "\n");
@ -412,13 +410,57 @@ namespace synth {
arith_util a(m);
if (!a.is_int_real(obj.output()))
return false;
model_ref mdl = alloc(model, m);
ctx.update_model(mdl, false);
verbose_stream() << "int-real-objective\n";
verbose_stream() << *mdl << "\n";
expr_ref_vector lits(m), core(m);
for (unsigned i = 0; i < s().trail_size(); ++i) {
sat::literal l = s().trail_literal(i);
if (!ctx.is_relevant(l))
continue;
expr_ref e = literal2expr(l);
if (e)
lits.push_back(e);
}
verbose_stream() << lits << "\n";
sat::no_drat_params no_drat_params;
ref<::solver> solver = mk_smt2_solver(m, no_drat_params, symbol::null);
solver->assert_expr(m.mk_not(m.mk_and(m_spec)));
lbool r = solver->check_sat(lits);
if (r != l_false)
return false;
solver->get_unsat_core(core);
verbose_stream() << "core " << core << "\n";
qe::uflia_project proj(m);
auto& egraph = ctx.get_egraph();
func_decl_ref_vector shared(m);
ast_mark visited;
for (auto* n : egraph.nodes())
if (is_app(n->get_expr()) && !is_output(n->get_expr()) && !m_uncomputable.contains(n->get_decl()) && !visited.is_marked(n->get_decl())) {
visited.mark(n->get_decl(), true);
shared.push_back(n->get_decl());
}
verbose_stream() << "shared " << shared << "\n";
proj.set_shared(shared);
auto defs = proj.project_solve(mdl, core);
for (auto const& d : defs) {
verbose_stream() << d.var << " := " << d.term << "\n";
if (d.var == obj.output()) {
obj.set_solution(d.term);
ctx.push(synth_objective::unset_solution(obj));
return true;
}
}
#if 0
// 1 retrieve a model
// 1.5 - difference cert?
// 1.6 - split arith?
// 2 retrieve literal dependencies
// 3 split_arith, arith_vars, rpoejct, project_euf,
// - retrieve literal dependencies
// - difference cert?
// - split arith?
// - split_arith, arith_vars, rpoejct, project_euf,
// produce projection
add_dcert(mdl, lits);

View file

@ -62,7 +62,7 @@ namespace synth {
};
bool is_output(expr* e) const;
sat::literal synthesize(synth_objective const& synth_objective);
void add_uncomputable(app* e);
void add_synth_objective(synth_objective const & e);