3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-08 10:25:18 +00:00

first pass on extracting binary clauses, ensure that binary clauses used by simplifier are in scope of DRAT, add certification of units

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2020-01-14 09:08:40 -08:00
parent d77ac69015
commit a12fca3105
4 changed files with 220 additions and 54 deletions

View file

@ -232,6 +232,7 @@ namespace sat {
m_stats.m_num_cuts = m_aig_cuts.num_cuts();
add_dont_cares(cuts);
cuts2equiv(cuts);
cuts2implies(cuts);
}
void aig_simplifier::cuts2equiv(vector<cut_set> const& cuts) {
@ -254,10 +255,10 @@ namespace sat {
cut nc(c);
nc.negate();
if (m_config.m_enable_units && c.is_true()) {
assign_unit(u);
assign_unit(c, u);
}
else if (m_config.m_enable_units && c.is_false()) {
assign_unit(~u);
assign_unit(nc, ~u);
}
else if (cut2id.find(&c, j)) {
literal v(j, false);
@ -279,11 +280,12 @@ namespace sat {
}
}
void aig_simplifier::assign_unit(literal lit) {
void aig_simplifier::assign_unit(cut const& c, literal lit) {
if (s.value(lit) == l_undef) {
// validate_unit(lit);
IF_VERBOSE(2, verbose_stream() << "new unit " << lit << "\n");
s.assign_unit(lit);
certify_unit(lit, c);
++m_stats.m_num_units;
}
}
@ -329,6 +331,103 @@ namespace sat {
}
}
void aig_simplifier::cuts2implies(vector<cut_set> const& cuts) {
if (!m_config.m_enable_implies) return;
vector<vector<std::pair<unsigned, cut const*>>> var_tables;
map<cut const*, unsigned, cut::dom_hash_proc, cut::dom_eq_proc> cut2tables;
unsigned j = 0;
big big(s.rand());
big.init(s, true);
for (auto const& cs : cuts) {
for (auto const& c : cs) {
if (c.is_false() || c.is_true())
continue;
if (!cut2tables.find(&c, j)) {
j = var_tables.size();
var_tables.push_back(vector<std::pair<unsigned, cut const*>>());
cut2tables.insert(&c, j);
}
var_tables[j].push_back(std::make_pair(cs.var(), &c));
}
}
for (unsigned i = 0; i < var_tables.size(); ++i) {
auto const& vt = var_tables[i];
for (unsigned j = 0; j < vt.size(); ++j) {
literal u(vt[j].first, false);
cut const& c1 = *vt[j].second;
cut nc1(c1);
uint64_t t1 = c1.table();
uint64_t n1 = nc1.table();
for (unsigned k = j + 1; k < vt.size(); ++k) {
literal v(vt[k].first, false);
cut const& c2 = *vt[k].second;
uint64_t t2 = c2.table();
uint64_t n2 = c2.ntable();
//
if (t1 == t2 || t1 == n2) {
// already handled
}
else if ((t1 | t2) == t2) {
learn_implies(big, c1, u, v);
}
else if ((t1 | n2) == n2) {
learn_implies(big, c1, u, ~v);
}
else if ((n1 | t2) == t2) {
learn_implies(big, nc1, ~u, v);
}
else if ((n1 | n2) == n2) {
learn_implies(big, nc1, ~u, ~v);
}
}
}
}
}
void aig_simplifier::learn_implies(big& big, cut const& c, literal u, literal v) {
bin_rel q, p(~u, v);
if (m_bins.find(p, q) && q.op != none)
return;
if (big.connected(u, v))
return;
s.mk_clause(~u, v, true);
m_bins.insert(p);
certify_implies(u, v, c);
track_binary(~u, v);
}
void aig_simplifier::track_binary(bin_rel const& p) {
if (s.m_config.m_drat) {
literal u, v;
p.to_binary(u, v);
track_binary(u, v);
}
}
void aig_simplifier::untrack_binary(bin_rel const& p) {
if (s.m_config.m_drat) {
literal u, v;
p.to_binary(u, v);
untrack_binary(u, v);
}
}
void aig_simplifier::track_binary(literal u, literal v) {
if (s.m_config.m_drat) {
s.m_drat.add(u, v, true);
}
}
void aig_simplifier::untrack_binary(literal u, literal v) {
if (s.m_config.m_drat) {
s.m_drat.del(u, v);
}
}
void aig_simplifier::certify_unit(literal u, cut const& c) {
certify_implies(~u, u, c);
}
/**
* Equilvalences modulo cuts are not necessarily DRAT derivable.
* To ensure that there is a DRAT derivation we create all resolvents
@ -337,36 +436,37 @@ namespace sat {
* contain complementary literals.
*/
void aig_simplifier::certify_equivalence(literal u, literal v, cut const& c) {
certify_implies(u, v, c);
certify_implies(v, u, c);
}
/**
* certify that u implies v, where c is the cut for u.
* Then every position in c where u is true, it has to be
* the case that v is too.
* Where u is false, v can have any value.
* Thus, for every clause C or u', where u' is u or ~u,
* it follows that C or ~u or v
*/
void aig_simplifier::certify_implies(literal u, literal v, cut const& c) {
if (!s.m_config.m_drat) return;
vector<literal_vector> clauses;
std::function<void(literal_vector const& clause)> on_clause =
[&](literal_vector const& clause) { SASSERT(clause.back().var() == u.var()); clauses.push_back(clause); };
[&,this](literal_vector const& clause) {
SASSERT(clause.back().var() == u.var());
clauses.push_back(clause);
clauses.back().back() = ~u;
if (~u != v) clauses.back().push_back(v);
s.m_drat.add(clauses.back());
};
m_aig_cuts.cut2def(on_clause, c, u);
// create C or u or ~v for each clause C or u
// create C or ~u or v for each clause C or ~u
for (auto& clause : clauses) {
literal w = clause.back();
SASSERT(w.var() == u.var());
clause.push_back(w == u ? ~v : v);
s.m_drat.add(clause);
}
// create C or ~u or v for each clause
unsigned i = 0, sz = clauses.size();
for (; i < sz; ++i) {
literal_vector clause(clauses[i]);
clause[clause.size()-2] = ~clause[clause.size()-2];
clause[clause.size()-1] = ~clause[clause.size()-1];
clauses.push_back(clause);
s.m_drat.add(clause);
}
// create all resolvents over C. C is assumed to
// contain all combinations of some set of literals.
i = 0; sz = clauses.size();
while (sz - i > 2) {
SASSERT((sz & (sz - 1)) == 0);
unsigned i = 0, sz = clauses.size();
while (sz - i > 1) {
SASSERT((sz & (sz - 1)) == 0 && "sz is a power of 2");
for (; i < sz; ++i) {
auto const& clause = clauses[i];
if (clause[0].sign()) {
@ -383,13 +483,12 @@ namespace sat {
// once we established equivalence, don't need auxiliary clauses for DRAT.
for (auto const& clause : clauses) {
if (clause.size() > 2) {
if (clause.size() > 1) {
s.m_drat.del(clause);
}
}
}
}
void aig_simplifier::add_dont_cares(vector<cut_set> const& cuts) {
if (m_config.m_enable_dont_cares) {
cuts2bins(cuts);
@ -419,8 +518,12 @@ namespace sat {
}
// don't lose previous don't cares
for (auto const& p : dcs) {
if (m_bins.contains(p))
if (m_bins.contains(p)) {
m_bins.insert(p);
}
else {
untrack_binary(p);
}
}
}
@ -446,6 +549,7 @@ namespace sat {
else if (b.connected(~u, ~v)) {
p.op = np;
}
track_binary(p);
}
IF_VERBOSE(2, {
unsigned n = 0; for (auto const& p : m_bins) if (p.op != none) ++n;

View file

@ -36,41 +36,31 @@ namespace sat {
bool m_validate;
bool m_enable_units;
bool m_enable_dont_cares;
bool m_enable_implies;
bool m_add_learned;
config():
m_validate(false),
m_enable_units(false),
m_enable_dont_cares(false),
m_enable_implies(false),
m_add_learned(true) {}
};
private:
struct report;
struct validator;
solver& s;
stats m_stats;
config m_config;
aig_cuts m_aig_cuts;
unsigned m_trail_size;
literal_vector m_lits;
validator* m_validator;
void clauses2aig();
void aig2clauses();
void cuts2equiv(vector<cut_set> const& cuts);
void uf2equiv(union_find<> const& uf);
void assign_unit(literal lit);
void assign_equiv(cut const& c, literal u, literal v);
void ensure_validator();
void validate_unit(literal lit);
void validate_eq(literal a, literal b);
void certify_equivalence(literal u, literal v, cut const& c);
/**
* collect pairs of literal combinations that are impossible
* base on binary implication graph queries. Apply the masks
* on cut sets so to allow detecting equivalences modulo
* implications.
*
* The encoding is as follows:
* a or b -> op = nn because (~a & ~b) is a don't care
* ~a or b -> op = pn because (a & ~b) is a don't care
* a or ~b -> op = np because (~a & b) is a don't care
* ~a or ~b -> op = pp because (a & b) is a don't care
*
*/
enum op_code { pp, pn, np, nn, none };
@ -81,6 +71,18 @@ namespace sat {
bin_rel(unsigned _u, unsigned _v): u(_u), v(_v), op(none) {
if (u > v) std::swap(u, v);
}
// convert binary clause into a bin-rel
bin_rel(literal _u, literal _v): u(_u.var()), v(_v.var()), op(none) {
if (_u.sign() && _v.sign()) op = pp;
else if (_u.sign()) op = pn;
else if (_v.sign()) op = np;
else op = nn;
if (u > v) {
std::swap(u, v);
if (op == np) op = pn;
else if (op == pn) op = np;
}
}
bin_rel(): u(UINT_MAX), v(UINT_MAX), op(none) {}
struct hash {
@ -93,8 +95,46 @@ namespace sat {
return a.u == b.u && a.v == b.v;
}
};
void to_binary(literal& lu, literal& lv) const {
switch (op) {
case pp: lu = literal(u, true); lv = literal(v, true); break;
case pn: lu = literal(u, true); lv = literal(v, false); break;
case np: lu = literal(u, false); lv = literal(v, true); break;
case nn: lu = literal(u, false); lv = literal(v, false); break;
default: UNREACHABLE(); break;
}
}
};
solver& s;
stats m_stats;
config m_config;
aig_cuts m_aig_cuts;
unsigned m_trail_size;
literal_vector m_lits;
validator* m_validator;
hashtable<bin_rel, bin_rel::hash, bin_rel::eq> m_bins;
void clauses2aig();
void aig2clauses();
void cuts2equiv(vector<cut_set> const& cuts);
void cuts2implies(vector<cut_set> const& cuts);
void uf2equiv(union_find<> const& uf);
void assign_unit(cut const& c, literal lit);
void assign_equiv(cut const& c, literal u, literal v);
void learn_implies(big& big, cut const& c, literal u, literal v);
void ensure_validator();
void validate_unit(literal lit);
void validate_eq(literal a, literal b);
void certify_unit(literal u, cut const& c);
void certify_implies(literal u, literal v, cut const& c);
void certify_equivalence(literal u, literal v, cut const& c);
void track_binary(literal u, literal v);
void untrack_binary(literal u, literal v);
void track_binary(bin_rel const& p);
void untrack_binary(bin_rel const& p);
void add_dont_cares(vector<cut_set> const& cuts);
void cuts2bins(vector<cut_set> const& cuts);

View file

@ -139,12 +139,7 @@ namespace sat {
}
bool cut::operator==(cut const& other) const {
if (m_size != other.m_size) return false;
if (table() != other.table()) return false;
for (unsigned i = 0; i < m_size; ++i) {
if ((*this)[i] != other[i]) return false;
}
return true;
return table() == other.table() && dom_eq(other);
}
unsigned cut::hash() const {
@ -152,6 +147,20 @@ namespace sat {
[](cut const& c) { return (unsigned)c.table(); },
[](cut const& c, unsigned i) { return c[i]; });
}
unsigned cut::dom_hash() const {
return get_composite_hash(*this, m_size,
[](cut const& c) { return 3; },
[](cut const& c, unsigned i) { return c[i]; });
}
bool cut::dom_eq(cut const& other) const {
if (m_size != other.m_size) return false;
for (unsigned i = 0; i < m_size; ++i) {
if ((*this)[i] != other[i]) return false;
}
return true;
}
std::ostream& cut::display(std::ostream& out) const {
out << "{";

View file

@ -71,6 +71,7 @@ namespace sat {
void negate() { set_table(~m_table); }
void set_table(uint64_t t) { m_table = t & table_mask(); }
uint64_t table() const { return (m_table | m_dont_care) & table_mask(); }
uint64_t ntable() const { return (~m_table | m_dont_care) & table_mask(); }
uint64_t dont_care() const { return m_dont_care; }
void add_dont_care(uint64_t t) const { m_dont_care |= t; }
@ -81,6 +82,8 @@ namespace sat {
bool operator==(cut const& other) const;
bool operator!=(cut const& other) const { return !(*this == other); }
unsigned hash() const;
unsigned dom_hash() const;
bool dom_eq(cut const& other) const;
struct eq_proc {
bool operator()(cut const& a, cut const& b) const { return a == b; }
bool operator()(cut const* a, cut const* b) const { return *a == *b; }
@ -90,6 +93,16 @@ namespace sat {
unsigned operator()(cut const* a) const { return a->hash(); }
};
struct dom_eq_proc {
bool operator()(cut const& a, cut const& b) const { return a.dom_eq(b); }
bool operator()(cut const* a, cut const* b) const { return a->dom_eq(*b); }
};
struct dom_hash_proc {
unsigned operator()(cut const& a) const { return a.dom_hash(); }
unsigned operator()(cut const* a) const { return a->dom_hash(); }
};
unsigned operator[](unsigned idx) const {
return (idx >= m_size) ? UINT_MAX : m_elems[idx];
}