From 173fb9c2bdb70118d39dc3cb72d36cd5a3275224 Mon Sep 17 00:00:00 2001 From: Clemens Eisenhofer Date: Sat, 24 Dec 2022 16:37:53 +0100 Subject: [PATCH] Bit-Propagation for most operations (Backtracking missing) --- src/math/polysat/fixed_bits.cpp | 522 ++++++++++++++++++++++------ src/math/polysat/fixed_bits.h | 104 ++++-- src/math/polysat/op_constraint.cpp | 56 ++- src/math/polysat/ule_constraint.cpp | 108 ++++++ src/math/polysat/ule_constraint.h | 1 + src/test/polysat.cpp | 54 ++- src/util/tbv.h | 11 +- 7 files changed, 706 insertions(+), 150 deletions(-) diff --git a/src/math/polysat/fixed_bits.cpp b/src/math/polysat/fixed_bits.cpp index 0ce49d092..8b94c1a59 100644 --- a/src/math/polysat/fixed_bits.cpp +++ b/src/math/polysat/fixed_bits.cpp @@ -20,17 +20,30 @@ namespace polysat { return fixed.m_tbv_to_justification[{ p, idx }]; } - const tbv_ref& bit_justication::get_tbv(fixed_bits& fixed, const pdd& p) { + const tbv_ref* bit_justication::get_tbv(fixed_bits& fixed, const pdd& p) { return fixed.get_tbv(p); } - bool bit_justication::fix_value(fixed_bits& fixed, const pdd& p, tbv_ref& tbv, unsigned idx, tbit val, bit_justication* j) { - return fixed.fix_value(p, tbv, idx, val, j); + // returns: Is it consistent + bool bit_justication::fix_value_core(solver& s, fixed_bits& fixed, const pdd& p, tbv_ref& tbv, unsigned idx, tbit val, bit_justication** j) { + SASSERT(j && *j); + if (!fixed.fix_value(s, p, tbv, idx, val, *j) && (*j)->can_dealloc()) { + // TODO: Potential double deallocation + dealloc(*j); + *j = nullptr; + } + return fixed.m_consistent; + } + + bool bit_justication::fix_value_core(solver& s, fixed_bits& fixed, const pdd& p, tbv_ref& tbv, unsigned idx, tbit val, bit_justication* j) { + return fix_value_core(s, fixed, p, tbv, idx, val, &j); } void bit_justication_constraint::get_dependencies(fixed_bits& fixed, bit_dependencies& to_process) { - for (const auto& dep : this->m_dependencies) + for (const auto& dep : this->m_dependencies) { + LOG("Dependency: pdd: " << dep.pdd() << " idx: " << dep.idx()); to_process.push_back(dep); + } } bit_justication_constraint* bit_justication_constraint::mk_justify_at_least(constraint *c, const pdd& v, const tbv_ref& tbv, const rational& least) { @@ -69,38 +82,164 @@ namespace polysat { // r1 = (p0 q1 + p1 q0) + (p0 q0) / 2 = (p0 q1 + p1 q0) // r2 = (p0 q2 + p1 q1 + p2 q0) + (p0 q1 + p1 q0) / 2 + (p0 q0) / 4 = (p0 q2 + p1 q1 + p2 q0) + (p0 q1 + p1 q0) / 2 // r3 = (p0 q3 + p1 q2 + p2 q1 + p3 q0) + (p0 q2 + p1 q1 + p2 q0) / 2 + (p0 q1 + p1 q0) / 4 + (p0 q0) / 8 = (p0 q3 + p1 q2 + p2 q1 + p3 q0) + (p0 q2 + p1 q1 + p2 q0) / 2 - tbv_ref& bit_justication_mul::mul(fixed_bits& fixed, const pdd& p, const tbv_ref& in1, const tbv_ref& in2) { - auto m = in1.manager(); - tbv_ref& out = fixed.get_tbv(p); + void bit_justication_mul::propagate(solver& s, fixed_bits& fixed, const pdd& r, const pdd &p, const pdd &q) { + LOG_H2("Bit-Propagating: " << r << " = (" << p << ") * (" << q << ")"); + tbv_ref& p_tbv = *fixed.get_tbv(p); + tbv_ref& q_tbv = *fixed.get_tbv(q); + tbv_ref& r_tbv = *fixed.get_tbv(r); + LOG("p: " << p << " = " << p_tbv); + LOG("q: " << q << " = " << q_tbv); + LOG("r: " << r << " = " << r_tbv); - unsigned min_bit_value = 0; // The value of the current bit assuming all unknown bits are 0 - unsigned max_bit_value = 0; // The value of the current bit assuming all unknown bits are 1 - - // TODO: Check: Is the performance too worse? It is O(k^2) + auto& m = r_tbv.manager(); + // TODO: maybe propagate the bits only until the first "don't know" and as well for the leading "0"s [The bits in-between are rare and hard to compute] + unsigned min_val = 0; // The value of the current bit assuming all unknown bits are 0 + unsigned max_val = 0; // The value of the current bit assuming all unknown bits are 1 + unsigned highest_overflow_idx = -1; // The index which could result in the highest overflow (Used for backward propagation. Which previous bit-index could have the highest overflow to the current bit?) + unsigned highest_overflow_val = 0; // The respective value + bool highest_overflow_precise = false; // True if the highest overflow is still precise after all divisions by 2 (We can only use those for backward propagation. If it is not a power of 2 we don't know which values to set.) + + // Forward propagation + // Example 1: + // r4 = (0 q3 + 1 1 + 0 q1 + 0 q0) + (1 1 + 0 q1 + 1 1) / 2 + // min_val = 2 = 2 / 2 + 1; max_val = 2 = 2 / 2 + 1 and (0 q3 + 1 1 + 0 q1 + 0 q0) + (1 1 + 0 q1 + 1 1) / 2 = 2 we conclude r3 = 0 (and min_val = max_val := min_val / 2 + 2 / 2) + // + // Example 2: + // r4 = (0 q3 + 1 1 + 0 q1 + 0 q0) + (1 1 + 0 q1 + 1 q0) / 2 + // min_val = 1 = 1 + 1 / 2; max_val = 2 = 1 + 2 / 2. We cannot propagate to r4 as we don't know the value of the overflow + // + // Example 3: + // r4 = (0 q3 + p1 1 + 0 q1 + 0 q0) + (1 1 + 0 q1 + 1 1) / 2 + // min_val = 1 = 0 + 2 / 2; v = 2 = 1 + 2 / 2. We cannot propagate to r4 as we don't know the precise value + + // Backward propagation + // Example 1: + // 0 = r3 = (1 1 + 0 q2 + 1 q1 + p3 0) + (0 q2 + 1 1 + 1 1) / 2 + // highest_overflow_idx = 3 [meaning r3]; min_val = 2 = 1 + 2 / 2; max_val = 3 = 2 + 2 / 2. We can propagate q1 = 0 as min_val == max_val - 1 + // + // Example 2: + // 0 = r3 = (1 1 + 0 q2 + 0 q1 + p3 0) + (0 q2 + p1 1 + p2 1) / 2 + // highest_overflow_idx = 2; highest_overflow_precise = true; min_val = 1; max_val = 2. We can propagate p2 = p1 = 1 in r2 as min_val == max_val - 1 and we know that we can make all [highest_overflow_precise == true] undetermined products in r2 true + // + // Example 3: + // 0 = r3 = (1 1 + 0 q2 + 0 q1 + p3 0) + (1 q2 + 1 1 + p2 1) / 2 + // highest_overflow_idx = 2; highest_overflow_precise = false; min_val = 1; max_val = 2. We can not propagate p2 = 1 or q2 = 1 in r2 as we don't know which [highest_overflow_precise == false i.e., 3 is not divisible by 2] + // + // Example 4: + // 0 = r3 = (1 1 + 0 q2 + 0 q1 + p3 0) + (p0 q2 + p1 1 + 0 1) / 2 + // highest_overflow_idx = 2; highest_overflow_precise = true; min_val = 1; max_val = 2. We can propagate p1 = 1 but not p0 = 1 or q2 = 1 as we don't know which + // + // In all cases cases min_val == max_val after backward propagation [max_val = min_val if assigned to 0; min_val = max_val if assigned to 1] + + // TODO: Check: Is the performance too worse? It is O(k^3) in the worst case... for (unsigned i = 0; i < m.num_tbits(); i++) { + unsigned current_min_val = 0, current_max_val = 0; for (unsigned x = 0, y = i; x <= i; x++, y--) { - tbit bit1 = in1[x]; - tbit bit2 = in2[y]; - + tbit bit1 = p_tbv[x]; + tbit bit2 = q_tbv[y]; + if (bit1 == BIT_1 && bit2 == BIT_1) { - min_bit_value++; // we get two 1 - max_bit_value++; - } - else if (bit1 != BIT_0 && bit2 != BIT_0) { - max_bit_value++; // we could get two 1 + current_min_val++; // we get two 1 + current_max_val++; } + else if (bit1 != BIT_0 && bit2 != BIT_0) + current_max_val++; // we could get two 1 } - if (min_bit_value == max_bit_value) { + + if (max_val >= highest_overflow_val) { + highest_overflow_val = max_val; + highest_overflow_idx = i; + highest_overflow_precise = true; + } + min_val += current_min_val; + max_val += current_max_val; + + if (min_val == max_val) { // We know the value of this bit - if (!fix_value(fixed, p, out, i, min_bit_value & 1 ? BIT_1 : BIT_0, alloc(bit_justication_mul))) - return out; + // forward propagation + // this might add a conflict if the value is already set to another value + if (!fix_value_core(s, fixed, r, r_tbv, i, min_val & 1 ? BIT_1 : BIT_0, alloc(bit_justication_mul, i, p, q))) + return; } - // Subtract one; shift this to the next higher bit as "carry value" - min_bit_value >>= 1; - max_bit_value >>= 1; + else if (r_tbv[i] != BIT_z && min_val == max_val - 1) { + // backward propagation + // this cannot add a conflict. However, conflicts are already captured in the forward propagation case + tbit set; + if ((min_val & 1) == (r_tbv[i] == BIT_0 ? 0 : 1)) { + set = BIT_0; + max_val = min_val; + } + else { + set = BIT_1; + min_val = max_val; + } + SASSERT(set == BIT_0 || set == BIT_1); + SASSERT(highest_overflow_idx <= i); + if (highest_overflow_precise) { // Otherwise, we cannot set the elements in the previous ri but we at least know max_val == min_val (resp., vice-versa) + bit_justication_shared* j = nullptr; + unsigned_vector set_bits; +#define SHARED_JUSTIFICATION (j ? (j->inc_ref(), (bit_justication**)&j) : (j = alloc(bit_justication_shared, alloc(bit_justication_mul, i, p, q, r)), (bit_justication**)&j)) + + for (unsigned x = 0, y = i; x <= highest_overflow_idx; x++, y--) { + tbit bit1 = p_tbv[x]; + tbit bit2 = q_tbv[y]; + if (set == BIT_0 && bit1 != bit2) { + // Sets: (1, z), (z, 1), (0, 1), (1, 0) [the cases with two constants are used for minimizing decision levels] + // Does not set: (1, 1), (0, 0), (0, z), (z, 0) + // Also does not set: (z, z) [because we don't know which one. We only know that it has to be 0 => we can still set max_val = min_val] + if (bit1 == BIT_1) { + if (!fix_value_core(s, fixed, q, q_tbv, y, BIT_0, SHARED_JUSTIFICATION)) { + VERIFY(false); + } + set_bits.push_back(y << 1 | 1); + } + else if (bit2 == BIT_1) { + if (!fix_value_core(s, fixed, p, p_tbv, x, BIT_0, SHARED_JUSTIFICATION)) { + VERIFY(false); + } + set_bits.push_back(x << 1 | 0); + } + } + else if (set == BIT_1 && bit1 != BIT_0 && bit2 != BIT_0) { + // Sets: (1, z), (z, 1), (1, 1), (z, z) + // Does not set: (0, 0), (0, z), (z, 0), (0, 1), (1, 0) + if (bit1 == BIT_1) { + if (!fix_value_core(s, fixed, q, q_tbv, y, BIT_1, SHARED_JUSTIFICATION)) { + VERIFY(false); + } + set_bits.push_back(y << 1 | 1); + } + if (bit2 == BIT_1) { + if (!fix_value_core(s, fixed, p, p_tbv, x, BIT_1, SHARED_JUSTIFICATION)) { + VERIFY(false); + } + set_bits.push_back(x << 1 | 0); + } + if (bit1 == BIT_z && bit2 == BIT_z) { + if (!fix_value_core(s, fixed, p, p_tbv, i, BIT_1, SHARED_JUSTIFICATION) || + !fix_value_core(s, fixed, q, q_tbv, i, BIT_1, SHARED_JUSTIFICATION)) { + VERIFY(false); + } + set_bits.push_back(y << 1 | 1); + set_bits.push_back(x << 1 | 0); + } + } + } + + if (j) { + // the reference count might be higher than the number of elements in the vector + // some elements might not be relevant for the justification (e.g., because of decision-level) + ((bit_justication_mul*)j->get_justification())->m_bit_indexes = set_bits; + } + } + } + + // Subtract one; shift this to the next higher bit as "carry values" + min_val >>= 1; + max_val >>= 1; + highest_overflow_precise &= (highest_overflow_val & 1) == 0; + highest_overflow_val >>= 1; } - - return out; } // collect all bits that effect the given bit. These might be quite a lot @@ -127,91 +266,150 @@ namespace polysat { relevant_range = m_idx >= 2; else relevant_range = log2(m_idx - (log2(m_idx) + 1)); + + const tbv_ref& p_tbv = *get_tbv(fixed, *m_p); + const tbv_ref& q_tbv = *get_tbv(fixed, *m_q); - const tbv_ref& tbv1 = get_tbv(fixed, *m_c1); - const tbv_ref& tbv2 = get_tbv(fixed, *m_c2); - + if (m_r) + get_dependencies_forward(fixed, to_process, p_tbv, q_tbv, relevant_range); + else + get_dependencies_backward(fixed, to_process, p_tbv, q_tbv, relevant_range); + } + + void bit_justication_mul::get_dependencies_forward(fixed_bits &fixed, bit_dependencies &to_process, const tbv_ref& p_tbv, const tbv_ref& q_tbv, unsigned relevant_range) { for (unsigned i = m_idx - relevant_range; i <= m_idx; i++) { for (unsigned x = 0, y = i; x <= i; x++, y--) { - tbit bit1 = tbv1[x]; - tbit bit2 = tbv2[y]; + tbit bit1 = p_tbv[x]; + tbit bit2 = q_tbv[y]; if (bit1 == BIT_1 && bit2 == BIT_1) { - get_other_justification(fixed, *m_c1, x)->get_dependencies(fixed, to_process); - get_other_justification(fixed, *m_c2, x)->get_dependencies(fixed, to_process); + get_other_justification(fixed, *m_p, x)->get_dependencies(fixed, to_process); + get_other_justification(fixed, *m_q, y)->get_dependencies(fixed, to_process); } else if (bit1 == BIT_0) // TODO: Take the better one if both are zero - get_other_justification(fixed, *m_c1, x)->get_dependencies(fixed, to_process); + get_other_justification(fixed, *m_p, x)->get_dependencies(fixed, to_process); else if (bit2 == BIT_0) - get_other_justification(fixed, *m_c2, x)->get_dependencies(fixed, to_process); + get_other_justification(fixed, *m_q, y)->get_dependencies(fixed, to_process); else { // The bit is apparently not set because we cannot derive a truth-value. - // Why do we ask for an explanation + // Why do we ask for an explanation? VERIFY(false); } } } } + void bit_justication_mul::get_dependencies_backward(fixed_bits& fixed, bit_dependencies& to_process, const tbv_ref& p_tbv, const tbv_ref& q_tbv, unsigned relevant_range) { + SASSERT(!m_bit_indexes.empty()); // Who asked us for an explanation if there is nothing in the set? + unsigned set_idx = 0; + for (unsigned i = m_idx - relevant_range; i <= m_idx; i++) { + for (unsigned x = 0, y = i; x <= i; x++, y--) { + + unsigned i_p = x << 1 | 0; + unsigned i_q = y << 1 | 1; + + // the list is ordered in the same way we iterate now through it so we just look at the first elements + unsigned next1 = set_idx >= m_bit_indexes.size() ? -1 : m_bit_indexes[set_idx]; + unsigned next2 = set_idx + 1 >= m_bit_indexes.size() ? -1 : m_bit_indexes[set_idx + 1]; + + bool p_in_set = false, q_in_set =false; + + if (i_p == next1 || i_p == next2) { + set_idx++; + p_in_set = true; + } + else if (i_q == next1 || i_q == next2) { + set_idx++; + q_in_set = true; + } + + tbit bit1 = p_tbv[x]; + tbit bit2 = q_tbv[y]; + + // TODO: Check once more + + if (bit1 == BIT_1 && bit2 == BIT_1) { + if (!p_in_set) + get_other_justification(fixed, *m_p, x)->get_dependencies(fixed, to_process); + if (!q_in_set) + get_other_justification(fixed, *m_q, y)->get_dependencies(fixed, to_process); + } + else if (bit1 == BIT_0) { + if (!p_in_set) + get_other_justification(fixed, *m_p, x)->get_dependencies(fixed, to_process); + else if (!q_in_set) + get_other_justification(fixed, *m_q, y)->get_dependencies(fixed, to_process); + } + else if (bit2 == BIT_0 && !q_in_set) { + if (!q_in_set) + get_other_justification(fixed, *m_q, y)->get_dependencies(fixed, to_process); + else if (!p_in_set) + get_other_justification(fixed, *m_p, x)->get_dependencies(fixed, to_process); + } + else { + // unlike in the forward case this can happen + } + } + } + } + // similar to multiplying but far simpler/faster (only the direct predecessor might overflow) - tbv_ref& bit_justication_add::add(fixed_bits& fixed, const pdd& p, const tbv_ref& in1, const tbv_ref& in2) { - auto m = in1.manager(); - tbv_ref& out = fixed.get_tbv(p); + void bit_justication_add::propagate(solver& s, fixed_bits& fixed, const pdd& r, const pdd& p, const pdd& q) { + LOG_H2("Bit-Propagating: " << r << " = (" << p << ") + (" << q << ")"); + // TODO: Add backward propagation + tbv_ref& p_tbv = *fixed.get_tbv(p); + tbv_ref& q_tbv = *fixed.get_tbv(q); + tbv_ref& r_tbv = *fixed.get_tbv(r); + LOG("p: " << p << " = " << p_tbv); + LOG("q: " << q << " = " << q_tbv); + LOG("r: " << r << " = " << r_tbv); + + auto& m = r_tbv.manager(); unsigned min_bit_value = 0; unsigned max_bit_value = 0; for (unsigned i = 0; i < m.num_tbits(); i++) { - tbit bit1 = in1[i]; - tbit bit2 = in2[i]; - if (bit1 == BIT_1 && bit2 == BIT_1) { + tbit bit1 = p_tbv[i]; + tbit bit2 = q_tbv[i]; + if (bit1 == BIT_1) { min_bit_value++; max_bit_value++; } - else if (bit1 != BIT_0 && bit2 != BIT_0) { + else if (bit1 == BIT_z) + max_bit_value++; + + if (bit2 == BIT_1) { + min_bit_value++; max_bit_value++; } + else if (bit2 == BIT_z) + max_bit_value++; if (min_bit_value == max_bit_value) - if (!fix_value(fixed, p, out, i, min_bit_value & 1 ? BIT_1 : BIT_0, alloc(bit_justication_add))) - return out; + if (!fix_value_core(s, fixed, r, r_tbv, i, min_bit_value & 1 ? BIT_1 : BIT_0, alloc(bit_justication_add))) + return; min_bit_value >>= 1; max_bit_value >>= 1; } - - if (min_bit_value == max_bit_value) // Overflow to the first bit - fix_value(fixed, p, out, 0, min_bit_value & 1 ? BIT_1 : BIT_0, alloc(bit_justication_add)); - - return out; } void bit_justication_add::get_dependencies(fixed_bits& fixed, bit_dependencies& to_process) { - if (m_c1->power_of_2() > 1) { - if (m_idx == 0) { - get_other_justification(fixed, *m_c1, m_c1->power_of_2() - 1)->get_dependencies(fixed, to_process); - get_other_justification(fixed, *m_c2, m_c1->power_of_2() - 1)->get_dependencies(fixed, to_process); - DEBUG_CODE( - const tbv_ref& tbv1 = get_tbv(fixed, *m_c1); - const tbv_ref& tbv2 = get_tbv(fixed, *m_c2); - SASSERT(tbv1[m_c1->power_of_2() - 1] != BIT_z && tbv2[m_c1->power_of_2() - 1] != BIT_z); - ); - } - else { - get_other_justification(fixed, *m_c1, m_idx - 1)->get_dependencies(fixed, to_process); - get_other_justification(fixed, *m_c2, m_idx - 1)->get_dependencies(fixed, to_process); - DEBUG_CODE( - const tbv_ref& tbv1 = get_tbv(fixed, *m_c1); - const tbv_ref& tbv2 = get_tbv(fixed, *m_c2); - SASSERT(tbv1[m_idx - 1] != BIT_z && tbv2[m_idx - 1] != BIT_z); - ); - } + if (m_c1->power_of_2() > 1 && m_idx > 0) { + get_other_justification(fixed, *m_c1, m_idx - 1)->get_dependencies(fixed, to_process); + get_other_justification(fixed, *m_c2, m_idx - 1)->get_dependencies(fixed, to_process); + DEBUG_CODE( + const tbv_ref& tbv1 = *get_tbv(fixed, *m_c1); + const tbv_ref& tbv2 = *get_tbv(fixed, *m_c2); + SASSERT(tbv1[m_idx - 1] != BIT_z && tbv2[m_idx - 1] != BIT_z); + ); } get_other_justification(fixed, *m_c1, m_idx)->get_dependencies(fixed, to_process); get_other_justification(fixed, *m_c2, m_idx)->get_dependencies(fixed, to_process); DEBUG_CODE( - const tbv_ref& tbv1 = get_tbv(fixed, *m_c1); - const tbv_ref& tbv2 = get_tbv(fixed, *m_c2); + const tbv_ref& tbv1 = *get_tbv(fixed, *m_c1); + const tbv_ref& tbv2 = *get_tbv(fixed, *m_c2); SASSERT(tbv1[m_idx] != BIT_z && tbv2[m_idx] != BIT_z); ); } @@ -227,15 +425,16 @@ namespace polysat { return get_manager(v.power_of_2()); } - tbv_ref& fixed_bits::get_tbv(const pdd& v) { + tbv_ref* fixed_bits::get_tbv(const pdd& v) { + LOG("Looking for tbv for " << v); auto found = m_var_to_tbv.find_iterator(optional(v)); if (found == m_var_to_tbv.end()) { auto& manager = get_manager(v.power_of_2()); if (v.is_val()) - m_var_to_tbv[optional(v)] = optional(tbv_ref(manager, manager.allocate(v.val()))); + m_var_to_tbv.insert(optional(v), alloc(tbv_ref, manager, manager.allocate(v.val()))); else - m_var_to_tbv[optional(v)] = optional(tbv_ref(manager, manager.allocate())); - return *m_var_to_tbv[optional(v)]; + m_var_to_tbv.insert(optional(v), alloc(tbv_ref, manager, manager.allocate())); + return m_var_to_tbv[optional(v)]; } /*if (m_var_to_tbv.size() <= v) { m_var_to_tbv.reserve(v + 1); @@ -243,7 +442,7 @@ namespace polysat { m_var_to_tbv[v] = tbv_ref(manager, manager.allocate()); return *m_var_to_tbv[v]; }*/ - return *m_var_to_tbv[optional(v)]; + return found->m_value; /*auto& old_manager = m_var_to_tbv[optional(v)]->manager(); if (old_manager.num_tbits() >= v.power_of_2()) return *(m_var_to_tbv[optional(v)]); @@ -259,15 +458,15 @@ namespace polysat { clause_ref fixed_bits::get_explanation(solver& s, bit_justication* j1, bit_justication* j2) { bit_dependencies to_process; // TODO: Check that we do not process the same tuple multiples times (efficiency) - j1->get_dependencies(*this, to_process); - j2->get_dependencies(*this, to_process); + +#define GET_DEPENDENCY(X) do { (X)->get_dependencies(*this, to_process); if ((X)->can_dealloc()) { dealloc(X); } } while (false) clause_builder conflict(s); conflict.set_redundant(true); auto insert_constraint = [&conflict, &s](bit_justication* j) { constraint* constr; - if (j->has_constraint(constr)) + if (!j->has_constraint(constr)) return; SASSERT(constr); if (constr->has_bvar()) { @@ -280,14 +479,22 @@ namespace polysat { insert_constraint(j1); insert_constraint(j2); + GET_DEPENDENCY(j1); + GET_DEPENDENCY(j2); + // In principle, the dependencies should be acyclic so this should terminate. If there are cycles it is for sure a bug while (!to_process.empty()) { bit_dependency& curr = to_process.back(); - to_process.pop_back(); + if (curr.pdd().is_val()) { + to_process.pop_back(); + continue; // We don't need an explanation for bits of constants + } SASSERT(m_tbv_to_justification.contains(curr)); + bit_justication* j = m_tbv_to_justification[curr]; + to_process.pop_back(); insert_constraint(j); - j->get_dependencies(*this, to_process); + GET_DEPENDENCY(j); } return conflict.build(); @@ -295,50 +502,85 @@ namespace polysat { tbit fixed_bits::get_value(const pdd& p, unsigned idx) { SASSERT(p.is_var()); - return get_tbv(p)[idx]; + return (*get_tbv(p))[idx]; } - bool fixed_bits::fix_value(const pdd& p, tbv_ref& tbv, unsigned idx, tbit val, bit_justication* j) { + // True iff the justification changed? Alternatively: true if the justification was not used (can be deallocated). + bool fixed_bits::fix_value_core(const pdd& p, tbv_ref& tbv, unsigned idx, tbit val, bit_justication* j) { + LOG("Fixing bit " << idx << " in " << p << " (" << tbv << ")"); + constraint* c; + if (j->has_constraint(c)) { + LOG("justification constraint: " << *c); + } + SASSERT(val != BIT_x); // We don't use don't-cares - SASSERT(p.is_var()); if (val == BIT_z) - return true; + return false; tbit curr_val = tbv[idx]; if (val == curr_val) - return true; // TODO: Take the new justification if it has a lower decision level + return false; // TODO: Take the new justification if it has a lower decision level auto& m = tbv.manager(); if (curr_val == BIT_z) { m.set(*tbv, idx, val); - delete m_tbv_to_justification[{ p, idx }]; - m_tbv_to_justification[{ p, idx }] = j; + auto jstfc = m_tbv_to_justification.get({ p, idx }, nullptr); + if (jstfc && jstfc->can_dealloc()) + dealloc(jstfc); + m_tbv_to_justification.insert({ p, idx }, j); return true; } SASSERT((curr_val == BIT_1 && val == BIT_0) || (curr_val == BIT_0 && val == BIT_1)); SASSERT(m_tbv_to_justification.contains({ p, idx })); - return m_consistent = false; + m_consistent = false; + return false; } - bool fixed_bits::fix_value(solver& s, const pdd& p, unsigned idx, tbit val, bit_justication* j) { - tbv_ref& tbv = get_tbv(p); - if (fix_value(p, tbv, idx, val, j)) + bool fixed_bits::fix_value(solver& s, const pdd& p, tbv_ref& tbv, unsigned idx, tbit val, bit_justication* j) { + bool changed = fix_value_core(p, tbv, idx, val, j); + if (changed) return true; - clause_ref explanation = get_explanation(s, j, m_tbv_to_justification[{ p, idx }]); - s.set_conflict(*explanation); + + if (!m_consistent) { + clause_ref explanation = get_explanation(s, j, m_tbv_to_justification[{ p, idx }]); + s.set_conflict(*explanation); + } return false; } + + // return: consistent? + bool fixed_bits::fix_value(solver& s, const pdd& p, unsigned idx, tbit val, bit_justication* j) { + tbv_ref& tbv = *get_tbv(p); + bool changed = fix_value_core(p, tbv, idx, val, j); + if (changed) { // this implies consistency + propagate_to_subterm(s, p); + return true; + } + // TODO: Propagate equality if everything is set + if (!m_consistent) { + LOG("Adding conflict on bit " << idx << " on pdd " << p); + clause_ref explanation = get_explanation(s, j, m_tbv_to_justification[{ p, idx }]); + s.set_conflict(*explanation); + return false; // get_explanation will dealloc the justification + } + if (j->can_dealloc()) + dealloc(j); + return m_consistent; + } void fixed_bits::clear_value(const pdd& p, unsigned idx) { + // TODO: Use during backtracking SASSERT(p.is_var()); - tbv_ref& tbv = get_tbv(p); + tbv_ref& tbv = *get_tbv(p); auto& m = tbv.manager(); m.set(*tbv, idx, BIT_z); SASSERT(m_tbv_to_justification.contains({ p, idx })); - delete m_tbv_to_justification[{ p, idx }]; - m_tbv_to_justification[{ p, idx }] = nullptr; + auto& jstfc = m_tbv_to_justification[{ p, idx }]; + if (jstfc->can_dealloc()) + dealloc(jstfc); + jstfc = nullptr; } #define COUNT(DOWN, TO_COUNT) \ @@ -384,30 +626,94 @@ namespace polysat { return { least, most }; } - tbv_ref& fixed_bits::eval(solver& s, const pdd& p) { + tbv_ref* fixed_bits::eval(solver& s, const pdd& p) { + + if (p.is_val() || p.is_var()) + return get_tbv(p); + pdd zero = p.manager().zero(); pdd one = p.manager().one(); pdd sum = zero; - tbv_ref* prev_sum_tbv = &get_tbv(sum); for (const dd::pdd_monomial& n : p) { SASSERT(!n.coeff.is_zero()); pdd prod = p.manager().mk_val(n.coeff); - tbv_ref* prev_mul_tbv = &get_tbv(prod); + + for (pvar fac : n.vars) { + pdd fac_pdd = s.var(fac); + pdd pre_prod = prod; + prod *= fac_pdd; + + if (!pre_prod.is_val() || !pre_prod.val().is_one()) { + bit_justication_mul::propagate(s, *this, prod, pre_prod, fac_pdd); + if (!m_consistent) + return nullptr; + } + } + pdd pre_sum = sum; + sum += prod; + + if (!pre_sum.is_val() || !pre_sum.val().is_zero()) { + bit_justication_add::propagate(s, *this, sum, pre_sum, prod); + if (!m_consistent) + return nullptr; + } + } + return get_tbv(sum); + } + + //propagate to subterms of the polynomial/pdd + void fixed_bits::propagate_to_subterm(solver& s, const pdd& p) { + // we assume the tbv of p was already assigned and there was no conflict + if (p.is_var() || p.is_val()) + return; + + vector sum_subterms; + vector> prod_subterms; + pdd zero = p.manager().zero(); + pdd one = p.manager().one(); + + pdd sum = zero; + + for (const dd::pdd_monomial& n : p) { + SASSERT(!n.coeff.is_zero()); + pdd prod = p.manager().mk_val(n.coeff); + prod_subterms.push_back(vector()); + + // TODO: Maybe process the coefficient first as we have the most information there + // (however, we cannot really revert the order as we used the coefficient first for forward propagation) + if (n.coeff != 1) + prod_subterms.back().push_back(prod); for (pvar fac : n.vars) { pdd fac_pdd = s.var(fac); prod *= fac_pdd; - prev_mul_tbv = &bit_justication_mul::mul(*this, prod, *prev_mul_tbv, get_tbv(fac_pdd)); - if (!m_consistent) - return *prev_sum_tbv; + prod_subterms.back().push_back(prod); + prod_subterms.back().push_back(fac_pdd); } sum += prod; - prev_sum_tbv = &bit_justication_add::add(*this, sum, *prev_sum_tbv, *prev_mul_tbv); - if (!m_consistent) - return *prev_sum_tbv; + sum_subterms.push_back(sum); + sum_subterms.push_back(prod); + } + + SASSERT(sum_subterms[0] == sum_subterms[1] && sum_subterms.size() % 2 == 1); + SASSERT(2 * prod_subterms.size() == sum_subterms.size()); + + pdd current = p; + + for (unsigned i = sum_subterms.size() - 1; i > 1; i -= 2) { + pdd rhs = sum_subterms[i]; // a monomial for sure + pdd lhs = sum_subterms[i - 1]; + SASSERT(rhs.is_monomial()); + bit_justication_add::propagate(s, *this, current, lhs, rhs); + current = rhs; + auto& prod = prod_subterms[i / 2]; + for (unsigned j = prod.size() - 1; j > 1; j -= 2) { + bit_justication_mul::propagate(s, *this, current, prod[j], prod[j - 1]); + current = prod[j - 1]; + } + current = lhs; } - return *prev_sum_tbv; } } diff --git a/src/math/polysat/fixed_bits.h b/src/math/polysat/fixed_bits.h index 6e4e010b7..122eb2cc2 100644 --- a/src/math/polysat/fixed_bits.h +++ b/src/math/polysat/fixed_bits.h @@ -24,11 +24,13 @@ namespace polysat { class constraint; class fixed_bits; - struct bit_dependency { + class bit_dependency { optional m_pdd; unsigned m_bit_idx; - bit_dependency() : m_pdd(optional::undef()), m_bit_idx(0) {} + public: + + bit_dependency() : m_pdd(optional::undef()), m_bit_idx(0) {} bit_dependency(const bit_dependency& v) = default; bit_dependency(bit_dependency&& v) = default; @@ -38,6 +40,20 @@ namespace polysat { return m_pdd == other.m_pdd && m_bit_idx == other.m_bit_idx; } + bit_dependency& operator=(bit_dependency&& other) { + m_pdd = other.m_pdd; + m_bit_idx = other.m_bit_idx; + return *this; + } + + bit_dependency& operator=(bit_dependency& other) { + m_pdd = other.m_pdd; + m_bit_idx = other.m_bit_idx; + return *this; + } + + unsigned idx() const { return m_bit_idx; } + const pdd& pdd() const { return *m_pdd; } }; using bit_dependencies = vector; @@ -45,11 +61,42 @@ namespace polysat { class bit_justication { protected: static bit_justication* get_other_justification(const fixed_bits& fixed, const pdd& p, unsigned idx); - static const tbv_ref& get_tbv(fixed_bits& fixed, const pdd& p); - static bool fix_value(fixed_bits& fixed, const pdd& p, tbv_ref& tbv, unsigned idx, tbit val, bit_justication* j); + static const tbv_ref* get_tbv(fixed_bits& fixed, const pdd& p); + static bool fix_value_core(solver& s, fixed_bits& fixed, const pdd& p, tbv_ref& tbv, unsigned idx, tbit val, bit_justication** j); + static bool fix_value_core(solver& s, fixed_bits& fixed, const pdd& p, tbv_ref& tbv, unsigned idx, tbit val, bit_justication* j); public: + virtual bool can_dealloc() { return true; } virtual bool has_constraint(constraint*& constr) { return false; } - virtual void get_dependencies(fixed_bits& fixed, bit_dependencies& to_process) = 0; + virtual void get_dependencies(fixed_bits& fixed, bit_dependencies& to_process) = 0; // returns if element may be deallocated after call (usually true) + }; + + // if multiple bits are justified by the same justification + class bit_justication_shared : public bit_justication { + bit_justication* m_justification; + unsigned m_references = 0; + public: + bit_justication_shared() = default; + bit_justication_shared(bit_justication* j) : m_justification(j), m_references(1) {} + + bit_justication* get_justification() { return m_justification; } + + virtual bool can_dealloc() { + m_references = m_references == 0 ? 0 : m_references - 1; + if (m_references != 0) + return false; + if (m_justification->can_dealloc()) { + dealloc(m_justification); + m_justification = nullptr; + } + return true; + } + + virtual void get_dependencies(fixed_bits& fixed, bit_dependencies& to_process) override { + SASSERT(m_justification); + m_justification->get_dependencies(fixed, to_process); + } + + void inc_ref() { m_references++; } }; class bit_justication_constraint : public bit_justication { @@ -59,8 +106,9 @@ namespace polysat { // A pdd might occur multiple times if more bits of it are relevant bit_dependencies m_dependencies; - bit_justication_constraint(constraint* c) : m_constraint(c) { } - bit_justication_constraint(constraint* c, bit_dependencies&& dep) : m_constraint(c), m_dependencies(dep) { } + bit_justication_constraint(constraint* c) : m_constraint(c) {} + bit_justication_constraint(constraint* c, const bit_dependencies& dep) : m_constraint(c), m_dependencies(dep) {} + bit_justication_constraint(constraint* c, bit_dependencies&& dep) : m_constraint(c), m_dependencies(dep) {} public: @@ -74,16 +122,18 @@ namespace polysat { static bit_justication_constraint* mk_assignment(constraint* c) { return alloc(bit_justication_constraint, c ); } static bit_justication_constraint* mk_unary(constraint* c, bit_dependency v) { - bit_dependencies dep(1); + bit_dependencies dep; dep.push_back(std::move(v)); return alloc(bit_justication_constraint, c, std::move(dep)); } static bit_justication_constraint* mk_binary(constraint* c, bit_dependency v1, bit_dependency v2) { - bit_dependencies dep(2); + bit_dependencies dep; dep.push_back(std::move(v1)); dep.push_back(std::move(v2)); return alloc(bit_justication_constraint, c, std::move(dep)); } + static bit_justication_constraint* mk(constraint* c, const bit_dependencies& dep) { return alloc(bit_justication_constraint, c, dep); } + // gives the highest bits such that they already enforce a value of "tbv" that is at least "val" static bit_justication_constraint* mk_justify_at_least(constraint *c, const pdd& v, const tbv_ref& tbv, const rational& least); // similar to mk_justify_at_least: gives the highest bits such that they already enforce a value of "tbv" that is at most "val" @@ -97,15 +147,20 @@ namespace polysat { class bit_justication_mul : public bit_justication { unsigned m_idx; - optional m_c1, m_c2; + optional m_p, m_q, m_r; + unsigned_vector m_bit_indexes; public: bit_justication_mul() = default; - bit_justication_mul(unsigned idx, const pdd& c1, const pdd& c2) : m_idx(idx), m_c1(c1), m_c2(c2) {} + bit_justication_mul(unsigned idx, const pdd& p, const pdd& q) : m_idx(idx), m_p(p), m_q(q) {} + bit_justication_mul(unsigned idx, const pdd& p, const pdd& q, const pdd& r) : m_idx(idx), m_p(p), m_q(q), m_r(r) {} - static tbv_ref& mul(fixed_bits& fixed, const pdd& p, const tbv_ref& in1, const tbv_ref& in2); + // propagates from p, q => r (forward) and r, p/q => q/p (backward) + static void propagate(solver& s, fixed_bits& fixed, const pdd& r, const pdd &p, const pdd &q); void get_dependencies(fixed_bits& fixed, bit_dependencies& to_process) override; + void get_dependencies_forward(fixed_bits &fixed, bit_dependencies &to_process, const tbv_ref& p_tbv, const tbv_ref& q_tbv, unsigned relevant_range); + void get_dependencies_backward(fixed_bits& fixed, bit_dependencies& to_process, const tbv_ref& p_tbv, const tbv_ref& q_tbv, unsigned relevant_range); }; class bit_justication_add : public bit_justication { @@ -118,12 +173,12 @@ namespace polysat { bit_justication_add() = default; bit_justication_add(unsigned idx, const pdd& c1, const pdd& c2) : m_idx(idx), m_c1(c1), m_c2(c2) {} - static tbv_ref& add(fixed_bits& fixed, const pdd& p, const tbv_ref& in1, const tbv_ref& in2); + static void propagate(solver& s, fixed_bits& fixed, const pdd& r, const pdd& p, const pdd& q); void get_dependencies(fixed_bits& fixed, bit_dependencies& to_process) override; }; - class fixed_bits { + class fixed_bits final { friend bit_justication; @@ -137,19 +192,19 @@ namespace polysat { return args ? args->hash() : 0; } }; - using pdd_to_tbv_map = map, pdd_to_tbv_hash, pdd_to_tbv_eq>; + using pdd_to_tbv_map = map; using tbv_to_justification_key = bit_dependency; using tbv_to_justification_eq = default_eq; struct tbv_to_justification_hash { unsigned operator()(tbv_to_justification_key const& args) const { - return combine_hash((*args.m_pdd).hash(), args.m_bit_idx); + return combine_hash(args.pdd().hash(), args.idx()); } }; using tbv_to_justification_map = map; //vector> m_var_to_tbv; - pdd_to_tbv_map m_var_to_tbv; + pdd_to_tbv_map m_var_to_tbv; // TODO: free tbv_ref pointers tbv_to_justification_map m_tbv_to_justification; // the elements are pointers. Deallocate them before replacing them bool m_consistent = true; // in case evaluating results in a bit-conflict @@ -158,13 +213,22 @@ namespace polysat { tbv_manager& get_manager(unsigned sz); clause_ref get_explanation(solver& s, bit_justication* j1, bit_justication* j2); - bool fix_value(const pdd& p, tbv_ref& tbv, unsigned idx, tbit val, bit_justication* j); + bool fix_value_core(const pdd& p, tbv_ref& tbv, unsigned idx, tbit val, bit_justication* j); + bool fix_value(solver& s, const pdd& p, tbv_ref& tbv, unsigned idx, tbit val, bit_justication* j); + + void propagate_to_subterm(solver& s, const pdd& p); public: fixed_bits(solver& s) : m_solver(s) {} - tbv_ref& get_tbv(const pdd& p); + ~fixed_bits() { + for (auto& tbv : m_var_to_tbv) { + dealloc(tbv.m_value); + } + } + + tbv_ref* get_tbv(const pdd& p); // #count [min; max] static std::pair leading_zeros(const tbv_ref& ref); @@ -178,7 +242,7 @@ namespace polysat { bool fix_value(solver& s, const pdd& p, unsigned idx, tbit val, bit_justication* j); void clear_value(const pdd& p, unsigned idx); - tbv_ref& eval(solver& s, const pdd& p); + tbv_ref* eval(solver& s, const pdd& p); }; } diff --git a/src/math/polysat/op_constraint.cpp b/src/math/polysat/op_constraint.cpp index c29a97e87..a60eee3b0 100644 --- a/src/math/polysat/op_constraint.cpp +++ b/src/math/polysat/op_constraint.cpp @@ -113,8 +113,10 @@ namespace polysat { if (first) activate(s); +#if 0 if (!propagate_bits(s, is_positive)) return; // conflict +#endif if (clause_ref lemma = produce_lemma(s, s.assignment())) s.add_clause(*lemma); @@ -360,12 +362,12 @@ namespace polysat { bool op_constraint::propagate_bits_shl(solver& s, bool is_positive) { // TODO: Implement: negative case - tbv_ref& p_val = s.m_fixed_bits.eval(s, m_p); - tbv_ref& q_val = s.m_fixed_bits.eval(s, m_q); - tbv_ref& r_val = s.m_fixed_bits.eval(s, m_r); + tbv_ref* p_val = s.m_fixed_bits.eval(s, m_p); + tbv_ref* q_val = s.m_fixed_bits.eval(s, m_q); + tbv_ref* r_val = s.m_fixed_bits.eval(s, m_r); unsigned sz = m_p.power_of_2(); - auto [shift_min, shift_max] = s.m_fixed_bits.min_max(q_val); + auto [shift_min, shift_max] = fixed_bits::min_max(*q_val); unsigned shift_min_u, shift_max_u; @@ -388,23 +390,23 @@ namespace polysat { // TODO: Improve performance; we can reuse the justifications from the previous iteration if (shift_min_u > 0) { for (unsigned i = 0; i < shift_min_u; i++) { - if (!s.m_fixed_bits.fix_value(s, m_r, i, BIT_0, bit_justication_constraint::mk_justify_at_least(this, m_q, q_val, rational(i + 1)))) + if (!s.m_fixed_bits.fix_value(s, m_r, i, BIT_0, bit_justication_constraint::mk_justify_at_least(this, m_q, *q_val, rational(i + 1)))) return false; } } for (unsigned i = shift_min_u; i < sz; i++) { unsigned j = 0; - tbit val = p_val[i - shift_min_u]; + tbit val = (*p_val)[i - shift_min_u]; if (val == BIT_z) continue; for (; j < span; j++) { - if (p_val[i - shift_min_u + 1] != val) + if ((*p_val)[i - shift_min_u + 1] != val) break; } if (j == span) { // all elements we could shift there are equal. We can safely set this value // TODO: Relax. Sometimes we can reduce the span if further elements in q are set to the respective value - if (!s.m_fixed_bits.fix_value(s, m_r, i, val, bit_justication_constraint::mk_justify_between(this, m_q, q_val, shift_min, shift_max))) + if (!s.m_fixed_bits.fix_value(s, m_r, i, val, bit_justication_constraint::mk_justify_between(this, m_q, *q_val, shift_min, shift_max))) return false; } } @@ -524,20 +526,22 @@ namespace polysat { return l_undef; } - bool op_constraint::propagate_bits_and(solver& s, bool is_positive){ + bool op_constraint::propagate_bits_and(solver& s, bool is_positive) { // TODO: Implement: negative case - tbv_ref& p_val = s.m_fixed_bits.eval(s, m_p); - tbv_ref& q_val = s.m_fixed_bits.eval(s, m_q); - tbv_ref& r_val = s.m_fixed_bits.eval(s, m_r); + LOG_H2("Bit-Propagating: " << m_r << " = (" << m_p << ") & (" << m_q << ")"); + tbv_ref* p_val = s.m_fixed_bits.eval(s, m_p); + tbv_ref* q_val = s.m_fixed_bits.eval(s, m_q); + tbv_ref* r_val = s.m_fixed_bits.eval(s, m_r); + LOG("p: " << m_p << " = " << *p_val); + LOG("q: " << m_q << " = " << *q_val); + LOG("r: " << m_r << " = " << *r_val); unsigned sz = m_p.power_of_2(); for (unsigned i = 0; i < sz; i++) { - tbit bp = p_val[i]; - tbit bq = q_val[i]; - tbit br = r_val[i]; - - // TODO: Propagate from the result to the operands. e.g., 110... = xx1... & yyy... - // TODO: ==> x = 111..., y = 110... + tbit bp = (*p_val)[i]; + tbit bq = (*q_val)[i]; + tbit br = (*r_val)[i]; + if (bp == BIT_0 || bq == BIT_0) { // TODO: In case both are 0 use the one with the lower decision-level and not necessarily p if (!s.m_fixed_bits.fix_value(s, m_r, i, BIT_0, bit_justication_constraint::mk_unary(this, { bp == BIT_0 ? m_p : m_q, i }))) @@ -547,6 +551,22 @@ namespace polysat { if (!s.m_fixed_bits.fix_value(s, m_r, i, BIT_1, bit_justication_constraint::mk_binary(this, { m_p, i }, { m_q, i }))) return false; } + else if (br == BIT_1) { + if (!s.m_fixed_bits.fix_value(s, m_p, i, BIT_1, bit_justication_constraint::mk_unary(this, { m_r, i }))) + return false; + if (!s.m_fixed_bits.fix_value(s, m_q, i, BIT_1, bit_justication_constraint::mk_unary(this, { m_r, i }))) + return false; + } + else if (br == BIT_0) { + if (bp == BIT_1) { + if (!s.m_fixed_bits.fix_value(s, m_q, i, BIT_1, bit_justication_constraint::mk_binary(this, { m_p, i }, { m_r, i }))) + return false; + } + else if (bq == BIT_1) { + if (!s.m_fixed_bits.fix_value(s, m_p, i, BIT_1, bit_justication_constraint::mk_binary(this, { m_q, i }, { m_r, i }))) + return false; + } + } } return true; } diff --git a/src/math/polysat/ule_constraint.cpp b/src/math/polysat/ule_constraint.cpp index 57cf4b0e6..59d2675ea 100644 --- a/src/math/polysat/ule_constraint.cpp +++ b/src/math/polysat/ule_constraint.cpp @@ -209,6 +209,9 @@ namespace polysat { // p > 0 s.add_clause(~sc, s.ult(0, p), false); } +#if 0 + propagate_bits(s, is_positive); +#endif } // Evaluate lhs <= rhs @@ -236,6 +239,111 @@ namespace polysat { lbool ule_constraint::eval(assignment const& a) const { return eval(a.apply_to(lhs()), a.apply_to(rhs())); } + + bool ule_constraint::propagate_bits(solver& s, bool is_positive) { + if (is_eq() && is_positive) { + vector> e; + bool failed = false; + for (const auto& m : lhs()) { + if (e.size() > 1) { + failed = true; + break; + } + pdd p = lhs().manager().mk_val(m.coeff); + for (pvar v : m.vars) + p *= s.var(v); + e.push_back(optional(p)); + if (e.size() == 2 && (m.coeff < 0 || m.coeff >= rational::power_of_two(p.power_of_2() - 1))) + std::swap(e[0], e[1]); // try to keep it positive + } + if (!failed && !e.empty()) { + if (e.size() == 1) + e.push_back(optional(lhs().manager().mk_val(0))); + else + e[0] = optional(-*(e[0])); + SASSERT(e.size() == 2); + tbv_ref* lhs_val = s.m_fixed_bits.eval(s, *(e[0])); + tbv_ref* rhs_val = s.m_fixed_bits.eval(s, *(e[1])); + LOG("Bit-Propagating: " << *lhs_val << " = " << *rhs_val); + unsigned sz = lhs_val->num_tbits(); + for (unsigned i = 0; i < sz; i++) { + // we propagate in both directions to get the least decision level + if ((*lhs_val)[i] != BIT_z) { + if (!s.m_fixed_bits.fix_value(s, *(e[1]), i, (*lhs_val)[i], bit_justication_constraint::mk_unary(this, { *(e[0]), i }))) + return false; + } + if ((*rhs_val)[i] != BIT_z) { + if (!s.m_fixed_bits.fix_value(s, *(e[0]), i, (*rhs_val)[i], bit_justication_constraint::mk_unary(this, { *(e[1]), i }))) + return false; + } + } + return true; + } + } + + pdd lhs = is_positive ? m_lhs : m_rhs; + pdd rhs = is_positive ? m_rhs : m_lhs; + + tbv_ref* lhs_val = s.m_fixed_bits.eval(s, lhs); + tbv_ref* rhs_val = s.m_fixed_bits.eval(s, rhs); + unsigned sz = lhs_val->num_tbits(); + + LOG("Bit-Propagating: " << lhs << " (" << *lhs_val << ") " << (is_positive ? "<= " : "< ") << rhs << " (" << *rhs_val << ")"); + + // TODO: Propagate powers of 2 (lower bound) + bool conflict = false; + bit_dependencies dep; + static unsigned char action_lookup[] = { + // lhs <= rhs + // 0 .. break; could be still satisfied; + // 1 ... continue; there might still be a conflict [lhs is the justification; rhs is propagated 1]; + // 2 ... continue; --||-- [rhs is the justification; lhs is propagated 0]; + // 3 ... conflict; lhs is for sure greater than rhs; + // 4 ... invalid (should not happen) + 0, /*(z, z)*/ 0, /*(0, z)*/ 1, /*(1, z)*/ 4, /*(x, z)*/ + 2, /*(z, 0)*/ 2, /*(0, 0)*/ 3, /*(1, 0)*/ 4, /*(x, 0)*/ + 0, /*(z, 1)*/ 0, /*(0, 1)*/ 1, /*(1, 1)*/ 4, /*(x, 1)*/ + // for the positive case (vice-versa for negative case -> we swap lhs/rhs + special treatment for index 0) + }; + unsigned i = sz; + for (; i > (unsigned)!is_positive && !conflict; i--) { + tbit l = (*lhs_val)[i - 1]; + tbit r = (*rhs_val)[i - 1]; + + unsigned char action = action_lookup[l | (r << 2)]; + switch (action) { + case 0: + i = 0; + break; + case 1: + case 3: + dep.push_back({ lhs, i - 1 }); + LOG("Added dependency: pdd: " << lhs << " idx: " << i - 1); + conflict = !s.m_fixed_bits.fix_value(s, rhs, i - 1, BIT_1, bit_justication_constraint::mk(this, dep)); + SASSERT((action != 3) == conflict); + break; + case 2: + dep.push_back({ rhs, i - 1 }); + LOG("Added dependency: pdd: " << rhs << " idx: " << i - 1); + conflict = !s.m_fixed_bits.fix_value(s, lhs, i - 1, BIT_0, bit_justication_constraint::mk(this, dep)); + SASSERT(!conflict); + break; + default: + VERIFY(false); + } + } + if (!conflict && !is_positive && i == 1) { + // Special treatment for lhs < rhs (note: we swapped lhs <-> rhs so this is really a less and not a greater) + conflict = !s.m_fixed_bits.fix_value(s, lhs, 0, BIT_0, bit_justication_constraint::mk(this, dep)); + if (!conflict) + conflict = !s.m_fixed_bits.fix_value(s, rhs, 0, BIT_1, bit_justication_constraint::mk(this, dep)); + } + SASSERT( + is_positive && conflict == (fixed_bits::min_max(*lhs_val).first > fixed_bits::min_max(*rhs_val).second) || + !is_positive && conflict == (fixed_bits::min_max(*lhs_val).second <= fixed_bits::min_max(*rhs_val).first) + ); + return !conflict; + } unsigned ule_constraint::hash() const { return mk_mix(lhs().hash(), rhs().hash(), kind()); diff --git a/src/math/polysat/ule_constraint.h b/src/math/polysat/ule_constraint.h index 87dea1179..a798cae47 100644 --- a/src/math/polysat/ule_constraint.h +++ b/src/math/polysat/ule_constraint.h @@ -41,6 +41,7 @@ namespace polysat { lbool eval() const override; lbool eval(assignment const& a) const override; void narrow(solver& s, bool is_positive, bool first) override; + bool propagate_bits(solver& s, bool is_positive) override; unsigned hash() const override; bool operator==(constraint const& other) const override; bool is_eq() const override { return m_rhs.is_zero(); } diff --git a/src/test/polysat.cpp b/src/test/polysat.cpp index 9b91024de..d36a60b9f 100644 --- a/src/test/polysat.cpp +++ b/src/test/polysat.cpp @@ -1636,6 +1636,55 @@ namespace polysat { s.check(); s.expect_unsat(); } + + static void test_band6(unsigned bw = 32) { + scoped_solver s(concat(__func__, " bw=", bw)); + auto a = s.var(s.add_var(bw)); + auto x = s.var(s.add_var(bw)); + auto y = s.var(s.add_var(bw)); + + signed_constraint and1 = s.eq(a, s.band(x, x.manager().mk_val(rational::power_of_two((bw + 1) / 2) - 1))); + signed_constraint and2 = s.eq(a, s.band(x, x.manager().mk_val((rational::power_of_two((bw + 1) / 2) - 1) * rational::power_of_two(bw / 2)))); + s.add_clause(and1, false); + s.add_clause(and2, false); + s.add_clause(~s.eq(a, 0), false); + + s.check(); + s.expect_unsat(); + } + + static void test_band6_complex_term(unsigned bw = 32) { + scoped_solver s(concat(__func__, " bw=", bw)); + auto a = s.var(s.add_var(bw)); + auto x = s.var(s.add_var(bw)); + auto y = s.var(s.add_var(bw)); + + signed_constraint and1 = s.eq(a, s.band(x * y, x.manager().mk_val(rational::power_of_two((bw + 1) / 2) - 1))); + signed_constraint and2 = s.eq(a, s.band(x * y, x.manager().mk_val((rational::power_of_two((bw + 1) / 2) - 1) * rational::power_of_two(bw / 2)))); + s.add_clause(and1, false); + s.add_clause(and2, false); + s.add_clause(~s.eq(a, 0), false); + + s.check(); + s.expect_unsat(); + } + + static void test_band6_eq_order(unsigned bw = 32) { + scoped_solver s(concat(__func__, " bw=", bw)); + auto a = s.var(s.add_var(bw)); + auto x = s.var(s.add_var(bw)); + auto y = s.var(s.add_var(bw)); + + signed_constraint and1 = s.eq(a, s.band(x, x.manager().mk_val(rational::power_of_two((bw + 1) / 2) - 1))); + signed_constraint and2 = s.eq(a, s.band(x, y)); + s.add_clause(and1, false); + s.add_clause(and2, false); + s.add_clause(s.eq(x.manager().mk_val((rational::power_of_two((bw + 1) / 2) - 1) * rational::power_of_two(bw / 2)), y), false); + s.add_clause(~s.eq(a, 0), false); + + s.check(); + s.expect_unsat(); + } static void test_fi_zero() { scoped_solver s(__func__); @@ -1849,8 +1898,7 @@ static void STD_CALL polysat_on_ctrl_c(int) { void tst_polysat() { using namespace polysat; - - polysat::test_polysat::test_elim7(3); + polysat::test_polysat::test_band6(4); #if 0 // Enable this block to run a single unit test with detailed output. collect_test_records = false; @@ -1870,7 +1918,7 @@ void tst_polysat() { return; #endif - // If non-empty, only run tests whose name contains the run_filter + // If non-empty, only run tests whose name c9ontains the run_filter run_filter = ""; test_max_conflicts = 20; diff --git a/src/util/tbv.h b/src/util/tbv.h index c3588eec1..b4725a1a0 100644 --- a/src/util/tbv.h +++ b/src/util/tbv.h @@ -43,6 +43,7 @@ class tbv_manager { ptr_vector allocated_tbvs; public: tbv_manager(unsigned n): m(2*n) {} + tbv_manager(const tbv_manager& m) = delete; ~tbv_manager(); void reset(); tbv* allocate(); @@ -127,7 +128,7 @@ private: return (fixed_bit_vector::get(index) << 1) | (unsigned)fixed_bit_vector::get(index+1); } }; - + class tbv_ref { tbv_manager& mgr; tbv* d; @@ -152,4 +153,12 @@ public: unsigned num_tbits() const { return mgr.num_tbits(); } }; +inline std::ostream& operator<<(std::ostream& out, tbv_ref const& c) { + const char* names[] = { "z", "0", "1", "x" }; + for (unsigned i = c.num_tbits(); i > 0; i--) { + out << names[(unsigned)c[i - 1]]; + } + return out; +} +