3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-06 17:44:08 +00:00

Fix UP's decide callback (#6707)

* Query Boolean Assignment in the UP

* UP's decide ref arguments => next_split

* Fixed wrapper

* More fixes
This commit is contained in:
Clemens Eisenhofer 2023-06-02 09:52:54 +02:00 committed by GitHub
parent d59bf55539
commit 82667bd86b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 174 additions and 169 deletions

View file

@ -642,7 +642,7 @@ def mk_java(java_src, java_dir, package_name):
public static native void propagateRegisterFinal(Object o, long ctx, long solver);
public static native void propagateConflict(Object o, long ctx, long solver, long javainfo, int num_fixed, long[] fixed, long num_eqs, long[] eq_lhs, long[] eq_rhs, long conseq);
public static native void propagateAdd(Object o, long ctx, long solver, long javainfo, long e);
public static native void propagateNextSplit(Object o, long ctx, long solver, long javainfo, long e, long idx, long phase);
public static native boolean propagateNextSplit(Object o, long ctx, long solver, long javainfo, long e, long idx, int phase);
public static native void propagateDestroy(Object o, long ctx, long solver, long javainfo);
public static abstract class UserPropagatorBase implements AutoCloseable {
@ -1929,7 +1929,7 @@ Z3_final_eh = ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_void_p)
Z3_eq_eh = ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p)
Z3_created_eh = ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p)
Z3_decide_eh = ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p)
Z3_decide_eh = ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_uint, ctypes.c_int)
_lib.Z3_solver_register_on_clause.restype = None
_lib.Z3_solver_propagate_init.restype = None

View file

@ -1114,17 +1114,17 @@ extern "C" {
void Z3_API Z3_solver_propagate_decide(Z3_context c, Z3_solver s, Z3_decide_eh decide_eh) {
Z3_TRY;
RESET_ERROR_CODE();
user_propagator::decide_eh_t c = (void(*)(void*, user_propagator::callback*, expr**, unsigned*, lbool*))decide_eh;
user_propagator::decide_eh_t c = (void(*)(void*, user_propagator::callback*, expr*, unsigned, bool))decide_eh;
to_solver_ref(s)->user_propagate_register_decide(c);
Z3_CATCH;
}
void Z3_API Z3_solver_next_split(Z3_context c, Z3_solver_callback cb, Z3_ast t, unsigned idx, Z3_lbool phase) {
bool Z3_API Z3_solver_next_split(Z3_context c, Z3_solver_callback cb, Z3_ast t, unsigned idx, Z3_lbool phase) {
Z3_TRY;
LOG_Z3_solver_next_split(c, cb, t, idx, phase);
RESET_ERROR_CODE();
reinterpret_cast<user_propagator::callback*>(cb)->next_split_cb(to_expr(t), idx, (lbool)phase);
Z3_CATCH;
return reinterpret_cast<user_propagator::callback*>(cb)->next_split_cb(to_expr(t), idx, (lbool)phase);
Z3_CATCH_RETURN(false);
}
Z3_func_decl Z3_API Z3_solver_propagate_declare(Z3_context c, Z3_symbol name, unsigned n, Z3_sort* domain, Z3_sort range) {

View file

@ -4240,7 +4240,7 @@ namespace z3 {
typedef std::function<void(void)> final_eh_t;
typedef std::function<void(expr const&, expr const&)> eq_eh_t;
typedef std::function<void(expr const&)> created_eh_t;
typedef std::function<void(expr&, unsigned&, Z3_lbool&)> decide_eh_t;
typedef std::function<void(expr, unsigned, bool)> decide_eh_t;
final_eh_t m_final_eh;
eq_eh_t m_eq_eh;
@ -4309,13 +4309,11 @@ namespace z3 {
p->m_created_eh(e);
}
static void decide_eh(void* _p, Z3_solver_callback cb, Z3_ast* _val, unsigned* bit, Z3_lbool* is_pos) {
static void decide_eh(void* _p, Z3_solver_callback cb, Z3_ast _val, unsigned bit, bool is_pos) {
user_propagator_base* p = static_cast<user_propagator_base*>(_p);
scoped_cb _cb(p, cb);
expr val(p->ctx(), *_val);
p->m_decide_eh(val, *bit, *is_pos);
// TBD: life time of val is within the scope of this callback.
*_val = val;
expr val(p->ctx(), _val);
p->m_decide_eh(val, bit, is_pos);
}
public:
@ -4435,7 +4433,7 @@ namespace z3 {
}
void register_decide() {
m_decide_eh = [this](expr& val, unsigned& bit, Z3_lbool& is_pos) {
m_decide_eh = [this](expr val, unsigned bit, bool is_pos) {
decide(val, bit, is_pos);
};
if (s) {
@ -4451,11 +4449,11 @@ namespace z3 {
virtual void created(expr const& /*e*/) {}
virtual void decide(expr& /*val*/, unsigned& /*bit*/, Z3_lbool& /*is_pos*/) {}
virtual void decide(expr const& /*val*/, unsigned /*bit*/, bool /*is_pos*/) {}
void next_split(expr const & e, unsigned idx, Z3_lbool phase) {
bool next_split(expr const& e, unsigned idx, Z3_lbool phase) {
assert(cb);
Z3_solver_next_split(ctx(), cb, e, idx, phase);
return Z3_solver_next_split(ctx(), cb, e, idx, phase);
}
/**

View file

@ -58,12 +58,12 @@ namespace Microsoft.Z3
public delegate void CreatedEh(Expr term);
/// <summary>
/// Delegate type for callback into solver's branching
/// Delegate type for callback into solver's branching. The values can be overriden by calling <see cref="NextSplit" />.
/// </summary>
/// <param name="term">A bit-vector or Boolean used for branching</param>
/// <param name="idx">If the term is a bit-vector, then an index into the bit-vector being branched on</param>
/// <param name="phase">Set phase to -1 (false) or 1 (true) to override solver's phase</param>
/// </summary>
public delegate void DecideEh(ref Expr term, ref uint idx, ref int phase);
/// <param name="phase">The tentative truth-value</param>
public delegate void DecideEh(Expr term, uint idx, bool phase);
// access managed objects through a static array.
// thread safety is ignored for now.
@ -168,16 +168,11 @@ namespace Microsoft.Z3
prop.Callback(() => prop.created_eh(t), cb);
}
static void _decide(voidp ctx, Z3_solver_callback cb, ref Z3_ast a, ref uint idx, ref int phase)
static void _decide(voidp ctx, Z3_solver_callback cb, Z3_ast a, uint idx, bool phase)
{
var prop = (UserPropagator)GCHandle.FromIntPtr(ctx).Target;
var t = Expr.Create(prop.ctx, a);
var u = t;
prop.callback = cb;
prop.decide_eh(ref t, ref idx, ref phase);
prop.callback = IntPtr.Zero;
if (u != t)
a = t.NativeObject;
using var t = Expr.Create(prop.ctx, a);
prop.Callback(() => prop.decide_eh(t, idx, phase), cb);
}
/// <summary>
@ -352,10 +347,17 @@ namespace Microsoft.Z3
/// <summary>
/// Set the next decision
/// <param name="e">A bit-vector or Boolean used for branching. Use <see langword="null" /> to clear</param>
/// <param name="idx">If the term is a bit-vector, then an index into the bit-vector being branched on</param>
/// <param name="phase">The tentative truth-value (-1/false, 1/true, 0/let Z3 decide)</param>
/// </summary>
public void NextSplit(Expr e, uint idx, int phase)
/// <returns>
/// <see langword="true" /> in case the value was successfully set;
/// <see langword="false" /> if the next split could not be set
/// </returns>
public bool NextSplit(Expr e, uint idx, int phase)
{
Native.Z3_solver_next_split(ctx.nCtx, this.callback, e.NativeObject, idx, phase);
return Native.Z3_solver_next_split(ctx.nCtx, this.callback, e?.NativeObject ?? IntPtr.Zero, idx, phase) != 0;
}
/// <summary>

View file

@ -91,6 +91,7 @@ struct JavaInfo {
jmethodID fixed = nullptr;
jmethodID eq = nullptr;
jmethodID final = nullptr;
jmethodID decide = nullptr;
Z3_solver_callback cb = nullptr;
};
@ -146,11 +147,10 @@ static void final_eh(void* _p, Z3_solver_callback cb) {
info->jenv->CallVoidMethod(info->jobj, info->final);
}
// TODO: implement decide
static void decide_eh(void* _p, Z3_solver_callback cb, Z3_ast* _val, unsigned* bit, Z3_lbool* is_pos) {
static void decide_eh(void* _p, Z3_solver_callback cb, Z3_ast _val, unsigned bit, bool is_pos) {
JavaInfo *info = static_cast<JavaInfo*>(_p);
ScopedCB scoped(info, cb);
info->jenv->CallVoidMethod(info->jobj, info->decide, (jlong)_val);
}
DLL_VIS JNIEXPORT jlong JNICALL Java_com_microsoft_z3_Native_propagateInit(JNIEnv *jenv, jclass cls, jobject jobj, jlong ctx, jlong solver) {
@ -166,8 +166,9 @@ DLL_VIS JNIEXPORT jlong JNICALL Java_com_microsoft_z3_Native_propagateInit(JNIEn
info->fixed = jenv->GetMethodID(jcls, "fixedWrapper", "(JJ)V");
info->eq = jenv->GetMethodID(jcls, "eqWrapper", "(JJ)V");
info->final = jenv->GetMethodID(jcls, "finWrapper", "()V");
if (!info->push || !info->pop || !info->fresh || !info->created || !info->fixed || !info->eq || !info->final) {
info->decide = jenv->GetMethodID(jcls, "decideWrapper", "(JII)V");
if (!info->push || !info->pop || !info->fresh || !info->created || !info->fixed || !info->eq || !info->final || !info->decide) {
assert(false);
}
@ -225,8 +226,8 @@ DLL_VIS JNIEXPORT void JNICALL Java_com_microsoft_z3_Native_propagateAdd(JNIEnv
}
}
DLL_VIS JNIEXPORT void JNICALL Java_com_microsoft_z3_Native_propagateNextSplit(JNIEnv * jenv, jclass cls, jobject jobj, jlong ctx, jlong solver, jlong javainfo, long e, long idx, long phase) {
DLL_VIS JNIEXPORT bool JNICALL Java_com_microsoft_z3_Native_propagateNextSplit(JNIEnv * jenv, jclass cls, jobject jobj, jlong ctx, jlong solver, jlong javainfo, long e, long idx, int phase) {
JavaInfo *info = (JavaInfo*)javainfo;
Z3_solver_callback cb = info->cb;
Z3_solver_next_split((Z3_context)ctx, cb, (Z3_ast)e, idx, Z3_lbool(phase));
return Z3_solver_next_split((Z3_context)ctx, cb, (Z3_ast)e, idx, Z3_lbool(phase));
}

View file

@ -89,9 +89,9 @@ public abstract class UserPropagatorBase extends Native.UserPropagatorBase {
fixed.length, AST.arrayToNative(fixed), lhs.length, AST.arrayToNative(lhs), AST.arrayToNative(rhs), conseq.getNativeObject());
}
public final <R extends Sort> void nextSplit(Expr<R> e, long idx, Z3_lbool phase) {
Native.propagateNextSplit(
public final <R extends Sort> boolean nextSplit(Expr<R> e, long idx, Z3_lbool phase) {
return Native.propagateNextSplit(
this, ctx.nCtx(), solver.getNativeObject(), javainfo,
e.getNativeObject(), idx, phase.toInt());
}
}
}

View file

@ -11529,16 +11529,11 @@ def user_prop_diseq(ctx, cb, x, y):
prop.diseq(x, y)
prop.cb = None
# TODO The decision callback is not fully implemented.
# It needs to handle the ast*, unsigned* idx, and Z3_lbool*
def user_prop_decide(ctx, cb, t_ref, idx_ref, phase_ref):
def user_prop_decide(ctx, cb, t, idx, phase):
prop = _prop_closures.get(ctx)
prop.cb = cb
t = _to_expr_ref(to_Ast(t_ref), prop.ctx())
t, idx, phase = prop.decide(t, idx, phase)
t_ref = t
idx_ref = idx
phase_ref = phase
prop.decide(t, idx, phase)
prop.cb = None
@ -11685,7 +11680,7 @@ class UserPropagateBase:
# split on. A phase of true = 1/false = -1/undef = 0 = let solver decide is the last argument.
#
def next_split(self, t, idx, phase):
Z3_solver_next_split(self.ctx_ref(), ctypes.c_void_p(self.cb), t.ast, idx, phase)
return Z3_solver_next_split(self.ctx_ref(), ctypes.c_void_p(self.cb), t.ast, idx, phase)
#
# Propagation can only be invoked as during a fixed or final callback.

View file

@ -1435,7 +1435,7 @@ Z3_DECLARE_CLOSURE(Z3_fixed_eh, void, (void* ctx, Z3_solver_callback cb, Z3_as
Z3_DECLARE_CLOSURE(Z3_eq_eh, void, (void* ctx, Z3_solver_callback cb, Z3_ast s, Z3_ast t));
Z3_DECLARE_CLOSURE(Z3_final_eh, void, (void* ctx, Z3_solver_callback cb));
Z3_DECLARE_CLOSURE(Z3_created_eh, void, (void* ctx, Z3_solver_callback cb, Z3_ast t));
Z3_DECLARE_CLOSURE(Z3_decide_eh, void, (void* ctx, Z3_solver_callback cb, Z3_ast* t, unsigned* idx, Z3_lbool* phase));
Z3_DECLARE_CLOSURE(Z3_decide_eh, void, (void* ctx, Z3_solver_callback cb, Z3_ast t, unsigned idx, bool phase));
Z3_DECLARE_CLOSURE(Z3_on_clause_eh, void, (void* ctx, Z3_ast proof_hint, Z3_ast_vector literals));
@ -7098,20 +7098,21 @@ extern "C" {
/**
\brief register a callback when the solver decides to split on a registered expression.
The callback may set the passed expression to another registered expression which will be selected instead.
In case the expression is a bitvector the bit to split on is determined by the bit argument and the
truth-value to try first is given by is_pos. In case the truth value is undefined the solver will decide.
The callback may change the arguments by providing other values by calling \ref Z3_solver_next_split
def_API('Z3_solver_propagate_decide', VOID, (_in(CONTEXT), _in(SOLVER), _fnptr(Z3_decide_eh)))
*/
void Z3_API Z3_solver_propagate_decide(Z3_context c, Z3_solver s, Z3_decide_eh decide_eh);
/**
Sets the next expression to split on
Sets the next (registered) expression to split on.
The function returns false and ignores the given expression in case the expression is already assigned internally
(due to relevancy propagation, this assignments might not have been reported yet by the fixed callback).
In case the function is called in the decide callback, it overrides the currently selected variable and phase.
def_API('Z3_solver_next_split', VOID, (_in(CONTEXT), _in(SOLVER_CALLBACK), _in(AST), _in(UINT), _in(LBOOL)))
def_API('Z3_solver_next_split', BOOL, (_in(CONTEXT), _in(SOLVER_CALLBACK), _in(AST), _in(UINT), _in(LBOOL)))
*/
void Z3_API Z3_solver_next_split(Z3_context c, Z3_solver_callback cb, Z3_ast t, unsigned idx, Z3_lbool phase);
bool Z3_API Z3_solver_next_split(Z3_context c, Z3_solver_callback cb, Z3_ast t, unsigned idx, Z3_lbool phase);
/**
Create uninterpreted function declaration for the user propagator.

View file

@ -128,7 +128,7 @@ namespace bv {
/**
\brief Find an unassigned bit for m_wpos[v], if such bit cannot be found invoke fixed_var_eh
*/
void solver::find_wpos(theory_var v) {
bool solver::find_wpos(theory_var v) {
literal_vector const& bits = m_bits[v];
unsigned sz = bits.size();
unsigned& wpos = m_wpos[v];
@ -137,11 +137,12 @@ namespace bv {
if (s().value(bits[idx]) == l_undef) {
wpos = idx;
TRACE("bv", tout << "moved wpos of v" << v << " to " << wpos << "\n";);
return;
return false;
}
}
TRACE("bv", tout << "v" << v << " is a fixed variable.\n";);
fixed_var_eh(v);
return true;
}
/**
@ -853,7 +854,17 @@ namespace bv {
values[n->get_root_id()] = bv.mk_numeral(val, m_bits[v].size());
}
trail_stack& solver::get_trail_stack() {
sat::bool_var solver::get_bit(unsigned bit, euf::enode *n) const {
theory_var v = n->get_th_var(get_id());
if (v == euf::null_theory_var)
return sat::null_bool_var;
auto &bits = m_bits[v];
if (bit >= bits.size())
return sat::null_bool_var;
return bits[bit].var();
}
trail_stack &solver::get_trail_stack() {
return ctx.get_trail_stack();
}

View file

@ -321,7 +321,7 @@ namespace bv {
// solving
theory_var find(theory_var v) const { return m_find.find(v); }
void find_wpos(theory_var v);
bool find_wpos(theory_var v);
void find_new_diseq_axioms(atom& a, theory_var v, unsigned idx);
void mk_new_diseq_axiom(theory_var v1, theory_var v2, unsigned idx);
bool get_fixed_value(theory_var v, numeral& result) const;
@ -334,7 +334,6 @@ namespace bv {
numeral const& power2(unsigned i) const;
sat::literal mk_true();
// invariants
bool check_zero_one_bits(theory_var v);
void check_missing_propagation() const;
@ -391,6 +390,7 @@ namespace bv {
euf::theory_var mk_var(euf::enode* n) override;
void apply_sort_cnstr(euf::enode * n, sort * s) override;
bool_var get_bit(unsigned bit, euf::enode* n) const;
void merge_eh(theory_var, theory_var, theory_var v1, theory_var v2);
void after_merge_eh(theory_var r1, theory_var r2, theory_var v1, theory_var v2) { SASSERT(check_zero_one_bits(r1)); }

View file

@ -15,8 +15,9 @@ Author:
--*/
#include "sat/smt/user_solver.h"
#include "sat/smt/bv_solver.h"
#include "sat/smt/euf_solver.h"
#include "sat/smt/user_solver.h"
namespace user_solver {
@ -39,7 +40,7 @@ namespace user_solver {
expr_ref r(m);
sat::literal_vector explain;
if (ctx.is_fixed(n, r, explain))
m_prop.push_back(prop_info(explain, v, r));
m_prop.push_back(prop_info(explain, v, r));
}
void solver::propagate_cb(
@ -56,17 +57,21 @@ namespace user_solver {
void solver::register_cb(expr* e) {
add_expr(e);
}
void solver::next_split_cb(expr* e, unsigned idx, lbool phase) {
bool solver::next_split_cb(expr* e, unsigned idx, lbool phase) {
if (e == nullptr) {
m_next_split_expr = nullptr;
return;
m_next_split_var = sat::null_bool_var;
return true;
}
force_push();
ctx.internalize(e);
m_next_split_expr = e;
m_next_split_idx = idx;
sat::bool_var var = enode_to_bool(ctx.get_enode(e), idx);
m_next_split_phase = phase;
if (var == sat::null_bool_var || s().value(var) != l_undef)
return false;
m_next_split_var = var;
m_next_split_phase = phase;
return true;
}
sat::check_result solver::check() {
@ -84,39 +89,41 @@ namespace user_solver {
m_id2justification.setx(v, sat::literal_vector(num_lits, jlits), sat::literal_vector());
m_fixed_eh(m_user_context, this, var2expr(v), value);
}
bool solver::decide(sat::bool_var& var, lbool& phase) {
if (!m_decide_eh)
return false;
euf::enode* original_enode = bool_var2enode(var);
if (!original_enode || !is_attached_to_var(original_enode))
return false;
unsigned new_bit = 0; // ignored; currently no bv-support
expr* e = original_enode->get_expr();
m_decide_eh(m_user_context, this, &e, &new_bit, &phase);
euf::enode* new_enode = ctx.get_enode(e);
if (original_enode == new_enode || new_enode->bool_var() == sat::null_bool_var)
m_decide_eh(m_user_context, this, e, new_bit, phase);
sat::bool_var new_var;
if (!get_case_split(new_var, phase) || new_var == var)
// The user did not interfere
return false;
var = new_enode->bool_var();
var = new_var;
// check if the new variable is unassigned
if (s().value(var) != l_undef)
throw default_exception("expression in \"decide\" is already assigned");
return true;
}
bool solver::get_case_split(sat::bool_var& var, lbool& phase){
if (!m_next_split_expr)
bool solver::get_case_split(sat::bool_var& var, lbool& phase) {
if (m_next_split_var == sat::null_bool_var)
return false;
euf::enode* n = ctx.get_enode(m_next_split_expr);
var = n->bool_var();
var = m_next_split_var;
phase = m_next_split_phase;
m_next_split_expr = nullptr;
m_next_split_var = sat::null_bool_var;
m_next_split_phase = l_undef;
return true;
}
@ -134,14 +141,14 @@ namespace user_solver {
m_id2justification.setx(v, lits, sat::literal_vector());
m_fixed_eh(m_user_context, this, var2expr(v), lit.sign() ? m.mk_false() : m.mk_true());
}
void solver::new_eq_eh(euf::th_eq const& eq) {
if (!m_eq_eh)
return;
force_push();
m_eq_eh(m_user_context, this, var2expr(eq.v1()), var2expr(eq.v2()));
}
void solver::new_diseq_eh(euf::th_eq const& de) {
if (!m_diseq_eh)
return;
@ -188,7 +195,7 @@ namespace user_solver {
propagate_consequence(prop);
else
propagate_new_fixed(prop);
}
}
return np < m_stats.m_num_propagations;
}
@ -208,7 +215,7 @@ namespace user_solver {
auto& j = justification::from_index(idx);
auto const& prop = m_prop[j.m_propagation_index];
for (unsigned id : prop.m_ids)
r.append(m_id2justification[id]);
r.append(m_id2justification[id]);
for (auto const& p : prop.m_eqs)
ctx.add_antecedent(probing, expr2enode(p.first), expr2enode(p.second));
}
@ -243,7 +250,7 @@ namespace user_solver {
}
std::ostream& solver::display_constraint(std::ostream& out, sat::ext_constraint_idx idx) const {
return display_justification(out, idx);
return display_justification(out, idx);
}
euf::th_solver* solver::clone(euf::solver& dst_ctx) {
@ -278,26 +285,35 @@ namespace user_solver {
return true;
}
m_stack.push_back(sat::eframe(e));
return false;
return false;
}
bool solver::visited(expr* e) {
euf::enode* n = expr2enode(e);
return n && n->is_attached_to(get_id());
return n && n->is_attached_to(get_id());
}
bool solver::post_visit(expr* e, bool sign, bool root) {
euf::enode* n = expr2enode(e);
SASSERT(!n || !n->is_attached_to(get_id()));
if (!n)
n = mk_enode(e, false);
if (!n)
n = mk_enode(e, false);
add_expr(e);
if (m_created_eh)
m_created_eh(m_user_context, this, e);
return true;
}
sat::bool_var solver::enode_to_bool(euf::enode* n, unsigned idx) {
if (n->bool_var() != sat::null_bool_var) {
// expression is a boolean
return n->bool_var();
}
// expression is a bit-vector
bv_util bv(m);
th_solver* th = ctx.fid2solver(bv.get_fid());
return ((bv::solver*) th)->get_bit(idx, n);
}
}

View file

@ -75,9 +75,8 @@ namespace user_solver {
euf::enode_pair_vector m_eqs;
unsigned_vector m_fixed_ids;
stats m_stats;
expr* m_next_split_expr = nullptr;
unsigned m_next_split_idx;
lbool m_next_split_phase;
sat::bool_var m_next_split_var = sat::null_bool_var;
lbool m_next_split_phase = l_undef;
struct justification {
unsigned m_propagation_index { 0 };
@ -104,6 +103,8 @@ namespace user_solver {
bool visited(expr* e) override;
bool post_visit(expr* e, bool sign, bool root) override;
sat::bool_var enode_to_bool(euf::enode* n, unsigned idx);
public:
solver(euf::solver& ctx);
@ -136,7 +137,7 @@ namespace user_solver {
void propagate_cb(unsigned num_fixed, expr* const* fixed_ids, unsigned num_eqs, expr* const* lhs, expr* const* rhs, expr* conseq) override;
void register_cb(expr* e) override;
void next_split_cb(expr* e, unsigned idx, lbool phase) override;
bool next_split_cb(expr* e, unsigned idx, lbool phase) override;
void new_fixed_eh(euf::theory_var v, expr* value, unsigned num_lits, sat::literal const* jlits);

View file

@ -2918,7 +2918,9 @@ namespace smt {
bool context::has_split_candidate(bool_var& var, bool& is_pos) {
if (!m_user_propagator)
return false;
return m_user_propagator->get_case_split(var, is_pos);
if (!m_user_propagator->get_case_split(var, is_pos))
return false;
return get_assignment(var) == l_undef;
}
bool context::decide_user_interference(bool_var& var, bool& is_pos) {

View file

@ -1889,21 +1889,14 @@ namespace smt {
return var_enode_pos(nullptr, UINT32_MAX);
}
bool_var theory_bv::get_first_unassigned(unsigned start_bit, enode* n) const {
bool_var theory_bv::get_bit(unsigned bit, enode* n) const {
theory_var v = n->get_th_var(get_family_id());
if (v == null_theory_var)
return null_bool_var;
auto& bits = m_bits[v];
unsigned sz = bits.size();
for (unsigned i = start_bit; i < sz; ++i) {
if (ctx.get_assignment(bits[i].var()) == l_undef)
return bits[i].var();
}
for (unsigned i = 0; i < start_bit; ++i) {
if (ctx.get_assignment(bits[i].var()) == l_undef)
return bits[i].var();
}
return null_bool_var;
if (bit >= bits.size())
return null_bool_var;
return bits[bit].var();
}
bool theory_bv::check_assignment(theory_var v) {

View file

@ -291,7 +291,7 @@ namespace smt {
bool is_fixed_propagated(theory_var v, expr_ref& val, literal_vector& explain) override;
var_enode_pos get_bv_with_theory(bool_var v, theory_id id) const;
bool_var get_first_unassigned(unsigned start_bit, enode* n) const;
bool_var get_bit(unsigned bit, enode* n) const;
bool check_assignment(theory_var v);
bool check_invariant();

View file

@ -107,15 +107,18 @@ void theory_user_propagator::register_cb(expr* e) {
add_expr(e, true);
}
void theory_user_propagator::next_split_cb(expr* e, unsigned idx, lbool phase) {
bool theory_user_propagator::next_split_cb(expr* e, unsigned idx, lbool phase) {
if (e == nullptr) { // clear
m_next_split_expr = nullptr;
return;
m_next_split_var = null_bool_var;
return true;
}
ensure_enode(e);
m_next_split_expr = e;
m_next_split_idx = idx;
bool_var b = enode_to_bool(ctx.get_enode(e), idx);
if (b == null_bool_var || ctx.get_assignment(b) != l_undef)
return false;
m_next_split_var = b;
m_next_split_phase = phase;
return true;
}
theory * theory_user_propagator::mk_fresh(context * new_ctx) {
@ -174,18 +177,15 @@ void theory_user_propagator::new_fixed_eh(theory_var v, expr* value, unsigned nu
}
}
bool_var theory_user_propagator::enode_to_bool(enode* n, unsigned bit) {
bool_var theory_user_propagator::enode_to_bool(enode* n, unsigned idx) {
if (n->is_bool()) {
// expression is a boolean
bool_var new_var = ctx.enode2bool_var(n);
if (ctx.get_assignment(new_var) == l_undef)
return new_var;
return null_bool_var;
return ctx.enode2bool_var(n);
}
// expression is a bit-vector
bv_util bv(m);
auto th_bv = (theory_bv*)ctx.get_theory(bv.get_fid());
return th_bv->get_first_unassigned(bit, n);
return th_bv->get_bit(idx, n);
}
void theory_user_propagator::decide(bool_var& var, bool& is_pos) {
@ -225,7 +225,7 @@ void theory_user_propagator::decide(bool_var& var, bool& is_pos) {
if (v == null_theory_var) {
// it is not a registered boolean value but it is a bitvector
auto registered_bv = ((theory_bv*)th)->get_bv_with_theory(var, get_family_id());
auto registered_bv = ((theory_bv *) th)->get_bv_with_theory(var, get_family_id());
if (!registered_bv.first)
// there is no registered bv associated with the bit
return;
@ -236,47 +236,33 @@ void theory_user_propagator::decide(bool_var& var, bool& is_pos) {
// call the registered callback
unsigned new_bit = original_bit;
lbool phase = is_pos ? l_true : l_false;
expr* e = var2expr(v);
m_decide_eh(m_user_context, this, &e, &new_bit, &phase);
enode* new_enode = ctx.get_enode(e);
// check if the callback changed something
if (original_enode == new_enode && (new_enode->is_bool() || original_bit == new_bit)) {
if (phase != l_undef)
// it only affected the truth value
is_pos = phase == l_true;
expr *e = var2expr(v);
m_decide_eh(m_user_context, this, e, new_bit, is_pos);
bool_var new_var;
if (!get_case_split(new_var, is_pos) || new_var == var)
// The user did not interfere
return;
}
var = new_var;
// get unassigned variable from enode
var = enode_to_bool(new_enode, new_bit);
if (var == null_bool_var)
// selected variable is already assigned
// check if the new variable is unassigned
if (ctx.get_assignment(var) != l_undef)
throw default_exception("expression in \"decide\" is already assigned");
// in case the callback did not decide on a truth value -> let Z3 decide
is_pos = ctx.guess(var, phase);
}
bool theory_user_propagator::get_case_split(bool_var& var, bool& is_pos){
if (!m_next_split_expr)
bool theory_user_propagator::get_case_split(bool_var& var, bool& is_pos) {
if (m_next_split_var == null_bool_var)
return false;
enode* n = ctx.get_enode(m_next_split_expr);
var = enode_to_bool(n, m_next_split_idx);
if (var == null_bool_var)
return false;
var = m_next_split_var;
is_pos = ctx.guess(var, m_next_split_phase);
m_next_split_expr = nullptr;
m_next_split_var = null_bool_var;
m_next_split_phase = l_undef;
return true;
}
void theory_user_propagator::push_scope_eh() {
void theory_user_propagator::push_scope_eh() {
++m_num_scopes;
}
@ -421,9 +407,9 @@ bool theory_user_propagator::internalize_term(app* term) {
return true;
}
void theory_user_propagator::collect_statistics(::statistics & st) const {
void theory_user_propagator::collect_statistics(::statistics& st) const {
st.update("user-propagations", m_stats.m_num_propagations);
st.update("user-watched", get_num_vars());
st.update("user-watched", get_num_vars());
}

View file

@ -83,9 +83,8 @@ namespace smt {
expr_ref_vector m_to_add;
unsigned_vector m_to_add_lim;
unsigned m_to_add_qhead = 0;
expr* m_next_split_expr = nullptr;
unsigned m_next_split_idx;
lbool m_next_split_phase;
bool_var m_next_split_var = null_bool_var;
lbool m_next_split_phase = l_undef;
expr* var2expr(theory_var v) { return m_var2expr.get(v); }
theory_var expr2var(expr* e) { check_defined(e); return m_expr2var[e->get_id()]; }
@ -133,7 +132,7 @@ namespace smt {
void propagate_cb(unsigned num_fixed, expr* const* fixed_ids, unsigned num_eqs, expr* const* lhs, expr* const* rhs, expr* conseq) override;
void register_cb(expr* e) override;
void next_split_cb(expr* e, unsigned idx, lbool phase) override;
bool next_split_cb(expr* e, unsigned idx, lbool phase) override;
void new_fixed_eh(theory_var v, expr* value, unsigned num_lits, literal const* jlits);
void decide(bool_var& var, bool& is_pos);

View file

@ -11,7 +11,7 @@ namespace user_propagator {
virtual ~callback() = default;
virtual void propagate_cb(unsigned num_fixed, expr* const* fixed_ids, unsigned num_eqs, expr* const* eq_lhs, expr* const* eq_rhs, expr* conseq) = 0;
virtual void register_cb(expr* e) = 0;
virtual void next_split_cb(expr* e, unsigned idx, lbool phase) = 0;
virtual bool next_split_cb(expr* e, unsigned idx, lbool phase) = 0;
};
class context_obj {
@ -26,7 +26,7 @@ namespace user_propagator {
typedef std::function<void(void*, callback*)> push_eh_t;
typedef std::function<void(void*, callback*, unsigned)> pop_eh_t;
typedef std::function<void(void*, callback*, expr*)> created_eh_t;
typedef std::function<void(void*, callback*, expr**, unsigned*, lbool*)> decide_eh_t;
typedef std::function<void(void*, callback*, expr*, unsigned, bool)> decide_eh_t;
typedef std::function<void(void*, expr*, unsigned, expr* const*)> on_clause_eh_t;
class plugin : public decl_plugin {