diff --git a/src/math/lp/cross_nested.h b/src/math/lp/cross_nested.h index 959664de8..c880b39de 100644 --- a/src/math/lp/cross_nested.h +++ b/src/math/lp/cross_nested.h @@ -30,7 +30,7 @@ public: void run() { vector front; - cross_nested_of_expr_on_front_elem(&m_e, front); + cross_nested_of_expr_on_front_elem(&m_e, front, true); // true for trivial form - no change } static nex* pop_back(vector& front) { @@ -69,7 +69,7 @@ public: f.simplify(); * c = nex::mul(f, *c); TRACE("nla_cn", tout << "common factor=" << f << ", c=" << *c << "\n";); - cross_nested_of_expr_on_front_elem(&(c->children()[1]), front); + cross_nested_of_expr_on_front_elem(&(c->children()[1]), front, false); return true; } @@ -89,39 +89,54 @@ public: TRACE("nla_cn", tout << "restore c=" << *c << ", m_e=" << m_e << "\n";); for (unsigned i = 0; i < front.size(); i++) *(front[i]) = copy_of_front[i]; + TRACE("nla_cn", tout << "restore c=" << *c << ", m_e=" << m_e << "\n";); } } - - void cross_nested_of_expr_on_front_elem(nex* c, vector& front) { + + static std::ostream& dump_occurences(std::ostream& out, const std::unordered_map& occurences) { + out << "{"; + for(const auto& p: occurences) { + const occ& o = p.second; + out << "(" << (char)('a' + p.first) << "->" << o << ")"; + } + out << "}" << std::endl; + return out; + } + + void cross_nested_of_expr_on_front_elem(nex* c, vector& front, bool trivial_form) { SASSERT(c->is_sum()); auto occurences = get_mult_occurences(*c); - TRACE("nla_cn", tout << "m_e=" << m_e << "\nc=" << *c << "\noccurences="; print_vector(occurences, tout) << "\nfront:"; print_vector_of_ptrs(front, tout) << "\n";); + TRACE("nla_cn", tout << "m_e=" << m_e << "\nc=" << *c << ", c occurences="; + dump_occurences(tout, occurences) << "; front:"; print_vector_of_ptrs(front, tout) << "\ntrivial_form=" << trivial_form << "\n";); if (occurences.empty()) { if(front.empty()) { - TRACE("nla_cn_cn", tout << "got the cn form: m_e=" << m_e << "\n";); + if (trivial_form) + return; + TRACE("nla_cn", tout << "got the cn form: m_e=" << m_e << "\n";); SASSERT(!can_be_cross_nested_more(m_e)); auto e_to_report = m_e; e_to_report.simplify(); - e_to_report.sort(); - m_call_on_result(e_to_report); + m_call_on_result(e_to_report); } else { nex* c = pop_back(front); - cross_nested_of_expr_on_front_elem(c, front); + cross_nested_of_expr_on_front_elem(c, front, trivial_form); } } else { cross_nested_of_expr_on_front_elem_occs(c, front, occurences); } } + static char ch(unsigned j) { + return (char)('a'+j); + } // e is the global expression, c is the sub expressiond which is going to changed from sum to the cross nested form void cross_nested_of_expr_on_sum_and_var(nex* c, lpvar j, vector front) { - TRACE("nla_cn", tout << "m_e=" << m_e << "\nc=" << *c << "\nj = v" << j << "\nfront="; print_vector_of_ptrs(front, tout) << "\n";); + TRACE("nla_cn", tout << "m_e=" << m_e << "\nc=" << *c << "\nj = " << ch(j) << "\nfront="; print_vector_of_ptrs(front, tout) << "\n";); split_with_var(*c, j, front); - TRACE("nla_cn", tout << "after split c=" << *c << "\nfront="; print_vector_of_ptrs(front, tout) << "\n";); - do { - nex* n = pop_back(front); - cross_nested_of_expr_on_front_elem(n, front); - } while (!front.empty()); + TRACE("nla_cn", tout << "after split c=" << *c << "\nfront="; print_vector_of_ptrs(front, tout) << "\n";); + SASSERT(front.size()); + nex* n = pop_back(front); TRACE("nla_cn", tout << "n=" << *n <<"\n";); + cross_nested_of_expr_on_front_elem(n, front, false); // we got a non-trivial_form } static void process_var_occurences(lpvar j, std::unordered_map& occurences) { auto it = occurences.find(j); @@ -131,17 +146,7 @@ public: } else { occurences.insert(std::make_pair(j, occ(1, 1))); } - } - - static void dump_occurences(std::ostream& out, const std::unordered_map& occurences) { - out << "{"; - for(const auto& p: occurences) { - const occ& o = p.second; - out << "(v" << p.first << "->" << o << ")"; - } - out << "}" << std::endl; - } - + } static void update_occurences_with_powers(std::unordered_map& occurences, const std::unordered_map& powers) { @@ -156,6 +161,7 @@ public: it->second.m_power = std::min(it->second.m_power, jp); } } + TRACE("nla_cn_details", tout << "occs="; dump_occurences(tout, occurences) << "\n";); } static void remove_singular_occurences(std::unordered_map& occurences) { @@ -173,7 +179,6 @@ public: std::unordered_map occurences; SASSERT(e.type() == expr_type::SUM); for (const auto & ce : e.children()) { - std::unordered_set seen; if (ce.is_mul()) { auto powers = ce.get_powers_from_mul(); update_occurences_with_powers(occurences, powers); @@ -182,16 +187,20 @@ public: } } remove_singular_occurences(occurences); - TRACE("nla_cn_details", dump_occurences(tout, occurences);); + TRACE("nla_cn_details", tout << "e=" << e << "\noccs="; dump_occurences(tout, occurences) << "\n";); return occurences; } - bool can_be_cross_nested_more(const nex& e) const { + bool can_be_cross_nested_more(const nex& s) const { + auto e = s; + e.simplify(); + TRACE("nla_cn_details", tout << "simplified " << e << "\n";); switch (e.type()) { case expr_type::SCALAR: return false; - case expr_type::SUM: { - return !get_mult_occurences(e).empty(); - } + case expr_type::SUM: + if ( !get_mult_occurences(e).empty()) + return true; + // fall through MUL case expr_type::MUL: { for (const auto & c: e.children()) { diff --git a/src/math/lp/horner.cpp b/src/math/lp/horner.cpp index e8ac3fa20..e084dfa5f 100644 --- a/src/math/lp/horner.cpp +++ b/src/math/lp/horner.cpp @@ -40,14 +40,11 @@ bool horner::row_is_interesting(const T& row) const { void horner::lemmas_on_expr(nex& e) { TRACE("nla_cn", tout << "e = " << e << "\n";); - TRACE("nla_cn_cn", tout << "e = " << e << "\n";); cross_nested cn(e, [this](const nex& n) { TRACE("nla_cn", tout << "callback n = " << n << "\n";); auto i = interval_of_expr(n); m_intervals.check_interval_for_conflict_on_zero(i);} ); cn.run(); - TRACE("nla_cn", tout << "lemmas_on_expr done\n";); - } @@ -89,8 +86,8 @@ template nex horner::create_sum_from_row(const T& row) { TRACE("nla_cn", tout << "row="; m_core->print_term(row, tout) << "\n";); SASSERT(row.size() > 1); nex e(expr_type::SUM); - for (const auto &p : row) { - e.add_child(nex::mul(p.coeff(), nexvar(p.var()))); + for (const auto &p : row) { + e.add_child(nex::scalar(p.coeff())* nexvar(p.var())); } return e; } @@ -126,8 +123,9 @@ interv horner::interval_of_expr(const nex& e) { return interv(); } } - -interv horner::interval_of_mul(const std::vector& es) { +template +interv horner::interval_of_mul(const V& es) { + SASSERT(es.size()); interv a = interval_of_expr(es[0]); // std::cout << "a" << std::endl; TRACE("nla_cn_details", tout << "es[0]= "<< es[0] << std::endl << "a = "; m_intervals.display(tout, a); tout << "\n";); @@ -153,7 +151,8 @@ interv horner::interval_of_mul(const std::vector& es) { return a; } -interv horner::interval_of_sum(const std::vector& es) { +template +interv horner::interval_of_sum(const V& es) { interv a = interval_of_expr(es[0]); TRACE("nla_cn_details", tout << "es[0]= " << es[0] << "\n"; m_intervals.display(tout, a) << "\n";); if (m_intervals.is_inf(a)) { diff --git a/src/math/lp/horner.h b/src/math/lp/horner.h index 0eb5c88a2..a74978c5e 100644 --- a/src/math/lp/horner.h +++ b/src/math/lp/horner.h @@ -41,8 +41,10 @@ public: intervals::interval interval_of_expr(const nex& e); nex nexvar(lpvar j) const; - intervals::interval interval_of_sum(const std::vector&); - intervals::interval interval_of_mul(const std::vector&); + template // V is a vector of expressions + intervals::interval interval_of_sum(const V&); + template // V is a vector of expressions + intervals::interval interval_of_mul(const V&); void set_interval_for_scalar(intervals::interval&, const rational&); void set_var_interval(lpvar j, intervals::interval&); std::set get_vars_of_expr(const nex &) const; diff --git a/src/math/lp/nla_expr.h b/src/math/lp/nla_expr.h index 146214c1e..4ad7ba8f8 100644 --- a/src/math/lp/nla_expr.h +++ b/src/math/lp/nla_expr.h @@ -49,56 +49,11 @@ inline std::ostream & operator<<(std::ostream& out, expr_type t) { // This class is needed in horner calculation with intervals template class nla_expr { - - class sorted_children { - std::vector m_es; - // m_order will be sorted according to the non-decreasing order of m_es - svector m_order; - public: - const std::vector& es() const { return m_es; } - std::vector& es() { return m_es; } - void push_back(const nla_expr& e) { - SASSERT(m_es.size() == m_order.size()); - m_order.push_back(m_es.size()); - m_es.push_back(e); - } - const svector& order() const { return m_order; } - const nla_expr& back() const { return m_es.back(); } - nla_expr& back() { return m_es.back(); } - const nla_expr* begin() const { return m_es.begin(); } - const nla_expr* end() const { return m_es.end(); } - typename std::vector::iterator begin() { return m_es.begin(); } - typename std::vector::iterator end() { return m_es.end(); } - unsigned size() const { return m_es.size(); } - void sort() { - std::sort(m_order.begin(), m_order.end(), [this](unsigned i, unsigned j) { return m_es[i] < m_es[j]; }); - } - bool operator<(const sorted_children& e) const { - return compare(e) < 0; - } - - int compare(const sorted_children& e) const { - unsigned m = std::min(size(), e.size()); - for (unsigned j = 0; j < m; j++) { - int r = m_es[m_order[j]].compare(e.m_es[e.m_order[j]]); - TRACE("nla_cn_details", tout << "r=" << r << "\n";); - if (r) - return r; - } - return static_cast(size()) - static_cast(e.size()); - } - void reset_order() { - m_order.clear(); - for( unsigned i = 0; i < m_es.size(); i++) - m_order.push_back(i); - } - }; - // todo: use union expr_type m_type; lpvar m_j; T m_v; // for the scalar - sorted_children m_children; + vector m_children; public: bool is_sum() const { return m_type == expr_type::SUM; } bool is_var() const { return m_type == expr_type::VAR; } @@ -108,16 +63,13 @@ public: lpvar var() const { SASSERT(m_type == expr_type::VAR); return m_j; } expr_type type() const { return m_type; } expr_type& type() { return m_type; } - const std::vector& children() const { return m_children.es(); } - std::vector& children() { return m_children.es(); } - const sorted_children & s_children() const { return m_children; } - sorted_children & s_children() { return m_children; } + const vector& children() const { return m_children; } + vector& children() { return m_children; } const T& value() const { SASSERT(m_type == expr_type::SCALAR); return m_v; } std::string str() const { std::stringstream ss; ss << *this; return ss.str(); } std::ostream & print_sum(std::ostream& out) const { bool first = true; - for (unsigned j : m_children.order()) { - const nla_expr& v = m_children.es()[j]; + for (const nla_expr& v : m_children) { std::string s = v.str(); if (first) { first = false; @@ -139,14 +91,6 @@ public: } return out; } - void sort () { - if (is_sum() || is_mul()) { - for (auto & e : m_children.es()) { - e.sort(); - } - m_children.sort(); - } - } void simplify() { if (is_simple()) return; @@ -158,7 +102,7 @@ public: } if (has_sum) { nla_expr n(expr_type::SUM); - for (auto &e : m_children.es()) { + for (auto &e : m_children) { n += e; } *this = n; @@ -171,19 +115,18 @@ public: } if (has_mul) { nla_expr n(expr_type::MUL); - for (auto &e : m_children.es()) { + for (auto &e : m_children) { n *= e; } *this = n; } - TRACE("nla_cn", tout << "simplified " << *this << "\n";); + TRACE("nla_cn_details", tout << "simplified " << *this << "\n";); } } std::ostream & print_mul(std::ostream& out) const { bool first = true; - for (unsigned j : m_children.order()) { - const nla_expr& v = m_children.es()[j]; + for (const nla_expr& v : m_children) { std::string s = v.str(); if (first) { first = false; @@ -212,7 +155,7 @@ public: case expr_type::MUL: return print_mul(out); case expr_type::VAR: - out << 'v' << m_j; + out << (char)('a'+ m_j); return out; case expr_type::SCALAR: out << m_v; @@ -320,48 +263,6 @@ public: return false; } - int compare(const nla_expr& e) const { - TRACE("nla_cn_details", tout << "this="<<*this<<", e=" << e << "\n";); - if (type() != e.type()) - return (int)type() - (int)(e.type()); - - switch(m_type) { - case expr_type::SUM: - case expr_type::MUL: - return m_children.compare(e.m_children); - - case expr_type::VAR: - return static_cast(m_j) - static_cast(e.m_j); - case expr_type::SCALAR: - return m_v < e.m_v? -1 : (m_v == e.m_v? 0 : 1); - default: - SASSERT(false); - return 0; - } - } - - bool operator<(const nla_expr& e) const { - TRACE("nla_cn_details", tout << "this=" << *this << ", e=" << e << "\n";); - if (type() != (e.type())) - return (int)type() < (int)(e.type()); - - SASSERT(type() == (e.type())); - - switch(m_type) { - case expr_type::SUM: - case expr_type::MUL: - return m_children < e.m_children; - - case expr_type::VAR: - return m_j < e.m_j; - case expr_type::SCALAR: - return m_v < e.m_v; - default: - SASSERT(false); - return false; - } - } - nla_expr& operator*=(const nla_expr& b) { if (is_mul()) { if (b.is_mul()) { @@ -391,6 +292,7 @@ public: it->second++; } } + TRACE("nla_cn_details", tout << "powers of " << *this << "\n"; print_vector(r, tout)<< "\n";); return r; } @@ -436,8 +338,6 @@ public: while(k ++ < i) children().pop_back(); - - s_children().reset_order(); return *this; } @@ -458,11 +358,54 @@ public: } }; + +/* +nla_expr operator/=(const nla_expr &a, const nla_expr& b) { + TRACE("nla_cn_details", tout << a <<" / " << b << "\n";); + if (b.is_var()) { + return a / b.var(); + } + SASSERT(b.is_mul()); + if (a.is_sum()) { + auto r = nex::sum(); + for (auto & e : a.children()) { + r.add_child(e/b); + } + return r; + } + if (is_var()) { + return scalar(T(1)); + return *this; + } + SASSERT(a.is_mul()); + auto powers = b.get_powers_from_mul(); + auto r=nex::mul(); + for (unsigned i = 0; i < a.children().size(); i++, k++) { + auto & e = children()[i]; + if (!e.is_var()) { + SASSERT(e.is_scalar()); + r.add_child(e); + continue; + } + lpvar j = e.var(); + auto it = powers.find(j); + if (it == powers.end()) { + r.add_child(e); + } else { + it->second --; // finish h + if (it->second == 0) + powers.erase(it); + } + } + + return r; + } +*/ template nla_expr operator+(const nla_expr& a, const nla_expr& b) { if (a.is_sum()) { nla_expr r(expr_type::SUM); - r.s_children() = a.s_children(); + r.children() = a.children(); if (b.is_sum()) { for (auto& e: b.children()) r.add_child(e); @@ -473,7 +416,7 @@ nla_expr operator+(const nla_expr& a, const nla_expr& b) { } if (b.is_sum()) { nla_expr r(expr_type::SUM); - r.s_children() = b.s_children(); + r.children() = b.children(); r.add_child(a); return r; } @@ -482,9 +425,13 @@ nla_expr operator+(const nla_expr& a, const nla_expr& b) { template nla_expr operator*(const nla_expr& a, const nla_expr& b) { + if (a.is_scalar() && a.value() == T(1)) + return b; + if (b.is_scalar() && b.value() == T(1)) + return a; if (a.is_mul()) { nla_expr r(expr_type::MUL); - r.s_children() = a.s_children(); + r.children() = a.children(); if (b.is_mul()) { for (auto& e: b.children()) r.add_child(e); @@ -495,7 +442,7 @@ nla_expr operator*(const nla_expr& a, const nla_expr& b) { } if (b.is_mul()) { nla_expr r(expr_type::MUL); - r.s_children() = b.s_children(); + r.children() = b.children(); r.add_child(a); return r; } diff --git a/src/math/lp/nla_intervals.cpp b/src/math/lp/nla_intervals.cpp index 3d23ed6b8..045a52eac 100644 --- a/src/math/lp/nla_intervals.cpp +++ b/src/math/lp/nla_intervals.cpp @@ -47,7 +47,7 @@ bool intervals::check_interval_for_conflict_on_zero_upper(const interval & i) { svector expl; m_dep_manager.linearize(i.m_upper_dep, expl); _().current_expl().add_expl(expl); - TRACE("nla_cn_lemmas", print_lemma(tout);); + TRACE("nla_cn", print_lemma(tout);); return true; } diff --git a/src/test/lp/lp.cpp b/src/test/lp/lp.cpp index 6d1a8e10d..8dd71f9ca 100644 --- a/src/test/lp/lp.cpp +++ b/src/test/lp/lp.cpp @@ -70,7 +70,7 @@ void test_basic_lemma_for_mon_neutral_from_factors_to_monomial(); void test_cn_on_expr(horner::nex t) { TRACE("nla_cn", tout << "t=" << t << '\n';); cross_nested cn(t, [](const horner::nex& n) { - TRACE("nla_cn", tout << n << "\n";); + TRACE("nla_cn_test", tout << n << "\n";); } ); cn.run(); } @@ -78,16 +78,19 @@ void test_cn_on_expr(horner::nex t) { void test_cn() { typedef horner::nex nex; enable_trace("nla_cn"); - enable_trace("nla_cn_details"); enable_trace("nla_cn_cn"); nex a = nex::var(0), b = nex::var(1), c = nex::var(2), d = nex::var(3), e = nex::var(4); - + test_cn_on_expr(a*b + a*c + b*c); + TRACE("nla_cn", tout << "done\n";); + /* + test_cn_on_expr(a*a*d + a*b*c*d + a*a*c*c*d + a*d*d + e*a*e + e*a*c + e*d); + TRACE("nla_cn", tout << "done\n";); test_cn_on_expr(a*b*d + a*b*c + c*b*d + a*c*d); + TRACE("nla_cn", tout << "done\n";); test_cn_on_expr(a*b*b*d*d + a*b*b*c*d + c*b*b*d); TRACE("nla_cn", tout << "done\n";); test_cn_on_expr(a*b*d + a*b*c + c*b*d); - nex t = a*a*d + a*b*c*d + a*a*c*c*d + a*d*d + e*a*e + e*a*c + e*d; - test_cn_on_expr(t); + */ } } // end of namespace nla