diff --git a/src/opt/opt_cmds.cpp b/src/opt/opt_cmds.cpp index 35652918b..803d38295 100644 --- a/src/opt/opt_cmds.cpp +++ b/src/opt/opt_cmds.cpp @@ -29,7 +29,7 @@ Notes: #include "scoped_ctrl_c.h" #include "scoped_timer.h" #include "parametric_cmd.h" - +#include "objective_ast.h" class opt_context { cmd_context& ctx; @@ -248,10 +248,161 @@ private: } }; +static expr* sexpr2expr(cmd_context & ctx, sexpr * s) { + NOT_IMPLEMENTED_YET(); + return 0; +} + +static opt::objective* sexpr2objective(cmd_context & ctx, sexpr * s) { + if (s->is_symbol()) + throw cmd_exception("invalid objective, more arguments expected ", s->get_symbol(), s->get_line(), s->get_pos()); + if (s->is_composite()) { + sexpr * head = s->get_child(0); + if (!head->is_symbol()) + throw cmd_exception("invalid objective, symbol expected", s->get_line(), s->get_pos()); + symbol const & cmd_name = head->get_symbol(); + if (cmd_name == "maximize" || cmd_name == "minimize") { + if (s->get_num_children() != 2) + throw cmd_exception("invalid objective, wrong number of arguments ", s->get_line(), s->get_pos()); + sexpr * arg = s->get_child(1); + expr_ref term(sexpr2expr(ctx, arg), ctx.m()); + if (cmd_name == "maximize") + return opt::objective::mk_max(term); + else + return opt::objective::mk_min(term); + } + else if (cmd_name == "maxsat") { + if (s->get_num_children() != 2) + throw cmd_exception("invalid objective, wrong number of arguments ", s->get_line(), s->get_pos()); + sexpr * arg = s->get_child(1); + if (!arg->is_symbol()) + throw cmd_exception("invalid objective, symbol expected", s->get_line(), s->get_pos()); + symbol const & id = arg->get_symbol(); + // TODO: check whether id is declared via assert-weighted + return opt::objective::mk_maxsat(id); + } + else if (cmd_name == "lex" || cmd_name == "box" || cmd_name == "pareto") { + if (s->get_num_children() <= 2) + throw cmd_exception("invalid objective, wrong number of arguments ", s->get_line(), s->get_pos()); + unsigned num_children = s->get_num_children(); + ptr_vector args; + for (unsigned i = 1; i < num_children; i++) + args.push_back(sexpr2objective(ctx, s->get_child(i))); + if (cmd_name == "lex") + return opt::objective::mk_lex(args.size(), args.c_ptr()); + else if (cmd_name == "box") + return opt::objective::mk_box(args.size(), args.c_ptr()); + else + return opt::objective::mk_pareto(args.size(), args.c_ptr()); + } + else { + throw cmd_exception("invalid objective, unexpected input", s->get_line(), s->get_pos()); + } + } + return 0; +} + +class execute_cmd : public parametric_cmd { +protected: + sexpr * m_objective; + opt_context& m_opt_ctx; +public: + execute_cmd(opt_context& opt_ctx): + parametric_cmd("optimize"), + m_opt_ctx(opt_ctx) + {} + + virtual void init_pdescrs(cmd_context & ctx, param_descrs & p) { + insert_timeout(p); + insert_max_memory(p); + p.insert("print_statistics", CPK_BOOL, "(default: false) print statistics."); + opt::context::collect_param_descrs(p); + } + + virtual char const * get_main_descr() const { return "check sat modulo objective function";} + virtual char const * get_usage() const { return "( )*"; } + virtual void prepare(cmd_context & ctx) { + parametric_cmd::prepare(ctx); + m_objective = 0; + } + virtual void failure_cleanup(cmd_context & ctx) { + reset(ctx); + } + + virtual cmd_arg_kind next_arg_kind(cmd_context & ctx) const { + if (m_objective == 0) return CPK_SEXPR; + return parametric_cmd::next_arg_kind(ctx); + } + + virtual void set_next_arg(cmd_context & ctx, sexpr * arg) { + m_objective = arg; + } + + virtual void execute(cmd_context & ctx) { + params_ref p = ctx.params().merge_default_params(ps()); + opt::context& opt = m_opt_ctx(); + opt.updt_params(p); + unsigned timeout = p.get_uint("timeout", UINT_MAX); + + ptr_vector::const_iterator it = ctx.begin_assertions(); + ptr_vector::const_iterator end = ctx.end_assertions(); + for (; it != end; ++it) { + opt.add_hard_constraint(*it); + } + lbool r = l_undef; + cancel_eh eh(opt); + { + scoped_ctrl_c ctrlc(eh); + scoped_timer timer(timeout, &eh); + cmd_context::scoped_watch sw(ctx); + try { + opt::objective * o = sexpr2objective(ctx, m_objective); + r = opt.optimize(*o); + dealloc(o); + } + catch (z3_error& ex) { + ctx.regular_stream() << "(error: " << ex.msg() << "\")" << std::endl; + } + catch (z3_exception& ex) { + ctx.regular_stream() << "(error: " << ex.msg() << "\")" << std::endl; + } + } + switch(r) { + case l_true: + ctx.regular_stream() << "sat\n"; + opt.display_assignment(ctx.regular_stream()); + break; + case l_false: + ctx.regular_stream() << "unsat\n"; + break; + case l_undef: + ctx.regular_stream() << "unknown\n"; + opt.display_range_assignment(ctx.regular_stream()); + break; + } + if (p.get_bool("print_statistics", false)) { + display_statistics(ctx); + } + } +private: + + void display_statistics(cmd_context& ctx) { + statistics stats; + unsigned long long max_mem = memory::get_max_used_memory(); + unsigned long long mem = memory::get_allocation_size(); + stats.update("time", ctx.get_seconds()); + stats.update("memory", static_cast(mem)/static_cast(1024*1024)); + stats.update("max memory", static_cast(max_mem)/static_cast(1024*1024)); + m_opt_ctx().collect_statistics(stats); + stats.display_smt2(ctx.regular_stream()); + } +}; + void install_opt_cmds(cmd_context & ctx) { opt_context* opt_ctx = alloc(opt_context, ctx); ctx.insert(alloc(assert_weighted_cmd, ctx, *opt_ctx)); - ctx.insert(alloc(min_maximize_cmd, ctx, *opt_ctx, true)); - ctx.insert(alloc(min_maximize_cmd, ctx, *opt_ctx, false)); - ctx.insert(alloc(optimize_cmd, *opt_ctx)); + ctx.insert(alloc(execute_cmd, *opt_ctx)); + //ctx.insert(alloc(min_maximize_cmd, ctx, *opt_ctx, true)); + //ctx.insert(alloc(min_maximize_cmd, ctx, *opt_ctx, false)); + //ctx.insert(alloc(optimize_cmd, *opt_ctx)); } diff --git a/src/opt/opt_context.cpp b/src/opt/opt_context.cpp index 17d829ab9..13680c16a 100644 --- a/src/opt/opt_context.cpp +++ b/src/opt/opt_context.cpp @@ -111,6 +111,17 @@ namespace opt { return execute_lex(obj); } + lbool context::optimize(objective & objective) { + opt_solver& s = *m_solver.get(); + solver::scoped_push _sp(s); + + for (unsigned i = 0; i < m_hard_constraints.size(); ++i) { + s.assert_expr(m_hard_constraints[i].get()); + } + + return execute(objective, false); + } + lbool context::optimize() { // Construct objectives ptr_vector objectives; @@ -133,14 +144,7 @@ namespace opt { objective = objective::mk_box(objectives.size(), objectives.c_ptr()); } - opt_solver& s = *m_solver.get(); - solver::scoped_push _sp(s); - - for (unsigned i = 0; i < m_hard_constraints.size(); ++i) { - s.assert_expr(m_hard_constraints[i].get()); - } - - lbool result = execute(*objective, false); + lbool result = optimize(*objective); dealloc(objective); return result; } diff --git a/src/opt/opt_context.h b/src/opt/opt_context.h index 14ef2df0b..68cf61d80 100644 --- a/src/opt/opt_context.h +++ b/src/opt/opt_context.h @@ -59,7 +59,9 @@ namespace opt { lbool execute_box(compound_objective & obj); lbool execute_pareto(compound_objective & obj); + lbool optimize(objective & objective); lbool optimize(); + void set_cancel(bool f); void reset_cancel() { set_cancel(false); } void cancel() { set_cancel(true); }