diff --git a/src/api/api_ast.cpp b/src/api/api_ast.cpp index 9f9039378..7cd3b2dd9 100644 --- a/src/api/api_ast.cpp +++ b/src/api/api_ast.cpp @@ -874,6 +874,91 @@ extern "C" { Z3_CATCH_RETURN(nullptr); } + Z3_ast Z3_API Z3_substitute_funs(Z3_context c, + Z3_ast _a, + unsigned num_funs, + Z3_func_decl const _from[], + Z3_ast const _to[]) { + Z3_TRY; + LOG_Z3_substitute_funs(c, _a, num_funs, _from, _to); + RESET_ERROR_CODE(); + ast_manager & m = mk_c(c)->m(); + expr * a = to_expr(_a); + func_decl * const * from = to_func_decls(_from); + expr * const * to = to_exprs(num_funs, _to); + + expr * r = nullptr, *v, *w; + expr_ref_vector trail(m), args(m); + ptr_vector todo; + obj_map rep; + obj_map cache; + + for (unsigned i = 0; i < num_funs; i++) { + if (from[i]->get_range() != to[i]->get_sort()) { + SET_ERROR_CODE(Z3_SORT_ERROR, nullptr); + RETURN_Z3(of_expr(nullptr)); + } + rep.insert(from[i], to[i]); + } + + var_subst subst(m, false); + todo.push_back(a); + while (!todo.empty()) { + r = todo.back(); + if (cache.contains(r)) + todo.pop_back(); + else if (is_app(r)) { + args.reset(); + unsigned sz = todo.size(); + bool change = false; + for (expr* arg : *to_app(r)) { + if (cache.find(arg, v)) { + args.push_back(v); + change |= v != arg; + } + else { + todo.push_back(arg); + } + } + if (todo.size() == sz) { + if (rep.find(to_app(r)->get_decl(), w)) { + expr_ref new_v = subst(w, args); + v = new_v; + trail.push_back(v); + } + else if (change) { + v = m.mk_app(to_app(r)->get_decl(), args); + trail.push_back(v); + } + else + v = r; + cache.insert(r, v); + todo.pop_back(); + } + } + else if (is_var(r)) { + cache.insert(r, r); + todo.pop_back(); + } + else if (is_quantifier(r)) { + if (cache.find(to_quantifier(r)->get_expr(), v)) { + v = m.update_quantifier(to_quantifier(r), v); + trail.push_back(v); + cache.insert(r, v); + todo.pop_back(); + } + else + todo.push_back(to_quantifier(r)->get_expr()); + } + else + UNREACHABLE(); + } + r = cache[a]; + mk_c(c)->save_ast_trail(r); + RETURN_Z3(of_expr(r)); + Z3_CATCH_RETURN(nullptr); + } + Z3_ast Z3_API Z3_substitute_vars(Z3_context c, Z3_ast _a, unsigned num_exprs, diff --git a/src/api/python/z3/z3.py b/src/api/python/z3/z3.py index 5ad8ede50..99be0055e 100644 --- a/src/api/python/z3/z3.py +++ b/src/api/python/z3/z3.py @@ -8803,6 +8803,27 @@ def substitute_vars(t, *m): _to[i] = m[i].as_ast() return _to_expr_ref(Z3_substitute_vars(t.ctx.ref(), t.as_ast(), num, _to), t.ctx) +def substitute_funs(t, *m): + """Apply subistitution m on t, m is a list of pairs of a function and expression (from, to) + Every occurrence in to of the function from is replaced with the expression to. + The expression to can have free variables, that refer to the arguments of from. + For examples, see + """ + if isinstance(m, tuple): + m1 = _get_args(m) + if isinstance(m1, list) and all(isinstance(p, tuple) for p in m1): + m = m1 + if z3_debug(): + _z3_assert(is_expr(t), "Z3 expression expected") + _z3_assert(all([isinstance(p, tuple) and is_func_decl(p[0]) and is_expr(p[1]) for p in m]), "Z3 invalid substitution, funcion pairs expected.") + num = len(m) + _from = (FuncDecl * num)() + _to = (Ast * num)() + for i in range(num): + _from[i] = m[i][0].as_func_decl() + _to[i] = m[i][1].as_ast() + return _to_expr_ref(Z3_substitute_funs(t.ctx.ref(), t.as_ast(), num, _from, _to), t.ctx) + def Sum(*args): """Create the sum of the Z3 expressions. diff --git a/src/api/z3_api.h b/src/api/z3_api.h index b8bca8fab..740e304ad 100644 --- a/src/api/z3_api.h +++ b/src/api/z3_api.h @@ -5292,6 +5292,20 @@ extern "C" { unsigned num_exprs, Z3_ast const to[]); + /** + \brief Substitute funcions in \c from with new expressions in \c to. + + The expressions in \c to can have free variables. The free variable in \c to at index 0 + refers to the first argument of \c from, the free variable at index 1 corresponds to the second argument. + + def_API('Z3_substitute_funs', AST, (_in(CONTEXT), _in(AST), _in(UINT), _in_array(2, FUNC_DECL), _in_array(2, AST))) + */ + Z3_ast Z3_API Z3_substitute_funs(Z3_context c, + Z3_ast a, + unsigned num_funs, + Z3_func_decl const from[], + Z3_ast const to[]); + /** \brief Translate/Copy the AST \c a from context \c source to context \c target. AST \c a must have been created using context \c source.