From 5e749045e133ef6d0c977ba0b76ea6fa10d95cad Mon Sep 17 00:00:00 2001 From: Lev Nachmanson Date: Mon, 30 Sep 2019 13:50:18 -0700 Subject: [PATCH] pass simplify expession test Signed-off-by: Lev Nachmanson --- src/math/lp/nex.h | 28 +------------ src/math/lp/nex_creator.cpp | 84 ++++++++++++++++++++++++++++++++++--- src/math/lp/nex_creator.h | 4 +- src/test/lp/lp.cpp | 19 ++++++--- 4 files changed, 95 insertions(+), 40 deletions(-) diff --git a/src/math/lp/nex.h b/src/math/lp/nex.h index ce0a42ef1..71921e690 100644 --- a/src/math/lp/nex.h +++ b/src/math/lp/nex.h @@ -86,12 +86,7 @@ public: #endif bool virtual is_linear() const = 0; }; -#if Z3DEBUG -bool operator<(const nex& a , const nex& b); -inline bool operator ==(const nex& a , const nex& b) { - return ! (a < b || b < a) ; -} -#endif + std::ostream& operator<<(std::ostream& out, const nex&); class nex_var : public nex { @@ -228,7 +223,6 @@ public: } bool is_linear() const { - // SASSERT(is_simplified()); return get_degree() < 2; // todo: make it more efficient } @@ -374,25 +368,5 @@ inline bool less_than_nex_standard(const nex* a, const nex* b) { lt_on_vars lt = [](lpvar j, lpvar k) { return j < k; }; return less_than_nex(a, b, lt); } - -#if Z3DEBUG - -inline bool operator<(const ptr_vector&a , const ptr_vector& b) { - int r = (int)a.size() - (int)b.size(); - if (r) - return r < 0; - for (unsigned j = 0; j < a.size(); j++) { - if (*a[j] < *b[j]) { - return true; - } - if (*b[j] < *a[j]) { - return false; - } - } - return false; -} - -#endif - } diff --git a/src/math/lp/nex_creator.cpp b/src/math/lp/nex_creator.cpp index b6f141c52..31a60c184 100644 --- a/src/math/lp/nex_creator.cpp +++ b/src/math/lp/nex_creator.cpp @@ -116,12 +116,82 @@ void nex_creator::simplify_children_of_mul(vector & children) { TRACE("nla_cn_details", print_vector(children, tout);); } -bool nex_creator::less_than_on_mul(const nex_mul* a, const nex_mul* b) { - NOT_IMPLEMENTED_YET(); +bool nex_creator::less_than_on_mul(const nex_mul* a, const nex_mul* b, bool skip_scalar) { + // the scalar, if it is there, is at the beginning of the children() + TRACE("nla_cn_details", tout << "a = " << *a << ", b = " << *b << ", skip_scalar = " << skip_scalar << "\n";); + SASSERT(is_simplified(a) && is_simplified(b)); + unsigned a_deg = a->get_degree(); + unsigned b_deg = b->get_degree(); + if (a_deg > b_deg) + return true; + if (a_deg < b_deg) + return false; + auto it_a = a->children().begin(); + if (skip_scalar && it_a->e()->is_scalar()) + it_a ++; + auto it_b = b->children().begin(); + if (skip_scalar && it_b->e()->is_scalar()) + it_b ++; + auto a_end = a->children().end(); + auto b_end = b->children().end(); + unsigned a_pow, b_pow; + bool inside_a_p = false; // inside_a_p is true means we still compare the old position of it_a + bool inside_b_p = false; // inside_b_p is true means we still compare the old position of it_b + const nex* ae = nullptr; + const nex *be = nullptr; + if (it_a == a_end) { + return it_b != b_end; + } + if (it_b == b_end) + return false; + for (; ;) { + if (!inside_a_p) { + ae = it_a->e(); + a_pow = it_a->pow(); + } + if (!inside_b_p) { + be = it_b->e(); + b_pow = it_b->pow(); + } + + if (lt(ae, be, skip_scalar)) + return true; + if (lt(be, ae, skip_scalar)) + return false; + if (a_pow == b_pow) { + inside_a_p = inside_b_p = false; + it_a++; it_b++; + if (it_a == a_end) { + return it_b != b_end; + } else if (it_b == b_end) { + return true; + } + // no iterator reached the end + continue; + } + if (a_pow < b_pow) { + inside_a_p = false; + inside_b_p = true; + b_pow -= a_pow; + it_a++; + if (it_a == a_end) + return true; + } else { + inside_a_p = true; + inside_b_p = false; + SASSERT(b_pow < a_pow); + it_b++; + if (it_b == b_end) + return false; + } + } + return false; + } -bool nex_creator::sum_simplify_lt(const nex* a, const nex* b) { +bool nex_creator::lt(const nex* a, const nex* b, bool skip_scalar) { + TRACE("nla_cn_details", tout << "a = " << *a << ", b = " << *b << ", skip_scalar = " << skip_scalar << "\n";); int r = (int)(a->type()) - (int)(b->type()); if (r) { return r < 0; @@ -135,7 +205,7 @@ bool nex_creator::sum_simplify_lt(const nex* a, const nex* b) { return to_scalar(a)->value() < to_scalar(b)->value(); } case expr_type::MUL: { - return less_than_on_mul(to_mul(a), to_mul(b)); + return less_than_on_mul(to_mul(a), to_mul(b), skip_scalar); } case expr_type::SUM: { UNREACHABLE(); @@ -302,8 +372,10 @@ nex* nex_creator::create_child_from_nex_and_coeff(nex *e, // a + 3bc + 2bc => a + 5bc void nex_creator::sort_join_sum(ptr_vector & children) { std::map m([this](const nex *a , const nex *b) - { return sum_simplify_lt(a, b); }); + { return lt(a, b, true); }); + TRACE("nla_cn_details", print_vector_of_ptrs(children, tout);); fill_map_with_children(m, children); + TRACE("nla_cn_details", for (auto & p : m ) { tout << "(" << *p.first << ", " << p.second << ") ";}); children.clear(); for (auto& p : m) { children.push_back(create_child_from_nex_and_coeff(p.first, p.second)); @@ -357,6 +429,7 @@ bool is_zero_scalar(nex *e) { } void nex_creator::simplify_children_of_sum(ptr_vector & children) { + TRACE("nla_cn_details", print_vector_of_ptrs(children, tout);); ptr_vector to_promote; int skipped = 0; for(unsigned j = 0; j < children.size(); j++) { @@ -374,6 +447,7 @@ void nex_creator::simplify_children_of_sum(ptr_vector & children) { } } + TRACE("nla_cn_details", print_vector_of_ptrs(children, tout);); children.shrink(children.size() - to_promote.size() - skipped); for (nex *e : to_promote) { diff --git a/src/math/lp/nex_creator.h b/src/math/lp/nex_creator.h index adca07e4b..df4d62eb8 100644 --- a/src/math/lp/nex_creator.h +++ b/src/math/lp/nex_creator.h @@ -223,9 +223,9 @@ public: bool eat_scalar_pow(nex_scalar *& r, nex_pow& p); void simplify_children_of_mul(vector & children, lt_on_vars lt, std::function mk_scalar); - bool sum_simplify_lt(const nex* a, const nex* b); + bool lt(const nex* a, const nex* b, bool skip_scalar); - bool less_than_on_mul(const nex_mul* a, const nex_mul* b); + bool less_than_on_mul(const nex_mul* a, const nex_mul* b, bool skip_scalar); void fill_map_with_children(std::map & m, ptr_vector & children); }; } diff --git a/src/test/lp/lp.cpp b/src/test/lp/lp.cpp index 7faaa8da5..5e365ce7a 100644 --- a/src/test/lp/lp.cpp +++ b/src/test/lp/lp.cpp @@ -104,14 +104,21 @@ void test_simplify() { n->add_child_in_power(r.mk_scalar(rational(1, 3)), 2); TRACE("nla_cn", tout << "m = " << *m << "\n";); - nex * e = r.mk_sum(a, r.mk_sum(b, m)); + nex_sum * e = r.mk_sum(a, r.mk_sum(b, m)); TRACE("nla_cn", tout << "e = " << *e << "\n";); - e = r.simplify(e); + e = to_sum(r.simplify(e)); TRACE("nla_cn", tout << "simplified e = " << *e << "\n";); - nex * l = r.mk_sum(e, r.mk_mul(r.mk_scalar(rational(3)), r.clone(e))); - TRACE("nla_cn", tout << "sum l = " << *l << "\n";); - l = r.simplify(l); - TRACE("nla_cn", tout << "simplified sum l = " << *l << "\n";); + nex_sum * e_m = r.mk_sum(); + for (const nex* ex: to_sum(e)->children()) { + nex* ce = r.mk_mul(r.clone(ex), r.mk_scalar(rational(3))); + ce = r.simplify(ce); + TRACE("nla_cn", tout << "simplified ce = " << *ce << "\n";); + e_m->add_child(ce); + } + e->add_child(e_m); + TRACE("nla_cn", tout << "before simplify sum e = " << *e << "\n";); + e = to_sum(r.simplify(e)); + TRACE("nla_cn", tout << "simplified sum e = " << *e << "\n";); } void test_cn() {