diff --git a/src/api/api_solver.cpp b/src/api/api_solver.cpp index c9eda8712..a0803516f 100644 --- a/src/api/api_solver.cpp +++ b/src/api/api_solver.cpp @@ -975,6 +975,14 @@ extern "C" { Z3_CATCH; } + void Z3_API Z3_solver_propagate_decide(Z3_context c, Z3_solver s, Z3_decide_eh decide_eh) { + Z3_TRY; + RESET_ERROR_CODE(); + user_propagator::decide_eh_t c = (void(*)(void*, user_propagator::callback*, expr*&, unsigned&, lbool&))decide_eh; + to_solver_ref(s)->user_propagate_register_decide(c); + Z3_CATCH; + } + Z3_func_decl Z3_API Z3_solver_propagate_declare(Z3_context c, Z3_symbol name, unsigned n, Z3_sort* domain, Z3_sort range) { Z3_TRY; LOG_Z3_solver_propagate_declare(c, name, n, domain, range); diff --git a/src/api/c++/z3++.h b/src/api/c++/z3++.h index 101fa04a3..57a87415d 100644 --- a/src/api/c++/z3++.h +++ b/src/api/c++/z3++.h @@ -3943,11 +3943,13 @@ namespace z3 { typedef std::function final_eh_t; typedef std::function eq_eh_t; typedef std::function created_eh_t; + typedef std::function decide_eh_t; final_eh_t m_final_eh; eq_eh_t m_eq_eh; fixed_eh_t m_fixed_eh; created_eh_t m_created_eh; + decide_eh_t m_decide_eh; solver* s; context* c; std::vector subcontexts; @@ -4009,8 +4011,15 @@ namespace z3 { expr e(p->ctx(), _e); p->m_created_eh(e); } - - + + static void decide_eh(void* _p, Z3_solver_callback cb, Z3_ast& _val, unsigned& bit, Z3_lbool& is_pos) { + user_propagator_base* p = static_cast(_p); + scoped_cb _cb(p, cb); + expr val(p->ctx(), _val); + p->m_decide_eh(val, bit, is_pos); + _val = val; + } + public: user_propagator_base(context& c) : s(nullptr), c(&c) {} @@ -4119,6 +4128,22 @@ namespace z3 { Z3_solver_propagate_created(ctx(), *s, created_eh); } } + + void register_decide(decide_eh_t& c) { + m_decide_eh = c; + if (s) { + Z3_solver_propagate_decide(ctx(), *s, decide_eh); + } + } + + void register_decide() { + m_decide_eh = [this](expr& val, unsigned& bit, Z3_lbool& is_pos) { + decide(val, bit, is_pos); + }; + if (s) { + Z3_solver_propagate_decide(ctx(), *s, decide_eh); + } + } virtual void fixed(expr const& /*id*/, expr const& /*e*/) { } @@ -4127,6 +4152,8 @@ namespace z3 { virtual void final() { } virtual void created(expr const& /*e*/) {} + + virtual void decide(expr& /*val*/, unsigned& /*bit*/, Z3_lbool& /*is_pos*/) {} /** \brief tracks \c e by a unique identifier that is returned by the call. diff --git a/src/api/z3_api.h b/src/api/z3_api.h index 1388d0ab1..1eb2164d5 100644 --- a/src/api/z3_api.h +++ b/src/api/z3_api.h @@ -1444,6 +1444,7 @@ Z3_DECLARE_CLOSURE(Z3_fixed_eh, void, (void* ctx, Z3_solver_callback cb, Z3_as Z3_DECLARE_CLOSURE(Z3_eq_eh, void, (void* ctx, Z3_solver_callback cb, Z3_ast s, Z3_ast t)); Z3_DECLARE_CLOSURE(Z3_final_eh, void, (void* ctx, Z3_solver_callback cb)); Z3_DECLARE_CLOSURE(Z3_created_eh, void, (void* ctx, Z3_solver_callback cb, Z3_ast t)); +Z3_DECLARE_CLOSURE(Z3_decide_eh, void, (void* ctx, Z3_solver_callback cb, Z3_ast&, unsigned&, Z3_lbool&)); /** @@ -6758,6 +6759,14 @@ extern "C" { * The registered function appears at the top level and is created using \ref Z3_propagate_solver_declare. */ void Z3_API Z3_solver_propagate_created(Z3_context c, Z3_solver s, Z3_created_eh created_eh); + + /** + * \brief register a callback when a the solver decides to split on a registered expression + * The callback may set passed expression to another registered expression which will be selected instead. + * In case the expression is a bitvector the bit to split on is determined by the bit argument and the + * truth-value to try first is given by is_pos + */ + void Z3_API Z3_solver_propagate_decide(Z3_context c, Z3_solver s, Z3_decide_eh decide_eh); /** Create uninterpreted function declaration for the user propagator. diff --git a/src/smt/smt_context.cpp b/src/smt/smt_context.cpp index a9b44ab3c..3e444bec7 100644 --- a/src/smt/smt_context.cpp +++ b/src/smt/smt_context.cpp @@ -1766,6 +1766,70 @@ namespace smt { m_bvar_inc *= INV_ACTIVITY_LIMIT; } + /** + \brief Returns a truth value for the given variable + */ + bool context::guess(bool_var var, lbool phase) { + if (is_quantifier(m_bool_var2expr[var])) { + // Overriding any decision on how to assign the quantifier. + // assigning a quantifier to false is equivalent to make it irrelevant. + phase = l_false; + } + literal l(var, false); + + if (phase != l_undef) + return phase == l_true; + + bool_var_data & d = m_bdata[var]; + if (d.try_true_first()) + return true; + switch (m_fparams.m_phase_selection) { + case PS_THEORY: + if (m_phase_cache_on && d.m_phase_available) { + return m_bdata[var].m_phase; + } + if (!m_phase_cache_on && d.is_theory_atom()) { + theory * th = m_theories.get_plugin(d.get_theory()); + lbool th_phase = th->get_phase(var); + if (th_phase != l_undef) { + return th_phase == l_true; + } + } + if (track_occs()) { + if (m_lit_occs[l.index()] == 0) { + return false; + } + if (m_lit_occs[(~l).index()] == 0) { + return true; + } + } + return m_phase_default; + case PS_CACHING: + case PS_CACHING_CONSERVATIVE: + case PS_CACHING_CONSERVATIVE2: + if (m_phase_cache_on && d.m_phase_available) { + TRACE("phase_selection", tout << "using cached value, is_pos: " << m_bdata[var].m_phase << ", var: p" << var << "\n";); + return m_bdata[var].m_phase; + } + else { + TRACE("phase_selection", tout << "setting to false\n";); + return m_phase_default; + } + case PS_ALWAYS_FALSE: + return false; + case PS_ALWAYS_TRUE: + return true; + case PS_RANDOM: + return m_random() % 2 == 0; + case PS_OCCURRENCE: { + return m_lit_occs[l.index()] > m_lit_occs[(~l).index()]; + } + default: + UNREACHABLE(); + return false; + } + } + /** \brief Execute next case split, return false if there are no more case splits to be performed. @@ -1807,81 +1871,15 @@ namespace smt { TRACE("decide", tout << "splitting, lvl: " << m_scope_lvl << "\n";); TRACE("decide_detail", tout << mk_pp(bool_var2expr(var), m) << "\n";); - - bool is_pos; - - if (is_quantifier(m_bool_var2expr[var])) { - // Overriding any decision on how to assign the quantifier. - // assigning a quantifier to false is equivalent to make it irrelevant. - phase = l_false; - } + + bool is_pos = guess(var, phase); literal l(var, false); - if (phase != l_undef) { - is_pos = phase == l_true; - } - else { - bool_var_data & d = m_bdata[var]; - if (d.try_true_first()) { - is_pos = true; - } - else { - switch (m_fparams.m_phase_selection) { - case PS_THEORY: - if (m_phase_cache_on && d.m_phase_available) { - is_pos = m_bdata[var].m_phase; - break; - } - if (!m_phase_cache_on && d.is_theory_atom()) { - theory * th = m_theories.get_plugin(d.get_theory()); - lbool th_phase = th->get_phase(var); - if (th_phase != l_undef) { - is_pos = th_phase == l_true; - break; - } - } - if (track_occs()) { - if (m_lit_occs[l.index()] == 0) { - is_pos = false; - break; - } - if (m_lit_occs[(~l).index()] == 0) { - is_pos = true; - break; - } - } - is_pos = m_phase_default; - break; - case PS_CACHING: - case PS_CACHING_CONSERVATIVE: - case PS_CACHING_CONSERVATIVE2: - if (m_phase_cache_on && d.m_phase_available) { - TRACE("phase_selection", tout << "using cached value, is_pos: " << m_bdata[var].m_phase << ", var: p" << var << "\n";); - is_pos = m_bdata[var].m_phase; - } - else { - TRACE("phase_selection", tout << "setting to false\n";); - is_pos = m_phase_default; - } - break; - case PS_ALWAYS_FALSE: - is_pos = false; - break; - case PS_ALWAYS_TRUE: - is_pos = true; - break; - case PS_RANDOM: - is_pos = (m_random() % 2 == 0); - break; - case PS_OCCURRENCE: { - is_pos = m_lit_occs[l.index()] > m_lit_occs[(~l).index()]; - break; - } - default: - is_pos = false; - UNREACHABLE(); - } - } + bool_var original_choice = var; + + if (decide_user_interference(var, is_pos)) { + m_case_split_queue->unassign_var_eh(original_choice); + l = literal(var, false); } if (!is_pos) l.neg(); @@ -1889,7 +1887,7 @@ namespace smt { assign(l, b_justification::mk_axiom(), true); return true; } - + /** \brief Update counter that is used to enable/disable phase caching. */ @@ -2906,6 +2904,14 @@ namespace smt { return m_user_propagator && m_user_propagator->has_fixed() && n->get_th_var(m_user_propagator->get_family_id()) != null_theory_var; } + bool context::decide_user_interference(bool_var& var, bool& is_pos) { + if (!m_user_propagator || !m_user_propagator->has_decide()) + return false; + bool_var old = var; + m_user_propagator->decide(var, is_pos); + return old != var; + } + void context::assign_fixed(enode* n, expr* val, unsigned sz, literal const* explain) { theory_var v = n->get_th_var(m_user_propagator->get_family_id()); m_user_propagator->new_fixed_eh(v, val, sz, explain); @@ -3042,7 +3048,8 @@ namespace smt { } } } - } else { + } + else { literal_vector new_case_split; for (unsigned i = 0; i < num_lits; ++i) { literal l = lits[i]; diff --git a/src/smt/smt_context.h b/src/smt/smt_context.h index 637c2171b..696a5cc39 100644 --- a/src/smt/smt_context.h +++ b/src/smt/smt_context.h @@ -1134,6 +1134,8 @@ namespace smt { enode * get_enode_eq_to(func_decl * f, unsigned num_args, enode * const * args); + bool guess(bool_var var, lbool phase); + protected: bool decide(); @@ -1738,8 +1740,16 @@ namespace smt { m_user_propagator->register_created(r); } + void user_propagate_register_decide(user_propagator::decide_eh_t& r) { + if (!m_user_propagator) + throw default_exception("user propagator must be initialized"); + m_user_propagator->register_decide(r); + } + bool watches_fixed(enode* n) const; + bool decide_user_interference(bool_var& var, bool& is_pos); + void assign_fixed(enode* n, expr* val, unsigned sz, literal const* explain); void assign_fixed(enode* n, expr* val, literal_vector const& explain) { diff --git a/src/smt/smt_kernel.cpp b/src/smt/smt_kernel.cpp index 2d082170c..8f442596c 100644 --- a/src/smt/smt_kernel.cpp +++ b/src/smt/smt_kernel.cpp @@ -284,4 +284,8 @@ namespace smt { m_imp->m_kernel.user_propagate_register_created(r); } + void kernel::user_propagate_register_decide(user_propagator::decide_eh_t& r) { + m_imp->m_kernel.user_propagate_register_decide(r); + } + }; diff --git a/src/smt/smt_kernel.h b/src/smt/smt_kernel.h index 068bd1b52..4fa840f5e 100644 --- a/src/smt/smt_kernel.h +++ b/src/smt/smt_kernel.h @@ -311,6 +311,8 @@ namespace smt { void user_propagate_register_created(user_propagator::created_eh_t& r); + void user_propagate_register_decide(user_propagator::decide_eh_t& r); + /** \brief Return a reference to smt::context. This breaks abstractions. diff --git a/src/smt/smt_solver.cpp b/src/smt/smt_solver.cpp index 344cf9e6f..5064ed7ef 100644 --- a/src/smt/smt_solver.cpp +++ b/src/smt/smt_solver.cpp @@ -244,6 +244,10 @@ namespace { m_context.user_propagate_register_created(c); } + void user_propagate_register_decide(user_propagator::decide_eh_t& c) override { + m_context.user_propagate_register_decide(c); + } + struct scoped_minimize_core { smt_solver& s; expr_ref_vector m_assumptions; diff --git a/src/smt/tactic/smt_tactic_core.cpp b/src/smt/tactic/smt_tactic_core.cpp index 072e1ed24..9c5fc1c8e 100644 --- a/src/smt/tactic/smt_tactic_core.cpp +++ b/src/smt/tactic/smt_tactic_core.cpp @@ -322,6 +322,7 @@ public: user_propagator::eq_eh_t m_eq_eh; user_propagator::eq_eh_t m_diseq_eh; user_propagator::created_eh_t m_created_eh; + user_propagator::decide_eh_t m_decide_eh; void user_propagate_delay_init() { @@ -333,6 +334,7 @@ public: if (m_eq_eh) m_ctx->user_propagate_register_eq(m_eq_eh); if (m_diseq_eh) m_ctx->user_propagate_register_diseq(m_diseq_eh); if (m_created_eh) m_ctx->user_propagate_register_created(m_created_eh); + if (m_decide_eh) m_ctx->user_propagate_register_decide(m_decide_eh); for (expr* v : m_vars) m_ctx->user_propagate_register_expr(v); diff --git a/src/smt/theory_bv.cpp b/src/smt/theory_bv.cpp index 341daedec..682f4d6f9 100644 --- a/src/smt/theory_bv.cpp +++ b/src/smt/theory_bv.cpp @@ -531,7 +531,6 @@ namespace smt { return true; } - bool theory_bv::get_fixed_value(theory_var v, numeral & result) const { result.reset(); unsigned i = 0; @@ -1821,6 +1820,39 @@ namespace smt { st.update("bv dynamic eqs", m_stats.m_num_eq_dynamic); } + theory_bv::var_enode_pos theory_bv::get_bv_with_theory(bool_var v, theory_id id) const { + atom* a = get_bv2a(v); + svector vec; + if (!a->is_bit()) + return var_enode_pos(nullptr, UINT32_MAX); + bit_atom * b = static_cast(a); + var_pos_occ * curr = b->m_occs; + while (curr) { + enode* n = get_enode(curr->m_var); + if (n->get_th_var(id) != null_theory_var) + return var_enode_pos(n, curr->m_idx); + curr = curr->m_next; + } + return var_enode_pos(nullptr, UINT32_MAX); + } + + bool_var theory_bv::get_first_unassigned(unsigned start_bit, enode* n) const { + theory_var v = n->get_th_var(get_family_id()); + auto& bits = m_bits[v]; + unsigned sz = bits.size(); + + for (unsigned i = start_bit; i < sz; ++i) { + if (ctx.get_assignment(bits[i].var()) != l_undef) + return bits[i].var(); + } + for (unsigned i = 0; i < start_bit; ++i) { + if (ctx.get_assignment(bits[i].var()) != l_undef) + return bits[i].var(); + } + + return null_bool_var; + } + bool theory_bv::check_assignment(theory_var v) { if (!is_root(v)) return true; diff --git a/src/smt/theory_bv.h b/src/smt/theory_bv.h index ebca3fa83..d73b7a008 100644 --- a/src/smt/theory_bv.h +++ b/src/smt/theory_bv.h @@ -260,6 +260,9 @@ namespace smt { smt_params const& params() const; public: + + typedef std::pair var_enode_pos; + theory_bv(context& ctx); ~theory_bv() override; @@ -284,6 +287,9 @@ namespace smt { bool get_fixed_value(app* x, numeral & result) const; bool is_fixed_propagated(theory_var v, expr_ref& val, literal_vector& explain) override; + var_enode_pos get_bv_with_theory(bool_var v, theory_id id) const; + bool_var get_first_unassigned(unsigned start_bit, enode* n) const; + bool check_assignment(theory_var v); bool check_invariant(); bool check_zero_one_bits(theory_var v); diff --git a/src/smt/theory_user_propagator.cpp b/src/smt/theory_user_propagator.cpp index f783f22fb..bf8722701 100644 --- a/src/smt/theory_user_propagator.cpp +++ b/src/smt/theory_user_propagator.cpp @@ -17,6 +17,7 @@ Author: #include "ast/ast_pp.h" +#include "smt/theory_bv.h" #include "smt/theory_user_propagator.h" #include "smt/smt_context.h" @@ -116,6 +117,7 @@ theory * theory_user_propagator::mk_fresh(context * new_ctx) { if ((bool)m_eq_eh) th->register_eq(m_eq_eh); if ((bool)m_diseq_eh) th->register_diseq(m_diseq_eh); if ((bool)m_created_eh) th->register_created(m_created_eh); + if ((bool)m_decide_eh) th->register_decide(m_decide_eh); return th; } @@ -154,6 +156,73 @@ void theory_user_propagator::new_fixed_eh(theory_var v, expr* value, unsigned nu } } +void theory_user_propagator::decide(bool_var& var, bool& is_pos) { + + const bool_var_data& d = ctx.get_bdata(var); + + if (!d.is_theory_atom()) + return; + + theory* th = ctx.get_theory(d.get_theory()); + + bv_util bv(m); + enode* original_enode = nullptr; + unsigned original_bit = 0; + + if (d.is_enode() && th->get_family_id() == get_family_id()) { + // variable is just a registered expression + original_enode = ctx.bool_var2enode(var); + } + else if (th->get_family_id() == bv.get_fid()) { + // it might be a registered bit-vector + auto registered_bv = ((theory_bv*)th)->get_bv_with_theory(var, get_family_id()); + if (!registered_bv.first) + // there is no registered bv associated with the bit + return; + original_enode = registered_bv.first; + original_bit = registered_bv.second; + } + else + return; + + // call the registered callback + unsigned new_bit = original_bit; + lbool phase = is_pos ? l_true : l_false; + + expr* e = var2expr(original_enode->get_th_var(get_family_id())); + m_decide_eh(m_user_context, this, e, new_bit, phase); + enode* new_enode = ctx.get_enode(e); + + // check if the callback changed something + if (original_enode == new_enode && (new_enode->is_bool() || original_bit == new_bit)) { + if (phase != l_undef) + // it only affected the truth value + is_pos = phase == l_true; + return; + } + + bool_var old_var = var; + if (new_enode->is_bool()) { + // expression was set to a boolean + bool_var new_var = ctx.enode2bool_var(new_enode); + if (ctx.get_assignment(new_var) == l_undef) { + var = new_var; + } + } + else { + // expression was set to a bit-vector + auto th_bv = (theory_bv*)ctx.get_theory(bv.get_fid()); + bool_var new_var = th_bv->get_first_unassigned(new_bit, new_enode); + + if (new_var != null_bool_var) { + var = new_var; + } + } + + // in case the callback did not decide on a truth value -> let Z3 decide + is_pos = ctx.guess(var, phase); +} + void theory_user_propagator::push_scope_eh() { ++m_num_scopes; } diff --git a/src/smt/theory_user_propagator.h b/src/smt/theory_user_propagator.h index 9b271e9c3..bf82883e4 100644 --- a/src/smt/theory_user_propagator.h +++ b/src/smt/theory_user_propagator.h @@ -56,7 +56,7 @@ namespace smt { void reset() { memset(this, 0, sizeof(*this)); } }; - void* m_user_context = nullptr; + void* m_user_context = nullptr; user_propagator::push_eh_t m_push_eh; user_propagator::pop_eh_t m_pop_eh; user_propagator::fresh_eh_t m_fresh_eh; @@ -65,6 +65,7 @@ namespace smt { user_propagator::eq_eh_t m_eq_eh; user_propagator::eq_eh_t m_diseq_eh; user_propagator::created_eh_t m_created_eh; + user_propagator::decide_eh_t m_decide_eh; user_propagator::context_obj* m_api_context = nullptr; unsigned m_qhead = 0; @@ -121,13 +122,16 @@ namespace smt { void register_eq(user_propagator::eq_eh_t& eq_eh) { m_eq_eh = eq_eh; } void register_diseq(user_propagator::eq_eh_t& diseq_eh) { m_diseq_eh = diseq_eh; } void register_created(user_propagator::created_eh_t& created_eh) { m_created_eh = created_eh; } + void register_decide(user_propagator::decide_eh_t& decide_eh) { m_decide_eh = decide_eh; } bool has_fixed() const { return (bool)m_fixed_eh; } + bool has_decide() const { return (bool)m_decide_eh; } void propagate_cb(unsigned num_fixed, expr* const* fixed_ids, unsigned num_eqs, expr* const* lhs, expr* const* rhs, expr* conseq) override; void register_cb(expr* e) override; void new_fixed_eh(theory_var v, expr* value, unsigned num_lits, literal const* jlits); + void decide(bool_var& var, bool& is_pos); theory * mk_fresh(context * new_ctx) override; bool internalize_atom(app* atom, bool gate_ctx) override; diff --git a/src/solver/tactic2solver.cpp b/src/solver/tactic2solver.cpp index e8a30a009..fe89d6533 100644 --- a/src/solver/tactic2solver.cpp +++ b/src/solver/tactic2solver.cpp @@ -116,6 +116,10 @@ public: m_tactic->user_propagate_register_created(created_eh); } + void user_propagate_register_decide(user_propagator::decide_eh_t& created_eh) override { + m_tactic->user_propagate_register_decide(created_eh); + } + void user_propagate_clear() override { if (m_tactic) m_tactic->user_propagate_clear(); diff --git a/src/tactic/tactical.cpp b/src/tactic/tactical.cpp index 9167650ad..67a0e3062 100644 --- a/src/tactic/tactical.cpp +++ b/src/tactic/tactical.cpp @@ -204,6 +204,10 @@ public: m_t2->user_propagate_register_created(created_eh); } + void user_propagate_register_decide(user_propagator::decide_eh_t& decide_eh) override { + m_t2->user_propagate_register_decide(decide_eh); + } + }; tactic * and_then(tactic * t1, tactic * t2) { diff --git a/src/tactic/user_propagator_base.h b/src/tactic/user_propagator_base.h index 02a027762..c67a073cd 100644 --- a/src/tactic/user_propagator_base.h +++ b/src/tactic/user_propagator_base.h @@ -2,6 +2,7 @@ #pragma once #include "ast/ast.h" +#include "util/lbool.h" namespace user_propagator { @@ -17,14 +18,14 @@ namespace user_propagator { virtual ~context_obj() = default; }; - typedef std::function final_eh_t; - typedef std::function fixed_eh_t; - typedef std::function eq_eh_t; - typedef std::function fresh_eh_t; - typedef std::function push_eh_t; - typedef std::function pop_eh_t; - typedef std::function created_eh_t; - + typedef std::function final_eh_t; + typedef std::function fixed_eh_t; + typedef std::function eq_eh_t; + typedef std::function fresh_eh_t; + typedef std::function push_eh_t; + typedef std::function pop_eh_t; + typedef std::function created_eh_t; + typedef std::function decide_eh_t; class plugin : public decl_plugin { public: @@ -85,6 +86,10 @@ namespace user_propagator { throw default_exception("user-propagators are only supported on the SMT solver"); } + virtual void user_propagate_register_decide(decide_eh_t& r) { + throw default_exception("user-propagators are only supported on the SMT solver"); + } + virtual void user_propagate_clear() { }