From 3047d930e12297ae74d4d11979f5c266e41475f7 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner <nbjorner@microsoft.com> Date: Sat, 13 Jan 2018 19:53:50 -0800 Subject: [PATCH] fix xor processing Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com> --- src/sat/ba_solver.cpp | 65 ++++++++++++++++++++----------------- src/sat/ba_solver.h | 7 ++-- src/sat/tactic/goal2sat.cpp | 10 ++++-- 3 files changed, 46 insertions(+), 36 deletions(-) diff --git a/src/sat/ba_solver.cpp b/src/sat/ba_solver.cpp index e6f16dba5..7865f618e 100644 --- a/src/sat/ba_solver.cpp +++ b/src/sat/ba_solver.cpp @@ -189,8 +189,8 @@ namespace sat { // ----------------------------------- // xor - ba_solver::xor::xor(unsigned id, literal lit, literal_vector const& lits): - constraint(xor_t, id, lit, lits.size(), get_obj_size(lits.size())) { + ba_solver::xor::xor(unsigned id, literal_vector const& lits): + constraint(xor_t, id, null_literal, lits.size(), get_obj_size(lits.size())) { for (unsigned i = 0; i < size(); ++i) { m_lits[i] = lits[i]; } @@ -959,17 +959,19 @@ namespace sat { lbool ba_solver::add_assign(xor& x, literal alit) { // literal is assigned unsigned sz = x.size(); - TRACE("ba", tout << "assign: " << x.lit() << ": " << ~alit << "@" << lvl(~alit) << "\n";); + TRACE("ba", tout << "assign: " << ~alit << "@" << lvl(~alit) << " " << x << "\n"; display(tout, x, true); ); - SASSERT(x.lit() == null_literal || value(x.lit()) == l_true); + SASSERT(x.lit() == null_literal); SASSERT(value(alit) != l_undef); unsigned index = 0; - for (; index <= 2; ++index) { + for (; index < 2; ++index) { if (x[index].var() == alit.var()) break; } if (index == 2) { // literal is no longer watched. - UNREACHABLE(); + // this can happen as both polarities of literals + // are put in watch lists and they are removed only + // one polarity at a time. return l_undef; } SASSERT(x[index].var() == alit.var()); @@ -979,7 +981,10 @@ namespace sat { literal lit2 = x[i]; if (value(lit2) == l_undef) { x.swap(index, i); + // unwatch_literal(alit, x); watch_literal(lit2, x); + watch_literal(~lit2, x); + TRACE("ba", tout << "swap in: " << lit2 << " " << x << "\n";); return l_undef; } } @@ -1627,13 +1632,13 @@ namespace sat { add_pb_ge(lit, wlits, k, false); } - void ba_solver::add_xor(bool_var v, literal_vector const& lits) { - add_xor(literal(v, false), lits, false); + void ba_solver::add_xor(literal_vector const& lits) { + add_xor(lits, false); } - ba_solver::constraint* ba_solver::add_xor(literal lit, literal_vector const& lits, bool learned) { + ba_solver::constraint* ba_solver::add_xor(literal_vector const& lits, bool learned) { void * mem = m_allocator.allocate(xor::get_obj_size(lits.size())); - xor* x = new (mem) xor(next_id(), lit, lits); + xor* x = new (mem) xor(next_id(), lits); x->set_learned(learned); add_constraint(x); for (literal l : lits) s().set_external(l.var()); // TBD: determine if goal2sat does this. @@ -1740,20 +1745,24 @@ namespace sat { unsigned level = lvl(l); bool_var v = l.var(); SASSERT(js.get_kind() == justification::EXT_JUSTIFICATION); - TRACE("ba", tout << l << ": " << js << "\n"; tout << s().m_trail << "\n";); + TRACE("ba", tout << l << ": " << js << "\n"; + for (unsigned i = 0; i <= index; ++i) tout << s().m_trail[i] << " "; tout << "\n"; + s().display_units(tout); + ); unsigned num_marks = 0; unsigned count = 0; while (true) { + TRACE("ba", tout << "process: " << l << "\n";); ++count; if (js.get_kind() == justification::EXT_JUSTIFICATION) { constraint& c = index2constraint(js.get_ext_justification_idx()); + TRACE("ba", tout << c << "\n";); if (!c.is_xor()) { r.push_back(l); } else { - xor& x = c.to_xor(); - if (x.lit() != null_literal && lvl(x.lit()) > 0) r.push_back(x.lit()); + xor& x = c.to_xor(); if (x[1].var() == l.var()) { x.swap(0, 1); } @@ -1762,6 +1771,7 @@ namespace sat { literal lit(value(x[i]) == l_true ? x[i] : ~x[i]); inc_parity(lit.var()); if (lvl(lit) == level) { + TRACE("ba", tout << "mark: " << lit << "\n";); ++num_marks; } else { @@ -1773,24 +1783,25 @@ namespace sat { else { r.push_back(l); } + bool found = false; while (num_marks > 0) { l = s().m_trail[index]; v = l.var(); unsigned n = get_parity(v); if (n > 0) { reset_parity(v); + num_marks -= n; if (n % 2 == 1) { + found = true; break; } - --num_marks; } --index; } - if (num_marks == 0) { + if (!found) { break; } --index; - --num_marks; js = s().m_justification[v]; } @@ -2492,6 +2503,11 @@ namespace sat { m_lits.append(n, lits); s.s().mk_clause(n, m_lits.c_ptr()); } + + std::ostream& ba_solver::ba_sort::pp(std::ostream& out, literal l) const { + return out << l; + } + // ------------------------------- // set literals equivalent @@ -3299,7 +3315,7 @@ namespace sat { xor const& x = cp->to_xor(); lits.reset(); for (literal l : x) lits.push_back(l); - result->add_xor(x.lit(), lits, x.learned()); + result->add_xor(lits, x.learned()); break; } default: @@ -3427,19 +3443,8 @@ namespace sat { } void ba_solver::display(std::ostream& out, xor const& x, bool values) const { - out << "xor " << x.lit(); - if (x.lit() != null_literal && values) { - out << "@(" << value(x.lit()); - if (value(x.lit()) != l_undef) { - out << ":" << lvl(x.lit()); - } - out << "): "; - } - else { - out << ": "; - } - for (unsigned i = 0; i < x.size(); ++i) { - literal l = x[i]; + out << "xor: "; + for (literal l : x) { out << l; if (values) { out << "@(" << value(l); diff --git a/src/sat/ba_solver.h b/src/sat/ba_solver.h index 1116bb166..cd2e941ce 100644 --- a/src/sat/ba_solver.h +++ b/src/sat/ba_solver.h @@ -178,7 +178,7 @@ namespace sat { literal m_lits[0]; public: static size_t get_obj_size(unsigned num_lits) { return sizeof(xor) + num_lits * sizeof(literal); } - xor(unsigned id, literal lit, literal_vector const& lits); + xor(unsigned id, literal_vector const& lits); literal operator[](unsigned i) const { return m_lits[i]; } literal const* begin() const { return m_lits; } literal const* end() const { return begin() + m_size; } @@ -246,6 +246,7 @@ namespace sat { literal mk_max(literal l1, literal l2); literal mk_min(literal l1, literal l2); void mk_clause(unsigned n, literal const* lits); + std::ostream& pp(std::ostream& out, literal l) const; }; ba_sort m_ba; psort_nw<ba_sort> m_sort; @@ -458,7 +459,7 @@ namespace sat { constraint* add_at_least(literal l, literal_vector const& lits, unsigned k, bool learned); constraint* add_pb_ge(literal l, svector<wliteral> const& wlits, unsigned k, bool learned); - constraint* add_xor(literal l, literal_vector const& lits, bool learned); + constraint* add_xor(literal_vector const& lits, bool learned); void copy_core(ba_solver* result); public: @@ -469,7 +470,7 @@ namespace sat { virtual void set_unit_walk(unit_walk* u) { m_unit_walk = u; } void add_at_least(bool_var v, literal_vector const& lits, unsigned k); void add_pb_ge(bool_var v, svector<wliteral> const& wlits, unsigned k); - void add_xor(bool_var v, literal_vector const& lits); + void add_xor(literal_vector const& lits); virtual bool propagate(literal l, ext_constraint_idx idx); virtual lbool resolve_conflict(); diff --git a/src/sat/tactic/goal2sat.cpp b/src/sat/tactic/goal2sat.cpp index 97129a861..596712a7b 100644 --- a/src/sat/tactic/goal2sat.cpp +++ b/src/sat/tactic/goal2sat.cpp @@ -392,11 +392,15 @@ struct goal2sat::imp { return; } sat::literal_vector lits; - convert_pb_args(num, lits); sat::bool_var v = m_solver.mk_var(true); + lits.push_back(sat::literal(v, true)); + convert_pb_args(num, lits); + // ensure that = is converted to xor + for (unsigned i = 1; i + 1 < lits.size(); ++i) { + lits[i].neg(); + } ensure_extension(); - if (lits.size() % 2 == 0) lits[0].neg(); - m_ext->add_xor(v, lits); + m_ext->add_xor(lits); sat::literal lit(v, sign); if (root) { m_result_stack.reset();