diff --git a/src/ast/finite_set_decl_plugin.cpp b/src/ast/finite_set_decl_plugin.cpp index 1e54c65ce..985cc4c8b 100644 --- a/src/ast/finite_set_decl_plugin.cpp +++ b/src/ast/finite_set_decl_plugin.cpp @@ -28,39 +28,10 @@ finite_set_decl_plugin::finite_set_decl_plugin(): } finite_set_decl_plugin::~finite_set_decl_plugin() { - for (psig* s : m_sigs) + for (polymorphism::psig* s : m_sigs) dealloc(s); } -void finite_set_decl_plugin::match(psig& sig, unsigned dsz, sort *const* dom, sort* range, sort_ref& range_out) { - ast_manager& m = *m_manager; - - if (dsz != sig.m_dom.size()) { - std::ostringstream strm; - strm << "Incorrect number of arguments to '" << sig.m_name << "' "; - strm << "expected " << sig.m_dom.size() << " given " << dsz; - m.raise_exception(strm.str()); - } - - polymorphism::substitution sub(m); - bool is_match = true; - for (unsigned i = 0; is_match && i < dsz; ++i) { - SASSERT(dom[i]); - is_match = sub.match(sig.m_dom.get(i), dom[i]); - } - if (range && is_match) { - is_match = sub.match(sig.m_range, range); - } - if (!is_match) { - std::ostringstream strm; - strm << "Sort mismatch for function '" << sig.m_name << "'"; - m.raise_exception(strm.str()); - } - - // Apply substitution to get the range - range_out = sub(sig.m_range); -} - void finite_set_decl_plugin::init() { if (m_init) return; ast_manager& m = *m_manager; @@ -87,17 +58,17 @@ void finite_set_decl_plugin::init() { sort* intintT[2] = { intT, intT }; m_sigs.resize(LAST_FINITE_SET_OP); - m_sigs[OP_FINITE_SET_EMPTY] = alloc(psig, m, "set.empty", 1, 0, nullptr, setA); - m_sigs[OP_FINITE_SET_SINGLETON] = alloc(psig, m, "set.singleton", 1, 1, &A, setA); - m_sigs[OP_FINITE_SET_UNION] = alloc(psig, m, "set.union", 1, 2, setAsetA, setA); - m_sigs[OP_FINITE_SET_INTERSECT] = alloc(psig, m, "set.intersect", 1, 2, setAsetA, setA); - m_sigs[OP_FINITE_SET_DIFFERENCE] = alloc(psig, m, "set.difference", 1, 2, setAsetA, setA); - m_sigs[OP_FINITE_SET_IN] = alloc(psig, m, "set.in", 1, 2, AsetA, boolT); - m_sigs[OP_FINITE_SET_SIZE] = alloc(psig, m, "set.size", 1, 1, &setA, intT); - m_sigs[OP_FINITE_SET_SUBSET] = alloc(psig, m, "set.subset", 1, 2, setAsetA, boolT); - m_sigs[OP_FINITE_SET_MAP] = alloc(psig, m, "set.map", 2, 2, arrABsetA, setB); - m_sigs[OP_FINITE_SET_SELECT] = alloc(psig, m, "set.select", 1, 2, arrABoolsetA, setA); - m_sigs[OP_FINITE_SET_RANGE] = alloc(psig, m, "set.range", 0, 2, intintT, setInt); + m_sigs[OP_FINITE_SET_EMPTY] = alloc(polymorphism::psig, m, "set.empty", 1, 0, nullptr, setA); + m_sigs[OP_FINITE_SET_SINGLETON] = alloc(polymorphism::psig, m, "set.singleton", 1, 1, &A, setA); + m_sigs[OP_FINITE_SET_UNION] = alloc(polymorphism::psig, m, "set.union", 1, 2, setAsetA, setA); + m_sigs[OP_FINITE_SET_INTERSECT] = alloc(polymorphism::psig, m, "set.intersect", 1, 2, setAsetA, setA); + m_sigs[OP_FINITE_SET_DIFFERENCE] = alloc(polymorphism::psig, m, "set.difference", 1, 2, setAsetA, setA); + m_sigs[OP_FINITE_SET_IN] = alloc(polymorphism::psig, m, "set.in", 1, 2, AsetA, boolT); + m_sigs[OP_FINITE_SET_SIZE] = alloc(polymorphism::psig, m, "set.size", 1, 1, &setA, intT); + m_sigs[OP_FINITE_SET_SUBSET] = alloc(polymorphism::psig, m, "set.subset", 1, 2, setAsetA, boolT); + m_sigs[OP_FINITE_SET_MAP] = alloc(polymorphism::psig, m, "set.map", 2, 2, arrABsetA, setB); + m_sigs[OP_FINITE_SET_SELECT] = alloc(polymorphism::psig, m, "set.select", 1, 2, arrABoolsetA, setA); + m_sigs[OP_FINITE_SET_RANGE] = alloc(polymorphism::psig, m, "set.range", 0, 2, intintT, setInt); } sort * finite_set_decl_plugin::mk_sort(decl_kind k, unsigned num_parameters, parameter const * parameters) { @@ -141,8 +112,9 @@ func_decl * finite_set_decl_plugin::mk_empty(sort* element_sort) { func_decl * finite_set_decl_plugin::mk_finite_set_op(decl_kind k, unsigned arity, sort * const * domain, sort* range) { ast_manager& m = *m_manager; + polymorphism::util poly_util(m); sort_ref rng(m); - match(*m_sigs[k], arity, domain, range, rng); + poly_util.match(*m_sigs[k], arity, domain, range, rng); return m.mk_func_decl(m_sigs[k]->m_name, arity, domain, rng, func_decl_info(m_family_id, k)); } diff --git a/src/ast/finite_set_decl_plugin.h b/src/ast/finite_set_decl_plugin.h index 8f7437766..63a2a74a2 100644 --- a/src/ast/finite_set_decl_plugin.h +++ b/src/ast/finite_set_decl_plugin.h @@ -28,6 +28,7 @@ Operators: #pragma once #include "ast/ast.h" +#include "ast/polymorphism_util.h" enum finite_set_sort_kind { FINITE_SET_SORT @@ -49,26 +50,10 @@ enum finite_set_op_kind { }; class finite_set_decl_plugin : public decl_plugin { - struct psig { - symbol m_name; - unsigned m_num_params; - sort_ref_vector m_dom; - sort_ref m_range; - psig(ast_manager& m, char const* name, unsigned n, unsigned dsz, sort* const* dom, sort* rng): - m_name(name), - m_num_params(n), - m_dom(m), - m_range(rng, m) - { - m_dom.append(dsz, dom); - } - }; - - ptr_vector m_sigs; - bool m_init; + ptr_vector m_sigs; + bool m_init; void init(); - void match(psig& sig, unsigned dsz, sort *const* dom, sort* range, sort_ref& range_out); func_decl * mk_empty(sort* element_sort); func_decl * mk_finite_set_op(decl_kind k, unsigned arity, sort * const * domain, sort* range); sort * get_element_sort(sort* finite_set_sort) const; @@ -82,7 +67,7 @@ public: } void finalize() override { - for (psig* s : m_sigs) + for (polymorphism::psig* s : m_sigs) dealloc(s); m_sigs.reset(); } diff --git a/src/ast/polymorphism_util.cpp b/src/ast/polymorphism_util.cpp index 3431dd034..62dd4e949 100644 --- a/src/ast/polymorphism_util.cpp +++ b/src/ast/polymorphism_util.cpp @@ -350,4 +350,31 @@ namespace polymorphism { proc proc(m, tvs); for_each_ast(proc, e, true); } + + void util::match(psig& sig, unsigned dsz, sort* const* dom, sort* range, sort_ref& range_out) { + if (dsz != sig.m_dom.size()) { + std::ostringstream strm; + strm << "Incorrect number of arguments to '" << sig.m_name << "' "; + strm << "expected " << sig.m_dom.size() << " given " << dsz; + m.raise_exception(strm.str()); + } + + substitution sub(m); + bool is_match = true; + for (unsigned i = 0; is_match && i < dsz; ++i) { + SASSERT(dom[i]); + is_match = sub.match(sig.m_dom.get(i), dom[i]); + } + if (range && is_match) { + is_match = sub.match(sig.m_range, range); + } + if (!is_match) { + std::ostringstream strm; + strm << "Sort mismatch for function '" << sig.m_name << "'"; + m.raise_exception(strm.str()); + } + + // Apply substitution to get the range + range_out = sub(sig.m_range); + } } diff --git a/src/ast/polymorphism_util.h b/src/ast/polymorphism_util.h index 3023d0338..d7591b7fb 100644 --- a/src/ast/polymorphism_util.h +++ b/src/ast/polymorphism_util.h @@ -77,6 +77,24 @@ namespace polymorphism { }; typedef hashtable substitutions; + + /** + * Polymorphic signature for operators + */ + struct psig { + symbol m_name; + unsigned m_num_params; + sort_ref_vector m_dom; + sort_ref m_range; + psig(ast_manager& m, char const* name, unsigned n, unsigned dsz, sort* const* dom, sort* rng): + m_name(name), + m_num_params(n), + m_dom(m), + m_range(rng, m) + { + m_dom.append(dsz, dom); + } + }; class util { ast_manager& m; @@ -99,6 +117,13 @@ namespace polymorphism { substitution& sub); bool match(substitution& sub, sort* s1, sort* s_ground); + + /** + * Match a polymorphic signature against concrete argument sorts. + * Raises exception if arity mismatch or type mismatch. + * Returns the instantiated range sort via range_out. + */ + void match(psig& sig, unsigned dsz, sort* const* dom, sort* range, sort_ref& range_out); // collect instantiations of polymorphic functions void collect_poly_instances(expr* e, ptr_vector& instances);