From b1cdb3e4518d17e2c4dde22a55eb581f5926ba69 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Mon, 9 Sep 2019 11:28:25 +0200 Subject: [PATCH] add mbqi to smtfd. For Nuno, of course Signed-off-by: Nikolaj Bjorner --- src/smt/smt_model_checker.cpp | 2 +- src/tactic/fd_solver/smtfd_solver.cpp | 280 ++++++++++++++++++++++---- 2 files changed, 241 insertions(+), 41 deletions(-) diff --git a/src/smt/smt_model_checker.cpp b/src/smt/smt_model_checker.cpp index da3cbbef0..b65027076 100644 --- a/src/smt/smt_model_checker.cpp +++ b/src/smt/smt_model_checker.cpp @@ -156,7 +156,7 @@ namespace smt { TRACE("model_checker", tout << "q after applying interpretation:\n" << mk_ismt2_pp(tmp, m) << "\n";); ptr_buffer subst_args; unsigned num_decls = q->get_num_decls(); - subst_args.resize(num_decls, 0); + subst_args.resize(num_decls, nullptr); sks.resize(num_decls, nullptr); for (unsigned i = 0; i < num_decls; i++) { sort * s = q->get_decl_sort(num_decls - i - 1); diff --git a/src/tactic/fd_solver/smtfd_solver.cpp b/src/tactic/fd_solver/smtfd_solver.cpp index beccbe651..bffe1df46 100644 --- a/src/tactic/fd_solver/smtfd_solver.cpp +++ b/src/tactic/fd_solver/smtfd_solver.cpp @@ -54,7 +54,7 @@ else: table[f][v_args] := v1, t - for t in subterms(core) where t is select(A, args): + for t in subterms((core) where t is select(A, args): vA := M(abs(A)) v_args = M(abs(args)) v2, args2, t2 := table[vA][v_args] @@ -126,6 +126,7 @@ Note: #include "ast/for_each_expr.h" #include "ast/pb_decl_plugin.h" #include "ast/rewriter/th_rewriter.h" +#include "ast/rewriter/var_subst.h" #include "tactic/tactic_exception.h" #include "tactic/fd_solver/fd_solver.h" #include "solver/solver.h" @@ -336,15 +337,19 @@ namespace smtfd { class theory_plugin; class plugin_context { + smtfd_abs& m_abs; expr_ref_vector m_lemmas; unsigned m_max_lemmas; ptr_vector m_plugins; public: - plugin_context(ast_manager& m, unsigned max): + plugin_context(smtfd_abs& a, ast_manager& m, unsigned max): + m_abs(a), m_lemmas(m), m_max_lemmas(max) {} + smtfd_abs& get_abs() { return m_abs; } + void add(expr* f) { m_lemmas.push_back(f); } ast_manager& get_manager() { return m_lemmas.get_manager(); } @@ -412,9 +417,9 @@ namespace smtfd { } public: - theory_plugin(smtfd_abs& a, plugin_context& context, model_ref& mdl) : + theory_plugin(plugin_context& context, model_ref& mdl) : m(context.get_manager()), - m_abs(a), + m_abs(context.get_abs()), m_context(context), m_model(mdl), m_values(m), @@ -548,8 +553,8 @@ namespace smtfd { class basic_plugin : public theory_plugin { public: - basic_plugin(smtfd_abs& a, plugin_context& context, model_ref& mdl): - theory_plugin(a, context, mdl) + basic_plugin(plugin_context& context, model_ref& mdl): + theory_plugin(context, mdl) {} void check_term(expr* t, unsigned round) override { } bool term_covered(expr* t) override { return is_app(t) && to_app(t)->get_family_id() == m.get_basic_family_id(); } @@ -563,8 +568,8 @@ namespace smtfd { class pb_plugin : public theory_plugin { pb_util m_pb; public: - pb_plugin(smtfd_abs& a, plugin_context& context, model_ref& mdl): - theory_plugin(a, context, mdl), + pb_plugin(plugin_context& context, model_ref& mdl): + theory_plugin(context, mdl), m_pb(m) {} void check_term(expr* t, unsigned round) override { } @@ -579,8 +584,8 @@ namespace smtfd { class bv_plugin : public theory_plugin { bv_util m_butil; public: - bv_plugin(smtfd_abs& a, plugin_context& context, model_ref& mdl): - theory_plugin(a, context, mdl), + bv_plugin(plugin_context& context, model_ref& mdl): + theory_plugin(context, mdl), m_butil(m) {} void check_term(expr* t, unsigned round) override { } @@ -635,8 +640,8 @@ namespace smtfd { public: - uf_plugin(smtfd_abs& a, plugin_context& context, model_ref& mdl): - theory_plugin(a, context, mdl), + uf_plugin(plugin_context& context, model_ref& mdl): + theory_plugin(context, mdl), m_pinned(m) {} @@ -898,8 +903,8 @@ namespace smtfd { public: - a_plugin(smtfd_abs& a, plugin_context& context, model_ref& mdl): - theory_plugin(a, context, mdl), + a_plugin(plugin_context& context, model_ref& mdl): + theory_plugin(context, mdl), m_autil(m), m_rewriter(m) {} @@ -1026,13 +1031,171 @@ namespace smtfd { } } } - - }; + class mbqi { + ast_manager& m; + plugin_context& m_context; + obj_hashtable& m_enforced; + model_ref m_model; + ref<::solver> m_solver; + expr_ref_vector m_pinned; + obj_map m_val2term; + + expr* abs(expr* e) { return m_context.get_abs().abs(e); } + expr_ref eval_abs(expr* t) { return (*m_model)(abs(t)); } + + // !Ex P(x) => !P(t) + // Ax P(x) => P(t) + // l_true: new instance + // l_false: no new instance + // l_undef unresolved + lbool check_forall(quantifier* q) { + expr_ref tmp(m); + unsigned sz = q->get_num_decls(); + if (!m_model->eval_expr(q->get_expr(), tmp, true)) { + return l_undef; + } + expr_ref_vector vars(m), vals(m); + vars.resize(sz, nullptr); + vals.resize(sz, nullptr); + for (unsigned i = 0; i < sz; ++i) { + vars[sz - i - 1] = m.mk_fresh_const(q->get_decl_name(i), q->get_decl_sort(i)); + + // TBD: finite domain variables + } + var_subst subst(m); + expr_ref body = subst(tmp, vars.size(), vars.c_ptr()); + if (is_forall(q)) { + body = m.mk_not(body); + } + + m_solver->push(); + m_solver->assert_expr(body); + lbool r = m_solver->check_sat(0, nullptr); + model_ref mdl; + + if (r == l_true) { + expr_ref qq(q->get_expr(), m); + for (expr* t : subterms(qq)) { + if (is_ground(t)) { + expr_ref v = eval_abs(t); + m_pinned.push_back(v); + m_val2term.insert(v, t); + } + } + m_solver->get_model(mdl); + for (unsigned i = 0; i < sz; ++i) { + app* v = to_app(vars.get(i)); + func_decl* f = v->get_decl(); + expr_ref val(mdl->get_some_const_interp(f), m); + if (!val) { + r = l_undef; + break; + } + expr* t = nullptr; + if (m_val2term.find(val, t)) { + val = t; + } + vals[i] = val; + } + if (r == l_true) { + body = subst(q->get_expr(), vals.size(), vals.c_ptr()); + if (is_forall(q)) { + body = m.mk_implies(q, body); + } + else { + body = m.mk_implies(body, q); + } + body = abs(body); + m_context.add(body); + } + } + m_solver->pop(1); + return r; + } + + bool is_enforced(quantifier* q) { + return m_enforced.contains(q); + } + + lbool check_exists(quantifier* q) { + if (is_enforced(q)) { + return l_true; + } + expr_ref tmp(m); + expr_ref_vector vars(m); + unsigned sz = q->get_num_decls(); + vars.resize(sz, nullptr); + for (unsigned i = 0; i < sz; ++i) { + vars[sz - i - 1] = m.mk_fresh_const(q->get_decl_name(i), q->get_decl_sort(i)); + } + var_subst subst(m); + expr_ref body = subst(tmp, vars.size(), vars.c_ptr()); + if (is_exists(q)) { + body = m.mk_implies(q, body); + } + else { + body = m.mk_implies(body, q); + } + m_enforced.insert(q); + m_context.add(abs(body)); + return l_true; + } + + void init_val2term(expr_ref_vector const& core) { + for (expr* t : subterms(core)) { + if (!m.is_bool(t) && is_ground(t)) { + expr_ref v = eval_abs(t); + m_pinned.push_back(v); + m_val2term.insert(v, t); + } + } + } + + public: + + mbqi(::solver* s, plugin_context& c, obj_hashtable& enforced, model_ref& mdl): + m(s->get_manager()), + m_context(c), + m_enforced(enforced), + m_model(mdl), + m_solver(s), + m_pinned(m) + {} + + bool check_quantifiers(expr_ref_vector const& core) { + bool result = true; + init_val2term(core); + for (expr* c : core) { + lbool r = l_false; + if (is_forall(c)) { + r = check_forall(to_quantifier(c)); + } + else if (is_exists(c)) { + r = check_exists(to_quantifier(c)); + } + else if (m.is_not(c, c)) { + if (is_forall(c)) { + r = check_exists(to_quantifier(c)); + } + else if (is_exists(c)) { + r = check_forall(to_quantifier(c)); + } + } + if (r == l_undef) { + result = false; + } + } + return result; + } + }; + + struct stats { unsigned m_num_lemmas; unsigned m_num_rounds; + unsigned m_num_mbqi; stats() { memset(this, 0, sizeof(stats)); } }; @@ -1042,6 +1205,7 @@ namespace smtfd { ref<::solver> m_fd_sat_solver; ref<::solver> m_fd_core_solver; ref<::solver> m_smt_solver; + ref m_mbqi_solver; expr_ref_vector m_assertions; unsigned_vector m_assertions_lim; unsigned m_assertions_qhead; @@ -1053,6 +1217,7 @@ namespace smtfd { unsigned m_max_lemmas; stats m_stats; unsigned m_max_conflicts; + obj_hashtable m_enforced_quantifier; void set_delay_simplify() { params_ref p; @@ -1141,10 +1306,12 @@ namespace smtfd { return r; } + bool add_theory_lemmas(expr_ref_vector const& core) { - plugin_context context(m, m_max_lemmas); - a_plugin ap(m_abs, context, m_model); - uf_plugin uf(m_abs, context, m_model); + plugin_context context(m_abs, m, m_max_lemmas); + a_plugin ap(context, m_model); + uf_plugin uf(context, m_model); + unsigned max_rounds = std::max(ap.max_rounds(), uf.max_rounds()); for (unsigned round = 0; round < max_rounds; ++round) { for (expr* t : subterms(core)) { @@ -1165,22 +1332,48 @@ namespace smtfd { return !context.empty(); } - bool is_decided_sat(expr_ref_vector const& core) { - plugin_context context(m, m_max_lemmas); - uf_plugin uf(m_abs, context, m_model); - a_plugin ap(m_abs, context, m_model); - bv_plugin bv(m_abs, context, m_model); - basic_plugin bs(m_abs, context, m_model); - pb_plugin pb(m_abs, context, m_model); - + lbool is_decided_sat(expr_ref_vector const& core) { + plugin_context context(m_abs, m, m_max_lemmas); + uf_plugin uf(context, m_model); + a_plugin ap(context, m_model); + bv_plugin bv(context, m_model); + basic_plugin bs(context, m_model); + pb_plugin pb(context, m_model); + + bool has_q = false; + bool has_non_covered = false; for (expr* t : subterms(core)) { - if (!context.term_covered(t) || !context.sort_covered(m.get_sort(t))) { - return false; + if (is_forall(t) || is_exists(t)) { + has_q = true; + } + else if (!context.term_covered(t) || !context.sort_covered(m.get_sort(t))) { + has_non_covered = true; } } context.populate_model(m_model, core); + + if (!has_q) { + return has_non_covered ? l_false : l_true; + } + if (!m_mbqi_solver) { + m_mbqi_solver = alloc(solver, m, get_params()); + } + mbqi mb(m_mbqi_solver.get(), context, m_enforced_quantifier, m_model); + if (!mb.check_quantifiers(core) && context.empty()) { + return l_false; + } + for (expr* f : context) { + IF_VERBOSE(10, verbose_stream() << "lemma: " << expr_ref(rep(f), m) << "\n"); + assert_fd(f); + } + m_stats.m_num_mbqi += context.size(); - return true; + if (context.empty()) { + return has_non_covered ? l_false : l_true; + } + else { + return l_undef; + } } void init_assumptions(unsigned sz, expr* const* user_asms, expr_ref_vector& asms) { @@ -1340,13 +1533,18 @@ namespace smtfd { if (add_theory_lemmas(core)) { continue; } - if (r == l_undef) { - if (is_decided_sat(core)) { - return l_true; - } - m_max_conflicts = UINT_MAX; + if (r != l_undef) { + continue; + } + switch (is_decided_sat(core)) { + case l_true: + return l_true; + case l_undef: + break; + case l_false: + m_max_conflicts = UINT_MAX; + break; } - } return l_undef; } @@ -1365,18 +1563,20 @@ namespace smtfd { init(); m_smt_solver->collect_param_descrs(r); m_fd_sat_solver->collect_param_descrs(r); - m_fd_core_solver->collect_param_descrs(r); r.insert("max-lemmas", CPK_UINT, "maximal number of lemmas per round", "10"); } void set_produce_models(bool f) override { } void set_progress_callback(progress_callback * callback) override { } void collect_statistics(statistics & st) const override { - m_fd_sat_solver->collect_statistics(st); - m_fd_core_solver->collect_statistics(st); - m_smt_solver->collect_statistics(st); + if (m_fd_sat_solver) { + m_fd_sat_solver->collect_statistics(st); + m_fd_core_solver->collect_statistics(st); + m_smt_solver->collect_statistics(st); + } st.update("smtfd-num-lemmas", m_stats.m_num_lemmas); st.update("smtfd-num-rounds", m_stats.m_num_rounds); + st.update("smtfd-num-mbqi", m_stats.m_num_mbqi); } void get_unsat_core(expr_ref_vector & r) override { m_fd_sat_solver->get_unsat_core(r);