diff --git a/scripts/update_api.py b/scripts/update_api.py index 6c932d8ca..45899ddef 100755 --- a/scripts/update_api.py +++ b/scripts/update_api.py @@ -61,6 +61,7 @@ DOUBLE = 12 FLOAT = 13 CHAR = 14 CHAR_PTR = 15 +LBOOL = 16 FIRST_FN_ID = 50 @@ -74,25 +75,25 @@ def is_fn(ty): Type2Str = { VOID : 'void', VOID_PTR : 'void*', INT : 'int', UINT : 'unsigned', INT64 : 'int64_t', UINT64 : 'uint64_t', DOUBLE : 'double', FLOAT : 'float', STRING : 'Z3_string', STRING_PTR : 'Z3_string_ptr', BOOL : 'bool', SYMBOL : 'Z3_symbol', - PRINT_MODE : 'Z3_ast_print_mode', ERROR_CODE : 'Z3_error_code', CHAR: 'char', CHAR_PTR: 'Z3_char_ptr' + PRINT_MODE : 'Z3_ast_print_mode', ERROR_CODE : 'Z3_error_code', CHAR: 'char', CHAR_PTR: 'Z3_char_ptr', LBOOL : 'Z3_lbool' } Type2PyStr = { VOID_PTR : 'ctypes.c_void_p', INT : 'ctypes.c_int', UINT : 'ctypes.c_uint', INT64 : 'ctypes.c_longlong', UINT64 : 'ctypes.c_ulonglong', DOUBLE : 'ctypes.c_double', FLOAT : 'ctypes.c_float', STRING : 'ctypes.c_char_p', STRING_PTR : 'ctypes.POINTER(ctypes.c_char_p)', BOOL : 'ctypes.c_bool', SYMBOL : 'Symbol', - PRINT_MODE : 'ctypes.c_uint', ERROR_CODE : 'ctypes.c_uint', CHAR : 'ctypes.c_char', CHAR_PTR: 'ctypes.POINTER(ctypes.c_char)' + PRINT_MODE : 'ctypes.c_uint', ERROR_CODE : 'ctypes.c_uint', CHAR : 'ctypes.c_char', CHAR_PTR: 'ctypes.POINTER(ctypes.c_char)', LBOOL : 'ctypes.c_int' } # Mapping to .NET types Type2Dotnet = { VOID : 'void', VOID_PTR : 'IntPtr', INT : 'int', UINT : 'uint', INT64 : 'Int64', UINT64 : 'UInt64', DOUBLE : 'double', FLOAT : 'float', STRING : 'string', STRING_PTR : 'byte**', BOOL : 'byte', SYMBOL : 'IntPtr', - PRINT_MODE : 'uint', ERROR_CODE : 'uint', CHAR : 'char', CHAR_PTR : 'IntPtr' } + PRINT_MODE : 'uint', ERROR_CODE : 'uint', CHAR : 'char', CHAR_PTR : 'IntPtr', LBOOL : 'int' } # Mapping to ML types Type2ML = { VOID : 'unit', VOID_PTR : 'ptr', INT : 'int', UINT : 'int', INT64 : 'int', UINT64 : 'int', DOUBLE : 'float', FLOAT : 'float', STRING : 'string', STRING_PTR : 'char**', - BOOL : 'bool', SYMBOL : 'z3_symbol', PRINT_MODE : 'int', ERROR_CODE : 'int', CHAR : 'char', CHAR_PTR : 'string' } + BOOL : 'bool', SYMBOL : 'z3_symbol', PRINT_MODE : 'int', ERROR_CODE : 'int', CHAR : 'char', CHAR_PTR : 'string', LBOOL : 'int' } Closures = [] @@ -522,11 +523,11 @@ def mk_dotnet_wrappers(dotnet): Type2Java = { VOID : 'void', VOID_PTR : 'long', INT : 'int', UINT : 'int', INT64 : 'long', UINT64 : 'long', DOUBLE : 'double', FLOAT : 'float', STRING : 'String', STRING_PTR : 'StringPtr', - BOOL : 'boolean', SYMBOL : 'long', PRINT_MODE : 'int', ERROR_CODE : 'int', CHAR : 'char', CHAR_PTR : 'long' } + BOOL : 'boolean', SYMBOL : 'long', PRINT_MODE : 'int', ERROR_CODE : 'int', CHAR : 'char', CHAR_PTR : 'long', LBOOL : 'int' } Type2JavaW = { VOID : 'void', VOID_PTR : 'jlong', INT : 'jint', UINT : 'jint', INT64 : 'jlong', UINT64 : 'jlong', DOUBLE : 'jdouble', FLOAT : 'jfloat', STRING : 'jstring', STRING_PTR : 'jobject', - BOOL : 'jboolean', SYMBOL : 'jlong', PRINT_MODE : 'jint', ERROR_CODE : 'jint', CHAR : 'jchar', CHAR_PTR : 'jlong'} + BOOL : 'jboolean', SYMBOL : 'jlong', PRINT_MODE : 'jint', ERROR_CODE : 'jint', CHAR : 'jchar', CHAR_PTR : 'jlong', LBOOL : 'jint'} def type2java(ty): global Type2Java @@ -1024,6 +1025,9 @@ def def_API(name, result, params): elif ty == VOID_PTR: log_c.write(" P(0);\n") exe_c.write("in.get_obj_addr(%s)" % i) + elif ty == LBOOL: + log_c.write(" I(static_cast(a%s));\n" % i) + exe_c.write("static_cast<%s>(in.get_int(%s))" % (type2str(ty), i)) elif ty == PRINT_MODE or ty == ERROR_CODE: log_c.write(" U(static_cast(a%s));\n" % i) exe_c.write("static_cast<%s>(in.get_uint(%s))" % (type2str(ty), i)) @@ -1298,7 +1302,7 @@ def ml_unwrap(t, ts, s): return '(' + ts + ') String_val(' + s + ')' elif t == BOOL or (type2str(t) == 'bool'): return '(' + ts + ') Bool_val(' + s + ')' - elif t == INT or t == PRINT_MODE or t == ERROR_CODE: + elif t == INT or t == PRINT_MODE or t == ERROR_CODE or t == LBOOL: return '(' + ts + ') Int_val(' + s + ')' elif t == UINT: return '(' + ts + ') Unsigned_int_val(' + s + ')' @@ -1319,7 +1323,7 @@ def ml_set_wrap(t, d, n): return d + ' = Val_unit;' elif t == BOOL or (type2str(t) == 'bool'): return d + ' = Val_bool(' + n + ');' - elif t == INT or t == UINT or t == PRINT_MODE or t == ERROR_CODE: + elif t == INT or t == UINT or t == PRINT_MODE or t == ERROR_CODE or t == LBOOL: return d + ' = Val_int(' + n + ');' elif t == INT64 or t == UINT64: return d + ' = Val_long(' + n + ');' @@ -1332,7 +1336,7 @@ def ml_set_wrap(t, d, n): return '*(' + pts + '*)Data_custom_val(' + d + ') = ' + n + ';' def ml_alloc_and_store(t, lhs, rhs): - if t == VOID or t == BOOL or t == INT or t == UINT or t == PRINT_MODE or t == ERROR_CODE or t == INT64 or t == UINT64 or t == DOUBLE or t == STRING or (type2str(t) == 'bool'): + if t == VOID or t == BOOL or t == INT or t == UINT or t == PRINT_MODE or t == ERROR_CODE or t == INT64 or t == UINT64 or t == DOUBLE or t == STRING or t == LBOOL or (type2str(t) == 'bool'): return ml_set_wrap(t, lhs, rhs) else: pts = ml_plus_type(type2str(t)) diff --git a/src/api/api_solver.cpp b/src/api/api_solver.cpp index 3a28da0fd..8e5ebcf2a 100644 --- a/src/api/api_solver.cpp +++ b/src/api/api_solver.cpp @@ -981,6 +981,14 @@ extern "C" { Z3_CATCH; } + void Z3_API Z3_solver_next_split(Z3_context c, Z3_solver_callback cb, Z3_ast t, unsigned idx, Z3_lbool phase) { + Z3_TRY; + LOG_Z3_solver_next_split(c, cb, t, idx, phase); + RESET_ERROR_CODE(); + reinterpret_cast(cb)->next_split_cb(to_expr(t), idx, (lbool)phase); + Z3_CATCH; + } + Z3_func_decl Z3_API Z3_solver_propagate_declare(Z3_context c, Z3_symbol name, unsigned n, Z3_sort* domain, Z3_sort range) { Z3_TRY; LOG_Z3_solver_propagate_declare(c, name, n, domain, range); diff --git a/src/api/c++/z3++.h b/src/api/c++/z3++.h index 10116d751..fc0601fe5 100644 --- a/src/api/c++/z3++.h +++ b/src/api/c++/z3++.h @@ -4158,6 +4158,11 @@ namespace z3 { virtual void decide(expr& /*val*/, unsigned& /*bit*/, Z3_lbool& /*is_pos*/) {} + void next_split(expr const & e, unsigned idx, Z3_lbool phase) { + assert(cb); + Z3_solver_next_split(ctx(), cb, e, idx, phase); + } + /** \brief tracks \c e by a unique identifier that is returned by the call. diff --git a/src/api/z3_api.h b/src/api/z3_api.h index 740e304ad..8c60a09d1 100644 --- a/src/api/z3_api.h +++ b/src/api/z3_api.h @@ -4871,7 +4871,7 @@ extern "C" { /** \brief Return \c Z3_L_TRUE if \c a is true, \c Z3_L_FALSE if it is false, and \c Z3_L_UNDEF otherwise. - def_API('Z3_get_bool_value', INT, (_in(CONTEXT), _in(AST))) + def_API('Z3_get_bool_value', LBOOL, (_in(CONTEXT), _in(AST))) */ Z3_lbool Z3_API Z3_get_bool_value(Z3_context c, Z3_ast a); @@ -6827,6 +6827,13 @@ extern "C" { */ void Z3_API Z3_solver_propagate_decide(Z3_context c, Z3_solver s, Z3_decide_eh decide_eh); + /** + Sets the next expression to split on + + def_API('Z3_solver_next_split', VOID, (_in(CONTEXT), _in(SOLVER_CALLBACK), _in(AST), _in(UINT), _in(LBOOL))) + */ + void Z3_API Z3_solver_next_split(Z3_context c, Z3_solver_callback cb, Z3_ast t, unsigned idx, Z3_lbool phase); + /** Create uninterpreted function declaration for the user propagator. When expressions using the function are created by the solver invoke a callback @@ -6885,7 +6892,7 @@ extern "C" { \sa Z3_solver_check_assumptions - def_API('Z3_solver_check', INT, (_in(CONTEXT), _in(SOLVER))) + def_API('Z3_solver_check', LBOOL, (_in(CONTEXT), _in(SOLVER))) */ Z3_lbool Z3_API Z3_solver_check(Z3_context c, Z3_solver s); @@ -6898,7 +6905,7 @@ extern "C" { \sa Z3_solver_check - def_API('Z3_solver_check_assumptions', INT, (_in(CONTEXT), _in(SOLVER), _in(UINT), _in_array(2, AST))) + def_API('Z3_solver_check_assumptions', LBOOL, (_in(CONTEXT), _in(SOLVER), _in(UINT), _in_array(2, AST))) */ Z3_lbool Z3_API Z3_solver_check_assumptions(Z3_context c, Z3_solver s, unsigned num_assumptions, Z3_ast const assumptions[]); @@ -6919,7 +6926,7 @@ extern "C" { A side-effect of the function is a satisfiability check on the assertions on the solver that is passed in. The function return \c Z3_L_FALSE if the current assertions are not satisfiable. - def_API('Z3_get_implied_equalities', INT, (_in(CONTEXT), _in(SOLVER), _in(UINT), _in_array(2, AST), _out_array(2, UINT))) + def_API('Z3_get_implied_equalities', LBOOL, (_in(CONTEXT), _in(SOLVER), _in(UINT), _in_array(2, AST), _out_array(2, UINT))) */ Z3_lbool Z3_API Z3_get_implied_equalities(Z3_context c, Z3_solver s, @@ -6930,7 +6937,7 @@ extern "C" { /** \brief retrieve consequences from solver that determine values of the supplied function symbols. - def_API('Z3_solver_get_consequences', INT, (_in(CONTEXT), _in(SOLVER), _in(AST_VECTOR), _in(AST_VECTOR), _in(AST_VECTOR))) + def_API('Z3_solver_get_consequences', LBOOL, (_in(CONTEXT), _in(SOLVER), _in(AST_VECTOR), _in(AST_VECTOR), _in(AST_VECTOR))) */ Z3_lbool Z3_API Z3_solver_get_consequences(Z3_context c, diff --git a/src/api/z3_fixedpoint.h b/src/api/z3_fixedpoint.h index 5eadaaf46..6d4737d1b 100644 --- a/src/api/z3_fixedpoint.h +++ b/src/api/z3_fixedpoint.h @@ -109,7 +109,7 @@ extern "C" { - \c Z3_L_TRUE if the query is satisfiable. Obtain the answer by calling #Z3_fixedpoint_get_answer. - \c Z3_L_UNDEF if the query was interrupted, timed out or otherwise failed. - def_API('Z3_fixedpoint_query', INT, (_in(CONTEXT), _in(FIXEDPOINT), _in(AST))) + def_API('Z3_fixedpoint_query', LBOOL, (_in(CONTEXT), _in(FIXEDPOINT), _in(AST))) */ Z3_lbool Z3_API Z3_fixedpoint_query(Z3_context c, Z3_fixedpoint d, Z3_ast query); @@ -123,7 +123,7 @@ extern "C" { - \c Z3_L_TRUE if the query is satisfiable. Obtain the answer by calling #Z3_fixedpoint_get_answer. - \c Z3_L_UNDEF if the query was interrupted, timed out or otherwise failed. - def_API('Z3_fixedpoint_query_relations', INT, (_in(CONTEXT), _in(FIXEDPOINT), _in(UINT), _in_array(2, FUNC_DECL))) + def_API('Z3_fixedpoint_query_relations', LBOOL, (_in(CONTEXT), _in(FIXEDPOINT), _in(UINT), _in_array(2, FUNC_DECL))) */ Z3_lbool Z3_API Z3_fixedpoint_query_relations( Z3_context c, Z3_fixedpoint d, diff --git a/src/api/z3_optimization.h b/src/api/z3_optimization.h index 889db94ea..8bf0e9da5 100644 --- a/src/api/z3_optimization.h +++ b/src/api/z3_optimization.h @@ -151,7 +151,7 @@ extern "C" { \sa Z3_optimize_get_statistics \sa Z3_optimize_get_unsat_core - def_API('Z3_optimize_check', INT, (_in(CONTEXT), _in(OPTIMIZE), _in(UINT), _in_array(2, AST))) + def_API('Z3_optimize_check', LBOOL, (_in(CONTEXT), _in(OPTIMIZE), _in(UINT), _in_array(2, AST))) */ Z3_lbool Z3_API Z3_optimize_check(Z3_context c, Z3_optimize o, unsigned num_assumptions, Z3_ast const assumptions[]); diff --git a/src/api/z3_spacer.h b/src/api/z3_spacer.h index dd1028433..1f7a7ef34 100644 --- a/src/api/z3_spacer.h +++ b/src/api/z3_spacer.h @@ -40,7 +40,7 @@ extern "C" { - \c Z3_L_TRUE if the query is satisfiable. Obtain the answer by calling #Z3_fixedpoint_get_answer. - \c Z3_L_UNDEF if the query was interrupted, timed out or otherwise failed. - def_API('Z3_fixedpoint_query_from_lvl', INT, (_in(CONTEXT), _in(FIXEDPOINT), _in(AST), _in(UINT))) + def_API('Z3_fixedpoint_query_from_lvl', LBOOL, (_in(CONTEXT), _in(FIXEDPOINT), _in(AST), _in(UINT))) */ Z3_lbool Z3_API Z3_fixedpoint_query_from_lvl (Z3_context c,Z3_fixedpoint d, Z3_ast query, unsigned lvl); diff --git a/src/sat/sat_extension.h b/src/sat/sat_extension.h index 147bc90cc..1bb37b7d7 100644 --- a/src/sat/sat_extension.h +++ b/src/sat/sat_extension.h @@ -91,8 +91,10 @@ namespace sat { virtual double get_reward(literal l, ext_constraint_idx idx, literal_occs_fun& occs) const { return 0; } virtual void get_antecedents(literal l, ext_justification_idx idx, literal_vector & r, bool probing) = 0; virtual bool is_extended_binary(ext_justification_idx idx, literal_vector & r) { return false; } - virtual void asserted(literal l) {}; - virtual void set_eliminated(bool_var v) {}; + virtual bool decide(bool_var& var, lbool& phase) { return false; } + virtual bool get_case_split(bool_var& var, lbool& phase) { return false; } + virtual void asserted(literal l) {} + virtual void set_eliminated(bool_var v) {} virtual check_result check() = 0; virtual lbool resolve_conflict() { return l_undef; } // stores result in sat::solver::m_lemma virtual void push() = 0; diff --git a/src/sat/sat_solver.cpp b/src/sat/sat_solver.cpp index f8f187fc2..729b3c1f4 100644 --- a/src/sat/sat_solver.cpp +++ b/src/sat/sat_solver.cpp @@ -1661,49 +1661,66 @@ namespace sat { return null_bool_var; } - - bool solver::decide() { - bool_var next = next_var(); - if (next == null_bool_var) - return false; - push(); - m_stats.m_decision++; + + bool solver::guess(bool_var next) { lbool lphase = m_ext ? m_ext->get_phase(next) : l_undef; - bool phase = lphase == l_true; - if (lphase == l_undef) { - switch (m_config.m_phase) { + if (lphase != l_undef) + return lphase == l_true; + switch (m_config.m_phase) { case PS_ALWAYS_TRUE: - phase = true; - break; + return true; case PS_ALWAYS_FALSE: - phase = false; - break; + return false; case PS_BASIC_CACHING: - phase = m_phase[next]; - break; + return m_phase[next]; case PS_FROZEN: - phase = m_best_phase[next]; - break; + return m_best_phase[next]; case PS_SAT_CACHING: - if (m_search_state == s_unsat) { - phase = m_phase[next]; - } - else { - phase = m_best_phase[next]; - } - break; + if (m_search_state == s_unsat) + return m_phase[next]; + return m_best_phase[next]; case PS_RANDOM: - phase = (m_rand() % 2) == 0; - break; + return (m_rand() % 2) == 0; default: UNREACHABLE(); - phase = false; - break; - } + return false; } + } - literal next_lit(next, !phase); + bool solver::decide() { + bool_var next; + lbool phase = l_undef; + bool is_pos; + bool used_queue = false; + if (!m_ext || !m_ext->get_case_split(next, phase)) { + used_queue = true; + next = next_var(); + if (next == null_bool_var) + return false; + } + push(); + m_stats.m_decision++; + + if (phase == l_undef) + phase = guess(next) ? l_true: l_false; + + literal next_lit(next, false); + + if (m_ext && m_ext->decide(next, phase)) { + if (used_queue) + m_case_split_queue.unassign_var_eh(next); + next_lit = literal(next, false); + } + + if (phase == l_undef) + is_pos = guess(next); + else + is_pos = phase == l_true; + + if (!is_pos) + next_lit.neg(); + TRACE("sat_decide", tout << scope_lvl() << ": next-case-split: " << next_lit << "\n";); assign_scoped(next_lit); return true; diff --git a/src/sat/sat_solver.h b/src/sat/sat_solver.h index f7113609f..e5b13af98 100644 --- a/src/sat/sat_solver.h +++ b/src/sat/sat_solver.h @@ -541,6 +541,7 @@ namespace sat { unsigned m_next_simplify { 0 }; bool m_simplify_enabled { true }; bool m_restart_enabled { true }; + bool guess(bool_var next); bool decide(); bool_var next_var(); lbool bounded_search(); diff --git a/src/sat/smt/user_solver.cpp b/src/sat/smt/user_solver.cpp index d24af253e..9e2ea3eab 100644 --- a/src/sat/smt/user_solver.cpp +++ b/src/sat/smt/user_solver.cpp @@ -56,6 +56,18 @@ namespace user_solver { void solver::register_cb(expr* e) { add_expr(e); } + + void solver::next_split_cb(expr* e, unsigned idx, lbool phase) { + if (e == nullptr) { + m_next_split_expr = nullptr; + return; + } + force_push(); + ctx.internalize(e, false); + m_next_split_expr = e; + m_next_split_idx = idx; + m_next_split_phase = phase; + } sat::check_result solver::check() { if (!(bool)m_final_eh) @@ -72,6 +84,41 @@ namespace user_solver { m_id2justification.setx(v, sat::literal_vector(num_lits, jlits), sat::literal_vector()); m_fixed_eh(m_user_context, this, var2expr(v), value); } + + bool solver::decide(sat::bool_var& var, lbool& phase) { + + if (!m_decide_eh) + return false; + + euf::enode* original_enode = bool_var2enode(var); + + if (!is_attached_to_var(original_enode)) + return false; + + unsigned new_bit = 0; // ignored; currently no bv-support + expr* e = bool_var2expr(var); + + m_decide_eh(m_user_context, this, &e, &new_bit, &phase); + + euf::enode* new_enode = ctx.get_enode(e); + + if (original_enode == new_enode) + return false; + + var = new_enode->bool_var(); + return true; + } + + bool solver::get_case_split(sat::bool_var& var, lbool &phase){ + if (!m_next_split_expr) + return false; + + euf::enode* n = ctx.get_enode(m_next_split_expr); + var = n->bool_var(); + phase = m_next_split_phase; + m_next_split_expr = nullptr; + return true; + } void solver::asserted(sat::literal lit) { if (!m_fixed_eh) diff --git a/src/sat/smt/user_solver.h b/src/sat/smt/user_solver.h index 13948db81..951b97fb6 100644 --- a/src/sat/smt/user_solver.h +++ b/src/sat/smt/user_solver.h @@ -56,24 +56,28 @@ namespace user_solver { void reset() { memset(this, 0, sizeof(*this)); } }; - void* m_user_context; - user_propagator::push_eh_t m_push_eh; - user_propagator::pop_eh_t m_pop_eh; - user_propagator::fresh_eh_t m_fresh_eh; - user_propagator::final_eh_t m_final_eh; - user_propagator::fixed_eh_t m_fixed_eh; - user_propagator::eq_eh_t m_eq_eh; - user_propagator::eq_eh_t m_diseq_eh; - user_propagator::created_eh_t m_created_eh; + void* m_user_context; + user_propagator::push_eh_t m_push_eh = nullptr; + user_propagator::pop_eh_t m_pop_eh = nullptr; + user_propagator::fresh_eh_t m_fresh_eh = nullptr; + user_propagator::final_eh_t m_final_eh = nullptr; + user_propagator::fixed_eh_t m_fixed_eh = nullptr; + user_propagator::eq_eh_t m_eq_eh = nullptr; + user_propagator::eq_eh_t m_diseq_eh = nullptr; + user_propagator::created_eh_t m_created_eh = nullptr; + user_propagator::decide_eh_t m_decide_eh = nullptr; user_propagator::context_obj* m_api_context = nullptr; - unsigned m_qhead = 0; - vector m_prop; - unsigned_vector m_prop_lim; - vector m_id2justification; - sat::literal_vector m_lits; - euf::enode_pair_vector m_eqs; - unsigned_vector m_fixed_ids; - stats m_stats; + unsigned m_qhead = 0; + vector m_prop; + unsigned_vector m_prop_lim; + vector m_id2justification; + sat::literal_vector m_lits; + euf::enode_pair_vector m_eqs; + unsigned_vector m_fixed_ids; + stats m_stats; + expr* m_next_split_expr = nullptr; + unsigned m_next_split_idx; + lbool m_next_split_phase; struct justification { unsigned m_propagation_index { 0 }; @@ -94,7 +98,7 @@ namespace user_solver { void propagate_consequence(prop_info const& prop); void propagate_new_fixed(prop_info const& prop); - void validate_propagation(); + void validate_propagation(); bool visit(expr* e) override; bool visited(expr* e) override; @@ -126,14 +130,19 @@ namespace user_solver { void register_eq(user_propagator::eq_eh_t& eq_eh) { m_eq_eh = eq_eh; } void register_diseq(user_propagator::eq_eh_t& diseq_eh) { m_diseq_eh = diseq_eh; } void register_created(user_propagator::created_eh_t& created_eh) { m_created_eh = created_eh; } + void register_decide(user_propagator::decide_eh_t& decide_eh) { m_decide_eh = decide_eh; } bool has_fixed() const { return (bool)m_fixed_eh; } void propagate_cb(unsigned num_fixed, expr* const* fixed_ids, unsigned num_eqs, expr* const* lhs, expr* const* rhs, expr* conseq) override; void register_cb(expr* e) override; + void next_split_cb(expr* e, unsigned idx, lbool phase) override; void new_fixed_eh(euf::theory_var v, expr* value, unsigned num_lits, sat::literal const* jlits); + bool decide(sat::bool_var& var, lbool& phase) override; + bool get_case_split(sat::bool_var& var, lbool &phase) override; + void asserted(sat::literal lit) override; sat::check_result check() override; void push_core() override; diff --git a/src/smt/smt_context.cpp b/src/smt/smt_context.cpp index 1b6088c1f..6652804ea 100644 --- a/src/smt/smt_context.cpp +++ b/src/smt/smt_context.cpp @@ -1848,24 +1848,30 @@ namespace smt { } } bool_var var; - lbool phase = l_undef; - m_case_split_queue->next_case_split(var, phase); + bool is_pos; + bool used_queue = false; + + if (!has_split_candidate(var, is_pos)) { + lbool phase = l_undef; + m_case_split_queue->next_case_split(var, phase); + used_queue = true; + if (var == null_bool_var) + return false; - if (var == null_bool_var) { - return false; + TRACE_CODE({ + static unsigned counter = 0; + counter++; + if (counter % 100 == 0) { + TRACE("activity_profile", + for (unsigned i=0; iunassign_var_eh(original_choice); + if (used_queue) + m_case_split_queue->unassign_var_eh(original_choice); l = literal(var, false); } @@ -2905,8 +2911,14 @@ namespace smt { return m_user_propagator && m_user_propagator->has_fixed() && n->get_th_var(m_user_propagator->get_family_id()) != null_theory_var; } + bool context::has_split_candidate(bool_var& var, bool& is_pos) { + if (!m_user_propagator) + return false; + return m_user_propagator->get_case_split(var, is_pos); + } + bool context::decide_user_interference(bool_var& var, bool& is_pos) { - if (!m_user_propagator || !m_user_propagator->has_decide()) + if (!m_user_propagator) return false; bool_var old = var; m_user_propagator->decide(var, is_pos); diff --git a/src/smt/smt_context.h b/src/smt/smt_context.h index e2fc3a35f..8d98fc60e 100644 --- a/src/smt/smt_context.h +++ b/src/smt/smt_context.h @@ -1754,6 +1754,8 @@ namespace smt { bool watches_fixed(enode* n) const; + bool has_split_candidate(bool_var& var, bool& is_pos); + bool decide_user_interference(bool_var& var, bool& is_pos); void assign_fixed(enode* n, expr* val, unsigned sz, literal const* explain); diff --git a/src/smt/theory_user_propagator.cpp b/src/smt/theory_user_propagator.cpp index 04bb4b248..780023fab 100644 --- a/src/smt/theory_user_propagator.cpp +++ b/src/smt/theory_user_propagator.cpp @@ -102,6 +102,17 @@ void theory_user_propagator::register_cb(expr* e) { add_expr(e, true); } +void theory_user_propagator::next_split_cb(expr* e, unsigned idx, lbool phase) { + if (e == nullptr) { // clear + m_next_split_expr = nullptr; + return; + } + ensure_enode(e); + m_next_split_expr = e; + m_next_split_idx = idx; + m_next_split_phase = phase; +} + theory * theory_user_propagator::mk_fresh(context * new_ctx) { auto* th = alloc(theory_user_propagator, *new_ctx); void* ctx; @@ -156,8 +167,24 @@ void theory_user_propagator::new_fixed_eh(theory_var v, expr* value, unsigned nu } } -void theory_user_propagator::decide(bool_var& var, bool& is_pos) { +bool_var theory_user_propagator::enode_to_bool(enode* n, unsigned bit) { + if (n->is_bool()) { + // expression is a boolean + bool_var new_var = ctx.enode2bool_var(n); + if (ctx.get_assignment(new_var) == l_undef) + return new_var; + return null_bool_var; + } + // expression is a bit-vector + bv_util bv(m); + auto th_bv = (theory_bv*)ctx.get_theory(bv.get_fid()); + return th_bv->get_first_unassigned(bit, n); +} +void theory_user_propagator::decide(bool_var& var, bool& is_pos) { + if (!m_decide_eh) + return; + const bool_var_data& d = ctx.get_bdata(var); if (!d.is_enode() && !d.is_theory_atom()) @@ -216,25 +243,28 @@ void theory_user_propagator::decide(bool_var& var, bool& is_pos) { return; } - if (new_enode->is_bool()) { - // expression was set to a boolean - bool_var new_var = ctx.enode2bool_var(new_enode); - if (ctx.get_assignment(new_var) == l_undef) { - var = new_var; - } - } - else { - // expression was set to a bit-vector - auto th_bv = (theory_bv*)ctx.get_theory(bv.get_fid()); - bool_var new_var = th_bv->get_first_unassigned(new_bit, new_enode); - if (new_var != null_bool_var) - var = new_var; - } + // get unassigned variable from enode + var = enode_to_bool(new_enode, new_bit); // in case the callback did not decide on a truth value -> let Z3 decide is_pos = ctx.guess(var, phase); } +bool theory_user_propagator::get_case_split(bool_var& var, bool& is_pos){ + if (!m_next_split_expr) + return false; + enode* n = ctx.get_enode(m_next_split_expr); + + var = enode_to_bool(n, m_next_split_idx); + + if (var == null_bool_var) + return false; + + is_pos = ctx.guess(var, m_next_split_phase); + m_next_split_expr = nullptr; + return true; +} + void theory_user_propagator::push_scope_eh() { ++m_num_scopes; } diff --git a/src/smt/theory_user_propagator.h b/src/smt/theory_user_propagator.h index bf82883e4..ba9900848 100644 --- a/src/smt/theory_user_propagator.h +++ b/src/smt/theory_user_propagator.h @@ -83,6 +83,9 @@ namespace smt { expr_ref_vector m_to_add; unsigned_vector m_to_add_lim; unsigned m_to_add_qhead = 0; + expr* m_next_split_expr = nullptr; + unsigned m_next_split_idx; + lbool m_next_split_phase; expr* var2expr(theory_var v) { return m_var2expr.get(v); } theory_var expr2var(expr* e) { check_defined(e); return m_expr2var[e->get_id()]; } @@ -95,6 +98,8 @@ namespace smt { void propagate_consequence(prop_info const& prop); void propagate_new_fixed(prop_info const& prop); + + bool_var enode_to_bool(enode* n, unsigned bit); public: theory_user_propagator(context& ctx); @@ -125,13 +130,14 @@ namespace smt { void register_decide(user_propagator::decide_eh_t& decide_eh) { m_decide_eh = decide_eh; } bool has_fixed() const { return (bool)m_fixed_eh; } - bool has_decide() const { return (bool)m_decide_eh; } void propagate_cb(unsigned num_fixed, expr* const* fixed_ids, unsigned num_eqs, expr* const* lhs, expr* const* rhs, expr* conseq) override; void register_cb(expr* e) override; + void next_split_cb(expr* e, unsigned idx, lbool phase) override; void new_fixed_eh(theory_var v, expr* value, unsigned num_lits, literal const* jlits); void decide(bool_var& var, bool& is_pos); + bool get_case_split(bool_var& var, bool& is_pos); theory * mk_fresh(context * new_ctx) override; bool internalize_atom(app* atom, bool gate_ctx) override; @@ -154,5 +160,5 @@ namespace smt { bool can_propagate() override; void propagate() override; void display(std::ostream& out) const override {} - }; +}; }; diff --git a/src/tactic/user_propagator_base.h b/src/tactic/user_propagator_base.h index 3f4af0329..46b5eda8a 100644 --- a/src/tactic/user_propagator_base.h +++ b/src/tactic/user_propagator_base.h @@ -11,6 +11,7 @@ namespace user_propagator { virtual ~callback() = default; virtual void propagate_cb(unsigned num_fixed, expr* const* fixed_ids, unsigned num_eqs, expr* const* eq_lhs, expr* const* eq_rhs, expr* conseq) = 0; virtual void register_cb(expr* e) = 0; + virtual void next_split_cb(expr* e, unsigned idx, lbool phase) = 0; }; class context_obj {