mirror of
https://github.com/Z3Prover/z3
synced 2025-04-22 08:35:31 +00:00
We need to use expr_ref
when storing expressions across add
calls
Without this, bench3 created a constraint 2^parity == x * parity which should have been 2^parity == x * x_inv.
This commit is contained in:
parent
b33911de13
commit
26e7d0d35a
2 changed files with 46 additions and 21 deletions
|
@ -88,7 +88,7 @@ namespace polysat {
|
|||
s->pop(n);
|
||||
}
|
||||
|
||||
unsigned scope_level() override {
|
||||
unsigned scope_level() const override {
|
||||
return m_scope_level;
|
||||
}
|
||||
|
||||
|
@ -96,6 +96,18 @@ namespace polysat {
|
|||
return bv->mk_numeral(r, bit_width);
|
||||
}
|
||||
|
||||
expr* mk_numeral(uint64_t u) const {
|
||||
return bv->mk_numeral(u, bit_width);
|
||||
}
|
||||
|
||||
rational get_offset(univariate const& p) const {
|
||||
return p.empty() ? rational::zero() : p[0];
|
||||
}
|
||||
|
||||
bool is_constant(univariate const& p) const {
|
||||
return p.empty() || std::all_of(p.begin() + 1, p.end(), [](rational const& n) { return n.is_zero(); });
|
||||
}
|
||||
|
||||
bool is_zero(univariate const& p) const {
|
||||
for (auto n : p)
|
||||
if (n != 0)
|
||||
|
@ -181,9 +193,9 @@ namespace polysat {
|
|||
|
||||
template <typename lhs_t, typename rhs_t>
|
||||
void add_ule_impl(lhs_t const& lhs, rhs_t const& rhs, bool sign, dep_t dep) {
|
||||
if (is_zero(rhs))
|
||||
if (is_zero(rhs))
|
||||
add(m.mk_eq(mk_poly(lhs), mk_poly(rhs)), sign, dep);
|
||||
else
|
||||
else
|
||||
add(bv->mk_ule(mk_poly(lhs), mk_poly(rhs)), sign, dep);
|
||||
}
|
||||
|
||||
|
@ -232,19 +244,21 @@ namespace polysat {
|
|||
}
|
||||
|
||||
void add_inv(univariate const& in, univariate const& out, bool sign, dep_t dep) override {
|
||||
expr* input = mk_poly(in);
|
||||
expr* output = mk_poly(out);
|
||||
expr* parity = get_parity(in);
|
||||
expr* one = bv->mk_numeral(1, bit_width);
|
||||
|
||||
add(m.mk_eq(bv->mk_bv_shl(one, parity), bv->mk_bv_mul(input, output)), false, null_dep);
|
||||
add(bv->mk_ule(output, bv->mk_bv_sub(bv->mk_bv_shl(one, bv->mk_bv_sub(bv->mk_numeral(bit_width, bit_width), parity)), one)), false, null_dep); // TODO: Depending on whether we want all pseudo-inverses
|
||||
// out == smallest_pseudo_inverse(in)
|
||||
expr_ref v = mk_poly(in);
|
||||
expr_ref v_inv = mk_poly(out);
|
||||
expr_ref parity = mk_parity(v, in);
|
||||
// 2^parity = v * v_inv
|
||||
add(m.mk_eq(bv->mk_bv_shl(mk_numeral(1), parity), bv->mk_bv_mul(v, v_inv)), false, dep);
|
||||
// v_inv <= 2^(N - parity) - 1
|
||||
expr* v_inv_max = bv->mk_bv_sub(bv->mk_bv_shl(mk_numeral(1), bv->mk_bv_sub(mk_numeral(bit_width), parity)), mk_numeral(1));
|
||||
add(bv->mk_ule(v_inv, v_inv_max), false, dep);
|
||||
}
|
||||
|
||||
void add_ule_const(rational const& val, bool sign, dep_t dep) override {
|
||||
if (val == 0)
|
||||
add(m.mk_eq(x, mk_poly(val)), sign, dep);
|
||||
else
|
||||
else
|
||||
add(bv->mk_ule(x, mk_poly(val)), sign, dep);
|
||||
}
|
||||
|
||||
|
@ -255,14 +269,27 @@ namespace polysat {
|
|||
void add_bit(unsigned idx, bool sign, dep_t dep) override {
|
||||
add(bv->mk_bit2bool(x, idx), sign, dep);
|
||||
}
|
||||
|
||||
expr* get_parity(univariate const& in) override {
|
||||
expr* v = mk_poly(in);
|
||||
expr* parity = m.mk_fresh_const("parity", bv->mk_sort(bit_width));
|
||||
expr* parity_1 = bv->mk_bv_add(parity, bv->mk_numeral(1, bit_width));
|
||||
|
||||
uint64_t get_parity(rational const& r) const {
|
||||
return r.is_zero() ? bit_width : r.trailing_zeros();
|
||||
}
|
||||
|
||||
expr_ref mk_parity(expr* v, univariate const& v_coeff) {
|
||||
expr_ref parity(m);
|
||||
if (is_constant(v_coeff)) {
|
||||
parity = mk_numeral(get_parity(get_offset(v_coeff)));
|
||||
return parity;
|
||||
}
|
||||
parity = m.mk_fresh_const("parity", bv->mk_sort(bit_width));
|
||||
expr* parity_1 = bv->mk_bv_add(parity, mk_numeral(1));
|
||||
// if v = 0
|
||||
// then parity = N
|
||||
// else v = (v >> parity) << parity
|
||||
// && v != (v >> parity+1) << parity+1
|
||||
// TODO: what about: v[k:] = 0 && v[k+1:] != 0 ==> parity = k for each k?
|
||||
add(m.mk_ite(
|
||||
m.mk_eq(v, bv->mk_numeral(0, bit_width)),
|
||||
m.mk_eq(parity, bv->mk_numeral(bit_width, bit_width)),
|
||||
m.mk_eq(v, mk_numeral(0)),
|
||||
m.mk_eq(parity, mk_numeral(bit_width)),
|
||||
m.mk_and(
|
||||
m.mk_eq(bv->mk_bv_shl(bv->mk_bv_lshr(v, parity), parity), v),
|
||||
m.mk_not(m.mk_eq(bv->mk_bv_shl(bv->mk_bv_lshr(v, parity_1), parity_1), v))
|
||||
|
|
|
@ -38,7 +38,7 @@ namespace polysat {
|
|||
|
||||
virtual void push() = 0;
|
||||
virtual void pop(unsigned n) = 0;
|
||||
virtual unsigned scope_level() = 0;
|
||||
virtual unsigned scope_level() const = 0;
|
||||
|
||||
virtual lbool check() = 0;
|
||||
|
||||
|
@ -117,8 +117,6 @@ namespace polysat {
|
|||
virtual void add_bit(unsigned idx, bool sign, dep_t dep) = 0;
|
||||
void add_bit0(unsigned idx, dep_t dep) { add_bit(idx, true, dep); }
|
||||
void add_bit1(unsigned idx, dep_t dep) { add_bit(idx, false, dep); }
|
||||
|
||||
virtual expr* get_parity(univariate const& in) = 0;
|
||||
|
||||
virtual std::ostream& display(std::ostream& out) const = 0;
|
||||
};
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue