3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-06 17:44:08 +00:00

Added function to select the next variable to split on (User-Propagator) (#6096)

* Added function to select the next variable to split on

* Fixed typo

* Small fixes

* uint -> int
This commit is contained in:
Clemens Eisenhofer 2022-06-19 19:49:25 +02:00 committed by GitHub
parent f08e3d70a9
commit 2fa60aa43c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 255 additions and 104 deletions

View file

@ -61,6 +61,7 @@ DOUBLE = 12
FLOAT = 13
CHAR = 14
CHAR_PTR = 15
LBOOL = 16
FIRST_FN_ID = 50
@ -74,25 +75,25 @@ def is_fn(ty):
Type2Str = { VOID : 'void', VOID_PTR : 'void*', INT : 'int', UINT : 'unsigned', INT64 : 'int64_t', UINT64 : 'uint64_t', DOUBLE : 'double',
FLOAT : 'float', STRING : 'Z3_string', STRING_PTR : 'Z3_string_ptr', BOOL : 'bool', SYMBOL : 'Z3_symbol',
PRINT_MODE : 'Z3_ast_print_mode', ERROR_CODE : 'Z3_error_code', CHAR: 'char', CHAR_PTR: 'Z3_char_ptr'
PRINT_MODE : 'Z3_ast_print_mode', ERROR_CODE : 'Z3_error_code', CHAR: 'char', CHAR_PTR: 'Z3_char_ptr', LBOOL : 'Z3_lbool'
}
Type2PyStr = { VOID_PTR : 'ctypes.c_void_p', INT : 'ctypes.c_int', UINT : 'ctypes.c_uint', INT64 : 'ctypes.c_longlong',
UINT64 : 'ctypes.c_ulonglong', DOUBLE : 'ctypes.c_double', FLOAT : 'ctypes.c_float',
STRING : 'ctypes.c_char_p', STRING_PTR : 'ctypes.POINTER(ctypes.c_char_p)', BOOL : 'ctypes.c_bool', SYMBOL : 'Symbol',
PRINT_MODE : 'ctypes.c_uint', ERROR_CODE : 'ctypes.c_uint', CHAR : 'ctypes.c_char', CHAR_PTR: 'ctypes.POINTER(ctypes.c_char)'
PRINT_MODE : 'ctypes.c_uint', ERROR_CODE : 'ctypes.c_uint', CHAR : 'ctypes.c_char', CHAR_PTR: 'ctypes.POINTER(ctypes.c_char)', LBOOL : 'ctypes.c_int'
}
# Mapping to .NET types
Type2Dotnet = { VOID : 'void', VOID_PTR : 'IntPtr', INT : 'int', UINT : 'uint', INT64 : 'Int64', UINT64 : 'UInt64', DOUBLE : 'double',
FLOAT : 'float', STRING : 'string', STRING_PTR : 'byte**', BOOL : 'byte', SYMBOL : 'IntPtr',
PRINT_MODE : 'uint', ERROR_CODE : 'uint', CHAR : 'char', CHAR_PTR : 'IntPtr' }
PRINT_MODE : 'uint', ERROR_CODE : 'uint', CHAR : 'char', CHAR_PTR : 'IntPtr', LBOOL : 'int' }
# Mapping to ML types
Type2ML = { VOID : 'unit', VOID_PTR : 'ptr', INT : 'int', UINT : 'int', INT64 : 'int', UINT64 : 'int', DOUBLE : 'float',
FLOAT : 'float', STRING : 'string', STRING_PTR : 'char**',
BOOL : 'bool', SYMBOL : 'z3_symbol', PRINT_MODE : 'int', ERROR_CODE : 'int', CHAR : 'char', CHAR_PTR : 'string' }
BOOL : 'bool', SYMBOL : 'z3_symbol', PRINT_MODE : 'int', ERROR_CODE : 'int', CHAR : 'char', CHAR_PTR : 'string', LBOOL : 'int' }
Closures = []
@ -522,11 +523,11 @@ def mk_dotnet_wrappers(dotnet):
Type2Java = { VOID : 'void', VOID_PTR : 'long', INT : 'int', UINT : 'int', INT64 : 'long', UINT64 : 'long', DOUBLE : 'double',
FLOAT : 'float', STRING : 'String', STRING_PTR : 'StringPtr',
BOOL : 'boolean', SYMBOL : 'long', PRINT_MODE : 'int', ERROR_CODE : 'int', CHAR : 'char', CHAR_PTR : 'long' }
BOOL : 'boolean', SYMBOL : 'long', PRINT_MODE : 'int', ERROR_CODE : 'int', CHAR : 'char', CHAR_PTR : 'long', LBOOL : 'int' }
Type2JavaW = { VOID : 'void', VOID_PTR : 'jlong', INT : 'jint', UINT : 'jint', INT64 : 'jlong', UINT64 : 'jlong', DOUBLE : 'jdouble',
FLOAT : 'jfloat', STRING : 'jstring', STRING_PTR : 'jobject',
BOOL : 'jboolean', SYMBOL : 'jlong', PRINT_MODE : 'jint', ERROR_CODE : 'jint', CHAR : 'jchar', CHAR_PTR : 'jlong'}
BOOL : 'jboolean', SYMBOL : 'jlong', PRINT_MODE : 'jint', ERROR_CODE : 'jint', CHAR : 'jchar', CHAR_PTR : 'jlong', LBOOL : 'jint'}
def type2java(ty):
global Type2Java
@ -1024,6 +1025,9 @@ def def_API(name, result, params):
elif ty == VOID_PTR:
log_c.write(" P(0);\n")
exe_c.write("in.get_obj_addr(%s)" % i)
elif ty == LBOOL:
log_c.write(" I(static_cast<signed>(a%s));\n" % i)
exe_c.write("static_cast<%s>(in.get_int(%s))" % (type2str(ty), i))
elif ty == PRINT_MODE or ty == ERROR_CODE:
log_c.write(" U(static_cast<unsigned>(a%s));\n" % i)
exe_c.write("static_cast<%s>(in.get_uint(%s))" % (type2str(ty), i))
@ -1298,7 +1302,7 @@ def ml_unwrap(t, ts, s):
return '(' + ts + ') String_val(' + s + ')'
elif t == BOOL or (type2str(t) == 'bool'):
return '(' + ts + ') Bool_val(' + s + ')'
elif t == INT or t == PRINT_MODE or t == ERROR_CODE:
elif t == INT or t == PRINT_MODE or t == ERROR_CODE or t == LBOOL:
return '(' + ts + ') Int_val(' + s + ')'
elif t == UINT:
return '(' + ts + ') Unsigned_int_val(' + s + ')'
@ -1319,7 +1323,7 @@ def ml_set_wrap(t, d, n):
return d + ' = Val_unit;'
elif t == BOOL or (type2str(t) == 'bool'):
return d + ' = Val_bool(' + n + ');'
elif t == INT or t == UINT or t == PRINT_MODE or t == ERROR_CODE:
elif t == INT or t == UINT or t == PRINT_MODE or t == ERROR_CODE or t == LBOOL:
return d + ' = Val_int(' + n + ');'
elif t == INT64 or t == UINT64:
return d + ' = Val_long(' + n + ');'
@ -1332,7 +1336,7 @@ def ml_set_wrap(t, d, n):
return '*(' + pts + '*)Data_custom_val(' + d + ') = ' + n + ';'
def ml_alloc_and_store(t, lhs, rhs):
if t == VOID or t == BOOL or t == INT or t == UINT or t == PRINT_MODE or t == ERROR_CODE or t == INT64 or t == UINT64 or t == DOUBLE or t == STRING or (type2str(t) == 'bool'):
if t == VOID or t == BOOL or t == INT or t == UINT or t == PRINT_MODE or t == ERROR_CODE or t == INT64 or t == UINT64 or t == DOUBLE or t == STRING or t == LBOOL or (type2str(t) == 'bool'):
return ml_set_wrap(t, lhs, rhs)
else:
pts = ml_plus_type(type2str(t))

View file

@ -981,6 +981,14 @@ extern "C" {
Z3_CATCH;
}
void Z3_API Z3_solver_next_split(Z3_context c, Z3_solver_callback cb, Z3_ast t, unsigned idx, Z3_lbool phase) {
Z3_TRY;
LOG_Z3_solver_next_split(c, cb, t, idx, phase);
RESET_ERROR_CODE();
reinterpret_cast<user_propagator::callback*>(cb)->next_split_cb(to_expr(t), idx, (lbool)phase);
Z3_CATCH;
}
Z3_func_decl Z3_API Z3_solver_propagate_declare(Z3_context c, Z3_symbol name, unsigned n, Z3_sort* domain, Z3_sort range) {
Z3_TRY;
LOG_Z3_solver_propagate_declare(c, name, n, domain, range);

View file

@ -4158,6 +4158,11 @@ namespace z3 {
virtual void decide(expr& /*val*/, unsigned& /*bit*/, Z3_lbool& /*is_pos*/) {}
void next_split(expr const & e, unsigned idx, Z3_lbool phase) {
assert(cb);
Z3_solver_next_split(ctx(), cb, e, idx, phase);
}
/**
\brief tracks \c e by a unique identifier that is returned by the call.

View file

@ -4871,7 +4871,7 @@ extern "C" {
/**
\brief Return \c Z3_L_TRUE if \c a is true, \c Z3_L_FALSE if it is false, and \c Z3_L_UNDEF otherwise.
def_API('Z3_get_bool_value', INT, (_in(CONTEXT), _in(AST)))
def_API('Z3_get_bool_value', LBOOL, (_in(CONTEXT), _in(AST)))
*/
Z3_lbool Z3_API Z3_get_bool_value(Z3_context c, Z3_ast a);
@ -6827,6 +6827,13 @@ extern "C" {
*/
void Z3_API Z3_solver_propagate_decide(Z3_context c, Z3_solver s, Z3_decide_eh decide_eh);
/**
Sets the next expression to split on
def_API('Z3_solver_next_split', VOID, (_in(CONTEXT), _in(SOLVER_CALLBACK), _in(AST), _in(UINT), _in(LBOOL)))
*/
void Z3_API Z3_solver_next_split(Z3_context c, Z3_solver_callback cb, Z3_ast t, unsigned idx, Z3_lbool phase);
/**
Create uninterpreted function declaration for the user propagator.
When expressions using the function are created by the solver invoke a callback
@ -6885,7 +6892,7 @@ extern "C" {
\sa Z3_solver_check_assumptions
def_API('Z3_solver_check', INT, (_in(CONTEXT), _in(SOLVER)))
def_API('Z3_solver_check', LBOOL, (_in(CONTEXT), _in(SOLVER)))
*/
Z3_lbool Z3_API Z3_solver_check(Z3_context c, Z3_solver s);
@ -6898,7 +6905,7 @@ extern "C" {
\sa Z3_solver_check
def_API('Z3_solver_check_assumptions', INT, (_in(CONTEXT), _in(SOLVER), _in(UINT), _in_array(2, AST)))
def_API('Z3_solver_check_assumptions', LBOOL, (_in(CONTEXT), _in(SOLVER), _in(UINT), _in_array(2, AST)))
*/
Z3_lbool Z3_API Z3_solver_check_assumptions(Z3_context c, Z3_solver s,
unsigned num_assumptions, Z3_ast const assumptions[]);
@ -6919,7 +6926,7 @@ extern "C" {
A side-effect of the function is a satisfiability check on the assertions on the solver that is passed in.
The function return \c Z3_L_FALSE if the current assertions are not satisfiable.
def_API('Z3_get_implied_equalities', INT, (_in(CONTEXT), _in(SOLVER), _in(UINT), _in_array(2, AST), _out_array(2, UINT)))
def_API('Z3_get_implied_equalities', LBOOL, (_in(CONTEXT), _in(SOLVER), _in(UINT), _in_array(2, AST), _out_array(2, UINT)))
*/
Z3_lbool Z3_API Z3_get_implied_equalities(Z3_context c,
Z3_solver s,
@ -6930,7 +6937,7 @@ extern "C" {
/**
\brief retrieve consequences from solver that determine values of the supplied function symbols.
def_API('Z3_solver_get_consequences', INT, (_in(CONTEXT), _in(SOLVER), _in(AST_VECTOR), _in(AST_VECTOR), _in(AST_VECTOR)))
def_API('Z3_solver_get_consequences', LBOOL, (_in(CONTEXT), _in(SOLVER), _in(AST_VECTOR), _in(AST_VECTOR), _in(AST_VECTOR)))
*/
Z3_lbool Z3_API Z3_solver_get_consequences(Z3_context c,

View file

@ -109,7 +109,7 @@ extern "C" {
- \c Z3_L_TRUE if the query is satisfiable. Obtain the answer by calling #Z3_fixedpoint_get_answer.
- \c Z3_L_UNDEF if the query was interrupted, timed out or otherwise failed.
def_API('Z3_fixedpoint_query', INT, (_in(CONTEXT), _in(FIXEDPOINT), _in(AST)))
def_API('Z3_fixedpoint_query', LBOOL, (_in(CONTEXT), _in(FIXEDPOINT), _in(AST)))
*/
Z3_lbool Z3_API Z3_fixedpoint_query(Z3_context c, Z3_fixedpoint d, Z3_ast query);
@ -123,7 +123,7 @@ extern "C" {
- \c Z3_L_TRUE if the query is satisfiable. Obtain the answer by calling #Z3_fixedpoint_get_answer.
- \c Z3_L_UNDEF if the query was interrupted, timed out or otherwise failed.
def_API('Z3_fixedpoint_query_relations', INT, (_in(CONTEXT), _in(FIXEDPOINT), _in(UINT), _in_array(2, FUNC_DECL)))
def_API('Z3_fixedpoint_query_relations', LBOOL, (_in(CONTEXT), _in(FIXEDPOINT), _in(UINT), _in_array(2, FUNC_DECL)))
*/
Z3_lbool Z3_API Z3_fixedpoint_query_relations(
Z3_context c, Z3_fixedpoint d,

View file

@ -151,7 +151,7 @@ extern "C" {
\sa Z3_optimize_get_statistics
\sa Z3_optimize_get_unsat_core
def_API('Z3_optimize_check', INT, (_in(CONTEXT), _in(OPTIMIZE), _in(UINT), _in_array(2, AST)))
def_API('Z3_optimize_check', LBOOL, (_in(CONTEXT), _in(OPTIMIZE), _in(UINT), _in_array(2, AST)))
*/
Z3_lbool Z3_API Z3_optimize_check(Z3_context c, Z3_optimize o, unsigned num_assumptions, Z3_ast const assumptions[]);

View file

@ -40,7 +40,7 @@ extern "C" {
- \c Z3_L_TRUE if the query is satisfiable. Obtain the answer by calling #Z3_fixedpoint_get_answer.
- \c Z3_L_UNDEF if the query was interrupted, timed out or otherwise failed.
def_API('Z3_fixedpoint_query_from_lvl', INT, (_in(CONTEXT), _in(FIXEDPOINT), _in(AST), _in(UINT)))
def_API('Z3_fixedpoint_query_from_lvl', LBOOL, (_in(CONTEXT), _in(FIXEDPOINT), _in(AST), _in(UINT)))
*/
Z3_lbool Z3_API Z3_fixedpoint_query_from_lvl (Z3_context c,Z3_fixedpoint d, Z3_ast query, unsigned lvl);

View file

@ -91,8 +91,10 @@ namespace sat {
virtual double get_reward(literal l, ext_constraint_idx idx, literal_occs_fun& occs) const { return 0; }
virtual void get_antecedents(literal l, ext_justification_idx idx, literal_vector & r, bool probing) = 0;
virtual bool is_extended_binary(ext_justification_idx idx, literal_vector & r) { return false; }
virtual void asserted(literal l) {};
virtual void set_eliminated(bool_var v) {};
virtual bool decide(bool_var& var, lbool& phase) { return false; }
virtual bool get_case_split(bool_var& var, lbool& phase) { return false; }
virtual void asserted(literal l) {}
virtual void set_eliminated(bool_var v) {}
virtual check_result check() = 0;
virtual lbool resolve_conflict() { return l_undef; } // stores result in sat::solver::m_lemma
virtual void push() = 0;

View file

@ -1661,49 +1661,66 @@ namespace sat {
return null_bool_var;
}
bool solver::decide() {
bool_var next = next_var();
if (next == null_bool_var)
return false;
push();
m_stats.m_decision++;
bool solver::guess(bool_var next) {
lbool lphase = m_ext ? m_ext->get_phase(next) : l_undef;
bool phase = lphase == l_true;
if (lphase == l_undef) {
switch (m_config.m_phase) {
if (lphase != l_undef)
return lphase == l_true;
switch (m_config.m_phase) {
case PS_ALWAYS_TRUE:
phase = true;
break;
return true;
case PS_ALWAYS_FALSE:
phase = false;
break;
return false;
case PS_BASIC_CACHING:
phase = m_phase[next];
break;
return m_phase[next];
case PS_FROZEN:
phase = m_best_phase[next];
break;
return m_best_phase[next];
case PS_SAT_CACHING:
if (m_search_state == s_unsat) {
phase = m_phase[next];
}
else {
phase = m_best_phase[next];
}
break;
if (m_search_state == s_unsat)
return m_phase[next];
return m_best_phase[next];
case PS_RANDOM:
phase = (m_rand() % 2) == 0;
break;
return (m_rand() % 2) == 0;
default:
UNREACHABLE();
phase = false;
break;
}
return false;
}
}
literal next_lit(next, !phase);
bool solver::decide() {
bool_var next;
lbool phase = l_undef;
bool is_pos;
bool used_queue = false;
if (!m_ext || !m_ext->get_case_split(next, phase)) {
used_queue = true;
next = next_var();
if (next == null_bool_var)
return false;
}
push();
m_stats.m_decision++;
if (phase == l_undef)
phase = guess(next) ? l_true: l_false;
literal next_lit(next, false);
if (m_ext && m_ext->decide(next, phase)) {
if (used_queue)
m_case_split_queue.unassign_var_eh(next);
next_lit = literal(next, false);
}
if (phase == l_undef)
is_pos = guess(next);
else
is_pos = phase == l_true;
if (!is_pos)
next_lit.neg();
TRACE("sat_decide", tout << scope_lvl() << ": next-case-split: " << next_lit << "\n";);
assign_scoped(next_lit);
return true;

View file

@ -541,6 +541,7 @@ namespace sat {
unsigned m_next_simplify { 0 };
bool m_simplify_enabled { true };
bool m_restart_enabled { true };
bool guess(bool_var next);
bool decide();
bool_var next_var();
lbool bounded_search();

View file

@ -56,6 +56,18 @@ namespace user_solver {
void solver::register_cb(expr* e) {
add_expr(e);
}
void solver::next_split_cb(expr* e, unsigned idx, lbool phase) {
if (e == nullptr) {
m_next_split_expr = nullptr;
return;
}
force_push();
ctx.internalize(e, false);
m_next_split_expr = e;
m_next_split_idx = idx;
m_next_split_phase = phase;
}
sat::check_result solver::check() {
if (!(bool)m_final_eh)
@ -72,6 +84,41 @@ namespace user_solver {
m_id2justification.setx(v, sat::literal_vector(num_lits, jlits), sat::literal_vector());
m_fixed_eh(m_user_context, this, var2expr(v), value);
}
bool solver::decide(sat::bool_var& var, lbool& phase) {
if (!m_decide_eh)
return false;
euf::enode* original_enode = bool_var2enode(var);
if (!is_attached_to_var(original_enode))
return false;
unsigned new_bit = 0; // ignored; currently no bv-support
expr* e = bool_var2expr(var);
m_decide_eh(m_user_context, this, &e, &new_bit, &phase);
euf::enode* new_enode = ctx.get_enode(e);
if (original_enode == new_enode)
return false;
var = new_enode->bool_var();
return true;
}
bool solver::get_case_split(sat::bool_var& var, lbool &phase){
if (!m_next_split_expr)
return false;
euf::enode* n = ctx.get_enode(m_next_split_expr);
var = n->bool_var();
phase = m_next_split_phase;
m_next_split_expr = nullptr;
return true;
}
void solver::asserted(sat::literal lit) {
if (!m_fixed_eh)

View file

@ -56,24 +56,28 @@ namespace user_solver {
void reset() { memset(this, 0, sizeof(*this)); }
};
void* m_user_context;
user_propagator::push_eh_t m_push_eh;
user_propagator::pop_eh_t m_pop_eh;
user_propagator::fresh_eh_t m_fresh_eh;
user_propagator::final_eh_t m_final_eh;
user_propagator::fixed_eh_t m_fixed_eh;
user_propagator::eq_eh_t m_eq_eh;
user_propagator::eq_eh_t m_diseq_eh;
user_propagator::created_eh_t m_created_eh;
void* m_user_context;
user_propagator::push_eh_t m_push_eh = nullptr;
user_propagator::pop_eh_t m_pop_eh = nullptr;
user_propagator::fresh_eh_t m_fresh_eh = nullptr;
user_propagator::final_eh_t m_final_eh = nullptr;
user_propagator::fixed_eh_t m_fixed_eh = nullptr;
user_propagator::eq_eh_t m_eq_eh = nullptr;
user_propagator::eq_eh_t m_diseq_eh = nullptr;
user_propagator::created_eh_t m_created_eh = nullptr;
user_propagator::decide_eh_t m_decide_eh = nullptr;
user_propagator::context_obj* m_api_context = nullptr;
unsigned m_qhead = 0;
vector<prop_info> m_prop;
unsigned_vector m_prop_lim;
vector<sat::literal_vector> m_id2justification;
sat::literal_vector m_lits;
euf::enode_pair_vector m_eqs;
unsigned_vector m_fixed_ids;
stats m_stats;
unsigned m_qhead = 0;
vector<prop_info> m_prop;
unsigned_vector m_prop_lim;
vector<sat::literal_vector> m_id2justification;
sat::literal_vector m_lits;
euf::enode_pair_vector m_eqs;
unsigned_vector m_fixed_ids;
stats m_stats;
expr* m_next_split_expr = nullptr;
unsigned m_next_split_idx;
lbool m_next_split_phase;
struct justification {
unsigned m_propagation_index { 0 };
@ -94,7 +98,7 @@ namespace user_solver {
void propagate_consequence(prop_info const& prop);
void propagate_new_fixed(prop_info const& prop);
void validate_propagation();
void validate_propagation();
bool visit(expr* e) override;
bool visited(expr* e) override;
@ -126,14 +130,19 @@ namespace user_solver {
void register_eq(user_propagator::eq_eh_t& eq_eh) { m_eq_eh = eq_eh; }
void register_diseq(user_propagator::eq_eh_t& diseq_eh) { m_diseq_eh = diseq_eh; }
void register_created(user_propagator::created_eh_t& created_eh) { m_created_eh = created_eh; }
void register_decide(user_propagator::decide_eh_t& decide_eh) { m_decide_eh = decide_eh; }
bool has_fixed() const { return (bool)m_fixed_eh; }
void propagate_cb(unsigned num_fixed, expr* const* fixed_ids, unsigned num_eqs, expr* const* lhs, expr* const* rhs, expr* conseq) override;
void register_cb(expr* e) override;
void next_split_cb(expr* e, unsigned idx, lbool phase) override;
void new_fixed_eh(euf::theory_var v, expr* value, unsigned num_lits, sat::literal const* jlits);
bool decide(sat::bool_var& var, lbool& phase) override;
bool get_case_split(sat::bool_var& var, lbool &phase) override;
void asserted(sat::literal lit) override;
sat::check_result check() override;
void push_core() override;

View file

@ -1848,24 +1848,30 @@ namespace smt {
}
}
bool_var var;
lbool phase = l_undef;
m_case_split_queue->next_case_split(var, phase);
bool is_pos;
bool used_queue = false;
if (!has_split_candidate(var, is_pos)) {
lbool phase = l_undef;
m_case_split_queue->next_case_split(var, phase);
used_queue = true;
if (var == null_bool_var)
return false;
if (var == null_bool_var) {
return false;
TRACE_CODE({
static unsigned counter = 0;
counter++;
if (counter % 100 == 0) {
TRACE("activity_profile",
for (unsigned i=0; i<get_num_bool_vars(); i++) {
tout << get_activity(i) << " ";
}
tout << "\n";);
}});
is_pos = guess(var, phase);
}
TRACE_CODE({
static unsigned counter = 0;
counter++;
if (counter % 100 == 0) {
TRACE("activity_profile",
for (unsigned i=0; i<get_num_bool_vars(); i++) {
tout << get_activity(i) << " ";
}
tout << "\n";);
}});
m_stats.m_num_decisions++;
push_scope();
@ -1873,13 +1879,13 @@ namespace smt {
TRACE("decide_detail", tout << mk_pp(bool_var2expr(var), m) << "\n";);
bool is_pos = guess(var, phase);
literal l(var, false);
bool_var original_choice = var;
if (decide_user_interference(var, is_pos)) {
m_case_split_queue->unassign_var_eh(original_choice);
if (used_queue)
m_case_split_queue->unassign_var_eh(original_choice);
l = literal(var, false);
}
@ -2905,8 +2911,14 @@ namespace smt {
return m_user_propagator && m_user_propagator->has_fixed() && n->get_th_var(m_user_propagator->get_family_id()) != null_theory_var;
}
bool context::has_split_candidate(bool_var& var, bool& is_pos) {
if (!m_user_propagator)
return false;
return m_user_propagator->get_case_split(var, is_pos);
}
bool context::decide_user_interference(bool_var& var, bool& is_pos) {
if (!m_user_propagator || !m_user_propagator->has_decide())
if (!m_user_propagator)
return false;
bool_var old = var;
m_user_propagator->decide(var, is_pos);

View file

@ -1754,6 +1754,8 @@ namespace smt {
bool watches_fixed(enode* n) const;
bool has_split_candidate(bool_var& var, bool& is_pos);
bool decide_user_interference(bool_var& var, bool& is_pos);
void assign_fixed(enode* n, expr* val, unsigned sz, literal const* explain);

View file

@ -102,6 +102,17 @@ void theory_user_propagator::register_cb(expr* e) {
add_expr(e, true);
}
void theory_user_propagator::next_split_cb(expr* e, unsigned idx, lbool phase) {
if (e == nullptr) { // clear
m_next_split_expr = nullptr;
return;
}
ensure_enode(e);
m_next_split_expr = e;
m_next_split_idx = idx;
m_next_split_phase = phase;
}
theory * theory_user_propagator::mk_fresh(context * new_ctx) {
auto* th = alloc(theory_user_propagator, *new_ctx);
void* ctx;
@ -156,8 +167,24 @@ void theory_user_propagator::new_fixed_eh(theory_var v, expr* value, unsigned nu
}
}
void theory_user_propagator::decide(bool_var& var, bool& is_pos) {
bool_var theory_user_propagator::enode_to_bool(enode* n, unsigned bit) {
if (n->is_bool()) {
// expression is a boolean
bool_var new_var = ctx.enode2bool_var(n);
if (ctx.get_assignment(new_var) == l_undef)
return new_var;
return null_bool_var;
}
// expression is a bit-vector
bv_util bv(m);
auto th_bv = (theory_bv*)ctx.get_theory(bv.get_fid());
return th_bv->get_first_unassigned(bit, n);
}
void theory_user_propagator::decide(bool_var& var, bool& is_pos) {
if (!m_decide_eh)
return;
const bool_var_data& d = ctx.get_bdata(var);
if (!d.is_enode() && !d.is_theory_atom())
@ -216,25 +243,28 @@ void theory_user_propagator::decide(bool_var& var, bool& is_pos) {
return;
}
if (new_enode->is_bool()) {
// expression was set to a boolean
bool_var new_var = ctx.enode2bool_var(new_enode);
if (ctx.get_assignment(new_var) == l_undef) {
var = new_var;
}
}
else {
// expression was set to a bit-vector
auto th_bv = (theory_bv*)ctx.get_theory(bv.get_fid());
bool_var new_var = th_bv->get_first_unassigned(new_bit, new_enode);
if (new_var != null_bool_var)
var = new_var;
}
// get unassigned variable from enode
var = enode_to_bool(new_enode, new_bit);
// in case the callback did not decide on a truth value -> let Z3 decide
is_pos = ctx.guess(var, phase);
}
bool theory_user_propagator::get_case_split(bool_var& var, bool& is_pos){
if (!m_next_split_expr)
return false;
enode* n = ctx.get_enode(m_next_split_expr);
var = enode_to_bool(n, m_next_split_idx);
if (var == null_bool_var)
return false;
is_pos = ctx.guess(var, m_next_split_phase);
m_next_split_expr = nullptr;
return true;
}
void theory_user_propagator::push_scope_eh() {
++m_num_scopes;
}

View file

@ -83,6 +83,9 @@ namespace smt {
expr_ref_vector m_to_add;
unsigned_vector m_to_add_lim;
unsigned m_to_add_qhead = 0;
expr* m_next_split_expr = nullptr;
unsigned m_next_split_idx;
lbool m_next_split_phase;
expr* var2expr(theory_var v) { return m_var2expr.get(v); }
theory_var expr2var(expr* e) { check_defined(e); return m_expr2var[e->get_id()]; }
@ -95,6 +98,8 @@ namespace smt {
void propagate_consequence(prop_info const& prop);
void propagate_new_fixed(prop_info const& prop);
bool_var enode_to_bool(enode* n, unsigned bit);
public:
theory_user_propagator(context& ctx);
@ -125,13 +130,14 @@ namespace smt {
void register_decide(user_propagator::decide_eh_t& decide_eh) { m_decide_eh = decide_eh; }
bool has_fixed() const { return (bool)m_fixed_eh; }
bool has_decide() const { return (bool)m_decide_eh; }
void propagate_cb(unsigned num_fixed, expr* const* fixed_ids, unsigned num_eqs, expr* const* lhs, expr* const* rhs, expr* conseq) override;
void register_cb(expr* e) override;
void next_split_cb(expr* e, unsigned idx, lbool phase) override;
void new_fixed_eh(theory_var v, expr* value, unsigned num_lits, literal const* jlits);
void decide(bool_var& var, bool& is_pos);
bool get_case_split(bool_var& var, bool& is_pos);
theory * mk_fresh(context * new_ctx) override;
bool internalize_atom(app* atom, bool gate_ctx) override;
@ -154,5 +160,5 @@ namespace smt {
bool can_propagate() override;
void propagate() override;
void display(std::ostream& out) const override {}
};
};
};

View file

@ -11,6 +11,7 @@ namespace user_propagator {
virtual ~callback() = default;
virtual void propagate_cb(unsigned num_fixed, expr* const* fixed_ids, unsigned num_eqs, expr* const* eq_lhs, expr* const* eq_rhs, expr* conseq) = 0;
virtual void register_cb(expr* e) = 0;
virtual void next_split_cb(expr* e, unsigned idx, lbool phase) = 0;
};
class context_obj {