From a7449494a9be1cc8cfcaa296c3aa853138d00ff1 Mon Sep 17 00:00:00 2001
From: Lev Nachmanson <levnach@hotmail.com>
Date: Fri, 16 Aug 2019 18:10:12 -0700
Subject: [PATCH] rewrite horner scheme on top of nex_expr as a pointer

Signed-off-by: Lev Nachmanson <levnach@hotmail.com>
---
 src/math/lp/cross_nested.h | 188 +++++++++++++++++++++++--------------
 src/math/lp/nla_expr.h     |  36 ++++++-
 src/test/lp/lp.cpp         |  13 ++-
 3 files changed, 162 insertions(+), 75 deletions(-)

diff --git a/src/math/lp/cross_nested.h b/src/math/lp/cross_nested.h
index cda1c98cb..50f70b6f4 100644
--- a/src/math/lp/cross_nested.h
+++ b/src/math/lp/cross_nested.h
@@ -35,7 +35,7 @@ class cross_nested {
     };
 
     // fields
-    nex_sum *                             m_e;
+    nex *                                 m_e;
     std::function<bool (const nex*)>      m_call_on_result;
     std::function<bool (unsigned)>        m_var_is_fixed;
     bool                                  m_done;
@@ -43,6 +43,7 @@ class cross_nested {
     std::unordered_map<lpvar, unsigned>   m_powers;
     vector<nex*> m_allocated;
     vector<nex*> m_b_vec;
+    vector<nex*> m_b_split_vec;
 public:
     cross_nested(std::function<bool (const nex*)> call_on_result,
                  std::function<bool (unsigned)> var_is_fixed):
@@ -51,16 +52,16 @@ public:
         m_done(false)
     {}
 
-    void run(nex_sum *e) {
+    void run(nex *e) {
         m_e = e;
         
-        vector<nex_sum*> front;
+        vector<nex**> front;
         explore_expr_on_front_elem(m_e, front);
     }
 
-    static nex_sum* pop_back(vector<nex_sum*>& front) {
-        nex_sum* c = front.back();
-        TRACE("nla_cn", tout << *c << "\n";);
+    static nex** pop_front(vector<nex**>& front) {
+        nex** c = front.back();
+        TRACE("nla_cn", tout <<  **c << "\n";);
         front.pop_back();
         return c;
     }
@@ -70,6 +71,14 @@ public:
         m_allocated.push_back(r);
         return r;
     }
+    template <typename T>
+    void add_children(T) { }
+    
+    template <typename T, typename K, typename ...Args>
+    void add_children(T r, K e, Args ...  es) {
+        r->add_child(e);
+        add_children(r, es ...);
+    }
 
     nex_sum* mk_sum(const vector<nex*>& v) {
         auto r = new nex_sum();
@@ -78,40 +87,41 @@ public:
         return r;
     }
 
-    nex_sum* mk_sum(nex *a, nex* b) {
-        auto r = new nex_sum();
+    nex_mul* mk_mul(const vector<nex*>& v) {
+        auto r = new nex_mul();
         m_allocated.push_back(r);
-        r->children().push_back(a);
-        r->children().push_back(b);
+        r->children() = v;
         return r;
     }
 
+    template <typename K, typename...Args>
+    nex_sum* mk_sum(K e, Args... es) {
+        auto r = new nex_sum();
+        m_allocated.push_back(r);
+        r->add_child(e);
+        add_children(r, es...);
+        return r;
+    }
     nex_var* mk_var(lpvar j) {
         auto r = new nex_var(j);
         m_allocated.push_back(r);
         return r;
     }
-
+    
     nex_mul* mk_mul() {
         auto r = new nex_mul();
         m_allocated.push_back(r);
         return r;
     }
 
-    nex_mul* mk_mul(nex * a, nex * b) {
+    template <typename K, typename...Args>
+    nex_mul* mk_mul(K e, Args... es) {
         auto r = new nex_mul();
         m_allocated.push_back(r);
-        r->add_child(a); r->add_child(b);
+        add_children(r, e, es...);
         return r;
     }
-
-    nex_mul* mk_mul(nex * a, nex * b, nex *c) {
-        auto r = new nex_mul();
-        m_allocated.push_back(r);
-        r->add_child(a); r->add_child(b); r->add_child(c);
-        return r;
-    }
-
+    
     nex_scalar* mk_scalar(const rational& v) {
         auto r = new nex_scalar(v);
         m_allocated.push_back(r);
@@ -120,8 +130,32 @@ public:
 
 
     nex * mk_div(const nex* a, lpvar j) {
-        SASSERT(false);
-        return nullptr;
+        TRACE("nla_cn_details", tout << "a=" << *a << ", v" << j << "\n";);
+        SASSERT((a->is_mul() && a->contains(j)) || (a->is_var() && to_var(a)->var() == j));
+        if (a->is_var())
+            return mk_scalar(rational(1));
+        m_b_vec.clear();
+        bool seenj = false;
+        for (nex* c : to_mul(a)->children()) {
+            if (!seenj) {
+                if (c->contains(j)) {
+                    if (!c->is_var())                     
+                        m_b_vec.push_back(mk_div(c, j));
+                    seenj = true;
+                    continue;
+                } 
+            }
+            m_b_vec.push_back(c);
+        }
+        if (m_b_vec.size() > 1) { 
+            return mk_mul(m_b_vec);
+        }
+        if (m_b_vec.size() == 1) {
+            return m_b_vec[0];
+        }
+
+        SASSERT(m_b_vec.size() == 0);
+        return mk_scalar(rational(1));
     }
 
     nex * mk_div(const nex* a, const nex* b) {
@@ -218,14 +252,14 @@ public:
         return false;
     }
 
-    bool proceed_with_common_factor(nex*& c, vector<nex_sum*>& front, const vector<std::pair<lpvar, occ>> & occurences) {
+    bool proceed_with_common_factor(nex*& c, vector<nex**>& front, const vector<std::pair<lpvar, occ>> & occurences) {
         TRACE("nla_cn", tout << "c=" << *c << "\n";);
         nex* f = extract_common_factor(c, occurences);
         if (f == nullptr)
             return false;
         
-        nex_sum* c_over_f = to_sum(mk_div(c, f));
-        c_over_f->simplify();
+        nex* c_over_f = mk_div(c, f);
+        to_sum(c_over_f)->simplify();
         c = mk_mul(f, c_over_f);
         TRACE("nla_cn", tout << "common factor=" << *f << ", c=" << *c << "\ne = " << *m_e << "\n";);
         
@@ -233,28 +267,28 @@ public:
         return true;
     }
 
-    static void push(vector<nex_sum*>& front, nex_sum* e) {
-        TRACE("nla_cn", tout << *e << "\n";);
+    static void push(vector<nex**>& front, nex** e) {
+        TRACE("nla_cn", tout << **e << "\n";);
         front.push_back(e);
     }
     
-    static vector<nex_sum*> copy_front(const vector<nex_sum*>& front) {
-        vector<nex_sum*> v;
-        for (nex_sum* n: front)
-            v.push_back(n);
+    static vector<nex*> copy_front(const vector<nex**>& front) {
+        vector<nex*> v;
+        for (nex** n: front)
+            v.push_back(*n);
         return v;
     }
 
-    static void restore_front(const vector<nex_sum*> &copy, vector<nex_sum*>& front) {
+    static void restore_front(const vector<nex*> &copy, vector<nex**>& front) {
         SASSERT(copy.size() == front.size());
         for (unsigned i = 0; i < front.size(); i++)
-            front[i] = copy[i];
+            *(front[i]) = copy[i];
     }
     
-    void explore_expr_on_front_elem_occs(nex* c, vector<nex_sum*>& front, const vector<std::pair<lpvar, occ>> & occurences) {
+    void explore_expr_on_front_elem_occs(nex* &c, vector<nex**>& front, const vector<std::pair<lpvar, occ>> & occurences) {
         if (proceed_with_common_factor(c, front, occurences))
             return;
-        TRACE("nla_cn", tout << "save c=" << *c << "; front:"; print_vector_of_ptrs(front, tout) << "\n";);           
+        TRACE("nla_cn", tout << "save c=" << *c << "; front:"; print_front(front, tout) << "\n";);           
         nex* copy_of_c = c;
         auto copy_of_front = copy_front(front);
         for(auto& p : occurences) {
@@ -269,11 +303,12 @@ public:
             explore_of_expr_on_sum_and_var(c, j, front);
             if (m_done)
                 return;
+            TRACE("nla_cn", tout << "before restore c=" << *c << ", m_e=" << *m_e << "\n";);
             c = copy_of_c;
-            TRACE("nla_cn", tout << "restore c=" << *c << ", m_e=" << m_e << "\n";);
+            TRACE("nla_cn", tout << "after restore c=" << *c << ", m_e=" << *m_e << "\n";);
             restore_front(copy_of_front, front);
             TRACE("nla_cn", tout << "restore c=" << *c << "\n";);
-            TRACE("nla_cn", tout << "m_e=" << m_e << "\n";);   
+            TRACE("nla_cn", tout << "m_e=" << *m_e << "\n";);   
         }
     }
 
@@ -288,18 +323,18 @@ public:
         return out;
     }
 
-    void explore_expr_on_front_elem(nex_sum* c, vector<nex_sum*>& front) {
-        auto occurences = get_mult_occurences(c);
-        TRACE("nla_cn", tout << "m_e=" << m_e << "\nc=" << *c << ", c occurences=";
-              dump_occurences(tout, occurences) << "; front:"; print_vector_of_ptrs(front, tout) << "\n";);
+    void explore_expr_on_front_elem(nex*& c, vector<nex**>& front) {
+        auto occurences = get_mult_occurences(to_sum(c));
+        TRACE("nla_cn", tout << "m_e=" << *m_e << "\nc=" << *c << ", c occurences=";
+              dump_occurences(tout, occurences) << "; front:"; print_front(front, tout) << "\n";);
     
         if (occurences.empty()) {
             if(front.empty()) {
-                TRACE("nla_cn", tout << "got the cn form: =" << m_e << "\n";);
+                TRACE("nla_cn", tout << "got the cn form: =" << *m_e << "\n";);
                 m_done = m_call_on_result(m_e);
             } else {
-                auto c = pop_back(front);
-                explore_expr_on_front_elem(c, front);     
+                nex* f = *pop_front(front);
+                explore_expr_on_front_elem(f, front);     
             }
         } else {
             explore_expr_on_front_elem_occs(c, front, occurences);
@@ -311,15 +346,23 @@ public:
         return s.str();
         //        return (char)('a'+j);
     }
-    // e is the global expression, c is the sub expressiond which is going to changed from sum to the cross nested form
-    void explore_of_expr_on_sum_and_var(nex* & c, lpvar j, vector<nex_sum*> front) {
-        TRACE("nla_cn", tout << "m_e=" << m_e << "\nc=" << *c << "\nj = " << ch(j) << "\nfront="; print_vector_of_ptrs(front, tout) << "\n";);
+
+    std::ostream& print_front(const vector<nex**>& front, std::ostream& out) const {
+        for (auto e : front) {
+            out << **e << "\n";
+        }
+        return out;
+    }
+    // c is the sub expressiond which is going to be changed from sum to the cross nested form
+    // front will be explored more
+    void explore_of_expr_on_sum_and_var(nex*& c, lpvar j, vector<nex**> front) {
+        TRACE("nla_cn", tout << "m_e=" << *m_e << "\nc=" << *c << "\nj = " << ch(j) << "\nfront="; print_front(front, tout) << "\n";);
         if (!split_with_var(c, j, front))
             return;
-        TRACE("nla_cn", tout << "after split c=" << *c << "\nfront="; print_vector_of_ptrs(front, tout) << "\n";);
+        TRACE("nla_cn", tout << "after split c=" << *c << "\nfront="; print_front(front, tout) << "\n";);
         SASSERT(front.size());
-        auto n = pop_back(front); TRACE("nla_cn", tout << "n=" << *n <<"\n";);
-        explore_expr_on_front_elem(n, front);
+        auto n = pop_front(front);
+        explore_expr_on_front_elem(*n, front);
     }
 
     void add_var_occs(lpvar j) {
@@ -378,7 +421,7 @@ public:
             }
         }
         remove_singular_occurences();
-        TRACE("nla_cn_details", tout << "e=" << e << "\noccs="; dump_occurences(tout, m_occurences_map) << "\n";);
+        TRACE("nla_cn_details", tout << "e=" << *e << "\noccs="; dump_occurences(tout, m_occurences_map) << "\n";);
         vector<std::pair<lpvar, occ>> ret;
         for (auto & p : m_occurences_map)
             ret.push_back(p);
@@ -405,54 +448,57 @@ public:
     void pre_split(nex_sum * e, lpvar j, nex_sum* & a, nex* & b) {
         
         a = mk_sum();
-        m_b_vec.clear();
+        m_b_split_vec.clear();
         for (nex * ce: e->children()) {
             if (is_divisible_by_var(ce, j)) {
                 a->add_child(mk_div(ce , j));
             } else {
-                m_b_vec.push_back(ce);
+                m_b_split_vec.push_back(ce);
+                TRACE("nla_cn_details", tout << "ce = " << *ce << "\n";);
+                
             }        
         }
         TRACE("nla_cn_details", tout << "a = " << *a << "\n";);
-        SASSERT(a->children().size() >= 2 && m_b_vec.size());
+        SASSERT(a->children().size() >= 2 && m_b_split_vec.size());
         a->simplify();
         
-        if (m_b_vec.size() == 1) {
-            b = m_b_vec[0];      
+        if (m_b_split_vec.size() == 1) {
+            b = m_b_split_vec[0];
+            TRACE("nla_cn_details", tout << "b = " << *b << "\n";);
         } else {
-            SASSERT(m_b_vec.size() > 1);
-            b = mk_sum(m_b_vec);        
-        } 
+            SASSERT(m_b_split_vec.size() > 1);
+            b = mk_sum(m_b_split_vec);
+            TRACE("nla_cn_details", tout << "b = " << *b << "\n";);
+        }
     }
 
-    void update_front_with_split_with_non_empty_b(nex* &e, lpvar j, vector<nex_sum*> & front, nex_sum* a, nex* b) {
+    void update_front_with_split_with_non_empty_b(nex* &e, lpvar j, vector<nex**> & front, nex* a, nex* b) {
 
         SASSERT(a->is_sum());
         
-        TRACE("nla_cn_details", tout << "b = " << b << "\n";);
+        TRACE("nla_cn_details", tout << "b = " << *b << "\n";);
         e = mk_sum(mk_mul(mk_var(j), a), b); // e = j*a + b
-        push(front, a); // pushing 'a'
-        TRACE("nla_cn", tout << "push to front " << *a << "\n";);
+        nex **ptr_to_a = &(to_mul(to_sum(e)->children()[0]))->children()[1];
+        push(front, ptr_to_a);
         
         if (b->is_sum()) {
-            push(front, to_sum(b));
-            TRACE("nla_cn", tout << "push to front " << *b << "\n";);
+            nex **ptr_to_a = &(to_sum(e)->children()[1]);
+            push(front, ptr_to_a);
         }
     }
     
-   void update_front_with_split(nex* & e, lpvar j, vector<nex_sum*> & front, nex_sum* a, nex* b) {
+   void update_front_with_split(nex* & e, lpvar j, vector<nex**> & front, nex* a, nex* b) {
         if (b == nullptr) {
             e = mk_mul(mk_var(j), a);
-            push(front, a);
-            TRACE("nla_cn_details", tout << "push to front " << *a << "\n";);
+            push(front, &(to_mul(e)->children()[1]));
         } else {
             update_front_with_split_with_non_empty_b(e, j, front, a, b);
         }
     }
     // it returns true if the recursion brings a cross-nested form
-    bool split_with_var(nex*& e, lpvar j, vector<nex_sum*> & front) {
+    bool split_with_var(nex*& e, lpvar j, vector<nex**> & front) {
         SASSERT(e->is_sum());
-        TRACE("nla_cn", tout << "e = " << e << ", j=" << ch(j) << "\n";);
+        TRACE("nla_cn", tout << "e = " << *e << ", j=" << ch(j) << "\n";);
         nex_sum* a; nex * b;
         pre_split(to_sum(e), j, a, b);
         /*
diff --git a/src/math/lp/nla_expr.h b/src/math/lp/nla_expr.h
index f35a4fc9a..0d6319eff 100644
--- a/src/math/lp/nla_expr.h
+++ b/src/math/lp/nla_expr.h
@@ -70,6 +70,11 @@ public:
     virtual ~nex() {}
     virtual bool contains(lpvar j) const { return false; }
     virtual int get_degree() const = 0;
+    virtual void simplify() {}
+    virtual const vector<nex*> * children_ptr() const {
+        UNREACHABLE();
+        return nullptr;
+    }
 };
 std::ostream& operator<<(std::ostream& out, const nex&);
 
@@ -107,6 +112,28 @@ public:
 
 };
 
+static void promote_children_by_type(vector<nex*> * children, expr_type t) {
+    svector<nex*> to_promote;
+    for(unsigned j = 0; j < children->size(); j++) {
+        nex* e = (*children)[j];
+        e->simplify();
+        if (e->type() == t) {
+            to_promote.push_back(e);
+        } else {
+            unsigned offset = to_promote.size();
+            if (offset) {
+                (*children)[j - offset] = e;
+            }
+        }
+        for (nex *e : to_promote) {
+            for (nex *ee : *(e->children_ptr())) {
+                children->push_back(ee);
+            }
+        }
+    }
+        
+}
+
 class nex_mul : public nex {
     vector<nex*> m_children;
 public:
@@ -115,6 +142,8 @@ public:
     expr_type type() const { return expr_type::MUL; }
     vector<nex*>& children() { return m_children;}
     const vector<nex*>& children() const { return m_children;}
+    const vector<nex*>* children_ptr() const { return &m_children;}
+    
     std::ostream & print(std::ostream& out) const {
         bool first = true;
         for (const nex* v : m_children) {            
@@ -180,9 +209,13 @@ public:
         return degree;
     }
     
+    void simplify() {
+        promote_children_by_type(&m_children, expr_type::MUL);
+    }
 
 };
 
+
 class nex_sum : public nex {
     vector<nex*> m_children;
 public:
@@ -190,6 +223,7 @@ public:
     expr_type type() const { return expr_type::SUM; }
     vector<nex*>& children() { return m_children;}
     const vector<nex*>& children() const { return m_children;}    
+    const vector<nex*>* children_ptr() const { return &m_children;}
     unsigned size() const { return m_children.size(); }
 
     // we need a linear combination of at least two variables
@@ -233,7 +267,7 @@ public:
     }
 
     void simplify() {
-        SASSERT(false);
+        promote_children_by_type(&m_children, expr_type::SUM);
     }
     
     int get_degree() const {
diff --git a/src/test/lp/lp.cpp b/src/test/lp/lp.cpp
index a0ff48209..e72f4f9a7 100644
--- a/src/test/lp/lp.cpp
+++ b/src/test/lp/lp.cpp
@@ -86,7 +86,6 @@ void test_cn() {
     nex_var* c = cn.mk_var(2);
     nex_var* d = cn.mk_var(3);
     nex_var* e = cn.mk_var(4);
-    nex_var* f = cn.mk_var(5);
     nex_var* g = cn.mk_var(6);
     nex* min_1 = cn.mk_scalar(rational(-1));
     // test_cn_on_expr(min_1*c*e + min_1*b*d + min_1*a*b + a*c);
@@ -94,8 +93,16 @@ void test_cn() {
     nex_mul* bcg = cn.mk_mul(b, c, g);
     bcg->add_child(min_1);
     nex_sum* t = cn.mk_sum(bcd, bcg);
-    test_cn_on_expr(t, cn);
-    //    test_cn_on_expr(a*a*d + a*b*c*d + a*a*c*c*d + a*d*d + e*a*e + e*a*c + e*d);
+    //    test_cn_on_expr(t, cn);
+    nex* aad = cn.mk_mul(a, a, d);
+    nex* abcd = cn.mk_mul(a, b, c, d);
+    nex* aaccd = cn.mk_mul(a, a, c, c, d);
+    nex* add = cn.mk_mul(a, d, d);
+    nex* eae = cn.mk_mul(e, a, e);
+    nex* eac = cn.mk_mul(e, a, c);
+    nex* ed = cn.mk_mul(e, d);
+    
+    test_cn_on_expr(cn.mk_sum(aad,  abcd, aaccd, add, eae, eac, ed), cn);
     // TRACE("nla_cn", tout << "done\n";);
     // test_cn_on_expr(a*b*d + a*b*c + c*b*d + a*c*d);
     // TRACE("nla_cn", tout << "done\n";);