diff --git a/scripts/update_api.py b/scripts/update_api.py index 23d044832..863aa7394 100755 --- a/scripts/update_api.py +++ b/scripts/update_api.py @@ -642,7 +642,7 @@ def mk_java(java_src, java_dir, package_name): public static native void propagateRegisterFinal(Object o, long ctx, long solver); public static native void propagateConflict(Object o, long ctx, long solver, long javainfo, int num_fixed, long[] fixed, long num_eqs, long[] eq_lhs, long[] eq_rhs, long conseq); public static native void propagateAdd(Object o, long ctx, long solver, long javainfo, long e); - public static native void propagateNextSplit(Object o, long ctx, long solver, long javainfo, long e, long idx, long phase); + public static native boolean propagateNextSplit(Object o, long ctx, long solver, long javainfo, long e, long idx, int phase); public static native void propagateDestroy(Object o, long ctx, long solver, long javainfo); public static abstract class UserPropagatorBase implements AutoCloseable { @@ -1929,7 +1929,7 @@ Z3_final_eh = ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_void_p) Z3_eq_eh = ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p) Z3_created_eh = ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p) -Z3_decide_eh = ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p) +Z3_decide_eh = ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_uint, ctypes.c_int) _lib.Z3_solver_register_on_clause.restype = None _lib.Z3_solver_propagate_init.restype = None diff --git a/src/api/api_solver.cpp b/src/api/api_solver.cpp index 2c19d0d9e..08f864226 100644 --- a/src/api/api_solver.cpp +++ b/src/api/api_solver.cpp @@ -1114,17 +1114,17 @@ extern "C" { void Z3_API Z3_solver_propagate_decide(Z3_context c, Z3_solver s, Z3_decide_eh decide_eh) { Z3_TRY; RESET_ERROR_CODE(); - user_propagator::decide_eh_t c = (void(*)(void*, user_propagator::callback*, expr**, unsigned*, lbool*))decide_eh; + user_propagator::decide_eh_t c = (void(*)(void*, user_propagator::callback*, expr*, unsigned, bool))decide_eh; to_solver_ref(s)->user_propagate_register_decide(c); Z3_CATCH; } - void Z3_API Z3_solver_next_split(Z3_context c, Z3_solver_callback cb, Z3_ast t, unsigned idx, Z3_lbool phase) { + bool Z3_API Z3_solver_next_split(Z3_context c, Z3_solver_callback cb, Z3_ast t, unsigned idx, Z3_lbool phase) { Z3_TRY; LOG_Z3_solver_next_split(c, cb, t, idx, phase); RESET_ERROR_CODE(); - reinterpret_cast(cb)->next_split_cb(to_expr(t), idx, (lbool)phase); - Z3_CATCH; + return reinterpret_cast(cb)->next_split_cb(to_expr(t), idx, (lbool)phase); + Z3_CATCH_RETURN(false); } Z3_func_decl Z3_API Z3_solver_propagate_declare(Z3_context c, Z3_symbol name, unsigned n, Z3_sort* domain, Z3_sort range) { diff --git a/src/api/c++/z3++.h b/src/api/c++/z3++.h index 88b520147..88bbd2dcc 100644 --- a/src/api/c++/z3++.h +++ b/src/api/c++/z3++.h @@ -4240,7 +4240,7 @@ namespace z3 { typedef std::function final_eh_t; typedef std::function eq_eh_t; typedef std::function created_eh_t; - typedef std::function decide_eh_t; + typedef std::function decide_eh_t; final_eh_t m_final_eh; eq_eh_t m_eq_eh; @@ -4309,13 +4309,11 @@ namespace z3 { p->m_created_eh(e); } - static void decide_eh(void* _p, Z3_solver_callback cb, Z3_ast* _val, unsigned* bit, Z3_lbool* is_pos) { + static void decide_eh(void* _p, Z3_solver_callback cb, Z3_ast _val, unsigned bit, bool is_pos) { user_propagator_base* p = static_cast(_p); scoped_cb _cb(p, cb); - expr val(p->ctx(), *_val); - p->m_decide_eh(val, *bit, *is_pos); - // TBD: life time of val is within the scope of this callback. - *_val = val; + expr val(p->ctx(), _val); + p->m_decide_eh(val, bit, is_pos); } public: @@ -4435,7 +4433,7 @@ namespace z3 { } void register_decide() { - m_decide_eh = [this](expr& val, unsigned& bit, Z3_lbool& is_pos) { + m_decide_eh = [this](expr val, unsigned bit, bool is_pos) { decide(val, bit, is_pos); }; if (s) { @@ -4451,11 +4449,11 @@ namespace z3 { virtual void created(expr const& /*e*/) {} - virtual void decide(expr& /*val*/, unsigned& /*bit*/, Z3_lbool& /*is_pos*/) {} + virtual void decide(expr const& /*val*/, unsigned /*bit*/, bool /*is_pos*/) {} - void next_split(expr const & e, unsigned idx, Z3_lbool phase) { + bool next_split(expr const& e, unsigned idx, Z3_lbool phase) { assert(cb); - Z3_solver_next_split(ctx(), cb, e, idx, phase); + return Z3_solver_next_split(ctx(), cb, e, idx, phase); } /** diff --git a/src/api/dotnet/UserPropagator.cs b/src/api/dotnet/UserPropagator.cs index b9cd4dc39..af469ddff 100644 --- a/src/api/dotnet/UserPropagator.cs +++ b/src/api/dotnet/UserPropagator.cs @@ -58,12 +58,12 @@ namespace Microsoft.Z3 public delegate void CreatedEh(Expr term); /// - /// Delegate type for callback into solver's branching + /// Delegate type for callback into solver's branching. The values can be overriden by calling . + /// /// A bit-vector or Boolean used for branching /// If the term is a bit-vector, then an index into the bit-vector being branched on - /// Set phase to -1 (false) or 1 (true) to override solver's phase - /// - public delegate void DecideEh(ref Expr term, ref uint idx, ref int phase); + /// The tentative truth-value + public delegate void DecideEh(Expr term, uint idx, bool phase); // access managed objects through a static array. // thread safety is ignored for now. @@ -168,16 +168,11 @@ namespace Microsoft.Z3 prop.Callback(() => prop.created_eh(t), cb); } - static void _decide(voidp ctx, Z3_solver_callback cb, ref Z3_ast a, ref uint idx, ref int phase) + static void _decide(voidp ctx, Z3_solver_callback cb, Z3_ast a, uint idx, bool phase) { var prop = (UserPropagator)GCHandle.FromIntPtr(ctx).Target; - var t = Expr.Create(prop.ctx, a); - var u = t; - prop.callback = cb; - prop.decide_eh(ref t, ref idx, ref phase); - prop.callback = IntPtr.Zero; - if (u != t) - a = t.NativeObject; + using var t = Expr.Create(prop.ctx, a); + prop.Callback(() => prop.decide_eh(t, idx, phase), cb); } /// @@ -352,10 +347,17 @@ namespace Microsoft.Z3 /// /// Set the next decision + /// A bit-vector or Boolean used for branching. Use to clear + /// If the term is a bit-vector, then an index into the bit-vector being branched on + /// The tentative truth-value (-1/false, 1/true, 0/let Z3 decide) /// - public void NextSplit(Expr e, uint idx, int phase) + /// + /// in case the value was successfully set; + /// if the next split could not be set + /// + public bool NextSplit(Expr e, uint idx, int phase) { - Native.Z3_solver_next_split(ctx.nCtx, this.callback, e.NativeObject, idx, phase); + return Native.Z3_solver_next_split(ctx.nCtx, this.callback, e?.NativeObject ?? IntPtr.Zero, idx, phase) != 0; } /// diff --git a/src/api/java/NativeStatic.txt b/src/api/java/NativeStatic.txt index f68893e6b..6bd406dca 100644 --- a/src/api/java/NativeStatic.txt +++ b/src/api/java/NativeStatic.txt @@ -91,6 +91,7 @@ struct JavaInfo { jmethodID fixed = nullptr; jmethodID eq = nullptr; jmethodID final = nullptr; + jmethodID decide = nullptr; Z3_solver_callback cb = nullptr; }; @@ -146,11 +147,10 @@ static void final_eh(void* _p, Z3_solver_callback cb) { info->jenv->CallVoidMethod(info->jobj, info->final); } -// TODO: implement decide -static void decide_eh(void* _p, Z3_solver_callback cb, Z3_ast* _val, unsigned* bit, Z3_lbool* is_pos) { +static void decide_eh(void* _p, Z3_solver_callback cb, Z3_ast _val, unsigned bit, bool is_pos) { JavaInfo *info = static_cast(_p); ScopedCB scoped(info, cb); - + info->jenv->CallVoidMethod(info->jobj, info->decide, (jlong)_val); } DLL_VIS JNIEXPORT jlong JNICALL Java_com_microsoft_z3_Native_propagateInit(JNIEnv *jenv, jclass cls, jobject jobj, jlong ctx, jlong solver) { @@ -166,8 +166,9 @@ DLL_VIS JNIEXPORT jlong JNICALL Java_com_microsoft_z3_Native_propagateInit(JNIEn info->fixed = jenv->GetMethodID(jcls, "fixedWrapper", "(JJ)V"); info->eq = jenv->GetMethodID(jcls, "eqWrapper", "(JJ)V"); info->final = jenv->GetMethodID(jcls, "finWrapper", "()V"); - - if (!info->push || !info->pop || !info->fresh || !info->created || !info->fixed || !info->eq || !info->final) { + info->decide = jenv->GetMethodID(jcls, "decideWrapper", "(JII)V"); + + if (!info->push || !info->pop || !info->fresh || !info->created || !info->fixed || !info->eq || !info->final || !info->decide) { assert(false); } @@ -225,8 +226,8 @@ DLL_VIS JNIEXPORT void JNICALL Java_com_microsoft_z3_Native_propagateAdd(JNIEnv } } -DLL_VIS JNIEXPORT void JNICALL Java_com_microsoft_z3_Native_propagateNextSplit(JNIEnv * jenv, jclass cls, jobject jobj, jlong ctx, jlong solver, jlong javainfo, long e, long idx, long phase) { +DLL_VIS JNIEXPORT bool JNICALL Java_com_microsoft_z3_Native_propagateNextSplit(JNIEnv * jenv, jclass cls, jobject jobj, jlong ctx, jlong solver, jlong javainfo, long e, long idx, int phase) { JavaInfo *info = (JavaInfo*)javainfo; Z3_solver_callback cb = info->cb; - Z3_solver_next_split((Z3_context)ctx, cb, (Z3_ast)e, idx, Z3_lbool(phase)); + return Z3_solver_next_split((Z3_context)ctx, cb, (Z3_ast)e, idx, Z3_lbool(phase)); } diff --git a/src/api/java/UserPropagatorBase.java b/src/api/java/UserPropagatorBase.java index 90243ba68..407d3d0da 100644 --- a/src/api/java/UserPropagatorBase.java +++ b/src/api/java/UserPropagatorBase.java @@ -89,9 +89,9 @@ public abstract class UserPropagatorBase extends Native.UserPropagatorBase { fixed.length, AST.arrayToNative(fixed), lhs.length, AST.arrayToNative(lhs), AST.arrayToNative(rhs), conseq.getNativeObject()); } - public final void nextSplit(Expr e, long idx, Z3_lbool phase) { - Native.propagateNextSplit( + public final boolean nextSplit(Expr e, long idx, Z3_lbool phase) { + return Native.propagateNextSplit( this, ctx.nCtx(), solver.getNativeObject(), javainfo, e.getNativeObject(), idx, phase.toInt()); } -} \ No newline at end of file +} diff --git a/src/api/python/z3/z3.py b/src/api/python/z3/z3.py index 001776ec7..32c7a5763 100644 --- a/src/api/python/z3/z3.py +++ b/src/api/python/z3/z3.py @@ -11529,16 +11529,11 @@ def user_prop_diseq(ctx, cb, x, y): prop.diseq(x, y) prop.cb = None -# TODO The decision callback is not fully implemented. -# It needs to handle the ast*, unsigned* idx, and Z3_lbool* -def user_prop_decide(ctx, cb, t_ref, idx_ref, phase_ref): +def user_prop_decide(ctx, cb, t, idx, phase): prop = _prop_closures.get(ctx) prop.cb = cb t = _to_expr_ref(to_Ast(t_ref), prop.ctx()) - t, idx, phase = prop.decide(t, idx, phase) - t_ref = t - idx_ref = idx - phase_ref = phase + prop.decide(t, idx, phase) prop.cb = None @@ -11685,7 +11680,7 @@ class UserPropagateBase: # split on. A phase of true = 1/false = -1/undef = 0 = let solver decide is the last argument. # def next_split(self, t, idx, phase): - Z3_solver_next_split(self.ctx_ref(), ctypes.c_void_p(self.cb), t.ast, idx, phase) + return Z3_solver_next_split(self.ctx_ref(), ctypes.c_void_p(self.cb), t.ast, idx, phase) # # Propagation can only be invoked as during a fixed or final callback. diff --git a/src/api/z3_api.h b/src/api/z3_api.h index 0582ffa37..54974d57b 100644 --- a/src/api/z3_api.h +++ b/src/api/z3_api.h @@ -1435,7 +1435,7 @@ Z3_DECLARE_CLOSURE(Z3_fixed_eh, void, (void* ctx, Z3_solver_callback cb, Z3_as Z3_DECLARE_CLOSURE(Z3_eq_eh, void, (void* ctx, Z3_solver_callback cb, Z3_ast s, Z3_ast t)); Z3_DECLARE_CLOSURE(Z3_final_eh, void, (void* ctx, Z3_solver_callback cb)); Z3_DECLARE_CLOSURE(Z3_created_eh, void, (void* ctx, Z3_solver_callback cb, Z3_ast t)); -Z3_DECLARE_CLOSURE(Z3_decide_eh, void, (void* ctx, Z3_solver_callback cb, Z3_ast* t, unsigned* idx, Z3_lbool* phase)); +Z3_DECLARE_CLOSURE(Z3_decide_eh, void, (void* ctx, Z3_solver_callback cb, Z3_ast t, unsigned idx, bool phase)); Z3_DECLARE_CLOSURE(Z3_on_clause_eh, void, (void* ctx, Z3_ast proof_hint, Z3_ast_vector literals)); @@ -7098,20 +7098,21 @@ extern "C" { /** \brief register a callback when the solver decides to split on a registered expression. - The callback may set the passed expression to another registered expression which will be selected instead. - In case the expression is a bitvector the bit to split on is determined by the bit argument and the - truth-value to try first is given by is_pos. In case the truth value is undefined the solver will decide. + The callback may change the arguments by providing other values by calling \ref Z3_solver_next_split def_API('Z3_solver_propagate_decide', VOID, (_in(CONTEXT), _in(SOLVER), _fnptr(Z3_decide_eh))) */ void Z3_API Z3_solver_propagate_decide(Z3_context c, Z3_solver s, Z3_decide_eh decide_eh); /** - Sets the next expression to split on + Sets the next (registered) expression to split on. + The function returns false and ignores the given expression in case the expression is already assigned internally + (due to relevancy propagation, this assignments might not have been reported yet by the fixed callback). + In case the function is called in the decide callback, it overrides the currently selected variable and phase. - def_API('Z3_solver_next_split', VOID, (_in(CONTEXT), _in(SOLVER_CALLBACK), _in(AST), _in(UINT), _in(LBOOL))) + def_API('Z3_solver_next_split', BOOL, (_in(CONTEXT), _in(SOLVER_CALLBACK), _in(AST), _in(UINT), _in(LBOOL))) */ - void Z3_API Z3_solver_next_split(Z3_context c, Z3_solver_callback cb, Z3_ast t, unsigned idx, Z3_lbool phase); + bool Z3_API Z3_solver_next_split(Z3_context c, Z3_solver_callback cb, Z3_ast t, unsigned idx, Z3_lbool phase); /** Create uninterpreted function declaration for the user propagator. diff --git a/src/sat/smt/bv_solver.cpp b/src/sat/smt/bv_solver.cpp index a0bcea43b..f76e5ab70 100644 --- a/src/sat/smt/bv_solver.cpp +++ b/src/sat/smt/bv_solver.cpp @@ -128,7 +128,7 @@ namespace bv { /** \brief Find an unassigned bit for m_wpos[v], if such bit cannot be found invoke fixed_var_eh */ - void solver::find_wpos(theory_var v) { + bool solver::find_wpos(theory_var v) { literal_vector const& bits = m_bits[v]; unsigned sz = bits.size(); unsigned& wpos = m_wpos[v]; @@ -137,11 +137,12 @@ namespace bv { if (s().value(bits[idx]) == l_undef) { wpos = idx; TRACE("bv", tout << "moved wpos of v" << v << " to " << wpos << "\n";); - return; + return false; } } TRACE("bv", tout << "v" << v << " is a fixed variable.\n";); fixed_var_eh(v); + return true; } /** @@ -853,7 +854,17 @@ namespace bv { values[n->get_root_id()] = bv.mk_numeral(val, m_bits[v].size()); } - trail_stack& solver::get_trail_stack() { + sat::bool_var solver::get_bit(unsigned bit, euf::enode *n) const { + theory_var v = n->get_th_var(get_id()); + if (v == euf::null_theory_var) + return sat::null_bool_var; + auto &bits = m_bits[v]; + if (bit >= bits.size()) + return sat::null_bool_var; + return bits[bit].var(); + } + + trail_stack &solver::get_trail_stack() { return ctx.get_trail_stack(); } diff --git a/src/sat/smt/bv_solver.h b/src/sat/smt/bv_solver.h index dc9cd1456..91e485a9f 100644 --- a/src/sat/smt/bv_solver.h +++ b/src/sat/smt/bv_solver.h @@ -321,7 +321,7 @@ namespace bv { // solving theory_var find(theory_var v) const { return m_find.find(v); } - void find_wpos(theory_var v); + bool find_wpos(theory_var v); void find_new_diseq_axioms(atom& a, theory_var v, unsigned idx); void mk_new_diseq_axiom(theory_var v1, theory_var v2, unsigned idx); bool get_fixed_value(theory_var v, numeral& result) const; @@ -334,7 +334,6 @@ namespace bv { numeral const& power2(unsigned i) const; sat::literal mk_true(); - // invariants bool check_zero_one_bits(theory_var v); void check_missing_propagation() const; @@ -391,6 +390,7 @@ namespace bv { euf::theory_var mk_var(euf::enode* n) override; void apply_sort_cnstr(euf::enode * n, sort * s) override; + bool_var get_bit(unsigned bit, euf::enode* n) const; void merge_eh(theory_var, theory_var, theory_var v1, theory_var v2); void after_merge_eh(theory_var r1, theory_var r2, theory_var v1, theory_var v2) { SASSERT(check_zero_one_bits(r1)); } diff --git a/src/sat/smt/user_solver.cpp b/src/sat/smt/user_solver.cpp index 34f2b10b4..1e8897b8c 100644 --- a/src/sat/smt/user_solver.cpp +++ b/src/sat/smt/user_solver.cpp @@ -15,8 +15,9 @@ Author: --*/ -#include "sat/smt/user_solver.h" +#include "sat/smt/bv_solver.h" #include "sat/smt/euf_solver.h" +#include "sat/smt/user_solver.h" namespace user_solver { @@ -39,7 +40,7 @@ namespace user_solver { expr_ref r(m); sat::literal_vector explain; if (ctx.is_fixed(n, r, explain)) - m_prop.push_back(prop_info(explain, v, r)); + m_prop.push_back(prop_info(explain, v, r)); } void solver::propagate_cb( @@ -56,17 +57,21 @@ namespace user_solver { void solver::register_cb(expr* e) { add_expr(e); } - - void solver::next_split_cb(expr* e, unsigned idx, lbool phase) { + + bool solver::next_split_cb(expr* e, unsigned idx, lbool phase) { if (e == nullptr) { - m_next_split_expr = nullptr; - return; + m_next_split_var = sat::null_bool_var; + return true; } force_push(); ctx.internalize(e); - m_next_split_expr = e; - m_next_split_idx = idx; + sat::bool_var var = enode_to_bool(ctx.get_enode(e), idx); m_next_split_phase = phase; + if (var == sat::null_bool_var || s().value(var) != l_undef) + return false; + m_next_split_var = var; + m_next_split_phase = phase; + return true; } sat::check_result solver::check() { @@ -84,39 +89,41 @@ namespace user_solver { m_id2justification.setx(v, sat::literal_vector(num_lits, jlits), sat::literal_vector()); m_fixed_eh(m_user_context, this, var2expr(v), value); } - + bool solver::decide(sat::bool_var& var, lbool& phase) { - + if (!m_decide_eh) return false; - + euf::enode* original_enode = bool_var2enode(var); - + if (!original_enode || !is_attached_to_var(original_enode)) return false; - + unsigned new_bit = 0; // ignored; currently no bv-support expr* e = original_enode->get_expr(); - - m_decide_eh(m_user_context, this, &e, &new_bit, &phase); - - euf::enode* new_enode = ctx.get_enode(e); - - if (original_enode == new_enode || new_enode->bool_var() == sat::null_bool_var) + + m_decide_eh(m_user_context, this, e, new_bit, phase); + sat::bool_var new_var; + if (!get_case_split(new_var, phase) || new_var == var) + // The user did not interfere return false; - - var = new_enode->bool_var(); + var = new_var; + + // check if the new variable is unassigned + if (s().value(var) != l_undef) + throw default_exception("expression in \"decide\" is already assigned"); return true; } - - bool solver::get_case_split(sat::bool_var& var, lbool& phase){ - if (!m_next_split_expr) + + bool solver::get_case_split(sat::bool_var& var, lbool& phase) { + if (m_next_split_var == sat::null_bool_var) return false; - - euf::enode* n = ctx.get_enode(m_next_split_expr); - var = n->bool_var(); + + var = m_next_split_var; phase = m_next_split_phase; - m_next_split_expr = nullptr; + m_next_split_var = sat::null_bool_var; + m_next_split_phase = l_undef; return true; } @@ -134,14 +141,14 @@ namespace user_solver { m_id2justification.setx(v, lits, sat::literal_vector()); m_fixed_eh(m_user_context, this, var2expr(v), lit.sign() ? m.mk_false() : m.mk_true()); } - + void solver::new_eq_eh(euf::th_eq const& eq) { if (!m_eq_eh) return; force_push(); m_eq_eh(m_user_context, this, var2expr(eq.v1()), var2expr(eq.v2())); } - + void solver::new_diseq_eh(euf::th_eq const& de) { if (!m_diseq_eh) return; @@ -188,7 +195,7 @@ namespace user_solver { propagate_consequence(prop); else propagate_new_fixed(prop); - } + } return np < m_stats.m_num_propagations; } @@ -208,7 +215,7 @@ namespace user_solver { auto& j = justification::from_index(idx); auto const& prop = m_prop[j.m_propagation_index]; for (unsigned id : prop.m_ids) - r.append(m_id2justification[id]); + r.append(m_id2justification[id]); for (auto const& p : prop.m_eqs) ctx.add_antecedent(probing, expr2enode(p.first), expr2enode(p.second)); } @@ -243,7 +250,7 @@ namespace user_solver { } std::ostream& solver::display_constraint(std::ostream& out, sat::ext_constraint_idx idx) const { - return display_justification(out, idx); + return display_justification(out, idx); } euf::th_solver* solver::clone(euf::solver& dst_ctx) { @@ -278,26 +285,35 @@ namespace user_solver { return true; } m_stack.push_back(sat::eframe(e)); - return false; + return false; } - + bool solver::visited(expr* e) { euf::enode* n = expr2enode(e); - return n && n->is_attached_to(get_id()); + return n && n->is_attached_to(get_id()); } - + bool solver::post_visit(expr* e, bool sign, bool root) { euf::enode* n = expr2enode(e); SASSERT(!n || !n->is_attached_to(get_id())); - if (!n) - n = mk_enode(e, false); + if (!n) + n = mk_enode(e, false); add_expr(e); if (m_created_eh) m_created_eh(m_user_context, this, e); return true; } - + sat::bool_var solver::enode_to_bool(euf::enode* n, unsigned idx) { + if (n->bool_var() != sat::null_bool_var) { + // expression is a boolean + return n->bool_var(); + } + // expression is a bit-vector + bv_util bv(m); + th_solver* th = ctx.fid2solver(bv.get_fid()); + return ((bv::solver*) th)->get_bit(idx, n); + } } diff --git a/src/sat/smt/user_solver.h b/src/sat/smt/user_solver.h index cb1c6fe94..bd1b703e0 100644 --- a/src/sat/smt/user_solver.h +++ b/src/sat/smt/user_solver.h @@ -75,9 +75,8 @@ namespace user_solver { euf::enode_pair_vector m_eqs; unsigned_vector m_fixed_ids; stats m_stats; - expr* m_next_split_expr = nullptr; - unsigned m_next_split_idx; - lbool m_next_split_phase; + sat::bool_var m_next_split_var = sat::null_bool_var; + lbool m_next_split_phase = l_undef; struct justification { unsigned m_propagation_index { 0 }; @@ -104,6 +103,8 @@ namespace user_solver { bool visited(expr* e) override; bool post_visit(expr* e, bool sign, bool root) override; + sat::bool_var enode_to_bool(euf::enode* n, unsigned idx); + public: solver(euf::solver& ctx); @@ -136,7 +137,7 @@ namespace user_solver { void propagate_cb(unsigned num_fixed, expr* const* fixed_ids, unsigned num_eqs, expr* const* lhs, expr* const* rhs, expr* conseq) override; void register_cb(expr* e) override; - void next_split_cb(expr* e, unsigned idx, lbool phase) override; + bool next_split_cb(expr* e, unsigned idx, lbool phase) override; void new_fixed_eh(euf::theory_var v, expr* value, unsigned num_lits, sat::literal const* jlits); diff --git a/src/smt/smt_context.cpp b/src/smt/smt_context.cpp index 82f56ab76..4125b24de 100644 --- a/src/smt/smt_context.cpp +++ b/src/smt/smt_context.cpp @@ -2918,7 +2918,9 @@ namespace smt { bool context::has_split_candidate(bool_var& var, bool& is_pos) { if (!m_user_propagator) return false; - return m_user_propagator->get_case_split(var, is_pos); + if (!m_user_propagator->get_case_split(var, is_pos)) + return false; + return get_assignment(var) == l_undef; } bool context::decide_user_interference(bool_var& var, bool& is_pos) { diff --git a/src/smt/theory_bv.cpp b/src/smt/theory_bv.cpp index 7adab35f4..a44e1a1aa 100644 --- a/src/smt/theory_bv.cpp +++ b/src/smt/theory_bv.cpp @@ -1889,21 +1889,14 @@ namespace smt { return var_enode_pos(nullptr, UINT32_MAX); } - bool_var theory_bv::get_first_unassigned(unsigned start_bit, enode* n) const { + bool_var theory_bv::get_bit(unsigned bit, enode* n) const { theory_var v = n->get_th_var(get_family_id()); + if (v == null_theory_var) + return null_bool_var; auto& bits = m_bits[v]; - unsigned sz = bits.size(); - - for (unsigned i = start_bit; i < sz; ++i) { - if (ctx.get_assignment(bits[i].var()) == l_undef) - return bits[i].var(); - } - for (unsigned i = 0; i < start_bit; ++i) { - if (ctx.get_assignment(bits[i].var()) == l_undef) - return bits[i].var(); - } - - return null_bool_var; + if (bit >= bits.size()) + return null_bool_var; + return bits[bit].var(); } bool theory_bv::check_assignment(theory_var v) { diff --git a/src/smt/theory_bv.h b/src/smt/theory_bv.h index 73d659c68..10cf005e3 100644 --- a/src/smt/theory_bv.h +++ b/src/smt/theory_bv.h @@ -291,7 +291,7 @@ namespace smt { bool is_fixed_propagated(theory_var v, expr_ref& val, literal_vector& explain) override; var_enode_pos get_bv_with_theory(bool_var v, theory_id id) const; - bool_var get_first_unassigned(unsigned start_bit, enode* n) const; + bool_var get_bit(unsigned bit, enode* n) const; bool check_assignment(theory_var v); bool check_invariant(); diff --git a/src/smt/theory_user_propagator.cpp b/src/smt/theory_user_propagator.cpp index 8eeaf4382..2d5b4917d 100644 --- a/src/smt/theory_user_propagator.cpp +++ b/src/smt/theory_user_propagator.cpp @@ -107,15 +107,18 @@ void theory_user_propagator::register_cb(expr* e) { add_expr(e, true); } -void theory_user_propagator::next_split_cb(expr* e, unsigned idx, lbool phase) { +bool theory_user_propagator::next_split_cb(expr* e, unsigned idx, lbool phase) { if (e == nullptr) { // clear - m_next_split_expr = nullptr; - return; + m_next_split_var = null_bool_var; + return true; } ensure_enode(e); - m_next_split_expr = e; - m_next_split_idx = idx; + bool_var b = enode_to_bool(ctx.get_enode(e), idx); + if (b == null_bool_var || ctx.get_assignment(b) != l_undef) + return false; + m_next_split_var = b; m_next_split_phase = phase; + return true; } theory * theory_user_propagator::mk_fresh(context * new_ctx) { @@ -174,18 +177,15 @@ void theory_user_propagator::new_fixed_eh(theory_var v, expr* value, unsigned nu } } -bool_var theory_user_propagator::enode_to_bool(enode* n, unsigned bit) { +bool_var theory_user_propagator::enode_to_bool(enode* n, unsigned idx) { if (n->is_bool()) { // expression is a boolean - bool_var new_var = ctx.enode2bool_var(n); - if (ctx.get_assignment(new_var) == l_undef) - return new_var; - return null_bool_var; + return ctx.enode2bool_var(n); } // expression is a bit-vector bv_util bv(m); auto th_bv = (theory_bv*)ctx.get_theory(bv.get_fid()); - return th_bv->get_first_unassigned(bit, n); + return th_bv->get_bit(idx, n); } void theory_user_propagator::decide(bool_var& var, bool& is_pos) { @@ -225,7 +225,7 @@ void theory_user_propagator::decide(bool_var& var, bool& is_pos) { if (v == null_theory_var) { // it is not a registered boolean value but it is a bitvector - auto registered_bv = ((theory_bv*)th)->get_bv_with_theory(var, get_family_id()); + auto registered_bv = ((theory_bv *) th)->get_bv_with_theory(var, get_family_id()); if (!registered_bv.first) // there is no registered bv associated with the bit return; @@ -236,47 +236,33 @@ void theory_user_propagator::decide(bool_var& var, bool& is_pos) { // call the registered callback unsigned new_bit = original_bit; - lbool phase = is_pos ? l_true : l_false; - - expr* e = var2expr(v); - m_decide_eh(m_user_context, this, &e, &new_bit, &phase); - enode* new_enode = ctx.get_enode(e); - // check if the callback changed something - if (original_enode == new_enode && (new_enode->is_bool() || original_bit == new_bit)) { - if (phase != l_undef) - // it only affected the truth value - is_pos = phase == l_true; + expr *e = var2expr(v); + m_decide_eh(m_user_context, this, e, new_bit, is_pos); + + bool_var new_var; + if (!get_case_split(new_var, is_pos) || new_var == var) + // The user did not interfere return; - } + var = new_var; - // get unassigned variable from enode - var = enode_to_bool(new_enode, new_bit); - - if (var == null_bool_var) - // selected variable is already assigned + // check if the new variable is unassigned + if (ctx.get_assignment(var) != l_undef) throw default_exception("expression in \"decide\" is already assigned"); - - // in case the callback did not decide on a truth value -> let Z3 decide - is_pos = ctx.guess(var, phase); } -bool theory_user_propagator::get_case_split(bool_var& var, bool& is_pos){ - if (!m_next_split_expr) +bool theory_user_propagator::get_case_split(bool_var& var, bool& is_pos) { + if (m_next_split_var == null_bool_var) return false; - enode* n = ctx.get_enode(m_next_split_expr); - - var = enode_to_bool(n, m_next_split_idx); - - if (var == null_bool_var) - return false; - + + var = m_next_split_var; is_pos = ctx.guess(var, m_next_split_phase); - m_next_split_expr = nullptr; + m_next_split_var = null_bool_var; + m_next_split_phase = l_undef; return true; } -void theory_user_propagator::push_scope_eh() { +void theory_user_propagator::push_scope_eh() { ++m_num_scopes; } @@ -421,9 +407,9 @@ bool theory_user_propagator::internalize_term(app* term) { return true; } -void theory_user_propagator::collect_statistics(::statistics & st) const { +void theory_user_propagator::collect_statistics(::statistics& st) const { st.update("user-propagations", m_stats.m_num_propagations); - st.update("user-watched", get_num_vars()); + st.update("user-watched", get_num_vars()); } diff --git a/src/smt/theory_user_propagator.h b/src/smt/theory_user_propagator.h index 2ed1acbdf..5a6eafc0a 100644 --- a/src/smt/theory_user_propagator.h +++ b/src/smt/theory_user_propagator.h @@ -83,9 +83,8 @@ namespace smt { expr_ref_vector m_to_add; unsigned_vector m_to_add_lim; unsigned m_to_add_qhead = 0; - expr* m_next_split_expr = nullptr; - unsigned m_next_split_idx; - lbool m_next_split_phase; + bool_var m_next_split_var = null_bool_var; + lbool m_next_split_phase = l_undef; expr* var2expr(theory_var v) { return m_var2expr.get(v); } theory_var expr2var(expr* e) { check_defined(e); return m_expr2var[e->get_id()]; } @@ -133,7 +132,7 @@ namespace smt { void propagate_cb(unsigned num_fixed, expr* const* fixed_ids, unsigned num_eqs, expr* const* lhs, expr* const* rhs, expr* conseq) override; void register_cb(expr* e) override; - void next_split_cb(expr* e, unsigned idx, lbool phase) override; + bool next_split_cb(expr* e, unsigned idx, lbool phase) override; void new_fixed_eh(theory_var v, expr* value, unsigned num_lits, literal const* jlits); void decide(bool_var& var, bool& is_pos); diff --git a/src/tactic/user_propagator_base.h b/src/tactic/user_propagator_base.h index 68e55be75..d4dae5166 100644 --- a/src/tactic/user_propagator_base.h +++ b/src/tactic/user_propagator_base.h @@ -11,7 +11,7 @@ namespace user_propagator { virtual ~callback() = default; virtual void propagate_cb(unsigned num_fixed, expr* const* fixed_ids, unsigned num_eqs, expr* const* eq_lhs, expr* const* eq_rhs, expr* conseq) = 0; virtual void register_cb(expr* e) = 0; - virtual void next_split_cb(expr* e, unsigned idx, lbool phase) = 0; + virtual bool next_split_cb(expr* e, unsigned idx, lbool phase) = 0; }; class context_obj { @@ -26,7 +26,7 @@ namespace user_propagator { typedef std::function push_eh_t; typedef std::function pop_eh_t; typedef std::function created_eh_t; - typedef std::function decide_eh_t; + typedef std::function decide_eh_t; typedef std::function on_clause_eh_t; class plugin : public decl_plugin {