From 56bda59de91ed4dc559955f8f1a9e1b0dc833814 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Mon, 2 Jan 2023 15:01:05 -0800 Subject: [PATCH] bugfix in parity code, add try_infer_parity_equality per status notes Signed-off-by: Nikolaj Bjorner --- src/math/polysat/saturation.cpp | 180 ++++++++++++++++++++++---------- src/math/polysat/saturation.h | 8 +- 2 files changed, 128 insertions(+), 60 deletions(-) diff --git a/src/math/polysat/saturation.cpp b/src/math/polysat/saturation.cpp index 63b7f868c..fda4f69d9 100644 --- a/src/math/polysat/saturation.cpp +++ b/src/math/polysat/saturation.cpp @@ -77,6 +77,8 @@ namespace polysat { prop = true; if (try_add_mul_bound(v, core, i)) prop = true; + if (try_infer_parity_equality(v, core, i)) + prop = true; if (try_mul_eq_bound(v, core, i)) prop = true; if (try_ugt_x(v, core, i)) @@ -194,25 +196,14 @@ namespace polysat { } bool saturation::propagate(pvar v, conflict& core, signed_constraint const& crit, signed_constraint c) { + m_lemma.insert(~crit); + return propagate(v, core, c); + } + + bool saturation::propagate(pvar v, conflict& core, signed_constraint c) { if (is_forced_true(c)) return false; - // NSB - review is it enough to propagate a new literal even if it is not false? - // unit propagation does not require conflicts. - // it should just avoid redundant propagation on literals that are true - // - // Furthermore propagation cannot be used when the resolved variable comes from - // forbidden interval conflicts. The propagated literal effectively adds a new and simpler bound - // on the non-viable variable. This bound then enables tighter non-viability conflicts. - // Effectively c is forced false, but it is forced false within the context of constraints used for viability. - // - // The effective level of the propagation is the level of all the other literals. If their level is below the - // last decision level (conflict level) we expect the propagation to be useful. - // The current assumptions on how conflict lemmas are used do not accomodate propagation it seems. - // - - m_lemma.insert(~crit); - SASSERT(all_of(m_lemma, [this](sat::literal lit) { return is_forced_false(s.lit2cnstr(lit)); })); m_lemma.insert(c); @@ -744,8 +735,10 @@ namespace polysat { // a*x - a*y + b*z = 0 0 <= x < b/a, 0 <= y < b/a => z = 0 // and then => x = y // - // the general lemma is that the linear term a*p = 0 is such that a*p does not overflow + // a general lemma is that the linear term a*p = 0 is such that a*p does not overflow // and therefore p = 0 + // + // the rule would also be subsumed by equality rewriting modulo parity // // TBD: encode the general lemma instead of this special case. // @@ -863,14 +856,22 @@ namespace polysat { * */ - unsigned saturation::min_parity(pdd const& p) { + unsigned saturation::min_parity(pdd const& p, vector& explain) { rational val; auto& m = p.manager(); unsigned N = m.power_of_2(); - if (s.try_eval(p, val)) - return val == 0 ? N : val.trailing_zeros(); + if (p.is_val()) + return p.val() == 0 ? N : p.val().trailing_zeros(); + + if (s.try_eval(p, val)) { + unsigned k = val == 0 ? N : val.trailing_zeros(); + if (k > 0) + explain.push_back(s.parity_at_least(p, k)); + return k; + } unsigned min = 0; + unsigned sz = explain.size(); if (!p.is_var()) { // parity of a product => sum of parities // parity of sum => minimum of monomial's minimal parities @@ -878,26 +879,37 @@ namespace polysat { for (const auto& monomial : p) { unsigned parity_sum = monomial.coeff.trailing_zeros(); for (pvar c : monomial.vars) - parity_sum += min_parity(m.mk_var(c)); + parity_sum += min_parity(m.mk_var(c), explain); min = std::min(min, parity_sum); } } SASSERT(min <= N); for (unsigned j = N; j > min; --j) - if (is_forced_true(s.parity_at_least(p, j))) + if (is_forced_true(s.parity_at_least(p, j))) { + explain.shrink(sz); + explain.push_back(s.parity_at_least(p, j)); return j; + } return min; } - unsigned saturation::max_parity(pdd const& p) { + unsigned saturation::max_parity(pdd const& p, vector& explain) { auto& m = p.manager(); unsigned N = m.power_of_2(); rational val; - if (s.try_eval(p, val)) - return val == 0 ? N : val.trailing_zeros(); + if (p.is_val()) + return p.val() == 0 ? N : p.val().trailing_zeros(); + + if (s.try_eval(p, val)) { + unsigned k = val == 0 ? N : val.trailing_zeros(); + if (k != N) + explain.push_back(s.parity_at_most(p, k)); + return k; + } unsigned max = N; + unsigned sz = explain.size(); if (!p.is_var() && p.is_monomial()) { // it's just a product => sum them up // the case of a sum is harder as the lower bound (because of carry bits) @@ -905,12 +917,15 @@ namespace polysat { dd::pdd_monomial monomial = *p.begin(); max = monomial.coeff.trailing_zeros(); for (pvar c : monomial.vars) - max += max_parity(m.mk_var(c)); + max += max_parity(m.mk_var(c), explain); } for (unsigned j = 0; j < max; ++j) - if (is_forced_true(s.parity_at_most(p, j))) + if (is_forced_true(s.parity_at_most(p, j))) { + explain.shrink(sz); + explain.push_back(s.parity_at_most(p, j)); return j; + } return max; } @@ -932,32 +947,39 @@ namespace polysat { if (a.is_val() && b.is_zero()) return false; - auto propagate1 = [&](signed_constraint premise, signed_constraint conseq) { - if (is_forced_false(premise)) - return false; + auto propagate1 = [&](vector const& premise, signed_constraint conseq) { IF_VERBOSE(1, verbose_stream() << "propagate " << axb_l_y << " " << premise << " => " << conseq << "\n"); m_lemma.reset(); m_lemma.insert_eval(~s.eq(y)); - m_lemma.insert_eval(~premise); + for (auto const& c : premise) { + if (is_forced_false(c)) + return false; + m_lemma.insert_eval(~c); + } return propagate(x, core, axb_l_y, conseq); }; - auto propagate2 = [&](signed_constraint premise1, signed_constraint premise2, signed_constraint conseq) { - if (is_forced_false(premise1)) - return false; - if (is_forced_false(premise2)) - return false; + auto propagate2 = [&](vector const& premise1, vector const& premise2, signed_constraint conseq) { IF_VERBOSE(1, verbose_stream() << "propagate " << axb_l_y << " " << premise1 << " " << premise2 << " => " << conseq << "\n"); m_lemma.reset(); m_lemma.insert_eval(~s.eq(y)); - m_lemma.insert_eval(~premise1); - m_lemma.insert_eval(~premise2); + for (auto const& c : premise1) { + if (is_forced_false(c)) + return false; + m_lemma.insert_eval(~c); + } + for (auto const& c : premise2) { + if (is_forced_false(c)) + return false; + m_lemma.insert_eval(~c); + } return propagate(x, core, axb_l_y, conseq); }; - unsigned min_x = min_parity(X), max_x = max_parity(X); - unsigned min_b = min_parity(b), max_b = max_parity(b); - unsigned min_a = min_parity(a), max_a = max_parity(a); + vector at_least_x, at_most_x, at_least_b, at_most_b, at_least_a, at_most_a; + unsigned min_x = min_parity(X, at_least_x), max_x = max_parity(X, at_most_x); + unsigned min_b = min_parity(b, at_least_b), max_b = max_parity(b, at_most_b); + unsigned min_a = min_parity(a, at_least_a), max_a = max_parity(a, at_most_a); SASSERT(min_x <= max_x && max_x <= N); SASSERT(min_a <= max_a && max_a <= N); SASSERT(min_b <= max_b && max_b <= N); @@ -980,23 +1002,22 @@ namespace polysat { VERIFY(k != 0); return s.parity_at_least(p, k); }; - - if (!b.is_val() && max_b > max_a + max_x && propagate2(at_most(a, max_a), at_most(X, max_x), at_most(b, max_x + max_a))) + if (!b.is_val() && max_b > max_a + max_x && propagate2(at_most_a, at_most_x, at_most(b, max_x + max_a))) return true; - if (!b.is_val() && min_x > min_b && propagate1(at_least(X, min_x), at_least(b, min_x))) + if (!b.is_val() && min_x > min_b && propagate1(at_least_x, at_least(b, min_x))) return true; - if (!b.is_val() && min_a > min_b && propagate1(at_least(a, min_a), at_least(b, min_a))) + if (!b.is_val() && min_a > min_b && propagate1(at_least_a, at_least(b, min_a))) return true; - if (!b.is_val() && min_x > 0 && min_a > 0 && min_x + min_a > min_b && propagate2(at_least(a, min_a), at_least(X, min_x), at_least(b, min_a + min_x))) + if (!b.is_val() && min_x > 0 && min_a > 0 && min_x + min_a > min_b && propagate2(at_least_a, at_least_x, at_least(b, min_a + min_x))) return true; - if (!a.is_val() && max_x <= min_b && min_a < min_b - max_x && propagate2(at_most(X, max_x), at_least(b, min_b), at_least(a, min_b - max_x))) + if (!a.is_val() && max_x <= min_b && min_a < min_b - max_x && propagate2(at_most_x, at_least_b, at_least(a, min_b - max_x))) return true; - if (max_a <= min_b && min_x < min_b - max_a && propagate2(at_most(a, max_a), at_least(b, min_b), at_least(X, min_b - max_a))) + if (max_a <= min_b && min_x < min_b - max_a && propagate2(at_most_a, at_least_b, at_least(X, min_b - max_a))) return true; - if (max_b < N && !a.is_val() && min_x > 0 && min_x <= max_b && max_a > max_b - min_x && propagate2(at_least(X, min_x), at_most(b, max_b), at_most(a, max_b - min_x))) + if (max_b < N && !a.is_val() && min_x > 0 && min_x <= max_b && max_a > max_b - min_x && propagate2(at_least_x, at_most_b, at_most(a, max_b - min_x))) return true; - if (max_b < N && min_a > 0 && min_a <= max_b && max_x > max_b - min_a && propagate2(at_least(a, min_a), at_most(b, max_b), at_most(X, max_b - min_a))) + if (max_b < N && min_a > 0 && min_a <= max_b && max_x > max_b - min_a && propagate2(at_least_a, at_most_b, at_most(X, max_b - min_a))) return true; return false; @@ -1007,7 +1028,7 @@ namespace polysat { * 2^k*x != 0 => parity(x) < N - k * 2^k*x*y != 0 => parity(x) + parity(y) < N - k * - * 2^k*x + b != 0 & parity(x) < N - k => b != 0 + * 2^k*x + b != 0 & parity(x) >= N - k => b != 0 & 2^k*x = 0 (rewriting constraints modulo parity is more powerful and subusmes this) */ bool saturation::try_parity_diseq(pvar x, conflict& core, inequality const& axb_l_y) { set_rule("[x] p(x,y) != 0 => constraints on parity(x), parity(y)"); @@ -1026,7 +1047,7 @@ namespace polysat { unsigned k = coeff.trailing_zeros(); m_lemma.reset(); m_lemma.insert_eval(~s.eq(y)); - m_lemma.insert_eval(~s.eq(b)); + m_lemma.insert_eval(~s.eq(b)); if (propagate(x, core, axb_l_y, ~s.parity_at_least(X, N - k))) return true; // TODO parity on a (without leading coefficient?) @@ -1034,11 +1055,14 @@ namespace polysat { if (a.is_val()) { auto coeff = a.val(); unsigned k = coeff.trailing_zeros(); - unsigned p_x = max_parity(X); - if (k + p_x < N) { + vector at_least_x; + unsigned p_x = min_parity(X, at_least_x); + if (k + p_x >= N) { + // ax + b != 0 m_lemma.reset(); m_lemma.insert_eval(~s.eq(y)); - m_lemma.insert_eval(~s.parity_at_most(X, p_x)); + for (auto c : at_least_x) + m_lemma.insert_eval(~c); if (propagate(x, core, axb_l_y, ~s.eq(b))) return true; } @@ -1153,9 +1177,10 @@ namespace polysat { if (!p1.is_monomial() || !p2.is_monomial()) // TODO: Actually, this could work as well. (4a*d + 6b*c*d) is a multiple of (2a + 3b*c) although none of them is a monomial return l_undef; - - unsigned max_parity_p1 = max_parity(p1); - unsigned min_parity_p2 = min_parity(p2); + + vector maxp1, minp2; + unsigned max_parity_p1 = max_parity(p1, maxp1); + unsigned min_parity_p2 = min_parity(p2, minp2); if (min_parity_p2 > max_parity_p1) return l_false; @@ -1772,6 +1797,45 @@ namespace polysat { return false; } + /** + * p >= q & q*2^k = 0 & p < 2^{K-k} => q = 0 + * More generally + * p >= q + r & q*2^k = 0 & p < 2^{K-k} & r < 2^{K-k} => q = 0 & p >= r + * + * The parity constraint on q entails that the low K-k bits of q must be 0 + * and therefore q is either 0 or at or above 2^{K-k}. + * Since p is blow 2^{K-k} the only intersection between the viable + * intervals imposed by p and possible for q is 0. + * + */ + bool saturation::try_infer_parity_equality(pvar x, conflict& core, inequality const& a_l_b) { + return false; + set_rule("[x] p > q & 2^k*q = 0 & p < 2^{K-k} => q = 0"); + auto& m = s.var2pdd(x); + auto p = a_l_b.rhs(), q = a_l_b.lhs(); + if (q.is_val()) + return false; + if (p.is_val() && p.val() == 0) + return false; + rational p_val; + if (!s.try_eval(p, p_val)) + return false; + vector at_least_k; + unsigned k = min_parity(q, at_least_k); + unsigned N = m.power_of_2(); + if (k == N) + return false; + if (rational::power_of_two(k) > p_val) { + verbose_stream() << k << " " << p_val << " " << a_l_b << "\n"; + m_lemma.reset(); + for (auto const& c : at_least_k) + m_lemma.insert_eval(~c); + m_lemma.insert_eval(~s.ult(p, rational::power_of_two(k))); + return propagate(x, core, a_l_b, s.eq(q)); + } + return false; + } + /* * TODO diff --git a/src/math/polysat/saturation.h b/src/math/polysat/saturation.h index 89afcf303..6cf3d8948 100644 --- a/src/math/polysat/saturation.h +++ b/src/math/polysat/saturation.h @@ -41,6 +41,7 @@ namespace polysat { void log_lemma(pvar v, conflict& core); bool propagate(pvar v, conflict& core, signed_constraint const& crit1, signed_constraint c); bool propagate(pvar v, conflict& core, inequality const& crit1, signed_constraint c); + bool propagate(pvar v, conflict& core, signed_constraint c); bool add_conflict(pvar v, conflict& core, inequality const& crit1, signed_constraint c); bool add_conflict(pvar v, conflict& core, inequality const& crit1, inequality const& crit2, signed_constraint c); @@ -68,6 +69,7 @@ namespace polysat { bool try_add_overflow_bound(pvar x, conflict& core, inequality const& axb_l_y); bool try_add_mul_bound(pvar x, conflict& core, inequality const& axb_l_y); bool try_add_mul_bound2(pvar x, conflict& core, inequality const& axb_l_y); + bool try_infer_parity_equality(pvar x, conflict& core, inequality const& a_l_b); rational round(rational const& N, rational const& x); bool extract_linear_form(pdd const& q, pvar& y, rational& a, rational& b); @@ -142,8 +144,10 @@ namespace polysat { bool has_lower_bound(pvar x, conflict& core, rational& bound, vector& x_le_bound); // determine min/max parity of polynomial - unsigned min_parity(pdd const& p); - unsigned max_parity(pdd const& p); + unsigned min_parity(pdd const& p, vector& explain); + unsigned max_parity(pdd const& p, vector& explain); + unsigned min_parity(pdd const& p) { vector ex; return min_parity(p, ex); } + unsigned max_parity(pdd const& p) { vector ex; return max_parity(p, ex); } lbool get_multiple(const pdd& p1, const pdd& p2, pdd& out);