3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-24 17:45:32 +00:00

Do a quick check for feasibility w.r.t. bits before using forbidden intervals

This commit is contained in:
Clemens Eisenhofer 2023-02-15 20:06:13 +01:00
parent e07c77e072
commit 5ddc727f91
6 changed files with 439 additions and 6 deletions

View file

@ -86,6 +86,10 @@ namespace polysat {
#endif
if (try_equal_body_subsumptions(cl))
simplified = true;
#if 0
if (try_bit_subsumptions(cl))
simplified = true;
#endif
return simplified;
}
@ -344,8 +348,245 @@ namespace polysat {
return true;
}
// decomposes into a plain constant and a part containing variables. e.g., 2*x*y + 3*z - 2 gets { 2*x*y + 3*z, -2 }
static std::pair<pdd, pdd> decompose_constant(const pdd& p) {
for (const auto& m : p) {
if (m.vars.empty())
return { p - m.coeff, p.manager().mk_val(m.coeff) };
}
return { p, p.manager().mk_val(0) };
}
// 2^(k - d) * x = m * 2^(k - d)
// TODO: Factor out constant factors from x and put them to the rhs
bool simplify_clause::get_trailing_mask(pdd lhs, pdd rhs, pdd& p, trailing_bits& mask, bool pos) {
auto lhs_decomp = decompose_constant(lhs);
auto rhs_decomp = decompose_constant(rhs);
lhs = lhs_decomp.first - rhs_decomp.first;
rhs = rhs_decomp.second - lhs_decomp.second;
SASSERT(rhs.is_val());
unsigned k = lhs.manager().power_of_2();
unsigned d = lhs.max_pow2_divisor();
unsigned span = k - d;
if (span == 0)
return false;
p = lhs.div(rational::power_of_two(d));
rational rhs_val = rhs.val();
mask.bits = rhs_val / rational::power_of_two(d);
if (!mask.bits.is_int())
return false;
mask.length = span;
mask.positive = pos;
return true;
}
// 2^(k - 1) <= 2^(k - i - 1) * x (original definition) // TODO: Have this as well
// 2^(k - i - 1) * x + 2^(k - 1) <= 2^(k - 1) - 1 (currently we test only for this form)
bool simplify_clause::get_bit(const pdd& lhs, const pdd& rhs, pdd& p, single_bit& bit, bool pos) {
if (!rhs.is_val())
return false;
rational rhs_val = rhs.val() + 1;
unsigned k = rhs.power_of_2();
if (rhs_val != rational::power_of_two(k - 1))
return false;
pdd rest = lhs - rhs_val;
unsigned d = rest.max_pow2_divisor();
bit.position = k - d - 1;
bit.positive = pos;
p = rest.div(rational::power_of_two(d));
return true;
}
// Compares with respect to "subsumption"
// -1: mask1 < mask2 (e.g., 101 < 0101)
// 0: incomparable
// 1: mask1 > mask2
// failure mask1 == mask2
static int compare(const trailing_bits& mask1, const trailing_bits& mask2) {
if (mask1.length == mask2.length) {
SASSERT(mask1.bits != mask2.bits); // otw. we would have already eliminated the duplicate constraint
return 0;
}
if (mask1.length < mask2.length) {
for (unsigned i = 0; i < mask1.length; i++)
if (mask1.bits.get_bit(i) != mask2.bits.get_bit(i))
return 0;
return -1;
}
SASSERT(mask1.length > mask2.length);
for (unsigned i = 0; i < mask2.length; i++)
if (mask1.bits.get_bit(i) != mask2.bits.get_bit(i))
return 0;
return 1;
}
/**
* Test simple subsumption between bit and parity constraints.
*
* let lsb(t, d) = m := 2^(k - d)*t = m * 2^(k - d) denotes that the last (least significant) d bits of t are the binary representation of m
* let bit(t, i) := 2^(k - 1) <= 2^(k - i - 1)*t
* TODO: 2^(k - 1 - d) <= 2^(k - i - 1)*t denotes that bits i-d...i are set to 0
*
* lsb(t, d) = m with log2(m) >= d => false
*
* parity(t) >= d denotes lsb(t, d) = 0
* parity(t) <= d denotes lsb(t, d + 1) != 0
*
* parity(t) >= d1 || parity(t) >= d2 with d1 < d2 implies parity(t) >= d1
* parity(t) <= d1 || parity(t) <= d2 with d1 < d2 implies parity(t) <= d2
*
* parity(t) >= d1 || !bit(t, d2) with d2 < d1 implies bit(t, d2)
* parity(t) <= d1 || bit(t, d2) with d2 < d1 implies parity(t) <= d1
*
* parity(t) >= d1 || parity(t) <= d2 with d1 <= d2 implies true
*
* More generally: parity can be replaced by lsb in case we check for subsumption between the bit-masks rather than comparing the parities (special case)
*/
bool simplify_clause::try_bit_subsumptions(clause& cl) {
struct pdd_info {
unsigned sz;
vector<trailing_bits> leading;
vector<single_bit> fixed_bits;
};
struct optional_pdd_hash {
unsigned operator()(optional<pdd> const& args) const {
return args->hash();
}
};
ptr_vector<pdd_info> info_list;
map<optional<pdd>, pdd_info*, optional_pdd_hash, default_eq<optional<pdd>>> info_table;
bool is_valid = false;
auto get_info = [&info_table, &info_list](const pdd& p) -> pdd_info& {
auto it = info_table.find_iterator(optional(p));
if (it != info_table.end())
return *it->m_value;
auto* info = alloc(pdd_info);
info->sz = p.manager().power_of_2();
info_list.push_back(info);
info_table.insert(optional(p), info);
return *info;
};
bool changed = false;
bool_vector removed(cl.size(), false);
for (unsigned i = 0; i < cl.size(); i++) {
signed_constraint c = s.lit2cnstr(cl[i]);
if (!c->is_ule())
continue;
trailing_bits mask;
single_bit bit;
pdd p = c->to_ule().lhs();
if ((c.is_eq() || c.is_diseq()) && get_trailing_mask(c->to_ule().lhs(), c->to_ule().rhs(), p, mask, c.is_positive())) {
if (mask.bits.bitsize() > mask.length) {
removed[i] = true; // skip this constraint. e.g., 2^(k-3)*x = 9*2^(k-3) is false as 9 >= 2^3
continue;
}
mask.src_idx = i;
get_info(p).leading.push_back(mask);
}
else if (c->is_ule() && get_bit(c->to_ule().lhs(), c->to_ule().rhs(), p, bit, c.is_positive())) {
bit.src_idx = i;
get_info(p).fixed_bits.push_back(bit);
}
}
for (const auto& entry : info_list) {
for (unsigned i = 0; i < entry->leading.size(); i++) {
auto& p1 = entry->leading[i];
// trailing vs. positive
for (unsigned j = i + 1; !removed[p1.src_idx] && j < entry->leading.size(); j++) {
auto& p2 = entry->leading[j];
if (!removed[p2.src_idx])
continue;
if (p1.positive == p2.positive) {
int cmp = compare(p1, p2);
if (cmp != 0) {
if ((cmp == -1) == p1.positive) {
LOG("Removed: " << s.lit2cnstr(cl[p2.src_idx]) << " because of " << s.lit2cnstr(cl[p1.src_idx]) << "\n");
removed[p2.src_idx] = true;
changed = true;
}
else if ((cmp == 1) == p1.positive) {
LOG("Removed: " << s.lit2cnstr(cl[p1.src_idx]) << " because of " << s.lit2cnstr(cl[p2.src_idx]) << "\n");
removed[p1.src_idx] = true;
changed = true;
}
}
}
else {
auto& pos = p1.positive ? p1 : p2;
auto& neg = p1.positive ? p2 : p1;
int cmp = compare(pos, neg);
if (cmp == -1) {
is_valid = true;
changed = true;
LOG("Tautology: " << s.lit2cnstr(cl[pos.src_idx]) << " and " << s.lit2cnstr(cl[neg.src_idx]) << "\n");
goto done;
}
}
}
// trailing vs. bit
for (unsigned j = 0; !removed[p1.src_idx] && j < entry->fixed_bits.size(); j++) {
auto& p2 = entry->fixed_bits[j];
if (removed[p2.src_idx])
continue;
if (p2.position >= p1.length)
continue;
if (p1.positive) {
if (p1.bits.get_bit(p2.position) == p2.positive) {
LOG("Removed: " << s.lit2cnstr(cl[p1.src_idx]) << " because of " << s.lit2cnstr(cl[p2.src_idx]) << " (bit)\n");
removed[p1.src_idx] = true;
changed = true;
}
}
else {
if (p1.bits.get_bit(p2.position) != p2.positive) {
LOG("Removed: " << s.lit2cnstr(cl[p2.src_idx]) << " (bit) because of " << s.lit2cnstr(cl[p1.src_idx]) << "\n");
removed[p2.src_idx] = true;
changed = true;
}
}
}
}
}
done:
for (auto entry : info_list)
dealloc(entry);
if (is_valid) {
SASSERT(!cl.empty());
cl.literals().clear();
cl.literals().push_back(s.eq(s.value(rational::zero(), 1)).blit()); // an obvious tautology
return true;
}
// Remove subsuming literals
if (!changed)
return false;
verbose_stream() << "Bit simplified\n";
unsigned cli = 0;
for (unsigned i = 0; i < cl.size(); ++i)
if (!removed[i])
cl[cli++] = cl[i];
cl.m_literals.shrink(cli);
return true;
}
#if 0
// All variables of clause 'cl' except 'z' are assigned.

View file

@ -17,6 +17,18 @@ Author:
namespace polysat {
class solver;
struct trailing_bits {
unsigned length;
rational bits;
bool positive;
unsigned src_idx;
};
struct single_bit {
bool positive;
unsigned position;
unsigned src_idx;
};
class simplify_clause {
@ -33,6 +45,7 @@ namespace polysat {
bool try_remove_equations(clause& cl);
bool try_recognize_bailout(clause& cl);
bool try_equal_body_subsumptions(clause& cl);
bool try_bit_subsumptions(clause& cl);
void prepare_subs_entry(subs_entry& entry, signed_constraint c);
@ -46,6 +59,9 @@ namespace polysat {
simplify_clause(solver& s);
bool apply(clause& cl);
static bool get_trailing_mask(pdd lhs, pdd rhs, pdd& p, trailing_bits& mask, bool pos);
static bool get_bit(const pdd& lhs, const pdd& rhs, pdd& p, single_bit& bit, bool pos);
};
}

View file

@ -749,16 +749,18 @@ namespace polysat {
if (a_parity != a_max_parity || (a_parity > 0 && saturation.min_parity(a1, explain_a_parity) < a_parity))
return { p, false }; // We need the parity of a and this has to be for sure less than the parity of a1
#if 0
pdd a_pi = get_pseudo_inverse(a, a_parity);
#else
pdd a_pi = s.pseudo_inv(a);
for (auto c : explain_a_parity)
precondition.insert_eval(~c);
if (b.is_zero())
return { b1, true };
#endif
pdd shift = a; // [nsb cr: should this be a1?]
pdd shift = a1;
if (a_parity > 0) {
shift = s.lshr(a1, a1.manager().mk_val(a_parity));

View file

@ -742,6 +742,143 @@ namespace {
while (e != first);
return true;
}
bool viable::quick_bit_check(pvar v) {
#if 0
return true;
#endif
auto* e = m_equal_lin[v];
auto* first = e;
if (!e)
return true;
pdd p = s.var(v);
clause_builder builder(s, "bit check");
svector<lbool> fixed(p.power_of_2(), l_undef);
vector<ptr_vector<entry>> justifications(p.power_of_2(), ptr_vector<entry>());
vector<std::pair<entry*, trailing_bits>> postponed;
auto add_entry = [&builder](entry* e) {
for (const auto& sc : e->side_cond) {
builder.insert_eval(~sc);
LOG("Side cond: " << sc);
}
builder.insert_eval(~e->src);
LOG("Adding to core: " << e->src);
};
auto add_entry_list = [add_entry](const ptr_vector<entry>& list) {
for (const auto& e : list)
add_entry(e);
};
do {
single_bit bit;
trailing_bits mask;
if (e->src->is_ule() &&
simplify_clause::get_bit(s.subst(e->src->to_ule().lhs()), s.subst(e->src->to_ule().rhs()), p, bit, e->src.is_positive()) && p.is_var()) {
lbool prev = fixed[bit.position];
fixed[bit.position] = bit.positive ? l_true : l_false;
//verbose_stream() << "Setting bit " << bit.position << " to " << bit.positive << " because of " << e->src << "\n";
if (prev != l_undef && fixed[bit.position] != prev) {
LOG("Bit conflicting " << e->src << " with " << justifications[bit.position][0]->src);
add_entry_list(justifications[bit.position]);
add_entry(e);
s.set_conflict(*builder.build());
return false;
}
// just override; we prefer bit constraints over parity as those are easier for subsumption to remove
justifications[bit.position].clear();
justifications[bit.position].push_back(e);
}
else if ((e->src->is_eq() || e->src.is_diseq()) &&
simplify_clause::get_trailing_mask(s.subst(e->src->to_ule().lhs()), s.subst(e->src->to_ule().rhs()), p, mask, e->src.is_positive()) && p.is_var()) {
if (e->src.is_positive()) {
for (unsigned i = 0; i < mask.length; i++) {
lbool prev = fixed[i];
fixed[i] = mask.bits.get_bit(i) ? l_true : l_false;
//verbose_stream() << "Setting bit " << i << " to " << mask.bits.get_bit(i) << " because of parity " << e->src << "\n";
if (prev != l_undef) {
if (fixed[i] != prev) {
LOG("Positive parity conflicting " << e->src << " with " << justifications[i][0]->src);
add_entry_list(justifications[i]);
add_entry(e);
s.set_conflict(*builder.build());
return false;
}
}
else {
SASSERT(justifications[i].empty());
justifications[i].push_back(e);
}
}
}
else
postponed.push_back({ e, mask });
}
e = e->next();
} while(e != first);
// TODO: Incomplete - e.g., if we know the trailing bits are not 00 not 10 not 01 and not 11 we could also detect a conflict
// This would require partially clause solving (worth the effort?)
bool_vector removed(postponed.size(), false);
bool changed;
do { // fixed-point required?
changed = false;
for (unsigned j = 0; j < postponed.size(); j++) {
if (removed[j])
continue;
const auto& neg = postponed[j];
unsigned indet = 0;
unsigned last_indet = 0;
unsigned i = 0;
for (; i < neg.second.length; i++) {
if (fixed[i] != l_undef) {
if (fixed[i] != (neg.second.bits.get_bit(i) ? l_true : l_false)) {
removed[j] = true;
break; // this is already satisfied
}
}
else {
indet++;
last_indet = i;
}
}
if (i == neg.second.length) {
if (indet == 0) {
// Already false
LOG("Found conflict with constraint " << neg.first->src);
for (unsigned k = 0; k < neg.second.length; k++)
add_entry_list(justifications[k]);
add_entry(neg.first);
s.set_conflict(*builder.build());
return false;
}
else if (indet == 1) {
// Simple BCP
auto& justification = justifications[last_indet];
SASSERT(justification.empty());
for (unsigned k = 0; k < neg.second.length; k++) {
if (k != last_indet) {
SASSERT(fixed[k] != l_undef);
for (const auto& just : justifications[k])
justification.push_back(just);
}
}
justification.push_back(neg.first);
fixed[last_indet] = neg.second.bits.get_bit(last_indet) ? l_false : l_true;
removed[j] = true;
//verbose_stream() << "Applying fast BCP on bit " << last_indet << " from constraint " << neg.first->src << "\n";
changed = true;
}
}
}
} while(changed);
return true;
}
bool viable::has_viable(pvar v) {
refined:
@ -997,7 +1134,7 @@ namespace {
LOG("Refinement budget exhausted! Fall back to univariate solver.");
return query_fallback<mode>(v, result);
}
lbool viable::query_find(pvar v, rational& lo, rational& hi) {
auto const& max_value = s.var2pdd(v).max_value();
lbool const refined = l_undef;
@ -1007,6 +1144,9 @@ namespace {
// For this reason, we start chasing the intervals from the start again.
lo = 0;
hi = max_value;
if (!quick_bit_check(v))
return l_false;
auto* e = m_units[v];
if (!e && !refine_viable(v, lo))
@ -1229,6 +1369,7 @@ namespace {
SASSERT(!core.vars().contains(v));
core.add_lemma("viable unsat core", core.build_lemma());
verbose_stream() << "unsat core " << core << "\n";
//exit(0);
return true;
}

View file

@ -72,6 +72,7 @@ namespace polysat {
class viable {
friend class test_fi;
friend class test_polysat;
solver& s;
forbidden_intervals m_forbidden_intervals;
@ -98,6 +99,8 @@ namespace polysat {
bool refine_equal_lin(pvar v, rational const& val);
bool refine_disequal_lin(pvar v, rational const& val);
bool quick_bit_check(pvar v);
std::ostream& display_one(std::ostream& out, pvar v, entry const* e) const;
std::ostream& display_all(std::ostream& out, pvar v, entry const* e, char const* delimiter = "") const;

View file

@ -668,6 +668,34 @@ namespace polysat {
VERIFY_EQ((*cl)[0], s.ule(p, q).blit());
}
// 2^1*x + 2^1 == 0 and 2^2*x == 0
static void test_fi_quickcheck1() {
scoped_solver s(__func__);
auto x = s.var(s.add_var(3));
signed_constraint c1 = s.eq(x * 2 + 2, 0);
signed_constraint c2 = s.eq(4 * x, 0);
s.add_clause(c1, false);
s.add_clause(c2, false);
s.m_viable.intersect(x.var(), c1);
s.m_viable.intersect(x.var(), c2);
VERIFY(!s.m_viable.quick_bit_check(x.var()));
}
// parity(x) >= 3 and bit_1(x)
static void test_fi_quickcheck2() {
scoped_solver s(__func__);
auto x = s.var(s.add_var(4));
signed_constraint c1 = s.parity_at_least(x, 3);
signed_constraint c2 = s.bit(x, 1);
s.add_clause(c1, false);
s.add_clause(c2, false);
s.m_viable.intersect(x.var(), c1);
s.m_viable.intersect(x.var(), c2);
VERIFY(!s.m_viable.quick_bit_check(x.var()));
}
// 8 * x + 3 == 0 or 8 * x + 5 == 0 is unsat
static void test_parity1() {
scoped_solver s(__func__);
@ -2037,7 +2065,9 @@ void tst_polysat() {
RUN(test_polysat::test_ineq_axiom6());
RUN(test_polysat::test_ineq_non_axiom1());
RUN(test_polysat::test_ineq_non_axiom4());
RUN(test_polysat::test_fi_quickcheck1());
RUN(test_polysat::test_fi_quickcheck2());
if (collect_test_records)
test_records.display(std::cout);
}