From 03979fd58024e734677f6577aa6b04692d38db22 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Tue, 13 May 2014 12:48:17 -0700 Subject: [PATCH] fix up pareto callback mechanism Signed-off-by: Nikolaj Bjorner --- src/cmd_context/cmd_context.cpp | 11 +- src/cmd_context/cmd_context.h | 1 + src/opt/opt_cmds.cpp | 10 ++ src/opt/opt_context.cpp | 209 +++++++++++++++++--------------- src/opt/opt_context.h | 17 ++- src/opt/opt_pareto.cpp | 58 ++++----- src/opt/opt_pareto.h | 10 +- src/solver/solver.h | 6 +- 8 files changed, 178 insertions(+), 144 deletions(-) diff --git a/src/cmd_context/cmd_context.cpp b/src/cmd_context/cmd_context.cpp index ca9e90430..4373ee33a 100644 --- a/src/cmd_context/cmd_context.cpp +++ b/src/cmd_context/cmd_context.cpp @@ -1359,6 +1359,7 @@ void cmd_context::check_sat(unsigned num_assumptions, expr * const * assumptions lbool r; if (m_opt && !m_opt->empty()) { + bool was_pareto = false; m_check_sat_result = get_opt(); cancel_eh eh(*get_opt()); scoped_ctrl_c ctrlc(eh); @@ -1368,6 +1369,11 @@ void cmd_context::check_sat(unsigned num_assumptions, expr * const * assumptions get_opt()->set_hard_constraints(cnstr); try { r = get_opt()->optimize(); + while (r == l_true && get_opt()->is_pareto()) { + was_pareto = true; + get_opt()->display_assignment(regular_stream()); + r = get_opt()->optimize(); + } } catch (z3_error & ex) { throw ex; @@ -1375,8 +1381,11 @@ void cmd_context::check_sat(unsigned num_assumptions, expr * const * assumptions catch (z3_exception & ex) { throw cmd_exception(ex.msg()); } + if (was_pareto && r == l_false) { + r = l_true; + } get_opt()->set_status(r); - if (r != l_false) { + if (r != l_false && !was_pareto) { get_opt()->display_assignment(regular_stream()); } } diff --git a/src/cmd_context/cmd_context.h b/src/cmd_context/cmd_context.h index 8f7fc3228..0dfea0441 100644 --- a/src/cmd_context/cmd_context.h +++ b/src/cmd_context/cmd_context.h @@ -122,6 +122,7 @@ public: virtual lbool optimize() = 0; virtual void set_hard_constraints(ptr_vector & hard) = 0; virtual void display_assignment(std::ostream& out) = 0; + virtual bool is_pareto() = 0; }; class cmd_context : public progress_callback, public tactic_manager, public ast_printer_context { diff --git a/src/opt/opt_cmds.cpp b/src/opt/opt_cmds.cpp index 857a3c585..42d749cb6 100644 --- a/src/opt/opt_cmds.cpp +++ b/src/opt/opt_cmds.cpp @@ -266,6 +266,16 @@ public: cmd_context::scoped_watch sw(ctx); try { r = opt.optimize(); + if (r == l_true && opt.is_pareto()) { + while (r == l_true) { + display_result(ctx); + r = opt.optimize(); + } + if (p.get_bool("print_statistics", false)) { + display_statistics(ctx); + } + return; + } } catch (z3_error& ex) { ctx.regular_stream() << "(error: " << ex.msg() << "\")" << std::endl; diff --git a/src/opt/opt_context.cpp b/src/opt/opt_context.cpp index 3a88ff904..fda90006c 100644 --- a/src/opt/opt_context.cpp +++ b/src/opt/opt_context.cpp @@ -180,10 +180,13 @@ namespace opt { } lbool context::optimize() { + if (m_pareto) { + return execute_pareto(); + } import_scoped_state(); normalize(); internalize(); - opt_solver& s = get_solver(); + opt_solver& s = get_solver(); solver::scoped_push _sp(s); for (unsigned i = 0; i < m_hard_constraints.size(); ++i) { TRACE("opt", tout << "Hard constraint: " << mk_ismt2_pp(m_hard_constraints[i].get(), m) << std::endl;); @@ -210,6 +213,7 @@ namespace opt { opt_params optp(m_params); symbol pri = optp.priority(); if (pri == symbol("pareto")) { + _sp.disable_pop(); return execute_pareto(); } else if (pri == symbol("box")) { @@ -282,115 +286,109 @@ namespace opt { return r; } - class context::pareto : public pareto_callback { - context& ctx; - ast_manager& m; - expr_ref mk_ge(expr* t, expr* s) { - expr_ref result(m); - if (ctx.m_bv.is_bv(t)) { - result = ctx.m_bv.mk_ule(s, t); + + expr_ref context::mk_le(unsigned i, model_ref& mdl) { + objective const& obj = m_objectives[i]; + expr_ref val(m), result(m), term(m); + mk_term_val(mdl, obj, term, val); + switch (obj.m_type) { + case O_MINIMIZE: + result = mk_ge(term, val); + break; + case O_MAXSMT: + result = mk_ge(term, val); + break; + case O_MAXIMIZE: + result = mk_ge(val, term); + break; + } + return result; + } + + expr_ref context::mk_ge(unsigned i, model_ref& mdl) { + objective const& obj = m_objectives[i]; + expr_ref val(m), result(m), term(m); + mk_term_val(mdl, obj, term, val); + switch (obj.m_type) { + case O_MINIMIZE: + result = mk_ge(val, term); + break; + case O_MAXSMT: + result = mk_ge(val, term); + break; + case O_MAXIMIZE: + result = mk_ge(term, val); + break; + } + return result; + } + + expr_ref context::mk_gt(unsigned i, model_ref& mdl) { + expr_ref result = mk_le(i, mdl); + result = m.mk_not(result); + return result; + } + + void context::mk_term_val(model_ref& mdl, objective const& obj, expr_ref& term, expr_ref& val) { + rational r; + switch (obj.m_type) { + case O_MINIMIZE: + case O_MAXIMIZE: + term = obj.m_term; + break; + case O_MAXSMT: { + unsigned sz = obj.m_terms.size(); + expr_ref_vector sum(m); + expr_ref zero(m); + zero = m_arith.mk_numeral(rational(0), false); + for (unsigned i = 0; i < sz; ++i) { + expr* t = obj.m_terms[i]; + rational const& w = obj.m_weights[i]; + sum.push_back(m.mk_ite(t, m_arith.mk_numeral(w, false), zero)); + } + if (sum.empty()) { + term = zero; } else { - result = ctx.m_arith.mk_ge(t, s); - } - return result; + term = m_arith.mk_add(sum.size(), sum.c_ptr()); + } + break; } - public: - pareto(context& ctx):ctx(ctx),m(ctx.m) {} + } + VERIFY(mdl->eval(term, val) && is_numeral(val, r)); + } - virtual void yield(model_ref& mdl) { - ctx.m_model = mdl; - ctx.update_lower(true); - for (unsigned i = 0; i < ctx.m_objectives.size(); ++i) { - objective const& obj = ctx.m_objectives[i]; - switch(obj.m_type) { - case O_MINIMIZE: - case O_MAXIMIZE: - ctx.m_optsmt.update_upper(obj.m_index, ctx.m_optsmt.get_lower(obj.m_index), true); - break; - case O_MAXSMT: { - rational r = ctx.m_maxsmts.find(obj.m_id)->get_lower(); - ctx.m_maxsmts.find(obj.m_id)->update_upper(r, true); - break; - } - } - } + expr_ref context::mk_ge(expr* t, expr* s) { + expr_ref result(m); + if (m_bv.is_bv(t)) { + result = m_bv.mk_ule(s, t); + } + else { + result = m_arith.mk_ge(t, s); + } + return result; + } - IF_VERBOSE(1, ctx.display_assignment(verbose_stream());); - } - virtual unsigned num_objectives() { - return ctx.m_objectives.size(); - } - virtual expr_ref mk_le(unsigned i, model_ref& mdl) { - objective const& obj = ctx.m_objectives[i]; - expr_ref val(m), result(m), term(m); - mk_term_val(mdl, obj, term, val); - switch (obj.m_type) { - case O_MINIMIZE: - result = mk_ge(term, val); - break; - case O_MAXSMT: - result = mk_ge(term, val); - break; - case O_MAXIMIZE: - result = mk_ge(val, term); - break; - } - return result; - } - virtual expr_ref mk_ge(unsigned i, model_ref& mdl) { - objective const& obj = ctx.m_objectives[i]; - expr_ref val(m), result(m), term(m); - mk_term_val(mdl, obj, term, val); - switch (obj.m_type) { - case O_MINIMIZE: - result = mk_ge(val, term); - break; - case O_MAXSMT: - result = mk_ge(val, term); - break; - case O_MAXIMIZE: - result = mk_ge(term, val); - break; - } - return result; - } - - virtual expr_ref mk_gt(unsigned i, model_ref& mdl) { - expr_ref result = mk_le(i, mdl); - result = m.mk_not(result); - return result; - } - private: - void mk_term_val(model_ref& mdl, objective const& obj, expr_ref& term, expr_ref& val) { - rational r; - switch (obj.m_type) { + void context::yield() { + m_pareto->get_model(m_model); + update_lower(true); + for (unsigned i = 0; i < m_objectives.size(); ++i) { + objective const& obj = m_objectives[i]; + switch(obj.m_type) { case O_MINIMIZE: case O_MAXIMIZE: - term = obj.m_term; + m_optsmt.update_upper(obj.m_index, m_optsmt.get_lower(obj.m_index), true); break; case O_MAXSMT: { - unsigned sz = obj.m_terms.size(); - expr_ref_vector sum(m); - expr_ref zero(m); - zero = ctx.m_arith.mk_numeral(rational(0), false); - for (unsigned i = 0; i < sz; ++i) { - expr* t = obj.m_terms[i]; - rational const& w = obj.m_weights[i]; - sum.push_back(m.mk_ite(t, ctx.m_arith.mk_numeral(w, false), zero)); - } - if (sum.empty()) { - term = zero; - } - else { - term = ctx.m_arith.mk_add(sum.size(), sum.c_ptr()); - } + rational r = m_maxsmts.find(obj.m_id)->get_lower(); + m_maxsmts.find(obj.m_id)->update_upper(r, true); break; } } - VERIFY(mdl->eval(term, val) && ctx.is_numeral(val, r)); } + } + #if 0 // use PB @@ -415,13 +413,22 @@ namespace opt { } } #endif - }; lbool context::execute_pareto() { - pareto cb(*this); - m_pareto = alloc(gia_pareto, m, cb, m_solver.get(), m_params); - return (*(m_pareto.get()))(); - // NB. stack reference cb is out of scope after return. + if (!m_pareto) { + m_pareto = alloc(gia_pareto, m, *this, m_solver.get(), m_params); + } + lbool is_sat = (*(m_pareto.get()))(); + if (is_sat != l_true) { + m_pareto = 0; + } + if (is_sat == l_true) { + yield(); + } + else { + m_solver->pop(1); + } + return is_sat; // NB. fix race condition for set_cancel } diff --git a/src/opt/opt_context.h b/src/opt/opt_context.h index a09872e85..18d6d3745 100644 --- a/src/opt/opt_context.h +++ b/src/opt/opt_context.h @@ -34,7 +34,7 @@ namespace opt { class opt_solver; - class context : public opt_wrapper { + class context : public opt_wrapper, public pareto_callback { struct free_func_visitor; typedef map map_t; typedef map map_id; @@ -145,6 +145,8 @@ namespace opt { virtual std::string reason_unknown() const { return std::string("unknown"); } virtual void display_assignment(std::ostream& out); + virtual bool is_pareto() { return m_pareto.get() != 0; } + void display(std::ostream& out); static void collect_param_descrs(param_descrs & r); void updt_params(params_ref& p); @@ -155,6 +157,13 @@ namespace opt { std::string to_string() const; + + virtual unsigned num_objectives() { return m_objectives.size(); } + virtual expr_ref mk_gt(unsigned i, model_ref& model); + virtual expr_ref mk_ge(unsigned i, model_ref& model); + virtual expr_ref mk_le(unsigned i, model_ref& model); + + private: void validate_feasibility(maxsmt& ms); @@ -199,7 +208,11 @@ namespace opt { void validate_lex(); - class pareto; + + // pareto + void yield(); + expr_ref mk_ge(expr* t, expr* s); + void mk_term_val(model_ref& mdl, objective const& obj, expr_ref& term, expr_ref& val); }; diff --git a/src/opt/opt_pareto.cpp b/src/opt/opt_pareto.cpp index d901da0e9..2e0286945 100644 --- a/src/opt/opt_pareto.cpp +++ b/src/opt/opt_pareto.cpp @@ -20,6 +20,7 @@ Notes: #include "opt_pareto.h" #include "ast_pp.h" +#include "model_smt2_pp.h" namespace opt { @@ -27,43 +28,39 @@ namespace opt { // GIA pareto algorithm lbool gia_pareto::operator()() { - model_ref model; expr_ref fml(m); lbool is_sat = m_solver->check_sat(0, 0); - while (is_sat == l_true) { + if (is_sat == l_true) { { solver::scoped_push _s(*m_solver.get()); while (is_sat == l_true) { if (m_cancel) { return l_undef; } - m_solver->get_model(model); + m_solver->get_model(m_model); + IF_VERBOSE(1, model_smt2_pp(verbose_stream() << "new model:\n", m, *m_model, 0);); // TBD: we can also use local search to tune solution coordinate-wise. - mk_dominates(model); + mk_dominates(); is_sat = m_solver->check_sat(0, 0); } - if (is_sat == l_undef) { - return l_undef; - } - is_sat = l_true; } - cb.yield(model); - mk_not_dominated_by(model); - is_sat = m_solver->check_sat(0, 0); + if (is_sat == l_undef) { + return l_undef; + } + SASSERT(is_sat == l_false); + is_sat = l_true; + mk_not_dominated_by(); } - if (is_sat == l_undef) { - return l_undef; - } - return l_true; + return is_sat; } - void pareto_base::mk_dominates(model_ref& model) { + void pareto_base::mk_dominates() { unsigned sz = cb.num_objectives(); expr_ref fml(m); expr_ref_vector gt(m), fmls(m); for (unsigned i = 0; i < sz; ++i) { - fmls.push_back(cb.mk_ge(i, model)); - gt.push_back(cb.mk_gt(i, model)); + fmls.push_back(cb.mk_ge(i, m_model)); + gt.push_back(cb.mk_gt(i, m_model)); } fmls.push_back(m.mk_or(gt.size(), gt.c_ptr())); fml = m.mk_and(fmls.size(), fmls.c_ptr()); @@ -71,12 +68,12 @@ namespace opt { m_solver->assert_expr(fml); } - void pareto_base::mk_not_dominated_by(model_ref& model) { + void pareto_base::mk_not_dominated_by() { unsigned sz = cb.num_objectives(); expr_ref fml(m); expr_ref_vector le(m); for (unsigned i = 0; i < sz; ++i) { - le.push_back(cb.mk_le(i, model)); + le.push_back(cb.mk_le(i, m_model)); } fml = m.mk_not(m.mk_and(le.size(), le.c_ptr())); IF_VERBOSE(10, verbose_stream() << "not dominated by: " << fml << "\n";); @@ -87,25 +84,16 @@ namespace opt { // OIA algorithm (without filtering) lbool oia_pareto::operator()() { - model_ref model; solver::scoped_push _s(*m_solver.get()); lbool is_sat = m_solver->check_sat(0, 0); - if (is_sat != l_true) { - return is_sat; - } - while (is_sat == l_true) { - if (m_cancel) { - return l_undef; - } - m_solver->get_model(model); - cb.yield(model); - mk_not_dominated_by(model); - is_sat = m_solver->check_sat(0, 0); - } if (m_cancel) { - return l_undef; + is_sat = l_undef; } - return l_true; + if (is_sat == l_true) { + m_solver->get_model(m_model); + mk_not_dominated_by(); + } + return is_sat; } } diff --git a/src/opt/opt_pareto.h b/src/opt/opt_pareto.h index fa0243807..8a2378ae9 100644 --- a/src/opt/opt_pareto.h +++ b/src/opt/opt_pareto.h @@ -27,7 +27,6 @@ namespace opt { class pareto_callback { public: - virtual void yield(model_ref& model) = 0; virtual unsigned num_objectives() = 0; virtual expr_ref mk_gt(unsigned i, model_ref& model) = 0; virtual expr_ref mk_ge(unsigned i, model_ref& model) = 0; @@ -40,6 +39,7 @@ namespace opt { volatile bool m_cancel; ref m_solver; params_ref m_params; + model_ref m_model; public: pareto_base( ast_manager & m, @@ -72,11 +72,15 @@ namespace opt { } virtual lbool operator()() = 0; + virtual void get_model(model_ref& mdl) { + mdl = m_model; + } + protected: - void mk_dominates(model_ref& model); + void mk_dominates(); - void mk_not_dominated_by(model_ref& model); + void mk_not_dominated_by(); }; class gia_pareto : public pareto_base { public: diff --git a/src/solver/solver.h b/src/solver/solver.h index bd1a26adb..e0d3d30e2 100644 --- a/src/solver/solver.h +++ b/src/solver/solver.h @@ -133,9 +133,11 @@ public: class scoped_push { solver& s; + bool m_nopop; public: - scoped_push(solver& s):s(s) { s.push(); } - ~scoped_push() { s.pop(1); } + scoped_push(solver& s):s(s), m_nopop(false) { s.push(); } + ~scoped_push() { if (!m_nopop) s.pop(1); } + void disable_pop() { m_nopop = true; } }; };