From 8ed865e447fab4b7aa6d465c3253acb185f124f3 Mon Sep 17 00:00:00 2001 From: Lev Nachmanson Date: Tue, 16 Jul 2019 16:50:44 -0700 Subject: [PATCH] add compare() to nla_expr Signed-off-by: Lev Nachmanson --- src/math/lp/cross_nested.h | 12 +++++++++-- src/math/lp/horner.cpp | 9 ++++++--- src/math/lp/nla_expr.h | 41 ++++++++++++++++++++++++++++++++++++-- src/test/lp/lp.cpp | 7 ++++--- 4 files changed, 59 insertions(+), 10 deletions(-) diff --git a/src/math/lp/cross_nested.h b/src/math/lp/cross_nested.h index 07ea0c1a6..1ef4832f6 100644 --- a/src/math/lp/cross_nested.h +++ b/src/math/lp/cross_nested.h @@ -25,6 +25,7 @@ class cross_nested { typedef nla_expr nex; nex& m_e; std::function m_call_on_result; + std::set m_reported; public: cross_nested(nex &e, std::function call_on_result): m_e(e), m_call_on_result(call_on_result) {} @@ -48,7 +49,15 @@ public: if(front.empty()) { TRACE("nla_cn_cn", tout << "got the cn form: m_e=" << m_e << "\n";); SASSERT(!can_be_cross_nested_more(m_e)); - m_call_on_result(m_e); + auto e_to_report = m_e; + e_to_report.simplify(); + e_to_report.sort(); + if (m_reported.find(e_to_report) == m_reported.end()) { + m_reported.insert(e_to_report); + m_call_on_result(e_to_report); + } else { + TRACE("nla_cn", tout << "do not report " << e_to_report << "\n";); + } } else { nex* c = pop_back(front); cross_nested_of_expr_on_front_elem(c, front); @@ -67,7 +76,6 @@ public: *(front[i]) = copy_of_front[i]; } } - TRACE("nla_cn", tout << "exit\n";); } // e is the global expression, c is the sub expressiond which is going to changed from sum to the cross nested form void cross_nested_of_expr_on_sum_and_var(nex* c, lpvar j, vector front) { diff --git a/src/math/lp/horner.cpp b/src/math/lp/horner.cpp index 3b24ab5cb..e8ac3fa20 100644 --- a/src/math/lp/horner.cpp +++ b/src/math/lp/horner.cpp @@ -42,9 +42,12 @@ void horner::lemmas_on_expr(nex& e) { TRACE("nla_cn", tout << "e = " << e << "\n";); TRACE("nla_cn_cn", tout << "e = " << e << "\n";); cross_nested cn(e, [this](const nex& n) { - auto i = interval_of_expr(n); - m_intervals.check_interval_for_conflict_on_zero(i);} ); - cn.run(); + TRACE("nla_cn", tout << "callback n = " << n << "\n";); + auto i = interval_of_expr(n); + m_intervals.check_interval_for_conflict_on_zero(i);} ); + cn.run(); + TRACE("nla_cn", tout << "lemmas_on_expr done\n";); + } diff --git a/src/math/lp/nla_expr.h b/src/math/lp/nla_expr.h index 31bc579af..c14cfcf26 100644 --- a/src/math/lp/nla_expr.h +++ b/src/math/lp/nla_expr.h @@ -20,7 +20,7 @@ #include #include "math/lp/nla_defs.h" namespace nla { -enum class expr_type { SUM, MUL, VAR, SCALAR, UNDEF }; +enum class expr_type { VAR, SUM, MUL, SCALAR, UNDEF }; inline std::ostream & operator<<(std::ostream& out, expr_type t) { switch (t) { case expr_type::SUM: @@ -73,6 +73,21 @@ class nla_expr { void sort() { std::sort(m_order.begin(), m_order.end(), [this](unsigned i, unsigned j) { return m_es[i] < m_es[j]; }); } + bool operator<(const sorted_children& e) const { + return compare(e) == -1; + } + + int compare(const sorted_children& e) const { + unsigned m = std::min(size(), e.size()); + for (unsigned j = 0; j < m; j++) { + int r = m_es[m_order[j]].compare(e.m_es[e.m_order[j]]); + if (r == -1) + return true; + if (r == 1) + return false; + } + return size() < e.size(); + } }; // todo: use union @@ -295,6 +310,27 @@ public: } return false; } + + int compare(const nla_expr& e) const { + if (type() != (e.type())) + return (int)type() - (int)(e.type()); + SASSERT(type() == (e.type())); + + switch(m_type) { + case expr_type::SUM: + case expr_type::MUL: + return m_children.compare(e.m_children); + + case expr_type::VAR: + return m_j - e.m_j; + case expr_type::SCALAR: + return m_v < e.m_v? -1 : (m_v == e.m_v? 0 : 1); + default: + SASSERT(false); + return 0; + } + } + bool operator<(const nla_expr& e) const { if (type() != (e.type())) return (int)type() < (int)(e.type()); @@ -304,7 +340,8 @@ public: switch(m_type) { case expr_type::SUM: case expr_type::MUL: - return m_children.es() < e.m_children.es(); + return m_children < e.m_children; + case expr_type::VAR: return m_j < e.m_j; case expr_type::SCALAR: diff --git a/src/test/lp/lp.cpp b/src/test/lp/lp.cpp index 12c7f61ce..485dab113 100644 --- a/src/test/lp/lp.cpp +++ b/src/test/lp/lp.cpp @@ -84,11 +84,12 @@ void test_cn() { typedef horner::nex nex; enable_trace("nla_cn"); enable_trace("nla_cn_details"); + enable_trace("nla_cn_cn"); nex a = nex::var(0), b = nex::var(1), c = nex::var(2), d = nex::var(3), e = nex::var(4); - nex t = a*a*d + a*b*c*d + a*a*c*c*d + a*d*d + e*a*e + e*a*c + e*d; - test_cn_on_expr(t); - test_cn_on_expr(a*b*d + a*b*c); + test_cn_on_expr(a*b*d + a*b*c + c*b*d); + // nex t = a*a*d + a*b*c*d + a*a*c*c*d + a*d*d + e*a*e + e*a*c + e*d; + // test_cn_on_expr(t); } } // end of namespace nla