diff --git a/src/sat/sat_aig_cuts.cpp b/src/sat/sat_aig_cuts.cpp index 41ab68307..5068ef330 100644 --- a/src/sat/sat_aig_cuts.cpp +++ b/src/sat/sat_aig_cuts.cpp @@ -18,6 +18,7 @@ #include "util/trace.h" #include "sat/sat_aig_cuts.h" #include "sat/sat_solver.h" +#include "sat/sat_lut_finder.h" namespace sat { @@ -762,6 +763,34 @@ namespace sat { cut2def(on_clause, c, literal(v, true)); } + /** + * simplify a set of cuts by removing don't cares. + */ + void aig_cuts::simplify() { + uint64_t masks[7]; + for (unsigned i = 0; i <= 6; ++i) { + masks[i] = cut::effect_mask(i); + } + unsigned dont_cares = 0; + for (cut_set & cs : m_cuts) { + for (cut const& c : cs) { + uint64_t t = c.table(); + for (unsigned i = 0; i < std::min(6u, c.size()); ++i) { + uint64_t diff = masks[i] & (t ^ (t >> (1ull << i))); + if (diff == 0ull) { + cut d(c); + d.remove_elem(i); + cs.insert(m_on_cut_add, m_on_cut_del, d); + cs.evict(m_on_cut_del, c); + ++dont_cares; + break; + } + } + } + } + IF_VERBOSE(0, verbose_stream() << "#don't cares " << dont_cares << "\n"); + } + struct aig_cuts::validator { aig_cuts& t; params_ref p; @@ -821,7 +850,7 @@ namespace sat { cut2def(on_clause, c, literal(v, false)); node2def(on_clause, n, literal(v, true)); val.check(); - } + } std::ostream& aig_cuts::display(std::ostream& out) const { auto ids = filter_valid_nodes(); diff --git a/src/sat/sat_aig_cuts.h b/src/sat/sat_aig_cuts.h index 2cd68501d..9ae00f7d0 100644 --- a/src/sat/sat_aig_cuts.h +++ b/src/sat/sat_aig_cuts.h @@ -227,6 +227,8 @@ namespace sat { cut_eval simulate(unsigned num_rounds); + void simplify(); + std::ostream& display(std::ostream& out) const; }; diff --git a/src/sat/sat_config.cpp b/src/sat/sat_config.cpp index f0d167c89..b9b0eb016 100644 --- a/src/sat/sat_config.cpp +++ b/src/sat/sat_config.cpp @@ -107,6 +107,7 @@ namespace sat { m_cut_delay = p.cut_delay(); m_cut_lut = p.cut_lut(); m_cut_xor = p.cut_xor(); + m_cut_dont_cares = p.cut_dont_cares(); m_lookahead_simplify = p.lookahead_simplify(); m_lookahead_double = p.lookahead_double(); m_lookahead_simplify_bca = p.lookahead_simplify_bca(); diff --git a/src/sat/sat_config.h b/src/sat/sat_config.h index 889c9f2b6..7390dcf58 100644 --- a/src/sat/sat_config.h +++ b/src/sat/sat_config.h @@ -124,6 +124,7 @@ namespace sat { unsigned m_cut_delay; bool m_cut_lut; bool m_cut_xor; + bool m_cut_dont_cares; bool m_anf_simplify; unsigned m_anf_delay; bool m_anf_exlin; diff --git a/src/sat/sat_cut_simplifier.cpp b/src/sat/sat_cut_simplifier.cpp index 673a69cdc..262b11500 100644 --- a/src/sat/sat_cut_simplifier.cpp +++ b/src/sat/sat_cut_simplifier.cpp @@ -192,12 +192,13 @@ namespace sat { m_aig_cuts.add_node(head, ite_op, 3, args); m_stats.m_xites++; }; + aig_finder af(s); af.set(on_and); af.set(on_ite); clause_vector clauses(s.clauses()); if (m_config.m_learned2aig) clauses.append(s.learned()); - af(clauses); + af(clauses); std::function on_xor = [&,this](literal_vector const& xors) { @@ -231,13 +232,15 @@ namespace sat { xf.set(on_xor); xf(clauses); } + + std::function on_lut = + [&,this](uint64_t lut, bool_var_vector const& vars, bool_var v) { + m_stats.m_xluts++; + // m_aig_cuts.add_cut(v, lut, vars); + m_aig_cuts.add_node(v, lut, vars.size(), vars.c_ptr()); + }; + if (s.m_config.m_cut_lut) { - std::function on_lut = - [&,this](uint64_t lut, bool_var_vector const& vars, bool_var v) { - m_stats.m_xluts++; - // m_aig_cuts.add_cut(v, lut, vars); - m_aig_cuts.add_node(v, lut, vars.size(), vars.c_ptr()); - }; lut_finder lf(s); lf.set(on_lut); lf(clauses); @@ -567,11 +570,12 @@ namespace sat { } void cut_simplifier::add_dont_cares(vector const& cuts) { - if (!m_config.m_enable_dont_cares) + if (!s.m_config.m_cut_dont_cares) return; cuts2bins(cuts); bins2dont_cares(); dont_cares2cuts(cuts); + m_aig_cuts.simplify(); } /** diff --git a/src/sat/sat_cutset.cpp b/src/sat/sat_cutset.cpp index dc0ca51bd..9ff37012c 100644 --- a/src/sat/sat_cutset.cpp +++ b/src/sat/sat_cutset.cpp @@ -89,6 +89,15 @@ namespace sat { m_cuts[m_size++] = c; } + void cut_set::evict(on_update_t& on_del, cut const& c) { + for (unsigned i = 0; i < m_size; ++i) { + if (m_cuts[i] == c) { + evict(on_del, i); + break; + } + } + } + void cut_set::evict(on_update_t& on_del, unsigned idx) { if (m_var != UINT_MAX && on_del) on_del(m_var, m_cuts[idx]); m_cuts[idx] = m_cuts[--m_size]; @@ -166,6 +175,56 @@ namespace sat { return true; } + /** + * \brief create the masks + * i = 0: 101010101010101 + * i = 1: 1100110011001100 + * i = 2: 1111000011110000 + * i = 3: 111111110000000011111111 + */ + + uint64_t cut::effect_mask(unsigned i) { + SASSERT(i <= 6); + uint64_t m = 0; + if (i == 6) { + m = ~((uint64_t)0); + } + else { + m = (1ull << (1u << i)) - 1; // i = 0: m = 1 + unsigned w = 1u << (i + 1); // i = 0: w = 2 + while (w < 64) { + m |= (m << w); // i = 0: m = 1 + 4 + w *= 2; + } + } + return m; + } + + /** + remove element from cut as it is deemed a don't care + */ + void cut::remove_elem(unsigned i) { + for (unsigned j = i + 1; j < m_size; ++j) { + m_elems[j-1] = m_elems[j]; + } + --m_size; + uint64_t m = effect_mask(i); + uint64_t t = 0; + for (unsigned j = 0, offset = 0; j < 64; ++j) { + if (0 != (m & (1ull << j))) { + t |= ((m_table >> j) & 1u) << offset; + ++offset; + } + } + m_table = t; + m_dont_care = 0; + unsigned f = 0; + for (unsigned e : *this) { + f |= (1u << (e & 0x1F)); + } + m_filter = f; + } + /** sat-sweep evaluation. Given 64 bits worth of possible values per variable, find possible values for function table encoded by cut. diff --git a/src/sat/sat_cutset.h b/src/sat/sat_cutset.h index 9ec22e041..942eaf543 100644 --- a/src/sat/sat_cutset.h +++ b/src/sat/sat_cutset.h @@ -160,6 +160,10 @@ namespace sat { return true; } + void remove_elem(unsigned i); + + static uint64_t effect_mask(unsigned i); + std::ostream& display(std::ostream& out) const; static std::ostream& display_table(std::ostream& out, unsigned num_input, uint64_t table); @@ -196,6 +200,7 @@ namespace sat { std::swap(m_cuts, other.m_cuts); } void evict(on_update_t& on_del, unsigned idx); + void evict(on_update_t& on_del, cut const& c); std::ostream& display(std::ostream& out) const; }; diff --git a/src/sat/sat_lut_finder.cpp b/src/sat/sat_lut_finder.cpp index a450924a7..5459ab2a4 100644 --- a/src/sat/sat_lut_finder.cpp +++ b/src/sat/sat_lut_finder.cpp @@ -32,7 +32,9 @@ namespace sat { // max_size = 6 -> 64 bits SASSERT(sizeof(m_combination)*8 >= (1ull << static_cast(max_size))); init_clause_filter(); - for (unsigned i = 0; i <= 6; ++i) init_mask(i); + for (unsigned i = 0; i <= 6; ++i) { + m_masks[i] = cut::effect_mask(i); + } m_var_position.resize(s.num_vars()); for (clause* cp : clauses) { cp->unmark_used(); @@ -203,31 +205,6 @@ namespace sat { return false; } - /** - * \brief create the masks - * i = 0: 101010101010101 - * i = 1: 1100110011001100 - * i = 2: 1111000011110000 - * i = 3: 111111110000000011111111 - */ - - void lut_finder::init_mask(unsigned i) { - SASSERT(i <= 6); - uint64_t m = 0; - if (i == 6) { - m = ~((uint64_t)0); - } - else { - m = (1ull << (1u << i)) - 1; // i = 0: m = 1 - unsigned w = 1u << (i + 1); // i = 0: w = 2 - while (w < 64) { - m |= (m << w); // i = 0: m = 1 + 4 - w *= 2; - } - } - m_masks[i] = m; - } - /** * \brief check if all output combinations for variable i are defined. */ diff --git a/src/sat/sat_lut_finder.h b/src/sat/sat_lut_finder.h index 0107e0a6f..1ccd842aa 100644 --- a/src/sat/sat_lut_finder.h +++ b/src/sat/sat_lut_finder.h @@ -61,7 +61,6 @@ namespace sat { bool extract_lut(literal l1, literal l2); bool extract_lut(clause& c2); bool update_combinations(unsigned mask); - void init_mask(unsigned i); void init_clause_filter(); void init_clause_filter(clause_vector& clauses); unsigned get_clause_filter(clause const& c); @@ -75,5 +74,6 @@ namespace sat { unsigned max_lut_size() const { return m_max_lut_size; } void operator()(clause_vector& clauses); + }; } diff --git a/src/sat/sat_params.pyg b/src/sat/sat_params.pyg index 8978d7497..b63bca050 100644 --- a/src/sat/sat_params.pyg +++ b/src/sat/sat_params.pyg @@ -77,6 +77,7 @@ def_module_params('sat', ('cut.delay', UINT, 2, 'delay cut simplification by in-processing round'), ('cut.lut', BOOL, False, 'extract luts from clauses for cut simplification'), ('cut.xor', BOOL, False, 'extract xors from clauses for cut simplification'), + ('cut.dont_cares', BOOL, True, 'integrate dont cares with cuts'), ('lookahead.cube.cutoff', SYMBOL, 'depth', 'cutoff type used to create lookahead cubes: depth, freevars, psat, adaptive_freevars, adaptive_psat'), # - depth: the maximal cutoff is fixed to the value of lookahead.cube.depth. # So if the value is 10, at most 1024 cubes will be generated of length 10.