diff --git a/src/math/lp/nla_expr.h b/src/math/lp/nla_expr.h index bdb7a5d07..31bc579af 100644 --- a/src/math/lp/nla_expr.h +++ b/src/math/lp/nla_expr.h @@ -67,6 +67,8 @@ class nla_expr { nla_expr& back() { return m_es.back(); } const nla_expr* begin() const { return m_es.begin(); } const nla_expr* end() const { return m_es.end(); } + typename std::vector::iterator begin() { return m_es.begin(); } + typename std::vector::iterator end() { return m_es.end(); } unsigned size() const { return m_es.size(); } void sort() { std::sort(m_order.begin(), m_order.end(), [this](unsigned i, unsigned j) { return m_es[i] < m_es[j]; }); @@ -127,13 +129,12 @@ public: } void simplify() { + if (is_simple()) return; + bool has_sum = false; if (is_sum()) { - bool has_sum = false; - for (unsigned j = 0; j < m_children.es().size(); j++) { - auto& e = m_children.es()[j]; + for (auto & e : m_children) { e.simplify(); - if (e.is_sum()) - has_sum = true; + has_sum |= e.is_sum(); } if (has_sum) { nla_expr n(expr_type::SUM); @@ -141,15 +142,12 @@ public: n += e; } *this = n; - } - + } } else if (is_mul()) { bool has_mul = false; - for (unsigned j = 0; j < m_children.es().size(); j++) { - auto& e = m_children.es()[j]; + for (auto & e : m_children) { e.simplify(); - if (e.is_mul()) - has_mul = true; + has_mul |= e.is_mul(); } if (has_mul) { nla_expr n(expr_type::MUL); @@ -209,7 +207,7 @@ public: switch(m_type) { case expr_type::SUM: case expr_type::MUL: - return m_children.size() <= 1; + return false; default: return true;