From f9beef19ce3e2b17f693c4117cbc345ed2b6f64e Mon Sep 17 00:00:00 2001
From: Lev Nachmanson <levnach@hotmail.com>
Date: Thu, 3 Oct 2019 14:31:45 -0700
Subject: [PATCH] fixes in nex expressions

Signed-off-by: Lev Nachmanson <levnach@hotmail.com>
---
 src/math/lp/cross_nested.h  | 45 ++++++++-----------------------------
 src/math/lp/nex.h           | 28 +++++++++++++++++++++++
 src/math/lp/nex_creator.cpp | 15 ++++++++++---
 src/test/lp/lp.cpp          | 17 +++++++++-----
 4 files changed, 60 insertions(+), 45 deletions(-)

diff --git a/src/math/lp/cross_nested.h b/src/math/lp/cross_nested.h
index 50d01a2ee..d09c92653 100644
--- a/src/math/lp/cross_nested.h
+++ b/src/math/lp/cross_nested.h
@@ -78,7 +78,7 @@ public:
     nex* extract_common_factor(nex* e) {
         nex_sum* c = to_sum(e);
         TRACE("nla_cn", tout << "c=" << *c << "\n"; tout << "occs:"; dump_occurences(tout, m_nex_creator.occurences_map()) << "\n";);
-        unsigned size = c->children().size();
+        unsigned size = c->size();
         bool have_factor = false;
         for(const auto & p : m_nex_creator.occurences_map()) {
             if (p.second.m_occs == size) {
@@ -131,7 +131,7 @@ public:
         nex_mul* cm; 
         *c = cm = m_nex_creator.mk_mul(f, c_over_f);
         TRACE("nla_cn", tout << "common factor=" << *f << ", c=" << **c << "\ne = " << *m_e << "\n";);
-        explore_expr_on_front_elem(cm->children()[1].ee(),  front);
+        explore_expr_on_front_elem((*cm)[1].ee(),  front);
         return true;
     }
 
@@ -193,7 +193,7 @@ public:
 
     void calc_occurences(nex_sum* e) {
         clear_maps();
-        for (const auto * ce : e->children()) {
+        for (const auto * ce : *e) {
             if (ce->is_mul()) {
                 to_mul(ce)->get_powers_from_mul(m_nex_creator.powers());
                 update_occurences_with_powers();
@@ -338,7 +338,7 @@ public:
     // The result is sorted by large number of occurences first
     vector<std::pair<lpvar, occ>> get_mult_occurences(const nex_sum* e) {
         clear_maps();
-        for (const auto * ce : e->children()) {
+        for (const auto * ce : *e) {
             if (ce->is_mul()) {
                 to_mul(ce)->get_powers_from_mul(m_nex_creator.powers());
                 update_occurences_with_powers();
@@ -375,7 +375,7 @@ public:
         TRACE("nla_cn_details", tout << "e = " << * e << ", j = " << m_nex_creator.ch(j) << std::endl;);
         a = m_nex_creator.mk_sum();
         m_b_split_vec.clear();
-        for (nex * ce: e->children()) {
+        for (nex * ce: *e) {
             if (is_divisible_by_var(ce, j)) {
                 a->add_child(m_nex_creator.mk_div(ce , j));
             } else {
@@ -385,7 +385,7 @@ public:
             }        
         }
         TRACE("nla_cn_details", tout << "a = " << *a << "\n";);
-        SASSERT(a->children().size() >= 2 && m_b_split_vec.size());
+        SASSERT(a->size() >= 2 && m_b_split_vec.size());
         a = to_sum(m_nex_creator.simplify_sum(a));
         
         if (m_b_split_vec.size() == 1) {
@@ -402,12 +402,12 @@ public:
         TRACE("nla_cn_details", tout << "b = " << *b << "\n";);
         e = m_nex_creator.mk_sum(m_nex_creator.mk_mul(m_nex_creator.mk_var(j), a), b); // e = j*a + b
         if (!a->is_linear()) {
-            nex **ptr_to_a = (to_mul(to_sum(e)->children()[0]))->children()[1].ee();
+            nex **ptr_to_a = ((*to_mul((*to_sum(e))[0])))[1].ee();
             push_to_front(front, ptr_to_a);
         }
         
         if (b->is_sum() && !to_sum(b)->is_linear()) {
-            nex **ptr_to_a = &(to_sum(e)->children()[1]);
+            nex **ptr_to_a = &((*to_sum(e))[1]);
             push_to_front(front, ptr_to_a);
         }
     }
@@ -416,7 +416,7 @@ public:
         if (b == nullptr) {
             e = m_nex_creator.mk_mul(m_nex_creator.mk_var(j), a);
             if (!to_sum(a)->is_linear())
-                push_to_front(front, to_mul(e)->children()[1].ee());
+                push_to_front(front, (*to_mul(e))[1].ee());
         } else {
             update_front_with_split_with_non_empty_b(e, j, front, a, b);
         }
@@ -442,33 +442,6 @@ public:
         return true;
     }
 
-    static std::unordered_set<lpvar> get_vars_of_expr(const nex *e ) {
-        std::unordered_set<lpvar> r;
-        switch (e->type()) {
-        case expr_type::SCALAR:
-            return r;
-        case expr_type::SUM:
-            {
-                for (auto c: to_sum(e)->children())
-                    for ( lpvar j : get_vars_of_expr(c))
-                        r.insert(j);
-            }
-        case expr_type::MUL:
-            {
-                for (auto &c: to_mul(e)->children())
-                    for ( lpvar j : get_vars_of_expr(c.e()))
-                        r.insert(j);
-            }
-            return r;
-        case expr_type::VAR:
-            r.insert(to_var(e)->var());
-            return r;
-        default:
-            TRACE("nla_cn_details", tout << e->type() << "\n";);
-            SASSERT(false);
-            return r;
-        }
-    }
     
     ~cross_nested() {
         m_nex_creator.clear();
diff --git a/src/math/lp/nex.h b/src/math/lp/nex.h
index 881d2d71b..397242a62 100644
--- a/src/math/lp/nex.h
+++ b/src/math/lp/nex.h
@@ -398,5 +398,33 @@ inline bool less_than_nex_standard(const nex* a, const nex* b) {
     lt_on_vars lt = [](lpvar j, lpvar k) { return j < k; };
     return less_than_nex(a, b, lt);
 }
+
+inline std::unordered_set<lpvar> get_vars_of_expr(const nex *e ) {
+        std::unordered_set<lpvar> r;
+        switch (e->type()) {
+        case expr_type::SCALAR:
+            return r;
+        case expr_type::SUM:
+            {
+                for (auto c: *to_sum(e))
+                    for ( lpvar j : get_vars_of_expr(c))
+                        r.insert(j);
+            }
+        case expr_type::MUL:
+            {
+                for (auto &c: *to_mul(e))
+                    for ( lpvar j : get_vars_of_expr(c.e()))
+                        r.insert(j);
+            }
+            return r;
+        case expr_type::VAR:
+            r.insert(to_var(e)->var());
+            return r;
+        default:
+            TRACE("nla_cn_details", tout << e->type() << "\n";);
+            SASSERT(false);
+            return r;
+        }
+    }
 }
     
diff --git a/src/math/lp/nex_creator.cpp b/src/math/lp/nex_creator.cpp
index 078ab6b6f..5916a9dae 100644
--- a/src/math/lp/nex_creator.cpp
+++ b/src/math/lp/nex_creator.cpp
@@ -329,13 +329,22 @@ nex* nex_creator::simplify_sum(nex_sum *e) {
 }
 
 bool nex_creator::sum_is_simplified(const nex_sum* e) const {
-    TRACE("nla_cn_details",  tout << ++ lp::lp_settings::ddd << std::endl;);
     
     if (e->size() < 2) return false;
+    bool scalar = false;
     for (nex * ee : *e) {
         if (ee->is_sum())
             return false;
-        if (ee->is_scalar() && to_scalar(ee)->value().is_zero())
+        if (ee->is_scalar()) {
+            if (scalar) {
+                return false;
+            }
+            if (to_scalar(ee)->value().is_zero()) {
+                return false;
+            }
+            scalar = true;
+        }
+        if (!is_simplified(ee))
             return false;
     }
     return true;
@@ -550,7 +559,7 @@ nex * nex_creator::mk_div_by_mul(const nex* a, const nex_mul* b) {
         return mk_div_sum_by_mul(to_sum(a), b);
     }
     if (a->is_var() || (a->is_mul() && to_mul(a)->size() == 1)) {
-        SASSERT(b->get_degree() == 1 && !b->has_a_coeff() && b->contains(to_var(a)->var()));        
+        SASSERT(b->get_degree() == 1 && !b->has_a_coeff() && get_vars_of_expr(a) == get_vars_of_expr(b));        
         return mk_scalar(rational(1));
     }
     const nex_mul* am = to_mul(a);
diff --git a/src/test/lp/lp.cpp b/src/test/lp/lp.cpp
index 50794e1cd..5fc25ebf7 100644
--- a/src/test/lp/lp.cpp
+++ b/src/test/lp/lp.cpp
@@ -69,7 +69,8 @@ void test_basic_lemma_for_mon_neutral_from_monomial_to_factors();
 void test_basic_lemma_for_mon_neutral_from_factors_to_monomial();
 
 void test_cn_on_expr(nex_sum *t, cross_nested& cn) {
-    TRACE("nla_cn", tout << "t=" << *t << '\n';);
+    t = to_sum(cn.get_nex_creator().simplify(t));
+    TRACE("nla_test", tout << "t=" << *t << '\n';);
     cn.run(t);
 }
 
@@ -146,13 +147,14 @@ void test_simplify() {
 void test_cn() {
     cross_nested cn(
         [](const nex* n) {
-                           TRACE("nla_test", tout << *n << "\n";);
-                           return false;
-                       } ,
+            TRACE("nla_test", tout <<"cn form = " <<  *n << "\n";);
+            return false;
+        } ,
         [](unsigned) { return false; },
         []{ return 1; });
     enable_trace("nla_test");
-    enable_trace("nla_test_details");
+    //    enable_trace("nla_cn");
+    //   enable_trace("nla_test_details");
     auto & cr = cn.get_nex_creator();
     cr.active_vars_weights().resize(20);
     for (unsigned j = 0; j < cr.active_vars_weights().size(); j++)
@@ -164,6 +166,10 @@ void test_cn() {
     nex_var* d = cr.mk_var(3);
     nex_var* e = cr.mk_var(4);
     nex_var* g = cr.mk_var(6);
+    nex_sum * a_p_ae_sq = cr.mk_sum(a, cr.mk_mul(a, e, e));
+    a_p_ae_sq = to_sum(cr.simplify(a_p_ae_sq));
+    test_cn_on_expr(a_p_ae_sq, cn);
+
     nex* min_1 = cr.mk_scalar(rational(-1));
     // test_cn_on_expr(min_1*c*e + min_1*b*d + min_1*a*b + a*c);
     nex* bcd = cr.mk_mul(b, c, d);
@@ -171,7 +177,6 @@ void test_cn() {
     bcg->add_child(min_1);
     nex_sum* t = cr.mk_sum(bcd, bcg);
     test_cn_on_expr(t, cn);
-    nex* aad = cr.mk_mul(a, a, d);
     nex* abcd = cr.mk_mul(a, b, c, d);
     nex* aaccd = cr.mk_mul(a, a, c, c, d);
     nex* add = cr.mk_mul(a, d, d);