3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-10 19:27:06 +00:00

add xor parity solver feature

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2017-02-20 16:55:00 -08:00
parent cb050998e5
commit 98c5a779b4
6 changed files with 665 additions and 108 deletions

View file

@ -50,6 +50,7 @@ struct pb2bv_rewriter::imp {
rational m_k;
vector<rational> m_coeffs;
bool m_keep_cardinality_constraints;
unsigned m_min_arity;
template<lbool is_le>
expr_ref mk_le_ge(expr_ref_vector& fmls, expr* a, expr* b, expr* bound) {
@ -416,7 +417,8 @@ struct pb2bv_rewriter::imp {
bv(m),
m_trail(m),
m_args(m),
m_keep_cardinality_constraints(true)
m_keep_cardinality_constraints(true),
m_min_arity(8)
{}
bool mk_app(bool full, func_decl * f, unsigned sz, expr * const* args, expr_ref & result) {
@ -530,27 +532,26 @@ struct pb2bv_rewriter::imp {
bool mk_pb(bool full, func_decl * f, unsigned sz, expr * const* args, expr_ref & result) {
SASSERT(f->get_family_id() == pb.get_family_id());
if (is_or(f)) {
if (m_keep_cardinality_constraints) return false;
result = m.mk_or(sz, args);
}
else if (pb.is_at_most_k(f) && pb.get_k(f).is_unsigned()) {
if (m_keep_cardinality_constraints) return false;
if (m_keep_cardinality_constraints && f->get_arity() >= m_min_arity) return false;
result = m_sort.le(full, pb.get_k(f).get_unsigned(), sz, args);
}
else if (pb.is_at_least_k(f) && pb.get_k(f).is_unsigned()) {
if (m_keep_cardinality_constraints) return false;
if (m_keep_cardinality_constraints && f->get_arity() >= m_min_arity) return false;
result = m_sort.ge(full, pb.get_k(f).get_unsigned(), sz, args);
}
else if (pb.is_eq(f) && pb.get_k(f).is_unsigned() && pb.has_unit_coefficients(f)) {
if (m_keep_cardinality_constraints) return false;
if (m_keep_cardinality_constraints && f->get_arity() >= m_min_arity) return false;
result = m_sort.eq(full, pb.get_k(f).get_unsigned(), sz, args);
}
else if (pb.is_le(f) && pb.get_k(f).is_unsigned() && pb.has_unit_coefficients(f)) {
if (m_keep_cardinality_constraints) return false;
if (m_keep_cardinality_constraints && f->get_arity() >= m_min_arity) return false;
result = m_sort.le(full, pb.get_k(f).get_unsigned(), sz, args);
}
else if (pb.is_ge(f) && pb.get_k(f).is_unsigned() && pb.has_unit_coefficients(f)) {
if (m_keep_cardinality_constraints) return false;
if (m_keep_cardinality_constraints && f->get_arity() >= m_min_arity) return false;
result = m_sort.ge(full, pb.get_k(f).get_unsigned(), sz, args);
}
else {

View file

@ -7,7 +7,7 @@ Module Name:
Abstract:
Extension for cardinality reasoning.
Extension for cardinality and xor reasoning.
Author:
@ -42,6 +42,16 @@ namespace sat {
SASSERT(m_size >= m_k && m_k > 0);
}
card_extension::xor::xor(unsigned index, literal lit, literal_vector const& lits):
m_index(index),
m_lit(lit),
m_size(lits.size())
{
for (unsigned i = 0; i < lits.size(); ++i) {
m_lits[i] = lits[i];
}
}
void card_extension::init_watch(bool_var v) {
if (m_var_infos.size() <= static_cast<unsigned>(v)) {
m_var_infos.resize(static_cast<unsigned>(v)+100);
@ -120,7 +130,7 @@ namespace sat {
if (m_var_infos.size() <= static_cast<unsigned>(lit.var())) {
return;
}
ptr_vector<card>*& cards = m_var_infos[lit.var()].m_lit_watch[lit.sign()];
ptr_vector<card>*& cards = m_var_infos[lit.var()].m_card_watch[lit.sign()];
if (!is_tag_empty(cards)) {
if (remove(*cards, c)) {
cards = set_tag_empty(cards);
@ -128,30 +138,6 @@ namespace sat {
}
}
ptr_vector<card_extension::card>* card_extension::set_tag_empty(ptr_vector<card>* c) {
return TAG(ptr_vector<card>*, c, 1);
}
bool card_extension::is_tag_empty(ptr_vector<card> const* c) {
return !c || GET_TAG(c) == 1;
}
ptr_vector<card_extension::card>* card_extension::set_tag_non_empty(ptr_vector<card>* c) {
return UNTAG(ptr_vector<card>*, c);
}
bool card_extension::remove(ptr_vector<card>& cards, card* c) {
unsigned sz = cards.size();
for (unsigned j = 0; j < sz; ++j) {
if (cards[j] == c) {
std::swap(cards[j], cards[sz-1]);
cards.pop_back();
return sz == 1;
}
}
return false;
}
void card_extension::assign(card& c, literal lit) {
switch (value(lit)) {
case l_true:
@ -183,14 +169,14 @@ namespace sat {
void card_extension::watch_literal(card& c, literal lit) {
TRACE("sat_verbose", tout << "watch: " << lit << "\n";);
init_watch(lit.var());
ptr_vector<card>* cards = m_var_infos[lit.var()].m_lit_watch[lit.sign()];
ptr_vector<card>* cards = m_var_infos[lit.var()].m_card_watch[lit.sign()];
if (cards == 0) {
cards = alloc(ptr_vector<card>);
m_var_infos[lit.var()].m_lit_watch[lit.sign()] = cards;
m_var_infos[lit.var()].m_card_watch[lit.sign()] = cards;
}
else if (is_tag_empty(cards)) {
cards = set_tag_non_empty(cards);
m_var_infos[lit.var()].m_lit_watch[lit.sign()] = cards;
m_var_infos[lit.var()].m_card_watch[lit.sign()] = cards;
}
TRACE("sat_verbose", tout << "insert: " << lit.var() << " " << lit.sign() << "\n";);
cards->push_back(&c);
@ -202,6 +188,155 @@ namespace sat {
s().set_conflict(justification::mk_ext_justification(c.index()), ~lit);
SASSERT(s().inconsistent());
}
void card_extension::clear_watch(xor& x) {
unwatch_literal(x[0], &x);
unwatch_literal(x[1], &x);
}
void card_extension::unwatch_literal(literal lit, xor* c) {
if (m_var_infos.size() <= static_cast<unsigned>(lit.var())) {
return;
}
xor_watch* xors = m_var_infos[lit.var()].m_xor_watch;
if (!is_tag_empty(xors)) {
if (remove(*xors, c)) {
xors = set_tag_empty(xors);
}
}
}
bool card_extension::parity(xor const& x, unsigned offset) const {
bool odd = false;
unsigned sz = x.size();
for (unsigned i = offset; i < sz; ++i) {
SASSERT(value(x[i]) != l_undef);
if (value(x[i]) == l_true) {
odd = !odd;
}
}
return odd;
}
void card_extension::init_watch(xor& x, bool is_true) {
clear_watch(x);
if (x.lit().sign() == is_true) {
x.negate();
}
unsigned sz = x.size();
unsigned j = 0;
for (unsigned i = 0; i < sz && j < 2; ++i) {
if (value(x[i]) == l_undef) {
x.swap(i, j);
++j;
}
}
switch (j) {
case 0:
if (!parity(x, 0)) {
set_conflict(x, x[0]);
}
break;
case 1:
assign(x, parity(x, 1) ? ~x[0] : x[0]);
break;
default:
SASSERT(j == 2);
watch_literal(x, x[0]);
watch_literal(x, x[1]);
break;
}
}
void card_extension::assign(xor& x, literal lit) {
switch (value(lit)) {
case l_true:
break;
case l_false:
set_conflict(x, lit);
break;
default:
m_stats.m_num_propagations++;
m_num_propagations_since_pop++;
if (s().m_config.m_drat) {
svector<drat::premise> ps;
literal_vector lits;
lits.push_back(~x.lit());
for (unsigned i = 1; i < x.size(); ++i) {
lits.push_back(x[i]);
}
lits.push_back(lit);
ps.push_back(drat::premise(drat::s_ext(), x.lit()));
s().m_drat.add(lits, ps);
}
s().assign(lit, justification::mk_ext_justification(x.index()));
break;
}
}
void card_extension::watch_literal(xor& x, literal lit) {
TRACE("sat_verbose", tout << "watch: " << lit << "\n";);
init_watch(lit.var());
xor_watch*& xors = m_var_infos[lit.var()].m_xor_watch;
if (xors == 0) {
xors = alloc(ptr_vector<xor>);
}
else if (is_tag_empty(xors)) {
xors = set_tag_non_empty(xors);
}
xors->push_back(&x);
TRACE("sat_verbose", tout << "insert: " << lit.var() << " " << lit.sign() << "\n";);
}
void card_extension::set_conflict(xor& x, literal lit) {
TRACE("sat", display(tout, x, true); );
SASSERT(validate_conflict(x));
s().set_conflict(justification::mk_ext_justification(x.index()), ~lit);
SASSERT(s().inconsistent());
}
lbool card_extension::add_assign(xor& x, literal alit) {
// literal is assigned
unsigned sz = x.size();
TRACE("sat", tout << "assign: " << x.lit() << ": " << ~alit << "@" << lvl(~alit) << "\n";);
SASSERT(value(alit) != l_undef);
SASSERT(value(x.lit()) == l_true);
unsigned index = 0;
for (; index <= 2; ++index) {
if (x[index].var() == alit.var()) break;
}
if (index == 2) {
// literal is no longer watched.
return l_undef;
}
SASSERT(x[index].var() == alit.var());
// find a literal to swap with:
for (unsigned i = 2; i < sz; ++i) {
literal lit2 = x[i];
if (value(lit2) == l_undef) {
x.swap(index, i);
watch_literal(x, lit2);
return l_undef;
}
}
if (index == 0) {
x.swap(0, 1);
}
// alit resides at index 1.
SASSERT(x[1].var() == alit.var());
if (value(x[0]) == l_undef) {
bool p = parity(x, 1);
assign(x, p ? ~x[0] : x[0]);
}
else if (!parity(x, 0)) {
set_conflict(x, x[0]);
}
return s().inconsistent() ? l_false : l_true;
}
void card_extension::normalize_active_coeffs() {
while (!m_active_var_set.empty()) m_active_var_set.erase();
@ -288,6 +423,8 @@ namespace sat {
unsigned init_marks = m_num_marks;
vector<justification> jus;
do {
if (offset == 0) {
@ -349,9 +486,22 @@ namespace sat {
}
case justification::EXT_JUSTIFICATION: {
unsigned index = js.get_ext_justification_idx();
card& c = *m_constraints[index];
m_bound += offset * c.k();
process_card(c, offset);
if (is_card_index(index)) {
card& c = index2card(index);
m_bound += offset * c.k();
process_card(c, offset);
}
else {
// jus.push_back(js);
m_lemma.reset();
m_bound += offset;
inc_coeff(consequent, offset);
get_xor_antecedents(idx, m_lemma);
// get_antecedents(consequent, index, m_lemma);
for (unsigned i = 0; i < m_lemma.size(); ++i) {
process_antecedent(~m_lemma[i], offset);
}
}
break;
}
default:
@ -424,7 +574,6 @@ namespace sat {
lbool val = m_solver->value(v);
bool is_true = val == l_true;
bool append = coeff != 0 && val != l_undef && (coeff < 0 == is_true);
if (append) {
literal lit(v, !is_true);
if (lvl(lit) == m_conflict_lvl) {
@ -440,6 +589,17 @@ namespace sat {
}
}
if (jus.size() > 1) {
std::cout << jus.size() << "\n";
for (unsigned i = 0; i < jus.size(); ++i) {
s().display_justification(std::cout, jus[i]); std::cout << "\n";
}
std::cout << m_lemma << "\n";
active2pb(m_A);
display(std::cout, m_A);
}
if (slack >= 0) {
IF_VERBOSE(2, verbose_stream() << "(sat.card bail slack objective not met " << slack << ")\n";);
goto bail_out;
@ -564,7 +724,7 @@ namespace sat {
return p;
}
card_extension::card_extension(): m_solver(0) {
card_extension::card_extension(): m_solver(0), m_has_xor(false) {
TRACE("sat", tout << this << "\n";);
}
@ -578,33 +738,170 @@ namespace sat {
}
void card_extension::add_at_least(bool_var v, literal_vector const& lits, unsigned k) {
unsigned index = m_constraints.size();
unsigned index = 2*m_cards.size();
card* c = new (memory::allocate(card::get_obj_size(lits.size()))) card(index, literal(v, false), lits, k);
m_constraints.push_back(c);
m_cards.push_back(c);
init_watch(v);
m_var_infos[v].m_card = c;
m_var_trail.push_back(v);
}
void card_extension::add_xor(bool_var v, literal_vector const& lits) {
m_has_xor = true;
unsigned index = 2*m_xors.size()+1;
xor* x = new (memory::allocate(xor::get_obj_size(lits.size()))) xor(index, literal(v, false), lits);
m_xors.push_back(x);
init_watch(v);
m_var_infos[v].m_xor = x;
m_var_trail.push_back(v);
}
void card_extension::propagate(literal l, ext_constraint_idx idx, bool & keep) {
UNREACHABLE();
}
void card_extension::get_antecedents(literal l, ext_justification_idx idx, literal_vector & r) {
card& c = *m_constraints[idx];
DEBUG_CODE(
bool found = false;
for (unsigned i = 0; !found && i < c.k(); ++i) {
found = c[i] == l;
void card_extension::ensure_parity_size(bool_var v) {
if (m_parity_marks.size() <= static_cast<unsigned>(v)) {
m_parity_marks.resize(static_cast<unsigned>(v) + 1, 0);
}
}
unsigned card_extension::get_parity(bool_var v) {
return m_parity_marks.get(v, 0);
}
void card_extension::inc_parity(bool_var v) {
ensure_parity_size(v);
m_parity_marks[v]++;
}
void card_extension::reset_parity(bool_var v) {
ensure_parity_size(v);
m_parity_marks[v] = 0;
}
/**
\brief perform parity resolution on xor premises.
The idea is to collect premises based on xor resolvents.
Variables that are repeated an even number of times cancel out.
*/
void card_extension::get_xor_antecedents(unsigned index, literal_vector& r) {
literal_vector const& lits = s().m_trail;
literal l = lits[index + 1];
unsigned level = lvl(l);
bool_var v = l.var();
SASSERT(s().m_justification[v].get_kind() == justification::EXT_JUSTIFICATION);
SASSERT(!is_card_index(s().m_justification[v].get_ext_justification_idx()));
unsigned num_marks = 0;
unsigned count = 0;
while (true) {
++count;
justification js = s().m_justification[v];
if (js.get_kind() == justification::EXT_JUSTIFICATION) {
unsigned idx = js.get_ext_justification_idx();
if (is_card_index(idx)) {
r.push_back(l);
}
else {
xor& x = index2xor(idx);
if (lvl(x.lit()) > 0) r.push_back(x.lit());
if (x[1].var() == l.var()) {
x.swap(0, 1);
}
SASSERT(x[0].var() == l.var());
for (unsigned i = 1; i < x.size(); ++i) {
literal lit(value(x[i]) == l_true ? x[i] : ~x[i]);
inc_parity(lit.var());
if (true || lvl(lit) == level) {
++num_marks;
}
else {
m_parity_trail.push_back(lit);
}
}
}
}
SASSERT(found););
else {
r.push_back(l);
}
while (num_marks > 0) {
l = lits[index];
v = l.var();
unsigned n = get_parity(v);
if (n > 0) {
reset_parity(v);
if (n > 1) {
IF_VERBOSE(2, verbose_stream() << "parity greater than 1: " << l << " " << n << "\n";);
}
if (n % 2 == 1) {
break;
}
IF_VERBOSE(2, verbose_stream() << "skip even parity: " << l << "\n";);
--num_marks;
}
--index;
}
if (num_marks == 0) {
break;
}
--index;
--num_marks;
}
// now walk the defined literals
for (unsigned i = 0; i < m_parity_trail.size(); ++i) {
literal lit = m_parity_trail[i];
if (get_parity(lit.var()) % 2 == 1) {
r.push_back(lit);
}
else {
IF_VERBOSE(2, verbose_stream() << "skip even parity: " << lit << "\n";);
}
reset_parity(lit.var());
}
m_parity_trail.reset();
}
void card_extension::get_antecedents(literal l, ext_justification_idx idx, literal_vector & r) {
if (is_card_index(idx)) {
card& c = index2card(idx);
r.push_back(c.lit());
SASSERT(value(c.lit()) == l_true);
for (unsigned i = c.k(); i < c.size(); ++i) {
SASSERT(value(c[i]) == l_false);
r.push_back(~c[i]);
DEBUG_CODE(
bool found = false;
for (unsigned i = 0; !found && i < c.k(); ++i) {
found = c[i] == l;
}
SASSERT(found););
r.push_back(c.lit());
SASSERT(value(c.lit()) == l_true);
for (unsigned i = c.k(); i < c.size(); ++i) {
SASSERT(value(c[i]) == l_false);
r.push_back(~c[i]);
}
}
else {
xor& x = index2xor(idx);
r.push_back(x.lit());
TRACE("sat", display(tout << l << " ", x, true););
SASSERT(value(x.lit()) == l_true);
SASSERT(x[0].var() == l.var() || x[1].var() == l.var());
if (x[0].var() == l.var()) {
SASSERT(value(x[1]) != l_undef);
r.push_back(value(x[1]) == l_true ? x[1] : ~x[1]);
}
else {
SASSERT(value(x[0]) != l_undef);
r.push_back(value(x[0]) == l_true ? x[0] : ~x[0]);
}
for (unsigned i = 2; i < x.size(); ++i) {
SASSERT(value(x[i]) != l_undef);
r.push_back(value(x[i]) == l_true ? x[i] : ~x[i]);
}
}
}
@ -670,10 +967,11 @@ namespace sat {
if (s().inconsistent()) return;
if (v >= m_var_infos.size()) return;
var_info& vinfo = m_var_infos[v];
ptr_vector<card>* cards = vinfo.m_lit_watch[!l.sign()];
//TRACE("sat", tout << "retrieve: " << v << " " << !l.sign() << "\n";);
//TRACE("sat", tout << "asserted: " << l << " " << (cards ? "non-empty" : "empty") << "\n";);
static unsigned is_empty = 0, non_empty = 0;
ptr_vector<card>* cards = vinfo.m_card_watch[!l.sign()];
card* crd = vinfo.m_card;
xor* x = vinfo.m_xor;
ptr_vector<xor>* xors = vinfo.m_xor_watch;
if (!is_tag_empty(cards)) {
ptr_vector<card>::iterator begin = cards->begin();
ptr_vector<card>::iterator it = begin, it2 = it, end = cards->end();
@ -702,14 +1000,56 @@ namespace sat {
}
cards->set_end(it2);
if (cards->empty()) {
m_var_infos[v].m_lit_watch[!l.sign()] = set_tag_empty(cards);
m_var_infos[v].m_card_watch[!l.sign()] = set_tag_empty(cards);
}
}
card* crd = vinfo.m_card;
if (crd != 0 && !s().inconsistent()) {
init_watch(*crd, !l.sign());
}
if (m_has_xor && !s().inconsistent()) {
asserted_xor(l, xors, x);
}
}
void card_extension::asserted_xor(literal l, ptr_vector<xor>* xors, xor* x) {
TRACE("sat", tout << l << " " << !is_tag_empty(xors) << " " << (x != 0) << "\n";);
if (!is_tag_empty(xors)) {
ptr_vector<xor>::iterator begin = xors->begin();
ptr_vector<xor>::iterator it = begin, it2 = it, end = xors->end();
for (; it != end; ++it) {
xor& c = *(*it);
if (value(c.lit()) != l_true) {
continue;
}
switch (add_assign(c, ~l)) {
case l_false: // conflict
for (; it != end; ++it, ++it2) {
*it2 = *it;
}
SASSERT(s().inconsistent());
xors->set_end(it2);
return;
case l_undef: // watch literal was swapped
break;
case l_true: // unit propagation, keep watching the literal
if (it2 != it) {
*it2 = *it;
}
++it2;
break;
}
}
xors->set_end(it2);
if (xors->empty()) {
m_var_infos[l.var()].m_xor_watch = set_tag_empty(xors);
}
}
if (x != 0 && !s().inconsistent()) {
init_watch(*x, !l.sign());
}
}
check_result card_extension::check() { return CR_DONE; }
@ -730,6 +1070,10 @@ namespace sat {
clear_watch(*c);
m_var_infos[v].m_card = 0;
dealloc(c);
xor* x = m_var_infos[v].m_xor;
clear_watch(*x);
m_var_infos[v].m_xor = 0;
dealloc(x);
}
}
m_var_lim.resize(new_lim);
@ -743,22 +1087,30 @@ namespace sat {
extension* card_extension::copy(solver* s) {
card_extension* result = alloc(card_extension);
result->set_solver(s);
for (unsigned i = 0; i < m_constraints.size(); ++i) {
for (unsigned i = 0; i < m_cards.size(); ++i) {
literal_vector lits;
card& c = *m_constraints[i];
card& c = *m_cards[i];
for (unsigned i = 0; i < c.size(); ++i) {
lits.push_back(c[i]);
}
result->add_at_least(c.lit().var(), lits, c.k());
}
for (unsigned i = 0; i < m_xors.size(); ++i) {
literal_vector lits;
xor& x = *m_xors[i];
for (unsigned i = 0; i < x.size(); ++i) {
lits.push_back(x[i]);
}
result->add_xor(x.lit().var(), lits);
}
return result;
}
void card_extension::find_mutexes(literal_vector& lits, vector<literal_vector> & mutexes) {
literal_set slits(lits);
bool change = false;
for (unsigned i = 0; i < m_constraints.size(); ++i) {
card& c = *m_constraints[i];
for (unsigned i = 0; i < m_cards.size(); ++i) {
card& c = *m_cards[i];
if (c.size() == c.k() + 1) {
literal_vector mux;
for (unsigned j = 0; j < c.size(); ++j) {
@ -786,10 +1138,10 @@ namespace sat {
}
}
void card_extension::display_watch(std::ostream& out, bool_var v, bool sign) const {
watch const* w = m_var_infos[v].m_lit_watch[sign];
void card_extension::display_watch(std::ostream& out, bool_var v, bool sign) const {
card_watch const* w = m_var_infos[v].m_card_watch[sign];
if (!is_tag_empty(w)) {
watch const& wl = *w;
card_watch const& wl = *w;
out << literal(v, sign) << " |-> ";
for (unsigned i = 0; i < wl.size(); ++i) {
out << wl[i]->lit() << " ";
@ -798,6 +1150,18 @@ namespace sat {
}
}
void card_extension::display_watch(std::ostream& out, bool_var v) const {
xor_watch const* w = m_var_infos[v].m_xor_watch;
if (!is_tag_empty(w)) {
xor_watch const& wl = *w;
out << "v" << v << " |-> ";
for (unsigned i = 0; i < wl.size(); ++i) {
out << wl[i]->lit() << " ";
}
out << "\n";
}
}
void card_extension::display(std::ostream& out, ineq& ineq) const {
for (unsigned i = 0; i < ineq.m_lits.size(); ++i) {
out << ineq.m_coeffs[i] << "*" << ineq.m_lits[i] << " ";
@ -805,6 +1169,35 @@ namespace sat {
out << ">= " << ineq.m_k << "\n";
}
void card_extension::display(std::ostream& out, xor& x, bool values) const {
out << "xor " << x.lit();
if (x.lit() != null_literal && values) {
out << "@(" << value(x.lit());
if (value(x.lit()) != l_undef) {
out << ":" << lvl(x.lit());
}
out << "): ";
}
else {
out << ": ";
}
for (unsigned i = 0; i < x.size(); ++i) {
literal l = x[i];
out << l;
if (values) {
out << "@(" << value(l);
if (value(l) != l_undef) {
out << ":" << lvl(l);
}
out << ") ";
}
else {
out << " ";
}
}
out << "\n";
}
void card_extension::display(std::ostream& out, card& c, bool values) const {
out << c.lit() << "[" << c.size() << "]";
if (c.lit() != null_literal && values) {
@ -838,23 +1231,33 @@ namespace sat {
for (unsigned vi = 0; vi < m_var_infos.size(); ++vi) {
display_watch(out, vi, false);
display_watch(out, vi, true);
display_watch(out, vi);
}
for (unsigned vi = 0; vi < m_var_infos.size(); ++vi) {
card* c = m_var_infos[vi].m_card;
if (c) {
display(out, *c, false);
}
if (c) display(out, *c, false);
xor* x = m_var_infos[vi].m_xor;
if (x) display(out, *x, false);
}
return out;
}
std::ostream& card_extension::display_justification(std::ostream& out, ext_justification_idx idx) const {
card& c = *m_constraints[idx];
out << "bound " << c.lit() << ": ";
for (unsigned i = 0; i < c.size(); ++i) {
out << c[i] << " ";
if (is_card_index(idx)) {
card& c = index2card(idx);
out << "bound " << c.lit() << ": ";
for (unsigned i = 0; i < c.size(); ++i) {
out << c[i] << " ";
}
out << ">= " << c.k();
}
else {
xor& x = index2xor(idx);
out << "xor " << x.lit() << ": ";
for (unsigned i = 0; i < x.size(); ++i) {
out << x[i] << " ";
}
}
out << ">= " << c.k();
return out;
}
@ -870,6 +1273,9 @@ namespace sat {
}
return false;
}
bool card_extension::validate_conflict(xor& x) {
return !parity(x, 0);
}
bool card_extension::validate_unit_propagation(card const& c) {
if (value(c.lit()) != l_true) return false;
for (unsigned i = c.k(); i < c.size(); ++i) {
@ -933,12 +1339,23 @@ namespace sat {
}
case justification::EXT_JUSTIFICATION: {
unsigned index = js.get_ext_justification_idx();
card& c = *m_constraints[index];
p.reset(offset*c.k());
for (unsigned i = 0; i < c.size(); ++i) {
p.push(c[i], offset);
if (is_card_index(index)) {
card& c = index2card(index);
p.reset(offset*c.k());
for (unsigned i = 0; i < c.size(); ++i) {
p.push(c[i], offset);
}
p.push(~c.lit(), offset*c.k());
}
else {
literal_vector ls;
get_antecedents(lit, index, ls);
p.reset(offset);
for (unsigned i = 0; i < ls.size(); ++i) {
p.push(~ls[i], offset);
}
p.push(~index2xor(index).lit(), offset);
}
p.push(~c.lit(), offset*c.k());
break;
}
default:

View file

@ -32,9 +32,7 @@ namespace sat {
void reset() { memset(this, 0, sizeof(*this)); }
};
// class card_allocator;
class card {
//friend class card_allocator;
unsigned m_index;
literal m_lit;
unsigned m_k;
@ -53,6 +51,22 @@ namespace sat {
void negate();
};
class xor {
unsigned m_index;
literal m_lit;
unsigned m_size;
literal m_lits[0];
public:
static size_t get_obj_size(unsigned num_lits) { return sizeof(xor) + num_lits * sizeof(literal); }
xor(unsigned index, literal lit, literal_vector const& lits);
unsigned index() const { return m_index; }
literal lit() const { return m_lit; }
literal operator[](unsigned i) const { return m_lits[i]; }
unsigned size() const { return m_size; }
void swap(unsigned i, unsigned j) { std::swap(m_lits[i], m_lits[j]); }
void negate() { m_lits[0].neg(); }
};
struct ineq {
literal_vector m_lits;
unsigned_vector m_coeffs;
@ -61,29 +75,48 @@ namespace sat {
void push(literal l, unsigned c) { m_lits.push_back(l); m_coeffs.push_back(c); }
};
typedef ptr_vector<card> watch;
typedef ptr_vector<card> card_watch;
typedef ptr_vector<xor> xor_watch;
struct var_info {
watch* m_lit_watch[2];
card* m_card;
var_info(): m_card(0) {
m_lit_watch[0] = 0;
m_lit_watch[1] = 0;
card_watch* m_card_watch[2];
xor_watch* m_xor_watch;
card* m_card;
xor* m_xor;
var_info(): m_xor_watch(0), m_card(0), m_xor(0) {
m_card_watch[0] = 0;
m_card_watch[1] = 0;
}
void reset() {
dealloc(m_card);
dealloc(card_extension::set_tag_non_empty(m_lit_watch[0]));
dealloc(card_extension::set_tag_non_empty(m_lit_watch[1]));
dealloc(m_xor);
dealloc(card_extension::set_tag_non_empty(m_card_watch[0]));
dealloc(card_extension::set_tag_non_empty(m_card_watch[1]));
dealloc(card_extension::set_tag_non_empty(m_xor_watch));
}
};
template<typename T>
static ptr_vector<T>* set_tag_empty(ptr_vector<T>* c) {
return TAG(ptr_vector<T>*, c, 1);
}
template<typename T>
static bool is_tag_empty(ptr_vector<T> const* c) {
return !c || GET_TAG(c) == 1;
}
template<typename T>
static ptr_vector<T>* set_tag_non_empty(ptr_vector<T>* c) {
return UNTAG(ptr_vector<T>*, c);
}
static ptr_vector<card>* set_tag_empty(ptr_vector<card>* c);
static bool is_tag_empty(ptr_vector<card> const* c);
static ptr_vector<card>* set_tag_non_empty(ptr_vector<card>* c);
solver* m_solver;
stats m_stats;
ptr_vector<card> m_constraints;
ptr_vector<card> m_cards;
ptr_vector<xor> m_xors;
// watch literals
svector<var_info> m_var_infos;
@ -98,8 +131,14 @@ namespace sat {
int m_bound;
tracked_uint_set m_active_var_set;
literal_vector m_lemma;
// literal_vector m_literals;
unsigned m_num_propagations_since_pop;
bool m_has_xor;
unsigned_vector m_parity_marks;
literal_vector m_parity_trail;
void ensure_parity_size(bool_var v);
unsigned get_parity(bool_var v);
void inc_parity(bool_var v);
void reset_parity(bool_var v);
solver& s() const { return *m_solver; }
void init_watch(card& c, bool is_true);
@ -111,13 +150,44 @@ namespace sat {
void clear_watch(card& c);
void reset_coeffs();
void reset_marked_literals();
void unwatch_literal(literal w, card* c);
// xor specific functionality
void clear_watch(xor& x);
void watch_literal(xor& x, literal lit);
void unwatch_literal(literal w, xor* x);
void init_watch(xor& x, bool is_true);
void assign(xor& x, literal lit);
void set_conflict(xor& x, literal lit);
bool parity(xor const& x, unsigned offset) const;
lbool add_assign(xor& x, literal alit);
void asserted_xor(literal l, ptr_vector<xor>* xors, xor* x);
bool is_card_index(unsigned idx) const { return 0 == (idx & 0x1); }
card& index2card(unsigned idx) const { SASSERT(is_card_index(idx)); return *m_cards[idx >> 1]; }
xor& index2xor(unsigned idx) const { SASSERT(!is_card_index(idx)); return *m_xors[idx >> 1]; }
void get_xor_antecedents(unsigned index, literal_vector& r);
template<typename T>
bool remove(ptr_vector<T>& ts, T* t) {
unsigned sz = ts.size();
for (unsigned j = 0; j < sz; ++j) {
if (ts[j] == t) {
std::swap(ts[j], ts[sz-1]);
ts.pop_back();
return sz == 1;
}
}
return false;
}
inline lbool value(literal lit) const { return m_solver->value(lit); }
inline unsigned lvl(literal lit) const { return m_solver->lvl(lit); }
inline unsigned lvl(bool_var v) const { return m_solver->lvl(v); }
void unwatch_literal(literal w, card* c);
bool remove(ptr_vector<card>& cards, card* c);
void normalize_active_coeffs();
void inc_coeff(literal l, int offset);
@ -131,6 +201,7 @@ namespace sat {
// validation utilities
bool validate_conflict(card& c);
bool validate_conflict(xor& x);
bool validate_assign(literal_vector const& lits, literal lit);
bool validate_lemma();
bool validate_unit_propagation(card const& c);
@ -143,12 +214,16 @@ namespace sat {
void display(std::ostream& out, ineq& p) const;
void display(std::ostream& out, card& c, bool values) const;
void display(std::ostream& out, xor& c, bool values) const;
void display_watch(std::ostream& out, bool_var v) const;
void display_watch(std::ostream& out, bool_var v, bool sign) const;
public:
card_extension();
virtual ~card_extension();
virtual void set_solver(solver* s) { m_solver = s; }
void add_at_least(bool_var v, literal_vector const& lits, unsigned k);
void add_xor(bool_var v, literal_vector const& lits);
virtual void propagate(literal l, ext_constraint_idx idx, bool & keep);
virtual bool resolve_conflict();
virtual void get_antecedents(literal l, ext_justification_idx idx, literal_vector & r);

View file

@ -26,5 +26,5 @@ def_module_params('sat',
('dimacs.core', BOOL, False, 'extract core from DIMACS benchmarks'),
('drat.file', SYMBOL, '', 'file to dump DRAT proofs'),
('drat.check', BOOL, False, 'build up internal proof and check'),
('cardinality.solver', BOOL, True, 'use cardinality solver'),
('cardinality.solver', BOOL, False, 'use cardinality/xor solver'),
))

View file

@ -217,6 +217,7 @@ public:
sat_params p1(p);
m_params.set_bool("elim_vars", false);
m_params.set_bool("keep_cardinality_constraints", p1.cardinality_solver());
m_params.set_bool("cardinality_solver", p1.cardinality_solver());
m_solver.updt_params(m_params);
m_optimize_model = m_params.get_bool("optimize_model", false);

View file

@ -65,6 +65,7 @@ struct goal2sat::imp {
expr_ref_vector m_trail;
expr_ref_vector m_interpreted_atoms;
bool m_default_external;
bool m_cardinality_solver;
imp(ast_manager & _m, params_ref const & p, sat::solver & s, atom2bool_var & map, dep2asm_map& dep2asm, bool default_external):
m(_m),
@ -83,6 +84,8 @@ struct goal2sat::imp {
void updt_params(params_ref const & p) {
m_ite_extra = p.get_bool("ite_extra", true);
m_max_memory = megabytes_to_bytes(p.get_uint("max_memory", UINT_MAX));
m_cardinality_solver = p.get_bool("cardinality_solver", false);
std::cout << p << "\n";
}
void throw_op_not_handled(std::string const& s) {
@ -339,7 +342,7 @@ struct goal2sat::imp {
}
}
void convert_iff(app * t, bool root, bool sign) {
void convert_iff2(app * t, bool root, bool sign) {
TRACE("goal2sat", tout << "convert_iff " << root << " " << sign << "\n" << mk_ismt2_pp(t, m) << "\n";);
unsigned sz = m_result_stack.size();
SASSERT(sz >= 2);
@ -372,8 +375,33 @@ struct goal2sat::imp {
}
}
void convert_pb_args(app* t, sat::literal_vector& lits) {
unsigned num_args = t->get_num_args();
void convert_iff(app * t, bool root, bool sign) {
TRACE("goal2sat", tout << "convert_iff " << root << " " << sign << "\n" << mk_ismt2_pp(t, m) << "\n";);
unsigned sz = m_result_stack.size();
unsigned num = get_num_args(t);
SASSERT(sz >= num && num >= 2);
if (num == 2) {
convert_iff2(t, root, sign);
return;
}
sat::literal_vector lits;
convert_pb_args(num, lits);
sat::bool_var v = m_solver.mk_var(true);
ensure_extension();
if (lits.size() % 2 == 0) lits[0].neg();
m_ext->add_xor(v, lits);
sat::literal lit(v, sign);
if (root) {
m_result_stack.reset();
mk_clause(lit);
}
else {
m_result_stack.shrink(sz - num);
m_result_stack.push_back(lit);
}
}
void convert_pb_args(unsigned num_args, sat::literal_vector& lits) {
unsigned sz = m_result_stack.size();
for (unsigned i = 0; i < num_args; ++i) {
sat::literal lit(m_result_stack[sz - num_args + i]);
@ -396,7 +424,7 @@ struct goal2sat::imp {
SASSERT(k.is_unsigned());
sat::literal_vector lits;
unsigned sz = m_result_stack.size();
convert_pb_args(t, lits);
convert_pb_args(t->get_num_args(), lits);
sat::bool_var v = m_solver.mk_var(true);
sat::literal lit(v, sign);
m_ext->add_at_least(v, lits, k.get_unsigned());
@ -415,7 +443,7 @@ struct goal2sat::imp {
SASSERT(k.is_unsigned());
sat::literal_vector lits;
unsigned sz = m_result_stack.size();
convert_pb_args(t, lits);
convert_pb_args(t->get_num_args(), lits);
for (unsigned i = 0; i < lits.size(); ++i) {
lits[i].neg();
}
@ -434,7 +462,7 @@ struct goal2sat::imp {
void convert_eq_k(app* t, rational k, bool root, bool sign) {
SASSERT(k.is_unsigned());
sat::literal_vector lits;
convert_pb_args(t, lits);
convert_pb_args(t->get_num_args(), lits);
sat::bool_var v1 = m_solver.mk_var(true);
sat::bool_var v2 = m_solver.mk_var(true);
sat::literal l1(v1, false), l2(v2, false);
@ -528,6 +556,41 @@ struct goal2sat::imp {
UNREACHABLE();
}
}
unsigned get_num_args(app* t) {
if (m.is_iff(t) && m_cardinality_solver) {
unsigned n = 2;
while (m.is_iff(t->get_arg(1))) {
++n;
t = to_app(t->get_arg(1));
}
return n;
}
else {
return t->get_num_args();
}
}
expr* get_arg(app* t, unsigned idx) {
if (m.is_iff(t) && m_cardinality_solver) {
while (idx >= 1) {
SASSERT(m.is_iff(t));
t = to_app(t->get_arg(1));
--idx;
}
if (m.is_iff(t)) {
return t->get_arg(idx);
}
else {
return t;
}
}
else {
return t->get_arg(idx);
}
}
void process(expr * n) {
//SASSERT(m_result_stack.empty());
@ -559,9 +622,9 @@ struct goal2sat::imp {
visit(t->get_arg(0), root, !sign);
continue;
}
unsigned num = t->get_num_args();
unsigned num = get_num_args(t);
while (fr.m_idx < num) {
expr * arg = t->get_arg(fr.m_idx);
expr * arg = get_arg(t, fr.m_idx);
fr.m_idx++;
if (!visit(arg, false, false))
goto loop;