From a5d588ce0979c5a60b716e88e87249a65beb5074 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Tue, 5 Apr 2022 04:26:40 +0200 Subject: [PATCH] add example for #5933 --- examples/python/visitor.py | 76 +++++++++++++++++++++++++++++++++----- 1 file changed, 67 insertions(+), 9 deletions(-) diff --git a/examples/python/visitor.py b/examples/python/visitor.py index 504e2acc8..78ec824fb 100644 --- a/examples/python/visitor.py +++ b/examples/python/visitor.py @@ -17,13 +17,71 @@ def visitor(e, seen): yield e return -x, y = Ints('x y') -fml = x + x + y > 2 -seen = {} -for e in visitor(fml, seen): - if is_const(e) and e.decl().kind() == Z3_OP_UNINTERPRETED: - print("Variable", e) - else: - print(e) - +def modify(e, fn): + seen = {} + def visit(e): + if e in seen: + pass + elif fn(e) is not None: + seen[e] = fn(e) + elif is_and(e): + chs = [visit(ch) for ch in e.children()] + seen[e] = And(chs) + elif is_or(e): + chs = [visit(ch) for ch in e.children()] + seen[e] = Or(chs) + elif is_app(e): + chs = [visit(ch) for ch in e.children()] + seen[e] = e.decl()(chs) + elif is_quantifier(e): + # Note: does not work for Lambda that requires a separate case + body = visit(e.body()) + is_forall = e.is_forall() + num_pats = e.num_patterns() + pats = (Pattern * num_pats)() + for i in range(num_pats): + pats[i] = e.pattern(i).ast + + num_decls = e.num_vars() + sorts = (Sort * num_decls)() + names = (Symbol * num_decls)() + for i in range(num_decls): + sorts[i] = e.var_sort(i).ast + names[i] = to_symbol(e.var_name(i), e.ctx) + r = QuantifierRef(Z3_mk_quantifier(e.ctx_ref(), is_forall, e.weight(), num_pats, pats, num_decls, sorts, names, body.ast), e.ctx) + seen[e] = r + else: + seen[e] = e + return seen[e] + return visit(e) + +if __name__ == "__main__": + x, y = Ints('x y') + fml = x + x + y > 2 + seen = {} + for e in visitor(fml, seen): + if is_const(e) and e.decl().kind() == Z3_OP_UNINTERPRETED: + print("Variable", e) + else: + print(e) + + s = SolverFor("HORN") + inv = Function('inv', IntSort(), IntSort(), BoolSort()) + i, ip, j, jp = Ints('i ip j jp') + s.add(ForAll([i, j], Implies(i == 0, inv(i, j)))) + s.add(ForAll([i, ip, j, jp], Implies(And(inv(i, j), i < 10, ip == i + 1), inv(ip, jp)))) + s.add(ForAll([i, j], Implies(And(inv(i, j), i >= 10), i == 10))) + + a0, a1, a2 = Ints('a0 a1 a2') + b0, b1, b2 = Ints('b0 b1 b2') + x = Var(0, IntSort()) + y = Var(1, IntSort()) + template = And(a0 + a1*x + a2*y >= 0, b0 + b1*x + b2*y >= 0) + def update(e): + if is_app(e) and eq(e.decl(), inv): + return substitute_vars(template, (e.arg(0)), e.arg(1)) + return None + for f in s.assertions(): + f_new = modify(f, update) + print(f_new)