From 46f8159926c1edc43985a95831c1fabe4a049ba3 Mon Sep 17 00:00:00 2001 From: Lev Nachmanson Date: Sat, 20 Jul 2019 16:58:00 -0700 Subject: [PATCH] memory management in cross_nested is broken --- src/math/lp/cross_nested.h | 105 ++++++++++++++++++++++++++----------- src/test/lp/lp.cpp | 17 +++--- 2 files changed, 84 insertions(+), 38 deletions(-) diff --git a/src/math/lp/cross_nested.h b/src/math/lp/cross_nested.h index c880b39de..c3771f907 100644 --- a/src/math/lp/cross_nested.h +++ b/src/math/lp/cross_nested.h @@ -51,10 +51,10 @@ public: } }; - bool proceed_with_common_factor(nex* c, vector& front, const std::unordered_map & occurences) { - TRACE("nla_cn", tout << "c=" << *c << "\n";); + static bool extract_common_factor(nex* c, nex& f, const std::unordered_map & occurences) { SASSERT(c->is_sum()); - auto f = nex::mul(); + f.type() = expr_type::MUL; + SASSERT(f.children().empty()); unsigned size = c->children().size(); for(const auto & p : occurences) { if (p.second.m_occs == size) { @@ -64,7 +64,16 @@ public: } } } - if (f.children().size() == 0) return false; + return !f.children().empty(); + } + + bool proceed_with_common_factor(nex* c, vector& front, const std::unordered_map & occurences) { + TRACE("nla_cn", tout << "c=" << *c << "\n";); + SASSERT(c->is_sum()); + nex f; + if (!extract_common_factor(c, f, occurences)) + return false; + *c /= f; f.simplify(); * c = nex::mul(f, *c); @@ -89,7 +98,8 @@ public: TRACE("nla_cn", tout << "restore c=" << *c << ", m_e=" << m_e << "\n";); for (unsigned i = 0; i < front.size(); i++) *(front[i]) = copy_of_front[i]; - TRACE("nla_cn", tout << "restore c=" << *c << ", m_e=" << m_e << "\n";); + TRACE("nla_cn", tout << "restore c=" << *c << "\n";); + TRACE("nla_cn", tout << "m_e=" << m_e << "\n";); } } @@ -113,11 +123,18 @@ public: if(front.empty()) { if (trivial_form) return; - TRACE("nla_cn", tout << "got the cn form: m_e=" << m_e << "\n";); - SASSERT(!can_be_cross_nested_more(m_e)); - auto e_to_report = m_e; - e_to_report.simplify(); - m_call_on_result(e_to_report); + auto e_s = m_e; + e_s.simplify(); + occurences = get_mult_occurences(e_s); + if (occurences.empty()) { + TRACE("nla_cn", tout << "got the cn form: e_s=" << e_s << "\n";); + SASSERT(!can_be_cross_nested_more(e_s)); + m_call_on_result(e_s); + } else { + cross_nested cn(e_s, m_call_on_result); + cn.run(); + } + } else { nex* c = pop_back(front); cross_nested_of_expr_on_front_elem(c, front, trivial_form); @@ -135,8 +152,12 @@ public: split_with_var(*c, j, front); TRACE("nla_cn", tout << "after split c=" << *c << "\nfront="; print_vector_of_ptrs(front, tout) << "\n";); SASSERT(front.size()); - nex* n = pop_back(front); TRACE("nla_cn", tout << "n=" << *n <<"\n";); - cross_nested_of_expr_on_front_elem(n, front, false); // we got a non-trivial_form + if (can_be_cross_nested_more(*c)) { + cross_nested_of_expr_on_front_elem(c, front, false); + } else { + nex* n = pop_back(front); TRACE("nla_cn", tout << "n=" << *n <<"\n";); + cross_nested_of_expr_on_front_elem(n, front, false); // we got a non-trivial_form + } } static void process_var_occurences(lpvar j, std::unordered_map& occurences) { auto it = occurences.find(j); @@ -190,7 +211,7 @@ public: TRACE("nla_cn_details", tout << "e=" << e << "\noccs="; dump_occurences(tout, occurences) << "\n";); return occurences; } - bool can_be_cross_nested_more(const nex& s) const { + static bool can_be_cross_nested_more(const nex& s) { auto e = s; e.simplify(); TRACE("nla_cn_details", tout << "simplified " << e << "\n";); @@ -217,11 +238,8 @@ public: return false; } } - void split_with_var(nex& e, lpvar j, vector & front) { - TRACE("nla_cn_details", tout << "e = " << e << ", j = v" << j << "\n";); - if (!e.is_sum()) - return; - nex a, b; + // all factors of j go to a, the rest to b + static void pre_split(nex &e, lpvar j, nex &a, nex&b) { for (const nex & ce: e.children()) { if ((ce.is_mul() && ce.contains(j)) || (ce.is_var() && ce.var() == j)) { a.add_child(ce / j); @@ -232,14 +250,42 @@ public: a.type() = expr_type::SUM; TRACE("nla_cn_details", tout << "a = " << a << "\n";); SASSERT(a.children().size() >= 2); - + a.simplify(); + if (b.children().size() == 1) { nex t = b.children()[0]; b = t; } else if (b.children().size() > 1) { b.type() = expr_type::SUM; } + } + static void update_front_with_split_with_non_empty_b(nex& e, lpvar j, vector & front, nex& a, nex& b) { + nex f; + bool a_has_f = extract_common_factor(&a, f, get_mult_occurences(a)); + SASSERT(a.is_sum()); + if (a_has_f) { + TRACE("nla_cn", tout << "a=" << a << ", f=" << f << "\n";); + a /= f; + e = nex::sum(nex::mul(nex::var(j), f, a), b); + TRACE("nla_cn", tout << "a = " << a << ", e = " << e << "\n";); + auto& jfa = e.children()[0]; + SASSERT(jfa.size() == 3); + front.push_back(&(jfa.children()[2])); + front.push_back(&e); // e might have its own cross nested form + } else { + TRACE("nla_cn_details", tout << "b = " << b << "\n";); + e = nex::sum(nex::mul(nex::var(j), a), b); + front.push_back(&(e.children()[0].children()[1])); + TRACE("nla_cn_details", tout << "push to front " << e.children()[0].children()[1] << "\n";); + } + if (b.is_sum()) { + front.push_back(&(e.children()[1])); + TRACE("nla_cn_details", tout << "push to front " << e.children()[1] << "\n";); + } + } + + static void update_front_with_split(nex& e, lpvar j, vector & front, nex& a, nex& b) { if (b.is_undef()) { SASSERT(b.children().size() == 0); e = nex(expr_type::MUL); @@ -249,19 +295,18 @@ public: front.push_back(&e.children().back()); TRACE("nla_cn_details", tout << "push to front " << e.children().back() << "\n";); } - } else { - TRACE("nla_cn_details", tout << "b = " << b << "\n";); - e = nex::sum(nex::mul(nex::var(j), a), b); - if (a.is_sum()) { - front.push_back(&(e.children()[0].children()[1])); - TRACE("nla_cn_details", tout << "push to front " << e.children()[0].children()[1] << "\n";); - } - if (b.is_sum()) { - front.push_back(&(e.children()[1])); - TRACE("nla_cn_details", tout << "push to front " << e.children()[1] << "\n";); - } + update_front_with_split_with_non_empty_b(e, j, front, a, b); } + + } + static void split_with_var(nex& e, lpvar j, vector & front) { + TRACE("nla_cn_details", tout << "e = " << e << ", j = v" << j << "\n";); + if (!e.is_sum()) + return; + nex a, b; + pre_split(e, j, a, b); + update_front_with_split(e, j, front, a, b); } std::set get_vars_of_expr(const nex &e ) const { std::set r; diff --git a/src/test/lp/lp.cpp b/src/test/lp/lp.cpp index 8dd71f9ca..a99a851f9 100644 --- a/src/test/lp/lp.cpp +++ b/src/test/lp/lp.cpp @@ -79,17 +79,18 @@ void test_cn() { typedef horner::nex nex; enable_trace("nla_cn"); enable_trace("nla_cn_cn"); + enable_trace("nla_cn_details"); nex a = nex::var(0), b = nex::var(1), c = nex::var(2), d = nex::var(3), e = nex::var(4); - test_cn_on_expr(a*b + a*c + b*c); - TRACE("nla_cn", tout << "done\n";); - /* + // test_cn_on_expr(a*b + a*c + b*c); + //TRACE("nla_cn", tout << "done\n";); test_cn_on_expr(a*a*d + a*b*c*d + a*a*c*c*d + a*d*d + e*a*e + e*a*c + e*d); TRACE("nla_cn", tout << "done\n";); - test_cn_on_expr(a*b*d + a*b*c + c*b*d + a*c*d); - TRACE("nla_cn", tout << "done\n";); - test_cn_on_expr(a*b*b*d*d + a*b*b*c*d + c*b*b*d); - TRACE("nla_cn", tout << "done\n";); - test_cn_on_expr(a*b*d + a*b*c + c*b*d); + /* + test_cn_on_expr(a*b*d + a*b*c + c*b*d + a*c*d); + TRACE("nla_cn", tout << "done\n";); + test_cn_on_expr(a*b*b*d*d + a*b*b*c*d + c*b*b*d); + TRACE("nla_cn", tout << "done\n";); + test_cn_on_expr(a*b*d + a*b*c + c*b*d); */ }