From b33f4445453d6af6cd94449086059b8fa538c699 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Wed, 6 Aug 2025 21:11:38 -0700 Subject: [PATCH] add an option to register callback on quantifier instantiation Suppose a user propagator encodes axioms using quantifiers and uses E-matching for instantiation. If it wants to implement a custom priority scheme or drop some instances based on internal checks it can register a callback with quantifier instantiation --- scripts/update_api.py | 1 + src/api/api_solver.cpp | 8 ++++++++ src/api/python/z3/z3.py | 21 ++++++++++++++++++++- src/api/z3_api.h | 12 ++++++++++++ src/sat/sat_solver/sat_smt_solver.cpp | 4 ++++ src/sat/smt/euf_solver.h | 4 ++++ src/smt/qi_queue.cpp | 5 +++++ src/smt/qi_queue.h | 5 +++++ src/smt/smt_context.h | 8 ++++++++ src/smt/smt_kernel.cpp | 4 ++++ src/smt/smt_kernel.h | 2 ++ src/smt/smt_quantifier.cpp | 8 ++++++++ src/smt/smt_quantifier.h | 3 +++ src/smt/smt_solver.cpp | 4 ++++ src/smt/tactic/smt_tactic_core.cpp | 4 ++++ src/smt/theory_user_propagator.cpp | 9 +++++++++ src/smt/theory_user_propagator.h | 1 + src/solver/combined_solver.cpp | 4 ++++ src/solver/simplifier_solver.cpp | 5 ++++- src/solver/slice_solver.cpp | 3 ++- src/solver/tactic2solver.cpp | 4 ++++ src/tactic/dependent_expr_state_tactic.h | 1 + src/tactic/tactical.cpp | 4 ++++ src/tactic/user_propagator_base.h | 5 +++++ 24 files changed, 126 insertions(+), 3 deletions(-) diff --git a/scripts/update_api.py b/scripts/update_api.py index 153052deb..ad6bb3658 100755 --- a/scripts/update_api.py +++ b/scripts/update_api.py @@ -1944,6 +1944,7 @@ Z3_eq_eh = ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_ Z3_created_eh = ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p) Z3_decide_eh = ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_uint, ctypes.c_int) +Z3_on_binding_eh = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p) _lib.Z3_solver_register_on_clause.restype = None _lib.Z3_solver_propagate_init.restype = None diff --git a/src/api/api_solver.cpp b/src/api/api_solver.cpp index 29c012b86..dad3bc126 100644 --- a/src/api/api_solver.cpp +++ b/src/api/api_solver.cpp @@ -1160,6 +1160,14 @@ extern "C" { Z3_CATCH; } + void Z3_API Z3_solver_propagate_on_binding(Z3_context c, Z3_solver s, Z3_on_binding_eh binding_eh) { + Z3_TRY; + RESET_ERROR_CODE(); + user_propagator::binding_eh_t c = (bool(*)(void*, user_propagator::callback*, expr*, expr*))binding_eh; + to_solver_ref(s)->user_propagate_register_on_binding(c); + Z3_CATCH; + } + bool Z3_API Z3_solver_next_split(Z3_context c, Z3_solver_callback cb, Z3_ast t, unsigned idx, Z3_lbool phase) { Z3_TRY; LOG_Z3_solver_next_split(c, cb, t, idx, phase); diff --git a/src/api/python/z3/z3.py b/src/api/python/z3/z3.py index f9bb51699..16cc45de2 100644 --- a/src/api/python/z3/z3.py +++ b/src/api/python/z3/z3.py @@ -11814,6 +11814,16 @@ def user_prop_decide(ctx, cb, t_ref, idx, phase): t = _to_expr_ref(to_Ast(t_ref), prop.ctx()) prop.decide(t, idx, phase) prop.cb = old_cb + +def user_prop_binding(ctx, cb, q_ref, inst_ref): + prop = _prop_closures.get(ctx) + old_cb = prop.cb + prop.cb = cb + q = _to_expr_ref(to_Ast(q_ref), prop.ctx()) + inst = _to_expr_ref(to_Ast(inst_ref), prop.ctx()) + r = prop.binding(q, inst) + prop.cb = old_cb + return r _user_prop_push = Z3_push_eh(user_prop_push) @@ -11825,6 +11835,7 @@ _user_prop_final = Z3_final_eh(user_prop_final) _user_prop_eq = Z3_eq_eh(user_prop_eq) _user_prop_diseq = Z3_eq_eh(user_prop_diseq) _user_prop_decide = Z3_decide_eh(user_prop_decide) +_user_prop_binding = Z3_on_binding_eh(user_prop_binding) def PropagateFunction(name, *sig): @@ -11873,6 +11884,7 @@ class UserPropagateBase: self.diseq = None self.decide = None self.created = None + self.binding = None if ctx: self.fresh_ctx = ctx if s: @@ -11936,7 +11948,14 @@ class UserPropagateBase: assert not self._ctx if self.solver: Z3_solver_propagate_decide(self.ctx_ref(), self.solver.solver, _user_prop_decide) - self.decide = decide + self.decide = decide + + def add_on_binding(self, binding): + assert not self.binding + assert not self._ctx + if self.solver: + Z3_solver_propagate_on_binding(self.ctx_ref(), self.solver.solver, _user_prop_binding) + self.binding = binding def push(self): raise Z3Exception("push needs to be overwritten") diff --git a/src/api/z3_api.h b/src/api/z3_api.h index 0179392e0..9de58e057 100644 --- a/src/api/z3_api.h +++ b/src/api/z3_api.h @@ -1440,6 +1440,7 @@ Z3_DECLARE_CLOSURE(Z3_eq_eh, void, (void* ctx, Z3_solver_callback cb, Z3_as Z3_DECLARE_CLOSURE(Z3_final_eh, void, (void* ctx, Z3_solver_callback cb)); Z3_DECLARE_CLOSURE(Z3_created_eh, void, (void* ctx, Z3_solver_callback cb, Z3_ast t)); Z3_DECLARE_CLOSURE(Z3_decide_eh, void, (void* ctx, Z3_solver_callback cb, Z3_ast t, unsigned idx, bool phase)); +Z3_DECLARE_CLOSURE(Z3_on_binding_eh, bool, (void* ctx, Z3_solver_callback cb, Z3_ast q, Z3_ast inst)); Z3_DECLARE_CLOSURE(Z3_on_clause_eh, void, (void* ctx, Z3_ast proof_hint, unsigned n, unsigned const* deps, Z3_ast_vector literals)); @@ -7225,6 +7226,17 @@ extern "C" { */ void Z3_API Z3_solver_propagate_decide(Z3_context c, Z3_solver s, Z3_decide_eh decide_eh); + + /** + \brief register a callback when the solver instantiates a quantifier. + If the callback returns false, the actual instantiation of the quantifier is blocked. + This allows the user propagator selectively prioritize instantiations without relying on default + or configured weights. + + def_API('Z3_solver_propagate_on_binding', VOID, (_in(CONTEXT), _in(SOLVER), _fnptr(Z3_on_binding_eh))) + */ + + void Z3_API Z3_solver_propagate_on_binding(Z3_context c, Z3_solver s, Z3_on_binding_eh on_binding_eh); /** Sets the next (registered) expression to split on. The function returns false and ignores the given expression in case the expression is already assigned internally diff --git a/src/sat/sat_solver/sat_smt_solver.cpp b/src/sat/sat_solver/sat_smt_solver.cpp index 6e036c8e3..8548749be 100644 --- a/src/sat/sat_solver/sat_smt_solver.cpp +++ b/src/sat/sat_solver/sat_smt_solver.cpp @@ -565,6 +565,10 @@ public: void user_propagate_register_diseq(user_propagator::eq_eh_t& diseq_eh) override { ensure_euf()->user_propagate_register_diseq(diseq_eh); } + + void user_propagate_register_on_binding(user_propagator::binding_eh_t& binding_eh) override { + ensure_euf()->user_propagate_register_on_binding(binding_eh); + } void user_propagate_register_expr(expr* e) override { ensure_euf()->user_propagate_register_expr(e); diff --git a/src/sat/smt/euf_solver.h b/src/sat/smt/euf_solver.h index 12ace1a24..69017679c 100644 --- a/src/sat/smt/euf_solver.h +++ b/src/sat/smt/euf_solver.h @@ -554,6 +554,10 @@ namespace euf { check_for_user_propagator(); m_user_propagator->register_decide(ceh); } + void user_propagate_register_on_binding(user_propagator::binding_eh_t& on_binding_eh) { + check_for_user_propagator(); + NOT_IMPLEMENTED_YET(); + } void user_propagate_register_expr(expr* e) { check_for_user_propagator(); m_user_propagator->add_expr(e); diff --git a/src/smt/qi_queue.cpp b/src/smt/qi_queue.cpp index b8835e8fe..10a8ab7c6 100644 --- a/src/smt/qi_queue.cpp +++ b/src/smt/qi_queue.cpp @@ -263,6 +263,11 @@ namespace smt { if (stat->get_num_instances() % m_params.m_qi_profile_freq == 0) { m_qm.display_stats(verbose_stream(), q); } + + if (m_on_binding && !m_on_binding(q, instance)) { + verbose_stream() << "qi_queue: on_binding returned false, skipping instance.\n"; + return; + } expr_ref lemma(m); if (m.is_or(s_instance)) { ptr_vector args; diff --git a/src/smt/qi_queue.h b/src/smt/qi_queue.h index 7265875ef..13878a158 100644 --- a/src/smt/qi_queue.h +++ b/src/smt/qi_queue.h @@ -28,6 +28,7 @@ Revision History: #include "params/qi_params.h" #include "ast/cost_evaluator.h" #include "util/statistics.h" +#include "tactic/user_propagator_base.h" namespace smt { class context; @@ -52,6 +53,7 @@ namespace smt { cached_var_subst m_subst; svector m_vals; double m_eager_cost_threshold = 0; + std::function m_on_binding; struct entry { fingerprint * m_qb; float m_cost; @@ -95,6 +97,9 @@ namespace smt { void reset(); void display_delayed_instances_stats(std::ostream & out) const; void collect_statistics(::statistics & st) const; + void register_on_binding(std::function & on_binding) { + m_on_binding = on_binding; + } }; }; diff --git a/src/smt/smt_context.h b/src/smt/smt_context.h index 4d8508def..2fbc1d705 100644 --- a/src/smt/smt_context.h +++ b/src/smt/smt_context.h @@ -1814,6 +1814,14 @@ namespace smt { m_user_propagator->register_decide(r); } + void user_propagate_register_on_binding(user_propagator::binding_eh_t& t) { + m_user_propagator->register_on_binding(t); + } + + void register_on_binding(std::function& f) { + m_qmanager->register_on_binding(f); + } + void user_propagate_initialize_value(expr* var, expr* value); bool watches_fixed(enode* n) const; diff --git a/src/smt/smt_kernel.cpp b/src/smt/smt_kernel.cpp index 85efd5620..e914dcbf8 100644 --- a/src/smt/smt_kernel.cpp +++ b/src/smt/smt_kernel.cpp @@ -307,6 +307,10 @@ namespace smt { void kernel::user_propagate_register_fixed(user_propagator::fixed_eh_t& fixed_eh) { m_imp->m_kernel.user_propagate_register_fixed(fixed_eh); } + + void kernel::user_propagate_register_on_binding(user_propagator::binding_eh_t& on_binding) { + m_imp->m_kernel.user_propagate_register_on_binding(on_binding); + } void kernel::user_propagate_register_final(user_propagator::final_eh_t& final_eh) { m_imp->m_kernel.user_propagate_register_final(final_eh); diff --git a/src/smt/smt_kernel.h b/src/smt/smt_kernel.h index 92dac74d5..98b677213 100644 --- a/src/smt/smt_kernel.h +++ b/src/smt/smt_kernel.h @@ -319,6 +319,8 @@ namespace smt { void user_propagate_register_diseq(user_propagator::eq_eh_t& diseq_eh); + void user_propagate_register_on_binding(user_propagator::binding_eh_t& binding_eh); + void user_propagate_register_expr(expr* e); void user_propagate_register_created(user_propagator::created_eh_t& r); diff --git a/src/smt/smt_quantifier.cpp b/src/smt/smt_quantifier.cpp index e6f156195..32c785d90 100644 --- a/src/smt/smt_quantifier.cpp +++ b/src/smt/smt_quantifier.cpp @@ -339,6 +339,10 @@ namespace smt { m_plugin->add_eq_eh(n1, n2); } + void register_on_binding(std::function& on_binding) { + m_qi_queue.register_on_binding(on_binding); + } + void relevant_eh(enode * n) { m_plugin->relevant_eh(n); } @@ -493,6 +497,10 @@ namespace smt { m_imp->add_eq_eh(n1, n2); } + void quantifier_manager::register_on_binding(std::function& on_binding) { + m_imp->register_on_binding(on_binding); + } + void quantifier_manager::relevant_eh(enode * n) { m_imp->relevant_eh(n); } diff --git a/src/smt/smt_quantifier.h b/src/smt/smt_quantifier.h index abb3cac7c..981647606 100644 --- a/src/smt/smt_quantifier.h +++ b/src/smt/smt_quantifier.h @@ -23,6 +23,7 @@ Revision History: #include "util/statistics.h" #include "util/params.h" #include "smt/smt_types.h" +#include "tactic/user_propagator_base.h" #include class proto_model; @@ -96,6 +97,8 @@ namespace smt { void collect_statistics(::statistics & st) const; void reset_statistics(); + void register_on_binding(std::function & f); + ptr_vector::const_iterator begin_quantifiers() const; ptr_vector::const_iterator end_quantifiers() const; ptr_vector::const_iterator begin() const { return begin_quantifiers(); } diff --git a/src/smt/smt_solver.cpp b/src/smt/smt_solver.cpp index f7737b8a6..05bcc00ba 100644 --- a/src/smt/smt_solver.cpp +++ b/src/smt/smt_solver.cpp @@ -244,6 +244,10 @@ namespace { m_context.user_propagate_register_expr(e); } + void user_propagate_register_on_binding(user_propagator::binding_eh_t& binding_eh) override { + m_context.user_propagate_register_on_binding(binding_eh); + } + void user_propagate_register_created(user_propagator::created_eh_t& c) override { m_context.user_propagate_register_created(c); } diff --git a/src/smt/tactic/smt_tactic_core.cpp b/src/smt/tactic/smt_tactic_core.cpp index bbaf99e4d..2c584d288 100644 --- a/src/smt/tactic/smt_tactic_core.cpp +++ b/src/smt/tactic/smt_tactic_core.cpp @@ -402,6 +402,10 @@ public: m_diseq_eh = diseq_eh; } + void user_propagate_register_on_binding(user_propagator::binding_eh_t& binding_eh) override { + m_ctx.load()->user_propagate_register_on_binding(binding_eh); + } + void user_propagate_register_expr(expr* e) override { m_vars.push_back(e); } diff --git a/src/smt/theory_user_propagator.cpp b/src/smt/theory_user_propagator.cpp index 0dd619f92..f8c2a35b8 100644 --- a/src/smt/theory_user_propagator.cpp +++ b/src/smt/theory_user_propagator.cpp @@ -110,6 +110,15 @@ void theory_user_propagator::register_cb(expr* e) { add_expr(e, true); } +void theory_user_propagator::register_on_binding(user_propagator::binding_eh_t& binding_eh) { + std::function on_binding = + [this, binding_eh](quantifier* q, expr* inst) { + return binding_eh(m_user_context, this, q, inst); + }; + ctx.register_on_binding(on_binding); + +} + bool theory_user_propagator::next_split_cb(expr* e, unsigned idx, lbool phase) { if (e == nullptr) { // clear m_next_split_var = nullptr; diff --git a/src/smt/theory_user_propagator.h b/src/smt/theory_user_propagator.h index c9409612e..5e8d3878c 100644 --- a/src/smt/theory_user_propagator.h +++ b/src/smt/theory_user_propagator.h @@ -132,6 +132,7 @@ namespace smt { void register_diseq(user_propagator::eq_eh_t& diseq_eh) { m_diseq_eh = diseq_eh; } void register_created(user_propagator::created_eh_t& created_eh) { m_created_eh = created_eh; } void register_decide(user_propagator::decide_eh_t& decide_eh) { m_decide_eh = decide_eh; } + void register_on_binding(user_propagator::binding_eh_t& binding_eh); bool has_fixed() const { return (bool)m_fixed_eh; } diff --git a/src/solver/combined_solver.cpp b/src/solver/combined_solver.cpp index e1c9931bb..aacb2b1cc 100644 --- a/src/solver/combined_solver.cpp +++ b/src/solver/combined_solver.cpp @@ -379,6 +379,10 @@ public: void user_propagate_register_diseq(user_propagator::eq_eh_t& diseq_eh) override { m_solver2->user_propagate_register_diseq(diseq_eh); } + + void user_propagate_register_on_binding(user_propagator::binding_eh_t& binding_eh) override { + m_solver2->user_propagate_register_on_binding(binding_eh); + } void user_propagate_register_expr(expr* e) override { m_solver2->user_propagate_register_expr(e); diff --git a/src/solver/simplifier_solver.cpp b/src/solver/simplifier_solver.cpp index 961f6c9e7..ea2a1b2ea 100644 --- a/src/solver/simplifier_solver.cpp +++ b/src/solver/simplifier_solver.cpp @@ -387,7 +387,10 @@ public: void user_propagate_register_fixed(user_propagator::fixed_eh_t& fixed_eh) override { s->user_propagate_register_fixed(fixed_eh); } void user_propagate_register_final(user_propagator::final_eh_t& final_eh) override { s->user_propagate_register_final(final_eh); } void user_propagate_register_eq(user_propagator::eq_eh_t& eq_eh) override { s->user_propagate_register_eq(eq_eh); } - void user_propagate_register_diseq(user_propagator::eq_eh_t& diseq_eh) override { s->user_propagate_register_diseq(diseq_eh); } + void user_propagate_register_diseq(user_propagator::eq_eh_t& diseq_eh) override { s->user_propagate_register_diseq(diseq_eh); } + void user_propagate_register_on_binding(user_propagator::binding_eh_t& binding_eh) override { + s->user_propagate_register_on_binding(binding_eh); + } void user_propagate_register_expr(expr* e) override { m_preprocess_state.freeze(e); s->user_propagate_register_expr(e); } void user_propagate_register_created(user_propagator::created_eh_t& r) override { s->user_propagate_register_created(r); } void user_propagate_register_decide(user_propagator::decide_eh_t& r) override { s->user_propagate_register_decide(r); } diff --git a/src/solver/slice_solver.cpp b/src/solver/slice_solver.cpp index 8310c47f4..ee95cfa94 100644 --- a/src/solver/slice_solver.cpp +++ b/src/solver/slice_solver.cpp @@ -415,7 +415,8 @@ public: void user_propagate_register_fixed(user_propagator::fixed_eh_t& fixed_eh) override { s->user_propagate_register_fixed(fixed_eh); } void user_propagate_register_final(user_propagator::final_eh_t& final_eh) override { s->user_propagate_register_final(final_eh); } void user_propagate_register_eq(user_propagator::eq_eh_t& eq_eh) override { s->user_propagate_register_eq(eq_eh); } - void user_propagate_register_diseq(user_propagator::eq_eh_t& diseq_eh) override { s->user_propagate_register_diseq(diseq_eh); } + void user_propagate_register_diseq(user_propagator::eq_eh_t& diseq_eh) override { s->user_propagate_register_diseq(diseq_eh); } + void user_propagate_register_on_binding(user_propagator::binding_eh_t& binding_eh) override { s->user_propagate_register_on_binding(binding_eh); } void user_propagate_register_expr(expr* e) override { s->user_propagate_register_expr(e); } void user_propagate_register_created(user_propagator::created_eh_t& r) override { s->user_propagate_register_created(r); } void user_propagate_register_decide(user_propagator::decide_eh_t& r) override { s->user_propagate_register_decide(r); } diff --git a/src/solver/tactic2solver.cpp b/src/solver/tactic2solver.cpp index 618d9c161..7c4542451 100644 --- a/src/solver/tactic2solver.cpp +++ b/src/solver/tactic2solver.cpp @@ -115,6 +115,10 @@ public: m_tactic->user_propagate_register_diseq(diseq_eh); } + void user_propagate_register_on_binding(user_propagator::binding_eh_t& binding_eh) override { + m_tactic->user_propagate_register_on_binding(binding_eh); + } + void user_propagate_register_expr(expr* e) override { m_tactic->user_propagate_register_expr(e); } diff --git a/src/tactic/dependent_expr_state_tactic.h b/src/tactic/dependent_expr_state_tactic.h index fc1cf2b4f..fe3a5f2fb 100644 --- a/src/tactic/dependent_expr_state_tactic.h +++ b/src/tactic/dependent_expr_state_tactic.h @@ -168,6 +168,7 @@ public: m_frozen.push_back(e); } + void user_propagate_clear() override { if (m_simp) { pop(1); diff --git a/src/tactic/tactical.cpp b/src/tactic/tactical.cpp index 764107520..71a260358 100644 --- a/src/tactic/tactical.cpp +++ b/src/tactic/tactical.cpp @@ -200,6 +200,10 @@ public: m_t2->user_propagate_register_diseq(diseq_eh); } + void user_propagate_register_on_binding(user_propagator::binding_eh_t& binding_eh) override { + m_t2->user_propagate_register_on_binding(binding_eh); + } + void user_propagate_register_expr(expr* e) override { m_t1->user_propagate_register_expr(e); m_t2->user_propagate_register_expr(e); diff --git a/src/tactic/user_propagator_base.h b/src/tactic/user_propagator_base.h index 968196f63..1b480fb04 100644 --- a/src/tactic/user_propagator_base.h +++ b/src/tactic/user_propagator_base.h @@ -28,6 +28,7 @@ namespace user_propagator { typedef std::function created_eh_t; typedef std::function decide_eh_t; typedef std::function on_clause_eh_t; + typedef std::function binding_eh_t; class plugin : public decl_plugin { public: @@ -92,6 +93,10 @@ namespace user_propagator { throw default_exception("user-propagators are only supported on the SMT solver"); } + virtual void user_propagate_register_on_binding(binding_eh_t& r) { + throw default_exception("user-propagators are only supported on the SMT solver"); + } + virtual void user_propagate_clear() { }