3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-10 19:27:06 +00:00

add don't care option

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2020-01-12 17:00:05 -08:00
parent e0a41a18c3
commit 9f964be3f4
7 changed files with 201 additions and 98 deletions

View file

@ -43,11 +43,11 @@ namespace sat {
if (m_aig[id].empty()) {
continue;
}
IF_VERBOSE(3, m_cuts[id].display(verbose_stream() << "augment " << id << "\nbefore\n"));
IF_VERBOSE(10, m_cuts[id].display(verbose_stream() << "augment " << id << "\nbefore\n"));
for (node const& n : m_aig[id]) {
augment(id, n);
}
IF_VERBOSE(3, m_cuts[id].display(verbose_stream() << "after\n"));
IF_VERBOSE(10, m_cuts[id].display(verbose_stream() << "after\n"));
}
}
@ -82,7 +82,7 @@ namespace sat {
}
bool aig_cuts::insert_cut(unsigned v, cut const& c, cut_set& cs) {
if (!cs.insert(&m_on_cut_add, &m_on_cut_del, c)) {
if (!cs.insert(m_on_cut_add, m_on_cut_del, c)) {
return true;
}
m_num_cuts++;
@ -98,7 +98,7 @@ namespace sat {
}
void aig_cuts::augment_ite(unsigned v, node const& n, cut_set& cs) {
IF_VERBOSE(2, display(verbose_stream() << "augment_ite " << v << " ", n) << "\n");
IF_VERBOSE(4, display(verbose_stream() << "augment_ite " << v << " ", n) << "\n");
literal l1 = child(n, 0);
literal l2 = child(n, 1);
literal l3 = child(n, 2);
@ -172,7 +172,7 @@ namespace sat {
void aig_cuts::augment_aigN(unsigned v, node const& n, cut_set& cs) {
IF_VERBOSE(4, display(verbose_stream() << "augment_aigN " << v << " ", n) << "\n");
m_cut_set1.reset(nullptr);
m_cut_set1.reset(m_on_cut_del);
SASSERT(n.is_and() || n.is_xor());
literal lit = child(n, 0);
for (auto const& a : m_cuts[lit.var()]) {
@ -180,10 +180,10 @@ namespace sat {
if (lit.sign()) {
b.negate();
}
m_cut_set1.push_back(nullptr, b);
m_cut_set1.push_back(m_on_cut_add, b);
}
for (unsigned i = 1; i < n.size(); ++i) {
m_cut_set2.reset(nullptr);
m_cut_set2.reset(m_on_cut_del);
lit = child(n, i);
m_insertions = 0;
for (auto const& a : m_cut_set1) {
@ -212,6 +212,12 @@ namespace sat {
}
}
void aig_cuts::replace(unsigned v, cut const& src, cut const& dst) {
m_cuts[v].replace(m_on_cut_add, m_on_cut_del, src, dst);
touch(v);
}
bool aig_cuts::is_touched(node const& n) {
for (unsigned i = 0; i < n.size(); ++i) {
literal lit = m_literals[n.offset() + i];

View file

@ -138,10 +138,10 @@ namespace sat {
void on_node_add(unsigned v, node const& n);
void on_node_del(unsigned v, node const& n);
void evict(cut_set& cs, unsigned idx) { cs.evict(&m_on_cut_del, idx); }
void reset(cut_set& cs) { cs.reset(&m_on_cut_del); }
void push_back(cut_set& cs, cut const& c) { cs.push_back(&m_on_cut_add, c); }
void shrink(cut_set& cs, unsigned j) { cs.shrink(&m_on_cut_del, j); }
void evict(cut_set& cs, unsigned idx) { cs.evict(m_on_cut_del, idx); }
void reset(cut_set& cs) { cs.reset(m_on_cut_del); }
void push_back(cut_set& cs, cut const& c) { cs.push_back(m_on_cut_add, c); }
void shrink(cut_set& cs, unsigned j) { cs.shrink(m_on_cut_del, j); }
void cut2clauses(on_clause_t& on_clause, unsigned v, cut const& c);
void node2def(on_clause_t& on_clause, node const& n, literal r);
@ -166,6 +166,8 @@ namespace sat {
void cut2def(on_clause_t& on_clause, cut const& c, literal r);
void replace(unsigned v, cut const& src, cut const& dst);
std::ostream& display(std::ostream& out) const;

View file

@ -68,8 +68,9 @@ namespace sat {
for (literal lit : clause) m_assumptions.push_back(~lit);
lbool r = s.check(clause.size(), m_assumptions.c_ptr());
if (r != l_false) {
std::cout << "not validated: " << clause << "\n";
s.display(std::cout);
IF_VERBOSE(0,
verbose_stream() << "not validated: " << clause << "\n";
s.display(verbose_stream()););
std::string line;
std::getline(std::cin, line);
}
@ -78,7 +79,6 @@ namespace sat {
void aig_simplifier::ensure_validator() {
if (!m_validator) {
std::cout << "init validator\n";
params_ref p;
p.set_bool("aig", false);
p.set_bool("drat.check_unsat", false);
@ -92,15 +92,9 @@ namespace sat {
s(_s),
m_trail_size(0),
m_validator(nullptr) {
if (false) {
ensure_validator();
std::function<void(literal_vector const& clause)> _on_add =
[this](literal_vector const& clause) {
std::cout << "add " << clause << "\n"; m_validator->validate(clause);
};
m_aig_cuts.set_on_clause_add(_on_add);
}
else if (s.get_config().m_drat) {
m_config.m_enable_dont_cares = true;
m_config.m_enable_units = true;
if (s.get_config().m_drat) {
std::function<void(literal_vector const& clause)> _on_add =
[this](literal_vector const& clause) { s.m_drat.add(clause); };
std::function<void(literal_vector const& clause)> _on_del =
@ -108,6 +102,15 @@ namespace sat {
m_aig_cuts.set_on_clause_add(_on_add);
m_aig_cuts.set_on_clause_del(_on_del);
}
else if (m_config.m_validate) {
ensure_validator();
std::function<void(literal_vector const& clause)> _on_add =
[this](literal_vector const& clause) {
m_validator->validate(clause);
};
m_aig_cuts.set_on_clause_add(_on_add);
}
}
aig_simplifier::~aig_simplifier() {
@ -158,7 +161,7 @@ namespace sat {
++m_stats.m_num_calls;
do {
n = m_stats.m_num_eqs + m_stats.m_num_units;
if (m_config.m_full || true) clauses2aig();
clauses2aig();
aig2clauses();
++i;
}
@ -172,7 +175,7 @@ namespace sat {
void aig_simplifier::clauses2aig() {
// update units
for (; m_config.m_full && m_trail_size < s.init_trail_size(); ++m_trail_size) {
for (; m_config.m_enable_units && m_trail_size < s.init_trail_size(); ++m_trail_size) {
literal lit = s.trail_literal(m_trail_size);
m_aig_cuts.add_node(lit, and_op, 0, 0);
}
@ -192,7 +195,7 @@ namespace sat {
af.set(on_and);
af.set(on_ite);
clause_vector clauses(s.clauses());
if (m_config.m_full || true) clauses.append(s.learned());
if (m_config.m_add_learned) clauses.append(s.learned());
af(clauses);
std::function<void (literal_vector const&)> on_xor =
@ -229,6 +232,8 @@ namespace sat {
vector<cut_set> const& cuts = m_aig_cuts();
m_stats.m_num_cuts = m_aig_cuts.num_cuts();
add_dont_cares(cuts);
map<cut const*, unsigned, cut::hash_proc, cut::eq_proc> cut2id;
union_find_default_ctx ctx;
@ -242,20 +247,20 @@ namespace sat {
for (unsigned i = cuts.size(); i-- > 0; ) {
for (auto& c : cuts[i]) {
unsigned j = 0;
if (m_config.m_full && c.is_true()) {
if (m_config.m_enable_units && c.is_true()) {
if (s.value(i) == l_undef) {
literal lit(i, false);
validate_unit(lit);
// validate_unit(lit);
IF_VERBOSE(2, verbose_stream() << "new unit " << lit << "\n");
s.assign_unit(lit);
++m_stats.m_num_units;
}
break;
}
if (m_config.m_full && c.is_false()) {
if (m_config.m_enable_units && c.is_false()) {
if (s.value(i) == l_undef) {
literal lit(i, true);
validate_unit(lit);
// validate_unit(lit);
IF_VERBOSE(2, verbose_stream() << "new unit " << lit << "\n");
s.assign_unit(lit);
++m_stats.m_num_units;
@ -266,7 +271,7 @@ namespace sat {
VERIFY(i != j);
literal u(i, false);
literal v(j, false);
IF_VERBOSE(0,
IF_VERBOSE(10,
verbose_stream() << u << " " << c << "\n";
verbose_stream() << v << ": ";
for (cut const& d : cuts[v.var()]) verbose_stream() << d << "\n";);
@ -278,20 +283,19 @@ namespace sat {
new_eq = true;
break;
}
if (true || m_config.m_full) {
cut nc(c);
nc.negate();
if (cut2id.find(&nc, j)) {
VERIFY(i != j); // maybe possible with don't cares
literal u(i, false);
literal v(j, true);
certify_equivalence(u, v, c);
// validate_eq(u, v);
add_eq(u, v);
TRACE("aig_simplifier", tout << u << " == " << v << "\n";);
new_eq = true;
break;
}
cut nc(c);
nc.negate();
if (cut2id.find(&nc, j)) {
if (i == j) continue;
literal u(i, false);
literal v(j, true);
certify_equivalence(u, v, c);
// validate_eq(u, v);
add_eq(u, v);
TRACE("aig_simplifier", tout << u << " == " << v << "\n";);
new_eq = true;
break;
}
cut2id.insert(&c, i);
}
@ -389,72 +393,122 @@ namespace sat {
}
}
void aig_simplifier::add_dont_cares(vector<cut_set> const& cuts) {
if (m_config.m_enable_dont_cares) {
cuts2pairs(cuts);
pairs2dont_cares();
dont_cares2cuts(cuts);
}
}
/**
* collect pairs of variables that occur in cut sets.
*/
void aig_simplifier::collect_pairs(vector<cut_set> const& cuts) {
void aig_simplifier::cuts2pairs(vector<cut_set> const& cuts) {
svector<var_pair> dcs;
for (auto const& p : m_pairs) {
if (p.op != none)
dcs.push_back(p);
}
m_pairs.reset();
for (unsigned k = cuts.size(); k-- > 0; ) {
for (auto const& c : cuts[k]) {
for (auto const& cs : cuts) {
for (auto const& c : cs) {
for (unsigned i = c.size(); i-- > 0; ) {
for (unsigned j = i; j-- > 0; ) {
m_pairs.insert(var_pair(c[i],c[j]));
m_pairs.insert(var_pair(c[j],c[i]));
}
}
}
}
// don't lose previous don't cares
for (auto const& p : dcs) {
if (m_pairs.contains(p))
m_pairs.insert(p);
}
}
/**
* compute masks for pairs.
*/
void aig_simplifier::add_masks_to_pairs() {
void aig_simplifier::pairs2dont_cares() {
big b(s.rand());
b.init(s, true);
for (auto& p : m_pairs) {
if (p.op != none) continue;
literal u(p.u, false), v(p.v, false);
// u -> v, then u & ~v is impossible
if (b.connected(u, v)) {
add_mask(u, ~v, p);
p.op = pn;
}
else if (b.connected(u, ~v)) {
add_mask(u, v, p);
p.op = pp;
}
else if (b.connected(~u, v)) {
add_mask(~u, ~v, p);
p.op = nn;
}
else if (b.connected(~u, ~v)) {
add_mask(~u, v, p);
}
else {
memset(p.masks, 0xFF, var_pair::size());
p.op = np;
}
}
IF_VERBOSE(2, {
unsigned n = 0; for (auto const& p : m_pairs) if (p.op != none) ++n;
verbose_stream() << n << " / " << m_pairs.size() << " don't cares\n";
});
}
/*
* compute masks for each possible occurrence of u, v within 2-6 elements.
* combinaions relative to u.sign(), v.sign() are impossible.
*/
void aig_simplifier::add_mask(literal u, literal v, var_pair& p) {
unsigned offset = 0;
bool su = u.sign(), sv = v.sign();
for (unsigned k = 2; k <= 6; ++k) {
for (unsigned i = 0; i < k; ++i) {
for (unsigned j = i + 1; j < k; ++j) {
// convert su, sv, k, i, j into a mask for 2^k bits.
// for outputs
p.masks[offset++] = 0;
void aig_simplifier::dont_cares2cuts(vector<cut_set> const& cuts) {
struct rep {
cut src, dst; unsigned v;
rep(cut const& s, cut const& d, unsigned v):src(s), dst(d), v(v) {}
rep():v(UINT_MAX) {}
};
vector<rep> to_replace;
cut d;
for (auto const& cs : cuts) {
for (auto const& c : cs) {
if (rewrite_cut(c, d)) {
to_replace.push_back(rep(c, d, cs.var()));
}
}
}
for (auto const& p : to_replace) {
m_aig_cuts.replace(p.v, p.src, p.dst);
}
m_stats.m_num_dont_care_reductions += to_replace.size();
}
/*
* compute masks for position i, j and op-code p.op
*/
uint64_t aig_simplifier::op2dont_care(unsigned i, unsigned j, var_pair const& p) {
SASSERT(i < j && j < 6);
if (p.op == none) return 0ull;
// first position of mask is offset into output bits contributed by i and j
bool i_is_0 = (p.op == np || p.op == nn);
bool j_is_0 = (p.op == pn || p.op == nn);
uint64_t first = (i_is_0 ? 0 : (1 << i)) + (j_is_0 ? 0 : (1 << j));
uint64_t inc = 1ull << (j + 1);
uint64_t r = 1ull << first;
while (inc < 64ull) { r |= (r << inc); inc *= 2; }
return r;
}
/**
* apply obtained masks to cut sets.
* apply obtained dont_cares to cut sets.
*/
void aig_simplifier::apply_masks() {
bool aig_simplifier::rewrite_cut(cut const& c, cut& d) {
bool init = false;
for (unsigned i = 0; i < c.size(); ++i) {
for (unsigned j = i + 1; j < c.size(); ++j) {
var_pair p(c[i], c[j]);
if (m_pairs.find(p, p) && p.op != none) {
if (!init) { d = c; init = true; }
d.set_table(d.m_table | op2dont_care(i, j, p));
}
}
}
return init && d.m_table != c.m_table;
}
void aig_simplifier::collect_statistics(statistics& st) const {
@ -463,6 +517,7 @@ namespace sat {
st.update("sat-aig.ands", m_stats.m_num_ands);
st.update("sat-aig.ites", m_stats.m_num_ites);
st.update("sat-aig.xors", m_stats.m_num_xors);
st.update("sat-aig.dc-reduce", m_stats.m_num_dont_care_reductions);
}
void aig_simplifier::validate_unit(literal lit) {

View file

@ -27,13 +27,20 @@ namespace sat {
public:
struct stats {
unsigned m_num_eqs, m_num_units, m_num_cuts, m_num_xors, m_num_ands, m_num_ites;
unsigned m_num_calls;
unsigned m_num_calls, m_num_dont_care_reductions;
stats() { reset(); }
void reset() { memset(this, 0, sizeof(*this)); }
};
struct config {
bool m_full;
config():m_full(false) {}
bool m_validate;
bool m_enable_units;
bool m_enable_dont_cares;
bool m_add_learned;
config():
m_validate(false),
m_enable_units(false),
m_enable_dont_cares(false),
m_add_learned(true) {}
};
private:
struct report;
@ -60,14 +67,16 @@ namespace sat {
* Apply the masks on cut sets so to allow detecting
* equivalences modulo implications.
*/
enum op_code { pp, pn, np, nn, none };
struct var_pair {
unsigned u, v;
uint64_t masks[35];
static unsigned size() { return sizeof(uint64_t)*35; }
var_pair(unsigned u, unsigned v): u(u), v(v) {
op_code op;
var_pair(unsigned _u, unsigned _v): u(_u), v(_v), op(none) {
if (u > v) std::swap(u, v);
}
var_pair(): u(UINT_MAX), v(UINT_MAX) {}
var_pair(): u(UINT_MAX), v(UINT_MAX), op(none) {}
struct hash {
unsigned operator()(var_pair const& p) const {
@ -82,10 +91,13 @@ namespace sat {
};
hashtable<var_pair, var_pair::hash, var_pair::eq> m_pairs;
void collect_pairs(vector<cut_set> const& cuts);
void add_mask(literal u, literal v, var_pair& p);
void add_masks_to_pairs();
void apply_masks();
void add_dont_cares(vector<cut_set> const& cuts);
void cuts2pairs(vector<cut_set> const& cuts);
void pairs2dont_cares();
void dont_cares2cuts(vector<cut_set> const& cuts);
bool rewrite_cut(cut const& c, cut& r);
uint64_t op2dont_care(unsigned i, unsigned j, var_pair const& p);
public:
aig_simplifier(solver& s);
~aig_simplifier();

View file

@ -31,7 +31,7 @@ namespace sat {
- pre-allocate fixed array instead of vector for cut_set to avoid overhead for memory allocation.
*/
bool cut_set::insert(on_update_t* on_add, on_update_t* on_del, cut const& c) {
bool cut_set::insert(on_update_t& on_add, on_update_t& on_del, cut const& c) {
unsigned i = 0, j = 0, k = m_size;
for (; i < k; ++i) {
cut const& a = (*this)[i];
@ -42,8 +42,11 @@ namespace sat {
std::swap(m_cuts[i--], m_cuts[--k]);
}
}
shrink(on_del, i);
// for DRAT make sure to add new element before removing old cuts
// the new cut may need to be justified relative to the old cut
push_back(on_add, c);
std::swap(m_cuts[i++], m_cuts[m_size-1]);
shrink(on_del, i);
return true;
}
@ -64,16 +67,16 @@ namespace sat {
}
void cut_set::shrink(on_update_t* on_del, unsigned j) {
if (m_var != UINT_MAX && on_del && *on_del) {
void cut_set::shrink(on_update_t& on_del, unsigned j) {
if (m_var != UINT_MAX && on_del) {
for (unsigned i = j; i < m_size; ++i) {
(*on_del)(m_var, m_cuts[i]);
on_del(m_var, m_cuts[i]);
}
}
m_size = j;
}
void cut_set::push_back(on_update_t* on_add, cut const& c) {
void cut_set::push_back(on_update_t& on_add, cut const& c) {
SASSERT(m_max_size > 0);
if (m_size == m_max_size) {
m_max_size *= 2;
@ -81,10 +84,26 @@ namespace sat {
memcpy(new_cuts, m_cuts, sizeof(cut)*m_size);
m_cuts = new_cuts;
}
if (m_var != UINT_MAX && on_add && *on_add) (*on_add)(m_var, c);
if (m_var != UINT_MAX && on_add) on_add(m_var, c);
m_cuts[m_size++] = c;
}
void cut_set::replace(on_update_t& on_add, on_update_t& on_del, cut const& src, cut const& dst) {
SASSERT(src != dst);
insert(on_add, on_del, dst);
for (unsigned i = 0; i < size(); ++i) {
if (src == (*this)[i]) {
evict(on_del, i);
break;
}
}
}
void cut_set::evict(on_update_t& on_del, unsigned idx) {
if (m_var != UINT_MAX && on_del) on_del(m_var, m_cuts[idx]);
m_cuts[idx] = m_cuts[--m_size];
}
void cut_set::init(region& r, unsigned max_sz, unsigned v) {
m_var = v;
m_max_size = max_sz;

View file

@ -137,18 +137,27 @@ namespace sat {
cut_set(): m_var(UINT_MAX), m_region(nullptr), m_size(0), m_max_size(0), m_cuts(nullptr) {}
void init(region& r, unsigned max_sz, unsigned v);
bool insert(on_update_t* on_add, on_update_t* on_del, cut const& c);
bool insert(on_update_t& on_add, on_update_t& on_del, cut const& c);
bool no_duplicates() const;
unsigned var() const { return m_var; }
unsigned size() const { return m_size; }
cut const * begin() const { return m_cuts; }
cut const * end() const { return m_cuts + m_size; }
cut const & back() { return m_cuts[m_size-1]; }
void push_back(on_update_t* on_add, cut const& c);
void reset(on_update_t* on_del) { shrink(on_del, 0); }
void push_back(on_update_t& on_add, cut const& c);
void reset(on_update_t& on_del) { shrink(on_del, 0); }
cut const & operator[](unsigned idx) { return m_cuts[idx]; }
void shrink(on_update_t* on_del, unsigned j);
void swap(cut_set& other) { std::swap(m_size, other.m_size); std::swap(m_cuts, other.m_cuts); std::swap(m_max_size, other.m_max_size); }
void evict(on_update_t* on_del, unsigned idx) { if (m_var != UINT_MAX && on_del && *on_del) (*on_del)(m_var, m_cuts[idx]); m_cuts[idx] = m_cuts[--m_size]; }
void shrink(on_update_t& on_del, unsigned j);
void swap(cut_set& other) {
std::swap(m_var, other.m_var);
std::swap(m_size, other.m_size);
std::swap(m_max_size, other.m_max_size);
std::swap(m_cuts, other.m_cuts);
}
void evict(on_update_t& on_del, unsigned idx);
void replace(on_update_t& on_add, on_update_t& on_del, cut const& src, cut const& dst);
std::ostream& display(std::ostream& out) const;
};

View file

@ -122,7 +122,7 @@ symbol::symbol(char const * d) {
}
symbol & symbol::operator=(char const * d) {
m_data = g_symbol_tables->get_str(d);
m_data = d ? g_symbol_tables->get_str(d) : nullptr;
return *this;
}