diff --git a/src/api/go/z3.go b/src/api/go/z3.go index b1a93d19e..4e982111e 100644 --- a/src/api/go/z3.go +++ b/src/api/go/z3.go @@ -240,6 +240,45 @@ func newExpr(ctx *Context, ptr C.Z3_ast) *Expr { return expr } +// intsToCs converts a []int slice to []C.int, returning the slice and +// a pointer to its first element (nil if empty). +func intsToCs(ints []int) ([]C.int, *C.int) { + if len(ints) == 0 { + return nil, nil + } + cInts := make([]C.int, len(ints)) + for i, v := range ints { + cInts[i] = C.int(v) + } + return cInts, &cInts[0] +} + +// exprsToASTs converts a []*Expr slice to []C.Z3_ast, returning the slice and +// a pointer to its first element (nil if empty). +func exprsToASTs(exprs []*Expr) ([]C.Z3_ast, *C.Z3_ast) { + if len(exprs) == 0 { + return nil, nil + } + cExprs := make([]C.Z3_ast, len(exprs)) + for i, e := range exprs { + cExprs[i] = e.ptr + } + return cExprs, &cExprs[0] +} + +// sortsToCSorts converts a []*Sort slice to []C.Z3_sort, returning the slice and +// a pointer to its first element (nil if empty). +func sortsToCSorts(sorts []*Sort) ([]C.Z3_sort, *C.Z3_sort) { + if len(sorts) == 0 { + return nil, nil + } + cSorts := make([]C.Z3_sort, len(sorts)) + for i, s := range sorts { + cSorts[i] = s.ptr + } + return cSorts, &cSorts[0] +} + // String returns the string representation of the expression. func (e *Expr) String() string { return C.GoString(C.Z3_ast_to_string(e.ctx.ptr, e.ptr)) @@ -368,11 +407,8 @@ func (c *Context) MkAnd(exprs ...*Expr) *Expr { if len(exprs) == 1 { return exprs[0] } - cExprs := make([]C.Z3_ast, len(exprs)) - for i, e := range exprs { - cExprs[i] = e.ptr - } - return newExpr(c, C.Z3_mk_and(c.ptr, C.uint(len(exprs)), &cExprs[0])) + _, cExprsPtr := exprsToASTs(exprs) + return newExpr(c, C.Z3_mk_and(c.ptr, C.uint(len(exprs)), cExprsPtr)) } // MkOr creates a disjunction. @@ -383,11 +419,8 @@ func (c *Context) MkOr(exprs ...*Expr) *Expr { if len(exprs) == 1 { return exprs[0] } - cExprs := make([]C.Z3_ast, len(exprs)) - for i, e := range exprs { - cExprs[i] = e.ptr - } - return newExpr(c, C.Z3_mk_or(c.ptr, C.uint(len(exprs)), &cExprs[0])) + _, cExprsPtr := exprsToASTs(exprs) + return newExpr(c, C.Z3_mk_or(c.ptr, C.uint(len(exprs)), cExprsPtr)) } // MkNot creates a negation. @@ -422,38 +455,21 @@ func (c *Context) MkDistinct(exprs ...*Expr) *Expr { if len(exprs) <= 1 { return c.MkTrue() } - cExprs := make([]C.Z3_ast, len(exprs)) - for i, e := range exprs { - cExprs[i] = e.ptr - } - return newExpr(c, C.Z3_mk_distinct(c.ptr, C.uint(len(exprs)), &cExprs[0])) + _, cExprsPtr := exprsToASTs(exprs) + return newExpr(c, C.Z3_mk_distinct(c.ptr, C.uint(len(exprs)), cExprsPtr)) } // Pseudo-Boolean / cardinality constraints // MkAtMost encodes p1 + p2 + ... + pn <= k. func (c *Context) MkAtMost(args []*Expr, k uint) *Expr { - cArgs := make([]C.Z3_ast, len(args)) - for i, a := range args { - cArgs[i] = a.ptr - } - var cArgsPtr *C.Z3_ast - if len(cArgs) > 0 { - cArgsPtr = &cArgs[0] - } + _, cArgsPtr := exprsToASTs(args) return newExpr(c, C.Z3_mk_atmost(c.ptr, C.uint(len(args)), cArgsPtr, C.uint(k))) } // MkAtLeast encodes p1 + p2 + ... + pn >= k. func (c *Context) MkAtLeast(args []*Expr, k uint) *Expr { - cArgs := make([]C.Z3_ast, len(args)) - for i, a := range args { - cArgs[i] = a.ptr - } - var cArgsPtr *C.Z3_ast - if len(cArgs) > 0 { - cArgsPtr = &cArgs[0] - } + _, cArgsPtr := exprsToASTs(args) return newExpr(c, C.Z3_mk_atleast(c.ptr, C.uint(len(args)), cArgsPtr, C.uint(k))) } @@ -462,20 +478,8 @@ func (c *Context) MkPBLe(args []*Expr, coeffs []int, k int) *Expr { if len(args) != len(coeffs) { panic("MkPBLe: args and coeffs must have the same length") } - cArgs := make([]C.Z3_ast, len(args)) - for i, a := range args { - cArgs[i] = a.ptr - } - cCoeffs := make([]C.int, len(coeffs)) - for i, v := range coeffs { - cCoeffs[i] = C.int(v) - } - var cArgsPtr *C.Z3_ast - var cCoeffsPtr *C.int - if len(cArgs) > 0 { - cArgsPtr = &cArgs[0] - cCoeffsPtr = &cCoeffs[0] - } + _, cArgsPtr := exprsToASTs(args) + _, cCoeffsPtr := intsToCs(coeffs) return newExpr(c, C.Z3_mk_pble(c.ptr, C.uint(len(args)), cArgsPtr, cCoeffsPtr, C.int(k))) } @@ -484,20 +488,8 @@ func (c *Context) MkPBGe(args []*Expr, coeffs []int, k int) *Expr { if len(args) != len(coeffs) { panic("MkPBGe: args and coeffs must have the same length") } - cArgs := make([]C.Z3_ast, len(args)) - for i, a := range args { - cArgs[i] = a.ptr - } - cCoeffs := make([]C.int, len(coeffs)) - for i, v := range coeffs { - cCoeffs[i] = C.int(v) - } - var cArgsPtr *C.Z3_ast - var cCoeffsPtr *C.int - if len(cArgs) > 0 { - cArgsPtr = &cArgs[0] - cCoeffsPtr = &cCoeffs[0] - } + _, cArgsPtr := exprsToASTs(args) + _, cCoeffsPtr := intsToCs(coeffs) return newExpr(c, C.Z3_mk_pbge(c.ptr, C.uint(len(args)), cArgsPtr, cCoeffsPtr, C.int(k))) } @@ -506,20 +498,8 @@ func (c *Context) MkPBEq(args []*Expr, coeffs []int, k int) *Expr { if len(args) != len(coeffs) { panic("MkPBEq: args and coeffs must have the same length") } - cArgs := make([]C.Z3_ast, len(args)) - for i, a := range args { - cArgs[i] = a.ptr - } - cCoeffs := make([]C.int, len(coeffs)) - for i, v := range coeffs { - cCoeffs[i] = C.int(v) - } - var cArgsPtr *C.Z3_ast - var cCoeffsPtr *C.int - if len(cArgs) > 0 { - cArgsPtr = &cArgs[0] - cCoeffsPtr = &cCoeffs[0] - } + _, cArgsPtr := exprsToASTs(args) + _, cCoeffsPtr := intsToCs(coeffs) return newExpr(c, C.Z3_mk_pbeq(c.ptr, C.uint(len(args)), cArgsPtr, cCoeffsPtr, C.int(k))) } @@ -569,54 +549,26 @@ func (f *FuncDecl) GetRange() *Sort { // MkFuncDecl creates a function declaration. func (c *Context) MkFuncDecl(name *Symbol, domain []*Sort, range_ *Sort) *FuncDecl { - cDomain := make([]C.Z3_sort, len(domain)) - for i, s := range domain { - cDomain[i] = s.ptr - } - var domainPtr *C.Z3_sort - if len(domain) > 0 { - domainPtr = &cDomain[0] - } + _, domainPtr := sortsToCSorts(domain) return newFuncDecl(c, C.Z3_mk_func_decl(c.ptr, name.ptr, C.uint(len(domain)), domainPtr, range_.ptr)) } // MkRecFuncDecl creates a recursive function declaration. // After creating, use AddRecDef to provide the function body. func (c *Context) MkRecFuncDecl(name *Symbol, domain []*Sort, range_ *Sort) *FuncDecl { - cDomain := make([]C.Z3_sort, len(domain)) - for i, s := range domain { - cDomain[i] = s.ptr - } - var domainPtr *C.Z3_sort - if len(domain) > 0 { - domainPtr = &cDomain[0] - } + _, domainPtr := sortsToCSorts(domain) return newFuncDecl(c, C.Z3_mk_rec_func_decl(c.ptr, name.ptr, C.uint(len(domain)), domainPtr, range_.ptr)) } // AddRecDef adds the definition (body) for a recursive function created with MkRecFuncDecl. func (c *Context) AddRecDef(f *FuncDecl, args []*Expr, body *Expr) { - cArgs := make([]C.Z3_ast, len(args)) - for i, a := range args { - cArgs[i] = a.ptr - } - var argsPtr *C.Z3_ast - if len(args) > 0 { - argsPtr = &cArgs[0] - } + _, argsPtr := exprsToASTs(args) C.Z3_add_rec_def(c.ptr, f.ptr, C.uint(len(args)), argsPtr, body.ptr) } // MkApp creates a function application. func (c *Context) MkApp(decl *FuncDecl, args ...*Expr) *Expr { - cArgs := make([]C.Z3_ast, len(args)) - for i, a := range args { - cArgs[i] = a.ptr - } - var argsPtr *C.Z3_ast - if len(args) > 0 { - argsPtr = &cArgs[0] - } + _, argsPtr := exprsToASTs(args) return newExpr(c, C.Z3_mk_app(c.ptr, decl.ptr, C.uint(len(args)), argsPtr)) } @@ -691,16 +643,8 @@ func (e *Expr) Substitute(from, to []*Expr) *Expr { // SubstituteVars replaces free variables in the expression with the expressions in to. // Variable with de-Bruijn index i is replaced with to[i]. func (e *Expr) SubstituteVars(to []*Expr) *Expr { - n := len(to) - cTo := make([]C.Z3_ast, n) - for i, t := range to { - cTo[i] = t.ptr - } - var toPtr *C.Z3_ast - if n > 0 { - toPtr = &cTo[0] - } - return newExpr(e.ctx, C.Z3_substitute_vars(e.ctx.ptr, e.ptr, C.uint(n), toPtr)) + _, toPtr := exprsToASTs(to) + return newExpr(e.ctx, C.Z3_substitute_vars(e.ctx.ptr, e.ptr, C.uint(len(to)), toPtr)) } // SubstituteFuns replaces every occurrence of from[i] applied to arguments @@ -816,12 +760,7 @@ func (q *Quantifier) String() string { // MkQuantifier creates a quantifier with patterns func (c *Context) MkQuantifier(isForall bool, weight int, sorts []*Sort, names []*Symbol, body *Expr, patterns []*Pattern) *Quantifier { - var forallInt C.bool - if isForall { - forallInt = true - } else { - forallInt = false - } + forallInt := C.bool(isForall) numBound := len(sorts) if numBound != len(names) { @@ -864,12 +803,7 @@ func (c *Context) MkQuantifier(isForall bool, weight int, sorts []*Sort, names [ // MkQuantifierConst creates a quantifier using constant bound variables func (c *Context) MkQuantifierConst(isForall bool, weight int, bound []*Expr, body *Expr, patterns []*Pattern) *Quantifier { - var forallInt C.bool - if isForall { - forallInt = true - } else { - forallInt = false - } + forallInt := C.bool(isForall) numBound := len(bound) var cBound []C.Z3_app