diff --git a/src/sat/sat_aig_cuts.cpp b/src/sat/sat_aig_cuts.cpp index 556264524..fba07cc28 100644 --- a/src/sat/sat_aig_cuts.cpp +++ b/src/sat/sat_aig_cuts.cpp @@ -24,10 +24,13 @@ namespace sat { m_config.m_max_cut_size = std::min(cut().max_cut_size, m_config.m_max_cut_size); m_cut_set1.init(m_region, m_config.m_max_cutset_size + 1); m_cut_set2.init(m_region, m_config.m_max_cutset_size + 1); + m_true = null_bool_var; } vector const& aig_cuts::get_cuts() { + flush_roots(); unsigned_vector node_ids = filter_valid_nodes(); + TRACE("aig_simplifier", display(tout);); augment(node_ids); TRACE("aig_simplifier", display(tout);); return m_cuts; @@ -38,15 +41,17 @@ namespace sat { cut_set& cs = m_cuts[id]; node const& n = m_aig[id]; SASSERT(n.is_valid()); + // cs.display(std::cout << "augment " << id << "\nbefore\n"); augment(n, cs); for (node const& n2 : m_aux_aig[id]) { augment(n2, cs); } + // cs.display(std::cout << "after\n"); } } void aig_cuts::augment(node const& n, cut_set& cs) { - unsigned nc = n.num_children(); + unsigned nc = n.is_var() ? 0 : n.num_children(); if (n.is_var()) { SASSERT(!n.sign()); } @@ -68,6 +73,7 @@ namespace sat { } bool aig_cuts::insert_cut(cut const& c, cut_set& cs) { + SASSERT(c.m_size > 0); while (cs.size() >= m_config.m_max_cutset_size) { // never evict the first entry, it is used for the starting point unsigned idx = 1 + (m_rand() % (cs.size() - 1)); @@ -110,15 +116,16 @@ namespace sat { void aig_cuts::augment_aig0(node const& n, cut_set& cs) { SASSERT(n.is_and()); - cut c; + SASSERT(m_true != null_bool_var); + cut c(m_true); cs.reset(); if (n.sign()) { - c.m_table = 0; // constant false + c.m_table = 0x0; // constant false } else { c.m_table = 0x3; // constant true } - cs.insert(c); + cs.push_back(c); } void aig_cuts::augment_aig1(node const& n, cut_set& cs) { @@ -126,7 +133,7 @@ namespace sat { literal lit = child(n, 0); unsigned round = 0; for (auto const& a : m_cuts[lit.var()]) { - cut c; + cut c(lit.var()); c.set_table(a.m_table); if (n.sign()) c.negate(); if (insert_cut(c, cs) && ++round >= m_config.m_max_insertions) @@ -219,28 +226,106 @@ namespace sat { unsigned offset = m_literals.size(); node n(head.sign(), op, sz, offset); m_literals.append(sz, args); + if (op == and_op || op == xor_op) { + std::sort(m_literals.c_ptr() + offset, m_literals.c_ptr() + offset + sz); + } for (unsigned i = 0; i < sz; ++i) { if (!m_aig[args[i].var()].is_valid()) { add_var(args[i].var()); } } - if (!m_aig[v].is_valid() || m_aig[v].is_var() || (sz == 0)) { + if (n.is_const() && m_true == null_bool_var) { + m_true = v; + } + if (!m_aig[v].is_valid() || m_aig[v].is_var() || n.is_const()) { m_aig[v] = n; init_cut_set(v); + if (n.is_const()) { + augment_aig0(n, m_cuts[v]); + } } - else if (eq(n, m_aig[v]) || !insert_aux(v, n)) { + else if (m_aig[v].is_const() || eq(n, m_aig[v]) || !insert_aux(v, n)) { m_literals.shrink(m_literals.size() - sz); TRACE("aig_simplifier", tout << "duplicate\n";); } + for (auto const& c : m_cuts[v]) SASSERT(c.m_size > 0); SASSERT(m_aig[v].is_valid()); } + void aig_cuts::set_root(bool_var v, literal r) { + IF_VERBOSE(2, verbose_stream() << "set-root " << v << " -> " << r << "\n"); + m_roots.push_back(std::make_pair(v, r)); + } + + void aig_cuts::flush_roots() { + if (m_roots.empty()) return; + literal_vector to_root; + for (unsigned i = 0; i < m_aig.size(); ++i) { + to_root.push_back(literal(i, false)); + } + for (unsigned i = m_roots.size(); i-- > 0; ) { + bool_var v = m_roots[i].first; + literal r = m_roots[i].second; + literal rr = to_root[r.var()]; + to_root[v] = r.sign() ? ~rr : rr; + } + for (unsigned i = 0; i < m_aig.size(); ++i) { + node& n = m_aig[i]; + // invalidate nodes that have been rooted + if (to_root[i] != literal(i, false)) { + m_aux_aig[i].reset(); + m_aig[i] = node(); + m_cuts[i].reset(); + } + else if (n.is_valid()) { + flush_roots(to_root, n); + for (node & n2 : m_aux_aig[i]) { + flush_roots(to_root, n2); + } + } + } + for (cut_set& cs : m_cuts) { + flush_roots(to_root, cs); + } + m_roots.reset(); + TRACE("aig_simplifier", display(tout);); + } + + void aig_cuts::flush_roots(literal_vector const& to_root, node& n) { + bool changed = false; + for (unsigned i = 0; i < n.num_children(); ++i) { + literal& lit = m_literals[n.offset() + i]; + if (to_root[lit.var()] != lit) { + changed = true; + lit = lit.sign() ? ~to_root[lit.var()] : to_root[lit.var()]; + } + } + if (changed) { + std::sort(m_literals.c_ptr() + n.offset(), m_literals.c_ptr() + n.offset() + n.num_children()); + } + } + + void aig_cuts::flush_roots(literal_vector const& to_root, cut_set& cs) { + unsigned j = 0; + for (cut& c : cs) { + bool has_stale = false; + for (unsigned v : c) { + has_stale |= (to_root[v] != literal(v, false)); + } + if (!has_stale) { + cs[j++] = c; + } + } + cs.shrink(j); + } + void aig_cuts::init_cut_set(unsigned id) { node const& n = m_aig[id]; SASSERT(n.is_valid()); auto& cut_set = m_cuts[id]; cut_set.init(m_region, m_config.m_max_cutset_size + 1); - cut_set.push_back(cut(id)); // TBD: if entry is a constant? + cut_set.reset(); + cut_set.push_back(cut(id)); m_aux_aig[id].reset(); } @@ -281,8 +366,9 @@ namespace sat { out << id << " == "; display(out, m_aig[id]) << "\n"; for (auto const& n : m_aux_aig[id]) { - display(out << " ", n) << "\n"; + display(out << " ", n) << "\n"; } + m_cuts[id].display(out); } return out; } @@ -290,10 +376,10 @@ namespace sat { std::ostream& aig_cuts::display(std::ostream& out, node const& n) const { if (n.sign()) out << "! "; switch (n.op()) { - case var_op: out << "var "; break; - case and_op: out << "and "; break; - case xor_op: out << "xor "; break; - case ite_op: out << "ite "; break; + case var_op: out << "var "; return out; + case and_op: out << "& "; break; + case xor_op: out << "^ "; break; + case ite_op: out << "? "; break; default: break; } for (unsigned i = 0; i < n.num_children(); ++i) { diff --git a/src/sat/sat_aig_cuts.h b/src/sat/sat_aig_cuts.h index 3f6bf49d9..41eb0f436 100644 --- a/src/sat/sat_aig_cuts.h +++ b/src/sat/sat_aig_cuts.h @@ -31,6 +31,16 @@ namespace sat { no_op }; + inline std::ostream& operator<<(std::ostream& out, bool_op op) { + switch (op) { + case var_op: return out << "v"; + case and_op: return out << "&"; + case ite_op: return out << "?"; + case xor_op: return out << "^"; + default: return out << ""; + } + } + class aig_cuts { struct config { @@ -50,14 +60,15 @@ namespace sat { public: node(): m_sign(false), m_op(no_op), m_num_children(UINT_MAX), m_offset(UINT_MAX) {} explicit node(unsigned v): m_sign(false), m_op(var_op), m_num_children(UINT_MAX), m_offset(v) {} - explicit node(bool negate, bool_op op, unsigned num_children, unsigned offset): - m_sign(negate), m_op(op), m_num_children(num_children), m_offset(offset) {} + explicit node(bool negate, bool_op op, unsigned nc, unsigned o): + m_sign(negate), m_op(op), m_num_children(nc), m_offset(o) {} bool is_valid() const { return m_offset != UINT_MAX; } bool_op op() const { return m_op; } bool is_var() const { return m_op == var_op; } bool is_and() const { return m_op == and_op; } bool is_xor() const { return m_op == xor_op; } bool is_ite() const { return m_op == ite_op; } + bool is_const() const { return is_and() && num_children() == 0; } unsigned var() const { SASSERT(is_var()); return m_offset; } bool sign() const { return m_sign; } unsigned num_children() const { SASSERT(!is_var()); return m_num_children; } @@ -71,6 +82,8 @@ namespace sat { region m_region; cut_set m_cut_set1, m_cut_set2; vector m_cuts; + bool_var m_true; + svector> m_roots; void reserve(unsigned v); bool insert_aux(unsigned v, node const& n); @@ -89,6 +102,10 @@ namespace sat { bool insert_cut(cut const& c, cut_set& cs); + void flush_roots(); + void flush_roots(literal_vector const& to_root, node& n); + void flush_roots(literal_vector const& to_root, cut_set& cs); + std::ostream& display(std::ostream& out, node const& n) const; literal child(node const& n, unsigned idx) const { SASSERT(!n.is_var()); SASSERT(idx < n.num_children()); return m_literals[n.offset() + idx]; } @@ -97,6 +114,7 @@ namespace sat { aig_cuts(); void add_var(unsigned v); void add_node(literal head, bool_op op, unsigned sz, literal const* args); + void set_root(bool_var v, literal r); vector const & get_cuts(); diff --git a/src/sat/sat_aig_simplifier.cpp b/src/sat/sat_aig_simplifier.cpp index 296e47d15..6044624f8 100644 --- a/src/sat/sat_aig_simplifier.cpp +++ b/src/sat/sat_aig_simplifier.cpp @@ -76,6 +76,10 @@ namespace sat { m_stats.m_num_xors++; } + void aig_simplifier::set_root(bool_var v, literal r) { + m_aig_cuts.set_root(v, r); + } + void aig_simplifier::operator()() { report _report(*this); TRACE("aig_simplifier", s.display(tout);); @@ -159,6 +163,7 @@ namespace sat { for (auto& cut : cuts[i]) { unsigned j = 0; if (cut2id.find(&cut, j)) { + if (i == j) std::cout << "dup: " << i << "\n"; VERIFY(i != j); literal v(i, false); literal w(j, false); diff --git a/src/sat/sat_aig_simplifier.h b/src/sat/sat_aig_simplifier.h index 58962941d..5d88033f7 100644 --- a/src/sat/sat_aig_simplifier.h +++ b/src/sat/sat_aig_simplifier.h @@ -51,6 +51,7 @@ namespace sat { void add_xor(literal head, unsigned sz, literal const* args); void add_ite(literal head, literal c, literal t, literal e); void add_iff(literal head, literal l1, literal l2); + void set_root(bool_var v, literal r); }; } diff --git a/src/sat/sat_cutset.cpp b/src/sat/sat_cutset.cpp index e23c2b866..227638ee0 100644 --- a/src/sat/sat_cutset.cpp +++ b/src/sat/sat_cutset.cpp @@ -33,6 +33,7 @@ namespace sat { */ bool cut_set::insert(cut const& c) { + SASSERT(c.m_size > 0); unsigned i = 0, j = 0; for (; i < size(); ++i) { cut const& a = (*this)[i]; @@ -61,6 +62,14 @@ namespace sat { } return true; } + + std::ostream& cut_set::display(std::ostream& out) const { + for (auto const& cut : *this) { + cut.display(out) << "\n"; + } + return out; + } + /** \brief shift table 'a' by adding elements from 'c'. @@ -116,10 +125,12 @@ namespace sat { } std::ostream& cut::display(std::ostream& out) const { + out << "{"; for (unsigned i = 0; i < m_size; ++i) { - out << (*this)[i] << " "; - } - out << "t: "; + out << (*this)[i]; + if (i + 1 < m_size) out << " "; + } + out << "} t: "; for (unsigned i = 0; i < (1u << m_size); ++i) { if (0 != (m_table & (1ull << i))) out << "1"; else out << "0"; } diff --git a/src/sat/sat_cutset.h b/src/sat/sat_cutset.h index 6da26a7d8..361e0c889 100644 --- a/src/sat/sat_cutset.h +++ b/src/sat/sat_cutset.h @@ -111,9 +111,11 @@ namespace sat { public: cut_set(): m_region(nullptr), m_size(0), m_max_size(0), m_cuts(nullptr) {} void init(region& r, unsigned sz) { + m_max_size = sz; + SASSERT(!m_region || m_cuts); + if (m_region) return; m_region = &r; m_cuts = new (r) cut[sz]; - m_max_size = sz; } bool insert(cut const& c); bool no_duplicates() const; @@ -121,7 +123,8 @@ namespace sat { cut * begin() const { return m_cuts; } cut * end() const { return m_cuts + m_size; } cut & back() { return m_cuts[m_size-1]; } - void push_back(cut const& c) { + void push_back(cut const& c) { + SASSERT(c.m_size > 0); if (m_size == m_max_size) { m_max_size *= 2; cut* new_cuts = new (*m_region) cut[m_max_size]; @@ -135,6 +138,7 @@ namespace sat { void shrink(unsigned j) { m_size = j; } void swap(cut_set& other) { std::swap(m_size, other.m_size); std::swap(m_cuts, other.m_cuts); std::swap(m_max_size, other.m_max_size); } void evict(unsigned idx) { m_cuts[idx] = m_cuts[--m_size]; } + std::ostream& display(std::ostream& out) const; }; } diff --git a/src/sat/sat_elim_eqs.cpp b/src/sat/sat_elim_eqs.cpp index 908304fb5..06968ec69 100644 --- a/src/sat/sat_elim_eqs.cpp +++ b/src/sat/sat_elim_eqs.cpp @@ -230,6 +230,7 @@ namespace sat { literal l(v, false); literal r = roots[v]; SASSERT(v != r.var()); + if (m_solver.m_aig_simplifier) m_solver.m_aig_simplifier->set_root(v, r); bool set_root = m_solver.set_root(l, r); bool root_ok = !m_solver.is_external(v) || set_root; if (m_solver.is_assumption(v) || (m_solver.is_external(v) && (m_solver.is_incremental() || !root_ok))) {