From 33f4e65fa919349501e7511669131f6742fc6b1b Mon Sep 17 00:00:00 2001
From: Nikolaj Bjorner <nbjorner@microsoft.com>
Date: Tue, 5 Oct 2021 10:15:56 -0700
Subject: [PATCH] redo bindings/fingerprints

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
---
 src/sat/smt/q_clause.h      |  87 ++++++++++++++++++++++-------
 src/sat/smt/q_ematch.cpp    | 107 ++++++++++++++++++------------------
 src/sat/smt/q_ematch.h      |  14 ++---
 src/sat/smt/q_fingerprint.h |  77 --------------------------
 src/sat/smt/q_queue.cpp     |  22 ++++----
 src/sat/smt/q_queue.h       |  18 +++---
 6 files changed, 145 insertions(+), 180 deletions(-)
 delete mode 100644 src/sat/smt/q_fingerprint.h

diff --git a/src/sat/smt/q_clause.h b/src/sat/smt/q_clause.h
index 66daf07ea..08a6f615a 100644
--- a/src/sat/smt/q_clause.h
+++ b/src/sat/smt/q_clause.h
@@ -22,6 +22,7 @@ Author:
 #include "ast/euf/euf_enode.h"
 #include "sat/smt/euf_solver.h"
 
+
 namespace q {
 
     struct lit {
@@ -35,14 +36,40 @@ namespace q {
         std::ostream& display(std::ostream& out) const;
     };
 
+    struct binding;
+
+    struct clause {
+        unsigned            m_index;
+        vector<lit>         m_lits;
+        quantifier_ref      m_q;
+        unsigned            m_watch = 0;
+        sat::literal        m_literal = sat::null_literal;
+        q::quantifier_stat* m_stat = nullptr;
+        binding* m_bindings = nullptr;
+
+
+        clause(ast_manager& m, unsigned idx) : m_index(idx), m_q(m) {}
+
+        std::ostream& display(euf::solver& ctx, std::ostream& out) const;
+        lit const& operator[](unsigned i) const { return m_lits[i]; }
+        lit& operator[](unsigned i) { return m_lits[i]; }
+        unsigned size() const { return m_lits.size(); }
+        unsigned num_decls() const { return m_q->get_num_decls(); }
+        unsigned index() const { return m_index; }
+        quantifier* q() const { return m_q; }
+    };
+
+
     struct binding : public dll_base<binding> {
+        clause*            c;
         app*               m_pattern;
         unsigned           m_max_generation;
         unsigned           m_min_top_generation;
         unsigned           m_max_top_generation;
         euf::enode*        m_nodes[0];
 
-        binding(app* pat, unsigned max_generation, unsigned min_top, unsigned max_top):
+        binding(clause& c, app* pat, unsigned max_generation, unsigned min_top, unsigned max_top):
+            c(&c),
             m_pattern(pat),
             m_max_generation(max_generation),
             m_min_top_generation(min_top),
@@ -53,29 +80,49 @@ namespace q {
         euf::enode* operator[](unsigned i) const { return m_nodes[i]; }
 
         std::ostream& display(euf::solver& ctx, unsigned num_nodes, std::ostream& out) const;
+
+        unsigned size() const { return c->num_decls(); }
+        
+        quantifier* q() const { return c->m_q; }
+
+        bool eq(binding const& other) const {
+            if (q() != other.q())
+                return false;
+            for (unsigned i = size(); i-- > 0; )
+                if ((*this)[i] != other[i])
+                    return false;
+            return true;
+        }
     };
 
-    struct clause {
-        unsigned            m_index;
-        vector<lit>         m_lits;
-        quantifier_ref      m_q;
-        unsigned            m_watch = 0;
-        sat::literal        m_literal = sat::null_literal;
-        q::quantifier_stat* m_stat = nullptr;
-        binding*            m_bindings = nullptr;
-
-
-        clause(ast_manager& m, unsigned idx): m_index(idx), m_q(m) {}
-
-        std::ostream& display(euf::solver& ctx, std::ostream& out) const;
-        lit const& operator[](unsigned i) const { return m_lits[i]; }
-        lit& operator[](unsigned i) { return m_lits[i]; }
-        unsigned size() const { return m_lits.size(); }
-        unsigned num_decls() const { return m_q->get_num_decls(); }
-        unsigned index() const { return m_index; }
-        quantifier* q() const { return m_q; }
+    struct binding_khasher {
+        unsigned operator()(binding const* f) const { return f->q()->get_id(); }
     };
 
+    struct binding_chasher {
+        unsigned operator()(binding const* f, unsigned idx) const { return f->m_nodes[idx]->hash(); }
+    };
+
+    struct binding_hash_proc {
+        unsigned operator()(binding const* f) const {
+            return get_composite_hash<binding*, binding_khasher, binding_chasher>(const_cast<binding*>(f), f->size());
+        }
+    };
+
+    struct binding_eq_proc {
+        bool operator()(binding const* a, binding const* b) const { return a->eq(*b); }
+    };
+
+    typedef ptr_hashtable<binding, binding_hash_proc, binding_eq_proc> bindings;
+
+    inline std::ostream& operator<<(std::ostream& out, binding const& f) {
+        out << "[fp " << f.q()->get_id() << ":";
+        for (unsigned i = 0; i < f.size(); ++i)
+            out << " " << f[i]->get_expr_id();
+        return out << "]";
+    }
+
+
     struct justification {
         expr*     m_lhs, *m_rhs;
         bool      m_sign;
diff --git a/src/sat/smt/q_ematch.cpp b/src/sat/smt/q_ematch.cpp
index 56ab26093..1cfc1f678 100644
--- a/src/sat/smt/q_ematch.cpp
+++ b/src/sat/smt/q_ematch.cpp
@@ -219,18 +219,48 @@ namespace q {
     };
 
 
+
+    binding* ematch::tmp_binding(clause& c, app* pat, euf::enode* const* b) {
+        if (c.num_decls() > m_tmp_binding_capacity) {
+            void* mem = memory::allocate(sizeof(binding) + c.num_decls() * sizeof(euf::enode*));
+            m_tmp_binding = new (mem) binding(c, pat, 0, 0, 0);
+            m_tmp_binding_capacity = c.num_decls();
+        }
+
+        for (unsigned i = c.num_decls(); i-- > 0; )
+            m_tmp_binding->m_nodes[i] = b[i];
+        m_tmp_binding->m_pattern = pat;
+        m_tmp_binding->c = &c;
+
+        return m_tmp_binding.get();
+    }
+
     binding* ematch::alloc_binding(clause& c, app* pat, euf::enode* const* _binding, unsigned max_generation, unsigned min_top, unsigned max_top) {
+        binding* b = tmp_binding(c, pat, _binding);
+
+        if (m_bindings.contains(b))
+            return nullptr;
+
+        for (unsigned i = c.num_decls(); i-- > 0; )
+            b->m_nodes[i] = b->m_nodes[i]->get_root();
+
+        if (m_bindings.contains(b))
+            return nullptr;
+
         unsigned n = c.num_decls();
         unsigned sz = sizeof(binding) + sizeof(euf::enode* const*) * n;
         void* mem = ctx.get_region().allocate(sz);
-        binding* b = new (mem) binding(pat, max_generation, min_top, max_top);
+        b = new (mem) binding(c, pat, max_generation, min_top, max_top);
         b->init(b);
         for (unsigned i = 0; i < n; ++i)
             b->m_nodes[i] = _binding[i];
+
+        m_bindings.insert(b);
+        ctx.push(insert_map<bindings, binding*>(m_bindings, b));
         return b;
     }
 
-    euf::enode* const* ematch::alloc_binding(clause& c, euf::enode* const* _binding) {
+    euf::enode* const* ematch::alloc_nodes(clause& c, euf::enode* const* _binding) {
         unsigned sz = sizeof(euf::enode* const*) * c.num_decls();
         euf::enode** binding = (euf::enode**)ctx.get_region().allocate(sz);
         for (unsigned i = 0; i < c.num_decls(); ++i)
@@ -244,8 +274,7 @@ namespace q {
         clause& c = *m_clauses[idx];
         bool new_propagation = false;
         binding* b = alloc_binding(c, pat, _binding, max_generation, min_gen, max_gen);
-        fingerprint* f = add_fingerprint(c, *b, max_generation);
-        if (!f)
+        if (!b)
             return;
 
         if (propagate(false, _binding, max_generation, c, new_propagation))
@@ -276,7 +305,7 @@ namespace q {
         if (ev == l_undef && max_generation > m_generation_propagation_threshold)
             return false;
         if (!is_owned) 
-            binding = alloc_binding(c, binding); 
+            binding = alloc_nodes(c, binding); 
 
         auto j_idx = mk_justification(idx, c, binding);     
 
@@ -312,17 +341,14 @@ namespace q {
         return true;
     }
 
-    void ematch::instantiate(binding& b, clause& c) {
+    void ematch::instantiate(binding& b) {
         if (m_stats.m_num_instantiations > ctx.get_config().m_qi_max_instances) 
             return;
         unsigned max_generation = b.m_max_generation;
-        max_generation = std::max(max_generation, c.m_stat->get_generation());
-        c.m_stat->update_max_generation(max_generation);
-        fingerprint * f = add_fingerprint(c, b, max_generation);
-        if (!f)
-            return;
-        m_inst_queue.insert(f);
-        m_stats.m_num_instantiations++;        
+        max_generation = std::max(max_generation, b.c->m_stat->get_generation());
+        b.c->m_stat->update_max_generation(max_generation);
+        m_stats.m_num_instantiations++;     
+        m_inst_queue.insert(&b);
     }
 
     void ematch::add_instantiation(clause& c, binding& b, sat::literal lit) {
@@ -330,35 +356,6 @@ namespace q {
         ctx.propagate(lit, mk_justification(UINT_MAX, c, b.nodes()));
     }
 
-    void ematch::set_tmp_binding(fingerprint& fp) {               
-        binding& b = *fp.b;
-        clause& c = *fp.c;
-        if (c.num_decls() > m_tmp_binding_capacity) {
-            void* mem = memory::allocate(sizeof(binding) + c.num_decls()*sizeof(euf::enode*));
-            m_tmp_binding = new (mem) binding(b.m_pattern, 0, 0, 0);
-            m_tmp_binding_capacity = c.num_decls();            
-        }
-
-        fp.b = m_tmp_binding.get();
-        for (unsigned i = c.num_decls(); i-- > 0; )
-            fp.b->m_nodes[i] = b[i];
-    }
-
-    fingerprint* ematch::add_fingerprint(clause& c, binding& b, unsigned max_generation) {
-        fingerprint fp(c, b, max_generation);        
-        if (m_fingerprints.contains(&fp))
-            return nullptr;
-        set_tmp_binding(fp);
-        for (unsigned i = c.num_decls(); i-- > 0; )
-            fp.b->m_nodes[i] = fp.b->m_nodes[i]->get_root();
-        if (m_fingerprints.contains(&fp))
-            return nullptr;
-        fingerprint* f = new (ctx.get_region()) fingerprint(c, b, max_generation);
-        m_fingerprints.insert(f);
-        ctx.push(insert_map<fingerprints, fingerprint*>(m_fingerprints, f));
-        return f;
-    }
-
     sat::literal ematch::instantiate(clause& c, euf::enode* const* binding, lit const& l) {
         expr_ref_vector _binding(m);
         for (unsigned i = 0; i < c.num_decls(); ++i)
@@ -552,6 +549,7 @@ namespace q {
 
 
     bool ematch::unit_propagate() {
+        return false;
         return ctx.get_config().m_ematching && propagate(false);
     }
 
@@ -569,12 +567,13 @@ namespace q {
             if (!b)
                 continue;
 
-            do {
+            do {                
                 if (propagate(true, b->m_nodes, b->m_max_generation, c, propagated)) 
                     to_remove.push_back(b);
                 else if (flush) {
-                    instantiate(*b, c);
+                    instantiate(*b);
                     to_remove.push_back(b);
+                    propagated = true;
                 }
                 b = b->next();
             } 
@@ -600,21 +599,21 @@ namespace q {
         TRACE("q", m_mam->display(tout););
         if (propagate(false))
             return true;
-        if (m_lazy_mam) {
+        if (m_lazy_mam) 
             m_lazy_mam->propagate();
-            if (propagate(false))
-                return true;
-        }
-        unsigned idx = 0;
-        for (clause* c : m_clauses) {
-            if (c->m_bindings) 
-                insert_clause_in_queue(idx);
-            idx++;
-        }
+        if (propagate(false))
+            return true;        
+        for (unsigned i = 0; i < m_clauses.size(); ++i)
+            if (m_clauses[i]->m_bindings)
+                insert_clause_in_queue(i);
         if (propagate(true))
             return true;
         if (m_inst_queue.lazy_propagate())
             return true;
+        for (unsigned i = 0; i < m_clauses.size(); ++i)
+            if (m_clauses[i]->m_bindings)
+                std::cout << "missed propagation " << i << "\n";
+        TRACE("q", tout << "no more propagation\n";);
         return false;
     }
 
diff --git a/src/sat/smt/q_ematch.h b/src/sat/smt/q_ematch.h
index fbedbd65a..bd79511a8 100644
--- a/src/sat/smt/q_ematch.h
+++ b/src/sat/smt/q_ematch.h
@@ -23,7 +23,6 @@ Author:
 #include "sat/smt/sat_th.h"
 #include "sat/smt/q_mam.h"
 #include "sat/smt/q_clause.h"
-#include "sat/smt/q_fingerprint.h"
 #include "sat/smt/q_queue.h"
 #include "sat/smt/q_eval.h"
 
@@ -69,7 +68,7 @@ namespace q {
         ast_manager&                  m;
         eval                          m_eval;
         quantifier_stat_gen           m_qstat_gen;
-        fingerprints                  m_fingerprints;
+        bindings                      m_bindings;
         scoped_ptr<binding>           m_tmp_binding;
         unsigned                      m_tmp_binding_capacity = 0;
         queue                         m_inst_queue;
@@ -90,16 +89,16 @@ namespace q {
         unsigned_vector               m_clause_queue;
         euf::enode_pair_vector        m_evidence;
 
-        euf::enode* const* alloc_binding(clause& c, euf::enode* const* _binding);
-        binding* alloc_binding(clause& c, app* pat, euf::enode* const* _bidning, unsigned max_generation, unsigned min_top, unsigned max_top);
-        void add_binding(clause& c, app* pat, euf::enode* const* _binding, unsigned max_generation, unsigned min_top, unsigned max_top);
+        euf::enode* const* alloc_nodes(clause& c, euf::enode* const* _binding);
+        binding* tmp_binding(clause& c, app* pat, euf::enode* const* _binding);
+        binding* alloc_binding(clause& c, app* pat, euf::enode* const* _binding, unsigned max_generation, unsigned min_top, unsigned max_top);
         
         sat::ext_justification_idx mk_justification(unsigned idx, clause& c, euf::enode* const* b);
 
         void ensure_ground_enodes(expr* e);
         void ensure_ground_enodes(clause const& c);
 
-        void instantiate(binding& b, clause& c);
+        void instantiate(binding& b);
         sat::literal instantiate(clause& c, euf::enode* const* binding, lit const& l);
 
         // register as callback into egraph.
@@ -115,9 +114,6 @@ namespace q {
         clause* clausify(quantifier* q);
         lit clausify_literal(expr* arg);
 
-        fingerprint* add_fingerprint(clause& c, binding& b, unsigned max_generation);
-        void set_tmp_binding(fingerprint& fp);
-
         bool flush_prop_queue();
         void propagate(bool is_conflict, unsigned idx, sat::ext_justification_idx j_idx);
 
diff --git a/src/sat/smt/q_fingerprint.h b/src/sat/smt/q_fingerprint.h
deleted file mode 100644
index 99ad602b9..000000000
--- a/src/sat/smt/q_fingerprint.h
+++ /dev/null
@@ -1,77 +0,0 @@
-/*++
-Copyright (c) 2020 Microsoft Corporation
-
-Module Name:
-
-    q_fingerprint.h
-
-Abstract:
-
-    Fingerprint summary of a quantifier instantiation
-
-Author:
-
-    Nikolaj Bjorner (nbjorner) 2021-01-24
-
---*/
-#pragma once
-
-#include "util/hashtable.h"
-#include "ast/ast.h"
-#include "ast/quantifier_stat.h"
-#include "ast/euf/euf_enode.h"
-#include "sat/smt/q_clause.h"
-
-
-namespace q {
-
-    struct fingerprint {
-        clause*          c;
-        binding*         b;
-        unsigned         m_max_generation;
-        
-        unsigned size() const { return c->num_decls(); }
-        euf::enode* const* nodes() const { return b->nodes(); }
-        quantifier* q() const { return c->m_q; }
-        
-        fingerprint(clause& _c, binding& _b, unsigned mg) :
-            c(&_c), b(&_b), m_max_generation(mg) {}
-        
-        bool eq(fingerprint const& other) const {
-            if (c->m_q != other.c->m_q)
-                return false;
-            for (unsigned i = size(); i--> 0; ) 
-                if ((*b)[i] != (*other.b)[i])
-                    return false;
-            return true;
-        }
-    };
-
-    struct fingerprint_khasher {
-        unsigned operator()(fingerprint const * f) const { return f->c->m_q->get_id(); }
-    };
-
-    struct fingerprint_chasher {
-        unsigned operator()(fingerprint const * f, unsigned idx) const { return f->b->m_nodes[idx]->hash(); }
-    };
-
-    struct fingerprint_hash_proc {
-        unsigned operator()(fingerprint const * f) const {
-            return get_composite_hash<fingerprint *, fingerprint_khasher, fingerprint_chasher>(const_cast<fingerprint*>(f), f->size());
-        }
-    };
-    
-    struct fingerprint_eq_proc {
-        bool operator()(fingerprint const* a, fingerprint const* b) const { return a->eq(*b); }
-    };
-
-    typedef ptr_hashtable<fingerprint, fingerprint_hash_proc, fingerprint_eq_proc> fingerprints;    
-
-    inline std::ostream& operator<<(std::ostream& out, fingerprint const& f) {
-        out << "[fp " << f.q()->get_id() << ":";
-        for (unsigned i = 0; i < f.size(); ++i)
-            out << " " << (*f.b)[i]->get_expr_id();
-        return out << "]";
-    }
-
-}
diff --git a/src/sat/smt/q_queue.cpp b/src/sat/smt/q_queue.cpp
index 247451fb4..2e8db482f 100644
--- a/src/sat/smt/q_queue.cpp
+++ b/src/sat/smt/q_queue.cpp
@@ -86,13 +86,13 @@ namespace q {
         m_parser.add_var("cs_factor");
     }
 
-    void queue::set_values(fingerprint& f, float cost) {
+    void queue::set_values(binding& f, float cost) {
         quantifier_stat * stat  = f.c->m_stat;
         quantifier* q = f.q();
-        app* pat = f.b->m_pattern;
+        app* pat = f.m_pattern;
         m_vals[COST]               = cost;
-        m_vals[MIN_TOP_GENERATION] = static_cast<float>(f.b->m_min_top_generation);
-        m_vals[MAX_TOP_GENERATION] = static_cast<float>(f.b->m_max_top_generation);
+        m_vals[MIN_TOP_GENERATION] = static_cast<float>(f.m_min_top_generation);
+        m_vals[MAX_TOP_GENERATION] = static_cast<float>(f.m_max_top_generation);
         m_vals[INSTANCES]          = static_cast<float>(stat->get_num_instances_curr_branch());
         m_vals[SIZE]               = static_cast<float>(stat->get_size());
         m_vals[DEPTH]              = static_cast<float>(stat->get_depth());
@@ -108,14 +108,14 @@ namespace q {
         TRACE("q_detail", for (unsigned i = 0; i < m_vals.size(); i++) { tout << m_vals[i] << " "; } tout << "\n";);
     }
 
-    float queue::get_cost(fingerprint& f) {
+    float queue::get_cost(binding& f) {
         set_values(f, 0);
         float r = m_evaluator(m_cost_function, m_vals.size(), m_vals.data());
         f.c->m_stat->update_max_cost(r);
         return r;
     }
 
-    unsigned queue::get_new_gen(fingerprint& f, float cost) {
+    unsigned queue::get_new_gen(binding& f, float cost) {
         set_values(f, cost);
         float r = m_evaluator(m_new_gen_function, m_vals.size(), m_vals.data());
         return std::max(f.m_max_generation + 1, static_cast<unsigned>(r));
@@ -129,7 +129,7 @@ namespace q {
         }
     };
 
-    void queue::insert(fingerprint* f) {
+    void queue::insert(binding* f) {
         float cost = get_cost(*f);
         if (m_new_entries.empty()) 
             ctx.push(reset_new_entries(m_new_entries));
@@ -137,7 +137,7 @@ namespace q {
     }
 
     void queue::instantiate(entry& ent) {
-        fingerprint & f          = *ent.m_qb;
+        binding& f               = *ent.m_qb;
         quantifier * q           = f.q();
         unsigned num_bindings    = f.size();
         quantifier_stat * stat   = f.c->m_stat;
@@ -151,7 +151,7 @@ namespace q {
 
         auto* ebindings = m_subst(q, num_bindings);
         for (unsigned i = 0; i < num_bindings; ++i)
-            ebindings[i] = f.nodes()[i]->get_expr();
+            ebindings[i] = f[i]->get_expr();
         expr_ref instance = m_subst();
         ctx.get_rewriter()(instance);
         if (m.is_true(instance)) {
@@ -164,7 +164,7 @@ namespace q {
         
         euf::solver::scoped_generation _sg(ctx, gen);
         sat::literal result_l = ctx.mk_literal(instance);
-        em.add_instantiation(*f.c, *f.b, result_l);
+        em.add_instantiation(*f.c, f, result_l);
     }
 
     bool queue::propagate() {
@@ -178,7 +178,7 @@ namespace q {
             if (0 == since_last_check && ctx.resource_limits_exceeded()) 
                 break;
 
-            fingerprint& f = *curr.m_qb;
+            binding& f = *curr.m_qb;
 
             if (curr.m_cost <= m_eager_cost_threshold) 
                 instantiate(curr);
diff --git a/src/sat/smt/q_queue.h b/src/sat/smt/q_queue.h
index c23cb0377..3750ee31b 100644
--- a/src/sat/smt/q_queue.h
+++ b/src/sat/smt/q_queue.h
@@ -20,7 +20,7 @@ Author:
 #include "ast/cost_evaluator.h"
 #include "ast/rewriter/cached_var_subst.h"
 #include "parsers/util/cost_parser.h"
-#include "sat/smt/q_fingerprint.h"
+#include "sat/smt/q_clause.h"
 
 
 
@@ -51,12 +51,12 @@ namespace q {
         cost_evaluator                m_evaluator;
         cached_var_subst              m_subst;
         svector<float>                m_vals;
-        double                        m_eager_cost_threshold { 0 };
+        double                        m_eager_cost_threshold = 0;
         struct entry {
-            fingerprint * m_qb;
+            binding *     m_qb;
             float         m_cost;
-            bool          m_instantiated{ false };
-            entry(fingerprint * f, float c):m_qb(f), m_cost(c) {}
+            bool          m_instantiated = false;
+            entry(binding * f, float c):m_qb(f), m_cost(c) {}
         };
         struct reset_new_entries;
         struct reset_instantiated;
@@ -64,18 +64,18 @@ namespace q {
         svector<entry>                m_new_entries;
         svector<entry>                m_delayed_entries;
 
-        float get_cost(fingerprint& f);
-        void set_values(fingerprint& f, float cost);
+        float get_cost(binding& f);
+        void set_values(binding& f, float cost);
         void init_parser_vars();
         void setup();
-        unsigned get_new_gen(fingerprint& f, float cost);
+        unsigned get_new_gen(binding& f, float cost);
         void instantiate(entry& e);
 
     public:
 
         queue(ematch& em, euf::solver& ctx);
             
-        void insert(fingerprint* f);
+        void insert(binding* f);
 
         bool propagate();