3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-24 17:45:32 +00:00

Add more PDD utilities (div, pow) (#5180)

* Expose 'inv' on rationals to get reciprocal value

* Align parameter names with implementation

* Add cached operation that divides PDD by a constant

* Fix display for constant PDDs

* operator^ should probably call ^ instead of + (mk_xor instead of add)

* Add helper function 'pow' on PDDs
This commit is contained in:
Jakob Rath 2021-04-14 13:48:42 +02:00 committed by GitHub
parent 2f7069a8b7
commit 324d9ed461
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 146 additions and 20 deletions

View file

@ -202,7 +202,7 @@ namespace dd {
public:
struct mem_out {};
bdd_manager(unsigned nodes);
bdd_manager(unsigned num_vars);
~bdd_manager();
void set_max_num_nodes(unsigned n) { m_max_num_bdd_nodes = n; }

View file

@ -443,6 +443,97 @@ namespace dd {
return r;
}
/**
* Divide PDD by a constant value.
*
* IMPORTANT: Performs regular numerical division.
* For semantics 'mod2N_e', this means that 'c' must be an integer
* and all coefficients of 'a' must be divisible by 'c'.
*
* NOTE: Why do we not just use 'mul(a, inv(c))' instead?
* In case of semantics 'mod2N_e', an invariant is that all PDDs have integer coefficients.
* But such a multiplication would create nodes with non-integral coefficients.
*/
pdd pdd_manager::div(pdd const& a, rational const& c) {
if (m_semantics == free_e) {
// Don't cache separately for the free semantics;
// use 'mul' so we can share results for a/c and a*(1/c).
return mul(inv(c), a);
}
SASSERT(c.is_int());
bool first = true;
SASSERT(well_formed());
scoped_push _sp(*this);
while (true) {
try {
return pdd(div_rec(a.root, c, null_pdd), this);
}
catch (const mem_out &) {
try_gc();
if (!first) throw;
first = false;
}
}
SASSERT(well_formed());
return pdd(zero_pdd, this);
}
pdd_manager::PDD pdd_manager::div_rec(PDD a, rational const& c, PDD c_pdd) {
SASSERT(m_semantics != free_e);
SASSERT(c.is_int());
if (is_zero(a))
return zero_pdd;
if (is_val(a)) {
rational r = val(a) / c;
SASSERT(r.is_int());
return imk_val(r);
}
if (c_pdd == null_pdd)
c_pdd = imk_val(c);
op_entry* e1 = pop_entry(a, c_pdd, pdd_div_const_op);
op_entry const* e2 = m_op_cache.insert_if_not_there(e1);
if (check_result(e1, e2, a, c_pdd, pdd_div_const_op))
return e2->m_result;
push(div_rec(lo(a), c, c_pdd));
push(div_rec(hi(a), c, c_pdd));
PDD r = make_node(level(a), read(2), read(1));
pop(2);
e1->m_result = r;
return r;
}
pdd pdd_manager::pow(pdd const &p, unsigned j) {
return pdd(pow(p.root, j), this);
}
pdd_manager::PDD pdd_manager::pow(PDD p, unsigned j) {
if (j == 0)
return one_pdd;
else if (j == 1)
return p;
else if (is_zero(p))
return zero_pdd;
else if (is_one(p))
return one_pdd;
else if (is_val(p))
return imk_val(power(val(p), j));
else
return pow_rec(p, j);
}
pdd_manager::PDD pdd_manager::pow_rec(PDD p, unsigned j) {
SASSERT(j > 0);
if (j == 1)
return p;
// j even: pow(p,2*j') = pow(p*p,j')
// j odd: pow(p,2*j'+1) = p*pow(p*p,j')
PDD q = pow_rec(apply(p, p, pdd_mul_op), j / 2);
if (j & 1) {
q = apply(q, p, pdd_mul_op);
}
return q;
}
//
// produce polynomial where a is reduced by b.
// all monomials in a that are divisible by lm(b)
@ -754,6 +845,15 @@ namespace dd {
e->m_rest = rest.root;
}
/**
* Apply function f to all coefficients of the polynomial.
* The function should be of type
* rational const& -> rational
* rational const& -> unsigned
* and should always return integers.
*
* NOTE: the operation is not cached.
*/
template <class Fn>
pdd pdd_manager::map_coefficients(pdd const& p, Fn f) {
if (p.is_val()) {
@ -803,17 +903,9 @@ namespace dd {
unsigned const j = std::min(max_pow2_divisor(a), max_pow2_divisor(c));
SASSERT(j != UINT_MAX); // should only be possible when both l and m are 0
rational const pow2j = rational::power_of_two(j);
auto div_pow2j = [&pow2j](rational const& r) -> rational {
rational result = r / pow2j;
SASSERT(result.is_int());
return result;
};
pdd aa = map_coefficients(a, div_pow2j);
pdd cc = map_coefficients(c, div_pow2j);
pdd vv = one();
for (unsigned deg = l - m; deg-- > 0; ) {
vv *= mk_var(v);
}
pdd const aa = div(a, pow2j);
pdd const cc = div(c, pow2j);
pdd vv = pow(mk_var(v), l - m);
r = b * cc - aa * d * vv;
return true;
}
@ -1366,9 +1458,11 @@ namespace dd {
pow = 1;
v_prev = v;
}
out << "v" << v_prev;
if (pow > 1)
out << "^" << pow;
if (v_prev != UINT_MAX) {
out << "v" << v_prev;
if (pow > 1)
out << "^" << pow;
}
}
if (first) out << "0";
return out;

View file

@ -63,7 +63,8 @@ namespace dd {
pdd_mul_op = 5,
pdd_reduce_op = 6,
pdd_subst_val_op = 7,
pdd_no_op = 8
pdd_div_const_op = 8,
pdd_no_op = 9
};
struct node {
@ -213,6 +214,9 @@ namespace dd {
PDD apply(PDD arg1, PDD arg2, pdd_op op);
PDD apply_rec(PDD arg1, PDD arg2, pdd_op op);
PDD minus_rec(PDD p);
PDD div_rec(PDD p, rational const& c, PDD c_pdd);
PDD pow(PDD p, unsigned j);
PDD pow_rec(PDD p, unsigned j);
PDD reduce_on_match(PDD a, PDD b);
bool lm_occurs(PDD p, PDD q) const;
@ -297,7 +301,7 @@ namespace dd {
struct mem_out {};
pdd_manager(unsigned nodes, semantics s = free_e, unsigned power_of_2 = 0);
pdd_manager(unsigned num_vars, semantics s = free_e, unsigned power_of_2 = 0);
~pdd_manager();
semantics get_semantics() const { return m_semantics; }
@ -317,6 +321,7 @@ namespace dd {
pdd sub(pdd const& a, pdd const& b);
pdd mul(pdd const& a, pdd const& b);
pdd mul(rational const& c, pdd const& b);
pdd div(pdd const& a, rational const& c);
pdd mk_or(pdd const& p, pdd const& q);
pdd mk_xor(pdd const& p, pdd const& q);
pdd mk_xor(pdd const& p, unsigned q);
@ -325,6 +330,7 @@ namespace dd {
pdd subst_val(pdd const& a, vector<std::pair<unsigned, rational>> const& s);
pdd subst_val(pdd const& a, unsigned v, rational const& val);
bool resolve(unsigned v, pdd const& p, pdd const& q, pdd& r);
pdd pow(pdd const& p, unsigned j);
bool is_linear(PDD p) { return degree(p) == 1; }
bool is_linear(pdd const& p);
@ -399,6 +405,8 @@ namespace dd {
pdd operator+(rational const& other) const { return m.add(other, *this); }
pdd operator~() const { return m.mk_not(*this); }
pdd rev_sub(rational const& r) const { return m.sub(m.mk_val(r), *this); }
pdd div(rational const& other) const { return m.div(*this, other); }
pdd pow(unsigned j) const { return m.pow(*this, j); }
pdd reduce(pdd const& other) const { return m.reduce(*this, other); }
bool different_leading_term(pdd const& other) const { return m.different_leading_term(*this, other); }
void factor(unsigned v, unsigned degree, pdd& lc, pdd& rest) const { m.factor(*this, v, degree, lc, rest); }
@ -433,8 +441,8 @@ namespace dd {
inline pdd operator+(int x, pdd const& b) { return b + rational(x); }
inline pdd operator+(pdd const& b, int x) { return b + rational(x); }
inline pdd operator^(unsigned x, pdd const& b) { return b + x; }
inline pdd operator^(bool x, pdd const& b) { return b + x; }
inline pdd operator^(unsigned x, pdd const& b) { return b ^ x; }
inline pdd operator^(bool x, pdd const& b) { return b ^ x; }
inline pdd operator-(rational const& r, pdd const& b) { return b.rev_sub(r); }
inline pdd operator-(int x, pdd const& b) { return rational(x) - b; }

View file

@ -438,9 +438,10 @@ public :
SASSERT(m.zero().max_pow2_divisor() == UINT_MAX);
SASSERT(m.one().max_pow2_divisor() == 0);
pdd p = (1 << 20) * a * b + 1024 * b * b * b;
pdd p = (1 << 20)*a*b + 1024*b*b*b;
std::cout << p << " divided by 2^" << p.max_pow2_divisor() << "\n";
SASSERT(p.max_pow2_divisor() == 10);
SASSERT(p.div(rational::power_of_two(10)) == 1024*a*b + b*b*b);
SASSERT((p + p).max_pow2_divisor() == 11);
SASSERT((p * p).max_pow2_divisor() == 20);
SASSERT((p + 2*b).max_pow2_divisor() == 1);
@ -492,6 +493,22 @@ public :
SASSERT(r == -(2*a*a*b*b - 2*a*a*b - 3*a*b*b + a*b*b*b + 4*b));
}
static void pow() {
std::cout << "pow\n";
pdd_manager m(4, pdd_manager::mod2N_e, 5);
unsigned const va = 0;
unsigned const vb = 1;
pdd const a = m.mk_var(va);
pdd const b = m.mk_var(vb);
SASSERT(a.pow(0) == m.one());
SASSERT(a.pow(1) == a);
SASSERT(a.pow(2) == a*a);
SASSERT(a.pow(7) == a*a*a*a*a*a*a);
SASSERT((3*a*b).pow(3) == 27*a*a*a*b*b*b);
}
};
}
@ -510,4 +527,5 @@ void tst_pdd() {
dd::test::factor();
dd::test::max_pow2_divisor();
dd::test::binary_resolve();
dd::test::pow();
}

View file

@ -157,6 +157,12 @@ public:
friend inline rational numerator(rational const & r) { rational result; m().get_numerator(r.m_val, result.m_val); return result; }
friend inline rational denominator(rational const & r) { rational result; m().get_denominator(r.m_val, result.m_val); return result; }
friend inline rational inv(rational const & r) {
rational result;
m().inv(r.m_val, result.m_val);
return result;
}
rational & operator+=(rational const & r) {
m().add(m_val, r.m_val, m_val);