diff --git a/src/math/dd/dd_pdd.cpp b/src/math/dd/dd_pdd.cpp index 0c8aa520d..fdfa7ec41 100644 --- a/src/math/dd/dd_pdd.cpp +++ b/src/math/dd/dd_pdd.cpp @@ -50,6 +50,7 @@ namespace dd { void pdd_manager::reset(unsigned_vector const& level2var) { reset_op_cache(); + m_factor_cache.reset(); m_node_table.reset(); m_nodes.reset(); m_free_nodes.reset(); @@ -691,27 +692,36 @@ namespace dd { * factor p into lc*v^degree + rest * such that degree(rest, v) < degree * Initial implementation is very naive - * - memoize intermediary results */ void pdd_manager::factor(pdd const& p, unsigned v, unsigned degree, pdd& lc, pdd& rest) { unsigned level_v = m_var2level[v]; if (degree == 0) { lc = p; rest = zero(); + return; } - else if (level(p.root) < level_v) { + if (level(p.root) < level_v) { lc = zero(); rest = p; + return; } - else if (level(p.root) > level_v) { + // Memoize nontrivial cases + auto* et = m_factor_cache.insert_if_not_there2({p.root, v, degree}); + factor_entry& e = et->get_data(); + if (e.is_valid()) { + lc = pdd(e.m_lc, this); + rest = pdd(e.m_rest, this); + return; + } + if (level(p.root) > level_v) { pdd lc1 = zero(), rest1 = zero(); pdd vv = mk_var(p.var()); factor(p.hi(), v, degree, lc, rest); factor(p.lo(), v, degree, lc1, rest1); - lc += lc1; - rest += rest1; lc *= vv; rest *= vv; + lc += lc1; + rest += rest1; } else { unsigned d = 0; @@ -733,6 +743,8 @@ namespace dd { rest = p; } } + e.m_lc = lc.root; + e.m_rest = rest.root; } @@ -1169,6 +1181,8 @@ namespace dd { m_op_cache.insert(e); } + m_factor_cache.reset(); + m_node_table.reset(); // re-populate node cache for (unsigned i = m_nodes.size(); i-- > 2; ) { diff --git a/src/math/dd/dd_pdd.h b/src/math/dd/dd_pdd.h index 79ee3cb5b..1ae6ba02f 100644 --- a/src/math/dd/dd_pdd.h +++ b/src/math/dd/dd_pdd.h @@ -140,9 +140,44 @@ namespace dd { typedef ptr_hashtable op_table; + struct factor_entry { + factor_entry(PDD p, unsigned v, unsigned degree): + m_p(p), + m_v(v), + m_degree(degree), + m_lc(UINT_MAX), + m_rest(UINT_MAX) + {} + + factor_entry(): m_p(0), m_v(0), m_degree(0), m_lc(UINT_MAX), m_rest(UINT_MAX) {} + + PDD m_p; // input + unsigned m_v; // input + unsigned m_degree; // input + PDD m_lc; // output + PDD m_rest; // output + + bool is_valid() { return m_lc != UINT_MAX; } + + unsigned hash() const { return mk_mix(m_p, m_v, m_degree); } + }; + + struct hash_factor_entry { + unsigned operator()(factor_entry const& e) const { return e.hash(); } + }; + + struct eq_factor_entry { + bool operator()(factor_entry const& a, factor_entry const& b) const { + return a.m_p == b.m_p && a.m_v == b.m_v && a.m_degree == b.m_degree; + } + }; + + typedef hashtable factor_table; + svector m_nodes; vector m_values; op_table m_op_cache; + factor_table m_factor_cache; node_table m_node_table; mpq_table m_mpq_table; svector m_pdd_stack; @@ -361,7 +396,7 @@ namespace dd { pdd rev_sub(rational const& r) const { return m.sub(m.mk_val(r), *this); } pdd reduce(pdd const& other) const { return m.reduce(*this, other); } bool different_leading_term(pdd const& other) const { return m.different_leading_term(*this, other); } - void factor(unsigned v, unsigned degree, pdd& lc, pdd& rest) { m.factor(*this, v, degree, lc, rest); } + void factor(unsigned v, unsigned degree, pdd& lc, pdd& rest) const { m.factor(*this, v, degree, lc, rest); } pdd subst_val(vector> const& s) const { return m.subst_val(*this, s); } pdd subst_val(unsigned v, rational const& val) const { return m.subst_val(*this, v, val); } diff --git a/src/test/pdd.cpp b/src/test/pdd.cpp index 8142cb311..8079d5419 100644 --- a/src/test/pdd.cpp +++ b/src/test/pdd.cpp @@ -381,6 +381,52 @@ public : } } + static void factor() { + std::cout << "factor\n"; + pdd_manager m(4, pdd_manager::mod2N_e, 3); + + unsigned const va = 0; + unsigned const vb = 1; + unsigned const vc = 2; + unsigned const vd = 3; + pdd const a = m.mk_var(va); + pdd const b = m.mk_var(vb); + pdd const c = m.mk_var(vc); + pdd const d = m.mk_var(vd); + + auto test_one = [&m](pdd const& p, unsigned v, unsigned d) { + pdd lc = m.zero(); + pdd rest = m.zero(); + std::cout << "Factoring p = " << p << " by v" << v << "^" << d << "\n"; + p.factor(v, d, lc, rest); + std::cout << " lc = " << lc << "\n"; + std::cout << " rest = " << rest << "\n"; + pdd x = m.mk_var(v); + pdd x_pow_d = m.one(); + for (unsigned i = 0; i < d; ++i) { + x_pow_d *= x; + } + SASSERT( p == lc * x_pow_d + rest ); + SASSERT( d == 0 || rest.degree(v) < d ); + SASSERT( d != 0 || rest.is_zero() ); + }; + + auto test_multiple = [=](pdd const& p) { + for (auto v : {va, vb, vc, vd}) { + for (unsigned d = 0; d <= 5; ++d) { + test_one(p, v, d); + } + } + }; + + test_multiple( b ); + test_multiple( b*b*b ); + test_multiple( b + c ); + test_multiple( a*a*a*a*a + a*a*a*b + a*a*b*b + c ); + test_multiple( c*c*c*c*c + b*b*b*c + 3*b*c*c + a ); + test_multiple( (a + b) * (b + c) * (c + d) * (d + a) ); + } + }; } @@ -396,4 +442,5 @@ void tst_pdd() { dd::test::order_lm(); dd::test::mod4_operations(); dd::test::degree_of_variables(); + dd::test::factor(); }