diff --git a/src/math/lp/cross_nested.h b/src/math/lp/cross_nested.h
index f1de18ff4..cda1c98cb 100644
--- a/src/math/lp/cross_nested.h
+++ b/src/math/lp/cross_nested.h
@@ -22,7 +22,6 @@
 #include "math/lp/nla_expr.h"
 namespace nla {
 class cross_nested {
-    typedef nla_expr<rational> nex;
     struct occ {
         unsigned m_occs;
         unsigned m_power;
@@ -36,61 +35,179 @@ class cross_nested {
     };
 
     // fields
-    nex& m_e;
-    std::function<bool (const nex&)>      m_call_on_result;
+    nex_sum *                             m_e;
+    std::function<bool (const nex*)>      m_call_on_result;
     std::function<bool (unsigned)>        m_var_is_fixed;
     bool                                  m_done;
     std::unordered_map<lpvar, occ>        m_occurences_map;
     std::unordered_map<lpvar, unsigned>   m_powers;
-
+    vector<nex*> m_allocated;
+    vector<nex*> m_b_vec;
 public:
-    cross_nested(nex &e,
-                 std::function<bool (const nex&)> call_on_result,
+    cross_nested(std::function<bool (const nex*)> call_on_result,
                  std::function<bool (unsigned)> var_is_fixed):
-        m_e(e),
         m_call_on_result(call_on_result),
         m_var_is_fixed(var_is_fixed),
         m_done(false)
     {}
 
-    void run() {
-        vector<nex*> front;
-        explore_expr_on_front_elem(&m_e, front); // true for trivial form - no change
+    void run(nex_sum *e) {
+        m_e = e;
+        
+        vector<nex_sum*> front;
+        explore_expr_on_front_elem(m_e, front);
     }
 
-    static nex* pop_back(vector<nex*>& front) {
-        nex* c = front.back();
+    static nex_sum* pop_back(vector<nex_sum*>& front) {
+        nex_sum* c = front.back();
         TRACE("nla_cn", tout << *c << "\n";);
         front.pop_back();
         return c;
     }
 
-    static bool extract_common_factor(nex* c, nex& f, const vector<std::pair<lpvar, occ>> & occurences) {
+    nex_sum* mk_sum() {
+        auto r = new nex_sum();
+        m_allocated.push_back(r);
+        return r;
+    }
+
+    nex_sum* mk_sum(const vector<nex*>& v) {
+        auto r = new nex_sum();
+        m_allocated.push_back(r);
+        r->children() = v;
+        return r;
+    }
+
+    nex_sum* mk_sum(nex *a, nex* b) {
+        auto r = new nex_sum();
+        m_allocated.push_back(r);
+        r->children().push_back(a);
+        r->children().push_back(b);
+        return r;
+    }
+
+    nex_var* mk_var(lpvar j) {
+        auto r = new nex_var(j);
+        m_allocated.push_back(r);
+        return r;
+    }
+
+    nex_mul* mk_mul() {
+        auto r = new nex_mul();
+        m_allocated.push_back(r);
+        return r;
+    }
+
+    nex_mul* mk_mul(nex * a, nex * b) {
+        auto r = new nex_mul();
+        m_allocated.push_back(r);
+        r->add_child(a); r->add_child(b);
+        return r;
+    }
+
+    nex_mul* mk_mul(nex * a, nex * b, nex *c) {
+        auto r = new nex_mul();
+        m_allocated.push_back(r);
+        r->add_child(a); r->add_child(b); r->add_child(c);
+        return r;
+    }
+
+    nex_scalar* mk_scalar(const rational& v) {
+        auto r = new nex_scalar(v);
+        m_allocated.push_back(r);
+        return r;
+    }
+
+
+    nex * mk_div(const nex* a, lpvar j) {
+        SASSERT(false);
+        return nullptr;
+    }
+
+    nex * mk_div(const nex* a, const nex* b) {
+        TRACE("nla_cn_details", tout << *a <<" / " << *b << "\n";);
+        if (b->is_var()) {
+            return mk_div(a, to_var(b)->var());
+        }
+        SASSERT(b->is_mul());
+        const nex_mul *bm = to_mul(b);
+        if (a->is_sum()) {
+            nex_sum * r = mk_sum();
+            const nex_sum * m = to_sum(a);
+            for (auto e : m->children()) {
+                r->add_child(mk_div(e, bm));
+            }
+            TRACE("nla_cn_details", tout << *r << "\n";);
+            return r;
+        }
+        if (a->is_var() || (a->is_mul() && to_mul(a)->children().size() == 1)) {
+            return mk_scalar(rational(1));
+        }
+        SASSERT(a->is_mul());
+        const nex_mul* am = to_mul(a);
+        bm->get_powers_from_mul(m_powers);
+        nex_mul* ret = new nex_mul();
+        for (auto e : am->children()) {
+            TRACE("nla_cn_details", tout << "e=" << *e << "\n";);
+            
+            if (!e->is_var()) {
+                SASSERT(e->is_scalar());
+                ret->add_child(e);
+                TRACE("nla_cn_details", tout << "continue\n";);
+                continue;
+            }
+            SASSERT(e->is_var());
+            lpvar j = to_var(e)->var();
+            auto it = m_powers.find(j);
+            if (it == m_powers.end()) {
+                 ret->add_child(e);
+            } else {
+                it->second --;
+                if (it->second == 0)
+                    m_powers.erase(it);
+            }
+            TRACE("nla_cn_details", tout << *ret << "\n";);            
+        }
+        SASSERT(m_powers.size() == 0);
+        if (ret->children().size() == 0) {
+            delete ret;
+            TRACE("nla_cn_details", tout << "return 1\n";);
+            return mk_scalar(rational(1));
+        }
+        m_allocated.push_back(ret);
+        TRACE("nla_cn_details", tout << *ret << "\n";);        
+        return ret;
+    }
+
+    nex* extract_common_factor(nex* e, const vector<std::pair<lpvar, occ>> & occurences) {
+        nex_sum* c = to_sum(e);
         TRACE("nla_cn", tout << "c=" << *c << "\n";);
-        SASSERT(c->is_sum());
-        f.type() = expr_type::MUL;
-        SASSERT(f.children().empty());
         unsigned size = c->children().size();
         for(const auto & p : occurences) {
+            if (p.second.m_occs < size) {
+                return nullptr;
+            }
+        }
+        nex_mul* f = mk_mul();
+        for(const auto & p : occurences) { // randomize here: todo
             if (p.second.m_occs == size) {
                 unsigned pow = p.second.m_power;
                 while (pow --) {
-                    f *= nex::var(p.first);
+                    f->add_child(mk_var(p.first));
                 }
             }
         }
-        return !f.children().empty();
+        return f;
     }
 
-    static bool has_common_factor(const nex& c) {
-        TRACE("nla_cn", tout << "c=" << c << "\n";);
-        SASSERT(c.is_sum());
-        auto & ch = c.children();
+    static bool has_common_factor(const nex_sum* c) {
+        TRACE("nla_cn", tout << "c=" << *c << "\n";);
+        auto & ch = c->children();
         auto common_vars = get_vars_of_expr(ch[0]);
         for (lpvar j : common_vars) {
             bool divides_the_rest = true;
             for(unsigned i = 1; i < ch.size() && divides_the_rest; i++) {
-                if (!ch[i].contains(j))
+                if (!ch[i]->contains(j))
                     divides_the_rest = false;
             }
             if (divides_the_rest) {
@@ -101,45 +218,45 @@ public:
         return false;
     }
 
-    bool proceed_with_common_factor(nex* c, vector<nex*>& front, const vector<std::pair<lpvar, occ>> & occurences) {
+    bool proceed_with_common_factor(nex*& c, vector<nex_sum*>& front, const vector<std::pair<lpvar, occ>> & occurences) {
         TRACE("nla_cn", tout << "c=" << *c << "\n";);
-        SASSERT(c->is_sum());
-        nex f;
-        if (!extract_common_factor(c, f, occurences))
+        nex* f = extract_common_factor(c, occurences);
+        if (f == nullptr)
             return false;
         
-        *c /= f;
-        f.simplify();
-        * c = nex::mul(f, *c);
-        TRACE("nla_cn", tout << "common factor=" << f << ", c=" << *c << "\n";);
-        explore_expr_on_front_elem(&(c->children()[1]), front);
+        nex_sum* c_over_f = to_sum(mk_div(c, f));
+        c_over_f->simplify();
+        c = mk_mul(f, c_over_f);
+        TRACE("nla_cn", tout << "common factor=" << *f << ", c=" << *c << "\ne = " << *m_e << "\n";);
+        
+        explore_expr_on_front_elem(c_over_f, front);
         return true;
     }
 
-    static void push(vector<nex*>& front, nex* e) {
+    static void push(vector<nex_sum*>& front, nex_sum* e) {
         TRACE("nla_cn", tout << *e << "\n";);
         front.push_back(e);
     }
     
-    static vector<nex> copy_front(const vector<nex*>& front) {
-       vector<nex> v;
-        for (nex* n: front)
-            v.push_back(*n);
+    static vector<nex_sum*> copy_front(const vector<nex_sum*>& front) {
+        vector<nex_sum*> v;
+        for (nex_sum* n: front)
+            v.push_back(n);
         return v;
     }
 
-    static void restore_front(const vector<nex> &copy, vector<nex*>& front) {
+    static void restore_front(const vector<nex_sum*> &copy, vector<nex_sum*>& front) {
         SASSERT(copy.size() == front.size());
         for (unsigned i = 0; i < front.size(); i++)
-            *(front[i]) = copy[i];
+            front[i] = copy[i];
     }
     
-    void explore_expr_on_front_elem_occs(nex* c, vector<nex*>& front, const vector<std::pair<lpvar, occ>> & occurences) {
+    void explore_expr_on_front_elem_occs(nex* c, vector<nex_sum*>& front, const vector<std::pair<lpvar, occ>> & occurences) {
         if (proceed_with_common_factor(c, front, occurences))
             return;
         TRACE("nla_cn", tout << "save c=" << *c << "; front:"; print_vector_of_ptrs(front, tout) << "\n";);           
-        nex copy_of_c = *c;
-        vector<nex> copy_of_front = copy_front(front);
+        nex* copy_of_c = c;
+        auto copy_of_front = copy_front(front);
         for(auto& p : occurences) {
             SASSERT(p.second.m_occs > 1);
             lpvar j = p.first;
@@ -152,7 +269,7 @@ public:
             explore_of_expr_on_sum_and_var(c, j, front);
             if (m_done)
                 return;
-            *c = copy_of_c;
+            c = copy_of_c;
             TRACE("nla_cn", tout << "restore c=" << *c << ", m_e=" << m_e << "\n";);
             restore_front(copy_of_front, front);
             TRACE("nla_cn", tout << "restore c=" << *c << "\n";);
@@ -171,9 +288,8 @@ public:
         return out;
     }
 
-    void explore_expr_on_front_elem(nex* c, vector<nex*>& front) {
-        SASSERT(c->is_sum());
-        auto occurences = get_mult_occurences(*c);
+    void explore_expr_on_front_elem(nex_sum* c, vector<nex_sum*>& front) {
+        auto occurences = get_mult_occurences(c);
         TRACE("nla_cn", tout << "m_e=" << m_e << "\nc=" << *c << ", c occurences=";
               dump_occurences(tout, occurences) << "; front:"; print_vector_of_ptrs(front, tout) << "\n";);
     
@@ -182,7 +298,7 @@ public:
                 TRACE("nla_cn", tout << "got the cn form: =" << m_e << "\n";);
                 m_done = m_call_on_result(m_e);
             } else {
-                nex* c = pop_back(front);
+                auto c = pop_back(front);
                 explore_expr_on_front_elem(c, front);     
             }
         } else {
@@ -196,17 +312,17 @@ public:
         //        return (char)('a'+j);
     }
     // e is the global expression, c is the sub expressiond which is going to changed from sum to the cross nested form
-    void explore_of_expr_on_sum_and_var(nex* c, lpvar j, vector<nex*> front) {
+    void explore_of_expr_on_sum_and_var(nex* & c, lpvar j, vector<nex_sum*> front) {
         TRACE("nla_cn", tout << "m_e=" << m_e << "\nc=" << *c << "\nj = " << ch(j) << "\nfront="; print_vector_of_ptrs(front, tout) << "\n";);
-        if (!split_with_var(*c, j, front))
+        if (!split_with_var(c, j, front))
             return;
         TRACE("nla_cn", tout << "after split c=" << *c << "\nfront="; print_vector_of_ptrs(front, tout) << "\n";);
         SASSERT(front.size());
-        nex* n = pop_back(front); TRACE("nla_cn", tout << "n=" << *n <<"\n";);
+        auto n = pop_back(front); TRACE("nla_cn", tout << "n=" << *n <<"\n";);
         explore_expr_on_front_elem(n, front);
     }
 
-    void process_var_occurences(lpvar j) {
+    void add_var_occs(lpvar j) {
         auto it = m_occurences_map.find(j);
         if (it != m_occurences_map.end()) {
             it->second.m_occs++;
@@ -251,15 +367,14 @@ public:
     
     // j -> the number of expressions j appears in as a multiplier
     // The result is sorted by large number of occurences first
-    vector<std::pair<lpvar, occ>> get_mult_occurences(const nex& e) {
+    vector<std::pair<lpvar, occ>> get_mult_occurences(const nex_sum* e) {
         clear_maps();
-        SASSERT(e.type() == expr_type::SUM);
-        for (const auto & ce : e.children()) {
-            if (ce.is_mul()) {
-                auto powers = ce.get_powers_from_mul();
+        for (const auto * ce : e->children()) {
+            if (ce->is_mul()) {
+                to_mul(ce)->get_powers_from_mul(m_powers);
                 update_occurences_with_powers();
-            } else if (ce.type() ==  expr_type::VAR) {
-                process_var_occurences(ce.var());
+            } else if (ce->is_var()) {
+                add_var_occs(to_var(ce)->var());
             }
         }
         remove_singular_occurences();
@@ -281,63 +396,65 @@ public:
                                           });
         return ret;
     }
+
+    static bool is_divisible_by_var(nex* ce, lpvar j) {
+        return (ce->is_mul() && to_mul(ce)->contains(j))
+            || (ce->is_var() && to_var(ce)->var() == j);
+    }
     // all factors of j go to a, the rest to b
-    static void pre_split(nex &e, lpvar j, nex &a, nex&b) {
-        for (const nex & ce: e.children()) {
-            if ((ce.is_mul() && ce.contains(j)) || (ce.is_var() && ce.var() == j)) {
-                a.add_child(ce / j);
+    void pre_split(nex_sum * e, lpvar j, nex_sum* & a, nex* & b) {
+        
+        a = mk_sum();
+        m_b_vec.clear();
+        for (nex * ce: e->children()) {
+            if (is_divisible_by_var(ce, j)) {
+                a->add_child(mk_div(ce , j));
             } else {
-                b.add_child(ce);
+                m_b_vec.push_back(ce);
             }        
         }
-        a.type() = expr_type::SUM;
-        TRACE("nla_cn_details", tout << "a = " << a << "\n";);
-        SASSERT(a.children().size() >= 2);
-        a.simplify();
+        TRACE("nla_cn_details", tout << "a = " << *a << "\n";);
+        SASSERT(a->children().size() >= 2 && m_b_vec.size());
+        a->simplify();
         
-        if (b.children().size() == 1) {
-            nex t = b.children()[0];
-            b = t;      
-        } else if (b.children().size() > 1) {
-            b.type() = expr_type::SUM;        
-        }
+        if (m_b_vec.size() == 1) {
+            b = m_b_vec[0];      
+        } else {
+            SASSERT(m_b_vec.size() > 1);
+            b = mk_sum(m_b_vec);        
+        } 
     }
 
-    // returns true if the recursion is done inside
-    void update_front_with_split_with_non_empty_b(nex& e, lpvar j, vector<nex*> & front, nex& a, nex& b) {
-        nex f;
-        SASSERT(a.is_sum());
+    void update_front_with_split_with_non_empty_b(nex* &e, lpvar j, vector<nex_sum*> & front, nex_sum* a, nex* b) {
+
+        SASSERT(a->is_sum());
         
         TRACE("nla_cn_details", tout << "b = " << b << "\n";);
-        e = nex::sum(nex::mul(nex::var(j), a), b);
-        push(front, &(e.children()[0].children()[1])); // pushing 'a'
-        TRACE("nla_cn", tout << "push to front " << e.children()[0].children()[1] << "\n";);
+        e = mk_sum(mk_mul(mk_var(j), a), b); // e = j*a + b
+        push(front, a); // pushing 'a'
+        TRACE("nla_cn", tout << "push to front " << *a << "\n";);
         
-        if (b.is_sum()) {
-            push(front, &(e.children()[1]));
-            TRACE("nla_cn", tout << "push to front " << e.children()[1] << "\n";);
+        if (b->is_sum()) {
+            push(front, to_sum(b));
+            TRACE("nla_cn", tout << "push to front " << *b << "\n";);
         }
     }
     
-   void update_front_with_split(nex& e, lpvar j, vector<nex*> & front, nex& a, nex& b) {
-        if (b.is_undef()) {
-            SASSERT(b.children().size() == 0);
-            e = nex(expr_type::MUL);        
-            e.add_child(nex::var(j));
-            e.add_child(a);
-            if (a.size() > 1) {
-                push(front, &e.children().back());
-                TRACE("nla_cn_details", tout << "push to front " << e.children().back() << "\n";);
-            }
+   void update_front_with_split(nex* & e, lpvar j, vector<nex_sum*> & front, nex_sum* a, nex* b) {
+        if (b == nullptr) {
+            e = mk_mul(mk_var(j), a);
+            push(front, a);
+            TRACE("nla_cn_details", tout << "push to front " << *a << "\n";);
+        } else {
+            update_front_with_split_with_non_empty_b(e, j, front, a, b);
         }
-        update_front_with_split_with_non_empty_b(e, j, front, a, b);
     }
     // it returns true if the recursion brings a cross-nested form
-    bool split_with_var(nex& e, lpvar j, vector<nex*> & front) {
+    bool split_with_var(nex*& e, lpvar j, vector<nex_sum*> & front) {
+        SASSERT(e->is_sum());
         TRACE("nla_cn", tout << "e = " << e << ", j=" << ch(j) << "\n";);
-        if (!e.is_sum()) return false;
-        nex a, b;
-        pre_split(e, j, a, b);
+        nex_sum* a; nex * b;
+        pre_split(to_sum(e), j, a, b);
         /*
           When we have e without a non-trivial common factor then
           there is a variable j such that e = jP + Q, where Q has all members
@@ -352,28 +469,42 @@ public:
         update_front_with_split(e, j, front, a, b);
         return true;
     }
-    static std::unordered_set<lpvar> get_vars_of_expr(const nex &e ) {
+
+    static std::unordered_set<lpvar> get_vars_of_expr(const nex *e ) {
         std::unordered_set<lpvar> r;
-        switch (e.type()) {
+        switch (e->type()) {
         case expr_type::SCALAR:
             return r;
         case expr_type::SUM:
+            {
+                for (auto c: to_sum(e)->children())
+                    for ( lpvar j : get_vars_of_expr(c))
+                        r.insert(j);
+            }
         case expr_type::MUL:
             {
-                for (const auto & c: e.children())
+                for (auto c: to_mul(e)->children())
                     for ( lpvar j : get_vars_of_expr(c))
                         r.insert(j);
             }
             return r;
         case expr_type::VAR:
-            r.insert(e.var());
+            r.insert(to_var(e)->var());
             return r;
         default:
-            TRACE("nla_cn_details", tout << e.type() << "\n";);
+            TRACE("nla_cn_details", tout << e->type() << "\n";);
             SASSERT(false);
             return r;
         }
     }
+    
+    ~cross_nested() {
+        for (auto e: m_allocated)
+            delete e;
+        m_allocated.clear();
+    }
 
+    bool done() const { return m_done; }
+    
 };
 }
diff --git a/src/math/lp/horner.cpp b/src/math/lp/horner.cpp
index 0d87a728b..5c842dd78 100644
--- a/src/math/lp/horner.cpp
+++ b/src/math/lp/horner.cpp
@@ -63,31 +63,31 @@ bool horner::row_is_interesting(const T& row) const {
     return false;
 }
 
-bool horner::lemmas_on_expr(nex& e) {
+bool horner::lemmas_on_expr(nex_sum* e, cross_nested& cn) {
     TRACE("nla_horner", tout << "e = " << e << "\n";);
-    bool conflict = false;
-    cross_nested cn(e, [this, & conflict](const nex& n) {
+    cn.run(e);
+    return cn.done();
+}
+
+
+template <typename T> 
+bool horner::lemmas_on_row(const T& row) {
+    cross_nested cn([this](const nex* n) {
                            TRACE("nla_horner", tout << "cross-nested n = " << n << "\n";);
                            auto i = interval_of_expr(n);
                            TRACE("nla_horner", tout << "callback n = " << n << "\ni="; m_intervals.display(tout, i) << "\n";);
                            
-                           conflict = m_intervals.check_interval_for_conflict_on_zero(i);
+                           bool conflict = m_intervals.check_interval_for_conflict_on_zero(i);
                            c().lp_settings().st().m_cross_nested_forms++;
                            m_intervals.reset(); // clean the memory allocated by the interval bound dependencies
                            return conflict;
                        },
         [this](unsigned j) {  return c().var_is_fixed(j); }
         );
-    cn.run();
-    return conflict;
-}
 
-
-template <typename T> 
-bool horner::lemmas_on_row(const T& row) {
     SASSERT (row_is_interesting(row));
-    nex e = create_sum_from_row(row);
-    return lemmas_on_expr(e);
+    nex_sum* e = create_sum_from_row(row, cn);
+    return lemmas_on_expr(e, cn);
 }
 
 void horner::horner_lemmas() {
@@ -120,27 +120,28 @@ void horner::horner_lemmas() {
     }
 }
 
-typedef nla_expr<rational> nex;
-
-nex horner::nexvar(lpvar j) const {
+nex * horner::nexvar(lpvar j, cross_nested& cn) const {
     // todo: consider deepen the recursion
     if (!c().is_monomial_var(j))
-        return nex::var(j);
+        return cn.mk_var(j);
     const monomial& m = c().emons()[j];
-    nex e(expr_type::MUL);
+    nex_mul * e = cn.mk_mul();
     for (lpvar k : m.vars()) {
-        e.add_child(nex::var(k));
+        e->add_child(cn.mk_var(k));
         CTRACE("nla_horner", c().is_monomial_var(k), c().print_var(k, tout) << "\n";);
     }
     return e;
 }
 
-template <typename T> nex horner::create_sum_from_row(const T& row) {
+template <typename T> nex_sum* horner::create_sum_from_row(const T& row, cross_nested& cn) {
     TRACE("nla_horner", tout << "row="; m_core->print_term(row, tout) << "\n";);
     SASSERT(row.size() > 1);
-    nex e(expr_type::SUM);
-    for (const auto &p : row) {        
-        e.add_child(nex::scalar(p.coeff())* nexvar(p.var()));
+    nex_sum *e = cn.mk_sum();
+    for (const auto &p : row) {
+        if (p.coeff().is_one())
+            e->add_child(nexvar(p.var(), cn));
+        else
+            e->add_child(cn.mk_mul(cn.mk_scalar(p.coeff()), nexvar(p.var(), cn)));
     }
     return e;
 }
@@ -155,28 +156,28 @@ void horner::set_interval_for_scalar(interv& a, const rational& v) {
     m_intervals.set_upper_is_inf(a, false);
 }
 
-interv horner::interval_of_expr(const nex& e) {
+interv horner::interval_of_expr(const nex* e) {
     interv a;
-    switch (e.type()) {
+    switch (e->type()) {
     case expr_type::SCALAR:
-        set_interval_for_scalar(a, e.value());
+        set_interval_for_scalar(a, to_scalar(e)->value());
         return a;
     case expr_type::SUM:
-        return interval_of_sum(e);
+        return interval_of_sum(to_sum(e));
     case expr_type::MUL:
-        return interval_of_mul(e);
+        return interval_of_mul(to_mul(e));
     case expr_type::VAR:
-        set_var_interval(e.var(), a);
+        set_var_interval(to_var(e)->var(), a);
         return a;
     default:
-        TRACE("nla_horner_details", tout << e.type() << "\n";);
+        TRACE("nla_horner_details", tout << e->type() << "\n";);
         SASSERT(false);
         return interv();
     }
 }
-interv horner::interval_of_mul(const nex& e) {
-    SASSERT(e.is_mul());
-    auto & es = e.children();
+interv horner::interval_of_mul(const nex_mul* e) {
+    SASSERT(e->is_mul());
+    auto & es = to_mul(e)->children();
     interv a = interval_of_expr(es[0]);
     if (m_intervals.is_zero(a)) {
         m_intervals.set_zero_interval_deps_for_mult(a);
@@ -208,25 +209,25 @@ interv horner::interval_of_mul(const nex& e) {
     return a;
 }
 
-void horner::add_mul_to_vector(const nex& e, vector<std::pair<rational, lpvar>> &v) {
+void horner::add_mul_to_vector(const nex_mul* e, vector<std::pair<rational, lpvar>> &v) {
     TRACE("nla_horner_details", tout << e << "\n";);
-    SASSERT(e.is_mul() && e.size() > 0);
-    if (e.size() == 1) {
-        add_linear_to_vector(*(e.children().begin()), v);
+    SASSERT(e->size() > 0);
+    if (e->size() == 1) {
+        add_linear_to_vector(*(e->children().begin()), v);
         return;
     }
     rational r;
     lpvar j = -1;
-    for (const nex & c : e.children()) {
-        switch (c.type()) {
+    for (const nex * c : e->children()) {
+        switch (c->type()) {
         case expr_type::SCALAR:
-            r = c.value();
+            r = to_scalar(c)->value();
             break;
         case expr_type::VAR:
-            j = c.var();
+            j = to_var(c)->var();
             break;
         default:
-            TRACE("nla_horner_details", tout << e.type() << "\n";);
+            TRACE("nla_horner_details", tout << e->type() << "\n";);
             SASSERT(false);
         }
     }
@@ -234,30 +235,30 @@ void horner::add_mul_to_vector(const nex& e, vector<std::pair<rational, lpvar>>
     v.push_back(std::make_pair(r, j));
 }
 
-void horner::add_linear_to_vector(const nex& e, vector<std::pair<rational, lpvar>> &v) {
+void horner::add_linear_to_vector(const nex* e, vector<std::pair<rational, lpvar>> &v) {
     TRACE("nla_horner_details", tout << e << "\n";);
-    switch (e.type()) {
+    switch (e->type()) {
     case expr_type::MUL:
-        add_mul_to_vector(e, v);
+        add_mul_to_vector(to_mul(e), v);
         break; 
     case expr_type::VAR:
-        v.push_back(std::make_pair(rational(1), e.var()));
+        v.push_back(std::make_pair(rational(1), to_var(e)->var()));
         break;
     default:
-        SASSERT(!e.is_sum());
+        SASSERT(!e->is_sum());
         // noop
     }
 }
 // e = a * can_t + b
-lp::lar_term horner::expression_to_normalized_term(nex& e, rational& a, rational& b) {
+lp::lar_term horner::expression_to_normalized_term(const nex_sum* e, rational& a, rational& b) {
     TRACE("nla_horner_details", tout << e << "\n";);
     lpvar smallest_j;
     vector<std::pair<rational, lpvar>> v;
     b = rational(0);
     unsigned a_index;
-    for (const nex& c : e.children()) {
-        if (c.is_scalar()) {
-            b += c.value();
+    for (const nex* c : e->children()) {
+        if (c->is_scalar()) {
+            b += to_scalar(c)->value();
         } else {
             add_linear_to_vector(c, v);
             if (v.empty())
@@ -295,9 +296,10 @@ lp::lar_term horner::expression_to_normalized_term(nex& e, rational& a, rational
 
 // we should have in the case of found a*m_terms[k] + b = e,
 // where m_terms[k] corresponds to the returned lpvar
-lpvar horner::find_term_column(const nex& e, rational& a, rational& b) const {
-    nex n = e;
-    lp::lar_term norm_t = expression_to_normalized_term(n, a, b);
+lpvar horner::find_term_column(const nex* e, rational& a, rational& b) const {
+    if (!e->is_sum())
+        return -1;
+    lp::lar_term norm_t = expression_to_normalized_term(to_sum(e), a, b);
     std::pair<rational, lpvar> a_j;
     if (c().m_lar_solver.fetch_normalized_term_column(norm_t, a_j)) {
         a /= a_j.first;
@@ -306,8 +308,8 @@ lpvar horner::find_term_column(const nex& e, rational& a, rational& b) const {
     return -1;
 }
 
-interv horner::interval_of_sum_no_terms(const nex& e) {
-    auto & es = e.children(); 
+interv horner::interval_of_sum_no_terms(const nex_sum* e) {
+    auto & es = e->children(); 
     interv a = interval_of_expr(es[0]);
     if (m_intervals.is_inf(a)) {
         TRACE("nla_horner_details",  tout << "e=" << e << "\n";
@@ -340,10 +342,9 @@ interv horner::interval_of_sum_no_terms(const nex& e) {
     return a;
 }
 
-bool horner::interval_from_term(const nex& e, interv & i) const {
+bool horner::interval_from_term(const nex* e, interv & i) const {
     rational a, b;
-    nex n = e;
-    lpvar j = find_term_column(n, a, b);
+    lpvar j = find_term_column(e, a, b);
     if (j + 1 == 0)
         return false;
 
@@ -361,11 +362,10 @@ bool horner::interval_from_term(const nex& e, interv & i) const {
 }
 
 
-interv horner::interval_of_sum(const nex& e) {
+interv horner::interval_of_sum(const nex_sum* e) {
     TRACE("nla_horner_details", tout << "e=" << e << "\n";);
-    SASSERT(e.is_sum());
     interv i_e = interval_of_sum_no_terms(e);
-    if (e.sum_is_a_linear_term()) {
+    if (e->is_a_linear_term()) {
         interv i_from_term ;
         if (interval_from_term(e, i_from_term)) {
             interv r = m_intervals.intersect(i_e, i_from_term);
diff --git a/src/math/lp/horner.h b/src/math/lp/horner.h
index f5944b24c..cd9d60c20 100644
--- a/src/math/lp/horner.h
+++ b/src/math/lp/horner.h
@@ -22,6 +22,7 @@
 #include "math/lp/nla_common.h"
 #include "math/lp/nla_intervals.h"
 #include "math/lp/nla_expr.h"
+#include "math/lp/cross_nested.h"
 
 namespace nla {
 class core;
@@ -30,31 +31,30 @@ class core;
 class horner : common {
     intervals m_intervals;
 public:
-    typedef nla_expr<rational> nex;
     typedef intervals::interval interv;
     horner(core *core);
     void horner_lemmas();
     template <typename T> // T has an iterator of (coeff(), var())
     bool lemmas_on_row(const T&);
     template <typename T>  bool row_is_interesting(const T&) const;
-    template <typename T> nex create_sum_from_row(const T&);
-    intervals::interval interval_of_expr(const nex& e);
+    template <typename T>
+    nex_sum* create_sum_from_row(const T&, cross_nested&);
+    intervals::interval interval_of_expr(const nex* e);
     
-    nex nexvar(lpvar j) const;
-    intervals::interval interval_of_sum(const nex&);
-    intervals::interval interval_of_sum_no_terms(const nex&);
-    intervals::interval interval_of_mul(const nex&);
+    nex* nexvar(lpvar j, cross_nested& cn) const;
+    intervals::interval interval_of_sum(const nex_sum*);
+    intervals::interval interval_of_sum_no_terms(const nex_sum*);
+    intervals::interval interval_of_mul(const nex_mul*);
     void set_interval_for_scalar(intervals::interval&, const rational&);
     void set_var_interval(lpvar j, intervals::interval&) const;
-    bool lemmas_on_expr(nex &);
+    bool lemmas_on_expr(nex_sum* , cross_nested&);
     
     template <typename T> // T has an iterator of (coeff(), var())
     bool row_has_monomial_to_refine(const T&) const;
-    lpvar find_term_column(const nex& e, rational& a, rational& b) const;
-    static lp::lar_term expression_to_normalized_term(nex&, rational& a, rational & b);
-    static void add_linear_to_vector(const nex&, vector<std::pair<rational, lpvar>> &);
-    static void add_mul_to_vector(const nex&, vector<std::pair<rational, lpvar>> &);
-    bool is_tighter(const interv&, const interv&) const;
-    bool interval_from_term(const nex& e, interv&) const;
+    lpvar find_term_column(const nex* e, rational& a, rational& b) const;
+    static lp::lar_term expression_to_normalized_term(const nex_sum*, rational& a, rational & b);
+    static void add_linear_to_vector(const nex*, vector<std::pair<rational, lpvar>> &);
+    static void add_mul_to_vector(const nex_mul*, vector<std::pair<rational, lpvar>> &);
+    bool interval_from_term(const nex* e, interv&) const;
 }; // end of horner
 }
diff --git a/src/math/lp/nla_core.cpp b/src/math/lp/nla_core.cpp
index 5a69b0cae..70883c5c1 100644
--- a/src/math/lp/nla_core.cpp
+++ b/src/math/lp/nla_core.cpp
@@ -1346,41 +1346,41 @@ lbool core::test_check(
     return check(l);
 }
 
-nla_expr<rational> core::mk_expr(lpvar j)  const {
-    return nla_expr<rational>::var(j);
-}
+// nla_expr<rational> core::mk_expr(lpvar j)  const {
+//     return nla_expr<rational>::var(j);
+// }
 
-nla_expr<rational> core::mk_expr(const rational &a, lpvar j)  const {
-    if (a == 1)
-        return mk_expr(j);
-    nla_expr<rational> r(expr_type::MUL);
-    r.add_child(nla_expr<rational>::scalar(a));
-    r.add_child(nla_expr<rational>::var(j));
-    return r;            
-}
+// nla_expr<rational> core::mk_expr(const rational &a, lpvar j)  const {
+//     if (a == 1)
+//         return mk_expr(j);
+//     nla_expr<rational> r(expr_type::MUL);
+//     r.add_child(nla_expr<rational>::scalar(a));
+//     r.add_child(nla_expr<rational>::var(j));
+//     return r;            
+// }
 
-nla_expr<rational> core::mk_expr(const rational &a, const svector<lpvar>& vs) const {
-    nla_expr<rational> r(expr_type::MUL);
-    r.add_child(nla_expr<rational>::scalar(a));
-    for (lpvar j : vs)
-        r.add_child(nla_expr<rational>::var(j));
-    return r;            
-}
-nla_expr<rational> core::mk_expr(const lp::lar_term& t) const {
-    auto coeffs = t.coeffs_as_vector();
-    if (coeffs.size() == 1) {
-        return mk_expr(coeffs[0].first, coeffs[0].second);
-    }
-    nla_expr<rational> r(expr_type::SUM);
-    for (const auto & p : coeffs) {
-        lpvar j = p.second;
-        if (is_monomial_var(j))
-            r.add_child(mk_expr(p.first, m_emons[j].vars()));
-        else
-            r.add_child(mk_expr(p.first, j));
-    }
-    return r;
-}
+// nla_expr<rational> core::mk_expr(const rational &a, const svector<lpvar>& vs) const {
+//     nla_expr<rational> r(expr_type::MUL);
+//     r.add_child(nla_expr<rational>::scalar(a));
+//     for (lpvar j : vs)
+//         r.add_child(nla_expr<rational>::var(j));
+//     return r;            
+// }
+// nla_expr<rational> core::mk_expr(const lp::lar_term& t) const {
+//     auto coeffs = t.coeffs_as_vector();
+//     if (coeffs.size() == 1) {
+//         return mk_expr(coeffs[0].first, coeffs[0].second);
+//     }
+//     nla_expr<rational> r(expr_type::SUM);
+//     for (const auto & p : coeffs) {
+//         lpvar j = p.second;
+//         if (is_monomial_var(j))
+//             r.add_child(mk_expr(p.first, m_emons[j].vars()));
+//         else
+//             r.add_child(mk_expr(p.first, j));
+//     }
+//     return r;
+// }
 
 std::ostream& core::print_terms(std::ostream& out) const {
     for (unsigned i=0; i< m_lar_solver.m_terms.size(); i++) {
diff --git a/src/math/lp/nla_core.h b/src/math/lp/nla_core.h
index 5fee3abc0..e865f11cb 100644
--- a/src/math/lp/nla_core.h
+++ b/src/math/lp/nla_core.h
@@ -350,11 +350,6 @@ public:
     lpvar map_to_root(lpvar) const;
     std::ostream& print_terms(std::ostream&) const;
     std::ostream& print_term( const lp::lar_term&, std::ostream&) const;
-    nla_expr<rational> mk_expr(lpvar j) const;
-    nla_expr<rational> mk_expr(const rational &a, lpvar j) const;
-
-    nla_expr<rational> mk_expr(const rational &a, const svector<lpvar>& vs) const;
-    nla_expr<rational> mk_expr(const lp::lar_term& t) const;
 };  // end of core
 
 struct pp_mon {
diff --git a/src/math/lp/nla_expr.h b/src/math/lp/nla_expr.h
index 83be7f5eb..f35a4fc9a 100644
--- a/src/math/lp/nla_expr.h
+++ b/src/math/lp/nla_expr.h
@@ -47,38 +47,178 @@ inline std::ostream & operator<<(std::ostream& out, expr_type t) {
 
 
 // This class is needed in horner calculation with intervals
-template <typename T>
-class nla_expr {
-    // todo: use union
-    expr_type       m_type;
-    lpvar           m_j;
-    T             m_v; // for the scalar
-    vector<nla_expr>  m_children;
+class nex {
 public:
-    bool is_sum() const { return m_type == expr_type::SUM; }
-    bool is_var() const { return m_type == expr_type::VAR; }
-    bool is_mul() const { return m_type == expr_type::MUL; }
-    bool is_undef() const { return m_type == expr_type::UNDEF; }
-    bool is_scalar() const { return m_type == expr_type::SCALAR; }
-    lpvar var() const { SASSERT(m_type == expr_type::VAR); return m_j; }
-    expr_type type() const { return m_type; }
-    expr_type& type() { return m_type; }
-    const vector<nla_expr>& children() const { return m_children; }
-    vector<nla_expr>& children() { return m_children; }
-    const T& value() const { SASSERT(m_type == expr_type::SCALAR); return m_v; }
-    std::string str() const { std::stringstream ss; ss << *this; return ss.str(); }
-    std::ostream & print_sum(std::ostream& out) const {
+    virtual expr_type type() const = 0;
+    virtual std::ostream& print(std::ostream&) const = 0;
+    nex() {}
+    bool is_simple() const {
+        switch(type()) {
+        case expr_type::SUM:
+        case expr_type::MUL:
+            return false;        
+        default:
+            return true;
+         }
+    }
+
+    bool is_sum() const { return type() == expr_type::SUM; }
+    bool is_mul() const { return type() == expr_type::MUL; }
+    bool is_var() const { return type() == expr_type::VAR; }
+    bool is_scalar() const { return type() == expr_type::SCALAR; }
+    std::string str() const { std::stringstream ss; print(ss); return ss.str(); }
+    virtual ~nex() {}
+    virtual bool contains(lpvar j) const { return false; }
+    virtual int get_degree() const = 0;
+};
+std::ostream& operator<<(std::ostream& out, const nex&);
+
+class nex_var : public nex {
+    lpvar m_j;
+public:
+    nex_var(lpvar j) : m_j(j) {}
+    nex_var() {}
+    expr_type type() const { return expr_type::VAR; }
+    lpvar var() const {  return m_j; }
+    lpvar& var() {  return m_j; } // the setter
+    std::ostream & print(std::ostream& out) const {
+        out << 'v' <<  m_j;
+        return out;
+    }    
+
+    bool contains(lpvar j) const { return j == m_j; }
+    int get_degree() const { return 1; }
+};
+
+class nex_scalar : public nex {
+    rational m_v;
+public:
+    nex_scalar(const rational& v) : m_v(v) {}
+    nex_scalar() {}
+    expr_type type() const { return expr_type::SCALAR; }
+    const rational& value() const {  return m_v; }
+    rational& value() {  return m_v; } // the setter
+    std::ostream& print(std::ostream& out) const {
+        out << m_v;
+        return out;
+    }
+    
+    int get_degree() const { return 0; }
+
+};
+
+class nex_mul : public nex {
+    vector<nex*> m_children;
+public:
+    nex_mul()  {}
+    unsigned size() const { return m_children.size(); }
+    expr_type type() const { return expr_type::MUL; }
+    vector<nex*>& children() { return m_children;}
+    const vector<nex*>& children() const { return m_children;}
+    std::ostream & print(std::ostream& out) const {
         bool first = true;
-        for (const nla_expr& v : m_children) {            
-            std::string s = v.str();
+        for (const nex* v : m_children) {            
+            std::string s = v->str();
             if (first) {
                 first = false;
-                if (v.is_simple())
-                    out << v;
+                if (v->is_simple())
+                    out << s;
                 else 
                     out << "(" << s << ")";                            
             } else {
-                if (v.is_simple()) {
+                if (v->is_simple()) {
+                    if (s[0] == '-') {
+                        out << "*(" << s << ")";
+                    } else {
+                        out << "*" << s;
+                    }
+                } else {
+                    out << "*(" << s << ")";
+                }
+            }
+        }
+        return out;
+    }
+
+    void add_child(nex* e) { m_children.push_back(e); }
+
+    bool contains(lpvar j) const {
+        for (const nex* c : children()) {
+            if (c->contains(j))
+                return true;
+        }
+        return false;
+    }
+
+    static const nex_var* to_var(const nex*a) {
+        SASSERT(a->is_var());
+        return static_cast<const nex_var*>(a);
+    }
+
+    void get_powers_from_mul(std::unordered_map<lpvar, unsigned> & r) const {
+        r.clear();
+        for (const auto & c : children()) {
+            if (!c->is_var()) {
+                continue;
+            }
+            lpvar j = to_var(c)->var();
+            auto it = r.find(j);
+            if (it == r.end()) {
+                r[j] = 1;
+            } else {
+                it->second++;
+            }
+        }
+        TRACE("nla_cn_details", tout << "powers of " << *this << "\n"; print_vector(r, tout)<< "\n";);
+    }
+
+    int get_degree() const {
+        int degree = 0;       
+        for (auto  e : children()) {
+            degree +=  e->get_degree();
+        }
+        return degree;
+    }
+    
+
+};
+
+class nex_sum : public nex {
+    vector<nex*> m_children;
+public:
+    nex_sum()  {}
+    expr_type type() const { return expr_type::SUM; }
+    vector<nex*>& children() { return m_children;}
+    const vector<nex*>& children() const { return m_children;}    
+    unsigned size() const { return m_children.size(); }
+
+    // we need a linear combination of at least two variables
+    bool is_a_linear_term() const {
+        TRACE("nex_details", tout << *this << "\n";);
+        unsigned number_of_non_scalars = 0;
+        for (auto  e : children()) {
+            int d = e->get_degree();
+            if (d == 0) continue;
+            if (d > 1) return false;
+            
+            number_of_non_scalars++;
+        }
+        TRACE("nex_details", tout << (number_of_non_scalars > 1?"linear":"non-linear") << "\n";); 
+        return number_of_non_scalars > 1;
+    }
+    
+    std::ostream & print(std::ostream& out) const {
+        bool first = true;
+        for (const nex* v : m_children) {            
+            std::string s = v->str();
+            if (first) {
+                first = false;
+                if (v->is_simple())
+                    out << s;
+                else 
+                    out << "(" << s << ")";                            
+            } else {
+                if (v->is_simple()) {
                     if (s[0] == '-') {
                         out << s;
                     } else {
@@ -93,457 +233,52 @@ public:
     }
 
     void simplify() {
-        if (is_simple()) return;
-        bool has_sum = false;
-        if (is_sum()) {
-            for (auto & e : m_children) {
-                e.simplify();
-                has_sum |= e.is_sum();
-            }
-            if (has_sum) {
-                nla_expr n(expr_type::SUM);
-                for (auto &e : m_children) {
-                    n += e;
-                }
-                m_children = n.m_children;
-            }        
-        } else if (is_mul()) {
-            bool has_mul = false;
-            for (auto & e : m_children) {
-                e.simplify();
-                has_mul |= e.is_mul();
-            }
-            if (has_mul) {
-                nla_expr n(expr_type::MUL);
-                for (auto &e : m_children) {
-                    n *= e;
-                }
-                m_children = n.m_children;
-            }
-            TRACE("nla_cn_details", tout << "simplified " << *this << "\n";);
-        }
-    }
-
-    std::ostream & print_mul(std::ostream& out) const {
-        bool first = true;
-        for (const nla_expr& v : m_children) {            
-            std::string s = v.str();
-            if (first) {
-                first = false;
-                if (v.is_simple())
-                    out << s;
-                else 
-                    out << "(" << s << ")";                            
-            } else {
-                if (v.is_simple()) {
-                    if (s[0] == '-') {
-                        out << "*(" << s << ")";
-                    } else {
-                        out << "*" << s;
-                    }
-                } else {
-                    out << "*(" << s << ")";
-                }
-            }
-        }
-        return out;
-    }
-    std::ostream & print(std::ostream& out) const {
-        switch(m_type) {
-        case expr_type::SUM:
-            return print_sum(out);
-        case expr_type::MUL:
-            return print_mul(out);
-        case expr_type::VAR:
-            out << 'v' <<  m_j;
-            return out;
-        case expr_type::SCALAR:
-            out << m_v;
-            return out;
-        default:
-            out << "undef";
-            return out;
-        }
-    }
-
-    bool is_simple() const {
-        switch(m_type) {
-        case expr_type::SUM:
-        case expr_type::MUL:
-            return false;
-        
-        default:
-            return true;
-        }
-    }
-
-    unsigned size() const {
-        switch(m_type) {
-        case expr_type::SUM:
-        case expr_type::MUL:
-            return m_children.size();
-        
-        default:
-            return 1;
-        }
-    }
-    nla_expr(expr_type t): m_type(t) {}
-    nla_expr(): m_type(expr_type::UNDEF) {}
-    
-    void add_child(const nla_expr& e) {
-        m_children.push_back(e);
-    }
-
-    void add_child(const T& k) {
-        m_children.push_back(scalar(k));
-    }
-
-    void add_children() { }
-
-    template <typename K, typename ...Args>
-    void add_children(K e, Args ...  es) {
-        add_child(e);
-        add_children(es ...);
-    }
-
-    template <typename K, typename ... Args>
-    static nla_expr sum(K e, Args ... es) {
-        nla_expr r(expr_type::SUM);
-        r.add_children(e, es...);
-        return r;
-    }
-
-    template <typename K, typename ... Args>
-    static nla_expr mul(K e, Args ... es) {
-        nla_expr r(expr_type::MUL);
-        r.add_children(e, es...);
-        return r;
-    }
-
-    static nla_expr mul(const T& v, nla_expr & w) {
-        if (v == 1)
-            return w;
-        nla_expr r(expr_type::MUL);
-        r.add_child(scalar(v));
-        r.add_child(w);
-        return r;
-    }
-
-    static nla_expr mul() {
-        return nla_expr(expr_type::MUL);
-    }
-
-    static nla_expr mul(const T& v, lpvar j) {
-        if (v == 1)
-            return var(j);
-        return mul(scalar(v), var(j));
-    }
-
-    static nla_expr scalar(const T& v)  {
-        nla_expr r(expr_type::SCALAR);
-        r.m_v = v;
-        return r;
-    }
-
-    static nla_expr var(lpvar j)  {
-        nla_expr r(expr_type::VAR);
-        r.m_j = j;
-        return r;
-    }
-
-    bool contains(lpvar j) const {
-        if (is_var())
-            return m_j == j;
-        if (is_mul()) {
-            for (const nla_expr<T>& c : children()) {
-                if (c.contains(j))
-                    return true;
-            }
-        }
-        return false;
-    }
-
-    nla_expr& operator*=(const nla_expr& b) {
-        if (is_mul()) {
-            if (b.is_mul()) {
-                for (auto& e: b.children())
-                    add_child(e);
-            } else {
-                add_child(b);
-            }
-            return *this;
-        }
-        SASSERT(false); // not impl
-        return *this;
-    }
-
-    std::unordered_map<lpvar, int> get_powers_from_mul() const {
-        SASSERT(is_mul());
-        std::unordered_map<lpvar, int> r;
-        for (const auto & c : children()) {
-            if (!c.is_var()) {
-                continue;
-            }
-            lpvar j = c.var();
-            auto it = r.find(j);
-            if (it == r.end()) {
-                r[j] = 1;
-            } else {
-                it->second++;
-            }
-        }
-        TRACE("nla_cn_details", tout << "powers of " << *this << "\n"; print_vector(r, tout)<< "\n";);
-        return r;
-    }
-
-    friend nla_expr operator-(const nla_expr& a, const nla_expr&b) {
-        return a + scalar(T(-1))*b;
+        SASSERT(false);
     }
     
-    nla_expr& operator/=(const nla_expr& b) {
-        TRACE("nla_cn_details", tout << *this <<" / " << b << "\n";);
-        if (b.is_var()) {
-            *this = (*this) / b.var();
-            TRACE("nla_cn_details", tout << *this << "\n";);
-            return *this;
-        }
-        SASSERT(b.is_mul());
-        if (is_sum()) {
-            for (auto & e : children()) {
-                e /= b;
-            }
-            TRACE("nla_cn_details", tout << *this << "\n";);
-            return *this;
-        }
-        if (is_var() || children().size() == 1) {
-            *this = scalar(T(1));
-            TRACE("nla_cn_details", tout << *this << "\n";);
-            return *this;
-        }
-        SASSERT(is_mul());
-        auto powers = b.get_powers_from_mul();
-        unsigned i = 0, k = 0;
-        for (; i < children().size(); i++, k++) {
-            auto & e = children()[i];
-            TRACE("nla_cn_details", tout << "e=" << e << ",i=" <<i<< ",k=" << k<< "\n";);
-            
-            if (!e.is_var()) {
-                SASSERT(e.is_scalar());
-                if (i != k)
-                    children()[k] = children()[i];
-                
-                TRACE("nla_cn_details", tout << "continue\n";);
-                continue;
-
-            }
-            lpvar j = e.var();
-            auto it = powers.find(j);
-            if (it == powers.end()) {
-                if (i != k)
-                    children()[k] = children()[i];
-            } else {
-                it->second --;
-                if (it->second == 0)
-                    powers.erase(it);
-                k--;
-            }
-            TRACE("nla_cn_details", tout << *this << "\n";);
-            
-        }
-        SASSERT(powers.size() == 0);
-        while(k ++ < i)
-            children().pop_back();
-
-        if (children().size() == 0)
-            *this = scalar(T(1));
-        TRACE("nla_cn_details", tout << *this << "\n";);
-        
-        return *this;
-    }
-        
-    
-    nla_expr& operator+=(const nla_expr& b) {
-        if (is_sum()) {
-            if (b.is_sum()) {
-                for (auto& e: b.children())
-                    add_child(e);
-            } else {
-                add_child(b);
-            }
-            return *this;
-        }
-        SASSERT(false); // not impl
-        return *this;
-    }
-
-    // we need a linear combination of at least two variables
-    bool sum_is_a_linear_term() const {
-        SASSERT(is_sum());
-        TRACE("nla_expr_details", tout << *this << "\n";);
-        unsigned number_of_non_scalars = 0;
-        for (auto & e : children()) {
-            int d = e.get_degree();
-            if (d == 0) continue;
-            if (d > 1) return false;
-            
-            number_of_non_scalars++;
-        }
-        TRACE("nla_expr_details", tout << (number_of_non_scalars > 1?"linear":"non-linear") << "\n";); 
-        return number_of_non_scalars > 1;
-    }
-
     int get_degree() const {
-        switch (type()) {
-        case expr_type::SUM: {
-            int degree = 0;       
-            for (auto & e : children()) {
-                degree = std::max(degree, e.get_degree());
-            }
-            return degree;
+        int degree = 0;       
+        for (auto  e : children()) {
+            degree = std::max(degree, e->get_degree());
         }
-
-        case expr_type::MUL: {
-            int degree = 0;       
-            for (auto & e : children()) {
-                degree += e.get_degree();
-            }
-            return degree;
-        }            
-        case expr_type::VAR:
-            return 1;
-        case expr_type::SCALAR:
-            return 0;
-        case expr_type::UNDEF:
-        default:
-            UNREACHABLE();         
-            break;
-        }
-        return 0;
-    }    
+        return degree;
+    }
+    
+    void add_child(nex* e) { m_children.push_back(e); }
 };
 
-/*
-nla_expr operator/=(const nla_expr &a, const nla_expr& b) {
-        TRACE("nla_cn_details", tout << a <<" / " << b << "\n";);
-        if (b.is_var()) {
-            return a / b.var();
-        }
-        SASSERT(b.is_mul());
-        if (a.is_sum()) {
-            auto r = nex::sum();
-            for (auto & e : a.children()) {
-                r.add_child(e/b);
-            }
-            return r;
-        }
-        if (is_var()) {
-            return scalar(T(1));
-            return *this;
-        }
-        SASSERT(a.is_mul());
-        auto powers = b.get_powers_from_mul();
-        auto r=nex::mul();
-        for (unsigned i = 0; i < a.children().size(); i++, k++) {
-            auto & e = children()[i];
-            if (!e.is_var()) {
-                SASSERT(e.is_scalar());
-                r.add_child(e);
-                continue;
-            }
-            lpvar j = e.var();
-            auto it = powers.find(j);
-            if (it == powers.end()) {
-                r.add_child(e);
-            } else {
-                it->second --; // finish h
-                if (it->second == 0)
-                    powers.erase(it);
-            }            
-        }
-
-        return r;
-    }
-*/
-template <typename T> 
-nla_expr<T> operator+(const nla_expr<T>& a, const nla_expr<T>& b) {
-    if (a.is_sum()) {
-        nla_expr<T> r(expr_type::SUM);
-        r.children() = a.children();
-        if (b.is_sum()) {
-            for (auto& e: b.children())
-                r.add_child(e);
-        } else {
-            r.add_child(b);
-        }
-        return r;
-    }
-    if (b.is_sum()) {
-        nla_expr<T> r(expr_type::SUM);
-        r.children() = b.children();
-        r.add_child(a);
-        return r;
-    }
-    return nla_expr<T>::sum(a, b);
+inline const nex_sum* to_sum(const nex*a) {
+    SASSERT(a->is_sum());
+    return static_cast<const nex_sum*>(a);
 }
 
-template <typename T> 
-nla_expr<T> operator*(const nla_expr<T>& a, const nla_expr<T>& b) {
-    if (a.is_scalar() && a.value() == T(1))
-        return b;
-    if (b.is_scalar() && b.value() == T(1))
-        return a;
-    if (a.is_mul()) {
-        nla_expr<T> r(expr_type::MUL);
-        r.children() = a.children();
-        if (b.is_mul()) {
-            for (auto& e: b.children())
-                r.add_child(e);
-        } else {
-            r.add_child(b);
-        }
-        return r;
-    }
-    if (b.is_mul()) {
-        nla_expr<T> r(expr_type::MUL);
-        r.children() = b.children();
-        r.add_child(a);
-        return r;
-    }
-    return nla_expr<T>::mul(a, b);
+inline nex_sum* to_sum(nex * a) {
+    SASSERT(a->is_sum());
+    return static_cast<nex_sum*>(a);
 }
 
-
-template <typename T> 
-nla_expr<T> operator/(const nla_expr<T>& a, lpvar j) {
-    TRACE("nla_cn_details", tout << "a=" << a << ", v" << j << "\n";);
-    SASSERT((a.is_mul() && a.contains(j)) || (a.is_var() && a.var() == j));
-    if (a.is_var())
-        return nla_expr<T>::scalar(T(1));
-    nla_expr<T> b;
-    bool seenj = false;
-    for (const nla_expr<T>& c : a.children()) {
-        if (!seenj) {
-            if (c.contains(j)) {
-                if (!c.is_var())                     
-                    b.add_child(c / j);
-                seenj = true;
-                continue;
-            } 
-        }
-        b.add_child(c);
-    }
-    if (b.children().size() > 1) {
-        b.type() = expr_type::MUL;
-    } else if (b.children().size() == 1) {
-        auto t = b.children()[0];
-        b = t;        
-    } else {
-        b = nla_expr<T>::scalar(T(1));
-    }
-    return b;
+    
+inline const nex_var* to_var(const nex*a)  {
+    SASSERT(a->is_var());
+    return static_cast<const nex_var*>(a);
 }
-template <typename T>
-std::ostream& operator<<(std::ostream& out, const nla_expr<T>& e ) {
+
+inline const nex_mul* to_mul(const nex*a) {
+    SASSERT(a->is_mul());
+    return static_cast<const nex_mul*>(a);
+}
+
+inline nex_mul* to_mul(nex*a) {
+    SASSERT(a->is_mul());
+    return static_cast<nex_mul*>(a);
+}
+
+inline const nex_scalar * to_scalar(const nex* a) {
+    SASSERT(a->is_scalar());
+    return static_cast<const nex_scalar*>(a);
+}
+
+inline std::ostream& operator<<(std::ostream& out, const nex& e ) {
     return e.print(out);
 }
 
diff --git a/src/test/lp/lp.cpp b/src/test/lp/lp.cpp
index e6c65a2d1..a0ff48209 100644
--- a/src/test/lp/lp.cpp
+++ b/src/test/lp/lp.cpp
@@ -68,26 +68,33 @@ void test_basic_lemma_for_mon_zero_from_factors_to_monomial();
 void test_basic_lemma_for_mon_neutral_from_monomial_to_factors();
 void test_basic_lemma_for_mon_neutral_from_factors_to_monomial();
 
-void test_cn_on_expr(horner::nex t) {
-    TRACE("nla_cn", tout << "t=" << t << '\n';);
-    cross_nested cn(t, [](const horner::nex& n) {
-                           TRACE("nla_cn_test", tout << n << "\n";);
-                           return false;
-                       } ,
-        [](unsigned) { return false; });
-    cn.run();
+void test_cn_on_expr(nex_sum *t, cross_nested& cn) {
+    TRACE("nla_cn", tout << "t=" << *t << '\n';);
+    cn.run(t);
 }
 
 void test_cn() {
-    typedef horner::nex nex;
+    cross_nested cn([](const nex* n) {
+                           TRACE("nla_cn_test", tout << *n << "\n";);
+                           return false;
+                       } ,
+        [](unsigned) { return false; });
     enable_trace("nla_cn");
     enable_trace("nla_cn_details");
-    nex a = nex::var(0), b = nex::var(1), c = nex::var(2), d = nex::var(3), e = nex::var(4), f = nex::var(5), g = nex::var(6);
-    nex min_1 = nex::scalar(rational(-1));
+    nex_var* a = cn.mk_var(0);
+    nex_var* b = cn.mk_var(1);
+    nex_var* c = cn.mk_var(2);
+    nex_var* d = cn.mk_var(3);
+    nex_var* e = cn.mk_var(4);
+    nex_var* f = cn.mk_var(5);
+    nex_var* g = cn.mk_var(6);
+    nex* min_1 = cn.mk_scalar(rational(-1));
     // test_cn_on_expr(min_1*c*e + min_1*b*d + min_1*a*b + a*c);
-    TRACE("nla_cn", tout << "done\n";);
-   
-    test_cn_on_expr(b*c*d -  b*c*g);
+    nex* bcd = cn.mk_mul(b, c, d);
+    nex_mul* bcg = cn.mk_mul(b, c, g);
+    bcg->add_child(min_1);
+    nex_sum* t = cn.mk_sum(bcd, bcg);
+    test_cn_on_expr(t, cn);
     //    test_cn_on_expr(a*a*d + a*b*c*d + a*a*c*c*d + a*d*d + e*a*e + e*a*c + e*d);
     // TRACE("nla_cn", tout << "done\n";);
     // test_cn_on_expr(a*b*d + a*b*c + c*b*d + a*c*d);