From 98c5a779b43e22ea79b8dea4f88a3d635644beb9 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Mon, 20 Feb 2017 16:55:00 -0800 Subject: [PATCH] add xor parity solver feature Signed-off-by: Nikolaj Bjorner --- src/ast/rewriter/pb2bv_rewriter.cpp | 15 +- src/sat/card_extension.cpp | 567 ++++++++++++++++++++++---- src/sat/card_extension.h | 109 ++++- src/sat/sat_params.pyg | 2 +- src/sat/sat_solver/inc_sat_solver.cpp | 1 + src/sat/tactic/goal2sat.cpp | 79 +++- 6 files changed, 665 insertions(+), 108 deletions(-) diff --git a/src/ast/rewriter/pb2bv_rewriter.cpp b/src/ast/rewriter/pb2bv_rewriter.cpp index c0a13fa77..24f9bf44b 100644 --- a/src/ast/rewriter/pb2bv_rewriter.cpp +++ b/src/ast/rewriter/pb2bv_rewriter.cpp @@ -50,6 +50,7 @@ struct pb2bv_rewriter::imp { rational m_k; vector m_coeffs; bool m_keep_cardinality_constraints; + unsigned m_min_arity; template expr_ref mk_le_ge(expr_ref_vector& fmls, expr* a, expr* b, expr* bound) { @@ -416,7 +417,8 @@ struct pb2bv_rewriter::imp { bv(m), m_trail(m), m_args(m), - m_keep_cardinality_constraints(true) + m_keep_cardinality_constraints(true), + m_min_arity(8) {} bool mk_app(bool full, func_decl * f, unsigned sz, expr * const* args, expr_ref & result) { @@ -530,27 +532,26 @@ struct pb2bv_rewriter::imp { bool mk_pb(bool full, func_decl * f, unsigned sz, expr * const* args, expr_ref & result) { SASSERT(f->get_family_id() == pb.get_family_id()); if (is_or(f)) { - if (m_keep_cardinality_constraints) return false; result = m.mk_or(sz, args); } else if (pb.is_at_most_k(f) && pb.get_k(f).is_unsigned()) { - if (m_keep_cardinality_constraints) return false; + if (m_keep_cardinality_constraints && f->get_arity() >= m_min_arity) return false; result = m_sort.le(full, pb.get_k(f).get_unsigned(), sz, args); } else if (pb.is_at_least_k(f) && pb.get_k(f).is_unsigned()) { - if (m_keep_cardinality_constraints) return false; + if (m_keep_cardinality_constraints && f->get_arity() >= m_min_arity) return false; result = m_sort.ge(full, pb.get_k(f).get_unsigned(), sz, args); } else if (pb.is_eq(f) && pb.get_k(f).is_unsigned() && pb.has_unit_coefficients(f)) { - if (m_keep_cardinality_constraints) return false; + if (m_keep_cardinality_constraints && f->get_arity() >= m_min_arity) return false; result = m_sort.eq(full, pb.get_k(f).get_unsigned(), sz, args); } else if (pb.is_le(f) && pb.get_k(f).is_unsigned() && pb.has_unit_coefficients(f)) { - if (m_keep_cardinality_constraints) return false; + if (m_keep_cardinality_constraints && f->get_arity() >= m_min_arity) return false; result = m_sort.le(full, pb.get_k(f).get_unsigned(), sz, args); } else if (pb.is_ge(f) && pb.get_k(f).is_unsigned() && pb.has_unit_coefficients(f)) { - if (m_keep_cardinality_constraints) return false; + if (m_keep_cardinality_constraints && f->get_arity() >= m_min_arity) return false; result = m_sort.ge(full, pb.get_k(f).get_unsigned(), sz, args); } else { diff --git a/src/sat/card_extension.cpp b/src/sat/card_extension.cpp index b69b0aae5..fd8986cae 100644 --- a/src/sat/card_extension.cpp +++ b/src/sat/card_extension.cpp @@ -7,7 +7,7 @@ Module Name: Abstract: - Extension for cardinality reasoning. + Extension for cardinality and xor reasoning. Author: @@ -42,6 +42,16 @@ namespace sat { SASSERT(m_size >= m_k && m_k > 0); } + card_extension::xor::xor(unsigned index, literal lit, literal_vector const& lits): + m_index(index), + m_lit(lit), + m_size(lits.size()) + { + for (unsigned i = 0; i < lits.size(); ++i) { + m_lits[i] = lits[i]; + } + } + void card_extension::init_watch(bool_var v) { if (m_var_infos.size() <= static_cast(v)) { m_var_infos.resize(static_cast(v)+100); @@ -120,7 +130,7 @@ namespace sat { if (m_var_infos.size() <= static_cast(lit.var())) { return; } - ptr_vector*& cards = m_var_infos[lit.var()].m_lit_watch[lit.sign()]; + ptr_vector*& cards = m_var_infos[lit.var()].m_card_watch[lit.sign()]; if (!is_tag_empty(cards)) { if (remove(*cards, c)) { cards = set_tag_empty(cards); @@ -128,30 +138,6 @@ namespace sat { } } - ptr_vector* card_extension::set_tag_empty(ptr_vector* c) { - return TAG(ptr_vector*, c, 1); - } - - bool card_extension::is_tag_empty(ptr_vector const* c) { - return !c || GET_TAG(c) == 1; - } - - ptr_vector* card_extension::set_tag_non_empty(ptr_vector* c) { - return UNTAG(ptr_vector*, c); - } - - bool card_extension::remove(ptr_vector& cards, card* c) { - unsigned sz = cards.size(); - for (unsigned j = 0; j < sz; ++j) { - if (cards[j] == c) { - std::swap(cards[j], cards[sz-1]); - cards.pop_back(); - return sz == 1; - } - } - return false; - } - void card_extension::assign(card& c, literal lit) { switch (value(lit)) { case l_true: @@ -183,14 +169,14 @@ namespace sat { void card_extension::watch_literal(card& c, literal lit) { TRACE("sat_verbose", tout << "watch: " << lit << "\n";); init_watch(lit.var()); - ptr_vector* cards = m_var_infos[lit.var()].m_lit_watch[lit.sign()]; + ptr_vector* cards = m_var_infos[lit.var()].m_card_watch[lit.sign()]; if (cards == 0) { cards = alloc(ptr_vector); - m_var_infos[lit.var()].m_lit_watch[lit.sign()] = cards; + m_var_infos[lit.var()].m_card_watch[lit.sign()] = cards; } else if (is_tag_empty(cards)) { cards = set_tag_non_empty(cards); - m_var_infos[lit.var()].m_lit_watch[lit.sign()] = cards; + m_var_infos[lit.var()].m_card_watch[lit.sign()] = cards; } TRACE("sat_verbose", tout << "insert: " << lit.var() << " " << lit.sign() << "\n";); cards->push_back(&c); @@ -202,6 +188,155 @@ namespace sat { s().set_conflict(justification::mk_ext_justification(c.index()), ~lit); SASSERT(s().inconsistent()); } + + void card_extension::clear_watch(xor& x) { + unwatch_literal(x[0], &x); + unwatch_literal(x[1], &x); + } + + void card_extension::unwatch_literal(literal lit, xor* c) { + if (m_var_infos.size() <= static_cast(lit.var())) { + return; + } + xor_watch* xors = m_var_infos[lit.var()].m_xor_watch; + if (!is_tag_empty(xors)) { + if (remove(*xors, c)) { + xors = set_tag_empty(xors); + } + } + } + + bool card_extension::parity(xor const& x, unsigned offset) const { + bool odd = false; + unsigned sz = x.size(); + for (unsigned i = offset; i < sz; ++i) { + SASSERT(value(x[i]) != l_undef); + if (value(x[i]) == l_true) { + odd = !odd; + } + } + return odd; + } + + void card_extension::init_watch(xor& x, bool is_true) { + clear_watch(x); + if (x.lit().sign() == is_true) { + x.negate(); + } + unsigned sz = x.size(); + unsigned j = 0; + for (unsigned i = 0; i < sz && j < 2; ++i) { + if (value(x[i]) == l_undef) { + x.swap(i, j); + ++j; + } + } + switch (j) { + case 0: + if (!parity(x, 0)) { + set_conflict(x, x[0]); + } + break; + case 1: + assign(x, parity(x, 1) ? ~x[0] : x[0]); + break; + default: + SASSERT(j == 2); + watch_literal(x, x[0]); + watch_literal(x, x[1]); + break; + } + } + + void card_extension::assign(xor& x, literal lit) { + switch (value(lit)) { + case l_true: + break; + case l_false: + set_conflict(x, lit); + break; + default: + m_stats.m_num_propagations++; + m_num_propagations_since_pop++; + if (s().m_config.m_drat) { + svector ps; + literal_vector lits; + lits.push_back(~x.lit()); + for (unsigned i = 1; i < x.size(); ++i) { + lits.push_back(x[i]); + } + lits.push_back(lit); + ps.push_back(drat::premise(drat::s_ext(), x.lit())); + s().m_drat.add(lits, ps); + } + s().assign(lit, justification::mk_ext_justification(x.index())); + break; + } + } + + void card_extension::watch_literal(xor& x, literal lit) { + TRACE("sat_verbose", tout << "watch: " << lit << "\n";); + init_watch(lit.var()); + xor_watch*& xors = m_var_infos[lit.var()].m_xor_watch; + if (xors == 0) { + xors = alloc(ptr_vector); + } + else if (is_tag_empty(xors)) { + xors = set_tag_non_empty(xors); + } + xors->push_back(&x); + TRACE("sat_verbose", tout << "insert: " << lit.var() << " " << lit.sign() << "\n";); + } + + + void card_extension::set_conflict(xor& x, literal lit) { + TRACE("sat", display(tout, x, true); ); + SASSERT(validate_conflict(x)); + s().set_conflict(justification::mk_ext_justification(x.index()), ~lit); + SASSERT(s().inconsistent()); + } + + lbool card_extension::add_assign(xor& x, literal alit) { + // literal is assigned + unsigned sz = x.size(); + TRACE("sat", tout << "assign: " << x.lit() << ": " << ~alit << "@" << lvl(~alit) << "\n";); + + SASSERT(value(alit) != l_undef); + SASSERT(value(x.lit()) == l_true); + unsigned index = 0; + for (; index <= 2; ++index) { + if (x[index].var() == alit.var()) break; + } + if (index == 2) { + // literal is no longer watched. + return l_undef; + } + SASSERT(x[index].var() == alit.var()); + + // find a literal to swap with: + for (unsigned i = 2; i < sz; ++i) { + literal lit2 = x[i]; + if (value(lit2) == l_undef) { + x.swap(index, i); + watch_literal(x, lit2); + return l_undef; + } + } + if (index == 0) { + x.swap(0, 1); + } + // alit resides at index 1. + SASSERT(x[1].var() == alit.var()); + if (value(x[0]) == l_undef) { + bool p = parity(x, 1); + assign(x, p ? ~x[0] : x[0]); + } + else if (!parity(x, 0)) { + set_conflict(x, x[0]); + } + return s().inconsistent() ? l_false : l_true; + } + void card_extension::normalize_active_coeffs() { while (!m_active_var_set.empty()) m_active_var_set.erase(); @@ -288,6 +423,8 @@ namespace sat { unsigned init_marks = m_num_marks; + vector jus; + do { if (offset == 0) { @@ -349,9 +486,22 @@ namespace sat { } case justification::EXT_JUSTIFICATION: { unsigned index = js.get_ext_justification_idx(); - card& c = *m_constraints[index]; - m_bound += offset * c.k(); - process_card(c, offset); + if (is_card_index(index)) { + card& c = index2card(index); + m_bound += offset * c.k(); + process_card(c, offset); + } + else { + // jus.push_back(js); + m_lemma.reset(); + m_bound += offset; + inc_coeff(consequent, offset); + get_xor_antecedents(idx, m_lemma); + // get_antecedents(consequent, index, m_lemma); + for (unsigned i = 0; i < m_lemma.size(); ++i) { + process_antecedent(~m_lemma[i], offset); + } + } break; } default: @@ -424,7 +574,6 @@ namespace sat { lbool val = m_solver->value(v); bool is_true = val == l_true; bool append = coeff != 0 && val != l_undef && (coeff < 0 == is_true); - if (append) { literal lit(v, !is_true); if (lvl(lit) == m_conflict_lvl) { @@ -440,6 +589,17 @@ namespace sat { } } + if (jus.size() > 1) { + std::cout << jus.size() << "\n"; + for (unsigned i = 0; i < jus.size(); ++i) { + s().display_justification(std::cout, jus[i]); std::cout << "\n"; + } + std::cout << m_lemma << "\n"; + active2pb(m_A); + display(std::cout, m_A); + } + + if (slack >= 0) { IF_VERBOSE(2, verbose_stream() << "(sat.card bail slack objective not met " << slack << ")\n";); goto bail_out; @@ -564,7 +724,7 @@ namespace sat { return p; } - card_extension::card_extension(): m_solver(0) { + card_extension::card_extension(): m_solver(0), m_has_xor(false) { TRACE("sat", tout << this << "\n";); } @@ -578,33 +738,170 @@ namespace sat { } void card_extension::add_at_least(bool_var v, literal_vector const& lits, unsigned k) { - unsigned index = m_constraints.size(); + unsigned index = 2*m_cards.size(); card* c = new (memory::allocate(card::get_obj_size(lits.size()))) card(index, literal(v, false), lits, k); - m_constraints.push_back(c); + m_cards.push_back(c); init_watch(v); m_var_infos[v].m_card = c; m_var_trail.push_back(v); } + void card_extension::add_xor(bool_var v, literal_vector const& lits) { + m_has_xor = true; + unsigned index = 2*m_xors.size()+1; + xor* x = new (memory::allocate(xor::get_obj_size(lits.size()))) xor(index, literal(v, false), lits); + m_xors.push_back(x); + init_watch(v); + m_var_infos[v].m_xor = x; + m_var_trail.push_back(v); + } + + void card_extension::propagate(literal l, ext_constraint_idx idx, bool & keep) { UNREACHABLE(); } - void card_extension::get_antecedents(literal l, ext_justification_idx idx, literal_vector & r) { - card& c = *m_constraints[idx]; - - DEBUG_CODE( - bool found = false; - for (unsigned i = 0; !found && i < c.k(); ++i) { - found = c[i] == l; + + void card_extension::ensure_parity_size(bool_var v) { + if (m_parity_marks.size() <= static_cast(v)) { + m_parity_marks.resize(static_cast(v) + 1, 0); + } + } + + unsigned card_extension::get_parity(bool_var v) { + return m_parity_marks.get(v, 0); + } + + void card_extension::inc_parity(bool_var v) { + ensure_parity_size(v); + m_parity_marks[v]++; + } + + void card_extension::reset_parity(bool_var v) { + ensure_parity_size(v); + m_parity_marks[v] = 0; + } + + /** + \brief perform parity resolution on xor premises. + The idea is to collect premises based on xor resolvents. + Variables that are repeated an even number of times cancel out. + */ + void card_extension::get_xor_antecedents(unsigned index, literal_vector& r) { + literal_vector const& lits = s().m_trail; + literal l = lits[index + 1]; + unsigned level = lvl(l); + bool_var v = l.var(); + SASSERT(s().m_justification[v].get_kind() == justification::EXT_JUSTIFICATION); + SASSERT(!is_card_index(s().m_justification[v].get_ext_justification_idx())); + + unsigned num_marks = 0; + unsigned count = 0; + while (true) { + ++count; + justification js = s().m_justification[v]; + if (js.get_kind() == justification::EXT_JUSTIFICATION) { + unsigned idx = js.get_ext_justification_idx(); + if (is_card_index(idx)) { + r.push_back(l); + } + else { + xor& x = index2xor(idx); + if (lvl(x.lit()) > 0) r.push_back(x.lit()); + if (x[1].var() == l.var()) { + x.swap(0, 1); + } + SASSERT(x[0].var() == l.var()); + for (unsigned i = 1; i < x.size(); ++i) { + literal lit(value(x[i]) == l_true ? x[i] : ~x[i]); + inc_parity(lit.var()); + if (true || lvl(lit) == level) { + ++num_marks; + } + else { + m_parity_trail.push_back(lit); + } + } + } } - SASSERT(found);); + else { + r.push_back(l); + } + while (num_marks > 0) { + l = lits[index]; + v = l.var(); + unsigned n = get_parity(v); + if (n > 0) { + reset_parity(v); + if (n > 1) { + IF_VERBOSE(2, verbose_stream() << "parity greater than 1: " << l << " " << n << "\n";); + } + if (n % 2 == 1) { + break; + } + IF_VERBOSE(2, verbose_stream() << "skip even parity: " << l << "\n";); + --num_marks; + } + --index; + } + if (num_marks == 0) { + break; + } + --index; + --num_marks; + } + + // now walk the defined literals + + for (unsigned i = 0; i < m_parity_trail.size(); ++i) { + literal lit = m_parity_trail[i]; + if (get_parity(lit.var()) % 2 == 1) { + r.push_back(lit); + } + else { + IF_VERBOSE(2, verbose_stream() << "skip even parity: " << lit << "\n";); + } + reset_parity(lit.var()); + } + m_parity_trail.reset(); + } + + void card_extension::get_antecedents(literal l, ext_justification_idx idx, literal_vector & r) { + if (is_card_index(idx)) { + card& c = index2card(idx); - r.push_back(c.lit()); - SASSERT(value(c.lit()) == l_true); - for (unsigned i = c.k(); i < c.size(); ++i) { - SASSERT(value(c[i]) == l_false); - r.push_back(~c[i]); + DEBUG_CODE( + bool found = false; + for (unsigned i = 0; !found && i < c.k(); ++i) { + found = c[i] == l; + } + SASSERT(found);); + + r.push_back(c.lit()); + SASSERT(value(c.lit()) == l_true); + for (unsigned i = c.k(); i < c.size(); ++i) { + SASSERT(value(c[i]) == l_false); + r.push_back(~c[i]); + } + } + else { + xor& x = index2xor(idx); + r.push_back(x.lit()); + TRACE("sat", display(tout << l << " ", x, true);); + SASSERT(value(x.lit()) == l_true); + SASSERT(x[0].var() == l.var() || x[1].var() == l.var()); + if (x[0].var() == l.var()) { + SASSERT(value(x[1]) != l_undef); + r.push_back(value(x[1]) == l_true ? x[1] : ~x[1]); + } + else { + SASSERT(value(x[0]) != l_undef); + r.push_back(value(x[0]) == l_true ? x[0] : ~x[0]); + } + for (unsigned i = 2; i < x.size(); ++i) { + SASSERT(value(x[i]) != l_undef); + r.push_back(value(x[i]) == l_true ? x[i] : ~x[i]); + } } } @@ -670,10 +967,11 @@ namespace sat { if (s().inconsistent()) return; if (v >= m_var_infos.size()) return; var_info& vinfo = m_var_infos[v]; - ptr_vector* cards = vinfo.m_lit_watch[!l.sign()]; - //TRACE("sat", tout << "retrieve: " << v << " " << !l.sign() << "\n";); - //TRACE("sat", tout << "asserted: " << l << " " << (cards ? "non-empty" : "empty") << "\n";); - static unsigned is_empty = 0, non_empty = 0; + ptr_vector* cards = vinfo.m_card_watch[!l.sign()]; + card* crd = vinfo.m_card; + xor* x = vinfo.m_xor; + ptr_vector* xors = vinfo.m_xor_watch; + if (!is_tag_empty(cards)) { ptr_vector::iterator begin = cards->begin(); ptr_vector::iterator it = begin, it2 = it, end = cards->end(); @@ -702,14 +1000,56 @@ namespace sat { } cards->set_end(it2); if (cards->empty()) { - m_var_infos[v].m_lit_watch[!l.sign()] = set_tag_empty(cards); + m_var_infos[v].m_card_watch[!l.sign()] = set_tag_empty(cards); } } - card* crd = vinfo.m_card; if (crd != 0 && !s().inconsistent()) { init_watch(*crd, !l.sign()); } + if (m_has_xor && !s().inconsistent()) { + asserted_xor(l, xors, x); + } + } + + + void card_extension::asserted_xor(literal l, ptr_vector* xors, xor* x) { + TRACE("sat", tout << l << " " << !is_tag_empty(xors) << " " << (x != 0) << "\n";); + if (!is_tag_empty(xors)) { + ptr_vector::iterator begin = xors->begin(); + ptr_vector::iterator it = begin, it2 = it, end = xors->end(); + for (; it != end; ++it) { + xor& c = *(*it); + if (value(c.lit()) != l_true) { + continue; + } + switch (add_assign(c, ~l)) { + case l_false: // conflict + for (; it != end; ++it, ++it2) { + *it2 = *it; + } + SASSERT(s().inconsistent()); + xors->set_end(it2); + return; + case l_undef: // watch literal was swapped + break; + case l_true: // unit propagation, keep watching the literal + if (it2 != it) { + *it2 = *it; + } + ++it2; + break; + } + } + xors->set_end(it2); + if (xors->empty()) { + m_var_infos[l.var()].m_xor_watch = set_tag_empty(xors); + } + } + + if (x != 0 && !s().inconsistent()) { + init_watch(*x, !l.sign()); + } } check_result card_extension::check() { return CR_DONE; } @@ -730,6 +1070,10 @@ namespace sat { clear_watch(*c); m_var_infos[v].m_card = 0; dealloc(c); + xor* x = m_var_infos[v].m_xor; + clear_watch(*x); + m_var_infos[v].m_xor = 0; + dealloc(x); } } m_var_lim.resize(new_lim); @@ -743,22 +1087,30 @@ namespace sat { extension* card_extension::copy(solver* s) { card_extension* result = alloc(card_extension); result->set_solver(s); - for (unsigned i = 0; i < m_constraints.size(); ++i) { + for (unsigned i = 0; i < m_cards.size(); ++i) { literal_vector lits; - card& c = *m_constraints[i]; + card& c = *m_cards[i]; for (unsigned i = 0; i < c.size(); ++i) { lits.push_back(c[i]); } result->add_at_least(c.lit().var(), lits, c.k()); } + for (unsigned i = 0; i < m_xors.size(); ++i) { + literal_vector lits; + xor& x = *m_xors[i]; + for (unsigned i = 0; i < x.size(); ++i) { + lits.push_back(x[i]); + } + result->add_xor(x.lit().var(), lits); + } return result; } void card_extension::find_mutexes(literal_vector& lits, vector & mutexes) { literal_set slits(lits); bool change = false; - for (unsigned i = 0; i < m_constraints.size(); ++i) { - card& c = *m_constraints[i]; + for (unsigned i = 0; i < m_cards.size(); ++i) { + card& c = *m_cards[i]; if (c.size() == c.k() + 1) { literal_vector mux; for (unsigned j = 0; j < c.size(); ++j) { @@ -786,10 +1138,10 @@ namespace sat { } } - void card_extension::display_watch(std::ostream& out, bool_var v, bool sign) const { - watch const* w = m_var_infos[v].m_lit_watch[sign]; + void card_extension::display_watch(std::ostream& out, bool_var v, bool sign) const { + card_watch const* w = m_var_infos[v].m_card_watch[sign]; if (!is_tag_empty(w)) { - watch const& wl = *w; + card_watch const& wl = *w; out << literal(v, sign) << " |-> "; for (unsigned i = 0; i < wl.size(); ++i) { out << wl[i]->lit() << " "; @@ -798,6 +1150,18 @@ namespace sat { } } + void card_extension::display_watch(std::ostream& out, bool_var v) const { + xor_watch const* w = m_var_infos[v].m_xor_watch; + if (!is_tag_empty(w)) { + xor_watch const& wl = *w; + out << "v" << v << " |-> "; + for (unsigned i = 0; i < wl.size(); ++i) { + out << wl[i]->lit() << " "; + } + out << "\n"; + } + } + void card_extension::display(std::ostream& out, ineq& ineq) const { for (unsigned i = 0; i < ineq.m_lits.size(); ++i) { out << ineq.m_coeffs[i] << "*" << ineq.m_lits[i] << " "; @@ -805,6 +1169,35 @@ namespace sat { out << ">= " << ineq.m_k << "\n"; } + void card_extension::display(std::ostream& out, xor& x, bool values) const { + out << "xor " << x.lit(); + if (x.lit() != null_literal && values) { + out << "@(" << value(x.lit()); + if (value(x.lit()) != l_undef) { + out << ":" << lvl(x.lit()); + } + out << "): "; + } + else { + out << ": "; + } + for (unsigned i = 0; i < x.size(); ++i) { + literal l = x[i]; + out << l; + if (values) { + out << "@(" << value(l); + if (value(l) != l_undef) { + out << ":" << lvl(l); + } + out << ") "; + } + else { + out << " "; + } + } + out << "\n"; + } + void card_extension::display(std::ostream& out, card& c, bool values) const { out << c.lit() << "[" << c.size() << "]"; if (c.lit() != null_literal && values) { @@ -838,23 +1231,33 @@ namespace sat { for (unsigned vi = 0; vi < m_var_infos.size(); ++vi) { display_watch(out, vi, false); display_watch(out, vi, true); + display_watch(out, vi); } for (unsigned vi = 0; vi < m_var_infos.size(); ++vi) { card* c = m_var_infos[vi].m_card; - if (c) { - display(out, *c, false); - } + if (c) display(out, *c, false); + xor* x = m_var_infos[vi].m_xor; + if (x) display(out, *x, false); } return out; } std::ostream& card_extension::display_justification(std::ostream& out, ext_justification_idx idx) const { - card& c = *m_constraints[idx]; - out << "bound " << c.lit() << ": "; - for (unsigned i = 0; i < c.size(); ++i) { - out << c[i] << " "; + if (is_card_index(idx)) { + card& c = index2card(idx); + out << "bound " << c.lit() << ": "; + for (unsigned i = 0; i < c.size(); ++i) { + out << c[i] << " "; + } + out << ">= " << c.k(); + } + else { + xor& x = index2xor(idx); + out << "xor " << x.lit() << ": "; + for (unsigned i = 0; i < x.size(); ++i) { + out << x[i] << " "; + } } - out << ">= " << c.k(); return out; } @@ -870,6 +1273,9 @@ namespace sat { } return false; } + bool card_extension::validate_conflict(xor& x) { + return !parity(x, 0); + } bool card_extension::validate_unit_propagation(card const& c) { if (value(c.lit()) != l_true) return false; for (unsigned i = c.k(); i < c.size(); ++i) { @@ -933,12 +1339,23 @@ namespace sat { } case justification::EXT_JUSTIFICATION: { unsigned index = js.get_ext_justification_idx(); - card& c = *m_constraints[index]; - p.reset(offset*c.k()); - for (unsigned i = 0; i < c.size(); ++i) { - p.push(c[i], offset); + if (is_card_index(index)) { + card& c = index2card(index); + p.reset(offset*c.k()); + for (unsigned i = 0; i < c.size(); ++i) { + p.push(c[i], offset); + } + p.push(~c.lit(), offset*c.k()); + } + else { + literal_vector ls; + get_antecedents(lit, index, ls); + p.reset(offset); + for (unsigned i = 0; i < ls.size(); ++i) { + p.push(~ls[i], offset); + } + p.push(~index2xor(index).lit(), offset); } - p.push(~c.lit(), offset*c.k()); break; } default: diff --git a/src/sat/card_extension.h b/src/sat/card_extension.h index 65a991ec2..4e0c10cd2 100644 --- a/src/sat/card_extension.h +++ b/src/sat/card_extension.h @@ -32,9 +32,7 @@ namespace sat { void reset() { memset(this, 0, sizeof(*this)); } }; - // class card_allocator; class card { - //friend class card_allocator; unsigned m_index; literal m_lit; unsigned m_k; @@ -53,6 +51,22 @@ namespace sat { void negate(); }; + class xor { + unsigned m_index; + literal m_lit; + unsigned m_size; + literal m_lits[0]; + public: + static size_t get_obj_size(unsigned num_lits) { return sizeof(xor) + num_lits * sizeof(literal); } + xor(unsigned index, literal lit, literal_vector const& lits); + unsigned index() const { return m_index; } + literal lit() const { return m_lit; } + literal operator[](unsigned i) const { return m_lits[i]; } + unsigned size() const { return m_size; } + void swap(unsigned i, unsigned j) { std::swap(m_lits[i], m_lits[j]); } + void negate() { m_lits[0].neg(); } + }; + struct ineq { literal_vector m_lits; unsigned_vector m_coeffs; @@ -61,29 +75,48 @@ namespace sat { void push(literal l, unsigned c) { m_lits.push_back(l); m_coeffs.push_back(c); } }; - typedef ptr_vector watch; + typedef ptr_vector card_watch; + typedef ptr_vector xor_watch; struct var_info { - watch* m_lit_watch[2]; - card* m_card; - var_info(): m_card(0) { - m_lit_watch[0] = 0; - m_lit_watch[1] = 0; + card_watch* m_card_watch[2]; + xor_watch* m_xor_watch; + card* m_card; + xor* m_xor; + var_info(): m_xor_watch(0), m_card(0), m_xor(0) { + m_card_watch[0] = 0; + m_card_watch[1] = 0; } void reset() { dealloc(m_card); - dealloc(card_extension::set_tag_non_empty(m_lit_watch[0])); - dealloc(card_extension::set_tag_non_empty(m_lit_watch[1])); + dealloc(m_xor); + dealloc(card_extension::set_tag_non_empty(m_card_watch[0])); + dealloc(card_extension::set_tag_non_empty(m_card_watch[1])); + dealloc(card_extension::set_tag_non_empty(m_xor_watch)); } }; + + template + static ptr_vector* set_tag_empty(ptr_vector* c) { + return TAG(ptr_vector*, c, 1); + } + + template + static bool is_tag_empty(ptr_vector const* c) { + return !c || GET_TAG(c) == 1; + } + + template + static ptr_vector* set_tag_non_empty(ptr_vector* c) { + return UNTAG(ptr_vector*, c); + } + - static ptr_vector* set_tag_empty(ptr_vector* c); - static bool is_tag_empty(ptr_vector const* c); - static ptr_vector* set_tag_non_empty(ptr_vector* c); solver* m_solver; stats m_stats; - ptr_vector m_constraints; + ptr_vector m_cards; + ptr_vector m_xors; // watch literals svector m_var_infos; @@ -98,8 +131,14 @@ namespace sat { int m_bound; tracked_uint_set m_active_var_set; literal_vector m_lemma; -// literal_vector m_literals; unsigned m_num_propagations_since_pop; + bool m_has_xor; + unsigned_vector m_parity_marks; + literal_vector m_parity_trail; + void ensure_parity_size(bool_var v); + unsigned get_parity(bool_var v); + void inc_parity(bool_var v); + void reset_parity(bool_var v); solver& s() const { return *m_solver; } void init_watch(card& c, bool is_true); @@ -111,13 +150,44 @@ namespace sat { void clear_watch(card& c); void reset_coeffs(); void reset_marked_literals(); + void unwatch_literal(literal w, card* c); + + // xor specific functionality + void clear_watch(xor& x); + void watch_literal(xor& x, literal lit); + void unwatch_literal(literal w, xor* x); + void init_watch(xor& x, bool is_true); + void assign(xor& x, literal lit); + void set_conflict(xor& x, literal lit); + bool parity(xor const& x, unsigned offset) const; + lbool add_assign(xor& x, literal alit); + void asserted_xor(literal l, ptr_vector* xors, xor* x); + + bool is_card_index(unsigned idx) const { return 0 == (idx & 0x1); } + card& index2card(unsigned idx) const { SASSERT(is_card_index(idx)); return *m_cards[idx >> 1]; } + xor& index2xor(unsigned idx) const { SASSERT(!is_card_index(idx)); return *m_xors[idx >> 1]; } + void get_xor_antecedents(unsigned index, literal_vector& r); + + + template + bool remove(ptr_vector& ts, T* t) { + unsigned sz = ts.size(); + for (unsigned j = 0; j < sz; ++j) { + if (ts[j] == t) { + std::swap(ts[j], ts[sz-1]); + ts.pop_back(); + return sz == 1; + } + } + return false; + } + + inline lbool value(literal lit) const { return m_solver->value(lit); } inline unsigned lvl(literal lit) const { return m_solver->lvl(lit); } inline unsigned lvl(bool_var v) const { return m_solver->lvl(v); } - void unwatch_literal(literal w, card* c); - bool remove(ptr_vector& cards, card* c); void normalize_active_coeffs(); void inc_coeff(literal l, int offset); @@ -131,6 +201,7 @@ namespace sat { // validation utilities bool validate_conflict(card& c); + bool validate_conflict(xor& x); bool validate_assign(literal_vector const& lits, literal lit); bool validate_lemma(); bool validate_unit_propagation(card const& c); @@ -143,12 +214,16 @@ namespace sat { void display(std::ostream& out, ineq& p) const; void display(std::ostream& out, card& c, bool values) const; + void display(std::ostream& out, xor& c, bool values) const; + void display_watch(std::ostream& out, bool_var v) const; void display_watch(std::ostream& out, bool_var v, bool sign) const; + public: card_extension(); virtual ~card_extension(); virtual void set_solver(solver* s) { m_solver = s; } void add_at_least(bool_var v, literal_vector const& lits, unsigned k); + void add_xor(bool_var v, literal_vector const& lits); virtual void propagate(literal l, ext_constraint_idx idx, bool & keep); virtual bool resolve_conflict(); virtual void get_antecedents(literal l, ext_justification_idx idx, literal_vector & r); diff --git a/src/sat/sat_params.pyg b/src/sat/sat_params.pyg index e3ca12b73..a51e72e25 100644 --- a/src/sat/sat_params.pyg +++ b/src/sat/sat_params.pyg @@ -26,5 +26,5 @@ def_module_params('sat', ('dimacs.core', BOOL, False, 'extract core from DIMACS benchmarks'), ('drat.file', SYMBOL, '', 'file to dump DRAT proofs'), ('drat.check', BOOL, False, 'build up internal proof and check'), - ('cardinality.solver', BOOL, True, 'use cardinality solver'), + ('cardinality.solver', BOOL, False, 'use cardinality/xor solver'), )) diff --git a/src/sat/sat_solver/inc_sat_solver.cpp b/src/sat/sat_solver/inc_sat_solver.cpp index d2ae541a8..83b8362cf 100644 --- a/src/sat/sat_solver/inc_sat_solver.cpp +++ b/src/sat/sat_solver/inc_sat_solver.cpp @@ -217,6 +217,7 @@ public: sat_params p1(p); m_params.set_bool("elim_vars", false); m_params.set_bool("keep_cardinality_constraints", p1.cardinality_solver()); + m_params.set_bool("cardinality_solver", p1.cardinality_solver()); m_solver.updt_params(m_params); m_optimize_model = m_params.get_bool("optimize_model", false); diff --git a/src/sat/tactic/goal2sat.cpp b/src/sat/tactic/goal2sat.cpp index 87d69b35e..93367013e 100644 --- a/src/sat/tactic/goal2sat.cpp +++ b/src/sat/tactic/goal2sat.cpp @@ -65,6 +65,7 @@ struct goal2sat::imp { expr_ref_vector m_trail; expr_ref_vector m_interpreted_atoms; bool m_default_external; + bool m_cardinality_solver; imp(ast_manager & _m, params_ref const & p, sat::solver & s, atom2bool_var & map, dep2asm_map& dep2asm, bool default_external): m(_m), @@ -83,6 +84,8 @@ struct goal2sat::imp { void updt_params(params_ref const & p) { m_ite_extra = p.get_bool("ite_extra", true); m_max_memory = megabytes_to_bytes(p.get_uint("max_memory", UINT_MAX)); + m_cardinality_solver = p.get_bool("cardinality_solver", false); + std::cout << p << "\n"; } void throw_op_not_handled(std::string const& s) { @@ -339,7 +342,7 @@ struct goal2sat::imp { } } - void convert_iff(app * t, bool root, bool sign) { + void convert_iff2(app * t, bool root, bool sign) { TRACE("goal2sat", tout << "convert_iff " << root << " " << sign << "\n" << mk_ismt2_pp(t, m) << "\n";); unsigned sz = m_result_stack.size(); SASSERT(sz >= 2); @@ -372,8 +375,33 @@ struct goal2sat::imp { } } - void convert_pb_args(app* t, sat::literal_vector& lits) { - unsigned num_args = t->get_num_args(); + void convert_iff(app * t, bool root, bool sign) { + TRACE("goal2sat", tout << "convert_iff " << root << " " << sign << "\n" << mk_ismt2_pp(t, m) << "\n";); + unsigned sz = m_result_stack.size(); + unsigned num = get_num_args(t); + SASSERT(sz >= num && num >= 2); + if (num == 2) { + convert_iff2(t, root, sign); + return; + } + sat::literal_vector lits; + convert_pb_args(num, lits); + sat::bool_var v = m_solver.mk_var(true); + ensure_extension(); + if (lits.size() % 2 == 0) lits[0].neg(); + m_ext->add_xor(v, lits); + sat::literal lit(v, sign); + if (root) { + m_result_stack.reset(); + mk_clause(lit); + } + else { + m_result_stack.shrink(sz - num); + m_result_stack.push_back(lit); + } + } + + void convert_pb_args(unsigned num_args, sat::literal_vector& lits) { unsigned sz = m_result_stack.size(); for (unsigned i = 0; i < num_args; ++i) { sat::literal lit(m_result_stack[sz - num_args + i]); @@ -396,7 +424,7 @@ struct goal2sat::imp { SASSERT(k.is_unsigned()); sat::literal_vector lits; unsigned sz = m_result_stack.size(); - convert_pb_args(t, lits); + convert_pb_args(t->get_num_args(), lits); sat::bool_var v = m_solver.mk_var(true); sat::literal lit(v, sign); m_ext->add_at_least(v, lits, k.get_unsigned()); @@ -415,7 +443,7 @@ struct goal2sat::imp { SASSERT(k.is_unsigned()); sat::literal_vector lits; unsigned sz = m_result_stack.size(); - convert_pb_args(t, lits); + convert_pb_args(t->get_num_args(), lits); for (unsigned i = 0; i < lits.size(); ++i) { lits[i].neg(); } @@ -434,7 +462,7 @@ struct goal2sat::imp { void convert_eq_k(app* t, rational k, bool root, bool sign) { SASSERT(k.is_unsigned()); sat::literal_vector lits; - convert_pb_args(t, lits); + convert_pb_args(t->get_num_args(), lits); sat::bool_var v1 = m_solver.mk_var(true); sat::bool_var v2 = m_solver.mk_var(true); sat::literal l1(v1, false), l2(v2, false); @@ -528,6 +556,41 @@ struct goal2sat::imp { UNREACHABLE(); } } + + + unsigned get_num_args(app* t) { + + if (m.is_iff(t) && m_cardinality_solver) { + unsigned n = 2; + while (m.is_iff(t->get_arg(1))) { + ++n; + t = to_app(t->get_arg(1)); + } + return n; + } + else { + return t->get_num_args(); + } + } + + expr* get_arg(app* t, unsigned idx) { + if (m.is_iff(t) && m_cardinality_solver) { + while (idx >= 1) { + SASSERT(m.is_iff(t)); + t = to_app(t->get_arg(1)); + --idx; + } + if (m.is_iff(t)) { + return t->get_arg(idx); + } + else { + return t; + } + } + else { + return t->get_arg(idx); + } + } void process(expr * n) { //SASSERT(m_result_stack.empty()); @@ -559,9 +622,9 @@ struct goal2sat::imp { visit(t->get_arg(0), root, !sign); continue; } - unsigned num = t->get_num_args(); + unsigned num = get_num_args(t); while (fr.m_idx < num) { - expr * arg = t->get_arg(fr.m_idx); + expr * arg = get_arg(t, fr.m_idx); fr.m_idx++; if (!visit(arg, false, false)) goto loop;