From 4375f54c458d5d5996d30a23918e6006667ba2c7 Mon Sep 17 00:00:00 2001
From: Nikolaj Bjorner <nbjorner@microsoft.com>
Date: Tue, 13 Mar 2018 13:31:27 -0700
Subject: [PATCH] adding lns

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
---
 src/opt/CMakeLists.txt  |   1 +
 src/opt/opt_context.cpp |  74 +++++++++++++++++++++-----
 src/opt/opt_context.h   |  18 ++++---
 src/opt/opt_lns.cpp     | 115 ++++++++++++++++++++++++++++++++++++++++
 src/opt/opt_lns.h       |  66 +++++++++++++++++++++++
 src/opt/opt_params.pyg  |   2 +-
 src/sat/sat_scc.cpp     |   2 +-
 7 files changed, 257 insertions(+), 21 deletions(-)
 create mode 100644 src/opt/opt_lns.cpp
 create mode 100644 src/opt/opt_lns.h

diff --git a/src/opt/CMakeLists.txt b/src/opt/CMakeLists.txt
index 28a14be2e..3f1d8253c 100644
--- a/src/opt/CMakeLists.txt
+++ b/src/opt/CMakeLists.txt
@@ -5,6 +5,7 @@ z3_add_component(opt
     mss.cpp
     opt_cmds.cpp
     opt_context.cpp
+    opt_lns.cpp
     opt_pareto.cpp
     opt_parse.cpp
     optsmt.cpp
diff --git a/src/opt/opt_context.cpp b/src/opt/opt_context.cpp
index 8cbcbed5e..1396ae9c4 100644
--- a/src/opt/opt_context.cpp
+++ b/src/opt/opt_context.cpp
@@ -145,9 +145,8 @@ namespace opt {
     }
 
     void context::reset_maxsmts() {
-        map_t::iterator it = m_maxsmts.begin(), end = m_maxsmts.end();
-        for (; it != end; ++it) {
-            dealloc(it->m_value);
+        for (auto& kv : m_maxsmts) {
+            dealloc(kv.m_value);
         }
         m_maxsmts.reset();
     }
@@ -255,6 +254,9 @@ namespace opt {
         if (m_pareto) {
             return execute_pareto();
         }
+        if (m_lns) {
+            return execute_lns();
+        }
         if (m_box_index != UINT_MAX) {
             return execute_box();
         }
@@ -271,10 +273,16 @@ namespace opt {
 #endif
         solver& s = get_solver();
         s.assert_expr(m_hard_constraints);
-        IF_VERBOSE(1, verbose_stream() << "(optimize:check-sat)\n";);
+        
+        opt_params optp(m_params);
+        symbol pri = optp.priority();
+        if (pri == symbol("lns")) {
+            return execute_lns();
+        }
+
+        IF_VERBOSE(1, verbose_stream() << "(optimize:check-sat)\n");
         lbool is_sat = s.check_sat(0,0);
-        TRACE("opt", tout << "initial search result: " << is_sat << "\n";
-              s.display(tout););
+        TRACE("opt", s.display(tout << "initial search result: " << is_sat << "\n");); 
         if (is_sat != l_false) {
             s.get_model(m_model);
             s.get_labels(m_labels);
@@ -286,7 +294,7 @@ namespace opt {
             TRACE("opt", tout << m_hard_constraints << "\n";);            
             return is_sat;
         }
-        IF_VERBOSE(1, verbose_stream() << "(optimize:sat)\n";);
+        IF_VERBOSE(1, verbose_stream() << "(optimize:sat)\n");
         TRACE("opt", model_smt2_pp(tout, m, *m_model, 0););
         m_optsmt.setup(*m_opt_solver.get());
         update_lower();
@@ -303,6 +311,9 @@ namespace opt {
             if (pri == symbol("pareto")) {
                 is_sat = execute_pareto();
             }
+            else if (pri == symbol("lns")) {
+                is_sat = execute_lns();
+            }
             else if (pri == symbol("box")) {
                 is_sat = execute_box();
             }
@@ -525,7 +536,12 @@ namespace opt {
     }
 
     void context::yield() {
-        m_pareto->get_model(m_model, m_labels);
+        if (m_pareto) {
+            m_pareto->get_model(m_model, m_labels);
+        }
+        else if (m_lns) {
+            m_lns->get_model(m_model, m_labels);
+        }
         update_bound(true);
         update_bound(false);
     }
@@ -536,7 +552,7 @@ namespace opt {
         }
         lbool is_sat = (*(m_pareto.get()))();
         if (is_sat != l_true) {
-            set_pareto(0);
+            set_pareto(nullptr);
         }
         if (is_sat == l_true) {
             yield();
@@ -544,6 +560,20 @@ namespace opt {
         return is_sat;
     }
 
+    lbool context::execute_lns() {
+        if (!m_lns) {
+            m_lns = alloc(lns, *this, m_solver.get());
+        }
+        lbool is_sat = (*(m_lns.get()))();
+        if (is_sat != l_true) {
+            m_lns = nullptr;
+        }
+        if (is_sat == l_true) {
+            yield();
+        }
+        return l_undef;
+    }
+
     std::string context::reason_unknown() const { 
         if (m.canceled()) {
             return Z3_CANCELED_MSG;
@@ -990,6 +1020,24 @@ namespace opt {
         }
     }
 
+    /**
+       \brief retrieve literals used by the neighborhood search feature.
+     */
+
+    void context::get_lns_literals(expr_ref_vector& lits) {
+        for (objective & obj : m_objectives) {
+            switch(obj.m_type) {
+            case O_MAXSMT: 
+                for (expr* f : obj.m_terms) {
+                    lits.push_back(f);
+                }
+                break;
+            default:
+                break;
+            }
+        }
+    }
+
     bool context::verify_model(unsigned index, model* md, rational const& _v) {
         rational r;
         app_ref term = m_objectives[index].m_term;
@@ -1352,7 +1400,8 @@ namespace opt {
     }
 
     void context::clear_state() {
-        set_pareto(0);
+        m_pareto = nullptr;
+        m_lns = nullptr;
         m_box_index = UINT_MAX;
         m_model.reset();
     }
@@ -1388,9 +1437,8 @@ namespace opt {
             m_solver->updt_params(m_params);
         }
         m_optsmt.updt_params(m_params);
-        map_t::iterator it = m_maxsmts.begin(), end = m_maxsmts.end();
-        for (; it != end; ++it) {
-            it->m_value->updt_params(m_params);
+        for (auto & kv : m_maxsmts) {
+            kv.m_value->updt_params(m_params);
         }
         opt_params _p(p);
         m_enable_sat = _p.enable_sat();
diff --git a/src/opt/opt_context.h b/src/opt/opt_context.h
index e4d1f8e2d..9f55d5ca1 100644
--- a/src/opt/opt_context.h
+++ b/src/opt/opt_context.h
@@ -19,16 +19,18 @@ Notes:
 #define OPT_CONTEXT_H_
 
 #include "ast/ast.h"
+#include "ast/arith_decl_plugin.h"
+#include "ast/bv_decl_plugin.h"
+#include "tactic/model_converter.h"
+#include "tactic/tactic.h"
+#include "qe/qsat.h"
 #include "opt/opt_solver.h"
 #include "opt/opt_pareto.h"
 #include "opt/optsmt.h"
+#include "opt/opt_lns.h"
 #include "opt/maxsmt.h"
-#include "tactic/model_converter.h"
-#include "tactic/tactic.h"
-#include "ast/arith_decl_plugin.h"
-#include "ast/bv_decl_plugin.h"
 #include "cmd_context/cmd_context.h"
-#include "qe/qsat.h"
+
 
 namespace opt {
 
@@ -145,6 +147,7 @@ namespace opt {
         ref<solver>         m_solver;
         ref<solver>         m_sat_solver;
         scoped_ptr<pareto_base>          m_pareto;
+        scoped_ptr<lns>      m_lns;
         scoped_ptr<qe::qmax> m_qmax;
         sref_vector<model>  m_box_models;
         unsigned            m_box_index;
@@ -231,6 +234,8 @@ namespace opt {
 
         virtual bool verify_model(unsigned id, model* mdl, rational const& v);
 
+        void get_lns_literals(expr_ref_vector& lits);
+
     private:
         lbool execute(objective const& obj, bool committed, bool scoped);
         lbool execute_min_max(unsigned index, bool committed, bool scoped, bool is_max);
@@ -238,6 +243,7 @@ namespace opt {
         lbool execute_lex();
         lbool execute_box();
         lbool execute_pareto();
+        lbool execute_lns();
         lbool adjust_unknown(lbool r);
         bool scoped_lex();
         expr_ref to_expr(inf_eps const& n);
@@ -282,7 +288,7 @@ namespace opt {
         void    setup_arith_solver();
         void    add_maxsmt(symbol const& id, unsigned index);
         void    set_simplify(tactic *simplify);
-        void    set_pareto(pareto_base* p);        
+        void    set_pareto(pareto_base* p);  
         void    clear_state();
 
         bool is_numeral(expr* e, rational& n) const;
diff --git a/src/opt/opt_lns.cpp b/src/opt/opt_lns.cpp
new file mode 100644
index 000000000..d8692ab59
--- /dev/null
+++ b/src/opt/opt_lns.cpp
@@ -0,0 +1,115 @@
+/*++
+Copyright (c) 2018 Microsoft Corporation
+
+Module Name:
+
+    opt_lns.cpp
+
+Abstract:
+
+    Large neighborhood search default implementation 
+    based on phase saving and assumptions
+
+Author:
+
+    Nikolaj Bjorner (nbjorner) 2018-3-13
+
+Notes:
+
+   
+--*/
+
+#include "opt/opt_lns.h"
+#include "opt/opt_context.h"
+
+namespace opt {
+   
+    lns::lns(context& ctx, solver* s):
+        m(ctx.get_manager()),
+        m_ctx(ctx),
+        m_solver(s),
+        m_models_trail(m)
+    {}
+
+    lns::~lns() {}
+
+    void lns::display(std::ostream & out) const {
+        for (auto const& q : m_queue) {
+            out << q.m_index << ": " << q.m_assignment << "\n";
+        }
+    }
+    
+    lbool lns::operator()() {
+
+        if (m_queue.empty()) {
+            expr_ref_vector lits(m);
+            m_ctx.get_lns_literals(lits);
+            m_queue.push_back(queue_elem(lits));
+            m_qhead = 0;
+        }
+
+        params_ref p;
+        p.set_uint("sat.inprocess.max", 3);
+        p.set_uint("smt.max_conflicts", 10000);
+        m_solver->updt_params(p);
+
+        while (m_qhead < m_queue.size()) {
+            unsigned index = m_queue[m_qhead].m_index;            
+            if (index > m_queue[m_qhead].m_assignment.size()) {
+                ++m_qhead;
+                continue;
+            }
+            IF_VERBOSE(2, verbose_stream() << "(opt.lns :queue " << m_qhead << " :index " << index << ")\n");
+            
+            // recalibrate state to an initial satisfying assignment
+            lbool is_sat = m_solver->check_sat(m_queue[m_qhead].m_assignment);
+            IF_VERBOSE(2, verbose_stream() << "(opt.lns :calibrate-status " << is_sat << ")\n");
+
+            expr_ref lit(m_queue[m_qhead].m_assignment[index].get(), m);
+            lit = mk_not(m, lit);
+            expr* lits[1] = { lit };
+            ++m_queue[m_qhead].m_index;
+            if (!m.limit().inc()) {
+                return l_undef;
+            }
+
+            // Update configuration for local search:
+            // p.set_uint("sat.local_search_threads", 2);
+            // p.set_uint("sat.unit_walk_threads", 1);
+            
+            is_sat = m_solver->check_sat(1, lits);
+            IF_VERBOSE(2, verbose_stream() << "(opt.lns :status " << is_sat << ")\n");
+            if (is_sat == l_true && add_assignment()) {
+                return l_true;
+            }
+        }        
+        return l_false;
+    }
+
+    bool lns::add_assignment() {
+        model_ref mdl;
+        m_solver->get_model(mdl);
+        m_ctx.fix_model(mdl);
+        expr_ref tmp(m);
+        expr_ref_vector fmls(m);
+        for (expr* f : m_queue[0].m_assignment) {
+            mdl->eval(f, tmp);
+            if (m.is_false(tmp)) {
+                fmls.push_back(mk_not(m, tmp));
+            }
+            else {
+                fmls.push_back(tmp);
+            }
+        }
+        tmp = mk_and(fmls);
+        if (m_models.contains(tmp)) {
+            return false;
+        }
+        else {
+            m_models.insert(tmp);
+            m_models_trail.push_back(tmp);
+            return true;
+        }
+    }
+}
+
diff --git a/src/opt/opt_lns.h b/src/opt/opt_lns.h
new file mode 100644
index 000000000..8033a8ea8
--- /dev/null
+++ b/src/opt/opt_lns.h
@@ -0,0 +1,66 @@
+/*++
+Copyright (c) 2018 Microsoft Corporation
+
+Module Name:
+
+    opt_lns.h
+
+Abstract:
+
+    Large neighborhood seearch
+
+Author:
+
+    Nikolaj Bjorner (nbjorner) 2018-3-13
+
+Notes:
+
+   
+--*/
+#ifndef OPT_LNS_H_
+#define OPT_LNS_H_
+
+#include "solver/solver.h"
+#include "model/model.h"
+
+namespace opt {
+
+    class context;
+
+    class lns {
+        struct queue_elem {
+            expr_ref_vector m_assignment;
+            unsigned        m_index;
+            queue_elem(expr_ref_vector& assign):
+                m_assignment(assign),
+                m_index(0)
+            {}
+        };
+        ast_manager&     m;
+        context&         m_ctx;
+        ref<solver>      m_solver;
+        model_ref        m_model;
+        svector<symbol>  m_labels;
+        vector<queue_elem> m_queue;
+        unsigned         m_qhead;
+        expr_ref_vector  m_models_trail;
+        obj_hashtable<expr> m_models;
+
+        bool add_assignment();
+    public:
+        lns(context& ctx, solver* s);
+
+        ~lns();
+
+        void display(std::ostream & out) const;
+
+        lbool operator()();
+
+        void get_model(model_ref& mdl, svector<symbol>& labels) {
+            mdl = m_model;
+            labels = m_labels;
+        }
+    };
+}
+
+#endif
diff --git a/src/opt/opt_params.pyg b/src/opt/opt_params.pyg
index cfcc5e47e..21845d38a 100644
--- a/src/opt/opt_params.pyg
+++ b/src/opt/opt_params.pyg
@@ -3,7 +3,7 @@ def_module_params('opt',
                   export=True,
                   params=(('optsmt_engine', SYMBOL, 'basic', "select optimization engine: 'basic', 'farkas', 'symba'"),
                           ('maxsat_engine', SYMBOL, 'maxres', "select engine for maxsat: 'core_maxsat', 'wmax', 'maxres', 'pd-maxres'"),
-                          ('priority', SYMBOL, 'lex', "select how to priortize objectives: 'lex' (lexicographic), 'pareto', or 'box'"),
+                          ('priority', SYMBOL, 'lex', "select how to priortize objectives: 'lex' (lexicographic), 'pareto', 'box', or 'lns' (large neighborhood search)"),
                           ('dump_benchmarks', BOOL, False, 'dump benchmarks for profiling'),
                           ('timeout', UINT, UINT_MAX, 'timeout (in milliseconds) (UINT_MAX and 0 mean no timeout)'),
                           ('rlimit', UINT, 0, 'resource limit (0 means no limit)'),
diff --git a/src/sat/sat_scc.cpp b/src/sat/sat_scc.cpp
index 5fed5e0f6..8a2029df2 100644
--- a/src/sat/sat_scc.cpp
+++ b/src/sat/sat_scc.cpp
@@ -222,7 +222,7 @@ namespace sat {
             }
         }
         TRACE("scc", for (unsigned i = 0; i < roots.size(); i++) { tout << i << " -> " << roots[i] << "\n"; }
-              tout << "to_elim: "; for (literal l : to_elim) tout << l << " "; tout << "\n";);
+              tout << "to_elim: "; for (unsigned v : to_elim) tout << v << " "; tout << "\n";);
         m_num_elim += to_elim.size();
         elim_eqs eliminator(m_solver);
         eliminator(roots, to_elim);