From 11264c38d84bef9305df3dd34e5165140bec80f8 Mon Sep 17 00:00:00 2001 From: ditto <819045949@qq.com> Date: Thu, 25 May 2023 01:27:28 +0800 Subject: [PATCH] Java user propagator interface (#6733) * Java API: user propagator interface * Java API: improved user propagator interface * Java API: Add UserPropagatorBase.java * Remove redundant header file * Initialize `JavaInfo` object and error handling * Native.UserPropagatorBase implements AutoCloseable * Add Override annotation --- scripts/update_api.py | 65 ++++++++++++ src/api/java/CMakeLists.txt | 1 + src/api/java/Context.java | 27 +++-- src/api/java/NativeStatic.txt | 153 +++++++++++++++++++++++++++ src/api/java/UserPropagatorBase.java | 97 +++++++++++++++++ 5 files changed, 337 insertions(+), 6 deletions(-) create mode 100644 src/api/java/UserPropagatorBase.java diff --git a/scripts/update_api.py b/scripts/update_api.py index 4295b8961..23d044832 100755 --- a/scripts/update_api.py +++ b/scripts/update_api.py @@ -633,7 +633,72 @@ def mk_java(java_src, java_dir, package_name): java_native.write(' }\n') java_native.write(' }\n') java_native.write(' }\n') + java_native.write(""" + public static native long propagateInit(Object o, long ctx, long solver); + public static native void propagateRegisterCreated(Object o, long ctx, long solver); + public static native void propagateRegisterFixed(Object o, long ctx, long solver); + public static native void propagateRegisterEq(Object o, long ctx, long solver); + public static native void propagateRegisterDecide(Object o, long ctx, long solver); + public static native void propagateRegisterFinal(Object o, long ctx, long solver); + public static native void propagateConflict(Object o, long ctx, long solver, long javainfo, int num_fixed, long[] fixed, long num_eqs, long[] eq_lhs, long[] eq_rhs, long conseq); + public static native void propagateAdd(Object o, long ctx, long solver, long javainfo, long e); + public static native void propagateNextSplit(Object o, long ctx, long solver, long javainfo, long e, long idx, long phase); + public static native void propagateDestroy(Object o, long ctx, long solver, long javainfo); + public static abstract class UserPropagatorBase implements AutoCloseable { + protected long ctx; + protected long solver; + protected long javainfo; + + public UserPropagatorBase(long _ctx, long _solver) { + ctx = _ctx; + solver = _solver; + javainfo = propagateInit(this, ctx, solver); + } + + @Override + public void close() { + Native.propagateDestroy(this, ctx, solver, javainfo); + javainfo = 0; + solver = 0; + ctx = 0; + } + + protected final void registerCreated() { + Native.propagateRegisterCreated(this, ctx, solver); + } + + protected final void registerFixed() { + Native.propagateRegisterFixed(this, ctx, solver); + } + + protected final void registerEq() { + Native.propagateRegisterEq(this, ctx, solver); + } + + protected final void registerDecide() { + Native.propagateRegisterDecide(this, ctx, solver); + } + + protected final void registerFinal() { + Native.propagateRegisterFinal(this, ctx, solver); + } + + protected abstract void pushWrapper(); + + protected abstract void popWrapper(int number); + + protected abstract void finWrapper(); + + protected abstract void eqWrapper(long lx, long ly); + + protected abstract UserPropagatorBase freshWrapper(long lctx); + + protected abstract void createdWrapper(long le); + + protected abstract void fixedWrapper(long lvar, long lvalue); + } + """) java_native.write('\n') for name, result, params in _dotnet_decls: java_native.write(' protected static native %s INTERNAL%s(' % (type2java(result), java_method_name(name))) diff --git a/src/api/java/CMakeLists.txt b/src/api/java/CMakeLists.txt index 4b13a25b1..bd4338f7b 100644 --- a/src/api/java/CMakeLists.txt +++ b/src/api/java/CMakeLists.txt @@ -179,6 +179,7 @@ set(Z3_JAVA_JAR_SOURCE_FILES Tactic.java TupleSort.java UninterpretedSort.java + UserPropagatorBase.java Version.java Z3Exception.java Z3Object.java diff --git a/src/api/java/Context.java b/src/api/java/Context.java index 7aaef4801..b5b22405c 100644 --- a/src/api/java/Context.java +++ b/src/api/java/Context.java @@ -452,6 +452,21 @@ public class Context implements AutoCloseable { return new FuncDecl<>(this, name, domain, range); } + public final FuncDecl mkPropagateFunction(Symbol name, Sort[] domain, R range) + { + checkContextMatch(name); + checkContextMatch(domain); + checkContextMatch(range); + long f = Native.solverPropagateDeclare( + this.nCtx(), + name.getNativeObject(), + AST.arrayLength(domain), + AST.arrayToNative(domain), + range.getNativeObject()); + return new FuncDecl<>(this, f); + } + + /** * Creates a new function declaration. **/ @@ -2018,11 +2033,11 @@ public class Context implements AutoCloseable { { StringBuilder buf = new StringBuilder(); for (int i = 0; i < s.length(); ++i) { - int code = s.codePointAt(i); - if (code <= 32 || 127 < code) - buf.append(String.format("\\u{%x}", code)); - else - buf.append(s.charAt(i)); + int code = s.codePointAt(i); + if (code <= 32 || 127 < code) + buf.append(String.format("\\u{%x}", code)); + else + buf.append(s.charAt(i)); } return (SeqExpr) Expr.create(this, Native.mkString(nCtx(), buf.toString())); } @@ -2288,7 +2303,7 @@ public class Context implements AutoCloseable { public final ReExpr mkDiff(Expr> a, Expr> b) { checkContextMatch(a, b); - return (ReExpr) Expr.create(this, Native.mkReDiff(nCtx(), a.getNativeObject(), b.getNativeObject())); + return (ReExpr) Expr.create(this, Native.mkReDiff(nCtx(), a.getNativeObject(), b.getNativeObject())); } diff --git a/src/api/java/NativeStatic.txt b/src/api/java/NativeStatic.txt index 4693272d5..f68893e6b 100644 --- a/src/api/java/NativeStatic.txt +++ b/src/api/java/NativeStatic.txt @@ -77,3 +77,156 @@ DLL_VIS JNIEXPORT void JNICALL Java_com_microsoft_z3_Native_setInternalErrorHand Z3_set_error_handler((Z3_context)a0, Z3JavaErrorHandler); } + +#include + +struct JavaInfo { + JNIEnv *jenv = nullptr; + jobject jobj = nullptr; + + jmethodID push = nullptr; + jmethodID pop = nullptr; + jmethodID fresh = nullptr; + jmethodID created = nullptr; + jmethodID fixed = nullptr; + jmethodID eq = nullptr; + jmethodID final = nullptr; + + Z3_solver_callback cb = nullptr; +}; + +struct ScopedCB { + JavaInfo *info; + ScopedCB(JavaInfo *_info, Z3_solver_callback cb): info(_info) { + info->cb = cb; + } + ~ScopedCB() { + info->cb = nullptr; + } +}; + +static void push_eh(void* _p, Z3_solver_callback cb) { + JavaInfo *info = static_cast(_p); + ScopedCB scoped(info, cb); + info->jenv->CallVoidMethod(info->jobj, info->push); +} + +static void pop_eh(void* _p, Z3_solver_callback cb, unsigned int number) { + JavaInfo *info = static_cast(_p); + ScopedCB scoped(info, cb); + info->jenv->CallVoidMethod(info->jobj, info->pop, number); +} + +static void* fresh_eh(void* _p, Z3_context new_context) { + JavaInfo *info = static_cast(_p); + return info->jenv->CallObjectMethod(info->jobj, info->fresh, (jlong)new_context); +} + +static void created_eh(void* _p, Z3_solver_callback cb, Z3_ast _e) { + JavaInfo *info = static_cast(_p); + ScopedCB scoped(info, cb); + info->jenv->CallVoidMethod(info->jobj, info->created, (jlong)_e); +} + +static void fixed_eh(void* _p, Z3_solver_callback cb, Z3_ast _var, Z3_ast _value) { + JavaInfo *info = static_cast(_p); + ScopedCB scoped(info, cb); + info->jenv->CallVoidMethod(info->jobj, info->fixed, (jlong)_var, (jlong)_value); +} + +static void eq_eh(void* _p, Z3_solver_callback cb, Z3_ast _x, Z3_ast _y) { + JavaInfo *info = static_cast(_p); + ScopedCB scoped(info, cb); + info->jenv->CallVoidMethod(info->jobj, info->eq, (jlong)_x, (jlong)_y); +} + +static void final_eh(void* _p, Z3_solver_callback cb) { + JavaInfo *info = static_cast(_p); + ScopedCB scoped(info, cb); + info->jenv->CallVoidMethod(info->jobj, info->final); +} + +// TODO: implement decide +static void decide_eh(void* _p, Z3_solver_callback cb, Z3_ast* _val, unsigned* bit, Z3_lbool* is_pos) { + JavaInfo *info = static_cast(_p); + ScopedCB scoped(info, cb); + +} + +DLL_VIS JNIEXPORT jlong JNICALL Java_com_microsoft_z3_Native_propagateInit(JNIEnv *jenv, jclass cls, jobject jobj, jlong ctx, jlong solver) { + JavaInfo *info = new JavaInfo; + + info->jenv = jenv; + info->jobj = jenv->NewGlobalRef(jobj); + jclass jcls = jenv->GetObjectClass(info->jobj); + info->push = jenv->GetMethodID(jcls, "pushWrapper", "()V"); + info->pop = jenv->GetMethodID(jcls, "popWrapper", "(I)V"); + info->fresh = jenv->GetMethodID(jcls, "freshWrapper", "(J)Lcom/microsoft/z3/Native$UserPropagatorBase;"); + info->created = jenv->GetMethodID(jcls, "createdWrapper", "(J)V"); + info->fixed = jenv->GetMethodID(jcls, "fixedWrapper", "(JJ)V"); + info->eq = jenv->GetMethodID(jcls, "eqWrapper", "(JJ)V"); + info->final = jenv->GetMethodID(jcls, "finWrapper", "()V"); + + if (!info->push || !info->pop || !info->fresh || !info->created || !info->fixed || !info->eq || !info->final) { + assert(false); + } + + Z3_solver_propagate_init((Z3_context)ctx, (Z3_solver)solver, info, push_eh, pop_eh, fresh_eh); + + return (jlong)info; +} + +DLL_VIS JNIEXPORT void JNICALL Java_com_microsoft_z3_Native_propagateDestroy(JNIEnv *jenv, jclass cls, jobject jobj, jlong ctx, jlong solver, jlong javainfo) { + JavaInfo *info = (JavaInfo*)javainfo; + info->jenv->DeleteGlobalRef(info->jobj); + delete info; +} + +DLL_VIS JNIEXPORT void JNICALL Java_com_microsoft_z3_Native_propagateRegisterCreated(JNIEnv * jenv, jclass cls, jobject jobj, jlong ctx, jlong solver) { + Z3_solver_propagate_created((Z3_context)ctx, (Z3_solver)solver, created_eh); +} + +DLL_VIS JNIEXPORT void JNICALL Java_com_microsoft_z3_Native_propagateRegisterFinal(JNIEnv * jenv, jclass cls, jobject jobj, jlong ctx, jlong solver) { + Z3_solver_propagate_final((Z3_context)ctx, (Z3_solver)solver, final_eh); +} + +DLL_VIS JNIEXPORT void JNICALL Java_com_microsoft_z3_Native_propagateRegisterFixed(JNIEnv * jenv, jclass cls, jobject jobj, jlong ctx, jlong solver) { + Z3_solver_propagate_fixed((Z3_context)ctx, (Z3_solver)solver, fixed_eh); +} + +DLL_VIS JNIEXPORT void JNICALL Java_com_microsoft_z3_Native_propagateRegisterEq(JNIEnv * jenv, jclass cls, jobject jobj, jlong ctx, jlong solver) { + Z3_solver_propagate_eq((Z3_context)ctx, (Z3_solver)solver, eq_eh); +} + +DLL_VIS JNIEXPORT void JNICALL Java_com_microsoft_z3_Native_propagateRegisterDecide(JNIEnv * jenv, jclass cls, jobject jobj, jlong ctx, jlong solver) { + Z3_solver_propagate_decide((Z3_context)ctx, (Z3_solver)solver, decide_eh); +} + +DLL_VIS JNIEXPORT void JNICALL Java_com_microsoft_z3_Native_propagateConflict(JNIEnv * jenv, jclass cls, jobject jobj, jlong ctx, jlong solver, jlong javainfo, jlong num_fixed, jlongArray fixed, jlong num_eqs, jlongArray eq_lhs, jlongArray eq_rhs, jlong conseq) { + JavaInfo *info = (JavaInfo*)javainfo; + GETLONGAELEMS(Z3_ast, fixed, _fixed); + GETLONGAELEMS(Z3_ast, eq_lhs, _eq_lhs); + GETLONGAELEMS(Z3_ast, eq_rhs, _eq_rhs); + Z3_solver_propagate_consequence((Z3_context)ctx, info->cb, num_fixed, _fixed, num_eqs, _eq_lhs, _eq_rhs, (Z3_ast)conseq); + RELEASELONGAELEMS(fixed, _fixed); + RELEASELONGAELEMS(eq_lhs, _eq_lhs); + RELEASELONGAELEMS(eq_rhs, _eq_rhs); +} + +DLL_VIS JNIEXPORT void JNICALL Java_com_microsoft_z3_Native_propagateAdd(JNIEnv * jenv, jclass cls, jobject jobj, jlong ctx, jlong solver, jlong javainfo, jlong e) { + JavaInfo *info = (JavaInfo*)javainfo; + Z3_solver_callback cb = info->cb; + if (cb) + Z3_solver_propagate_register_cb((Z3_context)ctx, cb, (Z3_ast)e); + else if (solver) + Z3_solver_propagate_register((Z3_context)ctx, (Z3_solver)solver, (Z3_ast)e); + else { + assert(false); + } +} + +DLL_VIS JNIEXPORT void JNICALL Java_com_microsoft_z3_Native_propagateNextSplit(JNIEnv * jenv, jclass cls, jobject jobj, jlong ctx, jlong solver, jlong javainfo, long e, long idx, long phase) { + JavaInfo *info = (JavaInfo*)javainfo; + Z3_solver_callback cb = info->cb; + Z3_solver_next_split((Z3_context)ctx, cb, (Z3_ast)e, idx, Z3_lbool(phase)); +} diff --git a/src/api/java/UserPropagatorBase.java b/src/api/java/UserPropagatorBase.java new file mode 100644 index 000000000..90243ba68 --- /dev/null +++ b/src/api/java/UserPropagatorBase.java @@ -0,0 +1,97 @@ +package com.microsoft.z3; + +import com.microsoft.z3.Context; +import com.microsoft.z3.enumerations.Z3_lbool; + +public abstract class UserPropagatorBase extends Native.UserPropagatorBase { + private Context ctx; + private Solver solver; + + public UserPropagatorBase(Context _ctx, Solver _solver) { + super(_ctx.nCtx(), _solver.getNativeObject()); + ctx = _ctx; + solver = _solver; + } + + public final Context getCtx() { + return ctx; + } + + public final Solver getSolver() { + return solver; + } + + @Override + protected final void pushWrapper() { + push(); + } + + @Override + protected final void popWrapper(int number) { + pop(number); + } + + @Override + protected final void finWrapper() { + fin(); + } + + @Override + protected final void eqWrapper(long lx, long ly) { + Expr x = new Expr(ctx, lx); + Expr y = new Expr(ctx, ly); + eq(x, y); + } + + @Override + protected final UserPropagatorBase freshWrapper(long lctx) { + return fresh(new Context(lctx)); + } + + @Override + protected final void createdWrapper(long last) { + created(new Expr(ctx, last)); + } + + @Override + protected final void fixedWrapper(long lvar, long lvalue) { + Expr var = new Expr(ctx, lvar); + Expr value = new Expr(ctx, lvalue); + fixed(var, value); + } + + public abstract void push(); + + public abstract void pop(int number); + + public abstract UserPropagatorBase fresh(Context ctx); + + public void created(Expr ast) {} + + public void fixed(Expr var, Expr value) {} + + public void eq(Expr x, Expr y) {} + + public void fin() {} + + public final void add(Expr expr) { + Native.propagateAdd(this, ctx.nCtx(), solver.getNativeObject(), javainfo, expr.getNativeObject()); + } + + public final void conflict(Expr[] fixed) { + conflict(fixed, new Expr[0], new Expr[0]); + } + + public final void conflict(Expr[] fixed, Expr[] lhs, Expr[] rhs) { + AST conseq = ctx.mkBool(false); + Native.propagateConflict( + this, ctx.nCtx(), solver.getNativeObject(), javainfo, + fixed.length, AST.arrayToNative(fixed), lhs.length, AST.arrayToNative(lhs), AST.arrayToNative(rhs), conseq.getNativeObject()); + } + + public final void nextSplit(Expr e, long idx, Z3_lbool phase) { + Native.propagateNextSplit( + this, ctx.nCtx(), solver.getNativeObject(), javainfo, + e.getNativeObject(), idx, phase.toInt()); + } +} \ No newline at end of file