From 4e2cd2c8de51b501ec78a77e550dbbbe1c2be104 Mon Sep 17 00:00:00 2001 From: Lev Nachmanson Date: Wed, 25 Sep 2019 16:46:50 -0700 Subject: [PATCH] sort expressions by power Signed-off-by: Lev Nachmanson --- src/math/lp/cross_nested.h | 43 ++------------- src/math/lp/nex.cpp | 26 ++++++++- src/math/lp/nex.h | 107 +++++++++++++++++++++++++++++-------- src/math/lp/nex_creator.h | 75 ++++++++++++++++++++------ src/math/lp/nla_grobner.h | 2 +- src/test/lp/lp.cpp | 32 +++++++++-- 6 files changed, 202 insertions(+), 83 deletions(-) diff --git a/src/math/lp/cross_nested.h b/src/math/lp/cross_nested.h index 1cd530944..adf592820 100644 --- a/src/math/lp/cross_nested.h +++ b/src/math/lp/cross_nested.h @@ -35,7 +35,7 @@ class cross_nested { int m_reported; bool m_random_bit; nex_creator m_nex_creator; - std::function m_lt; + nex_lt m_lt; #ifdef Z3DEBUG nex* m_e_clone; @@ -47,18 +47,19 @@ public: cross_nested(std::function call_on_result, std::function var_is_fixed, std::function random, - std::function lt): + nex_lt lt): m_call_on_result(call_on_result), m_var_is_fixed(var_is_fixed), m_random(random), m_done(false), m_reported(0), - m_nex_creator(lt) {} + m_nex_creator(lt), + m_lt(lt) {} void run(nex *e) { TRACE("nla_cn", tout << *e << "\n";); - SASSERT(e->is_simplified()); + SASSERT(e->is_simplified(m_lt)); m_e = e; #ifdef Z3DEBUG // m_e_clone = clone(m_e); @@ -482,40 +483,6 @@ public: bool done() const { return m_done; } #if Z3DEBUG - nex *clone (const nex * a) { - switch (a->type()) { - case expr_type::VAR: { - auto v = to_var(a); - return m_nex_creator.mk_var(v->var()); - } - - case expr_type::SCALAR: { - auto v = to_scalar(a); - return m_nex_creator.mk_scalar(v->value()); - } - case expr_type::MUL: { - auto m = to_mul(a); - auto r = m_nex_creator.mk_mul(); - for (const auto& p : m->children()) { - r->add_child_in_power(clone(p.e()), p.pow()); - } - return r; - } - case expr_type::SUM: { - auto m = to_sum(a); - auto r = m_nex_creator.mk_sum(); - for (nex * e : m->children()) { - r->add_child(clone(e)); - } - return r; - } - default: - SASSERT(false); - break; - } - return nullptr; - } - nex * normalize_sum(nex_sum* a) { for (unsigned j = 0; j < a->size(); j ++) { a->children()[j] = normalize(a->children()[j]); diff --git a/src/math/lp/nex.cpp b/src/math/lp/nex.cpp index ac7e51ed3..8481f5ddc 100644 --- a/src/math/lp/nex.cpp +++ b/src/math/lp/nex.cpp @@ -18,6 +18,7 @@ --*/ #include "math/lp/nex.h" +#include namespace nla { @@ -32,9 +33,28 @@ bool ignored_child(nex* e, expr_type t) { return false; } +void mul_to_powers(vector& children, nex_lt lt) { + std::map m(lt); + for (auto & p : children) { + auto it = m.find(p.e()); + if (it == m.end()) { + m[p.e()] = p.pow(); + } else { + it->second+= p.pow(); + } + } + children.clear(); + for (auto & p : m) { + children.push_back(nex_pow(p.first, p.second)); + } -void promote_children_of_sum(ptr_vector & children,std::function lt ) { + std::sort(children.begin(), children.end(), [lt](const nex_pow& a, const nex_pow& b) { + return less_than(a, b, lt); + }); +} + +void promote_children_of_sum(ptr_vector & children, nex_lt lt ) { ptr_vector to_promote; int skipped = 0; for(unsigned j = 0; j < children.size(); j++) { @@ -63,7 +83,7 @@ void promote_children_of_sum(ptr_vector & children,std::function & children, std::function lt) { +void promote_children_of_mul(vector & children, nex_lt lt) { TRACE("nla_cn_details", print_vector(children, tout);); vector to_promote; int skipped = 0; @@ -92,6 +112,8 @@ void promote_children_of_mul(vector & children, std::function #include "math/lp/nla_defs.h" #include +#include namespace nla { -enum class expr_type { VAR, SUM, MUL, SCALAR, UNDEF }; +class nex; +typedef std::function nex_lt; + +typedef std::function lt_on_vars; + +enum class expr_type { SCALAR, VAR, SUM, MUL, UNDEF }; inline std::ostream & operator<<(std::ostream& out, expr_type t) { switch (t) { case expr_type::SUM: @@ -72,11 +78,14 @@ public: virtual bool contains(lpvar j) const { return false; } virtual int get_degree() const = 0; // simplifies the expression and also assigns the address of "this" to *e - virtual void simplify(nex** e, std::function lt) { *e = this; } + virtual void simplify(nex** e, nex_lt) { *e = this; } void simplify(nex** e) { return simplify(e, less_than_nex_standard); } - virtual bool is_simplified() const { + virtual bool is_simplified(nex_lt) const { return true; } + + virtual bool is_simplified() const { return is_simplified(less_than_nex_standard); } + #ifdef Z3DEBUG virtual void sort() {}; #endif @@ -130,9 +139,9 @@ const nex_scalar * to_scalar(const nex* a); class nex_sum; const nex_sum* to_sum(const nex*a); -void promote_children_of_sum(ptr_vector & children, std::function); +void promote_children_of_sum(ptr_vector & children, nex_lt); class nex_pow; -void promote_children_of_mul(vector & children, std::function lt); +void promote_children_of_mul(vector & children, nex_lt); class nex_pow { nex* m_e; @@ -145,11 +154,24 @@ public: nex ** ee() { return & m_e; } int pow() const { return m_power; } int& pow() { return m_power; } - std::string to_string() const { std::stringstream s; s << "(" << *e() << ", " << pow() << ")"; - return s.str(); } + std::string to_string() const { + std::stringstream s; + if (pow() == 1) { + s <<"(" << *e() << ")"; + } else { + s << "(" << *e() << "^" << pow() << ")"; + } + return s.str(); + } friend std::ostream& operator<<(std::ostream& out, const nex_pow & p) { out << p.to_string(); return out; } }; + +inline bool less_than(const nex_pow & a, const nex_pow& b, nex_lt lt) { + return (a.pow() < b.pow()) || (a.pow() == b.pow() && lt(a.e(), b.e())); +} + + class nex_mul : public nex { vector m_children; public: @@ -216,7 +238,7 @@ public: return degree; } // the second argument is the comparison less than operator - void simplify(nex **e, std::function lt) { + void simplify(nex **e, nex_lt lt) { TRACE("nla_cn_details", tout << *this << "\n";); TRACE("nla_cn_details", tout << "**e = " << **e << "\n";); *e = this; @@ -225,24 +247,43 @@ public: if (size() == 1 && m_children[0].pow() == 1) *e = m_children[0].e(); TRACE("nla_cn_details", tout << *this << "\n";); - SASSERT((*e)->is_simplified()); + SASSERT((*e)->is_simplified(lt)); } - virtual bool is_simplified() const { - if (size() < 2) - return false; - for (const auto &p : children()) { - const nex* e = p.e(); - if (e->is_mul()) - return false; - if (e->is_scalar() && to_scalar(e)->value().is_one()) + bool is_sorted(nex_lt lt) const { + for (unsigned j = 0; j < m_children.size() - 1; j++) { + if (!(less_than(m_children[j], m_children[j+1], lt))) return false; } return true; } + + virtual bool is_simplified(nex_lt lt) const { + if (size() == 1 && m_children.begin()->pow() == 1) + return false; + std::set s(lt); + for (const auto &p : children()) { + const nex* e = p.e(); + if (p.pow() == 0) + return false; + if (e->is_mul()) + return false; + if (e->is_scalar() && to_scalar(e)->value().is_one()) + return false; + + auto it = s.find(e); + if (it == s.end()) { + s.insert(e); + } else { + TRACE("nla_cn_details", tout << "not simplified " << *e << "\n";); + return false; + } + } + return is_sorted(lt); + } bool is_linear() const { - SASSERT(is_simplified()); + // SASSERT(is_simplified()); return get_degree() < 2; // todo: make it more efficient } @@ -320,7 +361,7 @@ public: return out; } - void simplify(nex **e, std::function lt ) { + void simplify(nex **e, nex_lt lt ) { *e = this; promote_children_of_sum(m_children, lt); if (size() == 1) @@ -398,8 +439,32 @@ inline std::ostream& operator<<(std::ostream& out, const nex& e ) { } -inline bool less_than_nex(const nex* a, const nex* b, std::function lt) { - NOT_IMPLEMENTED_YET(); +inline bool less_than_nex(const nex* a, const nex* b, lt_on_vars lt) { + int r = (int)(a->type()) - (int)(b->type()); + if (r) { + return r < 0; + } + // here a and b have the same type + switch (a->type()) { + case expr_type::VAR: { + return lt(to_var(a)->var() , to_var(b)->var()); + } + case expr_type::SCALAR: { + return to_scalar(a)->value() < to_scalar(b)->value(); + } + case expr_type::MUL: { + NOT_IMPLEMENTED_YET(); + return false; // to_mul(a)->children() < to_mul(b)->children(); + } + case expr_type::SUM: { + NOT_IMPLEMENTED_YET(); + return false; //to_sum(a)->children() < to_sum(b)->children(); + } + default: + SASSERT(false); + return false; + } + return false; } diff --git a/src/math/lp/nex_creator.h b/src/math/lp/nex_creator.h index 0f38c31b2..c50623a8b 100644 --- a/src/math/lp/nex_creator.h +++ b/src/math/lp/nex_creator.h @@ -42,9 +42,42 @@ class nex_creator { std::unordered_map m_occurences_map; std::unordered_map m_powers; // the "less than" operator on expressions - std::function m_lt; + nex_lt m_lt; public: - nex_creator(std::function lt) {} + nex * clone(const nex* a) { + switch (a->type()) { + case expr_type::VAR: { + auto v = to_var(a); + return mk_var(v->var()); + } + + case expr_type::SCALAR: { + auto v = to_scalar(a); + return mk_scalar(v->value()); + } + case expr_type::MUL: { + auto m = to_mul(a); + auto r = mk_mul(); + for (const auto& p : m->children()) { + r->add_child_in_power(clone(p.e()), p.pow()); + } + return r; + } + case expr_type::SUM: { + auto m = to_sum(a); + auto r = mk_sum(); + for (nex * e : m->children()) { + r->add_child(clone(e)); + } + return r; + } + default: + UNREACHABLE(); + break; + } + return nullptr; + } + nex_creator(nex_lt lt) : m_lt(lt) {} const std::unordered_map& occurences_map() const { return m_occurences_map; } std::unordered_map& occurences_map() { return m_occurences_map; } const std::unordered_map & powers() const { return m_powers; } @@ -129,37 +162,45 @@ public: return r; } - nex * mk_div(const nex* a, lpvar j) { + SASSERT(a->is_simplified(m_lt)); TRACE("nla_cn_details", tout << "a=" << *a << ", v" << j << "\n";); - NOT_IMPLEMENTED_YET(); - return nullptr; - /* SASSERT((a->is_mul() && a->contains(j)) || (a->is_var() && to_var(a)->var() == j)); if (a->is_var()) return mk_scalar(rational(1)); - ptr_vector bv; + vector bv; bool seenj = false; - for (nex* c : to_mul(a)->children()) { + for (auto& p : to_mul(a)->children()) { + const nex * c = p.e(); + int pow = p.pow(); if (!seenj) { if (c->contains(j)) { - if (!c->is_var()) - bv.push_back(mk_div(c, j)); + if (!c->is_var()) { + bv.push_back(nex_pow(mk_div(c, j))); + if (pow != 1) { + bv.push_back(nex_pow(clone(c), pow)); + } + } else { + SASSERT(to_var(c)->var() == j); + if (p.pow() != 1) { + bv.push_back(nex_pow(mk_var(j), pow - 1)); + } + } seenj = true; - continue; } + } else { + bv.push_back(nex_pow(clone(c))); } - bv.push_back(c); } if (bv.size() > 1) { return mk_mul(bv); } - if (bv.size() == 1) { - return bv[0]; + if (bv.size() == 1 && bv.begin()->pow() == 1) { + return bv.begin()->e(); } - - SASSERT(bv.size() == 0); - return mk_scalar(rational(1));*/ + if (bv.size() == 0) + return mk_scalar(rational(1)); + return mk_mul(bv); } nex * mk_div(const nex* a, const nex* b) { diff --git a/src/math/lp/nla_grobner.h b/src/math/lp/nla_grobner.h index 784b2efe9..c426ef13f 100644 --- a/src/math/lp/nla_grobner.h +++ b/src/math/lp/nla_grobner.h @@ -102,7 +102,7 @@ class nla_grobner : common { ci_value_manager m_val_manager; ci_dependency_manager m_dep_manager; nex_creator m_nex_creator; - std::function m_lt; + nex_lt m_lt; public: nla_grobner(core *core); void grobner_lemmas(); diff --git a/src/test/lp/lp.cpp b/src/test/lp/lp.cpp index 6a669d39f..129918020 100644 --- a/src/test/lp/lp.cpp +++ b/src/test/lp/lp.cpp @@ -73,7 +73,21 @@ void test_cn_on_expr(nex_sum *t, cross_nested& cn) { cn.run(t); } -void test_simplify(cross_nested& cn, nex_var* a, nex_var* b, nex_var* c) { +void test_simplify() { + cross_nested cn( + [](const nex* n) { + TRACE("nla_cn_test", tout << *n << "\n";); + return false; + } , + [](unsigned) { return false; }, + []{ return 1; }, + less_than_nex_standard + ); + enable_trace("nla_cn"); + enable_trace("nla_cn_details"); + nex_var* a = cn.get_nex_creator().mk_var(0); + nex_var* b = cn.get_nex_creator().mk_var(1); + nex_var* c = cn.get_nex_creator().mk_var(2); auto & r = cn.get_nex_creator(); auto m = r.mk_mul(); m->add_child_in_power(c, 2); TRACE("nla_cn", tout << "m = " << *m << "\n";); @@ -108,7 +122,6 @@ void test_cn() { nex_var* e = cn.get_nex_creator().mk_var(4); nex_var* g = cn.get_nex_creator().mk_var(6); nex* min_1 = cn.get_nex_creator().mk_scalar(rational(-1)); - test_simplify(cn, a, b, c); // test_cn_on_expr(min_1*c*e + min_1*b*d + min_1*a*b + a*c); nex* bcd = cn.get_nex_creator().mk_mul(b, c, d); nex_mul* bcg = cn.get_nex_creator().mk_mul(b, c, g); @@ -124,11 +137,13 @@ void test_cn() { nex* ed = cn.get_nex_creator().mk_mul(e, d); nex* _6aad = cn.get_nex_creator().mk_mul(cn.get_nex_creator().mk_scalar(rational(6)), a, a, d); #ifdef Z3DEBUG - nex * clone = cn.clone(cn.get_nex_creator().mk_sum(_6aad, abcd, aaccd, add, eae, eac, ed)); + nex * clone = cn.get_nex_creator().clone(cn.get_nex_creator().mk_sum(_6aad, abcd, aaccd, add, eae, eac, ed)); + clone->simplify(&clone); + SASSERT(clone->is_simplified()); TRACE("nla_cn", tout << "clone = " << *clone << "\n";); #endif // test_cn_on_expr(cn.get_nex_creator().mk_sum(aad, abcd, aaccd, add, eae, eac, ed), cn); - test_cn_on_expr(cn.get_nex_creator().mk_sum(_6aad, abcd, aaccd, add, eae, eac, ed), cn); + test_cn_on_expr(to_sum(clone), cn); // TRACE("nla_cn", tout << "done\n";); // test_cn_on_expr(a*b*d + a*b*c + c*b*d + a*c*d); // TRACE("nla_cn", tout << "done\n";); @@ -1981,6 +1996,7 @@ void test_replace_column() { void setup_args_parser(argument_parser & parser) { parser.add_option_with_help_string("-nla_cn", "test cross nornmal form"); + parser.add_option_with_help_string("-nla_sim", "test nex simplify"); parser.add_option_with_help_string("-nla_blfmz_mf", "test_basic_lemma_for_mon_zero_from_factor_to_monomial"); parser.add_option_with_help_string("-nla_blfmz_fm", "test_basic_lemma_for_mon_zero_from_monomials_to_factor"); parser.add_option_with_help_string("-nla_order", "test nla_solver order lemma"); @@ -3684,6 +3700,14 @@ void test_lp_local(int argn, char**argv) { return finalize(0); } + if (args_parser.option_is_used("-nla_sim")) { +#ifdef Z3DEBUG + nla::test_simplify(); +#endif + return finalize(0); + } + + if (args_parser.option_is_used("-nla_order")) { #ifdef Z3DEBUG test_nla_order_lemma();