3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-22 16:45:31 +00:00

Merge branch 'polysat' of https://github.com/z3prover/z3 into polysat

This commit is contained in:
Nikolaj Bjorner 2022-12-29 16:55:56 -08:00
commit ed76da1458
7 changed files with 235 additions and 65 deletions

View file

@ -30,7 +30,7 @@ TODO: when we check that 'x' is "unary":
namespace polysat {
saturation::saturation(solver& s) : s(s), m_lemma(s) {}
saturation::saturation(solver& s) : s(s), m_lemma(s), m_parity_tracker(s) {}
void saturation::log_lemma(pvar v, conflict& core) {
IF_VERBOSE(1, auto const& cl = core.lemmas().back();
@ -880,7 +880,7 @@ namespace polysat {
}
for (unsigned j = N; j > 0; --j)
if (is_forced_true(s.parity(p, j)))
if (is_forced_true(s.parity_at_least(p, j)))
return j;
return 0;
}
@ -971,7 +971,7 @@ namespace polysat {
auto at_least = [&](pdd const& p, unsigned k) {
VERIFY(k != 0);
return s.parity(p, k);
return s.parity_at_least(p, k);
};
@ -1020,7 +1020,7 @@ namespace polysat {
m_lemma.reset();
m_lemma.insert_eval(~s.eq(y));
m_lemma.insert_eval(~s.eq(b));
if (propagate(x, core, axb_l_y, ~s.parity(X, N - k)))
if (propagate(x, core, axb_l_y, ~s.parity_at_least(X, N - k)))
return true;
// TODO parity on a (without leading coefficient?)
}
@ -1135,7 +1135,7 @@ namespace polysat {
lbool saturation::get_multiple(const pdd& p1, const pdd& p2, pdd& out) {
LOG("Check if " << p2 << " can be multiplied with something to get " << p1);
if (p1.is_zero()) {
if (p1.is_zero()) { // TODO: use the evaluated parity (max_parity) instead?
out = p1.manager().zero();
return l_true;
}
@ -1202,40 +1202,8 @@ namespace polysat {
if (!is_AxB_eq_0(x, a_l_b, a, b, y)) // TODO: Is the restriction to linear "x" too restrictive?
return false;
bool is_invertible = a.is_val() && a.val().is_odd();
if (is_invertible) {
rational a_inv;
VERIFY(a.val().mult_inverse(m.power_of_2(), a_inv));
b = -b * a_inv;
}
bool change = false;
bool prop = false;
auto replace = [&](pdd p) {
unsigned p_degree = p.degree(x);
if (p_degree == 0)
return p;
if (is_invertible) {
change = true;
// this works as well if the degree of "p" is not 1: 3 x = a (mod 4) & x^2 <= b => (3a)^2 <= b
return p.subst_pdd(x, b);
}
if (p_degree != 1)
return p; // TODO: Maybe fallback to brute-force
p.factor(x, 1, a1, b1);
lbool is_multiple = get_multiple(a1, a, mul_fac);
if (is_multiple == l_false)
return p; // there is no chance to invert
if (is_multiple == l_true) {
change = true;
return b1 - b * mul_fac;
}
// We don't know whether it will work. Brute-force the parity
// TODO: Brute force goes here
return p;
};
for (auto c : core) {
change = false;
@ -1243,27 +1211,38 @@ namespace polysat {
continue;
LOG("Trying to eliminate v" << x << " in " << c << " by using equation " << a_l_b.as_signed_constraint());
if (c->is_ule()) {
// If both are equalities this boils down to polynomial superposition => Might generate the same lemma twice
auto const& ule = c->to_ule();
auto p = replace(ule.lhs());
auto q = replace(ule.rhs());
if (!change)
continue;
auto [lhs_new, changed_lhs, side_condition_lhs] = m_parity_tracker.eliminate_variable(*this, x, a, b, ule.lhs());
auto [rhs_new, changed_rhs, side_condition_rhs] = m_parity_tracker.eliminate_variable(*this, x, a, b, ule.rhs());
if (!changed_lhs && !changed_rhs)
continue; // nothing changed - no reason for propagating lemmas
m_lemma.reset();
m_lemma.insert(~c);
m_lemma.insert_eval(~s.eq(y));
if (propagate(x, core, a_l_b, c.is_positive() ? s.ule(p, q) : ~s.ule(p, q)))
for (auto& sc_lhs : side_condition_lhs) // TODO: Do we really need the path as a side-condition in case of parity elimination?
m_lemma.insert(sc_lhs);
for (auto& sc_rhs : side_condition_rhs)
m_lemma.insert(sc_rhs);
if (propagate(x, core, a_l_b, c.is_positive() ? s.ule(lhs_new, rhs_new) : ~s.ule(lhs_new, rhs_new)))
prop = true;
}
else if (c->is_umul_ovfl()) {
auto const& ovf = c->to_umul_ovfl();
auto p = replace(ovf.p());
auto q = replace(ovf.q());
if (!change)
auto [lhs_new, changed_lhs, side_condition_lhs] = m_parity_tracker.eliminate_variable(*this, x, a, b, ovf.p());
auto [rhs_new, changed_rhs, side_condition_rhs] = m_parity_tracker.eliminate_variable(*this, x, a, b, ovf.q());
if (!changed_lhs && !changed_rhs)
continue;
m_lemma.reset();
m_lemma.insert(~c);
m_lemma.insert_eval(~s.eq(y));
if (propagate(x, core, a_l_b, c.is_positive() ? s.umul_ovfl(p, q) : ~s.umul_ovfl(p, q)))
for (auto& sc_lhs : side_condition_lhs)
m_lemma.insert(sc_lhs);
for (auto& sc_rhs : side_condition_rhs)
m_lemma.insert(sc_rhs);
if (propagate(x, core, a_l_b, c.is_positive() ? s.umul_ovfl(lhs_new, rhs_new) : ~s.umul_ovfl(lhs_new, rhs_new)))
prop = true;
}
}

View file

@ -14,6 +14,7 @@ Author:
#pragma once
#include "math/polysat/clause_builder.h"
#include "math/polysat/conflict.h"
#include "math/polysat/variable_elimination.h"
namespace polysat {
@ -22,10 +23,13 @@ namespace polysat {
*/
class saturation {
friend class parity_tracker;
solver& s;
clause_builder m_lemma;
char const* m_rule = nullptr;
parity_tracker m_parity_tracker;
unsigned_vector m_occ;
unsigned_vector m_occ_cnt;

View file

@ -19,6 +19,7 @@ Author:
#include "math/polysat/solver.h"
#include "math/polysat/log.h"
#include "math/polysat/polysat_params.hpp"
#include "math/polysat/variable_elimination.h"
#include <variant>
// For development; to be removed once the linear solver works well enough
@ -238,12 +239,12 @@ namespace polysat {
LOG_H2("Propagate " << assignment_pp(*this, v, get_value(v)));
SASSERT(!m_locked_wlist);
DEBUG_CODE(m_locked_wlist = v;);
unsigned i = 0, j = 0;
for (; i < m_pwatch[v].size() && !is_conflict(); ++i)
if (!propagate(v, m_pwatch[v][i])) // propagate may change watch-list reference
m_pwatch[v][j++] = m_pwatch[v][i];
auto& wlist = m_pwatch[v];
unsigned i = 0, j = 0, sz = wlist.size();
for (; i < sz && !is_conflict(); ++i)
if (!propagate(v, wlist[i]))
wlist[j++] = wlist[i];
for (; i < sz; ++i)
for (; i < wlist.size(); ++i)
wlist[j++] = wlist[i];
wlist.shrink(j);
if (is_conflict())
@ -435,6 +436,7 @@ namespace polysat {
#if ENABLE_LINEAR_SOLVER
m_linear_solver.push();
#endif
m_fixed_bits.push();
}
void solver::pop_levels(unsigned num_levels) {
@ -448,6 +450,8 @@ namespace polysat {
#if ENABLE_LINEAR_SOLVER
m_linear_solver.pop(num_levels);
#endif
m_fixed_bits.pop();
while (num_levels > 0) {
switch (m_trail.back()) {
case trail_instr_t::qhead_i: {
@ -602,7 +606,7 @@ namespace polysat {
}
}
#endif
m_fixed_bits.push();
if (can_bdecide())
bdecide();
else
@ -833,7 +837,6 @@ namespace polysat {
continue;
}
if (j.is_decision()) {
m_fixed_bits.pop();
m_conflict.revert_pvar(v);
revert_decision(v);
return;
@ -862,7 +865,6 @@ namespace polysat {
}
SASSERT(!m_bvars.is_assumption(var)); // TODO: "assumption" is basically "propagated by unit clause" (or "at base level"); except we do not explicitly store the unit clause.
if (m_bvars.is_decision(var)) {
m_fixed_bits.pop();
revert_bool_decision(lit);
return;
}

View file

@ -137,6 +137,7 @@ namespace polysat {
friend class ex_polynomial_superposition;
friend class free_variable_elimination;
friend class saturation;
friend class parity_tracker;
friend class constraint_manager;
friend class scoped_solverv;
friend class test_polysat;
@ -422,9 +423,9 @@ namespace polysat {
signed_constraint eq(pdd const& p, rational const& q) { return eq(p - q); }
signed_constraint eq(pdd const& p, unsigned q) { return eq(p - q); }
signed_constraint odd(pdd const& p) { return ~even(p); }
signed_constraint even(pdd const& p) { return parity(p, 1); }
signed_constraint even(pdd const& p) { return parity_at_least(p, 1); }
/** parity(p) >= k */
signed_constraint parity(pdd const& p, unsigned k) { // TODO: rename to parity_at_least?
signed_constraint parity_at_least(pdd const& p, unsigned k) {
unsigned N = p.manager().power_of_2();
// parity(p) >= k
// <=> p * 2^(N - k) == 0
@ -449,7 +450,7 @@ namespace polysat {
return eq(p.manager().zero());
}
else
return ~parity(p, k + 1);
return ~parity_at_least(p, k + 1);
}
signed_constraint diseq(pdd const& p, rational const& q) { return diseq(p - q); }
signed_constraint diseq(pdd const& p, unsigned q) { return diseq(p - q); }

View file

@ -11,10 +11,11 @@ Author:
Jakob Rath 2021-04-06
--*/
#include "math/polysat/variable_elimination.h"
#include "math/polysat/conflict.h"
#include "math/polysat/clause_builder.h"
#include "math/polysat/saturation.h"
#include "math/polysat/solver.h"
#include "math/polysat/variable_elimination.h"
#include <algorithm>
namespace polysat {
@ -252,7 +253,7 @@ namespace polysat {
find_lemma(v, c, core);
}
}
void free_variable_elimination::find_lemma(pvar v, signed_constraint c, conflict& core) {
LOG_H3("Free Variable Elimination for v" << v << " using equation " << c);
pdd const& p = c.eq();
@ -380,7 +381,7 @@ namespace polysat {
LOG("pv_lhs: " << pv_lhs);
LOG("odd_fac_lhs: " << odd_fac_lhs);
LOG("power_diff_lhs: " << power_diff_lhs);
new_lhs = -rest * *fac_odd_inv * power_diff_lhs * odd_fac_lhs + rest_rhs;
new_lhs = -rest * *fac_odd_inv * power_diff_lhs * odd_fac_lhs + rest_lhs;
p1 = s.ule(get_dyadic_valuation(fac).first, get_dyadic_valuation(fac_lhs).first);
}
else {
@ -405,7 +406,7 @@ namespace polysat {
}
}
signed_constraint c_new = s.ule(new_lhs , new_rhs);
signed_constraint c_new = s.ule(new_lhs, new_rhs);
if (c_target.is_negative())
c_new.negate();
@ -524,5 +525,157 @@ namespace polysat {
LOG("Found multiple: " << out);
return is_multiple;
}
unsigned parity_tracker::get_id(const pdd& p) {
// SASSERT(p.is_var()); // For now
// pvar v = p.var();
unsigned id = m_pdd_to_id.get(optional(p), -1);
if (id == -1) {
id = m_pdd_to_id.size();
m_pdd_to_id.insert(optional(p), id);
}
return id;
}
pdd parity_tracker::get_inverse(const pdd &p) {
LOG("Getting inverse of " << p);
if (p.is_val()) {
SASSERT(p.val().is_odd());
rational iv;
VERIFY(p.val().mult_inverse(p.power_of_2(), iv));
return p.manager().mk_val(iv);
}
unsigned v = get_id(p);
if (m_inverse.size() > v && m_inverse[v] != -1)
return s.var(m_inverse[v]);
pvar inv = s.add_var(p.power_of_2());
pdd inv_pdd = p.manager().mk_var(inv);
m_inverse.setx(v, inv, -1);
s.add_clause(s.eq(inv_pdd * p, p.manager().one()), false);
return inv_pdd;
}
pdd parity_tracker::get_odd(const pdd& p, unsigned parity, svector<signed_constraint>& path) {
LOG("Getting odd part of " << p);
if (p.is_val()) {
SASSERT(!p.val().is_zero());
rational odd = machine_div(p.val(), rational::power_of_two(p.val().trailing_zeros()));
SASSERT(odd.is_odd());
return p.manager().mk_val(odd);
}
unsigned v = get_id(p);
pvar odd_v;
bool needs_propagate = true;
if (m_odd.size() > v && m_odd[v].initialized()) {
auto& tuple = *(m_odd[v]);
SASSERT(tuple.second.size() == p.power_of_2());
odd_v = tuple.first;
needs_propagate = !tuple.second[parity];
}
else {
odd_v = s.add_var(p.power_of_2());
m_odd.setx(v, optional<std::pair<pvar, bool_vector>>({ odd_v, bool_vector(p.power_of_2(), false) }), optional<std::pair<pvar, bool_vector>>::undef());
}
m_builder.reset();
m_builder.set_redundant(true);
unsigned lower = 0, upper = p.power_of_2();
// binary search for the parity (binary search instead of at_least_parity(p, parity) && at_most_parity(p, parity) for propagation if used with another parity
while (lower + 1 < upper) {
unsigned middle = (upper + lower) / 2;
signed_constraint c = s.parity_at_least(p, middle); // constraints are anyway cached and reused
LOG("Splitting on " << middle << " with " << parity);
if (parity >= middle) {
lower = middle;
path.push_back(~c);
if (needs_propagate)
m_builder.insert(~c);
}
else {
upper = middle;
path.push_back(c);
if (needs_propagate)
m_builder.insert(c);
}
LOG("Its in [" << lower << "; " << upper << ")");
}
if (!needs_propagate)
return s.var(odd_v);
(*m_odd[v]).second[parity] = true;
m_builder.insert(s.eq(rational::power_of_two(parity) * s.var(odd_v), p));
clause_ref c = m_builder.build();
s.add_clause(*c);
return s.var(odd_v);
}
// a * x + b = 0 (x not in a or b; i.e., the equation is linear in x)
// C[p, ...] resp., C[..., p]
std::tuple<pdd, bool, svector<signed_constraint>> parity_tracker::eliminate_variable(saturation& saturation, pvar x, const pdd& a, const pdd& b, const pdd& p) {
unsigned p_degree = p.degree(x);
if (p_degree == 0)
return { p, false, {} };
if (a.is_val() && a.val().is_odd()) { // just invert and plug it in
rational a_inv;
VERIFY(a.val().mult_inverse(a.power_of_2(), a_inv));
// this works as well if the degree of "p" is not 1: 3 x = a (mod 4) && x^2 <= b => (3a)^2 <= b
return { p.subst_pdd(x, -b * a_inv), true, {} };
}
// from now on we require linear factors
if (p_degree != 1)
return { p, false, {} }; // TODO: Maybe fallback to brute-force
pdd a1 = a.manager().zero(), b1 = a1, mul_fac = a1;
p.factor(x, 1, a1, b1);
lbool is_multiple = saturation.get_multiple(a1, a, mul_fac);
if (is_multiple == l_false)
return { p, false, {} }; // there is no chance to invert
if (is_multiple == l_true) // we multiply with a factor to make them equal
return { b1 - b * mul_fac, true, {} };
#if 1
return { p, false, {} };
#else
if (!a1.is_var() && !a1.is_val()) {
//return { p, false, {} };
LOG("Warning: Inverting " << a1 << " although it is not a single variable - might not be a good idea"); // TODO: Compromise: Maybe only monomials...?
}
if (!a.is_var() && !a.is_val()) {
//return { p, false, {} };
LOG("Warning: Inverting " << a << " although it is not a single variable - might not be a good idea");
}
if (!a.is_monomial() || !a1.is_monomial())
return { p , false, {} };
// We don't know whether it will work. Use the parity of the assignment
unsigned a_parity;
unsigned a1_parity;
if ((a_parity = saturation.min_parity(a)) != saturation.max_parity(a) || (a1_parity = saturation.min_parity(a1)) != saturation.max_parity(a1))
return { p, false, {} }; // We need the parity, but we failed to get it precisely
if (a_parity > a1_parity) {
SASSERT(false); // get_multiple should have excluded this case already
return { p, false, {} };
}
svector<signed_constraint> precondition;
auto odd_a = get_odd(a, a_parity, precondition);
auto odd_a1 = get_odd(a1, a1_parity, precondition);
pdd inv_odd_a = get_inverse(odd_a);
LOG("Forced elimination: " << odd_a1 * inv_odd_a * rational::power_of_two(a1_parity - a_parity) * b + b1);
verbose_stream() << "Forced elimination: " << odd_a1 * inv_odd_a * rational::power_of_two(a1_parity - a_parity) * b + b1 << "\n";
verbose_stream() << "From: " << "eliminated v" << x << " with a = " << a << "; b = " << b << "; p = " << p << "\n";
return { odd_a1 * inv_odd_a * rational::power_of_two(a1_parity - a_parity) * b + b1, true, {std::move(precondition)} };
#endif
}
}

View file

@ -15,6 +15,7 @@ Author:
#include "math/polysat/types.h"
#include "math/polysat/constraint.h"
#include "math/polysat/clause_builder.h"
namespace polysat {
@ -50,6 +51,36 @@ namespace polysat {
public:
free_variable_elimination(solver& s): s(s) {}
void find_lemma(conflict& core);
};
};
class saturation;
class parity_tracker {
solver& s;
clause_builder m_builder;
vector<optional<std::pair<pvar, bool_vector>>> m_odd;
unsigned_vector m_inverse;
struct optional_pdd_hash {
unsigned operator()(optional<pdd> const& args) const {
return args->hash();
}
};
using pdd_to_id = map<optional<pdd>, unsigned, optional_pdd_hash, default_eq<optional<pdd>>>;
pdd_to_id m_pdd_to_id; // if we want to use arbitrary pdds instead of pvars
unsigned get_id(const pdd& p);
public:
parity_tracker(solver& s) : s(s), m_builder(s) {}
pdd get_inverse(const pdd& p);
pdd get_odd(const pdd& p, unsigned parity, svector<signed_constraint>& pat);
std::tuple<pdd, bool, svector<signed_constraint>> eliminate_variable(saturation& saturation, pvar x, const pdd& a, const pdd& b, const pdd& p);
};
}

View file

@ -724,7 +724,7 @@ namespace polysat {
pdd x = s.var(s.add_var(bw));
pdd y = s.var(s.add_var(bw));
s.add_eq(x * y + 2);
s.add_clause({ s.parity(y, 4), s.parity(y, 8) }, false);
s.add_clause({ s.parity_at_least(y, 4), s.parity_at_least(y, 8) }, false);
s.check();
s.expect_unsat();
}