diff --git a/src/math/polysat/fixed_bits.cpp b/src/math/polysat/fixed_bits.cpp new file mode 100644 index 000000000..d49916c10 --- /dev/null +++ b/src/math/polysat/fixed_bits.cpp @@ -0,0 +1,144 @@ +/*++ +Copyright (c) 2022 Microsoft Corporation + +Module Name: + + Extract fixed bits from constraints + +Author: + + Jakob Rath, Nikolaj Bjorner (nbjorner), Clemens Eisenhofer 2022-08-22 + +--*/ + +#include "math/polysat/fixed_bits.h" +#include "math/polysat/ule_constraint.h" +#include "math/polysat/clause.h" + +namespace polysat { + + /** + * Constraint lhs <= rhs. + * + * 2^(k - d) * x = 2^(k - d) * c + * ==> x[|d|:0] = c[|d|:0] + * + * -2^(k - 2) * x > 2^(k - 1) + * <=> 2 + x[1:0] > 2 (mod 4) + * ==> x[1:0] = 1 + * -- TODO: Generalize [the obvious solution does not work] + */ + + /** + * 2^(k - d) * x = 2^(k - d) * c + * ==> x[|d|:0] = c[|d|:0] + */ + bool get_eq_fixed_lsb(pdd const& p, fixed_bits& out) { + if (!p.hi().is_val()) + return false; + // TODO: + return false; + } + + bool get_eq_fixed_bits(pdd const& p, fixed_bits& out) { + return get_eq_fixed_lsb(p, out); + } + + /** + * Constraint lhs <= rhs. + * + * -2^(k - 2) * x > 2^(k - 1) + * <=> 2 + x[1:0] > 2 (mod 4) + * ==> x[1:0] = 1 + * -- TODO: Generalize [the obvious solution does not work] + */ + bool get_ule_fixed_lsb(pdd const& lhs, pdd const& rhs, bool is_positive, fixed_bits& out) { + return false; + } + + bool get_ule_fixed_bits(pdd const& lhs, pdd const& rhs, bool is_positive, fixed_bits& out) { + return false; + } + + bool get_fixed_bits(signed_constraint c, fixed_bits& out) { + SASSERT_EQ(c->vars().size(), 1); // this only makes sense for univariate constraints + if (c->is_ule()) + return get_ule_fixed_bits(c->to_ule().lhs(), c->to_ule().rhs(), c.is_positive(), out); + // if (c->is_op()) + // ; // TODO: x & constant = constant ==> bitmask ... but we have trouble recognizing that because we introduce a new variable for '&' before we see the equality. + return false; + } + + + + +/* + // 2^(k - d) * x = m * 2^(k - d) + // Special case [still seems to occur frequently]: -2^(k - 2) * x > 2^(k - 1) - TODO: Generalize [the obvious solution does not work] => lsb(x, 2) = 1 + bool get_lsb(pdd lhs, pdd rhs, pdd& p, trailing_bits& info, bool pos) { + SASSERT(lhs.is_univariate() && lhs.degree() <= 1); + SASSERT(rhs.is_univariate() && rhs.degree() <= 1); + + if (rhs.is_zero()) { // equality + auto lhs_decomp = decouple_constant(lhs); + + lhs = lhs_decomp.first; + rhs = -lhs_decomp.second; + + SASSERT(rhs.is_val()); + + unsigned k = lhs.manager().power_of_2(); + unsigned d = lhs.max_pow2_divisor(); + unsigned span = k - d; + if (span == 0 || lhs.is_val()) + return false; + + p = lhs.div(rational::power_of_two(d)); + rational rhs_val = rhs.val(); + info.bits = rhs_val / rational::power_of_two(d); + if (!info.bits.is_int()) + return false; + + SASSERT(lhs.is_univariate() && lhs.degree() <= 1); + + auto it = p.begin(); + auto first = *it; + it++; + if (it == p.end()) { + // if the lhs contains only one monomial it is of the form: odd * x = mask. We can multiply by the inverse to get the mask for x + SASSERT(first.coeff.is_odd()); + rational inv; + VERIFY(first.coeff.mult_inverse(lhs.power_of_2(), inv)); + p *= inv; + info.bits = mod2k(info.bits * inv, span); + } + + info.length = span; + info.positive = pos; + return true; + } + else { // inequality - check for special case + if (pos || lhs.power_of_2() < 3) + return false; + auto it = lhs.begin(); + if (it == lhs.end()) + return false; + if (it->vars.size() != 1) + return false; + rational coeff = it->coeff; + it++; + if (it != lhs.end()) + return false; + if ((mod2k(-coeff, lhs.power_of_2())) != rational::power_of_two(lhs.power_of_2() - 2)) + return false; + p = lhs.div(coeff); + SASSERT(p.is_var()); + info.bits = 1; + info.length = 2; + info.positive = true; // this is a conjunction + return true; + } + } +*/ + +} // namespace polysat diff --git a/src/math/polysat/fixed_bits.h b/src/math/polysat/fixed_bits.h new file mode 100644 index 000000000..895ac8763 --- /dev/null +++ b/src/math/polysat/fixed_bits.h @@ -0,0 +1,41 @@ +/*++ +Copyright (c) 2022 Microsoft Corporation + +Module Name: + + Extract fixed bits from (univariate) constraints + +Author: + + Jakob Rath, Nikolaj Bjorner (nbjorner), Clemens Eisenhofer 2022-08-22 + +--*/ +#pragma once +#include "math/polysat/types.h" +#include "math/polysat/constraint.h" +#include "util/vector.h" + +namespace polysat { + + struct fixed_bits { + unsigned hi = 0; + unsigned lo = 0; + rational value; + + /// The constraint is equivalent to setting fixed bits on a variable. + // bool is_equivalent; + + fixed_bits() = default; + fixed_bits(unsigned hi, unsigned lo, rational value): hi(hi), lo(lo), value(value) {} + }; + + using fixed_bits_vector = vector; + + bool get_eq_fixed_lsb(pdd const& p, fixed_bits& out); + bool get_eq_fixed_bits(pdd const& p, fixed_bits& out); + + bool get_ule_fixed_lsb(pdd const& lhs, pdd const& rhs, bool is_positive, fixed_bits& out); + bool get_ule_fixed_bits(pdd const& lhs, pdd const& rhs, bool is_positive, fixed_bits& out); + bool get_fixed_bits(signed_constraint c, fixed_bits& out); + +} diff --git a/src/math/polysat/slicing.cpp b/src/math/polysat/slicing.cpp index bf53a1bf6..25463d71f 100644 --- a/src/math/polysat/slicing.cpp +++ b/src/math/polysat/slicing.cpp @@ -506,7 +506,9 @@ namespace polysat { } bool slicing::try_get_value(enode* s, rational& val) const { - return m_bv->is_numeral(s->get_expr(), val); + bool const ok = m_bv->is_numeral(s->get_expr(), val); + SASSERT_EQ(ok, is_value(s)); + return ok; } void slicing::explain_class(enode* x, enode* y, ptr_vector& out_deps) { @@ -1115,23 +1117,55 @@ namespace polysat { void slicing::add_constraint(signed_constraint c) { LOG(c); SASSERT(!is_conflict()); - if (!c->is_eq()) +#if 0 + if (!add_fixed_bits(c)) return; - pdd const& p = c->to_eq(); +#endif + if (c->is_eq()) + add_constraint_eq(c->to_eq(), c.blit()); + } + + bool slicing::add_fixed_bits(signed_constraint c) { + // TODO: what is missing here: + // - we don't prioritize constraints that set larger bit ranges + // e.g., c1 sets 3 lower bits, and c2 sets 5 lower bits. + // slicing may have both {c1,c2} in justifications while previously we always prefer c2. + // - (we could wait until propagate() to add fixed bits to the egraph. but that would only work on a single decision level.) + if (c->vars().size() != 1) + return true; + fixed_bits fb; + if (!get_fixed_bits(c, fb)) + return true; + pvar const x = c->vars()[0]; + return add_fixed_bits(x, fb.hi, fb.lo, fb.value, c.blit()); + } + + bool slicing::add_fixed_bits(pvar x, unsigned hi, unsigned lo, rational const& value, sat::literal lit) { + enode_vector& xs = m_tmp3; + SASSERT(xs.empty()); + mk_slice(var2slice(x), hi, lo, xs, false, false); + enode* const sval = mk_value_slice(value, hi - lo + 1); + // 'xs' will be cleared by 'merge'. + // NOTE: the 'nullptr' argument will be fixed by 'egraph_merge' + return merge(xs, sval, mk_var_dep(x, nullptr, lit)); + } + + bool slicing::add_constraint_eq(pdd const& p, sat::literal lit) { auto& m = p.manager(); for (auto& [a, x] : p.linear_monomials()) { if (a != 1 && a != m.max_value()) continue; pdd const body = a.is_one() ? (m.mk_var(x) - p) : (m.mk_var(x) + p); // c is either x = body or x != body, depending on polarity - if (!add_equation(x, body, c.blit())) { + if (!add_equation(x, body, lit)) { SASSERT(is_conflict()); - return; + return false; } // without this check, when p = x - y we would handle both x = y and y = x separately if (body.is_unary()) break; } + return true; } bool slicing::add_equation(pvar x, pdd const& body, sat::literal lit) { @@ -1288,28 +1322,29 @@ namespace polysat { SASSERT(all_of(m_egraph.nodes(), [](enode* n) { return !n->is_marked1(); })); } - void slicing::collect_fixed(pvar v, rational& mask, rational& value) { + void slicing::collect_fixed(pvar v, fixed_bits_vector& out, euf::enode_pair_vector& out_just) { enode_vector& base = m_tmp2; SASSERT(base.empty()); get_base(var2slice(v), base); - mask = 0; - value = 0; rational a; unsigned lo = 0; - for (auto it = base.rbegin(); it != base.rend(); ++it) { - enode* n = *it; + for (enode* n : base) { enode* r = n->get_root(); unsigned const w = width(n); + unsigned const hi = lo + w - 1; if (try_get_value(r, a)) { - rational const factor = rational::power_of_two(lo); - // TODO: probably better to return vector of {w, lo, a} instead - mask += (rational::power_of_two(w) - 1) * factor; - value += a * factor; + out.push_back({hi, lo, a}); + out_just.push_back({n, r}); } lo += w; } } + void slicing::explain_fixed(euf::enode_pair const& just, std::function const& on_lit, std::function const& on_var) { + auto [n, r] = just; + NOT_IMPLEMENTED_YET(); // TODO: like explain_value + } + std::ostream& slicing::display(std::ostream& out) const { enode_vector base; for (pvar v = 0; v < m_var2slice.size(); ++v) { diff --git a/src/math/polysat/slicing.h b/src/math/polysat/slicing.h index 15a4aacc1..97162a153 100644 --- a/src/math/polysat/slicing.h +++ b/src/math/polysat/slicing.h @@ -28,6 +28,7 @@ Notation: #include "ast/bv_decl_plugin.h" #include "math/polysat/types.h" #include "math/polysat/constraint.h" +#include "math/polysat/fixed_bits.h" #include namespace polysat { @@ -269,8 +270,11 @@ namespace polysat { pvar mk_concat(unsigned num_args, pvar const* args, pvar replay_var); void replay_concat(unsigned num_args, pvar const* args, pvar r); + bool add_constraint_eq(pdd const& p, sat::literal lit); bool add_equation(pvar x, pdd const& body, sat::literal lit); bool add_value(pvar v, rational const& value, sat::literal lit); + bool add_fixed_bits(signed_constraint c); + bool add_fixed_bits(pvar x, unsigned hi, unsigned lo, rational const& value, sat::literal lit); bool invariant() const; bool invariant_needs_congruence() const; @@ -335,7 +339,8 @@ namespace polysat { void collect_simple_overlaps(pvar v, pvar_vector& out); /** Collect fixed portions of the variable v */ - void collect_fixed(pvar v, rational& mask, rational& value); + void collect_fixed(pvar v, fixed_bits_vector& out, euf::enode_pair_vector& out_just); + void explain_fixed(euf::enode_pair const& just, std::function const& on_lit, std::function const& on_var); std::ostream& display(std::ostream& out) const; std::ostream& display_tree(std::ostream& out) const; diff --git a/src/math/polysat/viable.cpp b/src/math/polysat/viable.cpp index 2bf1697e5..0d8145e37 100644 --- a/src/math/polysat/viable.cpp +++ b/src/math/polysat/viable.cpp @@ -976,14 +976,17 @@ namespace { #if 0 // TODO: wip fixed_bits_vector fbs; - s.m_slicing.collect_fixed(v, fbs); + euf::enode_pair_vector fbs_just; + s.m_slicing.collect_fixed(v, fbs, fbs_just); - for (fixed_bits const& fb : fbs) { + for (unsigned idx = fbs.size(); idx-- > 0; ) { + fixed_bits const& fb = fbs[idx]; + euf::enode_pair const& just = fbs_just[idx]; for (unsigned i = fb.lo; i <= fb.hi; ++i) { SASSERT(out_fbi.just_src[i].empty()); // since we don't get overlapping ranges from collect_fixed. SASSERT(out_fbi.just_side_cond[i].empty()); out_fbi.fixed[i] = to_lbool(fb.value.get_bit(i - fb.lo)); - // TODO: out_fbi.just_src[i].push_back( + // TODO: s.m_slicing.explain_fixed( ... ); with out_fbi.just_src[i].push_back(...) } } #endif