diff --git a/src/math/lp/cross_nested.h b/src/math/lp/cross_nested.h index 1ef4832f6..eee3928f6 100644 --- a/src/math/lp/cross_nested.h +++ b/src/math/lp/cross_nested.h @@ -40,9 +40,62 @@ public: return c; } + struct occ { + unsigned m_occs; + unsigned m_power; + occ() : m_occs(0), m_power(0) {} + occ(unsigned k, unsigned p) : m_occs(k), m_power(p) {} + // use the "name injection rule here" + friend std::ostream& operator<<(std::ostream& out, const occ& c) { + out << "(occs:" << c.m_occs <<", pow:" << c.m_power << ")"; + return out; + } + }; + + bool proceed_with_common_factor(nex* c, vector& front, const std::unordered_map & occurences) { + TRACE("nla_cn", tout << "c=" << *c << "\n";); + SASSERT(c->is_sum()); + auto f = nex::mul(); + unsigned size = c->children().size(); + for(const auto & p : occurences) { + if (p.second.m_occs == size) { + unsigned pow = p.second.m_power; + while (pow --) { + f *= nex::var(p.first); + } + } + } + if (f.children().size() == 0) return false; + *c /= f; + f.simplify(); + * c = nex::mul(f, *c); + TRACE("nla_cn", tout << "common factor=" << f << ", c=" << *c << "\n";); + cross_nested_of_expr_on_front_elem(&(c->children()[1]), front); + return true; + } + + void cross_nested_of_expr_on_front_elem_occs(nex* c, vector& front, const std::unordered_map & occurences) { + if (proceed_with_common_factor(c, front, occurences)) + return; + TRACE("nla_cn", tout << "save c=" << *c << "front:"; print_vector_of_ptrs(front, tout) << "\n";); + nex copy_of_c = *c; + vector copy_of_front; + for (nex* n: front) + copy_of_front.push_back(*n); + for(auto& p : occurences) { + SASSERT(p.second.m_occs > 1); + lpvar j = p.first; + cross_nested_of_expr_on_sum_and_var(c, j, front); + *c = copy_of_c; + TRACE("nla_cn", tout << "restore c=" << *c << ", m_e=" << m_e << "\n";); + for (unsigned i = 0; i < front.size(); i++) + *(front[i]) = copy_of_front[i]; + } + } + void cross_nested_of_expr_on_front_elem(nex* c, vector& front) { SASSERT(c->is_sum()); - vector occurences = get_mult_occurences(*c); + auto occurences = get_mult_occurences(*c); TRACE("nla_cn", tout << "m_e=" << m_e << "\nc=" << *c << "\noccurences="; print_vector(occurences, tout) << "\nfront:"; print_vector_of_ptrs(front, tout) << "\n";); if (occurences.empty()) { @@ -63,18 +116,7 @@ public: cross_nested_of_expr_on_front_elem(c, front); } } else { - TRACE("nla_cn", tout << "save c=" << *c << "front:"; print_vector_of_ptrs(front, tout) << "\n";); - nex copy_of_c = *c; - vector copy_of_front; - for (nex* n: front) - copy_of_front.push_back(*n); - for(lpvar j : occurences) { - cross_nested_of_expr_on_sum_and_var(c, j, front); - *c = copy_of_c; - TRACE("nla_cn", tout << "restore c=" << *c << ", m_e=" << m_e << "\n";); - for (unsigned i = 0; i < front.size(); i++) - *(front[i]) = copy_of_front[i]; - } + cross_nested_of_expr_on_front_elem_occs(c, front, occurences); } } // e is the global expression, c is the sub expressiond which is going to changed from sum to the cross nested form @@ -87,60 +129,67 @@ public: cross_nested_of_expr_on_front_elem(n, front); } while (!front.empty()); } - void process_var_occurences(lpvar j, std::unordered_set& seen, std::unordered_map& occurences) const { - if (seen.find(j) != seen.end()) return; - seen.insert(j); + static void process_var_occurences(lpvar j, std::unordered_map& occurences) { auto it = occurences.find(j); - if (it == occurences.end()) - occurences[j] = 1; - else - it->second ++; + if (it != occurences.end()) { + it->second.m_occs++; + it->second.m_power = 1; + } else { + occurences.insert(std::make_pair(j, occ(1, 1))); + } } - void process_mul_occurences(const nex& e, std::unordered_set& seen, std::unordered_map& occurences) const { - SASSERT(e.type() == expr_type::MUL); - for (const auto & ce : e.children()) { - if (ce.type() == expr_type::VAR) { - process_var_occurences(ce.var(), seen, occurences); - } else if (ce.type() == expr_type::MUL){ - process_mul_occurences(ce, seen, occurences); - } + static void dump_occurences(std::ostream& out, const std::unordered_map& occurences) { + out << "{"; + for(const auto& p: occurences) { + const occ& o = p.second; + out << "(v" << p.first << "->" << o << ")"; + } + out << "}" << std::endl; + } + + + static void update_occurences_with_powers(std::unordered_map& occurences, + const std::unordered_map& powers) { + for (auto & p : powers) { + lpvar j = p.first; + unsigned jp = p.second; + auto it = occurences.find(j); + if (it == occurences.end()) { + occurences[j] = occ(1, jp); + } else { + it->second.m_occs++; + it->second.m_power = std::min(it->second.m_power, jp); + } } } - - // j -> the number of expressions j appears in as a multiplier - vector get_mult_occurences(const nex& e) const { - std::unordered_map occurences; + static void remove_singular_occurences(std::unordered_map& occurences) { + svector r; + for (const auto & p : occurences) { + if (p.second.m_occs <= 1) { + r.push_back(p.first); + } + } + for (lpvar j : r) + occurences.erase(j); + } + // j -> the number of expressions j appears in as a multiplier, get the max degree as well + static std::unordered_map get_mult_occurences(const nex& e) { + std::unordered_map occurences; SASSERT(e.type() == expr_type::SUM); for (const auto & ce : e.children()) { std::unordered_set seen; - if (ce.type() == expr_type::MUL) { - for (const auto & cce : ce.children()) { - if (cce.type() == expr_type::VAR) { - process_var_occurences(cce.var(), seen, occurences); - } else if (cce.type() == expr_type::MUL) { - process_mul_occurences(cce, seen, occurences); - } else { - continue; - } - } + if (ce.is_mul()) { + auto powers = ce.get_powers_from_mul(); + update_occurences_with_powers(occurences, powers); } else if (ce.type() == expr_type::VAR) { - process_var_occurences(ce.var(), seen, occurences); + process_var_occurences(ce.var(), occurences); } } - TRACE("nla_cn_details", - tout << "{"; - for(auto p: occurences) { - tout << "(v" << p.first << "->" << p.second << ")"; - } - tout << "}" << std::endl;); - vector r; - for(auto p: occurences) { - if (p.second > 1) - r.push_back(p.first); - } - return r; + remove_singular_occurences(occurences); + TRACE("nla_cn_details", dump_occurences(tout, occurences);); + return occurences; } bool can_be_cross_nested_more(const nex& e) const { switch (e.type()) { diff --git a/src/math/lp/nla_expr.h b/src/math/lp/nla_expr.h index c14cfcf26..3e598cd68 100644 --- a/src/math/lp/nla_expr.h +++ b/src/math/lp/nla_expr.h @@ -74,19 +74,23 @@ class nla_expr { std::sort(m_order.begin(), m_order.end(), [this](unsigned i, unsigned j) { return m_es[i] < m_es[j]; }); } bool operator<(const sorted_children& e) const { - return compare(e) == -1; + return compare(e) < 0; } int compare(const sorted_children& e) const { unsigned m = std::min(size(), e.size()); for (unsigned j = 0; j < m; j++) { int r = m_es[m_order[j]].compare(e.m_es[e.m_order[j]]); - if (r == -1) - return true; - if (r == 1) - return false; + TRACE("nla_cn_details", tout << "r=" << r << "\n";); + if (r) + return r; } - return size() < e.size(); + return static_cast(size()) - static_cast(e.size()); + } + void reset_order() { + m_order.clear(); + for( unsigned i = 0; i < m_es.size(); i++) + m_order.push_back(i); } }; @@ -100,6 +104,7 @@ public: bool is_var() const { return m_type == expr_type::VAR; } bool is_mul() const { return m_type == expr_type::MUL; } bool is_undef() const { return m_type == expr_type::UNDEF; } + bool is_scalar() const { return m_type == expr_type::SCALAR; } lpvar var() const { SASSERT(m_type == expr_type::VAR); return m_j; } expr_type type() const { return m_type; } expr_type& type() { return m_type; } @@ -281,6 +286,10 @@ public: return r; } + static nla_expr mul() { + return nla_expr(expr_type::MUL); + } + static nla_expr mul(const T& v, lpvar j) { if (v == 1) return var(j); @@ -312,9 +321,9 @@ public: } int compare(const nla_expr& e) const { - if (type() != (e.type())) + TRACE("nla_cn_details", tout << "this="<<*this<<", e=" << e << "\n";); + if (type() != e.type()) return (int)type() - (int)(e.type()); - SASSERT(type() == (e.type())); switch(m_type) { case expr_type::SUM: @@ -322,7 +331,7 @@ public: return m_children.compare(e.m_children); case expr_type::VAR: - return m_j - e.m_j; + return static_cast(m_j) - static_cast(e.m_j); case expr_type::SCALAR: return m_v < e.m_v? -1 : (m_v == e.m_v? 0 : 1); default: @@ -332,6 +341,7 @@ public: } bool operator<(const nla_expr& e) const { + TRACE("nla_cn_details", tout << "this=" << *this << ", e=" << e << "\n";); if (type() != (e.type())) return (int)type() < (int)(e.type()); @@ -366,6 +376,64 @@ public: return *this; } + std::unordered_map get_powers_from_mul() const { + SASSERT(is_mul()); + std::unordered_map r; + for (const auto & c : children()) { + if (!c.is_var()) { + continue; + } + lpvar j = c.var(); + auto it = r.find(j); + if (it == r.end()) { + r[j] = 1; + } else { + it->second++; + } + } + return r; + } + + + nla_expr& operator/=(const nla_expr& b) { + SASSERT(b.is_mul()); + if (is_sum()) { + for (auto & e : children()) { + e /= b; + } + return *this; + } + SASSERT(is_mul()); + auto powers = b.get_powers_from_mul(); + unsigned i = 0, k = 0; + for (; i < children().size(); i++, k++) { + auto & e = children()[i]; + if (!e.is_var()) { + SASSERT(e.is_scalar()); + continue; + } + lpvar j = e.var(); + auto it = powers.find(j); + if (it == powers.end()) { + if (i != k) + children()[k] = children()[i]; + } else { + it->second --; + if (it->second == 0) + powers.erase(it); + k--; + } + } + + while(k ++ < i) + children().pop_back(); + + s_children().reset_order(); + + return *this; + } + + nla_expr& operator+=(const nla_expr& b) { if (is_sum()) { if (b.is_sum()) {