diff --git a/src/sat/sat_aig_cuts.cpp b/src/sat/sat_aig_cuts.cpp index 406f2f015..703ef00da 100644 --- a/src/sat/sat_aig_cuts.cpp +++ b/src/sat/sat_aig_cuts.cpp @@ -284,20 +284,11 @@ namespace sat { } } - void aig_cuts::add_node(literal head, bool_op op, unsigned sz, literal const* args) { - TRACE("aig_simplifier", tout << head << " == " << op << " " << literal_vector(sz, args) << "\n";); - unsigned v = head.var(); - reserve(v); - unsigned offset = m_literals.size(); - node n(head.sign(), op, sz, offset); - m_literals.append(sz, args); - if (op == and_op || op == xor_op) { - std::sort(m_literals.c_ptr() + offset, m_literals.c_ptr() + offset + sz); - } - for (unsigned i = 0; i < sz; ++i) { - reserve(args[i].var()); - if (m_aig[args[i].var()].empty()) { - add_var(args[i].var()); + void aig_cuts::add_node(bool_var v, node const& n) { + for (unsigned i = 0; i < n.size(); ++i) { + reserve(m_literals[i].var()); + if (m_aig[m_literals[i].var()].empty()) { + add_var(m_literals[i].var()); } } if (m_aig[v].empty() || n.is_const()) { @@ -309,21 +300,44 @@ namespace sat { augment_aig0(v, n, m_cuts[v]); } touch(v); - IF_VERBOSE(12, display(verbose_stream() << "add " << head.var() << " == ", n) << "\n"); - + IF_VERBOSE(12, display(verbose_stream() << "add " << v << " == ", n) << "\n"); } else if (m_aig[v][0].is_const() || !insert_aux(v, n)) { - m_literals.shrink(m_literals.size() - sz); + m_literals.shrink(m_literals.size() - n.size()); TRACE("aig_simplifier", tout << "duplicate\n";); } SASSERT(!m_aig[v].empty()); } + void aig_cuts::add_node(bool_var v, uint64_t lut, unsigned sz, bool_var const* args) { + TRACE("aig_simplifier", tout << v << " == " << lut << " " << bool_var_vector(sz, args) << "\n";); + reserve(v); + unsigned offset = m_literals.size(); + node n(lut, sz, offset); + for (unsigned i = 0; i < sz; ++i) { + m_literals.push_back(literal(args[i], false)); + } + add_node(v, n); + } + + void aig_cuts::add_node(literal head, bool_op op, unsigned sz, literal const* args) { + TRACE("aig_simplifier", tout << head << " == " << op << " " << literal_vector(sz, args) << "\n";); + unsigned v = head.var(); + reserve(v); + unsigned offset = m_literals.size(); + node n(head.sign(), op, sz, offset); + m_literals.append(sz, args); + if (op == and_op || op == xor_op) { + std::sort(m_literals.c_ptr() + offset, m_literals.c_ptr() + offset + sz); + } + add_node(v, n); + } + void aig_cuts::add_cut(bool_var v, uint64_t lut, bool_var_vector const& args) { // args can be assumed to be sorted DEBUG_CODE(for (unsigned i = 0; i + 1 < args.size(); ++i) VERIFY(args[i] < args[i+1]);); - reserve(v); - for (bool_var w : args) reserve(w); + add_var(v); + for (bool_var w : args) add_var(w); cut c; for (bool_var w : args) VERIFY(c.add(w)); c.set_table(lut); @@ -588,6 +602,7 @@ namespace sat { literal c, t, e; if (n.sign()) r.neg(); m_clause.reset(); + unsigned num_comb = 0; switch (n.op()) { case var_op: return; @@ -623,15 +638,15 @@ namespace sat { m_clause.push_back(r, c, ~e); on_clause(m_clause); return; - case xor_op: { + case xor_op: // r = a ^ b ^ c // <=> // ~r ^ a ^ b ^ c = 1 if (n.size() > 10) { throw default_exception("cannot handle large xors"); } - unsigned num_comp = (1 << n.size()); - for (unsigned i = 0; i < num_comp; ++i) { + num_comb = (1 << n.size()); + for (unsigned i = 0; i < num_comb; ++i) { bool parity = n.size() % 2 == 1; m_clause.reset(); for (unsigned j = 0; j < n.size(); ++j) { @@ -649,7 +664,21 @@ namespace sat { on_clause(m_clause); } return; - } + case lut_op: + // r = LUT(v0, v1, v2) + num_comb = (1 << n.size()); + for (unsigned i = 0; i < num_comb; ++i) { + m_clause.reset(); + for (unsigned j = 0; j < n.size(); ++j) { + literal lit = m_literals[n.offset() + j]; + if (0 != (i & (1 << j))) lit.neg(); + m_clause.push_back(lit); + } + m_clause.push_back(0 == (n.lut() & (1ull << i)) ? ~r : r); + TRACE("aig_simplifier", tout << n.lut() << " " << m_clause << "\n";); + on_clause(m_clause); + } + return; default: UNREACHABLE(); break; diff --git a/src/sat/sat_aig_cuts.h b/src/sat/sat_aig_cuts.h index c74923209..c7ca473a0 100644 --- a/src/sat/sat_aig_cuts.h +++ b/src/sat/sat_aig_cuts.h @@ -82,7 +82,7 @@ namespace sat { explicit node(bool sign, bool_op op, unsigned nc, unsigned o) : m_sign(sign), m_op(op), m_size(nc), m_offset(o) {} explicit node(uint64_t lut, unsigned nc, unsigned o): - m_sign(false), m_op(lut_op), m_size(nc), m_offset(o) {} + m_sign(false), m_op(lut_op), m_lut(lut), m_size(nc), m_offset(o) {} bool is_valid() const { return m_offset != UINT_MAX; } bool_op op() const { return m_op; } bool is_var() const { return m_op == var_op; } @@ -165,11 +165,13 @@ namespace sat { void validate_aig2(cut const& a, cut const& b, unsigned v, node const& n, cut const& c); void validate_aigN(unsigned v, node const& n, cut const& c); + void add_node(bool_var v, node const& n); public: aig_cuts(); void add_var(unsigned v); void add_node(literal head, bool_op op, unsigned sz, literal const* args); + void add_node(bool_var head, uint64_t lut, unsigned sz, bool_var const* args); void add_cut(bool_var v, uint64_t lut, bool_var_vector const& args); void set_root(bool_var v, literal r); diff --git a/src/sat/sat_aig_simplifier.cpp b/src/sat/sat_aig_simplifier.cpp index 57d71e803..69f256912 100644 --- a/src/sat/sat_aig_simplifier.cpp +++ b/src/sat/sat_aig_simplifier.cpp @@ -237,7 +237,8 @@ namespace sat { std::function on_lut = [&,this](uint64_t lut, bool_var_vector const& vars, bool_var v) { m_stats.m_xluts++; - m_aig_cuts.add_cut(v, lut, vars); + // m_aig_cuts.add_cut(v, lut, vars); + m_aig_cuts.add_node(v, lut, vars.size(), vars.c_ptr()); }; lut_finder lf(s); lf.set(on_lut); diff --git a/src/sat/sat_lut_finder.cpp b/src/sat/sat_lut_finder.cpp index 11783eba8..2ee06034e 100644 --- a/src/sat/sat_lut_finder.cpp +++ b/src/sat/sat_lut_finder.cpp @@ -113,12 +113,11 @@ namespace sat { m_removed_clauses.append(m_clauses_to_remove); bool_var v; uint64_t lut = convert_combination(m_vars, v); - IF_VERBOSE(12, - for (clause* cp : m_clauses_to_remove) { - verbose_stream() << *cp << "\n"; - verbose_stream() << v << ": " << m_vars << "\n"; - } - display_mask(verbose_stream(), lut, 1u << m_vars.size()) << "\n";); + TRACE("aig_simplifier", + for (clause* cp : m_clauses_to_remove) { + tout << *cp << "\n" << v << ": " << m_vars << "\n"; + } + display_mask(tout, lut, 1u << m_vars.size()) << "\n";); m_on_lut(lut, m_vars, v); }