3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-22 16:45:31 +00:00

implement if-then-else BDD normal form for derivatives

(code compiles but is still buggy)
This commit is contained in:
calebstanford-msr 2020-06-05 15:01:42 -04:00
parent 460068eade
commit fdc6df1b17
3 changed files with 246 additions and 97 deletions

View file

@ -690,14 +690,6 @@ br_status seq_rewriter::mk_app_core(func_decl * f, unsigned num_args, expr * con
SASSERT(num_args == 1);
st = mk_str_stoi(args[0], result);
break;
// case OP_ITE:
// // Rewrite ITEs in the case of regexes
// SASSERT(num_args == 3);
// if (m_util.is_re(args[1])) {
// SASSERT(m_util.is_re(args[2]));
// st = mk_re_ite(args[0], args[1], args[2], result);
// }
// break;
case _OP_STRING_CONCAT:
case _OP_STRING_PREFIX:
case _OP_STRING_SUFFIX:
@ -2140,18 +2132,19 @@ expr_ref seq_rewriter::re_predicate(expr* cond, sort* seq_sort) {
}
expr_ref seq_rewriter::is_nullable_rec(expr* r) {
std::cout << "is_nullable_rec" << std::endl;
expr_ref result(m_op_cache.find(_OP_RE_IS_NULLABLE, r, nullptr), m());
std::cout << "n";
expr_ref result(m_op_cache.find(_OP_RE_IS_NULLABLE, r, nullptr, nullptr), m());
if (!result) {
std::cout << "(m) ";
result = is_nullable(r);
m_op_cache.insert(_OP_RE_IS_NULLABLE, r, nullptr, result);
m_op_cache.insert(_OP_RE_IS_NULLABLE, r, nullptr, nullptr, result);
} else {
std::cout << "(h) ";
}
return result;
}
expr_ref seq_rewriter::is_nullable(expr* r) {
std::cout << "is_nullable" << std::endl;
// std::cout << "call to is_nullable(" << expr_ref(r, m()) << ")" << std::endl;
SASSERT(m_util.is_re(r));
expr* r1 = nullptr, *r2 = nullptr, *cond = nullptr;
unsigned lo = 0, hi = 0;
@ -2303,7 +2296,6 @@ br_status seq_rewriter::mk_re_derivative(expr* ele, expr* r, expr_ref& result) {
expr *r1 = nullptr, *r2 = nullptr, *p = nullptr;
unsigned lo = 0, hi = 0;
if (re().is_concat(r, r1, r2)) {
std::cout << "is_nullable -- from concat" << std::endl;
expr_ref is_n = is_nullable(r1);
expr_ref dr1(re().mk_derivative(ele, r1), m());
expr_ref dr2(re().mk_derivative(ele, r2), m());
@ -2451,31 +2443,176 @@ br_status seq_rewriter::mk_re_derivative(expr* ele, expr* r, expr_ref& result) {
}
/*
Optimizations for ITEs of regexes, since they come up frequently
in calculating derivatives.
Combine two if-then-else expressions in BDD form.
lifting functions (lift_ites, lift_ites_throttled):
push all ite expressions to the top level.
Definition of BDD form:
if-then-elses are pushed outwards
and sorted by condition ID (cond->get_id()), from largest on
the outside to smallest on the inside.
Duplicate nested conditions are eliminated.
rewriting (mk_re_ite):
ite(not c, r1, r2) -> ite(c, r2, r1)
ite(c, ite(c, r1, r2), r3)) -> ite(c, r1, r3)
ite(c, r1, ite(c, r2, r3)) -> ite(c, r1, r3)
ite(c1, ite(c2, r1, r2), r3) where id of c1 < id of c2 ->
ite(c2, ite(c1, r1, r3), ite(c1, r2, r3))
ite(c1, r1, ite(c2, r2, r3)) where id of c1 < id of c2 ->
ite(c2, ite(c1, r1, r2), ite(c1, r1, r3))
Preconditions:
- EITHER k is a binary op code on REs (re.union, re.inter, etc.)
and cond is nullptr,
OR k is if-then-else (OP.ITE) and cond is the condition.
- a and b are in BDD form.
Postcondition: result is in BDD form.
if-then-elses are pushed outwards
and sorted by condition ID (cond->get_id()), from largest on
the outside to smallest on the inside.
Uses op cache (memoization) to avoid duplicating work for the same
pair of pointers.
*/
expr_ref seq_rewriter::lift_ites(expr* a, bool lift_over_union, bool lift_over_inter) {
expr_ref result(m());
result = a;
expr_ref seq_rewriter::combine_ites(decl_kind k, expr* a, expr* b, expr* cond) {
std::cout << "c";
expr_ref result(m_op_cache.find(k, a, b, cond), m());
if (result) {
std::cout << "(h) ";
return result;
}
std::cout << "(m) ";
SASSERT((k == OP_ITE) == (cond != nullptr));
expr *acond = nullptr, *a1 = nullptr, *a2 = nullptr,
*bcond = nullptr, *b1 = nullptr, *b2 = nullptr;
expr_ref result1(m()), result2(m());
if (k == OP_ITE) {
if (m().is_ite(a, acond, a1, a2) &&
cond->get_id() < acond->get_id()) {
// Push ITE inwards on first arg
result1 = combine_ites(k, a1, b, cond);
result2 = combine_ites(k, a2, b, cond);
result = combine_ites(k, result1, result2, acond);
}
else if (m().is_ite(a, acond, a1, a2) &&
cond == acond) {
// Collapse ITE on first arg
result = combine_ites(k, a1, b, cond);
}
else if (m().is_ite(b, bcond, b1, b2) &&
cond->get_id() < bcond->get_id()) {
// Push ITE inwards on second arg
result1 = combine_ites(k, a, b1, cond);
result2 = combine_ites(k, a, b2, cond);
result = combine_ites(k, result1, result2, bcond);
}
else if (m().is_ite(b, bcond, b1, b2) &&
cond == bcond) {
// Collapse ITE on second arg
result = combine_ites(k, a, b2, cond);
}
else {
// Apply ITE -- no simplification required
result = m().mk_ite(a, b, cond);
}
}
else if (m().is_ite(a, acond, a1, a2)) {
// Push binary op inwards on first arg
result1 = combine_ites(k, a1, b, nullptr);
result2 = combine_ites(k, a2, b, nullptr);
result = combine_ites(OP_ITE, result1, result2, acond);
}
else if (m().is_ite(b, bcond, b1, b2)) {
// Push binary op inwards on second arg
result1 = combine_ites(k, a, b1, nullptr);
result2 = combine_ites(k, a, b2, nullptr);
result = combine_ites(OP_ITE, result1, result2, bcond);
}
else {
// Apply binary op (a and b are free of ITE)
result = m().mk_app(get_fid(), k, a, b);
}
// Save result before returning
m_op_cache.insert(k, a, b, cond, result);
return result;
}
/*
Lift if-then-else expressions to the top level, enforcing a BDD form.
Postcondition: result is in BDD form.
- Alternatively, if lift_over_union and/or lift_over_inter is false,
then result is a disjunction and/or conjunciton of expressions in
BDD form. (Even in this case, ITE is still lifted at lower levels,
just not at the top level.)
- Note that the result may not be fully simplified (particularly the
nested expressions inside if-then-else). Simplification should be
called afterwards.
Cost: Causes potential blowup in the size of an expression (when
expanded out), but keeps the representation compact (subexpressions
are shared).
Used by: the regex solver in seq_regex.cpp when dealing with
derivatives of a regex by a symbolic character. Enables efficient
representation in unfolding string in regex constraints.
*/
expr_ref seq_rewriter::lift_ites(expr* r, bool lift_over_union, bool lift_over_inter) {
std::cout << "l ";
decl_kind k = to_app(r)->get_decl_kind();
family_id fid = get_fid();
expr *r1 = nullptr, *r2 = nullptr, *cond = nullptr, *ele = nullptr;
unsigned lo = 0, hi = 0;
expr_ref result(m()), result1(m()), result2(m());
if ((re().is_union(r, r1, r2) && !lift_over_union) ||
(re().is_intersection(r, r1, r2) && !lift_over_inter)) {
// Preserve unions and/or intersections
result1 = lift_ites(r1, lift_over_union, lift_over_inter);
result2 = lift_ites(r2, lift_over_union, lift_over_inter);
result = m().mk_app(fid, k, r1, r2);
}
else if (m().is_ite(r, cond, r1, r2) ||
re().is_concat(r, r1, r2) ||
re().is_union(r, r1, r2) ||
re().is_intersection(r, r1, r2) ||
re().is_diff(r, r1, r2)) {
// Use combine_ites on the subresults
// Stop preserving unions and intersections
result1 = lift_ites(r1, true, true);
result2 = lift_ites(r2, true, true);
result = combine_ites(k, r1, r2, cond);
}
else if (re().is_star(r, r1) ||
re().is_plus(r, r1) ||
re().is_opt(r, r1) ||
re().is_complement(r, r1) ||
re().is_reverse(r, r1)) {
// Stop preserving unions and intersections
result1 = lift_ites(r1, true, true);
result = m().mk_app(fid, k, r1);
}
else if (re().is_derivative(r, ele, r1)) {
result1 = lift_ites(r1, true, true);
result = m().mk_app(fid, k, ele, r1);
}
else if (re().is_loop(r, r1, lo)) {
result1 = lift_ites(r1, true, true);
result = re().mk_loop(result1, lo);
}
else if (re().is_loop(r, r1, lo, hi)) {
result1 = lift_ites(r1, true, true);
result = re().mk_loop(result1, lo, hi);
}
else {
// is_full_seq, is_empty, is_to_re, is_range, is_full_char, is_of_pred
result = r;
}
return result;
}
/*
Lift all ite expressions to the top level, but
a different "safe" version which is throttled to not
blowup the size of the expression.
Note: this function does not ensure the same BDD form that lift_ites
ensures.
*/
br_status seq_rewriter::lift_ites_throttled(func_decl* f, unsigned n, expr* const* args, expr_ref& result) {
expr* c = nullptr, *t = nullptr, *e = nullptr;
for (unsigned i = 0; i < n; ++i) {
if (m().is_ite(args[i], c, t, e) &&
for (unsigned i = 0; i < n; ++i) {
if (m().is_ite(args[i], c, t, e) &&
(get_depth(t) <= 2 || t->get_ref_count() == 1 ||
get_depth(e) <= 2 || e->get_ref_count() == 1)) {
ptr_buffer<expr> new_args;
@ -2491,40 +2628,50 @@ br_status seq_rewriter::lift_ites_throttled(func_decl* f, unsigned n, expr* cons
return BR_FAILED;
}
br_status seq_rewriter::mk_re_ite(expr* cond, expr* r1, expr* r2, expr_ref& result) {
VERIFY(m_util.is_re(r1));
VERIFY(m_util.is_re(r2));
expr *c = nullptr, *ra = nullptr, *rb = nullptr;
if (m().is_not(cond, c)) {
result = m().mk_ite(c, r2, r1);
return BR_REWRITE1;
}
if (m().is_ite(r1, c, ra, rb)) {
if (m().are_equal(c, cond)) {
result = m().mk_ite(cond, ra, r2);
return BR_REWRITE1;
}
if (cond->get_id() < c->get_id()) {
expr *result1 = m().mk_ite(cond, ra, r2);
expr *result2 = m().mk_ite(cond, rb, r2);
result = m().mk_ite(c, result1, result2);
return BR_REWRITE2;
}
}
if (m().is_ite(r2, c, ra, rb)) {
if (m().are_equal(c, cond)) {
result = m().mk_ite(cond, r1, rb);
return BR_REWRITE1;
}
if (cond->get_id() < c->get_id()) {
expr *result1 = m().mk_ite(cond, r1, ra);
expr* result2 = m().mk_ite(cond, r1, rb);
result = m().mk_ite(c, result1, result2);
return BR_REWRITE2;
}
}
return BR_DONE;
}
// /*
// Rewrite rules for ITEs of regexes.
// ite(not c, r1, r2) -> ite(c, r2, r1)
// ite(c, ite(c, r1, r2), r3)) -> ite(c, r1, r3)
// ite(c, r1, ite(c, r2, r3)) -> ite(c, r1, r3)
// ite(c1, ite(c2, r1, r2), r3) where id of c1 < id of c2 ->
// ite(c2, ite(c1, r1, r3), ite(c1, r2, r3))
// ite(c1, r1, ite(c2, r2, r3)) where id of c1 < id of c2 ->
// ite(c2, ite(c1, r1, r2), ite(c1, r1, r3))
// */
// br_status seq_rewriter::rewrite_re_ite(expr* cond, expr* r1, expr* r2, expr_ref& result) {
// VERIFY(m_util.is_re(r1));
// VERIFY(m_util.is_re(r2));
// expr *c = nullptr, *ra = nullptr, *rb = nullptr;
// if (m().is_not(cond, c)) {
// result = m().mk_ite(c, r2, r1);
// return BR_REWRITE1;
// }
// if (m().is_ite(r1, c, ra, rb)) {
// if (m().are_equal(c, cond)) {
// result = m().mk_ite(cond, ra, r2);
// return BR_REWRITE1;
// }
// if (cond->get_id() < c->get_id()) {
// expr *result1 = m().mk_ite(cond, ra, r2);
// expr *result2 = m().mk_ite(cond, rb, r2);
// result = m().mk_ite(c, result1, result2);
// return BR_REWRITE2;
// }
// }
// if (m().is_ite(r2, c, ra, rb)) {
// if (m().are_equal(c, cond)) {
// result = m().mk_ite(cond, r1, rb);
// return BR_REWRITE1;
// }
// if (cond->get_id() < c->get_id()) {
// expr *result1 = m().mk_ite(cond, r1, ra);
// expr* result2 = m().mk_ite(cond, r1, rb);
// result = m().mk_ite(c, result1, result2);
// return BR_REWRITE2;
// }
// }
// return BR_DONE;
// }
/*
* pattern match against all ++ "abc" ++ all ++ "def" ++ all regexes.
@ -2703,7 +2850,6 @@ br_status seq_rewriter::mk_str_in_regexp(expr* a, expr* b, expr_ref& result) {
return BR_REWRITE1;
}
if (str().is_empty(a)) {
std::cout << "is_nullable -- from str.in_re" << std::endl;
result = is_nullable(b);
if (str().is_in_re(result))
return BR_DONE;
@ -3963,18 +4109,20 @@ seq_rewriter::op_cache::op_cache(ast_manager& m):
m_trail(m)
{}
expr* seq_rewriter::op_cache::find(decl_kind op, expr* a, expr* b) {
op_entry e(op, a, b, nullptr);
expr* seq_rewriter::op_cache::find(decl_kind op, expr* a, expr* b, expr* c) {
op_entry e(op, a, b, c, nullptr);
m_table.find(e);
return e.r;
}
void seq_rewriter::op_cache::insert(decl_kind op, expr* a, expr* b, expr* r) {
void seq_rewriter::op_cache::insert(decl_kind op,
expr* a, expr* b, expr* c, expr* r) {
cleanup();
if (a) m_trail.push_back(a);
if (b) m_trail.push_back(b);
if (c) m_trail.push_back(c);
if (r) m_trail.push_back(r);
m_table.insert(op_entry(op, a, b, r));
m_table.insert(op_entry(op, a, b, c, r));
}
void seq_rewriter::op_cache::cleanup() {

View file

@ -118,20 +118,23 @@ class seq_rewriter {
class op_cache {
struct op_entry {
decl_kind k;
expr* a, *b, *r;
op_entry(decl_kind k, expr* a, expr* b, expr* r): k(k), a(a), b(b), r(r) {}
op_entry():k(0), a(nullptr), b(nullptr), r(nullptr) {}
expr* a, *b, *c, *r;
op_entry(decl_kind k, expr* a, expr* b, expr* c, expr* r):
k(k), a(a), b(b), c(c), r(r) {}
op_entry():k(0), a(nullptr), b(nullptr), c(nullptr), r(nullptr) {}
};
struct hash_entry {
unsigned operator()(op_entry const& e) const {
return mk_mix(e.k, e.a ? e.a->get_id() : 0, e.b ? e.b->get_id() : 0);
return combine_hash(e.k, mk_mix(e.a ? e.a->get_id() : 0,
e.b ? e.b->get_id() : 0,
e.c ? e.c->get_id() : 0));
}
};
struct eq_entry {
bool operator()(op_entry const& a, op_entry const& b) const {
return a.k == b.k && a.a == b.a && a.b == b.b;
return a.k == b.k && a.a == b.a && a.b == b.b && a.c == b.c;
}
};
@ -145,8 +148,8 @@ class seq_rewriter {
public:
op_cache(ast_manager& m);
expr* find(decl_kind op, expr* a, expr* b);
void insert(decl_kind op, expr* a, expr* b, expr* r);
expr* find(decl_kind op, expr* a, expr* b, expr* c);
void insert(decl_kind op, expr* a, expr* b, expr* c, expr* r);
};
seq_util m_util;
@ -219,15 +222,9 @@ class seq_rewriter {
br_status mk_re_reverse(expr* r, expr_ref& result);
br_status mk_re_derivative(expr* ele, expr* r, expr_ref& result);
// if-then-else rewriting support (for REs)
br_status mk_re_ite(expr* cond, expr* r1, expr* r2, expr_ref& result);
expr_ref lift_ites(expr* a, bool lift_over_union = true, bool lift_over_inter = true);
br_status lift_ites_throttled(func_decl* f, unsigned n, expr* const* args, expr_ref& result);
br_status reduce_re_eq(expr* a, expr* b, expr_ref& result);
br_status reduce_re_is_empty(expr* r, expr_ref& result);
bool non_overlap(expr_ref_vector const& p1, expr_ref_vector const& p2) const;
bool non_overlap(zstring const& p1, zstring const& p2) const;
bool rewrite_contains_pattern(expr* a, expr* b, expr_ref& result);
@ -271,6 +268,9 @@ class seq_rewriter {
void get_cofactors(expr* r, expr_ref_vector& conds, expr_ref_pair_vector& result);
void intersect(unsigned lo, unsigned hi, svector<std::pair<unsigned, unsigned>>& ranges);
expr_ref combine_ites(decl_kind k, expr* a, expr* b, expr* cond);
br_status lift_ites_throttled(func_decl* f, unsigned n, expr* const* args, expr_ref& result);
public:
seq_rewriter(ast_manager & m, params_ref const & p = params_ref()):
m_util(m), m_autil(m), m_re2aut(m), m_op_cache(m), m_es(m), m_lhs(m), m_rhs(m), m_coalesce_chars(true) {
@ -318,7 +318,7 @@ public:
expr_ref is_nullable(expr* r);
expr_ref is_nullable_rec(expr* r);
// utilities for cofactors: conditions that appear in if-then-else expressions
// utilities for cofactors of if-then-else expressions
bool has_cofactor(expr* r, expr_ref& cond, expr_ref& th, expr_ref& el);
void get_cofactors(expr* r, expr_ref_pair_vector& result) {
expr_ref_vector conds(m());
@ -329,6 +329,8 @@ public:
// special case optimization for conjunctions of equalities, disequalities and ranges.
void elim_condition(expr* elem, expr_ref& cond);
// if-then-else rewriting support (for REs)
expr_ref lift_ites(expr* r, bool lift_over_union = true, bool lift_over_inter = true);
};
#endif

View file

@ -101,7 +101,7 @@ namespace smt {
expr* e = ctx.bool_var2expr(lit.var());
VERIFY(str().is_in_re(e, s, r));
std::cout << "SEQ REGEX P_IN_RE" << std::endl;
std::cout << "PI ";
TRACE("seq", tout << "propagate " << mk_pp(e, m) << "\n";);
@ -146,7 +146,7 @@ namespace smt {
}
void seq_regex::propagate_accept(literal lit) {
std::cout << "SEQ REGEX P_ACCEPT" << std::endl;
std::cout << "PA ";
if (!propagate(lit))
m_to_propagate.push_back(lit);
}
@ -200,7 +200,7 @@ namespace smt {
TRACE("seq", tout << "propagate " << mk_pp(e, m) << "\n";);
std::cout << "SEQ REGEX P" << std::endl;
std::cout << "P ";
// << mk_pp(e, m) << std::endl;
if (block_unfolding(lit, idx))
@ -222,8 +222,7 @@ namespace smt {
case l_undef:
ctx.mark_as_relevant(len_s_le_i);
return false;
case l_true:
std::cout << "is_nullable -- from prop" << std::endl;
case l_true:
is_nullable = seq_rw().is_nullable(d);
rewrite(is_nullable);
conds.push_back(~len_s_le_i);
@ -234,12 +233,9 @@ namespace smt {
break;
}
std::cout << "...MK DERIVATIVE" << std::endl;
// (accept s i R) & len(s) > i => (accept s (+ i 1) D(nth(s, i), R)) or conds
expr_ref head = th.mk_nth(s, i);
d = re().mk_derivative(head, r);
rewrite(d);
d = derivative_wrapper(head, r);
literal acc_next = th.mk_literal(sk().mk_accept(s, a().mk_int(idx + 1), d));
conds.push_back(len_s_le_i);
@ -319,14 +315,17 @@ namespace smt {
}
/*
Memoized wrapper around the regex symbolic derivative.
Memoized(TODO) wrapper around the regex symbolic derivative.
Also ensures that the derivative is written in a normalized form
with optimizations for if-then-else expressions involving the head.
*/
expr_ref seq_regex::derivative_wrapper(expr* hd, expr* r) {
std::cout << "D ";
expr_ref result = expr_ref(re().mk_derivative(hd, r), m);
rewrite(result);
// TODO
// don't lift over unions
result = seq_rw().lift_ites(result); // false, true);
rewrite(result);
return result;
}
@ -362,7 +361,7 @@ namespace smt {
*
*/
void seq_regex::propagate_is_non_empty(literal lit) {
std::cout << "SEQ REGEX P_NE" << std::endl;
std::cout << "PN ";
expr* e = ctx.bool_var2expr(lit.var()), *r = nullptr, *u = nullptr;
VERIFY(sk().is_is_non_empty(e, r, u));
expr_ref is_nullable = seq_rw().is_nullable(r);
@ -403,7 +402,7 @@ namespace smt {
is_empty(r, u) is true if r is a member of u
*/
void seq_regex::propagate_is_empty(literal lit) {
std::cout << "SEQ REGEX P_E" << std::endl;
std::cout << "PE ";
expr* e = ctx.bool_var2expr(lit.var()), *r = nullptr, *u = nullptr;
VERIFY(sk().is_is_empty(e, r, u));
expr_ref is_nullable = seq_rw().is_nullable(r);