diff --git a/src/math/lp/cross_nested.h b/src/math/lp/cross_nested.h index c3771f907..d4940ee37 100644 --- a/src/math/lp/cross_nested.h +++ b/src/math/lp/cross_nested.h @@ -30,7 +30,7 @@ public: void run() { vector front; - cross_nested_of_expr_on_front_elem(&m_e, front, true); // true for trivial form - no change + explore_of_expr_on_front_elem(&m_e, front, true); // true for trivial form - no change } static nex* pop_back(vector& front) { @@ -52,6 +52,7 @@ public: }; static bool extract_common_factor(nex* c, nex& f, const std::unordered_map & occurences) { + TRACE("nla_cn", tout << "c=" << *c << "\n";); SASSERT(c->is_sum()); f.type() = expr_type::MUL; SASSERT(f.children().empty()); @@ -78,26 +79,36 @@ public: f.simplify(); * c = nex::mul(f, *c); TRACE("nla_cn", tout << "common factor=" << f << ", c=" << *c << "\n";); - cross_nested_of_expr_on_front_elem(&(c->children()[1]), front, false); + explore_of_expr_on_front_elem(&(c->children()[1]), front, false); return true; } + + static vector copy_front(const vector& front) { + vector v; + for (nex* n: front) + v.push_back(*n); + return v; + } + + static void restore_front(const vector ©, vector& front) { + SASSERT(copy.size() == front.size()); + for (unsigned i = 0; i < front.size(); i++) + *(front[i]) = copy[i]; + } - void cross_nested_of_expr_on_front_elem_occs(nex* c, vector& front, const std::unordered_map & occurences) { + void explore_of_expr_on_front_elem_occs(nex* c, vector& front, const std::unordered_map & occurences) { if (proceed_with_common_factor(c, front, occurences)) return; - TRACE("nla_cn", tout << "save c=" << *c << "front:"; print_vector_of_ptrs(front, tout) << "\n";); + TRACE("nla_cn", tout << "save c=" << *c << "; front:"; print_vector_of_ptrs(front, tout) << "\n";); nex copy_of_c = *c; - vector copy_of_front; - for (nex* n: front) - copy_of_front.push_back(*n); + vector copy_of_front = copy_front(front); for(auto& p : occurences) { SASSERT(p.second.m_occs > 1); lpvar j = p.first; - cross_nested_of_expr_on_sum_and_var(c, j, front); + explore_of_expr_on_sum_and_var(c, j, front); *c = copy_of_c; - TRACE("nla_cn", tout << "restore c=" << *c << ", m_e=" << m_e << "\n";); - for (unsigned i = 0; i < front.size(); i++) - *(front[i]) = copy_of_front[i]; + TRACE("nla_cn", tout << "restore c=" << *c << ", m_e=" << m_e << "\n";); + restore_front(copy_of_front, front); TRACE("nla_cn", tout << "restore c=" << *c << "\n";); TRACE("nla_cn", tout << "m_e=" << m_e << "\n";); } @@ -113,7 +124,7 @@ public: return out; } - void cross_nested_of_expr_on_front_elem(nex* c, vector& front, bool trivial_form) { + void explore_of_expr_on_front_elem(nex* c, vector& front, bool trivial_form) { SASSERT(c->is_sum()); auto occurences = get_mult_occurences(*c); TRACE("nla_cn", tout << "m_e=" << m_e << "\nc=" << *c << ", c occurences="; @@ -134,30 +145,26 @@ public: cross_nested cn(e_s, m_call_on_result); cn.run(); } - } else { nex* c = pop_back(front); - cross_nested_of_expr_on_front_elem(c, front, trivial_form); + explore_of_expr_on_front_elem(c, front, trivial_form); } } else { - cross_nested_of_expr_on_front_elem_occs(c, front, occurences); + explore_of_expr_on_front_elem_occs(c, front, occurences); } } static char ch(unsigned j) { return (char)('a'+j); } // e is the global expression, c is the sub expressiond which is going to changed from sum to the cross nested form - void cross_nested_of_expr_on_sum_and_var(nex* c, lpvar j, vector front) { + void explore_of_expr_on_sum_and_var(nex* c, lpvar j, vector front) { TRACE("nla_cn", tout << "m_e=" << m_e << "\nc=" << *c << "\nj = " << ch(j) << "\nfront="; print_vector_of_ptrs(front, tout) << "\n";); - split_with_var(*c, j, front); + if(split_with_var(*c, j, front)) + return; TRACE("nla_cn", tout << "after split c=" << *c << "\nfront="; print_vector_of_ptrs(front, tout) << "\n";); SASSERT(front.size()); - if (can_be_cross_nested_more(*c)) { - cross_nested_of_expr_on_front_elem(c, front, false); - } else { - nex* n = pop_back(front); TRACE("nla_cn", tout << "n=" << *n <<"\n";); - cross_nested_of_expr_on_front_elem(n, front, false); // we got a non-trivial_form - } + nex* n = pop_back(front); TRACE("nla_cn", tout << "n=" << *n <<"\n";); + explore_of_expr_on_front_elem(n, front, false); // we got a non-trivial_form } static void process_var_occurences(lpvar j, std::unordered_map& occurences) { auto it = occurences.find(j); @@ -214,19 +221,23 @@ public: static bool can_be_cross_nested_more(const nex& s) { auto e = s; e.simplify(); - TRACE("nla_cn_details", tout << "simplified " << e << "\n";); + TRACE("nla_cn", tout << "simplified " << e << "\n";); switch (e.type()) { case expr_type::SCALAR: return false; case expr_type::SUM: - if ( !get_mult_occurences(e).empty()) + if ( !get_mult_occurences(e).empty()) { + TRACE("nla_cn", tout << "true for " << e << "\n";); return true; + } // fall through MUL case expr_type::MUL: { for (const auto & c: e.children()) { - if (can_be_cross_nested_more(c)) + if (can_be_cross_nested_more(c)) { + TRACE("nla_cn", tout << "true for " << e << "\n";); return true; + } } return false; } @@ -260,32 +271,55 @@ public: } } - static void update_front_with_split_with_non_empty_b(nex& e, lpvar j, vector & front, nex& a, nex& b) { + // returns true if the recursion is done inside + bool update_front_with_split_with_non_empty_b(nex& e, lpvar j, vector & front, nex& a, nex& b) { nex f; bool a_has_f = extract_common_factor(&a, f, get_mult_occurences(a)); SASSERT(a.is_sum()); if (a_has_f) { TRACE("nla_cn", tout << "a=" << a << ", f=" << f << "\n";); a /= f; - e = nex::sum(nex::mul(nex::var(j), f, a), b); - TRACE("nla_cn", tout << "a = " << a << ", e = " << e << "\n";); - auto& jfa = e.children()[0]; - SASSERT(jfa.size() == 3); - front.push_back(&(jfa.children()[2])); - front.push_back(&e); // e might have its own cross nested form + auto j_f_a = nex::mul(nex::var(j), f, a); + j_f_a.simplify(); + e = nex::sum(j_f_a, b); + TRACE("nla_cn", tout << "j_f_a = " << j_f_a << ", e = " << e << "\n";); + e.simplify(); + auto occs = get_mult_occurences(e); + + if (occs.empty()) { + TRACE("nla_cn", tout << "occs are empty\n";); + auto& jfa = e.children()[0]; + front.push_back(jfa.children().end() - 1); + } else { + TRACE("nla_cn", tout << "recurse\n";); + nex copy_of_e = e; + vector copy_of_front = copy_front(front); + for(auto& p : occs) { + SASSERT(p.second.m_occs > 1); + auto& jfa = e.children()[0]; + front.push_back(jfa.children().end() - 1); + lpvar j = p.first; + explore_of_expr_on_sum_and_var(&e, j, front); + e = copy_of_e; + front.pop_back(); + restore_front(copy_of_front, front); + } + return true; + } } else { TRACE("nla_cn_details", tout << "b = " << b << "\n";); e = nex::sum(nex::mul(nex::var(j), a), b); front.push_back(&(e.children()[0].children()[1])); - TRACE("nla_cn_details", tout << "push to front " << e.children()[0].children()[1] << "\n";); + TRACE("nla_cn", tout << "push to front " << e.children()[0].children()[1] << "\n";); } if (b.is_sum()) { front.push_back(&(e.children()[1])); - TRACE("nla_cn_details", tout << "push to front " << e.children()[1] << "\n";); + TRACE("nla_cn", tout << "push to front " << e.children()[1] << "\n";); } + return false; } - static void update_front_with_split(nex& e, lpvar j, vector & front, nex& a, nex& b) { + bool update_front_with_split(nex& e, lpvar j, vector & front, nex& a, nex& b) { if (b.is_undef()) { SASSERT(b.children().size() == 0); e = nex(expr_type::MUL); @@ -295,18 +329,19 @@ public: front.push_back(&e.children().back()); TRACE("nla_cn_details", tout << "push to front " << e.children().back() << "\n";); } - } else { - update_front_with_split_with_non_empty_b(e, j, front, a, b); + return false; } - + return update_front_with_split_with_non_empty_b(e, j, front, a, b); } - static void split_with_var(nex& e, lpvar j, vector & front) { - TRACE("nla_cn_details", tout << "e = " << e << ", j = v" << j << "\n";); + // it returns true if the recursion is done inside + bool split_with_var(nex& e, lpvar j, vector & front) { + TRACE("nla_cn", tout << "e = " << e << ", j=" << ch(j) << "\n";); if (!e.is_sum()) - return; + return false; nex a, b; + pre_split(e, j, a, b); - update_front_with_split(e, j, front, a, b); + return update_front_with_split(e, j, front, a, b); } std::set get_vars_of_expr(const nex &e ) const { std::set r; diff --git a/src/math/lp/nla_expr.h b/src/math/lp/nla_expr.h index 4ad7ba8f8..fa3936d91 100644 --- a/src/math/lp/nla_expr.h +++ b/src/math/lp/nla_expr.h @@ -105,7 +105,7 @@ public: for (auto &e : m_children) { n += e; } - *this = n; + m_children = n.m_children; } } else if (is_mul()) { bool has_mul = false; @@ -118,7 +118,7 @@ public: for (auto &e : m_children) { n *= e; } - *this = n; + m_children = n.m_children; } TRACE("nla_cn_details", tout << "simplified " << *this << "\n";); } @@ -320,7 +320,6 @@ public: for (; i < children().size(); i++, k++) { auto & e = children()[i]; if (!e.is_var()) { - SASSERT(e.is_scalar()); continue; } lpvar j = e.var(); diff --git a/src/test/lp/lp.cpp b/src/test/lp/lp.cpp index a99a851f9..1396a6167 100644 --- a/src/test/lp/lp.cpp +++ b/src/test/lp/lp.cpp @@ -78,8 +78,7 @@ void test_cn_on_expr(horner::nex t) { void test_cn() { typedef horner::nex nex; enable_trace("nla_cn"); - enable_trace("nla_cn_cn"); - enable_trace("nla_cn_details"); + // enable_trace("nla_cn_details"); nex a = nex::var(0), b = nex::var(1), c = nex::var(2), d = nex::var(3), e = nex::var(4); // test_cn_on_expr(a*b + a*c + b*c); //TRACE("nla_cn", tout << "done\n";);