diff --git a/src/sat/sat_aig_cuts.cpp b/src/sat/sat_aig_cuts.cpp index 15bb3f268..c1bf18ac7 100644 --- a/src/sat/sat_aig_cuts.cpp +++ b/src/sat/sat_aig_cuts.cpp @@ -43,11 +43,11 @@ namespace sat { if (m_aig[id].empty()) { continue; } - IF_VERBOSE(3, m_cuts[id].display(verbose_stream() << "augment " << id << "\nbefore\n")); + IF_VERBOSE(10, m_cuts[id].display(verbose_stream() << "augment " << id << "\nbefore\n")); for (node const& n : m_aig[id]) { augment(id, n); } - IF_VERBOSE(3, m_cuts[id].display(verbose_stream() << "after\n")); + IF_VERBOSE(10, m_cuts[id].display(verbose_stream() << "after\n")); } } @@ -82,7 +82,7 @@ namespace sat { } bool aig_cuts::insert_cut(unsigned v, cut const& c, cut_set& cs) { - if (!cs.insert(&m_on_cut_add, &m_on_cut_del, c)) { + if (!cs.insert(m_on_cut_add, m_on_cut_del, c)) { return true; } m_num_cuts++; @@ -98,7 +98,7 @@ namespace sat { } void aig_cuts::augment_ite(unsigned v, node const& n, cut_set& cs) { - IF_VERBOSE(2, display(verbose_stream() << "augment_ite " << v << " ", n) << "\n"); + IF_VERBOSE(4, display(verbose_stream() << "augment_ite " << v << " ", n) << "\n"); literal l1 = child(n, 0); literal l2 = child(n, 1); literal l3 = child(n, 2); @@ -172,7 +172,7 @@ namespace sat { void aig_cuts::augment_aigN(unsigned v, node const& n, cut_set& cs) { IF_VERBOSE(4, display(verbose_stream() << "augment_aigN " << v << " ", n) << "\n"); - m_cut_set1.reset(nullptr); + m_cut_set1.reset(m_on_cut_del); SASSERT(n.is_and() || n.is_xor()); literal lit = child(n, 0); for (auto const& a : m_cuts[lit.var()]) { @@ -180,10 +180,10 @@ namespace sat { if (lit.sign()) { b.negate(); } - m_cut_set1.push_back(nullptr, b); + m_cut_set1.push_back(m_on_cut_add, b); } for (unsigned i = 1; i < n.size(); ++i) { - m_cut_set2.reset(nullptr); + m_cut_set2.reset(m_on_cut_del); lit = child(n, i); m_insertions = 0; for (auto const& a : m_cut_set1) { @@ -212,6 +212,12 @@ namespace sat { } } + void aig_cuts::replace(unsigned v, cut const& src, cut const& dst) { + m_cuts[v].replace(m_on_cut_add, m_on_cut_del, src, dst); + touch(v); + } + + bool aig_cuts::is_touched(node const& n) { for (unsigned i = 0; i < n.size(); ++i) { literal lit = m_literals[n.offset() + i]; diff --git a/src/sat/sat_aig_cuts.h b/src/sat/sat_aig_cuts.h index 2fd1fcfc0..feefb7ec8 100644 --- a/src/sat/sat_aig_cuts.h +++ b/src/sat/sat_aig_cuts.h @@ -138,10 +138,10 @@ namespace sat { void on_node_add(unsigned v, node const& n); void on_node_del(unsigned v, node const& n); - void evict(cut_set& cs, unsigned idx) { cs.evict(&m_on_cut_del, idx); } - void reset(cut_set& cs) { cs.reset(&m_on_cut_del); } - void push_back(cut_set& cs, cut const& c) { cs.push_back(&m_on_cut_add, c); } - void shrink(cut_set& cs, unsigned j) { cs.shrink(&m_on_cut_del, j); } + void evict(cut_set& cs, unsigned idx) { cs.evict(m_on_cut_del, idx); } + void reset(cut_set& cs) { cs.reset(m_on_cut_del); } + void push_back(cut_set& cs, cut const& c) { cs.push_back(m_on_cut_add, c); } + void shrink(cut_set& cs, unsigned j) { cs.shrink(m_on_cut_del, j); } void cut2clauses(on_clause_t& on_clause, unsigned v, cut const& c); void node2def(on_clause_t& on_clause, node const& n, literal r); @@ -166,6 +166,8 @@ namespace sat { void cut2def(on_clause_t& on_clause, cut const& c, literal r); + void replace(unsigned v, cut const& src, cut const& dst); + std::ostream& display(std::ostream& out) const; diff --git a/src/sat/sat_aig_simplifier.cpp b/src/sat/sat_aig_simplifier.cpp index edac66071..a94e131dd 100644 --- a/src/sat/sat_aig_simplifier.cpp +++ b/src/sat/sat_aig_simplifier.cpp @@ -68,8 +68,9 @@ namespace sat { for (literal lit : clause) m_assumptions.push_back(~lit); lbool r = s.check(clause.size(), m_assumptions.c_ptr()); if (r != l_false) { - std::cout << "not validated: " << clause << "\n"; - s.display(std::cout); + IF_VERBOSE(0, + verbose_stream() << "not validated: " << clause << "\n"; + s.display(verbose_stream());); std::string line; std::getline(std::cin, line); } @@ -78,7 +79,6 @@ namespace sat { void aig_simplifier::ensure_validator() { if (!m_validator) { - std::cout << "init validator\n"; params_ref p; p.set_bool("aig", false); p.set_bool("drat.check_unsat", false); @@ -92,15 +92,9 @@ namespace sat { s(_s), m_trail_size(0), m_validator(nullptr) { - if (false) { - ensure_validator(); - std::function _on_add = - [this](literal_vector const& clause) { - std::cout << "add " << clause << "\n"; m_validator->validate(clause); - }; - m_aig_cuts.set_on_clause_add(_on_add); - } - else if (s.get_config().m_drat) { + m_config.m_enable_dont_cares = true; + m_config.m_enable_units = true; + if (s.get_config().m_drat) { std::function _on_add = [this](literal_vector const& clause) { s.m_drat.add(clause); }; std::function _on_del = @@ -108,6 +102,15 @@ namespace sat { m_aig_cuts.set_on_clause_add(_on_add); m_aig_cuts.set_on_clause_del(_on_del); } + else if (m_config.m_validate) { + ensure_validator(); + std::function _on_add = + [this](literal_vector const& clause) { + m_validator->validate(clause); + }; + m_aig_cuts.set_on_clause_add(_on_add); + } + } aig_simplifier::~aig_simplifier() { @@ -158,7 +161,7 @@ namespace sat { ++m_stats.m_num_calls; do { n = m_stats.m_num_eqs + m_stats.m_num_units; - if (m_config.m_full || true) clauses2aig(); + clauses2aig(); aig2clauses(); ++i; } @@ -172,7 +175,7 @@ namespace sat { void aig_simplifier::clauses2aig() { // update units - for (; m_config.m_full && m_trail_size < s.init_trail_size(); ++m_trail_size) { + for (; m_config.m_enable_units && m_trail_size < s.init_trail_size(); ++m_trail_size) { literal lit = s.trail_literal(m_trail_size); m_aig_cuts.add_node(lit, and_op, 0, 0); } @@ -192,7 +195,7 @@ namespace sat { af.set(on_and); af.set(on_ite); clause_vector clauses(s.clauses()); - if (m_config.m_full || true) clauses.append(s.learned()); + if (m_config.m_add_learned) clauses.append(s.learned()); af(clauses); std::function on_xor = @@ -229,6 +232,8 @@ namespace sat { vector const& cuts = m_aig_cuts(); m_stats.m_num_cuts = m_aig_cuts.num_cuts(); + add_dont_cares(cuts); + map cut2id; union_find_default_ctx ctx; @@ -242,20 +247,20 @@ namespace sat { for (unsigned i = cuts.size(); i-- > 0; ) { for (auto& c : cuts[i]) { unsigned j = 0; - if (m_config.m_full && c.is_true()) { + if (m_config.m_enable_units && c.is_true()) { if (s.value(i) == l_undef) { literal lit(i, false); - validate_unit(lit); + // validate_unit(lit); IF_VERBOSE(2, verbose_stream() << "new unit " << lit << "\n"); s.assign_unit(lit); ++m_stats.m_num_units; } break; } - if (m_config.m_full && c.is_false()) { + if (m_config.m_enable_units && c.is_false()) { if (s.value(i) == l_undef) { literal lit(i, true); - validate_unit(lit); + // validate_unit(lit); IF_VERBOSE(2, verbose_stream() << "new unit " << lit << "\n"); s.assign_unit(lit); ++m_stats.m_num_units; @@ -266,7 +271,7 @@ namespace sat { VERIFY(i != j); literal u(i, false); literal v(j, false); - IF_VERBOSE(0, + IF_VERBOSE(10, verbose_stream() << u << " " << c << "\n"; verbose_stream() << v << ": "; for (cut const& d : cuts[v.var()]) verbose_stream() << d << "\n";); @@ -278,20 +283,19 @@ namespace sat { new_eq = true; break; } - if (true || m_config.m_full) { - cut nc(c); - nc.negate(); - if (cut2id.find(&nc, j)) { - VERIFY(i != j); // maybe possible with don't cares - literal u(i, false); - literal v(j, true); - certify_equivalence(u, v, c); - // validate_eq(u, v); - add_eq(u, v); - TRACE("aig_simplifier", tout << u << " == " << v << "\n";); - new_eq = true; - break; - } + + cut nc(c); + nc.negate(); + if (cut2id.find(&nc, j)) { + if (i == j) continue; + literal u(i, false); + literal v(j, true); + certify_equivalence(u, v, c); + // validate_eq(u, v); + add_eq(u, v); + TRACE("aig_simplifier", tout << u << " == " << v << "\n";); + new_eq = true; + break; } cut2id.insert(&c, i); } @@ -389,72 +393,122 @@ namespace sat { } } + + void aig_simplifier::add_dont_cares(vector const& cuts) { + if (m_config.m_enable_dont_cares) { + cuts2pairs(cuts); + pairs2dont_cares(); + dont_cares2cuts(cuts); + } + } + /** * collect pairs of variables that occur in cut sets. */ - void aig_simplifier::collect_pairs(vector const& cuts) { + void aig_simplifier::cuts2pairs(vector const& cuts) { + svector dcs; + for (auto const& p : m_pairs) { + if (p.op != none) + dcs.push_back(p); + } m_pairs.reset(); - for (unsigned k = cuts.size(); k-- > 0; ) { - for (auto const& c : cuts[k]) { + for (auto const& cs : cuts) { + for (auto const& c : cs) { for (unsigned i = c.size(); i-- > 0; ) { for (unsigned j = i; j-- > 0; ) { - m_pairs.insert(var_pair(c[i],c[j])); + m_pairs.insert(var_pair(c[j],c[i])); } } } } + // don't lose previous don't cares + for (auto const& p : dcs) { + if (m_pairs.contains(p)) + m_pairs.insert(p); + } } /** * compute masks for pairs. */ - void aig_simplifier::add_masks_to_pairs() { + void aig_simplifier::pairs2dont_cares() { big b(s.rand()); b.init(s, true); for (auto& p : m_pairs) { + if (p.op != none) continue; literal u(p.u, false), v(p.v, false); // u -> v, then u & ~v is impossible if (b.connected(u, v)) { - add_mask(u, ~v, p); + p.op = pn; } else if (b.connected(u, ~v)) { - add_mask(u, v, p); + p.op = pp; } else if (b.connected(~u, v)) { - add_mask(~u, ~v, p); + p.op = nn; } else if (b.connected(~u, ~v)) { - add_mask(~u, v, p); - } - else { - memset(p.masks, 0xFF, var_pair::size()); + p.op = np; } } + IF_VERBOSE(2, { + unsigned n = 0; for (auto const& p : m_pairs) if (p.op != none) ++n; + verbose_stream() << n << " / " << m_pairs.size() << " don't cares\n"; + }); } - /* - * compute masks for each possible occurrence of u, v within 2-6 elements. - * combinaions relative to u.sign(), v.sign() are impossible. - */ - void aig_simplifier::add_mask(literal u, literal v, var_pair& p) { - unsigned offset = 0; - bool su = u.sign(), sv = v.sign(); - for (unsigned k = 2; k <= 6; ++k) { - for (unsigned i = 0; i < k; ++i) { - for (unsigned j = i + 1; j < k; ++j) { - // convert su, sv, k, i, j into a mask for 2^k bits. - // for outputs - p.masks[offset++] = 0; + void aig_simplifier::dont_cares2cuts(vector const& cuts) { + struct rep { + cut src, dst; unsigned v; + rep(cut const& s, cut const& d, unsigned v):src(s), dst(d), v(v) {} + rep():v(UINT_MAX) {} + }; + vector to_replace; + cut d; + for (auto const& cs : cuts) { + for (auto const& c : cs) { + if (rewrite_cut(c, d)) { + to_replace.push_back(rep(c, d, cs.var())); } } } + for (auto const& p : to_replace) { + m_aig_cuts.replace(p.v, p.src, p.dst); + } + m_stats.m_num_dont_care_reductions += to_replace.size(); + } + + /* + * compute masks for position i, j and op-code p.op + */ + uint64_t aig_simplifier::op2dont_care(unsigned i, unsigned j, var_pair const& p) { + SASSERT(i < j && j < 6); + if (p.op == none) return 0ull; + // first position of mask is offset into output bits contributed by i and j + bool i_is_0 = (p.op == np || p.op == nn); + bool j_is_0 = (p.op == pn || p.op == nn); + uint64_t first = (i_is_0 ? 0 : (1 << i)) + (j_is_0 ? 0 : (1 << j)); + uint64_t inc = 1ull << (j + 1); + uint64_t r = 1ull << first; + while (inc < 64ull) { r |= (r << inc); inc *= 2; } + return r; } /** - * apply obtained masks to cut sets. + * apply obtained dont_cares to cut sets. */ - void aig_simplifier::apply_masks() { - + bool aig_simplifier::rewrite_cut(cut const& c, cut& d) { + bool init = false; + for (unsigned i = 0; i < c.size(); ++i) { + for (unsigned j = i + 1; j < c.size(); ++j) { + var_pair p(c[i], c[j]); + if (m_pairs.find(p, p) && p.op != none) { + if (!init) { d = c; init = true; } + d.set_table(d.m_table | op2dont_care(i, j, p)); + } + } + } + return init && d.m_table != c.m_table; } void aig_simplifier::collect_statistics(statistics& st) const { @@ -463,6 +517,7 @@ namespace sat { st.update("sat-aig.ands", m_stats.m_num_ands); st.update("sat-aig.ites", m_stats.m_num_ites); st.update("sat-aig.xors", m_stats.m_num_xors); + st.update("sat-aig.dc-reduce", m_stats.m_num_dont_care_reductions); } void aig_simplifier::validate_unit(literal lit) { diff --git a/src/sat/sat_aig_simplifier.h b/src/sat/sat_aig_simplifier.h index 0ace074b9..ba5d40561 100644 --- a/src/sat/sat_aig_simplifier.h +++ b/src/sat/sat_aig_simplifier.h @@ -27,13 +27,20 @@ namespace sat { public: struct stats { unsigned m_num_eqs, m_num_units, m_num_cuts, m_num_xors, m_num_ands, m_num_ites; - unsigned m_num_calls; + unsigned m_num_calls, m_num_dont_care_reductions; stats() { reset(); } void reset() { memset(this, 0, sizeof(*this)); } }; struct config { - bool m_full; - config():m_full(false) {} + bool m_validate; + bool m_enable_units; + bool m_enable_dont_cares; + bool m_add_learned; + config(): + m_validate(false), + m_enable_units(false), + m_enable_dont_cares(false), + m_add_learned(true) {} }; private: struct report; @@ -60,14 +67,16 @@ namespace sat { * Apply the masks on cut sets so to allow detecting * equivalences modulo implications. */ + + enum op_code { pp, pn, np, nn, none }; + struct var_pair { unsigned u, v; - uint64_t masks[35]; - static unsigned size() { return sizeof(uint64_t)*35; } - var_pair(unsigned u, unsigned v): u(u), v(v) { + op_code op; + var_pair(unsigned _u, unsigned _v): u(_u), v(_v), op(none) { if (u > v) std::swap(u, v); } - var_pair(): u(UINT_MAX), v(UINT_MAX) {} + var_pair(): u(UINT_MAX), v(UINT_MAX), op(none) {} struct hash { unsigned operator()(var_pair const& p) const { @@ -82,10 +91,13 @@ namespace sat { }; hashtable m_pairs; - void collect_pairs(vector const& cuts); - void add_mask(literal u, literal v, var_pair& p); - void add_masks_to_pairs(); - void apply_masks(); + void add_dont_cares(vector const& cuts); + void cuts2pairs(vector const& cuts); + void pairs2dont_cares(); + void dont_cares2cuts(vector const& cuts); + bool rewrite_cut(cut const& c, cut& r); + uint64_t op2dont_care(unsigned i, unsigned j, var_pair const& p); + public: aig_simplifier(solver& s); ~aig_simplifier(); diff --git a/src/sat/sat_cutset.cpp b/src/sat/sat_cutset.cpp index ca1f4690b..3b56b02ab 100644 --- a/src/sat/sat_cutset.cpp +++ b/src/sat/sat_cutset.cpp @@ -31,7 +31,7 @@ namespace sat { - pre-allocate fixed array instead of vector for cut_set to avoid overhead for memory allocation. */ - bool cut_set::insert(on_update_t* on_add, on_update_t* on_del, cut const& c) { + bool cut_set::insert(on_update_t& on_add, on_update_t& on_del, cut const& c) { unsigned i = 0, j = 0, k = m_size; for (; i < k; ++i) { cut const& a = (*this)[i]; @@ -42,8 +42,11 @@ namespace sat { std::swap(m_cuts[i--], m_cuts[--k]); } } - shrink(on_del, i); + // for DRAT make sure to add new element before removing old cuts + // the new cut may need to be justified relative to the old cut push_back(on_add, c); + std::swap(m_cuts[i++], m_cuts[m_size-1]); + shrink(on_del, i); return true; } @@ -64,16 +67,16 @@ namespace sat { } - void cut_set::shrink(on_update_t* on_del, unsigned j) { - if (m_var != UINT_MAX && on_del && *on_del) { + void cut_set::shrink(on_update_t& on_del, unsigned j) { + if (m_var != UINT_MAX && on_del) { for (unsigned i = j; i < m_size; ++i) { - (*on_del)(m_var, m_cuts[i]); + on_del(m_var, m_cuts[i]); } } m_size = j; } - void cut_set::push_back(on_update_t* on_add, cut const& c) { + void cut_set::push_back(on_update_t& on_add, cut const& c) { SASSERT(m_max_size > 0); if (m_size == m_max_size) { m_max_size *= 2; @@ -81,10 +84,26 @@ namespace sat { memcpy(new_cuts, m_cuts, sizeof(cut)*m_size); m_cuts = new_cuts; } - if (m_var != UINT_MAX && on_add && *on_add) (*on_add)(m_var, c); + if (m_var != UINT_MAX && on_add) on_add(m_var, c); m_cuts[m_size++] = c; } + void cut_set::replace(on_update_t& on_add, on_update_t& on_del, cut const& src, cut const& dst) { + SASSERT(src != dst); + insert(on_add, on_del, dst); + for (unsigned i = 0; i < size(); ++i) { + if (src == (*this)[i]) { + evict(on_del, i); + break; + } + } + } + + void cut_set::evict(on_update_t& on_del, unsigned idx) { + if (m_var != UINT_MAX && on_del) on_del(m_var, m_cuts[idx]); + m_cuts[idx] = m_cuts[--m_size]; + } + void cut_set::init(region& r, unsigned max_sz, unsigned v) { m_var = v; m_max_size = max_sz; diff --git a/src/sat/sat_cutset.h b/src/sat/sat_cutset.h index ca496ed39..dc6f0c4b7 100644 --- a/src/sat/sat_cutset.h +++ b/src/sat/sat_cutset.h @@ -137,18 +137,27 @@ namespace sat { cut_set(): m_var(UINT_MAX), m_region(nullptr), m_size(0), m_max_size(0), m_cuts(nullptr) {} void init(region& r, unsigned max_sz, unsigned v); - bool insert(on_update_t* on_add, on_update_t* on_del, cut const& c); + bool insert(on_update_t& on_add, on_update_t& on_del, cut const& c); bool no_duplicates() const; + unsigned var() const { return m_var; } unsigned size() const { return m_size; } cut const * begin() const { return m_cuts; } cut const * end() const { return m_cuts + m_size; } cut const & back() { return m_cuts[m_size-1]; } - void push_back(on_update_t* on_add, cut const& c); - void reset(on_update_t* on_del) { shrink(on_del, 0); } + void push_back(on_update_t& on_add, cut const& c); + void reset(on_update_t& on_del) { shrink(on_del, 0); } cut const & operator[](unsigned idx) { return m_cuts[idx]; } - void shrink(on_update_t* on_del, unsigned j); - void swap(cut_set& other) { std::swap(m_size, other.m_size); std::swap(m_cuts, other.m_cuts); std::swap(m_max_size, other.m_max_size); } - void evict(on_update_t* on_del, unsigned idx) { if (m_var != UINT_MAX && on_del && *on_del) (*on_del)(m_var, m_cuts[idx]); m_cuts[idx] = m_cuts[--m_size]; } + void shrink(on_update_t& on_del, unsigned j); + void swap(cut_set& other) { + std::swap(m_var, other.m_var); + std::swap(m_size, other.m_size); + std::swap(m_max_size, other.m_max_size); + std::swap(m_cuts, other.m_cuts); + } + void evict(on_update_t& on_del, unsigned idx); + + void replace(on_update_t& on_add, on_update_t& on_del, cut const& src, cut const& dst); + std::ostream& display(std::ostream& out) const; }; diff --git a/src/util/symbol.cpp b/src/util/symbol.cpp index febe79dbb..042d6292d 100644 --- a/src/util/symbol.cpp +++ b/src/util/symbol.cpp @@ -122,7 +122,7 @@ symbol::symbol(char const * d) { } symbol & symbol::operator=(char const * d) { - m_data = g_symbol_tables->get_str(d); + m_data = d ? g_symbol_tables->get_str(d) : nullptr; return *this; }