From b87405cc9274b10f486de02e8f34c806fb579635 Mon Sep 17 00:00:00 2001
From: Nikolaj Bjorner <nbjorner@microsoft.com>
Date: Mon, 18 Jan 2021 16:51:34 -0800
Subject: [PATCH] tune user-pop

---
 src/sat/sat_gc.cpp     |  58 ++++++++++++++
 src/sat/sat_solver.cpp | 172 +++++++++++++----------------------------
 src/sat/sat_solver.h   |   8 +-
 3 files changed, 113 insertions(+), 125 deletions(-)

diff --git a/src/sat/sat_gc.cpp b/src/sat/sat_gc.cpp
index 96d0f495a..2179a324f 100644
--- a/src/sat/sat_gc.cpp
+++ b/src/sat/sat_gc.cpp
@@ -423,6 +423,64 @@ namespace sat {
         return true;
     }
 
+    void solver::gc_vars(bool_var max_var) {
+        init_visited();
+        m_aux_literals.reset();
+        auto gc_watch = [&](literal lit) {
+            auto& wl1 = get_wlist(lit);
+            for (auto w : get_wlist(lit)) {
+                if (w.is_binary_clause() && w.get_literal().var() < max_var && !is_visited(w.get_literal())) {
+                    m_aux_literals.push_back(w.get_literal());
+                    mark_visited(w.get_literal());
+                }
+            }
+            wl1.reset();
+        };
+        for (unsigned v = max_var; v < num_vars(); ++v) {
+            gc_watch(literal(v, false));
+            gc_watch(literal(v, true));
+        }
+
+        for (literal lit : m_aux_literals) {
+            auto& wl2 = get_wlist(~lit);
+            unsigned j = 0;
+            for (auto w2 : wl2) 
+                if (!w2.is_binary_clause() || w2.get_literal().var() < max_var)
+                    wl2[j++] = w2;
+            wl2.shrink(j);                        
+        }
+        m_aux_literals.reset();
+
+        auto gc_clauses = [&](ptr_vector<clause>& clauses) {
+            unsigned j = 0;
+            for (clause* c : clauses) {
+                bool should_remove = false;
+                for (auto lit : *c) 
+                    should_remove |= lit.var() >= max_var;
+                if (should_remove) {
+                    SASSERT(!c->on_reinit_stack());
+                    detach_clause(*c);
+                    del_clause(*c);
+                }
+                else {
+                    clauses[j++] = c;
+                }
+            }
+            clauses.shrink(j);
+        };
+        gc_clauses(m_learned);
+        gc_clauses(m_clauses);
+        
+        unsigned j = 0;
+        for (literal lit : m_trail) {
+            SASSERT(lvl(lit) == 0);
+            if (lit.var() < max_var)
+                m_trail[j++] = lit;
+        }
+        m_trail.shrink(j);
+        shrink_vars(max_var);
+    }
+
 #if 0
     void solver::gc_reinit_stack(unsigned num_scopes) {
         SASSERT (!at_base_lvl());
diff --git a/src/sat/sat_solver.cpp b/src/sat/sat_solver.cpp
index a280ee94d..ba1309d0c 100644
--- a/src/sat/sat_solver.cpp
+++ b/src/sat/sat_solver.cpp
@@ -558,22 +558,27 @@ namespace sat {
         m_watches[(~c[0]).index()].push_back(watched(c[1], c[2]));
         m_watches[(~c[1]).index()].push_back(watched(c[0], c[2]));
         m_watches[(~c[2]).index()].push_back(watched(c[0], c[1]));
-        if (!at_base_lvl()) {
-            if (value(c[1]) == l_false && value(c[2]) == l_false) {
-                m_stats.m_ter_propagate++;
-                assign(c[0], justification(std::max(lvl(c[1]), lvl(c[2])), c[1], c[2]));
-                reinit = !c.is_learned();
-            }
-            else if (value(c[0]) == l_false && value(c[2]) == l_false) {
-                m_stats.m_ter_propagate++;
-                assign(c[1], justification(std::max(lvl(c[0]), lvl(c[2])), c[0], c[2]));
-                reinit = !c.is_learned();
-            }
-            else if (value(c[0]) == l_false && value(c[1]) == l_false) {
-                m_stats.m_ter_propagate++;
-                assign(c[2], justification(std::max(lvl(c[0]), lvl(c[1])), c[0], c[1]));
-                reinit = !c.is_learned();
-            }
+        if (!at_base_lvl()) 
+            reinit = propagate_ter_clause(c);        
+        return reinit;
+    }
+
+    bool solver::propagate_ter_clause(clause& c) {
+        bool reinit = false;
+        if (value(c[1]) == l_false && value(c[2]) == l_false) {
+            m_stats.m_ter_propagate++;
+            assign(c[0], justification(std::max(lvl(c[1]), lvl(c[2])), c[1], c[2]));
+            reinit = !c.is_learned();
+        }
+        else if (value(c[0]) == l_false && value(c[2]) == l_false) {
+            m_stats.m_ter_propagate++;
+            assign(c[1], justification(std::max(lvl(c[0]), lvl(c[2])), c[0], c[2]));
+            reinit = !c.is_learned();
+        }
+        else if (value(c[0]) == l_false && value(c[1]) == l_false) {
+            m_stats.m_ter_propagate++;
+            assign(c[2], justification(std::max(lvl(c[0]), lvl(c[1])), c[0], c[1]));
+            reinit = !c.is_learned();
         }
         return reinit;
     }
@@ -3459,8 +3464,8 @@ namespace sat {
     }
 
     void solver::pop_vars(unsigned num_scopes) {
-        integrity_checker check(*this);
-        check.check_reinit_stack();
+        //integrity_checker check(*this);
+        //check.check_reinit_stack();
         m_vars_to_reinit.reset();
         unsigned old_num_vars = m_vars_lim.pop(num_scopes);
         if (old_num_vars == m_active_vars.size())
@@ -3471,7 +3476,7 @@ namespace sat {
 
         gc_reinit_stack(num_scopes);        
 
-        check.check_reinit_stack();
+        // check.check_reinit_stack();
         init_visited();
         unsigned old_sz = m_scopes[new_lvl].m_clauses_to_reinit_lim;
         for (unsigned i = m_clauses_to_reinit.size(); i-- > old_sz; ) {
@@ -3575,9 +3580,8 @@ namespace sat {
         SASSERT(old_sz <= m_trail.size());
         SASSERT(m_replay_assign.empty());
         unsigned i = m_trail.size();
-        while (i != old_sz) {
-            --i;
-            literal l                  = m_trail[i];
+        for (unsigned i = m_trail.size(); i-- > old_sz; ) {
+            literal l  = m_trail[i];
             bool_var v = l.var();
             if (lvl(v) <= new_lvl) {
                 m_replay_assign.push_back(l);
@@ -3618,6 +3622,15 @@ namespace sat {
             }
             else {
                 clause & c = *(cw.get_clause());
+                if (ENABLE_TERNARY && c.size() == 3) {
+                    if (!at_base_lvl() && propagate_ter_clause(c))
+                        m_clauses_to_reinit[j++] = cw;                
+                    else if (has_variables_to_reinit(c))
+                        m_clauses_to_reinit[j++] = cw;
+                    else 
+                        c.set_reinit_stack(false);
+                    continue;
+                }
                 detach_clause(c);
                 attach_clause(c, reinit);
                 if (!at_base_lvl() && reinit) 
@@ -3639,9 +3652,10 @@ namespace sat {
 
     void solver::user_push() {
         pop_to_base_level();
-        literal lit;
+        m_free_var_freeze.push_back(m_free_vars);
+        m_free_vars.reset(); // resetting free_vars forces new variables to be assigned above new_v
         bool_var new_v = mk_var(true, false);
-        lit = literal(new_v, false);
+        literal lit = literal(new_v, false);
         m_user_scope_literals.push_back(lit);
         m_cut_simplifier = nullptr; // for simplicity, wipe it out
         if (m_ext)
@@ -3649,108 +3663,26 @@ namespace sat {
         TRACE("sat", tout << "user_push: " << lit << "\n";);
     }
 
-    void solver::gc_lit(clause_vector &clauses, literal lit) {
-        unsigned j = 0;
-        for (unsigned i = 0; i < clauses.size(); ++i) {
-            clause & c = *(clauses[i]);
-            if (c.contains(lit) || c.contains(~lit)) {
-                detach_clause(c);
-                del_clause(c);
-            }
-            else {
-                clauses[j] = &c;
-                ++j;
-            }
-        }
-        clauses.shrink(j);
-    }
-
-    void solver::gc_bin(literal lit) {
-        bool_var v = lit.var();
-        for (watch_list& wlist : m_watches) {
-            watch_list::iterator it  = wlist.begin();
-            watch_list::iterator it2 = wlist.begin();
-            watch_list::iterator end = wlist.end();
-            for (; it != end; ++it) {
-                if (it->is_binary_clause() && it->get_literal().var() == v) {
-                    // skip
-                }
-                else {
-                    *it2 = *it;
-                    ++it2;
-                }
-            }
-            wlist.set_end(it2);
-        }
-    }
-
-    bool_var solver::max_var(bool redundant, bool_var v) {
-        m_user_bin_clauses.reset();
-        collect_bin_clauses(m_user_bin_clauses, redundant, false);
-        for (unsigned i = 0; i < m_user_bin_clauses.size(); ++i) {
-            literal l1 = m_user_bin_clauses[i].first;
-            literal l2 = m_user_bin_clauses[i].second;
-            if (l1.var() > v) v = l1.var();
-            if (l2.var() > v) v = l2.var();
-        }
-        return v;
-    }
-
-    bool_var solver::max_var(clause_vector& clauses, bool_var v) {
-        for (clause* cp : clauses) 
-            for (auto it = cp->begin(), end = cp->end(); it != end; ++it) {
-                if (it->var() > v) 
-                    v = it->var();
-            }
-        return v;
-    }
-
-    void solver::gc_var(bool_var v) {
-        bool_var w = max_var(m_learned, v);
-        w = max_var(m_clauses, w);
-        w = max_var(true, w);
-        w = max_var(false, w);
-        v = m_mc.max_var(w);
-        for (literal lit : m_trail) {
-            w = std::max(w, lit.var());
-        }
-        if (m_ext) {
-            w = m_ext->max_var(w);
-        }
-        v = w + 1;
-        
-        // v is an index of a variable that does not occur in solver state.
-        if (v < m_justification.size()) {
-            shrink_vars(v);
-        }
-    }
-
     void solver::user_pop(unsigned num_scopes) {
+        unsigned old_sz = m_user_scope_literals.size() - num_scopes;
+        bool_var max_var = m_user_scope_literals[old_sz].var();
+        m_user_scope_literals.shrink(old_sz);
+
         pop_to_base_level();
-        TRACE("sat", display(tout););
         if (m_ext)
             m_ext->user_pop(num_scopes);
-        while (num_scopes > 0) {
-            literal lit = m_user_scope_literals.back();
-            m_user_scope_literals.pop_back();
-            get_wlist(lit).reset();
-            get_wlist(~lit).reset();
 
-            gc_lit(m_learned, lit);
-            gc_lit(m_clauses, lit);
-            gc_bin(lit);
-            TRACE("sat", tout << "gc: " << lit << "\n"; display(tout););
-            --num_scopes;
-            for (unsigned i = 0; i < m_trail.size(); ++i) {
-                if (m_trail[i] == lit) {
-                    TRACE("sat", tout << m_trail << "\n";);
-                    unassign_vars(i, 0);
-                    break;
-                }
-            }
-            gc_var(lit.var());            
-        }
+        gc_vars(max_var);
+        TRACE("sat", display(tout););
+
         m_qhead = 0;
+        unsigned j = 0;
+        for (bool_var v : m_free_vars) 
+            if (v < max_var)
+                m_free_vars[j++] = v;
+        m_free_vars.shrink(j);
+        m_free_vars.append(m_free_var_freeze[old_sz]); 
+        m_free_var_freeze.shrink(old_sz);
         scoped_suspend_rlimit _sp(m_rlimit);
         propagate(false);
     }
diff --git a/src/sat/sat_solver.h b/src/sat/sat_solver.h
index 5ada5ef88..6725057c3 100644
--- a/src/sat/sat_solver.h
+++ b/src/sat/sat_solver.h
@@ -287,6 +287,7 @@ namespace sat {
         bool propagate_bin_clause(literal l1, literal l2);
         clause * mk_ter_clause(literal * lits, status st);
         bool attach_ter_clause(clause & c, status st);
+        bool propagate_ter_clause(clause& c);
         clause * mk_nary_clause(unsigned num_lits, literal * lits, status st);
         bool has_variables_to_reinit(clause const& c) const;
         bool has_variables_to_reinit(literal l1, literal l2) const;
@@ -651,14 +652,11 @@ namespace sat {
         void reinit_clauses(unsigned old_sz);
 
         literal_vector m_user_scope_literals;
+        vector<svector<bool_var>> m_free_var_freeze;
         literal_vector m_aux_literals;
         svector<bin_clause> m_user_bin_clauses;
-        void gc_lit(clause_vector& clauses, literal lit);
-        void gc_bin(literal lit);
-        void gc_var(bool_var v);
 
-        bool_var max_var(clause_vector& clauses, bool_var v);
-        bool_var max_var(bool learned, bool_var v);
+        void gc_vars(bool_var max_var);
 
         // -----------------------
         //