diff --git a/src/sat/smt/polysat_core.cpp b/src/sat/smt/polysat_core.cpp index 3de88d93b..d4b62e8b6 100644 --- a/src/sat/smt/polysat_core.cpp +++ b/src/sat/smt/polysat_core.cpp @@ -40,7 +40,7 @@ namespace polysat { public: mk_assign_var(pvar v, core& c) : m_var(v), c(c) {} void undo() { - c.m_justification[m_var] = dependency::null_dependency(); + c.m_justification[m_var] = null_dependency; c.m_assignment.pop(); } }; @@ -106,7 +106,7 @@ namespace polysat { unsigned v = m_vars.size(); m_vars.push_back(sz2pdd(sz).mk_var(v)); m_activity.push_back({ sz, 0 }); - m_justification.push_back(dependency::null_dependency()); + m_justification.push_back(null_dependency); m_watch.push_back({}); m_var_queue.mk_var_eh(v); s.ctx.push(mk_add_var(*this)); diff --git a/src/sat/smt/polysat_types.h b/src/sat/smt/polysat_types.h index 42f283cc7..c8c8324d7 100644 --- a/src/sat/smt/polysat_types.h +++ b/src/sat/smt/polysat_types.h @@ -17,13 +17,16 @@ namespace polysat { using pdd = dd::pdd; using pvar = unsigned; + using pvar_vector = unsigned_vector; + inline const pvar null_var = UINT_MAX; + + class dependency { unsigned m_index; unsigned m_level; public: dependency(sat::literal lit, unsigned level) : m_index(2 * lit.index()), m_level(level) {} dependency(unsigned var_idx, unsigned level) : m_index(1 + 2 * var_idx), m_level(level) {} - static dependency null_dependency() { return dependency(0, UINT_MAX); } bool is_null() const { return m_level == UINT_MAX; } bool is_literal() const { return m_index % 2 == 0; } sat::literal literal() const { SASSERT(is_literal()); return sat::to_literal(m_index / 2); } @@ -31,6 +34,8 @@ namespace polysat { unsigned level() const { return m_level; } }; + inline const dependency null_dependency = dependency(0, UINT_MAX); + inline std::ostream& operator<<(std::ostream& out, dependency d) { if (d.is_literal()) return out << d.literal() << "@" << d.level(); diff --git a/src/sat/smt/polysat_viable.cpp b/src/sat/smt/polysat_viable.cpp index 79689d01f..deb6d415c 100644 --- a/src/sat/smt/polysat_viable.cpp +++ b/src/sat/smt/polysat_viable.cpp @@ -17,11 +17,21 @@ Notes: #include "util/debug.h" +#include "util/log.h" #include "sat/smt/polysat_viable.h" #include "sat/smt/polysat_core.h" namespace polysat { + using dd::val_pp; + + viable::viable(core& c) : c(c), cs(c.cs()), m_forbidden_intervals(c) {} + + viable::~viable() { + for (auto* e : m_alloc) + dealloc(e); + } + std::ostream& operator<<(std::ostream& out, find_t f) { switch (f) { case find_t::empty: return out << "empty"; @@ -32,5 +42,191 @@ namespace polysat { } } + viable::entry* viable::alloc_entry(pvar var) { + if (m_alloc.empty()) + return alloc(entry); + auto* e = m_alloc.back(); + e->reset(); + e->var = var; + m_alloc.pop_back(); + return e; + } + + find_t viable::find_viable(pvar v, rational& out_val) { throw default_exception("nyi"); } + + /* + * Explain why the current variable is not viable or signleton. + */ + dependency_vector viable::explain() { throw default_exception("nyi"); } + + /* + * Register constraint at index 'idx' as unitary in v. + */ + void viable::add_unitary(pvar v, unsigned idx) { + if (c.is_assigned(v)) + return; + auto [sc, d] = c.m_constraint_trail[idx]; + + entry* ne = alloc_entry(v); + if (!m_forbidden_intervals.get_interval(sc, v, *ne)) { + m_alloc.push_back(ne); + return; + } + + if (ne->interval.is_currently_empty()) { + m_alloc.push_back(ne); + return; + } + + if (ne->coeff == 1) { + intersect(v, ne); + return; + } + else if (ne->coeff == -1) { + insert(ne, v, m_diseq_lin, entry_kind::diseq_e); + return; + } + else { + unsigned const w = c.size(v); + unsigned const k = ne->coeff.parity(w); + // unsigned const lo_parity = ne->interval.lo_val().parity(w); + // unsigned const hi_parity = ne->interval.hi_val().parity(w); + + display_one(std::cerr << "try to reduce entry: ", v, ne) << "\n"; + + if (k > 0 && ne->coeff.is_power_of_two()) { + // reduction of coeff gives us a unit entry + // + // 2^k a x \not\in [ lo ; hi [ + // + // new_lo = lo[w-1:k] if lo[k-1:0] = 0 + // lo[w-1:k] + 1 otherwise + // + // new_hi = hi[w-1:k] if hi[k-1:0] = 0 + // hi[w-1:k] + 1 otherwise + // + // Reference: Fig. 1 (dtrim) in BitvectorsMCSAT + // + pdd const& pdd_lo = ne->interval.lo(); + pdd const& pdd_hi = ne->interval.hi(); + rational const& lo = ne->interval.lo_val(); + rational const& hi = ne->interval.hi_val(); + + rational new_lo = machine_div2k(lo, k); + if (mod2k(lo, k).is_zero()) + ne->side_cond.push_back(cs.eq(pdd_lo * rational::power_of_two(w - k))); + else { + new_lo += 1; + ne->side_cond.push_back(~cs.eq(pdd_lo * rational::power_of_two(w - k))); + } + + rational new_hi = machine_div2k(hi, k); + if (mod2k(hi, k).is_zero()) + ne->side_cond.push_back(cs.eq(pdd_hi * rational::power_of_two(w - k))); + else { + new_hi += 1; + ne->side_cond.push_back(~cs.eq(pdd_hi * rational::power_of_two(w - k))); + } + + // we have to update also the pdd bounds accordingly, but it seems not worth introducing new variables for this eagerly + // new_lo = lo[:k] etc. + // TODO: for now just disable the FI-lemma if this case occurs + ne->valid_for_lemma = false; + + if (new_lo == new_hi) { + // empty or full + // if (ne->interval.currently_contains(rational::zero())) + NOT_IMPLEMENTED_YET(); + } + + ne->coeff = machine_div2k(ne->coeff, k); + ne->interval = eval_interval::proper(pdd_lo, new_lo, pdd_hi, new_hi); + ne->bit_width -= k; + display_one(std::cerr << "reduced entry: ", v, ne) << "\n"; + LOG("reduced entry to unit in bitwidth " << ne->bit_width); + return intersect(v, ne); + } + + // TODO: later, can reduce according to shared_parity + // unsigned const shared_parity = std::min(coeff_parity, std::min(lo_parity, hi_parity)); + + insert(ne, v, m_equal_lin, entry_kind::equal_e); + return; + } + + } + + void viable::intersect(pvar v, entry* e) { + + throw default_exception("nyi"); + } + + void viable::log() { + for (pvar v = 0; v < m_units.size(); ++v) + log(v); + } + + void viable::log(pvar v) { + throw default_exception("nyi"); + } + + void viable::insert(entry* e, pvar v, ptr_vector& entries, entry_kind k) { + throw default_exception("nyi"); + } + + std::ostream& viable::display_one(std::ostream& out, pvar v, entry const* e) const { + auto& m = c.var2pdd(v); + if (e->coeff == -1) { + // p*val + q > r*val + s if e->src.is_positive() + // p*val + q >= r*val + s if e->src.is_negative() + // Note that e->interval is meaningless in this case, + // we just use it to transport the values p,q,r,s + rational const& p = e->interval.lo_val(); + rational const& q_ = e->interval.lo().val(); + rational const& r = e->interval.hi_val(); + rational const& s_ = e->interval.hi().val(); + out << "[ "; + out << val_pp(m, p, true) << "*v" << v << " + " << val_pp(m, q_); + out << (e->src[0].is_positive() ? " > " : " >= "); + out << val_pp(m, r, true) << "*v" << v << " + " << val_pp(m, s_); + out << " ] "; + } + else if (e->coeff != 1) + out << e->coeff << " * v" << v << " " << e->interval << " "; + else + out << e->interval << " "; + if (e->side_cond.size() <= 5) + out << e->side_cond << " "; + else + out << e->side_cond.size() << " side-conditions "; + unsigned count = 0; + for (const auto& src : e->src) { + ++count; + out << src << "; "; + if (count > 10) { + out << " ..."; + break; + } + } + return out; + } + + std::ostream& viable::display_all(std::ostream& out, pvar v, entry const* e, char const* delimiter) const { + if (!e) + return out; + entry const* first = e; + unsigned count = 0; + do { + display_one(out, v, e) << delimiter; + e = e->next(); + ++count; + if (count > 10) { + out << " ..."; + break; + } + } + while (e != first); + return out; + } } diff --git a/src/sat/smt/polysat_viable.h b/src/sat/smt/polysat_viable.h index 2f87e79cc..322b30551 100644 --- a/src/sat/smt/polysat_viable.h +++ b/src/sat/smt/polysat_viable.h @@ -17,7 +17,12 @@ Author: #pragma once #include "util/rational.h" +#include "util/dlist.h" +#include "util/map.h" +#include "util/small_object_allocator.h" + #include "sat/smt/polysat_types.h" +#include "sat/smt/polysat_fi.h" namespace polysat { @@ -29,28 +34,88 @@ namespace polysat { }; class core; + class constraints; std::ostream& operator<<(std::ostream& out, find_t x); class viable { core& c; - public: - viable(core& c) : c(c) {} + constraints& cs; + forbidden_intervals m_forbidden_intervals; - /** + struct entry final : public dll_base, public fi_record { + /// whether the entry has been created by refinement (from constraints in 'fi_record::src') + bool refined = false; + /// whether the entry is part of the current set of intervals, or stashed away for backtracking + bool active = true; + bool valid_for_lemma = true; + pvar var = null_var; + + void reset() { + // dll_base::init(this); // we never did this in alloc_entry either + fi_record::reset(); + refined = false; + active = true; + valid_for_lemma = true; + var = null_var; + } + }; + + enum class entry_kind { unit_e, equal_e, diseq_e }; + + struct layer final { + entry* entries = nullptr; + unsigned bit_width = 0; + layer(unsigned bw) : bit_width(bw) {} + }; + + class layers final { + svector m_layers; + public: + svector const& get_layers() const { return m_layers; } + layer& ensure_layer(unsigned bit_width); + layer* get_layer(unsigned bit_width); + layer* get_layer(entry* e) { return get_layer(e->bit_width); } + layer const* get_layer(unsigned bit_width) const; + layer const* get_layer(entry* e) const { return get_layer(e->bit_width); } + entry* get_entries(unsigned bit_width) const { layer const* l = get_layer(bit_width); return l ? l->entries : nullptr; } + }; + + ptr_vector m_alloc; + vector m_units; // set of viable values based on unit multipliers, layered by bit-width in descending order + ptr_vector m_equal_lin; // entries that have non-unit multipliers, but are equal + ptr_vector m_diseq_lin; // entries that have distinct non-zero multipliers + + entry* alloc_entry(pvar v); + + std::ostream& display_one(std::ostream& out, pvar v, entry const* e) const; + std::ostream& display_all(std::ostream& out, pvar v, entry const* e, char const* delimiter = "") const; + void log(); + void log(pvar v); + + void insert(entry* e, pvar v, ptr_vector& entries, entry_kind k); + + void intersect(pvar v, entry* e); + + + public: + viable(core& c); + ~viable(); + + /** * Find a next viable value for variable. */ - find_t find_viable(pvar v, rational& out_val) { throw default_exception("nyi"); } + find_t find_viable(pvar v, rational& out_val); /* * Explain why the current variable is not viable or signleton. */ - dependency_vector explain() { throw default_exception("nyi"); } + dependency_vector explain(); /* * Register constraint at index 'idx' as unitary in v. */ - void add_unitary(pvar v, unsigned idx) { throw default_exception("nyi"); } + void add_unitary(pvar v, unsigned idx); }; diff --git a/src/util/rational.h b/src/util/rational.h index f47fddefe..4253bd4a7 100644 --- a/src/util/rational.h +++ b/src/util/rational.h @@ -501,6 +501,16 @@ public: return k; } + /** Number of trailing zeros in an N-bit representation */ + unsigned parity(unsigned num_bits) const { + SASSERT(!is_neg()); + SASSERT(*this < rational::power_of_two(num_bits)); + if (is_zero()) + return num_bits; + return trailing_zeros(); + } + + static bool limit_denominator(rational &num, rational const& limit); };