From 9266ab7ed19763f8500033b607d246840b1aebaf Mon Sep 17 00:00:00 2001
From: Lev Nachmanson <levnach@hotmail.com>
Date: Sat, 17 Aug 2019 21:21:18 -0700
Subject: [PATCH] rewrite horner scheme on top of nex_expr as a pointer

Signed-off-by: Lev Nachmanson <levnach@hotmail.com>
---
 src/math/lp/cross_nested.h | 72 +++++++++++++++++++-------------------
 src/math/lp/nla_expr.h     | 40 ++++++++++++---------
 src/test/lp/lp.cpp         |  5 +--
 3 files changed, 62 insertions(+), 55 deletions(-)

diff --git a/src/math/lp/cross_nested.h b/src/math/lp/cross_nested.h
index 787c931e0..2fa1dd41d 100644
--- a/src/math/lp/cross_nested.h
+++ b/src/math/lp/cross_nested.h
@@ -41,9 +41,8 @@ class cross_nested {
     bool                                  m_done;
     std::unordered_map<lpvar, occ>        m_occurences_map;
     std::unordered_map<lpvar, unsigned>   m_powers;
-    vector<nex*> m_allocated;
-    vector<nex*> m_b_vec;
-    vector<nex*> m_b_split_vec;
+    ptr_vector<nex> m_allocated;
+    ptr_vector<nex> m_b_split_vec;
 public:
     cross_nested(std::function<bool (const nex*)> call_on_result,
                  std::function<bool (unsigned)> var_is_fixed):
@@ -56,7 +55,7 @@ public:
         m_e = e;
         
         vector<nex**> front;
-        explore_expr_on_front_elem(m_e, front);
+        explore_expr_on_front_elem(&m_e, front);
     }
 
     static nex** pop_front(vector<nex**>& front) {
@@ -80,20 +79,21 @@ public:
         add_children(r, es ...);
     }
 
-    nex_sum* mk_sum(const vector<nex*>& v) {
+    nex_sum* mk_sum(const ptr_vector<nex>& v) {
         auto r = new nex_sum();
         m_allocated.push_back(r);
         r->children() = v;
         return r;
     }
 
-    nex_mul* mk_mul(const vector<nex*>& v) {
+    nex_mul* mk_mul(const ptr_vector<nex>& v) {
         auto r = new nex_mul();
         m_allocated.push_back(r);
         r->children() = v;
         return r;
     }
 
+    
     template <typename K, typename...Args>
     nex_sum* mk_sum(K e, Args... es) {
         auto r = new nex_sum();
@@ -134,27 +134,27 @@ public:
         SASSERT((a->is_mul() && a->contains(j)) || (a->is_var() && to_var(a)->var() == j));
         if (a->is_var())
             return mk_scalar(rational(1));
-        m_b_vec.clear();
+        ptr_vector<nex> bv;
         bool seenj = false;
         for (nex* c : to_mul(a)->children()) {
             if (!seenj) {
                 if (c->contains(j)) {
                     if (!c->is_var())                     
-                        m_b_vec.push_back(mk_div(c, j));
+                        bv.push_back(mk_div(c, j));
                     seenj = true;
                     continue;
                 } 
             }
-            m_b_vec.push_back(c);
+            bv.push_back(c);
         }
-        if (m_b_vec.size() > 1) { 
-            return mk_mul(m_b_vec);
+        if (bv.size() > 1) { 
+            return mk_mul(bv);
         }
-        if (m_b_vec.size() == 1) {
-            return m_b_vec[0];
+        if (bv.size() == 1) {
+            return bv[0];
         }
 
-        SASSERT(m_b_vec.size() == 0);
+        SASSERT(bv.size() == 0);
         return mk_scalar(rational(1));
     }
 
@@ -255,20 +255,20 @@ public:
         return false;
     }
 
-    bool proceed_with_common_factor(nex*& c, vector<nex**>& front, const vector<std::pair<lpvar, occ>> & occurences) {
-        TRACE("nla_cn", tout << "c=" << *c << "\n";);
-        nex* f = extract_common_factor(c, occurences);
+    bool proceed_with_common_factor(nex** c, vector<nex**>& front, const vector<std::pair<lpvar, occ>> & occurences) {
+        TRACE("nla_cn", tout << "c=" << **c << "\n";);
+        nex* f = extract_common_factor(*c, occurences);
         if (f == nullptr) {
             TRACE("nla_cn", tout << "no common factor\n"; );
             return false;
         }
         
-        nex* c_over_f = mk_div(c, f);
+        nex* c_over_f = mk_div(*c, f);
         to_sum(c_over_f)->simplify();
-        c = mk_mul(f, c_over_f);
-        TRACE("nla_cn", tout << "common factor=" << *f << ", c=" << *c << "\ne = " << *m_e << "\n";);
+        *c = mk_mul(f, c_over_f);
+        TRACE("nla_cn", tout << "common factor=" << *f << ", c=" << **c << "\ne = " << *m_e << "\n";);
         
-        explore_expr_on_front_elem(c_over_f, front);
+        explore_expr_on_front_elem(&(*((*c)->children_ptr()))[1],  front);
         return true;
     }
 
@@ -290,11 +290,11 @@ public:
             *(front[i]) = copy[i];
     }
     
-    void explore_expr_on_front_elem_occs(nex* &c, vector<nex**>& front, const vector<std::pair<lpvar, occ>> & occurences) {
+    void explore_expr_on_front_elem_occs(nex** c, vector<nex**>& front, const vector<std::pair<lpvar, occ>> & occurences) {
         if (proceed_with_common_factor(c, front, occurences))
             return;
         TRACE("nla_cn", tout << "save c=" << *c << "; front:"; print_front(front, tout) << "\n";);           
-        nex* copy_of_c = c;
+        nex* copy_of_c = *c;
         auto copy_of_front = copy_front(front);
         for(auto& p : occurences) {
             SASSERT(p.second.m_occs > 1);
@@ -308,11 +308,11 @@ public:
             explore_of_expr_on_sum_and_var(c, j, front);
             if (m_done)
                 return;
-            TRACE("nla_cn", tout << "before restore c=" << *c << ", m_e=" << *m_e << "\n";);
-            c = copy_of_c;
-            TRACE("nla_cn", tout << "after restore c=" << *c << ", m_e=" << *m_e << "\n";);
+            TRACE("nla_cn", tout << "before restore c=" << **c << ", m_e=" << *m_e << "\n";);
+            *c = copy_of_c;
+            TRACE("nla_cn", tout << "after restore c=" << **c << ", m_e=" << *m_e << "\n";);
             restore_front(copy_of_front, front);
-            TRACE("nla_cn", tout << "restore c=" << *c << "\n";);
+            TRACE("nla_cn", tout << "restore c=" << **c << "\n";);
             TRACE("nla_cn", tout << "m_e=" << *m_e << "\n";);   
         }
     }
@@ -328,9 +328,9 @@ public:
         return out;
     }
 
-    void explore_expr_on_front_elem(nex*& c, vector<nex**>& front) {
-        auto occurences = get_mult_occurences(to_sum(c));
-        TRACE("nla_cn", tout << "m_e=" << *m_e << "\nc=" << *c << ", c occurences=";
+    void explore_expr_on_front_elem(nex** c, vector<nex**>& front) {
+        auto occurences = get_mult_occurences(to_sum(*c));
+        TRACE("nla_cn", tout << "m_e=" << *m_e << "\nc=" << **c << ", c occurences=";
               dump_occurences(tout, occurences) << "; front:"; print_front(front, tout) << "\n";);
     
         if (occurences.empty()) {
@@ -338,7 +338,7 @@ public:
                 TRACE("nla_cn", tout << "got the cn form: =" << *m_e << "\n";);
                 m_done = m_call_on_result(m_e);
             } else {
-                nex* f = *pop_front(front);
+                nex** f = pop_front(front);
                 explore_expr_on_front_elem(f, front);     
             }
         } else {
@@ -360,14 +360,14 @@ public:
     }
     // c is the sub expressiond which is going to be changed from sum to the cross nested form
     // front will be explored more
-    void explore_of_expr_on_sum_and_var(nex*& c, lpvar j, vector<nex**> front) {
-        TRACE("nla_cn", tout << "m_e=" << *m_e << "\nc=" << *c << "\nj = " << ch(j) << "\nfront="; print_front(front, tout) << "\n";);
-        if (!split_with_var(c, j, front))
+    void explore_of_expr_on_sum_and_var(nex** c, lpvar j, vector<nex**> front) {
+        TRACE("nla_cn", tout << "m_e=" << *m_e << "\nc=" << **c << "\nj = " << ch(j) << "\nfront="; print_front(front, tout) << "\n";);
+        if (!split_with_var(*c, j, front))
             return;
-        TRACE("nla_cn", tout << "after split c=" << *c << "\nfront="; print_front(front, tout) << "\n";);
+        TRACE("nla_cn", tout << "after split c=" << **c << "\nfront="; print_front(front, tout) << "\n";);
         SASSERT(front.size());
         auto n = pop_front(front);
-        explore_expr_on_front_elem(*n, front);
+        explore_expr_on_front_elem(n, front);
     }
 
     void add_var_occs(lpvar j) {
diff --git a/src/math/lp/nla_expr.h b/src/math/lp/nla_expr.h
index 0d6319eff..dba3cef5c 100644
--- a/src/math/lp/nla_expr.h
+++ b/src/math/lp/nla_expr.h
@@ -71,7 +71,11 @@ public:
     virtual bool contains(lpvar j) const { return false; }
     virtual int get_degree() const = 0;
     virtual void simplify() {}
-    virtual const vector<nex*> * children_ptr() const {
+    virtual const ptr_vector<nex> * children_ptr() const {
+        UNREACHABLE();
+        return nullptr;
+    }
+    virtual ptr_vector<nex> * children_ptr() {
         UNREACHABLE();
         return nullptr;
     }
@@ -112,8 +116,8 @@ public:
 
 };
 
-static void promote_children_by_type(vector<nex*> * children, expr_type t) {
-    svector<nex*> to_promote;
+static void promote_children_by_type(ptr_vector<nex> * children, expr_type t) {
+    ptr_vector<nex> to_promote;
     for(unsigned j = 0; j < children->size(); j++) {
         nex* e = (*children)[j];
         e->simplify();
@@ -125,24 +129,25 @@ static void promote_children_by_type(vector<nex*> * children, expr_type t) {
                 (*children)[j - offset] = e;
             }
         }
-        for (nex *e : to_promote) {
-            for (nex *ee : *(e->children_ptr())) {
-                children->push_back(ee);
-            }
-        }
     }
-        
+    
+    for (nex *e : to_promote) {
+        for (nex *ee : *(e->children_ptr())) {
+            children->push_back(ee);
+        }
+    }    
 }
 
 class nex_mul : public nex {
-    vector<nex*> m_children;
+    ptr_vector<nex> m_children;
 public:
     nex_mul()  {}
     unsigned size() const { return m_children.size(); }
     expr_type type() const { return expr_type::MUL; }
-    vector<nex*>& children() { return m_children;}
-    const vector<nex*>& children() const { return m_children;}
-    const vector<nex*>* children_ptr() const { return &m_children;}
+    ptr_vector<nex>& children() { return m_children;}
+    const ptr_vector<nex>& children() const { return m_children;}
+    const ptr_vector<nex>* children_ptr() const { return &m_children;}
+    ptr_vector<nex>* children_ptr() { return &m_children;}
     
     std::ostream & print(std::ostream& out) const {
         bool first = true;
@@ -217,13 +222,14 @@ public:
 
 
 class nex_sum : public nex {
-    vector<nex*> m_children;
+    ptr_vector<nex> m_children;
 public:
     nex_sum()  {}
     expr_type type() const { return expr_type::SUM; }
-    vector<nex*>& children() { return m_children;}
-    const vector<nex*>& children() const { return m_children;}    
-    const vector<nex*>* children_ptr() const { return &m_children;}
+    ptr_vector<nex>& children() { return m_children;}
+    const ptr_vector<nex>& children() const { return m_children;}    
+    const ptr_vector<nex>* children_ptr() const { return &m_children;}
+    ptr_vector<nex>* children_ptr() { return &m_children;}
     unsigned size() const { return m_children.size(); }
 
     // we need a linear combination of at least two variables
diff --git a/src/test/lp/lp.cpp b/src/test/lp/lp.cpp
index e72f4f9a7..d73c59a04 100644
--- a/src/test/lp/lp.cpp
+++ b/src/test/lp/lp.cpp
@@ -101,8 +101,9 @@ void test_cn() {
     nex* eae = cn.mk_mul(e, a, e);
     nex* eac = cn.mk_mul(e, a, c);
     nex* ed = cn.mk_mul(e, d);
-    
-    test_cn_on_expr(cn.mk_sum(aad,  abcd, aaccd, add, eae, eac, ed), cn);
+    nex* _6aad = cn.mk_mul(cn.mk_scalar(rational(6)), a, a, d); 
+    //    test_cn_on_expr(cn.mk_sum(aad,  abcd, aaccd, add, eae, eac, ed), cn);
+    test_cn_on_expr(cn.mk_sum(_6aad, abcd, aaccd, add, eae, eac, ed), cn);
     // TRACE("nla_cn", tout << "done\n";);
     // test_cn_on_expr(a*b*d + a*b*c + c*b*d + a*c*d);
     // TRACE("nla_cn", tout << "done\n";);