mirror of
https://github.com/Z3Prover/z3
synced 2025-04-23 00:55:31 +00:00
inequality propagation
Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
parent
a4696a1c27
commit
6a829f831d
7 changed files with 720 additions and 56 deletions
383
scripts/fixplex.py
Normal file
383
scripts/fixplex.py
Normal file
|
@ -0,0 +1,383 @@
|
|||
#
|
||||
# The following script synthesizes case analysis for bounds propagation with inequalities.
|
||||
# There are two versions of the script: non-strict and strict inequality v <= w, v < w,
|
||||
# respectively.
|
||||
#
|
||||
# It is used for code in src/math/polysat/fixplex_def.h
|
||||
#
|
||||
|
||||
from z3 import *
|
||||
|
||||
nb = 12
|
||||
v = BitVec("v", nb)
|
||||
w = BitVec("w", nb)
|
||||
i, j, k, l = BitVecs('lo(v) hi(v) lo(w) hi(w)', nb)
|
||||
|
||||
def in_bounds(x, i, j):
|
||||
return Or([And(ULT(i, j), ULE(i, x), ULT(x, j)),
|
||||
And(ULT(j, i), j != 0, ULE(i, x)),
|
||||
And(ULT(j, i), j != 0, ULT(x, j)),
|
||||
And(ULT(j, i), j == 0, ULE(i, x)),
|
||||
i == j])
|
||||
|
||||
def at_upper(x, i, j):
|
||||
return Or([i == j, x + 1 == j])
|
||||
|
||||
|
||||
s = Solver()
|
||||
s0 = Solver()
|
||||
s1 = Solver()
|
||||
s.add(in_bounds(v, i, j))
|
||||
s.add(in_bounds(w, k, l))
|
||||
s1.add(in_bounds(v, i, j))
|
||||
s1.add(in_bounds(w, k, l))
|
||||
|
||||
s.set("core.minimize", True)
|
||||
s1.set("core.minimize", True)
|
||||
|
||||
def add_def(name, p):
|
||||
b = Bool(name)
|
||||
s.add(b == p)
|
||||
s0.add(b == p)
|
||||
s1.add(b == p)
|
||||
return b
|
||||
|
||||
is_free_v = add_def("is_free(v)", i == j)
|
||||
is_free_w = add_def("is_free(w)", k == l)
|
||||
is_zero_lo_v = add_def("lo(v) == 0", i == 0)
|
||||
is_zero_lo_w = add_def("lo(w) == 0", k == 0)
|
||||
s.add(Implies(is_free_v, is_zero_lo_v))
|
||||
s.add(Implies(is_free_w, is_zero_lo_w))
|
||||
s0.add(Implies(is_free_v, is_zero_lo_v))
|
||||
s0.add(Implies(is_free_w, is_zero_lo_w))
|
||||
s1.add(Implies(is_free_v, is_zero_lo_v))
|
||||
s1.add(Implies(is_free_w, is_zero_lo_w))
|
||||
|
||||
preds = [add_def("lo(v) <= hi(v)", ULE(i, j)),
|
||||
add_def("lo(w) <= hi(w)", ULE(k, l)),
|
||||
add_def("hi(v) <= lo(w)", ULE(j, k)),
|
||||
add_def("lo(w) <= hi(v)", ULE(k, j)),
|
||||
add_def("lo(v) <= lo(w)", ULE(i, k)),
|
||||
add_def("lo(w) <= lo(v)", ULE(k, i)),
|
||||
add_def("hi(w) <= lo(v)", ULE(l, i)),
|
||||
add_def("lo(v) <= hi(w)", ULE(i, l)),
|
||||
add_def("hi(w) <= hi(v)", ULE(l, j)),
|
||||
add_def("hi(v) <= hi(w)", ULE(j, l)),
|
||||
is_zero_lo_v,
|
||||
add_def("hi(v) == 0", j == 0),
|
||||
is_zero_lo_w,
|
||||
add_def("hi(w) == 0", l == 0),
|
||||
add_def("hi(v) == 1", j == 1),
|
||||
add_def("hi(w) == 1", l == 1),
|
||||
add_def("is_fixed(v)", i + 1 == j),
|
||||
add_def("is_fixed(w)", k + 1 == l),
|
||||
add_def("lo(v) + 1 == hi(w)", i + 1 == l),
|
||||
add_def("lo(v) + 1 == 0", i + 1 == 0),
|
||||
is_free_v,
|
||||
is_free_w
|
||||
]
|
||||
|
||||
def is_tight(s, core, x, lo, hi):
|
||||
s.push()
|
||||
s.add(core)
|
||||
s.add(Not(in_bounds(x, lo, hi)))
|
||||
r = s.check()
|
||||
s.pop()
|
||||
if unsat != r:
|
||||
return False
|
||||
s.push()
|
||||
s.add(core)
|
||||
s.add(x == lo)
|
||||
r = s.check()
|
||||
s.pop()
|
||||
if sat != r:
|
||||
return False
|
||||
s.push()
|
||||
s.add(core)
|
||||
s.add(x + 1 == hi, hi != lo)
|
||||
r = s.check()
|
||||
s.pop()
|
||||
if sat != r:
|
||||
return False
|
||||
#print(core, x, lo, hi)
|
||||
#print(core)
|
||||
#print(Not(in_bounds(x, lo, hi)))
|
||||
#print(s)
|
||||
return True
|
||||
|
||||
def is_tighter(s, core, x, lo1, hi1, lo2, hi2):
|
||||
s.push()
|
||||
s.add(core)
|
||||
s.add(in_bounds(x, lo1, hi1))
|
||||
s.add(Not(in_bounds(x, lo2, hi2)))
|
||||
r = s.check()
|
||||
s.pop()
|
||||
return r == unsat
|
||||
|
||||
def core2deps(core):
|
||||
deps = set([])
|
||||
for c in core:
|
||||
sc = f"{c}"
|
||||
if "lo(v)" in sc:
|
||||
deps |= { "vlo" }
|
||||
if "lo(w)" in sc:
|
||||
deps |= { "wlo" }
|
||||
if "hi(v)" in sc:
|
||||
deps |= { "vhi" }
|
||||
if "hi(w)" in sc:
|
||||
deps |= { "whi" }
|
||||
if "fixed(v)" in sc:
|
||||
deps |= { "vlo", "vhi" }
|
||||
if "fixed(w)" in sc:
|
||||
deps |= { "wlo", "whi" }
|
||||
deps = list(deps)
|
||||
sorted(deps)
|
||||
return ", ".join(deps)
|
||||
|
||||
def core2pred(core):
|
||||
return " && ".join([f"!({c.arg(0)})" if is_not(c) else f"{c}" for c in core ])
|
||||
|
||||
|
||||
def propagate_bounds(core, x, lo, hi):
|
||||
deps = core2deps(core)
|
||||
sys.stdout.write("if (")
|
||||
sys.stdout.write(core2pred(core))
|
||||
sys.stdout.write(f" && !new_bound(i, {x}, {lo}, {hi}, {deps}))\n")
|
||||
sys.stdout.write(" return false;\n")
|
||||
sys.stdout.flush()
|
||||
|
||||
def propagate_conflict(core):
|
||||
deps = core2deps(core)
|
||||
sys.stdout.write("if (")
|
||||
sys.stdout.write(core2pred(core))
|
||||
sys.stdout.write(f")\n")
|
||||
sys.stdout.write(f" return conflict({deps}), false;\n")
|
||||
sys.stdout.flush()
|
||||
|
||||
lows = [BitVecVal(0, nb), l, k, i, j, k + 1, i + 1]
|
||||
highs = [BitVecVal(0, nb), l, k, i, j, l - 1, j - 1]
|
||||
|
||||
def find_new_bounds(s, core, x):
|
||||
bound = None
|
||||
for lo in lows:
|
||||
for hi in highs:
|
||||
if is_tight(s, core, x, lo, hi):
|
||||
if not bound:
|
||||
bound = (lo, hi)
|
||||
else:
|
||||
lo2, hi2 = bound
|
||||
if is_tighter(s, core, x, lo, hi, lo2, hi2):
|
||||
#print("tighter", lo, hi, lo2, hi2)
|
||||
bound = (lo, hi)
|
||||
|
||||
if bound:
|
||||
lo, hi = bound
|
||||
propagate_bounds(core, x, lo, hi)
|
||||
else:
|
||||
print("Could not find new bounds", x, lows, highs)
|
||||
|
||||
|
||||
|
||||
|
||||
num_tries = 0
|
||||
num_found = 0
|
||||
num_nodes = 0
|
||||
|
||||
# set_param(verbose=2)
|
||||
|
||||
def explore(s, s0, ps):
|
||||
global num_tries
|
||||
global num_found
|
||||
num_tries += 1
|
||||
r = s.check(ps)
|
||||
if r == unsat:
|
||||
core = s.unsat_core()
|
||||
propagate_conflict(core)
|
||||
s0.add(Not(And(core)))
|
||||
num_found += 1
|
||||
|
||||
return
|
||||
|
||||
found = False
|
||||
s.push()
|
||||
s.add(v == i)
|
||||
r = s.check(ps)
|
||||
if r == unsat:
|
||||
core = s.unsat_core()
|
||||
s0.add(Not(And(core)))
|
||||
found = True
|
||||
s.pop()
|
||||
if r == unsat:
|
||||
find_new_bounds(s, core, v)
|
||||
|
||||
s.push()
|
||||
s.add(w == k)
|
||||
r = s.check(ps)
|
||||
if r == unsat:
|
||||
core = s.unsat_core()
|
||||
s0.add(Not(And(core)))
|
||||
found = True
|
||||
s.pop()
|
||||
if r == unsat:
|
||||
find_new_bounds(s, core, w)
|
||||
|
||||
s.push()
|
||||
s.add(at_upper(v, i, j))
|
||||
r = s.check(ps)
|
||||
if r == unsat:
|
||||
core = s.unsat_core()
|
||||
s0.add(Not(And(core)))
|
||||
found = True
|
||||
s.pop()
|
||||
if r == unsat:
|
||||
find_new_bounds(s, core, v)
|
||||
|
||||
s.push()
|
||||
s.add(at_upper(w, k, l))
|
||||
r = s.check(ps)
|
||||
if r == unsat:
|
||||
core = s.unsat_core()
|
||||
s0.add(Not(And(core)))
|
||||
found = True
|
||||
s.pop()
|
||||
if r == unsat:
|
||||
find_new_bounds(s, core, w)
|
||||
|
||||
if found:
|
||||
num_found += 1
|
||||
|
||||
|
||||
def search(s, s0, trail, preds):
|
||||
global num_nodes
|
||||
num_nodes += 1
|
||||
r = s0.check(trail)
|
||||
if r == unsat:
|
||||
return
|
||||
if len(preds) == 0:
|
||||
explore(s, s0, trail)
|
||||
return
|
||||
hd = preds[0]
|
||||
tl = preds[1:]
|
||||
search(s, s0, trail + [hd], tl)
|
||||
search(s, s0, trail + [Not(hd)], tl)
|
||||
|
||||
def create_bounds(p):
|
||||
global num_tries
|
||||
global num_found
|
||||
global num_nodes
|
||||
num_tries = 0
|
||||
num_found = 0
|
||||
num_nodes = 0
|
||||
s0.push()
|
||||
s.push()
|
||||
s.add(p)
|
||||
search(s, s0, [], preds)
|
||||
s.pop()
|
||||
s0.pop()
|
||||
print("attempted predicates: ", num_tries, "predicates: ", num_found, "nodes: ", num_nodes)
|
||||
|
||||
def search_primal():
|
||||
print("strict")
|
||||
create_bounds(ULT(v, w))
|
||||
print("non-strict")
|
||||
create_bounds(ULE(v, w))
|
||||
|
||||
#search_primal()
|
||||
|
||||
def extract_predicates(s):
|
||||
for p in preds:
|
||||
r = s.check(p)
|
||||
if r == sat:
|
||||
yield p
|
||||
r = s.check(Not(p))
|
||||
if r == sat:
|
||||
yield Not(p)
|
||||
|
||||
def test_le(ineq, lov, hiv, low, hiw):
|
||||
if lov == hiv and lov > 0:
|
||||
return
|
||||
if low == hiw and low > 0:
|
||||
return
|
||||
s0.push()
|
||||
s0.add(i == lov)
|
||||
s0.add(j == hiv)
|
||||
s0.add(k == low)
|
||||
s0.add(l == hiw)
|
||||
r = s0.check()
|
||||
s0.pop()
|
||||
if r == unsat:
|
||||
return
|
||||
s.push()
|
||||
s.add(i == lov)
|
||||
s.add(j == hiv)
|
||||
s.add(k == low)
|
||||
s.add(l == hiw)
|
||||
r = s.check()
|
||||
assert r == sat
|
||||
|
||||
preds = list(extract_predicates(s))
|
||||
s.add(ineq)
|
||||
if r == unsat:
|
||||
print("core", preds)
|
||||
s.pop()
|
||||
return
|
||||
|
||||
def test_bound(x, p):
|
||||
s.push()
|
||||
s.add(p)
|
||||
r = s.check()
|
||||
s.pop()
|
||||
if r == unsat:
|
||||
s1.push()
|
||||
s1.add(p)
|
||||
s1.add(ineq)
|
||||
r = s1.check(preds)
|
||||
if r == unsat:
|
||||
core = [c for c in s1.unsat_core()]
|
||||
else:
|
||||
print("Did not find core for lower bound v")
|
||||
print(lov, hiv, low, hiw)
|
||||
print(s1)
|
||||
for p in preds:
|
||||
print(p)
|
||||
s1.pop()
|
||||
if r == unsat:
|
||||
s1.push()
|
||||
s1.add(ineq)
|
||||
r = s1.check(core)
|
||||
if r == unsat:
|
||||
propagate_conflict(core)
|
||||
else:
|
||||
find_new_bounds(s1, core, x)
|
||||
s1.pop()
|
||||
s0.add(Not(And(core)))
|
||||
|
||||
test_bound(v, v == i)
|
||||
test_bound(w, w == k)
|
||||
test_bound(v, at_upper(v, i, j))
|
||||
test_bound(w, at_upper(w, k, l))
|
||||
s.pop()
|
||||
|
||||
|
||||
bounds = [0, 1, 2, 3, 10, 2**nb - 3, 2**nb - 2, 2**nb - 1]
|
||||
|
||||
def search_dual(p):
|
||||
for i in bounds:
|
||||
for j in bounds:
|
||||
for k in bounds:
|
||||
for l in bounds:
|
||||
test_le(p, i, j, k, l)
|
||||
|
||||
|
||||
s0.push()
|
||||
s1.push()
|
||||
print("strict")
|
||||
search_dual(ULT(v, w))
|
||||
s0.pop()
|
||||
s1.pop()
|
||||
|
||||
print("non-strict")
|
||||
search_dual(ULE(v, w))
|
||||
|
||||
|
||||
|
Loading…
Add table
Add a link
Reference in a new issue