From 2682c2ef2b3f31f065cc54b83e91f6d42c60db2f Mon Sep 17 00:00:00 2001
From: Nikolaj Bjorner <nbjorner@microsoft.com>
Date: Sat, 13 Apr 2024 16:42:26 +0200
Subject: [PATCH] sls updates

- add SINGLE_THREAD mode
- add interface to retrieve "best" model so far
---
 src/ast/sls/bv_sls.cpp          | 18 +++++-
 src/ast/sls/bv_sls.h            |  8 +++
 src/sat/smt/intblast_solver.cpp |  2 +-
 src/sat/smt/sls_solver.cpp      | 98 +++++++++++++++++----------------
 src/sat/smt/sls_solver.h        | 68 +++++++++++++++++------
 5 files changed, 128 insertions(+), 66 deletions(-)

diff --git a/src/ast/sls/bv_sls.cpp b/src/ast/sls/bv_sls.cpp
index 9af87d3c5..c0972349b 100644
--- a/src/ast/sls/bv_sls.cpp
+++ b/src/ast/sls/bv_sls.cpp
@@ -61,6 +61,17 @@ namespace bv {
         }
     }
 
+
+    void sls::set_model() {
+        if (!m_set_model)
+            return;
+        if (m_repair_roots.size() >= m_min_repair_size)
+            return;
+        m_min_repair_size = m_repair_roots.size();
+        IF_VERBOSE(2, verbose_stream() << "(sls-update-model :num-unsat " << m_min_repair_size << ")\n");
+        m_set_model(*get_model());
+    }
+
     void sls::init_repair_goal(app* t) {
         m_eval.init_eval(t);
     }
@@ -94,6 +105,9 @@ namespace bv {
         if (m_to_repair.empty())
             return;
 
+        // refresh the best model so far to a callback
+        set_model();
+
         // add fresh units, if any
         bool new_assertion = false;
         while (m_get_unit) {
@@ -130,7 +144,7 @@ namespace bv {
             return m_rand() % 2 == 0;
         };
         m_eval.init_eval(m_terms.assertions(), eval);
-        init_repair();
+        init_repair();        
         // m_engine_init = false;
     }
 
@@ -295,10 +309,12 @@ namespace bv {
         model_ref mdl = alloc(model, m);         
         auto& terms = m_eval.sort_assertions(m_terms.assertions());
         for (expr* e : terms) {
+#if 0
             if (!m_eval.re_eval_is_correct(to_app(e))) {
                 verbose_stream() << "missed evaluation #" << e->get_id() << " " << mk_bounded_pp(e, m) << "\n";
                 m_eval.display_value(verbose_stream(), e) << "\n";
             }
+#endif
             if (!is_uninterp_const(e))
                 continue;
 
diff --git a/src/ast/sls/bv_sls.h b/src/ast/sls/bv_sls.h
index 690b618bf..987cebcdb 100644
--- a/src/ast/sls/bv_sls.h
+++ b/src/ast/sls/bv_sls.h
@@ -54,10 +54,13 @@ namespace bv {
         bool                m_engine_model = false;
         bool                m_engine_init = false;
         std::function<expr_ref()> m_get_unit;
+        std::function<void(model& mdl)> m_set_model;
+        unsigned            m_min_repair_size = UINT_MAX;
         
         std::pair<bool, app*> next_to_repair();
         
         void init_repair_goal(app* e);
+        void set_model();
         void try_repair_down(app* e);
         void try_repair_up(app* e);
         void set_repair_down(expr* e) { m_repair_down = e->get_id(); }
@@ -96,6 +99,11 @@ namespace bv {
         */
         void init_unit(std::function<expr_ref()> get_unit) { m_get_unit = get_unit; }
 
+        /**
+        * Add callback to set model
+        */
+        void set_model(std::function<void(model& mdl)> f) { m_set_model = f; }
+
         /**
         * Run (bounded) local search to find feasible assignments.
         */
diff --git a/src/sat/smt/intblast_solver.cpp b/src/sat/smt/intblast_solver.cpp
index f4491896b..459b26339 100644
--- a/src/sat/smt/intblast_solver.cpp
+++ b/src/sat/smt/intblast_solver.cpp
@@ -1069,7 +1069,7 @@ namespace intblast {
         if (e->get_family_id() != bv.get_family_id())
             return false;
         for (euf::enode* arg : euf::enode_args(n))
-            dep.add(n, arg->get_root());
+            dep.add(n, arg);
         return true;
     }
 
diff --git a/src/sat/smt/sls_solver.cpp b/src/sat/smt/sls_solver.cpp
index e12ff5ba7..a507619ee 100644
--- a/src/sat/smt/sls_solver.cpp
+++ b/src/sat/smt/sls_solver.cpp
@@ -22,31 +22,37 @@ Author:
 
 namespace sls {
 
+#ifdef SINGLE_THREAD
+
+    solver::solver(euf::solver& ctx) :
+        th_euf_solver(ctx, symbol("sls"), ctx.get_manager().mk_family_id("sls"))
+        {}
+
+#else
     solver::solver(euf::solver& ctx):
-        th_euf_solver(ctx, symbol("sls"), ctx.get_manager().mk_family_id("sls")),
-        m_units(m) {}
+        th_euf_solver(ctx, symbol("sls"), ctx.get_manager().mk_family_id("sls"))
+        {}
 
     solver::~solver() {
         finalize();
     }
 
     void solver::finalize() {
-        if (!m_completed && m_bvsls) {
-            m_bvsls->cancel();
+        if (!m_completed && m_sls) {
+            m_sls->cancel();
             m_thread.join();
-            m_bvsls->collect_statistics(m_st);
-            m_bvsls = nullptr;
+            m_sls->collect_statistics(m_st);
+            m_sls = nullptr;
+            m_shared = nullptr;
+            m_slsm = nullptr;
+            m_units = nullptr;
         }
     }
 
     sat::check_result solver::check() { 
-
         return sat::check_result::CR_DONE; 
     }
 
-    void solver::simplify() {    
-    }
-
     bool solver::unit_propagate() {
         force_push();
         sample_local_search();
@@ -66,10 +72,6 @@ namespace sls {
         return false;
     }
 
-    void solver::push_core() {
-
-    }
-
     void solver::pop_core(unsigned n) {
         for (; m_trail_lim < s().init_trail_size(); ++m_trail_lim) {
             auto lit = s().trail_literal(m_trail_lim);
@@ -77,60 +79,63 @@ namespace sls {
             if (is_unit(e)) {
                 // IF_VERBOSE(1, verbose_stream() << "add unit " << mk_pp(e, m) << "\n");
                 std::lock_guard<std::mutex> lock(m_mutex);
-                m_units.push_back(e);
+                ast_translation tr(m, *m_shared);
+                m_units->push_back(tr(e.get()));
                 m_has_units = true;
             }
         }
-    }
-        
-    void solver::init_search() {
-        init_local_search();
-    }
+    }       
 
-    void solver::init_local_search() {
-        if (m_bvsls) {
-            m_bvsls->cancel();
+    void solver::init_search() {
+        if (m_sls) {
+            m_sls->cancel();
             m_thread.join();
             m_result = l_undef;
             m_completed = false;
             m_has_units = false;
             m_model = nullptr;
-            m_units.reset();
+            m_units = nullptr;
         }
         // set up state for local search solver here
 
-        m_m = alloc(ast_manager, m);
-        ast_translation tr(m, *m_m);
+        m_shared = alloc(ast_manager);
+        m_slsm = alloc(ast_manager);
+        m_units = alloc(expr_ref_vector, *m_shared);
+        ast_translation tr(m, *m_slsm);
         
-        params_ref p;
         m_completed = false;
         m_result = l_undef;
         m_model = nullptr;
-        m_bvsls = alloc(bv::sls, *m_m, p);
+        m_sls = alloc(bv::sls, *m_slsm, s().params());
         
         for (expr* a : ctx.get_assertions())
-            m_bvsls->assert_expr(tr(a));
+            m_sls->assert_expr(tr(a));
 
         std::function<bool(expr*, unsigned)> eval = [&](expr* e, unsigned r) {
             return false;
         };
 
-        m_bvsls->init();
-        m_bvsls->init_eval(eval);
-        m_bvsls->updt_params(s().params());
-        m_bvsls->init_unit([&]() { 
+        m_sls->init();
+        m_sls->init_eval(eval);
+        m_sls->updt_params(s().params());
+        m_sls->init_unit([&]() { 
             if (!m_has_units)
-                return expr_ref(*m_m);
-            expr_ref e(m);
+                return expr_ref(*m_slsm);
+            expr_ref e(*m_slsm);
             {
                 std::lock_guard<std::mutex> lock(m_mutex);
-                if (m_units.empty())
-                    return expr_ref(*m_m);
-                e = m_units.back();
-                m_units.pop_back();
+                if (m_units->empty())
+                    return expr_ref(*m_slsm);
+                ast_translation tr(*m_shared, *m_slsm);
+                e = tr(m_units->back());
+                m_units->pop_back();
             }
-            ast_translation tr(m, *m_m);
-            return expr_ref(tr(e.get()), *m_m); 
+            return e;
+        });
+        m_sls->set_model([&](model& mdl) {
+            std::lock_guard<std::mutex> lock(m_mutex);
+            ast_translation tr(*m_shared, m);
+            m_model = mdl.translate(tr);
         });
                                      
         m_thread = std::thread([this]() { run_local_search(); });        
@@ -141,20 +146,21 @@ namespace sls {
             return;        
         m_thread.join();
         m_completed = false;
-        m_bvsls->collect_statistics(m_st);
+        m_sls->collect_statistics(m_st);
         if (m_result == l_true) {
             IF_VERBOSE(2, verbose_stream() << "(sat.sls :model-completed)\n";);
-            auto mdl = m_bvsls->get_model();
-            ast_translation tr(*m_m, m);
+            auto mdl = m_sls->get_model();
+            ast_translation tr(*m_slsm, m);
             m_model = mdl->translate(tr);
             s().set_canceled();
         }
-        m_bvsls = nullptr;
+        m_sls = nullptr;
     }
 
     void solver::run_local_search() {
-        m_result = (*m_bvsls)();
+        m_result = (*m_sls)();
         m_completed = true;
     }
 
+#endif
 }
diff --git a/src/sat/smt/sls_solver.h b/src/sat/smt/sls_solver.h
index 5a6c9950b..e1d8a95b5 100644
--- a/src/sat/smt/sls_solver.h
+++ b/src/sat/smt/sls_solver.h
@@ -16,13 +16,45 @@ Author:
 --*/
 #pragma once
 
-#include <thread>
-#include <mutex>
+
 #include "util/rlimit.h"
 #include "ast/sls/bv_sls.h"
 #include "sat/smt/sat_th.h"
 
 
+#ifdef SINGLE_THREAD
+
+
+namespace euf {
+    class solver;
+}
+
+namespace sls {
+
+    class solver : public euf::th_euf_solver {
+    public:
+        solver(euf::solver& ctx);
+            
+        sat::literal internalize(expr* e, bool sign, bool root) override { UNREACHABLE();  return sat::null_literal; }
+        void internalize(expr* e) override { UNREACHABLE(); }
+        th_solver* clone(euf::solver& ctx) override { return alloc(solver, ctx); }
+
+        model_ref get_model() { return model_ref(nullptr); }
+        bool unit_propagate() override { return false; }
+        void get_antecedents(sat::literal l, sat::ext_justification_idx idx, sat::literal_vector& r, bool probing) override { UNREACHABLE(); }
+        sat::check_result check() override { return sat::check_result::CR_DONE;}
+        std::ostream& display(std::ostream& out) const override { return out; }
+        std::ostream& display_justification(std::ostream& out, sat::ext_justification_idx idx) const override { UNREACHABLE(); return out; }
+        std::ostream& display_constraint(std::ostream& out, sat::ext_constraint_idx idx) const override { UNREACHABLE(); return out; }
+
+    };
+}
+
+#else
+
+#include <thread>
+#include <mutex>
+
 namespace euf {
     class solver;
 }
@@ -34,38 +66,36 @@ namespace sls {
         std::atomic<bool> m_completed, m_has_units;
         std::thread m_thread;
         std::mutex  m_mutex;
-        scoped_ptr<ast_manager> m_m;
-        scoped_ptr<bv::sls> m_bvsls;
+        // m is accessed by the main thread
+        // m_slsm is accessed by the sls thread
+        // m_shared is only accessed at synchronization points
+        scoped_ptr<ast_manager> m_shared, m_slsm;
+        scoped_ptr<bv::sls> m_sls;
+        scoped_ptr<expr_ref_vector> m_units;
         model_ref m_model;
         unsigned m_trail_lim = 0;
-        expr_ref_vector m_units;
         statistics m_st;
 
         void run_local_search();
-        void init_local_search();
         void sample_local_search();
-
         bool is_unit(expr*);
+
     public:
         solver(euf::solver& ctx);
         ~solver();
 
-        void simplify() override;
-        void init_search() override;
+        model_ref get_model() { return m_model; }
 
-        void push_core() override;
+        void init_search() override;
+        void push_core() override {}
         void pop_core(unsigned n) override;
+        th_solver* clone(euf::solver& ctx) override { return alloc(solver, ctx); }
+        void collect_statistics(statistics& st) const override { st.copy(m_st); }       
+        void finalize() override;
+        bool unit_propagate() override;
 
         sat::literal internalize(expr* e, bool sign, bool root) override { UNREACHABLE();  return sat::null_literal; }
         void internalize(expr* e) override { UNREACHABLE(); }
-        th_solver* clone(euf::solver& ctx) override { return alloc(solver, ctx); }
-        void collect_statistics(statistics& st) const override { st.copy(m_st); }
-
-        model_ref get_model() { return m_model;  }
-
-        void finalize() override;
-
-        bool unit_propagate() override;
         void get_antecedents(sat::literal l, sat::ext_justification_idx idx, sat::literal_vector & r, bool probing) override { UNREACHABLE(); }
         sat::check_result check() override;
         std::ostream & display(std::ostream & out) const override { return out; }
@@ -75,3 +105,5 @@ namespace sls {
     };
 
 }
+
+#endif
\ No newline at end of file