3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-23 00:55:31 +00:00

We need to use expr_ref when storing expressions across add calls

Without this, bench3 created a constraint 2^parity == x * parity which
should have been 2^parity == x * x_inv.
This commit is contained in:
Jakob Rath 2023-01-16 15:37:37 +01:00
parent b33911de13
commit 26e7d0d35a
2 changed files with 46 additions and 21 deletions

View file

@ -88,7 +88,7 @@ namespace polysat {
s->pop(n);
}
unsigned scope_level() override {
unsigned scope_level() const override {
return m_scope_level;
}
@ -96,6 +96,18 @@ namespace polysat {
return bv->mk_numeral(r, bit_width);
}
expr* mk_numeral(uint64_t u) const {
return bv->mk_numeral(u, bit_width);
}
rational get_offset(univariate const& p) const {
return p.empty() ? rational::zero() : p[0];
}
bool is_constant(univariate const& p) const {
return p.empty() || std::all_of(p.begin() + 1, p.end(), [](rational const& n) { return n.is_zero(); });
}
bool is_zero(univariate const& p) const {
for (auto n : p)
if (n != 0)
@ -181,9 +193,9 @@ namespace polysat {
template <typename lhs_t, typename rhs_t>
void add_ule_impl(lhs_t const& lhs, rhs_t const& rhs, bool sign, dep_t dep) {
if (is_zero(rhs))
if (is_zero(rhs))
add(m.mk_eq(mk_poly(lhs), mk_poly(rhs)), sign, dep);
else
else
add(bv->mk_ule(mk_poly(lhs), mk_poly(rhs)), sign, dep);
}
@ -232,19 +244,21 @@ namespace polysat {
}
void add_inv(univariate const& in, univariate const& out, bool sign, dep_t dep) override {
expr* input = mk_poly(in);
expr* output = mk_poly(out);
expr* parity = get_parity(in);
expr* one = bv->mk_numeral(1, bit_width);
add(m.mk_eq(bv->mk_bv_shl(one, parity), bv->mk_bv_mul(input, output)), false, null_dep);
add(bv->mk_ule(output, bv->mk_bv_sub(bv->mk_bv_shl(one, bv->mk_bv_sub(bv->mk_numeral(bit_width, bit_width), parity)), one)), false, null_dep); // TODO: Depending on whether we want all pseudo-inverses
// out == smallest_pseudo_inverse(in)
expr_ref v = mk_poly(in);
expr_ref v_inv = mk_poly(out);
expr_ref parity = mk_parity(v, in);
// 2^parity = v * v_inv
add(m.mk_eq(bv->mk_bv_shl(mk_numeral(1), parity), bv->mk_bv_mul(v, v_inv)), false, dep);
// v_inv <= 2^(N - parity) - 1
expr* v_inv_max = bv->mk_bv_sub(bv->mk_bv_shl(mk_numeral(1), bv->mk_bv_sub(mk_numeral(bit_width), parity)), mk_numeral(1));
add(bv->mk_ule(v_inv, v_inv_max), false, dep);
}
void add_ule_const(rational const& val, bool sign, dep_t dep) override {
if (val == 0)
add(m.mk_eq(x, mk_poly(val)), sign, dep);
else
else
add(bv->mk_ule(x, mk_poly(val)), sign, dep);
}
@ -255,14 +269,27 @@ namespace polysat {
void add_bit(unsigned idx, bool sign, dep_t dep) override {
add(bv->mk_bit2bool(x, idx), sign, dep);
}
expr* get_parity(univariate const& in) override {
expr* v = mk_poly(in);
expr* parity = m.mk_fresh_const("parity", bv->mk_sort(bit_width));
expr* parity_1 = bv->mk_bv_add(parity, bv->mk_numeral(1, bit_width));
uint64_t get_parity(rational const& r) const {
return r.is_zero() ? bit_width : r.trailing_zeros();
}
expr_ref mk_parity(expr* v, univariate const& v_coeff) {
expr_ref parity(m);
if (is_constant(v_coeff)) {
parity = mk_numeral(get_parity(get_offset(v_coeff)));
return parity;
}
parity = m.mk_fresh_const("parity", bv->mk_sort(bit_width));
expr* parity_1 = bv->mk_bv_add(parity, mk_numeral(1));
// if v = 0
// then parity = N
// else v = (v >> parity) << parity
// && v != (v >> parity+1) << parity+1
// TODO: what about: v[k:] = 0 && v[k+1:] != 0 ==> parity = k for each k?
add(m.mk_ite(
m.mk_eq(v, bv->mk_numeral(0, bit_width)),
m.mk_eq(parity, bv->mk_numeral(bit_width, bit_width)),
m.mk_eq(v, mk_numeral(0)),
m.mk_eq(parity, mk_numeral(bit_width)),
m.mk_and(
m.mk_eq(bv->mk_bv_shl(bv->mk_bv_lshr(v, parity), parity), v),
m.mk_not(m.mk_eq(bv->mk_bv_shl(bv->mk_bv_lshr(v, parity_1), parity_1), v))

View file

@ -38,7 +38,7 @@ namespace polysat {
virtual void push() = 0;
virtual void pop(unsigned n) = 0;
virtual unsigned scope_level() = 0;
virtual unsigned scope_level() const = 0;
virtual lbool check() = 0;
@ -117,8 +117,6 @@ namespace polysat {
virtual void add_bit(unsigned idx, bool sign, dep_t dep) = 0;
void add_bit0(unsigned idx, dep_t dep) { add_bit(idx, true, dep); }
void add_bit1(unsigned idx, dep_t dep) { add_bit(idx, false, dep); }
virtual expr* get_parity(univariate const& in) = 0;
virtual std::ostream& display(std::ostream& out) const = 0;
};