From 3bb05b5e011176b29fe8a8edabf39b84d7c6ad3f Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Wed, 19 Feb 2020 18:36:28 -0800 Subject: [PATCH] fix lut augment Signed-off-by: Nikolaj Bjorner --- src/sat/sat_aig_cuts.cpp | 61 ++++++++++++++++++++++++++-------- src/sat/sat_aig_cuts.h | 6 ++++ src/sat/sat_cut_simplifier.cpp | 5 +++ src/sat/sat_cutset.cpp | 28 ++++++++++++---- src/sat/sat_cutset.h | 6 ++++ 5 files changed, 86 insertions(+), 20 deletions(-) diff --git a/src/sat/sat_aig_cuts.cpp b/src/sat/sat_aig_cuts.cpp index 75d1f6a17..a099d671b 100644 --- a/src/sat/sat_aig_cuts.cpp +++ b/src/sat/sat_aig_cuts.cpp @@ -105,19 +105,24 @@ namespace sat { literal l1 = child(n, 0); VERIFY(&cs != &m_cuts[l1.var()]); for (auto const& a : m_cuts[l1.var()]) { - m_tables[0] = &a; - cut b(a); - augment_lut_rec(v, n, b, 1, cs); + if (a.size() > 0) { + m_tables[0] = &a; + m_lits[0] = l1; + cut b(a); + augment_lut_rec(v, n, b, 1, cs); + } } } void aig_cuts::augment_lut_rec(unsigned v, node const& n, cut& a, unsigned idx, cut_set& cs) { if (idx < n.size()) { - VERIFY(&cs != &m_cuts[child(n, idx).var()]); - for (auto const& b : m_cuts[child(n, idx).var()]) { + literal lit = child(n, idx); + VERIFY(&cs != &m_cuts[lit.var()]); + for (auto const& b : m_cuts[lit.var()]) { cut ab; - if (ab.merge(a, b)) { + if (b.size() > 0 && ab.merge(a, b)) { m_tables[idx] = &b; + m_lits[idx] = lit; augment_lut_rec(v, n, ab, idx + 1, cs); } } @@ -134,10 +139,11 @@ namespace sat { // when computing the output at position j, // the i'th bit to index into n.lut() is // based on the j'th output bit in lut[i] + // m_lits[i].sign() tracks if output bit is negated for (unsigned i = n.size(); i-- > 0; ) { - w |= ((m_luts[i] >> j) & 0x1) << i; + w |= (((m_luts[i] >> j) ^ (uint64_t)m_lits[i].sign()) & 1u) << i; } - r |= ((n.lut() >> w) & 0x1) << j; + r |= ((n.lut() >> w) & 1u) << j; } a.set_table(r); insert_cut(v, a, cs); @@ -152,13 +158,16 @@ namespace sat { VERIFY(&cs != &m_cuts[l2.var()]); VERIFY(&cs != &m_cuts[l3.var()]); for (auto const& a : m_cuts[l1.var()]) { + if (a.size() == 0) continue; for (auto const& b : m_cuts[l2.var()]) { + if (b.size() == 0) continue; cut ab; if (!ab.merge(a, b)) { continue; } for (auto const& c : m_cuts[l3.var()]) { cut abc; + if (c.size() == 0) continue; if (!abc.merge(ab, c)) { continue; } @@ -180,7 +189,7 @@ namespace sat { IF_VERBOSE(4, display(verbose_stream() << "augment_unit " << v << " ", n) << "\n"); SASSERT(n.is_and() && n.size() == 0); reset(cs); - cut c; + cut c; c.set_table(n.sign() ? 0x0 : 0x1); push_back(cs, c); } @@ -191,9 +200,11 @@ namespace sat { literal lit = child(n, 0); VERIFY(&cs != &m_cuts[lit.var()]); for (auto const& a : m_cuts[lit.var()]) { - cut c(a); - if (n.sign()) c.negate(); - if (!insert_cut(v, c, cs)) return; + if (a.size() > 0) { + cut c(a); + if (n.sign()) c.negate(); + if (!insert_cut(v, c, cs)) return; + } } } @@ -208,7 +219,9 @@ namespace sat { VERIFY(&cs != &m_cuts[l1.var()]); VERIFY(&cs != &m_cuts[l2.var()]); for (auto const& a : m_cuts[l1.var()]) { + if (a.size() == 0) continue; for (auto const& b : m_cuts[l2.var()]) { + if (b.size() == 0) continue; cut c; if (!c.merge(a, b)) { continue; @@ -232,6 +245,7 @@ namespace sat { SASSERT(n.is_and() || n.is_xor()); literal lit = child(n, 0); for (auto const& a : m_cuts[lit.var()]) { + if (a.size() == 0) continue; cut b(a); if (lit.sign()) { b.negate(); @@ -244,6 +258,7 @@ namespace sat { m_insertions = 0; for (auto const& a : m_cut_set1) { for (auto const& b : m_cuts[lit.var()]) { + if (b.size() == 0) continue; cut c; if (!c.merge(a, b)) { continue; @@ -321,7 +336,7 @@ namespace sat { } void aig_cuts::add_node(bool_var v, uint64_t lut, unsigned sz, bool_var const* args) { - TRACE("aig_simplifier", tout << v << " == " << lut << " " << bool_var_vector(sz, args) << "\n";); + TRACE("aig_simplifier", tout << v << " == " << cut::table2string(sz, lut) << " " << bool_var_vector(sz, args) << "\n";); reserve(v); unsigned offset = m_literals.size(); node n(lut, sz, offset); @@ -372,7 +387,6 @@ namespace sat { literal r = m_roots[i].second; literal rr = to_root[r.var()]; to_root[v] = r.sign() ? ~rr : rr; - // if (rr != r) std::cout << v << " -> " << to_root[v] << "\n"; } for (unsigned i = 0; i < m_aig.size(); ++i) { // invalidate nodes that have been rooted @@ -427,6 +441,25 @@ namespace sat { } } + void aig_cuts::flush_units() { + return; + // TBD: remove unit literals from cuts + for (unsigned i = 0; i < m_cuts.size(); ++i) { + + } + } + + void aig_cuts::flush_units(cut_set& cs) { + + } + + lbool aig_cuts::get_value(bool_var v) const { + if (m_aig[v].size() == 1 && m_aig[v][0].is_const()) { + return m_aig[v][0].sign() ? l_false : l_true; + } + return l_undef; + } + void aig_cuts::init_cut_set(unsigned id) { SASSERT(m_aig[id].size() == 1); SASSERT(m_aig[id][0].is_valid()); diff --git a/src/sat/sat_aig_cuts.h b/src/sat/sat_aig_cuts.h index 62e6c42e7..5e9d46f27 100644 --- a/src/sat/sat_aig_cuts.h +++ b/src/sat/sat_aig_cuts.h @@ -115,6 +115,7 @@ namespace sat { literal_vector m_clause; cut const* m_tables[6]; uint64_t m_luts[6]; + literal m_lits[6]; bool is_touched(bool_var v, node const& n); bool is_touched(literal lit) const { return is_touched(lit.var()); } @@ -144,7 +145,10 @@ namespace sat { bool flush_roots(bool_var var, literal_vector const& to_root, node& n); void flush_roots(literal_vector const& to_root, cut_set& cs); + void flush_units(cut_set& cs); + cut_val eval(node const& n, cut_eval const& env) const; + lbool get_value(bool_var v) const; std::ostream& display(std::ostream& out, node const& n) const; @@ -182,6 +186,8 @@ namespace sat { void inc_max_cutset_size(unsigned v) { m_max_cutset_size[v] += 10; touch(v); } unsigned max_cutset_size(unsigned v) const { return v == UINT_MAX ? m_config.m_max_cutset_size : m_max_cutset_size[v]; } + void flush_units(); + vector const & operator()(); unsigned num_cuts() const { return m_num_cuts; } diff --git a/src/sat/sat_cut_simplifier.cpp b/src/sat/sat_cut_simplifier.cpp index 935ca0b58..ec7a9baf7 100644 --- a/src/sat/sat_cut_simplifier.cpp +++ b/src/sat/sat_cut_simplifier.cpp @@ -177,10 +177,15 @@ namespace sat { void cut_simplifier::clauses2aig() { // update units + bool has_units = false; for (; m_config.m_enable_units && m_trail_size < s.init_trail_size(); ++m_trail_size) { + has_units = true; literal lit = s.trail_literal(m_trail_size); m_aig_cuts.add_node(lit, and_op, 0, 0); } + if (has_units) { + m_aig_cuts.flush_units(); + } std::function on_and = [&,this](literal head, literal_vector const& ands) { diff --git a/src/sat/sat_cutset.cpp b/src/sat/sat_cutset.cpp index 1acc34739..dc0ca51bd 100644 --- a/src/sat/sat_cutset.cpp +++ b/src/sat/sat_cutset.cpp @@ -12,6 +12,7 @@ --*/ +#include #include "util/hashtable.h" #include "sat/sat_cutset.h" #include "sat/sat_cutset_compute_shift.h" @@ -95,11 +96,14 @@ namespace sat { void cut_set::init(region& r, unsigned max_sz, unsigned v) { m_var = v; - m_max_size = max_sz; + m_size = 0; SASSERT(!m_region || m_cuts); - if (m_region) return; - m_region = &r; - m_cuts = new (r) cut[max_sz]; + VERIFY(!m_region || m_max_size > 0); + if (!m_region) { + m_max_size = max_sz; + m_region = &r; + m_cuts = new (r) cut[max_sz]; + } } /** @@ -192,10 +196,22 @@ namespace sat { if (i + 1 < m_size) out << " "; } out << "} "; - for (unsigned i = 0; i < (1u << m_size); ++i) { - if (0 != (table() & (1ull << i))) out << "1"; else out << "0"; + display_table(out, m_size, table()); + return out; + } + + std::ostream& cut::display_table(std::ostream& out, unsigned num_input, uint64_t table) { + for (unsigned i = 0; i < (1u << num_input); ++i) { + if (0 != (table & (1ull << i))) out << "1"; else out << "0"; } return out; } + std::string cut::table2string(unsigned num_input, uint64_t table) { + std::ostringstream strm; + display_table(strm, num_input, table); + return strm.str(); + } + + } diff --git a/src/sat/sat_cutset.h b/src/sat/sat_cutset.h index af7f4e15d..a1231baf1 100644 --- a/src/sat/sat_cutset.h +++ b/src/sat/sat_cutset.h @@ -63,6 +63,8 @@ namespace sat { unsigned size() const { return m_size; } + unsigned filter() const { return m_filter; } + static unsigned max_cut_size() { return 5; } unsigned const* begin() const { return m_elems; } @@ -158,6 +160,10 @@ namespace sat { } std::ostream& display(std::ostream& out) const; + + static std::ostream& display_table(std::ostream& out, unsigned num_input, uint64_t table); + + static std::string table2string(unsigned num_input, uint64_t table); }; class cut_set {