3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-08-17 00:32:16 +00:00

more ematching

This commit is contained in:
Nikolaj Bjorner 2021-01-29 13:39:14 -08:00
parent 41a4d102f4
commit 4af9132f2e
12 changed files with 263 additions and 108 deletions

View file

@ -16,13 +16,16 @@ Author:
Todo:
- clausify
- propagate without instantiations, produce explanations for eval
- generations
- insert instantiations into priority queue
- cache instantiations and substitutions
- nested quantifiers
- non-cnf quantifiers (handled in q_solver)
Done:
- propagate without instantiations, produce explanations for eval
--*/
#include "ast/ast_util.h"
@ -46,7 +49,8 @@ namespace q {
ematch::ematch(euf::solver& ctx, solver& s):
ctx(ctx),
m_qs(s),
m(ctx.get_manager())
m(ctx.get_manager()),
m_infer_patterns(m, ctx.get_config())
{
std::function<void(euf::enode*, euf::enode*)> _on_merge =
[&](euf::enode* root, euf::enode* other) {
@ -73,30 +77,64 @@ namespace q {
}
}
void ematch::explain(clause& c, unsigned literal_idx, binding& b) {
ctx.get_egraph().begin_explain();
m_explain.reset();
unsigned n = c.m_q->get_num_decls();
sat::ext_justification_idx ematch::mk_justification(unsigned idx, clause& c, euf::enode* const* b) {
void* mem = ctx.get_region().allocate(justification::get_obj_size());
sat::constraint_base::initialize(mem, &m_qs);
bool sign = false;
expr* l = nullptr, *r = nullptr;
lit lit(expr_ref(l,m), expr_ref(r, m), sign);
if (idx != UINT_MAX)
lit = c[idx];
auto* constraint = new (sat::constraint_base::ptr2mem(mem)) justification(lit, c, b);
return constraint->to_index();
}
void ematch::get_antecedents(sat::literal l, sat::ext_justification_idx idx, sat::literal_vector& r, bool probing) {
auto& j = justification::from_index(idx);
clause& c = j.m_clause;
unsigned l_idx = 0;
for (; l_idx < c.size(); ++l_idx) {
if (c[l_idx].lhs == j.m_lhs && c[l_idx].rhs == j.m_rhs && c[l_idx].sign == j.m_sign)
break;
}
explain(c, l_idx, j.m_binding);
r.push_back(c.m_literal);
(void)probing; // ignored
}
std::ostream& ematch::display_constraint(std::ostream& out, sat::ext_constraint_idx idx) const {
auto& j = justification::from_index(idx);
auto& c = j.m_clause;
out << "ematch: ";
for (auto const& lit : c.m_lits)
lit.display(out) << " ";
unsigned num_decls = c.num_decls();
for (unsigned i = 0; i < num_decls; ++i)
out << ctx.bpp(j.m_binding[i]) << " ";
out << "-> ";
lit lit(expr_ref(j.m_lhs, m), expr_ref(j.m_rhs, m), j.m_sign);
if (j.m_lhs)
lit.display(out);
else
out << "false";
return out;
}
void ematch::explain(clause& c, unsigned literal_idx, euf::enode* const* b) {
unsigned n = c.num_decls();
for (unsigned i = c.size(); i-- > 0; ) {
if (i == literal_idx)
continue;
auto const& lit = c[i];
lit.sign;
lit.lhs;
lit.rhs;
if (lit.sign) {
SASSERT(l_true == compare(n, b.m_nodes, lit.lhs, lit.rhs));
explain_eq(n, b.m_nodes, lit.lhs, lit.rhs);
}
else {
SASSERT(l_false == compare(n, b.m_nodes, lit.lhs, lit.rhs));
explain_diseq(n, b.m_nodes, lit.lhs, lit.rhs);
}
if (lit.sign)
explain_eq(n, b, lit.lhs, lit.rhs);
else
explain_diseq(n, b, lit.lhs, lit.rhs);
}
ctx.get_egraph().end_explain();
}
void ematch::explain_eq(unsigned n, euf::enode* const* binding, expr* s, expr* t) {
SASSERT(l_true == compare(n, binding, s, t));
if (s == t)
return;
euf::enode* sn = eval(n, binding, s);
@ -111,28 +149,29 @@ namespace q {
std::swap(s, t);
}
if (sn && !tn) {
ctx.add_antecedent(sn, sn->get_root());
for (euf::enode* s1 : euf::enode_class(sn)) {
if (l_true == compare_rec(n, binding, t, s1->get_expr())) {
ctx.add_antecedent(sn, s1);
explain_eq(n, binding, t, s1->get_expr());
return;
}
}
UNREACHABLE();
}
SASSERT(is_app(s) && is_app(t) && to_app(s)->get_decl() == to_app(t)->get_decl());
SASSERT(is_app(s) && is_app(t));
SASSERT(to_app(s)->get_decl() == to_app(t)->get_decl());
for (unsigned i = to_app(s)->get_num_args(); i-- > 0; )
explain_eq(n, binding, to_app(s)->get_arg(i), to_app(t)->get_arg(i));
}
void ematch::explain_diseq(unsigned n, euf::enode* const* binding, expr* s, expr* t) {
SASSERT(l_false == compare(n, binding, s, t));
if (m.are_distinct(s, t))
return;
euf::enode* sn = eval(n, binding, s);
euf::enode* tn = eval(n, binding, t);
if (sn && tn) {
SASSERT(sn->get_root() == tn->get_root());
ctx.add_antecedent(sn, tn);
if (sn && tn && ctx.get_egraph().are_diseq(sn, tn)) {
ctx.add_diseq_antecedent(sn, tn);
return;
}
if (!sn && tn) {
@ -140,19 +179,22 @@ namespace q {
std::swap(s, t);
}
if (sn && !tn) {
ctx.add_antecedent(sn, sn->get_root());
for (euf::enode* s1 : euf::enode_class(sn)) {
if (l_false == compare_rec(n, binding, t, s1->get_expr())) {
ctx.add_antecedent(sn, s1);
explain_diseq(n, binding, t, s1->get_expr());
return;
}
}
UNREACHABLE();
}
SASSERT(is_app(s) && is_app(t) && to_app(s)->get_decl() == to_app(t)->get_decl());
for (unsigned i = to_app(s)->get_num_args(); i-- > 0; ) {
if (l_false == compare_rec(n, binding, to_app(s)->get_arg(i), to_app(t)->get_arg(i))) {
explain_eq(n, binding, to_app(s)->get_arg(i), to_app(t)->get_arg(i));
SASSERT(is_app(s) && is_app(t));
app* at = to_app(t);
app* as = to_app(s);
SASSERT(as->get_decl() == at->get_decl());
for (unsigned i = as->get_num_args(); i-- > 0; ) {
if (l_false == compare_rec(n, binding, as->get_arg(i), at->get_arg(i))) {
explain_eq(n, binding, as->get_arg(i), at->get_arg(i));
return;
}
}
@ -170,6 +212,7 @@ namespace q {
};
void ematch::on_merge(euf::enode* root, euf::enode* other) {
TRACE("q", tout << "on-merge " << ctx.bpp(root) << " " << ctx.bpp(other) << "\n";);
SASSERT(root->get_root() == other->get_root());
unsigned root_id = root->get_expr_id();
unsigned other_id = other->get_expr_id();
@ -252,9 +295,23 @@ namespace q {
return new (mem) binding();
}
std::ostream& ematch::lit::display(std::ostream& out) const {
ast_manager& m = lhs.m();
if (m.is_true(rhs) && !sign)
return out << lhs;
if (m.is_false(rhs) && !sign)
return out << "(not " << lhs << ")";
return
out << mk_bounded_pp(lhs, lhs.m(), 2)
<< (sign ? " != " : " == ")
<< mk_bounded_pp(rhs, rhs.m(), 2);
}
void ematch::clause::add_binding(ematch& em, euf::enode* const* _binding) {
unsigned n = m_q->get_num_decls();
unsigned n = num_decls();
binding* b = em.alloc_binding(n);
b->init(b);
for (unsigned i = 0; i < n; ++i)
b->m_nodes[i] = _binding[i];
binding::push_to_front(m_bindings, b);
@ -262,6 +319,7 @@ namespace q {
}
void ematch::on_binding(quantifier* q, app* pat, euf::enode* const* _binding) {
TRACE("q", tout << "on-binding " << mk_pp(q, m) << "\n";);
clause& c = *m_clauses[m_q2clauses[q]];
if (!propagate(_binding, c))
c.add_binding(*this, _binding);
@ -270,14 +328,11 @@ namespace q {
std::ostream& ematch::clause::display(euf::solver& ctx, std::ostream& out) const {
out << "clause:\n";
for (auto const& lit : m_lits)
out << mk_bounded_pp(lit.lhs, lit.lhs.m(), 2)
<< (lit.sign ? " != " : " == ")
<< mk_bounded_pp(lit.rhs, lit.rhs.m(), 2) << "\n";
unsigned num_decls = m_q->get_num_decls();
lit.display(out) << "\n";
binding* b = m_bindings;
if (b) {
do {
for (unsigned i = 0; i < num_decls; ++i)
for (unsigned i = 0; i < num_decls(); ++i)
out << ctx.bpp(b->nodes()[i]) << " ";
out << "\n";
b = b->next();
@ -294,19 +349,22 @@ namespace q {
unsigned idx = UINT_MAX;
unsigned sz = c.m_lits.size();
unsigned n = c.m_q->get_num_decls();
unsigned n = c.num_decls();
m_indirect_nodes.reset();
for (unsigned i = 0; i < sz; ++i) {
unsigned lim = m_indirect_nodes.size();
lit l = c[i];
m_indirect_nodes.reset();
lbool cmp = compare(n, binding, l.lhs, l.rhs);
switch (cmp) {
case l_false:
m_indirect_nodes.shrink(lim);
if (!l.sign)
break;
if (i > 0)
std::swap(c[0], c[i]);
return true;
case l_true:
m_indirect_nodes.shrink(lim);
if (l.sign)
break;
if (i > 0)
@ -319,7 +377,7 @@ namespace q {
// to watch
for (euf::enode* n : m_indirect_nodes)
add_watch(n, clause_idx);
for (unsigned j = c.m_q->get_num_decls(); j-- > 0; )
for (unsigned j = c.num_decls(); j-- > 0; )
add_watch(binding[j], clause_idx);
if (i > 1)
std::swap(c[1], c[i]);
@ -332,7 +390,16 @@ namespace q {
}
}
TRACE("q", tout << "instantiate " << (idx == UINT_MAX ? "clause is false":"unit propagate") << "\n";);
instantiate(binding, c);
#if 1
auto j_idx = mk_justification(idx, c, binding);
if (idx == UINT_MAX)
ctx.set_conflict(j_idx);
else
ctx.propagate(instantiate(c, binding, c[idx]), j_idx);
#else
instantiate(c, binding);
#endif
return true;
}
@ -340,14 +407,33 @@ namespace q {
void ematch::instantiate(euf::enode* const* binding, clause& c) {
expr_ref_vector _binding(m);
quantifier* q = c.m_q;
for (unsigned i = 0; i < q->get_num_decls(); ++i)
for (unsigned i = 0; i < c.num_decls(); ++i)
_binding.push_back(binding[i]->get_expr());
var_subst subst(m);
expr_ref result = subst(q->get_expr(), _binding);
if (is_forall(q))
m_qs.add_clause(~ctx.mk_literal(q), ctx.mk_literal(result));
else
m_qs.add_clause(ctx.mk_literal(q), ~ctx.mk_literal(result));
sat::literal result_l = ctx.mk_literal(result);
if (is_exists(q))
result_l.neg();
m_qs.add_clause(c.m_literal, result_l);
}
sat::literal ematch::instantiate(clause& c, euf::enode* const* binding, lit const& l) {
expr_ref_vector _binding(m);
quantifier* q = c.m_q;
for (unsigned i = 0; i < c.num_decls(); ++i)
_binding.push_back(binding[i]->get_expr());
var_subst subst(m);
if (m.is_true(l.rhs)) {
SASSERT(!l.sign);
return ctx.mk_literal(subst(l.lhs, _binding));
}
else if (m.is_false(l.rhs)) {
SASSERT(!l.sign);
return ~ctx.mk_literal(subst(l.lhs, _binding));
}
expr_ref fml(m.mk_eq(l.lhs, l.rhs), m);
fml = subst(fml, _binding);
return l.sign ? ~ctx.mk_literal(fml) : ctx.mk_literal(fml);
}
lbool ematch::compare(unsigned n, euf::enode* const* binding, expr* s, expr* t) {
@ -357,7 +443,7 @@ namespace q {
if (tn) tn = tn->get_root();
TRACE("q", tout << mk_pp(s, m) << " ~~ " << mk_pp(t, m) << "\n";
tout << ctx.bpp(sn) << " " << ctx.bpp(tn) << "\n";);
lbool c;
if (sn && sn == tn)
return l_true;
@ -367,14 +453,15 @@ namespace q {
return l_undef;
if (!sn && !tn)
return compare_rec(n, binding, s, t);
if (!sn && tn)
for (euf::enode* t1 : euf::enode_class(tn))
if (c = compare_rec(n, binding, s, t1->get_expr()), c != l_undef)
return c;
if (sn && !tn)
for (euf::enode* s1 : euf::enode_class(sn))
if (c = compare_rec(n, binding, t, s1->get_expr()), c != l_undef)
return c;
if (!tn && !sn)
return l_undef;
if (!tn && sn) {
std::swap(tn, sn);
std::swap(t, s);
}
for (euf::enode* t1 : euf::enode_class(tn))
if (c = compare_rec(n, binding, s, t1->get_expr()), c != l_undef)
return c;
return l_undef;
}
@ -480,6 +567,7 @@ namespace q {
return false;
bool propagated = false;
ctx.push(value_trail<euf::solver, unsigned>(m_qhead));
ptr_buffer<binding> to_remove;
for (; m_qhead < m_queue.size(); ++m_qhead) {
unsigned idx = m_queue[m_qhead];
clause& c = *m_clauses[idx];
@ -487,14 +575,17 @@ namespace q {
if (!b)
continue;
do {
binding* next = b->next();
if (propagate(b->m_nodes, c)) {
binding::remove_from(c.m_bindings, b);
ctx.push(insert_binding(c, b));
}
b = next;
if (propagate(b->m_nodes, c))
to_remove.push_back(b);
b = b->next();
}
while (b != c.m_bindings);
for (binding* b : to_remove) {
binding::remove_from(c.m_bindings, b);
ctx.push(insert_binding(c, b));
}
to_remove.reset();
}
m_clause_in_queue.reset();
m_node_in_queue.reset();
@ -504,16 +595,17 @@ namespace q {
/**
* basic clausifier, assumes q has been normalized.
*/
ematch::clause* ematch::clausify(quantifier* q) {
clause* cl = alloc(clause);
ematch::clause* ematch::clausify(quantifier* _q) {
clause* cl = alloc(clause, m);
cl->m_literal = ctx.mk_literal(_q);
quantifier_ref q(_q, m);
if (is_exists(q)) {
cl->m_literal.neg();
expr_ref body(mk_not(m, q->get_expr()), m);
q = m.update_quantifier(q, forall_k, body);
}
expr_ref_vector ors(m);
if (is_forall(q))
flatten_or(q->get_expr(), ors);
else {
flatten_and(q->get_expr(), ors);
for (unsigned i = 0; i < ors.size(); ++i)
ors[i] = mk_not(m, ors.get(i));
}
flatten_or(q->get_expr(), ors);
for (expr* arg : ors) {
bool sign = m.is_not(arg, arg);
expr* l, *r;
@ -524,7 +616,13 @@ namespace q {
}
cl->m_lits.push_back(lit(expr_ref(l, m), expr_ref(r, m), sign));
}
if (q->get_num_patterns() == 0) {
expr_ref tmp(m);
m_infer_patterns(q, tmp);
q = to_quantifier(tmp);
}
cl->m_q = q;
SASSERT(ctx.s().value(cl->m_literal) == l_true);
return cl;
}
@ -591,6 +689,7 @@ namespace q {
}
bool ematch::operator()() {
TRACE("q", m_mam->display(tout););
if (propagate())
return true;
if (m_lazy_mam) {