3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2026-06-19 15:16:29 +00:00

tune and fix derive

This commit is contained in:
Nikolaj Bjorner 2026-06-09 11:10:07 -07:00
parent 143e5b9ffd
commit 758aff4f1e
2 changed files with 119 additions and 49 deletions

View file

@ -708,29 +708,116 @@ namespace seq {
return false;
}
void derive::flatten_union(expr* e, expr_ref_vector& args) {
expr* a, *b;
if (re().is_union(e, a, b)) {
flatten_union(a, args);
flatten_union(b, args);
} else {
args.push_back(e);
unsigned derive::union_id(expr* e) {
expr* c = nullptr;
return re().is_complement(e, c) ? c->get_id() : e->get_id();
}
bool derive::are_complements(expr* a, expr* b) {
expr* c = nullptr;
if (re().is_complement(a, c) && c == b) return true;
if (re().is_complement(b, c) && c == a) return true;
return false;
}
// Merge two sorted right-associated union chains.
// Uses is_subset for pairwise subsumption during merge.
expr_ref derive::merge_union(expr* r1, expr* r2) {
expr_ref _r1(r1, m), _r2(r2, m);
if (r1 == r2) return expr_ref(r1, m);
if (re().is_empty(r1)) return expr_ref(r2, m);
if (re().is_empty(r2)) return expr_ref(r1, m);
if (re().is_full_seq(r1)) return expr_ref(r1, m);
if (re().is_full_seq(r2)) return expr_ref(r2, m);
if (are_complements(r1, r2)) return expr_ref(re().mk_full_seq(r1->get_sort()), m);
// Flatten both chains into a vector, merge-sort style
expr_ref_vector elems(m);
auto collect = [&](expr* r) {
expr* a, *b;
while (re().is_union(r, a, b)) {
elems.push_back(a);
r = b;
}
elems.push_back(r);
};
unsigned split;
collect(r1);
split = elems.size();
collect(r2);
// Merge pass: produce sorted result with subsumption
expr_ref_vector result_elems(m);
unsigned i = 0, j = split;
while (i < split && j < elems.size()) {
expr* a = elems.get(i);
expr* b = elems.get(j);
if (a == b) {
result_elems.push_back(a);
++i; ++j;
} else if (are_complements(a, b)) {
return expr_ref(re().mk_full_seq(r1->get_sort()), m);
} else {
unsigned aid = union_id(a), bid = union_id(b);
if (aid == bid) {
// Same union_id: check subsumption
if (is_subset(a, b))
result_elems.push_back(b);
else if (is_subset(b, a))
result_elems.push_back(a);
else {
result_elems.push_back(a);
result_elems.push_back(b);
}
++i; ++j;
} else if (aid < bid) {
result_elems.push_back(a);
++i;
} else {
result_elems.push_back(b);
++j;
}
}
}
while (i < split) result_elems.push_back(elems.get(i++));
while (j < elems.size()) result_elems.push_back(elems.get(j++));
// Subsumption pass: check each element against its neighbors
// This catches cases like loop(0,k)·star ⊆ loop(0,k+1)·star
// which have different union_ids
unsigned n = result_elems.size();
svector<bool> removed(n, false);
for (unsigned k = 0; k + 1 < n; ++k) {
if (removed[k]) continue;
if (is_subset(result_elems.get(k), result_elems.get(k + 1))) {
removed[k] = true;
} else if (is_subset(result_elems.get(k + 1), result_elems.get(k))) {
removed[k + 1] = true;
}
}
// Build right-associated chain from result
expr_ref result(m);
for (unsigned k = n; k-- > 0; ) {
if (removed[k]) continue;
if (!result)
result = expr_ref(result_elems.get(k), m);
else
result = expr_ref(re().mk_union(result_elems.get(k), result.get()), m);
}
return result ? result : expr_ref(re().mk_empty(r1->get_sort()), m);
}
expr_ref derive::mk_union_core(expr* a, expr* b) {
// Canonical order: smaller id first
if (a->get_id() > b->get_id())
std::swap(a, b);
// Subsumption covers: a==b, empty(a), empty(b), full(a), full(b), complement absorption, etc.
if (is_subset(a, b)) return expr_ref(b, m);
if (is_subset(b, a)) return expr_ref(a, m);
// ITE handling with path pruning
auto union_op = [&](expr* x, expr* y) { return mk_union(x, y); };
expr_ref r = hoist_ite(a, b, union_op);
if (r) return r;
// ITE handling with path pruning (before merge, since ITEs aren't part of sorted chains)
expr *c1, *t1, *e1, *c2, *t2, *e2;
if (m.is_ite(a, c1, t1, e1) || m.is_ite(b, c2, t2, e2)) {
// Canonical order for non-ITE cases handled by merge below
auto union_op = [&](expr* x, expr* y) { return mk_union(x, y); };
expr_ref r = hoist_ite(a, b, union_op);
if (r) return r;
}
// Prefix factoring: a·x a·y = a·(x y)
expr *a1, *a2, *b1, *b2;
@ -739,35 +826,8 @@ namespace seq {
return mk_deriv_concat(a1, tail);
}
// ACI normalization: flatten, sort by id, deduplicate/subsume
expr_ref_vector args(m);
flatten_union(a, args);
flatten_union(b, args);
std::sort(args.data(), args.data() + args.size(), [](expr* x, expr* y) { return x->get_id() < y->get_id(); });
// Remove subsumed elements: if args[i] ⊆ args[j], drop args[i]
unsigned j = 0;
for (unsigned i = 0; i < args.size(); ++i) {
bool subsumed = false;
for (unsigned k = 0; k < j; ++k) {
if (is_subset(args.get(i), args.get(k))) { subsumed = true; break; }
}
if (!subsumed) {
// Check if new element subsumes any previously kept
unsigned new_j = 0;
for (unsigned k = 0; k < j; ++k) {
if (!is_subset(args.get(k), args.get(i)))
args[new_j++] = args.get(k);
}
args[new_j++] = args.get(i);
j = new_j;
}
}
if (j == 0)
return expr_ref(re().mk_empty(a->get_sort()), m);
expr_ref result(args.get(0), m);
for (unsigned i = 1; i < j; ++i)
result = expr_ref(re().mk_union(result, args.get(i)), m);
return result;
// Merge-based normalization: merge two sorted right-associated union chains
return merge_union(a, b);
}
expr_ref derive::mk_inter(expr* a, expr* b) {
@ -809,6 +869,14 @@ namespace seq {
expr_ref r = hoist_ite(a, b, inter_op);
if (r) return r;
// TODO: Distribution of intersection over union
// Disabled pending performance analysis
// expr *u1, *u2;
// if (re().is_union(a, u1, u2))
// return mk_union(mk_inter(u1, b), mk_inter(u2, b));
// if (re().is_union(b, u1, u2))
// return mk_union(mk_inter(a, u1), mk_inter(a, u2));
// Base case: build raw intersection
return expr_ref(re().mk_inter(a, b), m);
}