From ff75f88c4f2222882453d81f5e4294cc6bde8d81 Mon Sep 17 00:00:00 2001
From: Nikolaj Bjorner <nbjorner@microsoft.com>
Date: Mon, 31 Oct 2016 22:25:58 +0100
Subject: [PATCH] fix memory abuse in internalization in inc-sat-solver

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
---
 src/ast/rewriter/pb2bv_rewriter.cpp           |  43 ++--
 src/ast/rewriter/pb2bv_rewriter.h             |   1 -
 src/sat/sat_solver/inc_sat_solver.cpp         |   5 +-
 .../portfolio/bounded_int2bv_solver.cpp       |   1 +
 src/tactic/portfolio/enum2bv_solver.cpp       |   1 -
 src/tactic/portfolio/pb2bv_solver.cpp         |   2 +-
 src/util/sorting_network.h                    | 202 +++++++++++++-----
 7 files changed, 177 insertions(+), 78 deletions(-)

diff --git a/src/ast/rewriter/pb2bv_rewriter.cpp b/src/ast/rewriter/pb2bv_rewriter.cpp
index cf5b67793..37c87cd5b 100644
--- a/src/ast/rewriter/pb2bv_rewriter.cpp
+++ b/src/ast/rewriter/pb2bv_rewriter.cpp
@@ -241,37 +241,36 @@ struct pb2bv_rewriter::imp {
             m_args(m)
         {}
 
-        br_status mk_app_core(func_decl * f, unsigned sz, expr * const* args, expr_ref & result) {
+        bool mk_app(bool full, func_decl * f, unsigned sz, expr * const* args, expr_ref & result) {
             if (f->get_family_id() == pb.get_family_id()) {
-                mk_pb(f, sz, args, result);
-                ++m_imp.m_num_translated;
-                return BR_DONE;
+                mk_pb(full, f, sz, args, result);
             }
             else if (au.is_le(f) && is_pb(args[0], args[1])) {
-                ++m_imp.m_num_translated;
                 result = mk_le_ge<l_true>(m_args.size(), m_args.c_ptr(), m_k);
-                return BR_DONE;
             }
             else if (au.is_lt(f) && is_pb(args[0], args[1])) {
-                ++m_imp.m_num_translated;
                 ++m_k;
                 result = mk_le_ge<l_true>(m_args.size(), m_args.c_ptr(), m_k);
-                return BR_DONE;
             }
             else if (au.is_ge(f) && is_pb(args[1], args[0])) {
-                ++m_imp.m_num_translated;
                 result = mk_le_ge<l_true>(m_args.size(), m_args.c_ptr(), m_k);
-                return BR_DONE;
             }
             else if (au.is_gt(f) && is_pb(args[1], args[0])) {
-                ++m_imp.m_num_translated;
                 ++m_k;
                 result = mk_le_ge<l_true>(m_args.size(), m_args.c_ptr(), m_k);
-                return BR_DONE;
             }
             else if (m.is_eq(f) && is_pb(args[0], args[1])) {
-                ++m_imp.m_num_translated;
                 result = mk_le_ge<l_undef>(m_args.size(), m_args.c_ptr(), m_k);
+            }
+            else {
+                return false;
+            }
+            ++m_imp.m_num_translated;
+            return true;
+        }
+
+        br_status mk_app_core(func_decl * f, unsigned sz, expr * const* args, expr_ref & result) {
+            if (mk_app(true, f, sz, args, result)) {
                 return BR_DONE;
             }
             else {
@@ -350,25 +349,25 @@ struct pb2bv_rewriter::imp {
             return false;
         }
 
-        void mk_pb(func_decl * f, unsigned sz, expr * const* args, expr_ref & result) {
+        void mk_pb(bool full, func_decl * f, unsigned sz, expr * const* args, expr_ref & result) {
             SASSERT(f->get_family_id() == pb.get_family_id());
             if (is_or(f)) {
                 result = m.mk_or(sz, args);
             }
             else if (pb.is_at_most_k(f) && pb.get_k(f).is_unsigned()) {
-                result = m_sort.le(true, pb.get_k(f).get_unsigned(), sz, args);
+                result = m_sort.le(full, pb.get_k(f).get_unsigned(), sz, args);
             }
             else if (pb.is_at_least_k(f) && pb.get_k(f).is_unsigned()) {
-                result = m_sort.ge(true, pb.get_k(f).get_unsigned(), sz, args);
+                result = m_sort.ge(full, pb.get_k(f).get_unsigned(), sz, args);
             }
             else if (pb.is_eq(f) && pb.get_k(f).is_unsigned() && pb.has_unit_coefficients(f)) {
-                result = m_sort.eq(pb.get_k(f).get_unsigned(), sz, args);
+                result = m_sort.eq(full, pb.get_k(f).get_unsigned(), sz, args);
             }
             else if (pb.is_le(f) && pb.get_k(f).is_unsigned() && pb.has_unit_coefficients(f)) {
-                result = m_sort.le(true, pb.get_k(f).get_unsigned(), sz, args);
+                result = m_sort.le(full, pb.get_k(f).get_unsigned(), sz, args);
             }
             else if (pb.is_ge(f) && pb.get_k(f).is_unsigned() && pb.has_unit_coefficients(f)) {
-                result = m_sort.ge(true, pb.get_k(f).get_unsigned(), sz, args);
+                result = m_sort.ge(full, pb.get_k(f).get_unsigned(), sz, args);
             }
             else {
                 result = mk_bv(f, sz, args);
@@ -433,9 +432,6 @@ struct pb2bv_rewriter::imp {
     void operator()(expr * e, expr_ref & result, proof_ref & result_proof) {
         m_rw(e, result, result_proof);
     }
-    void assert_expr(expr * e, expr_ref & result, proof_ref & result_proof) {
-        m_rw(e, result, result_proof);
-    }
     void push() {
         m_fresh_lim.push_back(m_fresh.size());
     }
@@ -472,9 +468,6 @@ unsigned pb2bv_rewriter::get_num_steps() const { return m_imp->get_num_steps();
 void pb2bv_rewriter::cleanup() { ast_manager& mgr = m(); params_ref p = m_imp->m_params; dealloc(m_imp); m_imp = alloc(imp, mgr, p);  }
 func_decl_ref_vector const& pb2bv_rewriter::fresh_constants() const { return m_imp->m_fresh; }
 void pb2bv_rewriter::operator()(expr * e, expr_ref & result, proof_ref & result_proof) { (*m_imp)(e, result, result_proof); }
-void pb2bv_rewriter::assert_expr(expr* e, expr_ref & result, proof_ref & result_proof) { 
-    m_imp->assert_expr(e, result, result_proof); 
-}
 void pb2bv_rewriter::push() { m_imp->push(); }
 void pb2bv_rewriter::pop(unsigned num_scopes) { m_imp->pop(num_scopes); }
 void pb2bv_rewriter::flush_side_constraints(expr_ref_vector& side_constraints) { m_imp->flush_side_constraints(side_constraints); } 
diff --git a/src/ast/rewriter/pb2bv_rewriter.h b/src/ast/rewriter/pb2bv_rewriter.h
index 569eaf07d..47d8361cb 100644
--- a/src/ast/rewriter/pb2bv_rewriter.h
+++ b/src/ast/rewriter/pb2bv_rewriter.h
@@ -36,7 +36,6 @@ public:
     void cleanup();
     func_decl_ref_vector const& fresh_constants() const;
     void operator()(expr * e, expr_ref & result, proof_ref & result_proof);
-    void assert_expr(expr* e, expr_ref & result, proof_ref & result_proof);
     void push();
     void pop(unsigned num_scopes);
     void flush_side_constraints(expr_ref_vector& side_constraints);
diff --git a/src/sat/sat_solver/inc_sat_solver.cpp b/src/sat/sat_solver/inc_sat_solver.cpp
index f5efed726..d1fae0156 100644
--- a/src/sat/sat_solver/inc_sat_solver.cpp
+++ b/src/sat/sat_solver/inc_sat_solver.cpp
@@ -34,6 +34,7 @@ Notes:
 #include "bit_blaster_model_converter.h"
 #include "ast_translation.h"
 #include "ast_util.h"
+#include "propagate_values_tactic.h"
 
 // incremental SAT solver.
 class inc_sat_solver : public solver {
@@ -341,6 +342,7 @@ public:
                      mk_max_bv_sharing_tactic(m),
                      mk_bit_blaster_tactic(m, m_bb_rewriter.get()),
                      //mk_aig_tactic(),
+                     //mk_propagate_values_tactic(m, simp2_p),
                      using_params(mk_simplify_tactic(m), simp2_p));
         while (m_bb_rewriter->get_num_scopes() < m_num_scopes) {
             m_bb_rewriter->push();
@@ -377,6 +379,7 @@ private:
         g = m_subgoals[0];
         expr_ref_vector atoms(m);
         TRACE("sat", g->display_with_dependencies(tout););
+        std::cout << "exprs: " << g->num_exprs() << "\n";
         m_goal2sat(*g, m_params, m_solver, m_map, dep2asm, true);
         m_goal2sat.get_interpreted_atoms(atoms);
         if (!atoms.empty()) {
@@ -520,7 +523,7 @@ private:
         }
         dep2asm_t dep2asm;
         goal_ref g = alloc(goal, m, true, false); // models, maybe cores are enabled
-        for (unsigned i = 0 ; i < m_fmls.size(); ++i) {
+        for (unsigned i = m_fmls_head ; i < m_fmls.size(); ++i) {
             g->assert_expr(m_fmls[i].get());
         }
         lbool res = internalize_goal(g, dep2asm);
diff --git a/src/tactic/portfolio/bounded_int2bv_solver.cpp b/src/tactic/portfolio/bounded_int2bv_solver.cpp
index 0b136dda7..b6c85c159 100644
--- a/src/tactic/portfolio/bounded_int2bv_solver.cpp
+++ b/src/tactic/portfolio/bounded_int2bv_solver.cpp
@@ -303,6 +303,7 @@ private:
             }
         }
         m_assertions.reset();
+        m_rewriter.reset();
     }
 };
 
diff --git a/src/tactic/portfolio/enum2bv_solver.cpp b/src/tactic/portfolio/enum2bv_solver.cpp
index e89f9d188..f3288d8d6 100644
--- a/src/tactic/portfolio/enum2bv_solver.cpp
+++ b/src/tactic/portfolio/enum2bv_solver.cpp
@@ -99,7 +99,6 @@ public:
     virtual lbool find_mutexes(expr_ref_vector const& vars, vector<expr_ref_vector>& mutexes) { return m_solver->find_mutexes(vars, mutexes); }
     
     virtual lbool get_consequences_core(expr_ref_vector const& asms, expr_ref_vector const& vars, expr_ref_vector& consequences) {
-
         datatype_util dt(m);
         bv_util bv(m);
         expr_ref_vector bvars(m), conseq(m), bounds(m);
diff --git a/src/tactic/portfolio/pb2bv_solver.cpp b/src/tactic/portfolio/pb2bv_solver.cpp
index d1826e61d..bfd533e8a 100644
--- a/src/tactic/portfolio/pb2bv_solver.cpp
+++ b/src/tactic/portfolio/pb2bv_solver.cpp
@@ -113,7 +113,7 @@ private:
         expr_ref fml(m);
         expr_ref_vector fmls(m);
         for (unsigned i = 0; i < m_assertions.size(); ++i) {
-            m_rewriter.assert_expr(m_assertions[i].get(), fml, proof);
+            m_rewriter(m_assertions[i].get(), fml, proof);
             m_solver->assert_expr(fml);
         }
         m_rewriter.flush_side_constraints(fmls);
diff --git a/src/util/sorting_network.h b/src/util/sorting_network.h
index 33c4f8b61..87d8bbf3f 100644
--- a/src/util/sorting_network.h
+++ b/src/util/sorting_network.h
@@ -24,7 +24,6 @@ Notes:
 #ifndef SORTING_NETWORK_H_
 #define SORTING_NETWORK_H_
 
-
     template <typename Ext>
     class sorting_network {
         typedef typename Ext::vector vect;
@@ -213,17 +212,17 @@ Notes:
             }
         }
 
-        literal eq(unsigned k, unsigned n, literal const* xs) {
+        literal eq(bool full, unsigned k, unsigned n, literal const* xs) {
             if (k > n) {
                 return ctx.mk_false();
             }
             SASSERT(k <= n);
             literal_vector in, out;
             if (dualize(k, n, xs, in)) {
-                return eq(k, n, in.c_ptr());
+                return eq(full, k, n, in.c_ptr());
             }
             else if (k == 1) {
-                return mk_exactly_1(true, n, xs);
+                return mk_exactly_1(full, n, xs);
             }
             else {
                 SASSERT(2*k <= n);
@@ -242,34 +241,64 @@ Notes:
         
     private:
 
-
-        literal mk_and(literal l1, literal l2) {
-            literal result = fresh();
-            add_clause(ctx.mk_not(result), l1);
-            add_clause(ctx.mk_not(result), l2);
-            add_clause(ctx.mk_not(l1), ctx.mk_not(l2), result);
-            return result;
-        }
-
-        void mk_implies_or(literal l, unsigned n, literal const* xs) {
+        void add_implies_or(literal l, unsigned n, literal const* xs) {
             literal_vector lits(n, xs);
             lits.push_back(ctx.mk_not(l));
             add_clause(lits);
         }
 
-        void mk_or_implies(literal l, unsigned n, literal const* xs) {
+        void add_or_implies(literal l, unsigned n, literal const* xs) {
             for (unsigned j = 0; j < n; ++j) {
                 add_clause(ctx.mk_not(xs[j]), l);
             }
         }
 
-        literal mk_or(literal_vector const& ors) {
-            if (ors.size() == 1) {
+        literal mk_or(unsigned n, literal const* ors) {
+            if (n == 1) {
                 return ors[0];
             }
             literal result = fresh();
-            mk_implies_or(result, ors.size(), ors.c_ptr());
-            mk_or_implies(result, ors.size(), ors.c_ptr());
+            add_implies_or(result, n, ors);
+            add_or_implies(result, n, ors);
+            return result;
+        }
+
+        literal mk_or(literal l1, literal l2) {
+            literal ors[2] = { l1, l2 };
+            return mk_or(2, ors);
+        }
+        literal mk_or(literal_vector const& ors) {
+            return mk_or(ors.size(), ors.c_ptr());
+        }
+
+        void add_implies_and(literal l, literal_vector const& xs) {
+            for (unsigned j = 0; j < xs.size(); ++j) {
+                add_clause(ctx.mk_not(l), xs[j]);
+            }
+        }
+
+        void add_and_implies(literal l, literal_vector const& xs) {
+            literal_vector lits;
+            for (unsigned j = 0; j < xs.size(); ++j) {
+                lits.push_back(ctx.mk_not(xs[j]));
+            }
+            lits.push_back(l);
+            add_clause(lits);
+        }
+
+        literal mk_and(literal l1, literal l2) {
+            literal_vector xs;
+            xs.push_back(l1); xs.push_back(l2);
+            return mk_and(xs);
+        }
+
+        literal mk_and(literal_vector const& ands) {
+            if (ands.size() == 1) {
+                return ands[0];
+            }
+            literal result = fresh();
+            add_implies_and(result, ands);
+            add_and_implies(result, ands);
             return result;
         }
 
@@ -281,18 +310,19 @@ Notes:
                 r1 = mk_and(r1, mk_or(ors));
             }
             else {
-                mk_implies_or(r1, ors.size(), ors.c_ptr());
+                add_implies_or(r1, ors.size(), ors.c_ptr());
             }
             return r1;
         }
 
+#if 1
         literal mk_at_most_1(bool full, unsigned n, literal const* xs, literal_vector& ors) {
             TRACE("pb", tout << (full?"full":"partial") << " ";
                   for (unsigned i = 0; i < n; ++i) tout << xs[i] << " ";
                   tout << "\n";);
 
-            if (false && !full && n >= 4) {
-                return mk_at_most_1_bimander(n, xs);
+            if (n >= 4 && false) {
+                return mk_at_most_1_bimander(full, n, xs, ors);
             }
             literal_vector in(n, xs);
             literal result = fresh();
@@ -301,17 +331,14 @@ Notes:
             ands.push_back(result);
             while (!in.empty()) {
                 ors.reset();
-                unsigned i = 0;
                 unsigned n = in.size();
                 if (n + 1 == inc_size) ++inc_size;
-                bool last = n <= inc_size;
-                for (; i + inc_size < n; i += inc_size) {                    
-                    mk_at_most_1_small(full, last, inc_size, in.c_ptr() + i, result, ands, ors);
+                for (unsigned i = 0; i < n; i += inc_size) {       
+                    unsigned inc = std::min(n - i, inc_size);
+                    mk_at_most_1_small(full, inc, in.c_ptr() + i, result, ands);
+                    ors.push_back(mk_or(inc, in.c_ptr() + i));
                 }
-                if (i < n) {
-                    mk_at_most_1_small(full, last, n - i, in.c_ptr() + i, result, ands, ors);
-                }
-                if (last) {
+                if (n <= inc_size) {
                     break;
                 }
                 in.reset();
@@ -322,19 +349,40 @@ Notes:
             }
             return result;
         }
+#else
+        literal mk_at_most_1(bool full, unsigned n, literal const* xs, literal_vector& ors) {
+            TRACE("pb", tout << (full?"full":"partial") << " ";
+                  for (unsigned i = 0; i < n; ++i) tout << xs[i] << " ";
+                  tout << "\n";);
 
-        void mk_at_most_1_small(bool full, bool last, unsigned n, literal const* xs, literal result, literal_vector& ands, literal_vector& ors) {
+            literal_vector in(n, xs);
+            unsigned inc_size = 4;
+            literal_vector ands;
+            while (!in.empty()) {
+                ors.reset();
+                unsigned i = 0;
+                unsigned n = in.size();
+                if (n + 1 == inc_size) ++inc_size;
+                for (; i < n; i += inc_size) {                    
+                    unsigned inc = std::min(inc_size, n - i);
+                    ands.push_back(mk_at_most_1_small(inc, in.c_ptr() + i));
+                    ors.push_back(mk_or(inc, in.c_ptr() + i));
+                }
+                if (n <= inc_size) {
+                    break;
+                }
+                in.reset();
+                in.append(ors);
+            }
+            return mk_and(ands);
+        }
+
+#endif
+        void mk_at_most_1_small(bool full, unsigned n, literal const* xs, literal result, literal_vector& ands) {
             SASSERT(n > 0);
             if (n == 1) {
-                ors.push_back(xs[0]);                
                 return;
             }
-            literal ex = fresh();
-            mk_or_implies(ex, n, xs);
-            if (full) {
-                mk_implies_or(ex, n, xs);
-            }
-            ors.push_back(ex);                
             
             // result => xs[0] + ... + xs[n-1] <= 1
             for (unsigned i = 0; i < n; ++i) {
@@ -358,19 +406,75 @@ Notes:
             }
         }
 
-        literal mk_at_most_1_bimander(unsigned n, literal const* xs) {
+        literal mk_at_most_1_small(unsigned n, literal const* xs) {
+            SASSERT(n > 0);
+            if (n == 1) {
+                return ctx.mk_true();
+            }
+
+            
+#if 0
+            literal result = fresh();
+
+            // result => xs[0] + ... + xs[n-1] <= 1
+            for (unsigned i = 0; i < n; ++i) {
+                for (unsigned j = i + 1; j < n; ++j) {
+                    add_clause(ctx.mk_not(result), ctx.mk_not(xs[i]), ctx.mk_not(xs[j]));
+                }
+            }            
+
+            // xs[0] + ... + xs[n-1] <= 1 => result
+            for (unsigned i = 0; i < n; ++i) {
+                literal_vector lits;
+                lits.push_back(result);
+                for (unsigned j = 0; j < n; ++j) {
+                    if (j != i) lits.push_back(xs[j]);
+                }
+                add_clause(lits);
+            }
+
+            return result;
+#endif
+#if 1
+            // r <=> and( or(!xi,!xj))
+            // 
+            literal_vector ands;
+            for (unsigned i = 0; i < n; ++i) {
+                for (unsigned j = i + 1; j < n; ++j) {
+                    ands.push_back(mk_or(ctx.mk_not(xs[i]), ctx.mk_not(xs[j])));
+                }                
+            }
+            return mk_and(ands);
+#else
+            // r <=> or (and !x_{j != i})
+
+            literal_vector ors;
+            for (unsigned i = 0; i < n; ++i) {
+                literal_vector ands;
+                for (unsigned j = 0; j < n; ++j) {
+                    if (j != i) {
+                        ands.push_back(ctx.mk_not(xs[j]));
+                    }
+                }                
+                ors.push_back(mk_and(ands));
+            }
+            return mk_or(ors);
+            
+#endif
+        }
+
+        
+        // 
+
+        literal mk_at_most_1_bimander(bool full, unsigned n, literal const* xs, literal_vector& ors) {
             literal_vector in(n, xs);
             literal result = fresh();
             unsigned inc_size = 2;
-            bool last = false;
-            bool full = false;
-            literal_vector ors, ands;
-            unsigned i = 0;
-            for (; i + inc_size < n; i += inc_size) {                    
-                mk_at_most_1_small(full, last, inc_size, in.c_ptr() + i, result, ands, ors);
-            }
-            if (i < n) {
-                mk_at_most_1_small(full, last, n - i, in.c_ptr() + i, result, ands, ors);
+            literal_vector ands;
+            for (unsigned i = 0; i < n; i += inc_size) {                    
+                unsigned inc = std::min(n - i, inc_size);
+                mk_at_most_1_small(full, inc, in.c_ptr() + i, result, ands);
+                ors.push_back(mk_or(inc, in.c_ptr() + i));
             }
             
             unsigned nbits = 0;
@@ -381,7 +485,7 @@ Notes:
             for (unsigned k = 0; k < nbits; ++k) {
                 bits.push_back(fresh());
             }
-            for (i = 0; i < ors.size(); ++i) {
+            for (unsigned i = 0; i < ors.size(); ++i) {
                 for (unsigned k = 0; k < nbits; ++k) {
                     bool bit_set = (i & (static_cast<unsigned>(1 << k))) != 0;
                     add_clause(ctx.mk_not(result), ctx.mk_not(ors[i]), bit_set ? bits[k] : ctx.mk_not(bits[k]));