3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2026-03-23 12:59:12 +00:00
Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2023-12-29 15:13:11 -08:00
parent 97225b7d8f
commit 03e012c1d8
14 changed files with 167 additions and 139 deletions

View file

@ -183,78 +183,56 @@ namespace polysat {
}
/*
* v in [lo, hi[:
* - hi >= v: forward(v) := hi
* - hi < v: l_false
* 2^k v in [lo, hi[:
* - hi > 2^k v: forward(v) := hi//2^k + 2^k(v//2^k)
* - hi <= 2^k v: forward(v) := hi//2^k + 2^k(v//2^k + 1) unless it overflows.
* w is a suffix of v of width w.width <= v.width with forbidden 2^l w not in [lo, hi[ and 2^l v[w.width-1:0] in [lo, hi[.
* - set k := l + v.width - w.width, lo' := 2^{v.width-w.width} lo, hi' := 2^{v.width-w.width} hi.
* In either case we are checking a constraint $v[u-1:0]\not\in[lo, hi[$
* where $u := w'-k-1$ and using it to compute $\forward(v)$.
* Thus, suppose $v[u-1:0] \in [lo, hi[$.
* - $lo < hi$: $\forward(v) := \forward(2^u v[w-1:w-u] + hi)$.
* - $lo > hi$, $v[w-1:w-u]+1 = 2^{w-u}$: $\forward(v) := \bot$
* - $lo > hi$, $v[w-1:w-u]+1 < 2^{w-u}$: $\forward(v) := \forward(2^u (v[w-1:w-u]+1) + hi)$
*/
lbool viable::next_viable_layer(pvar w, layer& layer, rational& val) {
if (!layer.entries)
return l_true;
unsigned v_width = m_num_bits;
unsigned w_width = c.size(w);
unsigned l = w_width - layer.bit_width;
SASSERT(v_width >= w_width);
SASSERT(layer.bit_width <= w_width);
unsigned w_width = c.size(w);
unsigned b_width = layer.bit_width;
SASSERT(b_width <= w_width);
SASSERT(w_width <= v_width);
bool is_zero = val.is_zero(), wrapped = false;
rational val1 = val;
rational const& p2l = rational::power_of_two(l);
rational const& p2w = rational::power_of_two(w_width);
while (true) {
if (l > 0)
val1 *= p2l;
if (w_width < v_width || l > 0)
val1 = mod(val1, p2w);
rational const& p2v = rational::power_of_two(v_width);
rational const& p2b = rational::power_of_two(b_width);
if (b_width < v_width)
val1 = mod(val1, p2b);
rational start = val1;
while (true) {
auto e = find_overlap(val1, layer.entries);
if (!e) {
if (l > 0)
val1 /= p2l;
break;
}
if (!e)
break;
// TODO check if admitted: layer.entries = e;
m_explain.push_back(e);
if (e->interval.is_full())
return l_false;
auto hi = e->interval.hi_val();
if (hi < val1) {
if (is_zero)
return l_false;
if (w_width == v_width && l == 0)
return l_false;
// start from 0 and find the next viable value within this layer.
val1 = 0;
if (hi < e->interval.lo_val())
wrapped = true;
}
if (wrapped && start <= hi)
return l_false;
val1 = hi;
SASSERT(val1 < p2w);
// p2l * x = val1 = hi
if (l > 0)
val1 = hi / p2l;
SASSERT(val1.is_int());
SASSERT(val1 < p2b);
}
SASSERT(val1 < p2w);
if (w_width < v_width) {
if (l > 0)
NOT_IMPLEMENTED_YET();
rational val2 = val;
SASSERT(val1 < p2b);
if (b_width < v_width) {
rational val2 = clear_lower_bits(val, b_width);
if (wrapped) {
val2 = mod(div(val2, p2w) + 1, p2w) * p2w;
if (val2 == 0)
val2 += p2b;
if (val2 >= p2v)
return l_false;
}
else
val2 = clear_lower_bits(val2, w_width);
val = val1 + val2;
}
else if (l > 0) {
NOT_IMPLEMENTED_YET();
}
else
val = val1;
@ -264,7 +242,6 @@ namespace polysat {
// Find interval that contains 'val', or, if no such interval exists, null.
viable::entry* viable::find_overlap(rational const& val, entry* entries) {
SASSERT(entries);
// display_all(std::cerr << "entries:\n\t", 0, entries, "\n\t");
entry* const first = entries;
entry* e = entries;
do {
@ -557,6 +534,10 @@ namespace polysat {
if (value == l_false)
sc = ~sc;
if (!sc.is_linear()) {
return true;
}
entry* ne = alloc_entry(v, idx);
if (!m_forbidden_intervals.get_interval(sc, v, *ne)) {
m_alloc.push_back(ne);
@ -573,72 +554,88 @@ namespace polysat {
if (ne->coeff == 1)
intersect(v, ne);
else if (ne->coeff == -1)
insert(ne, v, m_diseq_lin, entry_kind::diseq_e);
insert(ne, v, m_diseq_lin, entry_kind::diseq_e);
else if (!ne->coeff.is_power_of_two())
insert(ne, v, m_equal_lin, entry_kind::equal_e);
else if (ne->interval.is_full())
insert(ne, v, m_equal_lin, entry_kind::equal_e);
else {
unsigned const w = c.size(v);
unsigned const k = ne->coeff.parity(w);
// unsigned const lo_parity = ne->interval.lo_val().parity(w);
// unsigned const hi_parity = ne->interval.hi_val().parity(w);
SASSERT(k > 0);
IF_VERBOSE(1, display_one(verbose_stream() << "try to reduce entry: ", v, ne) << "\n");
IF_VERBOSE(3, display_one(verbose_stream() << "try to reduce entry: ", v, ne) << "\n");
if (k > 0 && ne->coeff.is_power_of_two()) {
// reduction of coeff gives us a unit entry
//
// 2^k a x \not\in [ lo ; hi [
//
// new_lo = lo[w-1:k] if lo[k-1:0] = 0
// lo[w-1:k] + 1 otherwise
//
// new_hi = hi[w-1:k] if hi[k-1:0] = 0
// hi[w-1:k] + 1 otherwise
//
// Reference: Fig. 1 (dtrim) in BitvectorsMCSAT
//
pdd const& pdd_lo = ne->interval.lo();
pdd const& pdd_hi = ne->interval.hi();
rational const& lo = ne->interval.lo_val();
rational const& hi = ne->interval.hi_val();
// reduction of coeff gives us a unit entry
//
// 2^k x \not\in [ lo ; hi [
//
// new_lo = lo[w-1:k] if lo[k-1:0] = 0
// lo[w-1:k] + 1 otherwise
//
// new_hi = hi[w-1:k] if hi[k-1:0] = 0
// hi[w-1:k] + 1 otherwise
//
// Reference: Fig. 1 (dtrim) in BitvectorsMCSAT
//
//
// Suppose new_lo = new_hi
// Then since ne is not full, then lo != hi
// Cases
// lo < hi, 2^k|lo, new_lo = lo / 2^k != new_hi = hi / 2^k
// lo < hi, not 2^k|lo, 2^k|hi, new_lo = lo / 2^k + 1, new_hi = hi / 2^k
// new_lo = new_hi -> empty
// k = 3, lo = 1, hi = 8, new_lo = 1, new_hi = 1
// lo < hi, not 2^k|lo, not 2^k|hi, new_lo = lo / 2^k + 1, new_hi = hi / 2^k + 1
// new_lo = new_hi -> empty
// k = 3, lo = 1, hi = 2, new_lo = 1 div 2^3 + 1 = 1, new_hi = 2 div 2^3 + 1 = 1
// lo > hi
rational new_lo = machine_div2k(lo, k);
if (mod2k(lo, k).is_zero())
ne->side_cond.push_back(cs.eq(pdd_lo * rational::power_of_two(w - k)));
else {
new_lo += 1;
ne->side_cond.push_back(~cs.eq(pdd_lo * rational::power_of_two(w - k)));
}
pdd const& pdd_lo = ne->interval.lo();
pdd const& pdd_hi = ne->interval.hi();
rational const& lo = ne->interval.lo_val();
rational const& hi = ne->interval.hi_val();
rational new_hi = machine_div2k(hi, k);
if (mod2k(hi, k).is_zero())
ne->side_cond.push_back(cs.eq(pdd_hi * rational::power_of_two(w - k)));
else {
new_hi += 1;
ne->side_cond.push_back(~cs.eq(pdd_hi * rational::power_of_two(w - k)));
}
// we have to update also the pdd bounds accordingly, but it seems not worth introducing new variables for this eagerly
// new_lo = lo[:k] etc.
// TODO: for now just disable the FI-lemma if this case occurs
ne->valid_for_lemma = false;
if (new_lo == new_hi) {
// empty or full
// if (ne->interval.currently_contains(rational::zero()))
NOT_IMPLEMENTED_YET();
}
ne->coeff = machine_div2k(ne->coeff, k);
ne->interval = eval_interval::proper(pdd_lo, new_lo, pdd_hi, new_hi);
ne->bit_width -= k;
display_one(std::cerr << "reduced entry: ", v, ne) << "\n";
LOG("reduced entry to unit in bitwidth " << ne->bit_width);
intersect(v, ne);
rational new_lo = machine_div2k(lo, k);
pdd lo_eq = pdd_lo * rational::power_of_two(w - k);
if (mod2k(lo, k).is_zero()) {
if (!lo_eq.is_zero())
ne->side_cond.push_back(cs.eq(lo_eq));
}
else {
new_lo += 1;
new_lo = machine_div2k(new_lo, k);
if (!lo_eq.is_val() || lo_eq.is_zero())
ne->side_cond.push_back(~cs.eq(lo_eq));
}
rational new_hi = machine_div2k(hi, k);
pdd hi_eq = pdd_hi * rational::power_of_two(w - k);
if (mod2k(hi, k).is_zero()) {
if (!hi_eq.is_zero())
ne->side_cond.push_back(cs.eq(hi_eq));
}
else {
new_hi += 1;
new_hi = machine_div2k(new_hi, k);
if (!hi_eq.is_val() || hi_eq.is_zero())
ne->side_cond.push_back(~cs.eq(hi_eq));
}
// we have to update also the pdd bounds accordingly, but it seems not worth introducing new variables for this eagerly
// new_lo = lo[:k] etc.
if (new_lo == new_hi) {
// empty
verbose_stream() << "always true " << "x*" << ne->coeff << " not in " << ne->interval << "\n";
m_alloc.push_back(ne);
return true;
}
// TODO: later, can reduce according to shared_parity
// unsigned const shared_parity = std::min(coeff_parity, std::min(lo_parity, hi_parity));
insert(ne, v, m_equal_lin, entry_kind::equal_e);
ne->coeff = 1;
ne->interval = eval_interval::proper(pdd_lo, new_lo, pdd_hi, new_hi);
ne->bit_width -= k;
intersect(v, ne);
}
if (ne->interval.is_full()) {
m_explain.reset();
@ -770,7 +767,6 @@ namespace polysat {
switch (k) {
case entry_kind::unit_e:
entry::remove_from(m_units[v].get_layer(e)->entries, e);
SASSERT(well_formed(m_units[v]));
break;
case entry_kind::equal_e:
entry::remove_from(m_equal_lin[v], e);
@ -782,6 +778,7 @@ namespace polysat {
UNREACHABLE();
break;
}
SASSERT(well_formed(m_units[v]));
m_alloc.push_back(e);
}
@ -879,19 +876,15 @@ namespace polysat {
std::ostream& viable::display(std::ostream& out) const {
for (unsigned v = 0; v < m_units.size(); ++v) {
bool first = true;
for (auto const& layer : m_units[v].get_layers()) {
if (!layer.entries)
continue;
if (first)
out << "v" << v << ": ";
first = false;
out << "v" << v << ": ";
if (layer.bit_width != c.size(v))
out << "width[" << layer.bit_width << "] ";
display_all(out, v, layer.entries, " ");
}
if (!first)
out << "\n";
}
}
return out;
}
@ -913,6 +906,7 @@ namespace polysat {
return true;
entry* first = e;
while (true) {
CTRACE("bv", !e->active, tout << "inactive entry v" << e->var << " " << e->interval << "\n"; display(tout));
if (!e->active)
return false;
@ -922,13 +916,17 @@ namespace polysat {
return false;
auto* n = e->next();
if (n != e && e->interval.currently_contains(n->interval))
if (n != e && e->interval.currently_contains(n->interval)) {
TRACE("bv", tout << "currently contains\n");
return false;
}
if (n == first)
break;
if (e->interval.lo_val() >= n->interval.lo_val())
if (e->interval.lo_val() >= n->interval.lo_val()) {
TRACE("bv", tout << "lo-val >= n->lo_val\n");
return false;
}
e = n;
}
return true;
@ -942,12 +940,18 @@ namespace polysat {
bool first = true;
unsigned prev_width = 0;
for (layer const& l : ls.get_layers()) {
if (!well_formed(l.entries))
if (!well_formed(l.entries)) {
TRACE("bv", tout << "entries are not well-formed\n");
return false;
if (!all_of(dll_elements(l.entries), [&l](entry const& e) { return e.bit_width == l.bit_width; }))
}
if (!all_of(dll_elements(l.entries), [&l](entry const& e) { return e.bit_width == l.bit_width; })) {
TRACE("bv", tout << "elements don't have same bit-width\n");
return false;
if (!first && prev_width >= l.bit_width)
}
if (!first && prev_width >= l.bit_width) {
TRACE("bv", tout << "previous width is " << prev_width << " vs " << l.bit_width << "\n");
return false;
}
first = false;
prev_width = l.bit_width;
}