diff --git a/src/math/lp/nex.h b/src/math/lp/nex.h index 7e973baad..f6579439d 100644 --- a/src/math/lp/nex.h +++ b/src/math/lp/nex.h @@ -49,10 +49,14 @@ inline std::ostream & operator<<(std::ostream& out, expr_type t) { return out; } -class nex; +// forward definitions +class nex; class nex_scalar; -class nex_pow; // forward definitions +class nex_pow; +class nex_mul; +class nex_var; +class nex_sum; // This is the class of non-linear expressions @@ -72,6 +76,18 @@ public: return true; } } + nex_mul& to_mul(); + nex_mul const& to_mul() const; + + nex_sum& to_sum(); + nex_sum const& to_sum() const; + + nex_var& to_var(); + nex_var const& to_var() const; + + nex_scalar& to_scalar(); + nex_scalar const& to_scalar() const; + virtual unsigned number_of_child_powers() const { return 0; } virtual const nex* get_child_exp(unsigned) const { return this; } virtual unsigned get_child_pow(unsigned) const { return 1; } @@ -130,9 +146,6 @@ public: bool is_linear() const { return true; } }; -const nex_scalar * to_scalar(const nex* a); -class nex_sum; - class nex_pow { nex* m_e; int m_power; @@ -230,7 +243,7 @@ public: void add_child(nex* e) { if (e->is_scalar()) { - m_coeff *= to_scalar(e)->value(); + m_coeff *= e->to_scalar().value(); return; } add_child_in_power(e, 1); @@ -250,7 +263,7 @@ public: void add_child_in_power(nex* e, int power) { if (e->is_scalar()) { - m_coeff *= (to_scalar(e)->value()).expt(power); + m_coeff *= (e->to_scalar().value()).expt(power); } else { m_children.push_back(nex_pow(e, power)); @@ -265,11 +278,6 @@ public: return false; } - static const nex_var* to_var(const nex*a) { - SASSERT(a->is_var()); - return static_cast(a); - } - void get_powers_from_mul(std::unordered_map & r) const { TRACE("nla_cn_details", tout << "powers of " << *this << "\n";); r.clear(); @@ -277,7 +285,7 @@ public: if (!c.e()->is_var()) { continue; } - lpvar j = to_var(c.e())->var(); + lpvar j = c.e()->to_var().var(); SASSERT(r.find(j) == r.end()); r[j] = c.pow(); } @@ -397,40 +405,24 @@ public: #endif }; -inline const nex_sum* to_sum(const nex* a) { - SASSERT(a->is_sum()); - return static_cast(a); -} +inline nex_sum& nex::to_sum() { SASSERT(is_sum()); return *static_cast(this); } +inline nex_sum const& nex::to_sum() const { SASSERT(is_sum()); return *static_cast(this); } +inline nex_var& nex::to_var() { SASSERT(is_var()); return *static_cast(this); } +inline nex_var const& nex::to_var() const { SASSERT(is_var()); return *static_cast(this); } +inline nex_mul& nex::to_mul() { SASSERT(is_mul()); return *static_cast(this); } +inline nex_mul const& nex::to_mul() const { SASSERT(is_mul()); return *static_cast(this); } +inline nex_scalar& nex::to_scalar() { SASSERT(is_scalar()); return *static_cast(this); } +inline nex_scalar const& nex::to_scalar() const { SASSERT(is_scalar()); return *static_cast(this); } -inline nex_sum* to_sum(nex * a) { - SASSERT(a->is_sum()); - return static_cast(a); -} +inline const nex_sum* to_sum(const nex* a) { return &(a->to_sum()); } +inline nex_sum* to_sum(nex * a) { return &(a->to_sum()); } +inline const nex_var* to_var(const nex* a) { return &(a->to_var()); } +inline nex_var* to_var(nex * a) { return &(a->to_var()); } +inline const nex_scalar* to_scalar(const nex* a) { return &(a->to_scalar()); } +inline nex_scalar* to_scalar(nex * a) { return &(a->to_scalar()); } +inline const nex_mul* to_mul(const nex* a) { return &(a->to_mul()); } +inline nex_mul* to_mul(nex * a) { return &(a->to_mul()); } -inline const nex_var* to_var(const nex*a) { - SASSERT(a->is_var()); - return static_cast(a); -} - -inline const nex_scalar* to_scalar(const nex*a) { - SASSERT(a->is_scalar()); - return static_cast(a); -} - -inline nex_scalar* to_scalar(nex*a) { - SASSERT(a->is_scalar()); - return static_cast(a); -} - -inline const nex_mul* to_mul(const nex*a) { - SASSERT(a->is_mul()); - return static_cast(a); -} - -inline nex_mul* to_mul(nex*a) { - SASSERT(a->is_mul()); - return static_cast(a); -} inline std::ostream& operator<<(std::ostream& out, const nex& e ) { return e.print(out); diff --git a/src/math/lp/nex_creator.cpp b/src/math/lp/nex_creator.cpp index bf83bf047..c32ae72fd 100644 --- a/src/math/lp/nex_creator.cpp +++ b/src/math/lp/nex_creator.cpp @@ -643,28 +643,23 @@ void nex_creator::sort_join_sum(ptr_vector & children) { void nex_creator::simplify_children_of_sum(ptr_vector & children) { TRACE("grobner_d", print_vector_of_ptrs(children, tout);); ptr_vector to_promote; - int skipped = 0; - for (unsigned j = 0; j < children.size(); j++) { - nex* e = children[j] = simplify(children[j]); + bool skipped = false; + unsigned j = 0; + for (nex* e : children) { + e = simplify(e); if (e->is_sum()) { - to_promote.push_back(e); - } else if (is_zero_scalar(e)) { - skipped ++; + to_promote.push_back(e); } + else if (is_zero_scalar(e) || (e->is_mul() && e->to_mul().coeff().is_zero())) { + skipped = true; continue; - } else if (e->is_mul() && to_mul(e)->coeff().is_zero() ) { - skipped ++; - continue; - }else { - unsigned offset = to_promote.size() + skipped; - if (offset) { - children[j - offset] = e; - } + } + else { + if (skipped) + children[j++] = e; } } - + children.shrink(j); TRACE("grobner_d", print_vector_of_ptrs(children, tout);); - children.shrink(children.size() - to_promote.size() - skipped); - for (nex *e : to_promote) { for (nex *ee : *(to_sum(e)->children_ptr())) { if (!is_zero_scalar(ee))