diff --git a/src/opt/fu_malik.cpp b/src/opt/fu_malik.cpp index 07d256d9e..f242c1a37 100644 --- a/src/opt/fu_malik.cpp +++ b/src/opt/fu_malik.cpp @@ -37,13 +37,13 @@ namespace opt { class fu_malik { ast_manager& m; - solver& s; + ::solver& s; expr_ref_vector m_soft; expr_ref_vector m_aux; public: - fu_malik(ast_manager& m, solver& s, expr_ref_vector const& soft): + fu_malik(ast_manager& m, ::solver& s, expr_ref_vector const& soft): m(m), s(s), m_soft(soft), @@ -132,7 +132,7 @@ namespace opt { }; - lbool fu_malik_maxsat(solver& s, expr_ref_vector& soft_constraints) { + lbool fu_malik_maxsat(::solver& s, expr_ref_vector& soft_constraints) { ast_manager& m = soft_constraints.get_manager(); lbool is_sat = s.check_sat(0,0); if (!soft_constraints.empty() && is_sat == l_true) { diff --git a/src/opt/fu_malik.h b/src/opt/fu_malik.h index 6bfe37f63..e46032adf 100644 --- a/src/opt/fu_malik.h +++ b/src/opt/fu_malik.h @@ -28,7 +28,7 @@ namespace opt { that are still consistent with the solver state. */ - lbool fu_malik_maxsat(solver& s, expr_ref_vector& soft_constraints); + lbool fu_malik_maxsat(::solver& s, expr_ref_vector& soft_constraints); }; #endif diff --git a/src/opt/opt_context.cpp b/src/opt/opt_context.cpp index dbbcffa7c..b02d78466 100644 --- a/src/opt/opt_context.cpp +++ b/src/opt/opt_context.cpp @@ -18,11 +18,11 @@ Notes: #include "opt_context.h" -#include "smt_solver.h" #include "fu_malik.h" #include "weighted_maxsat.h" #include "optimize_objectives.h" #include "ast_pp.h" +#include "opt_solver.h" namespace opt { @@ -30,12 +30,14 @@ namespace opt { expr_ref_vector const& fmls = m_soft_constraints; - ref s; - symbol logic; - params_ref p; - p.set_bool("model", true); - p.set_bool("unsat_core", true); - s = mk_smt_solver(m, p, logic); + if (!m_solver) { + symbol logic; + params_ref p; + p.set_bool("model", true); + p.set_bool("unsat_core", true); + set_solver(alloc(opt_solver, m, p, logic)); + } + solver* s = m_solver.get(); for (unsigned i = 0; i < m_hard_constraints.size(); ++i) { s->assert_expr(m_hard_constraints[i].get()); @@ -64,7 +66,9 @@ namespace opt { for (unsigned i = 0; i < fmls_copy.size(); ++i) { s->assert_expr(fmls_copy[i].get()); } - is_sat = optimize_objectives(*s, m_objectives, m_is_max, values); + // SASSERT(instanceof(*s, opt_solver)); + // if (!instsanceof ...) { throw ... invalid usage ..} + is_sat = optimize_objectives(dynamic_cast(*s), m_objectives, m_is_max, values); std::cout << "is-sat: " << is_sat << "\n"; if (is_sat != l_true) { return; diff --git a/src/opt/opt_context.h b/src/opt/opt_context.h index 9e226389e..0555cdd83 100644 --- a/src/opt/opt_context.h +++ b/src/opt/opt_context.h @@ -19,6 +19,7 @@ Notes: #define _OPT_CONTEXT_H_ #include "ast.h" +#include "solver.h" namespace opt { @@ -30,6 +31,7 @@ namespace opt { expr_ref_vector m_objectives; svector m_is_max; + ref<::solver> m_solver; public: context(ast_manager& m): @@ -53,6 +55,10 @@ namespace opt { m_hard_constraints.push_back(f); } + void set_solver(::solver* s) { + m_solver = s; + } + void optimize(); private: @@ -63,4 +69,4 @@ namespace opt { } -#endif \ No newline at end of file +#endif diff --git a/src/opt/opt_solver.cpp b/src/opt/opt_solver.cpp new file mode 100644 index 000000000..4989993b4 --- /dev/null +++ b/src/opt/opt_solver.cpp @@ -0,0 +1,122 @@ +#include"reg_decl_plugins.h" +#include"opt_solver.h" +#include"smt_context.h" +#include"theory_arith.h" + +namespace opt { + + opt_solver::opt_solver(ast_manager & m, params_ref const & p, symbol const & l): + solver_na2as(m), + m_params(p), + m_context(m, m_params), + m_objective_enabled(false) { + m_logic = l; + if (m_logic != symbol::null) + m_context.set_logic(m_logic); + } + + opt_solver::~opt_solver() { + } + + void opt_solver::updt_params(params_ref const & p) { + m_params.updt_params(p); + m_context.updt_params(p); + } + + void opt_solver::collect_param_descrs(param_descrs & r) { + m_context.collect_param_descrs(r); + } + + void opt_solver::collect_statistics(statistics & st) const { + m_context.collect_statistics(st); + } + + void opt_solver::assert_expr(expr * t) { + m_context.assert_expr(t); + } + + void opt_solver::push_core() { + m_context.push(); + } + + void opt_solver::pop_core(unsigned n) { + m_context.pop(n); + } + +#define ACCESS_ARITHMETIC_CLASS(_code_) \ + smt::context& ctx = m_context.get_context(); \ + smt::theory_id arith_id = m_context.m().get_family_id("arith"); \ + smt::theory* arith_theory = ctx.get_theory(arith_id); \ + if (typeid(smt::theory_mi_arith) == typeid(*arith_theory)) { \ + smt::theory_mi_arith& th = dynamic_cast(*arith_theory); \ + _code_; \ + } \ + else if (typeid(smt::theory_i_arith) == typeid(*arith_theory)) { \ + smt::theory_i_arith& th = dynamic_cast(*arith_theory); \ + _code_; \ + } + + + lbool opt_solver::check_sat_core(unsigned num_assumptions, expr * const * assumptions) { + TRACE("opt_solver_na2as", tout << "smt_opt_solver::check_sat_core: " << num_assumptions << "\n";); + lbool r = m_context.check(num_assumptions, assumptions); + if (r == l_true &&& m_objective_enabled) { + ACCESS_ARITHMETIC_CLASS(th.min(m_objective_var);); + } + return r; + } + + void opt_solver::get_unsat_core(ptr_vector & r) { + unsigned sz = m_context.get_unsat_core_size(); + for (unsigned i = 0; i < sz; i++) + r.push_back(m_context.get_unsat_core_expr(i)); + } + + void opt_solver::get_model(model_ref & m) { + m_context.get_model(m); + } + + proof * opt_solver::get_proof() { + return m_context.get_proof(); + } + + std::string opt_solver::reason_unknown() const { + return m_context.last_failure_as_string(); + } + + void opt_solver::get_labels(svector & r) { + buffer tmp; + m_context.get_relevant_labels(0, tmp); + r.append(tmp.size(), tmp.c_ptr()); + } + + void opt_solver::set_cancel(bool f) { + m_context.set_cancel(f); + } + + void opt_solver::set_progress_callback(progress_callback * callback) { + m_callback = callback; + m_context.set_progress_callback(callback); + } + + unsigned opt_solver::get_num_assertions() const { + return m_context.size(); + } + + expr * opt_solver::get_assertion(unsigned idx) const { + SASSERT(idx < get_num_assertions()); + return m_context.get_formulas()[idx]; + } + + void opt_solver::display(std::ostream & out) const { + m_context.display(out); + } + + void opt_solver::set_objective(app* term) { + ACCESS_ARITHMETIC_CLASS(m_objective_var = th.set_objective(term);); + } + + void opt_solver::toggle_objective(bool enable) { + m_objective_enabled = enable; + } +} diff --git a/src/opt/opt_solver.h b/src/opt/opt_solver.h new file mode 100644 index 000000000..a6a48d3f9 --- /dev/null +++ b/src/opt/opt_solver.h @@ -0,0 +1,66 @@ +/*++ +Copyright (c) 2012 Microsoft Corporation + +Module Name: + + smt_solver.h + +Abstract: + + Wraps smt::kernel as a solver for the external API and cmd_context. + +Author: + + Leonardo (leonardo) 2012-10-21 + +Notes: + + Variant of smt_solver that exposes kernel object. +--*/ +#ifndef _OPT_SOLVER_H_ +#define _OPT_SOLVER_H_ + +#include"ast.h" +#include"params.h" +#include"solver_na2as.h" +#include"smt_kernel.h" +#include"smt_params.h" +#include"smt_types.h" + +namespace opt { + + class opt_solver : public solver_na2as { + smt_params m_params; + smt::kernel m_context; + progress_callback * m_callback; + symbol m_logic; + bool m_objective_enabled; + smt::theory_var m_objective_var; + public: + opt_solver(ast_manager & m, params_ref const & p, symbol const & l); + virtual ~opt_solver(); + + virtual void updt_params(params_ref const & p); + virtual void collect_param_descrs(param_descrs & r); + virtual void collect_statistics(statistics & st) const; + virtual void assert_expr(expr * t); + virtual void push_core(); + virtual void pop_core(unsigned n); + virtual lbool check_sat_core(unsigned num_assumptions, expr * const * assumptions); + virtual void get_unsat_core(ptr_vector & r); + virtual void get_model(model_ref & m); + virtual proof * get_proof(); + virtual std::string reason_unknown() const; + virtual void get_labels(svector & r); + virtual void set_cancel(bool f); + virtual void set_progress_callback(progress_callback * callback); + virtual unsigned get_num_assertions() const; + virtual expr * get_assertion(unsigned idx) const; + virtual void display(std::ostream & out) const; + + void set_objective(app* term); + void toggle_objective(bool enable); + }; +} + +#endif diff --git a/src/opt/optimize_objectives.cpp b/src/opt/optimize_objectives.cpp index 6f7f5c85e..0d865d7f9 100644 --- a/src/opt/optimize_objectives.cpp +++ b/src/opt/optimize_objectives.cpp @@ -17,15 +17,18 @@ Notes: --*/ +#ifndef _OPT_OBJECTIVE_H_ +#define _OPT_OBJECTIVE_H_ #include "optimize_objectives.h" +#include "opt_solver.h" namespace opt { /* Enumerate locally optimal assignments until fixedpoint. */ - lbool mathsat_style_opt(solver& s, + lbool mathsat_style_opt(opt_solver& s, expr_ref_vector& objectives, svector const& is_max, vector >& values) { lbool is_sat; @@ -52,9 +55,11 @@ namespace opt { Returns an optimal assignment to objective functions. */ - lbool optimize_objectives(solver& s, + lbool optimize_objectives(opt_solver& s, expr_ref_vector& objectives, svector const& is_max, vector >& values) { return mathsat_style_opt(s, objectives, is_max, values); } } + +#endif diff --git a/src/opt/optimize_objectives.h b/src/opt/optimize_objectives.h index a6c8e3602..f14cb37db 100644 --- a/src/opt/optimize_objectives.h +++ b/src/opt/optimize_objectives.h @@ -19,7 +19,7 @@ Notes: #ifndef _OPT_OBJECTIVES_H_ #define _OPT_OBJECTIVES_H_ -#include "solver.h" +#include "opt_solver.h" namespace opt { /** @@ -27,9 +27,9 @@ namespace opt { Returns an optimal assignment to objective functions. */ - lbool optimize_objectives(solver& s, + lbool optimize_objectives(opt_solver& s, expr_ref_vector& objectives, svector const& is_max, vector >& values); }; -#endif \ No newline at end of file +#endif diff --git a/src/smt/theory_arith.h b/src/smt/theory_arith.h index e7037f31a..d9c0efb72 100644 --- a/src/smt/theory_arith.h +++ b/src/smt/theory_arith.h @@ -985,6 +985,15 @@ namespace smt { // ----------------------------------- virtual bool get_value(enode * n, expr_ref & r); + // ----------------------------------- + // + // Optimization + // + // ----------------------------------- + + void min(theory_var v); + theory_var set_objective(app* term); + // ----------------------------------- // diff --git a/src/smt/theory_arith_aux.h b/src/smt/theory_arith_aux.h index 9f77934e5..e668997fc 100644 --- a/src/smt/theory_arith_aux.h +++ b/src/smt/theory_arith_aux.h @@ -948,6 +948,24 @@ namespace smt { return x_i; } + /** + \brief minimize the given variable. + TODO: max_min returns a bool. What does this do? + */ + template + void theory_arith::min(theory_var v) { + max_min(v, false); + } + + // set_objective(expr* term) internalizes the arithmetic term and creates + // a row for it if it is not already internalized. Then return the variable + // corresponding to the term. + // TODO handle case where internalize fails. e.g., check for this in a suitable way. + template + theory_var theory_arith::set_objective(app* term) { + return internalize_term_core(term); + } + /** \brief Maximize (Minimize) the given temporary row. Return true if succeeded.