From 18b491eee0bb9c403cfb3323ac272f0adbdcca44 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Wed, 3 Sep 2014 10:03:56 -0700 Subject: [PATCH] fixes to maxres/mss Signed-off-by: Nikolaj Bjorner --- src/opt/inc_sat_solver.cpp | 50 ++++++++--------- src/opt/maxres.cpp | 104 ++++++++++-------------------------- src/opt/mus.cpp | 1 + src/opt/opt_context.cpp | 6 ++- src/sat/sat_mus.cpp | 8 +++ src/sat/sat_mus.h | 1 + src/sat/sat_sls.cpp | 8 +-- src/sat/sat_sls.h | 2 +- src/sat/sat_solver.cpp | 2 +- src/sat/tactic/goal2sat.cpp | 8 +-- 10 files changed, 76 insertions(+), 114 deletions(-) diff --git a/src/opt/inc_sat_solver.cpp b/src/opt/inc_sat_solver.cpp index 9e041e0a3..11aa24d71 100644 --- a/src/opt/inc_sat_solver.cpp +++ b/src/opt/inc_sat_solver.cpp @@ -37,6 +37,7 @@ class inc_sat_solver : public solver { sat::solver m_solver; goal2sat m_goal2sat; params_ref m_params; + bool m_optimize_model; // parameter expr_ref_vector m_fmls; expr_ref_vector m_current_fmls; unsigned_vector m_fmls_lim; @@ -54,10 +55,12 @@ class inc_sat_solver : public solver { expr_ref_vector m_soft; vector m_weights; + typedef obj_map dep2asm_t; public: inc_sat_solver(ast_manager& m, params_ref const& p): - m(m), m_solver(p,0), m_params(p), + m(m), m_solver(p,0), + m_params(p), m_optimize_model(false), m_fmls(m), m_current_fmls(m), m_core(m), m_map(m), m_num_scopes(0), m_dep_core(m), @@ -165,6 +168,7 @@ public: m_params = p; m_params.set_bool("elim_vars", false); m_solver.updt_params(m_params); + m_optimize_model = m_params.get_bool("optimize_model", false); } virtual void collect_statistics(statistics & st) const { m_preprocess->collect_statistics(st); @@ -226,32 +230,26 @@ private: lbool r = internalize_formulas(); if (r != l_true) return r; r = internalize_assumptions(soft.size(), soft.c_ptr(), dep2asm); + if (r != l_true) return r; sat::literal_vector lits; svector weights; sat::literal lit; - - if (r == l_true) { - for (unsigned i = 0; i < soft.size(); ++i) { - weights.push_back(m_weights[i].get_double()); - expr* s = soft[i].get(); - bool is_neg = m.is_not(s, s); - if (!dep2asm.find(s, lit)) { - std::cout << "not found: " << mk_pp(s, m) << "\n"; - dep2asm_t::iterator it = dep2asm.begin(), end = dep2asm.end(); - for (; it != end; ++it) { - std::cout << mk_pp(it->m_key, m) << " " << it->m_value << "\n"; - } - UNREACHABLE(); - } - if (is_neg) { - lit.neg(); - } - lits.push_back(lit); + for (unsigned i = 0; i < soft.size(); ++i) { + weights.push_back(m_weights[i].get_double()); + expr* s = soft[i].get(); + if (!dep2asm.find(s, lit)) { + IF_VERBOSE(0, + verbose_stream() << "not found: " << mk_pp(s, m) << "\n"; + dep2asm_t::iterator it = dep2asm.begin(); + dep2asm_t::iterator end = dep2asm.end(); + for (; it != end; ++it) { + verbose_stream() << mk_pp(it->m_key, m) << " " << it->m_value << "\n"; + } + UNREACHABLE();); } - m_solver.initialize_soft(lits.size(), lits.c_ptr(), weights.c_ptr()); - m_params.set_bool("optimize_model", true); - m_solver.updt_params(m_params); + lits.push_back(lit); } + m_solver.initialize_soft(lits.size(), lits.c_ptr(), weights.c_ptr()); return r; } @@ -324,12 +322,8 @@ private: m_core.reset(); for (unsigned i = 0; i < core.size(); ++i) { expr* e; - if (asm2dep.find(core[i].index(), e)) { - if (core[i].sign()) { - e = m.mk_not(e); - } - m_core.push_back(e); - } + VERIFY (asm2dep.find(core[i].index(), e)); + m_core.push_back(e); } TRACE("opt", dep2asm_t::iterator it = dep2asm.begin(); diff --git a/src/opt/maxres.cpp b/src/opt/maxres.cpp index 7befdd235..f879c8934 100644 --- a/src/opt/maxres.cpp +++ b/src/opt/maxres.cpp @@ -72,14 +72,12 @@ public: enum strategy_t { s_mus, s_mus_mss, - s_mus_mss2, s_mss }; private: expr_ref_vector m_B; expr_ref_vector m_asms; obj_map m_asm2weight; - obj_map m_asm2value; ptr_vector m_new_core; mus m_mus; mss m_mss; @@ -186,44 +184,6 @@ public: return l_true; } - lbool mus_mss_solver() { - init(); - init_local(); - sls(); - exprs mcs; - vector cores; - while (m_lower < m_upper) { - TRACE("opt", - display_vec(tout, m_asms.size(), m_asms.c_ptr()); - s().display(tout); - tout << "\n"; - display(tout); - ); - lbool is_sat = try_improve_bound(cores, mcs); - if (m_cancel) { - return l_undef; - } - switch (is_sat) { - case l_undef: - return l_undef; - case l_false: - SASSERT(cores.empty() && mcs.empty()); - m_lower = m_upper; - return l_true; - case l_true: - SASSERT(cores.empty() || mcs.empty()); - for (unsigned i = 0; i < cores.size(); ++i) { - process_unsat(cores[i]); - } - if (cores.empty()) { - process_sat(mcs); - } - break; - } - } - m_lower = m_upper; - return l_true; - } lbool mss_solver() { init(); @@ -234,6 +194,9 @@ public: lbool is_sat = l_true; while (m_lower < m_upper && is_sat == l_true) { IF_VERBOSE(1, verbose_stream() << "(opt.maxres [" << m_lower << ":" << m_upper << "])\n";); + if (m_cancel) { + return l_undef; + } vector cores; exprs mss; model_ref mdl; @@ -241,6 +204,12 @@ public: mcs.reset(); s().get_model(mdl); update_assignment(mdl.get()); + + exprs cs; + get_current_correction_set(mdl.get(), cs); + process_sat(cs); + +#if 0 is_sat = get_mss(mdl.get(), cores, mss, mcs); switch (is_sat) { @@ -249,15 +218,12 @@ public: case l_false: m_lower = m_upper; return l_true; - case l_true: { + case l_true: process_sat(mcs); get_mss_model(); - break; - } - } - if (m_cancel) { - return l_undef; + break; } +#endif if (m_lower < m_upper) { is_sat = s().check_sat(0, 0); } @@ -294,7 +260,7 @@ public: Suppose correction set is huge. Do we really need it? */ - lbool mus_mss2_solver() { + lbool mus_mss_solver() { init(); init_local(); sls(); @@ -356,14 +322,8 @@ public: // obtained from the current best model. // - // - // TBD: throttle blocking on correction sets if they are too big. - // likewise, if the cores are too big, don't block the cores. - // - - exprs cs; - get_current_correction_set(cs); + get_current_correction_set(mdl.get(), cs); unsigned max_core = max_core_size(cores); if (!cs.empty() && cs.size() < max_core) { process_sat(cs); @@ -379,7 +339,6 @@ public: void found_optimum() { s().get_model(m_model); - m_asm2value.reset(); DEBUG_CODE( for (unsigned i = 0; i < m_asms.size(); ++i) { SASSERT(is_true(m_asms[i].get())); @@ -397,8 +356,6 @@ public: return mus_solver(); case s_mus_mss: return mus_mss_solver(); - case s_mus_mss2: - return mus_mss2_solver(); case s_mss: return mss_solver(); } @@ -462,14 +419,14 @@ public: return is_sat; } - void get_current_correction_set(exprs& cs) { + void get_current_correction_set(model* mdl, exprs& cs) { cs.reset(); + if (!mdl) return; for (unsigned i = 0; i < m_asms.size(); ++i) { - if (!is_true(m_asms[i].get())) { + if (!is_true(mdl, m_asms[i].get())) { cs.push_back(m_asms[i].get()); } } - IF_VERBOSE(2, verbose_stream() << "(opt.maxres correction set size: " << cs.size() << ")\n";); TRACE("opt", display_vec(tout << "new correction set: ", cs.size(), cs.c_ptr());); } @@ -554,7 +511,7 @@ public: IF_VERBOSE(1, verbose_stream() << "(opt.maxres [" << m_lower << ":" << m_upper << "])\n";); } - void get_mus_model(model_ref& mdl) { + bool get_mus_model(model_ref& mdl) { rational w(0); if (m_c.sat_enabled()) { // SAT solver core extracts some model @@ -567,6 +524,7 @@ public: if (mdl.get() && w < m_upper) { update_assignment(mdl.get()); } + return 0 != mdl.get(); } void get_mss_model() { @@ -682,17 +640,13 @@ public: s().assert_expr(fml); fml = m.mk_implies(dd, b_i); s().assert_expr(fml); - m_asm2value.insert(dd, is_true(d) && is_true(b_i)); d = dd; } else { - dd = m.mk_and(b_i, d); - m_asm2value.insert(dd, is_true(d) && is_true(b_i)); - m_trail.push_back(dd); - d = dd; + d = m.mk_and(b_i, d); + m_trail.push_back(d); } asum = mk_fresh_bool("a"); - m_asm2value.insert(asum, is_true(b_i1) || is_true(d)); cls = m.mk_or(b_i1, d); fml = m.mk_implies(asum, cls); new_assumption(asum, w); @@ -809,7 +763,6 @@ public: return l_undef; } - void update_assignment(model* mdl) { rational upper(0); expr_ref tmp(m); @@ -826,7 +779,6 @@ public: return; } m_model = mdl; - m_asm2value.reset(); for (unsigned i = 0; i < m_soft.size(); ++i) { m_assignment[i] = is_true(m_soft[i].get()); @@ -851,16 +803,16 @@ public: s().assert_expr(fml); } - bool is_true(expr* e) { - bool truth_value; - if (m_asm2value.find(e, truth_value)) { - return truth_value; - } + bool is_true(model* mdl, expr* e) { expr_ref tmp(m); - VERIFY(m_model->eval(e, tmp)); + VERIFY(mdl->eval(e, tmp)); return m.is_true(tmp); } + bool is_true(expr* e) { + return is_true(m_model.get(), e); + } + void remove_soft(exprs const& core, expr_ref_vector& asms) { for (unsigned i = 0; i < asms.size(); ++i) { if (core.contains(asms[i].get())) { @@ -933,7 +885,7 @@ opt::maxsmt_solver_base* opt::mk_maxres( opt::maxsmt_solver_base* opt::mk_mus_mss_maxres( context& c, weights_t& ws, expr_ref_vector const& soft) { - return alloc(maxres, c, ws, soft, maxres::s_mus_mss2); + return alloc(maxres, c, ws, soft, maxres::s_mus_mss); } opt::maxsmt_solver_base* opt::mk_mss_maxres( diff --git a/src/opt/mus.cpp b/src/opt/mus.cpp index 932d9ccf7..abb0b326e 100644 --- a/src/opt/mus.cpp +++ b/src/opt/mus.cpp @@ -70,6 +70,7 @@ struct mus::imp { lbool get_mus(unsigned_vector& mus) { // SASSERT: mus does not have duplicates. + m_model.reset(); unsigned_vector core; for (unsigned i = 0; i < m_cls2expr.size(); ++i) { core.push_back(i); diff --git a/src/opt/opt_context.cpp b/src/opt/opt_context.cpp index ffb1f815d..caa71eeaa 100644 --- a/src/opt/opt_context.cpp +++ b/src/opt/opt_context.cpp @@ -465,9 +465,13 @@ namespace opt { void context::enable_sls(expr_ref_vector const& soft, vector const& weights) { SASSERT(soft.size() == weights.size()); - if (m_enable_sls && m_sat_solver.get()) { + if (m_sat_solver.get()) { set_soft_inc_sat(m_sat_solver.get(), soft.size(), soft.c_ptr(), weights.c_ptr()); } + if (m_enable_sls && m_sat_solver.get()) { + m_params.set_bool("optimize_model", true); + m_sat_solver->updt_params(m_params); + } } struct context::is_bv { diff --git a/src/sat/sat_mus.cpp b/src/sat/sat_mus.cpp index 035423a19..4468f5b34 100644 --- a/src/sat/sat_mus.cpp +++ b/src/sat/sat_mus.cpp @@ -32,6 +32,7 @@ namespace sat { m_core.reset(); m_mus.reset(); m_model.reset(); + m_best_value = 0; } void mus::set_core() { @@ -96,8 +97,15 @@ namespace sat { if (!core.empty()) { // mr(); // TBD: measure } + double new_value = s.m_wsls.evaluate_model(s.m_model); if (m_model.empty()) { m_model.append(s.m_model); + m_best_value = new_value; + } + else if (m_best_value > new_value) { + m_model.reset(); + m_model.append(s.m_model); + m_best_value = new_value; } break; } diff --git a/src/sat/sat_mus.h b/src/sat/sat_mus.h index b68f4ee5c..eede15c63 100644 --- a/src/sat/sat_mus.h +++ b/src/sat/sat_mus.h @@ -25,6 +25,7 @@ namespace sat { literal_vector m_mus; bool m_is_active; model m_model; // model obtained during minimal unsat core + double m_best_value; solver& s; diff --git a/src/sat/sat_sls.cpp b/src/sat/sat_sls.cpp index 3a520742e..bd0f8855a 100644 --- a/src/sat/sat_sls.cpp +++ b/src/sat/sat_sls.cpp @@ -354,7 +354,7 @@ namespace sat { // // Initialize m_clause_weights, m_hscore, m_sscore. // - m_best_value = m_false.empty()?evaluate_model():-1.0; + m_best_value = m_false.empty()?evaluate_model(m_model):-1.0; m_best_model.reset(); m_clause_weights.reset(); m_hscore.reset(); @@ -382,7 +382,7 @@ namespace sat { for (; !m_cancel && m_best_value > 0 && i < m_max_tries; ++i) { wflip(); if (m_false.empty()) { - double val = evaluate_model(); + double val = evaluate_model(m_model); if (val < m_best_value || m_best_value < 0.0) { m_best_value = val; m_best_model.reset(); @@ -511,12 +511,12 @@ namespace sat { DEBUG_CODE(check_invariant();); } - double wsls::evaluate_model() { + double wsls::evaluate_model(model& mdl) { SASSERT(m_false.empty()); double result = 0.0; for (unsigned i = 0; i < m_soft.size(); ++i) { literal lit = m_soft[i]; - if (value_at(lit, m_model) != l_true) { + if (value_at(lit, mdl) != l_true) { result += m_weights[i]; } } diff --git a/src/sat/sat_sls.h b/src/sat/sat_sls.h index 8efc337cc..530d4be0c 100644 --- a/src/sat/sat_sls.h +++ b/src/sat/sat_sls.h @@ -99,12 +99,12 @@ namespace sat { void opt(unsigned sz, literal const* tabu, bool reuse_model); model const& get_model() { return m_best_model; } virtual void display(std::ostream& out) const; + double evaluate_model(model& mdl); private: void wflip(); void wflip(literal lit); void update_hard_weights(); bool pick_wflip(literal & lit); - double evaluate_model(); virtual void check_invariant(); void refresh_scores(bool_var v); int compute_hscore(bool_var v); diff --git a/src/sat/sat_solver.cpp b/src/sat/sat_solver.cpp index 1d3f4f68c..4b1bd32b1 100644 --- a/src/sat/sat_solver.cpp +++ b/src/sat/sat_solver.cpp @@ -1737,7 +1737,7 @@ namespace sat { m_mus(); // ignore return value on cancelation. m_model.reset(); - m_model.append(m_mus.get_model()); + m_model.append(m_mus.get_model()); } } diff --git a/src/sat/tactic/goal2sat.cpp b/src/sat/tactic/goal2sat.cpp index 38a22095a..7224de971 100644 --- a/src/sat/tactic/goal2sat.cpp +++ b/src/sat/tactic/goal2sat.cpp @@ -372,7 +372,9 @@ struct goal2sat::imp { SASSERT(m_result_stack.empty()); } - void insert_dep(expr* dep, bool sign) { + void insert_dep(expr* dep0, expr* dep, bool sign) { + SASSERT(sign || dep0 == dep); // !sign || (not dep0) == dep. + SASSERT(!sign || m.is_not(dep0)); expr_ref new_dep(m), fml(m); if (is_uninterp_const(dep)) { new_dep = dep; @@ -386,7 +388,7 @@ struct goal2sat::imp { } convert_atom(new_dep, false, false); sat::literal lit = m_result_stack.back(); - m_dep2asm.insert(dep, sign?(~lit):lit); + m_dep2asm.insert(dep0, sign?~lit:lit); m_result_stack.pop_back(); } @@ -411,7 +413,7 @@ struct goal2sat::imp { SASSERT(m.is_bool(d)); bool sign = m.is_not(d, d1); - insert_dep(d1, sign); + insert_dep(d, d1, sign); if (d == f) { goto skip_dep; }