diff --git a/src/math/interval/mod_interval.h b/src/math/interval/mod_interval.h index eaa7e5c98..0e3554a9b 100644 --- a/src/math/interval/mod_interval.h +++ b/src/math/interval/mod_interval.h @@ -48,12 +48,23 @@ public: mod_interval(Numeral const& l, Numeral const& h): lo(l), hi(h) {} static mod_interval free() { return mod_interval(0, 0); } static mod_interval empty() { mod_interval i(0, 0); i.emp = true; return i; } + bool is_free() const { return !emp && lo == hi; } bool is_empty() const { return emp; } + bool is_singleton() const { return !is_empty() && (lo + 1 == hi || (hi == 0 && is_max(lo))); } + bool contains(Numeral const& n) const; + virtual bool is_max(Numeral const& n) const { return n + 1 == 0; } + void set_free() { lo = hi = 0; emp = false; } void set_bounds(Numeral const& l, Numeral const& h) { lo = l; hi = h; } void set_empty() { emp = true; } - bool contains(Numeral const& n) const; + + void intersect_ule(Numeral const& h); + void intersect_uge(Numeral const& l); + void intersect_ult(Numeral const& h); + void intersect_ugt(Numeral const& l); + void intersect_fixed(Numeral const& n); + void intersect_diff(Numeral const& n); mod_interval operator&(mod_interval const& other) const; mod_interval operator+(mod_interval const& other) const; mod_interval operator-(mod_interval const& other) const; diff --git a/src/math/interval/mod_interval_def.h b/src/math/interval/mod_interval_def.h index 253dd7390..c41054ba3 100644 --- a/src/math/interval/mod_interval_def.h +++ b/src/math/interval/mod_interval_def.h @@ -120,3 +120,89 @@ Numeral mod_interval::closest_value(Numeral const& n) const { return lo; return hi - 1; } + +// TBD: correctness and completeness for wrap-around semantics needs to be checked/fixed + +template +void mod_interval::intersect_ule(Numeral const& h) { + if (is_empty()) + return; + if (is_max(h)) + return; + else if (is_free()) + lo = 0, hi = h + 1; + else if (hi > lo && lo > h) + set_empty(); + else if (hi != 0 || h + 1 < hi) + hi = h + 1; +} + +template +void mod_interval::intersect_uge(Numeral const& l) { + if (is_empty()) + return; + if (lo < hi && hi <= l) + set_empty(); + else if (is_free()) + lo = l, hi = 0; + else if (lo < hi && lo < l) + lo = l; +} + +template +void mod_interval::intersect_ult(Numeral const& h) { + if (is_empty()) + return; + if (h == 0) + set_empty(); + else if (is_free()) + lo = 0, hi = h; + else if (hi > lo && lo >= h) + set_empty(); + else if (hi > lo && h < hi) + hi = h; +} + +template +void mod_interval::intersect_ugt(Numeral const& l) { + if (is_empty()) + return; + if (is_max(l)) + set_empty(); + else if (is_free()) + lo = l + 1, hi = 0; + else if (lo > l) + return; + else if (lo < hi && hi <= l) + set_empty(); + else if (lo < hi) + lo = l + 1; +} + +template +void mod_interval::intersect_fixed(Numeral const& a) { + if (is_empty()) + return; + if (!contains(a)) + set_empty(); + else if (is_max(a)) + lo = a, hi = 0; + else + lo = a, hi = a + 1; +} + +template +void mod_interval::intersect_diff(Numeral const& a) { + if (!contains(a) || is_empty()) + return; + if (a == lo && a + 1 == hi) + set_empty(); + else if (a == lo && hi == 0 && is_max(a)) + set_empty(); + else if (a == lo && !is_max(a)) + lo = a + 1; + else if (a + 1 == hi) + hi = a; + else if (hi == 0 && is_max(a)) + hi = a; +} diff --git a/src/math/polysat/viable.cpp b/src/math/polysat/viable.cpp index 584adb17c..40c79975d 100644 --- a/src/math/polysat/viable.cpp +++ b/src/math/polysat/viable.cpp @@ -44,139 +44,78 @@ namespace polysat { return a + 1 == rational::power_of_two(m_num_bits); } - bool viable_set::is_singleton() const { - return !is_empty() && (lo + 1 == hi || (hi == 0 && is_max(lo))); - } - void viable_set::intersect_eq(rational const& a, bool is_positive) { - if (is_empty()) - return; - if (is_positive) { - if (!contains(a)) - set_empty(); - else if (is_max(a)) - lo = a, hi = 0; - else - lo = a, hi = a + 1; - } - else { - if (!contains(a)) - return; - if (a == lo && a + 1 == hi) - set_empty(); - else if (a == lo && hi == 0 && is_max(a)) - set_empty(); - else if (a == lo && !is_max(a)) - lo = a + 1; - else if (a + 1 == hi) - hi = a; - else if (hi == 0 && is_max(a)) - hi = a; - else - std::cout << "unhandled diseq " << lo << " " << a << " " << hi << "\n"; - } + if (is_positive) + intersect_fixed(a); + else + intersect_diff(a); } bool viable_set::intersect_eq(rational const& a, rational const& b, bool is_positive) { - if (a.is_odd()) { - if (b == 0) - intersect_eq(b, is_positive); - else { - rational a_inv; - VERIFY(a.mult_inverse(m_num_bits, a_inv)); - intersect_eq(mod(a_inv * -b, p2()), is_positive); - } - return true; + if (!a.is_odd()) { + std::function eval = [&](rational const& x) { + return is_positive == (mod(a * x + b, p2()) == 0); + }; + return narrow(eval); } + if (b == 0) + intersect_eq(b, is_positive); else { - return false; + rational a_inv; + VERIFY(a.mult_inverse(m_num_bits, a_inv)); + intersect_eq(mod(a_inv * -b, p2()), is_positive); } + return true; } - void viable_set::intersect_eq(rational const& a, rational const& b, bool is_positive, unsigned& budget) { - std::function eval = [&](rational const& x) { - return is_positive == (mod(a * x + b, p2()) == 0); - }; - narrow(eval, budget); - } - - - bool viable_set::intersect_ule(rational const& a, rational const& b, rational const& c, rational const& d, bool is_positive) { + bool viable_set::intersect_le(rational const& a, rational const& b, rational const& c, rational const& d, bool is_positive) { // x <= 0 if (a.is_odd() && b == 0 && c == 0 && d == 0) intersect_eq(b, is_positive); else if (a == 1 && b == 0 && c == 0) { - // x <= d + // x <= d or x > d if (is_positive) - set_hi(d); - // x > d - else if (is_max(d)) - set_empty(); - else - set_lo(d + 1); + intersect_ule(d); + else + intersect_ugt(d); } else if (a == 0 && c == 1 && d == 0) { - // x >= b + // x >= b or x < b if (is_positive) - set_lo(b); - else if (b == 0) - set_empty(); + intersect_uge(b); else - set_hi(b - 1); + intersect_ult(b); + } + // TBD: can also handle wrap-around semantics (for signed comparison) + else { + std::function eval = [&](rational const& x) { + return is_positive == mod(a * x + b, p2()) <= mod(c * x + d, p2()); + }; + return narrow(eval); } - else - return false; - return true; } - void viable_set::narrow(std::function& eval, unsigned& budget) { - while (budget > 0 && !eval(lo) && !is_max(lo) && !is_empty()) { + rational viable_set::prev(rational const& p) const { + if (p > 0) + return p - 1; + else + return rational::power_of_two(m_num_bits) - 1; + } + + bool viable_set::narrow(std::function& eval) { + unsigned budget = 10; + while (budget > 0 && !is_empty() && !eval(lo)) { --budget; - lo += 1; - set_lo(lo); + intersect_diff(lo); } - while (budget > 0 && hi > 0 && !eval(hi - 1) && !is_empty()) { + while (budget > 0 && !is_empty() && !eval(prev(hi))) { --budget; - hi = hi - 1; - set_hi(hi); + intersect_diff(prev(hi)); } + return 0 < budget; } - void viable_set::intersect_ule(rational const& a, rational const& b, rational const& c, rational const& d, bool is_positive, unsigned& budget) { - std::function eval = [&](rational const& x) { - return is_positive == mod(a * x + b, p2()) <= mod(c * x + d, p2()); - }; - narrow(eval, budget); - } - - void viable_set::set_hi(rational const& d) { - if (is_max(d)) - return; - else if (is_free()) - lo = 0, hi = d + 1; - else if (lo > d) - set_empty(); - else if (hi != 0 || d + 1 < hi) - hi = d + 1; - else if (d + 1 == hi) - return; - else - std::cout << "set hi " << d << " " << *this << "\n"; - } - - void viable_set::set_lo(rational const& b) { - if (hi != 0 && hi <= b) - set_empty(); - else if (is_free()) - lo = b, hi = 0; - else if (lo < b) - lo = b; - else if (lo == b) - return; - else - std::cout << "set lo " << b << " " << *this << "\n"; - } #endif viable::viable(solver& s): @@ -184,6 +123,18 @@ namespace polysat { m_bdd(1000) {} + viable::~viable() { +#if NEW_VIABLE + ptr_vector entries; + for (auto* e : m_constraint_cache) + entries.push_back(e); + m_constraint_cache.reset(); + for (auto* e : entries) + dealloc(e); +#endif + } + + void viable::push_viable(pvar v) { s.m_trail.push_back(trail_instr_t::viable_i); m_viable_trail.push_back(std::make_pair(v, m_viable[v])); @@ -200,15 +151,8 @@ namespace polysat { void viable::intersect_eq(rational const& a, pvar v, rational const& b, bool is_positive) { #if NEW_VIABLE push_viable(v); - if (!m_viable[v].intersect_eq(a, b, is_positive)) { - IF_VERBOSE(10, verbose_stream() << "could not intersect v" << v << " " << m_viable[v] << "\n"); - unsigned budget = 10; - m_viable[v].intersect_eq(a, b, is_positive, budget); - if (budget == 0) { - std::cout << "budget used\n"; - // then narrow the range using BDDs - } - } + if (!m_viable[v].intersect_eq(a, b, is_positive)) + intersect_eq_bdd(v, a, b, is_positive); if (m_viable[v].is_empty()) s.set_conflict(v); #else @@ -239,54 +183,9 @@ namespace polysat { void viable::intersect_ule(pvar v, rational const& a, rational const& b, rational const& c, rational const& d, bool is_positive) { #if NEW_VIABLE - // - // TODO This code needs to be partitioned into self-contained pieces. - // push_viable(v); - if (!m_viable[v].intersect_ule(a, b, c, d, is_positive)) { - unsigned budget = 10; - m_viable[v].intersect_ule(a, b, c, d, is_positive, budget); - if (budget == 0) { - std::cout << "miss: " << a << " " << b << " " << c << " " << d << " " << is_positive << "\n"; - unsigned sz = var2bits(v).num_bits(); - bdd le = m_bdd.mk_true(); - ineq_entry entry0(sz, a, b, c, d, le); - ineq_entry* other = nullptr; - if (!m_ineq_cache.find(&entry0, other)) { - std::cout << "ADD-to-cache\n"; - bddv const& x = var2bits(v).var(); - le = ((a * x) + b) <= ((c * x) + d); - other = alloc(ineq_entry, sz, a, b, c, d, le); - m_ineq_cache.insert(other); - } - bdd gt = is_positive ? !other->repr : other->repr; - other->m_activity++; - - // - // instead of using activity for GC, use the Move-To-Front approach - // see sat/smt/bv_ackerman.h or sat/smt/euf_ackerman.h - // where hash table entries use a dll_base. - // - - // le(lo) is false: find min x >= lo, such that le(x) is false, le(x+1) is true - // le(hi) is false: find max x =< hi, such that le(x) is false, le(x-1) is true - - rational bound = m_viable[v].lo; - if (var2bits(v).sup(gt, bound)) { - m_viable[v].set_lo(bound); - m_viable[v].set_ne(bound); - } - bound = m_viable[v].hi; - if (bound != 0) { - bound = bound - 1; - if (var2bits(v).inf(gt, bound)) { - std::cout << "TODO: new upper bound " << bound << "\n"; - } - } - - } - - } + if (!m_viable[v].intersect_le(a, b, c, d, is_positive)) + intersect_ule_bdd(v, a, b, c, d, is_positive); if (m_viable[v].is_empty()) s.set_conflict(v); #else @@ -305,6 +204,77 @@ namespace polysat { #endif } +#if NEW_VIABLE + + viable::cached_constraint& viable::cache_constraint(pvar v, cached_constraint& entry0, std::function& mk_constraint) { + cached_constraint* other = nullptr; + if (!m_constraint_cache.find(&entry0, other)) { + gc_cached_constraints(); + other = alloc(cached_constraint, entry0); + other->repr = mk_constraint(); + m_constraint_cache.insert(other); + } + other->m_activity++; + return *other; + } + + void viable::gc_cached_constraints() { + // + // TODO: instead of using activity for GC, use the Move-To-Front approach + // see sat/smt/bv_ackerman.h or sat/smt/euf_ackerman.h + // where hash table entries use a dll_base. + // + unsigned max_entries = 10000; + if (m_constraint_cache.size() > max_entries) { + ptr_vector entries; + for (auto* e : m_constraint_cache) + entries.push_back(e); + std::stable_sort(entries.begin(), entries.end(), [&](cached_constraint* a, cached_constraint* b) { return a->m_activity < b->m_activity; }); + for (unsigned i = 0; i < max_entries/2; ++i) { + m_constraint_cache.remove(entries[i]); + dealloc(entries[i]); + } + } + } + + void viable::narrow(pvar v, bdd const& is_false) { + rational bound = m_viable[v].lo; + if (var2bits(v).sup(is_false, bound)) + m_viable[v].intersect_ugt(bound); + bound = m_viable[v].prev(m_viable[v].hi); + if (var2bits(v).inf(is_false, bound)) + m_viable[v].intersect_ult(bound); + } + + void viable::intersect_ule_bdd(pvar v, rational const& a, rational const& b, rational const& c, rational const& d, bool is_positive) { + unsigned sz = var2bits(v).num_bits(); + std::function le = [&]() { + bddv const& x = var2bits(v).var(); + return ((a * x) + b) <= ((c * x) + d); + }; + cached_constraint entry0(sz, a, b, c, d, m_bdd.mk_true()); + cached_constraint& entry = cache_constraint(v, entry0, le); + + // le(lo) is false: find min x >= lo, such that le(x) is false, le(x+1) is true + // le(hi) is false: find max x =< hi, such that le(x) is false, le(x-1) is true + bdd gt = is_positive ? !entry.repr : entry.repr; + narrow(v, gt); + } + + void viable::intersect_eq_bdd(pvar v, rational const& a, rational const& b, bool is_positive) { + unsigned sz = var2bits(v).num_bits(); + std::function eq = [&]() { + bddv const& x = var2bits(v).var(); + return ((a * x) + b) == rational(0); + }; + cached_constraint entry0(sz, a, b, m_bdd.mk_true()); + cached_constraint& entry = cache_constraint(v, entry0, eq); + + bdd ne = is_positive ? !entry.repr : entry.repr; + narrow(v, ne); + } +#endif + bool viable::has_viable(pvar v) { #if NEW_VIABLE return !m_viable[v].is_empty(); @@ -325,7 +295,7 @@ namespace polysat { #if NEW_VIABLE push_viable(v); IF_VERBOSE(10, verbose_stream() << " v" << v << " != " << val << "\n"); - m_viable[v].set_ne(val); + m_viable[v].intersect_diff(val); if (m_viable[v].is_empty()) s.set_conflict(v); #else diff --git a/src/math/polysat/viable.h b/src/math/polysat/viable.h index 7c17c8510..ddec88412 100644 --- a/src/math/polysat/viable.h +++ b/src/math/polysat/viable.h @@ -14,12 +14,11 @@ Author: Notes: NEW_VIABLE uses cheaper book-keeping, but is partial. - The implementation of NEW_VIABLE is atm incomplete and ad-hoc. --*/ #pragma once -#define NEW_VIABLE 0 +#define NEW_VIABLE 1 #include @@ -44,52 +43,57 @@ namespace polysat { class viable_set : public mod_interval { unsigned m_num_bits; rational p2() const { return rational::power_of_two(m_num_bits); } - bool is_max(rational const& a) const; + bool is_max(rational const& a) const override; void intersect_eq(rational const& a, bool is_positive); - void narrow(std::function& eval, unsigned& budget); + bool narrow(std::function& eval); public: viable_set(unsigned num_bits): m_num_bits(num_bits) {} - bool is_singleton() const; dd::find_t find_hint(rational const& c, rational& val) const; - void set_ne(rational const& a) { intersect_eq(a, false); } - void set_lo(rational const& lo); - void set_hi(rational const& hi); bool intersect_eq(rational const& a, rational const& b, bool is_positive); - void intersect_eq(rational const& a, rational const& b, bool is_positive, unsigned& budget); - bool intersect_ule(rational const& a, rational const& b, rational const& c, rational const& d, bool is_positive); - void intersect_ule(rational const& a, rational const& b, rational const& c, rational const& d, bool is_positive, unsigned& budget); + bool intersect_le(rational const& a, rational const& b, rational const& c, rational const& d, bool is_positive); + rational prev(rational const& p) const; }; #endif class viable { - solver& s; typedef dd::bdd bdd; typedef dd::fdd fdd; + solver& s; dd::bdd_manager m_bdd; scoped_ptr_vector m_bits; #if NEW_VIABLE - struct ineq_entry { + struct cached_constraint { + enum op_code { is_ule, is_eq }; + op_code m_op; unsigned m_num_bits; rational a, b, c, d; bdd repr; unsigned m_activity = 0; - ineq_entry(unsigned n, rational const& a, rational const& b, rational const& c, rational const& d, bdd& f) : - m_num_bits(n), a(a), b(b), c(c), d(d), repr(f) {} + cached_constraint(unsigned n, rational const& a, rational const& b, rational const& c, rational const& d, bdd& f) : + m_op(op_code::is_ule), m_num_bits(n), a(a), b(b), c(c), d(d), repr(f) {} + cached_constraint(unsigned n, rational const& a, rational const& b, bdd& f) : + m_op(op_code::is_eq), m_num_bits(n), a(a), b(b), repr(f) {} struct hash { - unsigned operator()(ineq_entry const* e) const { - return mk_mix(e->a.hash(), e->b.hash(), mk_mix(e->c.hash(), e->d.hash(), e->m_num_bits)); + unsigned operator()(cached_constraint const* e) const { + return mk_mix(e->a.hash(), e->b.hash(), mk_mix(e->c.hash(), e->d.hash(), e->m_num_bits)) + e->m_op; } }; struct eq { - bool operator()(ineq_entry const* x, ineq_entry const* y) const { - return x->a == y->a && x->b == y->b && x->c == y->c && x->d == y->d && x->m_num_bits == y->m_num_bits; + bool operator()(cached_constraint const* x, cached_constraint const* y) const { + return x->m_op == y->m_op && x->a == y->a && x->b == y->b && x->c == y->c && x->d == y->d && x->m_num_bits == y->m_num_bits; } }; }; vector m_viable; vector> m_viable_trail; - hashtable m_ineq_cache; + hashtable m_constraint_cache; + + void intersect_ule_bdd(pvar v, rational const& a, rational const& b, rational const& c, rational const& d, bool is_positive); + void intersect_eq_bdd(pvar v, rational const& a, rational const& b, bool is_positive); + cached_constraint& cache_constraint(pvar v, cached_constraint& entry0, std::function& mk_constraint); + void gc_cached_constraints(); + void narrow(pvar v, bdd const& is_false); #else @@ -110,6 +114,8 @@ namespace polysat { public: viable(solver& s); + ~viable(); + void push(unsigned num_bits) { #if NEW_VIABLE m_viable.push_back(viable_set(num_bits));