3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-08 10:25:18 +00:00

extract labels for optimal model. Fix to #325

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2015-11-18 14:53:08 -08:00
parent 9cba63c31f
commit c58e640563
9 changed files with 29 additions and 24 deletions

View file

@ -191,7 +191,7 @@ namespace opt {
m_msolver->set_adjust_value(m_adjust_value);
is_sat = (*m_msolver)();
if (is_sat != l_false) {
m_msolver->get_model(m_model);
m_msolver->get_model(m_model, m_labels);
}
}
@ -247,8 +247,9 @@ namespace opt {
m_upper = r;
}
void maxsmt::get_model(model_ref& mdl) {
void maxsmt::get_model(model_ref& mdl, svector<symbol>& labels) {
mdl = m_model.get();
labels = m_labels;
}
void maxsmt::commit_assignment() {

View file

@ -46,7 +46,7 @@ namespace opt {
virtual bool get_assignment(unsigned index) const = 0;
virtual void set_cancel(bool f) = 0;
virtual void collect_statistics(statistics& st) const = 0;
virtual void get_model(model_ref& mdl) = 0;
virtual void get_model(model_ref& mdl, svector<symbol>& labels) = 0;
virtual void updt_params(params_ref& p) = 0;
void set_adjust_value(adjust_value& adj) { m_adjust_value = adj; }
@ -67,6 +67,7 @@ namespace opt {
rational m_lower;
rational m_upper;
model_ref m_model;
svector<symbol> m_labels;
svector<bool> m_assignment; // truth assignment to soft constraints
params_ref m_params; // config
@ -79,9 +80,9 @@ namespace opt {
virtual bool get_assignment(unsigned index) const { return m_assignment[index]; }
virtual void set_cancel(bool f) { m_cancel = f; if (f) s().cancel(); else s().reset_cancel(); }
virtual void collect_statistics(statistics& st) const { }
virtual void get_model(model_ref& mdl) { mdl = m_model.get(); }
virtual void get_model(model_ref& mdl, svector<symbol>& labels) { mdl = m_model.get(); labels = m_labels;}
virtual void commit_assignment();
void set_model() { s().get_model(m_model); }
void set_model() { s().get_model(m_model); s().get_labels(m_labels); }
virtual void updt_params(params_ref& p);
solver& s();
void init();
@ -122,6 +123,7 @@ namespace opt {
rational m_upper;
adjust_value m_adjust_value;
model_ref m_model;
svector<symbol> m_labels;
params_ref m_params;
public:
maxsmt(maxsat_context& c);
@ -139,7 +141,7 @@ namespace opt {
rational get_upper() const;
void update_lower(rational const& r);
void update_upper(rational const& r);
void get_model(model_ref& mdl);
void get_model(model_ref& mdl, svector<symbol>& labels);
bool get_assignment(unsigned index) const;
void display_answer(std::ostream& out) const;
void collect_statistics(statistics& st) const;

View file

@ -168,9 +168,7 @@ namespace opt {
}
void context::get_labels(svector<symbol> & r) {
if (m_solver) {
m_solver->get_labels(r);
}
r.append(m_labels);
}
void context::set_hard_constraints(ptr_vector<expr>& fmls) {
@ -234,6 +232,7 @@ namespace opt {
TRACE("opt", tout << "initial search result: " << is_sat << "\n";);
if (is_sat != l_false) {
s.get_model(m_model);
s.get_labels(m_labels);
}
if (is_sat != l_true) {
return is_sat;
@ -282,11 +281,6 @@ namespace opt {
}
}
void context::set_model(model_ref& mdl) {
m_model = mdl;
fix_model(mdl);
}
void context::get_model(model_ref& mdl) {
mdl = m_model;
fix_model(mdl);
@ -295,7 +289,7 @@ namespace opt {
lbool context::execute_min_max(unsigned index, bool committed, bool scoped, bool is_max) {
if (scoped) get_solver().push();
lbool result = m_optsmt.lex(index, is_max);
if (result == l_true) m_optsmt.get_model(m_model);
if (result == l_true) m_optsmt.get_model(m_model, m_labels);
if (scoped) get_solver().pop(1);
if (result == l_true && committed) m_optsmt.commit_assignment(index);
return result;
@ -306,7 +300,7 @@ namespace opt {
maxsmt& ms = *m_maxsmts.find(id);
if (scoped) get_solver().push();
lbool result = ms();
if (result != l_false && (ms.get_model(tmp), tmp.get())) ms.get_model(m_model);
if (result != l_false && (ms.get_model(tmp, m_labels), tmp.get())) ms.get_model(m_model, m_labels);
if (scoped) get_solver().pop(1);
if (result == l_true && committed) ms.commit_assignment();
return result;
@ -459,7 +453,7 @@ namespace opt {
}
void context::yield() {
m_pareto->get_model(m_model);
m_pareto->get_model(m_model, m_labels);
update_bound(true);
update_bound(false);
}

View file

@ -162,6 +162,7 @@ namespace opt {
bool m_pp_neat;
symbol m_maxsat_engine;
symbol m_logic;
svector<symbol> m_labels;
public:
context(ast_manager& m);
virtual ~context();
@ -180,7 +181,6 @@ namespace opt {
virtual lbool optimize();
virtual bool print_model() const;
virtual void get_model(model_ref& _m);
virtual void set_model(model_ref& _m);
virtual void fix_model(model_ref& _m);
virtual void collect_statistics(statistics& stats) const;
virtual proof* get_proof() { return 0; }

View file

@ -38,6 +38,7 @@ namespace opt {
return l_undef;
}
m_solver->get_model(m_model);
m_solver->get_labels(m_labels);
IF_VERBOSE(1,
model_ref mdl(m_model);
cb.fix_model(mdl);
@ -96,6 +97,7 @@ namespace opt {
}
if (is_sat == l_true) {
m_solver->get_model(m_model);
m_solver->get_labels(m_labels);
mk_not_dominated_by();
}
return is_sat;

View file

@ -31,7 +31,6 @@ namespace opt {
virtual expr_ref mk_gt(unsigned i, model_ref& model) = 0;
virtual expr_ref mk_ge(unsigned i, model_ref& model) = 0;
virtual expr_ref mk_le(unsigned i, model_ref& model) = 0;
virtual void set_model(model_ref& m) = 0;
virtual void fix_model(model_ref& m) = 0;
};
class pareto_base {
@ -42,6 +41,7 @@ namespace opt {
ref<solver> m_solver;
params_ref m_params;
model_ref m_model;
svector<symbol> m_labels;
public:
pareto_base(
ast_manager & m,
@ -77,8 +77,9 @@ namespace opt {
}
virtual lbool operator()() = 0;
virtual void get_model(model_ref& mdl) {
virtual void get_model(model_ref& mdl, svector<symbol>& labels) {
mdl = m_model;
labels = m_labels;
}
protected:

View file

@ -51,6 +51,7 @@ namespace opt {
if (src[i] >= dst[i]) {
dst[i] = src[i];
m_models.set(i, m_s->get_model(i));
m_s->get_labels(m_labels);
m_lower_fmls[i] = fmls[i].get();
if (dst[i].is_pos() && !dst[i].is_finite()) { // review: likely done already.
m_lower_fmls[i] = m.mk_false();
@ -156,7 +157,8 @@ namespace opt {
if (is_sat == l_true) {
disj.reset();
m_s->maximize_objectives(disj);
m_s->get_model(m_model);
m_s->get_model(m_model);
m_s->get_labels(m_labels);
for (unsigned i = 0; i < ors.size(); ++i) {
expr_ref tmp(m);
m_model->eval(ors[i].get(), tmp);
@ -203,6 +205,7 @@ namespace opt {
expr_ref optsmt::update_lower() {
expr_ref_vector disj(m);
m_s->get_model(m_model);
m_s->get_labels(m_labels);
m_s->maximize_objectives(disj);
set_max(m_lower, m_s->get_objective_values(), disj);
TRACE("opt",
@ -331,6 +334,7 @@ namespace opt {
m_s->maximize_objective(obj_index, block);
m_s->get_model(m_model);
m_s->get_labels(m_labels);
inf_eps obj = m_s->saved_objective_value(obj_index);
if (obj > m_lower[obj_index]) {
m_lower[obj_index] = obj;
@ -405,8 +409,9 @@ namespace opt {
return m_upper[i];
}
void optsmt::get_model(model_ref& mdl) {
void optsmt::get_model(model_ref& mdl, svector<symbol> & labels) {
mdl = m_model.get();
labels = m_labels;
}
// force lower_bound(i) <= objective_value(i)

View file

@ -38,6 +38,7 @@ namespace opt {
svector<smt::theory_var> m_vars;
symbol m_optsmt_engine;
model_ref m_model;
svector<symbol> m_labels;
sref_vector<model> m_models;
public:
optsmt(ast_manager& m):
@ -60,7 +61,7 @@ namespace opt {
inf_eps get_lower(unsigned index) const;
inf_eps get_upper(unsigned index) const;
bool objective_is_model_valid(unsigned index) const;
void get_model(model_ref& mdl);
void get_model(model_ref& mdl, svector<symbol>& labels);
model* get_model(unsigned index) const { return m_models[index]; }
void update_lower(unsigned idx, inf_eps const& r);

View file

@ -250,7 +250,6 @@ public:
return "no reason given";
}
virtual void get_labels(svector<symbol> & r) {
UNREACHABLE();
}
virtual unsigned get_num_assertions() const {
return m_fmls.size();