diff --git a/src/math/lp/cross_nested.h b/src/math/lp/cross_nested.h index 787c931e0..2fa1dd41d 100644 --- a/src/math/lp/cross_nested.h +++ b/src/math/lp/cross_nested.h @@ -41,9 +41,8 @@ class cross_nested { bool m_done; std::unordered_map m_occurences_map; std::unordered_map m_powers; - vector m_allocated; - vector m_b_vec; - vector m_b_split_vec; + ptr_vector m_allocated; + ptr_vector m_b_split_vec; public: cross_nested(std::function call_on_result, std::function var_is_fixed): @@ -56,7 +55,7 @@ public: m_e = e; vector front; - explore_expr_on_front_elem(m_e, front); + explore_expr_on_front_elem(&m_e, front); } static nex** pop_front(vector& front) { @@ -80,20 +79,21 @@ public: add_children(r, es ...); } - nex_sum* mk_sum(const vector& v) { + nex_sum* mk_sum(const ptr_vector& v) { auto r = new nex_sum(); m_allocated.push_back(r); r->children() = v; return r; } - nex_mul* mk_mul(const vector& v) { + nex_mul* mk_mul(const ptr_vector& v) { auto r = new nex_mul(); m_allocated.push_back(r); r->children() = v; return r; } + template nex_sum* mk_sum(K e, Args... es) { auto r = new nex_sum(); @@ -134,27 +134,27 @@ public: SASSERT((a->is_mul() && a->contains(j)) || (a->is_var() && to_var(a)->var() == j)); if (a->is_var()) return mk_scalar(rational(1)); - m_b_vec.clear(); + ptr_vector bv; bool seenj = false; for (nex* c : to_mul(a)->children()) { if (!seenj) { if (c->contains(j)) { if (!c->is_var()) - m_b_vec.push_back(mk_div(c, j)); + bv.push_back(mk_div(c, j)); seenj = true; continue; } } - m_b_vec.push_back(c); + bv.push_back(c); } - if (m_b_vec.size() > 1) { - return mk_mul(m_b_vec); + if (bv.size() > 1) { + return mk_mul(bv); } - if (m_b_vec.size() == 1) { - return m_b_vec[0]; + if (bv.size() == 1) { + return bv[0]; } - SASSERT(m_b_vec.size() == 0); + SASSERT(bv.size() == 0); return mk_scalar(rational(1)); } @@ -255,20 +255,20 @@ public: return false; } - bool proceed_with_common_factor(nex*& c, vector& front, const vector> & occurences) { - TRACE("nla_cn", tout << "c=" << *c << "\n";); - nex* f = extract_common_factor(c, occurences); + bool proceed_with_common_factor(nex** c, vector& front, const vector> & occurences) { + TRACE("nla_cn", tout << "c=" << **c << "\n";); + nex* f = extract_common_factor(*c, occurences); if (f == nullptr) { TRACE("nla_cn", tout << "no common factor\n"; ); return false; } - nex* c_over_f = mk_div(c, f); + nex* c_over_f = mk_div(*c, f); to_sum(c_over_f)->simplify(); - c = mk_mul(f, c_over_f); - TRACE("nla_cn", tout << "common factor=" << *f << ", c=" << *c << "\ne = " << *m_e << "\n";); + *c = mk_mul(f, c_over_f); + TRACE("nla_cn", tout << "common factor=" << *f << ", c=" << **c << "\ne = " << *m_e << "\n";); - explore_expr_on_front_elem(c_over_f, front); + explore_expr_on_front_elem(&(*((*c)->children_ptr()))[1], front); return true; } @@ -290,11 +290,11 @@ public: *(front[i]) = copy[i]; } - void explore_expr_on_front_elem_occs(nex* &c, vector& front, const vector> & occurences) { + void explore_expr_on_front_elem_occs(nex** c, vector& front, const vector> & occurences) { if (proceed_with_common_factor(c, front, occurences)) return; TRACE("nla_cn", tout << "save c=" << *c << "; front:"; print_front(front, tout) << "\n";); - nex* copy_of_c = c; + nex* copy_of_c = *c; auto copy_of_front = copy_front(front); for(auto& p : occurences) { SASSERT(p.second.m_occs > 1); @@ -308,11 +308,11 @@ public: explore_of_expr_on_sum_and_var(c, j, front); if (m_done) return; - TRACE("nla_cn", tout << "before restore c=" << *c << ", m_e=" << *m_e << "\n";); - c = copy_of_c; - TRACE("nla_cn", tout << "after restore c=" << *c << ", m_e=" << *m_e << "\n";); + TRACE("nla_cn", tout << "before restore c=" << **c << ", m_e=" << *m_e << "\n";); + *c = copy_of_c; + TRACE("nla_cn", tout << "after restore c=" << **c << ", m_e=" << *m_e << "\n";); restore_front(copy_of_front, front); - TRACE("nla_cn", tout << "restore c=" << *c << "\n";); + TRACE("nla_cn", tout << "restore c=" << **c << "\n";); TRACE("nla_cn", tout << "m_e=" << *m_e << "\n";); } } @@ -328,9 +328,9 @@ public: return out; } - void explore_expr_on_front_elem(nex*& c, vector& front) { - auto occurences = get_mult_occurences(to_sum(c)); - TRACE("nla_cn", tout << "m_e=" << *m_e << "\nc=" << *c << ", c occurences="; + void explore_expr_on_front_elem(nex** c, vector& front) { + auto occurences = get_mult_occurences(to_sum(*c)); + TRACE("nla_cn", tout << "m_e=" << *m_e << "\nc=" << **c << ", c occurences="; dump_occurences(tout, occurences) << "; front:"; print_front(front, tout) << "\n";); if (occurences.empty()) { @@ -338,7 +338,7 @@ public: TRACE("nla_cn", tout << "got the cn form: =" << *m_e << "\n";); m_done = m_call_on_result(m_e); } else { - nex* f = *pop_front(front); + nex** f = pop_front(front); explore_expr_on_front_elem(f, front); } } else { @@ -360,14 +360,14 @@ public: } // c is the sub expressiond which is going to be changed from sum to the cross nested form // front will be explored more - void explore_of_expr_on_sum_and_var(nex*& c, lpvar j, vector front) { - TRACE("nla_cn", tout << "m_e=" << *m_e << "\nc=" << *c << "\nj = " << ch(j) << "\nfront="; print_front(front, tout) << "\n";); - if (!split_with_var(c, j, front)) + void explore_of_expr_on_sum_and_var(nex** c, lpvar j, vector front) { + TRACE("nla_cn", tout << "m_e=" << *m_e << "\nc=" << **c << "\nj = " << ch(j) << "\nfront="; print_front(front, tout) << "\n";); + if (!split_with_var(*c, j, front)) return; - TRACE("nla_cn", tout << "after split c=" << *c << "\nfront="; print_front(front, tout) << "\n";); + TRACE("nla_cn", tout << "after split c=" << **c << "\nfront="; print_front(front, tout) << "\n";); SASSERT(front.size()); auto n = pop_front(front); - explore_expr_on_front_elem(*n, front); + explore_expr_on_front_elem(n, front); } void add_var_occs(lpvar j) { diff --git a/src/math/lp/nla_expr.h b/src/math/lp/nla_expr.h index 0d6319eff..dba3cef5c 100644 --- a/src/math/lp/nla_expr.h +++ b/src/math/lp/nla_expr.h @@ -71,7 +71,11 @@ public: virtual bool contains(lpvar j) const { return false; } virtual int get_degree() const = 0; virtual void simplify() {} - virtual const vector * children_ptr() const { + virtual const ptr_vector * children_ptr() const { + UNREACHABLE(); + return nullptr; + } + virtual ptr_vector * children_ptr() { UNREACHABLE(); return nullptr; } @@ -112,8 +116,8 @@ public: }; -static void promote_children_by_type(vector * children, expr_type t) { - svector to_promote; +static void promote_children_by_type(ptr_vector * children, expr_type t) { + ptr_vector to_promote; for(unsigned j = 0; j < children->size(); j++) { nex* e = (*children)[j]; e->simplify(); @@ -125,24 +129,25 @@ static void promote_children_by_type(vector * children, expr_type t) { (*children)[j - offset] = e; } } - for (nex *e : to_promote) { - for (nex *ee : *(e->children_ptr())) { - children->push_back(ee); - } - } } - + + for (nex *e : to_promote) { + for (nex *ee : *(e->children_ptr())) { + children->push_back(ee); + } + } } class nex_mul : public nex { - vector m_children; + ptr_vector m_children; public: nex_mul() {} unsigned size() const { return m_children.size(); } expr_type type() const { return expr_type::MUL; } - vector& children() { return m_children;} - const vector& children() const { return m_children;} - const vector* children_ptr() const { return &m_children;} + ptr_vector& children() { return m_children;} + const ptr_vector& children() const { return m_children;} + const ptr_vector* children_ptr() const { return &m_children;} + ptr_vector* children_ptr() { return &m_children;} std::ostream & print(std::ostream& out) const { bool first = true; @@ -217,13 +222,14 @@ public: class nex_sum : public nex { - vector m_children; + ptr_vector m_children; public: nex_sum() {} expr_type type() const { return expr_type::SUM; } - vector& children() { return m_children;} - const vector& children() const { return m_children;} - const vector* children_ptr() const { return &m_children;} + ptr_vector& children() { return m_children;} + const ptr_vector& children() const { return m_children;} + const ptr_vector* children_ptr() const { return &m_children;} + ptr_vector* children_ptr() { return &m_children;} unsigned size() const { return m_children.size(); } // we need a linear combination of at least two variables diff --git a/src/test/lp/lp.cpp b/src/test/lp/lp.cpp index e72f4f9a7..d73c59a04 100644 --- a/src/test/lp/lp.cpp +++ b/src/test/lp/lp.cpp @@ -101,8 +101,9 @@ void test_cn() { nex* eae = cn.mk_mul(e, a, e); nex* eac = cn.mk_mul(e, a, c); nex* ed = cn.mk_mul(e, d); - - test_cn_on_expr(cn.mk_sum(aad, abcd, aaccd, add, eae, eac, ed), cn); + nex* _6aad = cn.mk_mul(cn.mk_scalar(rational(6)), a, a, d); + // test_cn_on_expr(cn.mk_sum(aad, abcd, aaccd, add, eae, eac, ed), cn); + test_cn_on_expr(cn.mk_sum(_6aad, abcd, aaccd, add, eae, eac, ed), cn); // TRACE("nla_cn", tout << "done\n";); // test_cn_on_expr(a*b*d + a*b*c + c*b*d + a*c*d); // TRACE("nla_cn", tout << "done\n";);