From fe43f8df8f2c5298bca4546430a9c27defb36ff8 Mon Sep 17 00:00:00 2001
From: Nikolaj Bjorner <nbjorner@microsoft.com>
Date: Thu, 3 Sep 2020 08:11:43 -0700
Subject: [PATCH] na

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
---
 scripts/mk_project.py          |  2 +-
 src/sat/smt/bv_internalize.cpp | 61 +++++++++++++++++++++++-----------
 src/sat/smt/bv_solver.cpp      | 55 ++++++++++++++++++++++++++++++
 src/sat/smt/bv_solver.h        | 32 +++++++++++++++++-
 4 files changed, 129 insertions(+), 21 deletions(-)

diff --git a/scripts/mk_project.py b/scripts/mk_project.py
index 548733b3f..e829f3a1c 100644
--- a/scripts/mk_project.py
+++ b/scripts/mk_project.py
@@ -49,7 +49,7 @@ def init_project_def():
     add_lib('core_tactics', ['tactic', 'macros', 'normal_forms', 'rewriter', 'pattern'], 'tactic/core')
     add_lib('arith_tactics', ['core_tactics', 'sat'], 'tactic/arith')
 
-    add_lib('sat_smt', ['sat', 'euf', 'tactic', 'smt_params'], 'sat/smt')
+    add_lib('sat_smt', ['sat', 'euf', 'tactic', 'smt_params', 'bit_blaster'], 'sat/smt')
     add_lib('sat_tactic', ['tactic', 'sat', 'solver', 'sat_smt'], 'sat/tactic')
     add_lib('nlsat_tactic', ['nlsat', 'sat_tactic', 'arith_tactics'], 'nlsat/tactic')
     add_lib('subpaving_tactic', ['core_tactics', 'subpaving'], 'math/subpaving/tactic')
diff --git a/src/sat/smt/bv_internalize.cpp b/src/sat/smt/bv_internalize.cpp
index b290f7b16..28b132a3b 100644
--- a/src/sat/smt/bv_internalize.cpp
+++ b/src/sat/smt/bv_internalize.cpp
@@ -22,6 +22,28 @@ Author:
 
 namespace bv {
 
+    class add_var_pos_trail : public trail<euf::solver> {
+        solver::bit_atom * m_atom;
+    public:
+        add_var_pos_trail(solver::bit_atom * a):m_atom(a) {}
+        void undo(euf::solver & euf) override {
+            SASSERT(m_atom->m_occs);
+            m_atom->m_occs = m_atom->m_occs->m_next;
+        }
+    };
+
+    class mk_atom_trail : public trail<euf::solver> {
+        solver& th;
+        sat::bool_var m_var;
+    public:
+        mk_atom_trail(sat::bool_var v, solver& th):m_var(v), th(th) {}
+        void undo(euf::solver & euf) override {
+            solver::atom * a = th.get_bv2a(m_var);
+            a->~atom();
+            th.erase_bv2a(m_var);
+        }
+    };
+
     euf::theory_var solver::mk_var(euf::enode* n) {
         theory_var r = euf::th_euf_solver::mk_var(n);
         m_find.mk_var();
@@ -80,8 +102,11 @@ namespace bv {
         return e;
     }
 
-    void solver::register_true_false_bit(theory_var v, unsigned i) {
-
+    void solver::register_true_false_bit(theory_var v, unsigned idx) {
+        SASSERT(s().value(m_bits[v][idx]) != l_undef);
+        bool is_true = (s().value(m_bits[v][idx]) == l_true);
+        zero_one_bits & bits = m_zero_one_bits[v];
+        bits.push_back(zero_one_bit(v, idx, is_true));
     }
 
     /**
@@ -91,32 +116,26 @@ namespace bv {
         literal_vector & bits = m_bits[v];
         unsigned idx          = bits.size();
         bits.push_back(l);
-#if 0
-        if (l.var() == true_bool_var) {
+        if (s().value(l) != l_undef && s().lvl(l) == 0) {
             register_true_false_bit(v, idx);
         }
         else {
-            theory_id th_id = ctx.get_var_theory(l.var());
-            if (th_id == get_id()) {
-                atom * a = get_bv2a(l.var());
-                SASSERT(a && a->is_bit());
-                bit_atom * b = static_cast<bit_atom*>(a);
-                find_new_diseq_axioms(b->m_occs, v, idx);
-                ctx.push(add_var_pos_trail(b));
-                b->m_occs = new (get_region()) var_pos_occ(v, idx, b->m_occs);
-            }
-            else {
-                SASSERT(th_id == null_theory_id);
-                ctx.set_var_theory(l.var(), get_id());
-                SASSERT(ctx.get_var_theory(l.var()) == get_id());
+            atom * a = get_bv2a(l.var());
+            SASSERT(!a || a->is_bit());
+            if (a) {
                 bit_atom * b = new (get_region()) bit_atom();
                 insert_bv2a(l.var(), b);
-                ctx.push(mk_atom_trail(l.var()));
+                ctx.push(mk_atom_trail(l.var(), *this));
                 SASSERT(b->m_occs == 0);
                 b->m_occs = new (get_region()) var_pos_occ(v, idx);
             }
+            else {
+                bit_atom * b = static_cast<bit_atom*>(a);
+                find_new_diseq_axioms(b->m_occs, v, idx);
+                ctx.push(add_var_pos_trail(b));
+                b->m_occs = new (get_region()) var_pos_occ(v, idx, b->m_occs);                
+            }
         }
-#endif
     }
 
     void solver::init_bits(euf::enode * n, expr_ref_vector const & bits) {
@@ -134,6 +153,10 @@ namespace bv {
         return bv.get_bv_size(n->get_owner());
     }
 
+    unsigned solver::get_bv_size(theory_var v) {
+        return get_bv_size(get_enode(v));
+    }
+
     void solver::internalize_num(app* n, theory_var v) {
         numeral val;
         unsigned sz = 0;
diff --git a/src/sat/smt/bv_solver.cpp b/src/sat/smt/bv_solver.cpp
index b67fc90b8..f8469f7be 100644
--- a/src/sat/smt/bv_solver.cpp
+++ b/src/sat/smt/bv_solver.cpp
@@ -52,5 +52,60 @@ namespace bv {
         fixed_var_eh(v);
     }
 
+    /**
+   \brief v[idx] = ~v'[idx], then v /= v' is a theory axiom.
+*/
+    void solver::find_new_diseq_axioms(var_pos_occ* occs, theory_var v, unsigned idx) {
+        literal l = m_bits[v][idx];
+        l.neg();
+        while (occs) {
+            theory_var v2 = occs->m_var;
+            unsigned   idx2 = occs->m_idx;
+            if (idx == idx2 && m_bits[v2][idx2] == l && get_bv_size(v2) == get_bv_size(v))
+                mk_new_diseq_axiom(v, v2, idx);
+            occs = occs->m_next;
+        }
+    }
 
+
+    /**
+       \brief v1[idx] = ~v2[idx], then v1 /= v2 is a theory axiom.
+    */
+    void solver::mk_new_diseq_axiom(theory_var v1, theory_var v2, unsigned idx) {
+        if (!get_config().m_bv_eq_axioms)
+            return;
+
+        // TBD: disabled until new literal creation is supported
+        return;
+        SASSERT(m_bits[v1][idx] == ~m_bits[v2][idx]);
+        TRACE("bv_solver", tout << "found new diseq axiom\n" << pp(v1) << pp(v2);); 
+        m_stats.m_num_diseq_static++;
+        expr_ref eq(m.mk_eq(get_expr(v1), get_expr(v2)), m);
+        sat::literal not_eq = ctx.internalize(eq, true, false, m_is_redundant);
+        s().add_clause(1, &not_eq, sat::status::th(m_is_redundant, get_id()));
+    }
+
+    std::ostream& solver::display(std::ostream& out, theory_var v) const {
+        out << "v";
+        out.width(4);
+        out << std::left << v;
+        out << " #";
+        out.width(4);
+        out << get_enode(v)->get_owner_id() << " -> #";
+        out.width(4);
+#if 0
+        out << get_enode(find(v))->get_owner_id();
+        out << std::right << ", bits:";
+        literal_vector const& bits = m_bits[v];
+        for (literal lit : bits) {
+            out << " " << lit << ":";
+            ctx.display_literal(out, lit);
+        }
+        numeral val;
+        if (get_fixed_value(v, val))
+            out << ", value: " << val;
+        out << "\n";
+#endif
+        return out;
+    }
 }
diff --git a/src/sat/smt/bv_solver.h b/src/sat/smt/bv_solver.h
index 393f3d40f..b0c101ad4 100644
--- a/src/sat/smt/bv_solver.h
+++ b/src/sat/smt/bv_solver.h
@@ -34,6 +34,14 @@ namespace bv {
         typedef map<value_sort_pair, theory_var, value_sort_pair_hash, default_eq<value_sort_pair> > value2var;
         typedef union_find<solver>  th_union_find;
 
+        struct stats {
+            unsigned   m_num_diseq_static, m_num_diseq_dynamic, m_num_bit2core, m_num_th2core_eq, m_num_conflicts;
+            unsigned   m_num_eq_dynamic;
+            void reset() { memset(this, 0, sizeof(stats)); }
+            stats() { reset(); }
+        };
+
+
         /**
            \brief Structure used to store the position of a bitvector variable that
            contains the true_literal/false_literal.
@@ -86,25 +94,36 @@ namespace bv {
             bool is_bit() const override { return false; }
         };
 
+        friend class add_var_pos_trail;
+        friend class mk_atom_trail;
+        typedef ptr_vector<atom> bool_var2atom;
+
         bv_util                  bv;
         arith_util               m_autil;
+        stats                    m_stats;
         bit_blaster              m_bb;
         th_union_find            m_find;
         vector<literal_vector>   m_bits;     // per var, the bits of a given variable.
         ptr_vector<expr>         m_bits_expr;
         svector<unsigned>        m_wpos;     // per var, watch position for fixed variable detection. 
         vector<zero_one_bits>    m_zero_one_bits; // per var, see comment in the struct zero_one_bit
-//        bool_var2atom            m_bool_var2atom;
+        bool_var2atom            m_bool_var2atom;
         sat::solver* m_solver;
         sat::solver& s() { return *m_solver;  }
 
         // internalize:
+
+        void insert_bv2a(bool_var bv, atom * a) { m_bool_var2atom.setx(bv, a, 0); }
+        void erase_bv2a(bool_var bv) { m_bool_var2atom[bv] = 0; }
+        atom * get_bv2a(bool_var bv) const { return m_bool_var2atom.get(bv, 0); }
+
         sat::literal false_literal;
         sat::literal true_literal;
         bool visit(expr* e) override;
         bool visited(expr* e) override;
         bool post_visit(expr* e, bool sign, bool root) override;
         unsigned get_bv_size(euf::enode* n);
+        unsigned get_bv_size(theory_var v);
         euf::enode* mk_enode(app* n, ptr_vector<euf::enode> const& args);
         void fixed_var_eh(theory_var v);
         void register_true_false_bit(theory_var v, unsigned i);
@@ -147,7 +166,11 @@ namespace bv {
         void internalize_smul_no_overflow(app *n);
         void internalize_smul_no_underflow(app *n);
 
+        // solving
         void find_wpos(theory_var v);
+        void find_new_diseq_axioms(var_pos_occ* occs, theory_var v, unsigned idx);
+        void mk_new_diseq_axiom(theory_var v1, theory_var v2, unsigned idx);
+
 
     public:
         solver(euf::solver& ctx);
@@ -189,7 +212,14 @@ namespace bv {
         sat::literal internalize(expr* e, bool sign, bool root, bool learned) override;
         euf::theory_var mk_var(euf::enode* n) override;
 
+
+        // disagnostics
+        std::ostream& display(std::ostream& out, theory_var v) const;
+        typedef std::pair<solver const*, theory_var> pp_var;
+        pp_var pp(theory_var v) const { return pp_var(this, v); }
     };
 
+    inline std::ostream& operator<<(std::ostream& out, solver::pp_var const& p) { return p.first->display(out, p.second); }
+
 
 }