3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-08 18:31:49 +00:00

update format and checker for implied-eq

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2023-07-27 13:21:45 -07:00
parent 249f0de80b
commit f0184c3fde
8 changed files with 192 additions and 94 deletions

View file

@ -612,7 +612,8 @@ class lp_bound_propagator {
constraint_index lc, uc;
lp().get_bound_constraint_witnesses_for_column(j, lc, uc);
ex.push_back(lc);
ex.push_back(uc);
if (lc != uc)
ex.push_back(uc);
}
vector<edge> connect_in_tree(const vertex* u, const vertex* v) const {

View file

@ -539,10 +539,12 @@ namespace arith {
if (x->get_root() == y->get_root())
return;
reset_evidence();
set_evidence(ci1);
set_evidence(ci2);
m_explanation.clear();
consume(rational::one(), ci1);
consume(rational::one(), ci2);
++m_stats.m_fixed_eqs;
auto* jst = euf::th_explain::propagate(*this, m_core, m_eqs, x, y);
auto* hint = explain_implied_eq(m_explanation, x, y);
auto* jst = euf::th_explain::propagate(*this, m_core, m_eqs, x, y, hint);
ctx.propagate(x, y, jst->to_index());
}

View file

@ -32,7 +32,7 @@ namespace arith {
}
arith_proof_hint* arith_proof_hint_builder::mk(euf::solver& s) {
return new (s.get_region()) arith_proof_hint(m_ty, m_num_le, m_lit_head, m_lit_tail, m_eq_head, m_eq_tail);
return new (s.get_region()) arith_proof_hint(m_ty, m_lit_head, m_lit_tail, m_eq_head, m_eq_tail);
}
std::ostream& solver::display(std::ostream& out) const {
@ -164,7 +164,6 @@ namespace arith {
return nullptr;
m_arith_hint.set_type(ctx, hint_type::implied_eq_h);
explain_assumptions(e);
m_arith_hint.set_num_le(1); // TODO
m_arith_hint.add_diseq(a, b);
return m_arith_hint.mk(ctx);
}
@ -173,13 +172,19 @@ namespace arith {
if (!ctx.use_drat())
return nullptr;
m_arith_hint.set_type(ctx, hint_type::implied_eq_h);
m_arith_hint.set_num_le(1);
m_arith_hint.add_lit(rational(1), le);
m_arith_hint.add_lit(rational(1), ge);
m_arith_hint.add_lit(rational(1), ~eq);
return m_arith_hint.mk(ctx);
}
/**
* The expected format is:
* 1. all equalities
* 2. all inequalities
* 3. optional disequalities (used for the steps that propagate equalities)
*/
expr* arith_proof_hint::get_hint(euf::solver& s) const {
ast_manager& m = s.get_manager();
family_id fid = m.get_family_id("arith");
@ -200,29 +205,39 @@ namespace arith {
break;
case hint_type::implied_eq_h:
name = "implied-eq";
args.push_back(arith.mk_int(m_num_le));
break;
default:
name = "unknown-arithmetic";
break;
}
rational lc(1);
for (unsigned i = m_lit_head; i < m_lit_tail; ++i)
lc = lcm(lc, denominator(a.m_arith_hint.lit(i).first));
for (unsigned i = m_eq_head; i < m_eq_tail; ++i) {
auto [x, y, is_eq] = a.m_arith_hint.eq(i);
auto push_eq = [&](bool is_eq, enode* x, enode* y) {
if (x->get_id() > y->get_id())
std::swap(x, y);
expr_ref eq(m.mk_eq(x->get_expr(), y->get_expr()), m);
if (!is_eq) eq = m.mk_not(eq);
args.push_back(arith.mk_int(1));
args.push_back(eq);
};
rational lc(1);
for (unsigned i = m_lit_head; i < m_lit_tail; ++i)
lc = lcm(lc, denominator(a.m_arith_hint.lit(i).first));
for (unsigned i = m_eq_head; i < m_eq_tail; ++i) {
auto [x, y, is_eq] = a.m_arith_hint.eq(i);
if (is_eq)
push_eq(is_eq, x, y);
}
for (unsigned i = m_lit_head; i < m_lit_tail; ++i) {
auto const& [coeff, lit] = a.m_arith_hint.lit(i);
args.push_back(arith.mk_int(abs(coeff*lc)));
args.push_back(s.literal2expr(lit));
}
for (unsigned i = m_eq_head; i < m_eq_tail; ++i) {
auto [x, y, is_eq] = a.m_arith_hint.eq(i);
if (!is_eq)
push_eq(is_eq, x, y);
}
return m.mk_app(symbol(name), args.size(), args.data(), m.mk_proof_sort());
}
}

View file

@ -713,10 +713,11 @@ namespace arith {
++m_stats.m_fixed_eqs;
reset_evidence();
set_evidence(ci1);
set_evidence(ci2);
set_evidence(ci3);
set_evidence(ci4);
m_explanation.clear();
consume(rational::one(), ci1);
consume(rational::one(), ci2);
consume(rational::one(), ci3);
consume(rational::one(), ci4);
enode* x = var2enode(v1);
enode* y = var2enode(v2);
auto* ex = explain_implied_eq(m_explanation, x, y);

View file

@ -57,10 +57,9 @@ namespace arith {
struct arith_proof_hint : public euf::th_proof_hint {
hint_type m_ty;
unsigned m_num_le;
unsigned m_lit_head, m_lit_tail, m_eq_head, m_eq_tail;
arith_proof_hint(hint_type t, unsigned num_le, unsigned lh, unsigned lt, unsigned eh, unsigned et):
m_ty(t), m_num_le(num_le), m_lit_head(lh), m_lit_tail(lt), m_eq_head(eh), m_eq_tail(et) {}
arith_proof_hint(hint_type t, unsigned lh, unsigned lt, unsigned eh, unsigned et):
m_ty(t), m_lit_head(lh), m_lit_tail(lt), m_eq_head(eh), m_eq_tail(et) {}
expr* get_hint(euf::solver& s) const override;
};
@ -68,7 +67,6 @@ namespace arith {
vector<std::pair<rational, literal>> m_literals;
svector<std::tuple<euf::enode*,euf::enode*,bool>> m_eqs;
hint_type m_ty;
unsigned m_num_le = 0;
unsigned m_lit_head = 0, m_lit_tail = 0, m_eq_head = 0, m_eq_tail = 0;
void reset() { m_lit_head = m_lit_tail; m_eq_head = m_eq_tail; }
void add(euf::enode* a, euf::enode* b, bool is_eq) {
@ -80,7 +78,6 @@ namespace arith {
}
public:
void set_type(euf::solver& ctx, hint_type ty);
void set_num_le(unsigned n) { m_num_le = n; }
void add_eq(euf::enode* a, euf::enode* b) { add(a, b, true); }
void add_diseq(euf::enode* a, euf::enode* b) { add(a, b, false); }
void add_lit(rational const& coeff, literal lit) {

View file

@ -35,6 +35,15 @@ The module assumes a limited repertoire of arithmetic proof rules.
namespace arith {
class theory_checker : public euf::theory_checker_plugin {
enum rule_type_t {
cut_t,
farkas_t,
implied_eq_t,
bound_t,
none_t
};
struct row {
obj_map<expr, rational> m_coeffs;
rational m_coeff;
@ -42,6 +51,9 @@ namespace arith {
m_coeffs.reset();
m_coeff = 0;
}
bool is_zero() const {
return m_coeffs.empty() && m_coeff == 0;
}
};
ast_manager& m;
@ -50,10 +62,24 @@ namespace arith {
bool m_strict = false;
row m_ineq;
row m_conseq;
vector<row> m_eqs;
symbol m_farkas;
symbol m_implied_eq;
symbol m_bound;
vector<row> m_eqs, m_ineqs;
symbol m_farkas = symbol("farkas");
symbol m_implied_eq = symbol("implied-eq");
symbol m_bound = symbol("bound");
symbol m_cut = symbol("cut");
rule_type_t rule_type(app* jst) const {
if (jst->get_name() == m_cut)
return cut_t;
if (jst->get_name() == m_bound)
return bound_t;
if (jst->get_name() == m_implied_eq)
return implied_eq_t;
if (jst->get_name() == m_farkas)
return farkas_t;
return none_t;
}
void add(row& r, expr* v, rational const& coeff) {
rational coeff1;
@ -90,10 +116,10 @@ namespace arith {
// X = lcm(a,b)/b, Y = -lcm(a,b)/a if v is integer
// X = 1/b, Y = -1/a if v is real
//
void resolve(expr* v, row& dst, rational const& A, row const& src) {
bool resolve(expr* v, row& dst, rational const& A, row const& src) {
rational B, x, y;
if (!dst.m_coeffs.find(v, B))
return;
return false;
if (a.is_int(v)) {
rational lc = lcm(abs(A), abs(B));
x = lc / abs(B);
@ -109,6 +135,7 @@ namespace arith {
y.neg();
mul(dst, x);
add(dst, src, y);
return true;
}
void cut(row& r) {
@ -197,6 +224,8 @@ namespace arith {
resolve(v, m_eqs[j], coeff, r);
resolve(v, m_ineq, coeff, r);
resolve(v, m_conseq, coeff, r);
for (auto& ineq : m_ineqs)
resolve(v, ineq, coeff, r);
}
return true;
}
@ -269,6 +298,81 @@ namespace arith {
return false;
}
/**
Check implied equality lemma:
inequalities & equalities => equality
We may assume the set of inequality assumptions we are given are all tight, non-strict and imply equalities.
In other words, given a set of inequalities a1x + b1 <= 0, ..., anx + bn <= 0
the equalities a1x + b1 = 0, ..., anx + bn = 0 are all consequences.
We use a weaker property: We derive implied equalities by applying exhaustive Fourier-Motzkin
elimination and then collect the tight 0 <= 0 inequalities that are derived.
Claim: the set of inequalities used to derive 0 <= 0 are all tight equalities.
*/
svector<std::pair<unsigned, unsigned>> m_deps;
unsigned_vector m_tight_inequalities;
uint_set m_ineqs_that_are_eqs;
bool check_implied_eq() {
if (!reduce_eq())
return true;
if (m_conseq.is_zero())
return true;
m_eqs.reset();
m_deps.reset();
unsigned orig_size = m_ineqs.size();
m_deps.reserve(orig_size);
for (unsigned i = 0; i < m_ineqs.size(); ++i) {
row& r = m_ineqs[i];
if (r.is_zero()) {
m_tight_inequalities.push_back(i);
continue;
}
auto const& [v, coeff] = *r.m_coeffs.begin();
unsigned sz = m_ineqs.size();
for (unsigned j = i + 1; j < sz; ++j) {
rational B;
row& r2 = m_ineqs[j];
if (!r2.m_coeffs.find(v, B) || (coeff > 0 && B > 0) || (coeff < 0 && B < 0))
continue;
row& r3 = fresh(m_ineqs);
add(r3, m_ineqs[j], rational::one());
resolve(v, r3, coeff, m_ineqs[i]);
m_deps.push_back({i, j});
}
SASSERT(m_deps.size() == m_ineqs.size());
}
m_ineqs_that_are_eqs.reset();
while (!m_tight_inequalities.empty()) {
unsigned j = m_tight_inequalities.back();
m_tight_inequalities.pop_back();
if (m_ineqs_that_are_eqs.contains(j))
continue;
m_ineqs_that_are_eqs.insert(j);
if (j < orig_size) {
m_eqs.push_back(m_ineqs[j]);
}
else {
auto [a, b] = m_deps[j];
m_tight_inequalities.push_back(a);
m_tight_inequalities.push_back(b);
}
}
m_ineqs.reset();
VERIFY (reduce_eq());
return m_conseq.is_zero();
}
std::ostream& display_row(std::ostream& out, row const& r) {
bool first = true;
for (auto const& [v, coeff] : r.m_coeffs) {
@ -306,22 +410,21 @@ namespace arith {
public:
theory_checker(ast_manager& m):
m(m),
a(m),
m_farkas("farkas"),
m_implied_eq("implied-eq"),
m_bound("bound") {}
a(m) {}
void reset() {
m_ineq.reset();
m_conseq.reset();
m_eqs.reset();
m_ineqs.reset();
m_strict = false;
}
bool add_ineq(rational const& coeff, expr* e, bool sign) {
return add_literal(m_ineq, abs(coeff), e, sign);
bool add_ineq(rule_type_t rt, rational const& coeff, expr* e, bool sign) {
row& r = rt == implied_eq_t ? fresh(m_ineqs) : m_ineq;
return add_literal(r, abs(coeff), e, sign);
}
bool add_conseq(rational const& coeff, expr* e, bool sign) {
return add_literal(m_conseq, abs(coeff), e, sign);
}
@ -332,11 +435,17 @@ namespace arith {
linearize(r, rational(-1), b);
}
bool check() {
if (m_conseq.m_coeffs.empty())
bool check(rule_type_t rt) {
switch (rt) {
case farkas_t:
return check_farkas();
else
case bound_t:
return check_bound();
case implied_eq_t:
return check_implied_eq();
default:
return check_bound();
}
}
std::ostream& display(std::ostream& out) {
@ -359,7 +468,7 @@ namespace arith {
/**
Add implied equality as an inequality
*/
bool add_implied_ineq(bool sign, app* jst) {
bool add_implied_diseq(bool sign, app* jst) {
unsigned n = jst->get_num_args();
if (n < 2)
return false;
@ -374,90 +483,57 @@ namespace arith {
return false;
if (!sign)
coeff.neg();
auto& r = m_ineq;
auto& r = m_conseq;
linearize(r, coeff, arg1);
linearize(r, -coeff, arg2);
m_strict = true;
return true;
}
bool check(app* jst) override {
reset();
bool is_bound = jst->get_name() == m_bound;
bool is_implied_eq = jst->get_name() == m_implied_eq;
bool is_farkas = jst->get_name() == m_farkas;
if (!is_farkas && !is_bound && !is_implied_eq) {
auto rt = rule_type(jst);
switch (rt) {
case cut_t:
return false;
case none_t:
IF_VERBOSE(0, verbose_stream() << "unhandled inference " << mk_pp(jst, m) << "\n");
return false;
default:
break;
}
bool even = true;
rational coeff;
expr* x, * y;
unsigned j = 0, num_le = 0;
unsigned j = 0;
for (expr* arg : *jst) {
if (even) {
if (!a.is_numeral(arg, coeff)) {
IF_VERBOSE(0, verbose_stream() << "not numeral " << mk_pp(jst, m) << "\n");
return false;
}
if (is_implied_eq) {
is_implied_eq = false;
if (!coeff.is_unsigned()) {
IF_VERBOSE(0, verbose_stream() << "not unsigned " << mk_pp(jst, m) << "\n");
return false;
}
num_le = coeff.get_unsigned();
if (!add_implied_ineq(false, jst)) {
IF_VERBOSE(0, display(verbose_stream() << "did not add implied eq"));
return false;
}
++j;
continue;
}
}
else {
bool sign = m.is_not(arg, arg);
if (a.is_le(arg) || a.is_lt(arg) || a.is_ge(arg) || a.is_gt(arg)) {
if (is_bound && j + 1 == jst->get_num_args())
if (rt == bound_t && j + 1 == jst->get_num_args())
add_conseq(coeff, arg, sign);
else if (num_le > 0) {
add_ineq(coeff, arg, sign);
--num_le;
if (num_le == 0) {
// we processed all the first inequalities,
// check that they imply one half of the implied equality.
if (!check()) {
// we might have added the wrong direction of the implied equality.
// so try the opposite inequality.
add_implied_ineq(true, jst);
add_implied_ineq(true, jst);
if (check()) {
reset();
add_implied_ineq(false, jst);
}
else {
IF_VERBOSE(0, display(verbose_stream() << "failed to check implied eq "));
return false;
}
}
else {
reset();
VERIFY(add_implied_ineq(true, jst));
}
}
}
else
add_ineq(coeff, arg, sign);
add_ineq(rt, coeff, arg, sign);
}
else if (m.is_eq(arg, x, y)) {
if (is_bound && j + 1 == jst->get_num_args())
if (rt == bound_t && j + 1 == jst->get_num_args())
add_conseq(coeff, arg, sign);
else if (sign)
return check(); // it should be an implied equality
else
else if (rt == implied_eq_t && j + 1 == jst->get_num_args())
return add_implied_diseq(sign, jst) && check(rt);
else if (!sign)
add_eq(x, y);
else {
IF_VERBOSE(0, verbose_stream() << "unexpected disequality in justification " << mk_pp(arg, m) << "\n");
return false;
}
}
else {
IF_VERBOSE(0, verbose_stream() << "not a recognized arithmetical relation " << mk_pp(arg, m) << "\n");
@ -467,13 +543,14 @@ namespace arith {
even = !even;
++j;
}
return check();
return check(rt);
}
void register_plugins(euf::theory_checker& pc) override {
pc.register_plugin(m_farkas, this);
pc.register_plugin(m_bound, this);
pc.register_plugin(m_implied_eq, this);
pc.register_plugin(m_cut, this);
}
};

View file

@ -465,6 +465,8 @@ namespace euf {
void solver::display_inferred(std::ostream& out, unsigned n, literal const* lits, expr* proof_hint) {
expr_ref hint(proof_hint, m);
if (!proof_hint)
verbose_stream() << hint << "\n";
if (!hint)
hint = m.mk_const(m_smt, m.mk_proof_sort());
visit_expr(out, hint);

View file

@ -240,6 +240,9 @@ namespace euf {
m_literals[i] = lits[i];
base_ptr += sizeof(literal) * n_lits;
m_eqs = reinterpret_cast<enode_pair*>(base_ptr);
if (!pma) {
verbose_stream() << "null\n";
}
for (i = 0; i < n_eqs; ++i) {
m_eqs[i] = eqs[i];
if (m_eqs[i].first->get_id() > m_eqs[i].second->get_id())