diff --git a/scripts/update_api.py b/scripts/update_api.py index 23d044832..89a895e3b 100755 --- a/scripts/update_api.py +++ b/scripts/update_api.py @@ -642,7 +642,7 @@ def mk_java(java_src, java_dir, package_name): public static native void propagateRegisterFinal(Object o, long ctx, long solver); public static native void propagateConflict(Object o, long ctx, long solver, long javainfo, int num_fixed, long[] fixed, long num_eqs, long[] eq_lhs, long[] eq_rhs, long conseq); public static native void propagateAdd(Object o, long ctx, long solver, long javainfo, long e); - public static native void propagateNextSplit(Object o, long ctx, long solver, long javainfo, long e, long idx, long phase); + public static native boolean propagateNextSplit(Object o, long ctx, long solver, long javainfo, long e, long idx, int phase); public static native void propagateDestroy(Object o, long ctx, long solver, long javainfo); public static abstract class UserPropagatorBase implements AutoCloseable { @@ -765,7 +765,7 @@ def mk_java(java_src, java_dir, package_name): java_wrapper.write(line) for name, result, params in _dotnet_decls: java_wrapper.write('DLL_VIS JNIEXPORT %s JNICALL Java_%s_Native_INTERNAL%s(JNIEnv * jenv, jclass cls' % (type2javaw(result), pkg_str, java_method_name(name))) - i = 0 + i = 0 for param in params: java_wrapper.write(', ') java_wrapper.write('%s a%d' % (param2javaw(param), i)) @@ -1929,7 +1929,7 @@ Z3_final_eh = ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_void_p) Z3_eq_eh = ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p) Z3_created_eh = ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p) -Z3_decide_eh = ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p) +Z3_decide_eh = ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_uint, ctypes.c_int) _lib.Z3_solver_register_on_clause.restype = None _lib.Z3_solver_propagate_init.restype = None diff --git a/src/api/api_quant.cpp b/src/api/api_quant.cpp index 885cf6598..0da5c5a92 100644 --- a/src/api/api_quant.cpp +++ b/src/api/api_quant.cpp @@ -249,7 +249,10 @@ extern "C" { expr_abstract(mk_c(c)->m(), 0, num_bound, bound_asts.data(), pat, result); SASSERT(result.get()->get_kind() == AST_APP); pinned.push_back(result.get()); - SASSERT(mk_c(c)->m().is_pattern(result.get())); + if (!mk_c(c)->m().is_pattern(result.get())) { + SET_ERROR_CODE(Z3_INVALID_ARG, "invalid pattern"); + RETURN_Z3(nullptr); + } _patterns.push_back(of_pattern(result.get())); } svector _no_patterns; diff --git a/src/api/api_solver.cpp b/src/api/api_solver.cpp index 2c19d0d9e..08f864226 100644 --- a/src/api/api_solver.cpp +++ b/src/api/api_solver.cpp @@ -1114,17 +1114,17 @@ extern "C" { void Z3_API Z3_solver_propagate_decide(Z3_context c, Z3_solver s, Z3_decide_eh decide_eh) { Z3_TRY; RESET_ERROR_CODE(); - user_propagator::decide_eh_t c = (void(*)(void*, user_propagator::callback*, expr**, unsigned*, lbool*))decide_eh; + user_propagator::decide_eh_t c = (void(*)(void*, user_propagator::callback*, expr*, unsigned, bool))decide_eh; to_solver_ref(s)->user_propagate_register_decide(c); Z3_CATCH; } - void Z3_API Z3_solver_next_split(Z3_context c, Z3_solver_callback cb, Z3_ast t, unsigned idx, Z3_lbool phase) { + bool Z3_API Z3_solver_next_split(Z3_context c, Z3_solver_callback cb, Z3_ast t, unsigned idx, Z3_lbool phase) { Z3_TRY; LOG_Z3_solver_next_split(c, cb, t, idx, phase); RESET_ERROR_CODE(); - reinterpret_cast(cb)->next_split_cb(to_expr(t), idx, (lbool)phase); - Z3_CATCH; + return reinterpret_cast(cb)->next_split_cb(to_expr(t), idx, (lbool)phase); + Z3_CATCH_RETURN(false); } Z3_func_decl Z3_API Z3_solver_propagate_declare(Z3_context c, Z3_symbol name, unsigned n, Z3_sort* domain, Z3_sort range) { diff --git a/src/api/c++/z3++.h b/src/api/c++/z3++.h index 88b520147..88bbd2dcc 100644 --- a/src/api/c++/z3++.h +++ b/src/api/c++/z3++.h @@ -4240,7 +4240,7 @@ namespace z3 { typedef std::function final_eh_t; typedef std::function eq_eh_t; typedef std::function created_eh_t; - typedef std::function decide_eh_t; + typedef std::function decide_eh_t; final_eh_t m_final_eh; eq_eh_t m_eq_eh; @@ -4309,13 +4309,11 @@ namespace z3 { p->m_created_eh(e); } - static void decide_eh(void* _p, Z3_solver_callback cb, Z3_ast* _val, unsigned* bit, Z3_lbool* is_pos) { + static void decide_eh(void* _p, Z3_solver_callback cb, Z3_ast _val, unsigned bit, bool is_pos) { user_propagator_base* p = static_cast(_p); scoped_cb _cb(p, cb); - expr val(p->ctx(), *_val); - p->m_decide_eh(val, *bit, *is_pos); - // TBD: life time of val is within the scope of this callback. - *_val = val; + expr val(p->ctx(), _val); + p->m_decide_eh(val, bit, is_pos); } public: @@ -4435,7 +4433,7 @@ namespace z3 { } void register_decide() { - m_decide_eh = [this](expr& val, unsigned& bit, Z3_lbool& is_pos) { + m_decide_eh = [this](expr val, unsigned bit, bool is_pos) { decide(val, bit, is_pos); }; if (s) { @@ -4451,11 +4449,11 @@ namespace z3 { virtual void created(expr const& /*e*/) {} - virtual void decide(expr& /*val*/, unsigned& /*bit*/, Z3_lbool& /*is_pos*/) {} + virtual void decide(expr const& /*val*/, unsigned /*bit*/, bool /*is_pos*/) {} - void next_split(expr const & e, unsigned idx, Z3_lbool phase) { + bool next_split(expr const& e, unsigned idx, Z3_lbool phase) { assert(cb); - Z3_solver_next_split(ctx(), cb, e, idx, phase); + return Z3_solver_next_split(ctx(), cb, e, idx, phase); } /** diff --git a/src/api/dotnet/UserPropagator.cs b/src/api/dotnet/UserPropagator.cs index b9cd4dc39..af469ddff 100644 --- a/src/api/dotnet/UserPropagator.cs +++ b/src/api/dotnet/UserPropagator.cs @@ -58,12 +58,12 @@ namespace Microsoft.Z3 public delegate void CreatedEh(Expr term); /// - /// Delegate type for callback into solver's branching + /// Delegate type for callback into solver's branching. The values can be overriden by calling . + /// /// A bit-vector or Boolean used for branching /// If the term is a bit-vector, then an index into the bit-vector being branched on - /// Set phase to -1 (false) or 1 (true) to override solver's phase - /// - public delegate void DecideEh(ref Expr term, ref uint idx, ref int phase); + /// The tentative truth-value + public delegate void DecideEh(Expr term, uint idx, bool phase); // access managed objects through a static array. // thread safety is ignored for now. @@ -168,16 +168,11 @@ namespace Microsoft.Z3 prop.Callback(() => prop.created_eh(t), cb); } - static void _decide(voidp ctx, Z3_solver_callback cb, ref Z3_ast a, ref uint idx, ref int phase) + static void _decide(voidp ctx, Z3_solver_callback cb, Z3_ast a, uint idx, bool phase) { var prop = (UserPropagator)GCHandle.FromIntPtr(ctx).Target; - var t = Expr.Create(prop.ctx, a); - var u = t; - prop.callback = cb; - prop.decide_eh(ref t, ref idx, ref phase); - prop.callback = IntPtr.Zero; - if (u != t) - a = t.NativeObject; + using var t = Expr.Create(prop.ctx, a); + prop.Callback(() => prop.decide_eh(t, idx, phase), cb); } /// @@ -352,10 +347,17 @@ namespace Microsoft.Z3 /// /// Set the next decision + /// A bit-vector or Boolean used for branching. Use to clear + /// If the term is a bit-vector, then an index into the bit-vector being branched on + /// The tentative truth-value (-1/false, 1/true, 0/let Z3 decide) /// - public void NextSplit(Expr e, uint idx, int phase) + /// + /// in case the value was successfully set; + /// if the next split could not be set + /// + public bool NextSplit(Expr e, uint idx, int phase) { - Native.Z3_solver_next_split(ctx.nCtx, this.callback, e.NativeObject, idx, phase); + return Native.Z3_solver_next_split(ctx.nCtx, this.callback, e?.NativeObject ?? IntPtr.Zero, idx, phase) != 0; } /// diff --git a/src/api/java/NativeStatic.txt b/src/api/java/NativeStatic.txt index f68893e6b..f076a817c 100644 --- a/src/api/java/NativeStatic.txt +++ b/src/api/java/NativeStatic.txt @@ -91,6 +91,7 @@ struct JavaInfo { jmethodID fixed = nullptr; jmethodID eq = nullptr; jmethodID final = nullptr; + jmethodID decide = nullptr; Z3_solver_callback cb = nullptr; }; @@ -146,11 +147,10 @@ static void final_eh(void* _p, Z3_solver_callback cb) { info->jenv->CallVoidMethod(info->jobj, info->final); } -// TODO: implement decide -static void decide_eh(void* _p, Z3_solver_callback cb, Z3_ast* _val, unsigned* bit, Z3_lbool* is_pos) { +static void decide_eh(void* _p, Z3_solver_callback cb, Z3_ast _val, unsigned bit, bool is_pos) { JavaInfo *info = static_cast(_p); ScopedCB scoped(info, cb); - + info->jenv->CallVoidMethod(info->jobj, info->decide, (jlong)_val); } DLL_VIS JNIEXPORT jlong JNICALL Java_com_microsoft_z3_Native_propagateInit(JNIEnv *jenv, jclass cls, jobject jobj, jlong ctx, jlong solver) { @@ -166,8 +166,9 @@ DLL_VIS JNIEXPORT jlong JNICALL Java_com_microsoft_z3_Native_propagateInit(JNIEn info->fixed = jenv->GetMethodID(jcls, "fixedWrapper", "(JJ)V"); info->eq = jenv->GetMethodID(jcls, "eqWrapper", "(JJ)V"); info->final = jenv->GetMethodID(jcls, "finWrapper", "()V"); - - if (!info->push || !info->pop || !info->fresh || !info->created || !info->fixed || !info->eq || !info->final) { + info->decide = jenv->GetMethodID(jcls, "decideWrapper", "(JII)V"); + + if (!info->push || !info->pop || !info->fresh || !info->created || !info->fixed || !info->eq || !info->final || !info->decide) { assert(false); } @@ -202,7 +203,7 @@ DLL_VIS JNIEXPORT void JNICALL Java_com_microsoft_z3_Native_propagateRegisterDec Z3_solver_propagate_decide((Z3_context)ctx, (Z3_solver)solver, decide_eh); } -DLL_VIS JNIEXPORT void JNICALL Java_com_microsoft_z3_Native_propagateConflict(JNIEnv * jenv, jclass cls, jobject jobj, jlong ctx, jlong solver, jlong javainfo, jlong num_fixed, jlongArray fixed, jlong num_eqs, jlongArray eq_lhs, jlongArray eq_rhs, jlong conseq) { +DLL_VIS JNIEXPORT void JNICALL Java_com_microsoft_z3_Native_propagateConflict(JNIEnv * jenv, jclass cls, jobject jobj, jlong ctx, jlong solver, jlong javainfo, long num_fixed, jlongArray fixed, long num_eqs, jlongArray eq_lhs, jlongArray eq_rhs, jlong conseq) { JavaInfo *info = (JavaInfo*)javainfo; GETLONGAELEMS(Z3_ast, fixed, _fixed); GETLONGAELEMS(Z3_ast, eq_lhs, _eq_lhs); @@ -225,8 +226,8 @@ DLL_VIS JNIEXPORT void JNICALL Java_com_microsoft_z3_Native_propagateAdd(JNIEnv } } -DLL_VIS JNIEXPORT void JNICALL Java_com_microsoft_z3_Native_propagateNextSplit(JNIEnv * jenv, jclass cls, jobject jobj, jlong ctx, jlong solver, jlong javainfo, long e, long idx, long phase) { +DLL_VIS JNIEXPORT bool JNICALL Java_com_microsoft_z3_Native_propagateNextSplit(JNIEnv * jenv, jclass cls, jobject jobj, jlong ctx, jlong solver, jlong javainfo, jlong e, long idx, int phase) { JavaInfo *info = (JavaInfo*)javainfo; Z3_solver_callback cb = info->cb; - Z3_solver_next_split((Z3_context)ctx, cb, (Z3_ast)e, idx, Z3_lbool(phase)); + return Z3_solver_next_split((Z3_context)ctx, cb, (Z3_ast)e, idx, Z3_lbool(phase)); } diff --git a/src/api/java/UserPropagatorBase.java b/src/api/java/UserPropagatorBase.java index 90243ba68..407d3d0da 100644 --- a/src/api/java/UserPropagatorBase.java +++ b/src/api/java/UserPropagatorBase.java @@ -89,9 +89,9 @@ public abstract class UserPropagatorBase extends Native.UserPropagatorBase { fixed.length, AST.arrayToNative(fixed), lhs.length, AST.arrayToNative(lhs), AST.arrayToNative(rhs), conseq.getNativeObject()); } - public final void nextSplit(Expr e, long idx, Z3_lbool phase) { - Native.propagateNextSplit( + public final boolean nextSplit(Expr e, long idx, Z3_lbool phase) { + return Native.propagateNextSplit( this, ctx.nCtx(), solver.getNativeObject(), javainfo, e.getNativeObject(), idx, phase.toInt()); } -} \ No newline at end of file +} diff --git a/src/api/python/z3/z3.py b/src/api/python/z3/z3.py index 6b79dd1fe..32c7a5763 100644 --- a/src/api/python/z3/z3.py +++ b/src/api/python/z3/z3.py @@ -7136,6 +7136,13 @@ class Solver(Z3PPObject): """Import model converter from other into the current solver""" Z3_solver_import_model_converter(self.ctx.ref(), other.solver, self.solver) + def interrupt(self): + """Interrupt the execution of the solver object. + Remarks: This ensures that the interrupt applies only + to the given solver object and it applies only if it is running. + """ + Z3_solver_interrupt(self.ctx.ref(), self.solver) + def unsat_core(self): """Return a subset (as an AST vector) of the assumptions provided to the last check(). @@ -11522,16 +11529,11 @@ def user_prop_diseq(ctx, cb, x, y): prop.diseq(x, y) prop.cb = None -# TODO The decision callback is not fully implemented. -# It needs to handle the ast*, unsigned* idx, and Z3_lbool* -def user_prop_decide(ctx, cb, t_ref, idx_ref, phase_ref): +def user_prop_decide(ctx, cb, t, idx, phase): prop = _prop_closures.get(ctx) prop.cb = cb t = _to_expr_ref(to_Ast(t_ref), prop.ctx()) - t, idx, phase = prop.decide(t, idx, phase) - t_ref = t - idx_ref = idx - phase_ref = phase + prop.decide(t, idx, phase) prop.cb = None @@ -11678,7 +11680,7 @@ class UserPropagateBase: # split on. A phase of true = 1/false = -1/undef = 0 = let solver decide is the last argument. # def next_split(self, t, idx, phase): - Z3_solver_next_split(self.ctx_ref(), ctypes.c_void_p(self.cb), t.ast, idx, phase) + return Z3_solver_next_split(self.ctx_ref(), ctypes.c_void_p(self.cb), t.ast, idx, phase) # # Propagation can only be invoked as during a fixed or final callback. diff --git a/src/api/z3_api.h b/src/api/z3_api.h index 0582ffa37..54974d57b 100644 --- a/src/api/z3_api.h +++ b/src/api/z3_api.h @@ -1435,7 +1435,7 @@ Z3_DECLARE_CLOSURE(Z3_fixed_eh, void, (void* ctx, Z3_solver_callback cb, Z3_as Z3_DECLARE_CLOSURE(Z3_eq_eh, void, (void* ctx, Z3_solver_callback cb, Z3_ast s, Z3_ast t)); Z3_DECLARE_CLOSURE(Z3_final_eh, void, (void* ctx, Z3_solver_callback cb)); Z3_DECLARE_CLOSURE(Z3_created_eh, void, (void* ctx, Z3_solver_callback cb, Z3_ast t)); -Z3_DECLARE_CLOSURE(Z3_decide_eh, void, (void* ctx, Z3_solver_callback cb, Z3_ast* t, unsigned* idx, Z3_lbool* phase)); +Z3_DECLARE_CLOSURE(Z3_decide_eh, void, (void* ctx, Z3_solver_callback cb, Z3_ast t, unsigned idx, bool phase)); Z3_DECLARE_CLOSURE(Z3_on_clause_eh, void, (void* ctx, Z3_ast proof_hint, Z3_ast_vector literals)); @@ -7098,20 +7098,21 @@ extern "C" { /** \brief register a callback when the solver decides to split on a registered expression. - The callback may set the passed expression to another registered expression which will be selected instead. - In case the expression is a bitvector the bit to split on is determined by the bit argument and the - truth-value to try first is given by is_pos. In case the truth value is undefined the solver will decide. + The callback may change the arguments by providing other values by calling \ref Z3_solver_next_split def_API('Z3_solver_propagate_decide', VOID, (_in(CONTEXT), _in(SOLVER), _fnptr(Z3_decide_eh))) */ void Z3_API Z3_solver_propagate_decide(Z3_context c, Z3_solver s, Z3_decide_eh decide_eh); /** - Sets the next expression to split on + Sets the next (registered) expression to split on. + The function returns false and ignores the given expression in case the expression is already assigned internally + (due to relevancy propagation, this assignments might not have been reported yet by the fixed callback). + In case the function is called in the decide callback, it overrides the currently selected variable and phase. - def_API('Z3_solver_next_split', VOID, (_in(CONTEXT), _in(SOLVER_CALLBACK), _in(AST), _in(UINT), _in(LBOOL))) + def_API('Z3_solver_next_split', BOOL, (_in(CONTEXT), _in(SOLVER_CALLBACK), _in(AST), _in(UINT), _in(LBOOL))) */ - void Z3_API Z3_solver_next_split(Z3_context c, Z3_solver_callback cb, Z3_ast t, unsigned idx, Z3_lbool phase); + bool Z3_API Z3_solver_next_split(Z3_context c, Z3_solver_callback cb, Z3_ast t, unsigned idx, Z3_lbool phase); /** Create uninterpreted function declaration for the user propagator. diff --git a/src/ast/rewriter/bv_rewriter.cpp b/src/ast/rewriter/bv_rewriter.cpp index 82bbefce3..700d16ddf 100644 --- a/src/ast/rewriter/bv_rewriter.cpp +++ b/src/ast/rewriter/bv_rewriter.cpp @@ -3019,8 +3019,8 @@ br_status bv_rewriter::mk_bvumul_no_overflow(unsigned num, expr * const * args, br_status bv_rewriter::mk_bvneg_overflow(expr * const arg, expr_ref & result) { unsigned int sz = get_bv_size(arg); - auto maxUnsigned = mk_numeral(rational::power_of_two(sz)-1, sz); - result = m.mk_eq(arg, maxUnsigned); + auto minSigned = mk_numeral(rational::power_of_two(sz - 1), sz); // 0b1000...0 + result = m.mk_eq(arg, minSigned); return BR_REWRITE3; } @@ -3089,7 +3089,7 @@ br_status bv_rewriter::mk_bvssub_overflow(unsigned num, expr * const * args, exp SASSERT(num == 2); SASSERT(get_bv_size(args[0]) == get_bv_size(args[1])); auto sz = get_bv_size(args[0]); - auto minSigned = mk_numeral(-rational::power_of_two(sz-1), sz); + auto minSigned = mk_numeral(rational::power_of_two(sz-1), sz); expr_ref bvsaddo {m}; expr * args2[2] = { args[0], m_util.mk_bv_neg(args[1]) }; auto bvsaddo_stat = mk_bvsadd_overflow(2, args2, bvsaddo); @@ -3102,7 +3102,7 @@ br_status bv_rewriter::mk_bvsdiv_overflow(unsigned num, expr * const * args, exp SASSERT(num == 2); SASSERT(get_bv_size(args[0]) == get_bv_size(args[1])); auto sz = get_bv_size(args[1]); - auto minSigned = mk_numeral(-rational::power_of_two(sz-1), sz); + auto minSigned = mk_numeral(rational::power_of_two(sz-1), sz); auto minusOne = mk_numeral(rational::power_of_two(sz) - 1, sz); result = m.mk_and(m.mk_eq(args[0], minSigned), m.mk_eq(args[1], minusOne)); return BR_REWRITE_FULL; diff --git a/src/ast/rewriter/der.cpp b/src/ast/rewriter/der.cpp index 0e28cf6f6..1e2a19d72 100644 --- a/src/ast/rewriter/der.cpp +++ b/src/ast/rewriter/der.cpp @@ -176,9 +176,9 @@ void der::reduce1(quantifier * q, expr_ref & r, proof_ref & pr) { var * v = nullptr; expr_ref t(m); - if (is_forall(q) && is_var_diseq(e, num_decls, v, t) && !occurs(v, t)) + if (is_forall(q) && is_var_diseq(e, num_decls, v, t) && !has_quantifiers(t) && !occurs(v, t)) r = m.mk_false(); - else if (is_exists(q) && is_var_eq(e, num_decls, v, t) && !occurs(v, t)) + else if (is_exists(q) && is_var_eq(e, num_decls, v, t) && !has_quantifiers(t) && !occurs(v, t)) r = m.mk_true(); else { expr_ref_vector literals(m); diff --git a/src/ast/rewriter/th_rewriter.cpp b/src/ast/rewriter/th_rewriter.cpp index 9278ae5ae..d4e302a5d 100644 --- a/src/ast/rewriter/th_rewriter.cpp +++ b/src/ast/rewriter/th_rewriter.cpp @@ -74,6 +74,7 @@ struct th_rewriter_cfg : public default_rewriter_cfg { bool m_push_ite_bv = true; bool m_ignore_patterns_on_ground_qbody = true; bool m_rewrite_patterns = true; + bool m_enable_der = true; ast_manager & m() const { return m_b_rw.m(); } @@ -89,6 +90,7 @@ struct th_rewriter_cfg : public default_rewriter_cfg { m_push_ite_bv = p.push_ite_bv(); m_ignore_patterns_on_ground_qbody = p.ignore_patterns_on_ground_qbody(); m_rewrite_patterns = p.rewrite_patterns(); + m_enable_der = p.enable_der(); } void updt_params(params_ref const & p) { @@ -827,11 +829,12 @@ struct th_rewriter_cfg : public default_rewriter_cfg { expr_ref r(m()); bool der_change = false; - if (is_quantifier(result) && to_quantifier(result)->get_num_patterns() == 0) { + if (m_enable_der && is_quantifier(result) && to_quantifier(result)->get_num_patterns() == 0) { m_der(to_quantifier(result), r, p2); der_change = result.get() != r.get(); if (m().proofs_enabled() && der_change) - result_pr = m().mk_transitivity(result_pr, p2); + result_pr = m().mk_transitivity(result_pr, p2); + result = r; } diff --git a/src/ast/simplifiers/bound_simplifier.cpp b/src/ast/simplifiers/bound_simplifier.cpp index f5c986425..1a5d4c101 100644 --- a/src/ast/simplifiers/bound_simplifier.cpp +++ b/src/ast/simplifiers/bound_simplifier.cpp @@ -58,9 +58,9 @@ struct bound_simplifier::rw : public rewriter_tpl { br_status bound_simplifier::reduce_app(func_decl* f, unsigned num_args, expr* const* args, expr_ref& result, proof_ref& pr) { rational N, hi, lo; if (a.is_mod(f) && num_args == 2 && a.is_numeral(args[1], N)) { - expr* x = args[0]; auto& im = m_interval; scoped_dep_interval i(im); + expr* x = args[0]; get_bounds(x, i); if (im.upper_is_inf(i) || im.lower_is_inf(i)) return BR_FAILED; @@ -83,7 +83,55 @@ br_status bound_simplifier::reduce_app(func_decl* f, unsigned num_args, expr* co } IF_VERBOSE(2, verbose_stream() << "potentially missed simplification: " << mk_pp(x, m) << " " << lo << " " << hi << " not reduced\n"); } - return BR_FAILED; + + expr_ref_buffer new_args(m); + expr_ref new_arg(m); + bool change = false; + for (unsigned i = 0; i < num_args; ++i) { + expr* arg = args[i]; + change = reduce_arg(arg, new_arg) || change; + new_args.push_back(new_arg); + } + if (!change) + return BR_FAILED; + + result = m.mk_app(f, num_args, new_args.data()); + + return BR_DONE; +} + +bool bound_simplifier::reduce_arg(expr* arg, expr_ref& result) { + result = arg; + expr* x, *y; + rational N, lo, hi; + bool strict; + if ((a.is_le(arg, x, y) && a.is_numeral(y, N)) || + (a.is_ge(arg, y, x) && a.is_numeral(y, N))) { + + if (has_upper(x, hi, strict) && !strict && N >= hi) { + result = m.mk_true(); + return true; + } + if (has_lower(x, lo, strict) && !strict && N < lo) { + result = m.mk_false(); + return true; + } + return false; + } + + if ((a.is_le(arg, y, x) && a.is_numeral(y, N)) || + (a.is_ge(arg, x, y) && a.is_numeral(y, N))) { + if (has_lower(x, lo, strict) && !strict && N <= lo) { + result = m.mk_true(); + return true; + } + if (has_upper(x, hi, strict) && !strict && N > hi) { + result = m.mk_false(); + return true; + } + return false; + } + return false; } void bound_simplifier::reduce() { diff --git a/src/ast/simplifiers/bound_simplifier.h b/src/ast/simplifiers/bound_simplifier.h index 7950f418b..0e3fff239 100644 --- a/src/ast/simplifiers/bound_simplifier.h +++ b/src/ast/simplifiers/bound_simplifier.h @@ -77,8 +77,12 @@ class bound_simplifier : public dependent_expr_simplifier { return v; } + bool reduce_arg(expr* arg, expr_ref& result); + br_status reduce_app(func_decl* f, unsigned num_args, expr* const* args, expr_ref& result, proof_ref& pr); + + void assert_lower(expr* x, rational const& n, bool strict); void assert_upper(expr* x, rational const& n, bool strict); diff --git a/src/cmd_context/cmd_context.cpp b/src/cmd_context/cmd_context.cpp index 8d1e375d1..61def4c18 100644 --- a/src/cmd_context/cmd_context.cpp +++ b/src/cmd_context/cmd_context.cpp @@ -508,8 +508,12 @@ public: m_owner.m_func_decls.contains(s); } format_ns::format * pp_sort(sort * s) override { - return m_owner.pp(s); + auto * f = m_owner.try_pp(s); + if (f) + return f; + return smt2_pp_environment::pp_sort(s); } + format_ns::format * pp_fdecl(func_decl * f, unsigned & len) override { symbol s = f->get_name(); func_decls fs; @@ -2261,8 +2265,12 @@ bool cmd_context::is_model_available(model_ref& md) const { } format_ns::format * cmd_context::pp(sort * s) const { + return get_pp_env().pp_sort(s); +} + +format_ns::format* cmd_context::try_pp(sort* s) const { TRACE("cmd_context", tout << "pp(sort * s), s: " << mk_pp(s, m()) << "\n";); - return pm().pp(s); + return pm().pp(get_pp_env(), s); } cmd_context::pp_env & cmd_context::get_pp_env() const { diff --git a/src/cmd_context/cmd_context.h b/src/cmd_context/cmd_context.h index b034a9ffc..a4eb53237 100644 --- a/src/cmd_context/cmd_context.h +++ b/src/cmd_context/cmd_context.h @@ -538,6 +538,7 @@ public: } format_ns::format * pp(sort * s) const; + format_ns::format* try_pp(sort* s) const; void pp(sort * s, format_ns::format_ref & r) const override { r = pp(s); } void pp(func_decl * f, format_ns::format_ref & r) const override; void pp(expr * n, unsigned num_vars, char const * var_prefix, format_ns::format_ref & r, sbuffer & var_names) const override; diff --git a/src/cmd_context/pdecl.cpp b/src/cmd_context/pdecl.cpp index b8dd01aea..776a91a28 100644 --- a/src/cmd_context/pdecl.cpp +++ b/src/cmd_context/pdecl.cpp @@ -785,7 +785,7 @@ struct pdecl_manager::sort_info { virtual unsigned obj_size() const { return sizeof(sort_info); } virtual void finalize(pdecl_manager & m) { m.dec_ref(m_decl); } virtual void display(std::ostream & out, pdecl_manager const & m) const = 0; - virtual format * pp(pdecl_manager const & m) const = 0; + virtual format * pp(smt2_pp_environment& env, pdecl_manager const & m) const = 0; }; struct pdecl_manager::app_sort_info : public pdecl_manager::sort_info { @@ -817,14 +817,14 @@ struct pdecl_manager::app_sort_info : public pdecl_manager::sort_info { } } - format * pp(pdecl_manager const & m) const override { + format * pp(smt2_pp_environment& env, pdecl_manager const & m) const override { if (m_args.empty()) { return mk_string(m.m(), m_decl->get_name().str()); } else { ptr_buffer b; for (auto arg : m_args) - b.push_back(m.pp(arg)); + b.push_back(m.pp(env, arg)); return mk_seq1(m.m(), b.begin(), b.end(), f2f(), m_decl->get_name().str()); } } @@ -853,7 +853,7 @@ struct pdecl_manager::indexed_sort_info : public pdecl_manager::sort_info { } } - format * pp(pdecl_manager const & m) const override { + format * pp(smt2_pp_environment& env, pdecl_manager const & m) const override { if (m_indices.empty()) { return mk_string(m.m(), m_decl->get_name().str()); } @@ -1072,27 +1072,10 @@ void pdecl_manager::display(std::ostream & out, sort * s) const { out << s->get_name(); } -format * pdecl_manager::pp(sort * s) const { +format * pdecl_manager::pp(smt2_pp_environment& env, sort * s) const { sort_info * info = nullptr; - if (m_sort2info.find(s, info)) { - return info->pp(*this); - } - unsigned num_params = s->get_num_parameters(); - if (s->get_family_id() != null_family_id && num_params > 0) { - // Small hack to display FP and BitVec sorts that were not explicitly referenced by the user. - unsigned i = 0; - for (i = 0; i < num_params; i++) { - if (!s->get_parameter(i).is_int()) - break; - } - if (i == num_params) { - // all parameters are integer - ptr_buffer b; - b.push_back(mk_string(m(), s->get_name().str())); - for (unsigned i = 0; i < num_params; i++) - b.push_back(mk_unsigned(m(), s->get_parameter(i).get_int())); - return mk_seq1(m(), b.begin(), b.end(), f2f(), "_"); - } - } - return mk_string(m(), s->get_name().str()); + if (m_sort2info.find(s, info)) + return info->pp(env, *this); + else + return nullptr; } diff --git a/src/cmd_context/pdecl.h b/src/cmd_context/pdecl.h index a55f782f0..818c97eda 100644 --- a/src/cmd_context/pdecl.h +++ b/src/cmd_context/pdecl.h @@ -23,6 +23,7 @@ Revision History: #include "util/dictionary.h" #include "ast/format.h" #include "ast/datatype_decl_plugin.h" +#include "ast/ast_smt2_pp.h" class pdecl_manager; @@ -333,7 +334,7 @@ public: void save_info(sort * s, psort_decl * d, unsigned num_args, sort * const * args); void save_info(sort * s, psort_decl * d, unsigned num_indices, unsigned const * indices); void display(std::ostream & out, sort * s) const; - format_ns::format * pp(sort * s) const; + format_ns::format * pp(smt2_pp_environment& env, sort * s) const; }; diff --git a/src/math/lp/lar_solver.cpp b/src/math/lp/lar_solver.cpp index 0eb65e197..62712914c 100644 --- a/src/math/lp/lar_solver.cpp +++ b/src/math/lp/lar_solver.cpp @@ -1034,14 +1034,10 @@ namespace lp { r += p.coeff() * get_value(p.column()); return r; } - - impq lar_solver::get_tv_ivalue(tv const& t) const { - if (t.is_var()) - return get_column_value(t.column()); - impq r; - for (lar_term::ival p : get_term(t)) - r += p.coeff() * get_column_value(p.column()); - return r; + //fetches the cached value of the term or the variable by the given index + const impq& lar_solver::get_tv_ivalue(tv const& t) const { + unsigned j = t.is_var()? (unsigned)t.column(): this->map_term_index_to_column_index(t.index()); + return this->get_column_value(j); } void lar_solver::get_rid_of_inf_eps() { diff --git a/src/math/lp/lar_solver.h b/src/math/lp/lar_solver.h index 182ef0be3..5b421d70d 100644 --- a/src/math/lp/lar_solver.h +++ b/src/math/lp/lar_solver.h @@ -506,7 +506,7 @@ public: bool init_model() const; mpq get_value(column_index const& j) const; mpq get_tv_value(tv const& t) const; - impq get_tv_ivalue(tv const& t) const; + const impq & get_tv_ivalue(tv const& t) const; void get_model(std::unordered_map & variable_values) const; void get_rid_of_inf_eps(); void get_model_do_not_care_about_diff_vars(std::unordered_map & variable_values) const; diff --git a/src/math/lp/lp_primal_core_solver_def.h b/src/math/lp/lp_primal_core_solver_def.h index c3c545fdd..e18a5ef05 100644 --- a/src/math/lp/lp_primal_core_solver_def.h +++ b/src/math/lp/lp_primal_core_solver_def.h @@ -37,6 +37,7 @@ void lp_primal_core_solver::sort_non_basis() { unsigned ca = this->m_A.number_of_non_zeroes_in_column(a); unsigned cb = this->m_A.number_of_non_zeroes_in_column(b); if (ca == 0 && cb != 0) return false; + if (ca != 0 && cb == 0) return true; return ca < cb; }); diff --git a/src/params/rewriter_params.pyg b/src/params/rewriter_params.pyg index 290f7b1da..20490606c 100644 --- a/src/params/rewriter_params.pyg +++ b/src/params/rewriter_params.pyg @@ -8,6 +8,7 @@ def_module_params('rewriter', ("pull_cheap_ite", BOOL, False, "pull if-then-else terms when cheap."), ("bv_ineq_consistency_test_max", UINT, 0, "max size of conjunctions on which to perform consistency test based on inequalities on bitvectors."), ("cache_all", BOOL, False, "cache all intermediate results."), + ("enable_der", BOOL, True, "enable destructive equality resolution to quantifiers."), ("rewrite_patterns", BOOL, False, "rewrite patterns."), ("ignore_patterns_on_ground_qbody", BOOL, True, "ignores patterns on quantifiers that don't mention their bound variables."))) diff --git a/src/sat/smt/bv_solver.cpp b/src/sat/smt/bv_solver.cpp index c5795de0c..2c00e888b 100644 --- a/src/sat/smt/bv_solver.cpp +++ b/src/sat/smt/bv_solver.cpp @@ -129,7 +129,7 @@ namespace bv { /** \brief Find an unassigned bit for m_wpos[v], if such bit cannot be found invoke fixed_var_eh */ - void solver::find_wpos(theory_var v) { + bool solver::find_wpos(theory_var v) { literal_vector const& bits = m_bits[v]; unsigned sz = bits.size(); unsigned& wpos = m_wpos[v]; @@ -138,11 +138,12 @@ namespace bv { if (s().value(bits[idx]) == l_undef) { wpos = idx; TRACE("bv", tout << "moved wpos of v" << v << " to " << wpos << "\n";); - return; + return false; } } TRACE("bv", tout << "v" << v << " is a fixed variable.\n";); fixed_var_eh(v); + return true; } /** @@ -919,7 +920,17 @@ namespace bv { values[n->get_root_id()] = bv.mk_numeral(val, m_bits[v].size()); } - trail_stack& solver::get_trail_stack() { + sat::bool_var solver::get_bit(unsigned bit, euf::enode *n) const { + theory_var v = n->get_th_var(get_id()); + if (v == euf::null_theory_var) + return sat::null_bool_var; + auto &bits = m_bits[v]; + if (bit >= bits.size()) + return sat::null_bool_var; + return bits[bit].var(); + } + + trail_stack &solver::get_trail_stack() { return ctx.get_trail_stack(); } diff --git a/src/sat/smt/bv_solver.h b/src/sat/smt/bv_solver.h index 36634a0a1..6847390ea 100644 --- a/src/sat/smt/bv_solver.h +++ b/src/sat/smt/bv_solver.h @@ -382,7 +382,7 @@ namespace bv { // solving theory_var find(theory_var v) const { return m_find.find(v); } - void find_wpos(theory_var v); + bool find_wpos(theory_var v); void find_new_diseq_axioms(atom& a, theory_var v, unsigned idx); void mk_new_diseq_axiom(theory_var v1, theory_var v2, unsigned idx); bool get_fixed_value(theory_var v, numeral& result) const; @@ -395,7 +395,6 @@ namespace bv { numeral const& power2(unsigned i) const; sat::literal mk_true(); - // invariants bool check_zero_one_bits(theory_var v); void check_missing_propagation() const; @@ -452,6 +451,7 @@ namespace bv { euf::theory_var mk_var(euf::enode* n) override; void apply_sort_cnstr(euf::enode * n, sort * s) override; + bool_var get_bit(unsigned bit, euf::enode* n) const; void merge_eh(theory_var, theory_var, theory_var v1, theory_var v2); void after_merge_eh(theory_var r1, theory_var r2, theory_var v1, theory_var v2) { SASSERT(check_zero_one_bits(r1)); } diff --git a/src/sat/smt/user_solver.cpp b/src/sat/smt/user_solver.cpp index 34f2b10b4..1e8897b8c 100644 --- a/src/sat/smt/user_solver.cpp +++ b/src/sat/smt/user_solver.cpp @@ -15,8 +15,9 @@ Author: --*/ -#include "sat/smt/user_solver.h" +#include "sat/smt/bv_solver.h" #include "sat/smt/euf_solver.h" +#include "sat/smt/user_solver.h" namespace user_solver { @@ -39,7 +40,7 @@ namespace user_solver { expr_ref r(m); sat::literal_vector explain; if (ctx.is_fixed(n, r, explain)) - m_prop.push_back(prop_info(explain, v, r)); + m_prop.push_back(prop_info(explain, v, r)); } void solver::propagate_cb( @@ -56,17 +57,21 @@ namespace user_solver { void solver::register_cb(expr* e) { add_expr(e); } - - void solver::next_split_cb(expr* e, unsigned idx, lbool phase) { + + bool solver::next_split_cb(expr* e, unsigned idx, lbool phase) { if (e == nullptr) { - m_next_split_expr = nullptr; - return; + m_next_split_var = sat::null_bool_var; + return true; } force_push(); ctx.internalize(e); - m_next_split_expr = e; - m_next_split_idx = idx; + sat::bool_var var = enode_to_bool(ctx.get_enode(e), idx); m_next_split_phase = phase; + if (var == sat::null_bool_var || s().value(var) != l_undef) + return false; + m_next_split_var = var; + m_next_split_phase = phase; + return true; } sat::check_result solver::check() { @@ -84,39 +89,41 @@ namespace user_solver { m_id2justification.setx(v, sat::literal_vector(num_lits, jlits), sat::literal_vector()); m_fixed_eh(m_user_context, this, var2expr(v), value); } - + bool solver::decide(sat::bool_var& var, lbool& phase) { - + if (!m_decide_eh) return false; - + euf::enode* original_enode = bool_var2enode(var); - + if (!original_enode || !is_attached_to_var(original_enode)) return false; - + unsigned new_bit = 0; // ignored; currently no bv-support expr* e = original_enode->get_expr(); - - m_decide_eh(m_user_context, this, &e, &new_bit, &phase); - - euf::enode* new_enode = ctx.get_enode(e); - - if (original_enode == new_enode || new_enode->bool_var() == sat::null_bool_var) + + m_decide_eh(m_user_context, this, e, new_bit, phase); + sat::bool_var new_var; + if (!get_case_split(new_var, phase) || new_var == var) + // The user did not interfere return false; - - var = new_enode->bool_var(); + var = new_var; + + // check if the new variable is unassigned + if (s().value(var) != l_undef) + throw default_exception("expression in \"decide\" is already assigned"); return true; } - - bool solver::get_case_split(sat::bool_var& var, lbool& phase){ - if (!m_next_split_expr) + + bool solver::get_case_split(sat::bool_var& var, lbool& phase) { + if (m_next_split_var == sat::null_bool_var) return false; - - euf::enode* n = ctx.get_enode(m_next_split_expr); - var = n->bool_var(); + + var = m_next_split_var; phase = m_next_split_phase; - m_next_split_expr = nullptr; + m_next_split_var = sat::null_bool_var; + m_next_split_phase = l_undef; return true; } @@ -134,14 +141,14 @@ namespace user_solver { m_id2justification.setx(v, lits, sat::literal_vector()); m_fixed_eh(m_user_context, this, var2expr(v), lit.sign() ? m.mk_false() : m.mk_true()); } - + void solver::new_eq_eh(euf::th_eq const& eq) { if (!m_eq_eh) return; force_push(); m_eq_eh(m_user_context, this, var2expr(eq.v1()), var2expr(eq.v2())); } - + void solver::new_diseq_eh(euf::th_eq const& de) { if (!m_diseq_eh) return; @@ -188,7 +195,7 @@ namespace user_solver { propagate_consequence(prop); else propagate_new_fixed(prop); - } + } return np < m_stats.m_num_propagations; } @@ -208,7 +215,7 @@ namespace user_solver { auto& j = justification::from_index(idx); auto const& prop = m_prop[j.m_propagation_index]; for (unsigned id : prop.m_ids) - r.append(m_id2justification[id]); + r.append(m_id2justification[id]); for (auto const& p : prop.m_eqs) ctx.add_antecedent(probing, expr2enode(p.first), expr2enode(p.second)); } @@ -243,7 +250,7 @@ namespace user_solver { } std::ostream& solver::display_constraint(std::ostream& out, sat::ext_constraint_idx idx) const { - return display_justification(out, idx); + return display_justification(out, idx); } euf::th_solver* solver::clone(euf::solver& dst_ctx) { @@ -278,26 +285,35 @@ namespace user_solver { return true; } m_stack.push_back(sat::eframe(e)); - return false; + return false; } - + bool solver::visited(expr* e) { euf::enode* n = expr2enode(e); - return n && n->is_attached_to(get_id()); + return n && n->is_attached_to(get_id()); } - + bool solver::post_visit(expr* e, bool sign, bool root) { euf::enode* n = expr2enode(e); SASSERT(!n || !n->is_attached_to(get_id())); - if (!n) - n = mk_enode(e, false); + if (!n) + n = mk_enode(e, false); add_expr(e); if (m_created_eh) m_created_eh(m_user_context, this, e); return true; } - + sat::bool_var solver::enode_to_bool(euf::enode* n, unsigned idx) { + if (n->bool_var() != sat::null_bool_var) { + // expression is a boolean + return n->bool_var(); + } + // expression is a bit-vector + bv_util bv(m); + th_solver* th = ctx.fid2solver(bv.get_fid()); + return ((bv::solver*) th)->get_bit(idx, n); + } } diff --git a/src/sat/smt/user_solver.h b/src/sat/smt/user_solver.h index cb1c6fe94..bd1b703e0 100644 --- a/src/sat/smt/user_solver.h +++ b/src/sat/smt/user_solver.h @@ -75,9 +75,8 @@ namespace user_solver { euf::enode_pair_vector m_eqs; unsigned_vector m_fixed_ids; stats m_stats; - expr* m_next_split_expr = nullptr; - unsigned m_next_split_idx; - lbool m_next_split_phase; + sat::bool_var m_next_split_var = sat::null_bool_var; + lbool m_next_split_phase = l_undef; struct justification { unsigned m_propagation_index { 0 }; @@ -104,6 +103,8 @@ namespace user_solver { bool visited(expr* e) override; bool post_visit(expr* e, bool sign, bool root) override; + sat::bool_var enode_to_bool(euf::enode* n, unsigned idx); + public: solver(euf::solver& ctx); @@ -136,7 +137,7 @@ namespace user_solver { void propagate_cb(unsigned num_fixed, expr* const* fixed_ids, unsigned num_eqs, expr* const* lhs, expr* const* rhs, expr* conseq) override; void register_cb(expr* e) override; - void next_split_cb(expr* e, unsigned idx, lbool phase) override; + bool next_split_cb(expr* e, unsigned idx, lbool phase) override; void new_fixed_eh(euf::theory_var v, expr* value, unsigned num_lits, sat::literal const* jlits); diff --git a/src/smt/qi_queue.cpp b/src/smt/qi_queue.cpp index 582bcc664..781ed7c49 100644 --- a/src/smt/qi_queue.cpp +++ b/src/smt/qi_queue.cpp @@ -131,6 +131,8 @@ namespace smt { // max_top_generation and min_top_generation are not available for computing inc_gen set_values(q, nullptr, generation, 0, 0, cost); float r = m_evaluator(m_new_gen_function, m_vals.size(), m_vals.data()); + if (q->get_weight() > 0 || r > 0) + return static_cast(r); return std::max(generation + 1, static_cast(r)); } diff --git a/src/smt/smt_context.cpp b/src/smt/smt_context.cpp index 82f56ab76..4125b24de 100644 --- a/src/smt/smt_context.cpp +++ b/src/smt/smt_context.cpp @@ -2918,7 +2918,9 @@ namespace smt { bool context::has_split_candidate(bool_var& var, bool& is_pos) { if (!m_user_propagator) return false; - return m_user_propagator->get_case_split(var, is_pos); + if (!m_user_propagator->get_case_split(var, is_pos)) + return false; + return get_assignment(var) == l_undef; } bool context::decide_user_interference(bool_var& var, bool& is_pos) { diff --git a/src/smt/smt_context_pp.cpp b/src/smt/smt_context_pp.cpp index a6088fdf7..fe86c6811 100644 --- a/src/smt/smt_context_pp.cpp +++ b/src/smt/smt_context_pp.cpp @@ -635,7 +635,7 @@ namespace smt { literal_vector lits; const_cast(*m_conflict_resolution).justification2literals(j.get_justification(), lits); out << "justification " << j.get_justification()->get_from_theory() << ": "; - // display_literals_smt2(out, lits); + display_literals_smt2(out, lits); break; } default: diff --git a/src/smt/theory_array_base.cpp b/src/smt/theory_array_base.cpp index 6c2f4038f..b766451df 100644 --- a/src/smt/theory_array_base.cpp +++ b/src/smt/theory_array_base.cpp @@ -969,7 +969,6 @@ namespace smt { } model_value_proc * theory_array_base::mk_value(enode * n, model_generator & mg) { - SASSERT(ctx.is_relevant(n)); theory_var v = n->get_th_var(get_id()); SASSERT(v != null_theory_var); sort * s = n->get_expr()->get_sort(); diff --git a/src/smt/theory_bv.cpp b/src/smt/theory_bv.cpp index 7adab35f4..a44e1a1aa 100644 --- a/src/smt/theory_bv.cpp +++ b/src/smt/theory_bv.cpp @@ -1889,21 +1889,14 @@ namespace smt { return var_enode_pos(nullptr, UINT32_MAX); } - bool_var theory_bv::get_first_unassigned(unsigned start_bit, enode* n) const { + bool_var theory_bv::get_bit(unsigned bit, enode* n) const { theory_var v = n->get_th_var(get_family_id()); + if (v == null_theory_var) + return null_bool_var; auto& bits = m_bits[v]; - unsigned sz = bits.size(); - - for (unsigned i = start_bit; i < sz; ++i) { - if (ctx.get_assignment(bits[i].var()) == l_undef) - return bits[i].var(); - } - for (unsigned i = 0; i < start_bit; ++i) { - if (ctx.get_assignment(bits[i].var()) == l_undef) - return bits[i].var(); - } - - return null_bool_var; + if (bit >= bits.size()) + return null_bool_var; + return bits[bit].var(); } bool theory_bv::check_assignment(theory_var v) { diff --git a/src/smt/theory_bv.h b/src/smt/theory_bv.h index 73d659c68..10cf005e3 100644 --- a/src/smt/theory_bv.h +++ b/src/smt/theory_bv.h @@ -291,7 +291,7 @@ namespace smt { bool is_fixed_propagated(theory_var v, expr_ref& val, literal_vector& explain) override; var_enode_pos get_bv_with_theory(bool_var v, theory_id id) const; - bool_var get_first_unassigned(unsigned start_bit, enode* n) const; + bool_var get_bit(unsigned bit, enode* n) const; bool check_assignment(theory_var v); bool check_invariant(); diff --git a/src/smt/theory_special_relations.cpp b/src/smt/theory_special_relations.cpp index ddddfbc00..b6370f153 100644 --- a/src/smt/theory_special_relations.cpp +++ b/src/smt/theory_special_relations.cpp @@ -888,9 +888,14 @@ namespace smt { func_decl* memf, *nextf, *connectedf; + std::string member, next, connected_sym; + unsigned index = r.decl()->get_parameter(0).get_int(); + member = "member" + std::to_string(index); + next = "next" + std::to_string(index); + connected_sym = "connected" + std::to_string(index); { sort* dom[2] = { s, listS }; - recfun::promise_def mem = p.ensure_def(symbol("member"), 2, dom, m.mk_bool_sort(), true); + recfun::promise_def mem = p.ensure_def(symbol(member), 2, dom, m.mk_bool_sort(), true); memf = mem.get_def()->get_decl(); var_ref xV(m.mk_var(1, s), m); @@ -913,7 +918,7 @@ namespace smt { { sort* dom[5] = { s, s, listS, listS, tup }; - recfun::promise_def nxt = p.ensure_def(symbol("next"), 5, dom, tup, true); + recfun::promise_def nxt = p.ensure_def(symbol(next), 5, dom, tup, true); nextf = nxt.get_def()->get_decl(); expr_ref next_body(m); @@ -934,7 +939,7 @@ namespace smt { { sort* dom[3] = { listS, s, listS }; - recfun::promise_def connected = p.ensure_def(symbol("connected"), 3, dom, m.mk_bool_sort(), true); + recfun::promise_def connected = p.ensure_def(symbol(connected_sym), 3, dom, m.mk_bool_sort(), true); connectedf = connected.get_def()->get_decl(); var_ref AV(m.mk_var(2, listS), m); var_ref dstV(m.mk_var(1, s), m); diff --git a/src/smt/theory_user_propagator.cpp b/src/smt/theory_user_propagator.cpp index 8eeaf4382..2d5b4917d 100644 --- a/src/smt/theory_user_propagator.cpp +++ b/src/smt/theory_user_propagator.cpp @@ -107,15 +107,18 @@ void theory_user_propagator::register_cb(expr* e) { add_expr(e, true); } -void theory_user_propagator::next_split_cb(expr* e, unsigned idx, lbool phase) { +bool theory_user_propagator::next_split_cb(expr* e, unsigned idx, lbool phase) { if (e == nullptr) { // clear - m_next_split_expr = nullptr; - return; + m_next_split_var = null_bool_var; + return true; } ensure_enode(e); - m_next_split_expr = e; - m_next_split_idx = idx; + bool_var b = enode_to_bool(ctx.get_enode(e), idx); + if (b == null_bool_var || ctx.get_assignment(b) != l_undef) + return false; + m_next_split_var = b; m_next_split_phase = phase; + return true; } theory * theory_user_propagator::mk_fresh(context * new_ctx) { @@ -174,18 +177,15 @@ void theory_user_propagator::new_fixed_eh(theory_var v, expr* value, unsigned nu } } -bool_var theory_user_propagator::enode_to_bool(enode* n, unsigned bit) { +bool_var theory_user_propagator::enode_to_bool(enode* n, unsigned idx) { if (n->is_bool()) { // expression is a boolean - bool_var new_var = ctx.enode2bool_var(n); - if (ctx.get_assignment(new_var) == l_undef) - return new_var; - return null_bool_var; + return ctx.enode2bool_var(n); } // expression is a bit-vector bv_util bv(m); auto th_bv = (theory_bv*)ctx.get_theory(bv.get_fid()); - return th_bv->get_first_unassigned(bit, n); + return th_bv->get_bit(idx, n); } void theory_user_propagator::decide(bool_var& var, bool& is_pos) { @@ -225,7 +225,7 @@ void theory_user_propagator::decide(bool_var& var, bool& is_pos) { if (v == null_theory_var) { // it is not a registered boolean value but it is a bitvector - auto registered_bv = ((theory_bv*)th)->get_bv_with_theory(var, get_family_id()); + auto registered_bv = ((theory_bv *) th)->get_bv_with_theory(var, get_family_id()); if (!registered_bv.first) // there is no registered bv associated with the bit return; @@ -236,47 +236,33 @@ void theory_user_propagator::decide(bool_var& var, bool& is_pos) { // call the registered callback unsigned new_bit = original_bit; - lbool phase = is_pos ? l_true : l_false; - - expr* e = var2expr(v); - m_decide_eh(m_user_context, this, &e, &new_bit, &phase); - enode* new_enode = ctx.get_enode(e); - // check if the callback changed something - if (original_enode == new_enode && (new_enode->is_bool() || original_bit == new_bit)) { - if (phase != l_undef) - // it only affected the truth value - is_pos = phase == l_true; + expr *e = var2expr(v); + m_decide_eh(m_user_context, this, e, new_bit, is_pos); + + bool_var new_var; + if (!get_case_split(new_var, is_pos) || new_var == var) + // The user did not interfere return; - } + var = new_var; - // get unassigned variable from enode - var = enode_to_bool(new_enode, new_bit); - - if (var == null_bool_var) - // selected variable is already assigned + // check if the new variable is unassigned + if (ctx.get_assignment(var) != l_undef) throw default_exception("expression in \"decide\" is already assigned"); - - // in case the callback did not decide on a truth value -> let Z3 decide - is_pos = ctx.guess(var, phase); } -bool theory_user_propagator::get_case_split(bool_var& var, bool& is_pos){ - if (!m_next_split_expr) +bool theory_user_propagator::get_case_split(bool_var& var, bool& is_pos) { + if (m_next_split_var == null_bool_var) return false; - enode* n = ctx.get_enode(m_next_split_expr); - - var = enode_to_bool(n, m_next_split_idx); - - if (var == null_bool_var) - return false; - + + var = m_next_split_var; is_pos = ctx.guess(var, m_next_split_phase); - m_next_split_expr = nullptr; + m_next_split_var = null_bool_var; + m_next_split_phase = l_undef; return true; } -void theory_user_propagator::push_scope_eh() { +void theory_user_propagator::push_scope_eh() { ++m_num_scopes; } @@ -421,9 +407,9 @@ bool theory_user_propagator::internalize_term(app* term) { return true; } -void theory_user_propagator::collect_statistics(::statistics & st) const { +void theory_user_propagator::collect_statistics(::statistics& st) const { st.update("user-propagations", m_stats.m_num_propagations); - st.update("user-watched", get_num_vars()); + st.update("user-watched", get_num_vars()); } diff --git a/src/smt/theory_user_propagator.h b/src/smt/theory_user_propagator.h index 2ed1acbdf..5a6eafc0a 100644 --- a/src/smt/theory_user_propagator.h +++ b/src/smt/theory_user_propagator.h @@ -83,9 +83,8 @@ namespace smt { expr_ref_vector m_to_add; unsigned_vector m_to_add_lim; unsigned m_to_add_qhead = 0; - expr* m_next_split_expr = nullptr; - unsigned m_next_split_idx; - lbool m_next_split_phase; + bool_var m_next_split_var = null_bool_var; + lbool m_next_split_phase = l_undef; expr* var2expr(theory_var v) { return m_var2expr.get(v); } theory_var expr2var(expr* e) { check_defined(e); return m_expr2var[e->get_id()]; } @@ -133,7 +132,7 @@ namespace smt { void propagate_cb(unsigned num_fixed, expr* const* fixed_ids, unsigned num_eqs, expr* const* lhs, expr* const* rhs, expr* conseq) override; void register_cb(expr* e) override; - void next_split_cb(expr* e, unsigned idx, lbool phase) override; + bool next_split_cb(expr* e, unsigned idx, lbool phase) override; void new_fixed_eh(theory_var v, expr* value, unsigned num_lits, literal const* jlits); void decide(bool_var& var, bool& is_pos); diff --git a/src/tactic/user_propagator_base.h b/src/tactic/user_propagator_base.h index 68e55be75..d4dae5166 100644 --- a/src/tactic/user_propagator_base.h +++ b/src/tactic/user_propagator_base.h @@ -11,7 +11,7 @@ namespace user_propagator { virtual ~callback() = default; virtual void propagate_cb(unsigned num_fixed, expr* const* fixed_ids, unsigned num_eqs, expr* const* eq_lhs, expr* const* eq_rhs, expr* conseq) = 0; virtual void register_cb(expr* e) = 0; - virtual void next_split_cb(expr* e, unsigned idx, lbool phase) = 0; + virtual bool next_split_cb(expr* e, unsigned idx, lbool phase) = 0; }; class context_obj { @@ -26,7 +26,7 @@ namespace user_propagator { typedef std::function push_eh_t; typedef std::function pop_eh_t; typedef std::function created_eh_t; - typedef std::function decide_eh_t; + typedef std::function decide_eh_t; typedef std::function on_clause_eh_t; class plugin : public decl_plugin {