diff --git a/src/math/lp/cross_nested.h b/src/math/lp/cross_nested.h
index babb455dd..4dc166fb9 100644
--- a/src/math/lp/cross_nested.h
+++ b/src/math/lp/cross_nested.h
@@ -270,7 +270,7 @@ public:
         }
         
         nex* c_over_f = mk_div(*c, f);
-        to_sum(c_over_f)->simplify();
+        to_sum(c_over_f)->simplify(&c_over_f);
         *c = mk_mul(f, c_over_f);
         TRACE("nla_cn", tout << "common factor=" << *f << ", c=" << **c << "\ne = " << *m_e << "\n";);
         
@@ -463,8 +463,7 @@ public:
             || (ce->is_var() && to_var(ce)->var() == j);
     }
     // all factors of j go to a, the rest to b
-    void pre_split(nex_sum * e, lpvar j, nex_sum* & a, nex* & b) {
-        
+    void pre_split(nex_sum * e, lpvar j, nex_sum*& a, nex*& b) {        
         a = mk_sum();
         m_b_split_vec.clear();
         for (nex * ce: e->children()) {
@@ -478,7 +477,8 @@ public:
         }
         TRACE("nla_cn_details", tout << "a = " << *a << "\n";);
         SASSERT(a->children().size() >= 2 && m_b_split_vec.size());
-        a->simplify();
+        nex* f;
+        a->simplify(&f); 
         
         if (m_b_split_vec.size() == 1) {
             b = m_b_split_vec[0];
@@ -608,11 +608,13 @@ public:
         for (unsigned j = 0; j < a->size(); j ++) {
             a->children()[j] = normalize(a->children()[j]);            
         }
-        a->simplify();
-        return a;
+        nex *r;
+        a->simplify(&r);
+        return r;
     }
 
     nex * normalize_mul(nex_mul* a) {
+        TRACE("nla_cn", tout << *a << "\n";);
         int sum_j = -1;
         for (unsigned j = 0; j < a->size(); j ++) {
             a->children()[j] = normalize(a->children()[j]);
@@ -620,28 +622,36 @@ public:
                 sum_j = j;
         }
 
-        if (sum_j == -1)
-            return a;
+        if (sum_j == -1) {
+            nex * r;
+            a->simplify(&r);
+            SASSERT(r->is_simplified());
+            return r;
+        }
         
         nex_sum *r = mk_sum();
         nex_sum *as = to_sum(a->children()[sum_j]);
         for (unsigned k = 0; k < as->size(); k++) {
             nex_mul *b = mk_mul(as->children()[k]);
-            r->add_child(b);
             for (unsigned j = 0; j < a->size(); j ++) {
                 if ((int)j != sum_j)
                     b->add_child(a->children()[j]);
             }
-            b->simplify();
+            nex *e;
+            b->simplify(&e);
+            r->add_child(e);
         }
-        TRACE("nla_cn", tout << *r << "\n";); 
-        return normalize_sum(r);
+        TRACE("nla_cn", tout << *r << "\n";);
+        nex *rs = normalize_sum(r);
+        SASSERT(rs->is_simplified());
+        return rs;
+
     }
 
     
     
     nex * normalize(nex* a) {
-        if (a->is_simple())
+        if (a->is_elementary())
             return a;
         nex *r;
         if (a->is_mul()) {
diff --git a/src/math/lp/nex.h b/src/math/lp/nex.h
index 67da8168f..9f588806d 100644
--- a/src/math/lp/nex.h
+++ b/src/math/lp/nex.h
@@ -49,7 +49,7 @@ public:
     virtual expr_type type() const = 0;
     virtual std::ostream& print(std::ostream&) const = 0;
     nex() {}
-    bool is_simple() const {
+    bool is_elementary() const {
         switch(type()) {
         case expr_type::SUM:
         case expr_type::MUL:
@@ -67,7 +67,10 @@ public:
     virtual ~nex() {}
     virtual bool contains(lpvar j) const { return false; }
     virtual int get_degree() const = 0;
-    virtual void simplify() {}
+    virtual void simplify(nex** ) = 0;
+    virtual bool is_simplified() const {
+        return true;
+    }
     virtual const ptr_vector<nex> * children_ptr() const {
         UNREACHABLE();
         return nullptr;
@@ -103,6 +106,7 @@ public:
 
     bool contains(lpvar j) const { return j == m_j; }
     int get_degree() const { return 1; }
+    virtual void simplify(nex** e) { *e = this; }
 };
 
 class nex_scalar : public nex {
@@ -119,29 +123,48 @@ public:
     }
     
     int get_degree() const { return 0; }
+    virtual void simplify(nex** e) { *e = this; }
 
 };
 
+const nex_scalar * to_scalar(const nex* a);
+
+static bool ignored_child(nex* e, expr_type t) {
+    switch(t) {
+    case expr_type::MUL:
+        return e->is_scalar() && to_scalar(e)->value().is_one();
+    case expr_type::SUM:        
+        return e->is_scalar() && to_scalar(e)->value().is_zero();
+    default: return false;
+    }
+    return false;
+}
+
 static void promote_children_by_type(ptr_vector<nex> * children, expr_type t) {
     ptr_vector<nex> to_promote;
+    int skipped = 0;
     for(unsigned j = 0; j < children->size(); j++) {
-        nex* e = (*children)[j];
-        e->simplify();
-        if (e->type() == t) {
-            to_promote.push_back(e);
+        nex** e = &(*children)[j];
+        (*e)->simplify(e);
+        if ((*e)->type() == t) {
+            to_promote.push_back(*e);
+        } else if (ignored_child(*e, t)) {
+            skipped ++;
+            continue;
         } else {
-            unsigned offset = to_promote.size();
+            unsigned offset = to_promote.size() + skipped;
             if (offset) {
-                (*children)[j - offset] = e;
+                (*children)[j - offset] = *e;
             }
         }
     }
-
-    children->shrink(children->size() - to_promote.size());
+    
+    children->shrink(children->size() - to_promote.size() - skipped);
     
     for (nex *e : to_promote) {
         for (nex *ee : *(e->children_ptr())) {
-            children->push_back(ee);
+            if (!ignored_child(ee, t))
+                children->push_back(ee);            
         }
     }    
 }
@@ -163,12 +186,12 @@ public:
             std::string s = v->str();
             if (first) {
                 first = false;
-                if (v->is_simple())
+                if (v->is_elementary())
                     out << s;
                 else 
                     out << "(" << s << ")";                            
             } else {
-                if (v->is_simple()) {
+                if (v->is_elementary()) {
                     if (s[0] == '-') {
                         out << "*(" << s << ")";
                     } else {
@@ -222,12 +245,29 @@ public:
         return degree;
     }
     
-    void simplify() {
+    void simplify(nex **e) {
+        *e = this;
         TRACE("nla_cn_details", tout << *this << "\n";);
         promote_children_by_type(&m_children, expr_type::MUL);
+        if (size() == 1) 
+            *e = m_children[0];
         TRACE("nla_cn_details", tout << *this << "\n";);
+        SASSERT((*e)->is_simplified());
     }
-    #ifdef Z3DEBUG
+
+    virtual bool is_simplified() const {
+        if (size() < 2)
+            return false;
+        for (nex * e : children()) {
+            if (e->is_mul()) 
+                return false;
+            if (e->is_scalar() && to_scalar(e)->value().is_one())
+                return false;
+        }
+        return true;
+    }
+
+#ifdef Z3DEBUG
     virtual void sort() {
         for (nex * c : m_children) {
             c->sort();
@@ -271,12 +311,12 @@ public:
             std::string s = v->str();
             if (first) {
                 first = false;
-                if (v->is_simple())
+                if (v->is_elementary())
                     out << s;
                 else 
                     out << "(" << s << ")";                            
             } else {
-                if (v->is_simple()) {
+                if (v->is_elementary()) {
                     if (s[0] == '-') {
                         out << s;
                     } else {
@@ -290,8 +330,21 @@ public:
         return out;
     }
 
-    void simplify() {
+    void simplify(nex **e) {
+        *e = this;
         promote_children_by_type(&m_children, expr_type::SUM);
+        if (size() == 1)
+            *e = m_children[0];
+    }
+    virtual bool is_simplified() const {
+        if (size() < 2) return false;
+        for (nex * e : children()) {
+            if (e->is_sum())
+                return false;
+            if (e->is_scalar() && to_scalar(e)->value().is_zero())
+                return false;
+        }
+        return true;
     }
     
     int get_degree() const {
@@ -331,6 +384,11 @@ inline const nex_var* to_var(const nex*a)  {
     return static_cast<const nex_var*>(a);
 }
 
+inline const nex_scalar* to_scalar(const nex*a)  {
+    SASSERT(a->is_scalar());
+    return static_cast<const nex_scalar*>(a);
+}
+
 inline const nex_mul* to_mul(const nex*a) {
     SASSERT(a->is_mul());
     return static_cast<const nex_mul*>(a);
@@ -341,11 +399,6 @@ inline nex_mul* to_mul(nex*a) {
     return static_cast<nex_mul*>(a);
 }
 
-inline const nex_scalar * to_scalar(const nex* a) {
-    SASSERT(a->is_scalar());
-    return static_cast<const nex_scalar*>(a);
-}
-
 inline std::ostream& operator<<(std::ostream& out, const nex& e ) {
     return e.print(out);
 }
diff --git a/src/test/lp/lp.cpp b/src/test/lp/lp.cpp
index f0b9ef763..0433232d5 100644
--- a/src/test/lp/lp.cpp
+++ b/src/test/lp/lp.cpp
@@ -102,8 +102,10 @@ void test_cn() {
     nex* eac = cn.mk_mul(e, a, c);
     nex* ed = cn.mk_mul(e, d);
     nex* _6aad = cn.mk_mul(cn.mk_scalar(rational(6)), a, a, d);
+#ifdef Z3DEBUG
     nex * clone = cn.clone(cn.mk_sum(_6aad, abcd, aaccd, add, eae, eac, ed));
     TRACE("nla_cn", tout << "clone = " << *clone << "\n";);
+#endif
     //    test_cn_on_expr(cn.mk_sum(aad,  abcd, aaccd, add, eae, eac, ed), cn);
     test_cn_on_expr(cn.mk_sum(_6aad, abcd, aaccd, add, eae, eac, ed), cn);
     // TRACE("nla_cn", tout << "done\n";);