/*++ Copyright (c) 2017 Microsoft Corporation Module Name: xor_solver.cpp Abstract: Extension for xr reasoning. Author: Nikolaj Bjorner (nbjorner) 2017-01-30 Revision History: --*/ #include "sat/sat_types.h" #include "sat/smt/ba_solver.h" #include "sat/sat_simplifier_params.hpp" #include "sat/sat_xor_finder.h" namespace sat { // -------------------- // xr: lbool ba_solver::add_assign(xr& x, literal alit) { // literal is assigned unsigned sz = x.size(); TRACE("ba", tout << "assign: " << ~alit << "@" << lvl(~alit) << " " << x << "\n"; display(tout, x, true); ); VERIFY(x.lit() == null_literal); SASSERT(value(alit) != l_undef); unsigned index = (x[1].var() == alit.var()) ? 1 : 0; VERIFY(x[index].var() == alit.var()); // find a literal to swap with: for (unsigned i = 2; i < sz; ++i) { literal lit = x[i]; if (value(lit) == l_undef) { x.swap(index, i); x.unwatch_literal(*this, ~alit); // alit gets unwatched by propagate_core because we return l_undef x.watch_literal(*this, lit); x.watch_literal(*this, ~lit); TRACE("ba", tout << "swap in: " << lit << " " << x << "\n";); return l_undef; } } if (index == 0) { x.swap(0, 1); } // alit resides at index 1. VERIFY(x[1].var() == alit.var()); if (value(x[0]) == l_undef) { bool p = x.parity(*this, 1); assign(x, p ? ~x[0] : x[0]); } else if (!x.parity(*this, 0)) { set_conflict(x, ~x[1]); } return inconsistent() ? l_false : l_true; } void ba_solver::add_xr(literal_vector const& lits) { add_xr(lits, false); } bool ba_solver::all_distinct(literal_vector const& lits) { return s().all_distinct(lits); } bool ba_solver::all_distinct(clause const& c) { return s().all_distinct(c); } bool ba_solver::all_distinct(xr const& x) { init_visited(); for (literal l : x) { if (is_visited(l.var())) { return false; } mark_visited(l.var()); } return true; } literal ba_solver::add_xor_def(literal_vector& lits, bool learned) { unsigned sz = lits.size(); SASSERT (sz > 1); VERIFY(all_distinct(lits)); init_visited(); bool parity1 = true; for (literal l : lits) { mark_visited(l.var()); parity1 ^= l.sign(); } for (auto const & w : get_wlist(lits[0])) { if (w.get_kind() != watched::EXT_CONSTRAINT) continue; constraint& c = index2constraint(w.get_ext_constraint_idx()); if (!c.is_xr()) continue; xr& x = c.to_xr(); if (sz + 1 != x.size()) continue; bool is_match = true; literal l0 = null_literal; bool parity2 = true; for (literal l : x) { if (!is_visited(l.var())) { if (l0 == null_literal) { l0 = l; } else { is_match = false; break; } } else { parity2 ^= l.sign(); } } if (is_match) { SASSERT(all_distinct(x)); if (parity1 == parity2) l0.neg(); if (!learned && x.learned()) { set_non_learned(x); } return l0; } } bool_var v = s().mk_var(true, true); literal lit(v, false); lits.push_back(~lit); add_xr(lits, learned); return lit; } constraint* ba_solver::add_xr(literal_vector const& _lits, bool learned) { literal_vector lits; u_map var2sign; bool sign = false, odd = false; for (literal lit : _lits) { if (var2sign.find(lit.var(), sign)) { var2sign.erase(lit.var()); odd ^= (sign ^ lit.sign()); } else { var2sign.insert(lit.var(), lit.sign()); } } for (auto const& kv : var2sign) { lits.push_back(literal(kv.m_key, kv.m_value)); } if (odd && !lits.empty()) { lits[0].neg(); } switch (lits.size()) { case 0: if (!odd) s().set_conflict(justification(0)); return nullptr; case 1: s().assign_scoped(lits[0]); return nullptr; default: break; } void * mem = m_allocator.allocate(xr::get_obj_size(lits.size())); constraint_base::initialize(mem, this); xr* x = new (constraint_base::ptr2mem(mem)) xr(next_id(), lits); x->set_learned(learned); add_constraint(x); return x; } /** \brief perform parity resolution on xr premises. The idea is to collect premises based on xr resolvents. Variables that are repeated an even number of times cancel out. */ void ba_solver::get_xr_antecedents(literal l, unsigned index, justification js, literal_vector& r) { unsigned level = lvl(l); bool_var v = l.var(); SASSERT(js.get_kind() == justification::EXT_JUSTIFICATION); TRACE("ba", tout << l << ": " << js << "\n"; for (unsigned i = 0; i <= index; ++i) tout << s().m_trail[i] << " "; tout << "\n"; s().display_units(tout); ); unsigned num_marks = 0; while (true) { TRACE("ba", tout << "process: " << l << " " << js << "\n";); if (js.get_kind() == justification::EXT_JUSTIFICATION) { constraint& c = index2constraint(js.get_ext_justification_idx()); TRACE("ba", tout << c << "\n";); if (!c.is_xr()) { r.push_back(l); } else { xr& x = c.to_xr(); if (x[1].var() == l.var()) { x.swap(0, 1); } VERIFY(x[0].var() == l.var()); for (unsigned i = 1; i < x.size(); ++i) { literal lit(value(x[i]) == l_true ? x[i] : ~x[i]); inc_parity(lit.var()); if (lvl(lit) == level) { TRACE("ba", tout << "mark: " << lit << "\n";); ++num_marks; } else { m_parity_trail.push_back(lit); } } } } else { r.push_back(l); } bool found = false; while (num_marks > 0) { l = s().m_trail[index]; v = l.var(); unsigned n = get_parity(v); if (n > 0 && lvl(l) == level) { reset_parity(v); num_marks -= n; if (n % 2 == 1) { found = true; break; } } --index; } if (!found) { break; } --index; js = s().m_justification[v]; } // now walk the defined literals for (literal lit : m_parity_trail) { if (get_parity(lit.var()) % 2 == 1) { r.push_back(lit); } else { // IF_VERBOSE(2, verbose_stream() << "skip even parity: " << lit << "\n";); } reset_parity(lit.var()); } m_parity_trail.reset(); TRACE("ba", tout << r << "\n";); } void ba_solver::pre_simplify() { VERIFY(s().at_base_lvl()); if (s().inconsistent()) return; m_constraint_removed = false; xor_finder xf(s()); for (unsigned sz = m_constraints.size(), i = 0; i < sz; ++i) pre_simplify(xf, *m_constraints[i]); for (unsigned sz = m_learned.size(), i = 0; i < sz; ++i) pre_simplify(xf, *m_learned[i]); bool change = m_constraint_removed; cleanup_constraints(); if (change) { // remove non-used variables. init_use_lists(); remove_unused_defs(); set_non_external(); } } void ba_solver::simplify(xr& x) { if (x.learned()) { x.set_removed(); m_constraint_removed = true; } } void ba_solver::get_antecedents(literal l, xr const& x, literal_vector& r) { if (x.lit() != null_literal) r.push_back(x.lit()); // TRACE("ba", display(tout << l << " ", x, true);); SASSERT(x.lit() == null_literal || value(x.lit()) == l_true); SASSERT(x[0].var() == l.var() || x[1].var() == l.var()); if (x[0].var() == l.var()) { SASSERT(value(x[1]) != l_undef); r.push_back(value(x[1]) == l_true ? x[1] : ~x[1]); } else { SASSERT(value(x[0]) != l_undef); r.push_back(value(x[0]) == l_true ? x[0] : ~x[0]); } for (unsigned i = 2; i < x.size(); ++i) { SASSERT(value(x[i]) != l_undef); r.push_back(value(x[i]) == l_true ? x[i] : ~x[i]); } } void ba_solver::pre_simplify(xor_finder& xf, constraint& c) { if (c.is_xr() && c.size() <= xf.max_xor_size()) { unsigned sz = c.size(); literal_vector lits; bool parity = false; xr const& x = c.to_xr(); for (literal lit : x) { parity ^= lit.sign(); } // IF_VERBOSE(0, verbose_stream() << "blast: " << c << "\n"); for (unsigned i = 0; i < (1ul << sz); ++i) { if (xf.parity(sz, i) == parity) { lits.reset(); for (unsigned j = 0; j < sz; ++j) { lits.push_back(literal(x[j].var(), (0 != (i & (1 << j))))); } // IF_VERBOSE(0, verbose_stream() << lits << "\n"); s().mk_clause(lits); } } c.set_removed(); m_constraint_removed = true; } } bool ba_solver::clausify(xr& x) { return false; } // merge xors that contain cut variable void ba_solver::merge_xor() { unsigned sz = s().num_vars(); for (unsigned i = 0; i < sz; ++i) { literal lit(i, false); unsigned index = lit.index(); if (m_cnstr_use_list[index].size() == 2) { constraint& c1 = *m_cnstr_use_list[index][0]; constraint& c2 = *m_cnstr_use_list[index][1]; if (c1.is_xr() && c2.is_xr() && m_clause_use_list.get(lit).empty() && m_clause_use_list.get(~lit).empty()) { bool unique = true; for (watched w : get_wlist(lit)) { if (w.is_binary_clause()) unique = false; } for (watched w : get_wlist(~lit)) { if (w.is_binary_clause()) unique = false; } if (!unique) continue; xr const& x1 = c1.to_xr(); xr const& x2 = c2.to_xr(); literal_vector lits, dups; bool parity = false; init_visited(); for (literal l : x1) { mark_visited(l.var()); lits.push_back(l); } for (literal l : x2) { if (is_visited(l.var())) { dups.push_back(l); } else { lits.push_back(l); } } init_visited(); for (literal l : dups) mark_visited(l); unsigned j = 0; for (unsigned i = 0; i < lits.size(); ++i) { literal l = lits[i]; if (is_visited(l)) { // skip } else if (is_visited(~l)) { parity ^= true; } else { lits[j++] = l; } } lits.shrink(j); if (!parity) lits[0].neg(); IF_VERBOSE(1, verbose_stream() << "binary " << lits << " : " << c1 << " " << c2 << "\n"); c1.set_removed(); c2.set_removed(); add_xr(lits, !c1.learned() && !c2.learned()); m_constraint_removed = true; } } } } void ba_solver::extract_xor() { xor_finder xf(s()); std::function f = [this](literal_vector const& l) { add_xr(l, false); }; xf.set(f); clause_vector clauses(s().clauses()); xf(clauses); for (clause* cp : xf.removed_clauses()) { cp->set_removed(true); m_clause_removed = true; } } } // xr specific functionality lbool add_assign(xr& x, literal alit); void get_xr_antecedents(literal l, unsigned index, justification js, literal_vector& r); void get_antecedents(literal l, xr const& x, literal_vector & r); void simplify(xr& x); void extract_xor(); void merge_xor(); bool clausify(xr& x); void flush_roots(xr& x); lbool eval(xr const& x) const; lbool eval(model const& m, xr const& x) const; bool validate_conflict(xr const& x) const; constraint* add_xr(literal_vector const& lits, bool learned); literal add_xor_def(literal_vector& lits, bool learned = false); bool all_distinct(xr const& x); expr_ref get_xor(std::function& l2e, xr const& x); void add_xr(literal_vector const& lits); #include "sat/sat_xor_finder.h"