diff --git a/src/opt/opt_context.cpp b/src/opt/opt_context.cpp index 93f7e1acd..69ec91d85 100644 --- a/src/opt/opt_context.cpp +++ b/src/opt/opt_context.cpp @@ -37,38 +37,24 @@ namespace opt { { m_params.set_bool("model", true); m_params.set_bool("unsat_core", true); + m_solver = alloc(opt_solver, m, m_params, symbol()); } void context::optimize() { - if (!m_solver) { - symbol logic; - set_solver(alloc(opt_solver, m, m_params, logic)); - } - - // really just works for opt_solver now. - solver* s = m_solver.get(); - opt_solver::scoped_push _sp(*s); + opt_solver& s = *m_solver.get(); + opt_solver::scoped_push _sp(s); for (unsigned i = 0; i < m_hard_constraints.size(); ++i) { - s->assert_expr(m_hard_constraints[i].get()); + s.assert_expr(m_hard_constraints[i].get()); } - lbool is_sat; - - is_sat = m_maxsmt(*s); - - expr_ref_vector ans = m_maxsmt.get_assignment(); - for (unsigned i = 0; i < ans.size(); ++i) { - s->assert_expr(ans[i].get()); - } - + lbool is_sat = m_maxsmt(s); if (is_sat == l_true) { - is_sat = m_optsmt(opt_solver::to_opt(*s)); + is_sat = m_optsmt(s); } } - void context::set_cancel(bool f) { if (m_solver) { m_solver->set_cancel(f); @@ -93,6 +79,7 @@ namespace opt { m_solver->updt_params(m_params); } m_optsmt.updt_params(m_params); + m_maxsmt.updt_params(m_params); } diff --git a/src/opt/opt_context.h b/src/opt/opt_context.h index df2ca8248..22fd5fd7a 100644 --- a/src/opt/opt_context.h +++ b/src/opt/opt_context.h @@ -26,7 +26,7 @@ Notes: #define _OPT_CONTEXT_H_ #include "ast.h" -#include "solver.h" +#include "opt_solver.h" #include "optsmt.h" #include "maxsmt.h" @@ -37,7 +37,7 @@ namespace opt { class context { ast_manager& m; expr_ref_vector m_hard_constraints; - ref m_solver; + ref m_solver; params_ref m_params; optsmt m_optsmt; maxsmt m_maxsmt; @@ -47,7 +47,6 @@ namespace opt { void add_soft_constraint(expr* f, rational const& w) { m_maxsmt.add(f, w); } void add_objective(app* t, bool is_max) { m_optsmt.add(t, is_max); } void add_hard_constraint(expr* f) { m_hard_constraints.push_back(f); } - void set_solver(solver* s) { m_solver = s; } void optimize(); void set_cancel(bool f); void reset_cancel() { set_cancel(false); } @@ -55,12 +54,6 @@ namespace opt { void collect_statistics(statistics& stats); static void collect_param_descrs(param_descrs & r); void updt_params(params_ref& p); - - private: - bool is_maxsat_problem() const; - - opt_solver& get_opt_solver(solver& s); - }; } diff --git a/src/opt/optsmt.cpp b/src/opt/optsmt.cpp index 9c7268913..a6dbd30ab 100644 --- a/src/opt/optsmt.cpp +++ b/src/opt/optsmt.cpp @@ -132,9 +132,7 @@ namespace opt { lbool optsmt::update_upper() { smt::theory_opt& opt = s->get_optimizer(); - SASSERT(typeid(smt::theory_inf_arith) == typeid(opt)); - smt::theory_inf_arith& th = dynamic_cast(opt); expr_ref bound(m); @@ -246,13 +244,28 @@ namespace opt { return is_sat; } - inf_eps optsmt::get_value(unsigned index) const { - if (m_is_max[index]) { - return m_lower[index]; - } - else { - return -m_lower[index]; - } + inf_eps optsmt::get_value(unsigned i) const { + return m_is_max[i]?m_lower[i]:-m_lower[i]; + } + + inf_eps optsmt::get_lower(unsigned i) const { + return m_is_max[i]?m_lower[i]:-m_upper[i]; + } + + inf_eps optsmt::get_upper(unsigned i) const { + return m_is_max[i]?m_upper[i]:-m_lower[i]; + } + + // force lower_bound(i) <= objective_value(i) + void optsmt::commit_assignment(unsigned i) { + smt::theory_var v = m_vars[i]; + + // TBD: this should be a method on all optimization solvers. + smt::theory_opt& opt = s->get_optimizer(); + SASSERT(typeid(smt::theory_inf_arith) == typeid(opt)); + smt::theory_inf_arith& th = dynamic_cast(opt); + + s->assert_expr(th.block_upper_bound(v, get_lower(i))); } void optsmt::display(std::ostream& out) const { diff --git a/src/opt/optsmt.h b/src/opt/optsmt.h index 297f924ae..b892ef38f 100644 --- a/src/opt/optsmt.h +++ b/src/opt/optsmt.h @@ -44,6 +44,8 @@ namespace opt { void add(app* t, bool is_max); + void commit_assignment(unsigned i); + void set_cancel(bool f); void updt_params(params_ref& p); @@ -51,7 +53,9 @@ namespace opt { void display(std::ostream& out) const; inf_eps get_value(unsigned index) const; - + inf_eps get_lower(unsigned index) const; + inf_eps get_upper(unsigned index) const; + private: lbool basic_opt();