3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-11-15 10:25:45 +00:00

Do a quick check for feasibility w.r.t. bits before using forbidden intervals

This commit is contained in:
Clemens Eisenhofer 2023-02-15 20:06:13 +01:00
parent e07c77e072
commit 5ddc727f91
6 changed files with 439 additions and 6 deletions

View file

@ -86,6 +86,10 @@ namespace polysat {
#endif
if (try_equal_body_subsumptions(cl))
simplified = true;
#if 0
if (try_bit_subsumptions(cl))
simplified = true;
#endif
return simplified;
}
@ -344,8 +348,245 @@ namespace polysat {
return true;
}
// decomposes into a plain constant and a part containing variables. e.g., 2*x*y + 3*z - 2 gets { 2*x*y + 3*z, -2 }
static std::pair<pdd, pdd> decompose_constant(const pdd& p) {
for (const auto& m : p) {
if (m.vars.empty())
return { p - m.coeff, p.manager().mk_val(m.coeff) };
}
return { p, p.manager().mk_val(0) };
}
// 2^(k - d) * x = m * 2^(k - d)
// TODO: Factor out constant factors from x and put them to the rhs
bool simplify_clause::get_trailing_mask(pdd lhs, pdd rhs, pdd& p, trailing_bits& mask, bool pos) {
auto lhs_decomp = decompose_constant(lhs);
auto rhs_decomp = decompose_constant(rhs);
lhs = lhs_decomp.first - rhs_decomp.first;
rhs = rhs_decomp.second - lhs_decomp.second;
SASSERT(rhs.is_val());
unsigned k = lhs.manager().power_of_2();
unsigned d = lhs.max_pow2_divisor();
unsigned span = k - d;
if (span == 0)
return false;
p = lhs.div(rational::power_of_two(d));
rational rhs_val = rhs.val();
mask.bits = rhs_val / rational::power_of_two(d);
if (!mask.bits.is_int())
return false;
mask.length = span;
mask.positive = pos;
return true;
}
// 2^(k - 1) <= 2^(k - i - 1) * x (original definition) // TODO: Have this as well
// 2^(k - i - 1) * x + 2^(k - 1) <= 2^(k - 1) - 1 (currently we test only for this form)
bool simplify_clause::get_bit(const pdd& lhs, const pdd& rhs, pdd& p, single_bit& bit, bool pos) {
if (!rhs.is_val())
return false;
rational rhs_val = rhs.val() + 1;
unsigned k = rhs.power_of_2();
if (rhs_val != rational::power_of_two(k - 1))
return false;
pdd rest = lhs - rhs_val;
unsigned d = rest.max_pow2_divisor();
bit.position = k - d - 1;
bit.positive = pos;
p = rest.div(rational::power_of_two(d));
return true;
}
// Compares with respect to "subsumption"
// -1: mask1 < mask2 (e.g., 101 < 0101)
// 0: incomparable
// 1: mask1 > mask2
// failure mask1 == mask2
static int compare(const trailing_bits& mask1, const trailing_bits& mask2) {
if (mask1.length == mask2.length) {
SASSERT(mask1.bits != mask2.bits); // otw. we would have already eliminated the duplicate constraint
return 0;
}
if (mask1.length < mask2.length) {
for (unsigned i = 0; i < mask1.length; i++)
if (mask1.bits.get_bit(i) != mask2.bits.get_bit(i))
return 0;
return -1;
}
SASSERT(mask1.length > mask2.length);
for (unsigned i = 0; i < mask2.length; i++)
if (mask1.bits.get_bit(i) != mask2.bits.get_bit(i))
return 0;
return 1;
}
/**
* Test simple subsumption between bit and parity constraints.
*
* let lsb(t, d) = m := 2^(k - d)*t = m * 2^(k - d) denotes that the last (least significant) d bits of t are the binary representation of m
* let bit(t, i) := 2^(k - 1) <= 2^(k - i - 1)*t
* TODO: 2^(k - 1 - d) <= 2^(k - i - 1)*t denotes that bits i-d...i are set to 0
*
* lsb(t, d) = m with log2(m) >= d => false
*
* parity(t) >= d denotes lsb(t, d) = 0
* parity(t) <= d denotes lsb(t, d + 1) != 0
*
* parity(t) >= d1 || parity(t) >= d2 with d1 < d2 implies parity(t) >= d1
* parity(t) <= d1 || parity(t) <= d2 with d1 < d2 implies parity(t) <= d2
*
* parity(t) >= d1 || !bit(t, d2) with d2 < d1 implies bit(t, d2)
* parity(t) <= d1 || bit(t, d2) with d2 < d1 implies parity(t) <= d1
*
* parity(t) >= d1 || parity(t) <= d2 with d1 <= d2 implies true
*
* More generally: parity can be replaced by lsb in case we check for subsumption between the bit-masks rather than comparing the parities (special case)
*/
bool simplify_clause::try_bit_subsumptions(clause& cl) {
struct pdd_info {
unsigned sz;
vector<trailing_bits> leading;
vector<single_bit> fixed_bits;
};
struct optional_pdd_hash {
unsigned operator()(optional<pdd> const& args) const {
return args->hash();
}
};
ptr_vector<pdd_info> info_list;
map<optional<pdd>, pdd_info*, optional_pdd_hash, default_eq<optional<pdd>>> info_table;
bool is_valid = false;
auto get_info = [&info_table, &info_list](const pdd& p) -> pdd_info& {
auto it = info_table.find_iterator(optional(p));
if (it != info_table.end())
return *it->m_value;
auto* info = alloc(pdd_info);
info->sz = p.manager().power_of_2();
info_list.push_back(info);
info_table.insert(optional(p), info);
return *info;
};
bool changed = false;
bool_vector removed(cl.size(), false);
for (unsigned i = 0; i < cl.size(); i++) {
signed_constraint c = s.lit2cnstr(cl[i]);
if (!c->is_ule())
continue;
trailing_bits mask;
single_bit bit;
pdd p = c->to_ule().lhs();
if ((c.is_eq() || c.is_diseq()) && get_trailing_mask(c->to_ule().lhs(), c->to_ule().rhs(), p, mask, c.is_positive())) {
if (mask.bits.bitsize() > mask.length) {
removed[i] = true; // skip this constraint. e.g., 2^(k-3)*x = 9*2^(k-3) is false as 9 >= 2^3
continue;
}
mask.src_idx = i;
get_info(p).leading.push_back(mask);
}
else if (c->is_ule() && get_bit(c->to_ule().lhs(), c->to_ule().rhs(), p, bit, c.is_positive())) {
bit.src_idx = i;
get_info(p).fixed_bits.push_back(bit);
}
}
for (const auto& entry : info_list) {
for (unsigned i = 0; i < entry->leading.size(); i++) {
auto& p1 = entry->leading[i];
// trailing vs. positive
for (unsigned j = i + 1; !removed[p1.src_idx] && j < entry->leading.size(); j++) {
auto& p2 = entry->leading[j];
if (!removed[p2.src_idx])
continue;
if (p1.positive == p2.positive) {
int cmp = compare(p1, p2);
if (cmp != 0) {
if ((cmp == -1) == p1.positive) {
LOG("Removed: " << s.lit2cnstr(cl[p2.src_idx]) << " because of " << s.lit2cnstr(cl[p1.src_idx]) << "\n");
removed[p2.src_idx] = true;
changed = true;
}
else if ((cmp == 1) == p1.positive) {
LOG("Removed: " << s.lit2cnstr(cl[p1.src_idx]) << " because of " << s.lit2cnstr(cl[p2.src_idx]) << "\n");
removed[p1.src_idx] = true;
changed = true;
}
}
}
else {
auto& pos = p1.positive ? p1 : p2;
auto& neg = p1.positive ? p2 : p1;
int cmp = compare(pos, neg);
if (cmp == -1) {
is_valid = true;
changed = true;
LOG("Tautology: " << s.lit2cnstr(cl[pos.src_idx]) << " and " << s.lit2cnstr(cl[neg.src_idx]) << "\n");
goto done;
}
}
}
// trailing vs. bit
for (unsigned j = 0; !removed[p1.src_idx] && j < entry->fixed_bits.size(); j++) {
auto& p2 = entry->fixed_bits[j];
if (removed[p2.src_idx])
continue;
if (p2.position >= p1.length)
continue;
if (p1.positive) {
if (p1.bits.get_bit(p2.position) == p2.positive) {
LOG("Removed: " << s.lit2cnstr(cl[p1.src_idx]) << " because of " << s.lit2cnstr(cl[p2.src_idx]) << " (bit)\n");
removed[p1.src_idx] = true;
changed = true;
}
}
else {
if (p1.bits.get_bit(p2.position) != p2.positive) {
LOG("Removed: " << s.lit2cnstr(cl[p2.src_idx]) << " (bit) because of " << s.lit2cnstr(cl[p1.src_idx]) << "\n");
removed[p2.src_idx] = true;
changed = true;
}
}
}
}
}
done:
for (auto entry : info_list)
dealloc(entry);
if (is_valid) {
SASSERT(!cl.empty());
cl.literals().clear();
cl.literals().push_back(s.eq(s.value(rational::zero(), 1)).blit()); // an obvious tautology
return true;
}
// Remove subsuming literals
if (!changed)
return false;
verbose_stream() << "Bit simplified\n";
unsigned cli = 0;
for (unsigned i = 0; i < cl.size(); ++i)
if (!removed[i])
cl[cli++] = cl[i];
cl.m_literals.shrink(cli);
return true;
}
#if 0
// All variables of clause 'cl' except 'z' are assigned.