diff --git a/doc/design_recfuns.md b/doc/design_recfuns.md index 9dbe5c874..89980931e 100644 --- a/doc/design_recfuns.md +++ b/doc/design_recfuns.md @@ -83,3 +83,11 @@ This addition to Z3 would bring many benefits compared to current alternatives ( recursive functions - makes `define-funs-rec` a first-class citizen of the language, usable to model user-defined theories or to analyze functional programs directly + +## Optimizations + +- maybe `C_f_i` literals should never be decided on + (they can always be propagated). + Even stronger: they should not be part of conflicts? + (i.e. tune conflict resolution to always resolve + these literals away, disregarding their level) diff --git a/src/ast/CMakeLists.txt b/src/ast/CMakeLists.txt index 4dcdd2a35..f452c07d5 100644 --- a/src/ast/CMakeLists.txt +++ b/src/ast/CMakeLists.txt @@ -36,6 +36,7 @@ z3_add_component(ast occurs.cpp pb_decl_plugin.cpp pp.cpp + recfun_decl_plugin.cpp reg_decl_plugins.cpp seq_decl_plugin.cpp shared_occs.cpp diff --git a/src/ast/ast_pp.h b/src/ast/ast_pp.h index 997b1a6e0..514aba32f 100644 --- a/src/ast/ast_pp.h +++ b/src/ast/ast_pp.h @@ -32,5 +32,27 @@ struct mk_pp : public mk_ismt2_pp { } }; +// +#include +#include "ast/expr_functors.h" +#include "ast/expr_substitution.h" +#include "ast/recfun_decl_plugin.h" +#include "ast/ast_pp.h" +#include "util/scoped_ptr_vector.h" + +#define DEBUG(x) do { auto& out = std::cout; out << "recfun: "; x; out << '\n' << std::flush; } while(0) +#define VALIDATE_PARAM(m, _pred_) if (!(_pred_)) m.raise_exception("invalid parameter to recfun " #_pred_); + + +namespace recfun { + case_pred::case_pred(ast_manager & m, family_id fid, std::string const & s, sort_ref_vector const & domain) + : m_name(), m_name_buf(s), m_domain(domain), m_decl(m) + { + m_name = symbol(m_name_buf.c_str()); + func_decl_info info(fid, OP_FUN_CASE_PRED); + m_decl = m.mk_func_decl(m_name, domain.size(), domain.c_ptr(), m.mk_bool_sort(), info); + } + + case_def::case_def(ast_manager &m, + family_id fid, + def * d, + std::string & name, + sort_ref_vector const & arg_sorts, + unsigned num_guards, expr ** guards, expr* rhs) + : m_pred(m, fid, name, arg_sorts), m_guards(m), m_rhs(expr_ref(rhs,m)), m_def(d) { + for (unsigned i=0; i m_branches; + + public: + case_state(ast_manager & m) : m_reg(), m_manager(m), m_branches() {} + + bool empty() const { return m_branches.empty(); } + ast_manager & m() const { return m_manager; } + region & reg() { return m_reg; } + + branch pop_branch() { + branch res = m_branches.back(); + m_branches.pop_back(); + return res; + } + + void push_branch(branch const & b) { m_branches.push_back(b); } + + + unfold_lst const * cons_unfold(expr * e, unfold_lst const * next) { + return new (reg()) unfold_lst{e, next}; + } + unfold_lst const * cons_unfold(expr * e1, expr * e2, unfold_lst const * next) { + return cons_unfold(e1, cons_unfold(e2, next)); + } + unfold_lst const * mk_unfold_lst(expr * e) { + return cons_unfold(e, nullptr); + } + + ite_lst const * cons_ite(app * ite, ite_lst const * next) { + return new (reg()) ite_lst{ite, next}; + } + + choice_lst const * cons_choice(app * ite, bool sign, choice_lst const * next) { + return new (reg()) choice_lst{ite, sign, next}; + } + }; + + //next) { + app * ite = choices->ite; + SASSERT(m.is_ite(ite)); + + // condition to add to the guard + expr * cond0 = ite->get_arg(0); + conditions.push_back(choices->sign ? cond0 : m.mk_not(cond0)); + + // binding to add to the substitution + subst.insert(ite, choices->sign ? ite->get_arg(1) : ite->get_arg(2)); + } + } + + // substitute `subst` in `e` + expr_ref replace_subst(th_rewriter & th_rw, ast_manager & m, + expr_substitution & subst, expr * e) { + th_rw.reset(); + th_rw.set_substitution(&subst); + expr_ref res(m); + th_rw(e, res); + return res; + } + + void def::add_case(std::string & name, unsigned n_conditions, expr ** conditions, expr * rhs, bool is_imm) { + case_def c(m(), m_fid, this, name, get_domain(), n_conditions, conditions, rhs); + c.set_is_immediate(is_imm); + TRACE("recfun", tout << "add_case " << name << " " << mk_pp(rhs, m()) + << " :is_imm " << is_imm + << " :guards " << mk_pp_vec(n_conditions, (ast**)conditions, m());); + DEBUG(out << "add_case " << name << " " << mk_pp(rhs, m()) + << " :is_imm " << is_imm + << " :guards " << mk_pp_vec(n_conditions, (ast**)conditions, m())); + m_cases.push_back(c); + } + + + // Compute a set of cases, given the RHS + void def::compute_cases(is_immediate_pred & is_i, th_rewriter & th_rw, + unsigned n_vars, var *const * vars, expr* rhs0) + { + if (m_cases.size() != 0) { + TRACE("recfun", tout << "bug: cases for " << m_name << " has cases already";); + UNREACHABLE(); + } + SASSERT(n_vars = m_domain.size()); + + DEBUG(out << "compute cases " << mk_pp(rhs0, m())); + + unsigned case_idx = 0; + std::string name; + + name.append("case_"); + name.append(m_name.bare_str()); + name.append("_"); + + for (unsigned i=0; i stack; + stack.push_back(b.to_unfold->e); + + b.to_unfold = b.to_unfold->next; + + while (! stack.empty()) { + expr * e = stack.back(); + stack.pop_back(); + + if (m().is_ite(e)) { + // need to do a case split on `e`, forking the search space + b.to_split = st.cons_ite(to_app(e), b.to_split); + } else if (is_app(e)) { + // explore arguments + app * a = to_app(e); + + for (unsigned i=0; i < a->get_num_args(); ++i) + stack.push_back(a->get_arg(i)); + } + } + } + + if (b.to_split != nullptr) { + // split one `ite`, which will lead to distinct (sets of) cases + app * ite = b.to_split->ite; + SASSERT(m().is_ite(ite)); + + /* explore both positive choice and negative choice. + * each contains a longer path, with `ite` mapping to `true` (resp. `false), + * and must unfold the `then` (resp. `else`) branch. + * We must also unfold the test itself, for it could contain + * tests. + */ + + branch b_pos(st.cons_choice(ite, true, b.path), + b.to_split->next, + st.cons_unfold(ite->get_arg(0), ite->get_arg(1), b.to_unfold)); + branch b_neg(st.cons_choice(ite, false, b.path), + b.to_split->next, + st.cons_unfold(ite->get_arg(0), ite->get_arg(2), b.to_unfold)); + + st.push_branch(b_neg); + st.push_branch(b_pos); + } + else { + // leaf of the search tree + + expr_ref_vector conditions_raw(m()); + expr_substitution subst(m()); + convert_path(m(), b.path, conditions_raw, subst); + + // substitute, to get rid of `ite` terms + expr_ref case_rhs = replace_subst(th_rw, m(), subst, rhs); + expr_ref_vector conditions(m()); + for (expr * g : conditions_raw) { + expr_ref g_subst(replace_subst(th_rw, m(), subst, g), m()); + conditions.push_back(g_subst); + } + + + unsigned old_name_len = name.size(); + { // TODO: optimize? this does many copies + std::ostringstream sout; + sout << ((unsigned long) case_idx); + name.append(sout.str()); + } + case_idx ++; + + // yield new case + bool is_imm = is_i(case_rhs); + add_case(name, conditions.size(), conditions.c_ptr(), case_rhs, is_imm); + name.resize(old_name_len); + } + } + + TRACE("recfun", tout << "done analysing " << get_name();); + } + + /* + * Main manager for defined functions + */ + + util::util(ast_manager & m, family_id id) + : m_manager(m), m_family_id(id), m_th_rw(m), m_plugin(0) { + m_plugin = dynamic_cast(m.get_plugin(m_family_id)); + } + + def * util::decl_fun(symbol const& name, unsigned n, sort *const * domain, sort * range) { + return alloc(def, m(), m_family_id, name, n, domain, range); + } + + void util::set_definition(promise_def & d, unsigned n_vars, var * const * vars, expr * rhs) { + d.set_definition(n_vars, vars, rhs); + } + + + // used to know which `app` are from this theory + struct is_imm_pred : is_immediate_pred { + util & u; + is_imm_pred(util & u) : u(u) {} + bool operator()(expr * rhs) { + // find an `app` that is an application of a defined function + struct find : public i_expr_pred { + util & u; + find(util & u) : u(u) {} + bool operator()(expr * e) override { + //return is_app(e) ? u.owns_app(to_app(e)) : false; + if (! is_app(e)) return false; + + app * a = to_app(e); + return u.is_defined(a); + } + }; + find f(u); + check_pred cp(f, u.m()); + bool contains_defined_fun = cp(rhs); + return ! contains_defined_fun; + } + }; + + // set definition + void promise_def::set_definition(unsigned n_vars, var * const * vars, expr * rhs) { + SASSERT(n_vars == d->get_arity()); + + is_imm_pred is_i(*u); + d->compute_cases(is_i, u->get_th_rewriter(), n_vars, vars, rhs); + } + + namespace decl { + plugin::plugin() : decl_plugin(), m_defs(), m_case_defs(), m_def_block() {} + plugin::~plugin() { finalize(); } + + void plugin::finalize() { + for (auto& kv : m_defs) { + dealloc(kv.m_value); + } + m_defs.reset(); + // m_case_defs does not own its data, no need to deallocate + m_case_defs.reset(); + m_util = 0; // force deletion + } + + util & plugin::u() const { + SASSERT(m_manager); + SASSERT(m_family_id != null_family_id); + if (m_util.get() == 0) { + m_util = alloc(util, *m_manager, m_family_id); + } + return *(m_util.get()); + } + + promise_def plugin::mk_def(symbol const& name, unsigned n, sort *const * params, sort * range) { + SASSERT(! m_defs.contains(name)); + def* d = u().decl_fun(name, n, params, range); + m_defs.insert(name, d); + return promise_def(&u(), d); + } + + void plugin::set_definition(promise_def & d, unsigned n_vars, var * const * vars, expr * rhs) { + u().set_definition(d, n_vars, vars, rhs); + for (case_def & c : d.get_def()->get_cases()) { + m_case_defs.insert(c.get_name(), &c); + } + } + + def* plugin::mk_def(symbol const& name, unsigned n, sort ** params, sort * range, + unsigned n_vars, var ** vars, expr * rhs) { + SASSERT(! m_defs.contains(name)); + promise_def d = mk_def(name, n, params, range); + set_definition(d, n_vars, vars, rhs); + return d.get_def(); + } + + func_decl * plugin::mk_fun_pred_decl(unsigned num_parameters, parameter const * parameters, + unsigned arity, sort * const * domain, sort * range) + { + VALIDATE_PARAM(m(), m().is_bool(range) && num_parameters == 1 && parameters[0].is_ast()); + func_decl_info info(m_family_id, OP_FUN_CASE_PRED, num_parameters, parameters); + info.m_private_parameters = true; + return m().mk_func_decl(symbol(parameters[0].get_symbol()), arity, domain, range, info); + } + + func_decl * plugin::mk_fun_defined_decl(decl_kind k, unsigned num_parameters, + parameter const * parameters, + unsigned arity, sort * const * domain, sort * range) + { + VALIDATE_PARAM(m(), num_parameters == 1 && parameters[0].is_ast()); + func_decl_info info(m_family_id, k, num_parameters, parameters); + info.m_private_parameters = true; + return m().mk_func_decl(symbol(parameters[0].get_symbol()), arity, + domain, range, info); + } + + // generic declaration of symbols + func_decl * plugin::mk_func_decl(decl_kind k, unsigned num_parameters, parameter const * parameters, + unsigned arity, sort * const * domain, sort * range) + { + switch(k) { + case OP_FUN_CASE_PRED: + return mk_fun_pred_decl(num_parameters, parameters, arity, domain, range); + case OP_FUN_DEFINED: + return mk_fun_defined_decl(k, num_parameters, parameters, arity, domain, range); + default: + UNREACHABLE(); return 0; + } + } + } +} \ No newline at end of file diff --git a/src/ast/recfun_decl_plugin.h b/src/ast/recfun_decl_plugin.h new file mode 100644 index 000000000..e9b75a409 --- /dev/null +++ b/src/ast/recfun_decl_plugin.h @@ -0,0 +1,246 @@ +/*++ +Module Name: + + recfun_decl_plugin.h + +Abstract: + + Declaration and definition of (potentially recursive) functions + +Author: + + Simon Cruanes 2017-11 + +Revision History: + +--*/ + +#pragma once + +#include "ast/ast.h" +#include "ast/rewriter/th_rewriter.h" + +namespace recfun { + class case_def; // cases; + + ast_manager & m_manager; + symbol m_name; // def_map; + typedef map case_def_map; + + mutable scoped_ptr m_util; + def_map m_defs; // function->def + case_def_map m_case_defs; // case_pred->def + svector m_def_block; + + ast_manager & m() { return *m_manager; } + public: + plugin(); + virtual ~plugin() override; + virtual void finalize() override; + + util & u() const; // build or return util + + virtual bool is_fully_interp(sort * s) const override { return false; } // might depend on unin sorts + + virtual decl_plugin * mk_fresh() override { return alloc(plugin); } + + virtual sort * mk_sort(decl_kind k, unsigned num_parameters, parameter const * parameters) override { UNREACHABLE(); return 0; } + + virtual func_decl * mk_func_decl(decl_kind k, unsigned num_parameters, parameter const * parameters, + unsigned arity, sort * const * domain, sort * range) override; + + promise_def mk_def(symbol const& name, unsigned n, sort *const * params, sort * range); + + void set_definition(promise_def & d, unsigned n_vars, var * const * vars, expr * rhs); + + def* mk_def(symbol const& name, unsigned n, sort ** params, sort * range, unsigned n_vars, var ** vars, expr * rhs); + + bool has_def(const symbol& s) const { return m_defs.contains(s); } + def const& get_def(const symbol& s) const { return *(m_defs[s]); } + promise_def get_promise_def(const symbol &s) const { return promise_def(&u(), m_defs[s]); } + def& get_def(symbol const& s) { return *(m_defs[s]); } + bool has_case_def(const symbol& s) const { return m_case_defs.contains(s); } + case_def& get_case_def(symbol const& s) { SASSERT(has_case_def(s)); return *(m_case_defs[s]); } + bool is_declared(symbol const& s) const { return m_defs.contains(s); } + private: + func_decl * mk_fun_pred_decl(unsigned num_parameters, parameter const * parameters, + unsigned arity, sort * const * domain, sort * range); + func_decl * mk_fun_defined_decl(decl_kind k, + unsigned num_parameters, parameter const * parameters, + unsigned arity, sort * const * domain, sort * range); + }; + } + + // Varus utils for recursive functions + class util { + friend class decl::plugin; + + ast_manager & m_manager; + family_id m_family_id; + th_rewriter m_th_rw; + decl::plugin * m_plugin; + + bool compute_is_immediate(expr * rhs); + void set_definition(promise_def & d, unsigned n_vars, var * const * vars, expr * rhs); + public: + util(ast_manager &m, family_id); + + ast_manager & m() { return m_manager; } + th_rewriter & get_th_rewriter() { return m_th_rw; } + bool is_case_pred(app * e) const { return is_app_of(e, m_family_id, OP_FUN_CASE_PRED); } + bool is_defined(app * e) const { return is_app_of(e, m_family_id, OP_FUN_DEFINED); } + bool owns_app(app * e) const { return e->get_family_id() == m_family_id; } + + //has_def(s)); + return m_plugin->get_def(s); + } + + case_def& get_case_def(symbol const & s) { + SASSERT(m_plugin->has_case_def(s)); + return m_plugin->get_case_def(s); + } + + app* mk_fun_defined(def const & d, unsigned n_args, expr * const * args) { + return m().mk_app(d.get_decl(), n_args, args); + } + app* mk_fun_defined(def const & d, ptr_vector const & args) { + return mk_fun_defined(d, args.size(), args.c_ptr()); + } + app* mk_case_pred(case_pred const & p, ptr_vector const & args) { + return m().mk_app(p.get_decl(), args.size(), args.c_ptr()); + } + }; +} + +typedef recfun::def recfun_def; +typedef recfun::case_def recfun_case_def; +typedef recfun::decl::plugin recfun_decl_plugin; +typedef recfun::util recfun_util; diff --git a/src/ast/reg_decl_plugins.cpp b/src/ast/reg_decl_plugins.cpp index 985ecee9e..f8abe81d8 100644 --- a/src/ast/reg_decl_plugins.cpp +++ b/src/ast/reg_decl_plugins.cpp @@ -22,6 +22,7 @@ Revision History: #include "ast/array_decl_plugin.h" #include "ast/bv_decl_plugin.h" #include "ast/datatype_decl_plugin.h" +#include "ast/recfun_decl_plugin.h" #include "ast/dl_decl_plugin.h" #include "ast/seq_decl_plugin.h" #include "ast/pb_decl_plugin.h" @@ -40,6 +41,9 @@ void reg_decl_plugins(ast_manager & m) { if (!m.get_plugin(m.mk_family_id(symbol("datatype")))) { m.register_plugin(symbol("datatype"), alloc(datatype_decl_plugin)); } + if (!m.get_plugin(m.mk_family_id(symbol("recfun")))) { + m.register_plugin(symbol("recfun"), alloc(recfun_decl_plugin)); + } if (!m.get_plugin(m.mk_family_id(symbol("datalog_relation")))) { m.register_plugin(symbol("datalog_relation"), alloc(datalog::dl_decl_plugin)); } diff --git a/src/cmd_context/cmd_context.cpp b/src/cmd_context/cmd_context.cpp index ec2eea032..75ed4e7eb 100644 --- a/src/cmd_context/cmd_context.cpp +++ b/src/cmd_context/cmd_context.cpp @@ -678,6 +678,8 @@ bool cmd_context::logic_has_datatype() const { return !has_logic() || smt_logics::logic_has_datatype(m_logic); } +bool cmd_context::logic_has_recfun() const { return true; } + void cmd_context::init_manager_core(bool new_manager) { SASSERT(m_manager != 0); SASSERT(m_pmanager != 0); @@ -690,6 +692,7 @@ void cmd_context::init_manager_core(bool new_manager) { register_plugin(symbol("bv"), alloc(bv_decl_plugin), logic_has_bv()); register_plugin(symbol("array"), alloc(array_decl_plugin), logic_has_array()); register_plugin(symbol("datatype"), alloc(datatype_decl_plugin), logic_has_datatype()); + register_plugin(symbol("recfun"), alloc(recfun_decl_plugin), logic_has_recfun()); register_plugin(symbol("seq"), alloc(seq_decl_plugin), logic_has_seq()); register_plugin(symbol("pb"), alloc(pb_decl_plugin), logic_has_pb()); register_plugin(symbol("fpa"), alloc(fpa_decl_plugin), logic_has_fpa()); @@ -705,6 +708,7 @@ void cmd_context::init_manager_core(bool new_manager) { load_plugin(symbol("bv"), logic_has_bv(), fids); load_plugin(symbol("array"), logic_has_array(), fids); load_plugin(symbol("datatype"), logic_has_datatype(), fids); + load_plugin(symbol("recfun"), logic_has_recfun(), fids); load_plugin(symbol("seq"), logic_has_seq(), fids); load_plugin(symbol("fpa"), logic_has_fpa(), fids); load_plugin(symbol("pb"), logic_has_pb(), fids); @@ -868,7 +872,24 @@ void cmd_context::insert(symbol const & s, object_ref * r) { m_object_refs.insert(s, r); } -void cmd_context::insert_rec_fun(func_decl* f, expr_ref_vector const& binding, svector const& ids, expr* e) { +recfun_decl_plugin * cmd_context::get_recfun_plugin() { + ast_manager & m = get_ast_manager(); + family_id id = m.get_family_id("recfun"); + recfun_decl_plugin* p = reinterpret_cast(m.get_plugin(id)); + SASSERT(p); + return p; +} + + +recfun::promise_def cmd_context::decl_rec_fun(const symbol &name, unsigned int arity, sort *const *domain, sort *range) { + SASSERT(logic_has_recfun()); + recfun_decl_plugin* p = get_recfun_plugin(); + recfun::promise_def def = p->mk_def(name, arity, domain, range); + return def; +} + +// insert a recursive function as a regular quantified axiom +void cmd_context::insert_rec_fun_as_axiom(func_decl *f, expr_ref_vector const& binding, svector const &ids, expr* e) { expr_ref eq(m()); app_ref lhs(m()); lhs = m().mk_app(f, binding.size(), binding.c_ptr()); @@ -899,6 +920,31 @@ void cmd_context::insert_rec_fun(func_decl* f, expr_ref_vector const& binding, s assert_expr(eq); } +// TODO: make this a parameter +#define USE_NATIVE_RECFUN 1 + +void cmd_context::insert_rec_fun(func_decl* f, expr_ref_vector const& binding, svector const& ids, expr* rhs) { + TRACE("recfun", tout<<"define recfun " << f->get_name() + <<" = " << mk_pp(rhs, m()) << "\n";); + + if (! USE_NATIVE_RECFUN) { + // just use an axiom + insert_rec_fun_as_axiom(f, binding, ids, rhs); + return; + } + + recfun_decl_plugin* p = get_recfun_plugin(); + + var_ref_vector vars(m()); + for (expr* b : binding) { + SASSERT(is_var(b)); + vars.push_back(to_var(b)); + } + + recfun::promise_def d = p->get_promise_def(f->get_name()); + p->set_definition(d, vars.size(), vars.c_ptr(), rhs); +} + func_decl * cmd_context::find_func_decl(symbol const & s) const { builtin_decl d; if (m_builtin_decls.find(s, d)) { diff --git a/src/cmd_context/cmd_context.h b/src/cmd_context/cmd_context.h index 99d6c8284..b41eab42d 100644 --- a/src/cmd_context/cmd_context.h +++ b/src/cmd_context/cmd_context.h @@ -29,6 +29,7 @@ Notes: #include "util/dictionary.h" #include "solver/solver.h" #include "ast/datatype_decl_plugin.h" +#include "ast/recfun_decl_plugin.h" #include "util/stopwatch.h" #include "util/cmd_context_types.h" #include "util/event_handler.h" @@ -289,6 +290,7 @@ protected: bool logic_has_array() const; bool logic_has_datatype() const; bool logic_has_fpa() const; + bool logic_has_recfun() const; void print_unsupported_msg() { regular_stream() << "unsupported" << std::endl; } void print_unsupported_info(symbol const& s, int line, int pos) { if (s != symbol::null) diagnostic_stream() << "; " << s << " line: " << line << " position: " << pos << std::endl;} @@ -304,6 +306,7 @@ protected: void erase_macro(symbol const& s); bool macros_find(symbol const& s, unsigned n, expr*const* args, expr*& t) const; + recfun_decl_plugin * get_recfun_plugin(); public: cmd_context(bool main_ctx = true, ast_manager * m = 0, symbol const & l = symbol::null); @@ -382,9 +385,11 @@ public: void insert_user_tactic(symbol const & s, sexpr * d); void insert_aux_pdecl(pdecl * p); void insert_rec_fun(func_decl* f, expr_ref_vector const& binding, svector const& ids, expr* e); + void insert_rec_fun_as_axiom(func_decl* f, expr_ref_vector const& binding, svector const& ids, expr* e); func_decl * find_func_decl(symbol const & s) const; func_decl * find_func_decl(symbol const & s, unsigned num_indices, unsigned const * indices, unsigned arity, sort * const * domain, sort * range) const; + recfun::promise_def decl_rec_fun(const symbol &name, unsigned int arity, sort *const *domain, sort *range); psort_decl * find_psort_decl(symbol const & s) const; cmd * find_cmd(symbol const & s) const; sexpr * find_user_tactic(symbol const & s) const; diff --git a/src/parsers/smt2/smt2parser.cpp b/src/parsers/smt2/smt2parser.cpp index a06438c73..58cc80b08 100644 --- a/src/parsers/smt2/smt2parser.cpp +++ b/src/parsers/smt2/smt2parser.cpp @@ -2278,7 +2278,7 @@ namespace smt2 { next(); } - void parse_rec_fun_decl(func_decl_ref& f, expr_ref_vector& bindings, svector& ids) { + recfun::promise_def parse_rec_fun_decl(func_decl_ref& f, expr_ref_vector& bindings, svector& ids) { SASSERT(m_num_bindings == 0); check_identifier("invalid function/constant definition, symbol expected"); symbol id = curr_id(); @@ -2289,7 +2289,8 @@ namespace smt2 { unsigned num_vars = parse_sorted_vars(); SASSERT(num_vars == m_num_bindings); parse_sort("Invalid recursive function definition"); - f = m().mk_func_decl(id, num_vars, sort_stack().c_ptr() + sort_spos, sort_stack().back()); + recfun::promise_def pdef = m_ctx.decl_rec_fun(id, num_vars, sort_stack().c_ptr() + sort_spos, sort_stack().back()); + f = pdef.get_def()->get_decl(); bindings.append(num_vars, expr_stack().c_ptr() + expr_spos); ids.append(num_vars, symbol_stack().c_ptr() + sym_spos); symbol_stack().shrink(sym_spos); @@ -2297,6 +2298,7 @@ namespace smt2 { expr_stack().shrink(expr_spos); m_env.end_scope(); m_num_bindings = 0; + return pdef; } void parse_rec_fun_bodies(func_decl_ref_vector const& decls, vector const& bindings, vector >const & ids) { diff --git a/src/smt/CMakeLists.txt b/src/smt/CMakeLists.txt index e102bd28b..0d98472b7 100644 --- a/src/smt/CMakeLists.txt +++ b/src/smt/CMakeLists.txt @@ -58,6 +58,7 @@ z3_add_component(smt theory_lra.cpp theory_opt.cpp theory_pb.cpp + theory_recfun.cpp theory_seq.cpp theory_str.cpp theory_utvpi.cpp diff --git a/src/smt/params/smt_params.cpp b/src/smt/params/smt_params.cpp index a8eb81a2e..ec91ec608 100644 --- a/src/smt/params/smt_params.cpp +++ b/src/smt/params/smt_params.cpp @@ -27,6 +27,7 @@ void smt_params::updt_local_params(params_ref const & _p) { m_random_seed = p.random_seed(); m_relevancy_lvl = p.relevancy(); m_ematching = p.ematching(); + m_recfun_max_depth = p.recfun_max_depth(); m_phase_selection = static_cast(p.phase_selection()); m_restart_strategy = static_cast(p.restart_strategy()); m_restart_factor = p.restart_factor(); diff --git a/src/smt/params/smt_params.h b/src/smt/params/smt_params.h index b01499c04..e46b89401 100644 --- a/src/smt/params/smt_params.h +++ b/src/smt/params/smt_params.h @@ -105,6 +105,9 @@ struct smt_params : public preprocessor_params, bool m_new_core2th_eq; bool m_ematching; + // TODO: move into its own file? + unsigned m_recfun_max_depth; + // ----------------------------------- // // Case split strategy @@ -258,6 +261,7 @@ struct smt_params : public preprocessor_params, m_display_features(false), m_new_core2th_eq(true), m_ematching(true), + m_recfun_max_depth(500), m_case_split_strategy(CS_ACTIVITY_DELAY_NEW), m_rel_case_split_order(0), m_lookahead_diseq(false), diff --git a/src/smt/params/smt_params_helper.pyg b/src/smt/params/smt_params_helper.pyg index 937aa6a2b..96923ec5e 100644 --- a/src/smt/params/smt_params_helper.pyg +++ b/src/smt/params/smt_params_helper.pyg @@ -83,5 +83,6 @@ def_module_params(module_name='smt', ('core.extend_patterns', BOOL, False, 'extend unsat core with literals that trigger (potential) quantifier instances'), ('core.extend_patterns.max_distance', UINT, UINT_MAX, 'limits the distance of a pattern-extended unsat core'), ('core.extend_nonlocal_patterns', BOOL, False, 'extend unsat cores with literals that have quantifiers with patterns that contain symbols which are not in the quantifier\'s body'), - ('lemma_gc_strategy', UINT, 0, 'lemma garbage collection strategy: 0 - fixed, 1 - geometric, 2 - at restart, 3 - none') + ('lemma_gc_strategy', UINT, 0, 'lemma garbage collection strategy: 0 - fixed, 1 - geometric, 2 - at restart, 3 - none'), + ('recfun.max_depth', UINT, 500, 'maximum depth of unrolling for recursive functions') )) diff --git a/src/smt/smt_setup.cpp b/src/smt/smt_setup.cpp index bbc91e4c6..e8d55e0d4 100644 --- a/src/smt/smt_setup.cpp +++ b/src/smt/smt_setup.cpp @@ -28,6 +28,7 @@ Revision History: #include "smt/theory_array_full.h" #include "smt/theory_bv.h" #include "smt/theory_datatype.h" +#include "smt/theory_recfun.h" #include "smt/theory_dummy.h" #include "smt/theory_dl.h" #include "smt/theory_seq_empty.h" @@ -217,6 +218,7 @@ namespace smt { void setup::setup_QF_DT() { setup_QF_UF(); setup_datatypes(); + setup_recfuns(); } void setup::setup_QF_BVRE() { @@ -845,6 +847,13 @@ namespace smt { m_context.register_plugin(alloc(theory_datatype, m_manager, m_params)); } + void setup::setup_recfuns() { + TRACE("recfun", tout << "registering theory recfun...\n";); + theory_recfun * th = alloc(theory_recfun, m_manager); + m_context.register_plugin(th); + th->setup_params(); + } + void setup::setup_dl() { m_context.register_plugin(mk_theory_dl(m_manager)); } @@ -898,6 +907,7 @@ namespace smt { setup_arrays(); setup_bv(); setup_datatypes(); + setup_recfuns(); setup_dl(); setup_seq_str(st); setup_card(); diff --git a/src/smt/smt_setup.h b/src/smt/smt_setup.h index 924c2caec..ce8f44047 100644 --- a/src/smt/smt_setup.h +++ b/src/smt/smt_setup.h @@ -93,6 +93,7 @@ namespace smt { void setup_unknown(static_features & st); void setup_arrays(); void setup_datatypes(); + void setup_recfuns(); void setup_bv(); void setup_arith(); void setup_dl(); diff --git a/src/smt/theory_recfun.cpp b/src/smt/theory_recfun.cpp new file mode 100644 index 000000000..db9f7d938 --- /dev/null +++ b/src/smt/theory_recfun.cpp @@ -0,0 +1,363 @@ + +#include "util/stats.h" +#include "ast/ast_util.h" +#include "smt/theory_recfun.h" +#include "smt/params/smt_params_helper.hpp" + +#define DEBUG(x) \ + do { \ + TRACE("recfun", tout << x << '\n';); \ + auto& out = std::cout; out << "recfun: "; out << x; out << '\n' << std::flush; } while(0) + +// NOTE: debug +struct pp_lits { + smt::context & ctx; + smt::literal *lits; + unsigned len; + pp_lits(smt::context & ctx, unsigned len, smt::literal *lits) : ctx(ctx), lits(lits), len(len) {} +}; + +std::ostream & operator<<(std::ostream & out, pp_lits const & pp) { + out << "clause{"; + bool first = true; + for (auto i = 0; i < pp.len; ++i) { + if (first) { first = false; } else { out << " ∨ "; } + pp.ctx.display_detailed_literal(out, pp.lits[i]); + } + return out << "}"; +} + +namespace smt { + + theory_recfun::theory_recfun(ast_manager & m) + : theory(m.mk_family_id("recfun")), m_util(0), m_trail(*this), + m_guards(), m_max_depth(0), m_q_case_expand(), m_q_body_expand() + { + recfun_decl_plugin * plugin = + reinterpret_cast(m.get_plugin(get_family_id())); + SASSERT(plugin); + m_util = & plugin->u(); + SASSERT(m_util); + } + + theory_recfun::~theory_recfun() { + reset_queues(); + for (auto & kv : m_guards) { + m().dec_ref(kv.m_key); + } + m_guards.reset(); + } + + char const * theory_recfun::get_name() const { return "recfun"; } + + void theory_recfun::setup_params() { + // obtain max depth via parameters + smt_params_helper p(get_context().get_params()); + set_max_depth(p.recfun_max_depth()); + } + + theory* theory_recfun::mk_fresh(context* new_ctx) { + return alloc(theory_recfun, new_ctx->get_manager()); + } + + bool theory_recfun::internalize_atom(app * atom, bool gate_ctx) { + context & ctx = get_context(); + if (! ctx.e_internalized(atom)) { + unsigned num_args = atom->get_num_args(); + for (unsigned i = 0; i < num_args; ++i) + ctx.internalize(atom->get_arg(i), false); + ctx.mk_enode(atom, false, true, false); + } + if (! ctx.b_internalized(atom)) { + bool_var v = ctx.mk_bool_var(atom); + ctx.set_var_theory(v, get_id()); + } + return true; + } + + bool theory_recfun::internalize_term(app * term) { + DEBUG("internalizing term: " << mk_pp(term, m())); + context & ctx = get_context(); + for (expr* e : *term) ctx.internalize(e, false); + // the internalization of the arguments may have triggered the internalization of term. + if (ctx.e_internalized(term)) + return true; + ctx.mk_enode(term, false, false, true); + return true; // the theory doesn't actually map terms to variables + } + + void theory_recfun::reset_queues() { + m_q_case_expand.reset(); + m_q_body_expand.reset(); + } + + void theory_recfun::reset_eh() { + m_trail.reset(); + reset_queues(); + } + + /* + * when `n` becomes relevant, if it's `f(t1…tn)` with `f` defined, + * then case-expand `n`. If it's a macro we can also immediately + * body-expand it. + */ + void theory_recfun::relevant_eh(app * n) { + SASSERT(get_context().relevancy()); + if (u().is_defined(n)) { + DEBUG("relevant_eh: (defined) " << mk_pp(n, m())); + + case_expansion e(u(), n); + push_case_expand(std::move(e)); + } + } + + void theory_recfun::push_scope_eh() { + theory::push_scope_eh(); + m_trail.push_scope(); + } + + void theory_recfun::pop_scope_eh(unsigned num_scopes) { + m_trail.pop_scope(num_scopes); + theory::pop_scope_eh(num_scopes); + reset_queues(); + } + + void theory_recfun::restart_eh() { + m_trail.reset(); + reset_queues(); + } + + bool theory_recfun::can_propagate() { + return ! (m_q_case_expand.empty() && m_q_body_expand.empty()); + } + + void theory_recfun::propagate() { + for (case_expansion & e : m_q_case_expand) { + if (e.m_def->is_fun_macro()) { + // body expand immediately + assert_macro_axiom(e); + } + else { + // case expand + SASSERT(e.m_def->is_fun_defined()); + assert_case_axioms(e); + } + } + m_q_case_expand.clear(); + + for (body_expansion & e : m_q_body_expand) { + assert_body_axiom(e); + } + m_q_body_expand.clear(); + } + + void theory_recfun::max_depth_conflict() { + DEBUG("max-depth conflict"); + // TODO: build clause from `m_guards` + /* + context & ctx = get_context(); + region & r = ctx.get_region(); + ctx.set_conflict( + */ + } + + // if `is_true` and `v = C_f_i(t1…tn)`, then body-expand i-th case of `f(t1…tn)` + void theory_recfun::assign_eh(bool_var v, bool is_true) { + expr* e = get_context().bool_var2expr(v); + DEBUG("assign_eh "<< mk_pp(e,m())); + if (!is_true) return; + if (!is_app(e)) return; + app* a = to_app(e); + if (u().is_case_pred(a)) { + DEBUG("assign_case_pred_true "<< mk_pp(e,m())); + // add to set of local assumptions, for depth-limit purpose + { + m_guards.insert(e, empty()); + m().inc_ref(e); + insert_ref_map trail_elt(m(), m_guards, e); + m_trail.push(trail_elt); + } + if (m_guards.size() > get_max_depth()) { + // too many body-expansions: depth-limit conflict + max_depth_conflict(); + } + else { + // body-expand + body_expansion b_e(u(), a); + push_body_expand(std::move(b_e)); + } + } + } + + // replace `vars` by `args` in `e` + expr_ref theory_recfun::apply_args(recfun::vars const & vars, + ptr_vector const & args, + expr * e) { + // check that var order is standard + SASSERT(vars.size() == 0 || vars[vars.size()-1]->get_idx() == 0); + var_subst subst(m(), true); + expr_ref new_body(m()); + subst(e, args.size(), args.c_ptr(), new_body); + get_context().get_rewriter()(new_body); // simplify + return new_body; + } + + app_ref theory_recfun::apply_pred(recfun::case_pred const & p, + ptr_vector const & args){ + app_ref res(u().mk_case_pred(p, args), m()); + return res; + } + + void theory_recfun::assert_macro_axiom(case_expansion & e) { + DEBUG("assert_macro_axiom " << pp_case_expansion(e,m())); + SASSERT(e.m_def->is_fun_macro()); + expr * lhs = e.m_lhs; + context & ctx = get_context(); + auto & vars = e.m_def->get_vars(); + // substitute `e.args` into the macro RHS + expr * rhs = apply_args(vars, e.m_args, e.m_def->get_macro_rhs()); + DEBUG("macro expansion yields" << mk_pp(rhs,m())); + // now build the axiom `lhs = rhs` + ctx.internalize(rhs, false); + DEBUG("adding axiom: " << mk_pp(lhs, m()) << " = " << mk_pp(rhs, m())); + if (m().proofs_enabled()) { + // add unit clause `lhs=rhs` + literal l(mk_eq(lhs, rhs, true)); + ctx.mark_as_relevant(l); + literal lits[1] = {l}; + ctx.mk_th_axiom(get_id(), 1, lits); + } + else { + //region r; + enode * e_lhs = ctx.get_enode(lhs); + enode * e_rhs = ctx.get_enode(rhs); + // TODO: proper justification + //justification * js = ctx.mk_justification( + ctx.assign_eq(e_lhs, e_rhs, eq_justification()); + } + } + + void theory_recfun::assert_case_axioms(case_expansion & e) { + DEBUG("assert_case_axioms "<< pp_case_expansion(e,m()) + << " with " << e.m_def->get_cases().size() << " cases"); + SASSERT(e.m_def->is_fun_defined()); + context & ctx = get_context(); + // add case-axioms for all case-paths + auto & vars = e.m_def->get_vars(); + for (recfun::case_def const & c : e.m_def->get_cases()) { + // applied predicate to `args` + app_ref pred_applied = apply_pred(c.get_pred(), e.m_args); + SASSERT(u().owns_app(pred_applied)); + // substitute arguments in `path` + expr_ref_vector path(m()); + for (auto & g : c.get_guards()) { + expr_ref g_applied = apply_args(vars, e.m_args, g); + path.push_back(g_applied); + } + // assert `p(args) <=> And(guards)` (with CNF on the fly) + ctx.internalize(pred_applied, false); + // FIXME: we need to be informed wen `pred_applied` is true!! + ctx.mark_as_relevant(ctx.get_bool_var(pred_applied)); + literal concl = ctx.get_literal(pred_applied); + { + // assert `guards=>p(args)` + literal_vector c; + c.push_back(concl); + for (expr* g : path) { + ctx.internalize(g, false); + c.push_back(~ ctx.get_literal(g)); + } + + //TRACE("recfun", tout << "assert_case_axioms " << pp_case_expansion(e) + // << " axiom " << mk_pp(*l) <<"\n";); + DEBUG("assert_case_axiom " << pp_lits(get_context(), path.size()+1, c.c_ptr())); + get_context().mk_th_axiom(get_id(), path.size()+1, c.c_ptr()); + } + { + // assert `p(args) => guards[i]` for each `i` + for (expr * _g : path) { + SASSERT(ctx.b_internalized(_g)); + literal g = ctx.get_literal(_g); + literal c[2] = {~ concl, g}; + + DEBUG("assert_case_axiom " << pp_lits(get_context(), 2, c)); + get_context().mk_th_axiom(get_id(), 2, c); + } + } + + // also body-expand paths that do not depend on any defined fun + if (c.is_immediate()) { + body_expansion be(c, e.m_args); + assert_body_axiom(be); + } + } + } + + void theory_recfun::assert_body_axiom(body_expansion & e) { + DEBUG("assert_body_axioms "<< pp_body_expansion(e,m())); + context & ctx = get_context(); + recfun::def & d = *e.m_cdef->get_def(); + auto & vars = d.get_vars(); + auto & args = e.m_args; + // check that var order is standard + SASSERT(vars.size() == 0 || vars[vars.size()-1]->get_idx() == 0); + expr_ref lhs(u().mk_fun_defined(d, args), m()); + // substitute `e.args` into the RHS of this particular case + expr_ref rhs = apply_args(vars, args, e.m_cdef->get_rhs()); + // substitute `e.args` into the guard of this particular case, to make + // the `condition` part of the clause `conds => lhs=rhs` + expr_ref_vector guards(m()); + for (auto & g : e.m_cdef->get_guards()) { + expr_ref new_guard = apply_args(vars, args, g); + guards.push_back(new_guard); + } + // now build the axiom `conds => lhs = rhs` + ctx.internalize(rhs, false); + for (auto& g : guards) ctx.internalize(g, false); + + // add unit clause `conds => lhs=rhs` + literal_vector clause; + for (auto& g : guards) { + ctx.internalize(g, false); + literal l = ~ ctx.get_literal(g); + ctx.mark_as_relevant(l); + clause.push_back(l); + } + literal l(mk_eq(lhs, rhs, true)); + ctx.mark_as_relevant(l); + clause.push_back(l); + DEBUG("assert_body_axiom " << pp_lits(get_context(), clause.size(), clause.c_ptr())); + ctx.mk_th_axiom(get_id(), clause.size(), clause.c_ptr()); + } + + final_check_status theory_recfun::final_check_eh() { + return FC_DONE; + } + + void theory_recfun::add_theory_assumptions(expr_ref_vector & assumptions) { + DEBUG("add_theory_assumptions"); + // TODO: add depth-limit assumptions? + } + + void theory_recfun::display(std::ostream & out) const { + out << "recfun{}"; + } + + void theory_recfun::collect_statistics(::statistics & st) const { + st.update("recfun macro expansion", m_stats.m_macro_expansions); + st.update("recfun case expansion", m_stats.m_case_expansions); + st.update("recfun body expansion", m_stats.m_body_expansions); + } + + std::ostream& operator<<(std::ostream & out, theory_recfun::pp_case_expansion const & e) { + return out << "case_exp(" << mk_pp(e.e.m_lhs, e.m) << ")"; + } + + std::ostream& operator<<(std::ostream & out, theory_recfun::pp_body_expansion const & e) { + out << "body_exp(" << e.e.m_cdef->get_name(); + for (auto* t : e.e.m_args) { + out << " " << mk_pp(t,e.m); + } + return out << ")"; + } +} diff --git a/src/smt/theory_recfun.h b/src/smt/theory_recfun.h new file mode 100644 index 000000000..5645a1893 --- /dev/null +++ b/src/smt/theory_recfun.h @@ -0,0 +1,155 @@ +/*++ +Copyright (c) 2006 Microsoft Corporation + +Module Name: + + theory_recfun.h + +Abstract: + + Theory responsible for unrolling recursive functions + +Author: + + Leonardo de Moura (leonardo) 2008-10-31. + +Revision History: + +--*/ +#ifndef THEORY_RECFUN_H_ +#define THEORY_RECFUN_H_ + +#include "smt/smt_theory.h" +#include "smt/smt_context.h" +#include "ast/ast_pp.h" +#include "ast/recfun_decl_plugin.h" + +namespace smt { + + class theory_recfun : public theory { + struct stats { + unsigned m_case_expansions, m_body_expansions, m_macro_expansions; + void reset() { memset(this, 0, sizeof(stats)); } + stats() { reset(); } + }; + + // one case-expansion of `f(t1…tn)` + struct case_expansion { + expr * m_lhs; // the term to expand + recfun_def * m_def; + ptr_vector m_args; + case_expansion(recfun_util& u, app * n) : m_lhs(n), m_def(0), m_args() + { + SASSERT(u.is_defined(n)); + func_decl * d = n->get_decl(); + const symbol& name = d->get_name(); + m_def = &u.get_def(name); + m_args.append(n->get_num_args(), n->get_args()); + } + case_expansion(case_expansion const & from) + : m_lhs(from.m_lhs), + m_def(from.m_def), + m_args(from.m_args) {} + case_expansion(case_expansion && from) + : m_lhs(from.m_lhs), + m_def(from.m_def), + m_args(std::move(from.m_args)) {} + }; + + struct pp_case_expansion { + case_expansion & e; + ast_manager & m; + pp_case_expansion(case_expansion & e, ast_manager & m) : e(e), m(m) {} + }; + + friend std::ostream& operator<<(std::ostream&, pp_case_expansion const &); + + // one body-expansion of `f(t1…tn)` using a `C_f_i(t1…tn)` + struct body_expansion { + recfun_case_def const * m_cdef; + ptr_vector m_args; + + body_expansion(recfun_util& u, app * n) : m_cdef(0), m_args() { + SASSERT(u.is_case_pred(n)); + func_decl * d = n->get_decl(); + const symbol& name = d->get_name(); + m_cdef = &u.get_case_def(name); + for (unsigned i = 0; i < n->get_num_args(); ++i) + m_args.push_back(n->get_arg(i)); + } + body_expansion(recfun_case_def const & d, ptr_vector & args) : m_cdef(&d), m_args(args) {} + body_expansion(body_expansion const & from): m_cdef(from.m_cdef), m_args(from.m_args) {} + body_expansion(body_expansion && from) : m_cdef(from.m_cdef), m_args(std::move(from.m_args)) {} + }; + + struct pp_body_expansion { + body_expansion & e; + ast_manager & m; + pp_body_expansion(body_expansion & e, ast_manager & m) : e(e), m(m) {} + }; + + friend std::ostream& operator<<(std::ostream&, pp_body_expansion const &); + + struct empty{}; + + typedef trail_stack th_trail_stack; + typedef obj_map guard_set; + + recfun_util * m_util; + stats m_stats; + th_trail_stack m_trail; + guard_set m_guards; // true case-preds + unsigned m_max_depth; // for fairness and termination + + vector m_q_case_expand; + vector m_q_body_expand; + + recfun_util & u() const { SASSERT(m_util); return *m_util; } + ast_manager & m() { return get_manager(); } + bool is_defined(app * f) const { return u().is_defined(f); } + bool is_case_pred(app * f) const { return u().is_case_pred(f); } + + bool is_defined(enode * e) const { return is_defined(e->get_owner()); } + bool is_case_pred(enode * e) const { return is_case_pred(e->get_owner()); } + + void reset_queues(); + expr_ref apply_args(recfun::vars const & vars, ptr_vector const & args, expr * e); //!< substitute variables by args + app_ref apply_pred(recfun::case_pred const & p, ptr_vector const & args); //0); m_max_depth = n; } + }; +} + +#endif diff --git a/src/util/scoped_ptr_vector.h b/src/util/scoped_ptr_vector.h index 0bd0fd47e..15b9b889c 100644 --- a/src/util/scoped_ptr_vector.h +++ b/src/util/scoped_ptr_vector.h @@ -31,6 +31,7 @@ public: void reset() { std::for_each(m_vector.begin(), m_vector.end(), delete_proc()); m_vector.reset(); } void push_back(T * ptr) { m_vector.push_back(ptr); } void pop_back() { SASSERT(!empty()); set(size()-1, 0); m_vector.pop_back(); } + T * back() const { return m_vector.back(); } T * operator[](unsigned idx) const { return m_vector[idx]; } void set(unsigned idx, T * ptr) { if (m_vector[idx] == ptr) @@ -51,6 +52,13 @@ public: push_back(0); } } + //!< swap last element with given pointer + void swap_back(scoped_ptr & ptr) { + SASSERT(!empty()); + T * tmp = ptr.detach(); + ptr = m_vector.back(); + m_vector[m_vector.size()-1] = tmp; + } }; #endif diff --git a/src/util/trail.h b/src/util/trail.h index bba71fb00..0503bcfa7 100644 --- a/src/util/trail.h +++ b/src/util/trail.h @@ -143,6 +143,17 @@ public: }; +template +class insert_ref_map : public trail { + Mgr& m; + M& m_map; + D m_obj; +public: + insert_ref_map(Mgr& m, M& t, D o) : m(m), m_map(t), m_obj(o) {} + virtual ~insert_ref_map() {} + virtual void undo(Ctx & ctx) { m_map.remove(m_obj); m.dec_ref(m_obj); } +}; + template class push_back_vector : public trail {