diff --git a/src/sat/ba_solver.cpp b/src/sat/ba_solver.cpp index e3a8a7ea2..6819da579 100644 --- a/src/sat/ba_solver.cpp +++ b/src/sat/ba_solver.cpp @@ -2048,40 +2048,30 @@ namespace sat { return lit; } + ba_solver::constraint* ba_solver::add_xr(literal_vector const& _lits, bool learned) { - struct parity { - bool sign; bool lit; - parity(): sign(false), lit(false) {} - // {false, false}, p => {false, true} - // {false, false}, !p => {true, true} - // {false, true}, p => {true, false} - // {false, true}, !p => {true, false} - void add(literal l) { - sign = lit == (sign == l.sign()); - lit = !lit; - } - }; literal_vector lits; - u_map var2parity; + u_map var2sign; + bool sign = false, odd = false; for (literal lit : _lits) { - var2parity.insert_if_not_there2(lit.var(), parity())->get_data().m_value.add(lit); + if (var2sign.find(lit.var(), sign)) { + var2sign.erase(lit.var()); + odd ^= (sign ^ lit.sign()); + } + else { + var2sign.insert(lit.var(), lit.sign()); + } } - bool polarity = false; - for (auto const& kv : var2parity) { - bool lit = kv.m_value.lit; - bool sign = kv.m_value.sign; - if (lit) - lits.push_back(literal(kv.m_key, sign)); - else - polarity = polarity ^ sign; + for (auto const& kv : var2sign) { + lits.push_back(literal(kv.m_key, kv.m_value)); } - if (polarity && !lits.empty()) { + if (odd && !lits.empty()) { lits[0].neg(); } switch (lits.size()) { case 0: - if (polarity) + if (!odd) s().set_conflict(justification(0)); return nullptr; case 1: