3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-24 01:25:31 +00:00

Detect more equations in refine_equal_lin

This commit is contained in:
Jakob Rath 2022-12-21 12:21:22 +01:00
parent 8da9850d45
commit 109ab0be40
3 changed files with 93 additions and 11 deletions

View file

@ -1724,7 +1724,7 @@ namespace dd {
unsigned pow;
if (val.is_power_of_two(pow) && pow > 10)
return out << "2^" << pow;
for (int offset : {-1, 1})
for (int offset : {-2, -1, 1, 2})
if (val < m.max_value() && (val - offset).is_power_of_two(pow) && pow > 10)
return out << lparen() << "2^" << pow << (offset >= 0 ? "+" : "") << offset << rparen();
rational neg_val = mod(-val, m.two_to_N());

30
src/math/polysat/number.h Normal file
View file

@ -0,0 +1,30 @@
/*++
Copyright (c) 2021 Microsoft Corporation
Module Name:
polysat numbers
Author:
Nikolaj Bjorner (nbjorner) 2021-03-19
Jakob Rath 2021-04-06
--*/
#pragma once
#include "math/polysat/types.h"
namespace polysat {
inline unsigned get_parity(rational const& val, unsigned num_bits) {
if (val.is_zero())
return num_bits;
return val.trailing_zeros();
};
/** Return val with the lower k bits set to zero. */
inline rational clear_lower_bits(rational const& val, unsigned k) {
return val - mod(val, rational::power_of_two(k));
}
}

View file

@ -24,8 +24,11 @@ TODO: improve management of the fallback univariate solvers:
- set resource limit of univariate solver
TODO: plan to fix the FI "pumping":
1. simple looping detection and bitblasting fallback.
1. simple looping detection and bitblasting fallback. -- done
2. intervals at multiple bit widths
- for equations, this will give us exact solutions for all coefficients
- for inequalities, a coefficient 2^k*a means that intervals are periodic because the upper k bits of x are irrelevant;
storing the interval for x[K-k:0] would take care of this.
--*/
@ -33,6 +36,7 @@ TODO: plan to fix the FI "pumping":
#include "util/debug.h"
#include "math/polysat/viable.h"
#include "math/polysat/solver.h"
#include "math/polysat/number.h"
#include "math/polysat/univariate/univariate_solver.h"
namespace polysat {
@ -305,8 +309,10 @@ namespace polysat {
if (!e)
return true;
entry const* first = e;
rational const& max_value = s.var2pdd(v).max_value();
rational mod_value = max_value + 1;
auto& m = s.var2pdd(v);
unsigned const N = m.power_of_2();
rational const& max_value = m.max_value();
rational const& mod_value = m.two_to_N();
// Rotate the 'first' entry, to prevent getting stuck in a refinement loop
// with an early entry when a later entry could give a better interval.
@ -351,19 +357,66 @@ namespace polysat {
rational coeff_val = mod(e->coeff * val, mod_value);
if (e->interval.currently_contains(coeff_val)) {
LOG("refine-equal-lin for src: " << lit_pp(s, e->src));
LOG("forbidden interval v" << v << " " << num_pp(s, v, val) << " " << num_pp(s, v, e->coeff, true) << " * " << e->interval);
if (e->interval.lo_val().is_one() && e->interval.hi_val().is_zero() && e->coeff.is_odd()) {
rational lo(1);
rational hi(0);
LOG("refine-equal-lin: " << " [" << lo << ", " << hi << "[");
pdd lop = s.var2pdd(v).mk_val(lo);
pdd hip = s.var2pdd(v).mk_val(hi);
entry* ne = alloc_entry();
ne->refined = e;
ne->src = e->src;
ne->side_cond = e->side_cond;
ne->coeff = 1;
ne->interval = eval_interval::proper(lop, lo, hip, hi);
ne->interval = eval_interval::proper(m.mk_val(lo), lo, m.mk_val(hi), hi);
intersect(v, ne);
return false;
}
if (mod(e->interval.hi_val() + 1, mod_value) == e->interval.lo_val()) {
// We have an equation: a * v == b
rational const a = e->coeff;
rational const b = e->interval.hi_val();
LOG("refine-equal-lin: equation detected: " << dd::val_pp(m, a, true) << " * v" << v << " == " << dd::val_pp(m, b, false));
unsigned const parity_a = get_parity(a, N);
unsigned const parity_b = get_parity(b, N);
if (parity_a > parity_b) {
// No solution
LOG("refined: no solution due to parity");
entry* ne = alloc_entry();
ne->refined = e;
ne->src = e->src;
ne->side_cond = e->side_cond;
ne->coeff = 1;
ne->interval = eval_interval::full();
intersect(v, ne);
return false;
}
// 2^k * v == a_inv * b
// 2^k solutions because only the lower N-k bits of v are fixed.
//
// Solutions are of the form v_i = v0 + 2^(N-k) * i for i in { 0, 1, ..., 2^k - 1 }.
// Forbidden intervals: [v_i + 1; v_{i+1}[ == [ v_i + 1; v_i + 2^(N-k) [
// We need the interval that covers val:
// v_i + 1 <= val < v_i + 2^(N-k)
rational const a_inv = a.pseudo_inverse(N);
unsigned const N_minus_k = N - parity_a;
rational const two_to_N_minus_k = rational::power_of_two(N_minus_k);
rational const v0 = mod(a_inv * b, two_to_N_minus_k);
SASSERT(mod(val, two_to_N_minus_k) != v0); // val is not a solution
rational const vi = v0 + clear_lower_bits(mod(val - v0, mod_value), N_minus_k);
rational const lo = vi + 1;
rational const hi = mod(vi + two_to_N_minus_k, mod_value);
LOG("refined to [" << num_pp(s, v, lo) << ", " << num_pp(s, v, hi) << "[");
SASSERT_EQ(mod(a * (lo - 1), mod_value), b); // lo-1 is a solution
SASSERT_EQ(mod(a * hi, mod_value), b); // hi is a solution
entry* ne = alloc_entry();
ne->refined = e;
ne->src = e->src;
ne->side_cond = e->side_cond;
ne->coeff = 1;
ne->interval = eval_interval::proper(m.mk_val(lo), lo, m.mk_val(hi), hi);
SASSERT(ne->interval.currently_contains(val));
intersect(v, ne);
return false;
}
@ -388,13 +441,11 @@ namespace polysat {
lo = val - lambda_l;
increase_hi(hi);
}
LOG("forbidden interval v" << v << " " << num_pp(s, v, val) << " " << num_pp(s, v, e->coeff, true) << " * " << e->interval << " [" << num_pp(s, v, lo) << ", " << num_pp(s, v, hi) << "[");
LOG("refined to [" << num_pp(s, v, lo) << ", " << num_pp(s, v, hi) << "[");
SASSERT(hi <= mod_value);
bool full = (lo == 0 && hi == mod_value);
if (hi == mod_value)
hi = 0;
pdd lop = s.var2pdd(v).mk_val(lo);
pdd hip = s.var2pdd(v).mk_val(hi);
entry* ne = alloc_entry();
ne->refined = e;
ne->src = e->src;
@ -403,7 +454,7 @@ namespace polysat {
if (full)
ne->interval = eval_interval::full();
else
ne->interval = eval_interval::proper(lop, lo, hip, hi);
ne->interval = eval_interval::proper(m.mk_val(lo), lo, m.mk_val(hi), hi);
intersect(v, ne);
return false;
}
@ -739,6 +790,7 @@ namespace polysat {
// First step: only query the looping constraints and see if they alone are already UNSAT.
// The constraints which caused the refinement loop will be reached from m_units.
LOG_H3("Checking looping univariate constraints for v" << v << "...");
LOG("Assignment: " << assignments_pp(s));
entry const* first = m_units[v];
entry const* e = first;
do {