From 3929e002a59cac029353fa4ec7a7f89d31c9e330 Mon Sep 17 00:00:00 2001
From: Lev Nachmanson <levnach@hotmail.com>
Date: Thu, 3 Oct 2019 11:49:48 -0700
Subject: [PATCH] simplify nex_creator

Signed-off-by: Lev Nachmanson <levnach@hotmail.com>
---
 src/math/lp/nex.h           | 17 ++++++++++-------
 src/math/lp/nex_creator.cpp | 26 +++++++++++++-------------
 src/test/lp/lp.cpp          | 11 ++++++++---
 3 files changed, 31 insertions(+), 23 deletions(-)

diff --git a/src/math/lp/nex.h b/src/math/lp/nex.h
index 120428395..881d2d71b 100644
--- a/src/math/lp/nex.h
+++ b/src/math/lp/nex.h
@@ -202,6 +202,8 @@ public:
         return !r.is_one();
     }
 
+    const nex_pow& operator[](unsigned j) const { return m_children[j]; }
+    nex_pow& operator[](unsigned j) { return m_children[j]; }
     const nex_pow* begin() const { return m_children.begin(); }
     const nex_pow* end() const { return m_children.end(); }
     nex_pow* begin() { return m_children.begin(); }
@@ -210,7 +212,7 @@ public:
     void add_child_in_power(nex* e, int power) { m_children.push_back(nex_pow(e, power)); }
 
     bool contains(lpvar j) const {
-        for (const nex_pow& c : children()) {
+        for (const nex_pow& c : *this) {
             if (c.e()->contains(j))
                 return true;
         }
@@ -225,7 +227,7 @@ public:
     void get_powers_from_mul(std::unordered_map<lpvar, unsigned> & r) const {
         TRACE("nla_cn_details", tout << "powers of " << *this << "\n";);
         r.clear();
-        for (const auto & c : children()) {
+        for (const auto & c : *this) {
             if (!c.e()->is_var()) {
                 continue;
             }
@@ -238,7 +240,7 @@ public:
 
     int get_degree() const {
         int degree = 0;       
-        for (const auto& p : children()) {
+        for (const auto& p : *this) {
             degree +=  p.e()->get_degree() * p.pow();
         }
         return degree;
@@ -274,7 +276,7 @@ public:
 
     bool is_linear() const {
         TRACE("nex_details", tout << *this << "\n";);
-        for (auto  e : children()) {
+        for (auto  e : *this) {
             if (!e->is_linear())
                 return false;
         }
@@ -286,7 +288,7 @@ public:
     bool is_a_linear_term() const {
         TRACE("nex_details", tout << *this << "\n";);
         unsigned number_of_non_scalars = 0;
-        for (auto  e : children()) {
+        for (auto  e : *this) {
             int d = e->get_degree();
             if (d == 0) continue;
             if (d > 1) return false;
@@ -324,12 +326,13 @@ public:
 
     int get_degree() const {
         int degree = 0;       
-        for (auto  e : children()) {
+        for (auto  e : *this) {
             degree = std::max(degree, e->get_degree());
         }
         return degree;
     }
-    
+    const nex* operator[](unsigned j) const { return m_children[j]; }
+    nex*& operator[](unsigned j) { return m_children[j]; }
     const ptr_vector<nex>::const_iterator begin() const { return m_children.begin(); }
     const ptr_vector<nex>::const_iterator end() const { return m_children.end(); }
     ptr_vector<nex>::iterator begin() { return m_children.begin(); }
diff --git a/src/math/lp/nex_creator.cpp b/src/math/lp/nex_creator.cpp
index 88e2fd5e7..078ab6b6f 100644
--- a/src/math/lp/nex_creator.cpp
+++ b/src/math/lp/nex_creator.cpp
@@ -30,7 +30,7 @@ nex * nex_creator::mk_div(const nex* a, lpvar j) {
         return mk_scalar(rational(1));
     vector<nex_pow> bv; 
     bool seenj = false;
-    for (auto& p : to_mul(a)->children()) {
+    for (auto& p : *to_mul(a)) {
         const nex * c = p.e();
         int pow = p.pow();
         if (!seenj && c->contains(j)) {
@@ -102,7 +102,7 @@ void nex_creator::simplify_children_of_mul(vector<nex_pow> & children) {
 
     for (nex_pow & p : to_promote) {
         TRACE("nla_cn_details", tout << p << "\n";);
-        for (nex_pow& pp : to_mul(p.e())->children()) {
+        for (nex_pow& pp : *to_mul(p.e())) {
             TRACE("nla_cn_details", tout << pp << "\n";);
             if (!eat_scalar_pow(r, pp, p.pow()))
                 children.push_back(nex_pow(pp.e(), pp.pow() * p.pow()));            
@@ -277,7 +277,7 @@ bool nex_creator::lt(const nex* a, const nex* b) const {
 
 bool nex_creator::is_sorted(const nex_mul* e) const {
     for (unsigned j = 0; j < e->size() - 1; j++) {
-        if (!(less_than_on_nex_pow(e->children()[j], e->children()[j+1])))
+        if (!(less_than_on_nex_pow((*e)[j], (*e)[j+1])))
             return false;
     }
     return true;
@@ -290,7 +290,7 @@ bool nex_creator::mul_is_simplified(const nex_mul* e) const {
     if (e->size() == 1 && e->begin()->pow() == 1)
         return false;
     std::set<const nex*, nex_lt> s([this](const nex* a, const nex* b) {return lt(a, b); });
-    for (const auto &p : e->children()) {
+    for (const auto &p : *e) {
         const nex* ee = p.e();
         if (p.pow() == 0)
             return false;
@@ -313,8 +313,8 @@ bool nex_creator::mul_is_simplified(const nex_mul* e) const {
 nex * nex_creator::simplify_mul(nex_mul *e) {
     TRACE("nla_cn_details", tout << *e << "\n";);
     simplify_children_of_mul(e->children());
-    if (e->size() == 1 && e->children()[0].pow() == 1) 
-        return e->children()[0].e();
+    if (e->size() == 1 && (*e)[0].pow() == 1) 
+        return (*e)[0].e();
     TRACE("nla_cn_details", tout << *e << "\n";);
     SASSERT(is_simplified(e));
     return e;
@@ -323,7 +323,7 @@ nex * nex_creator::simplify_mul(nex_mul *e) {
 nex* nex_creator::simplify_sum(nex_sum *e) {
     TRACE("nla_cn_details", tout << "was e = " << *e << "\n";);
     simplify_children_of_sum(e->children());
-    nex *r = e->size() == 1? e->children()[0]: e;
+    nex *r = e->size() == 1? (*e)[0]: e;
     TRACE("nla_cn_details", tout << "became r = " << *r << "\n";);    
     return r;
 }
@@ -332,7 +332,7 @@ bool nex_creator::sum_is_simplified(const nex_sum* e) const {
     TRACE("nla_cn_details",  tout << ++ lp::lp_settings::ddd << std::endl;);
     
     if (e->size() < 2) return false;
-    for (nex * ee : e->children()) {
+    for (nex * ee : *e) {
         if (ee->is_sum())
             return false;
         if (ee->is_scalar() && to_scalar(ee)->value().is_zero())
@@ -420,8 +420,8 @@ bool nex_creator::process_mul_in_simplify_sum(nex_mul* em, std::map<nex*, ration
         SASSERT(it->pow() == 1);
         rational r = to_scalar(it->e())->value();              
         auto end = em->end();
-        if (em->size() == 2 && em->children()[1].pow() == 1) {
-            found = register_in_join_map(map, em->children()[1].e(), r);
+        if (em->size() == 2 && (*em)[1].pow() == 1) {
+            found = register_in_join_map(map, (*em)[1].e(), r);
         } else {
             nex_mul * m = new nex_mul();
             for (it++; it != end; it++) {
@@ -538,7 +538,7 @@ bool have_no_scalars(const nex_mul* a) {
 
 nex * nex_creator::mk_div_sum_by_mul(const nex_sum* m, const nex_mul* b) {
     nex_sum * r = mk_sum();
-    for (auto e : m->children()) {
+    for (auto e : *m) {
         r->add_child(mk_div_by_mul(e, b));
     }
     TRACE("nla_cn_details", tout << *r << "\n";);
@@ -557,7 +557,7 @@ nex * nex_creator::mk_div_by_mul(const nex* a, const nex_mul* b) {
     SASSERT(all_factors_are_elementary(am) && all_factors_are_elementary(b) && have_no_scalars(b));
     b->get_powers_from_mul(m_powers);
     nex_mul* ret = new nex_mul();
-    for (auto& p : am->children()) {
+    for (auto& p : *am) {
         TRACE("nla_cn_details", tout << "p = " << p << "\n";);
         const nex* e = p.e();
         if (!e->is_var()) {
@@ -634,7 +634,7 @@ void nex_creator::process_map_pair(nex *e, const rational& coeff, ptr_vector<nex
                 children.push_back(mk_mul(mk_scalar(coeff), e));
             } else {
                 SASSERT(e->is_mul());
-                nex* first = to_mul(e)->children()[0].e();
+                nex* first = (*to_mul(e))[0].e();
                 if (first->is_scalar()) {
                     to_scalar(first)->value() = coeff;
                     children.push_back(e);
diff --git a/src/test/lp/lp.cpp b/src/test/lp/lp.cpp
index f88d2cf4f..50794e1cd 100644
--- a/src/test/lp/lp.cpp
+++ b/src/test/lp/lp.cpp
@@ -83,8 +83,8 @@ void test_simplify() {
         [](unsigned) { return false; },
         []() { return 1; } // for random
                     );
-    enable_trace("nla_cn");
-    enable_trace("nla_cn_details");
+    // enable_trace("nla_cn");
+    // enable_trace("nla_cn_details");
     //    enable_trace("nla_cn_details_");
     enable_trace("nla_test");
     
@@ -112,7 +112,7 @@ void test_simplify() {
     auto n = r.mk_mul(a);
     n->add_child_in_power(b, 7);
     n->add_child(r.mk_scalar(rational(3)));
-    n->add_child_in_power(r.mk_scalar(rational(4)), 2);
+    n->add_child_in_power(r.mk_scalar(rational(2)), 2);
     n->add_child(r.mk_scalar(rational(1)));
     TRACE("nla_test_", tout << "n = " << *n << "\n";); 
     m->add_child_in_power(n, 3);
@@ -136,6 +136,11 @@ void test_simplify() {
     TRACE("nla_test", tout << "before simplify sum e = " << *e << "\n";);
     e = to_sum(r.simplify(e));
     TRACE("nla_test", tout << "simplified sum e = " << *e << "\n";);
+
+    nex * pr = r.mk_mul(a, b, b);
+    TRACE("nla_test", tout << "before simplify pr = " << *pr << "\n";);
+    r.simplify(pr);
+    TRACE("nla_test", tout << "simplified sum e = " << *pr << "\n";);
 }
 
 void test_cn() {