From 3f5ed8ff1109754a2972da626dbd1d6e0c7e3265 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sat, 19 Apr 2014 20:27:39 -0700 Subject: [PATCH] coallesce common code Signed-off-by: Nikolaj Bjorner --- src/opt/opt_params.pyg | 1 + src/opt/opt_sls_solver.h | 73 +++++++++++++++---- src/opt/pb_sls.cpp | 2 +- src/opt/pb_sls.h | 2 +- src/opt/weighted_maxsat.cpp | 140 ++++++++++-------------------------- 5 files changed, 97 insertions(+), 121 deletions(-) diff --git a/src/opt/opt_params.pyg b/src/opt/opt_params.pyg index eb0a02752..201a1601b 100644 --- a/src/opt/opt_params.pyg +++ b/src/opt/opt_params.pyg @@ -12,6 +12,7 @@ def_module_params('opt', ('wmaxsat_engine', SYMBOL, 'wmax', "weighted maxsat engine: 'wmax', 'pbmax', 'bcd2', 'wpm2', 'bvsls', 'sls'"), ('enable_sls', BOOL, False, 'enable SLS tuning during weighted maxsast'), ('enable_sat', BOOL, False, 'enable the new SAT core for propositional constraints'), + ('sls_engine', SYMBOL, 'bv', "SLS engine. Either 'bv' or 'pb'"), ('elim_01', BOOL, True, 'eliminate 01 variables'), ('pb.compile_equality', BOOL, False, 'compile arithmetical equalities into pseudo-Boolean equality (instead of two inequalites)') diff --git a/src/opt/opt_sls_solver.h b/src/opt/opt_sls_solver.h index 542728cc4..63810e974 100644 --- a/src/opt/opt_sls_solver.h +++ b/src/opt/opt_sls_solver.h @@ -22,23 +22,27 @@ Notes: #include "solver_na2as.h" #include "card2bv_tactic.h" +#include "pb_sls.h" namespace opt { class sls_solver : public solver_na2as { ast_manager& m; ref m_solver; - scoped_ptr m_sls; + scoped_ptr m_bvsls; + scoped_ptr m_pbsls; pb::card_pb_rewriter m_pb2bv; model_ref m_model; expr_ref m_objective; params_ref m_params; + symbol m_engine; public: sls_solver(ast_manager & m, solver* s, expr* to_maximize, params_ref const& p): solver_na2as(m), m(m), m_solver(s), - m_sls(0), + m_bvsls(0), + m_pbsls(0), m_pb2bv(m), m_objective(to_maximize, m) { @@ -48,13 +52,16 @@ namespace opt { virtual void updt_params(params_ref & p) { m_solver->updt_params(p); m_params.copy(p); + opt_params _p(p); + m_engine = _p.sls_engine(); } virtual void collect_param_descrs(param_descrs & r) { m_solver->collect_param_descrs(r); } virtual void collect_statistics(statistics & st) const { m_solver->collect_statistics(st); - // TBD: m_sls->get_stats(); + if (m_bvsls) m_bvsls->collect_statistics(st); + if (m_pbsls) m_pbsls->collect_statistics(st); } virtual void assert_expr(expr * t) { m_solver->assert_expr(t); @@ -79,8 +86,11 @@ namespace opt { m_pb2bv.set_cancel(f); #pragma omp critical (this) { - if (m_sls) { - m_sls->set_cancel(f); + if (m_bvsls) { + m_bvsls->set_cancel(f); + } + if (m_pbsls) { + m_pbsls->set_cancel(f); } } } @@ -95,7 +105,7 @@ namespace opt { } virtual void display(std::ostream & out) const { m_solver->display(out); - // if (m_sls) m_sls->display(out); + // if (m_bvsls) m_bvsls->display(out); } protected: @@ -105,15 +115,11 @@ namespace opt { lbool r = m_solver->check_sat(num_assumptions, assumptions); if (r == l_true) { m_solver->get_model(m_model); - #pragma omp critical (this) - { - m_sls = alloc(bvsls_opt_engine, m, m_params); + if (m_engine == symbol("pb")) { + } - assertions2sls(); - opt_result or = m_sls->optimize(m_objective, m_model, true); - SASSERT(or.is_sat == l_true || or.is_sat == l_undef); - if (or.is_sat == l_true) { - m_sls->get_model(m_model); + else { + bvsls_opt(); } } return r; @@ -124,6 +130,7 @@ namespace opt { virtual void pop_core(unsigned n) { m_solver->pop(n); } + private: void assertions2sls() { expr_ref tmp(m); @@ -141,7 +148,43 @@ namespace opt { SASSERT(result.size() == 1); goal* r = result[0]; for (unsigned i = 0; i < r->size(); ++i) { - m_sls->assert_expr(r->form(i)); + m_bvsls->assert_expr(r->form(i)); + } + } + + void pbsls_opt() { + #pragma omp critical (this) + { + m_pbsls = alloc(smt::pb_sls, m); + } + m_pbsls->set_model(m_model); + m_pbsls->updt_params(m_params); + for (unsigned i = 0; i < m_solver->get_num_assertions(); ++i) { + m_pbsls->add(m_solver->get_assertion(i)); + } +#if 0 + TBD: + for (unsigned i = 0; i < m_num_soft; ++i) { + m_pbsls->add(m_soft[i].get(), m_weights[i].get()); + } +#endif + + lbool is_sat = (*m_pbsls.get())(); + if (is_sat == l_true) { + m_bvsls->get_model(m_model); + } + } + + void bvsls_opt() { + #pragma omp critical (this) + { + m_bvsls = alloc(bvsls_opt_engine, m, m_params); + } + assertions2sls(); + opt_result or = m_bvsls->optimize(m_objective, m_model, true); + SASSERT(or.is_sat == l_true || or.is_sat == l_undef); + if (or.is_sat == l_true) { + m_bvsls->get_model(m_model); } } diff --git a/src/opt/pb_sls.cpp b/src/opt/pb_sls.cpp index c7c0c1032..d000350e5 100644 --- a/src/opt/pb_sls.cpp +++ b/src/opt/pb_sls.cpp @@ -233,7 +233,7 @@ namespace smt { } } - void collect_statistics(statistics& st) const { + void collect_statistics(::statistics& st) const { } void updt_params(params_ref& p) { diff --git a/src/opt/pb_sls.h b/src/opt/pb_sls.h index a65ed83bf..c8d2c60c2 100644 --- a/src/opt/pb_sls.h +++ b/src/opt/pb_sls.h @@ -39,7 +39,7 @@ namespace smt { void set_model(model_ref& mdl); lbool operator()(); void set_cancel(bool f); - void collect_statistics(statistics& st) const; + void collect_statistics(::statistics& st) const; void get_model(model_ref& mdl); void updt_params(params_ref& p); }; diff --git a/src/opt/weighted_maxsat.cpp b/src/opt/weighted_maxsat.cpp index ea3515ec7..d9a5eae11 100644 --- a/src/opt/weighted_maxsat.cpp +++ b/src/opt/weighted_maxsat.cpp @@ -562,10 +562,15 @@ namespace opt { } }; + // ---------------------------------- + // incrementally add pseudo-boolean + // lower bounds. + class pbmax : public maxsmt_solver_base { + bool m_use_aux; public: - pbmax(solver* s, ast_manager& m): - maxsmt_solver_base(s, m) {} + pbmax(solver* s, ast_manager& m, bool use_aux): + maxsmt_solver_base(s, m), m_use_aux(use_aux) {} virtual ~pbmax() {} @@ -579,29 +584,41 @@ namespace opt { ); pb_util u(m); expr_ref fml(m), val(m); + app_ref b(m); expr_ref_vector nsoft(m); init(); - for (unsigned i = 0; i < m_soft.size(); ++i) { - nsoft.push_back(mk_not(m_soft[i].get())); + if (m_use_aux) { + s().push(); + } + for (unsigned i = 0; i < m_soft.size(); ++i) { + if (m_use_aux) { + b = m.mk_fresh_const("b", m.mk_bool_sort()); + m_mc->insert(b->get_decl()); + fml = m.mk_or(m_soft[i].get(), b); + s().assert_expr(fml); + nsoft.push_back(b); + } + else { + nsoft.push_back(mk_not(m_soft[i].get())); + } } - solver::scoped_push _s1(s()); lbool is_sat = l_true; bool was_sat = false; fml = m.mk_true(); while (l_true == is_sat) { TRACE("opt", s().display(tout<<"looping\n");); - solver::scoped_push _s2(s()); + solver::scoped_push _scope2(s()); s().assert_expr(fml); - is_sat = simplify_and_check_sat(); + is_sat = s().check_sat(0,0); if (m_cancel) { is_sat = l_undef; } if (is_sat == l_true) { - m_upper = rational::zero(); + m_upper.reset(); for (unsigned i = 0; i < m_soft.size(); ++i) { - VERIFY(m_model->eval(m_soft[i].get(), val)); + VERIFY(m_model->eval(nsoft[i].get(), val)); TRACE("opt", tout << "eval " << mk_pp(m_soft[i].get(), m) << " " << val << "\n";); - m_assignment[i] = m.is_true(val); + m_assignment[i] = !m.is_true(val); if (!m_assignment[i]) { m_upper += m_weights[i]; } @@ -616,46 +633,12 @@ namespace opt { is_sat = l_true; m_lower = m_upper; } + if (m_use_aux) { + s().pop(1); + } TRACE("opt", tout << "lower: " << m_lower << "\n";); return is_sat; } - - private: - lbool simplify_and_check_sat() { - lbool is_sat = l_true; - tactic_ref tac = mk_simplify_tactic(m); - // TBD: make tac attribute for cancelation. - proof_converter_ref pc; - expr_dependency_ref core(m); - model_converter_ref mc; - goal_ref_buffer result; - goal_ref g(alloc(goal, m, true, false)); - for (unsigned i = 0; i < s().get_num_assertions(); ++i) { - g->assert_expr(s().get_assertion(i)); - } - (*tac)(g, result, mc, pc, core); - if (result.empty()) { - is_sat = l_false; - } - else { - SASSERT(result.size() == 1); - goal_ref r = result[0]; - solver::scoped_push _s(s()); - for (unsigned i = 0; i < r->size(); ++i) { - s().assert_expr(r->form(i)); - } - is_sat = s().check_sat(0, 0); - if (l_true == is_sat && !m_cancel) { - s().get_model(m_model); - if (mc && m_model) (*mc)(m_model, 0); - IF_VERBOSE(2, - g->display(verbose_stream() << "goal:\n"); - r->display(verbose_stream() << "reduced:\n"); - model_smt2_pp(verbose_stream(), m, *m_model, 0);); - } - } - return is_sat; - } }; // ------------------------------------------------------ @@ -998,6 +981,9 @@ namespace opt { } }; + // ---------------------------------------------------------- + // weighted max-sat using a custom theory solver for max-sat. + // NB. it is quite similar to pseudo-Boolean propagation. class wmax : public maxsmt_solver_wbase { @@ -1046,60 +1032,6 @@ namespace opt { } }; - class pwmax : public maxsmt_solver_base { - public: - pwmax(solver* s, ast_manager& m): maxsmt_solver_base(s, m) {} - virtual ~pwmax() {} - lbool operator()() { - enable_bvsat(); - enable_sls(); - pb_util u(m); - expr_ref fml(m), val(m); - app_ref b(m); - expr_ref_vector nsoft(m); - solver::scoped_push __s(s()); - init(); - for (unsigned i = 0; i < m_soft.size(); ++i) { - b = m.mk_fresh_const("b", m.mk_bool_sort()); - m_mc->insert(b->get_decl()); - fml = m.mk_or(m_soft[i].get(), b); - s().assert_expr(fml); - nsoft.push_back(b); - } - lbool is_sat = l_true; - bool was_sat = false; - fml = m.mk_true(); - while (l_true == is_sat) { - solver::scoped_push _s(s()); - s().assert_expr(fml); - is_sat = s().check_sat(0,0); - if (m_cancel) { - is_sat = l_undef; - } - if (is_sat == l_true) { - s().get_model(m_model); - m_upper = rational::zero(); - for (unsigned i = 0; i < m_soft.size(); ++i) { - VERIFY(m_model->eval(nsoft[i].get(), val)); - m_assignment[i] = !m.is_true(val); - if (!m_assignment[i]) { - m_upper += m_weights[i]; - } - } - IF_VERBOSE(1, verbose_stream() << "(wmaxsat.pb with upper bound: " << m_upper << ")\n";); - fml = m.mk_not(u.mk_ge(nsoft.size(), m_weights.c_ptr(), nsoft.c_ptr(), m_upper)); - was_sat = true; - } - } - if (is_sat == l_false && was_sat) { - is_sat = l_true; - m_lower = m_upper; - } - return is_sat; - } - - }; - struct wmaxsmt::imp { ast_manager& m; ref s; // solver state that contains hard constraints @@ -1125,13 +1057,13 @@ namespace opt { return *m_maxsmt; } if (m_engine == symbol("pwmax")) { - m_maxsmt = alloc(pwmax, s.get(), m); + m_maxsmt = alloc(pbmax, s.get(), m, true); } else if (m_engine == symbol("pbmax")) { - m_maxsmt = alloc(pbmax, s.get(), m); + m_maxsmt = alloc(pbmax, s.get(), m, false); } else if (m_engine == symbol("wpm2")) { - maxsmt_solver_base* s2 = alloc(pbmax, s.get(), m); + maxsmt_solver_base* s2 = alloc(pbmax, s.get(), m, false); m_maxsmt = alloc(wpm2, s.get(), m, s2); } else if (m_engine == symbol("bcd2")) {