diff --git a/src/opt/CMakeLists.txt b/src/opt/CMakeLists.txt index 28a14be2e..3f1d8253c 100644 --- a/src/opt/CMakeLists.txt +++ b/src/opt/CMakeLists.txt @@ -5,6 +5,7 @@ z3_add_component(opt mss.cpp opt_cmds.cpp opt_context.cpp + opt_lns.cpp opt_pareto.cpp opt_parse.cpp optsmt.cpp diff --git a/src/opt/opt_context.cpp b/src/opt/opt_context.cpp index 8cbcbed5e..1396ae9c4 100644 --- a/src/opt/opt_context.cpp +++ b/src/opt/opt_context.cpp @@ -145,9 +145,8 @@ namespace opt { } void context::reset_maxsmts() { - map_t::iterator it = m_maxsmts.begin(), end = m_maxsmts.end(); - for (; it != end; ++it) { - dealloc(it->m_value); + for (auto& kv : m_maxsmts) { + dealloc(kv.m_value); } m_maxsmts.reset(); } @@ -255,6 +254,9 @@ namespace opt { if (m_pareto) { return execute_pareto(); } + if (m_lns) { + return execute_lns(); + } if (m_box_index != UINT_MAX) { return execute_box(); } @@ -271,10 +273,16 @@ namespace opt { #endif solver& s = get_solver(); s.assert_expr(m_hard_constraints); - IF_VERBOSE(1, verbose_stream() << "(optimize:check-sat)\n";); + + opt_params optp(m_params); + symbol pri = optp.priority(); + if (pri == symbol("lns")) { + return execute_lns(); + } + + IF_VERBOSE(1, verbose_stream() << "(optimize:check-sat)\n"); lbool is_sat = s.check_sat(0,0); - TRACE("opt", tout << "initial search result: " << is_sat << "\n"; - s.display(tout);); + TRACE("opt", s.display(tout << "initial search result: " << is_sat << "\n");); if (is_sat != l_false) { s.get_model(m_model); s.get_labels(m_labels); @@ -286,7 +294,7 @@ namespace opt { TRACE("opt", tout << m_hard_constraints << "\n";); return is_sat; } - IF_VERBOSE(1, verbose_stream() << "(optimize:sat)\n";); + IF_VERBOSE(1, verbose_stream() << "(optimize:sat)\n"); TRACE("opt", model_smt2_pp(tout, m, *m_model, 0);); m_optsmt.setup(*m_opt_solver.get()); update_lower(); @@ -303,6 +311,9 @@ namespace opt { if (pri == symbol("pareto")) { is_sat = execute_pareto(); } + else if (pri == symbol("lns")) { + is_sat = execute_lns(); + } else if (pri == symbol("box")) { is_sat = execute_box(); } @@ -525,7 +536,12 @@ namespace opt { } void context::yield() { - m_pareto->get_model(m_model, m_labels); + if (m_pareto) { + m_pareto->get_model(m_model, m_labels); + } + else if (m_lns) { + m_lns->get_model(m_model, m_labels); + } update_bound(true); update_bound(false); } @@ -536,7 +552,7 @@ namespace opt { } lbool is_sat = (*(m_pareto.get()))(); if (is_sat != l_true) { - set_pareto(0); + set_pareto(nullptr); } if (is_sat == l_true) { yield(); @@ -544,6 +560,20 @@ namespace opt { return is_sat; } + lbool context::execute_lns() { + if (!m_lns) { + m_lns = alloc(lns, *this, m_solver.get()); + } + lbool is_sat = (*(m_lns.get()))(); + if (is_sat != l_true) { + m_lns = nullptr; + } + if (is_sat == l_true) { + yield(); + } + return l_undef; + } + std::string context::reason_unknown() const { if (m.canceled()) { return Z3_CANCELED_MSG; @@ -990,6 +1020,24 @@ namespace opt { } } + /** + \brief retrieve literals used by the neighborhood search feature. + */ + + void context::get_lns_literals(expr_ref_vector& lits) { + for (objective & obj : m_objectives) { + switch(obj.m_type) { + case O_MAXSMT: + for (expr* f : obj.m_terms) { + lits.push_back(f); + } + break; + default: + break; + } + } + } + bool context::verify_model(unsigned index, model* md, rational const& _v) { rational r; app_ref term = m_objectives[index].m_term; @@ -1352,7 +1400,8 @@ namespace opt { } void context::clear_state() { - set_pareto(0); + m_pareto = nullptr; + m_lns = nullptr; m_box_index = UINT_MAX; m_model.reset(); } @@ -1388,9 +1437,8 @@ namespace opt { m_solver->updt_params(m_params); } m_optsmt.updt_params(m_params); - map_t::iterator it = m_maxsmts.begin(), end = m_maxsmts.end(); - for (; it != end; ++it) { - it->m_value->updt_params(m_params); + for (auto & kv : m_maxsmts) { + kv.m_value->updt_params(m_params); } opt_params _p(p); m_enable_sat = _p.enable_sat(); diff --git a/src/opt/opt_context.h b/src/opt/opt_context.h index e4d1f8e2d..9f55d5ca1 100644 --- a/src/opt/opt_context.h +++ b/src/opt/opt_context.h @@ -19,16 +19,18 @@ Notes: #define OPT_CONTEXT_H_ #include "ast/ast.h" +#include "ast/arith_decl_plugin.h" +#include "ast/bv_decl_plugin.h" +#include "tactic/model_converter.h" +#include "tactic/tactic.h" +#include "qe/qsat.h" #include "opt/opt_solver.h" #include "opt/opt_pareto.h" #include "opt/optsmt.h" +#include "opt/opt_lns.h" #include "opt/maxsmt.h" -#include "tactic/model_converter.h" -#include "tactic/tactic.h" -#include "ast/arith_decl_plugin.h" -#include "ast/bv_decl_plugin.h" #include "cmd_context/cmd_context.h" -#include "qe/qsat.h" + namespace opt { @@ -145,6 +147,7 @@ namespace opt { ref m_solver; ref m_sat_solver; scoped_ptr m_pareto; + scoped_ptr m_lns; scoped_ptr m_qmax; sref_vector m_box_models; unsigned m_box_index; @@ -231,6 +234,8 @@ namespace opt { virtual bool verify_model(unsigned id, model* mdl, rational const& v); + void get_lns_literals(expr_ref_vector& lits); + private: lbool execute(objective const& obj, bool committed, bool scoped); lbool execute_min_max(unsigned index, bool committed, bool scoped, bool is_max); @@ -238,6 +243,7 @@ namespace opt { lbool execute_lex(); lbool execute_box(); lbool execute_pareto(); + lbool execute_lns(); lbool adjust_unknown(lbool r); bool scoped_lex(); expr_ref to_expr(inf_eps const& n); @@ -282,7 +288,7 @@ namespace opt { void setup_arith_solver(); void add_maxsmt(symbol const& id, unsigned index); void set_simplify(tactic *simplify); - void set_pareto(pareto_base* p); + void set_pareto(pareto_base* p); void clear_state(); bool is_numeral(expr* e, rational& n) const; diff --git a/src/opt/opt_lns.cpp b/src/opt/opt_lns.cpp new file mode 100644 index 000000000..d8692ab59 --- /dev/null +++ b/src/opt/opt_lns.cpp @@ -0,0 +1,115 @@ +/*++ +Copyright (c) 2018 Microsoft Corporation + +Module Name: + + opt_lns.cpp + +Abstract: + + Large neighborhood search default implementation + based on phase saving and assumptions + +Author: + + Nikolaj Bjorner (nbjorner) 2018-3-13 + +Notes: + + +--*/ + +#include "opt/opt_lns.h" +#include "opt/opt_context.h" + +namespace opt { + + lns::lns(context& ctx, solver* s): + m(ctx.get_manager()), + m_ctx(ctx), + m_solver(s), + m_models_trail(m) + {} + + lns::~lns() {} + + void lns::display(std::ostream & out) const { + for (auto const& q : m_queue) { + out << q.m_index << ": " << q.m_assignment << "\n"; + } + } + + lbool lns::operator()() { + + if (m_queue.empty()) { + expr_ref_vector lits(m); + m_ctx.get_lns_literals(lits); + m_queue.push_back(queue_elem(lits)); + m_qhead = 0; + } + + params_ref p; + p.set_uint("sat.inprocess.max", 3); + p.set_uint("smt.max_conflicts", 10000); + m_solver->updt_params(p); + + while (m_qhead < m_queue.size()) { + unsigned index = m_queue[m_qhead].m_index; + if (index > m_queue[m_qhead].m_assignment.size()) { + ++m_qhead; + continue; + } + IF_VERBOSE(2, verbose_stream() << "(opt.lns :queue " << m_qhead << " :index " << index << ")\n"); + + // recalibrate state to an initial satisfying assignment + lbool is_sat = m_solver->check_sat(m_queue[m_qhead].m_assignment); + IF_VERBOSE(2, verbose_stream() << "(opt.lns :calibrate-status " << is_sat << ")\n"); + + expr_ref lit(m_queue[m_qhead].m_assignment[index].get(), m); + lit = mk_not(m, lit); + expr* lits[1] = { lit }; + ++m_queue[m_qhead].m_index; + if (!m.limit().inc()) { + return l_undef; + } + + // Update configuration for local search: + // p.set_uint("sat.local_search_threads", 2); + // p.set_uint("sat.unit_walk_threads", 1); + + is_sat = m_solver->check_sat(1, lits); + IF_VERBOSE(2, verbose_stream() << "(opt.lns :status " << is_sat << ")\n"); + if (is_sat == l_true && add_assignment()) { + return l_true; + } + } + return l_false; + } + + bool lns::add_assignment() { + model_ref mdl; + m_solver->get_model(mdl); + m_ctx.fix_model(mdl); + expr_ref tmp(m); + expr_ref_vector fmls(m); + for (expr* f : m_queue[0].m_assignment) { + mdl->eval(f, tmp); + if (m.is_false(tmp)) { + fmls.push_back(mk_not(m, tmp)); + } + else { + fmls.push_back(tmp); + } + } + tmp = mk_and(fmls); + if (m_models.contains(tmp)) { + return false; + } + else { + m_models.insert(tmp); + m_models_trail.push_back(tmp); + return true; + } + } +} + diff --git a/src/opt/opt_lns.h b/src/opt/opt_lns.h new file mode 100644 index 000000000..8033a8ea8 --- /dev/null +++ b/src/opt/opt_lns.h @@ -0,0 +1,66 @@ +/*++ +Copyright (c) 2018 Microsoft Corporation + +Module Name: + + opt_lns.h + +Abstract: + + Large neighborhood seearch + +Author: + + Nikolaj Bjorner (nbjorner) 2018-3-13 + +Notes: + + +--*/ +#ifndef OPT_LNS_H_ +#define OPT_LNS_H_ + +#include "solver/solver.h" +#include "model/model.h" + +namespace opt { + + class context; + + class lns { + struct queue_elem { + expr_ref_vector m_assignment; + unsigned m_index; + queue_elem(expr_ref_vector& assign): + m_assignment(assign), + m_index(0) + {} + }; + ast_manager& m; + context& m_ctx; + ref m_solver; + model_ref m_model; + svector m_labels; + vector m_queue; + unsigned m_qhead; + expr_ref_vector m_models_trail; + obj_hashtable m_models; + + bool add_assignment(); + public: + lns(context& ctx, solver* s); + + ~lns(); + + void display(std::ostream & out) const; + + lbool operator()(); + + void get_model(model_ref& mdl, svector& labels) { + mdl = m_model; + labels = m_labels; + } + }; +} + +#endif diff --git a/src/opt/opt_params.pyg b/src/opt/opt_params.pyg index cfcc5e47e..21845d38a 100644 --- a/src/opt/opt_params.pyg +++ b/src/opt/opt_params.pyg @@ -3,7 +3,7 @@ def_module_params('opt', export=True, params=(('optsmt_engine', SYMBOL, 'basic', "select optimization engine: 'basic', 'farkas', 'symba'"), ('maxsat_engine', SYMBOL, 'maxres', "select engine for maxsat: 'core_maxsat', 'wmax', 'maxres', 'pd-maxres'"), - ('priority', SYMBOL, 'lex', "select how to priortize objectives: 'lex' (lexicographic), 'pareto', or 'box'"), + ('priority', SYMBOL, 'lex', "select how to priortize objectives: 'lex' (lexicographic), 'pareto', 'box', or 'lns' (large neighborhood search)"), ('dump_benchmarks', BOOL, False, 'dump benchmarks for profiling'), ('timeout', UINT, UINT_MAX, 'timeout (in milliseconds) (UINT_MAX and 0 mean no timeout)'), ('rlimit', UINT, 0, 'resource limit (0 means no limit)'), diff --git a/src/sat/sat_scc.cpp b/src/sat/sat_scc.cpp index 5fed5e0f6..8a2029df2 100644 --- a/src/sat/sat_scc.cpp +++ b/src/sat/sat_scc.cpp @@ -222,7 +222,7 @@ namespace sat { } } TRACE("scc", for (unsigned i = 0; i < roots.size(); i++) { tout << i << " -> " << roots[i] << "\n"; } - tout << "to_elim: "; for (literal l : to_elim) tout << l << " "; tout << "\n";); + tout << "to_elim: "; for (unsigned v : to_elim) tout << v << " "; tout << "\n";); m_num_elim += to_elim.size(); elim_eqs eliminator(m_solver); eliminator(roots, to_elim);