mirror of
https://github.com/Z3Prover/z3
synced 2025-07-19 19:02:02 +00:00
extract labels for optimal model. Fix to #325
Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
parent
9cba63c31f
commit
c58e640563
9 changed files with 29 additions and 24 deletions
|
@ -191,7 +191,7 @@ namespace opt {
|
||||||
m_msolver->set_adjust_value(m_adjust_value);
|
m_msolver->set_adjust_value(m_adjust_value);
|
||||||
is_sat = (*m_msolver)();
|
is_sat = (*m_msolver)();
|
||||||
if (is_sat != l_false) {
|
if (is_sat != l_false) {
|
||||||
m_msolver->get_model(m_model);
|
m_msolver->get_model(m_model, m_labels);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -247,8 +247,9 @@ namespace opt {
|
||||||
m_upper = r;
|
m_upper = r;
|
||||||
}
|
}
|
||||||
|
|
||||||
void maxsmt::get_model(model_ref& mdl) {
|
void maxsmt::get_model(model_ref& mdl, svector<symbol>& labels) {
|
||||||
mdl = m_model.get();
|
mdl = m_model.get();
|
||||||
|
labels = m_labels;
|
||||||
}
|
}
|
||||||
|
|
||||||
void maxsmt::commit_assignment() {
|
void maxsmt::commit_assignment() {
|
||||||
|
|
|
@ -46,7 +46,7 @@ namespace opt {
|
||||||
virtual bool get_assignment(unsigned index) const = 0;
|
virtual bool get_assignment(unsigned index) const = 0;
|
||||||
virtual void set_cancel(bool f) = 0;
|
virtual void set_cancel(bool f) = 0;
|
||||||
virtual void collect_statistics(statistics& st) const = 0;
|
virtual void collect_statistics(statistics& st) const = 0;
|
||||||
virtual void get_model(model_ref& mdl) = 0;
|
virtual void get_model(model_ref& mdl, svector<symbol>& labels) = 0;
|
||||||
virtual void updt_params(params_ref& p) = 0;
|
virtual void updt_params(params_ref& p) = 0;
|
||||||
void set_adjust_value(adjust_value& adj) { m_adjust_value = adj; }
|
void set_adjust_value(adjust_value& adj) { m_adjust_value = adj; }
|
||||||
|
|
||||||
|
@ -67,6 +67,7 @@ namespace opt {
|
||||||
rational m_lower;
|
rational m_lower;
|
||||||
rational m_upper;
|
rational m_upper;
|
||||||
model_ref m_model;
|
model_ref m_model;
|
||||||
|
svector<symbol> m_labels;
|
||||||
svector<bool> m_assignment; // truth assignment to soft constraints
|
svector<bool> m_assignment; // truth assignment to soft constraints
|
||||||
params_ref m_params; // config
|
params_ref m_params; // config
|
||||||
|
|
||||||
|
@ -79,9 +80,9 @@ namespace opt {
|
||||||
virtual bool get_assignment(unsigned index) const { return m_assignment[index]; }
|
virtual bool get_assignment(unsigned index) const { return m_assignment[index]; }
|
||||||
virtual void set_cancel(bool f) { m_cancel = f; if (f) s().cancel(); else s().reset_cancel(); }
|
virtual void set_cancel(bool f) { m_cancel = f; if (f) s().cancel(); else s().reset_cancel(); }
|
||||||
virtual void collect_statistics(statistics& st) const { }
|
virtual void collect_statistics(statistics& st) const { }
|
||||||
virtual void get_model(model_ref& mdl) { mdl = m_model.get(); }
|
virtual void get_model(model_ref& mdl, svector<symbol>& labels) { mdl = m_model.get(); labels = m_labels;}
|
||||||
virtual void commit_assignment();
|
virtual void commit_assignment();
|
||||||
void set_model() { s().get_model(m_model); }
|
void set_model() { s().get_model(m_model); s().get_labels(m_labels); }
|
||||||
virtual void updt_params(params_ref& p);
|
virtual void updt_params(params_ref& p);
|
||||||
solver& s();
|
solver& s();
|
||||||
void init();
|
void init();
|
||||||
|
@ -122,6 +123,7 @@ namespace opt {
|
||||||
rational m_upper;
|
rational m_upper;
|
||||||
adjust_value m_adjust_value;
|
adjust_value m_adjust_value;
|
||||||
model_ref m_model;
|
model_ref m_model;
|
||||||
|
svector<symbol> m_labels;
|
||||||
params_ref m_params;
|
params_ref m_params;
|
||||||
public:
|
public:
|
||||||
maxsmt(maxsat_context& c);
|
maxsmt(maxsat_context& c);
|
||||||
|
@ -139,7 +141,7 @@ namespace opt {
|
||||||
rational get_upper() const;
|
rational get_upper() const;
|
||||||
void update_lower(rational const& r);
|
void update_lower(rational const& r);
|
||||||
void update_upper(rational const& r);
|
void update_upper(rational const& r);
|
||||||
void get_model(model_ref& mdl);
|
void get_model(model_ref& mdl, svector<symbol>& labels);
|
||||||
bool get_assignment(unsigned index) const;
|
bool get_assignment(unsigned index) const;
|
||||||
void display_answer(std::ostream& out) const;
|
void display_answer(std::ostream& out) const;
|
||||||
void collect_statistics(statistics& st) const;
|
void collect_statistics(statistics& st) const;
|
||||||
|
|
|
@ -168,9 +168,7 @@ namespace opt {
|
||||||
}
|
}
|
||||||
|
|
||||||
void context::get_labels(svector<symbol> & r) {
|
void context::get_labels(svector<symbol> & r) {
|
||||||
if (m_solver) {
|
r.append(m_labels);
|
||||||
m_solver->get_labels(r);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void context::set_hard_constraints(ptr_vector<expr>& fmls) {
|
void context::set_hard_constraints(ptr_vector<expr>& fmls) {
|
||||||
|
@ -234,6 +232,7 @@ namespace opt {
|
||||||
TRACE("opt", tout << "initial search result: " << is_sat << "\n";);
|
TRACE("opt", tout << "initial search result: " << is_sat << "\n";);
|
||||||
if (is_sat != l_false) {
|
if (is_sat != l_false) {
|
||||||
s.get_model(m_model);
|
s.get_model(m_model);
|
||||||
|
s.get_labels(m_labels);
|
||||||
}
|
}
|
||||||
if (is_sat != l_true) {
|
if (is_sat != l_true) {
|
||||||
return is_sat;
|
return is_sat;
|
||||||
|
@ -282,11 +281,6 @@ namespace opt {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void context::set_model(model_ref& mdl) {
|
|
||||||
m_model = mdl;
|
|
||||||
fix_model(mdl);
|
|
||||||
}
|
|
||||||
|
|
||||||
void context::get_model(model_ref& mdl) {
|
void context::get_model(model_ref& mdl) {
|
||||||
mdl = m_model;
|
mdl = m_model;
|
||||||
fix_model(mdl);
|
fix_model(mdl);
|
||||||
|
@ -295,7 +289,7 @@ namespace opt {
|
||||||
lbool context::execute_min_max(unsigned index, bool committed, bool scoped, bool is_max) {
|
lbool context::execute_min_max(unsigned index, bool committed, bool scoped, bool is_max) {
|
||||||
if (scoped) get_solver().push();
|
if (scoped) get_solver().push();
|
||||||
lbool result = m_optsmt.lex(index, is_max);
|
lbool result = m_optsmt.lex(index, is_max);
|
||||||
if (result == l_true) m_optsmt.get_model(m_model);
|
if (result == l_true) m_optsmt.get_model(m_model, m_labels);
|
||||||
if (scoped) get_solver().pop(1);
|
if (scoped) get_solver().pop(1);
|
||||||
if (result == l_true && committed) m_optsmt.commit_assignment(index);
|
if (result == l_true && committed) m_optsmt.commit_assignment(index);
|
||||||
return result;
|
return result;
|
||||||
|
@ -306,7 +300,7 @@ namespace opt {
|
||||||
maxsmt& ms = *m_maxsmts.find(id);
|
maxsmt& ms = *m_maxsmts.find(id);
|
||||||
if (scoped) get_solver().push();
|
if (scoped) get_solver().push();
|
||||||
lbool result = ms();
|
lbool result = ms();
|
||||||
if (result != l_false && (ms.get_model(tmp), tmp.get())) ms.get_model(m_model);
|
if (result != l_false && (ms.get_model(tmp, m_labels), tmp.get())) ms.get_model(m_model, m_labels);
|
||||||
if (scoped) get_solver().pop(1);
|
if (scoped) get_solver().pop(1);
|
||||||
if (result == l_true && committed) ms.commit_assignment();
|
if (result == l_true && committed) ms.commit_assignment();
|
||||||
return result;
|
return result;
|
||||||
|
@ -459,7 +453,7 @@ namespace opt {
|
||||||
}
|
}
|
||||||
|
|
||||||
void context::yield() {
|
void context::yield() {
|
||||||
m_pareto->get_model(m_model);
|
m_pareto->get_model(m_model, m_labels);
|
||||||
update_bound(true);
|
update_bound(true);
|
||||||
update_bound(false);
|
update_bound(false);
|
||||||
}
|
}
|
||||||
|
|
|
@ -162,6 +162,7 @@ namespace opt {
|
||||||
bool m_pp_neat;
|
bool m_pp_neat;
|
||||||
symbol m_maxsat_engine;
|
symbol m_maxsat_engine;
|
||||||
symbol m_logic;
|
symbol m_logic;
|
||||||
|
svector<symbol> m_labels;
|
||||||
public:
|
public:
|
||||||
context(ast_manager& m);
|
context(ast_manager& m);
|
||||||
virtual ~context();
|
virtual ~context();
|
||||||
|
@ -180,7 +181,6 @@ namespace opt {
|
||||||
virtual lbool optimize();
|
virtual lbool optimize();
|
||||||
virtual bool print_model() const;
|
virtual bool print_model() const;
|
||||||
virtual void get_model(model_ref& _m);
|
virtual void get_model(model_ref& _m);
|
||||||
virtual void set_model(model_ref& _m);
|
|
||||||
virtual void fix_model(model_ref& _m);
|
virtual void fix_model(model_ref& _m);
|
||||||
virtual void collect_statistics(statistics& stats) const;
|
virtual void collect_statistics(statistics& stats) const;
|
||||||
virtual proof* get_proof() { return 0; }
|
virtual proof* get_proof() { return 0; }
|
||||||
|
|
|
@ -38,6 +38,7 @@ namespace opt {
|
||||||
return l_undef;
|
return l_undef;
|
||||||
}
|
}
|
||||||
m_solver->get_model(m_model);
|
m_solver->get_model(m_model);
|
||||||
|
m_solver->get_labels(m_labels);
|
||||||
IF_VERBOSE(1,
|
IF_VERBOSE(1,
|
||||||
model_ref mdl(m_model);
|
model_ref mdl(m_model);
|
||||||
cb.fix_model(mdl);
|
cb.fix_model(mdl);
|
||||||
|
@ -96,6 +97,7 @@ namespace opt {
|
||||||
}
|
}
|
||||||
if (is_sat == l_true) {
|
if (is_sat == l_true) {
|
||||||
m_solver->get_model(m_model);
|
m_solver->get_model(m_model);
|
||||||
|
m_solver->get_labels(m_labels);
|
||||||
mk_not_dominated_by();
|
mk_not_dominated_by();
|
||||||
}
|
}
|
||||||
return is_sat;
|
return is_sat;
|
||||||
|
|
|
@ -31,7 +31,6 @@ namespace opt {
|
||||||
virtual expr_ref mk_gt(unsigned i, model_ref& model) = 0;
|
virtual expr_ref mk_gt(unsigned i, model_ref& model) = 0;
|
||||||
virtual expr_ref mk_ge(unsigned i, model_ref& model) = 0;
|
virtual expr_ref mk_ge(unsigned i, model_ref& model) = 0;
|
||||||
virtual expr_ref mk_le(unsigned i, model_ref& model) = 0;
|
virtual expr_ref mk_le(unsigned i, model_ref& model) = 0;
|
||||||
virtual void set_model(model_ref& m) = 0;
|
|
||||||
virtual void fix_model(model_ref& m) = 0;
|
virtual void fix_model(model_ref& m) = 0;
|
||||||
};
|
};
|
||||||
class pareto_base {
|
class pareto_base {
|
||||||
|
@ -42,6 +41,7 @@ namespace opt {
|
||||||
ref<solver> m_solver;
|
ref<solver> m_solver;
|
||||||
params_ref m_params;
|
params_ref m_params;
|
||||||
model_ref m_model;
|
model_ref m_model;
|
||||||
|
svector<symbol> m_labels;
|
||||||
public:
|
public:
|
||||||
pareto_base(
|
pareto_base(
|
||||||
ast_manager & m,
|
ast_manager & m,
|
||||||
|
@ -77,8 +77,9 @@ namespace opt {
|
||||||
}
|
}
|
||||||
virtual lbool operator()() = 0;
|
virtual lbool operator()() = 0;
|
||||||
|
|
||||||
virtual void get_model(model_ref& mdl) {
|
virtual void get_model(model_ref& mdl, svector<symbol>& labels) {
|
||||||
mdl = m_model;
|
mdl = m_model;
|
||||||
|
labels = m_labels;
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
|
|
@ -51,6 +51,7 @@ namespace opt {
|
||||||
if (src[i] >= dst[i]) {
|
if (src[i] >= dst[i]) {
|
||||||
dst[i] = src[i];
|
dst[i] = src[i];
|
||||||
m_models.set(i, m_s->get_model(i));
|
m_models.set(i, m_s->get_model(i));
|
||||||
|
m_s->get_labels(m_labels);
|
||||||
m_lower_fmls[i] = fmls[i].get();
|
m_lower_fmls[i] = fmls[i].get();
|
||||||
if (dst[i].is_pos() && !dst[i].is_finite()) { // review: likely done already.
|
if (dst[i].is_pos() && !dst[i].is_finite()) { // review: likely done already.
|
||||||
m_lower_fmls[i] = m.mk_false();
|
m_lower_fmls[i] = m.mk_false();
|
||||||
|
@ -156,7 +157,8 @@ namespace opt {
|
||||||
if (is_sat == l_true) {
|
if (is_sat == l_true) {
|
||||||
disj.reset();
|
disj.reset();
|
||||||
m_s->maximize_objectives(disj);
|
m_s->maximize_objectives(disj);
|
||||||
m_s->get_model(m_model);
|
m_s->get_model(m_model);
|
||||||
|
m_s->get_labels(m_labels);
|
||||||
for (unsigned i = 0; i < ors.size(); ++i) {
|
for (unsigned i = 0; i < ors.size(); ++i) {
|
||||||
expr_ref tmp(m);
|
expr_ref tmp(m);
|
||||||
m_model->eval(ors[i].get(), tmp);
|
m_model->eval(ors[i].get(), tmp);
|
||||||
|
@ -203,6 +205,7 @@ namespace opt {
|
||||||
expr_ref optsmt::update_lower() {
|
expr_ref optsmt::update_lower() {
|
||||||
expr_ref_vector disj(m);
|
expr_ref_vector disj(m);
|
||||||
m_s->get_model(m_model);
|
m_s->get_model(m_model);
|
||||||
|
m_s->get_labels(m_labels);
|
||||||
m_s->maximize_objectives(disj);
|
m_s->maximize_objectives(disj);
|
||||||
set_max(m_lower, m_s->get_objective_values(), disj);
|
set_max(m_lower, m_s->get_objective_values(), disj);
|
||||||
TRACE("opt",
|
TRACE("opt",
|
||||||
|
@ -331,6 +334,7 @@ namespace opt {
|
||||||
|
|
||||||
m_s->maximize_objective(obj_index, block);
|
m_s->maximize_objective(obj_index, block);
|
||||||
m_s->get_model(m_model);
|
m_s->get_model(m_model);
|
||||||
|
m_s->get_labels(m_labels);
|
||||||
inf_eps obj = m_s->saved_objective_value(obj_index);
|
inf_eps obj = m_s->saved_objective_value(obj_index);
|
||||||
if (obj > m_lower[obj_index]) {
|
if (obj > m_lower[obj_index]) {
|
||||||
m_lower[obj_index] = obj;
|
m_lower[obj_index] = obj;
|
||||||
|
@ -405,8 +409,9 @@ namespace opt {
|
||||||
return m_upper[i];
|
return m_upper[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
void optsmt::get_model(model_ref& mdl) {
|
void optsmt::get_model(model_ref& mdl, svector<symbol> & labels) {
|
||||||
mdl = m_model.get();
|
mdl = m_model.get();
|
||||||
|
labels = m_labels;
|
||||||
}
|
}
|
||||||
|
|
||||||
// force lower_bound(i) <= objective_value(i)
|
// force lower_bound(i) <= objective_value(i)
|
||||||
|
|
|
@ -38,6 +38,7 @@ namespace opt {
|
||||||
svector<smt::theory_var> m_vars;
|
svector<smt::theory_var> m_vars;
|
||||||
symbol m_optsmt_engine;
|
symbol m_optsmt_engine;
|
||||||
model_ref m_model;
|
model_ref m_model;
|
||||||
|
svector<symbol> m_labels;
|
||||||
sref_vector<model> m_models;
|
sref_vector<model> m_models;
|
||||||
public:
|
public:
|
||||||
optsmt(ast_manager& m):
|
optsmt(ast_manager& m):
|
||||||
|
@ -60,7 +61,7 @@ namespace opt {
|
||||||
inf_eps get_lower(unsigned index) const;
|
inf_eps get_lower(unsigned index) const;
|
||||||
inf_eps get_upper(unsigned index) const;
|
inf_eps get_upper(unsigned index) const;
|
||||||
bool objective_is_model_valid(unsigned index) const;
|
bool objective_is_model_valid(unsigned index) const;
|
||||||
void get_model(model_ref& mdl);
|
void get_model(model_ref& mdl, svector<symbol>& labels);
|
||||||
model* get_model(unsigned index) const { return m_models[index]; }
|
model* get_model(unsigned index) const { return m_models[index]; }
|
||||||
|
|
||||||
void update_lower(unsigned idx, inf_eps const& r);
|
void update_lower(unsigned idx, inf_eps const& r);
|
||||||
|
|
|
@ -250,7 +250,6 @@ public:
|
||||||
return "no reason given";
|
return "no reason given";
|
||||||
}
|
}
|
||||||
virtual void get_labels(svector<symbol> & r) {
|
virtual void get_labels(svector<symbol> & r) {
|
||||||
UNREACHABLE();
|
|
||||||
}
|
}
|
||||||
virtual unsigned get_num_assertions() const {
|
virtual unsigned get_num_assertions() const {
|
||||||
return m_fmls.size();
|
return m_fmls.size();
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue