From d2706bab649b966170189d25090284c21cce7186 Mon Sep 17 00:00:00 2001
From: Thomas Haas <tomy.haas@t-online.de>
Date: Thu, 18 Jan 2024 18:29:15 +0100
Subject: [PATCH] Fixes in Java's User Propagator (#7088)

* Fixed decide callback for Java user propagators

* Java User Prop:
- Added return value to conflict
- Added consequence method
- Added missing access modifier to decideWrapper

* Removed type parameters of expressions in UserPropagatorBase

* Renamed propagateConflict to propagateConsequence
---
 scripts/update_api.py                |  4 +++-
 src/api/java/NativeStatic.txt        | 13 ++++++------
 src/api/java/UserPropagatorBase.java | 31 +++++++++++++++++++---------
 3 files changed, 31 insertions(+), 17 deletions(-)

diff --git a/scripts/update_api.py b/scripts/update_api.py
index 7d3d8899f..79f144142 100755
--- a/scripts/update_api.py
+++ b/scripts/update_api.py
@@ -641,8 +641,8 @@ def mk_java(java_src, java_dir, package_name):
   public static native void propagateRegisterEq(Object o, long ctx, long solver);
   public static native void propagateRegisterDecide(Object o, long ctx, long solver);
   public static native void propagateRegisterFinal(Object o, long ctx, long solver);
-  public static native void propagateConflict(Object o, long ctx, long solver, long javainfo, int num_fixed, long[] fixed, long num_eqs, long[] eq_lhs, long[] eq_rhs, long conseq);
   public static native void propagateAdd(Object o, long ctx, long solver, long javainfo, long e);
+  public static native boolean propagateConsequence(Object o, long ctx, long solver, long javainfo, int num_fixed, long[] fixed, long num_eqs, long[] eq_lhs, long[] eq_rhs, long conseq);
   public static native boolean propagateNextSplit(Object o, long ctx, long solver, long javainfo, long e, long idx, int phase);
   public static native void propagateDestroy(Object o, long ctx, long solver, long javainfo);
 
@@ -698,6 +698,8 @@ def mk_java(java_src, java_dir, package_name):
     protected abstract void createdWrapper(long le);
 
     protected abstract void fixedWrapper(long lvar, long lvalue);
+
+    protected abstract void decideWrapper(long lvar, int bit, boolean is_pos);
   }
     """)
     java_native.write('\n')
diff --git a/src/api/java/NativeStatic.txt b/src/api/java/NativeStatic.txt
index 9507130fd..21d6ba075 100644
--- a/src/api/java/NativeStatic.txt
+++ b/src/api/java/NativeStatic.txt
@@ -150,7 +150,7 @@ static void final_eh(void* _p, Z3_solver_callback cb) {
 static void decide_eh(void* _p, Z3_solver_callback cb, Z3_ast _val, unsigned bit, bool is_pos) {
   JavaInfo *info = static_cast<JavaInfo*>(_p);
   ScopedCB scoped(info, cb);
-  info->jenv->CallVoidMethod(info->jobj, info->decide, (jlong)_val);
+  info->jenv->CallVoidMethod(info->jobj, info->decide, (jlong)_val, bit, is_pos);
 }
 
 DLL_VIS JNIEXPORT jlong JNICALL Java_com_microsoft_z3_Native_propagateInit(JNIEnv *jenv, jclass cls, jobject jobj, jlong ctx, jlong solver) {
@@ -166,7 +166,7 @@ DLL_VIS JNIEXPORT jlong JNICALL Java_com_microsoft_z3_Native_propagateInit(JNIEn
   info->fixed = jenv->GetMethodID(jcls, "fixedWrapper", "(JJ)V");
   info->eq = jenv->GetMethodID(jcls, "eqWrapper", "(JJ)V");
   info->final = jenv->GetMethodID(jcls, "finWrapper", "()V");
-  info->decide = jenv->GetMethodID(jcls, "decideWrapper", "(JII)V");
+  info->decide = jenv->GetMethodID(jcls, "decideWrapper", "(JIZ)V");
 
   if (!info->push || !info->pop || !info->fresh || !info->created || !info->fixed || !info->eq || !info->final || !info->decide) {
     assert(false);
@@ -203,15 +203,16 @@ DLL_VIS JNIEXPORT void JNICALL Java_com_microsoft_z3_Native_propagateRegisterDec
   Z3_solver_propagate_decide((Z3_context)ctx, (Z3_solver)solver, decide_eh);
 }
 
-DLL_VIS JNIEXPORT void JNICALL Java_com_microsoft_z3_Native_propagateConflict(JNIEnv * jenv, jclass cls, jobject jobj, jlong ctx, jlong solver, jlong javainfo, long num_fixed, jlongArray fixed, long num_eqs, jlongArray eq_lhs, jlongArray eq_rhs, jlong conseq) {
+DLL_VIS JNIEXPORT jboolean JNICALL Java_com_microsoft_z3_Native_propagateConsequence(JNIEnv * jenv, jclass cls, jobject jobj, jlong ctx, jlong solver, jlong javainfo, long num_fixed, jlongArray fixed, long num_eqs, jlongArray eq_lhs, jlongArray eq_rhs, jlong conseq) {
   JavaInfo *info = (JavaInfo*)javainfo;
   GETLONGAELEMS(Z3_ast, fixed, _fixed);
   GETLONGAELEMS(Z3_ast, eq_lhs, _eq_lhs);
   GETLONGAELEMS(Z3_ast, eq_rhs, _eq_rhs);
-  Z3_solver_propagate_consequence((Z3_context)ctx, info->cb, num_fixed, _fixed, num_eqs, _eq_lhs, _eq_rhs, (Z3_ast)conseq);
+  bool retval = Z3_solver_propagate_consequence((Z3_context)ctx, info->cb, num_fixed, _fixed, num_eqs, _eq_lhs, _eq_rhs, (Z3_ast)conseq);
   RELEASELONGAELEMS(fixed, _fixed);
   RELEASELONGAELEMS(eq_lhs, _eq_lhs);
   RELEASELONGAELEMS(eq_rhs, _eq_rhs);
+  return (jboolean) retval;
 }
 
 DLL_VIS JNIEXPORT void JNICALL Java_com_microsoft_z3_Native_propagateAdd(JNIEnv * jenv, jclass cls, jobject jobj, jlong ctx, jlong solver, jlong javainfo, jlong e) {
@@ -227,8 +228,8 @@ DLL_VIS JNIEXPORT void JNICALL Java_com_microsoft_z3_Native_propagateAdd(JNIEnv
 }
 
 
-DLL_VIS JNIEXPORT bool JNICALL Java_com_microsoft_z3_Native_propagateNextSplit(JNIEnv * jenv, jclass cls, jobject jobj, jlong ctx, jlong solver, jlong javainfo, jlong e, long idx, int phase) {
+DLL_VIS JNIEXPORT jboolean JNICALL Java_com_microsoft_z3_Native_propagateNextSplit(JNIEnv * jenv, jclass cls, jobject jobj, jlong ctx, jlong solver, jlong javainfo, jlong e, long idx, int phase) {
   JavaInfo *info = (JavaInfo*)javainfo;
   Z3_solver_callback cb = info->cb;
-  return Z3_solver_next_split((Z3_context)ctx, cb, (Z3_ast)e, idx, Z3_lbool(phase));
+  return (jboolean) Z3_solver_next_split((Z3_context)ctx, cb, (Z3_ast)e, idx, Z3_lbool(phase));
 }
diff --git a/src/api/java/UserPropagatorBase.java b/src/api/java/UserPropagatorBase.java
index 407d3d0da..46a61400d 100644
--- a/src/api/java/UserPropagatorBase.java
+++ b/src/api/java/UserPropagatorBase.java
@@ -60,36 +60,47 @@ public abstract class UserPropagatorBase extends Native.UserPropagatorBase {
         fixed(var, value);
     }
 
+    @Override
+    protected final void decideWrapper(long lvar, int bit, boolean is_pos) {
+        Expr var = new Expr(ctx, lvar);
+        decide(var, bit, is_pos);
+    }
+
     public abstract void push();
 
     public abstract void pop(int number);
 
     public abstract UserPropagatorBase fresh(Context ctx);
 
-    public <R extends Sort> void created(Expr<R> ast) {}
+    public void created(Expr<?> ast) {}
 
-    public <R extends Sort> void fixed(Expr<R> var, Expr<R> value) {}
+    public void fixed(Expr<?> var, Expr<?> value) {}
 
-    public <R extends Sort> void eq(Expr<R> x, Expr<R> y) {}
+    public void eq(Expr<?> x, Expr<?> y) {}
+
+    public void decide(Expr<?> var, int bit, boolean is_pos) {}
 
     public void fin() {}
 
-    public final <R extends Sort> void add(Expr<R> expr) {
+    public final void add(Expr<?> expr) {
         Native.propagateAdd(this, ctx.nCtx(), solver.getNativeObject(), javainfo, expr.getNativeObject());
     }
 
-    public final <R extends Sort> void conflict(Expr<R>[] fixed) {
-        conflict(fixed, new Expr[0], new Expr[0]);
+    public final boolean conflict(Expr<?>[] fixed) {
+        return conflict(fixed, new Expr[0], new Expr[0]);
     }
 
-    public final <R extends Sort> void conflict(Expr<R>[] fixed, Expr<R>[] lhs, Expr<R>[] rhs) {
-        AST conseq = ctx.mkBool(false);
-        Native.propagateConflict(
+    public final boolean conflict(Expr<?>[] fixed, Expr<?>[] lhs, Expr<?>[] rhs) {
+        return consequence(fixed, lhs, rhs, ctx.mkBool(false));
+    }
+
+    public final boolean consequence(Expr<?>[] fixed, Expr<?>[] lhs, Expr<?>[] rhs, Expr<?> conseq) {
+        return Native.propagateConsequence(
             this, ctx.nCtx(), solver.getNativeObject(), javainfo,
             fixed.length, AST.arrayToNative(fixed), lhs.length, AST.arrayToNative(lhs), AST.arrayToNative(rhs), conseq.getNativeObject());
     }
 
-    public final <R extends Sort> boolean nextSplit(Expr<R> e, long idx, Z3_lbool phase) {
+    public final boolean nextSplit(Expr<?> e, long idx, Z3_lbool phase) {
         return Native.propagateNextSplit(
             this, ctx.nCtx(), solver.getNativeObject(), javainfo,
             e.getNativeObject(), idx, phase.toInt());