diff --git a/src/math/polysat/saturation.cpp b/src/math/polysat/saturation.cpp index 2735f9b53..1bdc752a8 100644 --- a/src/math/polysat/saturation.cpp +++ b/src/math/polysat/saturation.cpp @@ -30,7 +30,7 @@ TODO: when we check that 'x' is "unary": namespace polysat { - saturation::saturation(solver& s) : s(s), m_lemma(s) {} + saturation::saturation(solver& s) : s(s), m_lemma(s), m_parity_tracker(s) {} void saturation::log_lemma(pvar v, conflict& core) { IF_VERBOSE(1, auto const& cl = core.lemmas().back(); @@ -880,7 +880,7 @@ namespace polysat { } for (unsigned j = N; j > 0; --j) - if (is_forced_true(s.parity(p, j))) + if (is_forced_true(s.parity_at_least(p, j))) return j; return 0; } @@ -971,7 +971,7 @@ namespace polysat { auto at_least = [&](pdd const& p, unsigned k) { VERIFY(k != 0); - return s.parity(p, k); + return s.parity_at_least(p, k); }; @@ -1020,7 +1020,7 @@ namespace polysat { m_lemma.reset(); m_lemma.insert_eval(~s.eq(y)); m_lemma.insert_eval(~s.eq(b)); - if (propagate(x, core, axb_l_y, ~s.parity(X, N - k))) + if (propagate(x, core, axb_l_y, ~s.parity_at_least(X, N - k))) return true; // TODO parity on a (without leading coefficient?) } @@ -1135,7 +1135,7 @@ namespace polysat { lbool saturation::get_multiple(const pdd& p1, const pdd& p2, pdd& out) { LOG("Check if " << p2 << " can be multiplied with something to get " << p1); - if (p1.is_zero()) { + if (p1.is_zero()) { // TODO: use the evaluated parity (max_parity) instead? out = p1.manager().zero(); return l_true; } @@ -1202,40 +1202,8 @@ namespace polysat { if (!is_AxB_eq_0(x, a_l_b, a, b, y)) // TODO: Is the restriction to linear "x" too restrictive? return false; - bool is_invertible = a.is_val() && a.val().is_odd(); - if (is_invertible) { - rational a_inv; - VERIFY(a.val().mult_inverse(m.power_of_2(), a_inv)); - b = -b * a_inv; - } - bool change = false; bool prop = false; - auto replace = [&](pdd p) { - unsigned p_degree = p.degree(x); - if (p_degree == 0) - return p; - if (is_invertible) { - change = true; - // this works as well if the degree of "p" is not 1: 3 x = a (mod 4) & x^2 <= b => (3a)^2 <= b - return p.subst_pdd(x, b); - } - if (p_degree != 1) - return p; // TODO: Maybe fallback to brute-force - - p.factor(x, 1, a1, b1); - lbool is_multiple = get_multiple(a1, a, mul_fac); - if (is_multiple == l_false) - return p; // there is no chance to invert - if (is_multiple == l_true) { - change = true; - return b1 - b * mul_fac; - } - - // We don't know whether it will work. Brute-force the parity - // TODO: Brute force goes here - return p; - }; for (auto c : core) { change = false; @@ -1243,27 +1211,38 @@ namespace polysat { continue; LOG("Trying to eliminate v" << x << " in " << c << " by using equation " << a_l_b.as_signed_constraint()); if (c->is_ule()) { + // If both are equalities this boils down to polynomial superposition => Might generate the same lemma twice auto const& ule = c->to_ule(); - auto p = replace(ule.lhs()); - auto q = replace(ule.rhs()); - if (!change) - continue; + auto [lhs_new, changed_lhs, side_condition_lhs] = m_parity_tracker.eliminate_variable(*this, x, a, b, ule.lhs()); + auto [rhs_new, changed_rhs, side_condition_rhs] = m_parity_tracker.eliminate_variable(*this, x, a, b, ule.rhs()); + if (!changed_lhs && !changed_rhs) + continue; // nothing changed - no reason for propagating lemmas m_lemma.reset(); m_lemma.insert(~c); m_lemma.insert_eval(~s.eq(y)); - if (propagate(x, core, a_l_b, c.is_positive() ? s.ule(p, q) : ~s.ule(p, q))) + for (auto& sc_lhs : side_condition_lhs) // TODO: Do we really need the path as a side-condition in case of parity elimination? + m_lemma.insert(sc_lhs); + for (auto& sc_rhs : side_condition_rhs) + m_lemma.insert(sc_rhs); + + if (propagate(x, core, a_l_b, c.is_positive() ? s.ule(lhs_new, rhs_new) : ~s.ule(lhs_new, rhs_new))) prop = true; } else if (c->is_umul_ovfl()) { auto const& ovf = c->to_umul_ovfl(); - auto p = replace(ovf.p()); - auto q = replace(ovf.q()); - if (!change) + auto [lhs_new, changed_lhs, side_condition_lhs] = m_parity_tracker.eliminate_variable(*this, x, a, b, ovf.p()); + auto [rhs_new, changed_rhs, side_condition_rhs] = m_parity_tracker.eliminate_variable(*this, x, a, b, ovf.q()); + if (!changed_lhs && !changed_rhs) continue; m_lemma.reset(); m_lemma.insert(~c); m_lemma.insert_eval(~s.eq(y)); - if (propagate(x, core, a_l_b, c.is_positive() ? s.umul_ovfl(p, q) : ~s.umul_ovfl(p, q))) + for (auto& sc_lhs : side_condition_lhs) + m_lemma.insert(sc_lhs); + for (auto& sc_rhs : side_condition_rhs) + m_lemma.insert(sc_rhs); + + if (propagate(x, core, a_l_b, c.is_positive() ? s.umul_ovfl(lhs_new, rhs_new) : ~s.umul_ovfl(lhs_new, rhs_new))) prop = true; } } diff --git a/src/math/polysat/saturation.h b/src/math/polysat/saturation.h index 6f2ad19f3..fb874a0e4 100644 --- a/src/math/polysat/saturation.h +++ b/src/math/polysat/saturation.h @@ -14,6 +14,7 @@ Author: #pragma once #include "math/polysat/clause_builder.h" #include "math/polysat/conflict.h" +#include "math/polysat/variable_elimination.h" namespace polysat { @@ -22,10 +23,13 @@ namespace polysat { */ class saturation { + friend class parity_tracker; + solver& s; clause_builder m_lemma; char const* m_rule = nullptr; + parity_tracker m_parity_tracker; unsigned_vector m_occ; unsigned_vector m_occ_cnt; diff --git a/src/math/polysat/solver.cpp b/src/math/polysat/solver.cpp index 1d84bac4e..60f8f9c12 100644 --- a/src/math/polysat/solver.cpp +++ b/src/math/polysat/solver.cpp @@ -19,6 +19,7 @@ Author: #include "math/polysat/solver.h" #include "math/polysat/log.h" #include "math/polysat/polysat_params.hpp" +#include "math/polysat/variable_elimination.h" #include // For development; to be removed once the linear solver works well enough @@ -238,12 +239,12 @@ namespace polysat { LOG_H2("Propagate " << assignment_pp(*this, v, get_value(v))); SASSERT(!m_locked_wlist); DEBUG_CODE(m_locked_wlist = v;); + unsigned i = 0, j = 0; + for (; i < m_pwatch[v].size() && !is_conflict(); ++i) + if (!propagate(v, m_pwatch[v][i])) // propagate may change watch-list reference + m_pwatch[v][j++] = m_pwatch[v][i]; auto& wlist = m_pwatch[v]; - unsigned i = 0, j = 0, sz = wlist.size(); - for (; i < sz && !is_conflict(); ++i) - if (!propagate(v, wlist[i])) - wlist[j++] = wlist[i]; - for (; i < sz; ++i) + for (; i < wlist.size(); ++i) wlist[j++] = wlist[i]; wlist.shrink(j); if (is_conflict()) @@ -435,6 +436,7 @@ namespace polysat { #if ENABLE_LINEAR_SOLVER m_linear_solver.push(); #endif + m_fixed_bits.push(); } void solver::pop_levels(unsigned num_levels) { @@ -448,6 +450,8 @@ namespace polysat { #if ENABLE_LINEAR_SOLVER m_linear_solver.pop(num_levels); #endif + m_fixed_bits.pop(); + while (num_levels > 0) { switch (m_trail.back()) { case trail_instr_t::qhead_i: { @@ -602,7 +606,7 @@ namespace polysat { } } #endif - m_fixed_bits.push(); + if (can_bdecide()) bdecide(); else @@ -833,7 +837,6 @@ namespace polysat { continue; } if (j.is_decision()) { - m_fixed_bits.pop(); m_conflict.revert_pvar(v); revert_decision(v); return; @@ -862,7 +865,6 @@ namespace polysat { } SASSERT(!m_bvars.is_assumption(var)); // TODO: "assumption" is basically "propagated by unit clause" (or "at base level"); except we do not explicitly store the unit clause. if (m_bvars.is_decision(var)) { - m_fixed_bits.pop(); revert_bool_decision(lit); return; } diff --git a/src/math/polysat/solver.h b/src/math/polysat/solver.h index a9e384634..c3a5e9729 100644 --- a/src/math/polysat/solver.h +++ b/src/math/polysat/solver.h @@ -137,6 +137,7 @@ namespace polysat { friend class ex_polynomial_superposition; friend class free_variable_elimination; friend class saturation; + friend class parity_tracker; friend class constraint_manager; friend class scoped_solverv; friend class test_polysat; @@ -422,9 +423,9 @@ namespace polysat { signed_constraint eq(pdd const& p, rational const& q) { return eq(p - q); } signed_constraint eq(pdd const& p, unsigned q) { return eq(p - q); } signed_constraint odd(pdd const& p) { return ~even(p); } - signed_constraint even(pdd const& p) { return parity(p, 1); } + signed_constraint even(pdd const& p) { return parity_at_least(p, 1); } /** parity(p) >= k */ - signed_constraint parity(pdd const& p, unsigned k) { // TODO: rename to parity_at_least? + signed_constraint parity_at_least(pdd const& p, unsigned k) { unsigned N = p.manager().power_of_2(); // parity(p) >= k // <=> p * 2^(N - k) == 0 @@ -449,7 +450,7 @@ namespace polysat { return eq(p.manager().zero()); } else - return ~parity(p, k + 1); + return ~parity_at_least(p, k + 1); } signed_constraint diseq(pdd const& p, rational const& q) { return diseq(p - q); } signed_constraint diseq(pdd const& p, unsigned q) { return diseq(p - q); } diff --git a/src/math/polysat/variable_elimination.cpp b/src/math/polysat/variable_elimination.cpp index 5adbe11b5..c969437f8 100644 --- a/src/math/polysat/variable_elimination.cpp +++ b/src/math/polysat/variable_elimination.cpp @@ -11,10 +11,11 @@ Author: Jakob Rath 2021-04-06 --*/ -#include "math/polysat/variable_elimination.h" #include "math/polysat/conflict.h" #include "math/polysat/clause_builder.h" +#include "math/polysat/saturation.h" #include "math/polysat/solver.h" +#include "math/polysat/variable_elimination.h" #include namespace polysat { @@ -252,7 +253,7 @@ namespace polysat { find_lemma(v, c, core); } } - + void free_variable_elimination::find_lemma(pvar v, signed_constraint c, conflict& core) { LOG_H3("Free Variable Elimination for v" << v << " using equation " << c); pdd const& p = c.eq(); @@ -380,7 +381,7 @@ namespace polysat { LOG("pv_lhs: " << pv_lhs); LOG("odd_fac_lhs: " << odd_fac_lhs); LOG("power_diff_lhs: " << power_diff_lhs); - new_lhs = -rest * *fac_odd_inv * power_diff_lhs * odd_fac_lhs + rest_rhs; + new_lhs = -rest * *fac_odd_inv * power_diff_lhs * odd_fac_lhs + rest_lhs; p1 = s.ule(get_dyadic_valuation(fac).first, get_dyadic_valuation(fac_lhs).first); } else { @@ -405,7 +406,7 @@ namespace polysat { } } - signed_constraint c_new = s.ule(new_lhs , new_rhs); + signed_constraint c_new = s.ule(new_lhs, new_rhs); if (c_target.is_negative()) c_new.negate(); @@ -524,5 +525,157 @@ namespace polysat { LOG("Found multiple: " << out); return is_multiple; } - + + unsigned parity_tracker::get_id(const pdd& p) { + // SASSERT(p.is_var()); // For now + // pvar v = p.var(); + unsigned id = m_pdd_to_id.get(optional(p), -1); + if (id == -1) { + id = m_pdd_to_id.size(); + m_pdd_to_id.insert(optional(p), id); + } + return id; + } + + pdd parity_tracker::get_inverse(const pdd &p) { + LOG("Getting inverse of " << p); + if (p.is_val()) { + SASSERT(p.val().is_odd()); + rational iv; + VERIFY(p.val().mult_inverse(p.power_of_2(), iv)); + return p.manager().mk_val(iv); + } + unsigned v = get_id(p); + if (m_inverse.size() > v && m_inverse[v] != -1) + return s.var(m_inverse[v]); + + pvar inv = s.add_var(p.power_of_2()); + pdd inv_pdd = p.manager().mk_var(inv); + m_inverse.setx(v, inv, -1); + s.add_clause(s.eq(inv_pdd * p, p.manager().one()), false); + return inv_pdd; + } + + pdd parity_tracker::get_odd(const pdd& p, unsigned parity, svector& path) { + LOG("Getting odd part of " << p); + if (p.is_val()) { + SASSERT(!p.val().is_zero()); + rational odd = machine_div(p.val(), rational::power_of_two(p.val().trailing_zeros())); + SASSERT(odd.is_odd()); + return p.manager().mk_val(odd); + } + unsigned v = get_id(p); + pvar odd_v; + bool needs_propagate = true; + + if (m_odd.size() > v && m_odd[v].initialized()) { + auto& tuple = *(m_odd[v]); + SASSERT(tuple.second.size() == p.power_of_2()); + odd_v = tuple.first; + needs_propagate = !tuple.second[parity]; + } + else { + odd_v = s.add_var(p.power_of_2()); + m_odd.setx(v, optional>({ odd_v, bool_vector(p.power_of_2(), false) }), optional>::undef()); + } + + m_builder.reset(); + m_builder.set_redundant(true); + + unsigned lower = 0, upper = p.power_of_2(); + // binary search for the parity (binary search instead of at_least_parity(p, parity) && at_most_parity(p, parity) for propagation if used with another parity + while (lower + 1 < upper) { + unsigned middle = (upper + lower) / 2; + signed_constraint c = s.parity_at_least(p, middle); // constraints are anyway cached and reused + LOG("Splitting on " << middle << " with " << parity); + if (parity >= middle) { + lower = middle; + path.push_back(~c); + if (needs_propagate) + m_builder.insert(~c); + } + else { + upper = middle; + path.push_back(c); + if (needs_propagate) + m_builder.insert(c); + } + LOG("Its in [" << lower << "; " << upper << ")"); + } + if (!needs_propagate) + return s.var(odd_v); + + (*m_odd[v]).second[parity] = true; + m_builder.insert(s.eq(rational::power_of_two(parity) * s.var(odd_v), p)); + clause_ref c = m_builder.build(); + s.add_clause(*c); + return s.var(odd_v); + } + + // a * x + b = 0 (x not in a or b; i.e., the equation is linear in x) + // C[p, ...] resp., C[..., p] + std::tuple> parity_tracker::eliminate_variable(saturation& saturation, pvar x, const pdd& a, const pdd& b, const pdd& p) { + + unsigned p_degree = p.degree(x); + if (p_degree == 0) + return { p, false, {} }; + if (a.is_val() && a.val().is_odd()) { // just invert and plug it in + rational a_inv; + VERIFY(a.val().mult_inverse(a.power_of_2(), a_inv)); + // this works as well if the degree of "p" is not 1: 3 x = a (mod 4) && x^2 <= b => (3a)^2 <= b + return { p.subst_pdd(x, -b * a_inv), true, {} }; + } + // from now on we require linear factors + if (p_degree != 1) + return { p, false, {} }; // TODO: Maybe fallback to brute-force + + pdd a1 = a.manager().zero(), b1 = a1, mul_fac = a1; + + p.factor(x, 1, a1, b1); + lbool is_multiple = saturation.get_multiple(a1, a, mul_fac); + if (is_multiple == l_false) + return { p, false, {} }; // there is no chance to invert + if (is_multiple == l_true) // we multiply with a factor to make them equal + return { b1 - b * mul_fac, true, {} }; + + #if 1 + return { p, false, {} }; + #else + if (!a1.is_var() && !a1.is_val()) { + //return { p, false, {} }; + LOG("Warning: Inverting " << a1 << " although it is not a single variable - might not be a good idea"); // TODO: Compromise: Maybe only monomials...? + } + if (!a.is_var() && !a.is_val()) { + //return { p, false, {} }; + LOG("Warning: Inverting " << a << " although it is not a single variable - might not be a good idea"); + } + + if (!a.is_monomial() || !a1.is_monomial()) + return { p , false, {} }; + + // We don't know whether it will work. Use the parity of the assignment + + unsigned a_parity; + unsigned a1_parity; + + if ((a_parity = saturation.min_parity(a)) != saturation.max_parity(a) || (a1_parity = saturation.min_parity(a1)) != saturation.max_parity(a1)) + return { p, false, {} }; // We need the parity, but we failed to get it precisely + + if (a_parity > a1_parity) { + SASSERT(false); // get_multiple should have excluded this case already + return { p, false, {} }; + } + + svector precondition; + + auto odd_a = get_odd(a, a_parity, precondition); + auto odd_a1 = get_odd(a1, a1_parity, precondition); + pdd inv_odd_a = get_inverse(odd_a); + + LOG("Forced elimination: " << odd_a1 * inv_odd_a * rational::power_of_two(a1_parity - a_parity) * b + b1); + verbose_stream() << "Forced elimination: " << odd_a1 * inv_odd_a * rational::power_of_two(a1_parity - a_parity) * b + b1 << "\n"; + verbose_stream() << "From: " << "eliminated v" << x << " with a = " << a << "; b = " << b << "; p = " << p << "\n"; + return { odd_a1 * inv_odd_a * rational::power_of_two(a1_parity - a_parity) * b + b1, true, {std::move(precondition)} }; +#endif + } } diff --git a/src/math/polysat/variable_elimination.h b/src/math/polysat/variable_elimination.h index fa898e758..3b1bde455 100644 --- a/src/math/polysat/variable_elimination.h +++ b/src/math/polysat/variable_elimination.h @@ -15,6 +15,7 @@ Author: #include "math/polysat/types.h" #include "math/polysat/constraint.h" +#include "math/polysat/clause_builder.h" namespace polysat { @@ -50,6 +51,36 @@ namespace polysat { public: free_variable_elimination(solver& s): s(s) {} void find_lemma(conflict& core); -}; - + }; + + class saturation; + + class parity_tracker { + + solver& s; + clause_builder m_builder; + + vector>> m_odd; + unsigned_vector m_inverse; + + struct optional_pdd_hash { + unsigned operator()(optional const& args) const { + return args->hash(); + } + }; + using pdd_to_id = map, unsigned, optional_pdd_hash, default_eq>>; + + pdd_to_id m_pdd_to_id; // if we want to use arbitrary pdds instead of pvars + + unsigned get_id(const pdd& p); + + public: + + parity_tracker(solver& s) : s(s), m_builder(s) {} + + pdd get_inverse(const pdd& p); + pdd get_odd(const pdd& p, unsigned parity, svector& pat); + + std::tuple> eliminate_variable(saturation& saturation, pvar x, const pdd& a, const pdd& b, const pdd& p); + }; } diff --git a/src/test/polysat.cpp b/src/test/polysat.cpp index efcc806f2..4fc73abab 100644 --- a/src/test/polysat.cpp +++ b/src/test/polysat.cpp @@ -724,7 +724,7 @@ namespace polysat { pdd x = s.var(s.add_var(bw)); pdd y = s.var(s.add_var(bw)); s.add_eq(x * y + 2); - s.add_clause({ s.parity(y, 4), s.parity(y, 8) }, false); + s.add_clause({ s.parity_at_least(y, 4), s.parity_at_least(y, 8) }, false); s.check(); s.expect_unsat(); }