diff --git a/src/math/dd/dd_pdd.cpp b/src/math/dd/dd_pdd.cpp index 455fc4039..8604af3fb 100644 --- a/src/math/dd/dd_pdd.cpp +++ b/src/math/dd/dd_pdd.cpp @@ -1311,6 +1311,78 @@ namespace dd { return m.mk_var(var())*h + l; } + std::pair pdd::var_factors() { + if (is_val()) + return { unsigned_vector(), *this }; + unsigned v = var(); + if (lo().is_val()) { + if (!lo().is_zero()) + return { unsigned_vector(), *this }; + auto [vars, p] = hi().var_factors(); + vars.push_back(v); + return {vars, p}; + } + auto [lo_vars, q] = lo().var_factors(); + if (lo_vars.empty()) + return { unsigned_vector(), *this }; + + unsigned_vector lo_and_hi; + auto merge = [&](unsigned_vector& lo_vars, unsigned_vector& hi_vars) { + unsigned ir = 0, jr = 0; + for (unsigned i = 0, j = 0; i < lo_vars.size() || j < hi_vars.size(); ) { + if (i == lo_vars.size()) { + hi_vars[jr++] = hi_vars[j++]; + continue; + } + if (j == hi_vars.size()) { + lo_vars[ir++] = lo_vars[i++]; + continue; + } + if (lo_vars[i] == hi_vars[j]) { + lo_and_hi.push_back(lo_vars[i]); + ++i; + ++j; + continue; + } + unsigned lvl_lo = m.m_var2level[lo_vars[i]]; + unsigned lvl_hi = m.m_var2level[hi_vars[j]]; + if (lvl_lo > lvl_hi) { + hi_vars[jr++] = hi_vars[j++]; + continue; + } + else { + lo_vars[ir++] = lo_vars[i++]; + continue; + } + } + lo_vars.shrink(ir); + hi_vars.shrink(jr); + }; + + auto mul = [&](unsigned_vector const& vars, pdd p) { + for (auto v : vars) + p *= m.mk_var(v); + return p; + }; + + auto [hi_vars, p] = hi().var_factors(); + if (lo_vars.back() == v) { + lo_vars.pop_back(); + merge(lo_vars, hi_vars); + lo_and_hi.push_back(v); + return { lo_and_hi, mul(lo_vars, q) + mul(hi_vars, p) }; + } + if (hi_vars.empty()) + return { unsigned_vector(), *this }; + + merge(lo_vars, hi_vars); + hi_vars.push_back(v); + if (lo_and_hi.empty()) + return { unsigned_vector(), *this }; + else + return { lo_and_hi, mul(lo_vars, q) + mul(hi_vars, p) }; + } + std::ostream& operator<<(std::ostream& out, pdd const& b) { return b.display(out); } diff --git a/src/math/dd/dd_pdd.h b/src/math/dd/dd_pdd.h index 13c6c0605..45d646c51 100644 --- a/src/math/dd/dd_pdd.h +++ b/src/math/dd/dd_pdd.h @@ -364,6 +364,11 @@ namespace dd { bool different_leading_term(pdd const& other) const { return m.different_leading_term(*this, other); } void factor(unsigned v, unsigned degree, pdd& lc, pdd& rest) { m.factor(*this, v, degree, lc, rest); } + /** + * \brief factor out variables + */ + std::pair var_factors(); + pdd subst_val(vector> const& s) const { return m.subst_val(*this, s); } pdd subst_val(unsigned v, rational const& val) const { return m.subst_val(*this, v, val); } diff --git a/src/test/pdd.cpp b/src/test/pdd.cpp index b0cbb657c..9139e499f 100644 --- a/src/test/pdd.cpp +++ b/src/test/pdd.cpp @@ -323,6 +323,62 @@ public : SASSERT(!(2*a*b + 3*b + 2).is_non_zero()); } + static void factors() { + pdd_manager m(3); + pdd v0 = m.mk_var(0); + pdd v1 = m.mk_var(1); + pdd v2 = m.mk_var(2); + pdd v3 = m.mk_var(3); + pdd v4 = m.mk_var(4); + pdd c1 = v0 * v1 * v2 + v2 * v0 + v1 + 1; + { + auto [vars, p] = c1.var_factors(); + VERIFY(p == c1 && vars.empty()); + } + { + auto q = c1 * v4; + auto [vars, p] = q.var_factors(); + std::cout << p << " " << vars << "\n"; + VERIFY(p == c1 && vars.size() == 1 && vars[0] == 4); + } + for (unsigned i = 0; i < 5; ++i) { + auto v = m.mk_var(i); + auto q = c1 * v; + std::cout << i << ": " << q << "\n"; + auto [vars, p] = q.var_factors(); + std::cout << p << " " << vars << "\n"; + VERIFY(p == c1 && vars.size() == 1 && vars[0] == i); + } + for (unsigned i = 0; i < 5; ++i) { + for (unsigned j = 0; j < 5; ++j) { + auto vi = m.mk_var(i); + auto vj = m.mk_var(j); + auto q = c1 * vi * vj; + auto [vars, p] = q.var_factors(); + std::cout << p << " " << vars << "\n"; + VERIFY(p == c1 && vars.size() == 2); + VERIFY(vars[0] == i || vars[1] == i); + VERIFY(vars[0] == j || vars[1] == j); + } + } + for (unsigned i = 0; i < 5; ++i) { + for (unsigned j = i; j < 5; ++j) { + for (unsigned k = j; k < 5; ++k) { + auto vi = m.mk_var(i); + auto vj = m.mk_var(j); + auto vk = m.mk_var(k); + auto q = c1 * vi * vj * vk; + auto [vars, p] = q.var_factors(); + std::cout << p << " " << vars << "\n"; + VERIFY(p == c1 && vars.size() == 3); + VERIFY(vars[0] == i || vars[1] == i || vars[2] == i); + VERIFY(vars[0] == j || vars[1] == j || vars[2] == j); + VERIFY(vars[0] == k || vars[1] == k || vars[2] == k); + } + } + } + } + }; } @@ -337,4 +393,5 @@ void tst_pdd() { dd::test::order(); dd::test::order_lm(); dd::test::mod4_operations(); + dd::test::factors(); }