diff --git a/src/sat/CMakeLists.txt b/src/sat/CMakeLists.txt index 649bbb22e..11e290a26 100644 --- a/src/sat/CMakeLists.txt +++ b/src/sat/CMakeLists.txt @@ -23,6 +23,7 @@ z3_add_component(sat sat_integrity_checker.cpp sat_local_search.cpp sat_lookahead.cpp + sat_lut_finder.cpp sat_model_converter.cpp sat_mus.cpp sat_parallel.cpp diff --git a/src/sat/sat_aig_simplifier.cpp b/src/sat/sat_aig_simplifier.cpp index 20d3735b4..6f1613c38 100644 --- a/src/sat/sat_aig_simplifier.cpp +++ b/src/sat/sat_aig_simplifier.cpp @@ -18,6 +18,7 @@ #include "sat/sat_aig_simplifier.h" #include "sat/sat_xor_finder.h" +#include "sat/sat_lut_finder.h" #include "sat/sat_elim_eqs.h" namespace sat { @@ -225,7 +226,23 @@ namespace sat { }; xor_finder xf(s); xf.set(on_xor); - xf(clauses); + xf(clauses); + +#if 0 + std::function on_lut = + [&,this](uint64_t l, bool_var_vector const& vars, bool_var v) { + m_stats.m_num_luts++; + }; + lut_finder lf(s); + lf.set(on_lut); + lf(clauses); + + + statistics st; + collect_statistics(st); + st.display(std::cout); + exit(0); +#endif } void aig_simplifier::aig2clauses() { @@ -675,6 +692,7 @@ namespace sat { st.update("sat-aig.ands", m_stats.m_num_ands); st.update("sat-aig.ites", m_stats.m_num_ites); st.update("sat-aig.xors", m_stats.m_num_xors); + st.update("sat-aig.luts", m_stats.m_num_luts); st.update("sat-aig.dc-reduce", m_stats.m_num_dont_care_reductions); } diff --git a/src/sat/sat_aig_simplifier.h b/src/sat/sat_aig_simplifier.h index 7f523fd2f..9189b08fd 100644 --- a/src/sat/sat_aig_simplifier.h +++ b/src/sat/sat_aig_simplifier.h @@ -27,7 +27,7 @@ namespace sat { class aig_simplifier { public: struct stats { - unsigned m_num_eqs, m_num_units, m_num_cuts, m_num_xors, m_num_ands, m_num_ites; + unsigned m_num_eqs, m_num_units, m_num_cuts, m_num_xors, m_num_ands, m_num_ites, m_num_luts; unsigned m_num_calls, m_num_dont_care_reductions, m_num_learned_implies; stats() { reset(); } void reset() { memset(this, 0, sizeof(*this)); } diff --git a/src/sat/sat_lut_finder.cpp b/src/sat/sat_lut_finder.cpp new file mode 100644 index 000000000..d14d2d8fe --- /dev/null +++ b/src/sat/sat_lut_finder.cpp @@ -0,0 +1,295 @@ +/*++ + Copyright (c) 2020 Microsoft Corporation + + Module Name: + + sat_lut_finder.cpp + + Abstract: + + lut finder + + Author: + + Nikolaj Bjorner 2020-01-02 + + Notes: + + + --*/ + +#include "sat/sat_lut_finder.h" +#include "sat/sat_solver.h" + +namespace sat { + + void lut_finder::operator()(clause_vector& clauses) { + m_removed_clauses.reset(); + unsigned max_size = m_max_lut_size; + // we better have enough bits in the combination mask to + // handle clauses up to max_size. + // max_size = 5 -> 32 bits + // max_size = 6 -> 64 bits + SASSERT(sizeof(m_combination)*8 >= (1ull << static_cast(max_size))); + init_clause_filter(); + for (unsigned i = 0; i <= 6; ++i) init_mask(i); + m_var_position.resize(s.num_vars()); + for (clause* cp : clauses) { + cp->unmark_used(); + } + for (; max_size > 2; --max_size) { + for (clause* cp : clauses) { + clause& c = *cp; + if (c.size() == max_size && !c.was_removed() && !c.is_learned() && !c.was_used()) { + check_lut(c); + } + } + } + m_clause_filters.clear(); + + for (clause* cp : clauses) cp->unmark_used(); + for (clause* cp : m_removed_clauses) cp->mark_used(); + std::function not_used = [](clause* cp) { return !cp->was_used(); }; + clauses.filter_update(not_used); + } + + void lut_finder::check_lut(clause& c) { + SASSERT(c.size() > 2); + unsigned filter = get_clause_filter(c); + s.init_visited(); + unsigned mask = 0, i = 0; + m_vars.reset(); + for (literal l : c) { + m_vars.push_back(l.var()); + m_var_position[l.var()] = i; + s.mark_visited(l.var()); + mask |= (l.sign() << (i++)); + } + m_clauses_to_remove.reset(); + m_clauses_to_remove.push_back(&c); + m_clause.resize(c.size()); + m_combination = 0; + m_num_combinations = 0; + set_combination(mask); + c.mark_used(); + for (literal l : c) { + for (auto const& cf : m_clause_filters[l.var()]) { + if ((filter == (filter | cf.m_filter)) && + !cf.m_clause->was_used() && + extract_lut(*cf.m_clause)) { + add_lut(); + return; + } + } + // TBD: replace by BIG + // loop over binary clauses in watch list + for (watched const & w : s.get_wlist(l)) { + if (w.is_binary_clause() && s.is_visited(w.get_literal().var()) && w.get_literal().index() < l.index()) { + if (extract_lut(~l, w.get_literal())) { + add_lut(); + return; + } + } + } + l.neg(); + for (watched const & w : s.get_wlist(l)) { + if (w.is_binary_clause() && s.is_visited(w.get_literal().var()) && w.get_literal().index() < l.index()) { + if (extract_lut(~l, w.get_literal())) { + add_lut(); + return; + } + } + } + } + } + + void lut_finder::add_lut() { + DEBUG_CODE(for (clause* cp : m_clauses_to_remove) VERIFY(cp->was_used());); + m_removed_clauses.append(m_clauses_to_remove); + bool_var v; + uint64_t lut = convert_combination(m_vars, v); + m_on_lut(lut, m_vars, v); + } + + bool lut_finder::extract_lut(literal l1, literal l2) { + SASSERT(s.is_visited(l1.var())); + SASSERT(s.is_visited(l2.var())); + m_missing.reset(); + unsigned mask = 0; + for (unsigned i = 0; i < m_vars.size(); ++i) { + if (m_vars[i] == l1.var()) { + mask |= (l1.sign() << i); + } + else if (m_vars[i] == l2.var()) { + mask |= (l2.sign() << i); + } + else { + m_missing.push_back(i); + } + } + return update_combinations(mask); + } + + bool lut_finder::extract_lut(clause& c2) { + for (literal l : c2) { + if (!s.is_visited(l.var())) return false; + } + if (c2.size() == m_vars.size()) { + m_clauses_to_remove.push_back(&c2); + c2.mark_used(); + } + // insert missing + unsigned mask = 0; + m_missing.reset(); + SASSERT(c2.size() <= m_vars.size()); + for (unsigned i = 0; i < m_vars.size(); ++i) { + m_clause[i] = null_literal; + } + for (literal l : c2) { + unsigned pos = m_var_position[l.var()]; + m_clause[pos] = l; + } + for (unsigned j = 0; j < m_vars.size(); ++j) { + literal lit = m_clause[j]; + if (lit == null_literal) { + m_missing.push_back(j); + } + else { + mask |= (m_clause[j].sign() << j); + } + } + return update_combinations(mask); + } + + bool lut_finder::update_combinations(unsigned mask) { + unsigned num_missing = m_missing.size(); + for (unsigned k = 0; k < (1ul << num_missing); ++k) { + unsigned mask2 = mask; + for (unsigned i = 0; i < num_missing; ++i) { + if ((k & (1 << i)) != 0) { + mask2 |= 1ul << m_missing[i]; + } + } + set_combination(mask2); + } + return lut_is_defined(m_vars.size()); + } + + bool lut_finder::lut_is_defined(unsigned sz) { + if (m_num_combinations < (1ull << (sz/2))) + return false; + for (unsigned i = 0; i < sz; ++i) { + if (lut_is_defined(i, sz)) + return true; + } + return false; + } + + /** + * \brief create the masks + * i = 0: 101010101010101 + * i = 1: 1100110011001100 + * i = 2: 1111000011110000 + * i = 3: 111111110000000011111111 + */ + + void lut_finder::init_mask(unsigned i) { + SASSERT(i <= 6); + uint64_t m = 0; + if (i == 6) { + m = ~((uint64_t)0); + } + else { + m = (1ull << (1u << i)) - 1; // i = 0: m = 1 + unsigned w = 1u << (i + 1); // i = 0: w = 2 + while (w < 64) { + m |= (m << w); // i = 0: m = 1 + 4 + w *= 2; + } + } + m_masks[i] = m; + } + + /** + * \brief check if all output combinations for variable i are defined. + */ + bool lut_finder::lut_is_defined(unsigned i, unsigned sz) { + uint64_t c = m_combination | (m_combination >> (1ull << (uint64_t)i)); + uint64_t m = m_masks[i]; + if (sz < 6) m &= ((1ull << sz) - 1); + return (c & m) == m; + } + + /** + * find variable where it is defined + * convert bit-mask to truth table for that variable. + * remove variable from vars, + * return truth table. + */ + + uint64_t lut_finder::convert_combination(bool_var_vector& vars, bool_var& v) { + SASSERT(lut_is_defined(vars.size())); + unsigned i = 0, j = 0; + for (; i < vars.size(); ++i) { + if (lut_is_defined(i, vars.size())) { + break; + } + } + SASSERT(i < vars.size()); + v = vars[i]; + vars.erase(v); + uint64_t r = 0; + unsigned stride_sz = (1u << i); + unsigned num_strides = (1u << vars.size()) / (stride_sz * 2); + + switch (i) { + case 0: + for (unsigned j = 0; j < (1u << vars.size()); ++j) { + if (0 == (m_combination & (1ull << 2*j))) { + r |= (1ull << j); + } + } + break; + case 1: + // (0, 2) (1, 3), (4, 6), (5, 7) + for (unsigned j = 0; j < (1u << vars.size()); ++j) { + + } + // TBD + break; + default: + // TBD + break; + } + return r; + } + + void lut_finder::init_clause_filter() { + m_clause_filters.reset(); + m_clause_filters.resize(s.num_vars()); + init_clause_filter(s.m_clauses); + init_clause_filter(s.m_learned); + } + + void lut_finder::init_clause_filter(clause_vector& clauses) { + for (clause* cp : clauses) { + clause& c = *cp; + if (c.size() <= m_max_lut_size && s.all_distinct(c)) { + clause_filter cf(get_clause_filter(c), cp); + for (literal l : c) { + m_clause_filters[l.var()].push_back(cf); + } + } + } + } + + unsigned lut_finder::get_clause_filter(clause const& c) { + unsigned filter = 0; + for (literal l : c) { + filter |= 1 << ((l.var() % 32)); + } + return filter; + } + + +} diff --git a/src/sat/sat_lut_finder.h b/src/sat/sat_lut_finder.h new file mode 100644 index 000000000..8af848136 --- /dev/null +++ b/src/sat/sat_lut_finder.h @@ -0,0 +1,83 @@ +/*++ + Copyright (c) 2020 Microsoft Corporation + + Module Name: + + sat_lut_finder.h + + Abstract: + + lut finder + + Author: + + Nikolaj Bjorner 2020-02-03 + + Notes: + + Find LUT with small input fan-ins + + --*/ + +#pragma once + +#include "util/params.h" +#include "util/statistics.h" +#include "sat/sat_clause.h" +#include "sat/sat_types.h" +#include "sat/sat_solver.h" + +namespace sat { + + class lut_finder { + solver& s; + struct clause_filter { + unsigned m_filter; + clause* m_clause; + clause_filter(unsigned f, clause* cp): + m_filter(f), m_clause(cp) {} + }; + typedef svector bool_vector; + unsigned m_max_lut_size; + vector> m_clause_filters; // index of clauses. + uint64_t m_combination; // bit-mask of parities that have been found + unsigned m_num_combinations; + clause_vector m_clauses_to_remove; // remove clauses that become luts + unsigned_vector m_var_position; // position of var in main clause + bool_var_vector m_vars; // reference to variables being tested for LUT + literal_vector m_clause; // reference clause with literals sorted according to main clause + unsigned_vector m_missing; // set of indices not occurring in clause. + uint64_t m_masks[7]; + clause_vector m_removed_clauses; + std::function const& vars, bool_var v)> m_on_lut; + + inline void set_combination(unsigned mask) { + if (!get_combination(mask)) { + m_combination |= (1ull << mask); + m_num_combinations++; + } + } + inline bool get_combination(unsigned mask) const { return (m_combination & (1ull << mask)) != 0; } + bool lut_is_defined(unsigned sz); + bool lut_is_defined(unsigned i, unsigned sz); + uint64_t convert_combination(bool_var_vector& vars, bool_var& v); + void check_lut(clause& c); + void add_lut(); + bool extract_lut(literal l1, literal l2); + bool extract_lut(clause& c2); + bool update_combinations(unsigned mask); + void init_mask(unsigned i); + void init_clause_filter(); + void init_clause_filter(clause_vector& clauses); + unsigned get_clause_filter(clause const& c); + + public: + lut_finder(solver& s) : s(s), m_max_lut_size(5) { } + ~lut_finder() {} + + void set(std::function& f) { m_on_lut = f; } + + unsigned max_lut_size() const { return m_max_lut_size; } + void operator()(clause_vector& clauses); + }; +} diff --git a/src/sat/sat_solver.h b/src/sat/sat_solver.h index 8aff22f2d..9996e886c 100644 --- a/src/sat/sat_solver.h +++ b/src/sat/sat_solver.h @@ -219,6 +219,7 @@ namespace sat { friend class scoped_detach; friend class xor_finder; friend class aig_finder; + friend class lut_finder; public: solver(params_ref const & p, reslimit& l); ~solver() override;