diff --git a/src/math/lp/cross_nested.h b/src/math/lp/cross_nested.h index d4940ee37..a9ab08e56 100644 --- a/src/math/lp/cross_nested.h +++ b/src/math/lp/cross_nested.h @@ -35,6 +35,7 @@ public: static nex* pop_back(vector& front) { nex* c = front.back(); + TRACE("nla_cn", tout << *c << "\n";); front.pop_back(); return c; } @@ -67,7 +68,7 @@ public: } return !f.children().empty(); } - + bool proceed_with_common_factor(nex* c, vector& front, const std::unordered_map & occurences) { TRACE("nla_cn", tout << "c=" << *c << "\n";); SASSERT(c->is_sum()); @@ -83,6 +84,11 @@ public: return true; } + static void push(vector& front, nex* e) { + TRACE("nla_cn", tout << *e << "\n";); + front.push_back(e); + } + static vector copy_front(const vector& front) { vector v; for (nex* n: front) @@ -134,17 +140,8 @@ public: if(front.empty()) { if (trivial_form) return; - auto e_s = m_e; - e_s.simplify(); - occurences = get_mult_occurences(e_s); - if (occurences.empty()) { - TRACE("nla_cn", tout << "got the cn form: e_s=" << e_s << "\n";); - SASSERT(!can_be_cross_nested_more(e_s)); - m_call_on_result(e_s); - } else { - cross_nested cn(e_s, m_call_on_result); - cn.run(); - } + TRACE("nla_cn", tout << "got the cn form: =" << m_e << "\n";); + m_call_on_result(m_e); } else { nex* c = pop_back(front); explore_of_expr_on_front_elem(c, front, trivial_form); @@ -159,7 +156,7 @@ public: // e is the global expression, c is the sub expressiond which is going to changed from sum to the cross nested form void explore_of_expr_on_sum_and_var(nex* c, lpvar j, vector front) { TRACE("nla_cn", tout << "m_e=" << m_e << "\nc=" << *c << "\nj = " << ch(j) << "\nfront="; print_vector_of_ptrs(front, tout) << "\n";); - if(split_with_var(*c, j, front)) + if (!split_with_var(*c, j, front)) return; TRACE("nla_cn", tout << "after split c=" << *c << "\nfront="; print_vector_of_ptrs(front, tout) << "\n";); SASSERT(front.size()); @@ -272,76 +269,45 @@ public: } // returns true if the recursion is done inside - bool update_front_with_split_with_non_empty_b(nex& e, lpvar j, vector & front, nex& a, nex& b) { + void update_front_with_split_with_non_empty_b(nex& e, lpvar j, vector & front, nex& a, nex& b) { nex f; - bool a_has_f = extract_common_factor(&a, f, get_mult_occurences(a)); SASSERT(a.is_sum()); - if (a_has_f) { - TRACE("nla_cn", tout << "a=" << a << ", f=" << f << "\n";); - a /= f; - auto j_f_a = nex::mul(nex::var(j), f, a); - j_f_a.simplify(); - e = nex::sum(j_f_a, b); - TRACE("nla_cn", tout << "j_f_a = " << j_f_a << ", e = " << e << "\n";); - e.simplify(); - auto occs = get_mult_occurences(e); - - if (occs.empty()) { - TRACE("nla_cn", tout << "occs are empty\n";); - auto& jfa = e.children()[0]; - front.push_back(jfa.children().end() - 1); - } else { - TRACE("nla_cn", tout << "recurse\n";); - nex copy_of_e = e; - vector copy_of_front = copy_front(front); - for(auto& p : occs) { - SASSERT(p.second.m_occs > 1); - auto& jfa = e.children()[0]; - front.push_back(jfa.children().end() - 1); - lpvar j = p.first; - explore_of_expr_on_sum_and_var(&e, j, front); - e = copy_of_e; - front.pop_back(); - restore_front(copy_of_front, front); - } - return true; - } - } else { - TRACE("nla_cn_details", tout << "b = " << b << "\n";); - e = nex::sum(nex::mul(nex::var(j), a), b); - front.push_back(&(e.children()[0].children()[1])); - TRACE("nla_cn", tout << "push to front " << e.children()[0].children()[1] << "\n";); - } + + TRACE("nla_cn_details", tout << "b = " << b << "\n";); + e = nex::sum(nex::mul(nex::var(j), a), b); + push(front, &(e.children()[0].children()[1])); // pushing 'a' + TRACE("nla_cn", tout << "push to front " << e.children()[0].children()[1] << "\n";); + if (b.is_sum()) { - front.push_back(&(e.children()[1])); + push(front, &(e.children()[1])); TRACE("nla_cn", tout << "push to front " << e.children()[1] << "\n";); } - return false; } - bool update_front_with_split(nex& e, lpvar j, vector & front, nex& a, nex& b) { + void update_front_with_split(nex& e, lpvar j, vector & front, nex& a, nex& b) { if (b.is_undef()) { SASSERT(b.children().size() == 0); e = nex(expr_type::MUL); e.add_child(nex::var(j)); e.add_child(a); if (a.size() > 1) { - front.push_back(&e.children().back()); + push(front, &e.children().back()); TRACE("nla_cn_details", tout << "push to front " << e.children().back() << "\n";); } - return false; } - return update_front_with_split_with_non_empty_b(e, j, front, a, b); + update_front_with_split_with_non_empty_b(e, j, front, a, b); } - // it returns true if the recursion is done inside + // it returns true if the recursion brings a cross-nested form bool split_with_var(nex& e, lpvar j, vector & front) { TRACE("nla_cn", tout << "e = " << e << ", j=" << ch(j) << "\n";); - if (!e.is_sum()) - return false; + if (!e.is_sum()) return false; nex a, b; - pre_split(e, j, a, b); - return update_front_with_split(e, j, front, a, b); + nex f; + if (extract_common_factor(&a, f, get_mult_occurences(a))) + return false; + update_front_with_split(e, j, front, a, b); + return true; } std::set get_vars_of_expr(const nex &e ) const { std::set r; diff --git a/src/test/lp/lp.cpp b/src/test/lp/lp.cpp index 1396a6167..b3d7603e0 100644 --- a/src/test/lp/lp.cpp +++ b/src/test/lp/lp.cpp @@ -80,17 +80,16 @@ void test_cn() { enable_trace("nla_cn"); // enable_trace("nla_cn_details"); nex a = nex::var(0), b = nex::var(1), c = nex::var(2), d = nex::var(3), e = nex::var(4); - // test_cn_on_expr(a*b + a*c + b*c); + //test_cn_on_expr(a*b + a*c + b*c); //TRACE("nla_cn", tout << "done\n";); test_cn_on_expr(a*a*d + a*b*c*d + a*a*c*c*d + a*d*d + e*a*e + e*a*c + e*d); TRACE("nla_cn", tout << "done\n";); - /* - test_cn_on_expr(a*b*d + a*b*c + c*b*d + a*c*d); - TRACE("nla_cn", tout << "done\n";); - test_cn_on_expr(a*b*b*d*d + a*b*b*c*d + c*b*b*d); - TRACE("nla_cn", tout << "done\n";); - test_cn_on_expr(a*b*d + a*b*c + c*b*d); - */ + test_cn_on_expr(a*b*d + a*b*c + c*b*d + a*c*d); + TRACE("nla_cn", tout << "done\n";); + test_cn_on_expr(a*b*b*d*d + a*b*b*c*d + c*b*b*d); + TRACE("nla_cn", tout << "done\n";); + test_cn_on_expr(a*b*d + a*b*c + c*b*d); + } } // end of namespace nla