From 3047d930e12297ae74d4d11979f5c266e41475f7 Mon Sep 17 00:00:00 2001
From: Nikolaj Bjorner <nbjorner@microsoft.com>
Date: Sat, 13 Jan 2018 19:53:50 -0800
Subject: [PATCH] fix xor processing

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
---
 src/sat/ba_solver.cpp       | 65 ++++++++++++++++++++-----------------
 src/sat/ba_solver.h         |  7 ++--
 src/sat/tactic/goal2sat.cpp | 10 ++++--
 3 files changed, 46 insertions(+), 36 deletions(-)

diff --git a/src/sat/ba_solver.cpp b/src/sat/ba_solver.cpp
index e6f16dba5..7865f618e 100644
--- a/src/sat/ba_solver.cpp
+++ b/src/sat/ba_solver.cpp
@@ -189,8 +189,8 @@ namespace sat {
     // -----------------------------------
     // xor
     
-    ba_solver::xor::xor(unsigned id, literal lit, literal_vector const& lits):
-    constraint(xor_t, id, lit, lits.size(), get_obj_size(lits.size())) {
+    ba_solver::xor::xor(unsigned id, literal_vector const& lits):
+    constraint(xor_t, id, null_literal, lits.size(), get_obj_size(lits.size())) {
         for (unsigned i = 0; i < size(); ++i) {
             m_lits[i] = lits[i];
         }
@@ -959,17 +959,19 @@ namespace sat {
     lbool ba_solver::add_assign(xor& x, literal alit) {
         // literal is assigned     
         unsigned sz = x.size();
-        TRACE("ba", tout << "assign: " << x.lit() << ": " << ~alit << "@" << lvl(~alit) << "\n";);
+        TRACE("ba", tout << "assign: "  << ~alit << "@" << lvl(~alit) << " " << x << "\n"; display(tout, x, true); );
 
-        SASSERT(x.lit() == null_literal || value(x.lit()) == l_true);
+        SASSERT(x.lit() == null_literal);
         SASSERT(value(alit) != l_undef);
         unsigned index = 0;
-        for (; index <= 2; ++index) {
+        for (; index < 2; ++index) {
             if (x[index].var() == alit.var()) break;
         }
         if (index == 2) {
             // literal is no longer watched.
-            UNREACHABLE();
+            // this can happen as both polarities of literals
+            // are put in watch lists and they are removed only
+            // one polarity at a time.
             return l_undef;
         }
         SASSERT(x[index].var() == alit.var());
@@ -979,7 +981,10 @@ namespace sat {
             literal lit2 = x[i];
             if (value(lit2) == l_undef) {
                 x.swap(index, i);
+                // unwatch_literal(alit, x);
                 watch_literal(lit2, x);
+                watch_literal(~lit2, x);
+                TRACE("ba", tout << "swap in: " << lit2 << " " << x << "\n";);
                 return l_undef;
             }
         }
@@ -1627,13 +1632,13 @@ namespace sat {
         add_pb_ge(lit, wlits, k, false);
     }
 
-    void ba_solver::add_xor(bool_var v, literal_vector const& lits) {
-        add_xor(literal(v, false), lits, false);
+    void ba_solver::add_xor(literal_vector const& lits) {
+        add_xor(lits, false);
     }
 
-    ba_solver::constraint* ba_solver::add_xor(literal lit, literal_vector const& lits, bool learned) {
+    ba_solver::constraint* ba_solver::add_xor(literal_vector const& lits, bool learned) {
         void * mem = m_allocator.allocate(xor::get_obj_size(lits.size()));
-        xor* x = new (mem) xor(next_id(), lit, lits);
+        xor* x = new (mem) xor(next_id(), lits);
         x->set_learned(learned);
         add_constraint(x);
         for (literal l : lits) s().set_external(l.var()); // TBD: determine if goal2sat does this.
@@ -1740,20 +1745,24 @@ namespace sat {
         unsigned level = lvl(l);
         bool_var v = l.var();
         SASSERT(js.get_kind() == justification::EXT_JUSTIFICATION);
-        TRACE("ba", tout << l << ": " << js << "\n"; tout << s().m_trail << "\n";);
+        TRACE("ba", tout << l << ": " << js << "\n"; 
+              for (unsigned i = 0; i <= index; ++i) tout << s().m_trail[i] << " "; tout << "\n";
+              s().display_units(tout);
+              );
 
         unsigned num_marks = 0;
         unsigned count = 0;
         while (true) {
+            TRACE("ba", tout << "process: " << l << "\n";);
             ++count;
             if (js.get_kind() == justification::EXT_JUSTIFICATION) {
                 constraint& c = index2constraint(js.get_ext_justification_idx());
+                TRACE("ba", tout << c << "\n";);
                 if (!c.is_xor()) {
                     r.push_back(l);
                 }
                 else {
-                    xor& x = c.to_xor();
-                    if (x.lit() != null_literal && lvl(x.lit()) > 0) r.push_back(x.lit());
+                    xor& x = c.to_xor();                    
                     if (x[1].var() == l.var()) {
                         x.swap(0, 1);
                     }
@@ -1762,6 +1771,7 @@ namespace sat {
                         literal lit(value(x[i]) == l_true ? x[i] : ~x[i]);
                         inc_parity(lit.var());
                         if (lvl(lit) == level) {
+                            TRACE("ba", tout << "mark: " << lit << "\n";);
                             ++num_marks;
                         }
                         else {
@@ -1773,24 +1783,25 @@ namespace sat {
             else {
                 r.push_back(l);
             }
+            bool found = false;
             while (num_marks > 0) {
                 l = s().m_trail[index];
                 v = l.var();
                 unsigned n = get_parity(v);
                 if (n > 0) {
                     reset_parity(v);
+                    num_marks -= n;
                     if (n % 2 == 1) {
+                        found = true;
                         break;
                     }
-                    --num_marks;
                 }
                 --index;
             }
-            if (num_marks == 0) {
+            if (!found) {
                 break;
             }
             --index;
-            --num_marks;
             js = s().m_justification[v];
         }
 
@@ -2492,6 +2503,11 @@ namespace sat {
         m_lits.append(n, lits);
         s.s().mk_clause(n, m_lits.c_ptr());
     }
+
+    std::ostream& ba_solver::ba_sort::pp(std::ostream& out, literal l) const {
+        return out << l;
+    }
+
     
     // -------------------------------
     // set literals equivalent
@@ -3299,7 +3315,7 @@ namespace sat {
                 xor const& x = cp->to_xor();
                 lits.reset();
                 for (literal l : x) lits.push_back(l);
-                result->add_xor(x.lit(), lits, x.learned());        
+                result->add_xor(lits, x.learned());        
                 break;
             }
             default:
@@ -3427,19 +3443,8 @@ namespace sat {
     }
 
     void ba_solver::display(std::ostream& out, xor const& x, bool values) const {
-        out << "xor " << x.lit();
-        if (x.lit() != null_literal && values) {
-            out << "@(" << value(x.lit());
-            if (value(x.lit()) != l_undef) {
-                out << ":" << lvl(x.lit());
-            }
-            out << "): ";
-        }
-        else {
-            out << ": ";
-        }
-        for (unsigned i = 0; i < x.size(); ++i) {
-            literal l = x[i];
+        out << "xor: ";
+        for (literal l : x) {
             out << l;
             if (values) {
                 out << "@(" << value(l);
diff --git a/src/sat/ba_solver.h b/src/sat/ba_solver.h
index 1116bb166..cd2e941ce 100644
--- a/src/sat/ba_solver.h
+++ b/src/sat/ba_solver.h
@@ -178,7 +178,7 @@ namespace sat {
             literal        m_lits[0];
         public:
             static size_t get_obj_size(unsigned num_lits) { return sizeof(xor) + num_lits * sizeof(literal); }
-            xor(unsigned id, literal lit, literal_vector const& lits);
+            xor(unsigned id, literal_vector const& lits);
             literal operator[](unsigned i) const { return m_lits[i]; }
             literal const* begin() const { return m_lits; }
             literal const* end() const { return begin() + m_size; }
@@ -246,6 +246,7 @@ namespace sat {
             literal mk_max(literal l1, literal l2);
             literal mk_min(literal l1, literal l2);
             void    mk_clause(unsigned n, literal const* lits);
+            std::ostream& pp(std::ostream& out, literal l) const;
         };
         ba_sort           m_ba;
         psort_nw<ba_sort> m_sort;
@@ -458,7 +459,7 @@ namespace sat {
 
         constraint* add_at_least(literal l, literal_vector const& lits, unsigned k, bool learned);
         constraint* add_pb_ge(literal l, svector<wliteral> const& wlits, unsigned k, bool learned);
-        constraint* add_xor(literal l, literal_vector const& lits, bool learned);
+        constraint* add_xor(literal_vector const& lits, bool learned);
 
         void copy_core(ba_solver* result);
     public:
@@ -469,7 +470,7 @@ namespace sat {
         virtual void set_unit_walk(unit_walk* u) { m_unit_walk = u; }
         void    add_at_least(bool_var v, literal_vector const& lits, unsigned k);
         void    add_pb_ge(bool_var v, svector<wliteral> const& wlits, unsigned k);
-        void    add_xor(bool_var v, literal_vector const& lits);
+        void    add_xor(literal_vector const& lits);
 
         virtual bool propagate(literal l, ext_constraint_idx idx);
         virtual lbool resolve_conflict();
diff --git a/src/sat/tactic/goal2sat.cpp b/src/sat/tactic/goal2sat.cpp
index 97129a861..596712a7b 100644
--- a/src/sat/tactic/goal2sat.cpp
+++ b/src/sat/tactic/goal2sat.cpp
@@ -392,11 +392,15 @@ struct goal2sat::imp {
             return;
         }
         sat::literal_vector lits;
-        convert_pb_args(num, lits);
         sat::bool_var v = m_solver.mk_var(true);
+        lits.push_back(sat::literal(v, true));
+        convert_pb_args(num, lits);
+        // ensure that = is converted to xor
+        for (unsigned i = 1; i + 1 < lits.size(); ++i) {
+            lits[i].neg();
+        }
         ensure_extension();
-        if (lits.size() % 2 == 0) lits[0].neg();
-        m_ext->add_xor(v, lits);
+        m_ext->add_xor(lits);
         sat::literal lit(v, sign);
         if (root) {            
             m_result_stack.reset();