diff --git a/src/api/api_ast.cpp b/src/api/api_ast.cpp index bc29826ff..a28b3f20b 100644 --- a/src/api/api_ast.cpp +++ b/src/api/api_ast.cpp @@ -29,6 +29,7 @@ Revision History: #include "ast/ast_ll_pp.h" #include "ast/ast_smt_pp.h" #include "ast/ast_smt2_pp.h" +#include "ast/polymorphism_util.h" #include "ast/rewriter/th_rewriter.h" #include "ast/rewriter/var_subst.h" #include "ast/rewriter/expr_safe_replace.h" @@ -88,6 +89,16 @@ extern "C" { Z3_CATCH_RETURN(nullptr); } + Z3_sort Z3_API Z3_mk_type_variable(Z3_context c, Z3_symbol name) { + Z3_TRY; + LOG_Z3_mk_type_variable(c, name); + RESET_ERROR_CODE(); + sort* ty = mk_c(c)->m().mk_type_var(to_symbol(name)); + mk_c(c)->save_ast_trail(ty); + RETURN_Z3(of_sort(ty)); + Z3_CATCH_RETURN(nullptr); + } + bool Z3_API Z3_is_eq_ast(Z3_context c, Z3_ast s1, Z3_ast s2) { RESET_ERROR_CODE(); return s1 == s2; @@ -180,7 +191,20 @@ extern "C" { arg_list.push_back(to_expr(args[i])); } func_decl* _d = reinterpret_cast(d); - app* a = mk_c(c)->m().mk_app(_d, num_args, arg_list.data()); + ast_manager& m = mk_c(c)->m(); + if (_d->is_polymorphic()) { + polymorphism::util u(m); + polymorphism::substitution sub(m); + ptr_buffer domain; + for (unsigned i = 0; i < num_args; ++i) { + if (!sub.match(_d->get_domain(i), arg_list[i]->get_sort())) + SET_ERROR_CODE(Z3_INVALID_ARG, "failed to match argument of polymorphic function"); + domain.push_back(arg_list[i]->get_sort()); + } + sort_ref range = sub(_d->get_range()); + _d = m.instantiate_polymorphic(_d, num_args, domain.data(), range); + } + app* a = m.mk_app(_d, num_args, arg_list.data()); mk_c(c)->save_ast_trail(a); check_sorts(c, a); RETURN_Z3(of_ast(a)); @@ -728,6 +752,9 @@ extern "C" { else if (fid == mk_c(c)->get_char_fid() && k == CHAR_SORT) { return Z3_CHAR_SORT; } + else if (fid == poly_family_id) { + return Z3_TYPE_VAR; + } else { return Z3_UNKNOWN_SORT; } diff --git a/src/api/python/z3/z3.py b/src/api/python/z3/z3.py index ee47e7dd7..d1a856174 100644 --- a/src/api/python/z3/z3.py +++ b/src/api/python/z3/z3.py @@ -683,6 +683,8 @@ def _to_sort_ref(s, ctx): return SeqSortRef(s, ctx) elif k == Z3_CHAR_SORT: return CharSortRef(s, ctx) + elif k == Z3_TYPE_VAR: + return TypeVarRef(s, ctx) return SortRef(s, ctx) @@ -708,6 +710,26 @@ def DeclareSort(name, ctx=None): ctx = _get_ctx(ctx) return SortRef(Z3_mk_uninterpreted_sort(ctx.ref(), to_symbol(name, ctx)), ctx) +class TypeVarRef(SortRef): + """Type variable reference""" + + def subsort(self, other): + return True + + def cast(self, val): + return val + + +def DeclareTypeVar(name, ctx=None): + """Create a new type variable named `name`. + + If `ctx=None`, then the new sort is declared in the global Z3Py context. + + """ + ctx = _get_ctx(ctx) + return TypeVarRef(Z3_mk_type_variable(ctx.ref(), to_symbol(name, ctx)), ctx) + + ######################################### # # Function Declarations diff --git a/src/api/z3_api.h b/src/api/z3_api.h index 29eed7e26..a931bc523 100644 --- a/src/api/z3_api.h +++ b/src/api/z3_api.h @@ -151,6 +151,7 @@ typedef enum Z3_SEQ_SORT, Z3_RE_SORT, Z3_CHAR_SORT, + Z3_TYPE_VAR, Z3_UNKNOWN_SORT = 1000 } Z3_sort_kind; @@ -1883,6 +1884,17 @@ extern "C" { */ Z3_sort Z3_API Z3_mk_uninterpreted_sort(Z3_context c, Z3_symbol s); + /** + \brief Create a type variable. + + Functions using type variables can be applied to instantiations that match the signature + of the function. Assertions using type variables correspond to assertions over all possible + instantiations. + + def_API('Z3_mk_type_variable', SORT, (_in(CONTEXT), _in(SYMBOL))) + */ + Z3_sort Z3_API Z3_mk_type_variable(Z3_context c, Z3_symbol s); + /** \brief Create the Boolean type. diff --git a/src/ast/CMakeLists.txt b/src/ast/CMakeLists.txt index 8dd870964..9df3ff001 100644 --- a/src/ast/CMakeLists.txt +++ b/src/ast/CMakeLists.txt @@ -37,6 +37,8 @@ z3_add_component(ast num_occurs.cpp occurs.cpp pb_decl_plugin.cpp + polymorphism_inst.cpp + polymorphism_util.cpp pp.cpp quantifier_stat.cpp recfun_decl_plugin.cpp diff --git a/src/ast/ast.cpp b/src/ast/ast.cpp index aeccc7612..3bb435ed8 100644 --- a/src/ast/ast.cpp +++ b/src/ast/ast.cpp @@ -2049,8 +2049,8 @@ func_decl * ast_manager::mk_func_decl(symbol const & name, unsigned arity, sort } func_decl* new_node = new (mem) func_decl(name, arity, domain, range, info); new_node = register_node(new_node); - if (is_polymorphic_root) - m_poly_roots.insert(new_node, new_node); + if (is_polymorphic_root) + m_poly_roots.insert(new_node, new_node); return new_node; } @@ -2774,8 +2774,8 @@ bool ast_manager::has_type_var(unsigned n, sort* const* domain, sort* range) con func_decl* ast_manager::instantiate_polymorphic(func_decl* f, unsigned arity, sort * const* domain, sort * range) { SASSERT(f->is_polymorphic()); func_decl* g = mk_func_decl(f->get_name(), arity, domain, range, f->get_info()); - m_poly_roots.insert(f, g); - SASSERT(g->is_polymorphic()); + m_poly_roots.insert(g, f); + // SASSERT(g->is_polymorphic()); return g; } diff --git a/src/ast/ast_translation.cpp b/src/ast/ast_translation.cpp index 781593b38..e2369a35a 100644 --- a/src/ast/ast_translation.cpp +++ b/src/ast/ast_translation.cpp @@ -65,6 +65,13 @@ void ast_translation::collect_decl_extra_children(decl * d) { } void ast_translation::push_frame(ast * n) { + // ensure poly roots are pushed first. + if (m_from_manager.has_type_vars() && n->get_kind() == AST_FUNC_DECL && to_func_decl(n)->is_polymorphic()) { + func_decl* g = m_from_manager.poly_root(to_func_decl(n)); + if (n != g && m_cache.contains(g)) { + m_frame_stack.push_back(frame(n, 0, m_extra_children_stack.size(), m_result_stack.size())); + } + } m_frame_stack.push_back(frame(n, 0, m_extra_children_stack.size(), m_result_stack.size())); switch (n->get_kind()) { case AST_SORT: @@ -153,6 +160,10 @@ void ast_translation::mk_func_decl(func_decl * f, frame & fr) { new_domain, new_range); } + else if (f->is_polymorphic() && m_from_manager.poly_root(f) != f) { + func_decl* fr = to_func_decl(m_cache[m_from_manager.poly_root(f)]); + new_f = m_to_manager.instantiate_polymorphic(fr, f->get_arity(), new_domain, new_range); + } else { buffer ps; copy_params(f, fr.m_rpos, ps); diff --git a/src/ast/occurs.cpp b/src/ast/occurs.cpp index a619dfdbd..4e0008373 100644 --- a/src/ast/occurs.cpp +++ b/src/ast/occurs.cpp @@ -87,7 +87,7 @@ bool occurs(func_decl * d, expr * n) { bool occurs(sort* s1, sort* s2) { sort_proc p(s1); try { - for_each_ast(p, s2); + for_each_ast(p, s2, true); } catch (const found&) { return true; diff --git a/src/ast/polymorphism_inst.cpp b/src/ast/polymorphism_inst.cpp new file mode 100644 index 000000000..42dd516f9 --- /dev/null +++ b/src/ast/polymorphism_inst.cpp @@ -0,0 +1,138 @@ +/*++ +Copyright (c) 2023 Microsoft Corporation + +Module Name: + + polymorphism_inst.cpp + +Abstract: + + Utilities for instantiating polymorphic assertions. + +Author: + + Nikolaj Bjorner (nbjorner) 2023-7-8 + + +--*/ +#include "ast/polymorphism_inst.h" +#include "ast/ast_pp.h" + +namespace polymorphism { + + void inst::add(expr* e) { + if (!m.has_type_vars()) + return; + + if (m_from_instantiation.contains(e)) + return; + + instances inst; + u.collect_poly_instances(e, inst.m_poly_fns); + if (inst.m_poly_fns.empty()) + return; + if (m_instances.contains(e)) + return; + + add_instantiations(e, inst.m_poly_fns); + + if (!u.has_type_vars(e)) + return; + // insert e into the occurs list for polymorphic roots + ast_mark seen; + for (auto* f : inst.m_poly_fns) { + f = m.poly_root(f); + if (seen.is_marked(f)) + continue; + seen.mark(f, true); + if (!m_occurs.contains(f)) { + m_occurs.insert(f, ptr_vector()); + t.push(insert_map(m_occurs, f)); + } + auto& es = m_occurs.find(f); + es.push_back(e); + t.push(remove_back(m_occurs, f)); + } + m_assertions.push_back(e); + t.push(push_back_vector(m_assertions)); + u.collect_type_vars(e, inst.m_tvs); + inst.m_subst = alloc(substitutions); + inst.m_subst->insert(alloc(substitution, m)); + m_instances.insert(e, inst); + t.push(new_obj_trail(inst.m_subst)); + t.push(insert_map(m_instances, e)); + } + + void inst::collect_instantiations(expr* e) { + ptr_vector instances; + u.collect_poly_instances(e, instances); + add_instantiations(e, instances); + } + + void inst::add_instantiations(expr* e, ptr_vector const& instances) { + for (auto* f : instances) { + if (m_in_decl_queue.is_marked(f)) + continue; + m_in_decl_queue.mark(f, true); + m_decl_queue.push_back(f); + t.push(add_decl_queue(*this)); + } + } + + void inst::instantiate(vector& instances) { + unsigned num_decls = m_decl_queue.size(); + if (m_assertions_qhead < m_assertions.size()) { + t.push(value_trail(m_assertions_qhead)); + for (; m_assertions_qhead < m_assertions.size(); ++m_assertions_qhead) { + expr* e = m_assertions.get(m_assertions_qhead); + for (unsigned i = 0; i < num_decls; ++i) + instantiate(m_decl_queue.get(i), e, instances); + } + } + if (m_decl_qhead < num_decls) { + t.push(value_trail(m_decl_qhead)); + for (; m_decl_qhead < num_decls; ++m_decl_qhead) { + func_decl* p = m_decl_queue.get(m_decl_qhead); + for (expr* e : m_occurs[m.poly_root(p)]) + instantiate(p, e, instances); + } + } + } + + void inst::instantiate(func_decl* f1, expr* e, vector& instances) { + auto const& [tv, fns, substs] = m_instances[e]; + + for (auto* f2 : fns) { + substitution sub1(m), new_sub(m); + if (!u.unify(f1, f2, sub1)) + continue; + if (substs->contains(&sub1)) + continue; + substitutions new_substs; + for (auto* sub2 : *substs) { + if (!u.unify(sub1, *sub2, new_sub)) + continue; + if (substs->contains(&new_sub)) + continue; + if (new_substs.contains(&new_sub)) + continue; + expr_ref e_inst = new_sub(e); + if (!m_from_instantiation.contains(e_inst)) { + collect_instantiations(e_inst); + auto* new_sub1 = alloc(substitution, new_sub); + instances.push_back(instantiation(e, e_inst, new_sub1)); + new_substs.insert(new_sub1); + m_from_instantiation.insert(e_inst); + m.inc_ref(e_inst); + t.push(insert_ref_map(m, m_from_instantiation, e_inst)); + } + } + for (auto* sub2 : new_substs) { + SASSERT(!substs->contains(sub2)); + substs->insert(sub2); + t.push(new_obj_trail(sub2)); + t.push(insert_map(*substs, sub2)); + } + } + } +} diff --git a/src/ast/polymorphism_inst.h b/src/ast/polymorphism_inst.h new file mode 100644 index 000000000..1d171b314 --- /dev/null +++ b/src/ast/polymorphism_inst.h @@ -0,0 +1,91 @@ +/*++ +Copyright (c) 2023 Microsoft Corporation + +Module Name: + + polymorphism_inst.h + +Abstract: + + Utilities for instantiating polymorphic assertions. + +Author: + + Nikolaj Bjorner (nbjorner) 2023-7-8 + + +--*/ +#pragma once + +#include "util/trail.h" +#include "ast/ast.h" +#include "ast/polymorphism_util.h" + +namespace polymorphism { + + struct instantiation { + expr* orig; + expr_ref inst; + substitution* sub; + instantiation(expr* orig, expr_ref& inst, substitution* s): + orig(orig), inst(inst), sub(s) {} + }; + + class inst { + ast_manager& m; + trail_stack& t; + util u; + + struct instances { + ptr_vector m_tvs; + ptr_vector m_poly_fns; + substitutions* m_subst = nullptr; + }; + + func_decl_ref_vector m_poly_roots; + obj_map> m_occurs; + obj_map m_instances; + func_decl_ref_vector m_decl_queue; + unsigned m_decl_qhead = 0; + ast_mark m_in_decl_queue; + expr_ref_vector m_assertions; + unsigned m_assertions_qhead = 0; + obj_hashtable m_from_instantiation; + + struct add_decl_queue : public trail { + inst& i; + add_decl_queue(inst& i): i(i) {} + void undo() override { + i.m_in_decl_queue.mark(i.m_decl_queue.back(), false); + i.m_decl_queue.pop_back(); + }; + }; + + struct remove_back : public trail { + obj_map>& occ; + func_decl* f; + remove_back(obj_map>& occ, func_decl* f): + occ(occ), f(f) {} + void undo() override { + occ.find(f).pop_back(); + } + }; + + void instantiate(func_decl* p, expr* e, vector& instances); + + void collect_instantiations(expr* e); + + void add_instantiations(expr* e, ptr_vector const& insts); + + public: + inst(ast_manager& m, trail_stack& t): + m(m), t(t), u(m), m_poly_roots(m), m_decl_queue(m), m_assertions(m) {} + + void add(expr* e); + + void instantiate(vector& instances); + + bool pending() const { return m_decl_qhead < m_decl_queue.size() || m_assertions_qhead < m_assertions.size(); } + + }; +} diff --git a/src/ast/polymorphism_util.cpp b/src/ast/polymorphism_util.cpp new file mode 100644 index 000000000..2fe271fc5 --- /dev/null +++ b/src/ast/polymorphism_util.cpp @@ -0,0 +1,349 @@ +/*++ +Copyright (c) 2023 Microsoft Corporation + +Module Name: + + polymorphism_util.cpp + +Abstract: + + Utilities for supporting polymorphic type signatures. + +Author: + + Nikolaj Bjorner (nbjorner) 2023-7-8 + +--*/ + +#include "ast/polymorphism_util.h" +#include "ast/for_each_ast.h" +#include "ast/occurs.h" +#include "ast/ast_pp.h" + +namespace polymorphism { + + sort_ref_vector substitution::operator()(sort_ref_vector const& s) { + sort_ref_vector r(m); + for (auto* srt : s) + r.push_back((*this)(srt)); + return r; + } + + sort_ref substitution::operator()(sort* s) { + if (!m.has_type_var(s)) + return sort_ref(s, m); + if (s->is_type_var()) { + if (m_sub.find(s, s)) + return (*this)(s); + return sort_ref(s, m); + } + unsigned n = s->get_num_parameters(); + vector ps; + for (unsigned i = 0; i < n; ++i) { + auto p = s->get_parameter(i); + if (p.is_ast() && is_sort(p.get_ast())) + ps.push_back(parameter((*this)(to_sort(p.get_ast())))); + else + ps.push_back(p); + } + sort_info si(s->get_family_id(), s->get_decl_kind(), n, ps.data(), s->private_parameters()); + return sort_ref(m.mk_sort(s->get_name(), si), m); + } + + expr_ref substitution::operator()(expr* e) { + ptr_vector todo; + expr_ref_vector result(m); + todo.push_back(e); + auto in_cache = [&](expr* a) { + return result.size() > a->get_id() && result.get(a->get_id()); + }; + ptr_buffer args; + sort_ref_buffer domain(m); + while (!todo.empty()) { + expr* a = todo.back(); + if (in_cache(a)) { + todo.pop_back(); + continue; + } + if (is_var(a)) { + if (m.has_type_var(a->get_sort())) + result.setx(a->get_id(), m.mk_var(to_var(a)->get_idx(), (*this)(a->get_sort()))); + else + result.setx(a->get_id(), a); + todo.pop_back(); + } + else if (is_quantifier(a)) { + quantifier* q = to_quantifier(a); + bool pending = false; + if (!in_cache(q->get_expr())) { + todo.push_back(q->get_expr()); + pending = true; + } + ptr_buffer patterns, no_patterns; + unsigned np = q->get_num_patterns(); + for (unsigned i = 0; i < np; ++i) { + if (!in_cache(q->get_pattern(i))) { + todo.push_back(q->get_pattern(i)); + pending = true; + } + else + patterns.push_back(result.get(q->get_pattern(i)->get_id())); + } + np = q->get_num_no_patterns(); + for (unsigned i = 0; i < np; ++i) { + if (!in_cache(q->get_no_pattern(i))) { + todo.push_back(q->get_no_pattern(i)); + pending = true; + } + else + no_patterns.push_back(result.get(q->get_no_pattern(i)->get_id())); + } + if (pending) + continue; + todo.pop_back(); + ptr_buffer sorts; + for (unsigned i = 0; i < q->get_num_decls(); ++i) + sorts.push_back((*this)(q->get_decl_sort(i))); + quantifier* q2 = + m.mk_quantifier(q->get_kind(), q->get_num_decls(), sorts.data(), q->get_decl_names(), result.get(q->get_expr()->get_id()), + q->get_weight(), + q->get_qid(), q->get_skid(), + q->get_num_patterns(), patterns.data(), q->get_num_no_patterns(), no_patterns.data() + ); + result.setx(q->get_id(), q2); + } + else if (is_app(a)) { + args.reset(); + unsigned n = todo.size(); + for (expr* arg : *to_app(a)) { + if (!in_cache(arg)) + todo.push_back(arg); + else + args.push_back(result.get(arg->get_id())); + } + if (n < todo.size()) + continue; + func_decl* f = to_app(a)->get_decl(); + if (f->is_polymorphic()) { + domain.reset(); + for (unsigned i = 0; i < f->get_arity(); ++i) + domain.push_back((*this)(f->get_domain(i))); + sort_ref range = (*this)(f->get_range()); + f = m.instantiate_polymorphic(f, f->get_arity(), domain.data(), range); + } + result.setx(a->get_id(), m.mk_app(f, args)); + todo.pop_back(); + } + } + return expr_ref(result.get(e->get_id()), m); + } + + bool substitution::unify(sort* s1, sort* s2) { + if (s1 == s2) + return true; + if (s1->is_type_var() && m_sub.find(s1, s1)) + return unify(s1, s2); + if (s2->is_type_var() && m_sub.find(s2, s2)) + return unify(s1, s2); + if (s2->is_type_var() && !s1->is_type_var()) + std::swap(s1, s2); + if (s1->is_type_var()) { + auto s22 = (*this)(s2); + if (occurs(s1, s22)) + return false; + m_trail.push_back(s22); + m_trail.push_back(s1); + m_sub.insert(s1, s22); + return true; + } + if (s1->get_family_id() != s2->get_family_id()) + return false; + if (s1->get_decl_kind() != s2->get_decl_kind()) + return false; + if (s1->get_name() != s2->get_name()) + return false; + if (s1->get_num_parameters() != s2->get_num_parameters()) + return false; + for (unsigned i = s1->get_num_parameters(); i-- > 0;) { + auto p1 = s1->get_parameter(i); + auto p2 = s2->get_parameter(i); + if (p1.is_ast() && is_sort(p1.get_ast())) { + if (!p2.is_ast()) + return false; + if (!is_sort(p2.get_ast())) + return false; + if (!unify(to_sort(p1.get_ast()), to_sort(p2.get_ast()))) + return false; + continue; + } + if (p1 != p2) + return false; + } + return true; + } + + bool substitution::match(sort* s1, sort* s2) { + if (s1 == s2) + return true; + if (s1->is_type_var() && m_sub.find(s1, s1)) + return match(s1, s2); + if (s1->is_type_var()) { + m_trail.push_back(s2); + m_trail.push_back(s1); + m_sub.insert(s1, s2); + return true; + } + if (s1->get_family_id() != s2->get_family_id()) + return false; + if (s1->get_decl_kind() != s2->get_decl_kind()) + return false; + if (s1->get_name() != s2->get_name()) + return false; + if (s1->get_num_parameters() != s2->get_num_parameters()) + return false; + for (unsigned i = s1->get_num_parameters(); i-- > 0;) { + auto p1 = s1->get_parameter(i); + auto p2 = s2->get_parameter(i); + if (p1.is_ast() && is_sort(p1.get_ast())) { + if (!p2.is_ast()) + return false; + if (!is_sort(p2.get_ast())) + return false; + if (!match(to_sort(p1.get_ast()), to_sort(p2.get_ast()))) + return false; + continue; + } + if (p1 != p2) + return false; + } + return true; + } + + // util + bool util::unify(sort* s1, sort* s2, substitution& sub) { + return sub.unify(s1, s2); + } + + bool util::unify(func_decl* f1, func_decl* f2, substitution& sub) { + if (f1 == f2) + return true; + if (!f1->is_polymorphic() || !f2->is_polymorphic()) + return false; + if (m.poly_root(f1) != m.poly_root(f2)) + return false; + for (unsigned i = f1->get_arity(); i-- > 0; ) + if (!sub.unify(fresh(f1->get_domain(i)), f2->get_domain(i))) + return false; + return sub.unify(fresh(f1->get_range()), f2->get_range()); + } + + bool util::unify(substitution const& s1, substitution const& s2, + substitution& sub) { + sort* v2; + for (auto const& [k, v] : s1) + sub.insert(k, v); + for (auto const& [k, v] : s2) { + if (sub.find(k, v2)) { + if (!sub.unify(sub(v), v2)) + return false; + } + else + sub.insert(k, sub(v)); + } + return true; + } + + bool util::match(substitution& sub, sort* s1, sort* s_ground) { + return sub.match(s1, s_ground); + } + + /** + * Create fresh variables, but with caching. + * So "fresh" variables are not truly fresh globally. + * This can block some unifications and therefore block some instantiations of + * polymorphic assertions. A different caching scheme could be created to + * ensure that fresh variables are introduced at the right time, or use other + * tricks such as creating variable/offset pairs to distinguish name spaces without + * incurring costs. + */ + sort_ref util::fresh(sort* s) { + sort* s1; + if (m_fresh.find(s, s1)) + return sort_ref(s1, m); + + if (m.is_type_var(s)) { + s1 = m.mk_type_var(symbol("fresh!" + std::to_string(m_counter))); + m_trail.push_back(s1); + m_trail.push_back(s); + m_fresh.insert(s, s1); + return sort_ref(s1, m); + } + vector params; + for (unsigned i = 0; i < s->get_num_parameters(); ++i) { + parameter p = s->get_parameter(i); + if (p.is_ast() && is_sort(p.get_ast())) + params.push_back(parameter(fresh(to_sort(p.get_ast())))); + else + params.push_back(p); + } + sort_info info(s->get_family_id(), s->get_decl_kind(), params.size(), params.data(), s->private_parameters()); + s1 = m.mk_sort(s->get_name(), info); + m_trail.push_back(s1); + m_trail.push_back(s); + m_fresh.insert(s, s1); + return sort_ref(s1, m); + } + + sort_ref_vector util::fresh(unsigned n, sort* const* s) { + sort_ref_vector r(m); + for (unsigned i = 0; i < n; ++i) + r.push_back(fresh(s[i])); + return r; + } + + void util::collect_poly_instances(expr* e, ptr_vector& instances) { + struct proc { + ast_manager& m; + ptr_vector& instances; + proc(ast_manager& m, ptr_vector& instances) : m(m), instances(instances) {} + void operator()(func_decl* f) { + if (f->is_polymorphic() && !m.is_eq(f) && !is_decl_of(f, pattern_family_id, OP_PATTERN)) + instances.push_back(f); + } + void operator()(ast* a) {} + }; + proc proc(m, instances); + for_each_ast(proc, e, false); + } + + bool util::has_type_vars(expr* e) { + struct proc { + ast_manager& m; + bool found = false; + proc(ast_manager& m) : m(m) {} + void operator()(sort* f) { + if (m.has_type_var(f)) + found = true; + } + void operator()(ast* a) {} + }; + proc proc(m); + for_each_ast(proc, e, false); + return proc.found; + } + + void util::collect_type_vars(expr* e, ptr_vector& tvs) { + struct proc { + ast_manager& m; + ptr_vector& tvs; + proc(ast_manager& m, ptr_vector& tvs) : m(m), tvs(tvs) {} + void operator()(sort* s) { + if (m.is_type_var(s)) + tvs.push_back(s); + } + void operator()(ast* a) {} + }; + proc proc(m, tvs); + for_each_ast(proc, e, true); + } +} diff --git a/src/ast/polymorphism_util.h b/src/ast/polymorphism_util.h new file mode 100644 index 000000000..3023d0338 --- /dev/null +++ b/src/ast/polymorphism_util.h @@ -0,0 +1,112 @@ +/*++ +Copyright (c) 2023 Microsoft Corporation + +Module Name: + + polymorphism_util.h + +Abstract: + + Utilities for supporting polymorphic type signatures. + +Author: + + Nikolaj Bjorner (nbjorner) 2023-7-8 + +--*/ +#pragma once + +#include "ast/ast.h" +#include "util/hashtable.h" + +namespace polymorphism { + + class substitution { + ast_manager& m; + obj_map m_sub; + sort_ref_vector m_trail; + public: + substitution(ast_manager& m): m(m), m_trail(m) {} + + sort_ref_vector operator()(sort_ref_vector const& s); + + sort_ref operator()(sort* s); + + expr_ref operator()(expr* e); + + bool unify(sort* s1, sort* s2); + + bool match(sort* s1, sort* s_ground); + + obj_map::iterator begin() const { return m_sub.begin(); } + obj_map::iterator end() const { return m_sub.end(); } + + void insert(sort* v, sort* t) { m_trail.push_back(v).push_back(t); m_sub.insert(v, t); } + + bool find(sort* v, sort*& t) const { return m_sub.find(v, t); } + + unsigned size() const { return m_sub.size(); } + + /** + * weak equality: strong equality considers applying substitutions recursively in range + * because substitutions may be in triangular form. + */ + struct eq { + bool operator()(substitution const* s1, substitution const* s2) const { + if (s1->size() != s2->size()) + return false; + sort* v2; + for (auto const& [k, v] : *s1) { + if (!s2->find(k, v2)) + return false; + if (v != v2) + return false; + } + return true; + } + }; + + struct hash { + unsigned operator()(substitution const* s) const { + unsigned hash = 0xfabc1234 + s->size(); + for (auto const& [k, v] : *s) + hash ^= k->hash() + 2 * v->hash(); + return hash; + } + }; + }; + + typedef hashtable substitutions; + + class util { + ast_manager& m; + sort_ref_vector m_trail; + obj_map m_fresh; + unsigned m_counter = 0; + + sort_ref fresh(sort* s); + + sort_ref_vector fresh(unsigned n, sort* const* s); + + public: + util(ast_manager& m): m(m), m_trail(m) {} + + bool unify(sort* s1, sort* s2, substitution& sub); + + bool unify(func_decl* f1, func_decl* f2, substitution& sub); + + bool unify(substitution const& s1, substitution const& s2, + substitution& sub); + + bool match(substitution& sub, sort* s1, sort* s_ground); + + // collect instantiations of polymorphic functions + void collect_poly_instances(expr* e, ptr_vector& instances); + + // test if expression contains polymorphic variable. + bool has_type_vars(expr* e); + + void collect_type_vars(expr* e, ptr_vector& tvs); + + }; +} diff --git a/src/cmd_context/cmd_context.cpp b/src/cmd_context/cmd_context.cpp index 6b42d9ee1..6f04799f6 100644 --- a/src/cmd_context/cmd_context.cpp +++ b/src/cmd_context/cmd_context.cpp @@ -42,6 +42,7 @@ Notes: #include "ast/for_each_expr.h" #include "ast/rewriter/th_rewriter.h" #include "ast/rewriter/recfun_replace.h" +#include "ast/polymorphism_util.h" #include "model/model_evaluator.h" #include "model/model_smt2_pp.h" #include "model/model_v2_pp.h" @@ -223,12 +224,48 @@ bool func_decls::check_signature(ast_manager& m, func_decl* f, unsigned arity, s return true; } -func_decl * func_decls::find(ast_manager& m, unsigned arity, sort * const * domain, sort * range) const { +bool func_decls::check_poly_signature(ast_manager& m, func_decl* f, unsigned arity, sort* const* domain, sort* range, func_decl*& g) { + polymorphism::substitution sub(m); + arith_util au(m); + sort_ref range_ref(range, m); + if (range != nullptr && !sub.match(f->get_range(), range)) + return false; + if (f->get_arity() != arity) + return false; + for (unsigned i = 0; i < arity; i++) + if (!sub.match(f->get_domain(i), domain[i])) + return false; + if (!range) + range_ref = sub(f->get_range()); + + recfun::util u(m); + auto& p = u.get_plugin(); + if (!u.has_def(f)) { + g = m.instantiate_polymorphic(f, arity, domain, range_ref); + return true; + } + // this is an instantiation of a recursive polymorphic function. + // create a self-contained polymorphic definition for the instantiation. + auto def = u.get_def(f); + auto promise_def = p.mk_def(f->get_name(), arity, domain, range_ref, false); + recfun_replace replace(m); + expr_ref tt = sub(def.get_rhs()); + p.set_definition(replace, promise_def, def.is_macro(), def.get_vars().size(), def.get_vars().data(), tt); + g = promise_def.get_def()->get_decl(); + insert(m, g); + return true; +} + + +func_decl * func_decls::find(ast_manager& m, unsigned arity, sort * const * domain, sort * range) { bool coerced = false; + func_decl* g = nullptr; if (!more_than_one()) { func_decl* f = first(); if (check_signature(m, f, arity, domain, range, coerced)) - return f; + return f; + if (check_poly_signature(m, f, arity, domain, range, g)) + return g; return nullptr; } func_decl_set * fs = UNTAG(func_decl_set *, m_decls); @@ -241,10 +278,15 @@ func_decl * func_decls::find(ast_manager& m, unsigned arity, sort * const * doma return f; } } - return best_f; + if (best_f != nullptr) + return best_f; + for (func_decl* f : *fs) + if (check_poly_signature(m, f, arity, domain, range, g)) + return g; + return nullptr; } -func_decl * func_decls::find(ast_manager & m, unsigned num_args, expr * const * args, sort * range) const { +func_decl * func_decls::find(ast_manager & m, unsigned num_args, expr * const * args, sort * range) { if (!more_than_one()) first(); ptr_buffer sorts; @@ -376,12 +418,13 @@ void cmd_context::erase_macro(symbol const& s) { decls.erase_last(m()); } -bool cmd_context::macros_find(symbol const& s, unsigned n, expr*const* args, expr_ref_vector& coerced_args, expr*& t) const { +bool cmd_context::macros_find(symbol const& s, unsigned n, expr*const* args, expr_ref_vector& coerced_args, expr_ref& t) { macro_decls decls; if (!m_macros.find(s, decls)) return false; for (macro_decl const& d : decls) { - if (d.m_domain.size() != n) continue; + if (d.m_domain.size() != n) + continue; bool eq = true; coerced_args.reset(); for (unsigned i = 0; eq && i < n; ++i) { @@ -406,6 +449,26 @@ bool cmd_context::macros_find(symbol const& s, unsigned n, expr*const* args, exp return true; } } + for (macro_decl const& d : decls) { + if (d.m_domain.size() != n) + continue; + polymorphism::substitution sub(m()); + bool eq = true; + for (unsigned i = 0; eq && i < n; ++i) { + if (!sub.match(d.m_domain[i], args[i]->get_sort())) + eq = false; + } + if (eq) { + t = d.m_body; + t = sub(t); + verbose_stream() << "macro " << t << "\n"; + ptr_buffer domain; + for (unsigned i = 0; i < n; ++i) + domain.push_back(args[i]->get_sort()); + insert_macro(s, n, domain.data(), t); + return true; + } + } return false; } @@ -939,18 +1002,16 @@ void cmd_context::insert(cmd * c) { void cmd_context::insert_user_tactic(symbol const & s, sexpr * d) { sm().inc_ref(d); sexpr * old_d; - if (m_user_tactic_decls.find(s, old_d)) { - sm().dec_ref(old_d); - } + if (m_user_tactic_decls.find(s, old_d)) + sm().dec_ref(old_d); m_user_tactic_decls.insert(s, d); } void cmd_context::insert(symbol const & s, object_ref * r) { r->inc_ref(*this); object_ref * old_r = nullptr; - if (m_object_refs.find(s, old_r)) { - old_r->dec_ref(*this); - } + if (m_object_refs.find(s, old_r)) + old_r->dec_ref(*this); m_object_refs.insert(s, r); } @@ -1054,16 +1115,17 @@ static builtin_decl const & peek_builtin_decl(builtin_decl const & first, family } func_decl * cmd_context::find_func_decl(symbol const & s, unsigned num_indices, unsigned const * indices, - unsigned arity, sort * const * domain, sort * range) const { + unsigned arity, sort * const * domain, sort * range) { if (domain && contains_macro(s, arity, domain)) throw cmd_exception("invalid function declaration reference, named expressions (aka macros) cannot be referenced ", s); func_decl * f = nullptr; - func_decls fs; - if (num_indices == 0 && m_func_decls.find(s, fs)) + if (num_indices == 0 && m_func_decls.contains(s)) { + auto& fs = m_func_decls.find(s); f = fs.find(m(), arity, domain, range); - if (f) + } + if (f) return f; builtin_decl d; if ((arity == 0 || domain) && m_builtin_decls.find(s, d)) { @@ -1089,11 +1151,12 @@ func_decl * cmd_context::find_func_decl(symbol const & s, unsigned num_indices, throw cmd_exception("invalid function declaration reference, invalid builtin reference ", s); return f; } - if (num_indices > 0 && m_func_decls.find(s, fs)) + if (num_indices > 0 && m_func_decls.contains(s)) { + auto& fs = m_func_decls.find(s); f = fs.find(m(), arity, domain, range); - if (f) + } + if (f) return f; - throw cmd_exception("invalid function declaration reference, unknown indexed function ", s); } @@ -1125,7 +1188,7 @@ object_ref * cmd_context::find_object_ref(symbol const & s) const { #define CHECK_SORT(T) if (well_sorted_check_enabled()) m().check_sorts_core(T) -void cmd_context::mk_const(symbol const & s, expr_ref & result) const { +void cmd_context::mk_const(symbol const & s, expr_ref & result) { mk_app(s, 0, nullptr, 0, nullptr, nullptr, result); } @@ -1153,9 +1216,10 @@ bool cmd_context::try_mk_builtin_app(symbol const & s, unsigned num_args, expr * bool cmd_context::try_mk_declared_app(symbol const & s, unsigned num_args, expr * const * args, unsigned num_indices, parameter const * indices, sort * range, - func_decls& fs, expr_ref & result) const { - if (!m_func_decls.find(s, fs)) + expr_ref & result) { + if (!m_func_decls.contains(s)) return false; + func_decls& fs = m_func_decls.find(s); if (num_args == 0 && !range) { if (fs.more_than_one()) @@ -1180,8 +1244,8 @@ bool cmd_context::try_mk_declared_app(symbol const & s, unsigned num_args, expr bool cmd_context::try_mk_macro_app(symbol const & s, unsigned num_args, expr * const * args, unsigned num_indices, parameter const * indices, sort * range, - expr_ref & result) const { - expr* _t; + expr_ref & result) { + expr_ref _t(m()); expr_ref_vector coerced_args(m()); if (macros_find(s, num_args, args, coerced_args, _t)) { TRACE("macro_bug", tout << "well_sorted_check_enabled(): " << well_sorted_check_enabled() << "\n"; @@ -1256,19 +1320,21 @@ bool cmd_context::try_mk_pdecl_app(symbol const & s, unsigned num_args, expr * c void cmd_context::mk_app(symbol const & s, unsigned num_args, expr * const * args, unsigned num_indices, parameter const * indices, sort * range, - expr_ref & result) const { + expr_ref & result) { - func_decls fs; + if (try_mk_macro_app(s, num_args, args, num_indices, indices, range, result)) return; - if (try_mk_declared_app(s, num_args, args, num_indices, indices, range, fs, result)) - return; + if (try_mk_declared_app(s, num_args, args, num_indices, indices, range, result)) + return; if (try_mk_builtin_app(s, num_args, args, num_indices, indices, range, result)) return; if (!range && try_mk_pdecl_app(s, num_args, args, num_indices, indices, result)) return; + func_decls fs; + m_func_decls.find(s, fs); std::ostringstream buffer; buffer << "unknown constant " << s; if (num_args > 0) { diff --git a/src/cmd_context/cmd_context.h b/src/cmd_context/cmd_context.h index a4eb53237..c07d888c7 100644 --- a/src/cmd_context/cmd_context.h +++ b/src/cmd_context/cmd_context.h @@ -58,11 +58,12 @@ public: bool clash(func_decl * f) const; bool empty() const { return m_decls == nullptr; } func_decl * first() const; - func_decl * find(ast_manager & m, unsigned arity, sort * const * domain, sort * range) const; - func_decl * find(ast_manager & m, unsigned arity, expr * const * args, sort * range) const; + func_decl * find(ast_manager & m, unsigned arity, sort * const * domain, sort * range); + func_decl * find(ast_manager & m, unsigned arity, expr * const * args, sort * range); unsigned get_num_entries() const; func_decl * get_entry(unsigned inx); bool check_signature(ast_manager& m, func_decl* f, unsigned arityh, sort * const* domain, sort * range, bool& coerced) const; + bool check_poly_signature(ast_manager& m, func_decl* f, unsigned arity, sort* const* domain, sort* range, func_decl*& g); }; struct macro_decl { @@ -355,7 +356,7 @@ protected: bool contains_macro(symbol const& s, unsigned arity, sort *const* domain) const; void insert_macro(symbol const& s, unsigned arity, sort*const* domain, expr* t); void erase_macro(symbol const& s); - bool macros_find(symbol const& s, unsigned n, expr*const* args, expr_ref_vector& coerced_args, expr*& t) const; + bool macros_find(symbol const& s, unsigned n, expr*const* args, expr_ref_vector& coerced_args, expr_ref& t); recfun::decl::plugin& get_recfun_plugin(); @@ -449,22 +450,22 @@ public: void insert_rec_fun(func_decl* f, expr_ref_vector const& binding, svector const& ids, expr* e); func_decl * find_func_decl(symbol const & s) const; func_decl * find_func_decl(symbol const & s, unsigned num_indices, unsigned const * indices, - unsigned arity, sort * const * domain, sort * range) const; + unsigned arity, sort * const * domain, sort * range); recfun::promise_def decl_rec_fun(const symbol &name, unsigned int arity, sort *const *domain, sort *range); psort_decl * find_psort_decl(symbol const & s) const; cmd * find_cmd(symbol const & s) const; sexpr * find_user_tactic(symbol const & s) const; object_ref * find_object_ref(symbol const & s) const; - void mk_const(symbol const & s, expr_ref & result) const; + void mk_const(symbol const & s, expr_ref & result); void mk_app(symbol const & s, unsigned num_args, expr * const * args, unsigned num_indices, parameter const * indices, sort * range, - expr_ref & r) const; + expr_ref & r); bool try_mk_macro_app(symbol const & s, unsigned num_args, expr * const * args, unsigned num_indices, parameter const * indices, sort * range, - expr_ref & r) const; + expr_ref & r); bool try_mk_builtin_app(symbol const & s, unsigned num_args, expr * const * args, unsigned num_indices, parameter const * indices, sort * range, expr_ref & r) const; bool try_mk_declared_app(symbol const & s, unsigned num_args, expr * const * args, unsigned num_indices, parameter const * indices, sort * range, - func_decls& fs, expr_ref & result) const; + expr_ref & result); bool try_mk_pdecl_app(symbol const & s, unsigned num_args, expr * const * args, unsigned num_indices, parameter const * indices, expr_ref & r) const; void erase_cmd(symbol const & s); void erase_func_decl(symbol const & s); diff --git a/src/smt/smt_setup.cpp b/src/smt/smt_setup.cpp index 7699fc3a5..caeca9659 100644 --- a/src/smt/smt_setup.cpp +++ b/src/smt/smt_setup.cpp @@ -38,6 +38,7 @@ Revision History: #include "smt/theory_pb.h" #include "smt/theory_fpa.h" #include "smt/theory_str.h" +#include "smt/theory_polymorphism.h" namespace smt { @@ -788,6 +789,11 @@ namespace smt { m_context.register_plugin(alloc(smt::theory_special_relations, m_context, m_manager)); } + void setup::setup_polymorphism() { + if (m_manager.has_type_vars()) + m_context.register_plugin(alloc(theory_polymorphism, m_context)); + } + void setup::setup_unknown() { static_features st(m_manager); ptr_vector fmls; @@ -803,6 +809,7 @@ namespace smt { setup_seq_str(st); setup_fpa(); setup_special_relations(); + setup_polymorphism(); } void setup::setup_unknown(static_features & st) { @@ -819,6 +826,7 @@ namespace smt { setup_fpa(); setup_recfuns(); setup_special_relations(); + setup_polymorphism(); return; } diff --git a/src/smt/smt_setup.h b/src/smt/smt_setup.h index 2daa67085..bb4a81671 100644 --- a/src/smt/smt_setup.h +++ b/src/smt/smt_setup.h @@ -82,6 +82,7 @@ namespace smt { void setup_LRA(); void setup_CSP(); void setup_special_relations(); + void setup_polymorphism(); void setup_AUFLIA(bool simple_array = true); void setup_AUFLIA(static_features const & st); void setup_AUFLIRA(bool simple_array = true); diff --git a/src/smt/theory_polymorphism.h b/src/smt/theory_polymorphism.h new file mode 100644 index 000000000..4c64a0a9c --- /dev/null +++ b/src/smt/theory_polymorphism.h @@ -0,0 +1,105 @@ +/*++ +Copyright (c) 2023 Microsoft Corporation + +Module Name: + + theory_polymorphism.h + +Abstract: + + Plugin for handling polymorphism + The plugin instantiates polymorphic axioms based on occurrences of polymorphic functions in other axioms. + It uses blocking literals to restart search when there are new axioms that can be instantiated. + +Author: + + Nikolaj Bjorner (nbjorner) 2013-07-11 + +--*/ +#pragma once + +#include "ast/polymorphism_inst.h" +#include "smt/smt_theory.h" + +namespace smt { + + class theory_polymorphism : public theory { + trail_stack m_trail; + polymorphism::inst m_inst; + expr_ref m_assumption; + unsigned m_qhead = 0; + bool m_pending = true; + + bool internalize_atom(app*, bool) override { return false; } + bool internalize_term(app*) override { return false; } + void new_eq_eh(theory_var, theory_var) override { } + void new_diseq_eh(theory_var, theory_var) override {} + theory* mk_fresh(context* new_ctx) override { return alloc(theory_polymorphism, *new_ctx); } + char const * get_name() const override { return "polymorphism"; } + void display(std::ostream& out) const override {} + + void push_scope_eh() override { + m_trail.push_scope(); + } + + void pop_scope_eh(unsigned n) override { + m_trail.pop_scope(n); + } + + bool can_propagate() override { + return m_pending; + } + + /** + * Assert instances of polymorphic axioms + */ + void propagate() override { + if (!m_pending) + return; + m_pending = false; + vector instances; + m_inst.instantiate(instances); + if (instances.empty()) + return; + for (auto const& [orig, inst, sub] : instances) + ctx.add_asserted(inst); + ctx.internalize_assertions(); + } + + final_check_status final_check_eh() override { + if (m_inst.pending()) + ctx.assign(~mk_literal(m_assumption), nullptr); + return FC_DONE; + } + + void add_theory_assumptions(expr_ref_vector & assumptions) override { + if (m_qhead == ctx.get_num_asserted_formulas()) + return; + m_assumption = m.mk_fresh_const("poly", m.mk_bool_sort()); + assumptions.push_back(m_assumption); + ctx.push_trail(value_trail(m_qhead)); + for (; m_qhead < ctx.get_num_asserted_formulas(); ++m_qhead) + m_inst.add(ctx.get_asserted_formula(m_qhead)); + m_pending = true; + } + + bool should_research(expr_ref_vector & assumptions) override { + for (auto * a : assumptions) + if (a == m_assumption) + return true; + return false; + } + + + public: + theory_polymorphism(context& ctx): + theory(ctx, poly_family_id), + m_inst(ctx.get_manager(), m_trail), + m_assumption(ctx.get_manager()) {} + + void init_model(model_generator & mg) override { } + }; + +}; + +