3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-06-20 12:53:38 +00:00

add xor parity solver feature

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2017-02-20 16:55:00 -08:00
parent cb050998e5
commit 98c5a779b4
6 changed files with 665 additions and 108 deletions

View file

@ -50,6 +50,7 @@ struct pb2bv_rewriter::imp {
rational m_k; rational m_k;
vector<rational> m_coeffs; vector<rational> m_coeffs;
bool m_keep_cardinality_constraints; bool m_keep_cardinality_constraints;
unsigned m_min_arity;
template<lbool is_le> template<lbool is_le>
expr_ref mk_le_ge(expr_ref_vector& fmls, expr* a, expr* b, expr* bound) { expr_ref mk_le_ge(expr_ref_vector& fmls, expr* a, expr* b, expr* bound) {
@ -416,7 +417,8 @@ struct pb2bv_rewriter::imp {
bv(m), bv(m),
m_trail(m), m_trail(m),
m_args(m), m_args(m),
m_keep_cardinality_constraints(true) m_keep_cardinality_constraints(true),
m_min_arity(8)
{} {}
bool mk_app(bool full, func_decl * f, unsigned sz, expr * const* args, expr_ref & result) { bool mk_app(bool full, func_decl * f, unsigned sz, expr * const* args, expr_ref & result) {
@ -530,27 +532,26 @@ struct pb2bv_rewriter::imp {
bool mk_pb(bool full, func_decl * f, unsigned sz, expr * const* args, expr_ref & result) { bool mk_pb(bool full, func_decl * f, unsigned sz, expr * const* args, expr_ref & result) {
SASSERT(f->get_family_id() == pb.get_family_id()); SASSERT(f->get_family_id() == pb.get_family_id());
if (is_or(f)) { if (is_or(f)) {
if (m_keep_cardinality_constraints) return false;
result = m.mk_or(sz, args); result = m.mk_or(sz, args);
} }
else if (pb.is_at_most_k(f) && pb.get_k(f).is_unsigned()) { else if (pb.is_at_most_k(f) && pb.get_k(f).is_unsigned()) {
if (m_keep_cardinality_constraints) return false; if (m_keep_cardinality_constraints && f->get_arity() >= m_min_arity) return false;
result = m_sort.le(full, pb.get_k(f).get_unsigned(), sz, args); result = m_sort.le(full, pb.get_k(f).get_unsigned(), sz, args);
} }
else if (pb.is_at_least_k(f) && pb.get_k(f).is_unsigned()) { else if (pb.is_at_least_k(f) && pb.get_k(f).is_unsigned()) {
if (m_keep_cardinality_constraints) return false; if (m_keep_cardinality_constraints && f->get_arity() >= m_min_arity) return false;
result = m_sort.ge(full, pb.get_k(f).get_unsigned(), sz, args); result = m_sort.ge(full, pb.get_k(f).get_unsigned(), sz, args);
} }
else if (pb.is_eq(f) && pb.get_k(f).is_unsigned() && pb.has_unit_coefficients(f)) { else if (pb.is_eq(f) && pb.get_k(f).is_unsigned() && pb.has_unit_coefficients(f)) {
if (m_keep_cardinality_constraints) return false; if (m_keep_cardinality_constraints && f->get_arity() >= m_min_arity) return false;
result = m_sort.eq(full, pb.get_k(f).get_unsigned(), sz, args); result = m_sort.eq(full, pb.get_k(f).get_unsigned(), sz, args);
} }
else if (pb.is_le(f) && pb.get_k(f).is_unsigned() && pb.has_unit_coefficients(f)) { else if (pb.is_le(f) && pb.get_k(f).is_unsigned() && pb.has_unit_coefficients(f)) {
if (m_keep_cardinality_constraints) return false; if (m_keep_cardinality_constraints && f->get_arity() >= m_min_arity) return false;
result = m_sort.le(full, pb.get_k(f).get_unsigned(), sz, args); result = m_sort.le(full, pb.get_k(f).get_unsigned(), sz, args);
} }
else if (pb.is_ge(f) && pb.get_k(f).is_unsigned() && pb.has_unit_coefficients(f)) { else if (pb.is_ge(f) && pb.get_k(f).is_unsigned() && pb.has_unit_coefficients(f)) {
if (m_keep_cardinality_constraints) return false; if (m_keep_cardinality_constraints && f->get_arity() >= m_min_arity) return false;
result = m_sort.ge(full, pb.get_k(f).get_unsigned(), sz, args); result = m_sort.ge(full, pb.get_k(f).get_unsigned(), sz, args);
} }
else { else {

View file

@ -7,7 +7,7 @@ Module Name:
Abstract: Abstract:
Extension for cardinality reasoning. Extension for cardinality and xor reasoning.
Author: Author:
@ -42,6 +42,16 @@ namespace sat {
SASSERT(m_size >= m_k && m_k > 0); SASSERT(m_size >= m_k && m_k > 0);
} }
card_extension::xor::xor(unsigned index, literal lit, literal_vector const& lits):
m_index(index),
m_lit(lit),
m_size(lits.size())
{
for (unsigned i = 0; i < lits.size(); ++i) {
m_lits[i] = lits[i];
}
}
void card_extension::init_watch(bool_var v) { void card_extension::init_watch(bool_var v) {
if (m_var_infos.size() <= static_cast<unsigned>(v)) { if (m_var_infos.size() <= static_cast<unsigned>(v)) {
m_var_infos.resize(static_cast<unsigned>(v)+100); m_var_infos.resize(static_cast<unsigned>(v)+100);
@ -120,7 +130,7 @@ namespace sat {
if (m_var_infos.size() <= static_cast<unsigned>(lit.var())) { if (m_var_infos.size() <= static_cast<unsigned>(lit.var())) {
return; return;
} }
ptr_vector<card>*& cards = m_var_infos[lit.var()].m_lit_watch[lit.sign()]; ptr_vector<card>*& cards = m_var_infos[lit.var()].m_card_watch[lit.sign()];
if (!is_tag_empty(cards)) { if (!is_tag_empty(cards)) {
if (remove(*cards, c)) { if (remove(*cards, c)) {
cards = set_tag_empty(cards); cards = set_tag_empty(cards);
@ -128,30 +138,6 @@ namespace sat {
} }
} }
ptr_vector<card_extension::card>* card_extension::set_tag_empty(ptr_vector<card>* c) {
return TAG(ptr_vector<card>*, c, 1);
}
bool card_extension::is_tag_empty(ptr_vector<card> const* c) {
return !c || GET_TAG(c) == 1;
}
ptr_vector<card_extension::card>* card_extension::set_tag_non_empty(ptr_vector<card>* c) {
return UNTAG(ptr_vector<card>*, c);
}
bool card_extension::remove(ptr_vector<card>& cards, card* c) {
unsigned sz = cards.size();
for (unsigned j = 0; j < sz; ++j) {
if (cards[j] == c) {
std::swap(cards[j], cards[sz-1]);
cards.pop_back();
return sz == 1;
}
}
return false;
}
void card_extension::assign(card& c, literal lit) { void card_extension::assign(card& c, literal lit) {
switch (value(lit)) { switch (value(lit)) {
case l_true: case l_true:
@ -183,14 +169,14 @@ namespace sat {
void card_extension::watch_literal(card& c, literal lit) { void card_extension::watch_literal(card& c, literal lit) {
TRACE("sat_verbose", tout << "watch: " << lit << "\n";); TRACE("sat_verbose", tout << "watch: " << lit << "\n";);
init_watch(lit.var()); init_watch(lit.var());
ptr_vector<card>* cards = m_var_infos[lit.var()].m_lit_watch[lit.sign()]; ptr_vector<card>* cards = m_var_infos[lit.var()].m_card_watch[lit.sign()];
if (cards == 0) { if (cards == 0) {
cards = alloc(ptr_vector<card>); cards = alloc(ptr_vector<card>);
m_var_infos[lit.var()].m_lit_watch[lit.sign()] = cards; m_var_infos[lit.var()].m_card_watch[lit.sign()] = cards;
} }
else if (is_tag_empty(cards)) { else if (is_tag_empty(cards)) {
cards = set_tag_non_empty(cards); cards = set_tag_non_empty(cards);
m_var_infos[lit.var()].m_lit_watch[lit.sign()] = cards; m_var_infos[lit.var()].m_card_watch[lit.sign()] = cards;
} }
TRACE("sat_verbose", tout << "insert: " << lit.var() << " " << lit.sign() << "\n";); TRACE("sat_verbose", tout << "insert: " << lit.var() << " " << lit.sign() << "\n";);
cards->push_back(&c); cards->push_back(&c);
@ -203,6 +189,155 @@ namespace sat {
SASSERT(s().inconsistent()); SASSERT(s().inconsistent());
} }
void card_extension::clear_watch(xor& x) {
unwatch_literal(x[0], &x);
unwatch_literal(x[1], &x);
}
void card_extension::unwatch_literal(literal lit, xor* c) {
if (m_var_infos.size() <= static_cast<unsigned>(lit.var())) {
return;
}
xor_watch* xors = m_var_infos[lit.var()].m_xor_watch;
if (!is_tag_empty(xors)) {
if (remove(*xors, c)) {
xors = set_tag_empty(xors);
}
}
}
bool card_extension::parity(xor const& x, unsigned offset) const {
bool odd = false;
unsigned sz = x.size();
for (unsigned i = offset; i < sz; ++i) {
SASSERT(value(x[i]) != l_undef);
if (value(x[i]) == l_true) {
odd = !odd;
}
}
return odd;
}
void card_extension::init_watch(xor& x, bool is_true) {
clear_watch(x);
if (x.lit().sign() == is_true) {
x.negate();
}
unsigned sz = x.size();
unsigned j = 0;
for (unsigned i = 0; i < sz && j < 2; ++i) {
if (value(x[i]) == l_undef) {
x.swap(i, j);
++j;
}
}
switch (j) {
case 0:
if (!parity(x, 0)) {
set_conflict(x, x[0]);
}
break;
case 1:
assign(x, parity(x, 1) ? ~x[0] : x[0]);
break;
default:
SASSERT(j == 2);
watch_literal(x, x[0]);
watch_literal(x, x[1]);
break;
}
}
void card_extension::assign(xor& x, literal lit) {
switch (value(lit)) {
case l_true:
break;
case l_false:
set_conflict(x, lit);
break;
default:
m_stats.m_num_propagations++;
m_num_propagations_since_pop++;
if (s().m_config.m_drat) {
svector<drat::premise> ps;
literal_vector lits;
lits.push_back(~x.lit());
for (unsigned i = 1; i < x.size(); ++i) {
lits.push_back(x[i]);
}
lits.push_back(lit);
ps.push_back(drat::premise(drat::s_ext(), x.lit()));
s().m_drat.add(lits, ps);
}
s().assign(lit, justification::mk_ext_justification(x.index()));
break;
}
}
void card_extension::watch_literal(xor& x, literal lit) {
TRACE("sat_verbose", tout << "watch: " << lit << "\n";);
init_watch(lit.var());
xor_watch*& xors = m_var_infos[lit.var()].m_xor_watch;
if (xors == 0) {
xors = alloc(ptr_vector<xor>);
}
else if (is_tag_empty(xors)) {
xors = set_tag_non_empty(xors);
}
xors->push_back(&x);
TRACE("sat_verbose", tout << "insert: " << lit.var() << " " << lit.sign() << "\n";);
}
void card_extension::set_conflict(xor& x, literal lit) {
TRACE("sat", display(tout, x, true); );
SASSERT(validate_conflict(x));
s().set_conflict(justification::mk_ext_justification(x.index()), ~lit);
SASSERT(s().inconsistent());
}
lbool card_extension::add_assign(xor& x, literal alit) {
// literal is assigned
unsigned sz = x.size();
TRACE("sat", tout << "assign: " << x.lit() << ": " << ~alit << "@" << lvl(~alit) << "\n";);
SASSERT(value(alit) != l_undef);
SASSERT(value(x.lit()) == l_true);
unsigned index = 0;
for (; index <= 2; ++index) {
if (x[index].var() == alit.var()) break;
}
if (index == 2) {
// literal is no longer watched.
return l_undef;
}
SASSERT(x[index].var() == alit.var());
// find a literal to swap with:
for (unsigned i = 2; i < sz; ++i) {
literal lit2 = x[i];
if (value(lit2) == l_undef) {
x.swap(index, i);
watch_literal(x, lit2);
return l_undef;
}
}
if (index == 0) {
x.swap(0, 1);
}
// alit resides at index 1.
SASSERT(x[1].var() == alit.var());
if (value(x[0]) == l_undef) {
bool p = parity(x, 1);
assign(x, p ? ~x[0] : x[0]);
}
else if (!parity(x, 0)) {
set_conflict(x, x[0]);
}
return s().inconsistent() ? l_false : l_true;
}
void card_extension::normalize_active_coeffs() { void card_extension::normalize_active_coeffs() {
while (!m_active_var_set.empty()) m_active_var_set.erase(); while (!m_active_var_set.empty()) m_active_var_set.erase();
unsigned i = 0, j = 0, sz = m_active_vars.size(); unsigned i = 0, j = 0, sz = m_active_vars.size();
@ -288,6 +423,8 @@ namespace sat {
unsigned init_marks = m_num_marks; unsigned init_marks = m_num_marks;
vector<justification> jus;
do { do {
if (offset == 0) { if (offset == 0) {
@ -349,9 +486,22 @@ namespace sat {
} }
case justification::EXT_JUSTIFICATION: { case justification::EXT_JUSTIFICATION: {
unsigned index = js.get_ext_justification_idx(); unsigned index = js.get_ext_justification_idx();
card& c = *m_constraints[index]; if (is_card_index(index)) {
card& c = index2card(index);
m_bound += offset * c.k(); m_bound += offset * c.k();
process_card(c, offset); process_card(c, offset);
}
else {
// jus.push_back(js);
m_lemma.reset();
m_bound += offset;
inc_coeff(consequent, offset);
get_xor_antecedents(idx, m_lemma);
// get_antecedents(consequent, index, m_lemma);
for (unsigned i = 0; i < m_lemma.size(); ++i) {
process_antecedent(~m_lemma[i], offset);
}
}
break; break;
} }
default: default:
@ -424,7 +574,6 @@ namespace sat {
lbool val = m_solver->value(v); lbool val = m_solver->value(v);
bool is_true = val == l_true; bool is_true = val == l_true;
bool append = coeff != 0 && val != l_undef && (coeff < 0 == is_true); bool append = coeff != 0 && val != l_undef && (coeff < 0 == is_true);
if (append) { if (append) {
literal lit(v, !is_true); literal lit(v, !is_true);
if (lvl(lit) == m_conflict_lvl) { if (lvl(lit) == m_conflict_lvl) {
@ -440,6 +589,17 @@ namespace sat {
} }
} }
if (jus.size() > 1) {
std::cout << jus.size() << "\n";
for (unsigned i = 0; i < jus.size(); ++i) {
s().display_justification(std::cout, jus[i]); std::cout << "\n";
}
std::cout << m_lemma << "\n";
active2pb(m_A);
display(std::cout, m_A);
}
if (slack >= 0) { if (slack >= 0) {
IF_VERBOSE(2, verbose_stream() << "(sat.card bail slack objective not met " << slack << ")\n";); IF_VERBOSE(2, verbose_stream() << "(sat.card bail slack objective not met " << slack << ")\n";);
goto bail_out; goto bail_out;
@ -564,7 +724,7 @@ namespace sat {
return p; return p;
} }
card_extension::card_extension(): m_solver(0) { card_extension::card_extension(): m_solver(0), m_has_xor(false) {
TRACE("sat", tout << this << "\n";); TRACE("sat", tout << this << "\n";);
} }
@ -578,20 +738,137 @@ namespace sat {
} }
void card_extension::add_at_least(bool_var v, literal_vector const& lits, unsigned k) { void card_extension::add_at_least(bool_var v, literal_vector const& lits, unsigned k) {
unsigned index = m_constraints.size(); unsigned index = 2*m_cards.size();
card* c = new (memory::allocate(card::get_obj_size(lits.size()))) card(index, literal(v, false), lits, k); card* c = new (memory::allocate(card::get_obj_size(lits.size()))) card(index, literal(v, false), lits, k);
m_constraints.push_back(c); m_cards.push_back(c);
init_watch(v); init_watch(v);
m_var_infos[v].m_card = c; m_var_infos[v].m_card = c;
m_var_trail.push_back(v); m_var_trail.push_back(v);
} }
void card_extension::add_xor(bool_var v, literal_vector const& lits) {
m_has_xor = true;
unsigned index = 2*m_xors.size()+1;
xor* x = new (memory::allocate(xor::get_obj_size(lits.size()))) xor(index, literal(v, false), lits);
m_xors.push_back(x);
init_watch(v);
m_var_infos[v].m_xor = x;
m_var_trail.push_back(v);
}
void card_extension::propagate(literal l, ext_constraint_idx idx, bool & keep) { void card_extension::propagate(literal l, ext_constraint_idx idx, bool & keep) {
UNREACHABLE(); UNREACHABLE();
} }
void card_extension::ensure_parity_size(bool_var v) {
if (m_parity_marks.size() <= static_cast<unsigned>(v)) {
m_parity_marks.resize(static_cast<unsigned>(v) + 1, 0);
}
}
unsigned card_extension::get_parity(bool_var v) {
return m_parity_marks.get(v, 0);
}
void card_extension::inc_parity(bool_var v) {
ensure_parity_size(v);
m_parity_marks[v]++;
}
void card_extension::reset_parity(bool_var v) {
ensure_parity_size(v);
m_parity_marks[v] = 0;
}
/**
\brief perform parity resolution on xor premises.
The idea is to collect premises based on xor resolvents.
Variables that are repeated an even number of times cancel out.
*/
void card_extension::get_xor_antecedents(unsigned index, literal_vector& r) {
literal_vector const& lits = s().m_trail;
literal l = lits[index + 1];
unsigned level = lvl(l);
bool_var v = l.var();
SASSERT(s().m_justification[v].get_kind() == justification::EXT_JUSTIFICATION);
SASSERT(!is_card_index(s().m_justification[v].get_ext_justification_idx()));
unsigned num_marks = 0;
unsigned count = 0;
while (true) {
++count;
justification js = s().m_justification[v];
if (js.get_kind() == justification::EXT_JUSTIFICATION) {
unsigned idx = js.get_ext_justification_idx();
if (is_card_index(idx)) {
r.push_back(l);
}
else {
xor& x = index2xor(idx);
if (lvl(x.lit()) > 0) r.push_back(x.lit());
if (x[1].var() == l.var()) {
x.swap(0, 1);
}
SASSERT(x[0].var() == l.var());
for (unsigned i = 1; i < x.size(); ++i) {
literal lit(value(x[i]) == l_true ? x[i] : ~x[i]);
inc_parity(lit.var());
if (true || lvl(lit) == level) {
++num_marks;
}
else {
m_parity_trail.push_back(lit);
}
}
}
}
else {
r.push_back(l);
}
while (num_marks > 0) {
l = lits[index];
v = l.var();
unsigned n = get_parity(v);
if (n > 0) {
reset_parity(v);
if (n > 1) {
IF_VERBOSE(2, verbose_stream() << "parity greater than 1: " << l << " " << n << "\n";);
}
if (n % 2 == 1) {
break;
}
IF_VERBOSE(2, verbose_stream() << "skip even parity: " << l << "\n";);
--num_marks;
}
--index;
}
if (num_marks == 0) {
break;
}
--index;
--num_marks;
}
// now walk the defined literals
for (unsigned i = 0; i < m_parity_trail.size(); ++i) {
literal lit = m_parity_trail[i];
if (get_parity(lit.var()) % 2 == 1) {
r.push_back(lit);
}
else {
IF_VERBOSE(2, verbose_stream() << "skip even parity: " << lit << "\n";);
}
reset_parity(lit.var());
}
m_parity_trail.reset();
}
void card_extension::get_antecedents(literal l, ext_justification_idx idx, literal_vector & r) { void card_extension::get_antecedents(literal l, ext_justification_idx idx, literal_vector & r) {
card& c = *m_constraints[idx]; if (is_card_index(idx)) {
card& c = index2card(idx);
DEBUG_CODE( DEBUG_CODE(
bool found = false; bool found = false;
@ -607,6 +884,26 @@ namespace sat {
r.push_back(~c[i]); r.push_back(~c[i]);
} }
} }
else {
xor& x = index2xor(idx);
r.push_back(x.lit());
TRACE("sat", display(tout << l << " ", x, true););
SASSERT(value(x.lit()) == l_true);
SASSERT(x[0].var() == l.var() || x[1].var() == l.var());
if (x[0].var() == l.var()) {
SASSERT(value(x[1]) != l_undef);
r.push_back(value(x[1]) == l_true ? x[1] : ~x[1]);
}
else {
SASSERT(value(x[0]) != l_undef);
r.push_back(value(x[0]) == l_true ? x[0] : ~x[0]);
}
for (unsigned i = 2; i < x.size(); ++i) {
SASSERT(value(x[i]) != l_undef);
r.push_back(value(x[i]) == l_true ? x[i] : ~x[i]);
}
}
}
lbool card_extension::add_assign(card& c, literal alit) { lbool card_extension::add_assign(card& c, literal alit) {
@ -670,10 +967,11 @@ namespace sat {
if (s().inconsistent()) return; if (s().inconsistent()) return;
if (v >= m_var_infos.size()) return; if (v >= m_var_infos.size()) return;
var_info& vinfo = m_var_infos[v]; var_info& vinfo = m_var_infos[v];
ptr_vector<card>* cards = vinfo.m_lit_watch[!l.sign()]; ptr_vector<card>* cards = vinfo.m_card_watch[!l.sign()];
//TRACE("sat", tout << "retrieve: " << v << " " << !l.sign() << "\n";); card* crd = vinfo.m_card;
//TRACE("sat", tout << "asserted: " << l << " " << (cards ? "non-empty" : "empty") << "\n";); xor* x = vinfo.m_xor;
static unsigned is_empty = 0, non_empty = 0; ptr_vector<xor>* xors = vinfo.m_xor_watch;
if (!is_tag_empty(cards)) { if (!is_tag_empty(cards)) {
ptr_vector<card>::iterator begin = cards->begin(); ptr_vector<card>::iterator begin = cards->begin();
ptr_vector<card>::iterator it = begin, it2 = it, end = cards->end(); ptr_vector<card>::iterator it = begin, it2 = it, end = cards->end();
@ -702,14 +1000,56 @@ namespace sat {
} }
cards->set_end(it2); cards->set_end(it2);
if (cards->empty()) { if (cards->empty()) {
m_var_infos[v].m_lit_watch[!l.sign()] = set_tag_empty(cards); m_var_infos[v].m_card_watch[!l.sign()] = set_tag_empty(cards);
} }
} }
card* crd = vinfo.m_card;
if (crd != 0 && !s().inconsistent()) { if (crd != 0 && !s().inconsistent()) {
init_watch(*crd, !l.sign()); init_watch(*crd, !l.sign());
} }
if (m_has_xor && !s().inconsistent()) {
asserted_xor(l, xors, x);
}
}
void card_extension::asserted_xor(literal l, ptr_vector<xor>* xors, xor* x) {
TRACE("sat", tout << l << " " << !is_tag_empty(xors) << " " << (x != 0) << "\n";);
if (!is_tag_empty(xors)) {
ptr_vector<xor>::iterator begin = xors->begin();
ptr_vector<xor>::iterator it = begin, it2 = it, end = xors->end();
for (; it != end; ++it) {
xor& c = *(*it);
if (value(c.lit()) != l_true) {
continue;
}
switch (add_assign(c, ~l)) {
case l_false: // conflict
for (; it != end; ++it, ++it2) {
*it2 = *it;
}
SASSERT(s().inconsistent());
xors->set_end(it2);
return;
case l_undef: // watch literal was swapped
break;
case l_true: // unit propagation, keep watching the literal
if (it2 != it) {
*it2 = *it;
}
++it2;
break;
}
}
xors->set_end(it2);
if (xors->empty()) {
m_var_infos[l.var()].m_xor_watch = set_tag_empty(xors);
}
}
if (x != 0 && !s().inconsistent()) {
init_watch(*x, !l.sign());
}
} }
check_result card_extension::check() { return CR_DONE; } check_result card_extension::check() { return CR_DONE; }
@ -730,6 +1070,10 @@ namespace sat {
clear_watch(*c); clear_watch(*c);
m_var_infos[v].m_card = 0; m_var_infos[v].m_card = 0;
dealloc(c); dealloc(c);
xor* x = m_var_infos[v].m_xor;
clear_watch(*x);
m_var_infos[v].m_xor = 0;
dealloc(x);
} }
} }
m_var_lim.resize(new_lim); m_var_lim.resize(new_lim);
@ -743,22 +1087,30 @@ namespace sat {
extension* card_extension::copy(solver* s) { extension* card_extension::copy(solver* s) {
card_extension* result = alloc(card_extension); card_extension* result = alloc(card_extension);
result->set_solver(s); result->set_solver(s);
for (unsigned i = 0; i < m_constraints.size(); ++i) { for (unsigned i = 0; i < m_cards.size(); ++i) {
literal_vector lits; literal_vector lits;
card& c = *m_constraints[i]; card& c = *m_cards[i];
for (unsigned i = 0; i < c.size(); ++i) { for (unsigned i = 0; i < c.size(); ++i) {
lits.push_back(c[i]); lits.push_back(c[i]);
} }
result->add_at_least(c.lit().var(), lits, c.k()); result->add_at_least(c.lit().var(), lits, c.k());
} }
for (unsigned i = 0; i < m_xors.size(); ++i) {
literal_vector lits;
xor& x = *m_xors[i];
for (unsigned i = 0; i < x.size(); ++i) {
lits.push_back(x[i]);
}
result->add_xor(x.lit().var(), lits);
}
return result; return result;
} }
void card_extension::find_mutexes(literal_vector& lits, vector<literal_vector> & mutexes) { void card_extension::find_mutexes(literal_vector& lits, vector<literal_vector> & mutexes) {
literal_set slits(lits); literal_set slits(lits);
bool change = false; bool change = false;
for (unsigned i = 0; i < m_constraints.size(); ++i) { for (unsigned i = 0; i < m_cards.size(); ++i) {
card& c = *m_constraints[i]; card& c = *m_cards[i];
if (c.size() == c.k() + 1) { if (c.size() == c.k() + 1) {
literal_vector mux; literal_vector mux;
for (unsigned j = 0; j < c.size(); ++j) { for (unsigned j = 0; j < c.size(); ++j) {
@ -787,9 +1139,9 @@ namespace sat {
} }
void card_extension::display_watch(std::ostream& out, bool_var v, bool sign) const { void card_extension::display_watch(std::ostream& out, bool_var v, bool sign) const {
watch const* w = m_var_infos[v].m_lit_watch[sign]; card_watch const* w = m_var_infos[v].m_card_watch[sign];
if (!is_tag_empty(w)) { if (!is_tag_empty(w)) {
watch const& wl = *w; card_watch const& wl = *w;
out << literal(v, sign) << " |-> "; out << literal(v, sign) << " |-> ";
for (unsigned i = 0; i < wl.size(); ++i) { for (unsigned i = 0; i < wl.size(); ++i) {
out << wl[i]->lit() << " "; out << wl[i]->lit() << " ";
@ -798,6 +1150,18 @@ namespace sat {
} }
} }
void card_extension::display_watch(std::ostream& out, bool_var v) const {
xor_watch const* w = m_var_infos[v].m_xor_watch;
if (!is_tag_empty(w)) {
xor_watch const& wl = *w;
out << "v" << v << " |-> ";
for (unsigned i = 0; i < wl.size(); ++i) {
out << wl[i]->lit() << " ";
}
out << "\n";
}
}
void card_extension::display(std::ostream& out, ineq& ineq) const { void card_extension::display(std::ostream& out, ineq& ineq) const {
for (unsigned i = 0; i < ineq.m_lits.size(); ++i) { for (unsigned i = 0; i < ineq.m_lits.size(); ++i) {
out << ineq.m_coeffs[i] << "*" << ineq.m_lits[i] << " "; out << ineq.m_coeffs[i] << "*" << ineq.m_lits[i] << " ";
@ -805,6 +1169,35 @@ namespace sat {
out << ">= " << ineq.m_k << "\n"; out << ">= " << ineq.m_k << "\n";
} }
void card_extension::display(std::ostream& out, xor& x, bool values) const {
out << "xor " << x.lit();
if (x.lit() != null_literal && values) {
out << "@(" << value(x.lit());
if (value(x.lit()) != l_undef) {
out << ":" << lvl(x.lit());
}
out << "): ";
}
else {
out << ": ";
}
for (unsigned i = 0; i < x.size(); ++i) {
literal l = x[i];
out << l;
if (values) {
out << "@(" << value(l);
if (value(l) != l_undef) {
out << ":" << lvl(l);
}
out << ") ";
}
else {
out << " ";
}
}
out << "\n";
}
void card_extension::display(std::ostream& out, card& c, bool values) const { void card_extension::display(std::ostream& out, card& c, bool values) const {
out << c.lit() << "[" << c.size() << "]"; out << c.lit() << "[" << c.size() << "]";
if (c.lit() != null_literal && values) { if (c.lit() != null_literal && values) {
@ -838,23 +1231,33 @@ namespace sat {
for (unsigned vi = 0; vi < m_var_infos.size(); ++vi) { for (unsigned vi = 0; vi < m_var_infos.size(); ++vi) {
display_watch(out, vi, false); display_watch(out, vi, false);
display_watch(out, vi, true); display_watch(out, vi, true);
display_watch(out, vi);
} }
for (unsigned vi = 0; vi < m_var_infos.size(); ++vi) { for (unsigned vi = 0; vi < m_var_infos.size(); ++vi) {
card* c = m_var_infos[vi].m_card; card* c = m_var_infos[vi].m_card;
if (c) { if (c) display(out, *c, false);
display(out, *c, false); xor* x = m_var_infos[vi].m_xor;
} if (x) display(out, *x, false);
} }
return out; return out;
} }
std::ostream& card_extension::display_justification(std::ostream& out, ext_justification_idx idx) const { std::ostream& card_extension::display_justification(std::ostream& out, ext_justification_idx idx) const {
card& c = *m_constraints[idx]; if (is_card_index(idx)) {
card& c = index2card(idx);
out << "bound " << c.lit() << ": "; out << "bound " << c.lit() << ": ";
for (unsigned i = 0; i < c.size(); ++i) { for (unsigned i = 0; i < c.size(); ++i) {
out << c[i] << " "; out << c[i] << " ";
} }
out << ">= " << c.k(); out << ">= " << c.k();
}
else {
xor& x = index2xor(idx);
out << "xor " << x.lit() << ": ";
for (unsigned i = 0; i < x.size(); ++i) {
out << x[i] << " ";
}
}
return out; return out;
} }
@ -870,6 +1273,9 @@ namespace sat {
} }
return false; return false;
} }
bool card_extension::validate_conflict(xor& x) {
return !parity(x, 0);
}
bool card_extension::validate_unit_propagation(card const& c) { bool card_extension::validate_unit_propagation(card const& c) {
if (value(c.lit()) != l_true) return false; if (value(c.lit()) != l_true) return false;
for (unsigned i = c.k(); i < c.size(); ++i) { for (unsigned i = c.k(); i < c.size(); ++i) {
@ -933,12 +1339,23 @@ namespace sat {
} }
case justification::EXT_JUSTIFICATION: { case justification::EXT_JUSTIFICATION: {
unsigned index = js.get_ext_justification_idx(); unsigned index = js.get_ext_justification_idx();
card& c = *m_constraints[index]; if (is_card_index(index)) {
card& c = index2card(index);
p.reset(offset*c.k()); p.reset(offset*c.k());
for (unsigned i = 0; i < c.size(); ++i) { for (unsigned i = 0; i < c.size(); ++i) {
p.push(c[i], offset); p.push(c[i], offset);
} }
p.push(~c.lit(), offset*c.k()); p.push(~c.lit(), offset*c.k());
}
else {
literal_vector ls;
get_antecedents(lit, index, ls);
p.reset(offset);
for (unsigned i = 0; i < ls.size(); ++i) {
p.push(~ls[i], offset);
}
p.push(~index2xor(index).lit(), offset);
}
break; break;
} }
default: default:

View file

@ -32,9 +32,7 @@ namespace sat {
void reset() { memset(this, 0, sizeof(*this)); } void reset() { memset(this, 0, sizeof(*this)); }
}; };
// class card_allocator;
class card { class card {
//friend class card_allocator;
unsigned m_index; unsigned m_index;
literal m_lit; literal m_lit;
unsigned m_k; unsigned m_k;
@ -53,6 +51,22 @@ namespace sat {
void negate(); void negate();
}; };
class xor {
unsigned m_index;
literal m_lit;
unsigned m_size;
literal m_lits[0];
public:
static size_t get_obj_size(unsigned num_lits) { return sizeof(xor) + num_lits * sizeof(literal); }
xor(unsigned index, literal lit, literal_vector const& lits);
unsigned index() const { return m_index; }
literal lit() const { return m_lit; }
literal operator[](unsigned i) const { return m_lits[i]; }
unsigned size() const { return m_size; }
void swap(unsigned i, unsigned j) { std::swap(m_lits[i], m_lits[j]); }
void negate() { m_lits[0].neg(); }
};
struct ineq { struct ineq {
literal_vector m_lits; literal_vector m_lits;
unsigned_vector m_coeffs; unsigned_vector m_coeffs;
@ -61,29 +75,48 @@ namespace sat {
void push(literal l, unsigned c) { m_lits.push_back(l); m_coeffs.push_back(c); } void push(literal l, unsigned c) { m_lits.push_back(l); m_coeffs.push_back(c); }
}; };
typedef ptr_vector<card> watch; typedef ptr_vector<card> card_watch;
typedef ptr_vector<xor> xor_watch;
struct var_info { struct var_info {
watch* m_lit_watch[2]; card_watch* m_card_watch[2];
xor_watch* m_xor_watch;
card* m_card; card* m_card;
var_info(): m_card(0) { xor* m_xor;
m_lit_watch[0] = 0; var_info(): m_xor_watch(0), m_card(0), m_xor(0) {
m_lit_watch[1] = 0; m_card_watch[0] = 0;
m_card_watch[1] = 0;
} }
void reset() { void reset() {
dealloc(m_card); dealloc(m_card);
dealloc(card_extension::set_tag_non_empty(m_lit_watch[0])); dealloc(m_xor);
dealloc(card_extension::set_tag_non_empty(m_lit_watch[1])); dealloc(card_extension::set_tag_non_empty(m_card_watch[0]));
dealloc(card_extension::set_tag_non_empty(m_card_watch[1]));
dealloc(card_extension::set_tag_non_empty(m_xor_watch));
} }
}; };
static ptr_vector<card>* set_tag_empty(ptr_vector<card>* c); template<typename T>
static bool is_tag_empty(ptr_vector<card> const* c); static ptr_vector<T>* set_tag_empty(ptr_vector<T>* c) {
static ptr_vector<card>* set_tag_non_empty(ptr_vector<card>* c); return TAG(ptr_vector<T>*, c, 1);
}
template<typename T>
static bool is_tag_empty(ptr_vector<T> const* c) {
return !c || GET_TAG(c) == 1;
}
template<typename T>
static ptr_vector<T>* set_tag_non_empty(ptr_vector<T>* c) {
return UNTAG(ptr_vector<T>*, c);
}
solver* m_solver; solver* m_solver;
stats m_stats; stats m_stats;
ptr_vector<card> m_constraints; ptr_vector<card> m_cards;
ptr_vector<xor> m_xors;
// watch literals // watch literals
svector<var_info> m_var_infos; svector<var_info> m_var_infos;
@ -98,8 +131,14 @@ namespace sat {
int m_bound; int m_bound;
tracked_uint_set m_active_var_set; tracked_uint_set m_active_var_set;
literal_vector m_lemma; literal_vector m_lemma;
// literal_vector m_literals;
unsigned m_num_propagations_since_pop; unsigned m_num_propagations_since_pop;
bool m_has_xor;
unsigned_vector m_parity_marks;
literal_vector m_parity_trail;
void ensure_parity_size(bool_var v);
unsigned get_parity(bool_var v);
void inc_parity(bool_var v);
void reset_parity(bool_var v);
solver& s() const { return *m_solver; } solver& s() const { return *m_solver; }
void init_watch(card& c, bool is_true); void init_watch(card& c, bool is_true);
@ -111,13 +150,44 @@ namespace sat {
void clear_watch(card& c); void clear_watch(card& c);
void reset_coeffs(); void reset_coeffs();
void reset_marked_literals(); void reset_marked_literals();
void unwatch_literal(literal w, card* c);
// xor specific functionality
void clear_watch(xor& x);
void watch_literal(xor& x, literal lit);
void unwatch_literal(literal w, xor* x);
void init_watch(xor& x, bool is_true);
void assign(xor& x, literal lit);
void set_conflict(xor& x, literal lit);
bool parity(xor const& x, unsigned offset) const;
lbool add_assign(xor& x, literal alit);
void asserted_xor(literal l, ptr_vector<xor>* xors, xor* x);
bool is_card_index(unsigned idx) const { return 0 == (idx & 0x1); }
card& index2card(unsigned idx) const { SASSERT(is_card_index(idx)); return *m_cards[idx >> 1]; }
xor& index2xor(unsigned idx) const { SASSERT(!is_card_index(idx)); return *m_xors[idx >> 1]; }
void get_xor_antecedents(unsigned index, literal_vector& r);
template<typename T>
bool remove(ptr_vector<T>& ts, T* t) {
unsigned sz = ts.size();
for (unsigned j = 0; j < sz; ++j) {
if (ts[j] == t) {
std::swap(ts[j], ts[sz-1]);
ts.pop_back();
return sz == 1;
}
}
return false;
}
inline lbool value(literal lit) const { return m_solver->value(lit); } inline lbool value(literal lit) const { return m_solver->value(lit); }
inline unsigned lvl(literal lit) const { return m_solver->lvl(lit); } inline unsigned lvl(literal lit) const { return m_solver->lvl(lit); }
inline unsigned lvl(bool_var v) const { return m_solver->lvl(v); } inline unsigned lvl(bool_var v) const { return m_solver->lvl(v); }
void unwatch_literal(literal w, card* c);
bool remove(ptr_vector<card>& cards, card* c);
void normalize_active_coeffs(); void normalize_active_coeffs();
void inc_coeff(literal l, int offset); void inc_coeff(literal l, int offset);
@ -131,6 +201,7 @@ namespace sat {
// validation utilities // validation utilities
bool validate_conflict(card& c); bool validate_conflict(card& c);
bool validate_conflict(xor& x);
bool validate_assign(literal_vector const& lits, literal lit); bool validate_assign(literal_vector const& lits, literal lit);
bool validate_lemma(); bool validate_lemma();
bool validate_unit_propagation(card const& c); bool validate_unit_propagation(card const& c);
@ -143,12 +214,16 @@ namespace sat {
void display(std::ostream& out, ineq& p) const; void display(std::ostream& out, ineq& p) const;
void display(std::ostream& out, card& c, bool values) const; void display(std::ostream& out, card& c, bool values) const;
void display(std::ostream& out, xor& c, bool values) const;
void display_watch(std::ostream& out, bool_var v) const;
void display_watch(std::ostream& out, bool_var v, bool sign) const; void display_watch(std::ostream& out, bool_var v, bool sign) const;
public: public:
card_extension(); card_extension();
virtual ~card_extension(); virtual ~card_extension();
virtual void set_solver(solver* s) { m_solver = s; } virtual void set_solver(solver* s) { m_solver = s; }
void add_at_least(bool_var v, literal_vector const& lits, unsigned k); void add_at_least(bool_var v, literal_vector const& lits, unsigned k);
void add_xor(bool_var v, literal_vector const& lits);
virtual void propagate(literal l, ext_constraint_idx idx, bool & keep); virtual void propagate(literal l, ext_constraint_idx idx, bool & keep);
virtual bool resolve_conflict(); virtual bool resolve_conflict();
virtual void get_antecedents(literal l, ext_justification_idx idx, literal_vector & r); virtual void get_antecedents(literal l, ext_justification_idx idx, literal_vector & r);

View file

@ -26,5 +26,5 @@ def_module_params('sat',
('dimacs.core', BOOL, False, 'extract core from DIMACS benchmarks'), ('dimacs.core', BOOL, False, 'extract core from DIMACS benchmarks'),
('drat.file', SYMBOL, '', 'file to dump DRAT proofs'), ('drat.file', SYMBOL, '', 'file to dump DRAT proofs'),
('drat.check', BOOL, False, 'build up internal proof and check'), ('drat.check', BOOL, False, 'build up internal proof and check'),
('cardinality.solver', BOOL, True, 'use cardinality solver'), ('cardinality.solver', BOOL, False, 'use cardinality/xor solver'),
)) ))

View file

@ -217,6 +217,7 @@ public:
sat_params p1(p); sat_params p1(p);
m_params.set_bool("elim_vars", false); m_params.set_bool("elim_vars", false);
m_params.set_bool("keep_cardinality_constraints", p1.cardinality_solver()); m_params.set_bool("keep_cardinality_constraints", p1.cardinality_solver());
m_params.set_bool("cardinality_solver", p1.cardinality_solver());
m_solver.updt_params(m_params); m_solver.updt_params(m_params);
m_optimize_model = m_params.get_bool("optimize_model", false); m_optimize_model = m_params.get_bool("optimize_model", false);

View file

@ -65,6 +65,7 @@ struct goal2sat::imp {
expr_ref_vector m_trail; expr_ref_vector m_trail;
expr_ref_vector m_interpreted_atoms; expr_ref_vector m_interpreted_atoms;
bool m_default_external; bool m_default_external;
bool m_cardinality_solver;
imp(ast_manager & _m, params_ref const & p, sat::solver & s, atom2bool_var & map, dep2asm_map& dep2asm, bool default_external): imp(ast_manager & _m, params_ref const & p, sat::solver & s, atom2bool_var & map, dep2asm_map& dep2asm, bool default_external):
m(_m), m(_m),
@ -83,6 +84,8 @@ struct goal2sat::imp {
void updt_params(params_ref const & p) { void updt_params(params_ref const & p) {
m_ite_extra = p.get_bool("ite_extra", true); m_ite_extra = p.get_bool("ite_extra", true);
m_max_memory = megabytes_to_bytes(p.get_uint("max_memory", UINT_MAX)); m_max_memory = megabytes_to_bytes(p.get_uint("max_memory", UINT_MAX));
m_cardinality_solver = p.get_bool("cardinality_solver", false);
std::cout << p << "\n";
} }
void throw_op_not_handled(std::string const& s) { void throw_op_not_handled(std::string const& s) {
@ -339,7 +342,7 @@ struct goal2sat::imp {
} }
} }
void convert_iff(app * t, bool root, bool sign) { void convert_iff2(app * t, bool root, bool sign) {
TRACE("goal2sat", tout << "convert_iff " << root << " " << sign << "\n" << mk_ismt2_pp(t, m) << "\n";); TRACE("goal2sat", tout << "convert_iff " << root << " " << sign << "\n" << mk_ismt2_pp(t, m) << "\n";);
unsigned sz = m_result_stack.size(); unsigned sz = m_result_stack.size();
SASSERT(sz >= 2); SASSERT(sz >= 2);
@ -372,8 +375,33 @@ struct goal2sat::imp {
} }
} }
void convert_pb_args(app* t, sat::literal_vector& lits) { void convert_iff(app * t, bool root, bool sign) {
unsigned num_args = t->get_num_args(); TRACE("goal2sat", tout << "convert_iff " << root << " " << sign << "\n" << mk_ismt2_pp(t, m) << "\n";);
unsigned sz = m_result_stack.size();
unsigned num = get_num_args(t);
SASSERT(sz >= num && num >= 2);
if (num == 2) {
convert_iff2(t, root, sign);
return;
}
sat::literal_vector lits;
convert_pb_args(num, lits);
sat::bool_var v = m_solver.mk_var(true);
ensure_extension();
if (lits.size() % 2 == 0) lits[0].neg();
m_ext->add_xor(v, lits);
sat::literal lit(v, sign);
if (root) {
m_result_stack.reset();
mk_clause(lit);
}
else {
m_result_stack.shrink(sz - num);
m_result_stack.push_back(lit);
}
}
void convert_pb_args(unsigned num_args, sat::literal_vector& lits) {
unsigned sz = m_result_stack.size(); unsigned sz = m_result_stack.size();
for (unsigned i = 0; i < num_args; ++i) { for (unsigned i = 0; i < num_args; ++i) {
sat::literal lit(m_result_stack[sz - num_args + i]); sat::literal lit(m_result_stack[sz - num_args + i]);
@ -396,7 +424,7 @@ struct goal2sat::imp {
SASSERT(k.is_unsigned()); SASSERT(k.is_unsigned());
sat::literal_vector lits; sat::literal_vector lits;
unsigned sz = m_result_stack.size(); unsigned sz = m_result_stack.size();
convert_pb_args(t, lits); convert_pb_args(t->get_num_args(), lits);
sat::bool_var v = m_solver.mk_var(true); sat::bool_var v = m_solver.mk_var(true);
sat::literal lit(v, sign); sat::literal lit(v, sign);
m_ext->add_at_least(v, lits, k.get_unsigned()); m_ext->add_at_least(v, lits, k.get_unsigned());
@ -415,7 +443,7 @@ struct goal2sat::imp {
SASSERT(k.is_unsigned()); SASSERT(k.is_unsigned());
sat::literal_vector lits; sat::literal_vector lits;
unsigned sz = m_result_stack.size(); unsigned sz = m_result_stack.size();
convert_pb_args(t, lits); convert_pb_args(t->get_num_args(), lits);
for (unsigned i = 0; i < lits.size(); ++i) { for (unsigned i = 0; i < lits.size(); ++i) {
lits[i].neg(); lits[i].neg();
} }
@ -434,7 +462,7 @@ struct goal2sat::imp {
void convert_eq_k(app* t, rational k, bool root, bool sign) { void convert_eq_k(app* t, rational k, bool root, bool sign) {
SASSERT(k.is_unsigned()); SASSERT(k.is_unsigned());
sat::literal_vector lits; sat::literal_vector lits;
convert_pb_args(t, lits); convert_pb_args(t->get_num_args(), lits);
sat::bool_var v1 = m_solver.mk_var(true); sat::bool_var v1 = m_solver.mk_var(true);
sat::bool_var v2 = m_solver.mk_var(true); sat::bool_var v2 = m_solver.mk_var(true);
sat::literal l1(v1, false), l2(v2, false); sat::literal l1(v1, false), l2(v2, false);
@ -529,6 +557,41 @@ struct goal2sat::imp {
} }
} }
unsigned get_num_args(app* t) {
if (m.is_iff(t) && m_cardinality_solver) {
unsigned n = 2;
while (m.is_iff(t->get_arg(1))) {
++n;
t = to_app(t->get_arg(1));
}
return n;
}
else {
return t->get_num_args();
}
}
expr* get_arg(app* t, unsigned idx) {
if (m.is_iff(t) && m_cardinality_solver) {
while (idx >= 1) {
SASSERT(m.is_iff(t));
t = to_app(t->get_arg(1));
--idx;
}
if (m.is_iff(t)) {
return t->get_arg(idx);
}
else {
return t;
}
}
else {
return t->get_arg(idx);
}
}
void process(expr * n) { void process(expr * n) {
//SASSERT(m_result_stack.empty()); //SASSERT(m_result_stack.empty());
TRACE("goal2sat", tout << "converting: " << mk_ismt2_pp(n, m) << "\n";); TRACE("goal2sat", tout << "converting: " << mk_ismt2_pp(n, m) << "\n";);
@ -559,9 +622,9 @@ struct goal2sat::imp {
visit(t->get_arg(0), root, !sign); visit(t->get_arg(0), root, !sign);
continue; continue;
} }
unsigned num = t->get_num_args(); unsigned num = get_num_args(t);
while (fr.m_idx < num) { while (fr.m_idx < num) {
expr * arg = t->get_arg(fr.m_idx); expr * arg = get_arg(t, fr.m_idx);
fr.m_idx++; fr.m_idx++;
if (!visit(arg, false, false)) if (!visit(arg, false, false))
goto loop; goto loop;