From 2d5b7497455359a7b431e9ef7a0c6a94bb4985de Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Fri, 21 Aug 2020 19:24:59 -0700 Subject: [PATCH] extend solver callbacks with methods Signed-off-by: Nikolaj Bjorner --- scripts/update_api.py | 42 +++++++++++++-- src/api/api_solver.cpp | 48 +++++++++++++++-- src/api/ml/z3native.ml.pre | 1 + src/api/python/z3/z3.py | 102 +++++++++++++++++++++++++++++------- src/api/z3_api.h | 37 +++++++++++-- src/smt/smt_context.cpp | 3 +- src/smt/smt_context.h | 25 ++++++++- src/smt/smt_kernel.cpp | 38 ++++++++++++-- src/smt/smt_kernel.h | 10 +++- src/smt/smt_solver.cpp | 19 ++++++- src/smt/user_propagator.cpp | 16 +++++- src/smt/user_propagator.h | 21 +++++--- src/solver/solver.h | 36 ++++++++++--- 13 files changed, 343 insertions(+), 55 deletions(-) diff --git a/scripts/update_api.py b/scripts/update_api.py index 00251765b..665b3f01e 100755 --- a/scripts/update_api.py +++ b/scripts/update_api.py @@ -337,9 +337,26 @@ def Z3_set_error_handler(ctx, hndlr, _elems=Elementaries(_lib.Z3_set_error_handl _elems.Check(ctx) return ceh -def Z3_solver_propagate_init(ctx, s, user_ctx, push_eh, pop_eh, fixed_eh, fresh_eh, _elems = Elementaries(_lib.Z3_solver_propagate_init)): - _elems.f(ctx, s, user_ctx, push_eh, pop_eh, fixed_eh, fresh_eh) - _elems.Check(ctx) +def Z3_solver_propagate_init(ctx, s, user_ctx, push_eh, pop_eh, fresh_eh, _elems = Elementaries(_lib.Z3_solver_propagate_init)): + _elems.f(ctx, s, user_ctx, push_eh, pop_eh, fresh_eh) + _elems.Check(ctx) + +def Z3_solver_propagate_final(ctx, s, final_eh, _elems = Elementaries(_lib.Z3_solver_propagate_final)): + _elems.f(ctx, s, final_eh) + _elems.Check(ctx) + +def Z3_solver_propagate_fixed(ctx, s, fixed_eh, _elems = Elementaries(_lib.Z3_solver_propagate_fixed)): + _elems.f(ctx, s, fixed_eh) + _elems.Check(ctx) + +def Z3_solver_propagate_eq(ctx, s, eq_eh, _elems = Elementaries(_lib.Z3_solver_propagate_eq)): + _elems.f(ctx, s, eq_eh) + _elems.Check(ctx) + +def Z3_solver_propagate_diseq(ctx, s, diseq_eh, _elems = Elementaries(_lib.Z3_solver_propagate_diseq)): + _elems.f(ctx, s, diseq_eh) + _elems.Check(ctx) + """) @@ -1826,11 +1843,26 @@ _lib.Z3_set_error_handler.argtypes = [ContextObj, _error_handler_type] push_eh_type = ctypes.CFUNCTYPE(None, ctypes.c_void_p) pop_eh_type = ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_uint) -fixed_eh_type = ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_uint, ctypes.c_void_p) fresh_eh_type = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.c_void_p) +fixed_eh_type = ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_uint, ctypes.c_void_p) +final_eh_type = ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_void_p) +eq_eh_type = ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_uint, ctypes.c_uint) + _lib.Z3_solver_propagate_init.restype = None -_lib.Z3_solver_propagate_init.argtypes = [ContextObj, SolverObj, ctypes.c_void_p, push_eh_type, pop_eh_type, fixed_eh_type, fresh_eh_type] +_lib.Z3_solver_propagate_init.argtypes = [ContextObj, SolverObj, ctypes.c_void_p, push_eh_type, pop_eh_type, fresh_eh_type] + +_lib.Z3_solver_propagate_final.restype = None +_lib.Z3_solver_propagate_final.argtypes = [ContextObj, SolverObj, final_eh_type] + +_lib.Z3_solver_propagate_fixed.restype = None +_lib.Z3_solver_propagate_fixed.argtypes = [ContextObj, SolverObj, fixed_eh_type] + +_lib.Z3_solver_propagate_eq.restype = None +_lib.Z3_solver_propagate_eq.argtypes = [ContextObj, SolverObj, eq_eh_type] + +_lib.Z3_solver_propagate_diseq.restype = None +_lib.Z3_solver_propagate_diseq.argtypes = [ContextObj, SolverObj, eq_eh_type] """ diff --git a/src/api/api_solver.cpp b/src/api/api_solver.cpp index 726271131..8c1dbe4e3 100644 --- a/src/api/api_solver.cpp +++ b/src/api/api_solver.cpp @@ -892,19 +892,61 @@ extern "C" { void* user_context, Z3_push_eh push_eh, Z3_pop_eh pop_eh, - Z3_fixed_eh fixed_eh, Z3_fresh_eh fresh_eh) { Z3_TRY; RESET_ERROR_CODE(); init_solver(c, s); std::function _push = push_eh; std::function _pop = pop_eh; - std::function _fixed = (void(*)(void*,solver::propagate_callback*,unsigned,expr*))fixed_eh; std::function _fresh = fresh_eh; - to_solver_ref(s)->user_propagate_init(user_context, _fixed, _push, _pop, _fresh); + to_solver_ref(s)->user_propagate_init(user_context, _push, _pop, _fresh); Z3_CATCH; } + void Z3_API Z3_solver_propagate_fixed( + Z3_context c, + Z3_solver s, + Z3_fixed_eh fixed_eh) { + Z3_TRY; + RESET_ERROR_CODE(); + solver::fixed_eh_t _fixed = (void(*)(void*,solver::propagate_callback*,unsigned,expr*))fixed_eh; + to_solver_ref(s)->user_propagate_register_fixed(_fixed); + Z3_CATCH; + } + + void Z3_API Z3_solver_propagate_final( + Z3_context c, + Z3_solver s, + Z3_final_eh final_eh) { + Z3_TRY; + RESET_ERROR_CODE(); + solver::final_eh_t _final = (bool(*)(void*,solver::propagate_callback*))final_eh; + to_solver_ref(s)->user_propagate_register_final(_final); + Z3_CATCH; + } + + void Z3_API Z3_solver_propagate_eq( + Z3_context c, + Z3_solver s, + Z3_eq_eh eq_eh) { + Z3_TRY; + RESET_ERROR_CODE(); + solver::eq_eh_t _eq = (void(*)(void*,solver::propagate_callback*,unsigned,unsigned))eq_eh; + to_solver_ref(s)->user_propagate_register_eq(_eq); + Z3_CATCH; + } + + void Z3_API Z3_solver_propagate_diseq( + Z3_context c, + Z3_solver s, + Z3_eq_eh diseq_eh) { + Z3_TRY; + RESET_ERROR_CODE(); + solver::eq_eh_t _diseq = (void(*)(void*,solver::propagate_callback*,unsigned,unsigned))diseq_eh; + to_solver_ref(s)->user_propagate_register_diseq(_diseq); + Z3_CATCH; + } + unsigned Z3_API Z3_solver_propagate_register(Z3_context c, Z3_solver s, Z3_ast e) { Z3_TRY; LOG_Z3_solver_propagate_register(c, s, e); diff --git a/src/api/ml/z3native.ml.pre b/src/api/ml/z3native.ml.pre index 87a069df9..8f76aa950 100644 --- a/src/api/ml/z3native.ml.pre +++ b/src/api/ml/z3native.ml.pre @@ -17,6 +17,7 @@ and literals = ptr and constructor = ptr and constructor_list = ptr and solver = ptr +and solver_callback = ptr and goal = ptr and tactic = ptr and params = ptr diff --git a/src/api/python/z3/z3.py b/src/api/python/z3/z3.py index 5e595b63b..43fed72ef 100644 --- a/src/api/python/z3/z3.py +++ b/src/api/python/z3/z3.py @@ -10505,59 +10505,123 @@ def TransitiveClosure(f): return FuncDeclRef(Z3_mk_transitive_closure(f.ctx_ref(), f.ast), f.ctx) -_user_propagate_bases = {} +class PropClosures: +# import thread + def __init__(self): + self.bases = {} +# self.lock = thread.Lock() + + def get(self, ctx): +# self.lock.acquire() + r = self.bases[ctx] +# self.lock.release() + return r + + def set(self, ctx, r): +# self.lock.acquire() + self.bases[ctx] = r +# self.lock.release() + + def insert(self, r): +# self.lock.acquire() + id = len(self.bases) + 3 + self.bases[id] = r +# self.lock.release() + return id + +_prop_closures = PropClosures() def user_prop_push(ctx): - _user_propagate_bases[ctx].push(); + _prop_closures.get(ctx).push(); def user_prop_pop(ctx, num_scopes): - _user_propagate_bases[ctx].pop(num_scopes) + _prop_closures.get(ctx).pop(num_scopes) + +def user_prop_fresh(ctx): + prop = _prop_closures.get(ctx) + new_prop = UsePropagateBase(None, prop.ctx) + _prop_closures.set(new_prop.id, new_prop.fresh()) + return ctypes.c_void_p(new_prop.id) def user_prop_fixed(ctx, cb, id, value): - prop = _user_propagate_bases[ctx] + prop = _prop_closures.get(ctx) prop.cb = cb prop.fixed(id, _to_expr_ref(ctypes.c_void_p(value), prop.ctx)) prop.cb = None -def user_prop_fresh(ctx): - prop = _user_propagate_bases[ctx] - new_prop = UsePropagateBase(None, prop.ctx) - _user_prop_bases[new_prop.id] = new_prop.fresh() - return ctypes.c_void_p(new_prop.id) - +def user_prop_final(ctx, cb): + prop = _prop_closures.get(ctx) + prop.cb = cb + prop.final() + prop.cb = None + +def user_prop_eq(ctx, cb, x, y): + prop = _prop_closures.get(ctx) + prop.cb = cb + prop.eq(x, y) + prop.cb = None + +def user_prop_diseq(ctx, cb, x, y): + prop = _prop_closures.get(ctx) + prop.cb = cb + prop.diseq(x, y) + prop.cb = None _user_prop_push = push_eh_type(user_prop_push) _user_prop_pop = pop_eh_type(user_prop_pop) -_user_prop_fixed = fixed_eh_type(user_prop_fixed) _user_prop_fresh = fresh_eh_type(user_prop_fresh) +_user_prop_fixed = fixed_eh_type(user_prop_fixed) +_user_prop_final = final_eh_type(user_prop_final) +_user_prop_eq = eq_eh_type(user_prop_eq) +_user_prop_diseq = eq_eh_type(user_prop_diseq) class UserPropagateBase: def __init__(self, s, ctx = None): - self.id = len(_user_propagate_bases) + 3 - self.solver = s + self.solver = s self.ctx = s.ctx if s is not None else ctx self.cb = None - _user_propagate_bases[self.id] = self + self.id = _prop_closures.insert(self) + self.fixed = None + self.final = None + self.eq = None + self.diseq = None if s: Z3_solver_propagate_init(s.ctx.ref(), s.solver, ctypes.c_void_p(self.id), _user_prop_push, _user_prop_pop, - _user_prop_fixed, _user_prop_fresh) + + def add_fixed(self, fixed): + assert not self.fixed + Z3_solver_propagate_fixed(self.ctx.ref(), self.solver.solver, _user_prop_fixed) + self.fixed = fixed + + def add_final(self, final): + assert not self.final + Z3_solver_propagate_final(self.ctx.ref(), self.solver.solver, _user_prop_final) + self.final = final + + def add_eq(self, eq): + assert not self.eq + Z3_solver_propagate_eq(self.ctx.ref(), self.solver.solver, _user_prop_eq) + self.eq = eq + + def add_diseq(self, diseq): + assert not self.diseq + Z3_solver_propagate_diseq(self.ctx.ref(), self.solver.solver, _user_prop_diseq) + self.diseq = diseq + def push(self): raise Z3Exception("push has not been overwritten") def pop(self, num_scopes): raise Z3Exception("pop has not been overwritten") - def fixed(self, id, e): - raise Z3Exception("fixed has not been overwritten") - - def fresh(self, prop_base): + def fresh(self): raise Z3Exception("fresh has not been overwritten") def add(self, e): diff --git a/src/api/z3_api.h b/src/api/z3_api.h index 921ce372d..0d15adba5 100644 --- a/src/api/z3_api.h +++ b/src/api/z3_api.h @@ -1420,8 +1420,10 @@ typedef void Z3_error_handler(Z3_context c, Z3_error_code e); */ typedef void Z3_push_eh(void* ctx); typedef void Z3_pop_eh(void* ctx, unsigned num_scopes); -typedef void Z3_fixed_eh(void* ctx, Z3_solver_callback cb, unsigned id, Z3_ast value); typedef void* Z3_fresh_eh(void* ctx); +typedef void Z3_fixed_eh(void* ctx, Z3_solver_callback cb, unsigned id, Z3_ast value); +typedef void Z3_eq_eh(void* ctx, Z3_solver_callback cb, unsigned x, unsigned y); +typedef void Z3_final_eh(void* ctx, Z3_solver_callback cb); /** \brief A Goal is essentially a set of formulas. @@ -6537,9 +6539,38 @@ extern "C" { void* user_context, Z3_push_eh push_eh, Z3_pop_eh pop_eh, - Z3_fixed_eh fixed_eh, Z3_fresh_eh fresh_eh); + /** + \brief register a callback for when an expression is bound to a fixed value. + The supported expression types are + - Booleans + - Bit-vectors + */ + + void Z3_API Z3_solver_propagate_fixed(Z3_context c, Z3_solver s, Z3_fixed_eh fixed_eh); + + /** + \brief register a callback on final check. + This provides freedom to the propagator to delay actions or implement a branch-and bound solver. + + The final_eh callback takes as argument the original user_context that was used + when calling \c Z3_solver_propagate_init, and it takes a callback context for propagations. + If may use the callback context to invoke the \c Z3_solver_propagate_consequence function. + If the callback context gets used, the solver continues. + */ + void Z3_API Z3_solver_propagate_final(Z3_context c, Z3_solver s, Z3_final_eh final_eh); + + /** + \brief register a callback on expression equalities. + */ + void Z3_API Z3_solver_propagate_eq(Z3_context c, Z3_solver s, Z3_eq_eh eq_eh); + + /** + \brief register a callback on expression dis-equalities. + */ + void Z3_API Z3_solver_propagate_diseq(Z3_context c, Z3_solver s, Z3_eq_eh eq_eh); + /** \brief register an expression to propagate on with the solver. Only expressions of type Bool and type Bit-Vector can be registered for propagation. @@ -6547,7 +6578,7 @@ extern "C" { def_API('Z3_solver_propagate_register', UINT, (_in(CONTEXT), _in(SOLVER), _in(AST))) */ - unsigned Z3_API Z3_solver_propagate_register(Z3_context c, Z3_solver s, Z3_ast e); + unsigned Z3_API Z3_solver_propagate_register(Z3_context c, Z3_solver s, Z3_ast e); /** \brief propagate a consequence based on fixed values. diff --git a/src/smt/smt_context.cpp b/src/smt/smt_context.cpp index 77bab98d9..a354e1507 100644 --- a/src/smt/smt_context.cpp +++ b/src/smt/smt_context.cpp @@ -2951,13 +2951,12 @@ namespace smt { void context::user_propagate_init( void* ctx, - std::function& fixed_eh, std::function& push_eh, std::function& pop_eh, std::function& fresh_eh) { setup_context(m_fparams.m_auto_config); m_user_propagator = alloc(user_propagator, *this); - m_user_propagator->add(ctx, fixed_eh, push_eh, pop_eh, fresh_eh); + m_user_propagator->add(ctx, push_eh, pop_eh, fresh_eh); for (unsigned i = m_scopes.size(); i-- > 0; ) m_user_propagator->push_scope_eh(); register_plugin(m_user_propagator); diff --git a/src/smt/smt_context.h b/src/smt/smt_context.h index 4bacac97b..8f6d0a6a6 100644 --- a/src/smt/smt_context.h +++ b/src/smt/smt_context.h @@ -1689,11 +1689,34 @@ namespace smt { */ void user_propagate_init( void* ctx, - std::function& fixed_eh, std::function& push_eh, std::function& pop_eh, std::function& fresh_eh); + void user_propagate_register_final(solver::final_eh_t& final_eh) { + if (!m_user_propagator) + throw default_exception("user propagator must be initialized"); + m_user_propagator->register_final(final_eh); + } + + void user_propagate_register_fixed(solver::fixed_eh_t& fixed_eh) { + if (!m_user_propagator) + throw default_exception("user propagator must be initialized"); + m_user_propagator->register_fixed(fixed_eh); + } + + void user_propagate_register_eq(solver::eq_eh_t& eq_eh) { + if (!m_user_propagator) + throw default_exception("user propagator must be initialized"); + m_user_propagator->register_eq(eq_eh); + } + + void user_propagate_register_diseq(solver::eq_eh_t& diseq_eh) { + if (!m_user_propagator) + throw default_exception("user propagator must be initialized"); + m_user_propagator->register_diseq(diseq_eh); + } + unsigned user_propagate_register(expr* e) { if (!m_user_propagator) throw default_exception("user propagator must be initialized"); diff --git a/src/smt/smt_kernel.cpp b/src/smt/smt_kernel.cpp index 567f0eef5..9d6963142 100644 --- a/src/smt/smt_kernel.cpp +++ b/src/smt/smt_kernel.cpp @@ -235,11 +235,26 @@ namespace smt { void user_propagate_init( void* ctx, - std::function& fixed_eh, std::function& push_eh, std::function& pop_eh, std::function& fresh_eh) { - m_kernel.user_propagate_init(ctx, fixed_eh, push_eh, pop_eh, fresh_eh); + m_kernel.user_propagate_init(ctx, push_eh, pop_eh, fresh_eh); + } + + void user_propagate_register_final(solver::final_eh_t& final_eh) { + m_kernel.user_propagate_register_final(final_eh); + } + + void user_propagate_register_fixed(solver::fixed_eh_t& fixed_eh) { + m_kernel.user_propagate_register_fixed(fixed_eh); + } + + void user_propagate_register_eq(solver::eq_eh_t& eq_eh) { + m_kernel.user_propagate_register_eq(eq_eh); + } + + void user_propagate_register_diseq(solver::eq_eh_t& diseq_eh) { + m_kernel.user_propagate_register_diseq(diseq_eh); } unsigned user_propagate_register(expr* e) { @@ -460,11 +475,26 @@ namespace smt { void kernel::user_propagate_init( void* ctx, - std::function& fixed_eh, std::function& push_eh, std::function& pop_eh, std::function& fresh_eh) { - m_imp->user_propagate_init(ctx, fixed_eh, push_eh, pop_eh, fresh_eh); + m_imp->user_propagate_init(ctx, push_eh, pop_eh, fresh_eh); + } + + void kernel::user_propagate_register_fixed(solver::fixed_eh_t& fixed_eh) { + m_imp->user_propagate_register_fixed(fixed_eh); + } + + void kernel::user_propagate_register_final(solver::final_eh_t& final_eh) { + m_imp->user_propagate_register_final(final_eh); + } + + void kernel::user_propagate_register_eq(solver::eq_eh_t& eq_eh) { + m_imp->user_propagate_register_eq(eq_eh); + } + + void kernel::user_propagate_register_diseq(solver::eq_eh_t& diseq_eh) { + m_imp->user_propagate_register_diseq(diseq_eh); } unsigned kernel::user_propagate_register(expr* e) { diff --git a/src/smt/smt_kernel.h b/src/smt/smt_kernel.h index 3eac0bc18..c3a6d180e 100644 --- a/src/smt/smt_kernel.h +++ b/src/smt/smt_kernel.h @@ -290,11 +290,19 @@ namespace smt { */ void user_propagate_init( void* ctx, - std::function& fixed_eh, std::function& push_eh, std::function& pop_eh, std::function& fresh_eh); + void user_propagate_register_fixed(solver::fixed_eh_t& fixed_eh); + + void user_propagate_register_final(solver::final_eh_t& final_eh); + + void user_propagate_register_eq(solver::eq_eh_t& eq_eh); + + void user_propagate_register_diseq(solver::eq_eh_t& diseq_eh); + + /** \brief register an expression to be tracked fro user propagation. */ diff --git a/src/smt/smt_solver.cpp b/src/smt/smt_solver.cpp index eddc0e0d3..42854b811 100644 --- a/src/smt/smt_solver.cpp +++ b/src/smt/smt_solver.cpp @@ -210,11 +210,26 @@ namespace { void user_propagate_init( void* ctx, - std::function& fixed_eh, std::function& push_eh, std::function& pop_eh, std::function& fresh_eh) override { - m_context.user_propagate_init(ctx, fixed_eh, push_eh, pop_eh, fresh_eh); + m_context.user_propagate_init(ctx, push_eh, pop_eh, fresh_eh); + } + + void user_propagate_register_fixed(solver::fixed_eh_t& fixed_eh) override { + m_context.user_propagate_register_fixed(fixed_eh); + } + + void user_propagate_register_final(solver::final_eh_t& final_eh) override { + m_context.user_propagate_register_final(final_eh); + } + + void user_propagate_register_eq(solver::eq_eh_t& eq_eh) override { + m_context.user_propagate_register_eq(eq_eh); + } + + void user_propagate_register_diseq(solver::eq_eh_t& diseq_eh) override { + m_context.user_propagate_register_diseq(diseq_eh); } unsigned user_propagate_register(expr* e) override { diff --git a/src/smt/user_propagator.cpp b/src/smt/user_propagator.cpp index 5c401ffff..2bd4b2844 100644 --- a/src/smt/user_propagator.cpp +++ b/src/smt/user_propagator.cpp @@ -56,11 +56,25 @@ void user_propagator::propagate(unsigned sz, unsigned const* ids, expr* conseq) theory * user_propagator::mk_fresh(context * new_ctx) { auto* th = alloc(user_propagator, *new_ctx); void* ctx = m_fresh_eh(m_user_context); - th->add(ctx, m_fixed_eh, m_push_eh, m_pop_eh, m_fresh_eh); + th->add(ctx, m_push_eh, m_pop_eh, m_fresh_eh); + if ((bool)m_fixed_eh) th->register_fixed(m_fixed_eh); + if ((bool)m_final_eh) th->register_final(m_final_eh); + if ((bool)m_eq_eh) th->register_eq(m_eq_eh); + if ((bool)m_diseq_eh) th->register_diseq(m_diseq_eh); return th; } +final_check_status user_propagator::final_check_eh() { + if (!(bool)m_final_eh) + return FC_DONE; + unsigned sz = m_prop.size(); + m_final_eh(m_user_context, this); + return sz == m_prop.size() ? FC_DONE : FC_CONTINUE; +} + void user_propagator::new_fixed_eh(theory_var v, expr* value, unsigned num_lits, literal const* jlits) { + if (!m_fixed_eh) + return; force_push(); m_id2justification.setx(v, literal_vector(num_lits, jlits), literal_vector()); m_fixed_eh(m_user_context, this, v, value); diff --git a/src/smt/user_propagator.h b/src/smt/user_propagator.h index db8450865..409891b86 100644 --- a/src/smt/user_propagator.h +++ b/src/smt/user_propagator.h @@ -30,10 +30,14 @@ Notes: namespace smt { class user_propagator : public theory, public solver::propagate_callback { void* m_user_context; - std::function m_fixed_eh; std::function m_push_eh; std::function m_pop_eh; std::function m_fresh_eh; + solver::final_eh_t m_final_eh; + solver::fixed_eh_t m_fixed_eh; + solver::eq_eh_t m_eq_eh; + solver::eq_eh_t m_diseq_eh; + struct prop_info { unsigned_vector m_ids; expr_ref m_conseq; @@ -61,12 +65,10 @@ namespace smt { */ void add( void* ctx, - std::function& fixed_eh, std::function& push_eh, std::function& pop_eh, std::function& fresh_eh) { m_user_context = ctx; - m_fixed_eh = fixed_eh; m_push_eh = push_eh; m_pop_eh = pop_eh; m_fresh_eh = fresh_eh; @@ -74,6 +76,11 @@ namespace smt { unsigned add_expr(expr* e); + void register_final(solver::final_eh_t& final_eh) { m_final_eh = final_eh; } + void register_fixed(solver::fixed_eh_t& fixed_eh) { m_fixed_eh = fixed_eh; } + void register_eq(solver::eq_eh_t& eq_eh) { m_eq_eh = eq_eh; } + void register_diseq(solver::eq_eh_t& diseq_eh) { m_diseq_eh = diseq_eh; } + void propagate(unsigned sz, unsigned const* ids, expr* conseq) override; void new_fixed_eh(theory_var v, expr* value, unsigned num_lits, literal const* jlits); @@ -81,11 +88,11 @@ namespace smt { theory * mk_fresh(context * new_ctx) override; bool internalize_atom(app * atom, bool gate_ctx) override { UNREACHABLE(); return false; } bool internalize_term(app * term) override { UNREACHABLE(); return false; } - void new_eq_eh(theory_var v1, theory_var v2) override { } - void new_diseq_eh(theory_var v1, theory_var v2) override { } - bool use_diseqs() const override { return false; } + void new_eq_eh(theory_var v1, theory_var v2) override { if (m_eq_eh) m_eq_eh(m_user_context, this, v1, v2); } + void new_diseq_eh(theory_var v1, theory_var v2) override { if (m_diseq_eh) m_diseq_eh(m_user_context, this, v1, v2); } + bool use_diseqs() const override { return ((bool)m_diseq_eh); } bool build_models() const override { return false; } - final_check_status final_check_eh() override { return FC_DONE; } + final_check_status final_check_eh() override; void reset_eh() override {} void assign_eh(bool_var v, bool is_true) override { } void init_search_eh() override {} diff --git a/src/solver/solver.h b/src/solver/solver.h index f5452029d..2fe014986 100644 --- a/src/solver/solver.h +++ b/src/solver/solver.h @@ -25,6 +25,7 @@ Notes: class solver; class model_converter; + class solver_factory { public: virtual ~solver_factory() {} @@ -238,21 +239,42 @@ public: virtual expr_ref get_implied_upper_bound(expr* e) = 0; - class propagate_callback { - public: - virtual void propagate(unsigned sz, unsigned const* ids, expr* conseq) = 0; - }; - virtual void user_propagate_init( void* ctx, - std::function& fixed_eh, std::function& push_eh, std::function& pop_eh, std::function& fresh_eh) { throw default_exception("user-propagators are only supported on the SMT solver"); } - virtual unsigned user_propagate_register(expr* e) { return 0; } + class propagate_callback { + public: + virtual void propagate(unsigned sz, unsigned const* ids, expr* conseq) = 0; + }; + + typedef std::function final_eh_t; + typedef std::function fixed_eh_t; + typedef std::function eq_eh_t; + + virtual void user_propagate_register_fixed(fixed_eh_t& fixed_eh) { + throw default_exception("user-propagators are only supported on the SMT solver"); + } + + virtual void user_propagate_register_final(final_eh_t& final_eh) { + throw default_exception("user-propagators are only supported on the SMT solver"); + } + + virtual void user_propagate_register_eq(eq_eh_t& eq_eh) { + throw default_exception("user-propagators are only supported on the SMT solver"); + } + + virtual void user_propagate_register_diseq(eq_eh_t& diseq_eh) { + throw default_exception("user-propagators are only supported on the SMT solver"); + } + + virtual unsigned user_propagate_register(expr* e) { + throw default_exception("user-propagators are only supported on the SMT solver"); + } /**