diff --git a/src/opt/maxlex.cpp b/src/opt/maxlex.cpp index 82eee6f5e..fa97359d0 100644 --- a/src/opt/maxlex.cpp +++ b/src/opt/maxlex.cpp @@ -186,7 +186,7 @@ namespace opt { public: maxlex(maxsat_context& c, unsigned id, vector& s): - maxsmt_solver_base(c, s), + maxsmt_solver_base(c, s, id), m(c.get_manager()), m_c(c) { // ensure that soft constraints are sorted with largest soft constraints first. diff --git a/src/opt/maxres.cpp b/src/opt/maxres.cpp index 0d1bf2394..135e0f80e 100644 --- a/src/opt/maxres.cpp +++ b/src/opt/maxres.cpp @@ -95,7 +95,6 @@ private: expr_ref_vector const& soft() override { return i.m_asms; } }; - unsigned m_index; stats m_stats; expr_ref_vector m_B; expr_ref_vector m_asms; @@ -132,8 +131,7 @@ public: maxres(maxsat_context& c, unsigned index, vector& soft, strategy_t st): - maxsmt_solver_base(c, soft), - m_index(index), + maxsmt_solver_base(c, soft, index), m_B(m), m_asms(m), m_defs(m), m_new_core(m), m_mus(c.get_solver()), diff --git a/src/opt/maxsmt.cpp b/src/opt/maxsmt.cpp index dab9ae445..a3d5f2f45 100644 --- a/src/opt/maxsmt.cpp +++ b/src/opt/maxsmt.cpp @@ -35,10 +35,10 @@ Notes: namespace opt { - maxsmt_solver_base::maxsmt_solver_base( - maxsat_context& c, vector& s): + maxsmt_solver_base::maxsmt_solver_base(maxsat_context& c, vector& s, unsigned index): m(c.get_manager()), m_c(c), + m_index(index), m_soft(s), m_assertions(m), m_trail(m) { @@ -91,18 +91,17 @@ namespace opt { m_upper += s.weight; } - return true; + // return true; preprocess pp(s()); rational lower(0); bool r = pp(m_soft, lower); - - if (lower != 0) - m_adjust_value->set_offset(lower + m_adjust_value->get_offset()); + m_c.add_offset(m_index, lower); + m_upper -= lower; TRACE("opt", - tout << "upper: " << m_upper << " assignments: "; + tout << "lower " << lower << " upper: " << m_upper << " assignments: "; for (soft& s : m_soft) tout << (s.is_true()?"T":"F"); tout << "\n";); return r; @@ -169,8 +168,8 @@ namespace opt { void maxsmt_solver_base::trace_bounds(char const * solver) { IF_VERBOSE(1, - rational l = (*m_adjust_value)(m_lower); - rational u = (*m_adjust_value)(m_upper); + rational l = m_c.adjust(m_index, m_lower); + rational u = m_c.adjust(m_index, m_upper); if (l > u) std::swap(l, u); verbose_stream() << "(opt." << solver << " [" << l << ":" << u << "])\n";); } @@ -196,10 +195,10 @@ namespace opt { m_msolver = mk_primal_dual_maxres(m_c, m_index, m_soft); } else if (maxsat_engine == symbol("wmax")) { - m_msolver = mk_wmax(m_c, m_soft); + m_msolver = mk_wmax(m_c, m_soft, m_index); } else if (maxsat_engine == symbol("sortmax")) { - m_msolver = mk_sortmax(m_c, m_soft); + m_msolver = mk_sortmax(m_c, m_soft, m_index); } else { auto str = maxsat_engine.str(); @@ -209,7 +208,6 @@ namespace opt { if (m_msolver) { m_msolver->updt_params(m_params); - m_msolver->set_adjust_value(*m_adjust_value); is_sat = l_undef; try { is_sat = (*m_msolver)(); @@ -233,13 +231,6 @@ namespace opt { return is_sat; } - void maxsmt::set_adjust_value(adjust_value& adj) { - m_adjust_value = &adj; - if (m_msolver) { - m_msolver->set_adjust_value(adj); - } - } - void maxsmt::reset_upper() { if (m_msolver) { m_msolver->reset_upper(); @@ -268,7 +259,7 @@ namespace opt { rational q = m_msolver->get_lower(); if (q > r) r = q; } - return (*m_adjust_value)(r); + return m_c.adjust(m_index, r); } rational maxsmt::get_upper() const { @@ -277,7 +268,7 @@ namespace opt { rational q = m_msolver->get_upper(); if (q < r) r = q; } - return (*m_adjust_value)(r); + return m_c.adjust(m_index, r); } void maxsmt::update_lower(rational const& r) { @@ -370,6 +361,7 @@ namespace opt { model_ref m_model; ref m_fm; symbol m_maxsat_engine; + vector m_offsets; public: solver_maxsat_context(params_ref& p, solver* s, model * m): m_params(p), @@ -394,6 +386,14 @@ namespace opt { bool verify_model(unsigned id, model* mdl, rational const& v) override { return true; }; void set_model(model_ref& _m) override { m_model = _m; } void model_updated(model* mdl) override { } // no-op + rational adjust(unsigned id, rational const& r) override { + m_offsets.reserve(id+1); + return r + m_offsets[id]; + } + void add_offset(unsigned id, rational const& r) override { + m_offsets.reserve(id+1); + m_offsets[id] += r; + } }; lbool maxsmt_wrapper::operator()(vector>& soft) { diff --git a/src/opt/maxsmt.h b/src/opt/maxsmt.h index 5a3dc4ca8..b0ae5eeb1 100644 --- a/src/opt/maxsmt.h +++ b/src/opt/maxsmt.h @@ -34,8 +34,6 @@ namespace opt { class maxsat_context; class maxsmt_solver { - protected: - adjust_value* m_adjust_value = nullptr; public: virtual ~maxsmt_solver() {} virtual lbool operator()() = 0; @@ -45,7 +43,6 @@ namespace opt { virtual void collect_statistics(statistics& st) const = 0; virtual void get_model(model_ref& mdl, svector& labels) = 0; virtual void updt_params(params_ref& p) = 0; - void set_adjust_value(adjust_value& adj) { m_adjust_value = &adj; } }; @@ -67,7 +64,8 @@ namespace opt { class maxsmt_solver_base : public maxsmt_solver { protected: ast_manager& m; - maxsat_context& m_c; + maxsat_context& m_c; + unsigned m_index; vector& m_soft; expr_ref_vector m_assertions; expr_ref_vector m_trail; @@ -78,7 +76,7 @@ namespace opt { params_ref m_params; // config public: - maxsmt_solver_base(maxsat_context& c, vector& soft); + maxsmt_solver_base(maxsat_context& c, vector& soft, unsigned index); ~maxsmt_solver_base() override {} rational get_lower() const override { return m_lower; } @@ -128,7 +126,6 @@ namespace opt { expr_ref_vector m_answer; rational m_lower; rational m_upper; - adjust_value* m_adjust_value = nullptr; model_ref m_model; svector m_labels; params_ref m_params; @@ -137,7 +134,6 @@ namespace opt { lbool operator()(); void updt_params(params_ref& p); void add(expr* f, rational const& w); - void set_adjust_value(adjust_value& adj); unsigned size() const { return m_soft.size(); } expr* operator[](unsigned idx) const { return m_soft[idx].s; } rational weight(unsigned idx) const { return m_soft[idx].weight; } diff --git a/src/opt/opt_context.cpp b/src/opt/opt_context.cpp index 8848b61c9..25982e89e 100644 --- a/src/opt/opt_context.cpp +++ b/src/opt/opt_context.cpp @@ -398,7 +398,7 @@ namespace opt { } void context::set_model(model_ref& m) { - m_model = m; + m_model = m; opt_params optp(m_params); if (optp.dump_models() && m) { model_ref md = m->copy(); @@ -930,7 +930,8 @@ namespace opt { bool context::is_maxsat(expr* fml, expr_ref_vector& terms, vector& weights, rational& offset, bool& neg, symbol& id, expr_ref& orig_term, unsigned& index) { - if (!is_app(fml)) return false; + if (!is_app(fml)) + return false; neg = false; orig_term = nullptr; index = 0; @@ -1105,8 +1106,7 @@ namespace opt { obj.m_weights.append(weights); obj.m_adjust_value.set_offset(offset); obj.m_adjust_value.set_negate(neg); - m_maxsmts.find(id)->set_adjust_value(obj.m_adjust_value); - TRACE("opt", tout << "maxsat: " << id << " offset:" << offset << "\n"; + TRACE("opt", tout << "maxsat: " << neg << " " << id << " offset: " << offset << "\n"; tout << terms << "\n";); } else if (is_maximize(fml, tr, orig_term, index)) { @@ -1158,7 +1158,14 @@ namespace opt { #endif } + rational context::adjust(unsigned id, rational const& v) { + return m_objectives[id].m_adjust_value(v); + } + void context::add_offset(unsigned id, rational const& o) { + m_objectives[id].m_adjust_value.add_offset(o); + } + bool context::verify_model(unsigned index, model* md, rational const& _v) { rational r; app_ref term = m_objectives[index].m_term; @@ -1341,24 +1348,21 @@ namespace opt { break; } case O_MAXSMT: { - bool ok = true; - for (unsigned j = 0; ok && j < obj.m_terms.size(); ++j) { + for (unsigned j = 0; j < obj.m_terms.size(); ++j) { val = (*m_model)(obj.m_terms[j]); TRACE("opt", tout << mk_pp(obj.m_terms[j], m) << " " << val << "\n";); - if (!m.is_true(val)) { + if (!m.is_true(val)) r += obj.m_weights[j]; - } } - if (ok) { - maxsmt& ms = *m_maxsmts.find(obj.m_id); - if (is_lower) { - ms.update_upper(r); - TRACE("opt", tout << "update upper from " << r << " to " << ms.get_upper() << "\n";); - } - else { - ms.update_lower(r); - TRACE("opt", tout << "update lower from " << r << " to " << ms.get_lower() << "\n";); - } + + maxsmt& ms = *m_maxsmts.find(obj.m_id); + if (is_lower) { + ms.update_upper(r); + TRACE("opt", tout << "update upper from " << r << " to " << ms.get_upper() << "\n";); + } + else { + ms.update_lower(r); + TRACE("opt", tout << "update lower from " << r << " to " << ms.get_lower() << "\n";); } break; } diff --git a/src/opt/opt_context.h b/src/opt/opt_context.h index 1e950174a..c02689a38 100644 --- a/src/opt/opt_context.h +++ b/src/opt/opt_context.h @@ -57,6 +57,8 @@ namespace opt { virtual smt::context& smt_context() = 0; // access SMT context for SMT based MaxSMT solver (wmax requires SMT core) virtual unsigned num_objectives() = 0; virtual bool verify_model(unsigned id, model* mdl, rational const& v) = 0; + virtual rational adjust(unsigned id, rational const& v) = 0; + virtual void add_offset(unsigned id, rational const& o) = 0; virtual void set_model(model_ref& _m) = 0; virtual void model_updated(model* mdl) = 0; }; @@ -93,7 +95,7 @@ namespace opt { app_ref m_term; // for maximize, minimize term expr_ref_vector m_terms; // for maxsmt vector m_weights; // for maxsmt - adjust_value m_adjust_value; + adjust_value m_adjust_value; symbol m_id; // for maxsmt unsigned m_index; // for maximize/minimize index @@ -269,11 +271,14 @@ namespace opt { void model_updated(model* mdl) override; + rational adjust(unsigned id, rational const& v) override; + + void add_offset(unsigned id, rational const& o) override; + void register_on_model(on_model_t& ctx, std::function& on_model) { m_on_model_ctx = ctx; m_on_model_eh = on_model; } - void collect_timer_stats(statistics& st) const { if (m_time != 0) diff --git a/src/opt/opt_solver.h b/src/opt/opt_solver.h index caac008fd..47fe86f94 100644 --- a/src/opt/opt_solver.h +++ b/src/opt/opt_solver.h @@ -48,6 +48,7 @@ namespace opt { void set_offset(rational const& o) { m_offset = o; } void set_negate(bool neg) { m_negate = neg; } rational const& get_offset() const { return m_offset; } + void add_offset(rational const& o) { if (m_negate) m_offset -= o; else m_offset += o; } bool get_negate() { return m_negate; } inf_eps operator()(inf_eps const& r) const { inf_eps result = r; diff --git a/src/opt/sortmax.cpp b/src/opt/sortmax.cpp index a6539e129..962369bf2 100644 --- a/src/opt/sortmax.cpp +++ b/src/opt/sortmax.cpp @@ -36,8 +36,8 @@ namespace opt { expr_ref_vector m_trail; func_decl_ref_vector m_fresh; ref m_filter; - sortmax(maxsat_context& c, vector& s): - maxsmt_solver_base(c, s), m_sort(*this), m_trail(m), m_fresh(m) {} + sortmax(maxsat_context& c, vector& s, unsigned index): + maxsmt_solver_base(c, s, index), m_sort(*this), m_trail(m), m_fresh(m) {} ~sortmax() override {} @@ -138,8 +138,8 @@ namespace opt { }; - maxsmt_solver_base* mk_sortmax(maxsat_context& c, vector& s) { - return alloc(sortmax, c, s); + maxsmt_solver_base* mk_sortmax(maxsat_context& c, vector& s, unsigned index) { + return alloc(sortmax, c, s, index); } } diff --git a/src/opt/wmax.cpp b/src/opt/wmax.cpp index 812c8f954..1fbd26cb8 100644 --- a/src/opt/wmax.cpp +++ b/src/opt/wmax.cpp @@ -44,8 +44,8 @@ namespace opt { } public: - wmax(maxsat_context& c, vector& s): - maxsmt_solver_base(c, s), + wmax(maxsat_context& c, vector& s, unsigned index): + maxsmt_solver_base(c, s, index), m_trail(m), m_defs(m) {} @@ -304,8 +304,8 @@ namespace opt { }; - maxsmt_solver_base* mk_wmax(maxsat_context& c, vector & s) { - return alloc(wmax, c, s); + maxsmt_solver_base* mk_wmax(maxsat_context& c, vector & s, unsigned index) { + return alloc(wmax, c, s, index); } } diff --git a/src/opt/wmax.h b/src/opt/wmax.h index 0a5167269..7f5b26ac6 100644 --- a/src/opt/wmax.h +++ b/src/opt/wmax.h @@ -22,8 +22,8 @@ Notes: #include "opt/maxsmt.h" namespace opt { - maxsmt_solver_base* mk_wmax(maxsat_context& c, vector& s); + maxsmt_solver_base* mk_wmax(maxsat_context& c, vector& s, unsigned index); - maxsmt_solver_base* mk_sortmax(maxsat_context& c, vector& s); + maxsmt_solver_base* mk_sortmax(maxsat_context& c, vector& s, unsigned index); }