diff --git a/examples/c++/example.cpp b/examples/c++/example.cpp index 3089f5e2b..ab9c73209 100644 --- a/examples/c++/example.cpp +++ b/examples/c++/example.cpp @@ -1190,6 +1190,20 @@ void mk_model_example() { std::cout << m.eval(a + b < 2)<< std::endl; } +void recfun_example() { + std::cout << "recfun example\n"; + context c; + expr x = c.int_const("x"); + expr y = c.int_const("y"); + expr b = c.bool_const("b"); + sort I = c.int_sort(); + sort B = c.bool_sort(); + func_decl f = recfun("f", I, B, I); + expr_vector args(c); + args.push_back(x); args.push_back(b); + c.recdef(f, args, ite(b, x, f(x + 1, !b))); + prove(f(x,c.bool_val(false)) > x); +} int main() { @@ -1239,6 +1253,7 @@ int main() { consequence_example(); std::cout << "\n"; parse_example(); std::cout << "\n"; mk_model_example(); std::cout << "\n"; + recfun_example(); std::cout << "\n"; std::cout << "done\n"; } catch (exception & ex) { diff --git a/src/api/api_ast.cpp b/src/api/api_ast.cpp index a28315cda..cbe365c6c 100644 --- a/src/api/api_ast.cpp +++ b/src/api/api_ast.cpp @@ -38,6 +38,8 @@ Revision History: #include "util/cancel_eh.h" #include "util/scoped_timer.h" #include "ast/pp_params.hpp" +#include "ast/expr_abstract.h" + extern bool is_numeral_sort(Z3_context c, Z3_sort ty); @@ -110,6 +112,54 @@ extern "C" { Z3_CATCH_RETURN(nullptr); } + Z3_func_decl Z3_API Z3_mk_rec_func_decl(Z3_context c, Z3_symbol s, unsigned domain_size, Z3_sort const* domain, + Z3_sort range) { + Z3_TRY; + LOG_Z3_mk_rec_func_decl(c, s, domain_size, domain, range); + RESET_ERROR_CODE(); + // + recfun::promise_def def = + mk_c(c)->recfun().get_plugin().mk_def(to_symbol(s), + domain_size, + to_sorts(domain), + to_sort(range)); + func_decl* d = def.get_def()->get_decl(); + mk_c(c)->save_ast_trail(d); + RETURN_Z3(of_func_decl(d)); + Z3_CATCH_RETURN(nullptr); + } + + void Z3_API Z3_add_rec_def(Z3_context c, Z3_func_decl f, unsigned n, Z3_ast args[], Z3_ast body) { + Z3_TRY; + LOG_Z3_add_rec_def(c, f, n, args, body); + func_decl* d = to_func_decl(f); + ast_manager& m = mk_c(c)->m(); + recfun::decl::plugin& p = mk_c(c)->recfun().get_plugin(); + expr_ref abs_body(m); + expr_ref_vector _args(m); + var_ref_vector _vars(m); + for (unsigned i = 0; i < n; ++i) { + _args.push_back(to_expr(args[i])); + _vars.push_back(m.mk_var(n - i - 1, m.get_sort(_args.back()))); + if (m.get_sort(_args.back()) != d->get_domain(i)) { + SET_ERROR_CODE(Z3_INVALID_ARG, nullptr); + return; + } + } + expr_abstract(m, 0, n, _args.c_ptr(), to_expr(body), abs_body); + recfun::promise_def pd = p.get_promise_def(d); + if (!pd.get_def()) { + SET_ERROR_CODE(Z3_INVALID_ARG, nullptr); + return; + } + if (m.get_sort(abs_body) != d->get_range()) { + SET_ERROR_CODE(Z3_INVALID_ARG, nullptr); + return; + } + p.set_definition(pd, n, _vars.c_ptr(), abs_body); + Z3_CATCH; + } + Z3_ast Z3_API Z3_mk_app(Z3_context c, Z3_func_decl d, unsigned num_args, Z3_ast const * args) { Z3_TRY; LOG_Z3_mk_app(c, d, num_args, args); diff --git a/src/api/api_context.cpp b/src/api/api_context.cpp index c236ba3e8..9fe13a15f 100644 --- a/src/api/api_context.cpp +++ b/src/api/api_context.cpp @@ -79,6 +79,7 @@ namespace api { m_datalog_util(m()), m_fpa_util(m()), m_sutil(m()), + m_recfun(m()), m_last_result(m()), m_ast_trail(m()), m_pmanager(m_limit) { diff --git a/src/api/api_context.h b/src/api/api_context.h index a6f55d1aa..b04768710 100644 --- a/src/api/api_context.h +++ b/src/api/api_context.h @@ -29,6 +29,7 @@ Revision History: #include "ast/datatype_decl_plugin.h" #include "ast/dl_decl_plugin.h" #include "ast/fpa_decl_plugin.h" +#include "ast/recfun_decl_plugin.h" #include "smt/smt_kernel.h" #include "smt/params/smt_params.h" #include "util/event_handler.h" @@ -62,6 +63,7 @@ namespace api { datalog::dl_decl_util m_datalog_util; fpa_util m_fpa_util; seq_util m_sutil; + recfun_util m_recfun; // Support for old solver API smt_params m_fparams; @@ -128,6 +130,7 @@ namespace api { fpa_util & fpautil() { return m_fpa_util; } datatype_util& dtutil() { return m_dt_plugin->u(); } seq_util& sutil() { return m_sutil; } + recfun_util& recfun() { return m_recfun; } family_id get_basic_fid() const { return m_basic_fid; } family_id get_array_fid() const { return m_array_fid; } family_id get_arith_fid() const { return m_arith_fid; } diff --git a/src/api/c++/z3++.h b/src/api/c++/z3++.h index 07056746d..d360b6153 100644 --- a/src/api/c++/z3++.h +++ b/src/api/c++/z3++.h @@ -311,6 +311,13 @@ namespace z3 { func_decl function(char const * name, sort const & d1, sort const & d2, sort const & d3, sort const & d4, sort const & range); func_decl function(char const * name, sort const & d1, sort const & d2, sort const & d3, sort const & d4, sort const & d5, sort const & range); + func_decl recfun(symbol const & name, unsigned arity, sort const * domain, sort const & range); + func_decl recfun(char const * name, unsigned arity, sort const * domain, sort const & range); + func_decl recfun(char const * name, sort const & domain, sort const & range); + func_decl recfun(char const * name, sort const & d1, sort const & d2, sort const & range); + + void recdef(func_decl, expr_vector const& args, expr const& body); + expr constant(symbol const & name, sort const & s); expr constant(char const * name, sort const & s); expr bool_const(char const * name); @@ -2815,6 +2822,37 @@ namespace z3 { return func_decl(*this, f); } + inline func_decl context::recfun(symbol const & name, unsigned arity, sort const * domain, sort const & range) { + array args(arity); + for (unsigned i = 0; i < arity; i++) { + check_context(domain[i], range); + args[i] = domain[i]; + } + Z3_func_decl f = Z3_mk_rec_func_decl(m_ctx, name, arity, args.ptr(), range); + check_error(); + return func_decl(*this, f); + + } + + inline func_decl context::recfun(char const * name, unsigned arity, sort const * domain, sort const & range) { + return recfun(str_symbol(name), arity, domain, range); + } + + inline func_decl context::recfun(char const * name, sort const& d1, sort const & range) { + return recfun(str_symbol(name), 1, &d1, range); + } + + inline func_decl context::recfun(char const * name, sort const& d1, sort const& d2, sort const & range) { + sort dom[2] = { d1, d2 }; + return recfun(str_symbol(name), 2, dom, range); + } + + void context::recdef(func_decl f, expr_vector const& args, expr const& body) { + check_context(f, args); check_context(f, body); + array vars(args); + Z3_add_rec_def(f.ctx(), f, vars.size(), vars.ptr(), body); + } + inline expr context::constant(symbol const & name, sort const & s) { Z3_ast r = Z3_mk_const(m_ctx, name, s); check_error(); @@ -2976,6 +3014,19 @@ namespace z3 { return range.ctx().function(name.c_str(), domain, range); } + inline func_decl recfun(symbol const & name, unsigned arity, sort const * domain, sort const & range) { + return range.ctx().recfun(name, arity, domain, range); + } + inline func_decl recfun(char const * name, unsigned arity, sort const * domain, sort const & range) { + return range.ctx().recfun(name, arity, domain, range); + } + inline func_decl recfun(char const * name, sort const& d1, sort const & range) { + return range.ctx().recfun(name, d1, range); + } + inline func_decl recfun(char const * name, sort const& d1, sort const& d2, sort const & range) { + return range.ctx().recfun(name, d1, d2, range); + } + inline expr select(expr const & a, expr const & i) { check_context(a, i); Z3_ast r = Z3_mk_select(a.ctx(), a, i); diff --git a/src/api/z3_api.h b/src/api/z3_api.h index de446d9a8..a776bbe8e 100644 --- a/src/api/z3_api.h +++ b/src/api/z3_api.h @@ -2081,6 +2081,7 @@ extern "C" { Z3_sort range); + /** \brief Create a constant or function application. @@ -2140,6 +2141,48 @@ extern "C" { def_API('Z3_mk_fresh_const', AST, (_in(CONTEXT), _in(STRING), _in(SORT))) */ Z3_ast Z3_API Z3_mk_fresh_const(Z3_context c, Z3_string prefix, Z3_sort ty); + + + /** + \brief Declare a recursive function + + \param c logical context. + \param s name of the function. + \param domain_size number of arguments. It should be greater than 0. + \param domain array containing the sort of each argument. The array must contain domain_size elements. + \param range sort of the constant or the return sort of the function. + + After declaring recursive function, it should be associated with a recursive definition #Z3_mk_rec_def. + The function #Z3_mk_app can be used to create a constant or function + application. + + \sa Z3_mk_app + \sa Z3_mk_rec_def + + def_API('Z3_mk_rec_func_decl', FUNC_DECL, (_in(CONTEXT), _in(SYMBOL), _in(UINT), _in_array(2, SORT), _in(SORT))) + */ + Z3_func_decl Z3_API Z3_mk_rec_func_decl(Z3_context c, Z3_symbol s, + unsigned domain_size, Z3_sort const domain[], + Z3_sort range); + + /** + \brief Define the body of a recursive function. + + \param c logical context. + \param f function declaration. + \param n number of arguments to the function + \param args constants that are used as arguments to the recursive function in the definition. + \param body body of the recursive function + + After declaring a recursive function or a collection of mutually recursive functions, use + this function to provide the definition for the recursive function. + + \sa Z3_mk_rec_func_decl + + def_API('Z3_add_rec_def', VOID, (_in(CONTEXT), _in(FUNC_DECL), _in(UINT), _in_array(2, AST), _in(AST))) + */ + void Z3_API Z3_add_rec_def(Z3_context c, Z3_func_decl f, unsigned n, Z3_ast args[], Z3_ast body); + /*@}*/ /** @name Propositional Logic and Equality */ diff --git a/src/ast/recfun_decl_plugin.cpp b/src/ast/recfun_decl_plugin.cpp index 655462a1d..25f545ca4 100644 --- a/src/ast/recfun_decl_plugin.cpp +++ b/src/ast/recfun_decl_plugin.cpp @@ -312,8 +312,8 @@ namespace recfun { * Main manager for defined functions */ - util::util(ast_manager & m, family_id id) - : m_manager(m), m_fid(id), m_th_rw(m), + util::util(ast_manager & m) + : m_manager(m), m_fid(m.get_family_id("recfun")), m_th_rw(m), m_plugin(dynamic_cast(m.get_plugin(m_fid))) { } @@ -385,15 +385,15 @@ namespace recfun { SASSERT(m_manager); SASSERT(m_family_id != null_family_id); if (!m_util.get()) { - m_util = alloc(util, *m_manager, m_family_id); + m_util = alloc(util, *m_manager); } return *(m_util.get()); } promise_def plugin::mk_def(symbol const& name, unsigned n, sort *const * params, sort * range) { - SASSERT(! m_defs.contains(name)); def* d = u().decl_fun(name, n, params, range); - m_defs.insert(name, d); + SASSERT(! m_defs.contains(d->get_decl())); + m_defs.insert(d->get_decl(), d); return promise_def(&u(), d); } @@ -410,8 +410,8 @@ namespace recfun { def* plugin::mk_def(symbol const& name, unsigned n, sort ** params, sort * range, unsigned n_vars, var ** vars, expr * rhs) { - SASSERT(! m_defs.contains(name)); promise_def d = mk_def(name, n, params, range); + SASSERT(! m_defs.contains(d.get_def()->get_decl())); set_definition(d, n_vars, vars, rhs); return d.get_def(); } diff --git a/src/ast/recfun_decl_plugin.h b/src/ast/recfun_decl_plugin.h index d93ad50b8..b516ae2bd 100644 --- a/src/ast/recfun_decl_plugin.h +++ b/src/ast/recfun_decl_plugin.h @@ -140,7 +140,7 @@ namespace recfun { namespace decl { class plugin : public decl_plugin { - typedef map def_map; + typedef obj_map def_map; typedef obj_map case_def_map; mutable scoped_ptr m_util; @@ -173,14 +173,14 @@ namespace recfun { def* mk_def(symbol const& name, unsigned n, sort ** params, sort * range, unsigned n_vars, var ** vars, expr * rhs); - bool has_def(const symbol& s) const { return m_defs.contains(s); } + bool has_def(func_decl* f) const { return m_defs.contains(f); } bool has_defs() const; - def const& get_def(const symbol& s) const { return *(m_defs[s]); } - promise_def get_promise_def(const symbol &s) const { return promise_def(&u(), m_defs[s]); } - def& get_def(symbol const& s) { return *(m_defs[s]); } + def const& get_def(func_decl* f) const { return *(m_defs[f]); } + promise_def get_promise_def(func_decl* f) const { return promise_def(&u(), m_defs[f]); } + def& get_def(func_decl* f) { return *(m_defs[f]); } bool has_case_def(func_decl* f) const { return m_case_defs.contains(f); } case_def& get_case_def(func_decl* f) { SASSERT(has_case_def(f)); return *(m_case_defs[f]); } - bool is_declared(symbol const& s) const { return m_defs.contains(s); } + //bool is_declared(symbol const& s) const { return m_defs.contains(s); } }; } @@ -197,7 +197,7 @@ namespace recfun { void set_definition(promise_def & d, unsigned n_vars, var * const * vars, expr * rhs); public: - util(ast_manager &m, family_id); + util(ast_manager &m); ~util(); ast_manager & m() { return m_manager; } @@ -213,9 +213,10 @@ namespace recfun { //has_def(s)); - return m_plugin->get_def(s); + + def& get_def(func_decl* f) { + SASSERT(m_plugin->has_def(f)); + return m_plugin->get_def(f); } case_def& get_case_def(expr* e) { @@ -232,6 +233,8 @@ namespace recfun { } app_ref mk_depth_limit_pred(unsigned d); + + decl::plugin& get_plugin() { return *m_plugin; } }; } diff --git a/src/cmd_context/cmd_context.cpp b/src/cmd_context/cmd_context.cpp index 3031f8ed8..a4b949280 100644 --- a/src/cmd_context/cmd_context.cpp +++ b/src/cmd_context/cmd_context.cpp @@ -475,7 +475,6 @@ cmd_context::cmd_context(bool main_ctx, ast_manager * m, symbol const & l): m_manager(m), m_own_manager(m == nullptr), m_manager_initialized(false), - m_rec_fun_declared(false), m_pmanager(nullptr), m_sexpr_manager(nullptr), m_regular("stdout", std::cout), @@ -900,20 +899,15 @@ void cmd_context::model_del(func_decl* f) { } -recfun_decl_plugin * cmd_context::get_recfun_plugin() { - ast_manager & m = get_ast_manager(); - family_id id = m.get_family_id("recfun"); - recfun_decl_plugin* p = reinterpret_cast(m.get_plugin(id)); - SASSERT(p); - return p; +recfun_decl_plugin& cmd_context::get_recfun_plugin() { + recfun::util u(get_ast_manager()); + return u.get_plugin(); } recfun::promise_def cmd_context::decl_rec_fun(const symbol &name, unsigned int arity, sort *const *domain, sort *range) { SASSERT(logic_has_recfun()); - recfun_decl_plugin* p = get_recfun_plugin(); - recfun::promise_def def = p->mk_def(name, arity, domain, range); - return def; + return get_recfun_plugin().mk_def(name, arity, domain, range); } // insert a recursive function as a regular quantified axiom @@ -936,15 +930,6 @@ void cmd_context::insert_rec_fun_as_axiom(func_decl *f, expr_ref_vector const& b eq = m().mk_forall(ids.size(), f->get_domain(), ids.c_ptr(), eq, 0, m().rec_fun_qid(), symbol::null, 2, pats); } - // - // disable warning given the current way they are used - // (Z3 will here silently assume and not check the definitions to be well founded, - // and please use HSF for everything else). - // - if (false && !ids.empty() && !m_rec_fun_declared) { - warning_msg("recursive function definitions are assumed well-founded"); - m_rec_fun_declared = true; - } assert_expr(eq); } @@ -959,7 +944,7 @@ void cmd_context::insert_rec_fun(func_decl* f, expr_ref_vector const& binding, s TRACE("recfun", tout<< "define recfun " << f->get_name() << " = " << mk_pp(rhs, m()) << "\n";); - recfun_decl_plugin* p = get_recfun_plugin(); + recfun_decl_plugin& p = get_recfun_plugin(); var_ref_vector vars(m()); for (expr* b : binding) { @@ -967,8 +952,8 @@ void cmd_context::insert_rec_fun(func_decl* f, expr_ref_vector const& binding, s vars.push_back(to_var(b)); } - recfun::promise_def d = p->get_promise_def(f->get_name()); - p->set_definition(d, vars.size(), vars.c_ptr(), rhs); + recfun::promise_def d = p.get_promise_def(f); + p.set_definition(d, vars.size(), vars.c_ptr(), rhs); } func_decl * cmd_context::find_func_decl(symbol const & s) const { diff --git a/src/cmd_context/cmd_context.h b/src/cmd_context/cmd_context.h index acc406dfc..5e21a1bca 100644 --- a/src/cmd_context/cmd_context.h +++ b/src/cmd_context/cmd_context.h @@ -200,7 +200,6 @@ protected: ast_manager * m_manager; bool m_own_manager; bool m_manager_initialized; - bool m_rec_fun_declared; pdecl_manager * m_pmanager; sexpr_manager * m_sexpr_manager; check_logic m_check_logic; @@ -308,7 +307,7 @@ protected: void erase_macro(symbol const& s); bool macros_find(symbol const& s, unsigned n, expr*const* args, expr*& t) const; - recfun_decl_plugin * get_recfun_plugin(); + recfun_decl_plugin& get_recfun_plugin(); public: cmd_context(bool main_ctx = true, ast_manager * m = nullptr, symbol const & l = symbol::null); diff --git a/src/smt/theory_recfun.h b/src/smt/theory_recfun.h index b0d705ddc..56e738f21 100644 --- a/src/smt/theory_recfun.h +++ b/src/smt/theory_recfun.h @@ -42,8 +42,7 @@ namespace smt { m_lhs(n), m_def(nullptr), m_args() { SASSERT(u.is_defined(n)); func_decl * d = n->get_decl(); - const symbol& name = d->get_name(); - m_def = &u.get_def(name); + m_def = &u.get_def(d); m_args.append(n->get_num_args(), n->get_args()); } case_expansion(case_expansion const & from)