diff --git a/src/api/go/propagator_callbacks.go b/src/api/go/propagator_callbacks.go index dfcf09583..8653f7fcb 100644 --- a/src/api/go/propagator_callbacks.go +++ b/src/api/go/propagator_callbacks.go @@ -10,15 +10,20 @@ import ( "unsafe" ) +// withCallback temporarily sets the callback context, calls fn, and restores the old context. +func (p *UserPropagator) withCallback(cb C.Z3_solver_callback, fn func()) { + old := p.cb + p.cb = cb + defer func() { p.cb = old }() + fn() +} + // goPushCb is exported to C as a callback for Z3_push_eh. // //export goPushCb func goPushCb(ctx C.uintptr_t, cb C.Z3_solver_callback) { p := cgo.Handle(ctx).Value().(*UserPropagator) - old := p.cb - p.cb = cb - defer func() { p.cb = old }() - p.iface.Push() + p.withCallback(cb, p.iface.Push) } // goPopCb is exported to C as a callback for Z3_pop_eh. @@ -26,10 +31,9 @@ func goPushCb(ctx C.uintptr_t, cb C.Z3_solver_callback) { //export goPopCb func goPopCb(ctx C.uintptr_t, cb C.Z3_solver_callback, numScopes C.uint) { p := cgo.Handle(ctx).Value().(*UserPropagator) - old := p.cb - p.cb = cb - defer func() { p.cb = old }() - p.iface.Pop(uint(numScopes)) + p.withCallback(cb, func() { + p.iface.Pop(uint(numScopes)) + }) } // goFreshCb is exported to C as a callback for Z3_fresh_eh. @@ -53,10 +57,9 @@ func goFreshCb(ctx C.uintptr_t, newContext C.Z3_context) C.uintptr_t { func goFixedCb(ctx C.uintptr_t, cb C.Z3_solver_callback, t C.Z3_ast, value C.Z3_ast) { p := cgo.Handle(ctx).Value().(*UserPropagator) if h, ok := p.iface.(FixedHandler); ok { - old := p.cb - p.cb = cb - defer func() { p.cb = old }() - h.Fixed(newExpr(p.ctx, t), newExpr(p.ctx, value)) + p.withCallback(cb, func() { + h.Fixed(newExpr(p.ctx, t), newExpr(p.ctx, value)) + }) } } @@ -66,10 +69,9 @@ func goFixedCb(ctx C.uintptr_t, cb C.Z3_solver_callback, t C.Z3_ast, value C.Z3_ func goEqCb(ctx C.uintptr_t, cb C.Z3_solver_callback, s C.Z3_ast, t C.Z3_ast) { p := cgo.Handle(ctx).Value().(*UserPropagator) if h, ok := p.iface.(EqHandler); ok { - old := p.cb - p.cb = cb - defer func() { p.cb = old }() - h.Eq(newExpr(p.ctx, s), newExpr(p.ctx, t)) + p.withCallback(cb, func() { + h.Eq(newExpr(p.ctx, s), newExpr(p.ctx, t)) + }) } } @@ -79,10 +81,9 @@ func goEqCb(ctx C.uintptr_t, cb C.Z3_solver_callback, s C.Z3_ast, t C.Z3_ast) { func goDiseqCb(ctx C.uintptr_t, cb C.Z3_solver_callback, s C.Z3_ast, t C.Z3_ast) { p := cgo.Handle(ctx).Value().(*UserPropagator) if h, ok := p.iface.(DiseqHandler); ok { - old := p.cb - p.cb = cb - defer func() { p.cb = old }() - h.Diseq(newExpr(p.ctx, s), newExpr(p.ctx, t)) + p.withCallback(cb, func() { + h.Diseq(newExpr(p.ctx, s), newExpr(p.ctx, t)) + }) } } @@ -92,10 +93,7 @@ func goDiseqCb(ctx C.uintptr_t, cb C.Z3_solver_callback, s C.Z3_ast, t C.Z3_ast) func goFinalCb(ctx C.uintptr_t, cb C.Z3_solver_callback) { p := cgo.Handle(ctx).Value().(*UserPropagator) if h, ok := p.iface.(FinalHandler); ok { - old := p.cb - p.cb = cb - defer func() { p.cb = old }() - h.Final() + p.withCallback(cb, h.Final) } } @@ -105,10 +103,9 @@ func goFinalCb(ctx C.uintptr_t, cb C.Z3_solver_callback) { func goCreatedCb(ctx C.uintptr_t, cb C.Z3_solver_callback, t C.Z3_ast) { p := cgo.Handle(ctx).Value().(*UserPropagator) if h, ok := p.iface.(CreatedHandler); ok { - old := p.cb - p.cb = cb - defer func() { p.cb = old }() - h.Created(newExpr(p.ctx, t)) + p.withCallback(cb, func() { + h.Created(newExpr(p.ctx, t)) + }) } } @@ -118,10 +115,9 @@ func goCreatedCb(ctx C.uintptr_t, cb C.Z3_solver_callback, t C.Z3_ast) { func goDecideCb(ctx C.uintptr_t, cb C.Z3_solver_callback, t C.Z3_ast, idx C.uint, phase C.bool) { p := cgo.Handle(ctx).Value().(*UserPropagator) if h, ok := p.iface.(DecideHandler); ok { - old := p.cb - p.cb = cb - defer func() { p.cb = old }() - h.Decide(newExpr(p.ctx, t), uint(idx), phase == C.bool(true)) + p.withCallback(cb, func() { + h.Decide(newExpr(p.ctx, t), uint(idx), phase == C.bool(true)) + }) } } @@ -130,13 +126,15 @@ func goDecideCb(ctx C.uintptr_t, cb C.Z3_solver_callback, t C.Z3_ast, idx C.uint //export goOnBindingCb func goOnBindingCb(ctx C.uintptr_t, cb C.Z3_solver_callback, q C.Z3_ast, inst C.Z3_ast) C.bool { p := cgo.Handle(ctx).Value().(*UserPropagator) + result := C.bool(true) // default: allow binding when handler is not implemented if h, ok := p.iface.(OnBindingHandler); ok { - old := p.cb - p.cb = cb - defer func() { p.cb = old }() - return C.bool(h.OnBinding(newExpr(p.ctx, q), newExpr(p.ctx, inst))) + p.withCallback(cb, func() { + if !h.OnBinding(newExpr(p.ctx, q), newExpr(p.ctx, inst)) { + result = C.bool(false) + } + }) } - return C.bool(true) + return result } // goOnClauseCb is exported to C as a callback for Z3_on_clause_eh.