diff --git a/src/ast/ast.cpp b/src/ast/ast.cpp index 6e2fac4bd..a74f87142 100644 --- a/src/ast/ast.cpp +++ b/src/ast/ast.cpp @@ -1016,6 +1016,28 @@ func_decl * basic_decl_plugin::mk_ite_decl(sort * s) { return m_ite_decls[id]; } +sort* basic_decl_plugin::join(unsigned n, sort* const* srts) { + SASSERT(n > 0); + sort* s = srts[0]; + while (n > 1) { + ++srts; + --n; + s = join(s, *srts); + } + return s; +} + +sort* basic_decl_plugin::join(unsigned n, expr* const* es) { + SASSERT(n > 0); + sort* s = m_manager->get_sort(*es); + while (n > 1) { + ++es; + --n; + s = join(s, m_manager->get_sort(*es)); + } + return s; +} + sort* basic_decl_plugin::join(sort* s1, sort* s2) { if (s1 == s2) return s1; if (s1->get_family_id() == m_manager->m_arith_family_id && @@ -1045,8 +1067,8 @@ func_decl * basic_decl_plugin::mk_func_decl(decl_kind k, unsigned num_parameters case OP_XOR: return m_xor_decl; case OP_ITE: return arity == 3 ? mk_ite_decl(join(domain[1], domain[2])) : 0; // eq and oeq must have at least two arguments, they can have more since they are chainable - case OP_EQ: return arity >= 2 ? mk_eq_decl_core("=", OP_EQ, join(domain[0],domain[1]), m_eq_decls) : 0; - case OP_OEQ: return arity >= 2 ? mk_eq_decl_core("~", OP_OEQ, join(domain[0],domain[1]), m_oeq_decls) : 0; + case OP_EQ: return arity >= 2 ? mk_eq_decl_core("=", OP_EQ, join(arity, domain), m_eq_decls) : 0; + case OP_OEQ: return arity >= 2 ? mk_eq_decl_core("~", OP_OEQ, join(arity, domain), m_oeq_decls) : 0; case OP_DISTINCT: { func_decl_info info(m_family_id, OP_DISTINCT); info.set_pairwise(); @@ -1088,10 +1110,8 @@ func_decl * basic_decl_plugin::mk_func_decl(decl_kind k, unsigned num_parameters case OP_XOR: return m_xor_decl; case OP_ITE: return num_args == 3 ? mk_ite_decl(join(m_manager->get_sort(args[1]), m_manager->get_sort(args[2]))): 0; // eq and oeq must have at least two arguments, they can have more since they are chainable - case OP_EQ: return num_args >= 2 ? mk_eq_decl_core("=", OP_EQ, join(m_manager->get_sort(args[0]), - m_manager->get_sort(args[1])), m_eq_decls) : 0; - case OP_OEQ: return num_args >= 2 ? mk_eq_decl_core("~", OP_OEQ, join(m_manager->get_sort(args[0]), - m_manager->get_sort(args[1])), m_oeq_decls) : 0; + case OP_EQ: return num_args >= 2 ? mk_eq_decl_core("=", OP_EQ, join(num_args, args), m_eq_decls) : 0; + case OP_OEQ: return num_args >= 2 ? mk_eq_decl_core("~", OP_OEQ, join(num_args, args), m_oeq_decls) : 0; case OP_DISTINCT: return decl_plugin::mk_func_decl(k, num_parameters, parameters, num_args, args, range); default: diff --git a/src/ast/ast.h b/src/ast/ast.h index 216073125..214668536 100644 --- a/src/ast/ast.h +++ b/src/ast/ast.h @@ -1101,6 +1101,8 @@ protected: func_decl * mk_eq_decl_core(char const * name, decl_kind k, sort * s, ptr_vector & cache); func_decl * mk_ite_decl(sort * s); sort* join(sort* s1, sort* s2); + sort* join(unsigned n, sort*const* srts); + sort* join(unsigned n, expr*const* es); public: basic_decl_plugin();