mirror of
https://github.com/Z3Prover/z3
synced 2025-04-08 10:25:18 +00:00
first pass on extracting binary clauses, ensure that binary clauses used by simplifier are in scope of DRAT, add certification of units
Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
parent
d77ac69015
commit
a12fca3105
|
@ -232,6 +232,7 @@ namespace sat {
|
|||
m_stats.m_num_cuts = m_aig_cuts.num_cuts();
|
||||
add_dont_cares(cuts);
|
||||
cuts2equiv(cuts);
|
||||
cuts2implies(cuts);
|
||||
}
|
||||
|
||||
void aig_simplifier::cuts2equiv(vector<cut_set> const& cuts) {
|
||||
|
@ -254,10 +255,10 @@ namespace sat {
|
|||
cut nc(c);
|
||||
nc.negate();
|
||||
if (m_config.m_enable_units && c.is_true()) {
|
||||
assign_unit(u);
|
||||
assign_unit(c, u);
|
||||
}
|
||||
else if (m_config.m_enable_units && c.is_false()) {
|
||||
assign_unit(~u);
|
||||
assign_unit(nc, ~u);
|
||||
}
|
||||
else if (cut2id.find(&c, j)) {
|
||||
literal v(j, false);
|
||||
|
@ -279,11 +280,12 @@ namespace sat {
|
|||
}
|
||||
}
|
||||
|
||||
void aig_simplifier::assign_unit(literal lit) {
|
||||
void aig_simplifier::assign_unit(cut const& c, literal lit) {
|
||||
if (s.value(lit) == l_undef) {
|
||||
// validate_unit(lit);
|
||||
IF_VERBOSE(2, verbose_stream() << "new unit " << lit << "\n");
|
||||
s.assign_unit(lit);
|
||||
certify_unit(lit, c);
|
||||
++m_stats.m_num_units;
|
||||
}
|
||||
}
|
||||
|
@ -329,6 +331,103 @@ namespace sat {
|
|||
}
|
||||
}
|
||||
|
||||
void aig_simplifier::cuts2implies(vector<cut_set> const& cuts) {
|
||||
if (!m_config.m_enable_implies) return;
|
||||
vector<vector<std::pair<unsigned, cut const*>>> var_tables;
|
||||
map<cut const*, unsigned, cut::dom_hash_proc, cut::dom_eq_proc> cut2tables;
|
||||
unsigned j = 0;
|
||||
big big(s.rand());
|
||||
big.init(s, true);
|
||||
for (auto const& cs : cuts) {
|
||||
for (auto const& c : cs) {
|
||||
if (c.is_false() || c.is_true())
|
||||
continue;
|
||||
if (!cut2tables.find(&c, j)) {
|
||||
j = var_tables.size();
|
||||
var_tables.push_back(vector<std::pair<unsigned, cut const*>>());
|
||||
cut2tables.insert(&c, j);
|
||||
}
|
||||
var_tables[j].push_back(std::make_pair(cs.var(), &c));
|
||||
}
|
||||
}
|
||||
for (unsigned i = 0; i < var_tables.size(); ++i) {
|
||||
auto const& vt = var_tables[i];
|
||||
for (unsigned j = 0; j < vt.size(); ++j) {
|
||||
literal u(vt[j].first, false);
|
||||
cut const& c1 = *vt[j].second;
|
||||
cut nc1(c1);
|
||||
uint64_t t1 = c1.table();
|
||||
uint64_t n1 = nc1.table();
|
||||
for (unsigned k = j + 1; k < vt.size(); ++k) {
|
||||
literal v(vt[k].first, false);
|
||||
cut const& c2 = *vt[k].second;
|
||||
uint64_t t2 = c2.table();
|
||||
uint64_t n2 = c2.ntable();
|
||||
//
|
||||
if (t1 == t2 || t1 == n2) {
|
||||
// already handled
|
||||
}
|
||||
else if ((t1 | t2) == t2) {
|
||||
learn_implies(big, c1, u, v);
|
||||
}
|
||||
else if ((t1 | n2) == n2) {
|
||||
learn_implies(big, c1, u, ~v);
|
||||
}
|
||||
else if ((n1 | t2) == t2) {
|
||||
learn_implies(big, nc1, ~u, v);
|
||||
}
|
||||
else if ((n1 | n2) == n2) {
|
||||
learn_implies(big, nc1, ~u, ~v);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void aig_simplifier::learn_implies(big& big, cut const& c, literal u, literal v) {
|
||||
bin_rel q, p(~u, v);
|
||||
if (m_bins.find(p, q) && q.op != none)
|
||||
return;
|
||||
if (big.connected(u, v))
|
||||
return;
|
||||
s.mk_clause(~u, v, true);
|
||||
m_bins.insert(p);
|
||||
certify_implies(u, v, c);
|
||||
track_binary(~u, v);
|
||||
}
|
||||
|
||||
void aig_simplifier::track_binary(bin_rel const& p) {
|
||||
if (s.m_config.m_drat) {
|
||||
literal u, v;
|
||||
p.to_binary(u, v);
|
||||
track_binary(u, v);
|
||||
}
|
||||
}
|
||||
|
||||
void aig_simplifier::untrack_binary(bin_rel const& p) {
|
||||
if (s.m_config.m_drat) {
|
||||
literal u, v;
|
||||
p.to_binary(u, v);
|
||||
untrack_binary(u, v);
|
||||
}
|
||||
}
|
||||
|
||||
void aig_simplifier::track_binary(literal u, literal v) {
|
||||
if (s.m_config.m_drat) {
|
||||
s.m_drat.add(u, v, true);
|
||||
}
|
||||
}
|
||||
|
||||
void aig_simplifier::untrack_binary(literal u, literal v) {
|
||||
if (s.m_config.m_drat) {
|
||||
s.m_drat.del(u, v);
|
||||
}
|
||||
}
|
||||
|
||||
void aig_simplifier::certify_unit(literal u, cut const& c) {
|
||||
certify_implies(~u, u, c);
|
||||
}
|
||||
|
||||
/**
|
||||
* Equilvalences modulo cuts are not necessarily DRAT derivable.
|
||||
* To ensure that there is a DRAT derivation we create all resolvents
|
||||
|
@ -337,36 +436,37 @@ namespace sat {
|
|||
* contain complementary literals.
|
||||
*/
|
||||
void aig_simplifier::certify_equivalence(literal u, literal v, cut const& c) {
|
||||
certify_implies(u, v, c);
|
||||
certify_implies(v, u, c);
|
||||
}
|
||||
|
||||
/**
|
||||
* certify that u implies v, where c is the cut for u.
|
||||
* Then every position in c where u is true, it has to be
|
||||
* the case that v is too.
|
||||
* Where u is false, v can have any value.
|
||||
* Thus, for every clause C or u', where u' is u or ~u,
|
||||
* it follows that C or ~u or v
|
||||
*/
|
||||
void aig_simplifier::certify_implies(literal u, literal v, cut const& c) {
|
||||
if (!s.m_config.m_drat) return;
|
||||
|
||||
vector<literal_vector> clauses;
|
||||
std::function<void(literal_vector const& clause)> on_clause =
|
||||
[&](literal_vector const& clause) { SASSERT(clause.back().var() == u.var()); clauses.push_back(clause); };
|
||||
[&,this](literal_vector const& clause) {
|
||||
SASSERT(clause.back().var() == u.var());
|
||||
clauses.push_back(clause);
|
||||
clauses.back().back() = ~u;
|
||||
if (~u != v) clauses.back().push_back(v);
|
||||
s.m_drat.add(clauses.back());
|
||||
};
|
||||
m_aig_cuts.cut2def(on_clause, c, u);
|
||||
|
||||
// create C or u or ~v for each clause C or u
|
||||
// create C or ~u or v for each clause C or ~u
|
||||
for (auto& clause : clauses) {
|
||||
literal w = clause.back();
|
||||
SASSERT(w.var() == u.var());
|
||||
clause.push_back(w == u ? ~v : v);
|
||||
s.m_drat.add(clause);
|
||||
}
|
||||
// create C or ~u or v for each clause
|
||||
unsigned i = 0, sz = clauses.size();
|
||||
for (; i < sz; ++i) {
|
||||
literal_vector clause(clauses[i]);
|
||||
clause[clause.size()-2] = ~clause[clause.size()-2];
|
||||
clause[clause.size()-1] = ~clause[clause.size()-1];
|
||||
clauses.push_back(clause);
|
||||
s.m_drat.add(clause);
|
||||
}
|
||||
|
||||
// create all resolvents over C. C is assumed to
|
||||
// contain all combinations of some set of literals.
|
||||
i = 0; sz = clauses.size();
|
||||
while (sz - i > 2) {
|
||||
SASSERT((sz & (sz - 1)) == 0);
|
||||
unsigned i = 0, sz = clauses.size();
|
||||
while (sz - i > 1) {
|
||||
SASSERT((sz & (sz - 1)) == 0 && "sz is a power of 2");
|
||||
for (; i < sz; ++i) {
|
||||
auto const& clause = clauses[i];
|
||||
if (clause[0].sign()) {
|
||||
|
@ -383,13 +483,12 @@ namespace sat {
|
|||
|
||||
// once we established equivalence, don't need auxiliary clauses for DRAT.
|
||||
for (auto const& clause : clauses) {
|
||||
if (clause.size() > 2) {
|
||||
if (clause.size() > 1) {
|
||||
s.m_drat.del(clause);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void aig_simplifier::add_dont_cares(vector<cut_set> const& cuts) {
|
||||
if (m_config.m_enable_dont_cares) {
|
||||
cuts2bins(cuts);
|
||||
|
@ -419,8 +518,12 @@ namespace sat {
|
|||
}
|
||||
// don't lose previous don't cares
|
||||
for (auto const& p : dcs) {
|
||||
if (m_bins.contains(p))
|
||||
if (m_bins.contains(p)) {
|
||||
m_bins.insert(p);
|
||||
}
|
||||
else {
|
||||
untrack_binary(p);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -446,6 +549,7 @@ namespace sat {
|
|||
else if (b.connected(~u, ~v)) {
|
||||
p.op = np;
|
||||
}
|
||||
track_binary(p);
|
||||
}
|
||||
IF_VERBOSE(2, {
|
||||
unsigned n = 0; for (auto const& p : m_bins) if (p.op != none) ++n;
|
||||
|
|
|
@ -36,41 +36,31 @@ namespace sat {
|
|||
bool m_validate;
|
||||
bool m_enable_units;
|
||||
bool m_enable_dont_cares;
|
||||
bool m_enable_implies;
|
||||
bool m_add_learned;
|
||||
config():
|
||||
m_validate(false),
|
||||
m_enable_units(false),
|
||||
m_enable_dont_cares(false),
|
||||
m_enable_implies(false),
|
||||
m_add_learned(true) {}
|
||||
};
|
||||
private:
|
||||
struct report;
|
||||
struct validator;
|
||||
|
||||
solver& s;
|
||||
stats m_stats;
|
||||
config m_config;
|
||||
aig_cuts m_aig_cuts;
|
||||
unsigned m_trail_size;
|
||||
literal_vector m_lits;
|
||||
validator* m_validator;
|
||||
|
||||
void clauses2aig();
|
||||
void aig2clauses();
|
||||
void cuts2equiv(vector<cut_set> const& cuts);
|
||||
void uf2equiv(union_find<> const& uf);
|
||||
void assign_unit(literal lit);
|
||||
void assign_equiv(cut const& c, literal u, literal v);
|
||||
void ensure_validator();
|
||||
void validate_unit(literal lit);
|
||||
void validate_eq(literal a, literal b);
|
||||
void certify_equivalence(literal u, literal v, cut const& c);
|
||||
|
||||
/**
|
||||
* collect pairs of literal combinations that are impossible
|
||||
* base on binary implication graph queries. Apply the masks
|
||||
* on cut sets so to allow detecting equivalences modulo
|
||||
* implications.
|
||||
*
|
||||
* The encoding is as follows:
|
||||
* a or b -> op = nn because (~a & ~b) is a don't care
|
||||
* ~a or b -> op = pn because (a & ~b) is a don't care
|
||||
* a or ~b -> op = np because (~a & b) is a don't care
|
||||
* ~a or ~b -> op = pp because (a & b) is a don't care
|
||||
*
|
||||
*/
|
||||
|
||||
enum op_code { pp, pn, np, nn, none };
|
||||
|
@ -81,6 +71,18 @@ namespace sat {
|
|||
bin_rel(unsigned _u, unsigned _v): u(_u), v(_v), op(none) {
|
||||
if (u > v) std::swap(u, v);
|
||||
}
|
||||
// convert binary clause into a bin-rel
|
||||
bin_rel(literal _u, literal _v): u(_u.var()), v(_v.var()), op(none) {
|
||||
if (_u.sign() && _v.sign()) op = pp;
|
||||
else if (_u.sign()) op = pn;
|
||||
else if (_v.sign()) op = np;
|
||||
else op = nn;
|
||||
if (u > v) {
|
||||
std::swap(u, v);
|
||||
if (op == np) op = pn;
|
||||
else if (op == pn) op = np;
|
||||
}
|
||||
}
|
||||
bin_rel(): u(UINT_MAX), v(UINT_MAX), op(none) {}
|
||||
|
||||
struct hash {
|
||||
|
@ -93,8 +95,46 @@ namespace sat {
|
|||
return a.u == b.u && a.v == b.v;
|
||||
}
|
||||
};
|
||||
void to_binary(literal& lu, literal& lv) const {
|
||||
switch (op) {
|
||||
case pp: lu = literal(u, true); lv = literal(v, true); break;
|
||||
case pn: lu = literal(u, true); lv = literal(v, false); break;
|
||||
case np: lu = literal(u, false); lv = literal(v, true); break;
|
||||
case nn: lu = literal(u, false); lv = literal(v, false); break;
|
||||
default: UNREACHABLE(); break;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
solver& s;
|
||||
stats m_stats;
|
||||
config m_config;
|
||||
aig_cuts m_aig_cuts;
|
||||
unsigned m_trail_size;
|
||||
literal_vector m_lits;
|
||||
validator* m_validator;
|
||||
hashtable<bin_rel, bin_rel::hash, bin_rel::eq> m_bins;
|
||||
|
||||
void clauses2aig();
|
||||
void aig2clauses();
|
||||
void cuts2equiv(vector<cut_set> const& cuts);
|
||||
void cuts2implies(vector<cut_set> const& cuts);
|
||||
void uf2equiv(union_find<> const& uf);
|
||||
void assign_unit(cut const& c, literal lit);
|
||||
void assign_equiv(cut const& c, literal u, literal v);
|
||||
void learn_implies(big& big, cut const& c, literal u, literal v);
|
||||
void ensure_validator();
|
||||
void validate_unit(literal lit);
|
||||
void validate_eq(literal a, literal b);
|
||||
void certify_unit(literal u, cut const& c);
|
||||
void certify_implies(literal u, literal v, cut const& c);
|
||||
void certify_equivalence(literal u, literal v, cut const& c);
|
||||
void track_binary(literal u, literal v);
|
||||
void untrack_binary(literal u, literal v);
|
||||
void track_binary(bin_rel const& p);
|
||||
void untrack_binary(bin_rel const& p);
|
||||
|
||||
|
||||
void add_dont_cares(vector<cut_set> const& cuts);
|
||||
void cuts2bins(vector<cut_set> const& cuts);
|
||||
|
|
|
@ -139,12 +139,7 @@ namespace sat {
|
|||
}
|
||||
|
||||
bool cut::operator==(cut const& other) const {
|
||||
if (m_size != other.m_size) return false;
|
||||
if (table() != other.table()) return false;
|
||||
for (unsigned i = 0; i < m_size; ++i) {
|
||||
if ((*this)[i] != other[i]) return false;
|
||||
}
|
||||
return true;
|
||||
return table() == other.table() && dom_eq(other);
|
||||
}
|
||||
|
||||
unsigned cut::hash() const {
|
||||
|
@ -152,6 +147,20 @@ namespace sat {
|
|||
[](cut const& c) { return (unsigned)c.table(); },
|
||||
[](cut const& c, unsigned i) { return c[i]; });
|
||||
}
|
||||
|
||||
unsigned cut::dom_hash() const {
|
||||
return get_composite_hash(*this, m_size,
|
||||
[](cut const& c) { return 3; },
|
||||
[](cut const& c, unsigned i) { return c[i]; });
|
||||
}
|
||||
|
||||
bool cut::dom_eq(cut const& other) const {
|
||||
if (m_size != other.m_size) return false;
|
||||
for (unsigned i = 0; i < m_size; ++i) {
|
||||
if ((*this)[i] != other[i]) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::ostream& cut::display(std::ostream& out) const {
|
||||
out << "{";
|
||||
|
|
|
@ -71,6 +71,7 @@ namespace sat {
|
|||
void negate() { set_table(~m_table); }
|
||||
void set_table(uint64_t t) { m_table = t & table_mask(); }
|
||||
uint64_t table() const { return (m_table | m_dont_care) & table_mask(); }
|
||||
uint64_t ntable() const { return (~m_table | m_dont_care) & table_mask(); }
|
||||
|
||||
uint64_t dont_care() const { return m_dont_care; }
|
||||
void add_dont_care(uint64_t t) const { m_dont_care |= t; }
|
||||
|
@ -81,6 +82,8 @@ namespace sat {
|
|||
bool operator==(cut const& other) const;
|
||||
bool operator!=(cut const& other) const { return !(*this == other); }
|
||||
unsigned hash() const;
|
||||
unsigned dom_hash() const;
|
||||
bool dom_eq(cut const& other) const;
|
||||
struct eq_proc {
|
||||
bool operator()(cut const& a, cut const& b) const { return a == b; }
|
||||
bool operator()(cut const* a, cut const* b) const { return *a == *b; }
|
||||
|
@ -90,6 +93,16 @@ namespace sat {
|
|||
unsigned operator()(cut const* a) const { return a->hash(); }
|
||||
};
|
||||
|
||||
struct dom_eq_proc {
|
||||
bool operator()(cut const& a, cut const& b) const { return a.dom_eq(b); }
|
||||
bool operator()(cut const* a, cut const* b) const { return a->dom_eq(*b); }
|
||||
};
|
||||
|
||||
struct dom_hash_proc {
|
||||
unsigned operator()(cut const& a) const { return a.dom_hash(); }
|
||||
unsigned operator()(cut const* a) const { return a->dom_hash(); }
|
||||
};
|
||||
|
||||
unsigned operator[](unsigned idx) const {
|
||||
return (idx >= m_size) ? UINT_MAX : m_elems[idx];
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue