From c58e640563424747ecf086339aeb2f6437b46567 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Wed, 18 Nov 2015 14:53:08 -0800 Subject: [PATCH] extract labels for optimal model. Fix to #325 Signed-off-by: Nikolaj Bjorner --- src/opt/maxsmt.cpp | 5 +++-- src/opt/maxsmt.h | 10 ++++++---- src/opt/opt_context.cpp | 16 +++++----------- src/opt/opt_context.h | 2 +- src/opt/opt_pareto.cpp | 2 ++ src/opt/opt_pareto.h | 5 +++-- src/opt/optsmt.cpp | 9 +++++++-- src/opt/optsmt.h | 3 ++- src/sat/sat_solver/inc_sat_solver.cpp | 1 - 9 files changed, 29 insertions(+), 24 deletions(-) diff --git a/src/opt/maxsmt.cpp b/src/opt/maxsmt.cpp index 1bb521923..cb598617c 100644 --- a/src/opt/maxsmt.cpp +++ b/src/opt/maxsmt.cpp @@ -191,7 +191,7 @@ namespace opt { m_msolver->set_adjust_value(m_adjust_value); is_sat = (*m_msolver)(); if (is_sat != l_false) { - m_msolver->get_model(m_model); + m_msolver->get_model(m_model, m_labels); } } @@ -247,8 +247,9 @@ namespace opt { m_upper = r; } - void maxsmt::get_model(model_ref& mdl) { + void maxsmt::get_model(model_ref& mdl, svector& labels) { mdl = m_model.get(); + labels = m_labels; } void maxsmt::commit_assignment() { diff --git a/src/opt/maxsmt.h b/src/opt/maxsmt.h index 8d290a640..7dbf763e2 100644 --- a/src/opt/maxsmt.h +++ b/src/opt/maxsmt.h @@ -46,7 +46,7 @@ namespace opt { virtual bool get_assignment(unsigned index) const = 0; virtual void set_cancel(bool f) = 0; virtual void collect_statistics(statistics& st) const = 0; - virtual void get_model(model_ref& mdl) = 0; + virtual void get_model(model_ref& mdl, svector& labels) = 0; virtual void updt_params(params_ref& p) = 0; void set_adjust_value(adjust_value& adj) { m_adjust_value = adj; } @@ -67,6 +67,7 @@ namespace opt { rational m_lower; rational m_upper; model_ref m_model; + svector m_labels; svector m_assignment; // truth assignment to soft constraints params_ref m_params; // config @@ -79,9 +80,9 @@ namespace opt { virtual bool get_assignment(unsigned index) const { return m_assignment[index]; } virtual void set_cancel(bool f) { m_cancel = f; if (f) s().cancel(); else s().reset_cancel(); } virtual void collect_statistics(statistics& st) const { } - virtual void get_model(model_ref& mdl) { mdl = m_model.get(); } + virtual void get_model(model_ref& mdl, svector& labels) { mdl = m_model.get(); labels = m_labels;} virtual void commit_assignment(); - void set_model() { s().get_model(m_model); } + void set_model() { s().get_model(m_model); s().get_labels(m_labels); } virtual void updt_params(params_ref& p); solver& s(); void init(); @@ -122,6 +123,7 @@ namespace opt { rational m_upper; adjust_value m_adjust_value; model_ref m_model; + svector m_labels; params_ref m_params; public: maxsmt(maxsat_context& c); @@ -139,7 +141,7 @@ namespace opt { rational get_upper() const; void update_lower(rational const& r); void update_upper(rational const& r); - void get_model(model_ref& mdl); + void get_model(model_ref& mdl, svector& labels); bool get_assignment(unsigned index) const; void display_answer(std::ostream& out) const; void collect_statistics(statistics& st) const; diff --git a/src/opt/opt_context.cpp b/src/opt/opt_context.cpp index 265e3b7c8..1b33c2cf1 100644 --- a/src/opt/opt_context.cpp +++ b/src/opt/opt_context.cpp @@ -168,9 +168,7 @@ namespace opt { } void context::get_labels(svector & r) { - if (m_solver) { - m_solver->get_labels(r); - } + r.append(m_labels); } void context::set_hard_constraints(ptr_vector& fmls) { @@ -234,6 +232,7 @@ namespace opt { TRACE("opt", tout << "initial search result: " << is_sat << "\n";); if (is_sat != l_false) { s.get_model(m_model); + s.get_labels(m_labels); } if (is_sat != l_true) { return is_sat; @@ -282,11 +281,6 @@ namespace opt { } } - void context::set_model(model_ref& mdl) { - m_model = mdl; - fix_model(mdl); - } - void context::get_model(model_ref& mdl) { mdl = m_model; fix_model(mdl); @@ -295,7 +289,7 @@ namespace opt { lbool context::execute_min_max(unsigned index, bool committed, bool scoped, bool is_max) { if (scoped) get_solver().push(); lbool result = m_optsmt.lex(index, is_max); - if (result == l_true) m_optsmt.get_model(m_model); + if (result == l_true) m_optsmt.get_model(m_model, m_labels); if (scoped) get_solver().pop(1); if (result == l_true && committed) m_optsmt.commit_assignment(index); return result; @@ -306,7 +300,7 @@ namespace opt { maxsmt& ms = *m_maxsmts.find(id); if (scoped) get_solver().push(); lbool result = ms(); - if (result != l_false && (ms.get_model(tmp), tmp.get())) ms.get_model(m_model); + if (result != l_false && (ms.get_model(tmp, m_labels), tmp.get())) ms.get_model(m_model, m_labels); if (scoped) get_solver().pop(1); if (result == l_true && committed) ms.commit_assignment(); return result; @@ -459,7 +453,7 @@ namespace opt { } void context::yield() { - m_pareto->get_model(m_model); + m_pareto->get_model(m_model, m_labels); update_bound(true); update_bound(false); } diff --git a/src/opt/opt_context.h b/src/opt/opt_context.h index b364fe63e..dc180e8e6 100644 --- a/src/opt/opt_context.h +++ b/src/opt/opt_context.h @@ -162,6 +162,7 @@ namespace opt { bool m_pp_neat; symbol m_maxsat_engine; symbol m_logic; + svector m_labels; public: context(ast_manager& m); virtual ~context(); @@ -180,7 +181,6 @@ namespace opt { virtual lbool optimize(); virtual bool print_model() const; virtual void get_model(model_ref& _m); - virtual void set_model(model_ref& _m); virtual void fix_model(model_ref& _m); virtual void collect_statistics(statistics& stats) const; virtual proof* get_proof() { return 0; } diff --git a/src/opt/opt_pareto.cpp b/src/opt/opt_pareto.cpp index 5e1d4b269..dea744a72 100644 --- a/src/opt/opt_pareto.cpp +++ b/src/opt/opt_pareto.cpp @@ -38,6 +38,7 @@ namespace opt { return l_undef; } m_solver->get_model(m_model); + m_solver->get_labels(m_labels); IF_VERBOSE(1, model_ref mdl(m_model); cb.fix_model(mdl); @@ -96,6 +97,7 @@ namespace opt { } if (is_sat == l_true) { m_solver->get_model(m_model); + m_solver->get_labels(m_labels); mk_not_dominated_by(); } return is_sat; diff --git a/src/opt/opt_pareto.h b/src/opt/opt_pareto.h index 88a3ce9c1..122f29c55 100644 --- a/src/opt/opt_pareto.h +++ b/src/opt/opt_pareto.h @@ -31,7 +31,6 @@ namespace opt { virtual expr_ref mk_gt(unsigned i, model_ref& model) = 0; virtual expr_ref mk_ge(unsigned i, model_ref& model) = 0; virtual expr_ref mk_le(unsigned i, model_ref& model) = 0; - virtual void set_model(model_ref& m) = 0; virtual void fix_model(model_ref& m) = 0; }; class pareto_base { @@ -42,6 +41,7 @@ namespace opt { ref m_solver; params_ref m_params; model_ref m_model; + svector m_labels; public: pareto_base( ast_manager & m, @@ -77,8 +77,9 @@ namespace opt { } virtual lbool operator()() = 0; - virtual void get_model(model_ref& mdl) { + virtual void get_model(model_ref& mdl, svector& labels) { mdl = m_model; + labels = m_labels; } protected: diff --git a/src/opt/optsmt.cpp b/src/opt/optsmt.cpp index 02effa337..c9ee61ebc 100644 --- a/src/opt/optsmt.cpp +++ b/src/opt/optsmt.cpp @@ -51,6 +51,7 @@ namespace opt { if (src[i] >= dst[i]) { dst[i] = src[i]; m_models.set(i, m_s->get_model(i)); + m_s->get_labels(m_labels); m_lower_fmls[i] = fmls[i].get(); if (dst[i].is_pos() && !dst[i].is_finite()) { // review: likely done already. m_lower_fmls[i] = m.mk_false(); @@ -156,7 +157,8 @@ namespace opt { if (is_sat == l_true) { disj.reset(); m_s->maximize_objectives(disj); - m_s->get_model(m_model); + m_s->get_model(m_model); + m_s->get_labels(m_labels); for (unsigned i = 0; i < ors.size(); ++i) { expr_ref tmp(m); m_model->eval(ors[i].get(), tmp); @@ -203,6 +205,7 @@ namespace opt { expr_ref optsmt::update_lower() { expr_ref_vector disj(m); m_s->get_model(m_model); + m_s->get_labels(m_labels); m_s->maximize_objectives(disj); set_max(m_lower, m_s->get_objective_values(), disj); TRACE("opt", @@ -331,6 +334,7 @@ namespace opt { m_s->maximize_objective(obj_index, block); m_s->get_model(m_model); + m_s->get_labels(m_labels); inf_eps obj = m_s->saved_objective_value(obj_index); if (obj > m_lower[obj_index]) { m_lower[obj_index] = obj; @@ -405,8 +409,9 @@ namespace opt { return m_upper[i]; } - void optsmt::get_model(model_ref& mdl) { + void optsmt::get_model(model_ref& mdl, svector & labels) { mdl = m_model.get(); + labels = m_labels; } // force lower_bound(i) <= objective_value(i) diff --git a/src/opt/optsmt.h b/src/opt/optsmt.h index 10ea9f11c..f4efa25f9 100644 --- a/src/opt/optsmt.h +++ b/src/opt/optsmt.h @@ -38,6 +38,7 @@ namespace opt { svector m_vars; symbol m_optsmt_engine; model_ref m_model; + svector m_labels; sref_vector m_models; public: optsmt(ast_manager& m): @@ -60,7 +61,7 @@ namespace opt { inf_eps get_lower(unsigned index) const; inf_eps get_upper(unsigned index) const; bool objective_is_model_valid(unsigned index) const; - void get_model(model_ref& mdl); + void get_model(model_ref& mdl, svector& labels); model* get_model(unsigned index) const { return m_models[index]; } void update_lower(unsigned idx, inf_eps const& r); diff --git a/src/sat/sat_solver/inc_sat_solver.cpp b/src/sat/sat_solver/inc_sat_solver.cpp index b0d771ba8..43b2c2e3f 100644 --- a/src/sat/sat_solver/inc_sat_solver.cpp +++ b/src/sat/sat_solver/inc_sat_solver.cpp @@ -250,7 +250,6 @@ public: return "no reason given"; } virtual void get_labels(svector & r) { - UNREACHABLE(); } virtual unsigned get_num_assertions() const { return m_fmls.size();