From 6a829f831db8a321a1cd0ddfaf4bdf882aa67c82 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sun, 8 Aug 2021 13:21:15 -0700 Subject: [PATCH] inequality propagation Signed-off-by: Nikolaj Bjorner --- scripts/fixplex.py | 383 +++++++++++++++++++++++++++ src/math/interval/mod_interval.h | 9 +- src/math/interval/mod_interval_def.h | 51 +++- src/math/polysat/fixplex.h | 2 + src/math/polysat/fixplex_def.h | 230 +++++++++++++--- src/test/fixplex.cpp | 59 ++++- src/test/mod_interval.cpp | 42 ++- 7 files changed, 720 insertions(+), 56 deletions(-) create mode 100644 scripts/fixplex.py diff --git a/scripts/fixplex.py b/scripts/fixplex.py new file mode 100644 index 000000000..3119f9321 --- /dev/null +++ b/scripts/fixplex.py @@ -0,0 +1,383 @@ +# +# The following script synthesizes case analysis for bounds propagation with inequalities. +# There are two versions of the script: non-strict and strict inequality v <= w, v < w, +# respectively. +# +# It is used for code in src/math/polysat/fixplex_def.h +# + +from z3 import * + +nb = 12 +v = BitVec("v", nb) +w = BitVec("w", nb) +i, j, k, l = BitVecs('lo(v) hi(v) lo(w) hi(w)', nb) + +def in_bounds(x, i, j): + return Or([And(ULT(i, j), ULE(i, x), ULT(x, j)), + And(ULT(j, i), j != 0, ULE(i, x)), + And(ULT(j, i), j != 0, ULT(x, j)), + And(ULT(j, i), j == 0, ULE(i, x)), + i == j]) + +def at_upper(x, i, j): + return Or([i == j, x + 1 == j]) + + +s = Solver() +s0 = Solver() +s1 = Solver() +s.add(in_bounds(v, i, j)) +s.add(in_bounds(w, k, l)) +s1.add(in_bounds(v, i, j)) +s1.add(in_bounds(w, k, l)) + +s.set("core.minimize", True) +s1.set("core.minimize", True) + +def add_def(name, p): + b = Bool(name) + s.add(b == p) + s0.add(b == p) + s1.add(b == p) + return b + +is_free_v = add_def("is_free(v)", i == j) +is_free_w = add_def("is_free(w)", k == l) +is_zero_lo_v = add_def("lo(v) == 0", i == 0) +is_zero_lo_w = add_def("lo(w) == 0", k == 0) +s.add(Implies(is_free_v, is_zero_lo_v)) +s.add(Implies(is_free_w, is_zero_lo_w)) +s0.add(Implies(is_free_v, is_zero_lo_v)) +s0.add(Implies(is_free_w, is_zero_lo_w)) +s1.add(Implies(is_free_v, is_zero_lo_v)) +s1.add(Implies(is_free_w, is_zero_lo_w)) + +preds = [add_def("lo(v) <= hi(v)", ULE(i, j)), + add_def("lo(w) <= hi(w)", ULE(k, l)), + add_def("hi(v) <= lo(w)", ULE(j, k)), + add_def("lo(w) <= hi(v)", ULE(k, j)), + add_def("lo(v) <= lo(w)", ULE(i, k)), + add_def("lo(w) <= lo(v)", ULE(k, i)), + add_def("hi(w) <= lo(v)", ULE(l, i)), + add_def("lo(v) <= hi(w)", ULE(i, l)), + add_def("hi(w) <= hi(v)", ULE(l, j)), + add_def("hi(v) <= hi(w)", ULE(j, l)), + is_zero_lo_v, + add_def("hi(v) == 0", j == 0), + is_zero_lo_w, + add_def("hi(w) == 0", l == 0), + add_def("hi(v) == 1", j == 1), + add_def("hi(w) == 1", l == 1), + add_def("is_fixed(v)", i + 1 == j), + add_def("is_fixed(w)", k + 1 == l), + add_def("lo(v) + 1 == hi(w)", i + 1 == l), + add_def("lo(v) + 1 == 0", i + 1 == 0), + is_free_v, + is_free_w + ] + +def is_tight(s, core, x, lo, hi): + s.push() + s.add(core) + s.add(Not(in_bounds(x, lo, hi))) + r = s.check() + s.pop() + if unsat != r: + return False + s.push() + s.add(core) + s.add(x == lo) + r = s.check() + s.pop() + if sat != r: + return False + s.push() + s.add(core) + s.add(x + 1 == hi, hi != lo) + r = s.check() + s.pop() + if sat != r: + return False + #print(core, x, lo, hi) + #print(core) + #print(Not(in_bounds(x, lo, hi))) + #print(s) + return True + +def is_tighter(s, core, x, lo1, hi1, lo2, hi2): + s.push() + s.add(core) + s.add(in_bounds(x, lo1, hi1)) + s.add(Not(in_bounds(x, lo2, hi2))) + r = s.check() + s.pop() + return r == unsat + +def core2deps(core): + deps = set([]) + for c in core: + sc = f"{c}" + if "lo(v)" in sc: + deps |= { "vlo" } + if "lo(w)" in sc: + deps |= { "wlo" } + if "hi(v)" in sc: + deps |= { "vhi" } + if "hi(w)" in sc: + deps |= { "whi" } + if "fixed(v)" in sc: + deps |= { "vlo", "vhi" } + if "fixed(w)" in sc: + deps |= { "wlo", "whi" } + deps = list(deps) + sorted(deps) + return ", ".join(deps) + +def core2pred(core): + return " && ".join([f"!({c.arg(0)})" if is_not(c) else f"{c}" for c in core ]) + + +def propagate_bounds(core, x, lo, hi): + deps = core2deps(core) + sys.stdout.write("if (") + sys.stdout.write(core2pred(core)) + sys.stdout.write(f" && !new_bound(i, {x}, {lo}, {hi}, {deps}))\n") + sys.stdout.write(" return false;\n") + sys.stdout.flush() + +def propagate_conflict(core): + deps = core2deps(core) + sys.stdout.write("if (") + sys.stdout.write(core2pred(core)) + sys.stdout.write(f")\n") + sys.stdout.write(f" return conflict({deps}), false;\n") + sys.stdout.flush() + +lows = [BitVecVal(0, nb), l, k, i, j, k + 1, i + 1] +highs = [BitVecVal(0, nb), l, k, i, j, l - 1, j - 1] + +def find_new_bounds(s, core, x): + bound = None + for lo in lows: + for hi in highs: + if is_tight(s, core, x, lo, hi): + if not bound: + bound = (lo, hi) + else: + lo2, hi2 = bound + if is_tighter(s, core, x, lo, hi, lo2, hi2): + #print("tighter", lo, hi, lo2, hi2) + bound = (lo, hi) + + if bound: + lo, hi = bound + propagate_bounds(core, x, lo, hi) + else: + print("Could not find new bounds", x, lows, highs) + + + + +num_tries = 0 +num_found = 0 +num_nodes = 0 + +# set_param(verbose=2) + +def explore(s, s0, ps): + global num_tries + global num_found + num_tries += 1 + r = s.check(ps) + if r == unsat: + core = s.unsat_core() + propagate_conflict(core) + s0.add(Not(And(core))) + num_found += 1 + + return + + found = False + s.push() + s.add(v == i) + r = s.check(ps) + if r == unsat: + core = s.unsat_core() + s0.add(Not(And(core))) + found = True + s.pop() + if r == unsat: + find_new_bounds(s, core, v) + + s.push() + s.add(w == k) + r = s.check(ps) + if r == unsat: + core = s.unsat_core() + s0.add(Not(And(core))) + found = True + s.pop() + if r == unsat: + find_new_bounds(s, core, w) + + s.push() + s.add(at_upper(v, i, j)) + r = s.check(ps) + if r == unsat: + core = s.unsat_core() + s0.add(Not(And(core))) + found = True + s.pop() + if r == unsat: + find_new_bounds(s, core, v) + + s.push() + s.add(at_upper(w, k, l)) + r = s.check(ps) + if r == unsat: + core = s.unsat_core() + s0.add(Not(And(core))) + found = True + s.pop() + if r == unsat: + find_new_bounds(s, core, w) + + if found: + num_found += 1 + + +def search(s, s0, trail, preds): + global num_nodes + num_nodes += 1 + r = s0.check(trail) + if r == unsat: + return + if len(preds) == 0: + explore(s, s0, trail) + return + hd = preds[0] + tl = preds[1:] + search(s, s0, trail + [hd], tl) + search(s, s0, trail + [Not(hd)], tl) + +def create_bounds(p): + global num_tries + global num_found + global num_nodes + num_tries = 0 + num_found = 0 + num_nodes = 0 + s0.push() + s.push() + s.add(p) + search(s, s0, [], preds) + s.pop() + s0.pop() + print("attempted predicates: ", num_tries, "predicates: ", num_found, "nodes: ", num_nodes) + +def search_primal(): + print("strict") + create_bounds(ULT(v, w)) + print("non-strict") + create_bounds(ULE(v, w)) + +#search_primal() + +def extract_predicates(s): + for p in preds: + r = s.check(p) + if r == sat: + yield p + r = s.check(Not(p)) + if r == sat: + yield Not(p) + +def test_le(ineq, lov, hiv, low, hiw): + if lov == hiv and lov > 0: + return + if low == hiw and low > 0: + return + s0.push() + s0.add(i == lov) + s0.add(j == hiv) + s0.add(k == low) + s0.add(l == hiw) + r = s0.check() + s0.pop() + if r == unsat: + return + s.push() + s.add(i == lov) + s.add(j == hiv) + s.add(k == low) + s.add(l == hiw) + r = s.check() + assert r == sat + + preds = list(extract_predicates(s)) + s.add(ineq) + if r == unsat: + print("core", preds) + s.pop() + return + + def test_bound(x, p): + s.push() + s.add(p) + r = s.check() + s.pop() + if r == unsat: + s1.push() + s1.add(p) + s1.add(ineq) + r = s1.check(preds) + if r == unsat: + core = [c for c in s1.unsat_core()] + else: + print("Did not find core for lower bound v") + print(lov, hiv, low, hiw) + print(s1) + for p in preds: + print(p) + s1.pop() + if r == unsat: + s1.push() + s1.add(ineq) + r = s1.check(core) + if r == unsat: + propagate_conflict(core) + else: + find_new_bounds(s1, core, x) + s1.pop() + s0.add(Not(And(core))) + + test_bound(v, v == i) + test_bound(w, w == k) + test_bound(v, at_upper(v, i, j)) + test_bound(w, at_upper(w, k, l)) + s.pop() + + +bounds = [0, 1, 2, 3, 10, 2**nb - 3, 2**nb - 2, 2**nb - 1] + +def search_dual(p): + for i in bounds: + for j in bounds: + for k in bounds: + for l in bounds: + test_le(p, i, j, k, l) + + +s0.push() +s1.push() +print("strict") +search_dual(ULT(v, w)) +s0.pop() +s1.pop() + +print("non-strict") +search_dual(ULE(v, w)) + + + diff --git a/src/math/interval/mod_interval.h b/src/math/interval/mod_interval.h index 899395fd0..ab4bdfb6c 100644 --- a/src/math/interval/mod_interval.h +++ b/src/math/interval/mod_interval.h @@ -27,10 +27,15 @@ struct pp { pp(Numeral const& n):n(n) {} }; + +inline std::ostream& operator<<(std::ostream& out, pp const& p) { + return out << (unsigned)p.n; +} + template inline std::ostream& operator<<(std::ostream& out, pp const& p) { - if ((0 - p.n) < p.n) - return out << "-" << (0 - p.n); + if ((Numeral)(0 - p.n) < p.n) + return out << "-" << (Numeral)(0 - p.n); return out << p.n; } diff --git a/src/math/interval/mod_interval_def.h b/src/math/interval/mod_interval_def.h index 4a02e41b2..6182c1b93 100644 --- a/src/math/interval/mod_interval_def.h +++ b/src/math/interval/mod_interval_def.h @@ -95,19 +95,46 @@ mod_interval mod_interval::operator&(mod_interval const& other return other; if (other.is_free() || is_empty()) return *this; - if (contains(other.lo)) - l = other.lo; - else if (other.contains(lo)) - l = lo; - else - return mod_interval::empty(); - if (contains(other.hi - 1)) - h = other.hi; - else if (other.contains(hi - 1)) - h = hi; - else - return mod_interval::empty(); + + if (lo < hi || hi == 0) { + if (other.lo < other.hi || other.hi == 0) { + if (hi != 0 && hi <= other.lo) + return mod_interval::empty(); + if (other.hi != 0 && other.hi <= lo) + return mod_interval::empty(); + l = lo >= other.lo ? lo : other.lo; + h = hi == 0 ? other.hi : (other.hi == 0 ? hi : (hi <= other.hi ? hi : other.hi)); + return mod_interval(l, h); + } + SASSERT(0 < other.hi && other.hi < other.lo); + if (other.lo <= lo) + return *this; + if (other.hi <= lo && lo < hi && hi <= other.lo) + return mod_interval::empty(); + if (lo <= other.hi && other.hi <= hi && hi <= other.lo) + return mod_interval(lo, other.hi); + if (hi == 0 && lo < other.hi) + return *this; + if (hi == 0 && other.hi <= lo) + return mod_interval(other.lo, hi); + if (other.hi <= lo && other.hi <= hi) + return mod_interval(other.lo, hi); + return *this; + } + SASSERT(hi < lo); + if (other.lo < other.hi || other.hi == 0) + return other & *this; + SASSERT(other.hi < other.lo); + SASSERT(hi != 0); + SASSERT(other.hi != 0); + if (lo <= other.hi) + return *this; + if (other.lo <= hi) + return other; + l = lo <= other.lo ? other.lo : lo; + h = hi >= other.hi ? other.hi : hi; return mod_interval(l, h); + } template diff --git a/src/math/polysat/fixplex.h b/src/math/polysat/fixplex.h index 8cbf49956..7b65f112f 100644 --- a/src/math/polysat/fixplex.h +++ b/src/math/polysat/fixplex.h @@ -253,6 +253,8 @@ namespace polysat { void eq_eh(var_t x, var_t y, row const& r1, row const& r2); lbool propagate_bounds(row const& r); bool propagate_bounds(ineq const& i); + bool propagate_strict_bounds(ineq const& i); + bool propagate_non_strict_bounds(ineq const& i); bool new_bound(row const& r, var_t x, mod_interval const& range); bool new_bound(ineq const& i, var_t x, numeral const& lo, numeral const& hi, u_dependency* a = nullptr, u_dependency* b = nullptr, u_dependency* c = nullptr, u_dependency* d = nullptr); void conflict(ineq const& i, u_dependency* a = nullptr, u_dependency* b = nullptr, u_dependency* c = nullptr, u_dependency* d = nullptr); diff --git a/src/math/polysat/fixplex_def.h b/src/math/polysat/fixplex_def.h index 3c61494ec..aa223926a 100644 --- a/src/math/polysat/fixplex_def.h +++ b/src/math/polysat/fixplex_def.h @@ -1174,51 +1174,205 @@ namespace polysat { } template - bool fixplex::propagate_bounds(ineq const& i) { - // v < w & lo(v) + 1 = 0 -> conflict - // v < w & lo(w) = 0 & hi(w) = 1 -> conflict - // v < w & hi(w) != 0 & lo(w) <= hi(w) & hi(w) - 1 <= lo(v) -> conflict - // v <= w & hi(w) != 0 & lo(w) <= hi(w) & hi(w) <= lo(v) -> conflict - // v < w & hi(w) != 0 & lo(w) <= hi(w) <= hi(v) -> hi(v) := hi(w) - 1 - // v < w & lo(w) <= lo(v) -> lo(w) := lo(v) + 1 - // v <= w & hi(v) > hi(w) -> hi(v) := hi(w) - // v <= w & lo(v) > lo(w) -> lo(w) := lo(w) + bool fixplex::propagate_strict_bounds(ineq const& i) { var_t v = i.v, w = i.w; bool s = i.strict; auto* vlo = m_vars[v].m_lo_dep, *vhi = m_vars[v].m_hi_dep; auto* wlo = m_vars[w].m_lo_dep, *whi = m_vars[w].m_hi_dep; - if (s && lo(v) + 1 == 0 && is_fixed(v)) - return conflict(i, vlo, vhi), false; - if (s && lo(w) == 0 && is_fixed(w)) - return conflict(i, wlo, whi), false; - if (s && hi(w) != 0 && lo(w) <= hi(w) && lo(v) <= hi(v) && hi(w) - 1 <= lo(v)) - return conflict(i, vlo, wlo, whi), false; - if (s && hi(v) == 0 && lo(w) < hi(w) && hi(w) - 1 <= lo(v)) - return conflict(i, vlo, vhi, wlo, whi), false; - if (!s && hi(w) != 0 && lo(w) <= hi(w) && hi(w) <= lo(v) && lo(v) <= hi(v)) - return conflict(i, vlo, vhi, wlo, whi), false; - if (!s && hi(w) != 0 && lo(w) <= hi(w) && hi(w) <= lo(v) && hi(v) == 0) - return conflict(i, vlo, vhi, wlo, whi), false; - if (s && hi(w) != 0 && lo(w) <= hi(w) && hi(w) <= hi(v) && !new_bound(i, v, lo(v), hi(w) - 1, wlo, vhi, whi)) - return false; - if (s && lo(w) <= lo(v) && !new_bound(i, w, lo(v) + 1, hi(w), vlo, wlo)) - return false; - if (s && hi(w) != 0 && hi(w) - 1 <= lo(v) && lo(v) <= hi(v) && hi(w) < lo(w) && !new_bound(i, w, lo(w), 0, wlo, whi, vlo, vhi)) - return false; - if (s && hi(w) == 1 && !is_fixed(w) && !new_bound(i, w, lo(w), 0, wlo, whi)) - return false; - if (!s && hi(v) > hi(w) && !new_bound(i, v, lo(v), hi(w), vhi, whi)) - return false; - if (!s && lo(v) > lo(w) && !new_bound(i, w, lo(v), hi(w), vlo, wlo)) - return false; - if (!s && hi(w) != 0 && hi(w) < lo(w) && hi(w) <= lo(v) && lo(v) <= hi(v) && !new_bound(i, w, lo(w), 0, vlo, vhi, wlo, whi)) - return false; - if (hi(w) != 0 && lo(w) <= hi(w) && hi(w) <= lo(v) && !new_bound(i, v, 0, hi(v), wlo, vlo, whi)) - return false; + if (lo(w) == 0 && !new_bound(i, w, lo(w) + 1, lo(w), wlo)) + return false; + if (hi(w) == 1 && !new_bound(i, w, lo(w), hi(w) - 1, whi)) + return false; + if (hi(w) <= hi(v) && lo(w) <= hi(w) && !(is_free(w)) && !new_bound(i, v, lo(v), hi(v) - 1, vhi, whi, wlo)) + return false; + if (hi(v) == 0 && lo(w) <= lo(v) && !new_bound(i, w, lo(v) + 1, hi(v), vhi, vlo, wlo)) + return false; + if (hi(v) == 0 && !(is_free(v)) && !new_bound(i, v, lo(v), hi(v) - 1, vhi)) + return false; + if (lo(w) <= lo(v) && lo(v) <= hi(v) && !new_bound(i, w, lo(v) + 1, lo(v), vlo, vhi, wlo)) + return false; + if (lo(v) + 1 == hi(w) && lo(v) <= hi(v) && !new_bound(i, w, lo(w), hi(w) - 1, vlo, vhi, whi)) + return false; + if (!(lo(v) <= hi(v)) && is_fixed(w) && lo(w) <= hi(v) && !new_bound(i, v, lo(v) + 1, hi(w) - 1, vlo, vhi, whi, wlo)) + return false; + if (lo(v) + 1 == hi(w) && lo(w) <= hi(w) && !new_bound(i, v, lo(v) + 1, hi(v), vlo, whi, wlo)) + return false; + if (is_fixed(v) && lo(v) <= hi(w) && hi(w) <= lo(v) && !(hi(v) == 1) && !new_bound(i, w, lo(v) + 1, hi(w) - 1, vlo, vhi, whi)) + return false; + if (!(hi(w) == 0) && hi(w) <= lo(v) && lo(v) <= hi(v) && !new_bound(i, w, lo(v) + 1, hi(w) - 1, vlo, vhi, whi)) + return false; + if (hi(w) <= lo(v) && lo(w) <= hi(w) && !(is_free(w)) && !new_bound(i, v, lo(v) + 1, hi(w) - 1, vlo, whi, wlo)) + return false; + if (lo(v) + 1 == hi(w) && hi(w) == 0 && !new_bound(i, v, lo(v) + 1, hi(v), vlo, whi)) + return false; + if (lo(v) + 1 == 0 && !new_bound(i, v, lo(v) + 1, hi(v), vlo)) + return false; + if (lo(w) < hi(w) && hi(w) <= lo(v) && !new_bound(i, v, 0, hi(v), vlo, vhi, whi, wlo)) + return false; + //return true; + + // manual patch + if (is_fixed(w) && lo(w) == 0) + return conflict(wlo, whi), false; + if (is_fixed(v) && hi(v) == 0) + return conflict(vlo, vhi), false; + if (!is_free(w) && (lo(w) <= hi(w) || hi(w) == 0) && (lo(v) < hi(v) || hi(v) == 0) && !new_bound(i, v, lo(v), hi(w) - 1, vlo, wlo, whi)) + return false; + if (!is_free(v) && (lo(w) <= hi(w) || hi(w) == 0) && (lo(v) < hi(v) || hi(v) == 0) && !new_bound(i, w, lo(v) + 1, hi(w), vlo, vhi, whi)) + return false; + if (lo(w) == 0 && !new_bound(i, w, 1, hi(w), wlo)) + return false; + if (lo(v) + 1 == 0 && !new_bound(i, v, 0, hi(v), vhi)) + return false; + if (lo(w) < hi(w) && (hi(w) <= hi(v) || hi(v) == 0) && !new_bound(i, v, lo(v), hi(w) - 1, vlo, vhi, wlo, whi)) + return false; + if (!is_fixed(w) && lo(v) + 1 == hi(w) && (lo(v) <= hi(v) || hi(v) == 0) && !new_bound(i, w, lo(w), hi(w) - 1, vlo, wlo, whi)) + return false; + if (lo(w) <= lo(v) && (lo(v) < hi(v) || lo(v) == 0) && !new_bound(i, w, lo(v) + 1, hi(w), vlo, vhi, wlo, whi)) + return false; + if (hi(w) <= lo(v) && (lo(v) < hi(v) || hi(v) == 0) && !new_bound(i, w, lo(w), 0, vlo, vhi, wlo, whi)) + return false; + if (lo(w) < hi(w) && hi(w) <= lo(v) && (lo(v) < hi(v) || hi(v) == 0)) + return conflict(vlo, vhi, wlo, whi), false; +// if (!is_free(w) && hi(v) < lo(v) && lo(w) != 0 && (lo(w) <= hi(w) || hi(w) == 0) && !new_bound(i, v, lo(w) - 1, hi(v), vlo, vhi, wlo, whi)) +// return false; + + + // automatically generated code + // see scripts/fixplex.py for script + + if (lo(w) == 0 && !new_bound(i, w, lo(w) + 1, lo(w), wlo)) + return false; + if (is_fixed(v) && hi(w) <= hi(v) && lo(w) <= hi(w) && !(is_free(w))) + return conflict(wlo, whi, vhi, vlo), false; + if (lo(w) <= lo(v) && lo(v) <= hi(v) && !new_bound(i, w, lo(v) + 1, lo(v), wlo, vhi, vlo)) + return false; + if (hi(w) <= hi(v) && lo(w) <= hi(w) && !(is_free(w)) && !new_bound(i, v, lo(v), hi(v) - 1, wlo, whi, vhi)) + return false; + if (hi(w) == 1 && !new_bound(i, w, lo(w), hi(w) - 1, whi)) + return false; + if (!(lo(v) == 0) && lo(v) <= hi(w) && hi(w) <= lo(v) && lo(v) <= hi(v) && !new_bound(i, w, lo(v) + 1, hi(w) - 1, whi, vhi, vlo)) + return false; + if (!(hi(w) == 0) && is_fixed(v) && hi(w) <= hi(v) && !new_bound(i, w, lo(v) + 1, hi(v) - 1, whi, vhi, vlo)) + return false; + if (!(lo(v) <= hi(w)) && !(hi(w) == 0) && lo(v) <= hi(v) && !new_bound(i, w, lo(v) + 1, hi(w) - 1, whi, vhi, vlo)) + return false; + if (!(lo(v) <= lo(w)) && is_fixed(w) && !new_bound(i, v, lo(v) + 1, hi(w) - 1, wlo, whi, vlo)) + return false; + if (hi(w) <= lo(v) && lo(w) <= hi(w) && !(is_free(w)) && !new_bound(i, v, lo(v) + 1, hi(w) - 1, wlo, whi, vlo)) + return false; + if (is_fixed(w) && hi(v) == 0 && lo(w) <= lo(v)) + return conflict(wlo, whi, vhi, vlo), false; + if (hi(v) == 0 && lo(w) <= lo(v) && !new_bound(i, w, lo(v) + 1, hi(v), wlo, vhi, vlo)) + return false; + if (hi(v) == 0 && !(is_free(v)) && !new_bound(i, v, lo(v), hi(v) - 1, vhi)) + return false; + if (is_fixed(w) && lo(w) <= lo(v) && !new_bound(i, v, lo(v) + 1, hi(w) - 1, wlo, whi, vlo)) + return false; return true; } + template + bool fixplex::propagate_non_strict_bounds(ineq const& i) { + var_t v = i.v, w = i.w; + bool s = i.strict; + auto* vlo = m_vars[v].m_lo_dep, *vhi = m_vars[v].m_hi_dep; + auto* wlo = m_vars[w].m_lo_dep, *whi = m_vars[w].m_hi_dep; + + // manual patch + if (lo(w) < lo(v) && (lo(v) < hi(v) || hi(v) == 0) && !new_bound(i, w, lo(v), hi(w), vlo, vhi, wlo, whi)) + return false; + if (!is_free(w) && (lo(w) <= hi(w) || hi(w) == 0) && (lo(v) < hi(v) || hi(v) == 0) && !new_bound(i, v, lo(v), hi(w), vlo, vhi, wlo, whi)) + return false; + if (!is_free(v) && (lo(w) <= hi(w) || hi(w) == 0) && (lo(v) < hi(v) || hi(v) == 0) && !new_bound(i, w, lo(v), hi(w), vlo, vhi, whi)) + return false; + if (hi(w) < lo(w) && hi(w) <= lo(v) && lo(v) < hi(v) && !new_bound(i, w, lo(w), 0, vlo, vhi, wlo, whi)) + return false; + if (lo(w) < hi(w) && hi(w) <= lo(v) && (lo(v) < hi(v) || hi(v) == 0)) + return conflict(vlo, vhi, wlo, whi), false; + + // automatically generated code. + // see scripts/fixplex.py for script + if (!(hi(w) <= lo(v)) && !(is_fixed(v)) && is_fixed(w) && hi(w) == 1 && !(hi(v) == 0) && !new_bound(i, v, 0, hi(w), vlo, wlo, vhi, whi)) + return false; + if (!(hi(v) <= lo(w)) && !(is_fixed(v)) && is_fixed(w) && lo(w) <= lo(v) && lo(v) <= lo(w) && !new_bound(i, v, 0, hi(w), vlo, wlo, vhi, whi)) + return false; + if (!(hi(v) <= hi(w)) && !(hi(w) <= lo(v)) && lo(w) <= lo(v) && !new_bound(i, v, 0, hi(w), wlo, vhi, vlo, whi)) + return false; + if (!(lo(w) <= lo(v)) && !(hi(v) <= hi(w)) && is_fixed(w) && lo(w) <= hi(w) && !new_bound(i, v, 0, hi(w), vlo, wlo, vhi, whi)) + return false; + if (!(lo(v) <= lo(w)) && hi(w) == 1 && lo(v) <= hi(w) && !new_bound(i, v, 0, hi(w), wlo, vlo, whi)) + return false; + if (is_fixed(w) && hi(w) <= lo(v) && lo(w) <= hi(w) && !new_bound(i, v, 0, hi(w), wlo, vlo, whi)) + return false; + if (!(lo(v) <= lo(w)) && lo(v) <= hi(w) && hi(w) <= lo(v) && !new_bound(i, v, 0, hi(w), wlo, vlo, whi)) + return false; + if (!(lo(v) <= hi(w)) && is_fixed(v) && lo(w) <= hi(w) && !new_bound(i, w, lo(v), 0, vhi, vlo, wlo, whi)) + return false; + if (!(is_fixed(w)) && !(hi(v) <= lo(w)) && is_fixed(v) && hi(v) <= hi(w) && hi(w) <= hi(v) && !new_bound(i, w, hi(w) - 1, hi(w), vlo, wlo, vhi, whi)) + return false; + if (!(lo(v) <= lo(w)) && !(hi(w) <= lo(v)) && hi(w) <= hi(v) && !new_bound(i, w, lo(v), hi(w), vlo, wlo, vhi, whi)) + return false; + if (!(lo(v) <= lo(w)) && is_fixed(v) && !new_bound(i, w, lo(v), 0, vhi, wlo, vlo)) + return false; + if (is_fixed(v) && hi(w) == 1 && hi(w) <= lo(v) && hi(v) <= lo(w) && !(hi(v) == 0) && !new_bound(i, w, lo(w), 0, vhi, vlo, wlo, whi)) + return false; + if (!(hi(v) == 1) && hi(w) == 1 && lo(v) <= hi(w) && hi(w) <= lo(v) && hi(v) <= lo(w) && lo(v) <= hi(v) && !new_bound(i, w, lo(w), 0, vhi, vlo, wlo, whi)) + return false; + if (!(hi(w) == 0) && is_fixed(v) && hi(w) <= lo(v) && hi(v) <= lo(w) && lo(v) <= hi(v) && !new_bound(i, w, lo(w), 0, vhi, vlo, wlo, whi)) + return false; + if (!(hi(v) <= hi(w)) && !(hi(w) == 0) && lo(v) <= hi(w) && hi(w) <= lo(v) && hi(v) <= lo(w) && !new_bound(i, w, lo(w), 0, vhi, vlo, wlo, whi)) + return false; + if (!(lo(v) <= hi(w)) && !(lo(w) <= lo(v)) && hi(w) == 1 && lo(w) <= hi(v) && !new_bound(i, w, lo(w), 0, vhi, wlo, vlo, whi)) + return false; + if (!(lo(v) <= hi(w)) && !(lo(w) <= lo(v)) && !(hi(w) == 0) && lo(w) <= hi(v) && !new_bound(i, w, lo(w), 0, vhi, wlo, vlo, whi)) + return false; + if (!(lo(w) <= hi(w)) && is_fixed(v) && hi(w) == 1 && lo(w) <= lo(v) && !new_bound(i, w, lo(w), 0, vlo, wlo, vhi, whi)) + return false; + if (!(lo(w) <= hi(w)) && !(hi(v) <= lo(w)) && hi(w) == 1 && lo(w) <= lo(v) && lo(v) <= lo(w) && !new_bound(i, w, lo(w), 0, vlo, wlo, vhi, whi)) + return false; + if (!(lo(w) <= hi(w)) && !(hi(w) == 0) && is_fixed(v) && lo(w) <= lo(v) && !new_bound(i, w, lo(w), 0, vlo, wlo, vhi, whi)) + return false; + if (!(lo(w) <= hi(w)) && !(hi(v) <= lo(w)) && !(hi(w) == 0) && lo(w) <= lo(v) && lo(v) <= lo(w) && !new_bound(i, w, lo(w), 0, vlo, wlo, vhi, whi)) + return false; + if (!(lo(w) <= hi(w)) && !(hi(v) == 1) && hi(w) == 1 && lo(v) <= hi(w) && hi(w) <= lo(v) && !new_bound(i, w, lo(w), 0, vlo, wlo, vhi, whi)) + return false; + if (!(lo(w) <= hi(w)) && !(hi(v) <= hi(w)) && !(hi(w) == 0) && lo(v) <= hi(w) && hi(w) <= lo(v) && !new_bound(i, w, lo(w), 0, vlo, wlo, vhi, whi)) + return false; + if (!(lo(v) <= hi(w)) && hi(v) == 0 && lo(w) <= hi(v) && !new_bound(i, w, lo(v), 0, vhi, vlo, wlo, whi)) + return false; + if (!(hi(w) == 1) && hi(v) == 1 && hi(w) <= lo(v) && lo(w) <= hi(v) && hi(v) <= lo(w) && lo(w) <= hi(w) && !new_bound(i, v, 0, lo(w), vhi, vlo, wlo, whi)) + return false; + if (!(hi(w) <= hi(v)) && hi(w) <= lo(v) && lo(w) <= hi(v) && !new_bound(i, v, 0, hi(w) - 1, vhi, vlo, wlo, whi)) + return false; + if (!(lo(v) <= lo(w)) && hi(v) == 0 && !new_bound(i, w, lo(v), 0, vhi, wlo, vlo)) + return false; + if (!(lo(v) <= lo(w)) && !(hi(w) == 0) && hi(v) == 0 && lo(w) <= hi(v) && !new_bound(i, v, lo(v), hi(w), vlo, wlo, vhi, whi)) + return false; + if (!(lo(v) <= hi(v)) && is_fixed(w) && hi(v) == 0 && lo(w) <= hi(w) && !new_bound(i, v, lo(v), hi(w), vhi, vlo, wlo, whi)) + return false; + if (!(lo(v) <= hi(v)) && !(hi(w) <= lo(v)) && hi(v) == 0 && lo(w) <= lo(v) && !new_bound(i, v, lo(w), hi(w), wlo, vhi, vlo, whi)) + return false; + if (!(hi(v) <= lo(w)) && hi(v) <= hi(w) && hi(w) <= lo(v) && !new_bound(i, v, 0, hi(w), vlo, wlo, vhi, whi)) + return false; + if (!(lo(w) <= hi(w)) && hi(w) == 1 && hi(v) == 0 && lo(w) <= lo(v) && !new_bound(i, w, lo(w), 0, vlo, wlo, vhi, whi)) + return false; + if (!(lo(v) <= hi(w)) && !(hi(w) == 0) && hi(v) == 0 && lo(v) <= lo(w) && !new_bound(i, w, lo(w), 0, wlo, vhi, vlo, whi)) + return false; + if (!(lo(w) <= lo(v)) && !(hi(w) == 0) && hi(v) == 0 && hi(w) <= lo(v) && !new_bound(i, w, lo(w), 0, vlo, wlo, vhi, whi)) + return false; + return true; + } + + template + bool fixplex::propagate_bounds(ineq const& i) { + if (i.strict) + return propagate_strict_bounds(i); + else + return propagate_non_strict_bounds(i); + } + template void fixplex::conflict(ineq const& i, u_dependency* a, u_dependency* b, u_dependency* c, u_dependency* d) { conflict(a, m_deps.mk_join(mk_leaf(i.dep), m_deps.mk_join(b, m_deps.mk_join(c, d)))); @@ -1246,7 +1400,9 @@ namespace polysat { bool fixplex::new_bound(ineq const& i, var_t x, numeral const& l, numeral const& h, u_dependency* a, u_dependency* b, u_dependency* c, u_dependency* d) { bool was_fixed = lo(x) + 1 == hi(x); u_dependency* dep = m_deps.mk_join(mk_leaf(i.dep), m_deps.mk_join(a, m_deps.mk_join(b, m_deps.mk_join(c, d)))); + // std::cout << "new bound " << x << " " << m_vars[x] << " " << mod_interval(l, h) << " -> "; update_bounds(x, l, h, dep); + // std::cout << m_vars[x] << "\n"; if (m_vars[x].is_empty()) return conflict(m_vars[x].m_lo_dep, m_vars[x].m_hi_dep), false; else if (!was_fixed && lo(x) + 1 == hi(x)) { diff --git a/src/test/fixplex.cpp b/src/test/fixplex.cpp index c24e1851d..adad74428 100644 --- a/src/test/fixplex.cpp +++ b/src/test/fixplex.cpp @@ -141,6 +141,7 @@ namespace polysat { } static void test_ineqs() { + unsigned num_bad = 0; var_t x = 0, y = 1; unsigned nb = 6; uint64_t bounds[6] = { 0, 1, 2, 10 , (uint64_t)-2, (uint64_t)-1 }; @@ -169,11 +170,25 @@ namespace polysat { solver.assert_expr(bv.mk_ule(I, x)); }; + auto add_not_bound = [&](expr* x, uint64_t i, uint64_t j) { + expr_ref I(bv.mk_numeral(i, 64), m); + expr_ref J(bv.mk_numeral(j, 64), m); + if (i < j) + solver.assert_expr(m.mk_not(m.mk_and(bv.mk_ule(I, x), mk_ult(x, J)))); + else if (i > j && j != 0) + solver.assert_expr(m.mk_not(m.mk_or(bv.mk_ule(I, x), mk_ult(x, J)))); + else if (i > j && j == 0) + solver.assert_expr(m.mk_not(bv.mk_ule(I, x))); + else + solver.assert_expr(m.mk_false()); + }; + auto test_le = [&](bool test_le, uint64_t i, uint64_t j, uint64_t k, uint64_t l) { if (i == j && i != 0) return; if (k == l && k != 0) return; + // std::cout << "test " << i << " " << j << " " << k << " " << l << "\n"; scoped_fp fp; fp.set_bounds(x, i, j, 1); fp.set_bounds(y, k, l, 2); @@ -197,13 +212,11 @@ namespace polysat { lbool feas2 = solver.check(); - if (feas == feas2) { solver.pop(1); return; } - if (feas2 == l_false && feas == l_true) { std::cout << "BUG!\n"; solver.pop(1); @@ -217,10 +230,16 @@ namespace polysat { for (unsigned c : fp.get_unsat_core()) std::cout << c << "\n"; std::cout << "\n"; + // TBD: check that core is sufficient and minimal break; case l_true: case l_undef: + if (feas2 == l_false) { + std::cout << "Missed conflict\n"; + std::cout << fp << "\n"; + break; + } // Check for missed bounds: solver.push(); solver.assert_expr(m.mk_eq(X, bv.mk_numeral(fp.lo(x), 64))); @@ -258,9 +277,40 @@ namespace polysat { solver.pop(1); } + // check that inferred bounds are implied: + solver.push(); + add_not_bound(X, fp.lo(x), fp.hi(x)); + if (l_false != solver.check()) { + std::cout << "Bound on x is not implied\n"; + scoped_fp fp1; + fp1.set_bounds(x, i, j, 1); + fp1.set_bounds(y, k, l, 2); + std::cout << fp1 << "\n"; + bad = true; + } + solver.pop(1); + + solver.push(); + add_not_bound(Y, fp.lo(y), fp.hi(y)); + if (l_false != solver.check()) { + std::cout << "Bound on y is not implied\n"; + scoped_fp fp1; + fp1.set_bounds(x, i, j, 1); + fp1.set_bounds(y, k, l, 2); + std::cout << fp1 << "\n"; + bad = true; + } + solver.pop(1); + if (bad) { - std::cout << fp << "\n"; std::cout << feas << " " << feas2 << "\n"; + std::cout << fp << "\n"; + std::cout << "original:\n"; + scoped_fp fp1; + fp1.set_bounds(x, i, j, 1); + fp1.set_bounds(y, k, l, 2); + std::cout << fp1 << "\n"; + ++num_bad; } break; @@ -287,6 +337,8 @@ namespace polysat { test_le(true, bounds[i], bounds[j], bounds[k], bounds[l]); test_le(false, bounds[i], bounds[j], bounds[k], bounds[l]); } + + std::cout << "number of failures: " << num_bad << "\n"; } } @@ -294,7 +346,6 @@ void tst_fixplex() { polysat::test_ineq1(); polysat::test_ineqs(); - return; polysat::test1(); polysat::test2(); diff --git a/src/test/mod_interval.cpp b/src/test/mod_interval.cpp index 69a39220c..db560aee9 100644 --- a/src/test/mod_interval.cpp +++ b/src/test/mod_interval.cpp @@ -90,12 +90,52 @@ static void test_interval2() { std::cout << " < 500: " << i << " -> " << i.intersect_ult(500) << "\n"; i = mod_interval(500, 10); std::cout << " < 501: " << i << " -> " << i.intersect_ult(501) << "\n"; +} + +static void test_interval_intersect(unsigned i, unsigned j, unsigned k, unsigned l) { + if (i == j && i != 0) + return; + if (k == l && k != 0) + return; + mod_interval x(i, j); + mod_interval y(k, l); + auto r = x & y; + bool x_not_y = false, y_not_x = false; + // check that & computes a join + // it contains all elements in x, y + // it contains no elements neither in x, y + // it does not contain two elements, one in x\y the other in y\x + for (i = 0; i < 256; ++i) { + uint8_t c = (uint8_t)i; + if ((x.contains(c) && y.contains(c)) && !r.contains(c)) { + std::cout << x << " & " << y << " = " << r << "\n"; + std::cout << i << " " << r.contains(c) << " " << x.contains(c) << " " << y.contains(c) << "\n"; + } + VERIFY(!(x.contains(c) && y.contains(c)) || r.contains(c)); + VERIFY(x.contains(c) || y.contains(c) || !r.contains(c)); + if (r.contains(c) && x.contains(c) && !y.contains(c)) + x_not_y = true; + if (r.contains(c) && !x.contains(c) && y.contains(c)) + y_not_x = true; + } + if (x_not_y && y_not_x) { + std::cout << x << " & " << y << " = " << r << "\n"; + VERIFY(!x_not_y || !y_not_x); + } +} - +static void test_interval_intersect() { + unsigned bounds[8] = { 0, 1, 2, 3, 252, 253, 254, 255 }; + for (unsigned i = 0; i < 8; ++i) + for (unsigned j = 0; j < 8; ++j) + for (unsigned k = 0; k < 8; ++k) + for (unsigned l = 0; l < 8; ++l) + test_interval_intersect(bounds[i], bounds[j], bounds[k], bounds[l]); } void tst_mod_interval() { + test_interval_intersect(); test_interval1(); test_interval2(); }