3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-06-21 13:23:39 +00:00

Add more PDD utilities (div, pow) (#5180)

* Expose 'inv' on rationals to get reciprocal value

* Align parameter names with implementation

* Add cached operation that divides PDD by a constant

* Fix display for constant PDDs

* operator^ should probably call ^ instead of + (mk_xor instead of add)

* Add helper function 'pow' on PDDs
This commit is contained in:
Jakob Rath 2021-04-14 13:48:42 +02:00 committed by GitHub
parent 2f7069a8b7
commit 324d9ed461
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 146 additions and 20 deletions

View file

@ -202,7 +202,7 @@ namespace dd {
public: public:
struct mem_out {}; struct mem_out {};
bdd_manager(unsigned nodes); bdd_manager(unsigned num_vars);
~bdd_manager(); ~bdd_manager();
void set_max_num_nodes(unsigned n) { m_max_num_bdd_nodes = n; } void set_max_num_nodes(unsigned n) { m_max_num_bdd_nodes = n; }

View file

@ -443,6 +443,97 @@ namespace dd {
return r; return r;
} }
/**
* Divide PDD by a constant value.
*
* IMPORTANT: Performs regular numerical division.
* For semantics 'mod2N_e', this means that 'c' must be an integer
* and all coefficients of 'a' must be divisible by 'c'.
*
* NOTE: Why do we not just use 'mul(a, inv(c))' instead?
* In case of semantics 'mod2N_e', an invariant is that all PDDs have integer coefficients.
* But such a multiplication would create nodes with non-integral coefficients.
*/
pdd pdd_manager::div(pdd const& a, rational const& c) {
if (m_semantics == free_e) {
// Don't cache separately for the free semantics;
// use 'mul' so we can share results for a/c and a*(1/c).
return mul(inv(c), a);
}
SASSERT(c.is_int());
bool first = true;
SASSERT(well_formed());
scoped_push _sp(*this);
while (true) {
try {
return pdd(div_rec(a.root, c, null_pdd), this);
}
catch (const mem_out &) {
try_gc();
if (!first) throw;
first = false;
}
}
SASSERT(well_formed());
return pdd(zero_pdd, this);
}
pdd_manager::PDD pdd_manager::div_rec(PDD a, rational const& c, PDD c_pdd) {
SASSERT(m_semantics != free_e);
SASSERT(c.is_int());
if (is_zero(a))
return zero_pdd;
if (is_val(a)) {
rational r = val(a) / c;
SASSERT(r.is_int());
return imk_val(r);
}
if (c_pdd == null_pdd)
c_pdd = imk_val(c);
op_entry* e1 = pop_entry(a, c_pdd, pdd_div_const_op);
op_entry const* e2 = m_op_cache.insert_if_not_there(e1);
if (check_result(e1, e2, a, c_pdd, pdd_div_const_op))
return e2->m_result;
push(div_rec(lo(a), c, c_pdd));
push(div_rec(hi(a), c, c_pdd));
PDD r = make_node(level(a), read(2), read(1));
pop(2);
e1->m_result = r;
return r;
}
pdd pdd_manager::pow(pdd const &p, unsigned j) {
return pdd(pow(p.root, j), this);
}
pdd_manager::PDD pdd_manager::pow(PDD p, unsigned j) {
if (j == 0)
return one_pdd;
else if (j == 1)
return p;
else if (is_zero(p))
return zero_pdd;
else if (is_one(p))
return one_pdd;
else if (is_val(p))
return imk_val(power(val(p), j));
else
return pow_rec(p, j);
}
pdd_manager::PDD pdd_manager::pow_rec(PDD p, unsigned j) {
SASSERT(j > 0);
if (j == 1)
return p;
// j even: pow(p,2*j') = pow(p*p,j')
// j odd: pow(p,2*j'+1) = p*pow(p*p,j')
PDD q = pow_rec(apply(p, p, pdd_mul_op), j / 2);
if (j & 1) {
q = apply(q, p, pdd_mul_op);
}
return q;
}
// //
// produce polynomial where a is reduced by b. // produce polynomial where a is reduced by b.
// all monomials in a that are divisible by lm(b) // all monomials in a that are divisible by lm(b)
@ -754,6 +845,15 @@ namespace dd {
e->m_rest = rest.root; e->m_rest = rest.root;
} }
/**
* Apply function f to all coefficients of the polynomial.
* The function should be of type
* rational const& -> rational
* rational const& -> unsigned
* and should always return integers.
*
* NOTE: the operation is not cached.
*/
template <class Fn> template <class Fn>
pdd pdd_manager::map_coefficients(pdd const& p, Fn f) { pdd pdd_manager::map_coefficients(pdd const& p, Fn f) {
if (p.is_val()) { if (p.is_val()) {
@ -803,17 +903,9 @@ namespace dd {
unsigned const j = std::min(max_pow2_divisor(a), max_pow2_divisor(c)); unsigned const j = std::min(max_pow2_divisor(a), max_pow2_divisor(c));
SASSERT(j != UINT_MAX); // should only be possible when both l and m are 0 SASSERT(j != UINT_MAX); // should only be possible when both l and m are 0
rational const pow2j = rational::power_of_two(j); rational const pow2j = rational::power_of_two(j);
auto div_pow2j = [&pow2j](rational const& r) -> rational { pdd const aa = div(a, pow2j);
rational result = r / pow2j; pdd const cc = div(c, pow2j);
SASSERT(result.is_int()); pdd vv = pow(mk_var(v), l - m);
return result;
};
pdd aa = map_coefficients(a, div_pow2j);
pdd cc = map_coefficients(c, div_pow2j);
pdd vv = one();
for (unsigned deg = l - m; deg-- > 0; ) {
vv *= mk_var(v);
}
r = b * cc - aa * d * vv; r = b * cc - aa * d * vv;
return true; return true;
} }
@ -1366,9 +1458,11 @@ namespace dd {
pow = 1; pow = 1;
v_prev = v; v_prev = v;
} }
out << "v" << v_prev; if (v_prev != UINT_MAX) {
if (pow > 1) out << "v" << v_prev;
out << "^" << pow; if (pow > 1)
out << "^" << pow;
}
} }
if (first) out << "0"; if (first) out << "0";
return out; return out;

View file

@ -63,7 +63,8 @@ namespace dd {
pdd_mul_op = 5, pdd_mul_op = 5,
pdd_reduce_op = 6, pdd_reduce_op = 6,
pdd_subst_val_op = 7, pdd_subst_val_op = 7,
pdd_no_op = 8 pdd_div_const_op = 8,
pdd_no_op = 9
}; };
struct node { struct node {
@ -213,6 +214,9 @@ namespace dd {
PDD apply(PDD arg1, PDD arg2, pdd_op op); PDD apply(PDD arg1, PDD arg2, pdd_op op);
PDD apply_rec(PDD arg1, PDD arg2, pdd_op op); PDD apply_rec(PDD arg1, PDD arg2, pdd_op op);
PDD minus_rec(PDD p); PDD minus_rec(PDD p);
PDD div_rec(PDD p, rational const& c, PDD c_pdd);
PDD pow(PDD p, unsigned j);
PDD pow_rec(PDD p, unsigned j);
PDD reduce_on_match(PDD a, PDD b); PDD reduce_on_match(PDD a, PDD b);
bool lm_occurs(PDD p, PDD q) const; bool lm_occurs(PDD p, PDD q) const;
@ -297,7 +301,7 @@ namespace dd {
struct mem_out {}; struct mem_out {};
pdd_manager(unsigned nodes, semantics s = free_e, unsigned power_of_2 = 0); pdd_manager(unsigned num_vars, semantics s = free_e, unsigned power_of_2 = 0);
~pdd_manager(); ~pdd_manager();
semantics get_semantics() const { return m_semantics; } semantics get_semantics() const { return m_semantics; }
@ -317,6 +321,7 @@ namespace dd {
pdd sub(pdd const& a, pdd const& b); pdd sub(pdd const& a, pdd const& b);
pdd mul(pdd const& a, pdd const& b); pdd mul(pdd const& a, pdd const& b);
pdd mul(rational const& c, pdd const& b); pdd mul(rational const& c, pdd const& b);
pdd div(pdd const& a, rational const& c);
pdd mk_or(pdd const& p, pdd const& q); pdd mk_or(pdd const& p, pdd const& q);
pdd mk_xor(pdd const& p, pdd const& q); pdd mk_xor(pdd const& p, pdd const& q);
pdd mk_xor(pdd const& p, unsigned q); pdd mk_xor(pdd const& p, unsigned q);
@ -325,6 +330,7 @@ namespace dd {
pdd subst_val(pdd const& a, vector<std::pair<unsigned, rational>> const& s); pdd subst_val(pdd const& a, vector<std::pair<unsigned, rational>> const& s);
pdd subst_val(pdd const& a, unsigned v, rational const& val); pdd subst_val(pdd const& a, unsigned v, rational const& val);
bool resolve(unsigned v, pdd const& p, pdd const& q, pdd& r); bool resolve(unsigned v, pdd const& p, pdd const& q, pdd& r);
pdd pow(pdd const& p, unsigned j);
bool is_linear(PDD p) { return degree(p) == 1; } bool is_linear(PDD p) { return degree(p) == 1; }
bool is_linear(pdd const& p); bool is_linear(pdd const& p);
@ -399,6 +405,8 @@ namespace dd {
pdd operator+(rational const& other) const { return m.add(other, *this); } pdd operator+(rational const& other) const { return m.add(other, *this); }
pdd operator~() const { return m.mk_not(*this); } pdd operator~() const { return m.mk_not(*this); }
pdd rev_sub(rational const& r) const { return m.sub(m.mk_val(r), *this); } pdd rev_sub(rational const& r) const { return m.sub(m.mk_val(r), *this); }
pdd div(rational const& other) const { return m.div(*this, other); }
pdd pow(unsigned j) const { return m.pow(*this, j); }
pdd reduce(pdd const& other) const { return m.reduce(*this, other); } pdd reduce(pdd const& other) const { return m.reduce(*this, other); }
bool different_leading_term(pdd const& other) const { return m.different_leading_term(*this, other); } bool different_leading_term(pdd const& other) const { return m.different_leading_term(*this, other); }
void factor(unsigned v, unsigned degree, pdd& lc, pdd& rest) const { m.factor(*this, v, degree, lc, rest); } void factor(unsigned v, unsigned degree, pdd& lc, pdd& rest) const { m.factor(*this, v, degree, lc, rest); }
@ -433,8 +441,8 @@ namespace dd {
inline pdd operator+(int x, pdd const& b) { return b + rational(x); } inline pdd operator+(int x, pdd const& b) { return b + rational(x); }
inline pdd operator+(pdd const& b, int x) { return b + rational(x); } inline pdd operator+(pdd const& b, int x) { return b + rational(x); }
inline pdd operator^(unsigned x, pdd const& b) { return b + x; } inline pdd operator^(unsigned x, pdd const& b) { return b ^ x; }
inline pdd operator^(bool x, pdd const& b) { return b + x; } inline pdd operator^(bool x, pdd const& b) { return b ^ x; }
inline pdd operator-(rational const& r, pdd const& b) { return b.rev_sub(r); } inline pdd operator-(rational const& r, pdd const& b) { return b.rev_sub(r); }
inline pdd operator-(int x, pdd const& b) { return rational(x) - b; } inline pdd operator-(int x, pdd const& b) { return rational(x) - b; }

View file

@ -438,9 +438,10 @@ public :
SASSERT(m.zero().max_pow2_divisor() == UINT_MAX); SASSERT(m.zero().max_pow2_divisor() == UINT_MAX);
SASSERT(m.one().max_pow2_divisor() == 0); SASSERT(m.one().max_pow2_divisor() == 0);
pdd p = (1 << 20) * a * b + 1024 * b * b * b; pdd p = (1 << 20)*a*b + 1024*b*b*b;
std::cout << p << " divided by 2^" << p.max_pow2_divisor() << "\n"; std::cout << p << " divided by 2^" << p.max_pow2_divisor() << "\n";
SASSERT(p.max_pow2_divisor() == 10); SASSERT(p.max_pow2_divisor() == 10);
SASSERT(p.div(rational::power_of_two(10)) == 1024*a*b + b*b*b);
SASSERT((p + p).max_pow2_divisor() == 11); SASSERT((p + p).max_pow2_divisor() == 11);
SASSERT((p * p).max_pow2_divisor() == 20); SASSERT((p * p).max_pow2_divisor() == 20);
SASSERT((p + 2*b).max_pow2_divisor() == 1); SASSERT((p + 2*b).max_pow2_divisor() == 1);
@ -492,6 +493,22 @@ public :
SASSERT(r == -(2*a*a*b*b - 2*a*a*b - 3*a*b*b + a*b*b*b + 4*b)); SASSERT(r == -(2*a*a*b*b - 2*a*a*b - 3*a*b*b + a*b*b*b + 4*b));
} }
static void pow() {
std::cout << "pow\n";
pdd_manager m(4, pdd_manager::mod2N_e, 5);
unsigned const va = 0;
unsigned const vb = 1;
pdd const a = m.mk_var(va);
pdd const b = m.mk_var(vb);
SASSERT(a.pow(0) == m.one());
SASSERT(a.pow(1) == a);
SASSERT(a.pow(2) == a*a);
SASSERT(a.pow(7) == a*a*a*a*a*a*a);
SASSERT((3*a*b).pow(3) == 27*a*a*a*b*b*b);
}
}; };
} }
@ -510,4 +527,5 @@ void tst_pdd() {
dd::test::factor(); dd::test::factor();
dd::test::max_pow2_divisor(); dd::test::max_pow2_divisor();
dd::test::binary_resolve(); dd::test::binary_resolve();
dd::test::pow();
} }

View file

@ -157,6 +157,12 @@ public:
friend inline rational numerator(rational const & r) { rational result; m().get_numerator(r.m_val, result.m_val); return result; } friend inline rational numerator(rational const & r) { rational result; m().get_numerator(r.m_val, result.m_val); return result; }
friend inline rational denominator(rational const & r) { rational result; m().get_denominator(r.m_val, result.m_val); return result; } friend inline rational denominator(rational const & r) { rational result; m().get_denominator(r.m_val, result.m_val); return result; }
friend inline rational inv(rational const & r) {
rational result;
m().inv(r.m_val, result.m_val);
return result;
}
rational & operator+=(rational const & r) { rational & operator+=(rational const & r) {
m().add(m_val, r.m_val, m_val); m().add(m_val, r.m_val, m_val);