3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2026-03-17 02:30:01 +00:00

add xor parity solver feature

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2017-02-20 16:55:00 -08:00
parent cb050998e5
commit 98c5a779b4
6 changed files with 665 additions and 108 deletions

View file

@ -7,7 +7,7 @@ Module Name:
Abstract:
Extension for cardinality reasoning.
Extension for cardinality and xor reasoning.
Author:
@ -42,6 +42,16 @@ namespace sat {
SASSERT(m_size >= m_k && m_k > 0);
}
card_extension::xor::xor(unsigned index, literal lit, literal_vector const& lits):
m_index(index),
m_lit(lit),
m_size(lits.size())
{
for (unsigned i = 0; i < lits.size(); ++i) {
m_lits[i] = lits[i];
}
}
void card_extension::init_watch(bool_var v) {
if (m_var_infos.size() <= static_cast<unsigned>(v)) {
m_var_infos.resize(static_cast<unsigned>(v)+100);
@ -120,7 +130,7 @@ namespace sat {
if (m_var_infos.size() <= static_cast<unsigned>(lit.var())) {
return;
}
ptr_vector<card>*& cards = m_var_infos[lit.var()].m_lit_watch[lit.sign()];
ptr_vector<card>*& cards = m_var_infos[lit.var()].m_card_watch[lit.sign()];
if (!is_tag_empty(cards)) {
if (remove(*cards, c)) {
cards = set_tag_empty(cards);
@ -128,30 +138,6 @@ namespace sat {
}
}
ptr_vector<card_extension::card>* card_extension::set_tag_empty(ptr_vector<card>* c) {
return TAG(ptr_vector<card>*, c, 1);
}
bool card_extension::is_tag_empty(ptr_vector<card> const* c) {
return !c || GET_TAG(c) == 1;
}
ptr_vector<card_extension::card>* card_extension::set_tag_non_empty(ptr_vector<card>* c) {
return UNTAG(ptr_vector<card>*, c);
}
bool card_extension::remove(ptr_vector<card>& cards, card* c) {
unsigned sz = cards.size();
for (unsigned j = 0; j < sz; ++j) {
if (cards[j] == c) {
std::swap(cards[j], cards[sz-1]);
cards.pop_back();
return sz == 1;
}
}
return false;
}
void card_extension::assign(card& c, literal lit) {
switch (value(lit)) {
case l_true:
@ -183,14 +169,14 @@ namespace sat {
void card_extension::watch_literal(card& c, literal lit) {
TRACE("sat_verbose", tout << "watch: " << lit << "\n";);
init_watch(lit.var());
ptr_vector<card>* cards = m_var_infos[lit.var()].m_lit_watch[lit.sign()];
ptr_vector<card>* cards = m_var_infos[lit.var()].m_card_watch[lit.sign()];
if (cards == 0) {
cards = alloc(ptr_vector<card>);
m_var_infos[lit.var()].m_lit_watch[lit.sign()] = cards;
m_var_infos[lit.var()].m_card_watch[lit.sign()] = cards;
}
else if (is_tag_empty(cards)) {
cards = set_tag_non_empty(cards);
m_var_infos[lit.var()].m_lit_watch[lit.sign()] = cards;
m_var_infos[lit.var()].m_card_watch[lit.sign()] = cards;
}
TRACE("sat_verbose", tout << "insert: " << lit.var() << " " << lit.sign() << "\n";);
cards->push_back(&c);
@ -202,6 +188,155 @@ namespace sat {
s().set_conflict(justification::mk_ext_justification(c.index()), ~lit);
SASSERT(s().inconsistent());
}
void card_extension::clear_watch(xor& x) {
unwatch_literal(x[0], &x);
unwatch_literal(x[1], &x);
}
void card_extension::unwatch_literal(literal lit, xor* c) {
if (m_var_infos.size() <= static_cast<unsigned>(lit.var())) {
return;
}
xor_watch* xors = m_var_infos[lit.var()].m_xor_watch;
if (!is_tag_empty(xors)) {
if (remove(*xors, c)) {
xors = set_tag_empty(xors);
}
}
}
bool card_extension::parity(xor const& x, unsigned offset) const {
bool odd = false;
unsigned sz = x.size();
for (unsigned i = offset; i < sz; ++i) {
SASSERT(value(x[i]) != l_undef);
if (value(x[i]) == l_true) {
odd = !odd;
}
}
return odd;
}
void card_extension::init_watch(xor& x, bool is_true) {
clear_watch(x);
if (x.lit().sign() == is_true) {
x.negate();
}
unsigned sz = x.size();
unsigned j = 0;
for (unsigned i = 0; i < sz && j < 2; ++i) {
if (value(x[i]) == l_undef) {
x.swap(i, j);
++j;
}
}
switch (j) {
case 0:
if (!parity(x, 0)) {
set_conflict(x, x[0]);
}
break;
case 1:
assign(x, parity(x, 1) ? ~x[0] : x[0]);
break;
default:
SASSERT(j == 2);
watch_literal(x, x[0]);
watch_literal(x, x[1]);
break;
}
}
void card_extension::assign(xor& x, literal lit) {
switch (value(lit)) {
case l_true:
break;
case l_false:
set_conflict(x, lit);
break;
default:
m_stats.m_num_propagations++;
m_num_propagations_since_pop++;
if (s().m_config.m_drat) {
svector<drat::premise> ps;
literal_vector lits;
lits.push_back(~x.lit());
for (unsigned i = 1; i < x.size(); ++i) {
lits.push_back(x[i]);
}
lits.push_back(lit);
ps.push_back(drat::premise(drat::s_ext(), x.lit()));
s().m_drat.add(lits, ps);
}
s().assign(lit, justification::mk_ext_justification(x.index()));
break;
}
}
void card_extension::watch_literal(xor& x, literal lit) {
TRACE("sat_verbose", tout << "watch: " << lit << "\n";);
init_watch(lit.var());
xor_watch*& xors = m_var_infos[lit.var()].m_xor_watch;
if (xors == 0) {
xors = alloc(ptr_vector<xor>);
}
else if (is_tag_empty(xors)) {
xors = set_tag_non_empty(xors);
}
xors->push_back(&x);
TRACE("sat_verbose", tout << "insert: " << lit.var() << " " << lit.sign() << "\n";);
}
void card_extension::set_conflict(xor& x, literal lit) {
TRACE("sat", display(tout, x, true); );
SASSERT(validate_conflict(x));
s().set_conflict(justification::mk_ext_justification(x.index()), ~lit);
SASSERT(s().inconsistent());
}
lbool card_extension::add_assign(xor& x, literal alit) {
// literal is assigned
unsigned sz = x.size();
TRACE("sat", tout << "assign: " << x.lit() << ": " << ~alit << "@" << lvl(~alit) << "\n";);
SASSERT(value(alit) != l_undef);
SASSERT(value(x.lit()) == l_true);
unsigned index = 0;
for (; index <= 2; ++index) {
if (x[index].var() == alit.var()) break;
}
if (index == 2) {
// literal is no longer watched.
return l_undef;
}
SASSERT(x[index].var() == alit.var());
// find a literal to swap with:
for (unsigned i = 2; i < sz; ++i) {
literal lit2 = x[i];
if (value(lit2) == l_undef) {
x.swap(index, i);
watch_literal(x, lit2);
return l_undef;
}
}
if (index == 0) {
x.swap(0, 1);
}
// alit resides at index 1.
SASSERT(x[1].var() == alit.var());
if (value(x[0]) == l_undef) {
bool p = parity(x, 1);
assign(x, p ? ~x[0] : x[0]);
}
else if (!parity(x, 0)) {
set_conflict(x, x[0]);
}
return s().inconsistent() ? l_false : l_true;
}
void card_extension::normalize_active_coeffs() {
while (!m_active_var_set.empty()) m_active_var_set.erase();
@ -288,6 +423,8 @@ namespace sat {
unsigned init_marks = m_num_marks;
vector<justification> jus;
do {
if (offset == 0) {
@ -349,9 +486,22 @@ namespace sat {
}
case justification::EXT_JUSTIFICATION: {
unsigned index = js.get_ext_justification_idx();
card& c = *m_constraints[index];
m_bound += offset * c.k();
process_card(c, offset);
if (is_card_index(index)) {
card& c = index2card(index);
m_bound += offset * c.k();
process_card(c, offset);
}
else {
// jus.push_back(js);
m_lemma.reset();
m_bound += offset;
inc_coeff(consequent, offset);
get_xor_antecedents(idx, m_lemma);
// get_antecedents(consequent, index, m_lemma);
for (unsigned i = 0; i < m_lemma.size(); ++i) {
process_antecedent(~m_lemma[i], offset);
}
}
break;
}
default:
@ -424,7 +574,6 @@ namespace sat {
lbool val = m_solver->value(v);
bool is_true = val == l_true;
bool append = coeff != 0 && val != l_undef && (coeff < 0 == is_true);
if (append) {
literal lit(v, !is_true);
if (lvl(lit) == m_conflict_lvl) {
@ -440,6 +589,17 @@ namespace sat {
}
}
if (jus.size() > 1) {
std::cout << jus.size() << "\n";
for (unsigned i = 0; i < jus.size(); ++i) {
s().display_justification(std::cout, jus[i]); std::cout << "\n";
}
std::cout << m_lemma << "\n";
active2pb(m_A);
display(std::cout, m_A);
}
if (slack >= 0) {
IF_VERBOSE(2, verbose_stream() << "(sat.card bail slack objective not met " << slack << ")\n";);
goto bail_out;
@ -564,7 +724,7 @@ namespace sat {
return p;
}
card_extension::card_extension(): m_solver(0) {
card_extension::card_extension(): m_solver(0), m_has_xor(false) {
TRACE("sat", tout << this << "\n";);
}
@ -578,33 +738,170 @@ namespace sat {
}
void card_extension::add_at_least(bool_var v, literal_vector const& lits, unsigned k) {
unsigned index = m_constraints.size();
unsigned index = 2*m_cards.size();
card* c = new (memory::allocate(card::get_obj_size(lits.size()))) card(index, literal(v, false), lits, k);
m_constraints.push_back(c);
m_cards.push_back(c);
init_watch(v);
m_var_infos[v].m_card = c;
m_var_trail.push_back(v);
}
void card_extension::add_xor(bool_var v, literal_vector const& lits) {
m_has_xor = true;
unsigned index = 2*m_xors.size()+1;
xor* x = new (memory::allocate(xor::get_obj_size(lits.size()))) xor(index, literal(v, false), lits);
m_xors.push_back(x);
init_watch(v);
m_var_infos[v].m_xor = x;
m_var_trail.push_back(v);
}
void card_extension::propagate(literal l, ext_constraint_idx idx, bool & keep) {
UNREACHABLE();
}
void card_extension::get_antecedents(literal l, ext_justification_idx idx, literal_vector & r) {
card& c = *m_constraints[idx];
DEBUG_CODE(
bool found = false;
for (unsigned i = 0; !found && i < c.k(); ++i) {
found = c[i] == l;
void card_extension::ensure_parity_size(bool_var v) {
if (m_parity_marks.size() <= static_cast<unsigned>(v)) {
m_parity_marks.resize(static_cast<unsigned>(v) + 1, 0);
}
}
unsigned card_extension::get_parity(bool_var v) {
return m_parity_marks.get(v, 0);
}
void card_extension::inc_parity(bool_var v) {
ensure_parity_size(v);
m_parity_marks[v]++;
}
void card_extension::reset_parity(bool_var v) {
ensure_parity_size(v);
m_parity_marks[v] = 0;
}
/**
\brief perform parity resolution on xor premises.
The idea is to collect premises based on xor resolvents.
Variables that are repeated an even number of times cancel out.
*/
void card_extension::get_xor_antecedents(unsigned index, literal_vector& r) {
literal_vector const& lits = s().m_trail;
literal l = lits[index + 1];
unsigned level = lvl(l);
bool_var v = l.var();
SASSERT(s().m_justification[v].get_kind() == justification::EXT_JUSTIFICATION);
SASSERT(!is_card_index(s().m_justification[v].get_ext_justification_idx()));
unsigned num_marks = 0;
unsigned count = 0;
while (true) {
++count;
justification js = s().m_justification[v];
if (js.get_kind() == justification::EXT_JUSTIFICATION) {
unsigned idx = js.get_ext_justification_idx();
if (is_card_index(idx)) {
r.push_back(l);
}
else {
xor& x = index2xor(idx);
if (lvl(x.lit()) > 0) r.push_back(x.lit());
if (x[1].var() == l.var()) {
x.swap(0, 1);
}
SASSERT(x[0].var() == l.var());
for (unsigned i = 1; i < x.size(); ++i) {
literal lit(value(x[i]) == l_true ? x[i] : ~x[i]);
inc_parity(lit.var());
if (true || lvl(lit) == level) {
++num_marks;
}
else {
m_parity_trail.push_back(lit);
}
}
}
}
SASSERT(found););
else {
r.push_back(l);
}
while (num_marks > 0) {
l = lits[index];
v = l.var();
unsigned n = get_parity(v);
if (n > 0) {
reset_parity(v);
if (n > 1) {
IF_VERBOSE(2, verbose_stream() << "parity greater than 1: " << l << " " << n << "\n";);
}
if (n % 2 == 1) {
break;
}
IF_VERBOSE(2, verbose_stream() << "skip even parity: " << l << "\n";);
--num_marks;
}
--index;
}
if (num_marks == 0) {
break;
}
--index;
--num_marks;
}
// now walk the defined literals
for (unsigned i = 0; i < m_parity_trail.size(); ++i) {
literal lit = m_parity_trail[i];
if (get_parity(lit.var()) % 2 == 1) {
r.push_back(lit);
}
else {
IF_VERBOSE(2, verbose_stream() << "skip even parity: " << lit << "\n";);
}
reset_parity(lit.var());
}
m_parity_trail.reset();
}
void card_extension::get_antecedents(literal l, ext_justification_idx idx, literal_vector & r) {
if (is_card_index(idx)) {
card& c = index2card(idx);
r.push_back(c.lit());
SASSERT(value(c.lit()) == l_true);
for (unsigned i = c.k(); i < c.size(); ++i) {
SASSERT(value(c[i]) == l_false);
r.push_back(~c[i]);
DEBUG_CODE(
bool found = false;
for (unsigned i = 0; !found && i < c.k(); ++i) {
found = c[i] == l;
}
SASSERT(found););
r.push_back(c.lit());
SASSERT(value(c.lit()) == l_true);
for (unsigned i = c.k(); i < c.size(); ++i) {
SASSERT(value(c[i]) == l_false);
r.push_back(~c[i]);
}
}
else {
xor& x = index2xor(idx);
r.push_back(x.lit());
TRACE("sat", display(tout << l << " ", x, true););
SASSERT(value(x.lit()) == l_true);
SASSERT(x[0].var() == l.var() || x[1].var() == l.var());
if (x[0].var() == l.var()) {
SASSERT(value(x[1]) != l_undef);
r.push_back(value(x[1]) == l_true ? x[1] : ~x[1]);
}
else {
SASSERT(value(x[0]) != l_undef);
r.push_back(value(x[0]) == l_true ? x[0] : ~x[0]);
}
for (unsigned i = 2; i < x.size(); ++i) {
SASSERT(value(x[i]) != l_undef);
r.push_back(value(x[i]) == l_true ? x[i] : ~x[i]);
}
}
}
@ -670,10 +967,11 @@ namespace sat {
if (s().inconsistent()) return;
if (v >= m_var_infos.size()) return;
var_info& vinfo = m_var_infos[v];
ptr_vector<card>* cards = vinfo.m_lit_watch[!l.sign()];
//TRACE("sat", tout << "retrieve: " << v << " " << !l.sign() << "\n";);
//TRACE("sat", tout << "asserted: " << l << " " << (cards ? "non-empty" : "empty") << "\n";);
static unsigned is_empty = 0, non_empty = 0;
ptr_vector<card>* cards = vinfo.m_card_watch[!l.sign()];
card* crd = vinfo.m_card;
xor* x = vinfo.m_xor;
ptr_vector<xor>* xors = vinfo.m_xor_watch;
if (!is_tag_empty(cards)) {
ptr_vector<card>::iterator begin = cards->begin();
ptr_vector<card>::iterator it = begin, it2 = it, end = cards->end();
@ -702,14 +1000,56 @@ namespace sat {
}
cards->set_end(it2);
if (cards->empty()) {
m_var_infos[v].m_lit_watch[!l.sign()] = set_tag_empty(cards);
m_var_infos[v].m_card_watch[!l.sign()] = set_tag_empty(cards);
}
}
card* crd = vinfo.m_card;
if (crd != 0 && !s().inconsistent()) {
init_watch(*crd, !l.sign());
}
if (m_has_xor && !s().inconsistent()) {
asserted_xor(l, xors, x);
}
}
void card_extension::asserted_xor(literal l, ptr_vector<xor>* xors, xor* x) {
TRACE("sat", tout << l << " " << !is_tag_empty(xors) << " " << (x != 0) << "\n";);
if (!is_tag_empty(xors)) {
ptr_vector<xor>::iterator begin = xors->begin();
ptr_vector<xor>::iterator it = begin, it2 = it, end = xors->end();
for (; it != end; ++it) {
xor& c = *(*it);
if (value(c.lit()) != l_true) {
continue;
}
switch (add_assign(c, ~l)) {
case l_false: // conflict
for (; it != end; ++it, ++it2) {
*it2 = *it;
}
SASSERT(s().inconsistent());
xors->set_end(it2);
return;
case l_undef: // watch literal was swapped
break;
case l_true: // unit propagation, keep watching the literal
if (it2 != it) {
*it2 = *it;
}
++it2;
break;
}
}
xors->set_end(it2);
if (xors->empty()) {
m_var_infos[l.var()].m_xor_watch = set_tag_empty(xors);
}
}
if (x != 0 && !s().inconsistent()) {
init_watch(*x, !l.sign());
}
}
check_result card_extension::check() { return CR_DONE; }
@ -730,6 +1070,10 @@ namespace sat {
clear_watch(*c);
m_var_infos[v].m_card = 0;
dealloc(c);
xor* x = m_var_infos[v].m_xor;
clear_watch(*x);
m_var_infos[v].m_xor = 0;
dealloc(x);
}
}
m_var_lim.resize(new_lim);
@ -743,22 +1087,30 @@ namespace sat {
extension* card_extension::copy(solver* s) {
card_extension* result = alloc(card_extension);
result->set_solver(s);
for (unsigned i = 0; i < m_constraints.size(); ++i) {
for (unsigned i = 0; i < m_cards.size(); ++i) {
literal_vector lits;
card& c = *m_constraints[i];
card& c = *m_cards[i];
for (unsigned i = 0; i < c.size(); ++i) {
lits.push_back(c[i]);
}
result->add_at_least(c.lit().var(), lits, c.k());
}
for (unsigned i = 0; i < m_xors.size(); ++i) {
literal_vector lits;
xor& x = *m_xors[i];
for (unsigned i = 0; i < x.size(); ++i) {
lits.push_back(x[i]);
}
result->add_xor(x.lit().var(), lits);
}
return result;
}
void card_extension::find_mutexes(literal_vector& lits, vector<literal_vector> & mutexes) {
literal_set slits(lits);
bool change = false;
for (unsigned i = 0; i < m_constraints.size(); ++i) {
card& c = *m_constraints[i];
for (unsigned i = 0; i < m_cards.size(); ++i) {
card& c = *m_cards[i];
if (c.size() == c.k() + 1) {
literal_vector mux;
for (unsigned j = 0; j < c.size(); ++j) {
@ -786,10 +1138,10 @@ namespace sat {
}
}
void card_extension::display_watch(std::ostream& out, bool_var v, bool sign) const {
watch const* w = m_var_infos[v].m_lit_watch[sign];
void card_extension::display_watch(std::ostream& out, bool_var v, bool sign) const {
card_watch const* w = m_var_infos[v].m_card_watch[sign];
if (!is_tag_empty(w)) {
watch const& wl = *w;
card_watch const& wl = *w;
out << literal(v, sign) << " |-> ";
for (unsigned i = 0; i < wl.size(); ++i) {
out << wl[i]->lit() << " ";
@ -798,6 +1150,18 @@ namespace sat {
}
}
void card_extension::display_watch(std::ostream& out, bool_var v) const {
xor_watch const* w = m_var_infos[v].m_xor_watch;
if (!is_tag_empty(w)) {
xor_watch const& wl = *w;
out << "v" << v << " |-> ";
for (unsigned i = 0; i < wl.size(); ++i) {
out << wl[i]->lit() << " ";
}
out << "\n";
}
}
void card_extension::display(std::ostream& out, ineq& ineq) const {
for (unsigned i = 0; i < ineq.m_lits.size(); ++i) {
out << ineq.m_coeffs[i] << "*" << ineq.m_lits[i] << " ";
@ -805,6 +1169,35 @@ namespace sat {
out << ">= " << ineq.m_k << "\n";
}
void card_extension::display(std::ostream& out, xor& x, bool values) const {
out << "xor " << x.lit();
if (x.lit() != null_literal && values) {
out << "@(" << value(x.lit());
if (value(x.lit()) != l_undef) {
out << ":" << lvl(x.lit());
}
out << "): ";
}
else {
out << ": ";
}
for (unsigned i = 0; i < x.size(); ++i) {
literal l = x[i];
out << l;
if (values) {
out << "@(" << value(l);
if (value(l) != l_undef) {
out << ":" << lvl(l);
}
out << ") ";
}
else {
out << " ";
}
}
out << "\n";
}
void card_extension::display(std::ostream& out, card& c, bool values) const {
out << c.lit() << "[" << c.size() << "]";
if (c.lit() != null_literal && values) {
@ -838,23 +1231,33 @@ namespace sat {
for (unsigned vi = 0; vi < m_var_infos.size(); ++vi) {
display_watch(out, vi, false);
display_watch(out, vi, true);
display_watch(out, vi);
}
for (unsigned vi = 0; vi < m_var_infos.size(); ++vi) {
card* c = m_var_infos[vi].m_card;
if (c) {
display(out, *c, false);
}
if (c) display(out, *c, false);
xor* x = m_var_infos[vi].m_xor;
if (x) display(out, *x, false);
}
return out;
}
std::ostream& card_extension::display_justification(std::ostream& out, ext_justification_idx idx) const {
card& c = *m_constraints[idx];
out << "bound " << c.lit() << ": ";
for (unsigned i = 0; i < c.size(); ++i) {
out << c[i] << " ";
if (is_card_index(idx)) {
card& c = index2card(idx);
out << "bound " << c.lit() << ": ";
for (unsigned i = 0; i < c.size(); ++i) {
out << c[i] << " ";
}
out << ">= " << c.k();
}
else {
xor& x = index2xor(idx);
out << "xor " << x.lit() << ": ";
for (unsigned i = 0; i < x.size(); ++i) {
out << x[i] << " ";
}
}
out << ">= " << c.k();
return out;
}
@ -870,6 +1273,9 @@ namespace sat {
}
return false;
}
bool card_extension::validate_conflict(xor& x) {
return !parity(x, 0);
}
bool card_extension::validate_unit_propagation(card const& c) {
if (value(c.lit()) != l_true) return false;
for (unsigned i = c.k(); i < c.size(); ++i) {
@ -933,12 +1339,23 @@ namespace sat {
}
case justification::EXT_JUSTIFICATION: {
unsigned index = js.get_ext_justification_idx();
card& c = *m_constraints[index];
p.reset(offset*c.k());
for (unsigned i = 0; i < c.size(); ++i) {
p.push(c[i], offset);
if (is_card_index(index)) {
card& c = index2card(index);
p.reset(offset*c.k());
for (unsigned i = 0; i < c.size(); ++i) {
p.push(c[i], offset);
}
p.push(~c.lit(), offset*c.k());
}
else {
literal_vector ls;
get_antecedents(lit, index, ls);
p.reset(offset);
for (unsigned i = 0; i < ls.size(); ++i) {
p.push(~ls[i], offset);
}
p.push(~index2xor(index).lit(), offset);
}
p.push(~c.lit(), offset*c.k());
break;
}
default: