3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-07-25 21:57:00 +00:00

Moved "easy part" of variable elimination to saturation.cpp

This commit is contained in:
Clemens Eisenhofer 2022-12-28 15:07:03 +01:00
parent b4f5225ab3
commit 658877365c
3 changed files with 110 additions and 29 deletions

View file

@ -837,7 +837,7 @@ namespace polysat {
* *
* odd(x) & even(y) => x + y != 0 * odd(x) & even(y) => x + y != 0
* *
* Special case rule: a*x + y = 0 => (odd(b) <=> odd(a) & odd(x)) * Special case rule: a*x + y = 0 => (odd(y) <=> odd(a) & odd(x))
* *
* General rule: * General rule:
* *
@ -845,10 +845,10 @@ namespace polysat {
* *
* using inequalities: * using inequalities:
* *
* parity(x) <= i, parity(a) <= j => parity(b) <= i + j * parity(x) <= i, parity(a) <= j => parity(y) <= i + j
* parity(x) >= i, parity(a) >= j => parity(b) >= i + j * parity(x) >= i, parity(a) >= j => parity(y) >= i + j
* parity(x) <= i, parity(b) >= j => parity(a) >= j - i * parity(x) <= i, parity(y) >= j => parity(a) >= j - i
* parity(x) >= i, parity(b) <= j => parity(a) <= j - i * parity(x) >= i, parity(y) <= j => parity(a) <= j - i
* symmetric rules for swapping x, a * symmetric rules for swapping x, a
* *
* min_parity(x) = number of trailing bits of x if x is a value * min_parity(x) = number of trailing bits of x if x is a value
@ -869,13 +869,16 @@ namespace polysat {
unsigned N = m.power_of_2(); unsigned N = m.power_of_2();
if (s.try_eval(p, val)) if (s.try_eval(p, val))
return val == 0 ? N : val.trailing_zeros(); return val == 0 ? N : val.trailing_zeros();
#if 0 if (!p.is_var() && p.is_monomial()) {
// TBD: factor p // it's just a product => sum them up
auto coeff = p.leading_coefficient(); dd::pdd_monomial monomial = *p.begin();
unsigned offset = coeff.trailing_zeros(); unsigned parity_sum = monomial.coeff.trailing_zeros();
verbose_stream() << "COEFF " << coeff << "\n"; for (pvar c : monomial.vars)
#endif parity_sum += min_parity(m.mk_var(c));
return std::min(N, parity_sum);
}
for (unsigned j = N; j > 0; --j) for (unsigned j = N; j > 0; --j)
if (is_forced_true(s.parity(p, j))) if (is_forced_true(s.parity(p, j)))
return j; return j;
@ -889,7 +892,14 @@ namespace polysat {
if (s.try_eval(p, val)) if (s.try_eval(p, val))
return val == 0 ? N : val.trailing_zeros(); return val == 0 ? N : val.trailing_zeros();
// TBD: factor p if (!p.is_var() && p.is_monomial()) {
// it's just a product => sum them up
dd::pdd_monomial monomial = *p.begin();
unsigned parity_sum = monomial.coeff.trailing_zeros();
for (pvar c : monomial.vars)
parity_sum += max_parity(m.mk_var(c));
return std::min(N, parity_sum);
}
for (unsigned j = 0; j < N; ++j) for (unsigned j = 0; j < N; ++j)
if (is_forced_true(s.parity_at_most(p, j))) if (is_forced_true(s.parity_at_most(p, j)))
@ -910,7 +920,7 @@ namespace polysat {
return false; return false;
if (a.is_one() && (-b).is_var()) // y == x if (a.is_one() && (-b).is_var()) // y == x
return false; return false;
if (a.is_one()) if (a.is_one()) // TODO: Sure this is correct?
return false; return false;
if (a.is_val() && b.is_zero()) if (a.is_val() && b.is_zero())
return false; return false;
@ -1123,21 +1133,82 @@ namespace polysat {
return false; return false;
} }
lbool saturation::get_multiple(const pdd& p1, const pdd& p2, pdd& out) {
LOG("Check if " << p2 << " can be multiplied with something to get " << p1);
if (p1.is_zero()) {
out = p1.manager().zero();
return l_true;
}
if (p2.is_one()) {
out = p1;
return l_true;
}
if (!p1.is_monomial() || !p2.is_monomial())
// TODO: Actually, this could work as well. (4a*d + 6b*c*d) is a multiple of (2a + 3b*c) although none of them is a monomial
return l_undef;
unsigned max_parity_p1 = max_parity(p1);
unsigned min_parity_p2 = min_parity(p2);
if (min_parity_p2 > max_parity_p1)
return l_false;
dd::pdd_monomial p1m = *p1.begin();
dd::pdd_monomial p2m = *p2.begin();
m_occ_cnt.reserve(s.m_vars.size(), (unsigned)0); // TODO: Are there duplicates in the list (e.g., v1 * v1)?)
for (const auto& v1 : p1m.vars) {
if (m_occ_cnt[v1] == 0)
m_occ.push_back(v1);
m_occ_cnt[v1]++;
}
for (const auto& v2 : p2m.vars) {
if (m_occ_cnt[v2] == 0) {
for (const auto& occ : m_occ)
m_occ_cnt[occ] = 0;
m_occ.clear();
return l_undef; // p2 contains more v2 than p1; we need more information (assignments)
}
m_occ_cnt[v2]--;
}
unsigned tz1 = p1m.coeff.trailing_zeros();
unsigned tz2 = p2m.coeff.trailing_zeros();
if (tz2 > tz1)
return l_undef;
rational odd = div(p2m.coeff, rational::power_of_two(tz2));
rational inv;
VERIFY(odd.mult_inverse(p1.power_of_2() - tz2, inv)); // we divided by the even part, so it has to be odd/invertible now
inv *= div(p1m.coeff, rational::power_of_two(tz2));
out = p1.manager().mk_val(inv);
for (const auto& occ : m_occ) {
for (unsigned i = 0; i < m_occ_cnt[occ]; i++)
out *= s.var(occ);
m_occ_cnt[occ] = 0;
}
m_occ.clear();
LOG("Found multiple: " << out);
return l_true;
}
bool saturation::try_factor_equality(pvar x, conflict& core, inequality const& a_l_b) { bool saturation::try_factor_equality(pvar x, conflict& core, inequality const& a_l_b) {
set_rule("[x] ax + b = 0 & C[x] => C[-inv(a)*b]"); set_rule("[x] ax + b = 0 & C[x] => C[-inv(a)*b]");
auto& m = s.var2pdd(x); auto& m = s.var2pdd(x);
pdd y = m.zero(); pdd y = m.zero();
pdd a = y, b = y, a1 = y, b1 = y; pdd a = y, b = y, a1 = y, b1 = y, mul_fac = y;
if (!is_AxB_eq_0(x, a_l_b, a, b, y)) if (!is_AxB_eq_0(x, a_l_b, a, b, y)) // TODO: Is the restriction to linear "x" too restrictive?
return false; return false;
bool is_invertible = a.is_val() && a.val().is_odd(); bool is_invertible = a.is_val() && a.val().is_odd();
if (is_invertible) { if (is_invertible) {
rational a_inv; rational a_inv;
VERIFY(a.val().mult_inverse(m.power_of_2(), a_inv)); VERIFY(a.val().mult_inverse(m.power_of_2(), a_inv));
b = -b*a_inv; b = -b * a_inv;
} }
bool change = false; bool change = false;
bool prop = false; bool prop = false;
auto replace = [&](pdd p) { auto replace = [&](pdd p) {
@ -1146,19 +1217,23 @@ namespace polysat {
return p; return p;
if (is_invertible) { if (is_invertible) {
change = true; change = true;
// this works as well if the degree of "p" is not 1: 3 x = a (mod 4) & x^2 <= b => (3a)^2 <= b
return p.subst_pdd(x, b); return p.subst_pdd(x, b);
} }
if (p_degree == 1) { if (p_degree != 1)
p.factor(x, 1, a1, b1); return p; // TODO: Maybe fallback to brute-force
if (a1 == a) {
change = true; p.factor(x, 1, a1, b1);
return b1 - b; lbool is_multiple = get_multiple(a1, a, mul_fac);
} if (is_multiple == l_false)
if (a1 == -a) { return p; // there is no chance to invert
change = true; if (is_multiple == l_true) {
return b1 + b; change = true;
} return b1 - b * mul_fac;
} }
// We don't know whether it will work. Brute-force the parity
// TODO: Brute force goes here
return p; return p;
}; };
@ -1166,6 +1241,7 @@ namespace polysat {
change = false; change = false;
if (c == a_l_b.as_signed_constraint()) if (c == a_l_b.as_signed_constraint())
continue; continue;
LOG("Trying to eliminate v" << x << " in " << c << " by using equation " << a_l_b.as_signed_constraint());
if (c->is_ule()) { if (c->is_ule()) {
auto const& ule = c->to_ule(); auto const& ule = c->to_ule();
auto p = replace(ule.lhs()); auto p = replace(ule.lhs());

View file

@ -25,6 +25,9 @@ namespace polysat {
solver& s; solver& s;
clause_builder m_lemma; clause_builder m_lemma;
char const* m_rule = nullptr; char const* m_rule = nullptr;
unsigned_vector m_occ;
unsigned_vector m_occ_cnt;
void set_rule(char const* r) { m_rule = r; } void set_rule(char const* r) { m_rule = r; }
@ -128,6 +131,8 @@ namespace polysat {
unsigned min_parity(pdd const& p); unsigned min_parity(pdd const& p);
unsigned max_parity(pdd const& p); unsigned max_parity(pdd const& p);
lbool get_multiple(const pdd& p1, const pdd& p2, pdd& out);
bool is_forced_eq(pdd const& p, rational const& val); bool is_forced_eq(pdd const& p, rational const& val);
bool is_forced_eq(pdd const& p, int i) { return is_forced_eq(p, rational(i)); } bool is_forced_eq(pdd const& p, int i) { return is_forced_eq(p, rational(i)); }

View file

@ -1063,7 +1063,7 @@ namespace polysat {
void solver::assign_eval(sat::literal lit) { void solver::assign_eval(sat::literal lit) {
signed_constraint const c = lit2cnstr(lit); signed_constraint const c = lit2cnstr(lit);
LOG_V(10, "Evaluate: " << lit_pp(*this ,lit)); LOG_V(10, "Evaluate: " << lit_pp(*this, lit));
// assertion is false // assertion is false
if (!c.is_currently_true(*this)) IF_VERBOSE(0, verbose_stream() << c << " is not currently true\n"); if (!c.is_currently_true(*this)) IF_VERBOSE(0, verbose_stream() << c << " is not currently true\n");
SASSERT(c.is_currently_true(*this)); SASSERT(c.is_currently_true(*this));