From 03a5c88dedfc148308c04be596adc56ed251d674 Mon Sep 17 00:00:00 2001
From: "Christoph M. Wintersteiger" <cwinter@microsoft.com>
Date: Tue, 11 Dec 2012 15:07:33 +0000
Subject: [PATCH] More new ML API

Signed-off-by: Christoph M. Wintersteiger <cwinter@microsoft.com>
---
 .gitignore            |   3 +
 scripts/mk_util.py    |  89 ++++++++++++++++++++++++++++++
 scripts/update_api.py | 125 ++++++++++++++++++++++++------------------
 3 files changed, 164 insertions(+), 53 deletions(-)

diff --git a/.gitignore b/.gitignore
index 65a69bf51..3c4fb1218 100644
--- a/.gitignore
+++ b/.gitignore
@@ -67,3 +67,6 @@ src/api/java/enumerations/*.java
 *.bak
 doc/api
 doc/code
+src/api/ml/z3_native.c
+src/api/ml/z3_native.ml
+src/api/ml/z3_enums.ml
diff --git a/scripts/mk_util.py b/scripts/mk_util.py
index cb6b17cce..8543f1ffa 100644
--- a/scripts/mk_util.py
+++ b/scripts/mk_util.py
@@ -49,6 +49,7 @@ UTIL_COMPONENT='util'
 API_COMPONENT='api'
 DOTNET_COMPONENT='dotnet'
 JAVA_COMPONENT='java'
+ML_COMPONENT='ml'
 CPP_COMPONENT='cpp'
 #####################
 IS_WINDOWS=False
@@ -2230,6 +2231,8 @@ def mk_bindings(api_files):
         if is_java_enabled():
             check_java()
             mk_z3consts_java(api_files)
+        if is_ml_enabled():
+            mk_z3consts_ml(api_files)
         _execfile(os.path.join('scripts', 'update_api.py'), g) # HACK
         cp_z3py_to_build()
 
@@ -2503,6 +2506,92 @@ def mk_z3consts_java(api_files):
     if VERBOSE:
         print("Generated '%s'" % ('%s' % gendir))
 
+# Extract enumeration types from z3_api.h, and add ML definitions
+def mk_z3consts_ml(api_files):
+    blank_pat      = re.compile("^ *$")
+    comment_pat    = re.compile("^ *//.*$")
+    typedef_pat    = re.compile("typedef enum *")
+    typedef2_pat   = re.compile("typedef enum { *")
+    openbrace_pat  = re.compile("{ *")
+    closebrace_pat = re.compile("}.*;")
+
+    ml = get_component(ML_COMPONENT)
+
+    DeprecatedEnums = [ 'Z3_search_failure' ]
+    gendir = ml.src_dir
+    if not os.path.exists(gendir):
+        os.mkdir(gendir)
+
+    efile  = open('%s.ml' % os.path.join(gendir, "z3_enums"), 'w')
+    efile.write('(* Automatically generated file *)\n\n')
+    # efile.write('module z3_enums = struct\n\n');
+
+
+    for api_file in api_files:
+        api_file_c = ml.find_file(api_file, ml.name)
+        api_file   = os.path.join(api_file_c.src_dir, api_file)
+
+        api = open(api_file, 'r')
+
+        SEARCHING  = 0
+        FOUND_ENUM = 1
+        IN_ENUM    = 2
+
+        mode    = SEARCHING
+        decls   = {}
+        idx     = 0
+
+        linenum = 1
+        for line in api:
+            m1 = blank_pat.match(line)
+            m2 = comment_pat.match(line)
+            if m1 or m2:
+                # skip blank lines and comments
+                linenum = linenum + 1 
+            elif mode == SEARCHING:
+                m = typedef_pat.match(line)
+                if m:
+                    mode = FOUND_ENUM
+                m = typedef2_pat.match(line)
+                if m:
+                    mode = IN_ENUM
+                    decls = {}
+                    idx   = 0
+            elif mode == FOUND_ENUM:
+                m = openbrace_pat.match(line)
+                if m:
+                    mode  = IN_ENUM
+                    decls = {}
+                    idx   = 0
+                else:
+                    assert False, "Invalid %s, line: %s" % (api_file, linenum)
+            else:
+                assert mode == IN_ENUM
+                words = re.split('[^\-a-zA-Z0-9_]+', line)
+                m = closebrace_pat.match(line)
+                if m:
+                    name = words[1]
+                    if name not in DeprecatedEnums:
+                        efile.write('\n(* %s *)\n' % name)
+                        efile.write('type %s =\n' % name[3:]) # strip Z3_
+                        efile.write
+                        for k, i in decls.iteritems():
+                            efile.write('    | %s \n' % k[3:]) # strip Z3_
+                    mode = SEARCHING
+                else:
+                    if words[2] != '':
+                        if len(words[2]) > 1 and words[2][1] == 'x':
+                            idx = int(words[2], 16)
+                        else:
+                            idx = int(words[2])
+                    decls[words[1]] = idx
+                    idx = idx + 1
+            linenum = linenum + 1
+    efile.write('\n')
+    # efile.write'end\n');
+    if VERBOSE:
+        print "Generated '%s/z3_enums.ml'" % ('%s' % gendir)
+
 def mk_gui_str(id):
     return '4D2F40D8-E5F9-473B-B548-%012d' % id
 
diff --git a/scripts/update_api.py b/scripts/update_api.py
index 9ef0ce9a2..1e86dc0ef 100644
--- a/scripts/update_api.py
+++ b/scripts/update_api.py
@@ -156,9 +156,9 @@ Type2JavaW = { VOID : 'void', VOID_PTR : 'jlong', INT : 'jint', UINT : 'jint', I
                BOOL : 'jboolean', SYMBOL : 'jlong', PRINT_MODE : 'jint', ERROR_CODE : 'jint'}
 
 # Mapping to ML types
-Type2ML = { VOID : 'void', VOID_PTR : 'long', INT : 'int', UINT : 'int', INT64 : 'long', UINT64 : 'long', DOUBLE : 'double',
-              STRING : 'char*', STRING_PTR : 'char**', 
-              BOOL : 'boolean', SYMBOL : 'long', PRINT_MODE : 'int', ERROR_CODE : 'int' }
+Type2ML = { VOID : 'unit', VOID_PTR : 'long', INT : 'int', UINT : 'int', INT64 : 'long', UINT64 : 'long', DOUBLE : 'double',
+              STRING : 'string', STRING_PTR : 'char**', 
+              BOOL : 'lbool', SYMBOL : 'symbol', PRINT_MODE : 'ast_print_mode', ERROR_CODE : 'error_code' }
 
 next_type_id = FIRST_OBJ_ID
 
@@ -332,13 +332,13 @@ def param2ml(p):
         elif param_type(p) == INT64 or param_type(p) == UINT64 or param_type(p) >= FIRST_OBJ_ID:
             return "long*"
         elif param_type(p) == STRING:
-            return "char*"
+            return "string"
         else:
             print "ERROR: unreachable code"
             assert(False)
             exit(1)
     if k == IN_ARRAY or k == INOUT_ARRAY or k == OUT_ARRAY:
-        return "%s[]" % type2ml(param_type(p))
+        return "%s array" % type2ml(param_type(p))
     else:
         return type2ml(param_type(p))
 
@@ -1069,9 +1069,7 @@ def mk_bindings():
     exe_c.write("}\n")
 
 def ml_method_name(name):
-    result = ''
-    name = name[3:] # Remove Z3_
-    return result
+    return name[3:] # Remove Z3_
 
 def mk_ml():
     if not is_ml_enabled():
@@ -1080,10 +1078,57 @@ def mk_ml():
     ml_nativef  = os.path.join(ml_dir, 'z3_native.ml')
     ml_wrapperf = os.path.join(ml_dir, 'z3_native.c')
     ml_native   = open(ml_nativef, 'w')
-    ml_native.write('// Automatically generated file\n')
+    ml_native.write('(* Automatically generated file *)\n')
     ml_native.write('\n')
+    ml_native.write('module Z3 = struct\n\n')
+    ml_native.write('type context\n')
+    ml_native.write('and symbol\n')
+    ml_native.write('and ast\n')
+    ml_native.write('and sort = private ast\n')
+    ml_native.write('and func_decl = private ast\n')
+    ml_native.write('and app = private ast\n')
+    ml_native.write('and pattern = private ast\n')
+    ml_native.write('and params\n')
+    ml_native.write('and param_descrs\n')
+    ml_native.write('and model\n')
+    ml_native.write('and func_interp\n')
+    ml_native.write('and func_entry\n')
+    ml_native.write('and fixedpoint\n')
+    ml_native.write('and ast_vector\n')
+    ml_native.write('and ast_map\n')
+    ml_native.write('and goal\n')
+    ml_native.write('and tactic\n')
+    ml_native.write('and probe\n')
+    ml_native.write('and apply_result\n')
+    ml_native.write('and solver\n')
+    ml_native.write('and stats\n')
+    ml_native.write('\n')
+    ml_native.write('  exception Z3Exception of string\n\n')
     for name, result, params in _dotnet_decls:
-        ml_native.write(' external %s : (' % ml_method_name(name))
+        ml_native.write('  external native_%s : ' % ml_method_name(name))
+        i = 0;
+        for param in params:
+            ml_native.write('%s -> ' % param2ml(param))
+            i = i + 1
+        ml_native.write('%s\n' % (type2ml(result)))
+        ml_native.write('    = "Native_Z3_%s"\n\n' % ml_method_name(name))
+    # Exception wrappers
+    for name, result, params in _dotnet_decls:
+        ml_native.write(' let %s ' % ml_method_name(name))
+        first = True
+        i = 0;
+        for param in params:
+            if first:
+                first = False;
+            else:
+                ml_native.write(' ')
+            ml_native.write('a%d' % i)
+            i = i + 1
+        ml_native.write(' = \n')
+        ml_native.write('    ')
+        if result != VOID:
+            ml_native.write('let res = ')
+        ml_native.write('n_%s(' % (ml_method_name(name)))
         first = True
         i = 0;
         for param in params:
@@ -1091,50 +1136,24 @@ def mk_ml():
                 first = False
             else:
                 ml_native.write(', ')
-            ml_native.write('%s a%d' % (param2ml(param), i))
+            ml_native.write('a%d' % i)
             i = i + 1
-    ml_native.write('%s)\n' % (type2ml(result)))
-    # ml_native.write('    = "NATIVE_%s"' % ml_method_name(name))
-    # ml_native.write('\n\n')
-    # # Exception wrappers
-    # for name, result, params in _dotnet_decls:
-    #     java_native.write('  public static %s %s(' % (type2java(result), java_method_name(name)))
-    #     first = True
-    #     i = 0;
-    #     for param in params:
-    #         if first:
-    #             first = False
-    #         else:
-    #             java_native.write(', ')
-    #         java_native.write('%s a%d' % (param2java(param), i))
-    #         i = i + 1
-    #     java_native.write(')')
-    #     if len(params) > 0 and param_type(params[0]) == CONTEXT:
-    #         java_native.write(' throws Z3Exception')
-    #     java_native.write('\n')
-    #     java_native.write('  {\n')
-    #     java_native.write('      ')
-    #     if result != VOID:
-    #         java_native.write('%s res = ' % type2java(result))
-    #     java_native.write('INTERNAL%s(' % (java_method_name(name)))
-    #     first = True
-    #     i = 0;
-    #     for param in params:
-    #         if first:
-    #             first = False
-    #         else:
-    #             java_native.write(', ')
-    #         java_native.write('a%d' % i)
-    #         i = i + 1
-    #     java_native.write(');\n')
-    #     if len(params) > 0 and param_type(params[0]) == CONTEXT:
-    #         java_native.write('      Z3_error_code err = Z3_error_code.fromInt(INTERNALgetErrorCode(a0));\n')
-    #         java_native.write('      if (err != Z3_error_code.Z3_OK)\n')
-    #         java_native.write('          throw new Z3Exception(INTERNALgetErrorMsgEx(a0, err.toInt()));\n')
-    #     if result != VOID:
-    #         java_native.write('      return res;\n')
-    #     java_native.write('  }\n\n')
-    # java_native.write('}\n')
+        ml_native.write(')')
+        if result != VOID:
+            ml_native.write(' in\n')
+        else:
+            ml_native.write(';\n')
+        if len(params) > 0 and param_type(params[0]) == CONTEXT:
+            ml_native.write('    let err = error_code.fromInt(n_get_error_code(a0)) in \n')
+            ml_native.write('      if err <> Z3_enums.OK then\n')
+            ml_native.write('        raise (z3_exception n_get_error_msg_ex(a0, err.toInt()))\n')
+            ml_native.write('      else\n')
+        if result == VOID:
+            ml_native.write('        ()\n')
+        else:
+            ml_native.write('        res\n')
+        ml_native.write('\n')
+    ml_native.write('\nend\n')
     ml_wrapper = open(ml_wrapperf, 'w')
     ml_wrapper.write('// Automatically generated file\n\n')
     ml_wrapper.write('#include <stddef.h>\n')