mirror of
https://github.com/Z3Prover/z3
synced 2025-04-08 10:25:18 +00:00
extract labels for optimal model. Fix to #325
Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
parent
9cba63c31f
commit
c58e640563
|
@ -191,7 +191,7 @@ namespace opt {
|
|||
m_msolver->set_adjust_value(m_adjust_value);
|
||||
is_sat = (*m_msolver)();
|
||||
if (is_sat != l_false) {
|
||||
m_msolver->get_model(m_model);
|
||||
m_msolver->get_model(m_model, m_labels);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -247,8 +247,9 @@ namespace opt {
|
|||
m_upper = r;
|
||||
}
|
||||
|
||||
void maxsmt::get_model(model_ref& mdl) {
|
||||
void maxsmt::get_model(model_ref& mdl, svector<symbol>& labels) {
|
||||
mdl = m_model.get();
|
||||
labels = m_labels;
|
||||
}
|
||||
|
||||
void maxsmt::commit_assignment() {
|
||||
|
|
|
@ -46,7 +46,7 @@ namespace opt {
|
|||
virtual bool get_assignment(unsigned index) const = 0;
|
||||
virtual void set_cancel(bool f) = 0;
|
||||
virtual void collect_statistics(statistics& st) const = 0;
|
||||
virtual void get_model(model_ref& mdl) = 0;
|
||||
virtual void get_model(model_ref& mdl, svector<symbol>& labels) = 0;
|
||||
virtual void updt_params(params_ref& p) = 0;
|
||||
void set_adjust_value(adjust_value& adj) { m_adjust_value = adj; }
|
||||
|
||||
|
@ -67,6 +67,7 @@ namespace opt {
|
|||
rational m_lower;
|
||||
rational m_upper;
|
||||
model_ref m_model;
|
||||
svector<symbol> m_labels;
|
||||
svector<bool> m_assignment; // truth assignment to soft constraints
|
||||
params_ref m_params; // config
|
||||
|
||||
|
@ -79,9 +80,9 @@ namespace opt {
|
|||
virtual bool get_assignment(unsigned index) const { return m_assignment[index]; }
|
||||
virtual void set_cancel(bool f) { m_cancel = f; if (f) s().cancel(); else s().reset_cancel(); }
|
||||
virtual void collect_statistics(statistics& st) const { }
|
||||
virtual void get_model(model_ref& mdl) { mdl = m_model.get(); }
|
||||
virtual void get_model(model_ref& mdl, svector<symbol>& labels) { mdl = m_model.get(); labels = m_labels;}
|
||||
virtual void commit_assignment();
|
||||
void set_model() { s().get_model(m_model); }
|
||||
void set_model() { s().get_model(m_model); s().get_labels(m_labels); }
|
||||
virtual void updt_params(params_ref& p);
|
||||
solver& s();
|
||||
void init();
|
||||
|
@ -122,6 +123,7 @@ namespace opt {
|
|||
rational m_upper;
|
||||
adjust_value m_adjust_value;
|
||||
model_ref m_model;
|
||||
svector<symbol> m_labels;
|
||||
params_ref m_params;
|
||||
public:
|
||||
maxsmt(maxsat_context& c);
|
||||
|
@ -139,7 +141,7 @@ namespace opt {
|
|||
rational get_upper() const;
|
||||
void update_lower(rational const& r);
|
||||
void update_upper(rational const& r);
|
||||
void get_model(model_ref& mdl);
|
||||
void get_model(model_ref& mdl, svector<symbol>& labels);
|
||||
bool get_assignment(unsigned index) const;
|
||||
void display_answer(std::ostream& out) const;
|
||||
void collect_statistics(statistics& st) const;
|
||||
|
|
|
@ -168,9 +168,7 @@ namespace opt {
|
|||
}
|
||||
|
||||
void context::get_labels(svector<symbol> & r) {
|
||||
if (m_solver) {
|
||||
m_solver->get_labels(r);
|
||||
}
|
||||
r.append(m_labels);
|
||||
}
|
||||
|
||||
void context::set_hard_constraints(ptr_vector<expr>& fmls) {
|
||||
|
@ -234,6 +232,7 @@ namespace opt {
|
|||
TRACE("opt", tout << "initial search result: " << is_sat << "\n";);
|
||||
if (is_sat != l_false) {
|
||||
s.get_model(m_model);
|
||||
s.get_labels(m_labels);
|
||||
}
|
||||
if (is_sat != l_true) {
|
||||
return is_sat;
|
||||
|
@ -282,11 +281,6 @@ namespace opt {
|
|||
}
|
||||
}
|
||||
|
||||
void context::set_model(model_ref& mdl) {
|
||||
m_model = mdl;
|
||||
fix_model(mdl);
|
||||
}
|
||||
|
||||
void context::get_model(model_ref& mdl) {
|
||||
mdl = m_model;
|
||||
fix_model(mdl);
|
||||
|
@ -295,7 +289,7 @@ namespace opt {
|
|||
lbool context::execute_min_max(unsigned index, bool committed, bool scoped, bool is_max) {
|
||||
if (scoped) get_solver().push();
|
||||
lbool result = m_optsmt.lex(index, is_max);
|
||||
if (result == l_true) m_optsmt.get_model(m_model);
|
||||
if (result == l_true) m_optsmt.get_model(m_model, m_labels);
|
||||
if (scoped) get_solver().pop(1);
|
||||
if (result == l_true && committed) m_optsmt.commit_assignment(index);
|
||||
return result;
|
||||
|
@ -306,7 +300,7 @@ namespace opt {
|
|||
maxsmt& ms = *m_maxsmts.find(id);
|
||||
if (scoped) get_solver().push();
|
||||
lbool result = ms();
|
||||
if (result != l_false && (ms.get_model(tmp), tmp.get())) ms.get_model(m_model);
|
||||
if (result != l_false && (ms.get_model(tmp, m_labels), tmp.get())) ms.get_model(m_model, m_labels);
|
||||
if (scoped) get_solver().pop(1);
|
||||
if (result == l_true && committed) ms.commit_assignment();
|
||||
return result;
|
||||
|
@ -459,7 +453,7 @@ namespace opt {
|
|||
}
|
||||
|
||||
void context::yield() {
|
||||
m_pareto->get_model(m_model);
|
||||
m_pareto->get_model(m_model, m_labels);
|
||||
update_bound(true);
|
||||
update_bound(false);
|
||||
}
|
||||
|
|
|
@ -162,6 +162,7 @@ namespace opt {
|
|||
bool m_pp_neat;
|
||||
symbol m_maxsat_engine;
|
||||
symbol m_logic;
|
||||
svector<symbol> m_labels;
|
||||
public:
|
||||
context(ast_manager& m);
|
||||
virtual ~context();
|
||||
|
@ -180,7 +181,6 @@ namespace opt {
|
|||
virtual lbool optimize();
|
||||
virtual bool print_model() const;
|
||||
virtual void get_model(model_ref& _m);
|
||||
virtual void set_model(model_ref& _m);
|
||||
virtual void fix_model(model_ref& _m);
|
||||
virtual void collect_statistics(statistics& stats) const;
|
||||
virtual proof* get_proof() { return 0; }
|
||||
|
|
|
@ -38,6 +38,7 @@ namespace opt {
|
|||
return l_undef;
|
||||
}
|
||||
m_solver->get_model(m_model);
|
||||
m_solver->get_labels(m_labels);
|
||||
IF_VERBOSE(1,
|
||||
model_ref mdl(m_model);
|
||||
cb.fix_model(mdl);
|
||||
|
@ -96,6 +97,7 @@ namespace opt {
|
|||
}
|
||||
if (is_sat == l_true) {
|
||||
m_solver->get_model(m_model);
|
||||
m_solver->get_labels(m_labels);
|
||||
mk_not_dominated_by();
|
||||
}
|
||||
return is_sat;
|
||||
|
|
|
@ -31,7 +31,6 @@ namespace opt {
|
|||
virtual expr_ref mk_gt(unsigned i, model_ref& model) = 0;
|
||||
virtual expr_ref mk_ge(unsigned i, model_ref& model) = 0;
|
||||
virtual expr_ref mk_le(unsigned i, model_ref& model) = 0;
|
||||
virtual void set_model(model_ref& m) = 0;
|
||||
virtual void fix_model(model_ref& m) = 0;
|
||||
};
|
||||
class pareto_base {
|
||||
|
@ -42,6 +41,7 @@ namespace opt {
|
|||
ref<solver> m_solver;
|
||||
params_ref m_params;
|
||||
model_ref m_model;
|
||||
svector<symbol> m_labels;
|
||||
public:
|
||||
pareto_base(
|
||||
ast_manager & m,
|
||||
|
@ -77,8 +77,9 @@ namespace opt {
|
|||
}
|
||||
virtual lbool operator()() = 0;
|
||||
|
||||
virtual void get_model(model_ref& mdl) {
|
||||
virtual void get_model(model_ref& mdl, svector<symbol>& labels) {
|
||||
mdl = m_model;
|
||||
labels = m_labels;
|
||||
}
|
||||
|
||||
protected:
|
||||
|
|
|
@ -51,6 +51,7 @@ namespace opt {
|
|||
if (src[i] >= dst[i]) {
|
||||
dst[i] = src[i];
|
||||
m_models.set(i, m_s->get_model(i));
|
||||
m_s->get_labels(m_labels);
|
||||
m_lower_fmls[i] = fmls[i].get();
|
||||
if (dst[i].is_pos() && !dst[i].is_finite()) { // review: likely done already.
|
||||
m_lower_fmls[i] = m.mk_false();
|
||||
|
@ -156,7 +157,8 @@ namespace opt {
|
|||
if (is_sat == l_true) {
|
||||
disj.reset();
|
||||
m_s->maximize_objectives(disj);
|
||||
m_s->get_model(m_model);
|
||||
m_s->get_model(m_model);
|
||||
m_s->get_labels(m_labels);
|
||||
for (unsigned i = 0; i < ors.size(); ++i) {
|
||||
expr_ref tmp(m);
|
||||
m_model->eval(ors[i].get(), tmp);
|
||||
|
@ -203,6 +205,7 @@ namespace opt {
|
|||
expr_ref optsmt::update_lower() {
|
||||
expr_ref_vector disj(m);
|
||||
m_s->get_model(m_model);
|
||||
m_s->get_labels(m_labels);
|
||||
m_s->maximize_objectives(disj);
|
||||
set_max(m_lower, m_s->get_objective_values(), disj);
|
||||
TRACE("opt",
|
||||
|
@ -331,6 +334,7 @@ namespace opt {
|
|||
|
||||
m_s->maximize_objective(obj_index, block);
|
||||
m_s->get_model(m_model);
|
||||
m_s->get_labels(m_labels);
|
||||
inf_eps obj = m_s->saved_objective_value(obj_index);
|
||||
if (obj > m_lower[obj_index]) {
|
||||
m_lower[obj_index] = obj;
|
||||
|
@ -405,8 +409,9 @@ namespace opt {
|
|||
return m_upper[i];
|
||||
}
|
||||
|
||||
void optsmt::get_model(model_ref& mdl) {
|
||||
void optsmt::get_model(model_ref& mdl, svector<symbol> & labels) {
|
||||
mdl = m_model.get();
|
||||
labels = m_labels;
|
||||
}
|
||||
|
||||
// force lower_bound(i) <= objective_value(i)
|
||||
|
|
|
@ -38,6 +38,7 @@ namespace opt {
|
|||
svector<smt::theory_var> m_vars;
|
||||
symbol m_optsmt_engine;
|
||||
model_ref m_model;
|
||||
svector<symbol> m_labels;
|
||||
sref_vector<model> m_models;
|
||||
public:
|
||||
optsmt(ast_manager& m):
|
||||
|
@ -60,7 +61,7 @@ namespace opt {
|
|||
inf_eps get_lower(unsigned index) const;
|
||||
inf_eps get_upper(unsigned index) const;
|
||||
bool objective_is_model_valid(unsigned index) const;
|
||||
void get_model(model_ref& mdl);
|
||||
void get_model(model_ref& mdl, svector<symbol>& labels);
|
||||
model* get_model(unsigned index) const { return m_models[index]; }
|
||||
|
||||
void update_lower(unsigned idx, inf_eps const& r);
|
||||
|
|
|
@ -250,7 +250,6 @@ public:
|
|||
return "no reason given";
|
||||
}
|
||||
virtual void get_labels(svector<symbol> & r) {
|
||||
UNREACHABLE();
|
||||
}
|
||||
virtual unsigned get_num_assertions() const {
|
||||
return m_fmls.size();
|
||||
|
|
Loading…
Reference in a new issue