From 25d45a3500cac1291198b00da0f83b147ef08914 Mon Sep 17 00:00:00 2001
From: Nikolaj Bjorner <nbjorner@microsoft.com>
Date: Tue, 28 Feb 2023 17:40:00 -0800
Subject: [PATCH] fixes and tests for arith-sls

---
 src/sat/sat_ddfw.cpp      |  28 +++---
 src/sat/sat_ddfw.h        |   5 +-
 src/sat/sat_solver.cpp    |  48 ++++-------
 src/sat/sat_solver.h      |  19 +---
 src/sat/sat_types.h       |  34 ++++++++
 src/sat/smt/arith_sls.cpp | 177 ++++++++++++++++++++++----------------
 src/sat/smt/arith_sls.h   |   7 +-
 7 files changed, 182 insertions(+), 136 deletions(-)

diff --git a/src/sat/sat_ddfw.cpp b/src/sat/sat_ddfw.cpp
index 5d80e5af4..ca274be51 100644
--- a/src/sat/sat_ddfw.cpp
+++ b/src/sat/sat_ddfw.cpp
@@ -62,13 +62,16 @@ namespace sat {
     void ddfw::check_with_plugin() {
         m_plugin->init_search();
         m_steps_since_progress = 0;
-        while (m_min_sz > 0 && m_steps_since_progress++ <= 150000) {
+        unsigned steps = 0;
+        while (m_min_sz > 0 && m_steps_since_progress++ <= 1500000) {
             if (should_reinit_weights()) do_reinit_weights();
+            else if (steps % 5000 == 0) shift_weights(), m_plugin->on_rescale();
+            else if (should_restart()) do_restart(), m_plugin->on_restart();
             else if (do_flip<true>());
             else if (do_literal_flip<true>());
-            else if (should_restart()) do_restart(), m_plugin->on_restart();
             else if (should_parallel_sync()) do_parallel_sync();
             else shift_weights(), m_plugin->on_rescale();
+            ++steps;
         }
         m_plugin->finish_search();
     }
@@ -135,7 +138,7 @@ namespace sat {
         if (sum_pos > 0) {
             double lim_pos = ((double) m_rand() / (1.0 + m_rand.max_value())) * sum_pos;                
             for (bool_var v : m_unsat_vars) {
-                r = uses_plugin ? plugin_reward(v) : reward(v);
+                r = uses_plugin && is_external(v) ? m_vars[v].m_last_reward : reward(v);
                 if (r > 0) {
                     lim_pos -= score(r);
                     if (lim_pos <= 0) 
@@ -472,9 +475,7 @@ namespace sat {
 
 
     void ddfw::save_best_values() {
-        if (m_unsat.empty()) 
-            save_model();
-        else if (m_unsat.size() < m_min_sz) {
+        if (m_unsat.size() < m_min_sz) {
             m_steps_since_progress = 0;
             if (m_unsat.size() < 50 || m_min_sz * 10 > m_unsat.size() * 11)
                 save_model();
@@ -489,13 +490,20 @@ namespace sat {
                 }
             }
         }
+
         unsigned h = value_hash();
+        unsigned occs = 0;
+        bool contains = m_models.find(h, occs);
         if (!m_models.contains(h)) {
-            for (unsigned v = 0; v < num_vars(); ++v) 
+            for (unsigned v = 0; v < num_vars(); ++v)
                 bias(v) += value(v) ? 1 : -1;
-            m_models.insert(h);
-            if (m_models.size() > m_config.m_max_num_models) 
-                m_models.erase(*m_models.begin());
+            if (m_models.size() > m_config.m_max_num_models)
+                m_models.erase(m_models.begin()->m_key);
+        }
+        m_models.insert(h, occs + 1);
+        if (occs > 100) {
+            m_restart_next = m_flips;            
+            m_models.erase(h);
         }
         m_min_sz = m_unsat.size();
     }
diff --git a/src/sat/sat_ddfw.h b/src/sat/sat_ddfw.h
index 8c4f9287f..988365285 100644
--- a/src/sat/sat_ddfw.h
+++ b/src/sat/sat_ddfw.h
@@ -98,6 +98,7 @@ namespace sat {
             var_info() {}
             bool     m_value = false;
             double   m_reward = 0;
+            double   m_last_reward = 0;
             unsigned m_make_count = 0;
             int      m_bias = 0;
             bool     m_external = false;
@@ -127,7 +128,7 @@ namespace sat {
         uint64_t         m_restart_next = 0, m_reinit_next = 0, m_parsync_next = 0;
         uint64_t         m_flips = 0, m_last_flips = 0, m_shifts = 0;
         unsigned         m_min_sz = 0, m_steps_since_progress = 0;
-        hashtable<unsigned, unsigned_hash, default_eq<unsigned>> m_models;
+        u_map<unsigned>  m_models;
         stopwatch        m_stopwatch;
 
         parallel*        m_par;
@@ -153,7 +154,7 @@ namespace sat {
 
         inline double reward(bool_var v) const { return m_vars[v].m_reward; }
 
-        inline double plugin_reward(bool_var v) const { return is_external(v) ? m_plugin->reward(v) : reward(v); }
+        inline double plugin_reward(bool_var v) { return is_external(v) ? (m_vars[v].m_last_reward = m_plugin->reward(v)) : reward(v); }
 
         void set_external(bool_var v) { m_vars[v].m_external = true; }
 
diff --git a/src/sat/sat_solver.cpp b/src/sat/sat_solver.cpp
index d8d68d262..5c1ed6dae 100644
--- a/src/sat/sat_solver.cpp
+++ b/src/sat/sat_solver.cpp
@@ -40,26 +40,6 @@ Revision History:
 
 namespace sat {
 
-    /**
-    * Special cases of kissat style general backoff calculation.
-    * The version here calculates
-    * limit := value*log(C)^2*n*log(n)
-    * (effort calculation in kissat is based on ticks not clauses)
-    *
-    * respectively
-    * limit := conflicts + value*log(C)^2*n*log(n)
-    */
-    void backoff::delta_effort(solver& s) {
-        count++;
-        unsigned d = value * count * log2(count + 1);
-        unsigned cl = log2(s.num_clauses() + 2);
-        limit = cl * cl * d;
-    }
-
-    void backoff::delta_conflicts(solver& s) {
-        delta_effort(s);
-        limit += s.m_conflicts_since_init;
-    }
 
     solver::solver(params_ref const & p, reslimit& l):
         solver_core(l),
@@ -1302,10 +1282,9 @@ namespace sat {
                 return l_undef;
             }
 
-            if (false && m_config.m_phase == PS_LOCAL_SEARCH && m_ext) {
-                IF_VERBOSE(0, verbose_stream() << "WARNING: local search with theories is in testing mode\n");
+            if (m_config.m_phase == PS_LOCAL_SEARCH && m_ext) {                
                 bounded_local_search();
-                exit(0);
+                // exit(0);
             }
 
             log_stats();
@@ -1367,7 +1346,7 @@ namespace sat {
 
     void solver::bounded_local_search() {
         if (m_ext) {
-            verbose_stream() << "bounded local search\n";
+            IF_VERBOSE(0, verbose_stream() << "WARNING: local search with theories is in testing mode\n");
             do_restart(true);
             lbool r = m_ext->local_search(m_best_phase);
             verbose_stream() << r << "\n";
@@ -1388,8 +1367,8 @@ namespace sat {
         m_local_search->set_seed(m_rand());
         scoped_rl.push_child(&(m_local_search->rlimit()));
 
-        m_backoffs.m_local_search.delta_effort(*this);
-        m_local_search->rlimit().push(m_backoffs.m_local_search.limit);
+        m_local_search_lim.inc(num_clauses());
+        m_local_search->rlimit().push(m_local_search_lim.limit);
 
         m_local_search->reinit(*this, m_best_phase);
         lbool r = m_local_search->check(_lits.size(), _lits.data(), nullptr);
@@ -1977,11 +1956,13 @@ namespace sat {
         m_search_sat_conflicts    = m_config.m_search_sat_conflicts;
         m_search_next_toggle      = m_search_unsat_conflicts;
         m_best_phase_size         = 0;
+
+        m_reorder.lo              = m_config.m_reorder_base;
+        m_rephase.base            = m_config.m_rephase_base;
         m_rephase_lim             = 0;
         m_rephase_inc             = 0;
-        m_reorder_lim             = m_config.m_reorder_base;
-        m_backoffs.m_local_search.value = 500;
-        m_reorder_inc             = 0;
+        m_local_search_lim.base   = 500;        
+
         m_conflicts_since_restart = 0;
         m_force_conflict_analysis = false;
         m_restart_threshold       = m_config.m_restart_initial;
@@ -2981,6 +2962,7 @@ namespace sat {
 
     bool solver::should_rephase() {
         return m_conflicts_since_init > m_rephase_lim;
+//        return m_rephase.should_apply(m_conflicts_since_init);
     }
 
     void solver::do_rephase() {
@@ -2994,7 +2976,7 @@ namespace sat {
         case PS_FROZEN:
             break;
         case PS_BASIC_CACHING:
-            switch (m_rephase_lim % 4) {
+            switch (m_rephase.count % 4) {
             case 0:
                 for (auto& p : m_phase) p = (m_rand() % 2) == 0;
                 break;
@@ -3031,10 +3013,11 @@ namespace sat {
         }
         m_rephase_inc += m_config.m_rephase_base;
         m_rephase_lim += m_rephase_inc;
+        m_rephase.inc(m_conflicts_since_init, num_clauses());
     }
 
     bool solver::should_reorder() {
-        return m_conflicts_since_init > m_reorder_lim;
+        return m_reorder.should_apply(m_conflicts_since_init);
     }
 
     void solver::do_reorder() {
@@ -3078,8 +3061,7 @@ namespace sat {
             update_activity(v, m_rand(10)/10.0);
         }
 #endif
-        m_reorder_inc += m_config.m_reorder_base;
-        m_reorder_lim += m_reorder_inc;
+        m_reorder.inc(m_conflicts_since_init, num_clauses());
     }
 
     void solver::updt_phase_counters() {
diff --git a/src/sat/sat_solver.h b/src/sat/sat_solver.h
index 703b36dd0..3a437855e 100644
--- a/src/sat/sat_solver.h
+++ b/src/sat/sat_solver.h
@@ -87,23 +87,10 @@ namespace sat {
     struct no_drat_params : public params_ref {
         no_drat_params() { set_bool("drat.disable", true); }
     };
-
-    struct backoff {
-        unsigned value = 1;
-        unsigned lo    = 0;
-        unsigned hi    = 0;
-        unsigned limit = 0;
-        unsigned count = 0;    
-        void delta_effort(solver& s);
-        void delta_conflicts(solver& s);
-    };
     
     class solver : public solver_core {
     public:
         struct abort_solver {};
-        struct backoffs {
-            backoff m_local_search;
-        };
     protected:
         enum search_state { s_sat, s_unsat };
 
@@ -172,11 +159,11 @@ namespace sat {
         unsigned                m_search_next_toggle;
         unsigned                m_phase_counter; 
         unsigned                m_best_phase_size;
-        backoffs                m_backoffs;
+        backoff                 m_local_search_lim;
         unsigned                m_rephase_lim;
         unsigned                m_rephase_inc;
-        unsigned                m_reorder_lim;
-        unsigned                m_reorder_inc;
+        backoff                 m_rephase;
+        backoff                 m_reorder;
         var_queue               m_case_split_queue;
         unsigned                m_qhead;
         unsigned                m_scope_lvl;
diff --git a/src/sat/sat_types.h b/src/sat/sat_types.h
index d5d457cb0..427b6fb70 100644
--- a/src/sat/sat_types.h
+++ b/src/sat/sat_types.h
@@ -136,6 +136,40 @@ namespace sat {
     std::ostream& operator<<(std::ostream& out, sat::status const& st);
     std::ostream& operator<<(std::ostream& out, sat::status_pp const& p);
 
+    /**
+     * Special cases of kissat style general backoff calculation.
+     * The version here calculates
+     * limit := value*log(C)^2*n*log(n)
+     * (effort calculation in kissat is based on ticks not clauses)
+     *
+     * respectively
+     * limit := conflicts + value*log(C)^2*n*log(n)
+     */
+    struct backoff {
+        unsigned base = 1;
+        unsigned lo = 0;
+        unsigned hi = UINT_MAX;
+        unsigned limit = 0;
+        unsigned count = 0;
+
+        bool should_apply(unsigned n) const { 
+            return limit <= n && lo <= n && n <= hi;
+        }
+
+        void inc(unsigned num_clauses) {
+            count++;
+            unsigned d = base * count * log2(count + 1);
+            unsigned cl = log2(num_clauses + 2);
+            limit = cl * cl * d;
+        }
+
+        void inc(unsigned num_conflicts, unsigned num_clauses) {
+            inc(num_clauses);
+            limit += num_conflicts;
+        }
+
+    };
+
 };
 
 
diff --git a/src/sat/smt/arith_sls.cpp b/src/sat/smt/arith_sls.cpp
index b32358a3c..4fe153289 100644
--- a/src/sat/smt/arith_sls.cpp
+++ b/src/sat/smt/arith_sls.cpp
@@ -29,42 +29,49 @@ namespace arith {
         m_terms.reset();
     }
 
-    void sls::log() {
-        IF_VERBOSE(2, verbose_stream() << "(sls :flips " << m_stats.m_num_flips << " :unsat " << unsat().size() << ")\n");
-    }
-
     void sls::save_best_values() {
         for (unsigned v = 0; v < s.get_num_vars(); ++v)
             m_vars[v].m_best_value = m_vars[v].m_value;
-
-        auto check_bool_var = [&](sat::bool_var bv) {
-            auto const* ineq = atom(bv);
-            if (!ineq)
-                return;
-            sat::literal lit(bv, !m_bool_search->get_value(bv));
-            int64_t d = dtt(lit.sign(), *ineq);
-            // verbose_stream() << "check " << lit << " " << *ineq << "\n";
-            if (is_true(lit) != (d == 0)) {
-                verbose_stream() << lit << " " << *ineq << "\n";
+        check_ineqs();   
+        if (unsat().size() == 1) {
+            auto idx = *unsat().begin();
+            verbose_stream() << idx << "\n";
+            auto const& c = *m_bool_search->m_clauses[idx].m_clause;
+            verbose_stream() << c << "\n";
+            for (auto lit : c) {
+                bool_var bv = lit.var();
+                ineq* i = atom(bv);
+                if (i)
+                    verbose_stream() << lit << ": " << *i << "\n";
             }
-            VERIFY(is_true(lit) == (d == 0));
-        };
-        for (unsigned v = 0; v < s.get_num_vars(); ++v) 
-            check_bool_var(v);        
+            verbose_stream() << "\n";
+        }
     }
 
     void sls::store_best_values() {
         // first compute assignment to terms
         // then update non-basic variables in tableau.
-        for (auto const& [t, v] : m_terms) {
+
+        if (!unsat().empty())
+            return;
+        
+        for (auto const& [t,v] : m_terms) {
             int64_t val = 0;
             lp::lar_term const& term = s.lp().get_term(t);
-            for (lp::lar_term::ival arg : term) {
+            for (lp::lar_term::ival const& arg : term) {
                 auto t2 = s.lp().column2tv(arg.column());
                 auto w = s.lp().local_to_external(t2.id());
                 val += to_numeral(arg.coeff()) * m_vars[w].m_best_value;
             }
-            update(v, val);
+            if (v == 52) {
+                verbose_stream() << "update v" << v << " := " << val << "\n";
+                for (lp::lar_term::ival const& arg : term) {
+                    auto t2 = s.lp().column2tv(arg.column());
+                    auto w = s.lp().local_to_external(t2.id());
+                    verbose_stream() << "v" << w << " := " << m_vars[w].m_best_value << " * " << to_numeral(arg.coeff()) << "\n";
+                }
+            }
+            m_vars[v].m_best_value = val;
         }
 
         for (unsigned v = 0; v < s.get_num_vars(); ++v) {
@@ -80,16 +87,15 @@ namespace arith {
                 rational new_value_(new_value, rational::i64());
                 lp::impq val(new_value_, rational::zero());
                 s.lp().set_value_for_nbasic_column(vj.index(), val);
-                // TODO - figure out why this leads to unsound (unsat).
             }
         }
 
         lbool r = s.make_feasible();
         VERIFY (!unsat().empty() || r == l_true);    
-        if (unsat().empty()) {
+#if 0
+        if (unsat().empty()) 
             s.m_num_conflicts = s.get_config().m_arith_propagation_threshold;
-        }
-        verbose_stream() << "has changed " << s.m_solver->has_changed_columns() << "\n";
+#endif   
 
         auto check_bool_var = [&](sat::bool_var bv) {
             auto* ineq = m_bool_vars.get(bv, nullptr);
@@ -105,10 +111,10 @@ namespace arith {
                 return;
             switch (b->get_bound_kind()) {
             case lp_api::lower_t:
-                verbose_stream() << bv << " " << bound << " <= " << s.get_value(v) << "\n";
+                verbose_stream() << "v" << v << " " << bound << " <= " << s.get_value(v) << " " << m_vars[v].m_best_value << "\n";
                 break;
             case lp_api::upper_t:
-                verbose_stream() << bv << " " << bound << " >= " << s.get_value(v) << "\n";
+                verbose_stream() << "v" << v << " " << bound << " >= " << s.get_value(v) << " " << m_vars[v].m_best_value << "\n";
                 break;
             }
             int64_t value = 0;
@@ -117,6 +123,12 @@ namespace arith {
             }
             ineq->m_args_value = value;
             verbose_stream() << *ineq << " dtt " << dtt(false, *ineq) << " phase " << s.get_phase(bv) << " model " << m_bool_search->get_model()[bv] << "\n";
+            for (auto const& [coeff, v] : ineq->m_args) 
+                verbose_stream() << "v" << v << " := " << m_vars[v].m_best_value << "\n";
+            s.display(verbose_stream());
+            display(verbose_stream());
+            UNREACHABLE();
+            exit(0);
         };
 
         if (unsat().empty()) {
@@ -200,16 +212,16 @@ namespace arith {
         return dtt(sign, ineq.m_args_value + coeff * (new_value - old_value), ineq);
     }
 
-    bool sls::cm(bool sign, ineq const& ineq, var_t v, int64_t& new_value) {
+    bool sls::cm(bool old_sign, ineq const& ineq, var_t v, int64_t& new_value) {
         for (auto const& [coeff, w] : ineq.m_args) 
             if (w == v)
-                return cm(sign, ineq, v, coeff, new_value);        
+                return cm(old_sign, ineq, v, coeff, new_value);        
         return false;
     }
 
-    bool sls::cm(bool new_sign, ineq const& ineq, var_t v, int64_t coeff, int64_t& new_value) {
-        SASSERT(ineq.is_true() == new_sign);
-        VERIFY(ineq.is_true() == new_sign);
+    bool sls::cm(bool old_sign, ineq const& ineq, var_t v, int64_t coeff, int64_t& new_value) {
+        SASSERT(ineq.is_true() != old_sign);
+        VERIFY(ineq.is_true() != old_sign);
         auto bound = ineq.m_bound;
         auto argsv = ineq.m_args_value;
         bool solved = false;
@@ -239,7 +251,7 @@ namespace arith {
             return true;
         };
 
-        if (new_sign) {
+        if (!old_sign) {
             switch (ineq.m_op) {                
             case ineq_kind::LE:
                 // args <= bound -> args > bound
@@ -300,10 +312,10 @@ namespace arith {
         int64_t new_value;
         auto v = ineq.m_var_to_flip;
         if (v == UINT_MAX) {
-            // verbose_stream() << "no var to flip\n";
+            IF_VERBOSE(1, verbose_stream() << "no var to flip\n");
             return false;
         }
-        if (!cm(!sign, ineq, v, new_value)) {
+        if (!cm(sign, ineq, v, new_value)) {
             verbose_stream() << "no critical move for " << v << "\n";
             return false;
         }
@@ -316,16 +328,16 @@ namespace arith {
     // TODO - use cached dts instead of computed dts
     // cached dts has to be updated when the score of literals are updated.
     // 
-    double sls::dscore(var_t v, int64_t new_value) const {
-        verbose_stream() << "dscore\n";
+    double sls::dscore(var_t v, int64_t new_value) const {        
         double score = 0;
-#if 0
         auto const& vi = m_vars[v];
-        verbose_stream() << "dscore " << v << "\n";
-        for (auto const& [coeff, lit] : vi.m_literals) 
-            for (auto cl : m_bool_search->get_use_list(lit))              
-                score += (compute_dts(cl) - dts(cl, v, new_value)) * m_bool_search->get_weight(cl);    
-#endif
+        for (auto const& [coeff, bv] : vi.m_bool_vars) {
+            sat::literal lit(bv, false);
+            for (auto cl : m_bool_search->get_use_list(lit))
+                score += (compute_dts(cl) - dts(cl, v, new_value)) * m_bool_search->get_weight(cl);
+            for (auto cl : m_bool_search->get_use_list(~lit))
+                score += (compute_dts(cl) - dts(cl, v, new_value)) * m_bool_search->get_weight(cl);
+        }
         return score;
     }
 
@@ -341,12 +353,12 @@ namespace arith {
         int64_t old_value = vi.m_value;
         for (auto const& [coeff, bv] : vi.m_bool_vars) {
             auto const& ineq = *atom(bv);            
-            bool sign = !m_bool_search->value(bv);
-            int64_t dtt_old = dtt(sign, ineq);
-            int64_t dtt_new = dtt(sign, ineq, coeff, old_value, new_value);
+            bool old_sign = sign(bv);
+            int64_t dtt_old = dtt(old_sign, ineq);
+            int64_t dtt_new = dtt(old_sign, ineq, coeff, old_value, new_value);
             if ((dtt_old == 0) == (dtt_new == 0))
                 continue;
-            sat::literal lit(bv, sign);
+            sat::literal lit(bv, old_sign);
             if (dtt_old == 0) 
                 // flip from true to false
                 lit.neg();
@@ -408,14 +420,14 @@ namespace arith {
         auto old_value = vi.m_value;
         for (auto const& [coeff, bv] : vi.m_bool_vars) {
             auto& ineq = *atom(bv);
-            bool sign = !m_bool_search->value(bv);
-            sat::literal lit(bv, sign);
+            bool old_sign = sign(bv);
+            sat::literal lit(bv, old_sign);
             SASSERT(is_true(lit));            
             ineq.m_args_value += coeff * (new_value - old_value);
-            int64_t dtt_new = dtt(sign, ineq);
+            int64_t dtt_new = dtt(old_sign, ineq);
             if (dtt_new != 0)
                 m_bool_search->flip(bv);                                                 
-            SASSERT(dtt(!m_bool_search->value(bv), ineq) == 0);
+            SASSERT(dtt(sign(bv), ineq) == 0);
         }
         vi.m_value = new_value;
     }
@@ -451,7 +463,7 @@ namespace arith {
     void sls::add_args(sat::bool_var bv, ineq& ineq, lp::tv t, theory_var v, int64_t sign) {
         if (t.is_term()) {
             lp::lar_term const& term = s.lp().get_term(t);
-
+            m_terms.push_back({t,v});
             for (lp::lar_term::ival arg : term) {
                 auto t2 = s.lp().column2tv(arg.column());
                 auto w = s.lp().local_to_external(t2.id());
@@ -479,6 +491,7 @@ namespace arith {
 
             auto& ineq = new_ineq(op, to_numeral(bound));
 
+
             add_args(bv, ineq, t, b->get_var(), should_minus ? -1 : 1);
             m_bool_vars.set(bv, &ineq);
             m_bool_search->set_external(bv);
@@ -516,7 +529,7 @@ namespace arith {
     }
 
     void sls::flip(sat::bool_var v)  {
-        sat::literal lit(v, m_bool_search->get_value(v));
+        sat::literal lit(v, !sign(v));
         SASSERT(!is_true(lit));
         auto const* ineq = atom(v);
         if (!ineq) 
@@ -524,7 +537,7 @@ namespace arith {
         if (!ineq)
             return;
         SASSERT(ineq->is_true() == lit.sign());
-        flip(!lit.sign(), *ineq);
+        flip(sign(v), *ineq);
     }
 
     double sls::reward(sat::bool_var v) {
@@ -535,21 +548,23 @@ namespace arith {
     }
 
     double sls::dtt_reward(sat::bool_var bv0) {
-        bool sign0 = !m_bool_search->get_value(bv0);
+        bool sign0 = sign(bv0);
         auto* ineq = atom(bv0);
         if (!ineq)
             return -1;
         int64_t new_value;      
         double max_result = -1;
         for (auto const & [coeff, x] : ineq->m_args) {
-            if (!cm(!sign0, *ineq, x, coeff, new_value))
+            if (!cm(sign0, *ineq, x, coeff, new_value))
                 continue;
             double result = 0;
             auto old_value = m_vars[x].m_value;
             for (auto const& [coeff, bv] : m_vars[x].m_bool_vars) {
-                bool sign = !m_bool_search->value(bv);
-                auto dtt_old = dtt(sign, *atom(bv));
-                auto dtt_new = dtt(sign, *atom(bv), coeff, old_value, new_value);
+                result += m_bool_search->reward(bv);
+                continue;
+                bool old_sign = sign(bv);
+                auto dtt_old = dtt(old_sign, *atom(bv));
+                auto dtt_new = dtt(old_sign, *atom(bv), coeff, old_value, new_value);
                 if ((dtt_new == 0) != (dtt_old == 0))
                     result += m_bool_search->reward(bv);
             }
@@ -563,17 +578,17 @@ namespace arith {
 
     double sls::dscore_reward(sat::bool_var bv) {
         m_dscore_mode = false;
-        bool sign = !m_bool_search->get_value(bv);
-        sat::literal litv(bv, sign);
+        bool old_sign = sign(bv);
+        sat::literal litv(bv, old_sign);
         auto* ineq = atom(bv);
         if (!ineq)
             return 0;
-        SASSERT(ineq->is_true() != sign);
+        SASSERT(ineq->is_true() != old_sign);
         int64_t new_value;
 
         for (auto const& [coeff, v] : ineq->m_args) {
             double result = 0;
-            if (cm(!sign, *ineq, v, coeff, new_value))
+            if (cm(old_sign, *ineq, v, coeff, new_value))
                 result = dscore(v, new_value);
             // just pick first positive, or pick a max?
             if (result > 0) {
@@ -586,7 +601,7 @@ namespace arith {
 
     // switch to dscore mode
     void sls::on_rescale()  {
-        // m_dscore_mode = true;
+        m_dscore_mode = true;
     }
 
     void sls::on_save_model() {
@@ -597,23 +612,39 @@ namespace arith {
         for (unsigned v = 0; v < s.s().num_vars(); ++v)
             init_bool_var_assignment(v);
 
-        verbose_stream() << "on-restart\n";
+        check_ineqs();
+    }
+
+    void sls::check_ineqs() {
+
         auto check_bool_var = [&](sat::bool_var bv) {
             auto const* ineq = atom(bv);
             if (!ineq)
                 return;
-            bool sign = !m_bool_search->get_value(bv);
-            int64_t d = dtt(sign, *ineq);
-            sat::literal lit(bv, sign);
-            // verbose_stream() << "check " << lit << " " << *ineq << "\n";
+            int64_t d = dtt(sign(bv), *ineq);
+            sat::literal lit(bv, sign(bv));
             if (is_true(lit) != (d == 0)) {
-                verbose_stream() << "restart " << bv << " " << *ineq << "\n";
+                verbose_stream() << "invalid assignment " << bv << " " << *ineq << "\n";
             }
             VERIFY(is_true(lit) == (d == 0));
         };
-        for (unsigned v = 0; v < s.get_num_vars(); ++v) 
+        for (unsigned v = 0; v < s.get_num_vars(); ++v)
             check_bool_var(v);
-        
-        verbose_stream() << "on-restart-done\n";
     }
+
+    std::ostream& sls::display(std::ostream& out) const {
+        for (bool_var bv = 0; bv < s.s().num_vars(); ++bv) {
+            auto const* ineq = atom(bv);            
+            if (!ineq)
+                continue;
+            out << bv << " " << *ineq << "\n";            
+        }
+        for (unsigned v = 0; v < s.get_num_vars(); ++v) {
+            if (s.is_bool(v))
+                continue;
+            out << "v" << v << " := " << m_vars[v].m_value << " " << m_vars[v].m_best_value << "\n";
+        }
+        return out;
+    }
+
 }
diff --git a/src/sat/smt/arith_sls.h b/src/sat/smt/arith_sls.h
index 3c9daaa51..af3a46234 100644
--- a/src/sat/smt/arith_sls.h
+++ b/src/sat/smt/arith_sls.h
@@ -119,12 +119,11 @@ namespace arith {
         sat::ddfw::clause_info& get_clause_info(unsigned idx) { return m_bool_search->get_clause_info(idx); }
         sat::ddfw::clause_info const& get_clause_info(unsigned idx) const { return m_bool_search->get_clause_info(idx); }
         bool is_true(sat::literal lit) { return lit.sign() != m_bool_search->get_value(lit.var()); }
+        bool sign(sat::bool_var v) const { return !m_bool_search->get_value(v); }
 
         void reset();
         ineq* atom(sat::bool_var bv) const { return m_bool_vars[bv]; }
 
-        void log();
-
         bool flip(bool sign, ineq const& ineq);
         int64_t dtt(bool sign, ineq const& ineq) const { return dtt(sign, ineq.m_args_value, ineq); }
         int64_t dtt(bool sign, int64_t args_value, ineq const& ineq) const;
@@ -151,6 +150,10 @@ namespace arith {
         int64_t value(var_t v) const { return m_vars[v].m_value; }
         int64_t to_numeral(rational const& r);
 
+        void check_ineqs();
+
+        std::ostream& display(std::ostream& out) const;
+
     public:
         sls(solver& s);
         ~sls() override {}