diff --git a/src/api/api_ast.cpp b/src/api/api_ast.cpp index cbe365c6c..bf5a79bdf 100644 --- a/src/api/api_ast.cpp +++ b/src/api/api_ast.cpp @@ -33,6 +33,7 @@ Revision History: #include "ast/rewriter/th_rewriter.h" #include "ast/rewriter/var_subst.h" #include "ast/rewriter/expr_safe_replace.h" +#include "ast/rewriter/recfun_replace.h" #include "ast/pp.h" #include "util/scoped_ctrl_c.h" #include "util/cancel_eh.h" @@ -156,7 +157,8 @@ extern "C" { SET_ERROR_CODE(Z3_INVALID_ARG, nullptr); return; } - p.set_definition(pd, n, _vars.c_ptr(), abs_body); + recfun_replace replace(m); + p.set_definition(replace, pd, n, _vars.c_ptr(), abs_body); Z3_CATCH; } diff --git a/src/ast/recfun_decl_plugin.cpp b/src/ast/recfun_decl_plugin.cpp index 7fe7ae00b..1ac44067f 100644 --- a/src/ast/recfun_decl_plugin.cpp +++ b/src/ast/recfun_decl_plugin.cpp @@ -21,7 +21,6 @@ Revision History: #include #include #include "ast/expr_functors.h" -#include "ast/expr_substitution.h" #include "ast/recfun_decl_plugin.h" #include "ast/ast_pp.h" #include "util/scoped_ptr_vector.h" @@ -162,7 +161,7 @@ namespace recfun { static void convert_path(ast_manager & m, choice_lst const * choices, expr_ref_vector & conditions /* out */, - expr_substitution & subst /* out */) + replace & subst /* out */) { for (; choices != nullptr; choices = choices->next) { app * ite = choices->ite; @@ -177,15 +176,6 @@ namespace recfun { } } - // substitute `subst` in `e` - static expr_ref replace_subst(th_rewriter & th_rw, ast_manager & m, - expr_substitution & subst, expr * e) { - th_rw.reset(); - th_rw.set_substitution(&subst); - expr_ref res(m); - th_rw(e, res); - return res; - } void def::add_case(std::string & name, unsigned case_index, expr_ref_vector const& conditions, expr * rhs, bool is_imm) { case_def c(m, m_fid, this, name, case_index, get_domain(), conditions, rhs); @@ -198,7 +188,8 @@ namespace recfun { // Compute a set of cases, given the RHS - void def::compute_cases(is_immediate_pred & is_i, th_rewriter & th_rw, + void def::compute_cases(replace& subst, + is_immediate_pred & is_i, unsigned n_vars, var *const * vars, expr* rhs) { VERIFY(m_cases.empty() && "cases cannot already be computed"); @@ -291,13 +282,13 @@ namespace recfun { // leaf of the search tree conditions.reset(); - expr_substitution subst(m); + subst.reset(); convert_path(m, b.path, conditions, subst); // substitute, to get rid of `ite` terms - expr_ref case_rhs = replace_subst(th_rw, m, subst, rhs); + expr_ref case_rhs = subst(rhs); for (unsigned i = 0; i < conditions.size(); ++i) { - conditions[i] = replace_subst(th_rw, m, subst, conditions.get(i)); + conditions[i] = subst(conditions.get(i)); } // yield new case @@ -314,7 +305,7 @@ namespace recfun { */ util::util(ast_manager & m) - : m_manager(m), m_fid(m.get_family_id("recfun")), m_th_rw(m), + : m_manager(m), m_fid(m.get_family_id("recfun")), m_plugin(dynamic_cast(m.get_plugin(m_fid))) { } @@ -325,8 +316,8 @@ namespace recfun { return alloc(def, m(), m_fid, name, n, domain, range); } - void util::set_definition(promise_def & d, unsigned n_vars, var * const * vars, expr * rhs) { - d.set_definition(n_vars, vars, rhs); + void util::set_definition(replace& subst, promise_def & d, unsigned n_vars, var * const * vars, expr * rhs) { + d.set_definition(subst, n_vars, vars, rhs); } app_ref util::mk_depth_limit_pred(unsigned d) { @@ -361,11 +352,11 @@ namespace recfun { }; // set definition - void promise_def::set_definition(unsigned n_vars, var * const * vars, expr * rhs) { + void promise_def::set_definition(replace& r, unsigned n_vars, var * const * vars, expr * rhs) { SASSERT(n_vars == d->get_arity()); is_imm_pred is_i(*u); - d->compute_cases(is_i, u->get_th_rewriter(), n_vars, vars, rhs); + d->compute_cases(r, is_i, n_vars, vars, rhs); } namespace decl { @@ -398,8 +389,8 @@ namespace recfun { return promise_def(&u(), d); } - void plugin::set_definition(promise_def & d, unsigned n_vars, var * const * vars, expr * rhs) { - u().set_definition(d, n_vars, vars, rhs); + void plugin::set_definition(replace& r, promise_def & d, unsigned n_vars, var * const * vars, expr * rhs) { + u().set_definition(r, d, n_vars, vars, rhs); for (case_def & c : d.get_def()->get_cases()) { m_case_defs.insert(c.get_decl(), &c); } @@ -409,11 +400,12 @@ namespace recfun { return !m_case_defs.empty(); } - def* plugin::mk_def(symbol const& name, unsigned n, sort ** params, sort * range, + def* plugin::mk_def(replace& subst, + symbol const& name, unsigned n, sort ** params, sort * range, unsigned n_vars, var ** vars, expr * rhs) { promise_def d = mk_def(name, n, params, range); SASSERT(! m_defs.contains(d.get_def()->get_decl())); - set_definition(d, n_vars, vars, rhs); + set_definition(subst, d, n_vars, vars, rhs); return d.get_def(); } diff --git a/src/ast/recfun_decl_plugin.h b/src/ast/recfun_decl_plugin.h index 347689a37..0247335e8 100644 --- a/src/ast/recfun_decl_plugin.h +++ b/src/ast/recfun_decl_plugin.h @@ -20,7 +20,6 @@ Revision History: #pragma once #include "ast/ast.h" -#include "ast/rewriter/th_rewriter.h" #include "util/obj_hashtable.h" namespace recfun { @@ -46,6 +45,13 @@ namespace recfun { typedef var_ref_vector vars; + class replace { + public: + virtual void reset() = 0; + virtual void insert(expr* d, expr* r) = 0; + virtual expr_ref operator()(expr* e) = 0; + }; + class case_def { friend class def; func_decl_ref m_pred; //