From 5b219aab76a7443687d4a36846eaa80e43dfe348 Mon Sep 17 00:00:00 2001
From: Nikolaj Bjorner <nbjorner@microsoft.com>
Date: Wed, 20 Jul 2022 20:32:00 -0700
Subject: [PATCH] add mutual recursive datatypes to c++ API #6179

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
---
 examples/c++/example.cpp | 30 ++++++++++++++++++++++++-
 src/api/c++/z3++.h       | 48 +++++++++++++++++++++++++++++++---------
 2 files changed, 67 insertions(+), 11 deletions(-)

diff --git a/examples/c++/example.cpp b/examples/c++/example.cpp
index 980243761..eb6d2c19b 100644
--- a/examples/c++/example.cpp
+++ b/examples/c++/example.cpp
@@ -975,6 +975,34 @@ void datatype_example() {
     cs.query(1, cons, is_cons, cons_acc);
     std::cout << nil << " " << is_nil << " " << nil_acc << "\n";
     std::cout << cons << " " << is_cons << " " << cons_acc << "\n";
+
+    symbol tree = ctx.str_symbol("tree");
+    symbol tlist = ctx.str_symbol("tree_list");
+    symbol accs1[2] = { ctx.str_symbol("left"), ctx.str_symbol("right") };
+    symbol accs2[2] = { ctx.str_symbol("hd"), ctx.str_symbol("tail") };
+    sort sorts1[2] = { ctx.datatype_sort(tlist), ctx.datatype_sort(tlist) };
+    sort sorts2[2] = { ctx.int_sort(), ctx.datatype_sort(tree) };
+    constructors cs1(ctx), cs2(ctx);
+    cs1.add(ctx.str_symbol("tnil"), ctx.str_symbol("is-tnil"), 0, nullptr, nullptr);
+    cs1.add(ctx.str_symbol("tnode"), ctx.str_symbol("is-tnode"), 2, accs1, sorts1);
+    constructor_list cl1(cs1);
+    cs2.add(ctx.str_symbol("lnil"), ctx.str_symbol("is-lnil"), 0, nullptr, nullptr);
+    cs2.add(ctx.str_symbol("lcons"), ctx.str_symbol("is-lcons"), 2, accs2, sorts2);
+    constructor_list cl2(cs2);
+    symbol names[2] = { tree, tlist };
+    constructor_list* cl[2] = { &cl1, &cl2 };
+    sort_vector dsorts = ctx.datatypes(2, names, cl);
+    std::cout << dsorts << "\n";
+    cs1.query(0, nil, is_nil, nil_acc);
+    cs1.query(1, cons, is_cons, cons_acc);
+    std::cout << nil << " " << is_nil << " " << nil_acc << "\n";
+    std::cout << cons << " " << is_cons << " " << cons_acc << "\n";
+
+    cs2.query(0, nil, is_nil, nil_acc);
+    cs2.query(1, cons, is_cons, cons_acc);
+    std::cout << nil << " " << is_nil << " " << nil_acc << "\n";
+    std::cout << cons << " " << is_cons << " " << cons_acc << "\n";
+
 }
 
 void expr_vector_example() {
@@ -1328,7 +1356,7 @@ void iterate_args() {
 
 int main() {
 
-    try {        
+    try {
         demorgan(); std::cout << "\n";
         find_model_example1(); std::cout << "\n";
         prove_example1(); std::cout << "\n";
diff --git a/src/api/c++/z3++.h b/src/api/c++/z3++.h
index 50da0fcab..5f630d814 100644
--- a/src/api/c++/z3++.h
+++ b/src/api/c++/z3++.h
@@ -326,6 +326,16 @@ namespace z3 {
          */
         sort datatype(symbol const& name, constructors const& cs);
 
+        /**
+           \brief Create a set of mutually recursive datatypes.
+           \c n - number of recursive datatypes
+           \c names - array of names of length n
+           \c cons - array of constructor lists of length n
+        */
+        sort_vector datatypes(unsigned n, symbol const* names,
+                              constructor_list *const* cons);
+                       
+
         /**
            \brief a reference to a recursively defined datatype.
            Expect that it gets defined as a \ref datatype.
@@ -3354,15 +3364,13 @@ namespace z3 {
         context& ctx;
         Z3_constructor_list clist;
     public:
-        constructor_list(context& ctx, unsigned n, Z3_constructor const* cons): ctx(ctx) {
-            clist = Z3_mk_constructor_list(ctx, n, cons);
-        }
-        ~constructor_list() {
-            Z3_del_constructor_list(ctx, clist);
-        }
+        constructor_list(constructors const& cs);
+        ~constructor_list() { Z3_del_constructor_list(ctx, clist); }
+        operator Z3_constructor_list() const { return clist; }
     };
     
     class constructors {
+        friend class constructor_list;
         context&       ctx;
         std::vector<Z3_constructor> cons;
         std::vector<unsigned> num_fields;
@@ -3386,15 +3394,12 @@ namespace z3 {
         Z3_constructor operator[](unsigned i) const { return cons[i]; }
 
         unsigned size() const { return (unsigned)cons.size(); }
-
-        constructor_list get_constructors() const {
-            return constructor_list(ctx, (unsigned)cons.size(), cons.data());
-        }
         
         void query(unsigned i, func_decl& constructor, func_decl& test, func_decl_vector& accs) {
             Z3_func_decl _constructor;
             Z3_func_decl _test;
             array<Z3_func_decl> accessors(num_fields[i]);
+            accs.resize(0);
             Z3_query_constructor(ctx,
                                  cons[i],
                                  num_fields[i],
@@ -3408,6 +3413,13 @@ namespace z3 {
                 accs.push_back(func_decl(ctx, accessors[j]));
         }
     };
+    
+    constructor_list::constructor_list(constructors const& cs): ctx(cs.ctx) {
+        array<Z3_constructor> cons(cs.size());
+        for (unsigned i = 0; i < cs.size(); ++i)
+            cons[i] = cs[i];
+        clist = Z3_mk_constructor_list(ctx, cs.size(), cons.ptr());
+    }
 
     inline sort context::datatype(symbol const& name, constructors const& cs) {
         array<Z3_constructor> _cs(cs.size());
@@ -3417,6 +3429,22 @@ namespace z3 {
         return sort(*this, s);
     }
 
+    inline sort_vector context::datatypes(
+        unsigned n, symbol const* names,
+        constructor_list *const* cons) {
+        sort_vector result(*this);
+        array<Z3_symbol> _names(n);
+        array<Z3_sort> _sorts(n);
+        array<Z3_constructor_list> _cons(n);
+        for (unsigned i = 0; i < n; ++i)
+            _names[i] = names[i], _cons[i] = *cons[i];
+        Z3_mk_datatypes(*this, n, _names.ptr(), _sorts.ptr(), _cons.ptr());
+        for (unsigned i = 0; i < n; ++i)
+            result.push_back(sort(*this, _sorts[i]));
+        return result;
+    }
+
+
     inline sort context::datatype_sort(symbol const& name) {
         Z3_sort s = Z3_mk_datatype_sort(*this, name);
         check_error();