From 40a4326ad4a52cbe33f68a4a8514c51e333491a5 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Thu, 2 Jan 2020 18:02:53 -0800 Subject: [PATCH] add anf Signed-off-by: Nikolaj Bjorner --- scripts/mk_project.py | 10 +- src/math/dd/dd_pdd.cpp | 31 +++ src/math/dd/dd_pdd.h | 8 +- src/math/grobner/pdd_solver.h | 6 +- src/math/simplex/bit_matrix.cpp | 25 ++ src/math/simplex/bit_matrix.h | 1 + src/sat/CMakeLists.txt | 3 + src/sat/ba_solver.cpp | 418 ++------------------------------ src/sat/ba_solver.h | 52 +--- src/sat/sat_anf_simplifier.cpp | 274 +++++++++++++++++++++ src/sat/sat_anf_simplifier.h | 75 ++++++ src/sat/sat_config.cpp | 1 + src/sat/sat_config.h | 1 + src/sat/sat_params.pyg | 1 + src/sat/sat_solver.cpp | 51 +++- src/sat/sat_solver.h | 20 +- src/sat/sat_solver_core.h | 4 +- src/sat/sat_xor_util.cpp | 234 ++++++++++++++++++ src/sat/sat_xor_util.h | 76 ++++++ src/sat/tactic/goal2sat.cpp | 34 +-- 20 files changed, 860 insertions(+), 465 deletions(-) create mode 100644 src/sat/sat_anf_simplifier.cpp create mode 100644 src/sat/sat_anf_simplifier.h create mode 100644 src/sat/sat_xor_util.cpp create mode 100644 src/sat/sat_xor_util.h diff --git a/scripts/mk_project.py b/scripts/mk_project.py index e926862a6..35ddde70c 100644 --- a/scripts/mk_project.py +++ b/scripts/mk_project.py @@ -16,16 +16,17 @@ def init_project_def(): add_lib('util', [], includes2install = ['z3_version.h']) add_lib('polynomial', ['util'], 'math/polynomial') add_lib('dd', ['util'], 'math/dd') - add_lib('sat', ['util','dd']) - add_lib('nlsat', ['polynomial', 'sat']) - add_lib('lp', ['util','nlsat'], 'util/lp') - add_lib('hilbert', ['util'], 'math/hilbert') add_lib('simplex', ['util'], 'math/simplex') + add_lib('hilbert', ['util'], 'math/hilbert') add_lib('automata', ['util'], 'math/automata') add_lib('interval', ['util'], 'math/interval') add_lib('realclosure', ['interval'], 'math/realclosure') add_lib('subpaving', ['interval'], 'math/subpaving') add_lib('ast', ['util', 'polynomial']) + add_lib('grobner', ['ast', 'dd', 'simplex'], 'math/grobner') + add_lib('sat', ['util','dd', 'grobner']) + add_lib('nlsat', ['polynomial', 'sat']) + add_lib('lp', ['util','nlsat'], 'util/lp') add_lib('rewriter', ['ast', 'polynomial', 'automata'], 'ast/rewriter') add_lib('macros', ['rewriter'], 'ast/macros') add_lib('normal_forms', ['rewriter'], 'ast/normal_forms') @@ -33,7 +34,6 @@ def init_project_def(): add_lib('tactic', ['ast', 'model']) add_lib('substitution', ['ast', 'rewriter'], 'ast/substitution') add_lib('parser_util', ['ast'], 'parsers/util') - add_lib('grobner', ['ast', 'dd', 'simplex'], 'math/grobner') add_lib('euclid', ['util'], 'math/euclid') add_lib('proofs', ['rewriter', 'util'], 'ast/proofs') add_lib('solver', ['model', 'tactic', 'proofs']) diff --git a/src/math/dd/dd_pdd.cpp b/src/math/dd/dd_pdd.cpp index 1a3717a3e..f3375308b 100644 --- a/src/math/dd/dd_pdd.cpp +++ b/src/math/dd/dd_pdd.cpp @@ -105,6 +105,19 @@ namespace dd { pdd pdd_manager::mk_xor(pdd const& p, pdd const& q) { return (p*q*2) - p - q; } pdd pdd_manager::mk_not(pdd const& p) { return 1 - p; } + pdd pdd_manager::subst_val(pdd const& p, vector> const& _s) { + typedef std::pair pr; + vector s(_s); + std::function compare_level = + [&](pr const& a, pr const& b) { return m_var2level[a.first] < m_var2level[b.first]; }; + std::sort(s.begin(), s.end(), compare_level); + pdd r(one()); + for (auto const& q : s) { + r = (r*mk_var(q.first)) + q.second; + } + return pdd(apply(p.root, r.root, pdd_subst_val_op), this); + } + pdd_manager::PDD pdd_manager::apply(PDD arg1, PDD arg2, pdd_op op) { bool first = true; SASSERT(well_formed()); @@ -166,6 +179,14 @@ namespace dd { if (is_val(p)) return p; if (!is_val(q) && level(p) < level(q)) return p; break; + case pdd_subst_val_op: + while (!is_val(q) && !is_val(p)) { + if (level(p) == level(q)) break; + if (level(p) < level(q)) q = lo(q); + else p = lo(p); + } + if (is_val(p) || is_val(q)) return p; + break; default: UNREACHABLE(); break; @@ -298,6 +319,16 @@ namespace dd { npop = 0; } break; + case pdd_subst_val_op: + SASSERT(!is_val(p)); + SASSERT(!is_val(q)); + SASSERT(level_p = level_q); + push(apply_rec(lo(p), hi(q), pdd_subst_val_op)); // lo := subst(lo(p), s) + push(apply_rec(hi(p), hi(q), pdd_subst_val_op)); // hi := subst(hi(p), s) + push(apply_rec(lo(q), read(1), pdd_mul_op)); // hi := hi*s[var(p)] + r = apply_rec(read(1), read(3), pdd_add_op); // r := hi + lo := subst(lo(p),s) + s[var(p)]*subst(hi(p),s) + npop = 3; + break; default: UNREACHABLE(); } diff --git a/src/math/dd/dd_pdd.h b/src/math/dd/dd_pdd.h index 9a7c337e9..ee03c1ba9 100644 --- a/src/math/dd/dd_pdd.h +++ b/src/math/dd/dd_pdd.h @@ -62,7 +62,8 @@ namespace dd { pdd_minus_op = 4, pdd_mul_op = 5, pdd_reduce_op = 6, - pdd_no_op = 7 + pdd_subst_val_op = 7, + pdd_no_op = 8 }; struct node { @@ -264,6 +265,7 @@ namespace dd { pdd mk_xor(pdd const& p, pdd const& q); pdd mk_not(pdd const& p); pdd reduce(pdd const& a, pdd const& b); + pdd subst_val(pdd const& a, vector> const& s); bool is_linear(PDD p); bool is_linear(pdd const& p); @@ -330,11 +332,15 @@ namespace dd { pdd reduce(pdd const& other) const { return m.reduce(*this, other); } bool different_leading_term(pdd const& other) const { return m.different_leading_term(*this, other); } + pdd subst_val(vector> const& s) const { return m.subst_val(*this, s); } + std::ostream& display(std::ostream& out) const { return m.display(out, *this); } bool operator==(pdd const& other) const { return root == other.root; } bool operator!=(pdd const& other) const { return root != other.root; } bool operator<(pdd const& other) const { return m.lt(*this, other); } + + unsigned dag_size() const { return m.dag_size(*this); } double tree_size() const { return m.tree_size(*this); } unsigned degree() const { return m.degree(*this); } diff --git a/src/math/grobner/pdd_solver.h b/src/math/grobner/pdd_solver.h index b0f2a6a30..262e82329 100644 --- a/src/math/grobner/pdd_solver.h +++ b/src/math/grobner/pdd_solver.h @@ -105,8 +105,10 @@ public: solver(reslimit& lim, pdd_manager& m); ~solver(); - void operator=(print_dep_t& pd) { m_print_dep = pd; } - void operator=(config const& c) { m_config = c; } + pdd_manager& get_manager() { return m; } + + void set(print_dep_t& pd) { m_print_dep = pd; } + void set(config const& c) { m_config = c; } void reset(); void add(pdd const& p) { add(p, nullptr); } diff --git a/src/math/simplex/bit_matrix.cpp b/src/math/simplex/bit_matrix.cpp index b45f5eb8e..73ff0619e 100644 --- a/src/math/simplex/bit_matrix.cpp +++ b/src/math/simplex/bit_matrix.cpp @@ -91,3 +91,28 @@ std::ostream& bit_matrix::display(std::ostream& out) { return out; } +/* + produce a sequence of bits forming a Gray code. + - All 2^n bit-sequences are covered. + - The Hamming distance between two entries it one. + */ +unsigned_vector bit_matrix::gray(unsigned n) { + SASSERT(n < 32); + if (n == 0) { + return unsigned_vector(); + } + else if (n == 1) { + unsigned_vector v; + v.push_back(0); + v.push_back(1); + return v; + } + else { + auto v = gray(n-1); + auto w = v; + w.reverse(); + for (auto & u : v) u |= (1 << (n-1)); + v.append(w); + return v; + } +} diff --git a/src/math/simplex/bit_matrix.h b/src/math/simplex/bit_matrix.h index 0d81b7249..273f9ddb1 100644 --- a/src/math/simplex/bit_matrix.h +++ b/src/math/simplex/bit_matrix.h @@ -104,6 +104,7 @@ public: private: void basic_solve(); + unsigned_vector gray(unsigned n); }; inline std::ostream& operator<<(std::ostream& out, bit_matrix& m) { return m.display(out); } diff --git a/src/sat/CMakeLists.txt b/src/sat/CMakeLists.txt index f3cd5973b..3a4a98599 100644 --- a/src/sat/CMakeLists.txt +++ b/src/sat/CMakeLists.txt @@ -2,6 +2,7 @@ z3_add_component(sat SOURCES ba_solver.cpp dimacs.cpp + sat_anf_simplifier.cpp sat_asymm_branch.cpp sat_big.cpp sat_binspr.cpp @@ -28,9 +29,11 @@ z3_add_component(sat sat_solver.cpp sat_unit_walk.cpp sat_watched.cpp + sat_xor_util.cpp COMPONENT_DEPENDENCIES util dd + grobner PYG_FILES sat_asymm_branch_params.pyg sat_params.pyg diff --git a/src/sat/ba_solver.cpp b/src/sat/ba_solver.cpp index 9acfca49d..bfbde31dd 100644 --- a/src/sat/ba_solver.cpp +++ b/src/sat/ba_solver.cpp @@ -22,6 +22,7 @@ Revision History: #include "sat/sat_types.h" #include "util/mpz.h" #include "sat/sat_simplifier_params.hpp" +#include "sat/sat_xor_util.h" namespace sat { @@ -1102,16 +1103,6 @@ namespace sat { m_active_vars.reset(); } - void ba_solver::init_visited() { - m_visited_ts++; - if (m_visited_ts == 0) { - m_visited_ts = 1; - m_visited.reset(); - } - while (m_visited.size() < 2*s().num_vars()) { - m_visited.push_back(0); - } - } static bool _debug_conflict = false; static literal _debug_consequent = null_literal; @@ -1878,7 +1869,6 @@ namespace sat { m_constraint_id(0), m_ba(*this), m_sort(m_ba) { TRACE("ba", tout << this << "\n";); m_num_propagations_since_pop = 0; - m_max_xor_size = 5; } ba_solver::~ba_solver() { @@ -1991,14 +1981,11 @@ namespace sat { } bool ba_solver::all_distinct(literal_vector const& lits) { - init_visited(); - for (literal l : lits) { - if (is_visited(l.var())) { - return false; - } - mark_visited(l.var()); - } - return true; + return s().all_distinct(lits); + } + + bool ba_solver::all_distinct(clause const& c) { + return s().all_distinct(c); } bool ba_solver::all_distinct(xr const& x) { @@ -2012,17 +1999,6 @@ namespace sat { return true; } - bool ba_solver::all_distinct(clause const& c) { - init_visited(); - for (literal l : c) { - if (is_visited(l.var())) { - return false; - } - mark_visited(l.var()); - } - return true; - } - literal ba_solver::add_xor_def(literal_vector& lits, bool learned) { unsigned sz = lits.size(); SASSERT (sz > 1); @@ -2934,10 +2910,10 @@ namespace sat { void ba_solver::pre_simplify() { VERIFY(s().at_base_lvl()); - barbet_init_parity(); m_constraint_removed = false; - for (unsigned sz = m_constraints.size(), i = 0; i < sz; ++i) pre_simplify(*m_constraints[i]); - for (unsigned sz = m_learned.size(), i = 0; i < sz; ++i) pre_simplify(*m_learned[i]); + xor_util xu(s()); + for (unsigned sz = m_constraints.size(), i = 0; i < sz; ++i) pre_simplify(xu, *m_constraints[i]); + for (unsigned sz = m_learned.size(), i = 0; i < sz; ++i) pre_simplify(xu, *m_learned[i]); bool change = m_constraint_removed; cleanup_constraints(); if (change) { @@ -2948,10 +2924,10 @@ namespace sat { } } - void ba_solver::pre_simplify(constraint& c) { - if (c.is_xr() && c.size() <= m_max_xor_size) { + void ba_solver::pre_simplify(xor_util& xu, constraint& c) { + if (c.is_xr() && c.size() <= xu.max_xor_size()) { unsigned sz = c.size(); - literal_vector& lits = m_barbet_clause; + literal_vector lits; bool parity = false; xr const& x = c.to_xr(); for (literal lit : x) { @@ -2960,7 +2936,7 @@ namespace sat { // IF_VERBOSE(0, verbose_stream() << "blast: " << c << "\n"); for (unsigned i = 0; i < (1ul << sz); ++i) { - if (m_barbet_parity[sz][i] == parity) { + if (xu.parity(sz, i) == parity) { lits.reset(); for (unsigned j = 0; j < sz; ++j) { lits.push_back(literal(x[j].var(), (0 != (i & (1 << j))))); @@ -3178,7 +3154,6 @@ namespace sat { if (m_roots.empty()) return; reserve_roots(); // validate(); - m_visited.resize(s().num_vars()*2, false); m_constraint_removed = false; for (unsigned sz = m_constraints.size(), i = 0; i < sz; ++i) flush_roots(*m_constraints[i]); @@ -3515,7 +3490,6 @@ namespace sat { - blocked literals */ void ba_solver::init_use_lists() { - m_visited.resize(s().num_vars()*2, false); m_clause_use_list.init(s().num_vars()); m_cnstr_use_list.reset(); m_cnstr_use_list.resize(2*s().num_vars()); @@ -3779,7 +3753,6 @@ namespace sat { for (watched w : get_wlist(~lit)) { if (w.is_binary_clause()) unique = false; } -#if 1 if (!unique) continue; xr const& x1 = c1.to_xr(); xr const& x2 = c2.to_xr(); @@ -3820,370 +3793,27 @@ namespace sat { c2.set_removed(); add_xr(lits, !c1.learned() && !c2.learned()); m_constraint_removed = true; -#endif } } } } void ba_solver::extract_xor() { - if (!s().get_config().m_xor_solver) { - return; - } - barbet_extract_xor(); - return; - - for (clause* cp : s().m_clauses) { - clause& c = *cp; - if (c.was_removed() || c.size() <= 3 || !all_distinct(c)) continue; - init_visited(); - for (literal l : c) mark_visited(l); - literal l0 = c[0]; - literal l1 = c[1]; - if (extract_xor(c, l0) || - extract_xor(c, l1) || - extract_xor(c, ~l0)) { - m_simplify_change = true; - } - } - // extract xor from ternary clauses - unsigned sz = s().num_vars(); - m_ternary.reset(); - m_ternary.reserve(sz); - extract_ternary(s().m_clauses); - extract_ternary(s().m_learned); - for (unsigned v = 0; v < sz; ++v) { - ptr_vector& cs = m_ternary[v]; - for (unsigned i = 0; i < cs.size() && !cs[i]->is_learned(); ++i) { - clause& c = *cs[i]; - if (c.was_removed()) continue; - init_visited(); - for (literal l : c) mark_visited(l); - for (unsigned j = i + 1; j < cs.size(); ++j) { - if (extract_xor(c, *cs[j])) { - m_simplify_change = true; - break; - } - } - } - } - m_ternary.clear(); - } - - void ba_solver::extract_ternary(clause_vector const& clauses) { - for (clause* cp : clauses) { - clause& c = *cp; - if (!c.was_removed() && c.size() == 3 && all_distinct(c)) { - bool_var v = std::min(c[0].var(), std::min(c[1].var(), c[2].var())); - m_ternary[v].push_back(cp); - } - } - } - - void ba_solver::barbet_extract_xor() { - unsigned max_size = m_max_xor_size; - // we better have enough bits in the combination mask to - // handle clauses up to max_size. - // max_size = 5 -> 32 bits - // max_size = 6 -> 64 bits - SASSERT(sizeof(m_barbet_combination)*8 <= (1ull << static_cast(max_size))); - init_clause_filter(); - barbet_init_parity(); - m_barbet_var_position.resize(s().num_vars()); - for (clause* cp : s().m_clauses) { - cp->unmark_used(); - } - for (; max_size > 2; --max_size) { - for (clause* cp : s().m_clauses) { - clause& c = *cp; - if (c.size() == max_size && !c.was_removed() && !c.is_learned() && !c.was_used()) { - barbet_extract_xor(c); - } - } - } - m_clause_filters.clear(); - } - - void ba_solver::barbet_extract_xor(clause& c) { - SASSERT(c.size() > 2); - unsigned filter = get_clause_filter(c); - init_visited(); - bool parity = false; - unsigned mask = 0, i = 0; - for (literal l : c) { - m_barbet_var_position[l.var()] = i; - mark_visited(l.var()); - parity ^= l.sign(); - mask |= (l.sign() << (i++)); - } - m_barbet_clauses_to_remove.reset(); - m_barbet_clauses_to_remove.push_back(&c); - m_barbet_clause.resize(c.size()); - m_barbet_combination = 0; - barbet_set_combination(mask); - c.mark_used(); - for (literal l : c) { - for (auto const& cf : m_clause_filters[l.var()]) { - if ((filter == (filter | cf.m_filter)) && - !cf.m_clause->was_used() && - barbet_extract_xor(parity, c, *cf.m_clause)) { - barbet_add_xor(parity, c); - return; - } - } - // loop over binary clauses in watch list - for (watched const & w : get_wlist(l)) { - if (w.is_binary_clause() && is_visited(w.get_literal().var()) && w.get_literal().index() < l.index()) { - if (barbet_extract_xor(parity, c, ~l, w.get_literal())) { - barbet_add_xor(parity, c); - return; - } - } - } - l.neg(); - for (watched const & w : get_wlist(l)) { - if (w.is_binary_clause() && is_visited(w.get_literal().var()) && w.get_literal().index() < l.index()) { - if (barbet_extract_xor(parity, c, ~l, w.get_literal())) { - barbet_add_xor(parity, c); - return; - } - } - } - } - } - - void ba_solver::barbet_add_xor(bool parity, clause& c) { - for (clause* cp : m_barbet_clauses_to_remove) { + xor_util xu(s()); + std::function f = [this](literal_vector const& l, bool b) { add_xr(l,b); }; + xu.set(f); + xu.extract_xors(); + for (clause* cp : xu.removed_clauses()) { cp->set_removed(true); - } - m_clause_removed = true; - bool learned = false; - literal_vector lits; - for (literal l : c) { - lits.push_back(literal(l.var(), false)); - s().set_external(l.var()); - } - if (parity) lits[0].neg(); - add_xr(lits, learned); - } - - bool ba_solver::barbet_extract_xor(bool parity, clause& c, literal l1, literal l2) { - SASSERT(is_visited(l1.var())); - SASSERT(is_visited(l2.var())); - m_barbet_missing.reset(); - unsigned mask = 0; - for (unsigned i = 0; i < c.size(); ++i) { - if (c[i].var() == l1.var()) { - mask |= (l1.sign() << i); - } - else if (c[i].var() == l2.var()) { - mask |= (l2.sign() << i); - } - else { - m_barbet_missing.push_back(i); - } - } - return barbet_update_combinations(c, parity, mask); - } - - bool ba_solver::barbet_extract_xor(bool parity, clause& c, clause& c2) { - bool parity2 = false; - for (literal l : c2) { - if (!is_visited(l.var())) return false; - parity2 ^= l.sign(); - } - if (c2.size() == c.size() && parity2 != parity) { - return false; - } - if (c2.size() == c.size()) { - m_barbet_clauses_to_remove.push_back(&c2); - c2.mark_used(); - } - // insert missing - unsigned mask = 0; - m_barbet_missing.reset(); - SASSERT(c2.size() <= c.size()); - for (unsigned i = 0; i < c.size(); ++i) { - m_barbet_clause[i] = null_literal; - } - for (literal l : c2) { - unsigned pos = m_barbet_var_position[l.var()]; - m_barbet_clause[pos] = l; - } - for (unsigned j = 0; j < c.size(); ++j) { - literal lit = m_barbet_clause[j]; - if (lit == null_literal) { - m_barbet_missing.push_back(j); - } - else { - mask |= (m_barbet_clause[j].sign() << j); - } - } - - return barbet_update_combinations(c, parity, mask); - } - - bool ba_solver::barbet_update_combinations(clause& c, bool parity, unsigned mask) { - unsigned num_missing = m_barbet_missing.size(); - for (unsigned k = 0; k < (1ul << num_missing); ++k) { - unsigned mask2 = mask; - for (unsigned i = 0; i < num_missing; ++i) { - if ((k & (1 << i)) != 0) { - mask2 |= 1ul << m_barbet_missing[i]; - } - } - barbet_set_combination(mask2); - } - // return true if xor clause is covered. - unsigned sz = c.size(); - for (unsigned i = 0; i < (1ul << sz); ++i) { - if (parity == m_barbet_parity[sz][i] && !barbet_get_combination(i)) { - return false; - } - } - return true; - } - - void ba_solver::barbet_init_parity() { - for (unsigned i = m_barbet_parity.size(); i <= m_max_xor_size; ++i) { - bool_vector bv; - for (unsigned j = 0; j < (1ul << i); ++j) { - bool parity = false; - for (unsigned k = 0; k < i; ++k) { - parity ^= ((j & (1 << k)) != 0); - } - bv.push_back(parity); - } - m_barbet_parity.push_back(bv); + m_clause_removed = true; } } - void ba_solver::init_clause_filter() { - m_clause_filters.reset(); - m_clause_filters.resize(s().num_vars()); - init_clause_filter(s().m_clauses); - init_clause_filter(s().m_learned); - } - - void ba_solver::init_clause_filter(clause_vector& clauses) { - for (clause* cp : clauses) { - clause& c = *cp; - if (c.size() <= m_max_xor_size && all_distinct(c)) { - clause_filter cf(get_clause_filter(c), cp); - for (literal l : c) { - m_clause_filters[l.var()].push_back(cf); - } - } - } - } - - unsigned ba_solver::get_clause_filter(clause& c) { - unsigned filter = 0; - for (literal l : c) { - filter |= 1 << ((l.var() % 32)); - } - return filter; - } - - - /** - * \brief replace (lit0, lit1, lit2), (lit0, ~lit1, ~lit2) - * by (lit0, lit), ~lit x lit1 x lit2 - */ - bool ba_solver::extract_xor(clause& c1, clause& c2) { - SASSERT(c1.size() == 3); - SASSERT(c2.size() == 3); - SASSERT(&c1 != &c2); - literal lit0, lit1, lit2; - if (is_visited(c2[0]) && is_visited(~c2[1]) && is_visited(~c2[2])) { - lit0 = c2[0]; - lit1 = c2[1]; - lit2 = c2[2]; - } - else if (is_visited(c2[1]) && is_visited(~c2[0]) && is_visited(~c2[2])) { - lit0 = c2[1]; - lit1 = c2[0]; - lit2 = c2[2]; - } - else if (is_visited(c2[2]) && is_visited(~c2[0]) && is_visited(~c2[1])) { - lit0 = c2[2]; - lit1 = c2[0]; - lit2 = c2[1]; - } - else { - return false; - } - c1.set_removed(true); - c2.set_removed(true); - m_clause_removed = true; - literal_vector lits; - lits.push_back(lit1); - lits.push_back(lit2); - literal lit = add_xor_def(lits); - lits.reset(); - lits.push_back(lit); - lits.push_back(lit0); - s().mk_clause(lits); - TRACE("ba", tout << c1 << " " << c2 << "\n";); - return true; - } - - bool ba_solver::extract_xor(clause& c, literal l0) { - watch_list & wlist = get_wlist(~l0); - unsigned sz = c.size(); - SASSERT(sz > 3); - for (watched const& w : wlist) { - if (!w.is_clause()) continue; - clause& c2 = s().get_clause(w); - if (c2.size() != sz || c2.was_removed()) continue; - bool is_xor = true; - literal lit1 = null_literal; - literal lit2 = null_literal; - for (literal l : c2) { - if (is_visited(l)) { - // no-op - } - else if (is_visited(~l) && lit1 == null_literal) { - lit1 = l; - } - else if (is_visited(~l) && lit2 == null_literal) { - lit2 = l; - } - else { - is_xor = false; - break; - } - } - if (is_xor && lit2 != null_literal && lit1 != lit2) { - // ensure all literals in c2 are distinct - // this destroys visited, so re-initialize it. - bool distinct = all_distinct(c2); - init_visited(); - for (literal l : c) mark_visited(l); - if (!distinct) { - continue; - } - literal_vector lits; - lits.push_back(lit1); - lits.push_back(lit2); - literal lit = add_xor_def(lits); - lits.reset(); - lits.push_back(lit); - for (literal l : c2) { - if (l != lit1 && l != lit2) { - lits.push_back(l); - } - } - s().mk_clause(lits); - c.set_removed(true); - c2.set_removed(true); - m_clause_removed = true; - TRACE("ba", tout << "xor " << lit1 << " " << lit2 << " : " << c << " " << c2 << "\nnew clause: " << lits << "\n";); - return true; - } - } - return false; - } + void ba_solver::init_visited() { s().init_visited(); } + void ba_solver::mark_visited(literal l) { s().mark_visited(l); } + void ba_solver::mark_visited(bool_var v) { s().mark_visited(v); } + bool ba_solver::is_visited(bool_var v) const { return s().is_visited(v); } + bool ba_solver::is_visited(literal l) const { return s().is_visited(l); } void ba_solver::cleanup_clauses() { if (m_clause_removed) { diff --git a/src/sat/ba_solver.h b/src/sat/ba_solver.h index 4fc439d71..4a085fdd5 100644 --- a/src/sat/ba_solver.h +++ b/src/sat/ba_solver.h @@ -31,6 +31,8 @@ Revision History: #include "util/sorting_network.h" namespace sat { + + class xor_util; class ba_solver : public extension { @@ -287,8 +289,6 @@ namespace sat { // simplification routines - svector m_visited; - unsigned m_visited_ts; vector> m_cnstr_use_list; use_list m_clause_use_list; bool m_simplify_change; @@ -307,11 +307,6 @@ namespace sat { void binary_subsumption(card& c1, literal lit); void clause_subsumption(card& c1, literal lit, clause_vector& removed_clauses); void card_subsumption(card& c1, literal lit); - void init_visited(); - void mark_visited(literal l) { m_visited[l.index()] = m_visited_ts; } - void mark_visited(bool_var v) { mark_visited(literal(v, false)); } - bool is_visited(bool_var v) const { return is_visited(literal(v, false)); } - bool is_visited(literal l) const { return m_visited[l.index()] == m_visited_ts; } unsigned get_num_unblocked_bin(literal l); literal get_min_occurrence_literal(card const& c); void init_use_lists(); @@ -352,7 +347,7 @@ namespace sat { lbool add_assign(constraint& c, literal l); bool incremental_mode() const; void simplify(constraint& c); - void pre_simplify(constraint& c); + void pre_simplify(xor_util& xu, constraint& c); void nullify_tracking_literal(constraint& c); void set_conflict(constraint& c, literal lit); void assign(constraint& c, literal lit); @@ -401,38 +396,6 @@ namespace sat { void simplify(xr& x); void extract_xor(); void merge_xor(); - struct clause_filter { - unsigned m_filter; - clause* m_clause; - clause_filter(unsigned f, clause* cp): - m_filter(f), m_clause(cp) {} - }; - typedef svector bool_vector; - unsigned m_max_xor_size; - vector> m_clause_filters; // index of clauses. - unsigned m_barbet_combination; // bit-mask of parities that have been found - vector m_barbet_parity; // lookup parity for clauses - clause_vector m_barbet_clauses_to_remove; // remove clauses that become xors - unsigned_vector m_barbet_var_position; // position of var in main clause - literal_vector m_barbet_clause; // reference clause with literals sorted according to main clause - unsigned_vector m_barbet_missing; // set of indices not occurring in clause. - void init_clause_filter(); - void init_clause_filter(clause_vector& clauses); - inline void barbet_set_combination(unsigned mask) { m_barbet_combination |= (1 << mask); } - inline bool barbet_get_combination(unsigned mask) const { return (m_barbet_combination & (1 << mask)) != 0; } - void barbet_extract_xor(); - void barbet_init_parity(); - void barbet_extract_xor(clause& c); - bool barbet_extract_xor(bool parity, clause& c, clause& c2); - bool barbet_extract_xor(bool parity, clause& c, literal l1, literal l2); - bool barbet_update_combinations(clause& c, bool parity, unsigned mask); - void barbet_add_xor(bool parity, clause& c); - unsigned get_clause_filter(clause& c); - - vector> m_ternary; - void extract_ternary(clause_vector const& clauses); - bool extract_xor(clause& c, literal l); - bool extract_xor(clause& c1, clause& c2); bool clausify(xr& x); void flush_roots(xr& x); lbool eval(xr const& x) const; @@ -469,6 +432,13 @@ namespace sat { void bail_resolve_conflict(unsigned idx); + void init_visited(); + void mark_visited(literal l); + void mark_visited(bool_var v); + bool is_visited(bool_var v) const; + bool is_visited(literal l) const; + + // access solver inline lbool value(bool_var v) const { return value(literal(v, false)); } inline lbool value(literal lit) const { return m_lookahead ? m_lookahead->value(lit) : m_solver->value(lit); } @@ -563,8 +533,8 @@ namespace sat { constraint* add_xr(literal_vector const& lits, bool learned); literal add_xor_def(literal_vector& lits, bool learned = false); bool all_distinct(literal_vector const& lits); + bool all_distinct(clause const& c); bool all_distinct(xr const& x); - bool all_distinct(clause const& cl); void copy_core(ba_solver* result, bool learned); void copy_constraints(ba_solver* result, ptr_vector const& constraints); diff --git a/src/sat/sat_anf_simplifier.cpp b/src/sat/sat_anf_simplifier.cpp new file mode 100644 index 000000000..41ce35111 --- /dev/null +++ b/src/sat/sat_anf_simplifier.cpp @@ -0,0 +1,274 @@ +/*++ + Copyright (c) 2020 Microsoft Corporation + + Module Name: + + sat_anf_simplifier.cpp + + Abstract: + + Simplification based on ANF format. + + Author: + + Nikolaj Bjorner 2020-01-02 + + --*/ + +#include "sat/sat_anf_simplifier.h" +#include "sat/sat_solver.h" +#include "sat/sat_xor_util.h" +#include "math/grobner/pdd_solver.h" + +namespace sat { + + + class pdd_solver : public dd::solver { + public: + pdd_solver(reslimit& lim, dd::pdd_manager& m): dd::solver(lim, m) {} + }; + + void anf_simplifier::operator()() { + + vector xors; + clause_vector clauses; + svector bins; + m_relevant.reset(); + m_relevant.resize(s.num_vars(), false); + for (clause* cp : s.m_clauses) cp->unmark_used(); + collect_xors(xors); + collect_clauses(clauses, bins); + + dd::pdd_manager m(20, dd::pdd_manager::semantics::mod2_e); + pdd_solver solver(s.rlimit(), m); + configure_solver(solver); + + try { + for (literal_vector const& x : xors) { + add_xor(x, solver); + } + for (clause* cp : clauses) { + add_clause(*cp, solver); + } + for (auto const& b : bins) { + add_bin(b, solver); + } + } + catch (dd::pdd_manager::mem_out) { + IF_VERBOSE(2, verbose_stream() << "(sat.anf memout)\n"); + return; + } + + TRACE("anf_simplifier", solver.display(tout);); + + solver.simplify(); + + TRACE("anf_simplifier", solver.display(tout);); + + unsigned num_units = 0, num_eq = 0; + + for (auto* e : solver.equations()) { + auto const& p = e->poly(); + if (p.is_zero()) { + continue; + } + else if (p.is_val()) { + s.set_conflict(); + break; + } + else if (p.is_unary()) { + // unit + literal lit(p.var(), p.lo().val().is_zero()); + s.assign_unit(lit); + ++num_units; + } + else if (p.is_binary()) { + // equivalence + // x + y + c = 0 + literal x(p.var(), false); + literal y(p.lo().var(), p.lo().lo().val().is_zero()); + s.mk_clause(x, y, true); + s.mk_clause(~x, ~y, true); + ++num_eq; + } + // TBD: could learn binary clauses + // TBD: could try simplify equations using BIG subsumption similar to asymm_branch + } + + IF_VERBOSE(10, solver.display_statistics(verbose_stream() << "(sat.anf\n" ) + << "\n" + << " :num-unit " << num_units + << " :num-eq " << num_eq + << " :num-xor " << xors.size() + << " :num-cls " << clauses.size() + << " :num-bin " << bins.size() + << ")\n"); + } + + void anf_simplifier::collect_clauses(clause_vector & clauses, svector& bins) { + clause_vector oclauses; + for (clause* cp : s.clauses()) { + clause const& c = *cp; + if (c.was_used() || is_too_large(c)) + continue; + else if (is_pre_satisfied(c)) { + oclauses.push_back(cp); + } + else { + clauses.push_back(cp); + } + } + svector obins; + s.collect_bin_clauses(obins, false, false); + unsigned j = 0; + for (auto const& b : obins) { + if (is_pre_satisfied(b)) { + obins[j++] = b; + } + else { + bins.push_back(b); + } + } + obins.shrink(j); + + while (bins.size() + clauses.size() < m_config.m_max_clauses) { + + for (auto const& b : bins) set_relevant(b); + for (clause* cp : clauses) set_relevant(*cp); + + j = 0; + for (auto const& b : obins) { + if (has_relevant_var(b)) { + bins.push_back(b); + } + else { + obins[j++] = b; + } + } + obins.shrink(j); + + if (bins.size() + clauses.size() >= m_config.m_max_clauses) { + break; + } + + j = 0; + for (clause* cp : oclauses) { + clause& c = *cp; + if (has_relevant_var(c)) { + clauses.push_back(cp); + } + else { + oclauses.push_back(cp); + } + } + oclauses.shrink(j); + } + } + + void anf_simplifier::set_relevant(solver::bin_clause const& b) { + set_relevant(b.first); + set_relevant(b.second); + } + + void anf_simplifier::set_relevant(clause const& c) { + for (literal l : c) set_relevant(l); + } + + bool anf_simplifier::is_pre_satisfied(clause const& c) { + for (literal l : c) if (phase_is_true(l)) return true; + return false; + } + + bool anf_simplifier::is_pre_satisfied(solver::bin_clause const& b) { + return phase_is_true(b.first) || phase_is_true(b.second); + } + + bool anf_simplifier::phase_is_true(literal l) { + bool ph = (s.m_best_phase_size > 0) ? s.m_best_phase[l.var()] : s.m_phase[l.var()]; + return l.sign() ? !ph : ph; + } + + bool anf_simplifier::has_relevant_var(clause const& c) { + for (literal l : c) if (is_relevant(l)) return true; + return false; + } + + bool anf_simplifier::has_relevant_var(solver::bin_clause const& b) { + return is_relevant(b.first) || is_relevant(b.second); + } + + void anf_simplifier::collect_xors(vector& xors) { + std::function f = + [&](literal_vector const& l, bool) { xors.push_back(l); }; + + xor_util xu(s); + xu.set(f); + xu.extract_xors(); + for (clause* cp : s.m_clauses) cp->unmark_used(); + for (clause* cp : s.m_learned) cp->unmark_used(); + for (clause* cp : xu.removed_clauses()) cp->mark_used(); + } + + void anf_simplifier::configure_solver(pdd_solver& ps) { + // assign levels to variables. + // use s.def_level as a primary source for the level of a variable. + // secondarily, sort variables randomly (each variable is assigned + // a random, unique, id). + unsigned nv = s.num_vars(); + unsigned_vector l2v(nv), var2id(nv), id2var(nv); + svector> vl(nv); + + for (unsigned i = 0; i < nv; ++i) var2id[i] = i; + shuffle(var2id.size(), var2id.c_ptr(), s.rand()); + for (unsigned i = 0; i < nv; ++i) id2var[var2id[i]] = i; + for (unsigned i = 0; i < nv; ++i) vl[i] = std::make_pair(s.def_level(i), var2id[i]); + std::sort(vl.begin(), vl.end()); + for (unsigned i = 0; i < nv; ++i) l2v[i] = id2var[vl[i].second]; + + ps.get_manager().reset(l2v); + + // set configuration parameters. + dd::solver::config cfg; + cfg.m_expr_size_limit = 1000; + cfg.m_max_steps = 1000; + cfg.m_random_seed = s.rand()(); + cfg.m_enable_exlin = true; + + unsigned max_num_nodes = 1 << 18; + ps.get_manager().set_max_num_nodes(max_num_nodes); + ps.set(cfg); + } + + void anf_simplifier::add_bin(solver::bin_clause const& b, pdd_solver& ps) { + auto& m = ps.get_manager(); + auto v = m.mk_var(b.first.var()); + auto w = m.mk_var(b.second.var()); + if (b.first.sign()) v = ~v; + if (b.second.sign()) w = ~w; + dd::pdd p = v | w; + ps.add(p); + } + + void anf_simplifier::add_clause(clause const& c, pdd_solver& ps) { + auto& m = ps.get_manager(); + dd::pdd p = m.zero(); + for (literal l : c) { + auto v = m.mk_var(l.var()); + if (l.sign()) v = ~v; + p |= v; + } + ps.add(p); + } + + void anf_simplifier::add_xor(literal_vector const& x, pdd_solver& ps) { + auto& m = ps.get_manager(); + dd::pdd p = m.zero(); + for (literal l : x) { + auto v = m.mk_var(l.var()); + if (l.sign()) v = ~v; + p ^= v; + } + ps.add(p); + } + +} diff --git a/src/sat/sat_anf_simplifier.h b/src/sat/sat_anf_simplifier.h new file mode 100644 index 000000000..d008b86ce --- /dev/null +++ b/src/sat/sat_anf_simplifier.h @@ -0,0 +1,75 @@ +/*++ + Copyright (c) 2020 Microsoft Corporation + + Module Name: + + sat_anf_simplifier.h + + Abstract: + + Simplification based on ANF format. + + Author: + + Nikolaj Bjorner 2020-01-02 + + Notes: + + + --*/ +#pragma once; + +#include "util/params.h" +#include "util/statistics.h" +#include "sat/sat_clause.h" +#include "sat/sat_types.h" +#include "sat/sat_solver.h" + +namespace sat { + + class pdd_solver; + + class anf_simplifier { + public: + struct config { + unsigned m_max_clause_size; + unsigned m_max_clauses; + config(): + m_max_clause_size(10), + m_max_clauses(10000) + {} + }; + private: + solver& s; + config m_config; + svector m_relevant; + + void collect_clauses(clause_vector & clauses, svector& bins); + void collect_xors(vector& xors); + void configure_solver(pdd_solver& ps); + void add_clause(clause const& c, pdd_solver& ps); + void add_bin(solver::bin_clause const& b, pdd_solver& ps); + void add_xor(literal_vector const& x, pdd_solver& ps); + + bool is_pre_satisfied(clause const& c); + bool is_pre_satisfied(solver::bin_clause const& b); + bool is_too_large(clause const& c) { return c.size() > m_config.m_max_clause_size; } + bool has_relevant_var(clause const& c); + bool has_relevant_var(solver::bin_clause const& b); + bool is_relevant(literal l) { return is_relevant(l.var()); } + bool is_relevant(bool_var v) { return m_relevant[v]; } + bool phase_is_true(literal l); + + void set_relevant(solver::bin_clause const& b); + void set_relevant(clause const& c); + void set_relevant(literal l) { set_relevant(l.var()); } + void set_relevant(bool_var v) { m_relevant[v] = true; } + + public: + anf_simplifier(solver& s) : s(s) {} + ~anf_simplifier() {} + + void operator()(); + void set(config const& cfg) { m_config = cfg; } + }; +} diff --git a/src/sat/sat_config.cpp b/src/sat/sat_config.cpp index 832b58f6e..7935dc994 100644 --- a/src/sat/sat_config.cpp +++ b/src/sat/sat_config.cpp @@ -100,6 +100,7 @@ namespace sat { m_unit_walk = p.unit_walk(); m_unit_walk_threads = p.unit_walk_threads(); m_binspr = p.binspr(); + m_anf_simplify = p.anf(); m_lookahead_simplify = p.lookahead_simplify(); m_lookahead_double = p.lookahead_double(); m_lookahead_simplify_bca = p.lookahead_simplify_bca(); diff --git a/src/sat/sat_config.h b/src/sat/sat_config.h index e3fc957e4..1da814298 100644 --- a/src/sat/sat_config.h +++ b/src/sat/sat_config.h @@ -120,6 +120,7 @@ namespace sat { unsigned m_unit_walk_threads; bool m_unit_walk; bool m_binspr; + bool m_anf_simplify; bool m_lookahead_simplify; bool m_lookahead_simplify_bca; cutoff_t m_lookahead_cube_cutoff; diff --git a/src/sat/sat_params.pyg b/src/sat/sat_params.pyg index b1bd9666a..8985bdfa1 100644 --- a/src/sat/sat_params.pyg +++ b/src/sat/sat_params.pyg @@ -70,6 +70,7 @@ def_module_params('sat', ('unit_walk', BOOL, False, 'use unit-walk search instead of CDCL'), ('unit_walk_threads', UINT, 0, 'number of unit-walk search threads to find satisfiable solution'), ('binspr', BOOL, False, 'enable SPR inferences of binary propagation redundant clauses. This inprocessing step eliminates models'), + ('anf', BOOL, False, 'enable ANF based simplification in-processing'), ('lookahead.cube.cutoff', SYMBOL, 'depth', 'cutoff type used to create lookahead cubes: depth, freevars, psat, adaptive_freevars, adaptive_psat'), # - depth: the maximal cutoff is fixed to the value of lookahead.cube.depth. # So if the value is 10, at most 1024 cubes will be generated of length 10. diff --git a/src/sat/sat_solver.cpp b/src/sat/sat_solver.cpp index 2e2accfa0..9a0c4975c 100644 --- a/src/sat/sat_solver.cpp +++ b/src/sat/sat_solver.cpp @@ -30,6 +30,7 @@ Revision History: #include "sat/sat_unit_walk.h" #include "sat/sat_ddfw.h" #include "sat/sat_prob.h" +#include "sat/sat_anf_simplifier.h" #if defined(_MSC_VER) && !defined(_M_ARM) && !defined(_M_ARM64) # include #endif @@ -149,7 +150,8 @@ namespace sat { m_phase[v] = src.m_phase[v]; m_best_phase[v] = src.m_best_phase[v]; m_prev_phase[v] = src.m_prev_phase[v]; - + m_level[v] = src.m_level[v]; + // inherit activity: m_activity[v] = src.m_activity[v]; m_case_split_queue.activity_changed_eh(v, false); @@ -236,7 +238,7 @@ namespace sat { // // ----------------------- - bool_var solver::mk_var(bool ext, bool dvar) { + bool_var solver::mk_var(bool ext, bool dvar, unsigned level) { m_model_is_current = false; m_stats.m_mk_var++; bool_var v = m_justification.size(); @@ -248,6 +250,7 @@ namespace sat { m_decision.push_back(dvar); m_eliminated.push_back(false); m_external.push_back(ext); + m_level.push_back(level); m_touched.push_back(0); m_activity.push_back(0); m_mark.push_back(false); @@ -1903,6 +1906,11 @@ namespace sat { lh.collect_statistics(m_aux_stats); } + if (m_config.m_anf_simplify) { + anf_simplifier anf(*this); + anf(); + } + reinit_assumptions(); if (inconsistent()) return; @@ -3852,6 +3860,7 @@ namespace sat { m_decision.shrink(v); m_eliminated.shrink(v); m_external.shrink(v); + m_level.shrink(v); m_touched.shrink(v); m_activity.shrink(v); m_mark.shrink(v); @@ -4933,4 +4942,42 @@ namespace sat { return out; } + bool solver::all_distinct(literal_vector const& lits) { + init_visited(); + for (literal l : lits) { + if (is_visited(l.var())) { + return false; + } + mark_visited(l.var()); + } + return true; + } + + bool solver::all_distinct(clause const& c) { + init_visited(); + for (literal l : c) { + if (is_visited(l.var())) { + return false; + } + mark_visited(l.var()); + } + return true; + } + + void solver::init_visited() { + if (m_visited.empty()) { + m_visited_ts = 0; + } + m_visited_ts++; + if (m_visited_ts == 0) { + m_visited_ts = 1; + m_visited.reset(); + } + while (m_visited.size() < 2*num_vars()) { + m_visited.push_back(0); + } + } + + + }; diff --git a/src/sat/sat_solver.h b/src/sat/sat_solver.h index 4fb587c7c..ab10883f5 100644 --- a/src/sat/sat_solver.h +++ b/src/sat/sat_solver.h @@ -122,6 +122,7 @@ namespace sat { svector m_lit_mark; svector m_eliminated; svector m_external; + unsigned_vector m_level; unsigned_vector m_touched; unsigned m_touch_index; literal_vector m_replay_assign; @@ -163,6 +164,9 @@ namespace sat { clause_wrapper_vector m_clauses_to_reinit; std::string m_reason_unknown; + svector m_visited; + unsigned m_visited_ts; + struct scope { unsigned m_trail_lim; unsigned m_clauses_to_reinit_lim; @@ -202,6 +206,7 @@ namespace sat { friend class mus; friend class drat; friend class ba_solver; + friend class anf_simplifier; friend class parallel; friend class lookahead; friend class local_search; @@ -211,6 +216,7 @@ namespace sat { friend struct mk_stat; friend class elim_vars; friend class scoped_detach; + friend class xor_util; public: solver(params_ref const & p, reslimit& l); ~solver() override; @@ -241,9 +247,10 @@ namespace sat { // // ----------------------- void add_clause(unsigned num_lits, literal * lits, bool learned) override { mk_clause(num_lits, lits, learned); } - bool_var add_var(bool ext) override { return mk_var(ext, true); } + bool_var add_var(bool ext, unsigned level = 0) override { return mk_var(ext, true, level); } + + bool_var mk_var(bool ext = false, bool dvar = true, unsigned level = 0); - bool_var mk_var(bool ext = false, bool dvar = true); clause* mk_clause(literal_vector const& lits, bool learned = false) { return mk_clause(lits.size(), lits.c_ptr(), learned); } clause* mk_clause(unsigned num_lits, literal * lits, bool learned = false); clause* mk_clause(literal l1, literal l2, bool learned = false); @@ -302,6 +309,14 @@ namespace sat { void detach_ter_clause(clause & c); void push_reinit_stack(clause & c); + void init_visited(); + void mark_visited(literal l) { m_visited[l.index()] = m_visited_ts; } + void mark_visited(bool_var v) { mark_visited(literal(v, false)); } + bool is_visited(bool_var v) const { return is_visited(literal(v, false)); } + bool is_visited(literal l) const { return m_visited[l.index()] == m_visited_ts; } + bool all_distinct(literal_vector const& lits); + bool all_distinct(clause const& cl); + // ----------------------- // // Basic @@ -319,6 +334,7 @@ namespace sat { bool was_eliminated(bool_var v) const { return m_eliminated[v]; } void set_eliminated(bool_var v, bool f) override; bool was_eliminated(literal l) const { return was_eliminated(l.var()); } + unsigned def_level(bool_var v) const { return m_level[v]; } unsigned scope_lvl() const { return m_scope_lvl; } unsigned search_lvl() const { return m_search_lvl; } bool at_search_lvl() const { return m_scope_lvl == m_search_lvl; } diff --git a/src/sat/sat_solver_core.h b/src/sat/sat_solver_core.h index b3c43ea6a..3bcd4f143 100644 --- a/src/sat/sat_solver_core.h +++ b/src/sat/sat_solver_core.h @@ -63,7 +63,9 @@ namespace sat { add_clause(3, lits, is_redundant); } // create boolean variable, tagged as external (= true) or internal (can be eliminated). - virtual bool_var add_var(bool ext) = 0; + // the level indicates the depth in an ast the variable comes from. + // variables of higher levels are outputs gates relative to lower levels + virtual bool_var add_var(bool ext, unsigned level = 0) = 0; // update parameters virtual void updt_params(params_ref const& p) {} diff --git a/src/sat/sat_xor_util.cpp b/src/sat/sat_xor_util.cpp new file mode 100644 index 000000000..60a7a3565 --- /dev/null +++ b/src/sat/sat_xor_util.cpp @@ -0,0 +1,234 @@ +/*++ + Copyright (c) 2020 Microsoft Corporation + + Module Name: + + sat_xor_util.cpp + + Abstract: + + xor utilities + + Author: + + Nikolaj Bjorner 2020-01-02 + + Notes: + + + --*/ +#pragma once; + +#include "sat/sat_xor_util.h" +#include "sat/sat_solver.h" + +namespace sat { + + void xor_util::extract_xors() { + m_removed_clauses.reset(); + if (!s.get_config().m_xor_solver) { + return; + } + unsigned max_size = m_max_xor_size; + // we better have enough bits in the combination mask to + // handle clauses up to max_size. + // max_size = 5 -> 32 bits + // max_size = 6 -> 64 bits + SASSERT(sizeof(m_combination)*8 <= (1ull << static_cast(max_size))); + init_clause_filter(); + m_var_position.resize(s.num_vars()); + for (clause* cp : s.m_clauses) { + cp->unmark_used(); + } + for (; max_size > 2; --max_size) { + for (clause* cp : s.m_clauses) { + clause& c = *cp; + if (c.size() == max_size && !c.was_removed() && !c.is_learned() && !c.was_used()) { + extract_xor(c); + } + } + } + m_clause_filters.clear(); + } + + void xor_util::extract_xor(clause& c) { + SASSERT(c.size() > 2); + unsigned filter = get_clause_filter(c); + s.init_visited(); + bool parity = false; + unsigned mask = 0, i = 0; + for (literal l : c) { + m_var_position[l.var()] = i; + s.mark_visited(l.var()); + parity ^= l.sign(); + mask |= (l.sign() << (i++)); + } + m_clauses_to_remove.reset(); + m_clauses_to_remove.push_back(&c); + m_clause.resize(c.size()); + m_combination = 0; + set_combination(mask); + c.mark_used(); + for (literal l : c) { + for (auto const& cf : m_clause_filters[l.var()]) { + if ((filter == (filter | cf.m_filter)) && + !cf.m_clause->was_used() && + extract_xor(parity, c, *cf.m_clause)) { + add_xor(parity, c); + return; + } + } + // loop over binary clauses in watch list + for (watched const & w : s.get_wlist(l)) { + if (w.is_binary_clause() && s.is_visited(w.get_literal().var()) && w.get_literal().index() < l.index()) { + if (extract_xor(parity, c, ~l, w.get_literal())) { + add_xor(parity, c); + return; + } + } + } + l.neg(); + for (watched const & w : s.get_wlist(l)) { + if (w.is_binary_clause() && s.is_visited(w.get_literal().var()) && w.get_literal().index() < l.index()) { + if (extract_xor(parity, c, ~l, w.get_literal())) { + add_xor(parity, c); + return; + } + } + } + } + } + + void xor_util::add_xor(bool parity, clause& c) { + DEBUG_CODE(for (clause* cp : m_clauses_to_remove) VERIFY(cp->was_used());); + m_removed_clauses.append(m_clauses_to_remove); + bool learned = false; + literal_vector lits; + for (literal l : c) { + lits.push_back(literal(l.var(), false)); + s.set_external(l.var()); + } + if (parity) lits[0].neg(); + m_add_xr(lits, learned); + } + + bool xor_util::extract_xor(bool parity, clause& c, literal l1, literal l2) { + SASSERT(s.is_visited(l1.var())); + SASSERT(s.is_visited(l2.var())); + m_missing.reset(); + unsigned mask = 0; + for (unsigned i = 0; i < c.size(); ++i) { + if (c[i].var() == l1.var()) { + mask |= (l1.sign() << i); + } + else if (c[i].var() == l2.var()) { + mask |= (l2.sign() << i); + } + else { + m_missing.push_back(i); + } + } + return update_combinations(c, parity, mask); + } + + bool xor_util::extract_xor(bool parity, clause& c, clause& c2) { + bool parity2 = false; + for (literal l : c2) { + if (!s.is_visited(l.var())) return false; + parity2 ^= l.sign(); + } + if (c2.size() == c.size() && parity2 != parity) { + return false; + } + if (c2.size() == c.size()) { + m_clauses_to_remove.push_back(&c2); + c2.mark_used(); + } + // insert missing + unsigned mask = 0; + m_missing.reset(); + SASSERT(c2.size() <= c.size()); + for (unsigned i = 0; i < c.size(); ++i) { + m_clause[i] = null_literal; + } + for (literal l : c2) { + unsigned pos = m_var_position[l.var()]; + m_clause[pos] = l; + } + for (unsigned j = 0; j < c.size(); ++j) { + literal lit = m_clause[j]; + if (lit == null_literal) { + m_missing.push_back(j); + } + else { + mask |= (m_clause[j].sign() << j); + } + } + + return update_combinations(c, parity, mask); + } + + bool xor_util::update_combinations(clause& c, bool parity, unsigned mask) { + unsigned num_missing = m_missing.size(); + for (unsigned k = 0; k < (1ul << num_missing); ++k) { + unsigned mask2 = mask; + for (unsigned i = 0; i < num_missing; ++i) { + if ((k & (1 << i)) != 0) { + mask2 |= 1ul << m_missing[i]; + } + } + set_combination(mask2); + } + // return true if xor clause is covered. + unsigned sz = c.size(); + for (unsigned i = 0; i < (1ul << sz); ++i) { + if (parity == m_parity[sz][i] && !get_combination(i)) { + return false; + } + } + return true; + } + + void xor_util::init_parity() { + for (unsigned i = m_parity.size(); i <= m_max_xor_size; ++i) { + bool_vector bv; + for (unsigned j = 0; j < (1ul << i); ++j) { + bool parity = false; + for (unsigned k = 0; k < i; ++k) { + parity ^= ((j & (1 << k)) != 0); + } + bv.push_back(parity); + } + m_parity.push_back(bv); + } + } + + void xor_util::init_clause_filter() { + m_clause_filters.reset(); + m_clause_filters.resize(s.num_vars()); + init_clause_filter(s.m_clauses); + init_clause_filter(s.m_learned); + } + + void xor_util::init_clause_filter(clause_vector& clauses) { + for (clause* cp : clauses) { + clause& c = *cp; + if (c.size() <= m_max_xor_size && s.all_distinct(c)) { + clause_filter cf(get_clause_filter(c), cp); + for (literal l : c) { + m_clause_filters[l.var()].push_back(cf); + } + } + } + } + + unsigned xor_util::get_clause_filter(clause& c) { + unsigned filter = 0; + for (literal l : c) { + filter |= 1 << ((l.var() % 32)); + } + return filter; + } + + +} diff --git a/src/sat/sat_xor_util.h b/src/sat/sat_xor_util.h new file mode 100644 index 000000000..183608840 --- /dev/null +++ b/src/sat/sat_xor_util.h @@ -0,0 +1,76 @@ +/*++ + Copyright (c) 2020 Microsoft Corporation + + Module Name: + + sat_xor.h + + Abstract: + + xor utilities + + Author: + + Nikolaj Bjorner 2020-01-02 + + Notes: + + Based on xor extraction paper by Meel & Soos, AAAI 2018. + + --*/ + +#pragma once; + +#include "util/params.h" +#include "util/statistics.h" +#include "sat/sat_clause.h" +#include "sat/sat_types.h" +#include "sat/sat_solver.h" + +namespace sat { + + class xor_util { + solver& s; + struct clause_filter { + unsigned m_filter; + clause* m_clause; + clause_filter(unsigned f, clause* cp): + m_filter(f), m_clause(cp) {} + }; + typedef svector bool_vector; + unsigned m_max_xor_size; + vector> m_clause_filters; // index of clauses. + unsigned m_combination; // bit-mask of parities that have been found + vector m_parity; // lookup parity for clauses + clause_vector m_clauses_to_remove; // remove clauses that become xors + unsigned_vector m_var_position; // position of var in main clause + literal_vector m_clause; // reference clause with literals sorted according to main clause + unsigned_vector m_missing; // set of indices not occurring in clause. + clause_vector m_removed_clauses; + std::function m_add_xr; + + inline void set_combination(unsigned mask) { m_combination |= (1 << mask); } + inline bool get_combination(unsigned mask) const { return (m_combination & (1 << mask)) != 0; } + void extract_xor(clause& c); + void add_xor(bool parity, clause& c); + bool extract_xor(bool parity, clause& c, literal l1, literal l2); + bool extract_xor(bool parity, clause& c, clause& c2); + bool update_combinations(clause& c, bool parity, unsigned mask); + void init_parity(); + void init_clause_filter(); + void init_clause_filter(clause_vector& clauses); + unsigned get_clause_filter(clause& c); + + public: + xor_util(solver& s) : s(s), m_max_xor_size(5) { init_parity(); } + ~xor_util() {} + + void set(std::function& f) { m_add_xr = f; } + + bool parity(unsigned i, unsigned j) const { return m_parity[i][j]; } + unsigned max_xor_size() const { return m_max_xor_size; } + + void extract_xors(); + clause_vector& removed_clauses() { return m_removed_clauses; } + }; +} diff --git a/src/sat/tactic/goal2sat.cpp b/src/sat/tactic/goal2sat.cpp index a3e6b857c..7eb24a6a5 100644 --- a/src/sat/tactic/goal2sat.cpp +++ b/src/sat/tactic/goal2sat.cpp @@ -121,7 +121,7 @@ struct goal2sat::imp { sat::literal mk_true() { if (m_true == sat::null_literal) { // create fake variable to represent true; - m_true = sat::literal(m_solver.add_var(false), false); + m_true = sat::literal(m_solver.add_var(false, 0), false); mk_clause(m_true); // v is true } return m_true; @@ -140,7 +140,7 @@ struct goal2sat::imp { } else { bool ext = m_default_external || !is_uninterp_const(t) || m_interface_vars.contains(t); - sat::bool_var v = m_solver.add_var(ext); + sat::bool_var v = m_solver.add_var(ext, get_depth(t)); m_map.insert(t, v); l = sat::literal(v, sign); TRACE("sat", tout << "new_var: " << v << ": " << mk_bounded_pp(t, m, 2) << "\n";); @@ -248,7 +248,7 @@ struct goal2sat::imp { } else { SASSERT(num <= m_result_stack.size()); - sat::bool_var k = m_solver.add_var(false); + sat::bool_var k = m_solver.add_var(false, get_depth(t)); sat::literal l(k, false); m_cache.insert(t, l); sat::literal * lits = m_result_stack.end() - num; @@ -287,7 +287,7 @@ struct goal2sat::imp { } else { SASSERT(num <= m_result_stack.size()); - sat::bool_var k = m_solver.add_var(false); + sat::bool_var k = m_solver.add_var(false, get_depth(t)); sat::literal l(k, false); m_cache.insert(t, l); // l => /\ lits @@ -330,7 +330,7 @@ struct goal2sat::imp { m_result_stack.reset(); } else { - sat::bool_var k = m_solver.add_var(false); + sat::bool_var k = m_solver.add_var(false, get_depth(n)); sat::literal l(k, false); m_cache.insert(n, l); mk_clause(~l, ~c, t); @@ -367,7 +367,7 @@ struct goal2sat::imp { m_result_stack.reset(); } else { - sat::bool_var k = m_solver.add_var(false); + sat::bool_var k = m_solver.add_var(false, get_depth(t)); sat::literal l(k, false); m_cache.insert(t, l); mk_clause(~l, l1, ~l2); @@ -391,7 +391,7 @@ struct goal2sat::imp { return; } sat::literal_vector lits; - sat::bool_var v = m_solver.add_var(true); + sat::bool_var v = m_solver.add_var(true, get_depth(t)); lits.push_back(sat::literal(v, true)); convert_pb_args(num, lits); // ensure that = is converted to xor @@ -473,7 +473,7 @@ struct goal2sat::imp { m_ext->add_pb_ge(sat::null_bool_var, wlits, k1); } else { - sat::bool_var v = m_solver.add_var(true); + sat::bool_var v = m_solver.add_var(true, get_depth(t)); sat::literal lit(v, sign); m_ext->add_pb_ge(v, wlits, k.get_unsigned()); TRACE("goal2sat", tout << "root: " << root << " lit: " << lit << "\n";); @@ -504,7 +504,7 @@ struct goal2sat::imp { m_ext->add_pb_ge(sat::null_bool_var, wlits, k1); } else { - sat::bool_var v = m_solver.add_var(true); + sat::bool_var v = m_solver.add_var(true, get_depth(t)); sat::literal lit(v, sign); m_ext->add_pb_ge(v, wlits, k.get_unsigned()); TRACE("goal2sat", tout << "root: " << root << " lit: " << lit << "\n";); @@ -518,8 +518,8 @@ struct goal2sat::imp { svector wlits; convert_pb_args(t, wlits); bool base_assert = (root && !sign && m_solver.num_user_scopes() == 0); - sat::bool_var v1 = base_assert ? sat::null_bool_var : m_solver.add_var(true); - sat::bool_var v2 = base_assert ? sat::null_bool_var : m_solver.add_var(true); + sat::bool_var v1 = base_assert ? sat::null_bool_var : m_solver.add_var(true, get_depth(t)); + sat::bool_var v2 = base_assert ? sat::null_bool_var : m_solver.add_var(true, get_depth(t)); m_ext->add_pb_ge(v1, wlits, k.get_unsigned()); k.neg(); for (wliteral& wl : wlits) { @@ -533,7 +533,7 @@ struct goal2sat::imp { } else { sat::literal l1(v1, false), l2(v2, false); - sat::bool_var v = m_solver.add_var(false); + sat::bool_var v = m_solver.add_var(false, get_depth(t)); sat::literal l(v, false); mk_clause(~l, l1); mk_clause(~l, l2); @@ -558,7 +558,7 @@ struct goal2sat::imp { m_ext->add_at_least(sat::null_bool_var, lits, k2); } else { - sat::bool_var v = m_solver.add_var(true); + sat::bool_var v = m_solver.add_var(true, get_depth(t)); sat::literal lit(v, false); m_ext->add_at_least(v, lits, k.get_unsigned()); m_cache.insert(t, lit); @@ -585,7 +585,7 @@ struct goal2sat::imp { m_ext->add_at_least(sat::null_bool_var, lits, k2); } else { - sat::bool_var v = m_solver.add_var(true); + sat::bool_var v = m_solver.add_var(true, get_depth(t)); sat::literal lit(v, false); m_ext->add_at_least(v, lits, k2); m_cache.insert(t, lit); @@ -598,8 +598,8 @@ struct goal2sat::imp { SASSERT(k.is_unsigned()); sat::literal_vector lits; convert_pb_args(t->get_num_args(), lits); - sat::bool_var v1 = (root && !sign) ? sat::null_bool_var : m_solver.add_var(true); - sat::bool_var v2 = (root && !sign) ? sat::null_bool_var : m_solver.add_var(true); + sat::bool_var v1 = (root && !sign) ? sat::null_bool_var : m_solver.add_var(true, get_depth(t)); + sat::bool_var v2 = (root && !sign) ? sat::null_bool_var : m_solver.add_var(true, get_depth(t)); m_ext->add_at_least(v1, lits, k.get_unsigned()); for (sat::literal& l : lits) { l.neg(); @@ -612,7 +612,7 @@ struct goal2sat::imp { } else { sat::literal l1(v1, false), l2(v2, false); - sat::bool_var v = m_solver.add_var(false); + sat::bool_var v = m_solver.add_var(false, get_depth(t)); sat::literal l(v, false); mk_clause(~l, l1); mk_clause(~l, l2);