diff --git a/src/math/lp/cross_nested.h b/src/math/lp/cross_nested.h index 87fe6a512..9f7da6b3c 100644 --- a/src/math/lp/cross_nested.h +++ b/src/math/lp/cross_nested.h @@ -590,6 +590,52 @@ public: } return nullptr; } + + nex * normalize_sum(nex_sum* a) { + for (unsigned j = 0; j < a->size(); j ++) { + a->children()[j] = normalize(a->children()[j]); + } + a->simplify(); + return a; + } + + nex * normalize_mul(nex_mul* a) { + int sum_j = -1; + for (unsigned j = 0; j < a->size(); j ++) { + a->children()[j] = normalize(a->children()[j]); + if (a->children()[j]->is_sum()) + sum_j = j; + } + + if (sum_j == -1) + return a; + + nex_sum *r = mk_sum(); + nex_sum *as = to_sum(a->children()[sum_j]); + for (unsigned k = 0; k < as->size(); k++) { + nex_mul *b = mk_mul(as->children()[k]); + r->add_child(b); + for (unsigned j = 0; j < a->size(); j ++) { + if ((int)j != sum_j) + b->add_child(a->children()[j]); + } + } + return normalize_sum(r); + } + + + + nex * normalize(nex* a) { + if (a->is_simple()) + return a; + nex *r; + if (a->is_mul()) { + r = normalize_mul(to_mul(a)); + } else { + r = normalize_sum(to_sum(a)); + } + r->sort(); + } #endif }; diff --git a/src/math/lp/nex.h b/src/math/lp/nex.h index dba3cef5c..7ee473687 100644 --- a/src/math/lp/nex.h +++ b/src/math/lp/nex.h @@ -35,9 +35,6 @@ inline std::ostream & operator<<(std::ostream& out, expr_type t) { case expr_type::SCALAR: out << "SCALAR"; break; - case expr_type::UNDEF: - out << "UNDEF"; - break; default: out << "NN"; break; @@ -79,7 +76,13 @@ public: UNREACHABLE(); return nullptr; } + #ifdef Z3DEBUG + virtual void sort() {}; + #endif }; +#if Z3DEBUG +bool operator<(const nex& a , const nex& b); +#endif std::ostream& operator<<(std::ostream& out, const nex&); class nex_var : public nex { @@ -217,6 +220,14 @@ public: void simplify() { promote_children_by_type(&m_children, expr_type::MUL); } + #ifdef Z3DEBUG + virtual void sort() { + for (nex * c : m_children) { + c->sort(); + } + std::sort(m_children.begin(), m_children.end(), [](const nex* a, const nex* b) { return *a < *b; }); + } + #endif }; @@ -285,6 +296,16 @@ public: } void add_child(nex* e) { m_children.push_back(e); } +#ifdef Z3DEBUG + virtual void sort() { + for (nex * c : m_children) { + c->sort(); + } + + + std::sort(m_children.begin(), m_children.end(), [](const nex* a, const nex* b) { return *a < *b; }); + } +#endif }; inline const nex_sum* to_sum(const nex*a) { @@ -322,6 +343,47 @@ inline std::ostream& operator<<(std::ostream& out, const nex& e ) { return e.print(out); } +#if Z3DEBUG +inline bool operator<(const ptr_vector&a , const ptr_vector& b) { + int r = (int)a.size() - (int)b.size(); + if (r) + return r < 0; + for (unsigned j = 0; j < a.size(); j++) { + if (*a[j] < *b[j]) { + return true; + } + if (*b[j] < *a[j]) { + return false; + } + } + return false; +} + +inline bool operator<(const nex& a , const nex& b) { + int r = (int)(a.type()) - (int)(b.type()); + ptr_vector ch; + if (r) { + return r < 0; + } + switch (a.type()) { + case expr_type::VAR: { + return to_var(&a)->var() < to_var(&b)->var(); + } + case expr_type::SCALAR: { + return to_scalar(&a)->value() < to_scalar(&b)->value(); + } + case expr_type::MUL: { + return to_mul(&a)->children() < to_mul(&b)->children(); + } + case expr_type::SUM: { + return to_mul(&a)->children() < to_mul(&b)->children(); + } + default: + SASSERT(false); + return false; + } +} +#endif }