diff --git a/scripts/update_api.py b/scripts/update_api.py index d4d1ab0e0..b9da427df 100755 --- a/scripts/update_api.py +++ b/scripts/update_api.py @@ -42,6 +42,7 @@ IN_ARRAY = 3 OUT_ARRAY = 4 INOUT_ARRAY = 5 OUT_MANAGED_ARRAY = 6 +FN_PTR = 7 # Primitive Types VOID = 0 @@ -61,11 +62,16 @@ FLOAT = 13 CHAR = 14 CHAR_PTR = 15 +FIRST_FN_ID = 50 + FIRST_OBJ_ID = 100 def is_obj(ty): return ty >= FIRST_OBJ_ID +def is_fn(ty): + return FIRST_FN_ID <= ty and ty < FIRST_OBJ_ID + Type2Str = { VOID : 'void', VOID_PTR : 'void*', INT : 'int', UINT : 'unsigned', INT64 : 'int64_t', UINT64 : 'uint64_t', DOUBLE : 'double', FLOAT : 'float', STRING : 'Z3_string', STRING_PTR : 'Z3_string_ptr', BOOL : 'bool', SYMBOL : 'Z3_symbol', PRINT_MODE : 'Z3_ast_print_mode', ERROR_CODE : 'Z3_error_code', CHAR: 'char', CHAR_PTR: 'Z3_char_ptr' @@ -88,9 +94,12 @@ Type2ML = { VOID : 'unit', VOID_PTR : 'VOIDP', INT : 'int', UINT : 'int', INT64 FLOAT : 'float', STRING : 'string', STRING_PTR : 'char**', BOOL : 'bool', SYMBOL : 'z3_symbol', PRINT_MODE : 'int', ERROR_CODE : 'int', CHAR : 'char', CHAR_PTR : 'string' } +Closures = [] + class APITypes: def __init__(self): self.next_type_id = FIRST_OBJ_ID + self.next_fntype_id = FIRST_FN_ID def def_Type(self, var, c_type, py_type): """Process type definitions of the form def_Type(var, c_type, py_type) @@ -103,24 +112,42 @@ class APITypes: Type2Str[id] = c_type Type2PyStr[id] = py_type self.next_type_id += 1 + def def_Types(self, api_files): + global Closures pat1 = re.compile(" *def_Type\(\'(.*)\',[^\']*\'(.*)\',[^\']*\'(.*)\'\)[ \t]*") + pat2 = re.compile("Z3_DECLARE_CLOSURE\((.*),(.*), \((.*)\)\)") for api_file in api_files: with open(api_file, 'r') as api: for line in api: m = pat1.match(line) if m: self.def_Type(m.group(1), m.group(2), m.group(3)) + continue + m = pat2.match(line) + if m: + self.fun_Type(m.group(1)) + Closures += [(m.group(1), m.group(2), m.group(3))] + continue # # Populate object type entries in dotnet and ML bindings. # for k in Type2Str: v = Type2Str[k] - if is_obj(k): + if is_obj(k) or is_fn(k): Type2Dotnet[k] = v Type2ML[k] = v.lower() + def fun_Type(self, var): + """Process function type definitions""" + id = self.next_fntype_id + exec('%s = %s' % (var, id), globals()) + Type2Str[id] = var + Type2PyStr[id] = var + self.next_fntype_id += 1 + + def type2str(ty): global Type2Str return Type2Str[ty] @@ -147,6 +174,9 @@ def _in(ty): def _in_array(sz, ty): return (IN_ARRAY, ty, sz) +def _fnptr(ty): + return (FN_PTR, ty) + def _out(ty): return (OUT, ty) @@ -180,7 +210,7 @@ def param_array_size_pos(p): def param2str(p): if param_kind(p) == IN_ARRAY: return "%s const *" % type2str(param_type(p)) - elif param_kind(p) == OUT_ARRAY or param_kind(p) == IN_ARRAY or param_kind(p) == INOUT_ARRAY: + elif param_kind(p) == OUT_ARRAY or param_kind(p) == IN_ARRAY or param_kind(p) == INOUT_ARRAY or param_kind(p) == FN_PTR: return "%s*" % type2str(param_type(p)) elif param_kind(p) == OUT: return "%s*" % type2str(param_type(p)) @@ -374,11 +404,20 @@ def mk_dotnet(dotnet): v = Type2Str[k] if is_obj(k): dotnet.write(' using %s = System.IntPtr;\n' % v) + + dotnet.write(' using voidp = System.IntPtr;\n') dotnet.write('\n') dotnet.write(' public class Native\n') dotnet.write(' {\n\n') - dotnet.write(' [UnmanagedFunctionPointer(CallingConvention.Cdecl)]\n') - dotnet.write(' public delegate void Z3_error_handler(Z3_context c, Z3_error_code e);\n\n') + + for name, ret, sig in Closures: + sig = sig.replace("void*","voidp").replace("unsigned","uint") + ret = ret.replace("void*","voidp").replace("unsigned","uint") + if "*" in sig or "*" in ret: + continue + dotnet.write(' [UnmanagedFunctionPointer(CallingConvention.Cdecl)]\n') + dotnet.write(f" public delegate {ret} {name}({sig});\n") + dotnet.write(' public class LIB\n') dotnet.write(' {\n') dotnet.write(' const string Z3_DLL_NAME = \"libz3\";\n' @@ -1070,6 +1109,9 @@ def def_API(name, result, params): log_c.write(" }\n") log_c.write(" Ap(%s);\n" % sz_e) exe_c.write("reinterpret_cast<%s**>(in.get_obj_array(%s))" % (tstr, i)) + elif kind == FN_PTR: + log_c.write(" P(a%s);\n" % i) + exe_c.write("reinterpret_cast<%s>(in.get_obj(%s))" % (param2str(p), i)) else: error ("unsupported parameter for %s, %s" % (name, p)) i = i + 1 diff --git a/src/api/dotnet/CMakeLists.txt b/src/api/dotnet/CMakeLists.txt index e2055cd7a..1e9f598e7 100644 --- a/src/api/dotnet/CMakeLists.txt +++ b/src/api/dotnet/CMakeLists.txt @@ -113,6 +113,7 @@ set(Z3_DOTNET_ASSEMBLY_SOURCES_IN_SRC_TREE Tactic.cs TupleSort.cs UninterpretedSort.cs + UserPropagator.cs Version.cs Z3Exception.cs Z3Object.cs diff --git a/src/api/dotnet/UserPropagator.cs b/src/api/dotnet/UserPropagator.cs new file mode 100644 index 000000000..8e7831390 --- /dev/null +++ b/src/api/dotnet/UserPropagator.cs @@ -0,0 +1,222 @@ +/*++ +Copyright (c) 2012 Microsoft Corporation + +Module Name: + + UserPropagator.cs + +Abstract: + + User Propagator plugin + +Author: + + Nikolaj Bjorner (nbjorner) 2022-05-07 + +Notes: + +// Todo: fresh, created, declare user function, register_cb, decide, + +--*/ + +using System; +using System.Diagnostics; +using System.Linq; +using System.Collections.Generic; +using System.Runtime.InteropServices; + +namespace Microsoft.Z3 +{ + + using Z3_solver_callback = System.IntPtr; + using Z3_context = System.IntPtr; + using Z3_solver = System.IntPtr; + using voidp = System.IntPtr; + using Z3_ast = System.IntPtr; + + + /// + /// Propagator context for .Net + /// + public class UserPropagator + { + /// + /// Delegate type for fixed callback + /// + public delegate void FixedEh(Expr term, Expr value); + + /// + /// Delegate type for equality or disequality callback + /// + public delegate void EqEh(Expr term, Expr value); + + + Solver solver; + GCHandle gch; + Z3_solver_callback callback; + FixedEh fixed_eh; + Action final_eh; + EqEh eq_eh; + EqEh diseq_eh; + + + unsafe static void _push(voidp ctx, Z3_solver_callback cb) { + var gch = GCHandle.FromIntPtr(ctx); + var prop = (UserPropagator)gch.Target; + prop.callback = cb; + prop.Push(); + } + + unsafe static void _pop(voidp ctx, Z3_solver_callback cb, uint num_scopes) { + var gch = GCHandle.FromIntPtr(ctx); + var prop = (UserPropagator)gch.Target; + prop.callback = cb; + prop.Pop(num_scopes); + } + + unsafe static voidp _fresh(voidp ctx, Z3_context new_context) { + var gch = GCHandle.FromIntPtr(ctx); + var prop = (UserPropagator)gch.Target; + throw new Z3Exception("fresh is NYI"); + } + + unsafe static void _fixed(voidp ctx, Z3_solver_callback cb, Z3_ast _term, Z3_ast _value) { + var gch = GCHandle.FromIntPtr(ctx); + var prop = (UserPropagator)gch.Target; + var term = Expr.Create(prop.solver.Context, _term); + var value = Expr.Create(prop.solver.Context, _value); + prop.callback = cb; + prop.fixed_eh(term, value); + } + + unsafe static void _final(voidp ctx, Z3_solver_callback cb) { + var gch = GCHandle.FromIntPtr(ctx); + var prop = (UserPropagator)gch.Target; + prop.callback = cb; + prop.final_eh(); + } + + unsafe static void _eq(voidp ctx, Z3_solver_callback cb, Z3_ast a, Z3_ast b) { + var gch = GCHandle.FromIntPtr(ctx); + var prop = (UserPropagator)gch.Target; + var s = Expr.Create(prop.solver.Context, a); + var t = Expr.Create(prop.solver.Context, b); + prop.callback = cb; + prop.eq_eh(s, t); + } + + unsafe static void _diseq(voidp ctx, Z3_solver_callback cb, Z3_ast a, Z3_ast b) { + var gch = GCHandle.FromIntPtr(ctx); + var prop = (UserPropagator)gch.Target; + var s = Expr.Create(prop.solver.Context, a); + var t = Expr.Create(prop.solver.Context, b); + prop.callback = cb; + prop.diseq_eh(s, t); + } + + /// + /// Propagator constructor from a solver class. + /// + public UserPropagator(Solver s) + { + gch = GCHandle.Alloc(this); + solver = s; + var cb = GCHandle.ToIntPtr(gch); + Native.Z3_solver_propagate_init(solver.Context.nCtx, solver.NativeObject, cb, _push, _pop, _fresh); + } + + /// + /// Release provate memory. + /// + ~UserPropagator() + { + gch.Free(); + } + + /// + /// Virtual method for push. It must be overwritten by inherited class. + /// + public virtual void Push() { throw new Z3Exception("Push method should be overwritten"); } + + /// + /// Virtual method for pop. It must be overwritten by inherited class. + /// + public virtual void Pop(uint n) { throw new Z3Exception("Pop method should be overwritten"); } + + /// + /// Virtual method for fresh. It must be overwritten by inherited class. + /// + public virtual UserPropagator Fresh(Context ctx) { throw new Z3Exception("Fresh method should be overwritten"); } + + /// + /// Declare combination of assigned expressions a conflict + /// + void Conflict(params Expr[] terms) { + Propagate(terms, solver.Context.MkFalse()); + } + + /// + /// Propagate consequence + /// + void Propagate(Expr[] terms, Expr conseq) { + var nTerms = Z3Object.ArrayToNative(terms); + Native.Z3_solver_propagate_consequence(solver.Context.nCtx, this.callback, (uint)nTerms.Length, nTerms, 0u, null, null, conseq.NativeObject); + } + + + /// + /// Set fixed callback + /// + public FixedEh Fixed + { + set + { + this.fixed_eh = value; + Native.Z3_solver_propagate_fixed(solver.Context.nCtx, solver.NativeObject, _fixed); + } + } + + /// + /// Set final callback + /// + public Action Final + { + set + { + this.final_eh = value; + Native.Z3_solver_propagate_final(solver.Context.nCtx, solver.NativeObject, _final); + } + } + + /// + /// Set equality event callback + /// + public EqEh Eq + { + set + { + this.eq_eh = value; + Native.Z3_solver_propagate_eq(solver.Context.nCtx, solver.NativeObject, _eq); + } + } + + /// + /// Set disequality event callback + /// + public EqEh Diseq + { + set + { + this.diseq_eh = value; + Native.Z3_solver_propagate_diseq(solver.Context.nCtx, solver.NativeObject, _diseq); + } + } + + /// + /// Track assignments to a term + /// + public void Register(Expr term) { + Native.Z3_solver_propagate_register(solver.Context.nCtx, solver.NativeObject, term.NativeObject); + } + } +} diff --git a/src/api/z3_api.h b/src/api/z3_api.h index 13c4131bb..e7528da55 100644 --- a/src/api/z3_api.h +++ b/src/api/z3_api.h @@ -1444,7 +1444,7 @@ Z3_DECLARE_CLOSURE(Z3_fixed_eh, void, (void* ctx, Z3_solver_callback cb, Z3_as Z3_DECLARE_CLOSURE(Z3_eq_eh, void, (void* ctx, Z3_solver_callback cb, Z3_ast s, Z3_ast t)); Z3_DECLARE_CLOSURE(Z3_final_eh, void, (void* ctx, Z3_solver_callback cb)); Z3_DECLARE_CLOSURE(Z3_created_eh, void, (void* ctx, Z3_solver_callback cb, Z3_ast t)); -Z3_DECLARE_CLOSURE(Z3_decide_eh, void, (void* ctx, Z3_solver_callback cb, Z3_ast*, unsigned*, Z3_lbool*)); +Z3_DECLARE_CLOSURE(Z3_decide_eh, void, (void* ctx, Z3_solver_callback cb, Z3_ast* t, unsigned* idx, Z3_lbool* phase)); /** @@ -6733,6 +6733,8 @@ extern "C" { \param push_eh - a callback invoked when scopes are pushed \param pop_eh - a callback invoked when scopes are poped \param fresh_eh - a solver may spawn new solvers internally. This callback is used to produce a fresh user_context to be associated with fresh solvers. + + def_API('Z3_solver_propagate_init', VOID, (_in(CONTEXT), _in(SOLVER), _in(VOID_PTR), _fnptr(Z3_push_eh), _fnptr(Z3_pop_eh), _fnptr(Z3_fresh_eh))) */ void Z3_API Z3_solver_propagate_init( @@ -6748,6 +6750,8 @@ extern "C" { The supported expression types are - Booleans - Bit-vectors + + def_API('Z3_solver_propagate_fixed', VOID, (_in(CONTEXT), _in(SOLVER), _fnptr(Z3_fixed_eh))) */ void Z3_API Z3_solver_propagate_fixed(Z3_context c, Z3_solver s, Z3_fixed_eh fixed_eh); @@ -6764,22 +6768,30 @@ extern "C" { The callback context can only be accessed (for propagation and for dynamically registering expressions) within a callback. If the callback context gets used for propagation or conflicts, those propagations take effect and may trigger new decision variables to be set. + + def_API('Z3_solver_propagate_final', VOID, (_in(CONTEXT), _in(SOLVER), _fnptr(Z3_final_eh))) */ void Z3_API Z3_solver_propagate_final(Z3_context c, Z3_solver s, Z3_final_eh final_eh); /** \brief register a callback on expression equalities. + + def_API('Z3_solver_propagate_eq', VOID, (_in(CONTEXT), _in(SOLVER), _fnptr(Z3_eq_eh))) */ void Z3_API Z3_solver_propagate_eq(Z3_context c, Z3_solver s, Z3_eq_eh eq_eh); /** \brief register a callback on expression dis-equalities. + + def_API('Z3_solver_propagate_diseq', VOID, (_in(CONTEXT), _in(SOLVER), _fnptr(Z3_eq_eh))) */ void Z3_API Z3_solver_propagate_diseq(Z3_context c, Z3_solver s, Z3_eq_eh eq_eh); /** \brief register a callback when a new expression with a registered function is used by the solver The registered function appears at the top level and is created using \ref Z3_propagate_solver_declare. + + def_API('Z3_solver_propagate_created', VOID, (_in(CONTEXT), _in(SOLVER), _fnptr(Z3_created_eh))) */ void Z3_API Z3_solver_propagate_created(Z3_context c, Z3_solver s, Z3_created_eh created_eh); @@ -6788,6 +6800,7 @@ extern "C" { The callback may set the passed expression to another registered expression which will be selected instead. In case the expression is a bitvector the bit to split on is determined by the bit argument and the truth-value to try first is given by is_pos. In case the truth value is undefined the solver will decide. + */ void Z3_API Z3_solver_propagate_decide(Z3_context c, Z3_solver s, Z3_decide_eh decide_eh); diff --git a/src/muz/spacer/spacer_generalizers.cpp b/src/muz/spacer/spacer_generalizers.cpp index 9a63a4a7c..d3f083904 100644 --- a/src/muz/spacer/spacer_generalizers.cpp +++ b/src/muz/spacer/spacer_generalizers.cpp @@ -195,8 +195,10 @@ public: void operator()(app* a) { if (a->get_family_id() == null_family_id && m_au.is_array(a)) { - if (m_sort && m_sort != a->get_sort()) { return; } - if (!m_sort) { m_sort = a->get_sort(); } + if (m_sort && m_sort != a->get_sort()) + return; + if (!m_sort) + m_sort = a->get_sort(); m_symbs.insert(a->get_decl()); } } @@ -208,16 +210,10 @@ public: bool lemma_array_eq_generalizer::is_array_eq (ast_manager &m, expr* e) { expr *e1 = nullptr, *e2 = nullptr; - if (m.is_eq(e, e1, e2) && is_app(e1) && is_app(e2)) { - app *a1 = to_app(e1); - app *a2 = to_app(e2); - array_util au(m); - if (a1->get_family_id() == null_family_id && - a2->get_family_id() == null_family_id && - au.is_array(a1) && au.is_array(a2)) - return true; - } - return false; + array_util au(m); + return m.is_eq(e, e1, e2) && + is_uninterp(e1) && is_uninterp(e2) && + au.is_array(e1) && au.is_array(e2); } void lemma_array_eq_generalizer::operator() (lemma_ref &lemma) diff --git a/src/sat/sat_solver.h b/src/sat/sat_solver.h index 0f242bd4d..f7113609f 100644 --- a/src/sat/sat_solver.h +++ b/src/sat/sat_solver.h @@ -288,7 +288,7 @@ namespace sat { inline clause_allocator& cls_allocator() { return m_cls_allocator[m_cls_allocator_idx]; } inline clause_allocator const& cls_allocator() const { return m_cls_allocator[m_cls_allocator_idx]; } inline clause * alloc_clause(unsigned num_lits, literal const * lits, bool learned) { return cls_allocator().mk_clause(num_lits, lits, learned); } - inline void dealloc_clause(clause* c) { cls_allocator().del_clause(c); } + inline void dealloc_clause(clause* c) { cls_allocator().del_clause(c); } struct cmp_activity; void defrag_clauses(); bool should_defrag(); diff --git a/src/sat/smt/arith_solver.cpp b/src/sat/smt/arith_solver.cpp index e02a42979..a6eaeffbf 100644 --- a/src/sat/smt/arith_solver.cpp +++ b/src/sat/smt/arith_solver.cpp @@ -1190,6 +1190,10 @@ namespace arith { } void solver::assign(literal lit, literal_vector const& core, svector const& eqs, vector const& params) { + std::cout << "assign: "; + for (auto const& p : params) + std::cout << p << " "; + std::cout << "\n"; if (core.size() < small_lemma_size() && eqs.empty()) { m_core2.reset(); for (auto const& c : core)