3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2026-03-21 04:15:51 +00:00

Fixed couple of regex problems [there are still others]

This commit is contained in:
CEisenhofer 2026-03-18 14:28:53 +01:00
parent b1bae695e6
commit ab53889c10
11 changed files with 392 additions and 382 deletions

View file

@ -513,34 +513,75 @@ namespace euf {
return n;
}
snode* sgraph::brzozowski_deriv(snode* re, snode* elem) {
snode* sgraph::brzozowski_deriv(snode* re, snode* elem, snode* allowed_range) {
expr* re_expr = re->get_expr();
expr* elem_expr = elem->get_expr();
if (!re_expr || !elem_expr)
return nullptr;
// unwrap str.unit to get the character expression
expr* ch = nullptr;
if (m_seq.str.is_unit(elem_expr, ch))
elem_expr = ch;
// If elem is a regex predicate (e.g., re.allchar from compute_minterms),
// extract a representative character for the derivative.
sort* seq_sort = nullptr, *ele_sort = nullptr;
if (m_seq.is_re(re_expr, seq_sort) && m_seq.is_seq(seq_sort, ele_sort)) {
if (ele_sort != elem_expr->get_sort()) {
// If an explicit allowed_range is provided (which is a regex minterm),
// we extract a representative character (like 'lo') from it,
// and evaluate the derivative with respect to that representative character.
// This avoids generating massive 'ite' structures for symbolic variables.
sort* seq_sort = nullptr, *ele_sort = nullptr; if (m_seq.is_re(re_expr, seq_sort) && m_seq.is_seq(seq_sort, ele_sort)) {
// Just take one element - they are anyway all assumed to produce the same result
auto extract_rep = [&](expr* e) -> expr* {
expr* lo = nullptr, *hi = nullptr;
if (m_seq.re.is_full_char(elem_expr)) {
// re.allchar represents the entire alphabet; computing a derivative
// w.r.t. a single character would be imprecise and could incorrectly
// report fail. Return nullptr to prevent incorrect pruning.
return nullptr;
expr* r1 = nullptr, *r2 = nullptr;
while (e) {
if (m_seq.re.is_full_char(e))
return m_seq.mk_char(0);
if (m_seq.re.is_range(e, lo, hi) && lo) {
expr* lo_ch = nullptr;
zstring zs;
if (m_seq.str.is_unit(lo, lo_ch))
return lo_ch;
if (m_seq.str.is_string(lo, zs) && zs.length() > 0)
return m_seq.str.mk_char(zs[0]);
return lo;
}
if (m_seq.re.is_union(e, r1, r2))
e = r1;
else
return nullptr;
}
else if (m_seq.re.is_range(elem_expr, lo, hi) && lo)
elem_expr = lo;
return nullptr;
};
if (allowed_range && allowed_range->get_expr()) {
expr* range_expr = allowed_range->get_expr();
if (m_seq.re.is_full_char(range_expr)) {
// For full char, keep symbolic
}
else {
expr* rep = extract_rep(range_expr);
if (rep)
elem_expr = rep;
}
}
else if (ele_sort != elem_expr->get_sort()) {
expr* rep = extract_rep(elem_expr);
if (rep)
elem_expr = rep;
else
return nullptr;
}
}
SASSERT(elem_expr);
if (elem_expr->get_sort() != ele_sort) {
std::cout << "SORT MISMATCH before mk_derivative\n"
<< " elem_expr: " << mk_pp(elem_expr, m) << "\n"
<< " elem_sort: " << mk_pp(elem_expr->get_sort(), m) << "\n"
<< " ele_sort: " << mk_pp(ele_sort, m) << "\n"
<< " re_expr: " << mk_pp(re_expr, m) << std::endl;
}
expr_ref result = m_rewriter.mk_derivative(elem_expr, re_expr);
if (!result)
return nullptr;
@ -557,8 +598,35 @@ namespace euf {
preds.push_back(e);
return;
}
if (m_seq.re.is_to_re(e))
if (m_seq.re.is_to_re(e)) {
expr* s = nullptr;
if (m_seq.re.is_to_re(e, s)) {
zstring zs;
expr* ch_expr = nullptr;
if (m_seq.str.is_string(s, zs) && zs.length() > 0) {
unsigned c = zs[0];
ch_expr = m_seq.str.mk_char(c);
}
else if (m_seq.str.is_unit(s, ch_expr)) {
// ch_expr correctly extracted
}
if (ch_expr) {
expr_ref unit_str(m_seq.str.mk_unit(ch_expr), m);
expr_ref re_char(m_seq.re.mk_to_re(unit_str), m);
bool dup = false;
for (expr* p : preds) {
if (p == re_char) {
dup = true;
break;
}
}
if (!dup)
preds.push_back(re_char);
}
}
return;
}
if (m_seq.re.is_full_char(e))
return;
if (m_seq.re.is_full_seq(e))
@ -566,41 +634,82 @@ namespace euf {
if (m_seq.re.is_empty(e))
return;
// recurse into compound regex operators
for (unsigned i = 0; i < re->num_args(); ++i)
for (unsigned i = 0; i < re->num_args(); ++i) {
collect_re_predicates(re->arg(i), preds);
}
}
void sgraph::compute_minterms(snode* re, snode_vector& minterms) {
// extract character predicates from the regex
expr_ref_vector preds(m);
collect_re_predicates(re, preds);
unsigned max_c = m_seq.max_char();
if (preds.empty()) {
// no predicates means the whole alphabet is one minterm
// represented by full_char
expr_ref fc(m_seq.re.mk_full_char(m_str_sort), m);
minterms.push_back(mk(fc));
return;
}
// generate minterms as conjunctions/negations of predicates
// for n predicates, there are up to 2^n minterms
unsigned n = preds.size();
// cap at reasonable size to prevent exponential blowup
if (n > 20)
n = 20;
for (unsigned mask = 0; mask < (1u << n); ++mask) {
expr_ref_vector conj(m);
for (unsigned i = 0; i < n; ++i) {
if (mask & (1u << i))
conj.push_back(preds.get(i));
else
conj.push_back(m_seq.re.mk_complement(preds.get(i)));
std::vector<char_set> classes;
classes.push_back(char_set::full(max_c));
for (expr* p : preds) {
char_set p_set;
expr* lo = nullptr, *hi = nullptr;
if (m_seq.re.is_range(p, lo, hi)) {
unsigned vlo = 0, vhi = 0;
if (m_seq.is_const_char(lo, vlo) && m_seq.is_const_char(hi, vhi)) {
if (vlo <= vhi)
p_set = char_set(char_range(vlo, vhi + 1));
}
}
SASSERT(!conj.empty());
// intersect all terms
expr_ref mt(conj.get(0), m);
for (unsigned i = 1; i < conj.size(); ++i)
mt = m_seq.re.mk_inter(mt, conj.get(i));
minterms.push_back(mk(mt));
else if (m_seq.re.is_to_re(p)) {
expr* str_arg = nullptr;
expr* ch_expr = nullptr;
unsigned char_val = 0;
if (m_seq.re.is_to_re(p, str_arg) &&
m_seq.str.is_unit(str_arg, ch_expr) &&
m_seq.is_const_char(ch_expr, char_val)) {
p_set.add(char_val);
}
}
else if (m_seq.re.is_full_char(p))
p_set = char_set::full(max_c);
else
continue;
if (p_set.is_empty() || p_set.is_full(max_c))
continue;
std::vector<char_set> next_classes;
char_set p_comp = p_set.complement(max_c);
for (auto const& c : classes) {
char_set in_c = c.intersect_with(p_set);
char_set out_c = c.intersect_with(p_comp);
if (!in_c.is_empty())
next_classes.push_back(in_c);
if (!out_c.is_empty())
next_classes.push_back(out_c);
}
classes = std::move(next_classes);
}
for (auto const& c : classes) {
expr_ref class_expr(m);
for (auto const& r : c.ranges()) {
zstring z_lo(r.m_lo);
zstring z_hi(r.m_hi - 1);
expr_ref c_lo(m_seq.str.mk_string(z_lo), m);
expr_ref c_hi(m_seq.str.mk_string(z_hi), m);
expr_ref range_expr(m_seq.re.mk_range(c_lo, c_hi), m);
if (!class_expr)
class_expr = range_expr;
else
class_expr = m_seq.re.mk_union(class_expr, range_expr);
}
if (class_expr)
minterms.push_back(mk(class_expr));
}
}
@ -660,3 +769,9 @@ namespace euf {
}