3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-22 00:26:38 +00:00

working on viable

This commit is contained in:
Nikolaj Bjorner 2023-12-09 13:10:47 -08:00
parent 94ba85bb12
commit 0c2ecf8b90
15 changed files with 530 additions and 3 deletions

View file

@ -1,5 +1,6 @@
z3_add_component(polysat
SOURCES
fixed_bits.cpp
polysat_assignment.cpp
polysat_constraints.cpp
polysat_core.cpp

View file

@ -0,0 +1,180 @@
/*++
Copyright (c) 2022 Microsoft Corporation
Module Name:
Extract fixed bits from constraints
Author:
Jakob Rath, Nikolaj Bjorner (nbjorner), Clemens Eisenhofer 2022-08-22
--*/
#include "sat/smt/polysat/fixed_bits.h"
#include "sat/smt/polysat/polysat_ule.h"
namespace polysat {
/**
* 2^k * x = 2^k * b
* ==> x[N-k-1:0] = b[N-k-1:0]
*/
bool get_eq_fixed_lsb(pdd const& p, fixed_bits& out) {
SASSERT(!p.is_val());
unsigned const N = p.power_of_2();
// Recognize p = 2^k * a * x - 2^k * b
if (!p.hi().is_val())
return false;
if (!p.lo().is_val())
return false;
// p = c * x - d
rational const c = p.hi().val();
rational const d = (-p.lo()).val();
SASSERT(!c.is_zero());
#if 1
// NOTE: ule_constraint::simplify removes odd factors of the leading term
unsigned k;
VERIFY(c.is_power_of_two(k));
if (d.parity(N) < k)
return false;
rational const b = machine_div2k(d, k);
out = fixed_bits(N - k - 1, 0, b);
SASSERT_EQ(d, b * rational::power_of_two(k));
SASSERT_EQ(p, (p.manager().mk_var(p.var()) - out.value) * rational::power_of_two(k));
return true;
#else
// branch if we want to support non-simplifed constraints (not recommended)
//
// 2^k * a * x = 2^k * b
// ==> x[N-k-1:0] = a^-1 * b[N-k-1:0]
// for odd a
unsigned k = c.parity(N);
if (d.parity(N) < k)
return false;
rational const a = machine_div2k(c, k);
SASSERT(a.is_odd());
SASSERT(a.is_one()); // TODO: ule-simplify will multiply with a_inv already, so we can drop the check here.
rational a_inv;
VERIFY(a.mult_inverse(N, a_inv));
rational const b = machine_div2k(d, k);
out.hi = N - k - 1;
out.lo = 0;
out.value = a_inv * b;
SASSERT_EQ(p, (p.manager().mk_var(p.var()) - out.value) * a * rational::power_of_two(k));
return true;
#endif
}
bool get_eq_fixed_bits(pdd const& p, fixed_bits& out) {
if (get_eq_fixed_lsb(p, out))
return true;
return false;
}
/**
* Constraint lhs <= rhs.
*
* -2^(k - 2) * x > 2^(k - 1)
* <=> 2 + x[1:0] > 2 (mod 4)
* ==> x[1:0] = 1
* -- TODO: Generalize [the obvious solution does not work]
*/
bool get_ule_fixed_lsb(pdd const& lhs, pdd const& rhs, bool is_positive, fixed_bits& out) {
return false;
}
/**
* Constraint lhs <= rhs.
*
* x <= 2^k - 1 ==> x[N-1:k] = 0
* x < 2^k ==> x[N-1:k] = 0
*/
bool get_ule_fixed_msb(pdd const& p, pdd const& q, bool is_positive, fixed_bits& out) {
SASSERT(!q.is_zero()); // equalities are handled elsewhere
unsigned const N = p.power_of_2();
pdd const& lhs = is_positive ? p : q;
pdd const& rhs = is_positive ? q : p;
bool const is_strict = !is_positive;
if (lhs.is_var() && rhs.is_val()) {
// x <= c
// find smallest k such that c <= 2^k - 1, i.e., c+1 <= 2^k
// ==> x <= 2^k - 1 ==> x[N-1:k] = 0
//
// x < c
// find smallest k such that c <= 2^k
// ==> x < 2^k ==> x[N-1:k] = 0
rational const c = is_strict ? rhs.val() : (rhs.val() + 1);
unsigned const k = c.next_power_of_two();
if (k < N) {
out.hi = N - 1;
out.lo = k;
out.value = 0;
return true;
}
}
return false;
}
// 2^(N-1) <= 2^(N-1-i) * x
bool get_ule_fixed_bit(pdd const& p, pdd const& q, bool is_positive, fixed_bits& out) {
return false;
}
bool get_ule_fixed_bits(pdd const& lhs, pdd const& rhs, bool is_positive, fixed_bits& out) {
SASSERT(ule_constraint::is_simplified(lhs, rhs));
if (rhs.is_zero())
return is_positive ? get_eq_fixed_bits(lhs, out) : false;
if (get_ule_fixed_msb(lhs, rhs, is_positive, out))
return true;
if (get_ule_fixed_lsb(lhs, rhs, is_positive, out))
return true;
if (get_ule_fixed_bit(lhs, rhs, is_positive, out))
return true;
return false;
}
bool get_fixed_bits(signed_constraint c, fixed_bits& out) {
SASSERT_EQ(c.vars().size(), 1); // this only makes sense for univariate constraints
if (c.is_ule())
return get_ule_fixed_bits(c.to_ule().lhs(), c.to_ule().rhs(), c.is_positive(), out);
// if (c->is_op())
// ; // TODO: x & constant = constant ==> bitmask ... but we have trouble recognizing that because we introduce a new variable for '&' before we see the equality.
return false;
}
/*
// 2^(k - d) * x = m * 2^(k - d)
// Special case [still seems to occur frequently]: -2^(k - 2) * x > 2^(k - 1) - TODO: Generalize [the obvious solution does not work] => lsb(x, 2) = 1
bool get_lsb(pdd lhs, pdd rhs, pdd& p, trailing_bits& info, bool pos) {
SASSERT(lhs.is_univariate() && lhs.degree() <= 1);
SASSERT(rhs.is_univariate() && rhs.degree() <= 1);
else { // inequality - check for special case
if (pos || lhs.power_of_2() < 3)
return false;
auto it = lhs.begin();
if (it == lhs.end())
return false;
if (it->vars.size() != 1)
return false;
rational coeff = it->coeff;
it++;
if (it != lhs.end())
return false;
if ((mod2k(-coeff, lhs.power_of_2())) != rational::power_of_two(lhs.power_of_2() - 2))
return false;
p = lhs.div(coeff);
SASSERT(p.is_var());
info.bits = 1;
info.length = 2;
info.positive = true; // this is a conjunction
return true;
}
}
*/
} // namespace polysat

View file

@ -0,0 +1,31 @@
/*++
Copyright (c) 2022 Microsoft Corporation
Module Name:
Extract fixed bits of variables from univariate constraints
Author:
Jakob Rath, Nikolaj Bjorner (nbjorner), Clemens Eisenhofer 2022-08-22
--*/
#pragma once
#include "sat/smt/polysat/polysat_types.h"
#include "sat/smt/polysat/polysat_constraints.h"
#include "util/vector.h"
namespace polysat {
using fixed_bits_vector = vector<fixed_bits>;
bool get_eq_fixed_lsb(pdd const& p, fixed_bits& out);
bool get_eq_fixed_bits(pdd const& p, fixed_bits& out);
bool get_ule_fixed_lsb(pdd const& lhs, pdd const& rhs, bool is_positive, fixed_bits& out);
bool get_ule_fixed_msb(pdd const& lhs, pdd const& rhs, bool is_positive, fixed_bits& out);
bool get_ule_fixed_bit(pdd const& lhs, pdd const& rhs, bool is_positive, fixed_bits& out);
bool get_ule_fixed_bits(pdd const& lhs, pdd const& rhs, bool is_positive, fixed_bits& out);
bool get_fixed_bits(signed_constraint c, fixed_bits& out);
}

View file

@ -282,6 +282,10 @@ namespace polysat {
s.get_bitvector_prefixes(v, out);
}
void core::get_fixed_bits(pvar v, svector<justified_fixed_bits>& fixed_bits) {
s.get_fixed_bits(v, fixed_bits);
}
bool core::inconsistent() const {
return s.inconsistent();
}

View file

@ -80,7 +80,9 @@ namespace polysat {
void propagate_unsat_core();
void get_bitvector_prefixes(pvar v, pvar_vector& out);
void get_fixed_bits(pvar v, svector<justified_fixed_bits>& fixed_bits);
bool inconsistent() const;
void add_watch(unsigned idx, unsigned var);

View file

@ -52,6 +52,39 @@ namespace polysat {
return out << "v" << d.eq().first << " == v" << d.eq().second << "@" << d.level();
}
struct trailing_bits {
unsigned length;
rational bits;
bool positive;
unsigned src_idx;
};
struct leading_bits {
unsigned length;
bool positive; // either all 0 or all 1
unsigned src_idx;
};
struct single_bit {
bool positive;
unsigned position;
unsigned src_idx;
};
struct fixed_bits {
unsigned hi = 0;
unsigned lo = 0;
rational value;
/// The constraint is equivalent to setting fixed bits on a variable.
// bool is_equivalent;
fixed_bits() = default;
fixed_bits(unsigned hi, unsigned lo, rational value) : hi(hi), lo(lo), value(value) {}
};
struct justified_fixed_bits : public fixed_bits, public dependency {};
using dependency_vector = vector<dependency>;
class signed_constraint;
@ -66,6 +99,7 @@ namespace polysat {
virtual trail_stack& trail() = 0;
virtual bool inconsistent() const = 0;
virtual void get_bitvector_prefixes(pvar v, pvar_vector& out) = 0;
virtual void get_fixed_bits(pvar v, svector<justified_fixed_bits>& fixed_bits) = 0;
};
}

View file

@ -20,6 +20,7 @@ Notes:
#include "util/log.h"
#include "sat/smt/polysat/polysat_viable.h"
#include "sat/smt/polysat/polysat_core.h"
#include "sat/smt/polysat/polysat_ule.h"
namespace polysat {
@ -420,6 +421,206 @@ namespace polysat {
return l_undef;
}
// returns true iff no conflict was encountered
bool viable::collect_bit_information(pvar v, bool add_conflict, fixed_bits_info& out_fbi) {
pdd p = c.var(v);
unsigned const v_sz = c.size(v);
out_fbi.reset(v_sz);
auto& [fixed, just_src, just_side_cond, just_slice] = out_fbi;
svector<justified_fixed_bits> fbs;
c.get_fixed_bits(v, fbs);
for (auto const& fb : fbs) {
LOG("slicing fixed bits: v" << v << "[" << fb.hi << ":" << fb.lo << "] = " << fb.value);
for (unsigned i = fb.lo; i <= fb.hi; ++i) {
SASSERT(out_fbi.just_src[i].empty()); // since we don't get overlapping ranges from collect_fixed.
SASSERT(out_fbi.just_side_cond[i].empty());
SASSERT(out_fbi.just_slicing[i].empty());
out_fbi.fixed[i] = to_lbool(fb.value.get_bit(i - fb.lo));
out_fbi.just_slicing[i].push_back(fb);
}
}
entry* e1 = m_equal_lin[v];
entry* e2 = m_units[v].get_entries(c.size(v)); // TODO: take other widths into account (will be done automatically by tracking fixed bits in the slicing egraph)
entry* first = e1;
if (!e1 && !e2)
return true;
#if 0
clause_builder builder(s, "bit check");
sat::literal_set added;
vector<std::pair<entry*, trailing_bits>> postponed;
auto add_literal = [&builder, &added](sat::literal lit) {
if (added.contains(lit))
return;
added.insert(lit);
builder.insert_eval(~lit);
};
auto add_literals = [&add_literal](sat::literal_vector const& lits) {
for (sat::literal lit : lits)
add_literal(lit);
};
auto add_entry = [&add_literal](entry* e) {
for (const auto& sc : e->side_cond)
add_literal(sc.blit());
for (const auto& src : e->src)
add_literal(src.blit());
};
auto add_slicing = [this, &add_literal](slicing::enode* n) {
s.m_slicing.explain_fixed(n, [&](sat::literal lit) {
add_literal(lit);
}, [&](pvar v) {
LOG("from slicing: v" << v);
add_literal(s.cs().eq(c.var(v), c.get_value(v)).blit());
});
};
auto add_bit_justification = [&add_literals, &add_slicing](fixed_bits_info const& fbi, unsigned i) {
add_literals(fbi.just_src[i]);
add_literals(fbi.just_side_cond[i]);
for (slicing::enode* n : fbi.just_slicing[i])
add_slicing(n);
};
if (e1) {
unsigned largest_lsb = 0;
do {
if (e1->src.size() != 1) {
// We just consider the ordinary constraints and not already contracted ones
e1 = e1->next();
continue;
}
signed_constraint& src = e1->src[0];
single_bit bit;
trailing_bits lsb;
if (src.is_ule() &&
simplify_clause::get_bit(s.subst(src.to_ule().lhs()), s.subst(src.to_ule().rhs()), p, bit, src.is_positive()) && p.is_var()) {
lbool prev = fixed[bit.position];
fixed[bit.position] = to_lbool(bit.positive);
//verbose_stream() << "Setting bit " << bit.position << " to " << bit.positive << " because of " << e->src << "\n";
if (prev != l_undef && fixed[bit.position] != prev) {
// LOG("Bit conflicting " << e1->src << " with " << just_src[bit.position][0]); // NOTE: just_src may be empty if the justification is by slicing
if (add_conflict) {
add_bit_justification(out_fbi, bit.position);
add_entry(e1);
s.set_conflict(*builder.build());
}
return false;
}
// just override; we prefer bit constraints over parity as those are easier for subsumption to remove
// do we just introduce a new justification here that subsumption will remove anyway?
// the only way it will not is if all bits are overwritten like this.
// but in that case we basically replace one parity constraint by multiple bit constraints?
// verbose_stream() << "Adding bit constraint: " << e->src[0] << " (" << bit.position << ")\n";
if (prev == l_undef) {
out_fbi.set_just(bit.position, e1);
}
}
else if (src.is_eq() &&
simplify_clause::get_lsb(s.subst(src.to_ule().lhs()), s.subst(src.to_ule().rhs()), p, lsb, src.is_positive()) && p.is_var()) {
if (src.is_positive()) {
for (unsigned i = 0; i < lsb.length; i++) {
lbool prev = fixed[i];
fixed[i] = to_lbool(lsb.bits.get_bit(i));
if (prev == l_undef) {
SASSERT(just_src[i].empty());
out_fbi.set_just(i, e1);
continue;
}
if (fixed[i] != prev) {
// LOG("Positive parity conflicting " << e1->src << " with " << just_src[i][0]); // NOTE: just_src may be empty if the justification is by slicing
if (add_conflict) {
add_bit_justification(out_fbi, i);
add_entry(e1);
s.set_conflict(*builder.build());
}
return false;
}
// Prefer justifications from larger masks (less premises)
// TODO: Check that we don't override justifications coming from bit constraints
if (largest_lsb < lsb.length)
out_fbi.set_just(i, e1);
}
largest_lsb = std::max(largest_lsb, lsb.length);
}
else
postponed.push_back({ e1, lsb });
}
e1 = e1->next();
} while (e1 != first);
}
// so far every bit is justified by a single constraint
SASSERT(all_of(just_src, [](auto const& vec) { return vec.size() <= 1; }));
// TODO: Incomplete - e.g., if we know the trailing bits are not 00 not 10 not 01 and not 11 we could also detect a conflict
// This would require partially clause solving (worth the effort?)
bool_vector removed(postponed.size(), false);
bool changed;
do { // fixed-point required?
changed = false;
for (unsigned j = 0; j < postponed.size(); j++) {
if (removed[j])
continue;
const auto& neg = postponed[j];
unsigned indet = 0;
unsigned last_indet = 0;
unsigned i = 0;
for (; i < neg.second.length; i++) {
if (fixed[i] != l_undef) {
if (fixed[i] != to_lbool(neg.second.bits.get_bit(i))) {
removed[j] = true;
break; // this is already satisfied
}
}
else {
indet++;
last_indet = i;
}
}
if (i == neg.second.length) {
if (indet == 0) {
// Already false
LOG("Found conflict with constraint " << neg.first->src);
if (add_conflict) {
for (unsigned k = 0; k < neg.second.length; k++)
add_bit_justification(out_fbi, k);
add_entry(neg.first);
s.set_conflict(*builder.build());
}
return false;
}
else if (indet == 1) {
// Simple BCP
SASSERT(just_src[last_indet].empty());
SASSERT(just_side_cond[last_indet].empty());
for (unsigned k = 0; k < neg.second.length; k++) {
if (k != last_indet) {
SASSERT(fixed[k] != l_undef);
out_fbi.push_from_bit(last_indet, k);
}
}
out_fbi.push_just(last_indet, neg.first);
fixed[last_indet] = neg.second.bits.get_bit(last_indet) ? l_false : l_true;
removed[j] = true;
LOG("Applying fast BCP on bit " << last_indet << " from constraint " << neg.first->src);
changed = true;
}
}
}
} while (changed);
#endif
return true;
}
/*
* Explain why the current variable is not viable or signleton.
@ -436,6 +637,8 @@ namespace polysat {
if (c.is_assigned(v))
return;
auto [sc, d] = c.m_constraint_trail[idx];
// fixme: constraint must be assigned a value l_true or l_false at this point.
// adjust sc to the truth value of the constraint when passed to forbidden intervals.
entry* ne = alloc_entry(v, idx);
if (!m_forbidden_intervals.get_interval(sc, v, *ne)) {

View file

@ -87,7 +87,7 @@ namespace polysat {
svector<lbool> fixed;
vector<vector<signed_constraint>> just_src;
vector<vector<signed_constraint>> just_side_cond;
vector<svector<pvar>> just_slicing;
vector<svector<justified_fixed_bits>> just_slicing;
bool is_empty() const {
SASSERT_EQ(fixed.empty(), just_src.empty());
@ -219,6 +219,8 @@ namespace polysat {
throw default_exception("nyi");
}
bool collect_bit_information(pvar v, bool add_conflict, fixed_bits_info& out_fbi);
public:
viable(core& c);

View file

@ -296,4 +296,24 @@ namespace polysat {
n->get_root()->unmark1();
}
void solver::get_fixed_bits(pvar pv, svector<justified_fixed_bits>& fixed_bits) {
theory_var v = m_pddvar2var[pv];
auto n = var2enode(v);
auto r = n->get_root();
unsigned lo, hi;
expr* e = nullptr;
for (auto p : euf::enode_parents(r)) {
if (!p->interpreted())
continue;
for (auto sib : euf::enode_class(p)) {
if (bv.is_extract(sib->get_expr(), lo, hi, e) && r == expr2enode(e)->get_root()) {
throw default_exception("nyi");
// TODO
// dependency d = dependency(p->get_th_var(get_id()), n->get_th_var(get_id()), s().scope_lvl());
// fixed_bits.push_back({ hi, lo, rational::zero(), null_dependency()});
}
}
}
}
}

View file

@ -134,6 +134,7 @@ namespace polysat {
trail_stack& trail() override;
bool inconsistent() const override;
void get_bitvector_prefixes(pvar v, pvar_vector& out) override;
void get_fixed_bits(pvar v, svector<justified_fixed_bits>& fixed_bits) override;
void add_lemma(vector<signed_constraint> const& lemma);

View file

@ -316,6 +316,12 @@ unsigned mpq_manager<SYNCH>::prev_power_of_two(mpq const & a) {
return prev_power_of_two(_tmp);
}
template<bool SYNCH>
unsigned mpq_manager<SYNCH>::next_power_of_two(mpq const & a) {
_scoped_numeral<mpz_manager<SYNCH> > _tmp(*this);
ceil(a, _tmp);
return next_power_of_two(_tmp);
}
template<bool SYNCH>
template<bool SUB>

View file

@ -848,6 +848,14 @@ public:
unsigned prev_power_of_two(mpz const & a) { return mpz_manager<SYNCH>::prev_power_of_two(a); }
unsigned prev_power_of_two(mpq const & a);
/**
\brief Return the smallest k s.t. a <= 2^k.
\remark Return 0 if a is not positive.
*/
unsigned next_power_of_two(mpz const & a) { return mpz_manager<SYNCH>::next_power_of_two(a); }
unsigned next_power_of_two(mpq const & a);
bool is_int_perfect_square(mpq const & a, mpq & r) {
SASSERT(is_int(a));
reset_denominator(r);

View file

@ -2288,6 +2288,19 @@ unsigned mpz_manager<SYNCH>::bitsize(mpz const & a) {
return mlog2(a) + 1;
}
template<bool SYNCH>
unsigned mpz_manager<SYNCH>::next_power_of_two(mpz const & a) {
if (is_nonpos(a))
return 0;
if (is_one(a))
return 0;
unsigned shift;
if (is_power_of_two(a, shift))
return shift;
else
return log2(a) + 1;
}
template<bool SYNCH>
bool mpz_manager<SYNCH>::is_perfect_square(mpz const & a, mpz & root) {
if (is_neg(a))

View file

@ -692,6 +692,13 @@ public:
\remark Return 0 if a is not positive.
*/
unsigned prev_power_of_two(mpz const & a) { return log2(a); }
/**
\brief Return the smallest k s.t. a <= 2^k.
\remark Return 0 if a is not positive.
*/
unsigned next_power_of_two(mpz const & a);
/**
\brief Return true if a^{1/n} is an integer, and store the result in a.

View file

@ -55,7 +55,7 @@ public:
explicit rational(double z) { UNREACHABLE(); }
explicit rational(char const * v) { m().set(m_val, v); }
explicit rational(unsigned const * v, unsigned sz) { m().set(m_val, sz, v); }
struct i64 {};
@ -489,6 +489,18 @@ public:
return get_num_digits(rational(10));
}
/**
* \brief Return the biggest k s.t. 2^k <= a.
* \remark Return 0 if a is not positive.
*/
unsigned prev_power_of_two() const { return m().prev_power_of_two(m_val); }
/**
* \brief Return the smallest k s.t. a <= 2^k.
* \remark Return 0 if a is not positive.
*/
unsigned next_power_of_two() const { return m().next_power_of_two(m_val); }
bool get_bit(unsigned index) const {
return m().get_bit(m_val, index);
}
@ -510,7 +522,6 @@ public:
return trailing_zeros();
}
static bool limit_denominator(rational &num, rational const& limit);
};
@ -659,3 +670,7 @@ inline rational gcd(rational const & r1, rational const & r2, rational & a, rati
rational::m().gcd(r1.m_val, r2.m_val, a.m_val, b.m_val, result.m_val);
return result;
}
inline void swap(rational& r1, rational& r2) {
r1.swap(r2);
}