From 7ae90f0b20ef3dc24e0f220f9f30db1ca870dd0c Mon Sep 17 00:00:00 2001
From: "Christoph M. Wintersteiger" <cwinter@microsoft.com>
Date: Fri, 21 Dec 2012 07:18:46 +0000
Subject: [PATCH] More ML API: Fixes in native layer. Added symbols. Prepared
 code for automatic documentation.

Signed-off-by: Christoph M. Wintersteiger <cwinter@microsoft.com>
---
 examples/ml/ml_example.ml |   5 ++
 scripts/mk_util.py        |  89 +-------------------
 scripts/update_api.py     |  37 ++++++---
 src/api/ml/z3.ml          | 168 ++++++++++++++++++++++++++------------
 4 files changed, 148 insertions(+), 151 deletions(-)

diff --git a/examples/ml/ml_example.ml b/examples/ml/ml_example.ml
index c79b3cdaa..7046a92b0 100644
--- a/examples/ml/ml_example.ml
+++ b/examples/ml/ml_example.ml
@@ -4,6 +4,7 @@
 *)
 
 open Z3
+open Z3.Context
 
 exception ExampleException of string
 
@@ -15,6 +16,10 @@ let _ =
       Printf.printf "Running Z3 version %s\n" Version.to_string ;
       let cfg = [("model", "true"); ("proof", "false")] in
       let ctx = (new context cfg) in
+      let is = (mk_symbol_int ctx 42) in
+      let ss = (mk_symbol_string ctx "mySymbol") in
+      Printf.printf "int symbol: %s\n" (Symbol.to_string (is :> symbol));
+      Printf.printf "string symbol: %s\n" (Symbol.to_string (ss :> symbol));
       Printf.printf "Disposing...\n";
       ctx#dispose (* can do, but we'd rather let it go out of scope *) ;
     );
diff --git a/scripts/mk_util.py b/scripts/mk_util.py
index e44c120cb..5ca0d2299 100644
--- a/scripts/mk_util.py
+++ b/scripts/mk_util.py
@@ -2777,99 +2777,12 @@ def mk_z3consts_ml(api_files):
                             efile.write('  | %s \n' % k[3:]) # strip Z3_
                         efile.write('\n')
                         efile.write('(** Convert %s to int*)\n' % name[3:])
-                        efile.write('let int_of_%s x : int =\n' % (name[3:])) # strip Z3_
-                        efile.write('  match x with\n')
-                        for k, i in decls.iteritems():
-                            efile.write('  | %s -> %d\n' % (k[3:], i))
-                        efile.write('\n')
-                        efile.write('(** Convert int to %s*)\n' % name[3:])
-                        efile.write('let %s_of_int x : %s =\n' % (name[3:],name[3:])) # strip Z3_
-                        efile.write('  match x with\n')
-                        for k, i in decls.iteritems():
-                            efile.write('  | %d -> %s\n' % (i, k[3:]))
-                        # use Z3.Exception?
-                        efile.write('  | _ -> raise (Failure "undefined enum value")\n\n')
-                    mode = SEARCHING
-                else:
-                    if words[2] != '':
-                        if len(words[2]) > 1 and words[2][1] == 'x':
-                            idx = int(words[2], 16)
-                        else:
-                            idx = int(words[2])
-                    decls[words[1]] = idx
-                    idx = idx + 1
-            linenum = linenum + 1
-    if VERBOSE:
-        print "Generated '%s/z3enums.ml'" % ('%s' % gendir)
-    efile  = open('%s.mli' % os.path.join(gendir, "z3enums"), 'w')
-    efile.write('(* Automatically generated file *)\n\n')
-    efile.write('(** The enumeration types of Z3. *)\n\n')
-    efile.write('module Z3enums = struct\n')
-    for api_file in api_files:
-        api_file_c = ml.find_file(api_file, ml.name)
-        api_file   = os.path.join(api_file_c.src_dir, api_file)
-
-        api = open(api_file, 'r')
-
-        SEARCHING  = 0
-        FOUND_ENUM = 1
-        IN_ENUM    = 2
-
-        mode    = SEARCHING
-        decls   = {}
-        idx     = 0
-
-        linenum = 1
-        for line in api:
-            m1 = blank_pat.match(line)
-            m2 = comment_pat.match(line)
-            if m1 or m2:
-                # skip blank lines and comments
-                linenum = linenum + 1 
-            elif mode == SEARCHING:
-                m = typedef_pat.match(line)
-                if m:
-                    mode = FOUND_ENUM
-                m = typedef2_pat.match(line)
-                if m:
-                    mode = IN_ENUM
-                    decls = {}
-                    idx   = 0
-            elif mode == FOUND_ENUM:
-                m = openbrace_pat.match(line)
-                if m:
-                    mode  = IN_ENUM
-                    decls = {}
-                    idx   = 0
-                else:
-                    assert False, "Invalid %s, line: %s" % (api_file, linenum)
-            else:
-                assert mode == IN_ENUM
-                words = re.split('[^\-a-zA-Z0-9_]+', line)
-                m = closebrace_pat.match(line)
-                if m:
-                    name = words[1]
-                    if name not in DeprecatedEnums:
-                        efile.write('(** %s *)\n' % name[3:])
-                        efile.write('type %s =\n' % name[3:]) # strip Z3_
-                        for k, i in decls.iteritems():
-                            efile.write('  | %s \n' % k[3:]) # strip Z3_
-                        efile.write('\n')
-                        efile.write('(** Convert %s to int*)\n' % name[3:])
-                        efile.write('val int_of_%s : %s -> int\n' % (name[3:], name[3:])) # strip Z3_
-                        efile.write('(** Convert int to %s*)\n' % name[3:])
-                        efile.write('val %s_of_int : int -> %s\n' % (name[3:],name[3:])) # strip Z3_
-                        efile.write('\n')
-                        efile.write('\n(* %s *)\n' % name)
-                        efile.write('type %s =\n' % name[3:]) # strip Z3_
-                        for k, i in decls.iteritems():
-                            efile.write('    | %s \n' % k[3:]) # strip Z3_
-                        efile.write('\n')
                         efile.write('let %s2int x : int =\n' % (name[3:])) # strip Z3_
                         efile.write('  match x with\n')
                         for k, i in decls.iteritems():
                             efile.write('  | %s -> %d\n' % (k[3:], i))
                         efile.write('\n')
+                        efile.write('(** Convert int to %s*)\n' % name[3:])
                         efile.write('let int2%s x : %s =\n' % (name[3:],name[3:])) # strip Z3_
                         efile.write('  match x with\n')
                         for k, i in decls.iteritems():
diff --git a/scripts/update_api.py b/scripts/update_api.py
index d0ef2ec86..c8cd05c7e 100644
--- a/scripts/update_api.py
+++ b/scripts/update_api.py
@@ -1106,26 +1106,26 @@ def arrayparams(params):
     return op
 
 
-def ml_unwrap(t):
+def ml_unwrap(t, ts, s):
     if t == STRING:
-        return 'String_val'
-    elif t == BOOL or t == INT or PRINT_MODE or ERROR_CODE:
-        return 'Int_val'
+        return '(' + ts + ') String_val(' + s + ')'
+    elif t == BOOL or t == INT or t == PRINT_MODE or t == ERROR_CODE:
+        return '(' + ts + ') Int_val(' + s + ')'
     elif t == UINT:
-        return 'Unsigned_int_val'
+        return '(' + ts + ') Unsigned_int_val(' + s + ')'
     elif t == INT64:
-        return 'Long_val'
+        return '(' + ts + ') Long_val(' + s + ')'
     elif t == UINT64:
-        return 'Unsigned_long_val'
+        return '(' + ts + ') Unsigned_long_val(' + s + ')'
     elif t == DOUBLE:
-        return 'Double_val'
+        return '(' + ts + ') Double_val(' + s + ')'
     else:
-        return 'Data_custom_val'
+        return '* (' + ts + '*) Data_custom_val(' + s + ')'
 
 def ml_set_wrap(t, d, n):
     if t == VOID:
         return d + ' = Val_unit;'
-    elif t == BOOL or t == INT or t == UINT or PRINT_MODE or ERROR_CODE:
+    elif t == BOOL or t == INT or t == UINT or t == PRINT_MODE or t == ERROR_CODE:
         return d + ' = Val_int(' + n + ');'
     elif t == INT64 or t == UINT64:
         return d + ' = Val_long(' + n + ');'
@@ -1135,7 +1135,7 @@ def ml_set_wrap(t, d, n):
         return d + ' = caml_copy_string((const char*) ' + n + ');'
     else:
         ts = type2str(t)
-        return d + ' = caml_alloc_custom(0, sizeof(' + ts + '), 0, 1); memcpy( Data_custom_val(' + d + '), &' + n + ', sizeof(' + ts + '));'
+        return d + ' = caml_alloc_custom(&default_custom_ops, sizeof(' + ts + '), 0, 1); memcpy( Data_custom_val(' + d + '), &' + n + ', sizeof(' + ts + '));'
 
 def mk_ml():
     global Type2Str
@@ -1146,7 +1146,9 @@ def mk_ml():
     ml_wrapperf = os.path.join(ml_dir, 'z3native.c')
     ml_native   = open(ml_nativef, 'w')
     ml_native.write('(* Automatically generated file *)\n\n')
+    ml_native.write('(** The native (raw) interface to the dynamic Z3 library. *)\n\n')
     ml_native.write('open Z3enums\n\n')
+    ml_native.write('(**/**)\n')
     ml_native.write('type ptr\n')
     ml_native.write('and z3_symbol = ptr\n')
     for k, v in Type2Str.iteritems():
@@ -1226,6 +1228,7 @@ def mk_ml():
         else:
             ml_native.write('        res\n')
         ml_native.write('\n')
+    ml_native.write('(**/**)\n')
 
     # C interface
     ml_wrapper = open(ml_wrapperf, 'w')
@@ -1276,6 +1279,14 @@ def mk_ml():
     ml_wrapper.write('  CAMLxparam5(X6,X7,X8,X9,X10);                                     \\\n')
     ml_wrapper.write('  CAMLxparam3(X11,X12,X13);                                           \n')
     ml_wrapper.write('\n\n')
+    ml_wrapper.write('static struct custom_operations default_custom_ops = {\n')
+    ml_wrapper.write('  identifier: "default handling",\n')
+    ml_wrapper.write('  finalize:    custom_finalize_default,\n')
+    ml_wrapper.write('  compare:     custom_compare_default,\n')
+    ml_wrapper.write('  hash:        custom_hash_default,\n')
+    ml_wrapper.write('  serialize:   custom_serialize_default,\n')
+    ml_wrapper.write('  deserialize: custom_deserialize_default\n')
+    ml_wrapper.write('};\n\n')
     ml_wrapper.write('#ifdef __cplusplus\n')
     ml_wrapper.write('extern "C" {\n')
     ml_wrapper.write('#endif\n\n')
@@ -1342,10 +1353,10 @@ def mk_ml():
                 t = param_type(param)
                 ts = type2str(t)
                 ml_wrapper.write('  %s * _a%s = (%s*) malloc(sizeof(%s) * a%s);\n' % (ts, i, ts, ts, param_array_capacity_pos(param)))
-                ml_wrapper.write('  for (unsigned i = 0; i < a%s; i++) _a%s[i] = (%s) %s(Field(a%s, i));\n' % (param_array_capacity_pos(param), i, ts, ml_unwrap(t), i))
+                ml_wrapper.write('  for (unsigned i = 0; i < _a%s; i++) _a%s[i] = %s;\n' % (param_array_capacity_pos(param), i, ml_unwrap(t, ts, 'Field(a' + str(i) + ', i)')))
             elif k == IN:
                 t = param_type(param)
-                ml_wrapper.write('  %s _a%s = (%s) %s(a%s);\n' % (type2str(t), i, type2str(t), ml_unwrap(t), i))
+                ml_wrapper.write('  %s _a%s = %s;\n' % (type2str(t), i, ml_unwrap(t, type2str(t), 'a' + str(i))))
             elif k == OUT:
                 ml_wrapper.write('  %s _a%s;\n' % (type2str(param_type(param)), i))
             elif k == INOUT:
diff --git a/src/api/ml/z3.ml b/src/api/ml/z3.ml
index d1ba2b746..321b17c82 100644
--- a/src/api/ml/z3.ml
+++ b/src/api/ml/z3.ml
@@ -1,33 +1,17 @@
-(* 
+(**
+   The Z3 ML/Ocaml Interface.
+
    Copyright (C) 2012 Microsoft Corporation
-   Author: CM Wintersteiger (cwinter) 2012-12-17
+   @author CM Wintersteiger (cwinter) 2012-12-17
 *)
 
 open Z3enums
 open Z3native
 
-module Log = 
-struct
-  let m_is_open = false
-  (* CMW: "open" seems to be an invalid function name*)
-  let open_ fn = ((int2lbool (open_log fn)) == L_TRUE)
-  let close = close_log
-  let append s = append_log s
-end
+(**/**)
 
-module Version =
-struct
-  let major = let (x, _, _, _) = get_version in x
-  let minor = let (_, x, _, _) = get_version in x
-  let build = let (_, _, x, _) = get_version in x
-  let revision = let (_, _, _, x) = get_version in x
-  let to_string = 
-    let (mj, mn, bld, rev) = get_version in
-    string_of_int mj ^ "." ^
-      string_of_int mn ^ "." ^
-      string_of_int bld ^ "." ^
-      string_of_int rev ^ "."
-end
+(* Object definitions. These are internal and should be interacted 
+   with only via the corresponding functions from modules. *)
 
 class virtual idisposable = 
 object
@@ -61,7 +45,7 @@ object (self)
 
   method sub_one_ctx_obj = m_refCount <- m_refCount - 1
   method add_one_ctx_obj = m_refCount <- m_refCount + 1
-  method get_native = m_n_ctx
+  method gno = m_n_ctx
 end
 
 class virtual z3object ctx_init obj_init =
@@ -91,9 +75,9 @@ object (self)
       | None -> ()
     ); 
 
-  method get_native_object = m_n_obj
+  method gno = m_n_obj
 
-  method set_native_object x =
+  method sno x =
     (match x with
       | Some(x) -> self#incref x
       | None -> ()
@@ -105,11 +89,11 @@ object (self)
     m_n_obj <- x
 
   method get_context = m_ctx
-  method get_native_context = m_ctx#get_native
+  method gnc = m_ctx#gno
 
 (*
   method array_to_native a =
-    let f e = e#get_native_object in 
+    let f e = e#gno in 
     (Array.map f a) 
 
   method array_length a =
@@ -120,61 +104,145 @@ object (self)
 
 end
 
-class symbol ctx_init obj_init = 
+class symbol ctx obj = 
 object (self)
-  inherit z3object ctx_init obj_init
+  inherit z3object ctx obj
 
   method incref o = ()
   method decref o = ()
 end
 
-class int_symbol ctx_init obj_init  = 
+class int_symbol ctx = 
 object(self)
-  inherit symbol ctx_init obj_init
+  inherit symbol ctx None
+  method cnstr_obj obj = (self#sno obj) ; self
+  method cnstr_int i = (self#sno (Some (mk_int_symbol ctx#gno i))) ; self
 end
 
-class string_symbol ctx_init obj_init = 
+class string_symbol ctx = 
 object(self)
-  inherit symbol ctx_init obj_init
+  inherit symbol ctx None
+  method cnstr_obj obj = (self#sno obj) ; self
+  method cnstr_string name = (self#sno (Some (mk_string_symbol ctx#gno name))) ; self
 end
 
+(**/**)
+
+(** Interaction logging for Z3.
+    Note that this is a global, static log and if multiple Context 
+    objects are created, it logs the interaction with all of them. *)
+module Log = 
+struct
+  (** Open an interaction log file. 
+      @param filename the name of the file to open.
+      @return True if opening the log file succeeds, false otherwise.
+  *)
+  (* CMW: "open" seems to be a reserved keyword? *)
+  let open_ filename = ((int2lbool (open_log filename)) == L_TRUE)
+
+  (** Closes the interaction log. *)
+  let close = close_log
+
+  (** Appends a user-provided string to the interaction log.
+      @param s the string to append*)
+  let append s = append_log s
+end
+
+(** Version information. *)
+module Version =
+struct
+  (** The major version. *)
+  let major = let (x, _, _, _) = get_version in x
+
+  (** The minor version. *)
+  let minor = let (_, x, _, _) = get_version in x
+
+  (** The build version. *)
+  let build = let (_, _, x, _) = get_version in x
+
+  (** The revision. *)
+  let revision = let (_, _, _, x) = get_version in x
+
+  (** A string representation of the version information. *)
+  let to_string = 
+    let (mj, mn, bld, rev) = get_version in
+    string_of_int mj ^ "." ^
+      string_of_int mn ^ "." ^
+      string_of_int bld ^ "." ^
+      string_of_int rev ^ "."
+end
+
+(** Symbols are used to name several term and type constructors. *)
 module Symbol =
 struct
+(**/**)
   let create ctx obj =
     match obj with 
       | Some(x) -> (
-	match (int2symbol_kind (get_symbol_kind ctx#get_native x)) with
-	  | INT_SYMBOL -> (new int_symbol ctx obj :> symbol)
-          | STRING_SYMBOL -> (new string_symbol ctx obj :> symbol)
+	match (int2symbol_kind (get_symbol_kind ctx#gno x)) with
+	  | INT_SYMBOL -> (((new int_symbol ctx)#cnstr_obj obj) :> symbol)
+          | STRING_SYMBOL -> (((new string_symbol ctx)#cnstr_obj obj) :> symbol)
       )
       | None -> raise (Exception "Can't create null objects")
+(**/**)
 
-  let kind o = match o#m_n_obj with
-    | Some(x) -> (int2symbol_kind (get_symbol_kind o#get_native_context x))
+  (** The kind of the symbol (int or string) *)
+  let kind (o : symbol) = match o#gno with
+    | Some(x) -> (int2symbol_kind (get_symbol_kind o#gnc x))
     | _ -> raise (Exception "Underlying object lost")
 
-  let is_int_symbol o = match o#m_n_obj with
-    | Some(x) -> x#kind == INT_SYMBOL
+  (** Indicates whether the symbol is of Int kind *)
+  let is_int_symbol (o : symbol) = match o#gno with
+    | Some(x) -> (kind o) == INT_SYMBOL
     | _ -> false
 
-  let is_string_symbol o = match o#m_n_obj with
-    | Some(x) -> x#kind == STRING_SYMBOL
+  (** Indicates whether the symbol is of string kind. *)
+  let is_string_symbol (o : symbol) = match o#gno with
+    | Some(x) -> (kind o) == STRING_SYMBOL
     | _ -> false
 
-  let get_int o = match o#m_n_obj with
-    | Some(x) -> (get_symbol_int o#get_native_context x)
+  (** The int value of the symbol. *)
+  let get_int (o : int_symbol) = match o#gno with
+    | Some(x) -> (get_symbol_int o#gnc x)
     | None -> 0
 
-  let get_string o = match o#m_n_obj with
-    | Some(x) -> (get_symbol_string o#get_native_context x)
+  (** The string value of the symbol. *)
+  let get_string (o : string_symbol) = match o#gno with
+    | Some(x) -> (get_symbol_string o#gnc x)
     | None -> ""
 
-  let to_string o = match o#m_n_obj with
+  (** A string representation of the symbol. *)
+  let to_string (o : symbol) = match o#gno with
     | Some(x) -> 
       (
 	match (kind o) with
-	  | INT_SYMBOL -> (string_of_int (get_symbol_int o#get_native_context x))
-	  | STRING_SYMBOL -> (get_symbol_string o#get_native_context x)
+	  | INT_SYMBOL -> (string_of_int (get_symbol_int o#gnc x))
+	  | STRING_SYMBOL -> (get_symbol_string o#gnc x)
       )
     | None -> ""
 end
+
+(** The main interaction with Z3 happens via the Context. *)
+module Context =
+struct
+  (**
+     Creates a new symbol using an integer.
+     
+     Not all integers can be passed to this function.
+     The legal range of unsigned integers is 0 to 2^30-1.
+  *)
+  let mk_symbol_int ctx i = 
+    (new int_symbol ctx)#cnstr_int i
+    
+  (** Creates a new symbol using a string. *)
+  let mk_symbol_string ctx s =
+    (new string_symbol ctx)#cnstr_string s
+
+  (**
+     Create an array of symbols.
+  *)
+  let mk_symbols ctx names =
+    let f elem = mk_symbol_string ctx elem in
+    (Array.map f names)
+      
+end