diff --git a/src/math/lp/nex.h b/src/math/lp/nex.h index 120428395..881d2d71b 100644 --- a/src/math/lp/nex.h +++ b/src/math/lp/nex.h @@ -202,6 +202,8 @@ public: return !r.is_one(); } + const nex_pow& operator[](unsigned j) const { return m_children[j]; } + nex_pow& operator[](unsigned j) { return m_children[j]; } const nex_pow* begin() const { return m_children.begin(); } const nex_pow* end() const { return m_children.end(); } nex_pow* begin() { return m_children.begin(); } @@ -210,7 +212,7 @@ public: void add_child_in_power(nex* e, int power) { m_children.push_back(nex_pow(e, power)); } bool contains(lpvar j) const { - for (const nex_pow& c : children()) { + for (const nex_pow& c : *this) { if (c.e()->contains(j)) return true; } @@ -225,7 +227,7 @@ public: void get_powers_from_mul(std::unordered_map & r) const { TRACE("nla_cn_details", tout << "powers of " << *this << "\n";); r.clear(); - for (const auto & c : children()) { + for (const auto & c : *this) { if (!c.e()->is_var()) { continue; } @@ -238,7 +240,7 @@ public: int get_degree() const { int degree = 0; - for (const auto& p : children()) { + for (const auto& p : *this) { degree += p.e()->get_degree() * p.pow(); } return degree; @@ -274,7 +276,7 @@ public: bool is_linear() const { TRACE("nex_details", tout << *this << "\n";); - for (auto e : children()) { + for (auto e : *this) { if (!e->is_linear()) return false; } @@ -286,7 +288,7 @@ public: bool is_a_linear_term() const { TRACE("nex_details", tout << *this << "\n";); unsigned number_of_non_scalars = 0; - for (auto e : children()) { + for (auto e : *this) { int d = e->get_degree(); if (d == 0) continue; if (d > 1) return false; @@ -324,12 +326,13 @@ public: int get_degree() const { int degree = 0; - for (auto e : children()) { + for (auto e : *this) { degree = std::max(degree, e->get_degree()); } return degree; } - + const nex* operator[](unsigned j) const { return m_children[j]; } + nex*& operator[](unsigned j) { return m_children[j]; } const ptr_vector::const_iterator begin() const { return m_children.begin(); } const ptr_vector::const_iterator end() const { return m_children.end(); } ptr_vector::iterator begin() { return m_children.begin(); } diff --git a/src/math/lp/nex_creator.cpp b/src/math/lp/nex_creator.cpp index 88e2fd5e7..078ab6b6f 100644 --- a/src/math/lp/nex_creator.cpp +++ b/src/math/lp/nex_creator.cpp @@ -30,7 +30,7 @@ nex * nex_creator::mk_div(const nex* a, lpvar j) { return mk_scalar(rational(1)); vector bv; bool seenj = false; - for (auto& p : to_mul(a)->children()) { + for (auto& p : *to_mul(a)) { const nex * c = p.e(); int pow = p.pow(); if (!seenj && c->contains(j)) { @@ -102,7 +102,7 @@ void nex_creator::simplify_children_of_mul(vector & children) { for (nex_pow & p : to_promote) { TRACE("nla_cn_details", tout << p << "\n";); - for (nex_pow& pp : to_mul(p.e())->children()) { + for (nex_pow& pp : *to_mul(p.e())) { TRACE("nla_cn_details", tout << pp << "\n";); if (!eat_scalar_pow(r, pp, p.pow())) children.push_back(nex_pow(pp.e(), pp.pow() * p.pow())); @@ -277,7 +277,7 @@ bool nex_creator::lt(const nex* a, const nex* b) const { bool nex_creator::is_sorted(const nex_mul* e) const { for (unsigned j = 0; j < e->size() - 1; j++) { - if (!(less_than_on_nex_pow(e->children()[j], e->children()[j+1]))) + if (!(less_than_on_nex_pow((*e)[j], (*e)[j+1]))) return false; } return true; @@ -290,7 +290,7 @@ bool nex_creator::mul_is_simplified(const nex_mul* e) const { if (e->size() == 1 && e->begin()->pow() == 1) return false; std::set s([this](const nex* a, const nex* b) {return lt(a, b); }); - for (const auto &p : e->children()) { + for (const auto &p : *e) { const nex* ee = p.e(); if (p.pow() == 0) return false; @@ -313,8 +313,8 @@ bool nex_creator::mul_is_simplified(const nex_mul* e) const { nex * nex_creator::simplify_mul(nex_mul *e) { TRACE("nla_cn_details", tout << *e << "\n";); simplify_children_of_mul(e->children()); - if (e->size() == 1 && e->children()[0].pow() == 1) - return e->children()[0].e(); + if (e->size() == 1 && (*e)[0].pow() == 1) + return (*e)[0].e(); TRACE("nla_cn_details", tout << *e << "\n";); SASSERT(is_simplified(e)); return e; @@ -323,7 +323,7 @@ nex * nex_creator::simplify_mul(nex_mul *e) { nex* nex_creator::simplify_sum(nex_sum *e) { TRACE("nla_cn_details", tout << "was e = " << *e << "\n";); simplify_children_of_sum(e->children()); - nex *r = e->size() == 1? e->children()[0]: e; + nex *r = e->size() == 1? (*e)[0]: e; TRACE("nla_cn_details", tout << "became r = " << *r << "\n";); return r; } @@ -332,7 +332,7 @@ bool nex_creator::sum_is_simplified(const nex_sum* e) const { TRACE("nla_cn_details", tout << ++ lp::lp_settings::ddd << std::endl;); if (e->size() < 2) return false; - for (nex * ee : e->children()) { + for (nex * ee : *e) { if (ee->is_sum()) return false; if (ee->is_scalar() && to_scalar(ee)->value().is_zero()) @@ -420,8 +420,8 @@ bool nex_creator::process_mul_in_simplify_sum(nex_mul* em, std::mappow() == 1); rational r = to_scalar(it->e())->value(); auto end = em->end(); - if (em->size() == 2 && em->children()[1].pow() == 1) { - found = register_in_join_map(map, em->children()[1].e(), r); + if (em->size() == 2 && (*em)[1].pow() == 1) { + found = register_in_join_map(map, (*em)[1].e(), r); } else { nex_mul * m = new nex_mul(); for (it++; it != end; it++) { @@ -538,7 +538,7 @@ bool have_no_scalars(const nex_mul* a) { nex * nex_creator::mk_div_sum_by_mul(const nex_sum* m, const nex_mul* b) { nex_sum * r = mk_sum(); - for (auto e : m->children()) { + for (auto e : *m) { r->add_child(mk_div_by_mul(e, b)); } TRACE("nla_cn_details", tout << *r << "\n";); @@ -557,7 +557,7 @@ nex * nex_creator::mk_div_by_mul(const nex* a, const nex_mul* b) { SASSERT(all_factors_are_elementary(am) && all_factors_are_elementary(b) && have_no_scalars(b)); b->get_powers_from_mul(m_powers); nex_mul* ret = new nex_mul(); - for (auto& p : am->children()) { + for (auto& p : *am) { TRACE("nla_cn_details", tout << "p = " << p << "\n";); const nex* e = p.e(); if (!e->is_var()) { @@ -634,7 +634,7 @@ void nex_creator::process_map_pair(nex *e, const rational& coeff, ptr_vectoris_mul()); - nex* first = to_mul(e)->children()[0].e(); + nex* first = (*to_mul(e))[0].e(); if (first->is_scalar()) { to_scalar(first)->value() = coeff; children.push_back(e); diff --git a/src/test/lp/lp.cpp b/src/test/lp/lp.cpp index f88d2cf4f..50794e1cd 100644 --- a/src/test/lp/lp.cpp +++ b/src/test/lp/lp.cpp @@ -83,8 +83,8 @@ void test_simplify() { [](unsigned) { return false; }, []() { return 1; } // for random ); - enable_trace("nla_cn"); - enable_trace("nla_cn_details"); + // enable_trace("nla_cn"); + // enable_trace("nla_cn_details"); // enable_trace("nla_cn_details_"); enable_trace("nla_test"); @@ -112,7 +112,7 @@ void test_simplify() { auto n = r.mk_mul(a); n->add_child_in_power(b, 7); n->add_child(r.mk_scalar(rational(3))); - n->add_child_in_power(r.mk_scalar(rational(4)), 2); + n->add_child_in_power(r.mk_scalar(rational(2)), 2); n->add_child(r.mk_scalar(rational(1))); TRACE("nla_test_", tout << "n = " << *n << "\n";); m->add_child_in_power(n, 3); @@ -136,6 +136,11 @@ void test_simplify() { TRACE("nla_test", tout << "before simplify sum e = " << *e << "\n";); e = to_sum(r.simplify(e)); TRACE("nla_test", tout << "simplified sum e = " << *e << "\n";); + + nex * pr = r.mk_mul(a, b, b); + TRACE("nla_test", tout << "before simplify pr = " << *pr << "\n";); + r.simplify(pr); + TRACE("nla_test", tout << "simplified sum e = " << *pr << "\n";); } void test_cn() {