3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-23 00:55:31 +00:00

re-adding saturation for inequalities

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2021-09-07 23:20:17 +02:00
parent e6e5621366
commit d8f0926620
6 changed files with 343 additions and 238 deletions

View file

@ -867,6 +867,12 @@ namespace dd {
e->m_rest = rest.root;
}
bool pdd_manager::factor(pdd const& p, unsigned v, unsigned degree, pdd& lc) {
pdd rest = lc;
factor(p, v, degree, lc, rest);
return rest.is_zero();
}
/**
* Apply function f to all coefficients of the polynomial.
* The function should be of type

View file

@ -271,6 +271,7 @@ namespace dd {
template <class Fn> pdd map_coefficients(pdd const& p, Fn f);
void factor(pdd const& p, unsigned v, unsigned degree, pdd& lc, pdd& rest);
bool factor(pdd const& p, unsigned v, unsigned degree, pdd& lc);
bool var_is_leaf(PDD p, unsigned v);
@ -414,6 +415,7 @@ namespace dd {
pdd reduce(pdd const& other) const { return m.reduce(*this, other); }
bool different_leading_term(pdd const& other) const { return m.different_leading_term(*this, other); }
void factor(unsigned v, unsigned degree, pdd& lc, pdd& rest) const { m.factor(*this, v, degree, lc, rest); }
bool factor(unsigned v, unsigned degree, pdd& lc) const { return m.factor(*this, v, degree, lc); }
bool resolve(unsigned v, pdd const& other, pdd& result) { return m.resolve(v, *this, other, result); }
pdd subst_val(vector<std::pair<unsigned, rational>> const& s) const { return m.subst_val(*this, s); }
@ -437,6 +439,8 @@ namespace dd {
pdd_iterator begin() const;
pdd_iterator end() const;
pdd_manager& manager() const { return m; }
};
inline pdd operator*(rational const& r, pdd const& b) { return b * r; }

View file

@ -176,127 +176,8 @@ namespace polysat {
// return lemma;
// }
// /// [x] zx > yx ==> Ω*(x,y) \/ z > y
// /// [x] yx <= zx ==> Ω*(x,y) \/ y <= z
// clause_ref conflict_explainer::by_ugt_x() {
// LOG_H3("Try zx > yx where x := v" << m_var);
// pdd const x = m_solver.var(m_var);
// unsigned const sz = m_solver.size(m_var);
// pdd const zero = m_solver.sz2pdd(sz).zero();
// // Find constraint of shape: yx <= zx
// for (auto* c1 : m_conflict.units()) {
// auto c = c1->as_inequality();
// if (c.lhs.degree(m_var) != 1)
// continue;
// if (c.rhs.degree(m_var) != 1)
// continue;
// pdd y = zero;
// pdd rest = zero;
// c.lhs.factor(m_var, 1, y, rest);
// if (!rest.is_zero())
// continue;
// pdd z = zero;
// c.rhs.factor(m_var, 1, z, rest);
// if (!rest.is_zero())
// continue;
// unsigned const lvl = c.src->level();
// clause_builder clause(m_solver);
// clause.push_literal(~c.src->blit());
// // Omega^*(x, y)
// if (!push_omega_mul(clause, lvl, sz, x, y))
// continue;
// constraint_literal y_le_z;
// if (c.is_strict)
// y_le_z = m_solver.m_constraints.ult(lvl, y, z);
// else
// y_le_z = m_solver.m_constraints.ule(lvl, y, z);
// LOG("z>y: " << show_deref(y_le_z));
// clause.push_new_constraint(std::move(y_le_z));
// return clause.build();
// }
// return nullptr;
// }
// /// [y] z' <= y /\ zx > yx ==> Ω*(x,y) \/ zx > z'x
// /// [y] z' <= y /\ yx <= zx ==> Ω*(x,y) \/ z'x <= zx
// clause_ref conflict_explainer::by_ugt_y() {
// LOG_H3("Try z' <= y && zx > yx where y := v" << m_var);
// pdd const y = m_solver.var(m_var);
// unsigned const sz = m_solver.size(m_var);
// pdd const zero = m_solver.sz2pdd(sz).zero();
// // Collect constraints of shape "_ <= y"
// vector<inequality> ds;
// for (auto* d1 : m_conflict.units()) {
// auto d = d1->as_inequality();
// // TODO: a*y where 'a' divides 'x' should also be easy to handle (assuming for now they're numbers)
// // TODO: also z' < y should follow the same pattern.
// if (d.rhs != y)
// continue;
// LOG("z' <= y candidate: " << show_deref(d.src));
// ds.push_back(std::move(d));
// }
// if (ds.empty())
// return nullptr;
// // Find constraint of shape: yx <= zx
// for (auto* c1 : m_conflict.units()) {
// auto c = c1->as_inequality();
// if (c.lhs.degree(m_var) != 1)
// continue;
// pdd x = zero;
// pdd rest = zero;
// c.lhs.factor(m_var, 1, x, rest);
// if (!rest.is_zero())
// continue;
// // TODO: in principle, 'x' could be any polynomial. However, we need to divide the lhs by x, and we don't have general polynomial division yet.
// // so for now we just allow the form 'value*variable'.
// // (extension to arbitrary monomials for 'x' should be fairly easy too)
// if (!x.is_unary())
// continue;
// unsigned x_var = x.var();
// rational x_coeff = x.hi().val();
// pdd xz = zero;
// if (!c.rhs.try_div(x_coeff, xz))
// continue;
// pdd z = zero;
// xz.factor(x_var, 1, z, rest);
// if (!rest.is_zero())
// continue;
// LOG("zx > yx: " << show_deref(c.src));
// // TODO: for now, we just try all ds
// for (auto const& d : ds) {
// unsigned const lvl = std::max(c.src->level(), d.src->level());
// pdd const& z_prime = d.lhs;
// clause_builder clause(m_solver);
// clause.push_literal(~c.src->blit());
// clause.push_literal(~d.src->blit());
// // Omega^*(x, y)
// if (!push_omega_mul(clause, lvl, sz, x, y))
// continue;
// // z'x <= zx
// constraint_literal zpx_le_zx;
// if (c.is_strict || d.is_strict)
// zpx_le_zx = m_solver.m_constraints.ult(lvl, z_prime*x, z*x);
// else
// zpx_le_zx = m_solver.m_constraints.ule(lvl, z_prime*x, z*x);
// LOG("zx>z'x: " << show_deref(zpx_le_zx));
// clause.push_new_constraint(std::move(zpx_le_zx));
// return clause.build();
// }
// }
// return nullptr;
// }
// /// [z] z <= y' /\ zx > yx ==> Ω*(x,y') \/ y'x > yx
// /// [z] z <= y' /\ yx <= zx ==> Ω*(x,y') \/ yx <= y'x
@ -374,120 +255,5 @@ namespace polysat {
// return nullptr;
// }
// /// [x] y <= ax /\ x <= z (non-overflow case)
// /// ==> Ω*(a, z) \/ y <= az
// clause_ref conflict_explainer::y_ule_ax_and_x_ule_z() {
// LOG_H3("Try y <= ax && x <= z where x := v" << m_var);
// pdd const x = m_solver.var(m_var);
// unsigned const sz = m_solver.size(m_var);
// pdd const zero = m_solver.sz2pdd(sz).zero();
// // Collect constraints of shape "x <= _"
// vector<inequality> ds;
// for (auto* d1 : m_conflict.units()) {
// inequality d = d1->as_inequality();
// if (d.lhs != x)
// continue;
// LOG("x <= z' candidate: " << show_deref(d.src));
// ds.push_back(std::move(d));
// }
// if (ds.empty())
// return nullptr;
// // Find constraint of shape: y <= ax
// for (auto* c1 : m_conflict.units()) {
// inequality c = c1->as_inequality();
// if (c.rhs.degree(m_var) != 1)
// continue;
// pdd a = zero;
// pdd rest = zero;
// c.rhs.factor(m_var, 1, a, rest);
// if (!rest.is_zero())
// continue;
// pdd const& y = c.lhs;
// LOG("y <= ax: " << show_deref(c1));
// // TODO: for now, we just try all of the other constraints in order
// for (auto const& d : ds) {
// unsigned const lvl = std::max(c1->level(), d.src->level());
// pdd const& z = d.rhs;
// clause_builder clause(m_solver);
// clause.push_literal(~c.src->blit());
// clause.push_literal(~d.src->blit());
// // Omega^*(a, z)
// if (!push_omega_mul(clause, lvl, sz, a, z))
// continue;
// // y'x > yx
// constraint_literal y_ule_az;
// if (c.is_strict || d.is_strict)
// y_ule_az = m_solver.m_constraints.ult(lvl, y, a*z);
// else
// y_ule_az = m_solver.m_constraints.ule(lvl, y, a*z);
// LOG("y<=az: " << show_deref(y_ule_az));
// clause.push_new_constraint(std::move(y_ule_az));
// return clause.build();
// }
// }
// return nullptr;
// }
// /// Add Ω*(x, y) to the clause.
// ///
// /// @param[in] p bit width
// bool conflict_explainer::push_omega_mul(clause_builder& clause, unsigned level, unsigned p, pdd const& x, pdd const& y) {
// LOG_H3("Omega^*(x, y)");
// LOG("x = " << x);
// LOG("y = " << y);
// auto& pddm = m_solver.sz2pdd(p);
// unsigned min_k = 0;
// unsigned max_k = p - 1;
// unsigned k = p/2;
// rational x_val;
// if (m_solver.try_eval(x, x_val)) {
// unsigned x_bits = x_val.bitsize();
// LOG("eval x: " << x << " := " << x_val << " (x_bits: " << x_bits << ")");
// SASSERT(x_val < rational::power_of_two(x_bits));
// min_k = x_bits;
// }
// rational y_val;
// if (m_solver.try_eval(y, y_val)) {
// unsigned y_bits = y_val.bitsize();
// LOG("eval y: " << y << " := " << y_val << " (y_bits: " << y_bits << ")");
// SASSERT(y_val < rational::power_of_two(y_bits));
// max_k = p - y_bits;
// }
// if (min_k > max_k) {
// // In this case, we cannot choose k such that both literals are false.
// // This means x*y overflows in the current model and the chosen rule is not applicable.
// // (or maybe we are in the case where we need the msb-encoding for overflow).
// return false;
// }
// SASSERT(min_k <= max_k); // if this assertion fails, we cannot choose k s.t. both literals are false
// // TODO: could also choose other value for k (but between the bounds)
// if (min_k == 0)
// k = max_k;
// else
// k = min_k;
// LOG("k = " << k << "; min_k = " << min_k << "; max_k = " << max_k << "; p = " << p);
// SASSERT(min_k <= k && k <= max_k);
// // x >= 2^k
// auto c1 = m_solver.m_constraints.ule(level, pddm.mk_val(rational::power_of_two(k)), x);
// // y > 2^{p-k}
// auto c2 = m_solver.m_constraints.ult(level, pddm.mk_val(rational::power_of_two(p-k)), y);
// clause.push_new_constraint(std::move(c1));
// clause.push_new_constraint(std::move(c2));
// return true;
// }
}

View file

@ -1,4 +1,4 @@
/*++
/*++
Copyright (c) 2021 Microsoft Corporation
Module Name:
@ -10,6 +10,15 @@ Author:
Nikolaj Bjorner (nbjorner) 2021-03-19
Jakob Rath 2021-04-6
TODO:
- currently saturation just removes premise or premises and adds new clauses
- this needs to be fixed as follows:
- per calculus it really adds a propagation to the stack
- then it adds the propagated literal to the core and removes the premise that needed to be simplified from core.
TODO:
- remove level information from created constraints.
-
--*/
#include "math/polysat/saturation.h"
#include "math/polysat/solver.h"
@ -73,4 +82,293 @@ namespace polysat {
return false;
}
bool inf_saturate::perform(pvar v, conflict_core& core) {
for (auto c1 : core) {
auto c = c1.as_inequality();
if (try_ugt_x(v, core, c))
return true;
if (try_ugt_y(v, core, c))
return true;
if (try_ugt_z(v, core, c))
return true;
if (try_y_l_ax_and_x_l_z(v, core, c))
return true;
}
return false;
}
/**
* Implement the inferences
* [x] zx > yx ==> Ω*(x,y) \/ z > y
* [x] yx <= zx ==> Ω*(x,y) \/ y <= z
*/
bool inf_saturate::try_ugt_x(pvar v, conflict_core& core, inequality const& c) {
LOG_H3("Try zx > yx where x := v" << v);
if (c.lhs.degree(v) != 1)
return false;
if (c.rhs.degree(v) != 1)
return false;
pdd const x = s().var(v);
pdd y = x;
if (!c.lhs.factor(v, 1, y))
return false;
pdd z = x;
if (!c.rhs.factor(v, 1, z))
return false;
unsigned const lvl = c.src->level();
// Omega^*(x, y)
if (!push_omega_mul(core, lvl, x, y))
return false;
push_l(core, lvl, c.is_strict, y, z);
// TODO
// requires signed constraint: core.remove(*c.src);
return true;
}
void inf_saturate::push_l(conflict_core& core, unsigned lvl, bool is_strict, pdd const& lhs, pdd const& rhs) {
if (is_strict)
core.insert(s().m_constraints.ult(lvl, lhs, rhs));
else
core.insert(s().m_constraints.ule(lvl, lhs, rhs));
}
/// Add Ω*(x, y) to the conflict state.
///
/// @param[in] p bit width
bool inf_saturate::push_omega_mul(conflict_core& core, unsigned level, pdd const& x, pdd const& y) {
LOG_H3("Omega^*(x, y)");
LOG("x = " << x);
LOG("y = " << y);
auto& pddm = x.manager();
unsigned p = pddm.power_of_2();
unsigned min_k = 0;
unsigned max_k = p - 1;
unsigned k = p / 2;
rational x_val;
if (s().try_eval(x, x_val)) {
unsigned x_bits = x_val.bitsize();
LOG("eval x: " << x << " := " << x_val << " (x_bits: " << x_bits << ")");
SASSERT(x_val < rational::power_of_two(x_bits));
min_k = x_bits;
}
rational y_val;
if (s().try_eval(y, y_val)) {
unsigned y_bits = y_val.bitsize();
LOG("eval y: " << y << " := " << y_val << " (y_bits: " << y_bits << ")");
SASSERT(y_val < rational::power_of_two(y_bits));
max_k = p - y_bits;
}
if (min_k > max_k) {
// In this case, we cannot choose k such that both literals are false.
// This means x*y overflows in the current model and the chosen rule is not applicable.
// (or maybe we are in the case where we need the msb-encoding for overflow).
return false;
}
SASSERT(min_k <= max_k); // if this assertion fails, we cannot choose k s.t. both literals are false
// TODO: could also choose other value for k (but between the bounds)
if (min_k == 0)
k = max_k;
else
k = min_k;
LOG("k = " << k << "; min_k = " << min_k << "; max_k = " << max_k << "; p = " << p);
SASSERT(min_k <= k && k <= max_k);
// x >= 2^k
auto c1 = s().m_constraints.ule(level, pddm.mk_val(rational::power_of_two(k)), x);
// y > 2^{p-k}
auto c2 = s().m_constraints.ult(level, pddm.mk_val(rational::power_of_two(p - k)), y);
core.insert(~c1);
core.insert(~c2);
return true;
}
/*
* Match [v] .. <= v
*/
bool inf_saturate::is_l_v(pvar v, inequality const& i) {
return i.rhs == s().var(v);
}
/*
* Match [v] v <= ...
*/
bool inf_saturate::is_g_v(pvar v, inequality const& i) {
return i.lhs == s().var(v);
}
/*
* Match [x] y <= a*x
*/
bool inf_saturate::is_y_l_ax(pvar x, inequality const& d, pdd& a, pdd& y) {
y = d.lhs;
return d.rhs.degree(x) == 1 && d.rhs.factor(x, 1, a);
}
/**
* Match [v] v*x <= z*x
*/
bool inf_saturate::is_Xy_l_XZ(pvar v, inequality const& c, pdd& x, pdd& z) {
if (c.lhs.degree(v) != 1)
return false;
if (!c.lhs.factor(v, 1, x))
return false;
// TODO: in principle, 'x' could be any polynomial. However, we need to divide the lhs by x, and we don't have general polynomial division yet.
// so for now we just allow the form 'value*variable'.
// (extension to arbitrary monomials for 'x' should be fairly easy too)
if (!x.is_unary())
return false;
unsigned x_var = x.var();
rational x_coeff = x.hi().val();
pdd xz = x;
if (!c.rhs.try_div(x_coeff, xz))
return false;
if (!xz.factor(x_var, 1, z))
return false;
LOG("zx > yx: " << show_deref(c.src));
return true;
}
/**
* Match [z] yx <= zx
*/
bool inf_saturate::is_YX_l_zX(pvar z, inequality const& c, pdd& x, pdd& y) {
if (c.rhs.degree(z) != 1)
return false;
if (!c.rhs.factor(z, 1, x))
return false;
// TODO: in principle, 'x' could be any polynomial. However, we need to divide the lhs by x, and we don't have general polynomial division yet.
// so for now we just allow the form 'value*variable'.
// (extension to arbitrary monomials for 'x' should be fairly easy too)
if (!x.is_unary())
return false;
unsigned x_var = x.var();
rational x_coeff = x.hi().val();
pdd xy = x;
return c.lhs.try_div(x_coeff, xy) && xy.factor(x_var, 1, y);
}
/// [y] z' <= y /\ zx > yx ==> Ω*(x,y) \/ zx > z'x
/// [y] z' <= y /\ yx <= zx ==> Ω*(x,y) \/ z'x <= zx
bool inf_saturate::try_ugt_y(pvar v, conflict_core& core, inequality const& le_y, inequality const& yx_l_zx, pdd const& x, pdd const& z) {
LOG_H3("Try z' <= y && zx > yx where y := v" << v);
pdd const y = s().var(v);
SASSERT(is_l_v(v, le_y));
// SASSERT(is_yx_l_zx(v, yx_l_zx, x, z));
unsigned const lvl = std::max(yx_l_zx.src->level(), le_y.src->level());
pdd const& z_prime = le_y.lhs;
// Omega^*(x, y)
if (!push_omega_mul(core, lvl, x, y))
return false;
// z'x <= zx
push_l(core, lvl, yx_l_zx.is_strict || le_y.is_strict, z_prime * x, z * x);
// TODO core.remove(*le_y.src);
// core.remove(*yx_l_zs.src);
return true;
}
bool inf_saturate::try_ugt_y(pvar v, conflict_core& core, inequality const& c) {
if (!is_l_v(v, c))
return false;
pdd x = s().var(v);
pdd z = x;
for (auto dd : core) {
auto d = dd.as_inequality();
if (is_Xy_l_XZ(v, d, x, z) && try_ugt_y(v, core, c, d, x, z))
return true;
}
return false;
}
/// [x] y <= ax /\ x <= z (non-overflow case)
/// ==> Ω*(a, z) \/ y <= az
bool inf_saturate::try_y_l_ax_and_x_l_z(pvar x, conflict_core& core, inequality const& c) {
if (!is_g_v(x, c))
return false;
pdd y = s().var(x);
pdd a = y;
for (auto dd : core) {
auto d = dd.as_inequality();
if (is_y_l_ax(x, d, a, y) && try_y_l_ax_and_x_l_z(x, core, c, d, a, y))
return true;
}
return false;
}
bool inf_saturate::try_y_l_ax_and_x_l_z(pvar x, conflict_core& core, inequality const& x_l_z, inequality const& y_l_ax, pdd const& a, pdd const& y) {
SASSERT(is_g_v(x, x_l_z));
// SASSERT(is_y_l_ax(x, y_l_ax, a, y));
LOG_H3("Try y <= ax && x <= z where x := v" << x);
pdd z = x_l_z.rhs;
unsigned const lvl = std::max(x_l_z.src->level(), y_l_ax.src->level());
if (!push_omega_mul(core, lvl, a, z))
return false;
push_l(core, lvl, x_l_z.is_strict || y_l_ax.is_strict, y, a * z);
// core.remove(*x_l_z.src);
// core.remove(*y_l_ax.src);
//
// TBD justify all propagations into the core with the corresponding lemma
//
return true;
}
/// [z] z <= y' /\ zx > yx ==> Ω*(x,y') \/ y'x > yx
/// [z] z <= y' /\ yx <= zx ==> Ω*(x,y') \/ yx <= y'x
bool inf_saturate::try_ugt_z(pvar z, conflict_core& core, inequality const& c) {
if (!is_g_v(z, c))
return false;
pdd y = s().var(z);
pdd x = y;
for (auto dd : core) {
auto d = dd.as_inequality();
if (is_YX_l_zX(z, d, x, y) && try_ugt_z(z, core, c, d, x, y))
return true;
}
return false;
}
bool inf_saturate::try_ugt_z(pvar z, conflict_core& core, inequality const& c, inequality const& d, pdd const& x, pdd const& y) {
LOG_H3("Try z <= y' && zx > yx where z := v" << z);
SASSERT(is_g_v(z, c));
// SASSERT(is_YX_l_zX(x, d, x, y));
unsigned const lvl = std::max(c.src->level(), d.src->level());
pdd const& y_prime = c.rhs;
// Omega^*(x, y')
if (!push_omega_mul(core, lvl, x, y_prime))
return false;
// yx <= y'x
push_l(core, lvl, c.is_strict || d.is_strict, y * x, y_prime * x);
return true;
}
}

View file

@ -40,11 +40,41 @@ namespace polysat {
bool perform(pvar v, conflict_core& core) override;
};
class inf_saturate : public inference_engine {
bool push_omega_mul(conflict_core& core, unsigned level, pdd const& x, pdd const& y);
void push_l(conflict_core& core, unsigned level, bool strict, pdd const& lhs, pdd const& rhs);
bool try_ugt_x(pvar v, conflict_core& core, inequality const& c);
bool try_ugt_y(pvar v, conflict_core& core, inequality const& c);
bool try_ugt_y(pvar v, conflict_core& core, inequality const& l_y, inequality const& yx_l_zx, pdd const& x, pdd const& z);
bool try_y_l_ax_and_x_l_z(pvar x, conflict_core& core, inequality const& c);
bool try_y_l_ax_and_x_l_z(pvar x, conflict_core& core, inequality const& x_l_z, inequality const& y_l_ax, pdd const& a, pdd const& y);
bool try_ugt_z(pvar z, conflict_core& core, inequality const& c);
bool try_ugt_z(pvar z, conflict_core& core, inequality const& x_l_z0, inequality const& yz_l_xz, pdd const& y, pdd const& x);
// c := lhs ~ v
// where ~ is < or <=
bool is_l_v(pvar v, inequality const& c);
// c := v ~ rhs
bool is_g_v(pvar v, inequality const& c);
// c := X*y ~ X*Z
bool is_Xy_l_XZ(pvar y, inequality const& c, pdd& x, pdd& z);
// c := Y ~ Ax
bool is_y_l_ax(pvar x, inequality const& d, pdd& a, pdd& y);
// c := Y*X ~ z*X
bool is_YX_l_zX(pvar z, inequality const& c, pdd& x, pdd& y);
public:
bool perform(pvar v, conflict_core& core) override;
};
// TODO: other rules
// clause_ref by_ugt_x();
// clause_ref by_ugt_y();
// clause_ref by_ugt_z();
// clause_ref y_ule_ax_and_x_ule_z();
/*
* TODO: we could resolve constraints in cjust[v] against each other to

View file

@ -63,6 +63,7 @@ namespace polysat {
friend class viable;
friend class assignment_pp;
friend class assignments_pp;
friend class inf_saturate;
typedef ptr_vector<constraint> constraints;
typedef vector<signed_constraint> signed_constraints;