3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-06 17:44:08 +00:00

seq rewriting fixes

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2018-11-22 10:48:49 -08:00
parent 7b2590c026
commit 498fa87993
4 changed files with 184 additions and 142 deletions

View file

@ -1,5 +1,9 @@
# RC2 algorithm
# basic version without optimizations
# basic version with some optimizations
# - process soft constraints in order of highest values first.
# - extract multiple cores, not just one
# - use built-in cardinality constraints, cheap core minimization.
#
# See also https://github.com/pysathq/pysat and papers in CP 2014, JSAT 2015.
from z3 import *
@ -7,7 +11,6 @@ from z3 import *
def tt(s, f):
return is_true(s.model().eval(f))
def add(Ws, f, w):
Ws[f] = w + (Ws[f] if f in Ws else 0)
@ -19,46 +22,106 @@ def sub(Ws, f, w):
del(Ws[f])
class RC2:
def __init__(self, s):
self.bounds = {}
self.names = {}
self.solver = s
self.solver.set("sat.cardinality.solver", True)
self.solver.set("sat.core.minimize", True)
self.solver.set("sat.core.minimize_partial", True)
def at_most(self, S, k):
fml = AtMost(S + [k])
fml = simplify(AtMost(S + [k]))
if fml in self.names:
return self.names[fml]
name = Bool("%s" % fml)
self.solver.add(Implies(name, fml))
self.bounds[name] = (S, k)
sel.names[fml] = name
return name
def print_cost(self):
print("max cost", self.max_cost, "min cost", self.min_cost)
def update_max_cost(self):
self.max_cost = min(self.max_cost, self.get_cost())
self.print_cost()
# sort W, and incrementally add elements of W
# in sorted order to prefer cores with high weight.
def check(self, Ws):
ws = sorted(list(Ws), lambda f,w : -w)
# print(ws)
i = 0
while i < len(ws):
j = i
# increment j until making 5% progress or exhausting equal weight entries
while (j < len(ws) and ws[j][1] == ws[i][1]) or (i > 0 and (i - j)*20 < len(ws)):
j += 1
i = j
r = self.solver.check(ws[j][0] for j in range(i))
if r == sat:
self.update_max_cost()
else:
return r
return sat
def get_cost(self):
return sum(self.Ws0[c] for c in self.Ws0 if not tt(self.solver, c))
# Retrieve independendent cores from Ws
def get_cores(self, Ws):
cores = []
while unsat == self.check(Ws):
core = list(self.solver.unsat_core())
print (self.solver.statistics())
if not core:
return unsat
w = min([Ws[c] for c in core])
for f in core:
sub(Ws, f, w)
cores += [(core, w)]
self.update_max_cost()
return cores
# Add new soft constraints to replace core
# with weight w. Allow to weaken at most
# one element of core. Elements that are
# cardinality constraints are weakened by
# increasing their bounds. Non-cardinality
# constraints are weakened to "true". They
# correspond to the constraint Not(s) <= 0,
# so weakening produces Not(s) <= 1, which
# is a tautology.
def update_bounds(self, Ws, core, w):
for f in core:
if f in self.bounds:
S, k = self.bounds[f]
if k + 1 < len(S):
add(Ws, self.at_most(S, k + 1), w)
add(Ws, self.at_most([mk_not(f) for f in core], 1), w)
# Ws are weighted soft constraints
# Whenever there is an unsatisfiable core over ws
# increase the limit of each soft constraint from a bound
# and create a soft constraint that limits the number of
# increased bounds to be at most one.
def maxsat(self, Ws):
cost = 0
Ws0 = Ws.copy()
while unsat == self.solver.check([f for f in Ws]):
core = list(self.solver.unsat_core())
print (self.solver.statistics())
print("Core", core)
if not core:
return unsat
w = min([Ws[c] for c in core])
cost += w
for f in core:
sub(Ws, f, w)
for f in core:
if f in self.bounds:
S, k = self.bounds[f]
if k + 1 < len(S):
add(Ws, self.at_most(S, k + 1), w)
add(Ws, self.at_most([mk_not(f) for f in core], 1), w)
return cost, { f for f in Ws0 if tt(self.solver, f) }
self.min_cost = 0
self.max_cost = sum(Ws[c] for c in Ws)
self.Ws0 = Ws.copy()
while True:
cores = self.get_cores(Ws)
if not cores:
break
if cores == unsat:
return unsat
for (core, w) in cores:
self.min_cost += w
self.print_cost()
self.update_bounds(Ws, core, w)
return sel.min_cost, { f for f in self.Ws0 if not tt(self.solver, f) }
def from_file(self, file):
opt = Optimize()
@ -71,12 +134,11 @@ class RC2:
add(Ws, f.arg(0), f.arg(2).as_long())
return self.maxsat(Ws)
def main(file):
s = SolverFor("QF_FD")
rc2 = RC2(s)
set_param(verbose=0)
cost, trues = rc2.from_file(file)
cost, falses = rc2.from_file(file)
print(cost)
print(s.statistics())

View file

@ -675,25 +675,17 @@ br_status seq_rewriter::mk_seq_contains(expr* a, expr* b, expr_ref& result) {
result = m().mk_bool_val(c.contains(d));
return BR_DONE;
}
expr* x = nullptr, *y, *z;
if (m_util.str.is_extract(b, x, y, z) && x == a) {
result = m().mk_true();
return BR_DONE;
}
// check if subsequence of a is in b.
expr_ref_vector as(m()), bs(m());
if (m_util.str.is_string(a, c)) {
for (unsigned i = 0; i < c.length(); ++i) {
as.push_back(m_util.str.mk_unit(m_util.str.mk_char(c, i)));
}
}
else {
m_util.str.get_concat(a, as);
}
if (m_util.str.is_string(b, d)) {
for (unsigned i = 0; i < d.length(); ++i) {
bs.push_back(m_util.str.mk_unit(m_util.str.mk_char(d, i)));
}
}
else {
m_util.str.get_concat(b, bs);
}
bool all_values = true;
m_util.str.get_concat_units(a, as);
m_util.str.get_concat_units(b, bs);
TRACE("seq", tout << mk_pp(a, m()) << " contains " << mk_pp(b, m()) << "\n";);
if (bs.empty()) {
@ -701,24 +693,21 @@ br_status seq_rewriter::mk_seq_contains(expr* a, expr* b, expr_ref& result) {
return BR_DONE;
}
for (unsigned i = 0; all_values && i < bs.size(); ++i) {
all_values = m().is_value(bs[i].get());
if (as.empty()) {
result = m().mk_eq(b, m_util.str.mk_empty(m().get_sort(b)));
return BR_REWRITE2;
}
bool found = false;
for (unsigned i = 0; !found && i < as.size(); ++i) {
all_values &= m().is_value(as[i].get());
if (bs.size() <= as.size() - i) {
unsigned j = 0;
for (; j < bs.size() && as[j+i].get() == bs[j].get(); ++j) {};
found = j == bs.size();
for (unsigned i = 0; bs.size() + i <= as.size(); ++i) {
unsigned j = 0;
for (; j < bs.size() && as.get(j+i) == bs.get(j); ++j) {};
if (j == bs.size()) {
result = m().mk_true();
return BR_DONE;
}
}
if (found) {
result = m().mk_true();
return BR_DONE;
}
if (all_values) {
std::function<bool(expr*)> is_value = [&](expr* e) { return m().is_value(e); };
if (bs.forall(is_value) && as.forall(is_value)) {
result = m().mk_false();
return BR_DONE;
}
@ -733,26 +722,12 @@ br_status seq_rewriter::mk_seq_contains(expr* a, expr* b, expr_ref& result) {
}
}
if (as.empty()) {
result = m().mk_eq(b, m_util.str.mk_empty(m().get_sort(b)));
return BR_REWRITE2;
}
if (bs.size() == 1 && m_util.str.is_string(bs[0].get(), c)) {
for (auto a_i : as) {
if (m_util.str.is_string(a_i, d) && d.contains(c)) {
result = m().mk_true();
return BR_DONE;
}
}
}
unsigned offs = 0;
unsigned sz = as.size();
expr* b0 = bs[0].get();
expr* bL = bs[bs.size()-1].get();
expr* b0 = bs.get(0);
expr* bL = bs.get(bs.size()-1);
for (; offs < as.size() && cannot_contain_prefix(as[offs].get(), b0); ++offs) {}
for (; sz > offs && cannot_contain_suffix(as[sz-1].get(), bL); --sz) {}
for (; sz > offs && cannot_contain_suffix(as.get(sz-1), bL); --sz) {}
if (offs == sz) {
result = m().mk_eq(b, m_util.str.mk_empty(m().get_sort(b)));
return BR_REWRITE2;
@ -763,24 +738,15 @@ br_status seq_rewriter::mk_seq_contains(expr* a, expr* b, expr_ref& result) {
return BR_REWRITE2;
}
expr* x, *y, *z;
if (m_util.str.is_extract(b, x, y, z) && x == a) {
result = m().mk_true();
return BR_DONE;
}
std::function<bool(expr*)> is_unit = [&](expr *e) { return m_util.str.is_unit(e); };
if (bs.forall(is_unit) && as.forall(is_unit)) {
if (as.size() < bs.size()) {
result = m().mk_false();
return BR_DONE;
}
expr_ref_vector ors(m());
for (unsigned i = 0; i < as.size() - bs.size() + 1; ++i) {
for (unsigned i = 0; i + bs.size() <= as.size(); ++i) {
expr_ref_vector ands(m());
for (unsigned j = 0; j < bs.size(); ++j) {
ands.push_back(m().mk_eq(as[i + j].get(), bs[j].get()));
ands.push_back(m().mk_eq(as.get(i + j), bs.get(j)));
}
ors.push_back(::mk_and(ands));
}
@ -797,46 +763,41 @@ br_status seq_rewriter::mk_seq_contains(expr* a, expr* b, expr_ref& result) {
br_status seq_rewriter::mk_seq_at(expr* a, expr* b, expr_ref& result) {
zstring c;
rational r;
if (m_autil.is_numeral(b, r)) {
if (r.is_neg()) {
result = m_util.str.mk_empty(m().get_sort(a));
return BR_DONE;
}
unsigned len = 0;
bool bounded = min_length(1, &a, len);
if (bounded && r >= rational(len)) {
result = m_util.str.mk_empty(m().get_sort(a));
return BR_DONE;
}
if (m_util.str.is_string(a, c)) {
if (r.is_unsigned() && r < rational(c.length())) {
result = m_util.str.mk_string(c.extract(r.get_unsigned(), 1));
}
else {
result = m_util.str.mk_empty(m().get_sort(a));
}
return BR_DONE;
}
if (r.is_unsigned()) {
len = r.get_unsigned();
expr_ref_vector as(m());
m_util.str.get_concat(a, as);
for (unsigned i = 0; i < as.size(); ++i) {
if (m_util.str.is_unit(as[i].get())) {
if (len == 0) {
result = as[i].get();
return BR_DONE;
}
--len;
}
else {
return BR_FAILED;
}
}
}
if (!m_autil.is_numeral(b, r)) {
return BR_FAILED;
}
return BR_FAILED;
if (r.is_neg()) {
result = m_util.str.mk_empty(m().get_sort(a));
return BR_DONE;
}
if (!r.is_unsigned()) {
return BR_FAILED;
}
unsigned len = r.get_unsigned();
expr_ref_vector as(m());
m_util.str.get_concat_units(a, as);
for (unsigned i = 0; i < as.size(); ++i) {
expr* a = as.get(i);
if (m_util.str.is_unit(a)) {
if (len == i) {
result = a;
return BR_REWRITE1;
}
}
else if (i > 0) {
SASSERT(len >= i);
result = m_util.str.mk_concat(as.size() - i, as.c_ptr() + i);
result = m().mk_app(m_util.get_family_id(), OP_SEQ_AT, result, m_autil.mk_int(len - i));
return BR_REWRITE2;
}
else {
return BR_FAILED;
}
}
result = m_util.str.mk_empty(m().get_sort(a));
return BR_DONE;
}
br_status seq_rewriter::mk_seq_index(expr* a, expr* b, expr* c, expr_ref& result) {

View file

@ -933,16 +933,16 @@ app* seq_util::mk_skolem(symbol const& name, unsigned n, expr* const* args, sort
return m.mk_app(f, n, args);
}
app* seq_util::str::mk_string(zstring const& s) { return u.seq.mk_string(s); }
app* seq_util::str::mk_string(zstring const& s) const {
return u.seq.mk_string(s);
}
app* seq_util::str::mk_char(zstring const& s, unsigned idx) {
app* seq_util::str::mk_char(zstring const& s, unsigned idx) const {
bv_util bvu(m);
return bvu.mk_numeral(s[idx], s.num_bits());
}
app* seq_util::str::mk_char(char ch) {
app* seq_util::str::mk_char(char ch) const {
zstring s(ch, zstring::ascii);
return mk_char(s, 0);
}
@ -969,6 +969,24 @@ void seq_util::str::get_concat(expr* e, expr_ref_vector& es) const {
}
}
void seq_util::str::get_concat_units(expr* e, expr_ref_vector& es) const {
expr* e1, *e2;
while (is_concat(e, e1, e2)) {
get_concat_units(e1, es);
e = e2;
}
zstring s;
if (is_string(e, s)) {
unsigned sz = s.length();
for (unsigned j = 0; j < sz; ++j) {
es.push_back(mk_unit(mk_char(s, j)));
}
}
else if (!is_empty(e)) {
es.push_back(e);
}
}
app* seq_util::re::mk_loop(expr* r, unsigned lo) {
parameter param(lo);
return m.mk_app(m_fid, OP_RE_LOOP, 1, &param, 1, &r);

View file

@ -232,26 +232,26 @@ public:
public:
str(seq_util& u): u(u), m(u.m), m_fid(u.m_fid) {}
sort* mk_seq(sort* s) { parameter param(s); return m.mk_sort(m_fid, SEQ_SORT, 1, &param); }
sort* mk_seq(sort* s) const { parameter param(s); return m.mk_sort(m_fid, SEQ_SORT, 1, &param); }
sort* mk_string_sort() const { return m.mk_sort(m_fid, _STRING_SORT, 0, nullptr); }
app* mk_empty(sort* s) const { return m.mk_const(m.mk_func_decl(m_fid, OP_SEQ_EMPTY, 0, nullptr, 0, (expr*const*)nullptr, s)); }
app* mk_string(zstring const& s);
app* mk_string(symbol const& s) { return u.seq.mk_string(s); }
app* mk_char(char ch);
app* mk_string(zstring const& s) const;
app* mk_string(symbol const& s) const { return u.seq.mk_string(s); }
app* mk_char(char ch) const;
app* mk_concat(expr* a, expr* b) const { expr* es[2] = { a, b }; return m.mk_app(m_fid, OP_SEQ_CONCAT, 2, es); }
app* mk_concat(expr* a, expr* b, expr* c) { return mk_concat(a, mk_concat(b, c)); }
app* mk_concat(expr* a, expr* b, expr* c) const { return mk_concat(a, mk_concat(b, c)); }
expr* mk_concat(unsigned n, expr* const* es) const { if (n == 1) return es[0]; SASSERT(n > 1); return m.mk_app(m_fid, OP_SEQ_CONCAT, n, es); }
expr* mk_concat(expr_ref_vector const& es) const { return mk_concat(es.size(), es.c_ptr()); }
app* mk_length(expr* a) const { return m.mk_app(m_fid, OP_SEQ_LENGTH, 1, &a); }
app* mk_substr(expr* a, expr* b, expr* c) { expr* es[3] = { a, b, c }; return m.mk_app(m_fid, OP_SEQ_EXTRACT, 3, es); }
app* mk_contains(expr* a, expr* b) { expr* es[2] = { a, b }; return m.mk_app(m_fid, OP_SEQ_CONTAINS, 2, es); }
app* mk_prefix(expr* a, expr* b) { expr* es[2] = { a, b }; return m.mk_app(m_fid, OP_SEQ_PREFIX, 2, es); }
app* mk_suffix(expr* a, expr* b) { expr* es[2] = { a, b }; return m.mk_app(m_fid, OP_SEQ_SUFFIX, 2, es); }
app* mk_index(expr* a, expr* b, expr* i) { expr* es[3] = { a, b, i}; return m.mk_app(m_fid, OP_SEQ_INDEX, 3, es); }
app* mk_unit(expr* u) { return m.mk_app(m_fid, OP_SEQ_UNIT, 1, &u); }
app* mk_char(zstring const& s, unsigned idx);
app* mk_itos(expr* i) { return m.mk_app(m_fid, OP_STRING_ITOS, 1, &i); }
app* mk_stoi(expr* s) { return m.mk_app(m_fid, OP_STRING_STOI, 1, &s); }
app* mk_substr(expr* a, expr* b, expr* c) const { expr* es[3] = { a, b, c }; return m.mk_app(m_fid, OP_SEQ_EXTRACT, 3, es); }
app* mk_contains(expr* a, expr* b) const { expr* es[2] = { a, b }; return m.mk_app(m_fid, OP_SEQ_CONTAINS, 2, es); }
app* mk_prefix(expr* a, expr* b) const { expr* es[2] = { a, b }; return m.mk_app(m_fid, OP_SEQ_PREFIX, 2, es); }
app* mk_suffix(expr* a, expr* b) const { expr* es[2] = { a, b }; return m.mk_app(m_fid, OP_SEQ_SUFFIX, 2, es); }
app* mk_index(expr* a, expr* b, expr* i) const { expr* es[3] = { a, b, i}; return m.mk_app(m_fid, OP_SEQ_INDEX, 3, es); }
app* mk_unit(expr* u) const { return m.mk_app(m_fid, OP_SEQ_UNIT, 1, &u); }
app* mk_char(zstring const& s, unsigned idx) const;
app* mk_itos(expr* i) const { return m.mk_app(m_fid, OP_STRING_ITOS, 1, &i); }
app* mk_stoi(expr* s) const { return m.mk_app(m_fid, OP_STRING_STOI, 1, &s); }
bool is_string(expr const * n) const { return is_app_of(n, m_fid, OP_STRING_CONST); }
@ -304,6 +304,7 @@ public:
MATCH_UNARY(is_unit);
void get_concat(expr* e, expr_ref_vector& es) const;
void get_concat_units(expr* e, expr_ref_vector& es) const;
expr* get_leftmost_concat(expr* e) const { expr* e1, *e2; while (is_concat(e, e1, e2)) e = e1; return e; }
expr* get_rightmost_concat(expr* e) const { expr* e1, *e2; while (is_concat(e, e1, e2)) e = e2; return e; }
};