diff --git a/src/sat/sat_aig_finder.cpp b/src/sat/sat_aig_finder.cpp index 2d1904857..588a27606 100644 --- a/src/sat/sat_aig_finder.cpp +++ b/src/sat/sat_aig_finder.cpp @@ -16,9 +16,12 @@ --*/ #include "sat/sat_aig_finder.h" +#include "sat/sat_solver.h" namespace sat { + aig_finder::aig_finder(solver& s): s(s), m_big(s.rand()) {} + void aig_finder::operator()(clause_vector& clauses) { m_big.init(s, true); find_aigs(clauses); diff --git a/src/sat/sat_aig_finder.h b/src/sat/sat_aig_finder.h index 5a0ccb254..ab29d3c1a 100644 --- a/src/sat/sat_aig_finder.h +++ b/src/sat/sat_aig_finder.h @@ -29,11 +29,12 @@ #include "util/statistics.h" #include "sat/sat_clause.h" #include "sat/sat_types.h" -#include "sat/sat_solver.h" #include "sat/sat_big.h" namespace sat { + class solver; + class aig_finder { solver& s; big m_big; @@ -50,7 +51,7 @@ namespace sat { void validate_clause(literal_vector const& clause, vector const& clauses); public: - aig_finder(solver& s) : s(s), m_big(s.rand()) {} + aig_finder(solver& s); ~aig_finder() {} void set(std::function& f) { m_on_aig = f; } void set(std::function& f) { m_on_if = f; } diff --git a/src/sat/sat_aig_simplifier.cpp b/src/sat/sat_aig_simplifier.cpp index c64d93b95..fac9c8810 100644 --- a/src/sat/sat_aig_simplifier.cpp +++ b/src/sat/sat_aig_simplifier.cpp @@ -25,9 +25,8 @@ namespace sat { struct aig_simplifier::report { aig_simplifier& s; - aig_cuts& c; stopwatch m_watch; - report(aig_simplifier& s, aig_cuts& c): s(s), c(c) { m_watch.start(); } + report(aig_simplifier& s): s(s) { m_watch.start(); } ~report() { IF_VERBOSE(2, verbose_stream() << "(sat.aig-simplifier" @@ -39,39 +38,53 @@ namespace sat { } }; + aig_simplifier::aig_simplifier(solver& s):s(s), m_aig_cuts(m_config.m_max_cut_size, m_config.m_max_cutset_size) { + } + + void aig_simplifier::add_and(literal head, unsigned sz, literal const* lits) { + m_aig_cuts.add_node(head, and_op, sz, lits); + } + + void aig_simplifier::add_or(literal head, unsigned sz, literal const* lits) { + m_aig_cuts.add_node(head, and_op, sz, lits); + } + + void aig_simplifier::add_xor(literal head, unsigned sz, literal const* lits) { + m_aig_cuts.add_node(head, xor_op, sz, lits); + } + + void aig_simplifier::add_ite(literal head, literal c, literal t, literal e) { + literal lits[3] = { c, t, e }; + m_aig_cuts.add_node(head, ite_op, 3, lits); + } + + void aig_simplifier::add_iff(literal head, literal l1, literal l2) { + literal lits[2] = { l1, ~l2 }; + m_aig_cuts.add_node(head, xor_op, 2, lits); + } + void aig_simplifier::operator()() { - aig_cuts aigc; - report _report(*this, aigc); + report _report(*this); TRACE("aig_simplifier", s.display(tout);); - clauses2aig(aigc); - aig2clauses(aigc); + clauses2aig(); + aig2clauses(); } /** \brief extract AIG definitions from clauses Ensure that they are sorted and variables have unique definitions. */ - void aig_simplifier::clauses2aig(aig_cuts& aigc) { - struct aig_def { - literal head; - bool_op op; - unsigned sz; - unsigned offset; - aig_def(literal h, bool_op op, unsigned sz, unsigned o): head(h), op(op), sz(sz), offset(o) {} - }; - svector aig_defs; + void aig_simplifier::clauses2aig() { literal_vector literals; std::function on_and = - [&,this](literal head, literal_vector const& ands) { - aig_defs.push_back(aig_def(head, and_op, ands.size(), literals.size())); - literals.append(ands); + [&,this](literal head, literal_vector const& ands) { + m_aig_cuts.add_node(head, and_op, ands.size(), ands.c_ptr()); m_stats.m_num_ands++; }; std::function on_ite = [&,this](literal head, literal c, literal t, literal e) { - aig_defs.push_back(aig_def(head, ite_op, 3, literals.size())); - literal args[3] = { c, t, e }; - literals.append(3, args); + literal args[3] = { c, t, e }; + m_aig_cuts.add_node(head, ite_op, 3, args); m_stats.m_num_ites++; }; aig_finder af(s); @@ -97,88 +110,25 @@ namespace sat { // ~head = t1 + t2 + .. literal head = ~xors[index]; unsigned sz = xors.size() - 1; - aig_defs.push_back(aig_def(head, xor_op, sz, literals.size())); for (unsigned i = xors.size(); i-- > 0; ) { if (i != index) literals.push_back(xors[i]); } + m_aig_cuts.add_node(head, xor_op, sz, literals.c_ptr()); + literals.reset(); m_stats.m_num_xors++; }; xor_finder xf(s); xf.set(on_xor); - xf(clauses); - - svector outs(s.num_vars(), false); - svector ins(s.num_vars(), false); - for (auto a : aig_defs) { - outs[a.head.var()] = true; - } - - for (auto a : aig_defs) { - for (unsigned i = 0; i < a.sz; ++i) { - unsigned v = literals[a.offset+i].var(); - if (!outs[v]) ins[v] = true; - } - } - - std::function force_var = [&] (aig_def a) { - for (unsigned i = 0; i < a.sz; ++i) { - unsigned v = literals[a.offset + i].var(); - if (!ins[v]) { - aigc.add_var(v); - ins[v] = true; - } - } - }; - std::function add_var = [&] (unsigned v) { - if (!outs[v] && ins[v]) { - aigc.add_var(v); - outs[v] = true; - } - }; - for (auto a : aig_defs) { - for (unsigned i = 0; i < a.sz; ++i) { - add_var(literals[a.offset+i].var()); - } - } - - while (true) { - unsigned j = 0; - for (auto a : aig_defs) { - bool visited = true; - for (unsigned i = 0; visited && i < a.sz; ++i) { - visited &= ins[literals[a.offset + i].var()]; - } - unsigned h = a.head.var(); - if (!ins[h] && visited) { - ins[h] = true; - aigc.add_node(a.head, a.op, a.sz, literals.c_ptr() + a.offset); - } - else if (!ins[h]) { - aig_defs[j++] = a; - } - else { - TRACE("aig_simplifier", tout << "skip " << a.head << " == .. \n";); - force_var(a); - } - } - if (j == 0) { - break; - } - if (j == aig_defs.size()) { - IF_VERBOSE(2, verbose_stream() << "break cycle " << j << "\n"); - force_var(aig_defs.back()); - } - aig_defs.shrink(j); - } + xf(clauses); } - void aig_simplifier::aig2clauses(aig_cuts& aigc) { - vector cuts = aigc.get_cuts(m_config.m_max_cut_size, m_config.m_max_cutset_size); + void aig_simplifier::aig2clauses() { + vector const& cuts = m_aig_cuts.get_cuts(); map cut2id; union_find_default_ctx ctx; - union_find<> uf(ctx); + union_find<> uf(ctx), uf2(ctx); for (unsigned i = 2*s.num_vars(); i--> 0; ) uf.mk_var(); auto add_eq = [&](literal l1, literal l2) { uf.merge(l1.index(), l2.index()); @@ -212,8 +162,31 @@ namespace sat { } } if (old_num_eqs < m_stats.m_num_eqs) { - elim_eqs elim(s); - elim(uf); + // extract equivalences over non-eliminated literals. + bool new_eq = false; + for (unsigned idx = 0; idx < uf.get_num_vars(); ++idx) { + if (!uf.is_root(idx) || 1 == uf.size(idx)) continue; + literal root = null_literal; + unsigned first = idx; + do { + literal lit = to_literal(idx); + if (!s.was_eliminated(lit)) { + if (root == null_literal) { + root = lit; + } + else { + uf2.merge(lit.index(), root.index()); + new_eq = true; + } + } + idx = uf.next(idx); + } + while (first != idx); + } + if (new_eq) { + elim_eqs elim(s); + elim(uf2); + } } } @@ -224,28 +197,18 @@ namespace sat { st.update("sat-aig.ites", m_stats.m_num_ites); st.update("sat-aig.xors", m_stats.m_num_xors); } - - vector aig_cuts::get_cuts(unsigned max_cut_size, unsigned max_cutset_size) { - unsigned_vector sorted = top_sort(); - vector cuts(m_aig.size()); + + aig_cuts::aig_cuts(unsigned max_cut_size, unsigned max_cutset_size) { m_max_cut_size = std::min(cut().max_cut_size, max_cut_size); m_max_cutset_size = max_cutset_size; + } + + vector const& aig_cuts::get_cuts() { + unsigned_vector node_ids = filter_valid_nodes(); m_cut_set1.init(m_region, m_max_cutset_size + 1); m_cut_set2.init(m_region, m_max_cutset_size + 1); - - unsigned j = 0; - for (unsigned id : sorted) { - node const& n = m_aig[id]; - if (n.is_valid()) { - auto& cut_set = cuts[id]; - cut_set.init(m_region, m_max_cutset_size + 1); - cut_set.push_back(cut(id)); - sorted[j++] = id; - } - } - sorted.shrink(j); - augment(sorted, cuts); - return cuts; + augment(node_ids, m_cuts); + return m_cuts; } void aig_cuts::augment(unsigned_vector const& ids, vector& cuts) { @@ -259,6 +222,12 @@ namespace sat { else if (n.is_ite()) { augment_ite(n, cut_set, cuts); } + else if (n.num_children() == 0) { + augment_aig0(n, cut_set, cuts); + } + else if (n.num_children() == 1) { + augment_aig1(n, cut_set, cuts); + } else if (n.num_children() == 2) { augment_aig2(n, cut_set, cuts); } @@ -299,6 +268,28 @@ namespace sat { } } + void aig_cuts::augment_aig0(node const& n, cut_set& cs, vector& cuts) { + SASSERT(n.is_and()); + cut c; + cs.reset(); + if (!n.sign()) { + c.m_table = 3; + } + cs.insert(c); + } + + void aig_cuts::augment_aig1(node const& n, cut_set& cs, vector& cuts) { + SASSERT(n.is_and()); + literal lit = child(n, 0); + for (auto const& a : cuts[lit.var()]) { + if (cs.size() >= m_max_cutset_size) break; + cut c; + c.set_table(a.m_table); + if (n.sign()) c.negate(); + cs.insert(c); + } + } + void aig_cuts::augment_aig2(node const& n, cut_set& cs, vector& cuts) { SASSERT(n.is_and() || n.is_xor()); literal l1 = child(n, 0); @@ -363,7 +354,11 @@ namespace sat { void aig_cuts::add_var(unsigned v) { m_aig.reserve(v + 1); - m_aig[v] = node(v); + m_cuts.reserve(v + 1); + if (!m_aig[v].is_valid()) { + m_aig[v] = node(v); + init_cut_set(v); + } SASSERT(m_aig[v].is_valid()); } @@ -371,49 +366,45 @@ namespace sat { TRACE("aig_simplifier", tout << head << " == " << op << " " << literal_vector(sz, args) << "\n";); unsigned v = head.var(); m_aig.reserve(v + 1); - m_aig[v] = node(head.sign(), op, sz, m_literals.size()); + unsigned offset = m_literals.size(); + node n(head.sign(), op, sz, offset); m_literals.append(sz, args); - DEBUG_CODE( - for (unsigned i = 0; i < sz; ++i) { - SASSERT(m_aig[args[i].var()].is_valid()); - }); + for (unsigned i = 0; i < sz; ++i) { + if (!m_aig[args[i].var()].is_valid()) { + add_var(args[i].var()); + } + } + if (!m_aig[v].is_valid() || m_aig[v].is_var()) { + m_aig[v] = n; + init_cut_set(v); + } + else { + insert_aux(v, n); + } SASSERT(m_aig[v].is_valid()); } - unsigned_vector aig_cuts::top_sort() { - unsigned_vector result; - svector visit; - visit.reserve(m_aig.size(), false); - unsigned_vector todo; + void aig_cuts::init_cut_set(unsigned id) { + node const& n = m_aig[id]; + SASSERT(n.is_valid()); + auto& cut_set = m_cuts[id]; + cut_set.init(m_region, m_max_cutset_size + 1); + cut_set.push_back(cut(id)); + } + + void aig_cuts::insert_aux(unsigned v, node const& n) { + // TBD: throttle and replacement strategy + m_aux_aig.reserve(v + 1); + m_aux_aig[v].push_back(n); + } + + unsigned_vector aig_cuts::filter_valid_nodes() { unsigned id = 0; + unsigned_vector result; for (node const& n : m_aig) { - if (n.is_valid()) todo.push_back(id); + if (n.is_valid()) result.push_back(id); ++id; } - while (!todo.empty()) { - unsigned id = todo.back(); - if (visit[id]) { - todo.pop_back(); - continue; - } - bool all_visit = true; - node const& n = m_aig[id]; - SASSERT(n.is_valid()); - if (!n.is_var()) { - for (unsigned i = 0; i < n.num_children(); ++i) { - bool_var v = child(n, i).var(); - if (!visit[v]) { - todo.push_back(v); - all_visit = false; - } - } - } - if (all_visit) { - visit[id] = true; - result.push_back(id); - todo.pop_back(); - } - } return result; } } diff --git a/src/sat/sat_aig_simplifier.h b/src/sat/sat_aig_simplifier.h index 3800f3ac8..72b04986a 100644 --- a/src/sat/sat_aig_simplifier.h +++ b/src/sat/sat_aig_simplifier.h @@ -61,18 +61,26 @@ namespace sat { unsigned m_max_cut_size; unsigned m_max_cutset_size; cut_set m_cut_set1, m_cut_set2; + vector m_cuts; - unsigned_vector top_sort(); + void insert_aux(unsigned v, node const& n); + void init_cut_set(unsigned id); + + unsigned_vector filter_valid_nodes(); void augment(unsigned_vector const& ids, vector& cuts); void augment_ite(node const& n, cut_set& cs, vector& cuts); + void augment_aig0(node const& n, cut_set& cs, vector& cuts); + void augment_aig1(node const& n, cut_set& cs, vector& cuts); void augment_aig2(node const& n, cut_set& cs, vector& cuts); void augment_aigN(node const& n, cut_set& cs, vector& cuts); public: + aig_cuts(unsigned max_cut_size, unsigned max_cutset_size); void add_var(unsigned v); void add_node(literal head, bool_op op, unsigned sz, literal const* args); + literal child(node const& n, unsigned idx) const { SASSERT(!n.is_var()); SASSERT(idx < n.num_children()); return m_literals[n.offset() + idx]; } - vector get_cuts(unsigned max_cut_size, unsigned max_cutset_size); + vector const & get_cuts(); }; class aig_simplifier { @@ -91,14 +99,22 @@ namespace sat { solver& s; stats m_stats; config m_config; + aig_cuts m_aig_cuts; + struct report; - void clauses2aig(aig_cuts& aigc); - void aig2clauses(aig_cuts& aigc); + void clauses2aig(); + void aig2clauses(); public: - aig_simplifier(solver& s) : s(s) {} + aig_simplifier(solver& s); ~aig_simplifier() {} void operator()(); void collect_statistics(statistics& st) const; + + void add_and(literal head, unsigned sz, literal const* args); + void add_or(literal head, unsigned sz, literal const* args); + void add_xor(literal head, unsigned sz, literal const* args); + void add_ite(literal head, literal c, literal t, literal e); + void add_iff(literal head, literal l1, literal l2); }; } diff --git a/src/sat/sat_solver.h b/src/sat/sat_solver.h index 98edec401..cdb1142c9 100644 --- a/src/sat/sat_solver.h +++ b/src/sat/sat_solver.h @@ -31,6 +31,7 @@ Revision History: #include "sat/sat_simplifier.h" #include "sat/sat_scc.h" #include "sat/sat_asymm_branch.h" +#include "sat/sat_aig_simplifier.h" #include "sat/sat_iff3_finder.h" #include "sat/sat_probing.h" #include "sat/sat_mus.h" @@ -89,6 +90,7 @@ namespace sat { config m_config; stats m_stats; scoped_ptr m_ext; + scoped_ptr m_aig_simplifier; parallel* m_par; drat m_drat; // DRAT for generating proofs clause_allocator m_cls_allocator[2]; @@ -398,6 +400,7 @@ namespace sat { bool is_incremental() const { return m_config.m_incremental; } extension* get_extension() const override { return m_ext.get(); } void set_extension(extension* e) override; + aig_simplifier* get_aig_simplifier() override { return m_aig_simplifier.get(); } bool set_root(literal l, literal r); void flush_roots(); typedef std::pair bin_clause; diff --git a/src/sat/sat_solver_core.h b/src/sat/sat_solver_core.h index b3c43ea6a..92be0b19f 100644 --- a/src/sat/sat_solver_core.h +++ b/src/sat/sat_solver_core.h @@ -23,7 +23,10 @@ Revision History: #include "sat/sat_types.h" namespace sat { - + + class aig_simplifier; + class extension; + class solver_core { protected: reslimit& m_rlimit; @@ -89,6 +92,8 @@ namespace sat { virtual extension* get_extension() const { return nullptr; } virtual void set_extension(extension* e) { if (e) throw default_exception("optional API not supported"); } + virtual aig_simplifier* get_aig_simplifier() { return nullptr; } + // The following methods are used when converting the state from the SAT solver back // to a set of assertions. diff --git a/src/sat/tactic/goal2sat.cpp b/src/sat/tactic/goal2sat.cpp index a3e6b857c..73f000193 100644 --- a/src/sat/tactic/goal2sat.cpp +++ b/src/sat/tactic/goal2sat.cpp @@ -35,6 +35,7 @@ Notes: #include "ast/for_each_expr.h" #include "sat/tactic/goal2sat.h" #include "sat/ba_solver.h" +#include "sat/sat_aig_simplifier.h" #include "model/model_evaluator.h" #include "model/model_v2_pp.h" #include "tactic/tactic.h" @@ -53,6 +54,7 @@ struct goal2sat::imp { ast_manager & m; pb_util pb; sat::ba_solver* m_ext; + sat::aig_simplifier* m_aig; svector m_frame_stack; svector m_result_stack; obj_map m_cache; @@ -73,6 +75,7 @@ struct goal2sat::imp { m(_m), pb(m), m_ext(nullptr), + m_aig(nullptr), m_solver(s), m_map(map), m_dep2asm(dep2asm), @@ -82,6 +85,7 @@ struct goal2sat::imp { m_is_lemma(false) { updt_params(p); m_true = sat::null_literal; + m_aig = s.get_aig_simplifier(); } void updt_params(params_ref const & p) { @@ -252,6 +256,9 @@ struct goal2sat::imp { sat::literal l(k, false); m_cache.insert(t, l); sat::literal * lits = m_result_stack.end() - num; + + if (m_aig) m_aig->add_or(l, num, lits); + for (unsigned i = 0; i < num; i++) { mk_clause(~lits[i], l); } @@ -290,8 +297,11 @@ struct goal2sat::imp { sat::bool_var k = m_solver.add_var(false); sat::literal l(k, false); m_cache.insert(t, l); - // l => /\ lits sat::literal * lits = m_result_stack.end() - num; + + if (m_aig) m_aig->add_and(l, num, lits); + + // l => /\ lits for (unsigned i = 0; i < num; i++) { mk_clause(~l, lits[i]); } @@ -341,6 +351,7 @@ struct goal2sat::imp { mk_clause(~t, ~e, l, false); mk_clause(t, e, ~l, false); } + if (m_aig) m_aig->add_ite(l, c, t, e); m_result_stack.shrink(sz-3); if (sign) l.neg(); @@ -374,6 +385,7 @@ struct goal2sat::imp { mk_clause(~l, ~l1, l2); mk_clause(l, l1, l2); mk_clause(l, ~l1, ~l2); + if (m_aig) m_aig->add_iff(l, l1, l2); m_result_stack.shrink(sz-2); if (sign) l.neg(); @@ -400,6 +412,7 @@ struct goal2sat::imp { } ensure_extension(); m_ext->add_xr(lits); + if (m_aig) m_aig->add_xor(~lits.back(), lits.size() - 1, lits.c_ptr() + 1); sat::literal lit(v, sign); if (root) { m_result_stack.reset(); @@ -634,7 +647,7 @@ struct goal2sat::imp { m_ext = alloc(sat::ba_solver); m_solver.set_extension(m_ext); } - } + } } void convert(app * t, bool root, bool sign) {