From 806690571e06f12f21239dd8d68d1786ec9cec4e Mon Sep 17 00:00:00 2001
From: Miguel Neves <t-mineve@microsoft.com>
Date: Tue, 17 Oct 2017 13:15:34 -0700
Subject: [PATCH] Lookahead clause size optimization. Fixed some missing
 propagations

---
 src/sat/sat_lookahead.cpp | 164 ++++++++++++++++++++++++++------------
 src/sat/sat_lookahead.h   |   6 +-
 2 files changed, 117 insertions(+), 53 deletions(-)

diff --git a/src/sat/sat_lookahead.cpp b/src/sat/sat_lookahead.cpp
index f5d0910ab..edb881aa4 100644
--- a/src/sat/sat_lookahead.cpp
+++ b/src/sat/sat_lookahead.cpp
@@ -155,7 +155,7 @@ namespace sat {
                 if (is_stamped(~w)) {
                     // u \/ v, ~v \/ w, u \/ ~w => u is unit
                     TRACE("sat", tout << "tc1: " << u << "\n";);
-                    assign(u);        
+                    propagated(u);
                     return false;
                 }
                 if (m_num_tc1 < m_config.m_tc1_limit) {
@@ -181,14 +181,14 @@ namespace sat {
         set_bstamps(~u);
         if (is_stamped(~v)) {         
             TRACE("sat", tout << "try_add_binary: " << u << "\n";);       
-            assign(u);        // u \/ ~v, u \/ v => u is a unit literal
+            propagated(u);        // u \/ ~v, u \/ v => u is a unit literal
         }
         else if (!is_stamped(v) && add_tc1(u, v)) {
             // u \/ v is not in index
             set_bstamps(~v);
             if (is_stamped(~u)) {
-                TRACE("sat", tout << "try_add_binary: " << v << "\n";);       
-                assign(v);    // v \/ ~u, u \/ v => v is a unit literal
+                TRACE("sat", tout << "try_add_binary: " << v << "\n";);
+                propagated(v);    // v \/ ~u, u \/ v => v is a unit literal
             }
             else if (add_tc1(v, u)) {
                 update_prefix(u);
@@ -407,6 +407,37 @@ namespace sat {
         return true;
     }
 
+    bool lookahead::missed_propagation() const {
+        for (literal l1 : m_trail) {
+            SASSERT(is_true(l1));
+            for (literal l2 : m_binary[l1.index()]) {
+                if (is_undef(l2)) return true;
+            }
+            unsigned sz = m_ternary_count[(~l1).index()];
+            for (binary b : m_ternary[(~l1).index()]) {
+                if (sz-- == 0) break;
+                if ((is_false(b.m_u) && is_undef(b.m_v)) || (is_false(b.m_v) && is_undef(b.m_u)))
+                    return true;
+            }
+        }
+        for (nary * n : m_nary_clauses) {
+            if (n->size() == 1 && !is_true(n->get_head())) {
+                for (literal lit : *n) {
+                    if (is_undef(lit)) return true;
+                }
+            }
+        }
+        return false;
+    }
+
+    bool lookahead::missed_conflict() const {
+        if (inconsistent()) return false;
+        for (nary * n : m_nary_clauses) {
+            if (n->size() == 0) return true;
+        }
+        return false;
+    }
+
     void lookahead::init_pre_selection(unsigned level) {
         switch (m_config.m_reward_type) {
         case ternary_reward: {
@@ -1098,11 +1129,19 @@ namespace sat {
     }
 
     void lookahead::lookahead_backtrack() {
-        while (!m_trail.empty() && is_undef(m_trail.back())) {
+        literal lit = null_literal;
+        while (!m_trail.empty() && is_undef((lit = m_trail.back()))) {
+            if (m_qhead == m_trail.size()) {
+                unsigned sz = m_nary_count[(~lit).index()];
+                for (nary* n : m_nary[(~lit).index()]) {
+                    if (sz-- == 0) break;
+                    n->inc_size();
+                }
+                --m_qhead;
+            }
             m_trail.pop_back();
         }
         SASSERT(m_trail_lim.empty() || m_trail.size() >= m_trail_lim.back());
-        m_qhead = std::min(m_qhead, m_trail.size());
     }
 
     // 
@@ -1137,14 +1176,15 @@ namespace sat {
     lbool lookahead::propagate_ternary(literal l1, literal l2) {
         if (is_fixed(l1)) {
             if (is_false(l1)) {
-                if (is_undef(l2)) {
-                    propagated(l2);
-                }
-                else if (is_false(l2)) {
+                if (is_false(l2)) {
                     TRACE("sat", tout << l1 << " " << l2 << " " << "\n";);
                     set_conflict();
+                    return l_false;
                 }
-                return l_false;
+                else if (is_undef(l2)) {
+                    propagated(l2);
+                }
+                return l_true;
             }
             else {
                 return l_true;
@@ -1298,10 +1338,11 @@ namespace sat {
         unsigned sz = m_nary_count[(~l).index()];
         literal lit;
         SASSERT(m_search_mode == lookahead_mode::searching);
-        for (nary * n : m_nary[(~l).index()]) {
+        for (nary* n : m_nary[(~l).index()]) {
             if (sz-- == 0) break;
             unsigned len = n->dec_size();
-            if (m_inconsistent) continue;
+            if (is_true(n->get_head())) continue;
+            if (inconsistent()) continue;
             if (len <= 1) continue; // already processed
             // find the two unassigned literals, if any
             if (len == 2) {
@@ -1357,53 +1398,35 @@ namespace sat {
     void lookahead::propagate_clauses_lookahead(literal l) {
         // clauses where l is negative
         unsigned sz = m_nary_count[(~l).index()];
-        literal lit;
         SASSERT(m_search_mode == lookahead_mode::lookahead1 ||
                 m_search_mode == lookahead_mode::lookahead2);
         
         for (nary* n : m_nary[(~l).index()]) {
             if (sz-- == 0) break;
-
-            if (is_true(n->get_head())) {
-                continue;
-            }
-            literal l1 = null_literal;
-            literal l2 = null_literal;
-            bool skip_clause = false;
-            unsigned nonfixed = 0;
-            for (literal lit : *n) {
-                if (!is_fixed(lit)) {
-                    ++nonfixed;
-                    if (l1 == null_literal) {
-                        l1 = lit;
-                    } 
-                    else if (l2 == null_literal) {
-                        l2 = lit;
+            unsigned nonfixed = n->dec_size();
+            if (is_true(n->get_head())) continue;
+            if (inconsistent()) continue;
+            if (nonfixed <= 1) {
+                bool found_conflict = true;
+                for (literal lit : *n) {
+                    if (!is_fixed(lit)) {
+                        propagated(lit);
+                        found_conflict = false;
+                        break;
                     }
-                    else if (m_search_mode == lookahead_mode::lookahead2) {
-                        skip_clause = true;
+                    else if (is_true(lit)) {
+                        n->set_head(lit);
+                        found_conflict = false;
                         break;
                     }
                 }
-                else if (is_true(lit)) {
-                    n->set_head(lit);
-                    skip_clause = true;
-                    break;
+                if (found_conflict) {
+                    set_conflict();
+                    continue;
                 }
             }
-            if (skip_clause) {
-                // skip, the clause 
-            }
-            else if (l1 == null_literal) {
-                set_conflict();
-                return;
-            }
-            else if (l2 == null_literal) {
-                propagated(l1);
-            }
-            else {
+            else if (m_search_mode == lookahead_mode::lookahead1) {
                 SASSERT(nonfixed >= 2);
-                SASSERT(m_search_mode == lookahead_mode::lookahead1);
                 switch (m_config.m_reward_type) {
                 case heule_schur_reward: {
                     double to_add = 0;
@@ -1418,9 +1441,35 @@ namespace sat {
                 case heule_unit_reward:
                     m_lookahead_reward += pow(0.5, nonfixed);
                     break;
+                case march_cu_reward:
+                    m_lookahead_reward += 3.3 * pow(0.5, nonfixed - 2);
+                    break;
                 case ternary_reward:
                     if (nonfixed == 2) {
-                        m_lookahead_reward += (*m_heur)[l1.index()] * (*m_heur)[l2.index()];
+                        literal l1 = null_literal;
+                        literal l2 = null_literal;
+                        for (literal lit : *n) {
+                            if (!is_fixed(lit)) {
+                                if (l1 == null_literal) {
+                                    l1 = lit;
+                                }
+                                else {
+                                    SASSERT(l2 != null_literal);
+                                    l2 = lit;
+                                    break;
+                                }
+                            }
+                        }
+                        if (l1 == null_literal) {
+                            set_conflict();
+                            continue;
+                        }
+                        else if (l2 == null_literal) {
+                            propagated(l1);
+                        }
+                        else {
+                            m_lookahead_reward += (*m_heur)[l1.index()] * (*m_heur)[l2.index()];
+                        }
                     }
                     else {
                         m_lookahead_reward += (double)0.001;            
@@ -1431,6 +1480,14 @@ namespace sat {
                 }
             }
         }
+        // clauses where l is positive:
+        sz = m_nary_count[l.index()];
+        for (nary* n : m_nary[l.index()]) {
+            if (sz-- == 0) break;
+            if (m_stamp[l.var()] > m_stamp[n->get_head().var()]) {
+                n->set_head(l);
+            }
+        }
     }
 
     void lookahead::remove_clause_at(literal l, nary& n) {
@@ -1567,9 +1624,9 @@ namespace sat {
     void lookahead::propagate_binary(literal l) {
         literal_vector const& lits = m_binary[l.index()];
         TRACE("sat", tout << l << " => " << lits << "\n";);
-        for (literal l : lits) {
+        for (literal lit : lits) {
             if (inconsistent()) break;
-            assign(l);
+            assign(lit);
         }
     }
 
@@ -1584,6 +1641,8 @@ namespace sat {
             propagate_clauses(m_trail[m_qhead++]);
         }
         SASSERT(m_qhead == m_trail.size() || (inconsistent() && m_qhead < m_trail.size()));
+        //SASSERT(!missed_conflict());
+        //SASSERT(inconsistent() || !missed_propagation());
         TRACE("sat_verbose", display(tout << scope_lvl() << " " << (inconsistent()?"unsat":"sat") << "\n"););
     }
 
@@ -1600,6 +1659,7 @@ namespace sat {
                 checkpoint();
                 literal lit = m_lookahead[i].m_lit;
                 if (lit == last_changed) {
+                    SASSERT(!change);
                     break;
                 }
                 if (scope_lvl() == 1) {
@@ -1812,6 +1872,7 @@ namespace sat {
         lookahead_backtrack();
         assign(l);
         propagate();
+        //SASSERT(!inconsistent());
         unsigned old_sz = m_trail.size();
         bool change = true;
         literal last_changed = null_literal;
@@ -1847,6 +1908,7 @@ namespace sat {
                     propagate();
                     change = true;
                     last_changed = lit;
+                    m_wstack.push_back(~lit);
                 }
             }
             base += 2 * m_lookahead.size();
diff --git a/src/sat/sat_lookahead.h b/src/sat/sat_lookahead.h
index c9ce7d946..d30485254 100644
--- a/src/sat/sat_lookahead.h
+++ b/src/sat/sat_lookahead.h
@@ -90,7 +90,7 @@ namespace sat {
                 m_min_cutoff = 30;
                 m_preselect = false;
                 m_level_cand = 600;
-                m_delta_rho = (double)0.25;
+                m_delta_rho = (double)0.85;
                 m_dl_max_iterations = 2;
                 m_tc1_limit = 10000000;
                 m_reward_type = ternary_reward;
@@ -340,6 +340,8 @@ namespace sat {
         std::ostream& display_candidates(std::ostream& out) const;
         bool is_unsat() const;
         bool is_sat() const;
+        bool missed_propagation() const;
+        bool missed_conflict() const;
         void init_pre_selection(unsigned level);
         void ensure_H(unsigned level);
         void h_scores(svector<double>& h, svector<double>& hp);
@@ -503,7 +505,7 @@ namespace sat {
         unsigned do_double(literal l, unsigned& base);
         unsigned double_look(literal l, unsigned& base);
         void set_conflict() { TRACE("sat", tout << "conflict\n";); m_inconsistent = true; }
-        bool inconsistent() { return m_inconsistent; }
+        bool inconsistent() const { return m_inconsistent; }
 
         unsigned scope_lvl() const { return m_trail_lim.size(); }