3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-06-21 05:13:39 +00:00

fix up pareto callback mechanism

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2014-05-13 12:48:17 -07:00
parent 1ea376e310
commit 03979fd580
8 changed files with 178 additions and 144 deletions

View file

@ -1359,6 +1359,7 @@ void cmd_context::check_sat(unsigned num_assumptions, expr * const * assumptions
lbool r; lbool r;
if (m_opt && !m_opt->empty()) { if (m_opt && !m_opt->empty()) {
bool was_pareto = false;
m_check_sat_result = get_opt(); m_check_sat_result = get_opt();
cancel_eh<opt_wrapper> eh(*get_opt()); cancel_eh<opt_wrapper> eh(*get_opt());
scoped_ctrl_c ctrlc(eh); scoped_ctrl_c ctrlc(eh);
@ -1368,6 +1369,11 @@ void cmd_context::check_sat(unsigned num_assumptions, expr * const * assumptions
get_opt()->set_hard_constraints(cnstr); get_opt()->set_hard_constraints(cnstr);
try { try {
r = get_opt()->optimize(); r = get_opt()->optimize();
while (r == l_true && get_opt()->is_pareto()) {
was_pareto = true;
get_opt()->display_assignment(regular_stream());
r = get_opt()->optimize();
}
} }
catch (z3_error & ex) { catch (z3_error & ex) {
throw ex; throw ex;
@ -1375,8 +1381,11 @@ void cmd_context::check_sat(unsigned num_assumptions, expr * const * assumptions
catch (z3_exception & ex) { catch (z3_exception & ex) {
throw cmd_exception(ex.msg()); throw cmd_exception(ex.msg());
} }
if (was_pareto && r == l_false) {
r = l_true;
}
get_opt()->set_status(r); get_opt()->set_status(r);
if (r != l_false) { if (r != l_false && !was_pareto) {
get_opt()->display_assignment(regular_stream()); get_opt()->display_assignment(regular_stream());
} }
} }

View file

@ -122,6 +122,7 @@ public:
virtual lbool optimize() = 0; virtual lbool optimize() = 0;
virtual void set_hard_constraints(ptr_vector<expr> & hard) = 0; virtual void set_hard_constraints(ptr_vector<expr> & hard) = 0;
virtual void display_assignment(std::ostream& out) = 0; virtual void display_assignment(std::ostream& out) = 0;
virtual bool is_pareto() = 0;
}; };
class cmd_context : public progress_callback, public tactic_manager, public ast_printer_context { class cmd_context : public progress_callback, public tactic_manager, public ast_printer_context {

View file

@ -266,6 +266,16 @@ public:
cmd_context::scoped_watch sw(ctx); cmd_context::scoped_watch sw(ctx);
try { try {
r = opt.optimize(); r = opt.optimize();
if (r == l_true && opt.is_pareto()) {
while (r == l_true) {
display_result(ctx);
r = opt.optimize();
}
if (p.get_bool("print_statistics", false)) {
display_statistics(ctx);
}
return;
}
} }
catch (z3_error& ex) { catch (z3_error& ex) {
ctx.regular_stream() << "(error: " << ex.msg() << "\")" << std::endl; ctx.regular_stream() << "(error: " << ex.msg() << "\")" << std::endl;

View file

@ -180,10 +180,13 @@ namespace opt {
} }
lbool context::optimize() { lbool context::optimize() {
if (m_pareto) {
return execute_pareto();
}
import_scoped_state(); import_scoped_state();
normalize(); normalize();
internalize(); internalize();
opt_solver& s = get_solver(); opt_solver& s = get_solver();
solver::scoped_push _sp(s); solver::scoped_push _sp(s);
for (unsigned i = 0; i < m_hard_constraints.size(); ++i) { for (unsigned i = 0; i < m_hard_constraints.size(); ++i) {
TRACE("opt", tout << "Hard constraint: " << mk_ismt2_pp(m_hard_constraints[i].get(), m) << std::endl;); TRACE("opt", tout << "Hard constraint: " << mk_ismt2_pp(m_hard_constraints[i].get(), m) << std::endl;);
@ -210,6 +213,7 @@ namespace opt {
opt_params optp(m_params); opt_params optp(m_params);
symbol pri = optp.priority(); symbol pri = optp.priority();
if (pri == symbol("pareto")) { if (pri == symbol("pareto")) {
_sp.disable_pop();
return execute_pareto(); return execute_pareto();
} }
else if (pri == symbol("box")) { else if (pri == symbol("box")) {
@ -282,115 +286,109 @@ namespace opt {
return r; return r;
} }
class context::pareto : public pareto_callback {
context& ctx; expr_ref context::mk_le(unsigned i, model_ref& mdl) {
ast_manager& m; objective const& obj = m_objectives[i];
expr_ref mk_ge(expr* t, expr* s) { expr_ref val(m), result(m), term(m);
expr_ref result(m); mk_term_val(mdl, obj, term, val);
if (ctx.m_bv.is_bv(t)) { switch (obj.m_type) {
result = ctx.m_bv.mk_ule(s, t); case O_MINIMIZE:
result = mk_ge(term, val);
break;
case O_MAXSMT:
result = mk_ge(term, val);
break;
case O_MAXIMIZE:
result = mk_ge(val, term);
break;
}
return result;
}
expr_ref context::mk_ge(unsigned i, model_ref& mdl) {
objective const& obj = m_objectives[i];
expr_ref val(m), result(m), term(m);
mk_term_val(mdl, obj, term, val);
switch (obj.m_type) {
case O_MINIMIZE:
result = mk_ge(val, term);
break;
case O_MAXSMT:
result = mk_ge(val, term);
break;
case O_MAXIMIZE:
result = mk_ge(term, val);
break;
}
return result;
}
expr_ref context::mk_gt(unsigned i, model_ref& mdl) {
expr_ref result = mk_le(i, mdl);
result = m.mk_not(result);
return result;
}
void context::mk_term_val(model_ref& mdl, objective const& obj, expr_ref& term, expr_ref& val) {
rational r;
switch (obj.m_type) {
case O_MINIMIZE:
case O_MAXIMIZE:
term = obj.m_term;
break;
case O_MAXSMT: {
unsigned sz = obj.m_terms.size();
expr_ref_vector sum(m);
expr_ref zero(m);
zero = m_arith.mk_numeral(rational(0), false);
for (unsigned i = 0; i < sz; ++i) {
expr* t = obj.m_terms[i];
rational const& w = obj.m_weights[i];
sum.push_back(m.mk_ite(t, m_arith.mk_numeral(w, false), zero));
}
if (sum.empty()) {
term = zero;
} }
else { else {
result = ctx.m_arith.mk_ge(t, s); term = m_arith.mk_add(sum.size(), sum.c_ptr());
} }
return result; break;
} }
public: }
pareto(context& ctx):ctx(ctx),m(ctx.m) {} VERIFY(mdl->eval(term, val) && is_numeral(val, r));
}
virtual void yield(model_ref& mdl) { expr_ref context::mk_ge(expr* t, expr* s) {
ctx.m_model = mdl; expr_ref result(m);
ctx.update_lower(true); if (m_bv.is_bv(t)) {
for (unsigned i = 0; i < ctx.m_objectives.size(); ++i) { result = m_bv.mk_ule(s, t);
objective const& obj = ctx.m_objectives[i]; }
switch(obj.m_type) { else {
case O_MINIMIZE: result = m_arith.mk_ge(t, s);
case O_MAXIMIZE: }
ctx.m_optsmt.update_upper(obj.m_index, ctx.m_optsmt.get_lower(obj.m_index), true); return result;
break; }
case O_MAXSMT: {
rational r = ctx.m_maxsmts.find(obj.m_id)->get_lower();
ctx.m_maxsmts.find(obj.m_id)->update_upper(r, true);
break;
}
}
}
IF_VERBOSE(1, ctx.display_assignment(verbose_stream());); void context::yield() {
} m_pareto->get_model(m_model);
virtual unsigned num_objectives() { update_lower(true);
return ctx.m_objectives.size(); for (unsigned i = 0; i < m_objectives.size(); ++i) {
} objective const& obj = m_objectives[i];
virtual expr_ref mk_le(unsigned i, model_ref& mdl) { switch(obj.m_type) {
objective const& obj = ctx.m_objectives[i];
expr_ref val(m), result(m), term(m);
mk_term_val(mdl, obj, term, val);
switch (obj.m_type) {
case O_MINIMIZE:
result = mk_ge(term, val);
break;
case O_MAXSMT:
result = mk_ge(term, val);
break;
case O_MAXIMIZE:
result = mk_ge(val, term);
break;
}
return result;
}
virtual expr_ref mk_ge(unsigned i, model_ref& mdl) {
objective const& obj = ctx.m_objectives[i];
expr_ref val(m), result(m), term(m);
mk_term_val(mdl, obj, term, val);
switch (obj.m_type) {
case O_MINIMIZE:
result = mk_ge(val, term);
break;
case O_MAXSMT:
result = mk_ge(val, term);
break;
case O_MAXIMIZE:
result = mk_ge(term, val);
break;
}
return result;
}
virtual expr_ref mk_gt(unsigned i, model_ref& mdl) {
expr_ref result = mk_le(i, mdl);
result = m.mk_not(result);
return result;
}
private:
void mk_term_val(model_ref& mdl, objective const& obj, expr_ref& term, expr_ref& val) {
rational r;
switch (obj.m_type) {
case O_MINIMIZE: case O_MINIMIZE:
case O_MAXIMIZE: case O_MAXIMIZE:
term = obj.m_term; m_optsmt.update_upper(obj.m_index, m_optsmt.get_lower(obj.m_index), true);
break; break;
case O_MAXSMT: { case O_MAXSMT: {
unsigned sz = obj.m_terms.size(); rational r = m_maxsmts.find(obj.m_id)->get_lower();
expr_ref_vector sum(m); m_maxsmts.find(obj.m_id)->update_upper(r, true);
expr_ref zero(m);
zero = ctx.m_arith.mk_numeral(rational(0), false);
for (unsigned i = 0; i < sz; ++i) {
expr* t = obj.m_terms[i];
rational const& w = obj.m_weights[i];
sum.push_back(m.mk_ite(t, ctx.m_arith.mk_numeral(w, false), zero));
}
if (sum.empty()) {
term = zero;
}
else {
term = ctx.m_arith.mk_add(sum.size(), sum.c_ptr());
}
break; break;
} }
} }
VERIFY(mdl->eval(term, val) && ctx.is_numeral(val, r));
} }
}
#if 0 #if 0
// use PB // use PB
@ -415,13 +413,22 @@ namespace opt {
} }
} }
#endif #endif
};
lbool context::execute_pareto() { lbool context::execute_pareto() {
pareto cb(*this); if (!m_pareto) {
m_pareto = alloc(gia_pareto, m, cb, m_solver.get(), m_params); m_pareto = alloc(gia_pareto, m, *this, m_solver.get(), m_params);
return (*(m_pareto.get()))(); }
// NB. stack reference cb is out of scope after return. lbool is_sat = (*(m_pareto.get()))();
if (is_sat != l_true) {
m_pareto = 0;
}
if (is_sat == l_true) {
yield();
}
else {
m_solver->pop(1);
}
return is_sat;
// NB. fix race condition for set_cancel // NB. fix race condition for set_cancel
} }

View file

@ -34,7 +34,7 @@ namespace opt {
class opt_solver; class opt_solver;
class context : public opt_wrapper { class context : public opt_wrapper, public pareto_callback {
struct free_func_visitor; struct free_func_visitor;
typedef map<symbol, maxsmt*, symbol_hash_proc, symbol_eq_proc> map_t; typedef map<symbol, maxsmt*, symbol_hash_proc, symbol_eq_proc> map_t;
typedef map<symbol, unsigned, symbol_hash_proc, symbol_eq_proc> map_id; typedef map<symbol, unsigned, symbol_hash_proc, symbol_eq_proc> map_id;
@ -145,6 +145,8 @@ namespace opt {
virtual std::string reason_unknown() const { return std::string("unknown"); } virtual std::string reason_unknown() const { return std::string("unknown"); }
virtual void display_assignment(std::ostream& out); virtual void display_assignment(std::ostream& out);
virtual bool is_pareto() { return m_pareto.get() != 0; }
void display(std::ostream& out); void display(std::ostream& out);
static void collect_param_descrs(param_descrs & r); static void collect_param_descrs(param_descrs & r);
void updt_params(params_ref& p); void updt_params(params_ref& p);
@ -155,6 +157,13 @@ namespace opt {
std::string to_string() const; std::string to_string() const;
virtual unsigned num_objectives() { return m_objectives.size(); }
virtual expr_ref mk_gt(unsigned i, model_ref& model);
virtual expr_ref mk_ge(unsigned i, model_ref& model);
virtual expr_ref mk_le(unsigned i, model_ref& model);
private: private:
void validate_feasibility(maxsmt& ms); void validate_feasibility(maxsmt& ms);
@ -199,7 +208,11 @@ namespace opt {
void validate_lex(); void validate_lex();
class pareto;
// pareto
void yield();
expr_ref mk_ge(expr* t, expr* s);
void mk_term_val(model_ref& mdl, objective const& obj, expr_ref& term, expr_ref& val);
}; };

View file

@ -20,6 +20,7 @@ Notes:
#include "opt_pareto.h" #include "opt_pareto.h"
#include "ast_pp.h" #include "ast_pp.h"
#include "model_smt2_pp.h"
namespace opt { namespace opt {
@ -27,43 +28,39 @@ namespace opt {
// GIA pareto algorithm // GIA pareto algorithm
lbool gia_pareto::operator()() { lbool gia_pareto::operator()() {
model_ref model;
expr_ref fml(m); expr_ref fml(m);
lbool is_sat = m_solver->check_sat(0, 0); lbool is_sat = m_solver->check_sat(0, 0);
while (is_sat == l_true) { if (is_sat == l_true) {
{ {
solver::scoped_push _s(*m_solver.get()); solver::scoped_push _s(*m_solver.get());
while (is_sat == l_true) { while (is_sat == l_true) {
if (m_cancel) { if (m_cancel) {
return l_undef; return l_undef;
} }
m_solver->get_model(model); m_solver->get_model(m_model);
IF_VERBOSE(1, model_smt2_pp(verbose_stream() << "new model:\n", m, *m_model, 0););
// TBD: we can also use local search to tune solution coordinate-wise. // TBD: we can also use local search to tune solution coordinate-wise.
mk_dominates(model); mk_dominates();
is_sat = m_solver->check_sat(0, 0); is_sat = m_solver->check_sat(0, 0);
} }
if (is_sat == l_undef) {
return l_undef;
}
is_sat = l_true;
} }
cb.yield(model); if (is_sat == l_undef) {
mk_not_dominated_by(model); return l_undef;
is_sat = m_solver->check_sat(0, 0); }
SASSERT(is_sat == l_false);
is_sat = l_true;
mk_not_dominated_by();
} }
if (is_sat == l_undef) { return is_sat;
return l_undef;
}
return l_true;
} }
void pareto_base::mk_dominates(model_ref& model) { void pareto_base::mk_dominates() {
unsigned sz = cb.num_objectives(); unsigned sz = cb.num_objectives();
expr_ref fml(m); expr_ref fml(m);
expr_ref_vector gt(m), fmls(m); expr_ref_vector gt(m), fmls(m);
for (unsigned i = 0; i < sz; ++i) { for (unsigned i = 0; i < sz; ++i) {
fmls.push_back(cb.mk_ge(i, model)); fmls.push_back(cb.mk_ge(i, m_model));
gt.push_back(cb.mk_gt(i, model)); gt.push_back(cb.mk_gt(i, m_model));
} }
fmls.push_back(m.mk_or(gt.size(), gt.c_ptr())); fmls.push_back(m.mk_or(gt.size(), gt.c_ptr()));
fml = m.mk_and(fmls.size(), fmls.c_ptr()); fml = m.mk_and(fmls.size(), fmls.c_ptr());
@ -71,12 +68,12 @@ namespace opt {
m_solver->assert_expr(fml); m_solver->assert_expr(fml);
} }
void pareto_base::mk_not_dominated_by(model_ref& model) { void pareto_base::mk_not_dominated_by() {
unsigned sz = cb.num_objectives(); unsigned sz = cb.num_objectives();
expr_ref fml(m); expr_ref fml(m);
expr_ref_vector le(m); expr_ref_vector le(m);
for (unsigned i = 0; i < sz; ++i) { for (unsigned i = 0; i < sz; ++i) {
le.push_back(cb.mk_le(i, model)); le.push_back(cb.mk_le(i, m_model));
} }
fml = m.mk_not(m.mk_and(le.size(), le.c_ptr())); fml = m.mk_not(m.mk_and(le.size(), le.c_ptr()));
IF_VERBOSE(10, verbose_stream() << "not dominated by: " << fml << "\n";); IF_VERBOSE(10, verbose_stream() << "not dominated by: " << fml << "\n";);
@ -87,25 +84,16 @@ namespace opt {
// OIA algorithm (without filtering) // OIA algorithm (without filtering)
lbool oia_pareto::operator()() { lbool oia_pareto::operator()() {
model_ref model;
solver::scoped_push _s(*m_solver.get()); solver::scoped_push _s(*m_solver.get());
lbool is_sat = m_solver->check_sat(0, 0); lbool is_sat = m_solver->check_sat(0, 0);
if (is_sat != l_true) {
return is_sat;
}
while (is_sat == l_true) {
if (m_cancel) {
return l_undef;
}
m_solver->get_model(model);
cb.yield(model);
mk_not_dominated_by(model);
is_sat = m_solver->check_sat(0, 0);
}
if (m_cancel) { if (m_cancel) {
return l_undef; is_sat = l_undef;
} }
return l_true; if (is_sat == l_true) {
m_solver->get_model(m_model);
mk_not_dominated_by();
}
return is_sat;
} }
} }

View file

@ -27,7 +27,6 @@ namespace opt {
class pareto_callback { class pareto_callback {
public: public:
virtual void yield(model_ref& model) = 0;
virtual unsigned num_objectives() = 0; virtual unsigned num_objectives() = 0;
virtual expr_ref mk_gt(unsigned i, model_ref& model) = 0; virtual expr_ref mk_gt(unsigned i, model_ref& model) = 0;
virtual expr_ref mk_ge(unsigned i, model_ref& model) = 0; virtual expr_ref mk_ge(unsigned i, model_ref& model) = 0;
@ -40,6 +39,7 @@ namespace opt {
volatile bool m_cancel; volatile bool m_cancel;
ref<solver> m_solver; ref<solver> m_solver;
params_ref m_params; params_ref m_params;
model_ref m_model;
public: public:
pareto_base( pareto_base(
ast_manager & m, ast_manager & m,
@ -72,11 +72,15 @@ namespace opt {
} }
virtual lbool operator()() = 0; virtual lbool operator()() = 0;
virtual void get_model(model_ref& mdl) {
mdl = m_model;
}
protected: protected:
void mk_dominates(model_ref& model); void mk_dominates();
void mk_not_dominated_by(model_ref& model); void mk_not_dominated_by();
}; };
class gia_pareto : public pareto_base { class gia_pareto : public pareto_base {
public: public:

View file

@ -133,9 +133,11 @@ public:
class scoped_push { class scoped_push {
solver& s; solver& s;
bool m_nopop;
public: public:
scoped_push(solver& s):s(s) { s.push(); } scoped_push(solver& s):s(s), m_nopop(false) { s.push(); }
~scoped_push() { s.pop(1); } ~scoped_push() { if (!m_nopop) s.pop(1); }
void disable_pop() { m_nopop = true; }
}; };
}; };