/*++ Copyright (c) 2017 Microsoft Corporation, Simon Cruanes Module Name: recfun_decl_plugin.cpp Abstract: Declaration and definition of (potentially recursive) functions Author: Simon Cruanes 2017-11 Revision History: --*/ #include #include #include "ast/expr_functors.h" #include "ast/recfun_decl_plugin.h" #include "ast/ast_pp.h" #include "ast/for_each_expr.h" #include "util/scoped_ptr_vector.h" #define TRACEFN(x) TRACE("recfun", tout << x << '\n';) #define VALIDATE_PARAM(m, _pred_) if (!(_pred_)) m.raise_exception("invalid parameter to recfun " #_pred_); namespace recfun { case_def::case_def( ast_manager &m, family_id fid, def * d, std::string & name, unsigned case_index, sort_ref_vector const & arg_sorts, expr_ref_vector const& guards, expr* rhs) : m_pred(m), m_guards(guards), m_rhs(expr_ref(rhs,m)), m_def(d) { parameter p(case_index); func_decl_info info(fid, OP_FUN_CASE_PRED, 1, &p); m_pred = m.mk_func_decl(symbol(name.c_str()), arg_sorts.size(), arg_sorts.data(), m.mk_bool_sort(), info); } def::def(ast_manager &m, family_id fid, symbol const & s, unsigned arity, sort* const * domain, sort* range, bool is_generated) : m(m), m_name(s), m_domain(m, arity, domain), m_range(range, m), m_vars(m), m_cases(), m_decl(m), m_rhs(m), m_fid(fid) { SASSERT(arity == get_arity()); parameter p(is_generated); func_decl_info info(fid, OP_FUN_DEFINED, 1, &p); m_decl = m.mk_func_decl(s, arity, domain, range, info); } def* def::copy(util& dst, ast_translation& tr) { SASSERT(&dst.m() == &tr.to()); sort_ref_vector domain(tr.to()); sort_ref range(tr(m_range.get()), tr.to()); for (auto* s : m_domain) domain.push_back(tr(s)); family_id fid = dst.get_family_id(); bool is_generated = m_decl->get_parameter(0).get_int() != 0; def* r = alloc(def, tr.to(), fid, m_name, domain.size(), domain.data(), range, is_generated); r->m_rhs = tr(m_rhs.get()); for (auto* v : m_vars) r->m_vars.push_back(tr(v)); for (auto const& c1 : m_cases) { r->m_cases.push_back(case_def(tr.to())); auto& c2 = r->m_cases.back(); c2.m_pred = tr(c1.m_pred.get()); c2.m_guards = tr(c1.m_guards); c2.m_rhs = tr(c1.m_rhs.get()); c2.m_def = r; c2.m_immediate = c1.m_immediate; } return r; } bool def::contains_def(util& u, expr * e) { struct def_find_p : public i_expr_pred { util& u; def_find_p(util& u): u(u) {} bool operator()(expr* a) override { return is_app(a) && u.is_defined(to_app(a)->get_decl()); } }; def_find_p p(u); check_pred cp(p, m, false); return cp(e); } // does `e` contain any `ite` construct? // subject to the then/else branch using a recursive call, // but the guard does not. bool def::contains_ite(util& u, expr * e) { struct ite_find_p : public i_expr_pred { ast_manager & m; def& d; util& u; ite_find_p(ast_manager & m, def& d, util& u) : m(m), d(d), u(u) {} bool operator()(expr * e) override { return m.is_ite(e) && !d.contains_def(u, to_app(e)->get_arg(0)) && d.contains_def(u, e); } }; // ignore ites under quantifiers. // this is redundant as the code // that unfolds ites uses quantifier-free portion. ite_find_p p(m, *this, u); check_pred cp(p, m, false); return cp(e); } /* * compilation of functions to a list of cases. * * We use a backtracking algorithm in a relatively functional style, * where the multiple states (corresponding to alternatives) are stored in * a region, and deallocated at the end */ // immutable list of choices of `ite` terms (mapping each one's condition to true/false) struct choice_lst { app * ite; bool sign; choice_lst const * next; // or null for the last one choice_lst(app * ite, bool sign, choice_lst const * next) : ite(ite), sign(sign), next(next) {} }; struct ite_lst { app * ite; // invariant: `is_ite(e)` ite_lst const * next; ite_lst(app * ite, ite_lst const * next) : ite(ite), next(next) {} }; // immutable stack of expressions to unfold struct unfold_lst { expr * e; unfold_lst const * next; // or null for last one }; // main state for one branch of the search tree. struct branch { choice_lst const * path; // choices made so far ite_lst const * to_split; // `ite` terms to make a choice on unfold_lst const * to_unfold; // terms yet to unfold branch(unfold_lst const * to_unfold): path(nullptr), to_split(nullptr), to_unfold(to_unfold) {} branch(choice_lst const * path, ite_lst const * to_split, unfold_lst const * to_unfold) : path(path), to_split(to_split), to_unfold(to_unfold) {} }; // state for computing cases from the RHS of a functions' definition class case_state { region m_reg; vector m_branches; public: case_state() : m_reg(), m_branches() {} bool empty() const { return m_branches.empty(); } branch pop_branch() { branch res = m_branches.back(); m_branches.pop_back(); return res; } void push_branch(branch const & b) { m_branches.push_back(b); } unfold_lst const * cons_unfold(expr * e, unfold_lst const * next) { return new (m_reg) unfold_lst{e, next}; } unfold_lst const * cons_unfold(expr * e1, expr * e2, unfold_lst const * next) { return cons_unfold(e1, cons_unfold(e2, next)); } unfold_lst const * mk_unfold_lst(expr * e) { return cons_unfold(e, nullptr); } ite_lst const * cons_ite(app * ite, ite_lst const * next) { return new (m_reg) ite_lst{ite, next}; } choice_lst const * cons_choice(app * ite, bool sign, choice_lst const * next) { return new (m_reg) choice_lst{ite, sign, next}; } }; //next) { app * ite = choices->ite; expr* c = nullptr, *th = nullptr, *el = nullptr; VERIFY(m.is_ite(ite, c, th, el)); // condition to add to the guard conditions.push_back(choices->sign ? c : m.mk_not(c)); // binding to add to the substitution expr_ref tgt(choices->sign ? th : el, m); tgt = subst(tgt); subst.insert(ite, tgt); } } void def::add_case(std::string & name, unsigned case_index, expr_ref_vector const& conditions, expr * rhs, bool is_imm) { case_def c(m, m_fid, this, name, case_index, get_domain(), conditions, rhs); c.set_is_immediate(is_imm); TRACEFN("add_case " << name << "\n" << mk_pp(rhs, m) << "\n:is_imm " << is_imm << "\n:guards " << conditions); m_cases.push_back(c); } // Compute a set of cases, given the RHS void def::compute_cases(util& u, replace& subst, is_immediate_pred & is_i, bool is_macro, unsigned n_vars, var *const * vars, expr* rhs) { VERIFY(m_cases.empty() && "cases cannot already be computed"); SASSERT(n_vars == m_domain.size()); TRACEFN("compute cases " << mk_pp(rhs, m)); unsigned case_idx = 0; std::string name("case-"); name.append(m_name.str()); m_vars.append(n_vars, vars); m_rhs = rhs; if (!is_macro) for (expr* e : subterms::all(m_rhs)) if (is_lambda(e)) throw default_exception("recursive definitions with lambdas are not supported"); expr_ref_vector conditions(m); // is the function a macro (unconditional body)? if (is_macro || n_vars == 0 || !contains_ite(u, rhs)) { // constant function or trivial control flow, only one (dummy) case add_case(name, 0, conditions, rhs); return; } // analyze control flow of `rhs`, accumulating guards and // rebuilding a `ite`-free RHS on the fly for each path in `rhs`. // Each such `ite`-free term is converted into a case_def and added to definition. case_state st; st.push_branch(branch(st.mk_unfold_lst(rhs))); while (! st.empty()) { branch b = st.pop_branch(); // first: unfold expressions, stopping when we meet subterms that are `ite` while (b.to_unfold != nullptr) { ptr_vector stack; stack.push_back(b.to_unfold->e); b.to_unfold = b.to_unfold->next; while (! stack.empty()) { expr * e = stack.back(); stack.pop_back(); expr* cond = nullptr, *th = nullptr, *el = nullptr; if (m.is_ite(e, cond, th, el) && contains_def(u, cond)) { // skip } else if (m.is_ite(e)) { // need to do a case split on `e`, forking the search space b.to_split = st.cons_ite(to_app(e), b.to_split); } else if (is_app(e)) { // explore arguments for (expr * arg : *to_app(e)) if (contains_ite(u, arg)) stack.push_back(arg); } } } if (b.to_split != nullptr) { // split one `ite`, which will lead to distinct (sets of) cases app * ite = b.to_split->ite; TRACEFN("split: " << mk_pp(ite, m)); expr* c = nullptr, *th = nullptr, *el = nullptr; VERIFY(m.is_ite(ite, c, th, el)); /* explore both positive choice and negative choice. * each contains a longer path, with `ite` mapping to `true` (resp. `false), * and must unfold the `then` (resp. `else`) branch. * We must also unfold the test itself, for it could contain * tests. */ branch b_pos(st.cons_choice(ite, true, b.path), b.to_split->next, st.cons_unfold(c, th, b.to_unfold)); branch b_neg(st.cons_choice(ite, false, b.path), b.to_split->next, st.cons_unfold(c, el, b.to_unfold)); st.push_branch(b_neg); st.push_branch(b_pos); } else { // leaf of the search tree conditions.reset(); subst.reset(); convert_path(m, b.path, conditions, subst); // substitute, to get rid of `ite` terms expr_ref case_rhs = subst(rhs); for (unsigned i = 0; i < conditions.size(); ++i) conditions[i] = subst(conditions.get(i)); // yield new case bool is_imm = is_i(case_rhs); add_case(name, case_idx++, conditions, case_rhs, is_imm); } } TRACEFN("done analyzing " << get_name()); } /* * Main manager for defined functions */ util::util(ast_manager & m) : m_manager(m), m_fid(m.get_family_id("recfun")), m_plugin(dynamic_cast(m.get_plugin(m_fid))) { } util::~util() { } def * util::decl_fun(symbol const& name, unsigned n, sort *const * domain, sort * range, bool is_generated) { return alloc(def, m(), m_fid, name, n, domain, range, is_generated); } void util::set_definition(replace& subst, promise_def & d, bool is_macro, unsigned n_vars, var * const * vars, expr * rhs) { expr_ref rhs1(rhs, m()); if (!is_macro) rhs1 = get_plugin().redirect_ite(subst, n_vars, vars, rhs); d.set_definition(subst, is_macro, n_vars, vars, rhs1); } app_ref util::mk_num_rounds_pred(unsigned d) { parameter p(d); func_decl_info info(m_fid, OP_NUM_ROUNDS, 1, &p); func_decl* decl = m().mk_const_decl(symbol("recfun-num-rounds"), m().mk_bool_sort(), info); return app_ref(m().mk_const(decl), m()); } // used to know which `app` are from this theory struct is_imm_pred : is_immediate_pred { util & u; is_imm_pred(util & u) : u(u) {} bool operator()(expr * rhs) override { // find an `app` that is an application of a defined function struct find : public i_expr_pred { util & u; find(util & u) : u(u) {} bool operator()(expr * e) override { return is_app(e) && u.is_defined(to_app(e)); } }; find f(u); check_pred cp(f, u.m()); bool contains_defined_fun = cp(rhs); return ! contains_defined_fun; } }; // set definition void promise_def::set_definition(replace& r, bool is_macro, unsigned n_vars, var * const * vars, expr * rhs) { SASSERT(n_vars == d->get_arity()); is_imm_pred is_i(*u); d->compute_cases(*u, r, is_i, is_macro, n_vars, vars, rhs); } namespace decl { plugin::plugin() : decl_plugin(), m_defs(), m_case_defs() {} plugin::~plugin() { finalize(); } void plugin::finalize() { for (auto& kv : m_defs) { dealloc(kv.m_value); } m_defs.reset(); // m_case_defs does not own its data, no need to deallocate m_case_defs.reset(); m_util = nullptr; // force deletion } util & plugin::u() const { SASSERT(m_manager); SASSERT(m_family_id != null_family_id); if (!m_util.get()) { m_util = alloc(util, *m_manager); } return *(m_util.get()); } promise_def plugin::mk_def(symbol const& name, unsigned n, sort *const * params, sort * range, bool is_generated) { def* d = u().decl_fun(name, n, params, range, is_generated); SASSERT(!m_defs.contains(d->get_decl())); m_defs.insert(d->get_decl(), d); return promise_def(&u(), d); } void plugin::inherit(decl_plugin* other, ast_translation& tr) { for (auto [k, v] : static_cast(other)->m_defs) { func_decl_ref f(tr(k), tr.to()); if (m_defs.contains(f)) continue; def* d = v->copy(u(), tr); m_defs.insert(f, d); for (case_def & c : d->get_cases()) m_case_defs.insert(c.get_decl(), &c); } } promise_def plugin::ensure_def(symbol const& name, unsigned n, sort *const * params, sort * range, bool is_generated) { def* d = u().decl_fun(name, n, params, range, is_generated); erase_def(d->get_decl()); m_defs.insert(d->get_decl(), d); return promise_def(&u(), d); } void plugin::erase_def(func_decl* f) { def* d = nullptr; if (m_defs.find(f, d)) { for (case_def & c : d->get_cases()) m_case_defs.erase(c.get_decl()); m_defs.erase(f); dealloc(d); } } void plugin::set_definition(replace& r, promise_def & d, bool is_macro, unsigned n_vars, var * const * vars, expr * rhs) { m_has_rec_defs |= !is_macro; u().set_definition(r, d, is_macro, n_vars, vars, rhs); for (case_def & c : d.get_def()->get_cases()) m_case_defs.insert(c.get_decl(), &c); } bool plugin::has_defs() const { return !m_case_defs.empty(); } def* plugin::mk_def(replace& subst, bool is_macro, symbol const& name, unsigned n, sort ** params, sort * range, unsigned n_vars, var ** vars, expr * rhs) { promise_def d = mk_def(name, n, params, range); SASSERT(! m_defs.contains(d.get_def()->get_decl())); set_definition(subst, d, is_macro, n_vars, vars, rhs); return d.get_def(); } // generic declaration of symbols func_decl * plugin::mk_func_decl(decl_kind k, unsigned num_parameters, parameter const * parameters, unsigned arity, sort * const * domain, sort * range) { UNREACHABLE(); return nullptr; } /** * \brief compute ite nesting depth scores with each sub-expression of e. * associate with each subterm of e its parent terms. * and for every term depth the set of terms with the same depth */ void plugin::compute_scores(expr* e, obj_map& scores) { u_map> by_depth; obj_map> parents; expr_ref tmp(e, m()); parents.insert(e, ptr_vector()); for (expr* t : subterms::ground(tmp)) { if (is_app(t)) { for (expr* arg : *to_app(t)) { parents.insert_if_not_there(arg, ptr_vector()).push_back(t); } } by_depth.insert_if_not_there(get_depth(t), ptr_vector()).push_back(t); } unsigned max_depth = get_depth(e); scores.insert(e, 0); // walk deepest terms first. for (unsigned i = max_depth; i > 0; --i) { if (!by_depth.contains(i)) continue; for (expr* t : by_depth[i]) { unsigned score = 0; for (expr* parent : parents[t]) { score += scores[parent]; } if (m().is_ite(t)) { score++; TRACEFN("score " << mk_pp(t, m()) << ": " << score); } scores.insert(t, score); } } } expr_ref plugin::redirect_ite(replace& subst, unsigned n, var * const* vars, expr * e) { expr_ref result(e, m()); while (true) { obj_map scores; compute_scores(result, scores); unsigned max_score = 0; expr* max_expr = nullptr; for (auto const& kv : scores) { if (m().is_ite(kv.m_key) && kv.m_value > max_score) { max_expr = kv.m_key; max_score = kv.m_value; } } if (max_score <= 4) break; ptr_vector domain; ptr_vector args; for (unsigned i = 0; i < n; ++i) { domain.push_back(vars[i]->get_sort()); args.push_back(vars[i]); } symbol fresh_name("fold-rec-" + std::to_string(m().mk_fresh_id())); auto pd = mk_def(fresh_name, n, domain.data(), max_expr->get_sort()); func_decl* f = pd.get_def()->get_decl(); expr_ref new_body(m().mk_app(f, n, args.data()), m()); set_definition(subst, pd, false, n, vars, max_expr); subst.reset(); subst.insert(max_expr, new_body); result = subst(result); TRACEFN("substituted\n" << mk_pp(max_expr, m()) << "\n->\n" << new_body << "\n-result->\n" << result); } return result; } } case_expansion::case_expansion(recfun::util& u, app * n) : m_lhs(n, u.m()), m_def(nullptr), m_args(u.m()) { SASSERT(u.is_defined(n)); func_decl * d = n->get_decl(); m_def = &u.get_def(d); m_args.append(n->get_num_args(), n->get_args()); } case_expansion::case_expansion(case_expansion const & from) : m_lhs(from.m_lhs), m_def(from.m_def), m_args(from.m_args) {} case_expansion::case_expansion(case_expansion && from) : m_lhs(from.m_lhs), m_def(from.m_def), m_args(std::move(from.m_args)) {} std::ostream& case_expansion::display(std::ostream & out) const { return out << "case_exp(" << m_lhs << ")"; } std::ostream& body_expansion::display(std::ostream & out) const { ast_manager& m = m_pred.m(); out << "body_exp(" << m_cdef->get_decl()->get_name(); for (auto* t : m_args) out << " " << mk_pp(t, m); return out << ")"; } }