diff --git a/src/sat/sat_aig_simplifier.cpp b/src/sat/sat_aig_simplifier.cpp index 703b73941..28752c8d1 100644 --- a/src/sat/sat_aig_simplifier.cpp +++ b/src/sat/sat_aig_simplifier.cpp @@ -232,6 +232,7 @@ namespace sat { m_stats.m_num_cuts = m_aig_cuts.num_cuts(); add_dont_cares(cuts); cuts2equiv(cuts); + cuts2implies(cuts); } void aig_simplifier::cuts2equiv(vector const& cuts) { @@ -254,10 +255,10 @@ namespace sat { cut nc(c); nc.negate(); if (m_config.m_enable_units && c.is_true()) { - assign_unit(u); + assign_unit(c, u); } else if (m_config.m_enable_units && c.is_false()) { - assign_unit(~u); + assign_unit(nc, ~u); } else if (cut2id.find(&c, j)) { literal v(j, false); @@ -279,11 +280,12 @@ namespace sat { } } - void aig_simplifier::assign_unit(literal lit) { + void aig_simplifier::assign_unit(cut const& c, literal lit) { if (s.value(lit) == l_undef) { // validate_unit(lit); IF_VERBOSE(2, verbose_stream() << "new unit " << lit << "\n"); s.assign_unit(lit); + certify_unit(lit, c); ++m_stats.m_num_units; } } @@ -329,6 +331,103 @@ namespace sat { } } + void aig_simplifier::cuts2implies(vector const& cuts) { + if (!m_config.m_enable_implies) return; + vector>> var_tables; + map cut2tables; + unsigned j = 0; + big big(s.rand()); + big.init(s, true); + for (auto const& cs : cuts) { + for (auto const& c : cs) { + if (c.is_false() || c.is_true()) + continue; + if (!cut2tables.find(&c, j)) { + j = var_tables.size(); + var_tables.push_back(vector>()); + cut2tables.insert(&c, j); + } + var_tables[j].push_back(std::make_pair(cs.var(), &c)); + } + } + for (unsigned i = 0; i < var_tables.size(); ++i) { + auto const& vt = var_tables[i]; + for (unsigned j = 0; j < vt.size(); ++j) { + literal u(vt[j].first, false); + cut const& c1 = *vt[j].second; + cut nc1(c1); + uint64_t t1 = c1.table(); + uint64_t n1 = nc1.table(); + for (unsigned k = j + 1; k < vt.size(); ++k) { + literal v(vt[k].first, false); + cut const& c2 = *vt[k].second; + uint64_t t2 = c2.table(); + uint64_t n2 = c2.ntable(); + // + if (t1 == t2 || t1 == n2) { + // already handled + } + else if ((t1 | t2) == t2) { + learn_implies(big, c1, u, v); + } + else if ((t1 | n2) == n2) { + learn_implies(big, c1, u, ~v); + } + else if ((n1 | t2) == t2) { + learn_implies(big, nc1, ~u, v); + } + else if ((n1 | n2) == n2) { + learn_implies(big, nc1, ~u, ~v); + } + } + } + } + } + + void aig_simplifier::learn_implies(big& big, cut const& c, literal u, literal v) { + bin_rel q, p(~u, v); + if (m_bins.find(p, q) && q.op != none) + return; + if (big.connected(u, v)) + return; + s.mk_clause(~u, v, true); + m_bins.insert(p); + certify_implies(u, v, c); + track_binary(~u, v); + } + + void aig_simplifier::track_binary(bin_rel const& p) { + if (s.m_config.m_drat) { + literal u, v; + p.to_binary(u, v); + track_binary(u, v); + } + } + + void aig_simplifier::untrack_binary(bin_rel const& p) { + if (s.m_config.m_drat) { + literal u, v; + p.to_binary(u, v); + untrack_binary(u, v); + } + } + + void aig_simplifier::track_binary(literal u, literal v) { + if (s.m_config.m_drat) { + s.m_drat.add(u, v, true); + } + } + + void aig_simplifier::untrack_binary(literal u, literal v) { + if (s.m_config.m_drat) { + s.m_drat.del(u, v); + } + } + + void aig_simplifier::certify_unit(literal u, cut const& c) { + certify_implies(~u, u, c); + } + /** * Equilvalences modulo cuts are not necessarily DRAT derivable. * To ensure that there is a DRAT derivation we create all resolvents @@ -337,36 +436,37 @@ namespace sat { * contain complementary literals. */ void aig_simplifier::certify_equivalence(literal u, literal v, cut const& c) { + certify_implies(u, v, c); + certify_implies(v, u, c); + } + + /** + * certify that u implies v, where c is the cut for u. + * Then every position in c where u is true, it has to be + * the case that v is too. + * Where u is false, v can have any value. + * Thus, for every clause C or u', where u' is u or ~u, + * it follows that C or ~u or v + */ + void aig_simplifier::certify_implies(literal u, literal v, cut const& c) { if (!s.m_config.m_drat) return; vector clauses; std::function on_clause = - [&](literal_vector const& clause) { SASSERT(clause.back().var() == u.var()); clauses.push_back(clause); }; + [&,this](literal_vector const& clause) { + SASSERT(clause.back().var() == u.var()); + clauses.push_back(clause); + clauses.back().back() = ~u; + if (~u != v) clauses.back().push_back(v); + s.m_drat.add(clauses.back()); + }; m_aig_cuts.cut2def(on_clause, c, u); - // create C or u or ~v for each clause C or u - // create C or ~u or v for each clause C or ~u - for (auto& clause : clauses) { - literal w = clause.back(); - SASSERT(w.var() == u.var()); - clause.push_back(w == u ? ~v : v); - s.m_drat.add(clause); - } - // create C or ~u or v for each clause - unsigned i = 0, sz = clauses.size(); - for (; i < sz; ++i) { - literal_vector clause(clauses[i]); - clause[clause.size()-2] = ~clause[clause.size()-2]; - clause[clause.size()-1] = ~clause[clause.size()-1]; - clauses.push_back(clause); - s.m_drat.add(clause); - } - // create all resolvents over C. C is assumed to // contain all combinations of some set of literals. - i = 0; sz = clauses.size(); - while (sz - i > 2) { - SASSERT((sz & (sz - 1)) == 0); + unsigned i = 0, sz = clauses.size(); + while (sz - i > 1) { + SASSERT((sz & (sz - 1)) == 0 && "sz is a power of 2"); for (; i < sz; ++i) { auto const& clause = clauses[i]; if (clause[0].sign()) { @@ -383,13 +483,12 @@ namespace sat { // once we established equivalence, don't need auxiliary clauses for DRAT. for (auto const& clause : clauses) { - if (clause.size() > 2) { + if (clause.size() > 1) { s.m_drat.del(clause); } - } + } } - void aig_simplifier::add_dont_cares(vector const& cuts) { if (m_config.m_enable_dont_cares) { cuts2bins(cuts); @@ -419,8 +518,12 @@ namespace sat { } // don't lose previous don't cares for (auto const& p : dcs) { - if (m_bins.contains(p)) + if (m_bins.contains(p)) { m_bins.insert(p); + } + else { + untrack_binary(p); + } } } @@ -446,6 +549,7 @@ namespace sat { else if (b.connected(~u, ~v)) { p.op = np; } + track_binary(p); } IF_VERBOSE(2, { unsigned n = 0; for (auto const& p : m_bins) if (p.op != none) ++n; diff --git a/src/sat/sat_aig_simplifier.h b/src/sat/sat_aig_simplifier.h index b6dc58263..bf562ebab 100644 --- a/src/sat/sat_aig_simplifier.h +++ b/src/sat/sat_aig_simplifier.h @@ -36,41 +36,31 @@ namespace sat { bool m_validate; bool m_enable_units; bool m_enable_dont_cares; + bool m_enable_implies; bool m_add_learned; config(): m_validate(false), m_enable_units(false), m_enable_dont_cares(false), + m_enable_implies(false), m_add_learned(true) {} }; private: struct report; struct validator; - solver& s; - stats m_stats; - config m_config; - aig_cuts m_aig_cuts; - unsigned m_trail_size; - literal_vector m_lits; - validator* m_validator; - - void clauses2aig(); - void aig2clauses(); - void cuts2equiv(vector const& cuts); - void uf2equiv(union_find<> const& uf); - void assign_unit(literal lit); - void assign_equiv(cut const& c, literal u, literal v); - void ensure_validator(); - void validate_unit(literal lit); - void validate_eq(literal a, literal b); - void certify_equivalence(literal u, literal v, cut const& c); - /** * collect pairs of literal combinations that are impossible * base on binary implication graph queries. Apply the masks * on cut sets so to allow detecting equivalences modulo * implications. + * + * The encoding is as follows: + * a or b -> op = nn because (~a & ~b) is a don't care + * ~a or b -> op = pn because (a & ~b) is a don't care + * a or ~b -> op = np because (~a & b) is a don't care + * ~a or ~b -> op = pp because (a & b) is a don't care + * */ enum op_code { pp, pn, np, nn, none }; @@ -81,6 +71,18 @@ namespace sat { bin_rel(unsigned _u, unsigned _v): u(_u), v(_v), op(none) { if (u > v) std::swap(u, v); } + // convert binary clause into a bin-rel + bin_rel(literal _u, literal _v): u(_u.var()), v(_v.var()), op(none) { + if (_u.sign() && _v.sign()) op = pp; + else if (_u.sign()) op = pn; + else if (_v.sign()) op = np; + else op = nn; + if (u > v) { + std::swap(u, v); + if (op == np) op = pn; + else if (op == pn) op = np; + } + } bin_rel(): u(UINT_MAX), v(UINT_MAX), op(none) {} struct hash { @@ -93,8 +95,46 @@ namespace sat { return a.u == b.u && a.v == b.v; } }; + void to_binary(literal& lu, literal& lv) const { + switch (op) { + case pp: lu = literal(u, true); lv = literal(v, true); break; + case pn: lu = literal(u, true); lv = literal(v, false); break; + case np: lu = literal(u, false); lv = literal(v, true); break; + case nn: lu = literal(u, false); lv = literal(v, false); break; + default: UNREACHABLE(); break; + } + } }; + + + solver& s; + stats m_stats; + config m_config; + aig_cuts m_aig_cuts; + unsigned m_trail_size; + literal_vector m_lits; + validator* m_validator; hashtable m_bins; + + void clauses2aig(); + void aig2clauses(); + void cuts2equiv(vector const& cuts); + void cuts2implies(vector const& cuts); + void uf2equiv(union_find<> const& uf); + void assign_unit(cut const& c, literal lit); + void assign_equiv(cut const& c, literal u, literal v); + void learn_implies(big& big, cut const& c, literal u, literal v); + void ensure_validator(); + void validate_unit(literal lit); + void validate_eq(literal a, literal b); + void certify_unit(literal u, cut const& c); + void certify_implies(literal u, literal v, cut const& c); + void certify_equivalence(literal u, literal v, cut const& c); + void track_binary(literal u, literal v); + void untrack_binary(literal u, literal v); + void track_binary(bin_rel const& p); + void untrack_binary(bin_rel const& p); + void add_dont_cares(vector const& cuts); void cuts2bins(vector const& cuts); diff --git a/src/sat/sat_cutset.cpp b/src/sat/sat_cutset.cpp index 4909198d2..b3b0f3be8 100644 --- a/src/sat/sat_cutset.cpp +++ b/src/sat/sat_cutset.cpp @@ -139,12 +139,7 @@ namespace sat { } bool cut::operator==(cut const& other) const { - if (m_size != other.m_size) return false; - if (table() != other.table()) return false; - for (unsigned i = 0; i < m_size; ++i) { - if ((*this)[i] != other[i]) return false; - } - return true; + return table() == other.table() && dom_eq(other); } unsigned cut::hash() const { @@ -152,6 +147,20 @@ namespace sat { [](cut const& c) { return (unsigned)c.table(); }, [](cut const& c, unsigned i) { return c[i]; }); } + + unsigned cut::dom_hash() const { + return get_composite_hash(*this, m_size, + [](cut const& c) { return 3; }, + [](cut const& c, unsigned i) { return c[i]; }); + } + + bool cut::dom_eq(cut const& other) const { + if (m_size != other.m_size) return false; + for (unsigned i = 0; i < m_size; ++i) { + if ((*this)[i] != other[i]) return false; + } + return true; + } std::ostream& cut::display(std::ostream& out) const { out << "{"; diff --git a/src/sat/sat_cutset.h b/src/sat/sat_cutset.h index ec189e958..530e7bd87 100644 --- a/src/sat/sat_cutset.h +++ b/src/sat/sat_cutset.h @@ -71,6 +71,7 @@ namespace sat { void negate() { set_table(~m_table); } void set_table(uint64_t t) { m_table = t & table_mask(); } uint64_t table() const { return (m_table | m_dont_care) & table_mask(); } + uint64_t ntable() const { return (~m_table | m_dont_care) & table_mask(); } uint64_t dont_care() const { return m_dont_care; } void add_dont_care(uint64_t t) const { m_dont_care |= t; } @@ -81,6 +82,8 @@ namespace sat { bool operator==(cut const& other) const; bool operator!=(cut const& other) const { return !(*this == other); } unsigned hash() const; + unsigned dom_hash() const; + bool dom_eq(cut const& other) const; struct eq_proc { bool operator()(cut const& a, cut const& b) const { return a == b; } bool operator()(cut const* a, cut const* b) const { return *a == *b; } @@ -90,6 +93,16 @@ namespace sat { unsigned operator()(cut const* a) const { return a->hash(); } }; + struct dom_eq_proc { + bool operator()(cut const& a, cut const& b) const { return a.dom_eq(b); } + bool operator()(cut const* a, cut const* b) const { return a->dom_eq(*b); } + }; + + struct dom_hash_proc { + unsigned operator()(cut const& a) const { return a.dom_hash(); } + unsigned operator()(cut const* a) const { return a->dom_hash(); } + }; + unsigned operator[](unsigned idx) const { return (idx >= m_size) ? UINT_MAX : m_elems[idx]; }