3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-23 09:05:31 +00:00

intblast with lazy expansion of shl, ashr, lshr

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2023-12-16 15:12:57 -08:00
parent 50e0fd3ba6
commit d0a59f3740
10 changed files with 321 additions and 83 deletions

View file

@ -205,58 +205,117 @@ namespace arith {
add_clause(dgez, neg);
}
bool solver::check_band_term(app* n) {
bool solver::check_bv_term(app* n) {
unsigned sz;
expr* x, * y;
expr* _x, * _y;
if (!ctx.is_relevant(expr2enode(n)))
return true;
VERIFY(a.is_band(n, sz, x, y));
expr_ref vx(m), vy(m),vn(m);
if (!get_value(expr2enode(x), vx) || !get_value(expr2enode(y), vy) || !get_value(expr2enode(n), vn)) {
rational valn, valx, valy;
bool is_int;
VERIFY(a.is_band(n, sz, _x, _y) || a.is_shl(n, sz, _x, _y) || a.is_ashr(n, sz, _x, _y) || a.is_lshr(n, sz, _x, _y));
if (!get_value(expr2enode(_x), vx) || !get_value(expr2enode(_y), vy) || !get_value(expr2enode(n), vn)) {
IF_VERBOSE(2, verbose_stream() << "could not get value of " << mk_pp(n, m) << "\n");
found_unsupported(n);
return true;
}
rational valn, valx, valy;
bool is_int;
if (!a.is_numeral(vn, valn, is_int) || !is_int || !a.is_numeral(vx, valx, is_int) || !is_int || !a.is_numeral(vy, valy, is_int) || !is_int) {
IF_VERBOSE(2, verbose_stream() << "could not get value of " << mk_pp(n, m) << "\n");
found_unsupported(n);
return true;
}
// verbose_stream() << "band: " << mk_pp(n, m) << " " << valn << " := " << valx << "&" << valy << "\n";
rational N = rational::power_of_two(sz);
valx = mod(valx, N);
valy = mod(valy, N);
expr_ref x(a.mk_mod(_x, a.mk_int(N)), m);
expr_ref y(a.mk_mod(_y, a.mk_int(N)), m);
SASSERT(0 <= valn && valn < N);
// x mod 2^{i + 1} >= 2^i means the i'th bit is 1.
auto bitof = [&](expr* x, unsigned i) {
expr_ref r(m);
r = a.mk_ge(a.mk_mod(x, a.mk_int(rational::power_of_two(i+1))), a.mk_int(rational::power_of_two(i)));
return mk_literal(r);
};
for (unsigned i = 0; i < sz; ++i) {
bool xb = valx.get_bit(i);
bool yb = valy.get_bit(i);
bool nb = valn.get_bit(i);
if (xb && yb && !nb)
add_clause(~bitof(x, i), ~bitof(y, i), bitof(n, i));
else if (nb && !xb)
add_clause(~bitof(n, i), bitof(x, i));
else if (nb && !yb)
add_clause(~bitof(n, i), bitof(y, i));
else
continue;
if (a.is_band(n)) {
IF_VERBOSE(2, verbose_stream() << "band: " << mk_bounded_pp(n, m) << " " << valn << " := " << valx << "&" << valy << "\n");
for (unsigned i = 0; i < sz; ++i) {
bool xb = valx.get_bit(i);
bool yb = valy.get_bit(i);
bool nb = valn.get_bit(i);
if (xb && yb && !nb)
add_clause(~bitof(x, i), ~bitof(y, i), bitof(n, i));
else if (nb && !xb)
add_clause(~bitof(n, i), bitof(x, i));
else if (nb && !yb)
add_clause(~bitof(n, i), bitof(y, i));
else
continue;
return false;
}
}
if (a.is_shl(n)) {
SASSERT(valy >= 0);
if (valy >= sz || valy == 0)
return true;
unsigned k = valy.get_unsigned();
sat::literal eq = eq_internalize(n, a.mk_mod(a.mk_mul(_x, a.mk_int(rational::power_of_two(k))), a.mk_int(N)));
if (s().value(eq) == l_true)
return true;
add_clause(~eq_internalize(y, a.mk_int(k)), eq);
IF_VERBOSE(2, verbose_stream() << "shl: " << mk_bounded_pp(n, m) << " " << valn << " := " << valx << " << " << valy << "\n");
return false;
}
if (a.is_lshr(n)) {
SASSERT(valy >= 0);
if (valy >= sz || valy == 0)
return true;
unsigned k = valy.get_unsigned();
sat::literal eq = eq_internalize(n, a.mk_idiv(x, a.mk_int(rational::power_of_two(k))));
if (s().value(eq) == l_true)
return true;
add_clause(~eq_internalize(y, a.mk_int(k)), eq);
IF_VERBOSE(2, verbose_stream() << "lshr: " << mk_bounded_pp(n, m) << " " << valn << " := " << valx << " >>l " << valy << "\n");
return false;
}
if (a.is_ashr(n)) {
SASSERT(valy >= 0);
if (valy >= sz || valy == 0)
return true;
unsigned k = valy.get_unsigned();
sat::literal signx = mk_literal(a.mk_ge(x, a.mk_int(N/2)));
sat::literal eq;
expr* xdiv2k;
switch (s().value(signx)) {
case l_true:
// x < 0 & y = k -> n = (x div 2^k - 2^{N-k}) mod 2^N
xdiv2k = a.mk_idiv(x, a.mk_int(rational::power_of_two(k)));
eq = eq_internalize(n, a.mk_mod(a.mk_add(xdiv2k, a.mk_int(-rational::power_of_two(sz - k))), a.mk_int(N)));
if (s().value(eq) == l_true)
return true;
break;
case l_false:
// x >= 0 & y = k -> n = x div 2^k
xdiv2k = a.mk_idiv(x, a.mk_int(rational::power_of_two(k)));
eq = eq_internalize(n, xdiv2k);
if (s().value(eq) == l_true)
return true;
break;
case l_undef:
ctx.mark_relevant(signx);
return false;
}
add_clause(~eq_internalize(y, a.mk_int(k)), ~signx, eq);
return false;
}
return true;
}
bool solver::check_band_terms() {
for (app* n : m_band_terms) {
if (!check_band_term(n)) {
++m_stats.m_band_axioms;
bool solver::check_bv_terms() {
for (app* n : m_bv_terms) {
if (!check_bv_term(n)) {
++m_stats.m_bv_axioms;
return false;
}
}
@ -268,15 +327,43 @@ namespace arith {
* x&y <= x
* x&y <= y
*/
void solver::mk_band_axiom(app* n) {
void solver::mk_bv_axiom(app* n) {
unsigned sz;
expr* x, * y;
VERIFY(a.is_band(n, sz, x, y));
expr* _x, * _y;
VERIFY(a.is_band(n, sz, _x, _y) || a.is_shl(n, sz, _x, _y) || a.is_ashr(n, sz, _x, _y) || a.is_lshr(n, sz, _x, _y));
rational N = rational::power_of_two(sz);
add_clause(mk_literal(a.mk_ge(n, a.mk_int(0))));
add_clause(mk_literal(a.mk_le(n, a.mk_int(N - 1))));
add_clause(mk_literal(a.mk_le(n, a.mk_mod(x, a.mk_int(N)))));
add_clause(mk_literal(a.mk_le(n, a.mk_mod(y, a.mk_int(N)))));
expr_ref x(a.mk_mod(_x, a.mk_int(N)), m);
expr_ref y(a.mk_mod(_y, a.mk_int(N)), m);
if (a.is_band(n)) {
add_clause(mk_literal(a.mk_ge(n, a.mk_int(0))));
add_clause(mk_literal(a.mk_le(n, a.mk_int(N - 1))));
add_clause(mk_literal(a.mk_le(n, x)));
add_clause(mk_literal(a.mk_le(n, y)));
}
else if (a.is_shl(n)) {
// y >= sz => n = 0
// y = 0 => n = x
add_clause(~mk_literal(a.mk_ge(y, a.mk_int(sz))), mk_literal(m.mk_eq(n, a.mk_int(0))));
add_clause(~mk_literal(a.mk_eq(y, a.mk_int(0))), mk_literal(m.mk_eq(n, x)));
}
else if (a.is_lshr(n)) {
// y >= sz => n = 0
// y = 0 => n = x
add_clause(~mk_literal(a.mk_ge(y, a.mk_int(sz))), mk_literal(m.mk_eq(n, a.mk_int(0))));
add_clause(~mk_literal(a.mk_eq(y, a.mk_int(0))), mk_literal(m.mk_eq(n, x)));
}
else if (a.is_ashr(n)) {
// y >= sz & x < 2^{sz-1} => n = 0
// y >= sz & x >= 2^{sz-1} => n = -1
// y = 0 => n = x
auto signx = mk_literal(a.mk_ge(x, a.mk_int(N/2)));
add_clause(~mk_literal(a.mk_ge(a.mk_mod(y, a.mk_int(N)), a.mk_int(sz))), signx, mk_literal(m.mk_eq(n, a.mk_int(0))));
add_clause(~mk_literal(a.mk_ge(a.mk_mod(y, a.mk_int(N)), a.mk_int(sz))), ~signx, mk_literal(m.mk_eq(n, a.mk_int(N-1))));
add_clause(~mk_literal(a.mk_eq(a.mk_mod(y, a.mk_int(N)), a.mk_int(0))), mk_literal(m.mk_eq(n, x)));
}
else
UNREACHABLE();
}
void solver::mk_bound_axioms(api_bound& b) {

View file

@ -252,10 +252,10 @@ namespace arith {
st.to_ensure_var().push_back(n1);
st.to_ensure_var().push_back(n2);
}
else if (a.is_band(n)) {
m_band_terms.push_back(to_app(n));
mk_band_axiom(to_app(n));
ctx.push(push_back_vector(m_band_terms));
else if (a.is_band(n) || a.is_shl(n) || a.is_ashr(n) || a.is_lshr(n)) {
m_bv_terms.push_back(to_app(n));
ctx.push(push_back_vector(m_bv_terms));
mk_bv_axiom(to_app(n));
ensure_arg_vars(to_app(n));
}
else if (!a.is_div0(n) && !a.is_mod0(n) && !a.is_idiv0(n) && !a.is_rem0(n) && !a.is_power0(n)) {

View file

@ -1053,7 +1053,7 @@ namespace arith {
if (!check_delayed_eqs())
return sat::check_result::CR_CONTINUE;
if (!int_undef && !check_band_terms())
if (!int_undef && !check_bv_terms())
return sat::check_result::CR_CONTINUE;
if (ctx.get_config().m_arith_ignore_int && int_undef)

View file

@ -214,7 +214,7 @@ namespace arith {
expr* m_not_handled = nullptr;
ptr_vector<app> m_underspecified;
ptr_vector<expr> m_idiv_terms;
ptr_vector<app> m_band_terms;
ptr_vector<app> m_bv_terms;
vector<ptr_vector<api_bound> > m_use_list; // bounds where variables are used.
// attributes for incremental version:
@ -318,7 +318,7 @@ namespace arith {
void mk_bound_axioms(api_bound& b);
void mk_bound_axiom(api_bound& b1, api_bound& b2);
void mk_power0_axioms(app* t, app* n);
void mk_band_axiom(app* n);
void mk_bv_axiom(app* n);
void flush_bound_axioms();
void add_farkas_clause(sat::literal l1, sat::literal l2);
@ -410,8 +410,8 @@ namespace arith {
bool check_delayed_eqs();
lbool check_lia();
lbool check_nla();
bool check_band_terms();
bool check_band_term(app* n);
bool check_bv_terms();
bool check_bv_term(app* n);
void add_lemmas();
void propagate_nla();
void add_equality(lpvar v, rational const& k, lp::explanation const& exp);

View file

@ -656,24 +656,58 @@ namespace intblast {
break;
}
case OP_BSHL: {
expr* x = arg(0), * y = umod(e, 1);
r = a.mk_int(0);
for (unsigned i = 0; i < bv.get_bv_size(e); ++i)
r = m.mk_ite(m.mk_eq(y, a.mk_int(i)), a.mk_mul(x, a.mk_int(rational::power_of_two(i))), r);
if (!a.is_numeral(arg(0)) && !a.is_numeral(arg(1)))
r = a.mk_shl(bv.get_bv_size(e), arg(0),arg(1));
else {
expr* x = arg(0), * y = umod(e, 1);
r = a.mk_int(0);
IF_VERBOSE(2, verbose_stream() << "shl " << mk_bounded_pp(e, m) << " " << bv.get_bv_size(e) << "\n");
for (unsigned i = 0; i < bv.get_bv_size(e); ++i)
r = m.mk_ite(m.mk_eq(y, a.mk_int(i)), a.mk_mul(x, a.mk_int(rational::power_of_two(i))), r);
}
break;
}
case OP_BNOT:
r = bnot(arg(0));
break;
case OP_BLSHR: {
expr* x = arg(0), * y = umod(e, 1);
r = a.mk_int(0);
for (unsigned i = 0; i < bv.get_bv_size(e); ++i)
r = m.mk_ite(m.mk_eq(y, a.mk_int(i)), a.mk_idiv(x, a.mk_int(rational::power_of_two(i))), r);
case OP_BLSHR:
if (!a.is_numeral(arg(0)) && !a.is_numeral(arg(1)))
r = a.mk_lshr(bv.get_bv_size(e), arg(0), arg(1));
else {
expr* x = arg(0), * y = umod(e, 1);
r = a.mk_int(0);
IF_VERBOSE(2, verbose_stream() << "lshr " << mk_bounded_pp(e, m) << " " << bv.get_bv_size(e) << "\n");
for (unsigned i = 0; i < bv.get_bv_size(e); ++i)
r = m.mk_ite(m.mk_eq(y, a.mk_int(i)), a.mk_idiv(x, a.mk_int(rational::power_of_two(i))), r);
}
break;
case OP_BASHR:
if (!a.is_numeral(arg(1)))
r = a.mk_ashr(bv.get_bv_size(e), arg(0), arg(1));
else {
//
// ashr(x, y)
// if y = k & x >= 0 -> x / 2^k
// if y = k & x < 0 -> (x / 2^k) - 2^{N-k}
//
unsigned sz = bv.get_bv_size(e);
rational N = bv_size(e);
expr* x = umod(e, 0), *y = umod(e, 1);
expr* signx = a.mk_ge(x, a.mk_int(N / 2));
r = m.mk_ite(signx, a.mk_int(- 1), a.mk_int(0));
IF_VERBOSE(1, verbose_stream() << "ashr " << mk_bounded_pp(e, m) << " " << bv.get_bv_size(e) << "\n");
for (unsigned i = 0; i < sz; ++i) {
expr* d = a.mk_idiv(x, a.mk_int(rational::power_of_two(i)));
r = m.mk_ite(m.mk_eq(y, a.mk_int(i)),
m.mk_ite(signx, a.mk_add(d, a.mk_int(- rational::power_of_two(sz-i))), d),
r);
}
}
break;
}
case OP_BOR: {
// p | q := (p + q) - band(p, q)
IF_VERBOSE(2, verbose_stream() << "bor " << mk_bounded_pp(e, m) << " " << bv.get_bv_size(e) << "\n");
r = arg(0);
for (unsigned i = 1; i < args.size(); ++i)
r = a.mk_sub(a.mk_add(r, arg(i)), a.mk_band(bv.get_bv_size(e), r, arg(i)));
@ -683,12 +717,14 @@ namespace intblast {
r = bnot(band(args));
break;
case OP_BAND:
IF_VERBOSE(2, verbose_stream() << "band " << mk_bounded_pp(e, m) << " " << bv.get_bv_size(e) << "\n");
r = band(args);
break;
case OP_BXNOR:
case OP_BXOR: {
// p ^ q := (p + q) - 2*band(p, q);
unsigned sz = bv.get_bv_size(e);
IF_VERBOSE(2, verbose_stream() << "bxor " << bv.get_bv_size(e) << "\n");
r = arg(0);
for (unsigned i = 1; i < args.size(); ++i) {
expr* q = arg(i);
@ -698,25 +734,6 @@ namespace intblast {
r = bnot(r);
break;
}
case OP_BASHR: {
//
// ashr(x, y)
// if y = k & x >= 0 -> x / 2^k
// if y = k & x < 0 -> (x / 2^k) - 1 + 2^{N-k}
//
unsigned sz = bv.get_bv_size(e);
rational N = bv_size(e);
expr* x = umod(e, 0), *y = umod(e, 1);
expr* signx = a.mk_ge(x, a.mk_int(N / 2));
r = m.mk_ite(signx, a.mk_int(- 1), a.mk_int(0));
for (unsigned i = 0; i < sz; ++i) {
expr* d = a.mk_idiv(x, a.mk_int(rational::power_of_two(i)));
r = m.mk_ite(m.mk_eq(y, a.mk_int(i)),
m.mk_ite(signx, a.mk_add(d, a.mk_int(- rational::power_of_two(sz-i))), d),
r);
}
break;
}
case OP_ZERO_EXT:
bv_expr = e->get_arg(0);
r = umod(bv_expr, 0);