From da124e42759ab4fb93e69f9bb49899d3d21fb8a3 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Tue, 28 Sep 2021 13:41:37 -0700 Subject: [PATCH] tune q-eval and q-ematch Signed-off-by: Nikolaj Bjorner --- src/ast/has_free_vars.cpp | 15 +++- src/ast/has_free_vars.h | 9 +++ src/ast/normal_forms/pull_quant.cpp | 14 ++-- src/sat/smt/q_eval.cpp | 2 +- src/sat/smt/q_eval.h | 2 + src/sat/smt/q_solver.cpp | 107 ++++++++++++++++++++-------- src/sat/smt/q_solver.h | 6 +- 7 files changed, 119 insertions(+), 36 deletions(-) diff --git a/src/ast/has_free_vars.cpp b/src/ast/has_free_vars.cpp index 09d6d7740..b1b74c7da 100644 --- a/src/ast/has_free_vars.cpp +++ b/src/ast/has_free_vars.cpp @@ -18,9 +18,10 @@ Revision History: --*/ #include "ast/ast.h" #include "ast/expr_delta_pair.h" +#include "ast/has_free_vars.h" #include "util/hashtable.h" -class contains_vars { +class contains_vars::imp { typedef hashtable, default_eq > cache; cache m_cache; svector m_todo; @@ -86,6 +87,18 @@ public: } }; +contains_vars::contains_vars() { + m_imp = alloc(imp); +} + +contains_vars::~contains_vars() { + dealloc(m_imp); +} + +bool contains_vars::operator()(expr* e) { + return (*m_imp)(e); +} + bool has_free_vars(expr * n) { contains_vars p; return p(n); diff --git a/src/ast/has_free_vars.h b/src/ast/has_free_vars.h index c1cbfa0e6..1f7cc62a3 100644 --- a/src/ast/has_free_vars.h +++ b/src/ast/has_free_vars.h @@ -20,6 +20,15 @@ Revision History: class expr; +class contains_vars { + class imp; + imp* m_imp; +public: + contains_vars(); + ~contains_vars(); + bool operator()(expr* n); +}; + bool has_free_vars(expr * n); diff --git a/src/ast/normal_forms/pull_quant.cpp b/src/ast/normal_forms/pull_quant.cpp index 8a94a3482..486f4e949 100644 --- a/src/ast/normal_forms/pull_quant.cpp +++ b/src/ast/normal_forms/pull_quant.cpp @@ -20,6 +20,7 @@ Notes: #include "ast/rewriter/var_subst.h" #include "ast/rewriter/rewriter_def.h" #include "ast/ast_pp.h" +#include "ast/ast_util.h" struct pull_quant::imp { @@ -50,7 +51,7 @@ struct pull_quant::imp { quantifier * q = to_quantifier(child); expr * body = q->get_expr(); quantifier_kind k = q->get_kind() == forall_k ? exists_k : forall_k; - result = m.update_quantifier(q, k, m.mk_not(body)); + result = m.update_quantifier(q, k, mk_not(m, body)); return true; } else { @@ -78,9 +79,8 @@ struct pull_quant::imp { qid = nested_q->get_qid(); } w = std::min(w, nested_q->get_weight()); - unsigned j = nested_q->get_num_decls(); - while (j > 0) { - --j; + + for (unsigned j = nested_q->get_num_decls(); j-- > 0; ) { var_sorts.push_back(nested_q->get_decl_sort(j)); symbol s = nested_q->get_decl_name(j); if (std::find(var_names.begin(), var_names.end(), s) != var_names.end()) @@ -254,6 +254,10 @@ struct pull_quant::imp { } br_status reduce_app(func_decl * f, unsigned num, expr * const * args, expr_ref & result, proof_ref & result_pr) { + if (m.is_not(f) && m.is_not(args[0])) { + result = to_app(args[0])->get_arg(0); + return BR_REWRITE1; + } if (!m.is_or(f) && !m.is_and(f) && !m.is_not(f)) return BR_FAILED; @@ -275,7 +279,7 @@ struct pull_quant::imp { proof_ref & result_pr) { if (is_exists(old_q)) { - result = m.mk_not(new_body); + result = mk_not(m, new_body); result = m.mk_not(m.update_quantifier(old_q, forall_k, result)); if (m.proofs_enabled()) m.mk_rewrite(old_q, result); diff --git a/src/sat/smt/q_eval.cpp b/src/sat/smt/q_eval.cpp index c506a52f7..a01bb9d9d 100644 --- a/src/sat/smt/q_eval.cpp +++ b/src/sat/smt/q_eval.cpp @@ -198,7 +198,7 @@ namespace q { todo.pop_back(); continue; } - if (is_ground(t) || (has_quantifiers(t) && !has_free_vars(t))) { + if (is_ground(t) || (has_quantifiers(t) && !m_contains_vars(t))) { m_eval.setx(t->get_id(), ctx.get_egraph().find(t), nullptr); if (!m_eval[t->get_id()]) return nullptr; diff --git a/src/sat/smt/q_eval.h b/src/sat/smt/q_eval.h index 76c219343..ef016407a 100644 --- a/src/sat/smt/q_eval.h +++ b/src/sat/smt/q_eval.h @@ -16,6 +16,7 @@ Author: --*/ #pragma once +#include "ast/has_free_vars.h" #include "sat/smt/q_clause.h" namespace euf { @@ -32,6 +33,7 @@ namespace q { euf::enode_vector m_indirect_nodes; bool m_freeze_swap = false; euf::enode_pair m_diseq_undef; + contains_vars m_contains_vars; struct scoped_mark_reset; diff --git a/src/sat/smt/q_solver.cpp b/src/sat/smt/q_solver.cpp index 4727dfefa..480853dbd 100644 --- a/src/sat/smt/q_solver.cpp +++ b/src/sat/smt/q_solver.cpp @@ -30,7 +30,8 @@ namespace q { th_euf_solver(ctx, ctx.get_manager().get_family_name(fid), fid), m_mbqi(ctx, *this), m_ematch(ctx, *this), - m_expanded(ctx.get_manager()) + m_expanded(ctx.get_manager()), + m_der(ctx.get_manager()) { } @@ -45,25 +46,22 @@ namespace q { add_clause(~l, lit); ctx.add_root(~l, lit); } - else { - auto const& exp = expand(q); - if (exp.size() > 1) { - for (expr* e : exp) { - sat::literal lit = ctx.internalize(e, l.sign(), false, false); - add_clause(~l, lit); - ctx.add_root(~l, lit); - } - } - else if (is_ground(q->get_expr())) { - auto lit = ctx.internalize(q->get_expr(), l.sign(), false, false); + else if (expand(q)) { + for (expr* e : m_expanded) { + sat::literal lit = ctx.internalize(e, l.sign(), false, false); add_clause(~l, lit); ctx.add_root(~l, lit); } - else { - ctx.push_vec(m_universal, l); - if (ctx.get_config().m_ematching) - m_ematch.add(q); - } + } + else if (is_ground(q->get_expr())) { + auto lit = ctx.internalize(q->get_expr(), l.sign(), false, false); + add_clause(~l, lit); + ctx.add_root(~l, lit); + } + else { + ctx.push_vec(m_universal, l); + if (ctx.get_config().m_ematching) + m_ematch.add(q); } m_stats.m_num_quantifier_asserts++; } @@ -223,8 +221,16 @@ namespace q { return val; } - expr_ref_vector const& solver::expand(quantifier* q) { + bool solver::expand(quantifier* q) { + expr_ref r(m); + proof_ref pr(m); + m_der(q, r, pr); m_expanded.reset(); + if (r != q) { + ctx.get_rewriter()(r); + m_expanded.push_back(r); + return true; + } if (is_forall(q)) flatten_and(q->get_expr(), m_expanded); else if (is_exists(q)) @@ -232,13 +238,31 @@ namespace q { else UNREACHABLE(); - expr* a, *b; - if (m_expanded.size() == 1 && m.is_iff(m_expanded.get(0), a, b)) { - expr_ref f1(m.mk_implies(a, b), m); - expr_ref f2(m.mk_implies(b, a), m); + if (m_expanded.size() == 1 && is_forall(q)) { m_expanded.reset(); - m_expanded.push_back(f1); - m_expanded.push_back(f2); + flatten_or(q->get_expr(), m_expanded); + expr_ref split1(m), split2(m), e1(m), e2(m); + unsigned idx = 0; + for (unsigned i = m_expanded.size(); i-- > 0; ) { + expr* arg = m_expanded.get(i); + if (split(arg, split1, split2)) { + if (e1) + return false; + e1 = split1; + e2 = split2; + idx = i; + } + } + if (!e1) + return false; + + m_expanded[idx] = e1; + e1 = mk_or(m_expanded); + m_expanded[idx] = e2; + e2 = mk_or(m_expanded); + m_expanded.reset(); + m_expanded.push_back(e1); + m_expanded.push_back(e2); } if (m_expanded.size() > 1) { for (unsigned i = m_expanded.size(); i-- > 0; ) { @@ -246,12 +270,39 @@ namespace q { ctx.get_rewriter()(tmp); m_expanded[i] = tmp; } + return true; } - else { - m_expanded.reset(); - m_expanded.push_back(q); + return false; + } + + bool solver::split(expr* arg, expr_ref& e1, expr_ref& e2) { + expr* x, * y, * z; + if (m.is_not(arg, x) && m.is_or(x, y, z) && is_literal(y) && is_literal(z)) { + e1 = mk_not(m, y); + e2 = mk_not(m, z); + return true; } - return m_expanded; + if (m.is_iff(arg, x, y) && is_literal(x) && is_literal(y)) { + e1 = m.mk_implies(x, y); + e2 = m.mk_implies(y, x); + return true; + } + if (m.is_and(arg, x, y) && is_literal(x) && is_literal(y)) { + e1 = x; + e2 = y; + return true; + } + if (m.is_not(arg, z) && m.is_iff(z, x, y) && is_literal(x) && is_literal(y)) { + e1 = m.mk_or(x, y); + e2 = m.mk_or(mk_not(m, x), mk_not(m, y)); + return true; + } + return false; + } + + bool solver::is_literal(expr* arg) { + m.is_not(arg, arg); + return !m.is_and(arg) && !m.is_or(arg) && !m.is_iff(arg) && !m.is_implies(arg); } void solver::get_antecedents(sat::literal l, sat::ext_justification_idx idx, sat::literal_vector& r, bool probing) { diff --git a/src/sat/smt/q_solver.h b/src/sat/smt/q_solver.h index 426104c87..934864669 100644 --- a/src/sat/smt/q_solver.h +++ b/src/sat/smt/q_solver.h @@ -18,6 +18,7 @@ Author: #include "util/obj_hashtable.h" #include "ast/ast_trail.h" +#include "ast/rewriter/der.h" #include "sat/smt/sat_th.h" #include "sat/smt/q_mbi.h" #include "sat/smt/q_ematch.h" @@ -47,6 +48,7 @@ namespace q { sat::literal_vector m_universal; obj_map m_unit_table; expr_ref_vector m_expanded; + der_rewriter m_der; sat::literal instantiate(quantifier* q, bool negate, std::function& mk_var); sat::literal skolemize(quantifier* q); @@ -54,7 +56,9 @@ namespace q { void init_units(); expr* get_unit(sort* s); - expr_ref_vector const& expand(quantifier* q); + bool expand(quantifier* q); + bool split(expr* arg, expr_ref& e1, expr_ref& e2); + bool is_literal(expr* arg); friend class ematch;