mirror of
https://github.com/Z3Prover/z3
synced 2025-11-15 10:25:45 +00:00
Do a quick check for feasibility w.r.t. bits before using forbidden intervals
This commit is contained in:
parent
e07c77e072
commit
5ddc727f91
6 changed files with 439 additions and 6 deletions
|
|
@ -86,6 +86,10 @@ namespace polysat {
|
|||
#endif
|
||||
if (try_equal_body_subsumptions(cl))
|
||||
simplified = true;
|
||||
#if 0
|
||||
if (try_bit_subsumptions(cl))
|
||||
simplified = true;
|
||||
#endif
|
||||
return simplified;
|
||||
}
|
||||
|
||||
|
|
@ -344,8 +348,245 @@ namespace polysat {
|
|||
return true;
|
||||
}
|
||||
|
||||
// decomposes into a plain constant and a part containing variables. e.g., 2*x*y + 3*z - 2 gets { 2*x*y + 3*z, -2 }
|
||||
static std::pair<pdd, pdd> decompose_constant(const pdd& p) {
|
||||
for (const auto& m : p) {
|
||||
if (m.vars.empty())
|
||||
return { p - m.coeff, p.manager().mk_val(m.coeff) };
|
||||
}
|
||||
return { p, p.manager().mk_val(0) };
|
||||
}
|
||||
|
||||
// 2^(k - d) * x = m * 2^(k - d)
|
||||
// TODO: Factor out constant factors from x and put them to the rhs
|
||||
bool simplify_clause::get_trailing_mask(pdd lhs, pdd rhs, pdd& p, trailing_bits& mask, bool pos) {
|
||||
auto lhs_decomp = decompose_constant(lhs);
|
||||
auto rhs_decomp = decompose_constant(rhs);
|
||||
|
||||
lhs = lhs_decomp.first - rhs_decomp.first;
|
||||
rhs = rhs_decomp.second - lhs_decomp.second;
|
||||
|
||||
SASSERT(rhs.is_val());
|
||||
|
||||
unsigned k = lhs.manager().power_of_2();
|
||||
unsigned d = lhs.max_pow2_divisor();
|
||||
unsigned span = k - d;
|
||||
if (span == 0)
|
||||
return false;
|
||||
|
||||
p = lhs.div(rational::power_of_two(d));
|
||||
rational rhs_val = rhs.val();
|
||||
mask.bits = rhs_val / rational::power_of_two(d);
|
||||
if (!mask.bits.is_int())
|
||||
return false;
|
||||
mask.length = span;
|
||||
mask.positive = pos;
|
||||
return true;
|
||||
}
|
||||
|
||||
// 2^(k - 1) <= 2^(k - i - 1) * x (original definition) // TODO: Have this as well
|
||||
// 2^(k - i - 1) * x + 2^(k - 1) <= 2^(k - 1) - 1 (currently we test only for this form)
|
||||
bool simplify_clause::get_bit(const pdd& lhs, const pdd& rhs, pdd& p, single_bit& bit, bool pos) {
|
||||
if (!rhs.is_val())
|
||||
return false;
|
||||
|
||||
rational rhs_val = rhs.val() + 1;
|
||||
unsigned k = rhs.power_of_2();
|
||||
|
||||
if (rhs_val != rational::power_of_two(k - 1))
|
||||
return false;
|
||||
|
||||
pdd rest = lhs - rhs_val;
|
||||
unsigned d = rest.max_pow2_divisor();
|
||||
bit.position = k - d - 1;
|
||||
bit.positive = pos;
|
||||
p = rest.div(rational::power_of_two(d));
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
// Compares with respect to "subsumption"
|
||||
// -1: mask1 < mask2 (e.g., 101 < 0101)
|
||||
// 0: incomparable
|
||||
// 1: mask1 > mask2
|
||||
// failure mask1 == mask2
|
||||
static int compare(const trailing_bits& mask1, const trailing_bits& mask2) {
|
||||
if (mask1.length == mask2.length) {
|
||||
SASSERT(mask1.bits != mask2.bits); // otw. we would have already eliminated the duplicate constraint
|
||||
return 0;
|
||||
}
|
||||
if (mask1.length < mask2.length) {
|
||||
for (unsigned i = 0; i < mask1.length; i++)
|
||||
if (mask1.bits.get_bit(i) != mask2.bits.get_bit(i))
|
||||
return 0;
|
||||
return -1;
|
||||
}
|
||||
SASSERT(mask1.length > mask2.length);
|
||||
for (unsigned i = 0; i < mask2.length; i++)
|
||||
if (mask1.bits.get_bit(i) != mask2.bits.get_bit(i))
|
||||
return 0;
|
||||
return 1;
|
||||
}
|
||||
|
||||
/**
|
||||
* Test simple subsumption between bit and parity constraints.
|
||||
*
|
||||
* let lsb(t, d) = m := 2^(k - d)*t = m * 2^(k - d) denotes that the last (least significant) d bits of t are the binary representation of m
|
||||
* let bit(t, i) := 2^(k - 1) <= 2^(k - i - 1)*t
|
||||
* TODO: 2^(k - 1 - d) <= 2^(k - i - 1)*t denotes that bits i-d...i are set to 0
|
||||
*
|
||||
* lsb(t, d) = m with log2(m) >= d => false
|
||||
*
|
||||
* parity(t) >= d denotes lsb(t, d) = 0
|
||||
* parity(t) <= d denotes lsb(t, d + 1) != 0
|
||||
*
|
||||
* parity(t) >= d1 || parity(t) >= d2 with d1 < d2 implies parity(t) >= d1
|
||||
* parity(t) <= d1 || parity(t) <= d2 with d1 < d2 implies parity(t) <= d2
|
||||
*
|
||||
* parity(t) >= d1 || !bit(t, d2) with d2 < d1 implies bit(t, d2)
|
||||
* parity(t) <= d1 || bit(t, d2) with d2 < d1 implies parity(t) <= d1
|
||||
*
|
||||
* parity(t) >= d1 || parity(t) <= d2 with d1 <= d2 implies true
|
||||
*
|
||||
* More generally: parity can be replaced by lsb in case we check for subsumption between the bit-masks rather than comparing the parities (special case)
|
||||
*/
|
||||
bool simplify_clause::try_bit_subsumptions(clause& cl) {
|
||||
|
||||
struct pdd_info {
|
||||
unsigned sz;
|
||||
vector<trailing_bits> leading;
|
||||
vector<single_bit> fixed_bits;
|
||||
};
|
||||
|
||||
struct optional_pdd_hash {
|
||||
unsigned operator()(optional<pdd> const& args) const {
|
||||
return args->hash();
|
||||
}
|
||||
};
|
||||
|
||||
ptr_vector<pdd_info> info_list;
|
||||
map<optional<pdd>, pdd_info*, optional_pdd_hash, default_eq<optional<pdd>>> info_table;
|
||||
bool is_valid = false;
|
||||
|
||||
auto get_info = [&info_table, &info_list](const pdd& p) -> pdd_info& {
|
||||
auto it = info_table.find_iterator(optional(p));
|
||||
if (it != info_table.end())
|
||||
return *it->m_value;
|
||||
auto* info = alloc(pdd_info);
|
||||
info->sz = p.manager().power_of_2();
|
||||
info_list.push_back(info);
|
||||
info_table.insert(optional(p), info);
|
||||
return *info;
|
||||
};
|
||||
|
||||
bool changed = false;
|
||||
bool_vector removed(cl.size(), false);
|
||||
|
||||
for (unsigned i = 0; i < cl.size(); i++) {
|
||||
signed_constraint c = s.lit2cnstr(cl[i]);
|
||||
if (!c->is_ule())
|
||||
continue;
|
||||
trailing_bits mask;
|
||||
single_bit bit;
|
||||
pdd p = c->to_ule().lhs();
|
||||
if ((c.is_eq() || c.is_diseq()) && get_trailing_mask(c->to_ule().lhs(), c->to_ule().rhs(), p, mask, c.is_positive())) {
|
||||
if (mask.bits.bitsize() > mask.length) {
|
||||
removed[i] = true; // skip this constraint. e.g., 2^(k-3)*x = 9*2^(k-3) is false as 9 >= 2^3
|
||||
continue;
|
||||
}
|
||||
|
||||
mask.src_idx = i;
|
||||
get_info(p).leading.push_back(mask);
|
||||
}
|
||||
else if (c->is_ule() && get_bit(c->to_ule().lhs(), c->to_ule().rhs(), p, bit, c.is_positive())) {
|
||||
bit.src_idx = i;
|
||||
get_info(p).fixed_bits.push_back(bit);
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& entry : info_list) {
|
||||
for (unsigned i = 0; i < entry->leading.size(); i++) {
|
||||
auto& p1 = entry->leading[i];
|
||||
// trailing vs. positive
|
||||
for (unsigned j = i + 1; !removed[p1.src_idx] && j < entry->leading.size(); j++) {
|
||||
auto& p2 = entry->leading[j];
|
||||
if (!removed[p2.src_idx])
|
||||
continue;
|
||||
|
||||
if (p1.positive == p2.positive) {
|
||||
int cmp = compare(p1, p2);
|
||||
if (cmp != 0) {
|
||||
if ((cmp == -1) == p1.positive) {
|
||||
LOG("Removed: " << s.lit2cnstr(cl[p2.src_idx]) << " because of " << s.lit2cnstr(cl[p1.src_idx]) << "\n");
|
||||
removed[p2.src_idx] = true;
|
||||
changed = true;
|
||||
}
|
||||
else if ((cmp == 1) == p1.positive) {
|
||||
LOG("Removed: " << s.lit2cnstr(cl[p1.src_idx]) << " because of " << s.lit2cnstr(cl[p2.src_idx]) << "\n");
|
||||
removed[p1.src_idx] = true;
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
auto& pos = p1.positive ? p1 : p2;
|
||||
auto& neg = p1.positive ? p2 : p1;
|
||||
|
||||
int cmp = compare(pos, neg);
|
||||
if (cmp == -1) {
|
||||
is_valid = true;
|
||||
changed = true;
|
||||
LOG("Tautology: " << s.lit2cnstr(cl[pos.src_idx]) << " and " << s.lit2cnstr(cl[neg.src_idx]) << "\n");
|
||||
goto done;
|
||||
}
|
||||
}
|
||||
}
|
||||
// trailing vs. bit
|
||||
for (unsigned j = 0; !removed[p1.src_idx] && j < entry->fixed_bits.size(); j++) {
|
||||
auto& p2 = entry->fixed_bits[j];
|
||||
if (removed[p2.src_idx])
|
||||
continue;
|
||||
|
||||
if (p2.position >= p1.length)
|
||||
continue;
|
||||
|
||||
if (p1.positive) {
|
||||
if (p1.bits.get_bit(p2.position) == p2.positive) {
|
||||
LOG("Removed: " << s.lit2cnstr(cl[p1.src_idx]) << " because of " << s.lit2cnstr(cl[p2.src_idx]) << " (bit)\n");
|
||||
removed[p1.src_idx] = true;
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
else {
|
||||
if (p1.bits.get_bit(p2.position) != p2.positive) {
|
||||
LOG("Removed: " << s.lit2cnstr(cl[p2.src_idx]) << " (bit) because of " << s.lit2cnstr(cl[p1.src_idx]) << "\n");
|
||||
removed[p2.src_idx] = true;
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
done:
|
||||
for (auto entry : info_list)
|
||||
dealloc(entry);
|
||||
if (is_valid) {
|
||||
SASSERT(!cl.empty());
|
||||
cl.literals().clear();
|
||||
cl.literals().push_back(s.eq(s.value(rational::zero(), 1)).blit()); // an obvious tautology
|
||||
return true;
|
||||
}
|
||||
// Remove subsuming literals
|
||||
if (!changed)
|
||||
return false;
|
||||
verbose_stream() << "Bit simplified\n";
|
||||
unsigned cli = 0;
|
||||
for (unsigned i = 0; i < cl.size(); ++i)
|
||||
if (!removed[i])
|
||||
cl[cli++] = cl[i];
|
||||
cl.m_literals.shrink(cli);
|
||||
return true;
|
||||
}
|
||||
|
||||
#if 0
|
||||
// All variables of clause 'cl' except 'z' are assigned.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue