From 79e7380ffc5dd36a3872c3c0061f113ea22d0c94 Mon Sep 17 00:00:00 2001 From: Clemens Eisenhofer Date: Tue, 3 Jan 2023 17:47:54 +0100 Subject: [PATCH] Pseudo-inverse op_constraint --- src/math/polysat/constraint_manager.cpp | 6 ++ src/math/polysat/constraint_manager.h | 1 + src/math/polysat/op_constraint.cpp | 80 +++++++++++++++++++++++ src/math/polysat/op_constraint.h | 5 +- src/math/polysat/saturation.cpp | 18 ++--- src/math/polysat/simplify_clause.cpp | 2 +- src/math/polysat/solver.h | 3 + src/math/polysat/variable_elimination.cpp | 68 +++++++++++-------- src/math/polysat/variable_elimination.h | 4 +- src/util/rational.cpp | 4 +- 10 files changed, 144 insertions(+), 47 deletions(-) diff --git a/src/math/polysat/constraint_manager.cpp b/src/math/polysat/constraint_manager.cpp index ae3e383ee..901a6a88a 100644 --- a/src/math/polysat/constraint_manager.cpp +++ b/src/math/polysat/constraint_manager.cpp @@ -509,4 +509,10 @@ namespace polysat { pdd constraint_manager::bnor(pdd const& p, pdd const& q) { return bnot(bor(p, q)); } + + pdd constraint_manager::pseudo_inv(pdd const& p) { + if (p.is_val()) + return p.manager().mk_val(p.val().pseudo_inverse(p.power_of_2())); + return mk_op_term(op_constraint::code::inv_op, p, p.manager().zero()); + } } diff --git a/src/math/polysat/constraint_manager.h b/src/math/polysat/constraint_manager.h index 2f8e8541d..dc484f92d 100644 --- a/src/math/polysat/constraint_manager.h +++ b/src/math/polysat/constraint_manager.h @@ -126,6 +126,7 @@ namespace polysat { pdd bxor(pdd const& p, pdd const& q); pdd bnand(pdd const& p, pdd const& q); pdd bnor(pdd const& p, pdd const& q); + pdd pseudo_inv(pdd const& p); constraint* const* begin() const { return m_constraints.data(); } constraint* const* end() const { return m_constraints.data() + m_constraints.size(); } diff --git a/src/math/polysat/op_constraint.cpp b/src/math/polysat/op_constraint.cpp index a54f37b5c..4e19c2d43 100644 --- a/src/math/polysat/op_constraint.cpp +++ b/src/math/polysat/op_constraint.cpp @@ -40,6 +40,8 @@ namespace polysat { if (p.index() > q.index()) std::swap(m_p, m_q); break; + case code::inv_op: + SASSERT(q.is_zero()); default: break; } @@ -61,6 +63,8 @@ namespace polysat { return eval_shl(p, q, r); case code::and_op: return eval_and(p, q, r); + case code::inv_op: + return eval_inv(p, r); default: return l_undef; } @@ -84,6 +88,8 @@ namespace polysat { return out << "<<"; case op_constraint::code::and_op: return out << "&"; + case op_constraint::code::inv_op: + return out << "inv"; default: UNREACHABLE(); return out; @@ -96,6 +102,9 @@ namespace polysat { } std::ostream& op_constraint::display(std::ostream& out, char const* eq) const { + if (m_op == code::inv_op) + return out << r() << " " << eq << " " << m_op << " " << p(); + return out << r() << " " << eq << " " << p() << " " << m_op << " " << q(); } @@ -161,6 +170,8 @@ namespace polysat { return lemma_shl(s, a); case code::and_op: return lemma_and(s, a); + case code::inv_op: + return lemma_inv(s, a); default: NOT_IMPLEMENTED_YET(); return {}; @@ -178,6 +189,8 @@ namespace polysat { // handle masking of high order bits activate_and(s); break; + case code::inv_op: + break; default: break; } @@ -571,6 +584,73 @@ namespace polysat { return true; } + /** + * Produce lemmas for constraint: r == inv p + * p = 0 => r = 0 + * r = 0 => p = 0 + * odd(r) -- for now we are looking for the smallest pseudo-inverse (there are 2^parity(p) of them) + * parity(p) >= k && p * r < 2^k => p * r >= 2^k + * parity(p) < k && p * r >= 2^k => p * r < 2^k + */ + clause_ref op_constraint::lemma_inv(solver& s, assignment const& a) { + auto& m = p().manager(); + auto pv = a.apply_to(p()); + auto rv = a.apply_to(r()); + + if (!pv.is_val() || !rv.is_val() || eval_inv(pv, rv) == l_true) + return {}; + + unsigned parity_pv = pv.val().trailing_zeros(); + unsigned parity_rv = rv.val().trailing_zeros(); + + signed_constraint const invc(this, true); + + // p = 0 => r = 0 + if (pv.is_zero()) + return s.mk_clause(~invc, ~s.eq(p()), s.eq(r()), true); + // r = 0 => p = 0 + if (rv.is_zero()) + return s.mk_clause(~invc, ~s.eq(r()), s.eq(p()), true); + // odd(r) + if (parity_rv != 0) + return s.mk_clause(~invc, s.odd(r()), true); + // parity(p) >= k && p * r < 2^k => p * r >= 2^k + // parity(p) < k && p * r >= 2^k => p * r < 2^k + rational prod = (p() * r()).val(); + SASSERT(prod != rational::power_of_two(parity_pv)); // Why did it evaluate to false in this case? + unsigned lower = 0, upper = p().power_of_2(); + // binary search for the parity + while (lower + 1 < upper) { + unsigned middle = (upper + lower) / 2; + LOG("Splitting on " << middle); + if (parity_pv >= middle) { + lower = middle; + LOG("Its in [" << lower << "; " << upper << ")"); + if (prod < rational::power_of_two(middle)) + return s.mk_clause(~invc, ~s.parity_at_least(p(), middle), s.uge(p() * r(), rational::power_of_two(middle)), false); + } + else { + upper = middle; + LOG("Its in [" << lower << "; " << upper << ")"); + if (prod >= rational::power_of_two(middle)) + return s.mk_clause(~invc, s.parity_at_least(p(), middle), s.ult(p() * r(), rational::power_of_two(middle)), false); + } + } + UNREACHABLE(); + return {}; + } + + /** Evaluate constraint: r == inv p */ + lbool op_constraint::eval_inv(pdd const& p, pdd const& r) { + if (!p.is_val() || !r.is_val()) + return l_undef; + + if (p.is_zero() || r.is_zero()) // the inverse of 0 is 0 (by arbitrary definition). Just to have some unique value + return p.is_zero() && r.is_zero() ? l_true : l_false; + + return p.val().pseudo_inverse(p.power_of_2()) == r.val() ? l_true : l_false; + } + void op_constraint::add_to_univariate_solver(pvar v, solver& s, univariate_solver& us, unsigned dep, bool is_positive) const { pdd pv = s.subst(p()); if (!pv.is_univariate_in(v)) diff --git a/src/math/polysat/op_constraint.h b/src/math/polysat/op_constraint.h index 8414436a6..4bddee287 100644 --- a/src/math/polysat/op_constraint.h +++ b/src/math/polysat/op_constraint.h @@ -26,7 +26,7 @@ namespace polysat { class op_constraint final : public constraint { public: - enum class code { lshr_op, ashr_op, shl_op, and_op }; + enum class code { lshr_op, ashr_op, shl_op, and_op, inv_op }; protected: friend class constraint_manager; @@ -51,6 +51,9 @@ namespace polysat { static lbool eval_and(pdd const& p, pdd const& q, pdd const& r); bool propagate_bits_and(solver& s, bool is_positive); + clause_ref lemma_inv(solver& s, assignment const& a); + static lbool eval_inv(pdd const& p, pdd const& r); + std::ostream& display(std::ostream& out, char const* eq) const; void activate(solver& s); diff --git a/src/math/polysat/saturation.cpp b/src/math/polysat/saturation.cpp index 03c882092..7b79923c0 100644 --- a/src/math/polysat/saturation.cpp +++ b/src/math/polysat/saturation.cpp @@ -1245,34 +1245,26 @@ namespace polysat { if (c->is_ule()) { // If both are equalities this boils down to polynomial superposition => Might generate the same lemma twice auto const& ule = c->to_ule(); - auto [lhs_new, changed_lhs, side_condition_lhs] = m_parity_tracker.eliminate_variable(*this, x, a, b, ule.lhs()); - auto [rhs_new, changed_rhs, side_condition_rhs] = m_parity_tracker.eliminate_variable(*this, x, a, b, ule.rhs()); + m_lemma.reset(); + auto [lhs_new, changed_lhs] = m_parity_tracker.eliminate_variable(*this, x, a, b, ule.lhs(), m_lemma); + auto [rhs_new, changed_rhs] = m_parity_tracker.eliminate_variable(*this, x, a, b, ule.rhs(), m_lemma); if (!changed_lhs && !changed_rhs) continue; // nothing changed - no reason for propagating lemmas - m_lemma.reset(); m_lemma.insert(~c); m_lemma.insert_eval(~s.eq(y)); - for (auto& sc_lhs : side_condition_lhs) // the "path to get the parities" - m_lemma.insert(sc_lhs); - for (auto& sc_rhs : side_condition_rhs) - m_lemma.insert(sc_rhs); if (propagate(x, core, a_l_b, c.is_positive() ? s.ule(lhs_new, rhs_new) : ~s.ule(lhs_new, rhs_new))) prop = true; } else if (c->is_umul_ovfl()) { auto const& ovf = c->to_umul_ovfl(); - auto [lhs_new, changed_lhs, side_condition_lhs] = m_parity_tracker.eliminate_variable(*this, x, a, b, ovf.p()); - auto [rhs_new, changed_rhs, side_condition_rhs] = m_parity_tracker.eliminate_variable(*this, x, a, b, ovf.q()); + auto [lhs_new, changed_lhs] = m_parity_tracker.eliminate_variable(*this, x, a, b, ovf.p(), m_lemma); + auto [rhs_new, changed_rhs] = m_parity_tracker.eliminate_variable(*this, x, a, b, ovf.q(), m_lemma); if (!changed_lhs && !changed_rhs) continue; m_lemma.reset(); m_lemma.insert(~c); m_lemma.insert_eval(~s.eq(y)); - for (auto& sc_lhs : side_condition_lhs) - m_lemma.insert(sc_lhs); - for (auto& sc_rhs : side_condition_rhs) - m_lemma.insert(sc_rhs); if (propagate(x, core, a_l_b, c.is_positive() ? s.umul_ovfl(lhs_new, rhs_new) : ~s.umul_ovfl(lhs_new, rhs_new))) prop = true; diff --git a/src/math/polysat/simplify_clause.cpp b/src/math/polysat/simplify_clause.cpp index e23db0788..29ff47c37 100644 --- a/src/math/polysat/simplify_clause.cpp +++ b/src/math/polysat/simplify_clause.cpp @@ -132,7 +132,7 @@ namespace polysat { auto const eq_it = std::find(cl.begin(), cl.end(), eq.blit()); if (eq_it == cl.end()) continue; - unsigned const eq_idx = std::distance(cl.begin(), eq_it); + unsigned eq_idx = (unsigned)std::distance(cl.begin(), eq_it); any_removed = true; should_remove[eq_idx] = true; if (c.is_positive()) { diff --git a/src/math/polysat/solver.h b/src/math/polysat/solver.h index c3a5e9729..c327c23e8 100644 --- a/src/math/polysat/solver.h +++ b/src/math/polysat/solver.h @@ -405,6 +405,9 @@ namespace polysat { /** Create expression for bit-wise nor of p, q. */ pdd bnor(pdd const& p, pdd const& q) { return m_constraints.bnor(p, q); } + /** Create expression for the smallest pseudo-inverse of p. */ + pdd pseudo_inv(pdd const& p) { return m_constraints.pseudo_inv(p); } + /** * Create polynomial constant. */ diff --git a/src/math/polysat/variable_elimination.cpp b/src/math/polysat/variable_elimination.cpp index 0a563f65e..b4381b44d 100644 --- a/src/math/polysat/variable_elimination.cpp +++ b/src/math/polysat/variable_elimination.cpp @@ -583,7 +583,7 @@ namespace polysat { return inv_pdd; } - pdd parity_tracker::get_odd(const pdd& p, unsigned parity, svector& path) { + pdd parity_tracker::get_odd(const pdd& p, unsigned parity, clause_builder& precondition) { LOG("Getting odd part of " << p); if (p.is_val()) { SASSERT(!p.val().is_zero()); @@ -618,14 +618,14 @@ namespace polysat { LOG("Splitting on " << middle << " with " << parity); if (parity >= middle) { lower = middle; - path.push_back(~c); + precondition.insert(~c); if (needs_propagate) m_builder.insert(~c); verbose_stream() << "Side-condition: " << ~c << "\n"; } else { upper = middle; - path.push_back(c); + precondition.insert(c); if (needs_propagate) m_builder.insert(c); verbose_stream() << "Side-condition: " << c << "\n"; @@ -643,40 +643,40 @@ namespace polysat { } // a * x + b = 0 (x not in a or b; i.e., the equation is linear in x) - // C[p, ...] resp., C[..., p] - std::tuple> parity_tracker::eliminate_variable(saturation& saturation, pvar x, const pdd& a, const pdd& b, const pdd& p) { + // C[x, ...] resp., C[..., x] + std::tuple parity_tracker::eliminate_variable(saturation& saturation, pvar x, const pdd& a, const pdd& b, const pdd& p, clause_builder& precondition) { unsigned p_degree = p.degree(x); if (p_degree == 0) - return { p, false, {} }; + return { p, false }; if (a.is_val() && a.val().is_odd()) { // just invert and plug it in rational a_inv; VERIFY(a.val().mult_inverse(a.power_of_2(), a_inv)); // this works as well if the degree of "p" is not 1: 3 x = a (mod 4) && x^2 <= b => (3a)^2 <= b - return { p.subst_pdd(x, -b * a_inv), true, {} }; + return { p.subst_pdd(x, -b * a_inv), true, }; } // from now on we require linear factors if (p_degree != 1) - return { p, false, {} }; // TODO: Maybe fallback to brute-force + return { p, false }; // TODO: Maybe fallback to brute-force pdd a1 = a.manager().zero(), b1 = a1, mul_fac = a1; p.factor(x, 1, a1, b1); lbool is_multiple = saturation.get_multiple(a1, a, mul_fac); if (is_multiple == l_false) - return { p, false, {} }; // there is no chance to invert + return { p, false }; // there is no chance to invert if (is_multiple == l_true) // we multiply with a factor to make them equal - return { b1 - b * mul_fac, true, {} }; + return { b1 - b * mul_fac, true }; -#if 1 - return { p, false, {} }; +#if 0 + return { p, false }; #else if (!a.is_monomial() || !a1.is_monomial()) - return { p , false, {} }; + return { p , false }; if (!a1.is_var() && !a1.is_val()) { - // TODO: Compromise: Maybe only monomials...? Does this make sense? + // TODO: Compromise: Maybe only monomials...? //return { p, false, {} }; LOG("Warning: Inverting " << a1 << " although it is not a single variable - might not be a good idea"); } @@ -685,37 +685,49 @@ namespace polysat { LOG("Warning: Inverting " << a << " although it is not a single variable - might not be a good idea"); } - // We don't know whether it will work. Use the parity of the assignment - #if 1 unsigned a_parity; if ((a_parity = saturation.min_parity(a)) != saturation.max_parity(a) || saturation.min_parity(a1) < a_parity) - return { p, false, {} }; // We need the parity of a and this has to be for sure less than the parity of a1 + return { p, false }; // We need the parity of a and this has to be for sure less than the parity of a1 + + if (b.is_zero()) + return { b1, true }; - svector precondition; +#if 0 pdd a_pi = get_pseudo_inverse(a, a_parity); +#else + pdd a_pi = s.pseudo_inv(a); + //precondition.insert(~s.eq(a_pi * a, rational::power_of_two(a_parity))); // TODO: This is unfortunately not a justification as the inverse might not be set yet (Can we make it to one?) + precondition.insert(~s.parity_at_most(a, a_parity)); +#endif + + pdd shift = a; if (a_parity > 0) { - pdd shift = s.lshr(a1, a1.manager().mk_val(a_parity)); - precondition.push_back(s.eq(rational::power_of_two(a_parity) * shift, a1)); // TODO: Or s.parity_at_least(a1, a_parity) but we want to reuse the variable introduced by the shift - return { a_pi * (-b) * shift + b1, true, {std::move(precondition)} }; + shift = s.lshr(a1, a1.manager().mk_val(a_parity)); + precondition.insert(~s.eq(rational::power_of_two(a_parity) * shift, a1)); // TODO: Or s.parity_at_least(a1, a_parity) but we want to reuse the variable introduced by the shift } - // Special case: If it is already odd we can directly use the pseudo inverse (as it is the inverse in this case!) - return { a_pi * (-b) * a + b1, true, {std::move(precondition)} }; + LOG("Forced elimination: " << a_pi * (-b) * shift + b1); + LOG("a: " << a); + LOG("a1: " << a1); + LOG("parity of a: " << a_parity); + LOG("pseudo inverse: " << a_pi); + LOG("-b: " << (-b)); + LOG("shifted a" << shift); + LOG("Forced elimination: " << a_pi * (-b) * shift + b1); + return { a_pi * (-b) * shift + b1, true }; #else unsigned a_parity; unsigned a1_parity; if ((a_parity = saturation.min_parity(a)) != saturation.max_parity(a) || (a1_parity = saturation.min_parity(a1)) != saturation.max_parity(a1)) - return { p, false, {} }; // We need the parity, but we failed to get it precisely + return { p, false }; // We need the parity, but we failed to get it precisely if (a_parity > a1_parity) { SASSERT(false); // get_multiple should have excluded this case already - return { p, false, {} }; + return { p, false }; } - svector precondition; - auto odd_a = get_odd(a, a_parity, precondition); auto odd_a1 = get_odd(a1, a1_parity, precondition); pdd inv_odd_a = get_inverse(odd_a); @@ -723,7 +735,7 @@ namespace polysat { LOG("Forced elimination: " << odd_a1 * inv_odd_a * rational::power_of_two(a1_parity - a_parity) * b + b1); verbose_stream() << "Forced elimination: " << odd_a1 * inv_odd_a * rational::power_of_two(a1_parity - a_parity) * (-b) + b1 << "\n"; verbose_stream() << "From: " << "eliminated v" << x << " with a = " << a << "; -b = " << -b << "; p = " << p << "\n"; - return { odd_a1 * inv_odd_a * rational::power_of_two(a1_parity - a_parity) * (-b) + b1, true, {std::move(precondition)} }; + return { odd_a1 * inv_odd_a * rational::power_of_two(a1_parity - a_parity) * (-b) + b1, true }; #endif #endif } diff --git a/src/math/polysat/variable_elimination.h b/src/math/polysat/variable_elimination.h index ee403b1b3..3a09f60a4 100644 --- a/src/math/polysat/variable_elimination.h +++ b/src/math/polysat/variable_elimination.h @@ -87,8 +87,8 @@ namespace polysat { pdd get_pseudo_inverse(const pdd& p, unsigned parity); pdd get_inverse(const pdd& p); - pdd get_odd(const pdd& p, unsigned parity, svector& pat); + pdd get_odd(const pdd& p, unsigned parity, clause_builder& precondition); - std::tuple> eliminate_variable(saturation& saturation, pvar x, const pdd& a, const pdd& b, const pdd& p); + std::tuple eliminate_variable(saturation& saturation, pvar x, const pdd& a, const pdd& b, const pdd& p, clause_builder& precondition); }; } diff --git a/src/util/rational.cpp b/src/util/rational.cpp index 531669dd4..54b40ac58 100644 --- a/src/util/rational.cpp +++ b/src/util/rational.cpp @@ -154,7 +154,7 @@ bool rational::mult_inverse(unsigned num_bits, rational & result) const { } /** - * Compute multiplicative pseudo-inverse modulo 2^num_bits: + * Compute the smallest multiplicative pseudo-inverse modulo 2^num_bits: * * mod(n * n.pseudo_inverse(bits), 2^bits) == 2^k, * where k is maximal such that 2^k divides n. @@ -167,7 +167,7 @@ rational rational::pseudo_inverse(unsigned num_bits) const { SASSERT(!n.is_zero()); // TODO: or we define pseudo-inverse of 0 as 0. unsigned const k = n.trailing_zeros(); rational const odd = machine_div2k(n, k); - VERIFY(odd.mult_inverse(num_bits, result)); + VERIFY(odd.mult_inverse(num_bits - k, result)); SASSERT_EQ(mod(n * result, rational::power_of_two(num_bits)), rational::power_of_two(k)); return result; }