From e2fbd05fe7f37ab0d5ddcf6fc729ed2ee1e6840b Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Tue, 27 Oct 2020 11:41:45 -0700 Subject: [PATCH] adding argument restriction to mbqi, fix tracking of m_src/m_dst for expr_safe_replace and avoid resetting the cache. --- src/ast/rewriter/expr_safe_replace.cpp | 9 ++- src/qe/mbp/mbp_plugin.h | 11 +++ src/sat/smt/q_mbi.cpp | 102 +++++++++++++++++++------ src/sat/smt/q_mbi.h | 10 ++- src/sat/smt/q_model_fixer.cpp | 79 +++++++++++-------- src/sat/smt/q_model_fixer.h | 6 ++ 6 files changed, 159 insertions(+), 58 deletions(-) diff --git a/src/ast/rewriter/expr_safe_replace.cpp b/src/ast/rewriter/expr_safe_replace.cpp index cd2c5f91f..9b67f5b5b 100644 --- a/src/ast/rewriter/expr_safe_replace.cpp +++ b/src/ast/rewriter/expr_safe_replace.cpp @@ -27,11 +27,11 @@ Revision History: void expr_safe_replace::insert(expr* src, expr* dst) { SASSERT(m.get_sort(src) == m.get_sort(dst)); + m_src.push_back(src); + m_dst.push_back(dst); #if ALIVE_OPT cache_insert(src, dst); #else - m_src.push_back(src); - m_dst.push_back(dst); m_subst.insert(src, dst); #endif } @@ -149,14 +149,17 @@ void expr_safe_replace::operator()(expr* e, expr_ref& res) { m_cache.reset(); m_todo.reset(); m_args.reset(); +#if !ALIVE_OPT m_refs.reset(); +#endif } void expr_safe_replace::reset() { m_src.reset(); m_dst.reset(); m_subst.reset(); - m_refs.reset(); + m_refs.finalize(); + m_cache.reset(); } void expr_safe_replace::apply_substitution(expr* s, expr* def, expr_ref& t) { diff --git a/src/qe/mbp/mbp_plugin.h b/src/qe/mbp/mbp_plugin.h index 2d84fb576..2b74f2740 100644 --- a/src/qe/mbp/mbp_plugin.h +++ b/src/qe/mbp/mbp_plugin.h @@ -114,6 +114,17 @@ namespace mbp { static void mark_rec(expr_mark& visited, expr* e); static void mark_rec(expr_mark& visited, expr_ref_vector const& es); + + /** + * mark sub-terms in e whether they contain a variable from vars. + */ + void mark_non_ground(app_ref_vector const& vars, expr* e) { + for (app* v : vars) + m_non_ground.mark(v); + mark_non_ground(e); + } + + bool is_non_ground(expr* t) const { return m_non_ground.is_marked(t); } }; } diff --git a/src/sat/smt/q_mbi.cpp b/src/sat/smt/q_mbi.cpp index 2fc1a9e23..defc287d3 100644 --- a/src/sat/smt/q_mbi.cpp +++ b/src/sat/smt/q_mbi.cpp @@ -17,6 +17,7 @@ Author: #include "ast/ast_trail.h" #include "ast/ast_util.h" +#include "ast/for_each_expr.h" #include "ast/rewriter/var_subst.h" #include "ast/rewriter/expr_safe_replace.h" #include "qe/mbp/mbp_arith.h" @@ -98,7 +99,6 @@ namespace q { } lbool mbqi::check_forall(quantifier* q) { - quantifier* q_flat = m_qs.flatten(q); auto* qb = specialize(q_flat); if (!qb) @@ -113,18 +113,33 @@ namespace q { return r; if (r == l_false) return l_true; - model_ref mdl0; + model_ref mdl0, mdl1; + expr_ref proj(m); m_solver->get_model(mdl0); - expr_ref proj = solver_project(*mdl0, *qb); - if (!proj) - return l_undef; sat::literal qlit = ctx.expr2literal(q); if (is_exists(q)) qlit.neg(); - ctx.get_rewriter()(proj); - TRACE("q", tout << proj << "\n";); - // TODO: add as top-level clause for relevancy - m_qs.add_clause(~qlit, ~ctx.mk_literal(proj)); + unsigned i = 0; + { + ::solver::scoped_push _sp(*m_solver); + restrict_domains(*mdl0, *qb); + for (; i < m_max_cex && l_true == m_solver->check_sat(0, nullptr); ++i) { + m_solver->get_model(mdl1); + proj = solver_project(*mdl1, *qb); + if (!proj) + break; + m_qs.add_clause(~qlit, ~ctx.mk_literal(proj)); + m_solver->assert_expr(m.mk_not(proj)); + } + } + if (i == 0) { + qb->domain_eqs.reset(); + proj = solver_project(*mdl0, *qb); + if (!proj) + return l_undef; + m_qs.add_clause(~qlit, ~ctx.mk_literal(proj)); + } + // TODO: add as top-level clause for relevancy return l_false; } @@ -146,6 +161,7 @@ namespace q { restrict_to_universe(vars.get(i), m_model->get_universe(s)); } expr_ref fml = subst(q->get_expr(), vars); + extract_var_args(q->get_expr(), *result); if (is_forall(q)) fml = m.mk_not(fml); flatten_and(fml, result->vbody); @@ -189,12 +205,13 @@ namespace q { for (app* v : qb.vars) m_model->register_decl(v->get_decl(), mdl(v)); TRACE("q", - tout << "Project\n"; - tout << *m_model << "\n"; - tout << qb.vbody << "\n"; - tout << "model of projection\n" << mdl << "\n";); + tout << "Project\n"; + tout << *m_model << "\n"; + tout << qb.vbody << "\n"; + tout << "model of projection\n" << mdl << "\n";); expr_ref_vector fmls(qb.vbody); app_ref_vector vars(qb.vars); + fmls.append(qb.domain_eqs); mbp::project_plugin proj(m); proj.purify(m_model_fixer, *m_model, vars, fmls); for (unsigned i = 0; i < vars.size(); ++i) { @@ -203,19 +220,60 @@ namespace q { if (p) (*p)(*m_model, vars, fmls); } - if (!vars.empty()) { - expr_safe_replace esubst(m); - for (app* v : vars) { - expr_ref val = assign_value(*m_model, v); - if (!val) - return expr_ref(m); - esubst.insert(v, val); - } - esubst(fmls); + expr_safe_replace esubst(m); + for (app* v : qb.vars) { + expr_ref val = assign_value(*m_model, v); + if (!val) + return expr_ref(m); + esubst.insert(v, val); } + esubst(fmls); return mk_and(fmls); } + /** + * Add disjunctions to m_solver that restrict the possible values of + * arguments to uninterpreted functions. The disjunctions added to the solver + * are specialized with respect to m_model. + * Add also disjunctions to the quantifier "domain_eqs", to track the constraints + * added to the solver. + */ + void mbqi::restrict_domains(model& mdl, q_body& qb) { + qb.domain_eqs.reset(); + var_subst subst(m); + for (auto p : qb.var_args) { + expr_ref bounds = m_model_fixer.restrict_arg(p.first, p.second); + if (m.is_true(bounds)) + continue; + expr_ref vbounds = subst(bounds, qb.vars); + expr_ref mbounds(m); + if (!m_model->eval_expr(bounds, mbounds, true)) + return; + mbounds = subst(mbounds, qb.vars); + std::cout << "restrict with bounds " << mbounds << " " << vbounds << "\n"; + m_solver->assert_expr(mbounds); + qb.domain_eqs.push_back(vbounds); + } + } + + /* + * Add domain restrictions for every non-ground arguments to uninterpreted functions. + */ + void mbqi::extract_var_args(expr* _t, q_body& qb) { + expr_ref t(_t, m); + for (expr* s : subterms(t)) { + if (is_ground(s)) + continue; + if (is_uninterp(s) && to_app(s)->get_num_args() > 0) { + app* a = to_app(s); + for (unsigned i = 0; i < a->get_num_args(); ++i) { + if (is_ground(a->get_arg(i))) + qb.var_args.push_back(std::make_pair(a, i)); + } + } + } + } + expr_ref mbqi::assign_value(model& mdl, app* v) { func_decl* f = v->get_decl(); expr_ref val(mdl.get_some_const_interp(f), m); diff --git a/src/sat/smt/q_mbi.h b/src/sat/smt/q_mbi.h index 52544e5aa..5c6067d51 100644 --- a/src/sat/smt/q_mbi.h +++ b/src/sat/smt/q_mbi.h @@ -32,9 +32,11 @@ namespace q { class mbqi { struct q_body { app_ref_vector vars; - expr_ref mbody; // body specialized with respect to model - expr_ref_vector vbody; // (negation of) body specialized with respect to vars - q_body(ast_manager& m) : vars(m), mbody(m), vbody(m) {} + expr_ref mbody; // body specialized with respect to model + expr_ref_vector vbody; // (negation of) body specialized with respect to vars + expr_ref_vector domain_eqs; // additional domain restrictions + svector> var_args; // (uninterpreted) functions in vbody that contain arguments with variables + q_body(ast_manager& m) : vars(m), mbody(m), vbody(m), domain_eqs(m) {} }; euf::solver& ctx; @@ -59,6 +61,8 @@ namespace q { expr_ref basic_project(model& mdl, quantifier* q, app_ref_vector& vars); expr_ref solver_project(model& mdl, q_body& qb); expr_ref assign_value(model& mdl, app* v); + void restrict_domains(model& mdl, q_body& qb); + void extract_var_args(expr* t, q_body& qb); void init_model(); void init_solver(); mbp::project_plugin* get_plugin(app* var); diff --git a/src/sat/smt/q_model_fixer.cpp b/src/sat/smt/q_model_fixer.cpp index b3678aa94..c5329b923 100644 --- a/src/sat/smt/q_model_fixer.cpp +++ b/src/sat/smt/q_model_fixer.cpp @@ -21,6 +21,7 @@ Notes: #include "ast/for_each_expr.h" +#include "ast/ast_util.h" #include "ast/arith_decl_plugin.h" #include "ast/bv_decl_plugin.h" #include "model/model_macro_solver.h" @@ -34,16 +35,16 @@ namespace q { template static bool lt(U const& u, expr* x, expr* y) { rational v1, v2; - if (u.is_numeral(x, v1) && u.is_numeral(y, v2)) + if (u.is_numeral(x, v1) && u.is_numeral(y, v2)) return v1 < v2; - else - return x->get_id() < y->get_id(); + else + return x->get_id() < y->get_id(); } class arith_projection : public projection_function { arith_util a; - public: - arith_projection(ast_manager& m): projection_function(m), a(m) {} + public: + arith_projection(ast_manager& m) : projection_function(m), a(m) {} ~arith_projection() override {} bool operator()(expr* e1, expr* e2) const override { return lt(a, e1, e2); } expr* mk_lt(expr* x, expr* y) override { return a.mk_lt(x, y); } @@ -51,11 +52,11 @@ namespace q { class ubv_projection : public projection_function { bv_util bvu; - public: - ubv_projection(ast_manager& m): projection_function(m), bvu(m) {} + public: + ubv_projection(ast_manager& m) : projection_function(m), bvu(m) {} ~ubv_projection() override {} bool operator()(expr* e1, expr* e2) const override { return lt(bvu, e1, e2); } - expr* mk_lt(expr* x, expr* y) override { return m.mk_not(bvu.mk_ule(y, x)); } + expr* mk_lt(expr* x, expr* y) override { return m.mk_not(bvu.mk_ule(y, x)); } }; model_fixer::model_fixer(euf::solver& ctx, q::solver& qs) : @@ -75,8 +76,8 @@ namespace q { m_dependencies.reset(); m_projection_data.reset(); m_projection_pinned.reset(); - ptr_vector residue; - + ptr_vector residue; + simple_macro_solver sms(m, *this); sms(mdl, univ, residue); @@ -115,12 +116,12 @@ namespace q { // ground values of its arguments. func_interp* fi = mdl.get_func_interp(f); - if (!fi) + if (!fi) return; if (fi->is_constant()) return; expr_ref_vector args(m); - for (unsigned i = 0; i < f->get_arity(); ++i) + for (unsigned i = 0; i < f->get_arity(); ++i) args.push_back(add_projection_function(mdl, f, i)); if (!fi->get_else() && fi->num_entries() > 0) { unsigned idx = ctx.s().rand()(fi->num_entries()); @@ -163,8 +164,8 @@ namespace q { lt _lt(proj); std::sort(values.c_ptr(), values.c_ptr() + values.size(), _lt); unsigned j = 0; - for (unsigned i = 0; i < values.size(); ++i) - if (i == 0 || values.get(i-1) != values.get(i)) + for (unsigned i = 0; i < values.size(); ++i) + if (i == 0 || values.get(i - 1) != values.get(i)) values[j++] = values.get(i); values.shrink(j); @@ -173,15 +174,15 @@ namespace q { unsigned sz = values.size(); expr_ref var(m.mk_var(0, srt), m); - expr_ref pi(values.get(sz-1), m); + expr_ref pi(values.get(sz - 1), m); for (unsigned i = sz - 1; i >= 1; i--) { expr* c = proj->mk_lt(var, values.get(i)); pi = m.mk_ite(c, values.get(i - 1), pi); } func_interp* rpi = alloc(func_interp, m, 1); rpi->set_else(pi); - func_decl * p = m.mk_fresh_func_decl(1, &srt, srt); - mdl.register_decl(p, rpi); + func_decl* p = m.mk_fresh_func_decl(1, &srt, srt); + mdl.register_decl(p, rpi); return expr_ref(m.mk_app(p, m.mk_var(idx, srt)), m); } @@ -209,24 +210,24 @@ namespace q { auto* info = (*this)(q); quantifier* flat_q = info->get_flat_q(); expr_ref body(flat_q->get_expr(), m); - for (expr* t : subterms(body)) + for (expr* t : subterms(body)) if (is_uninterp(t) && !to_app(t)->is_ground()) - fns.insert(to_app(t)->get_decl()); + fns.insert(to_app(t)->get_decl()); } } - expr* model_fixer::invert_app(app* t, expr* value) { + expr* model_fixer::invert_app(app* t, expr* value) { euf::enode* r = nullptr; TRACE("q", tout << "invert-app " << mk_pp(t, m) << " = " << mk_pp(value, m) << "\n"; - if (ctx.values2root().find(value, r)) - tout << "inverse " << mk_pp(r->get_expr(), m) << "\n";); + if (ctx.values2root().find(value, r)) + tout << "inverse " << mk_pp(r->get_expr(), m) << "\n";); if (ctx.values2root().find(value, r)) return r->get_expr(); - return value; + return value; } - void model_fixer::invert_arg(app* t, unsigned i, expr* value, expr_ref_vector& lits) { + void model_fixer::invert_arg(app* t, unsigned i, expr* value, expr_ref_vector& lits) { TRACE("q", tout << "invert-arg " << mk_pp(t, m) << " " << i << " " << mk_pp(value, m) << "\n";); auto const* md = get_projection_data(t->get_decl(), i); if (!md) @@ -236,9 +237,9 @@ namespace q { return; unsigned sz = md->values.size(); - if (sz <= 1) + if (sz <= 1) return; - + // // md->values are sorted // v1, v2, v3 @@ -246,8 +247,8 @@ namespace q { // v2 <= x < v3 => f(x) = f(v2), so t2 <= x < t3, where M(v3) = t3 // v3 <= x => f(x) = f(v3) // - auto is_lt = [&](expr* val) { - return (*proj)(value, val); + auto is_lt = [&](expr* val) { + return (*proj)(value, val); }; auto term = [&](unsigned j) { @@ -261,13 +262,31 @@ namespace q { return; } - for (unsigned j = 2; j < sz; ++j) + for (unsigned j = 2; j < sz; ++j) if (is_lt(md->values[j])) { lits.push_back(proj->mk_le(term(j - 1), arg)); lits.push_back(proj->mk_lt(arg, term(j))); return; } - lits.push_back(proj->mk_le(term(sz-1), arg)); + lits.push_back(proj->mk_le(term(sz - 1), arg)); + } + + /* + * restrict arg_i of t := f(...,arg_i,...) to be one of terms from the ground instantiations of f. + */ + expr_ref model_fixer::restrict_arg(app* t, unsigned i) { + TRACE("q", tout << "restrict-arg " << mk_pp(t, m) << " " << i << "\n";); + auto const* md = get_projection_data(t->get_decl(), i); + if (!md) + return expr_ref(m.mk_true(), m); + + expr* arg = t->get_arg(i); + expr_ref_vector eqs(m); + for (expr* v : md->values) + eqs.push_back(m.mk_eq(arg, md->v2t[v])); + if (eqs.empty()) + return expr_ref(m.mk_true(), m); + return mk_or(eqs); } } diff --git a/src/sat/smt/q_model_fixer.h b/src/sat/smt/q_model_fixer.h index aabf4ed24..be541bb4e 100644 --- a/src/sat/smt/q_model_fixer.h +++ b/src/sat/smt/q_model_fixer.h @@ -107,6 +107,12 @@ namespace q { quantifier_macro_info* operator()(quantifier* q) override; + /* + * Create a constraint that restricts the possible values of t to a finite set of values. + * Add value constraints to solver? + */ + expr_ref restrict_arg(app* t, unsigned i); + expr* invert_app(app* t, expr* value) override; void invert_arg(app* t, unsigned i, expr* value, expr_ref_vector& lits) override; };