diff --git a/src/math/lp/cross_nested.h b/src/math/lp/cross_nested.h index cda1c98cb..50f70b6f4 100644 --- a/src/math/lp/cross_nested.h +++ b/src/math/lp/cross_nested.h @@ -35,7 +35,7 @@ class cross_nested { }; // fields - nex_sum * m_e; + nex * m_e; std::function m_call_on_result; std::function m_var_is_fixed; bool m_done; @@ -43,6 +43,7 @@ class cross_nested { std::unordered_map m_powers; vector m_allocated; vector m_b_vec; + vector m_b_split_vec; public: cross_nested(std::function call_on_result, std::function var_is_fixed): @@ -51,16 +52,16 @@ public: m_done(false) {} - void run(nex_sum *e) { + void run(nex *e) { m_e = e; - vector front; + vector front; explore_expr_on_front_elem(m_e, front); } - static nex_sum* pop_back(vector& front) { - nex_sum* c = front.back(); - TRACE("nla_cn", tout << *c << "\n";); + static nex** pop_front(vector& front) { + nex** c = front.back(); + TRACE("nla_cn", tout << **c << "\n";); front.pop_back(); return c; } @@ -70,6 +71,14 @@ public: m_allocated.push_back(r); return r; } + template + void add_children(T) { } + + template + void add_children(T r, K e, Args ... es) { + r->add_child(e); + add_children(r, es ...); + } nex_sum* mk_sum(const vector& v) { auto r = new nex_sum(); @@ -78,40 +87,41 @@ public: return r; } - nex_sum* mk_sum(nex *a, nex* b) { - auto r = new nex_sum(); + nex_mul* mk_mul(const vector& v) { + auto r = new nex_mul(); m_allocated.push_back(r); - r->children().push_back(a); - r->children().push_back(b); + r->children() = v; return r; } + template + nex_sum* mk_sum(K e, Args... es) { + auto r = new nex_sum(); + m_allocated.push_back(r); + r->add_child(e); + add_children(r, es...); + return r; + } nex_var* mk_var(lpvar j) { auto r = new nex_var(j); m_allocated.push_back(r); return r; } - + nex_mul* mk_mul() { auto r = new nex_mul(); m_allocated.push_back(r); return r; } - nex_mul* mk_mul(nex * a, nex * b) { + template + nex_mul* mk_mul(K e, Args... es) { auto r = new nex_mul(); m_allocated.push_back(r); - r->add_child(a); r->add_child(b); + add_children(r, e, es...); return r; } - - nex_mul* mk_mul(nex * a, nex * b, nex *c) { - auto r = new nex_mul(); - m_allocated.push_back(r); - r->add_child(a); r->add_child(b); r->add_child(c); - return r; - } - + nex_scalar* mk_scalar(const rational& v) { auto r = new nex_scalar(v); m_allocated.push_back(r); @@ -120,8 +130,32 @@ public: nex * mk_div(const nex* a, lpvar j) { - SASSERT(false); - return nullptr; + TRACE("nla_cn_details", tout << "a=" << *a << ", v" << j << "\n";); + SASSERT((a->is_mul() && a->contains(j)) || (a->is_var() && to_var(a)->var() == j)); + if (a->is_var()) + return mk_scalar(rational(1)); + m_b_vec.clear(); + bool seenj = false; + for (nex* c : to_mul(a)->children()) { + if (!seenj) { + if (c->contains(j)) { + if (!c->is_var()) + m_b_vec.push_back(mk_div(c, j)); + seenj = true; + continue; + } + } + m_b_vec.push_back(c); + } + if (m_b_vec.size() > 1) { + return mk_mul(m_b_vec); + } + if (m_b_vec.size() == 1) { + return m_b_vec[0]; + } + + SASSERT(m_b_vec.size() == 0); + return mk_scalar(rational(1)); } nex * mk_div(const nex* a, const nex* b) { @@ -218,14 +252,14 @@ public: return false; } - bool proceed_with_common_factor(nex*& c, vector& front, const vector> & occurences) { + bool proceed_with_common_factor(nex*& c, vector& front, const vector> & occurences) { TRACE("nla_cn", tout << "c=" << *c << "\n";); nex* f = extract_common_factor(c, occurences); if (f == nullptr) return false; - nex_sum* c_over_f = to_sum(mk_div(c, f)); - c_over_f->simplify(); + nex* c_over_f = mk_div(c, f); + to_sum(c_over_f)->simplify(); c = mk_mul(f, c_over_f); TRACE("nla_cn", tout << "common factor=" << *f << ", c=" << *c << "\ne = " << *m_e << "\n";); @@ -233,28 +267,28 @@ public: return true; } - static void push(vector& front, nex_sum* e) { - TRACE("nla_cn", tout << *e << "\n";); + static void push(vector& front, nex** e) { + TRACE("nla_cn", tout << **e << "\n";); front.push_back(e); } - static vector copy_front(const vector& front) { - vector v; - for (nex_sum* n: front) - v.push_back(n); + static vector copy_front(const vector& front) { + vector v; + for (nex** n: front) + v.push_back(*n); return v; } - static void restore_front(const vector ©, vector& front) { + static void restore_front(const vector ©, vector& front) { SASSERT(copy.size() == front.size()); for (unsigned i = 0; i < front.size(); i++) - front[i] = copy[i]; + *(front[i]) = copy[i]; } - void explore_expr_on_front_elem_occs(nex* c, vector& front, const vector> & occurences) { + void explore_expr_on_front_elem_occs(nex* &c, vector& front, const vector> & occurences) { if (proceed_with_common_factor(c, front, occurences)) return; - TRACE("nla_cn", tout << "save c=" << *c << "; front:"; print_vector_of_ptrs(front, tout) << "\n";); + TRACE("nla_cn", tout << "save c=" << *c << "; front:"; print_front(front, tout) << "\n";); nex* copy_of_c = c; auto copy_of_front = copy_front(front); for(auto& p : occurences) { @@ -269,11 +303,12 @@ public: explore_of_expr_on_sum_and_var(c, j, front); if (m_done) return; + TRACE("nla_cn", tout << "before restore c=" << *c << ", m_e=" << *m_e << "\n";); c = copy_of_c; - TRACE("nla_cn", tout << "restore c=" << *c << ", m_e=" << m_e << "\n";); + TRACE("nla_cn", tout << "after restore c=" << *c << ", m_e=" << *m_e << "\n";); restore_front(copy_of_front, front); TRACE("nla_cn", tout << "restore c=" << *c << "\n";); - TRACE("nla_cn", tout << "m_e=" << m_e << "\n";); + TRACE("nla_cn", tout << "m_e=" << *m_e << "\n";); } } @@ -288,18 +323,18 @@ public: return out; } - void explore_expr_on_front_elem(nex_sum* c, vector& front) { - auto occurences = get_mult_occurences(c); - TRACE("nla_cn", tout << "m_e=" << m_e << "\nc=" << *c << ", c occurences="; - dump_occurences(tout, occurences) << "; front:"; print_vector_of_ptrs(front, tout) << "\n";); + void explore_expr_on_front_elem(nex*& c, vector& front) { + auto occurences = get_mult_occurences(to_sum(c)); + TRACE("nla_cn", tout << "m_e=" << *m_e << "\nc=" << *c << ", c occurences="; + dump_occurences(tout, occurences) << "; front:"; print_front(front, tout) << "\n";); if (occurences.empty()) { if(front.empty()) { - TRACE("nla_cn", tout << "got the cn form: =" << m_e << "\n";); + TRACE("nla_cn", tout << "got the cn form: =" << *m_e << "\n";); m_done = m_call_on_result(m_e); } else { - auto c = pop_back(front); - explore_expr_on_front_elem(c, front); + nex* f = *pop_front(front); + explore_expr_on_front_elem(f, front); } } else { explore_expr_on_front_elem_occs(c, front, occurences); @@ -311,15 +346,23 @@ public: return s.str(); // return (char)('a'+j); } - // e is the global expression, c is the sub expressiond which is going to changed from sum to the cross nested form - void explore_of_expr_on_sum_and_var(nex* & c, lpvar j, vector front) { - TRACE("nla_cn", tout << "m_e=" << m_e << "\nc=" << *c << "\nj = " << ch(j) << "\nfront="; print_vector_of_ptrs(front, tout) << "\n";); + + std::ostream& print_front(const vector& front, std::ostream& out) const { + for (auto e : front) { + out << **e << "\n"; + } + return out; + } + // c is the sub expressiond which is going to be changed from sum to the cross nested form + // front will be explored more + void explore_of_expr_on_sum_and_var(nex*& c, lpvar j, vector front) { + TRACE("nla_cn", tout << "m_e=" << *m_e << "\nc=" << *c << "\nj = " << ch(j) << "\nfront="; print_front(front, tout) << "\n";); if (!split_with_var(c, j, front)) return; - TRACE("nla_cn", tout << "after split c=" << *c << "\nfront="; print_vector_of_ptrs(front, tout) << "\n";); + TRACE("nla_cn", tout << "after split c=" << *c << "\nfront="; print_front(front, tout) << "\n";); SASSERT(front.size()); - auto n = pop_back(front); TRACE("nla_cn", tout << "n=" << *n <<"\n";); - explore_expr_on_front_elem(n, front); + auto n = pop_front(front); + explore_expr_on_front_elem(*n, front); } void add_var_occs(lpvar j) { @@ -378,7 +421,7 @@ public: } } remove_singular_occurences(); - TRACE("nla_cn_details", tout << "e=" << e << "\noccs="; dump_occurences(tout, m_occurences_map) << "\n";); + TRACE("nla_cn_details", tout << "e=" << *e << "\noccs="; dump_occurences(tout, m_occurences_map) << "\n";); vector> ret; for (auto & p : m_occurences_map) ret.push_back(p); @@ -405,54 +448,57 @@ public: void pre_split(nex_sum * e, lpvar j, nex_sum* & a, nex* & b) { a = mk_sum(); - m_b_vec.clear(); + m_b_split_vec.clear(); for (nex * ce: e->children()) { if (is_divisible_by_var(ce, j)) { a->add_child(mk_div(ce , j)); } else { - m_b_vec.push_back(ce); + m_b_split_vec.push_back(ce); + TRACE("nla_cn_details", tout << "ce = " << *ce << "\n";); + } } TRACE("nla_cn_details", tout << "a = " << *a << "\n";); - SASSERT(a->children().size() >= 2 && m_b_vec.size()); + SASSERT(a->children().size() >= 2 && m_b_split_vec.size()); a->simplify(); - if (m_b_vec.size() == 1) { - b = m_b_vec[0]; + if (m_b_split_vec.size() == 1) { + b = m_b_split_vec[0]; + TRACE("nla_cn_details", tout << "b = " << *b << "\n";); } else { - SASSERT(m_b_vec.size() > 1); - b = mk_sum(m_b_vec); - } + SASSERT(m_b_split_vec.size() > 1); + b = mk_sum(m_b_split_vec); + TRACE("nla_cn_details", tout << "b = " << *b << "\n";); + } } - void update_front_with_split_with_non_empty_b(nex* &e, lpvar j, vector & front, nex_sum* a, nex* b) { + void update_front_with_split_with_non_empty_b(nex* &e, lpvar j, vector & front, nex* a, nex* b) { SASSERT(a->is_sum()); - TRACE("nla_cn_details", tout << "b = " << b << "\n";); + TRACE("nla_cn_details", tout << "b = " << *b << "\n";); e = mk_sum(mk_mul(mk_var(j), a), b); // e = j*a + b - push(front, a); // pushing 'a' - TRACE("nla_cn", tout << "push to front " << *a << "\n";); + nex **ptr_to_a = &(to_mul(to_sum(e)->children()[0]))->children()[1]; + push(front, ptr_to_a); if (b->is_sum()) { - push(front, to_sum(b)); - TRACE("nla_cn", tout << "push to front " << *b << "\n";); + nex **ptr_to_a = &(to_sum(e)->children()[1]); + push(front, ptr_to_a); } } - void update_front_with_split(nex* & e, lpvar j, vector & front, nex_sum* a, nex* b) { + void update_front_with_split(nex* & e, lpvar j, vector & front, nex* a, nex* b) { if (b == nullptr) { e = mk_mul(mk_var(j), a); - push(front, a); - TRACE("nla_cn_details", tout << "push to front " << *a << "\n";); + push(front, &(to_mul(e)->children()[1])); } else { update_front_with_split_with_non_empty_b(e, j, front, a, b); } } // it returns true if the recursion brings a cross-nested form - bool split_with_var(nex*& e, lpvar j, vector & front) { + bool split_with_var(nex*& e, lpvar j, vector & front) { SASSERT(e->is_sum()); - TRACE("nla_cn", tout << "e = " << e << ", j=" << ch(j) << "\n";); + TRACE("nla_cn", tout << "e = " << *e << ", j=" << ch(j) << "\n";); nex_sum* a; nex * b; pre_split(to_sum(e), j, a, b); /* diff --git a/src/math/lp/nla_expr.h b/src/math/lp/nla_expr.h index f35a4fc9a..0d6319eff 100644 --- a/src/math/lp/nla_expr.h +++ b/src/math/lp/nla_expr.h @@ -70,6 +70,11 @@ public: virtual ~nex() {} virtual bool contains(lpvar j) const { return false; } virtual int get_degree() const = 0; + virtual void simplify() {} + virtual const vector * children_ptr() const { + UNREACHABLE(); + return nullptr; + } }; std::ostream& operator<<(std::ostream& out, const nex&); @@ -107,6 +112,28 @@ public: }; +static void promote_children_by_type(vector * children, expr_type t) { + svector to_promote; + for(unsigned j = 0; j < children->size(); j++) { + nex* e = (*children)[j]; + e->simplify(); + if (e->type() == t) { + to_promote.push_back(e); + } else { + unsigned offset = to_promote.size(); + if (offset) { + (*children)[j - offset] = e; + } + } + for (nex *e : to_promote) { + for (nex *ee : *(e->children_ptr())) { + children->push_back(ee); + } + } + } + +} + class nex_mul : public nex { vector m_children; public: @@ -115,6 +142,8 @@ public: expr_type type() const { return expr_type::MUL; } vector& children() { return m_children;} const vector& children() const { return m_children;} + const vector* children_ptr() const { return &m_children;} + std::ostream & print(std::ostream& out) const { bool first = true; for (const nex* v : m_children) { @@ -180,9 +209,13 @@ public: return degree; } + void simplify() { + promote_children_by_type(&m_children, expr_type::MUL); + } }; + class nex_sum : public nex { vector m_children; public: @@ -190,6 +223,7 @@ public: expr_type type() const { return expr_type::SUM; } vector& children() { return m_children;} const vector& children() const { return m_children;} + const vector* children_ptr() const { return &m_children;} unsigned size() const { return m_children.size(); } // we need a linear combination of at least two variables @@ -233,7 +267,7 @@ public: } void simplify() { - SASSERT(false); + promote_children_by_type(&m_children, expr_type::SUM); } int get_degree() const { diff --git a/src/test/lp/lp.cpp b/src/test/lp/lp.cpp index a0ff48209..e72f4f9a7 100644 --- a/src/test/lp/lp.cpp +++ b/src/test/lp/lp.cpp @@ -86,7 +86,6 @@ void test_cn() { nex_var* c = cn.mk_var(2); nex_var* d = cn.mk_var(3); nex_var* e = cn.mk_var(4); - nex_var* f = cn.mk_var(5); nex_var* g = cn.mk_var(6); nex* min_1 = cn.mk_scalar(rational(-1)); // test_cn_on_expr(min_1*c*e + min_1*b*d + min_1*a*b + a*c); @@ -94,8 +93,16 @@ void test_cn() { nex_mul* bcg = cn.mk_mul(b, c, g); bcg->add_child(min_1); nex_sum* t = cn.mk_sum(bcd, bcg); - test_cn_on_expr(t, cn); - // test_cn_on_expr(a*a*d + a*b*c*d + a*a*c*c*d + a*d*d + e*a*e + e*a*c + e*d); + // test_cn_on_expr(t, cn); + nex* aad = cn.mk_mul(a, a, d); + nex* abcd = cn.mk_mul(a, b, c, d); + nex* aaccd = cn.mk_mul(a, a, c, c, d); + nex* add = cn.mk_mul(a, d, d); + nex* eae = cn.mk_mul(e, a, e); + nex* eac = cn.mk_mul(e, a, c); + nex* ed = cn.mk_mul(e, d); + + test_cn_on_expr(cn.mk_sum(aad, abcd, aaccd, add, eae, eac, ed), cn); // TRACE("nla_cn", tout << "done\n";); // test_cn_on_expr(a*b*d + a*b*c + c*b*d + a*c*d); // TRACE("nla_cn", tout << "done\n";);