From 658877365c2532772b67793b1f234cc70e6e19cb Mon Sep 17 00:00:00 2001 From: Clemens Eisenhofer Date: Wed, 28 Dec 2022 15:07:03 +0100 Subject: [PATCH] Moved "easy part" of variable elimination to saturation.cpp --- src/math/polysat/saturation.cpp | 132 +++++++++++++++++++++++++------- src/math/polysat/saturation.h | 5 ++ src/math/polysat/solver.cpp | 2 +- 3 files changed, 110 insertions(+), 29 deletions(-) diff --git a/src/math/polysat/saturation.cpp b/src/math/polysat/saturation.cpp index 01ee93492..eedaff140 100644 --- a/src/math/polysat/saturation.cpp +++ b/src/math/polysat/saturation.cpp @@ -837,7 +837,7 @@ namespace polysat { * * odd(x) & even(y) => x + y != 0 * - * Special case rule: a*x + y = 0 => (odd(b) <=> odd(a) & odd(x)) + * Special case rule: a*x + y = 0 => (odd(y) <=> odd(a) & odd(x)) * * General rule: * @@ -845,10 +845,10 @@ namespace polysat { * * using inequalities: * - * parity(x) <= i, parity(a) <= j => parity(b) <= i + j - * parity(x) >= i, parity(a) >= j => parity(b) >= i + j - * parity(x) <= i, parity(b) >= j => parity(a) >= j - i - * parity(x) >= i, parity(b) <= j => parity(a) <= j - i + * parity(x) <= i, parity(a) <= j => parity(y) <= i + j + * parity(x) >= i, parity(a) >= j => parity(y) >= i + j + * parity(x) <= i, parity(y) >= j => parity(a) >= j - i + * parity(x) >= i, parity(y) <= j => parity(a) <= j - i * symmetric rules for swapping x, a * * min_parity(x) = number of trailing bits of x if x is a value @@ -869,13 +869,16 @@ namespace polysat { unsigned N = m.power_of_2(); if (s.try_eval(p, val)) return val == 0 ? N : val.trailing_zeros(); - -#if 0 - // TBD: factor p - auto coeff = p.leading_coefficient(); - unsigned offset = coeff.trailing_zeros(); - verbose_stream() << "COEFF " << coeff << "\n"; -#endif + + if (!p.is_var() && p.is_monomial()) { + // it's just a product => sum them up + dd::pdd_monomial monomial = *p.begin(); + unsigned parity_sum = monomial.coeff.trailing_zeros(); + for (pvar c : monomial.vars) + parity_sum += min_parity(m.mk_var(c)); + return std::min(N, parity_sum); + } + for (unsigned j = N; j > 0; --j) if (is_forced_true(s.parity(p, j))) return j; @@ -889,7 +892,14 @@ namespace polysat { if (s.try_eval(p, val)) return val == 0 ? N : val.trailing_zeros(); - // TBD: factor p + if (!p.is_var() && p.is_monomial()) { + // it's just a product => sum them up + dd::pdd_monomial monomial = *p.begin(); + unsigned parity_sum = monomial.coeff.trailing_zeros(); + for (pvar c : monomial.vars) + parity_sum += max_parity(m.mk_var(c)); + return std::min(N, parity_sum); + } for (unsigned j = 0; j < N; ++j) if (is_forced_true(s.parity_at_most(p, j))) @@ -910,7 +920,7 @@ namespace polysat { return false; if (a.is_one() && (-b).is_var()) // y == x return false; - if (a.is_one()) + if (a.is_one()) // TODO: Sure this is correct? return false; if (a.is_val() && b.is_zero()) return false; @@ -1123,21 +1133,82 @@ namespace polysat { return false; } - + lbool saturation::get_multiple(const pdd& p1, const pdd& p2, pdd& out) { + LOG("Check if " << p2 << " can be multiplied with something to get " << p1); + if (p1.is_zero()) { + out = p1.manager().zero(); + return l_true; + } + if (p2.is_one()) { + out = p1; + return l_true; + } + if (!p1.is_monomial() || !p2.is_monomial()) + // TODO: Actually, this could work as well. (4a*d + 6b*c*d) is a multiple of (2a + 3b*c) although none of them is a monomial + return l_undef; + + unsigned max_parity_p1 = max_parity(p1); + unsigned min_parity_p2 = min_parity(p2); + + if (min_parity_p2 > max_parity_p1) + return l_false; + + dd::pdd_monomial p1m = *p1.begin(); + dd::pdd_monomial p2m = *p2.begin(); + + m_occ_cnt.reserve(s.m_vars.size(), (unsigned)0); // TODO: Are there duplicates in the list (e.g., v1 * v1)?) + + for (const auto& v1 : p1m.vars) { + if (m_occ_cnt[v1] == 0) + m_occ.push_back(v1); + m_occ_cnt[v1]++; + } + for (const auto& v2 : p2m.vars) { + if (m_occ_cnt[v2] == 0) { + for (const auto& occ : m_occ) + m_occ_cnt[occ] = 0; + m_occ.clear(); + return l_undef; // p2 contains more v2 than p1; we need more information (assignments) + } + m_occ_cnt[v2]--; + } + + unsigned tz1 = p1m.coeff.trailing_zeros(); + unsigned tz2 = p2m.coeff.trailing_zeros(); + if (tz2 > tz1) + return l_undef; + + rational odd = div(p2m.coeff, rational::power_of_two(tz2)); + rational inv; + VERIFY(odd.mult_inverse(p1.power_of_2() - tz2, inv)); // we divided by the even part, so it has to be odd/invertible now + inv *= div(p1m.coeff, rational::power_of_two(tz2)); + + out = p1.manager().mk_val(inv); + for (const auto& occ : m_occ) { + for (unsigned i = 0; i < m_occ_cnt[occ]; i++) + out *= s.var(occ); + m_occ_cnt[occ] = 0; + } + m_occ.clear(); + LOG("Found multiple: " << out); + return l_true; + } bool saturation::try_factor_equality(pvar x, conflict& core, inequality const& a_l_b) { set_rule("[x] ax + b = 0 & C[x] => C[-inv(a)*b]"); auto& m = s.var2pdd(x); pdd y = m.zero(); - pdd a = y, b = y, a1 = y, b1 = y; - if (!is_AxB_eq_0(x, a_l_b, a, b, y)) + pdd a = y, b = y, a1 = y, b1 = y, mul_fac = y; + if (!is_AxB_eq_0(x, a_l_b, a, b, y)) // TODO: Is the restriction to linear "x" too restrictive? return false; + bool is_invertible = a.is_val() && a.val().is_odd(); if (is_invertible) { rational a_inv; VERIFY(a.val().mult_inverse(m.power_of_2(), a_inv)); - b = -b*a_inv; + b = -b * a_inv; } + bool change = false; bool prop = false; auto replace = [&](pdd p) { @@ -1146,19 +1217,23 @@ namespace polysat { return p; if (is_invertible) { change = true; + // this works as well if the degree of "p" is not 1: 3 x = a (mod 4) & x^2 <= b => (3a)^2 <= b return p.subst_pdd(x, b); } - if (p_degree == 1) { - p.factor(x, 1, a1, b1); - if (a1 == a) { - change = true; - return b1 - b; - } - if (a1 == -a) { - change = true; - return b1 + b; - } + if (p_degree != 1) + return p; // TODO: Maybe fallback to brute-force + + p.factor(x, 1, a1, b1); + lbool is_multiple = get_multiple(a1, a, mul_fac); + if (is_multiple == l_false) + return p; // there is no chance to invert + if (is_multiple == l_true) { + change = true; + return b1 - b * mul_fac; } + + // We don't know whether it will work. Brute-force the parity + // TODO: Brute force goes here return p; }; @@ -1166,6 +1241,7 @@ namespace polysat { change = false; if (c == a_l_b.as_signed_constraint()) continue; + LOG("Trying to eliminate v" << x << " in " << c << " by using equation " << a_l_b.as_signed_constraint()); if (c->is_ule()) { auto const& ule = c->to_ule(); auto p = replace(ule.lhs()); diff --git a/src/math/polysat/saturation.h b/src/math/polysat/saturation.h index 11d191ebb..4ce268021 100644 --- a/src/math/polysat/saturation.h +++ b/src/math/polysat/saturation.h @@ -25,6 +25,9 @@ namespace polysat { solver& s; clause_builder m_lemma; char const* m_rule = nullptr; + + unsigned_vector m_occ; + unsigned_vector m_occ_cnt; void set_rule(char const* r) { m_rule = r; } @@ -128,6 +131,8 @@ namespace polysat { unsigned min_parity(pdd const& p); unsigned max_parity(pdd const& p); + lbool get_multiple(const pdd& p1, const pdd& p2, pdd& out); + bool is_forced_eq(pdd const& p, rational const& val); bool is_forced_eq(pdd const& p, int i) { return is_forced_eq(p, rational(i)); } diff --git a/src/math/polysat/solver.cpp b/src/math/polysat/solver.cpp index 5260fa11c..1d84bac4e 100644 --- a/src/math/polysat/solver.cpp +++ b/src/math/polysat/solver.cpp @@ -1063,7 +1063,7 @@ namespace polysat { void solver::assign_eval(sat::literal lit) { signed_constraint const c = lit2cnstr(lit); - LOG_V(10, "Evaluate: " << lit_pp(*this ,lit)); + LOG_V(10, "Evaluate: " << lit_pp(*this, lit)); // assertion is false if (!c.is_currently_true(*this)) IF_VERBOSE(0, verbose_stream() << c << " is not currently true\n"); SASSERT(c.is_currently_true(*this));