3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-06 09:34:08 +00:00

add mbqi to smtfd. For Nuno, of course

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2019-09-09 11:28:25 +02:00
parent c22a17f430
commit b1cdb3e451
2 changed files with 241 additions and 41 deletions

View file

@ -156,7 +156,7 @@ namespace smt {
TRACE("model_checker", tout << "q after applying interpretation:\n" << mk_ismt2_pp(tmp, m) << "\n";);
ptr_buffer<expr> subst_args;
unsigned num_decls = q->get_num_decls();
subst_args.resize(num_decls, 0);
subst_args.resize(num_decls, nullptr);
sks.resize(num_decls, nullptr);
for (unsigned i = 0; i < num_decls; i++) {
sort * s = q->get_decl_sort(num_decls - i - 1);

View file

@ -54,7 +54,7 @@
else:
table[f][v_args] := v1, t
for t in subterms(core) where t is select(A, args):
for t in subterms((core) where t is select(A, args):
vA := M(abs(A))
v_args = M(abs(args))
v2, args2, t2 := table[vA][v_args]
@ -126,6 +126,7 @@ Note:
#include "ast/for_each_expr.h"
#include "ast/pb_decl_plugin.h"
#include "ast/rewriter/th_rewriter.h"
#include "ast/rewriter/var_subst.h"
#include "tactic/tactic_exception.h"
#include "tactic/fd_solver/fd_solver.h"
#include "solver/solver.h"
@ -336,15 +337,19 @@ namespace smtfd {
class theory_plugin;
class plugin_context {
smtfd_abs& m_abs;
expr_ref_vector m_lemmas;
unsigned m_max_lemmas;
ptr_vector<theory_plugin> m_plugins;
public:
plugin_context(ast_manager& m, unsigned max):
plugin_context(smtfd_abs& a, ast_manager& m, unsigned max):
m_abs(a),
m_lemmas(m),
m_max_lemmas(max)
{}
smtfd_abs& get_abs() { return m_abs; }
void add(expr* f) { m_lemmas.push_back(f); }
ast_manager& get_manager() { return m_lemmas.get_manager(); }
@ -412,9 +417,9 @@ namespace smtfd {
}
public:
theory_plugin(smtfd_abs& a, plugin_context& context, model_ref& mdl) :
theory_plugin(plugin_context& context, model_ref& mdl) :
m(context.get_manager()),
m_abs(a),
m_abs(context.get_abs()),
m_context(context),
m_model(mdl),
m_values(m),
@ -548,8 +553,8 @@ namespace smtfd {
class basic_plugin : public theory_plugin {
public:
basic_plugin(smtfd_abs& a, plugin_context& context, model_ref& mdl):
theory_plugin(a, context, mdl)
basic_plugin(plugin_context& context, model_ref& mdl):
theory_plugin(context, mdl)
{}
void check_term(expr* t, unsigned round) override { }
bool term_covered(expr* t) override { return is_app(t) && to_app(t)->get_family_id() == m.get_basic_family_id(); }
@ -563,8 +568,8 @@ namespace smtfd {
class pb_plugin : public theory_plugin {
pb_util m_pb;
public:
pb_plugin(smtfd_abs& a, plugin_context& context, model_ref& mdl):
theory_plugin(a, context, mdl),
pb_plugin(plugin_context& context, model_ref& mdl):
theory_plugin(context, mdl),
m_pb(m)
{}
void check_term(expr* t, unsigned round) override { }
@ -579,8 +584,8 @@ namespace smtfd {
class bv_plugin : public theory_plugin {
bv_util m_butil;
public:
bv_plugin(smtfd_abs& a, plugin_context& context, model_ref& mdl):
theory_plugin(a, context, mdl),
bv_plugin(plugin_context& context, model_ref& mdl):
theory_plugin(context, mdl),
m_butil(m)
{}
void check_term(expr* t, unsigned round) override { }
@ -635,8 +640,8 @@ namespace smtfd {
public:
uf_plugin(smtfd_abs& a, plugin_context& context, model_ref& mdl):
theory_plugin(a, context, mdl),
uf_plugin(plugin_context& context, model_ref& mdl):
theory_plugin(context, mdl),
m_pinned(m)
{}
@ -898,8 +903,8 @@ namespace smtfd {
public:
a_plugin(smtfd_abs& a, plugin_context& context, model_ref& mdl):
theory_plugin(a, context, mdl),
a_plugin(plugin_context& context, model_ref& mdl):
theory_plugin(context, mdl),
m_autil(m),
m_rewriter(m)
{}
@ -1026,13 +1031,171 @@ namespace smtfd {
}
}
}
};
class mbqi {
ast_manager& m;
plugin_context& m_context;
obj_hashtable<quantifier>& m_enforced;
model_ref m_model;
ref<::solver> m_solver;
expr_ref_vector m_pinned;
obj_map<expr, expr*> m_val2term;
expr* abs(expr* e) { return m_context.get_abs().abs(e); }
expr_ref eval_abs(expr* t) { return (*m_model)(abs(t)); }
// !Ex P(x) => !P(t)
// Ax P(x) => P(t)
// l_true: new instance
// l_false: no new instance
// l_undef unresolved
lbool check_forall(quantifier* q) {
expr_ref tmp(m);
unsigned sz = q->get_num_decls();
if (!m_model->eval_expr(q->get_expr(), tmp, true)) {
return l_undef;
}
expr_ref_vector vars(m), vals(m);
vars.resize(sz, nullptr);
vals.resize(sz, nullptr);
for (unsigned i = 0; i < sz; ++i) {
vars[sz - i - 1] = m.mk_fresh_const(q->get_decl_name(i), q->get_decl_sort(i));
// TBD: finite domain variables
}
var_subst subst(m);
expr_ref body = subst(tmp, vars.size(), vars.c_ptr());
if (is_forall(q)) {
body = m.mk_not(body);
}
m_solver->push();
m_solver->assert_expr(body);
lbool r = m_solver->check_sat(0, nullptr);
model_ref mdl;
if (r == l_true) {
expr_ref qq(q->get_expr(), m);
for (expr* t : subterms(qq)) {
if (is_ground(t)) {
expr_ref v = eval_abs(t);
m_pinned.push_back(v);
m_val2term.insert(v, t);
}
}
m_solver->get_model(mdl);
for (unsigned i = 0; i < sz; ++i) {
app* v = to_app(vars.get(i));
func_decl* f = v->get_decl();
expr_ref val(mdl->get_some_const_interp(f), m);
if (!val) {
r = l_undef;
break;
}
expr* t = nullptr;
if (m_val2term.find(val, t)) {
val = t;
}
vals[i] = val;
}
if (r == l_true) {
body = subst(q->get_expr(), vals.size(), vals.c_ptr());
if (is_forall(q)) {
body = m.mk_implies(q, body);
}
else {
body = m.mk_implies(body, q);
}
body = abs(body);
m_context.add(body);
}
}
m_solver->pop(1);
return r;
}
bool is_enforced(quantifier* q) {
return m_enforced.contains(q);
}
lbool check_exists(quantifier* q) {
if (is_enforced(q)) {
return l_true;
}
expr_ref tmp(m);
expr_ref_vector vars(m);
unsigned sz = q->get_num_decls();
vars.resize(sz, nullptr);
for (unsigned i = 0; i < sz; ++i) {
vars[sz - i - 1] = m.mk_fresh_const(q->get_decl_name(i), q->get_decl_sort(i));
}
var_subst subst(m);
expr_ref body = subst(tmp, vars.size(), vars.c_ptr());
if (is_exists(q)) {
body = m.mk_implies(q, body);
}
else {
body = m.mk_implies(body, q);
}
m_enforced.insert(q);
m_context.add(abs(body));
return l_true;
}
void init_val2term(expr_ref_vector const& core) {
for (expr* t : subterms(core)) {
if (!m.is_bool(t) && is_ground(t)) {
expr_ref v = eval_abs(t);
m_pinned.push_back(v);
m_val2term.insert(v, t);
}
}
}
public:
mbqi(::solver* s, plugin_context& c, obj_hashtable<quantifier>& enforced, model_ref& mdl):
m(s->get_manager()),
m_context(c),
m_enforced(enforced),
m_model(mdl),
m_solver(s),
m_pinned(m)
{}
bool check_quantifiers(expr_ref_vector const& core) {
bool result = true;
init_val2term(core);
for (expr* c : core) {
lbool r = l_false;
if (is_forall(c)) {
r = check_forall(to_quantifier(c));
}
else if (is_exists(c)) {
r = check_exists(to_quantifier(c));
}
else if (m.is_not(c, c)) {
if (is_forall(c)) {
r = check_exists(to_quantifier(c));
}
else if (is_exists(c)) {
r = check_forall(to_quantifier(c));
}
}
if (r == l_undef) {
result = false;
}
}
return result;
}
};
struct stats {
unsigned m_num_lemmas;
unsigned m_num_rounds;
unsigned m_num_mbqi;
stats() { memset(this, 0, sizeof(stats)); }
};
@ -1042,6 +1205,7 @@ namespace smtfd {
ref<::solver> m_fd_sat_solver;
ref<::solver> m_fd_core_solver;
ref<::solver> m_smt_solver;
ref<solver> m_mbqi_solver;
expr_ref_vector m_assertions;
unsigned_vector m_assertions_lim;
unsigned m_assertions_qhead;
@ -1053,6 +1217,7 @@ namespace smtfd {
unsigned m_max_lemmas;
stats m_stats;
unsigned m_max_conflicts;
obj_hashtable<quantifier> m_enforced_quantifier;
void set_delay_simplify() {
params_ref p;
@ -1141,10 +1306,12 @@ namespace smtfd {
return r;
}
bool add_theory_lemmas(expr_ref_vector const& core) {
plugin_context context(m, m_max_lemmas);
a_plugin ap(m_abs, context, m_model);
uf_plugin uf(m_abs, context, m_model);
plugin_context context(m_abs, m, m_max_lemmas);
a_plugin ap(context, m_model);
uf_plugin uf(context, m_model);
unsigned max_rounds = std::max(ap.max_rounds(), uf.max_rounds());
for (unsigned round = 0; round < max_rounds; ++round) {
for (expr* t : subterms(core)) {
@ -1165,22 +1332,48 @@ namespace smtfd {
return !context.empty();
}
bool is_decided_sat(expr_ref_vector const& core) {
plugin_context context(m, m_max_lemmas);
uf_plugin uf(m_abs, context, m_model);
a_plugin ap(m_abs, context, m_model);
bv_plugin bv(m_abs, context, m_model);
basic_plugin bs(m_abs, context, m_model);
pb_plugin pb(m_abs, context, m_model);
lbool is_decided_sat(expr_ref_vector const& core) {
plugin_context context(m_abs, m, m_max_lemmas);
uf_plugin uf(context, m_model);
a_plugin ap(context, m_model);
bv_plugin bv(context, m_model);
basic_plugin bs(context, m_model);
pb_plugin pb(context, m_model);
bool has_q = false;
bool has_non_covered = false;
for (expr* t : subterms(core)) {
if (!context.term_covered(t) || !context.sort_covered(m.get_sort(t))) {
return false;
if (is_forall(t) || is_exists(t)) {
has_q = true;
}
else if (!context.term_covered(t) || !context.sort_covered(m.get_sort(t))) {
has_non_covered = true;
}
}
context.populate_model(m_model, core);
if (!has_q) {
return has_non_covered ? l_false : l_true;
}
if (!m_mbqi_solver) {
m_mbqi_solver = alloc(solver, m, get_params());
}
mbqi mb(m_mbqi_solver.get(), context, m_enforced_quantifier, m_model);
if (!mb.check_quantifiers(core) && context.empty()) {
return l_false;
}
for (expr* f : context) {
IF_VERBOSE(10, verbose_stream() << "lemma: " << expr_ref(rep(f), m) << "\n");
assert_fd(f);
}
m_stats.m_num_mbqi += context.size();
return true;
if (context.empty()) {
return has_non_covered ? l_false : l_true;
}
else {
return l_undef;
}
}
void init_assumptions(unsigned sz, expr* const* user_asms, expr_ref_vector& asms) {
@ -1340,13 +1533,18 @@ namespace smtfd {
if (add_theory_lemmas(core)) {
continue;
}
if (r == l_undef) {
if (is_decided_sat(core)) {
return l_true;
}
m_max_conflicts = UINT_MAX;
if (r != l_undef) {
continue;
}
switch (is_decided_sat(core)) {
case l_true:
return l_true;
case l_undef:
break;
case l_false:
m_max_conflicts = UINT_MAX;
break;
}
}
return l_undef;
}
@ -1365,18 +1563,20 @@ namespace smtfd {
init();
m_smt_solver->collect_param_descrs(r);
m_fd_sat_solver->collect_param_descrs(r);
m_fd_core_solver->collect_param_descrs(r);
r.insert("max-lemmas", CPK_UINT, "maximal number of lemmas per round", "10");
}
void set_produce_models(bool f) override { }
void set_progress_callback(progress_callback * callback) override { }
void collect_statistics(statistics & st) const override {
m_fd_sat_solver->collect_statistics(st);
m_fd_core_solver->collect_statistics(st);
m_smt_solver->collect_statistics(st);
if (m_fd_sat_solver) {
m_fd_sat_solver->collect_statistics(st);
m_fd_core_solver->collect_statistics(st);
m_smt_solver->collect_statistics(st);
}
st.update("smtfd-num-lemmas", m_stats.m_num_lemmas);
st.update("smtfd-num-rounds", m_stats.m_num_rounds);
st.update("smtfd-num-mbqi", m_stats.m_num_mbqi);
}
void get_unsat_core(expr_ref_vector & r) override {
m_fd_sat_solver->get_unsat_core(r);