3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-06 17:44:08 +00:00

Java user propagator interface (#6733)

* Java API: user propagator interface

* Java API: improved user propagator interface

* Java API: Add UserPropagatorBase.java

* Remove redundant header file

* Initialize `JavaInfo` object and error handling

* Native.UserPropagatorBase implements AutoCloseable

* Add Override annotation
This commit is contained in:
ditto 2023-05-25 01:27:28 +08:00 committed by GitHub
parent 2c21072c99
commit 11264c38d8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 337 additions and 6 deletions

View file

@ -633,7 +633,72 @@ def mk_java(java_src, java_dir, package_name):
java_native.write(' }\n')
java_native.write(' }\n')
java_native.write(' }\n')
java_native.write("""
public static native long propagateInit(Object o, long ctx, long solver);
public static native void propagateRegisterCreated(Object o, long ctx, long solver);
public static native void propagateRegisterFixed(Object o, long ctx, long solver);
public static native void propagateRegisterEq(Object o, long ctx, long solver);
public static native void propagateRegisterDecide(Object o, long ctx, long solver);
public static native void propagateRegisterFinal(Object o, long ctx, long solver);
public static native void propagateConflict(Object o, long ctx, long solver, long javainfo, int num_fixed, long[] fixed, long num_eqs, long[] eq_lhs, long[] eq_rhs, long conseq);
public static native void propagateAdd(Object o, long ctx, long solver, long javainfo, long e);
public static native void propagateNextSplit(Object o, long ctx, long solver, long javainfo, long e, long idx, long phase);
public static native void propagateDestroy(Object o, long ctx, long solver, long javainfo);
public static abstract class UserPropagatorBase implements AutoCloseable {
protected long ctx;
protected long solver;
protected long javainfo;
public UserPropagatorBase(long _ctx, long _solver) {
ctx = _ctx;
solver = _solver;
javainfo = propagateInit(this, ctx, solver);
}
@Override
public void close() {
Native.propagateDestroy(this, ctx, solver, javainfo);
javainfo = 0;
solver = 0;
ctx = 0;
}
protected final void registerCreated() {
Native.propagateRegisterCreated(this, ctx, solver);
}
protected final void registerFixed() {
Native.propagateRegisterFixed(this, ctx, solver);
}
protected final void registerEq() {
Native.propagateRegisterEq(this, ctx, solver);
}
protected final void registerDecide() {
Native.propagateRegisterDecide(this, ctx, solver);
}
protected final void registerFinal() {
Native.propagateRegisterFinal(this, ctx, solver);
}
protected abstract void pushWrapper();
protected abstract void popWrapper(int number);
protected abstract void finWrapper();
protected abstract void eqWrapper(long lx, long ly);
protected abstract UserPropagatorBase freshWrapper(long lctx);
protected abstract void createdWrapper(long le);
protected abstract void fixedWrapper(long lvar, long lvalue);
}
""")
java_native.write('\n')
for name, result, params in _dotnet_decls:
java_native.write(' protected static native %s INTERNAL%s(' % (type2java(result), java_method_name(name)))

View file

@ -179,6 +179,7 @@ set(Z3_JAVA_JAR_SOURCE_FILES
Tactic.java
TupleSort.java
UninterpretedSort.java
UserPropagatorBase.java
Version.java
Z3Exception.java
Z3Object.java

View file

@ -452,6 +452,21 @@ public class Context implements AutoCloseable {
return new FuncDecl<>(this, name, domain, range);
}
public final <R extends Sort> FuncDecl<R> mkPropagateFunction(Symbol name, Sort[] domain, R range)
{
checkContextMatch(name);
checkContextMatch(domain);
checkContextMatch(range);
long f = Native.solverPropagateDeclare(
this.nCtx(),
name.getNativeObject(),
AST.arrayLength(domain),
AST.arrayToNative(domain),
range.getNativeObject());
return new FuncDecl<>(this, f);
}
/**
* Creates a new function declaration.
**/
@ -2018,11 +2033,11 @@ public class Context implements AutoCloseable {
{
StringBuilder buf = new StringBuilder();
for (int i = 0; i < s.length(); ++i) {
int code = s.codePointAt(i);
if (code <= 32 || 127 < code)
buf.append(String.format("\\u{%x}", code));
else
buf.append(s.charAt(i));
int code = s.codePointAt(i);
if (code <= 32 || 127 < code)
buf.append(String.format("\\u{%x}", code));
else
buf.append(s.charAt(i));
}
return (SeqExpr<CharSort>) Expr.create(this, Native.mkString(nCtx(), buf.toString()));
}
@ -2288,7 +2303,7 @@ public class Context implements AutoCloseable {
public final <R extends Sort> ReExpr<R> mkDiff(Expr<ReSort<R>> a, Expr<ReSort<R>> b)
{
checkContextMatch(a, b);
return (ReExpr<R>) Expr.create(this, Native.mkReDiff(nCtx(), a.getNativeObject(), b.getNativeObject()));
return (ReExpr<R>) Expr.create(this, Native.mkReDiff(nCtx(), a.getNativeObject(), b.getNativeObject()));
}

View file

@ -77,3 +77,156 @@ DLL_VIS JNIEXPORT void JNICALL Java_com_microsoft_z3_Native_setInternalErrorHand
Z3_set_error_handler((Z3_context)a0, Z3JavaErrorHandler);
}
#include <assert.h>
struct JavaInfo {
JNIEnv *jenv = nullptr;
jobject jobj = nullptr;
jmethodID push = nullptr;
jmethodID pop = nullptr;
jmethodID fresh = nullptr;
jmethodID created = nullptr;
jmethodID fixed = nullptr;
jmethodID eq = nullptr;
jmethodID final = nullptr;
Z3_solver_callback cb = nullptr;
};
struct ScopedCB {
JavaInfo *info;
ScopedCB(JavaInfo *_info, Z3_solver_callback cb): info(_info) {
info->cb = cb;
}
~ScopedCB() {
info->cb = nullptr;
}
};
static void push_eh(void* _p, Z3_solver_callback cb) {
JavaInfo *info = static_cast<JavaInfo*>(_p);
ScopedCB scoped(info, cb);
info->jenv->CallVoidMethod(info->jobj, info->push);
}
static void pop_eh(void* _p, Z3_solver_callback cb, unsigned int number) {
JavaInfo *info = static_cast<JavaInfo*>(_p);
ScopedCB scoped(info, cb);
info->jenv->CallVoidMethod(info->jobj, info->pop, number);
}
static void* fresh_eh(void* _p, Z3_context new_context) {
JavaInfo *info = static_cast<JavaInfo*>(_p);
return info->jenv->CallObjectMethod(info->jobj, info->fresh, (jlong)new_context);
}
static void created_eh(void* _p, Z3_solver_callback cb, Z3_ast _e) {
JavaInfo *info = static_cast<JavaInfo*>(_p);
ScopedCB scoped(info, cb);
info->jenv->CallVoidMethod(info->jobj, info->created, (jlong)_e);
}
static void fixed_eh(void* _p, Z3_solver_callback cb, Z3_ast _var, Z3_ast _value) {
JavaInfo *info = static_cast<JavaInfo*>(_p);
ScopedCB scoped(info, cb);
info->jenv->CallVoidMethod(info->jobj, info->fixed, (jlong)_var, (jlong)_value);
}
static void eq_eh(void* _p, Z3_solver_callback cb, Z3_ast _x, Z3_ast _y) {
JavaInfo *info = static_cast<JavaInfo*>(_p);
ScopedCB scoped(info, cb);
info->jenv->CallVoidMethod(info->jobj, info->eq, (jlong)_x, (jlong)_y);
}
static void final_eh(void* _p, Z3_solver_callback cb) {
JavaInfo *info = static_cast<JavaInfo*>(_p);
ScopedCB scoped(info, cb);
info->jenv->CallVoidMethod(info->jobj, info->final);
}
// TODO: implement decide
static void decide_eh(void* _p, Z3_solver_callback cb, Z3_ast* _val, unsigned* bit, Z3_lbool* is_pos) {
JavaInfo *info = static_cast<JavaInfo*>(_p);
ScopedCB scoped(info, cb);
}
DLL_VIS JNIEXPORT jlong JNICALL Java_com_microsoft_z3_Native_propagateInit(JNIEnv *jenv, jclass cls, jobject jobj, jlong ctx, jlong solver) {
JavaInfo *info = new JavaInfo;
info->jenv = jenv;
info->jobj = jenv->NewGlobalRef(jobj);
jclass jcls = jenv->GetObjectClass(info->jobj);
info->push = jenv->GetMethodID(jcls, "pushWrapper", "()V");
info->pop = jenv->GetMethodID(jcls, "popWrapper", "(I)V");
info->fresh = jenv->GetMethodID(jcls, "freshWrapper", "(J)Lcom/microsoft/z3/Native$UserPropagatorBase;");
info->created = jenv->GetMethodID(jcls, "createdWrapper", "(J)V");
info->fixed = jenv->GetMethodID(jcls, "fixedWrapper", "(JJ)V");
info->eq = jenv->GetMethodID(jcls, "eqWrapper", "(JJ)V");
info->final = jenv->GetMethodID(jcls, "finWrapper", "()V");
if (!info->push || !info->pop || !info->fresh || !info->created || !info->fixed || !info->eq || !info->final) {
assert(false);
}
Z3_solver_propagate_init((Z3_context)ctx, (Z3_solver)solver, info, push_eh, pop_eh, fresh_eh);
return (jlong)info;
}
DLL_VIS JNIEXPORT void JNICALL Java_com_microsoft_z3_Native_propagateDestroy(JNIEnv *jenv, jclass cls, jobject jobj, jlong ctx, jlong solver, jlong javainfo) {
JavaInfo *info = (JavaInfo*)javainfo;
info->jenv->DeleteGlobalRef(info->jobj);
delete info;
}
DLL_VIS JNIEXPORT void JNICALL Java_com_microsoft_z3_Native_propagateRegisterCreated(JNIEnv * jenv, jclass cls, jobject jobj, jlong ctx, jlong solver) {
Z3_solver_propagate_created((Z3_context)ctx, (Z3_solver)solver, created_eh);
}
DLL_VIS JNIEXPORT void JNICALL Java_com_microsoft_z3_Native_propagateRegisterFinal(JNIEnv * jenv, jclass cls, jobject jobj, jlong ctx, jlong solver) {
Z3_solver_propagate_final((Z3_context)ctx, (Z3_solver)solver, final_eh);
}
DLL_VIS JNIEXPORT void JNICALL Java_com_microsoft_z3_Native_propagateRegisterFixed(JNIEnv * jenv, jclass cls, jobject jobj, jlong ctx, jlong solver) {
Z3_solver_propagate_fixed((Z3_context)ctx, (Z3_solver)solver, fixed_eh);
}
DLL_VIS JNIEXPORT void JNICALL Java_com_microsoft_z3_Native_propagateRegisterEq(JNIEnv * jenv, jclass cls, jobject jobj, jlong ctx, jlong solver) {
Z3_solver_propagate_eq((Z3_context)ctx, (Z3_solver)solver, eq_eh);
}
DLL_VIS JNIEXPORT void JNICALL Java_com_microsoft_z3_Native_propagateRegisterDecide(JNIEnv * jenv, jclass cls, jobject jobj, jlong ctx, jlong solver) {
Z3_solver_propagate_decide((Z3_context)ctx, (Z3_solver)solver, decide_eh);
}
DLL_VIS JNIEXPORT void JNICALL Java_com_microsoft_z3_Native_propagateConflict(JNIEnv * jenv, jclass cls, jobject jobj, jlong ctx, jlong solver, jlong javainfo, jlong num_fixed, jlongArray fixed, jlong num_eqs, jlongArray eq_lhs, jlongArray eq_rhs, jlong conseq) {
JavaInfo *info = (JavaInfo*)javainfo;
GETLONGAELEMS(Z3_ast, fixed, _fixed);
GETLONGAELEMS(Z3_ast, eq_lhs, _eq_lhs);
GETLONGAELEMS(Z3_ast, eq_rhs, _eq_rhs);
Z3_solver_propagate_consequence((Z3_context)ctx, info->cb, num_fixed, _fixed, num_eqs, _eq_lhs, _eq_rhs, (Z3_ast)conseq);
RELEASELONGAELEMS(fixed, _fixed);
RELEASELONGAELEMS(eq_lhs, _eq_lhs);
RELEASELONGAELEMS(eq_rhs, _eq_rhs);
}
DLL_VIS JNIEXPORT void JNICALL Java_com_microsoft_z3_Native_propagateAdd(JNIEnv * jenv, jclass cls, jobject jobj, jlong ctx, jlong solver, jlong javainfo, jlong e) {
JavaInfo *info = (JavaInfo*)javainfo;
Z3_solver_callback cb = info->cb;
if (cb)
Z3_solver_propagate_register_cb((Z3_context)ctx, cb, (Z3_ast)e);
else if (solver)
Z3_solver_propagate_register((Z3_context)ctx, (Z3_solver)solver, (Z3_ast)e);
else {
assert(false);
}
}
DLL_VIS JNIEXPORT void JNICALL Java_com_microsoft_z3_Native_propagateNextSplit(JNIEnv * jenv, jclass cls, jobject jobj, jlong ctx, jlong solver, jlong javainfo, long e, long idx, long phase) {
JavaInfo *info = (JavaInfo*)javainfo;
Z3_solver_callback cb = info->cb;
Z3_solver_next_split((Z3_context)ctx, cb, (Z3_ast)e, idx, Z3_lbool(phase));
}

View file

@ -0,0 +1,97 @@
package com.microsoft.z3;
import com.microsoft.z3.Context;
import com.microsoft.z3.enumerations.Z3_lbool;
public abstract class UserPropagatorBase extends Native.UserPropagatorBase {
private Context ctx;
private Solver solver;
public UserPropagatorBase(Context _ctx, Solver _solver) {
super(_ctx.nCtx(), _solver.getNativeObject());
ctx = _ctx;
solver = _solver;
}
public final Context getCtx() {
return ctx;
}
public final Solver getSolver() {
return solver;
}
@Override
protected final void pushWrapper() {
push();
}
@Override
protected final void popWrapper(int number) {
pop(number);
}
@Override
protected final void finWrapper() {
fin();
}
@Override
protected final void eqWrapper(long lx, long ly) {
Expr x = new Expr(ctx, lx);
Expr y = new Expr(ctx, ly);
eq(x, y);
}
@Override
protected final UserPropagatorBase freshWrapper(long lctx) {
return fresh(new Context(lctx));
}
@Override
protected final void createdWrapper(long last) {
created(new Expr(ctx, last));
}
@Override
protected final void fixedWrapper(long lvar, long lvalue) {
Expr var = new Expr(ctx, lvar);
Expr value = new Expr(ctx, lvalue);
fixed(var, value);
}
public abstract void push();
public abstract void pop(int number);
public abstract UserPropagatorBase fresh(Context ctx);
public <R extends Sort> void created(Expr<R> ast) {}
public <R extends Sort> void fixed(Expr<R> var, Expr<R> value) {}
public <R extends Sort> void eq(Expr<R> x, Expr<R> y) {}
public void fin() {}
public final <R extends Sort> void add(Expr<R> expr) {
Native.propagateAdd(this, ctx.nCtx(), solver.getNativeObject(), javainfo, expr.getNativeObject());
}
public final <R extends Sort> void conflict(Expr<R>[] fixed) {
conflict(fixed, new Expr[0], new Expr[0]);
}
public final <R extends Sort> void conflict(Expr<R>[] fixed, Expr<R>[] lhs, Expr<R>[] rhs) {
AST conseq = ctx.mkBool(false);
Native.propagateConflict(
this, ctx.nCtx(), solver.getNativeObject(), javainfo,
fixed.length, AST.arrayToNative(fixed), lhs.length, AST.arrayToNative(lhs), AST.arrayToNative(rhs), conseq.getNativeObject());
}
public final <R extends Sort> void nextSplit(Expr<R> e, long idx, Z3_lbool phase) {
Native.propagateNextSplit(
this, ctx.nCtx(), solver.getNativeObject(), javainfo,
e.getNativeObject(), idx, phase.toInt());
}
}