From 3011b34b3b87ca0bfe20cc7c7471dfd1d2a127a4 Mon Sep 17 00:00:00 2001
From: Nikolaj Bjorner <nbjorner@microsoft.com>
Date: Wed, 31 Aug 2022 18:59:02 -0700
Subject: [PATCH] log E-matching based quantifier instantiations as hints

---
 src/cmd_context/extra_cmds/proof_cmds.cpp |  2 ++
 src/sat/smt/bv_solver.cpp                 |  3 +++
 src/sat/smt/euf_proof.cpp                 |  2 ++
 src/sat/smt/q_ematch.cpp                  |  7 +++++--
 src/sat/smt/q_solver.cpp                  | 24 +++++++++++++++++++++++
 src/sat/smt/q_solver.h                    |  9 +++++++++
 6 files changed, 45 insertions(+), 2 deletions(-)

diff --git a/src/cmd_context/extra_cmds/proof_cmds.cpp b/src/cmd_context/extra_cmds/proof_cmds.cpp
index 8d6b6ad8e..fe75acc84 100644
--- a/src/cmd_context/extra_cmds/proof_cmds.cpp
+++ b/src/cmd_context/extra_cmds/proof_cmds.cpp
@@ -166,6 +166,8 @@ public:
         }
         m_solver->pop(1);
         std::cout << "(verified-smt)\n";
+        if (proof_hint)
+            std::cout << "(missed-hint " << mk_pp(proof_hint, m) << ")\n";
         add_clause(clause);
     }
 
diff --git a/src/sat/smt/bv_solver.cpp b/src/sat/smt/bv_solver.cpp
index 7156058c7..4a14067d1 100644
--- a/src/sat/smt/bv_solver.cpp
+++ b/src/sat/smt/bv_solver.cpp
@@ -442,6 +442,9 @@ namespace bv {
         }
         ctx.get_drat().add(lits, status());
         // TBD, a proper way would be to delete the lemma after use.
+        ctx.set_tmp_bool_var(leq1.var(), nullptr);
+        ctx.set_tmp_bool_var(leq2.var(), nullptr);
+
     }
 
     void solver::asserted(literal l) {
diff --git a/src/sat/smt/euf_proof.cpp b/src/sat/smt/euf_proof.cpp
index 14671d55a..6c90b3c66 100644
--- a/src/sat/smt/euf_proof.cpp
+++ b/src/sat/smt/euf_proof.cpp
@@ -100,6 +100,8 @@ namespace euf {
         if (jst.eq_consequent().first != nullptr) 
             lits.push_back(add_lit(jst.eq_consequent()));
         get_drat().add(lits, sat::status::th(m_is_redundant, jst.ext().get_id(), jst.get_pragma()));
+        for (unsigned i = s().num_vars(); i < nv; ++i)
+            set_tmp_bool_var(i, nullptr);
     }
 
     void solver::on_clause(unsigned n, literal const* lits, sat::status st) {
diff --git a/src/sat/smt/q_ematch.cpp b/src/sat/smt/q_ematch.cpp
index be9d952cd..490bce46e 100644
--- a/src/sat/smt/q_ematch.cpp
+++ b/src/sat/smt/q_ematch.cpp
@@ -381,9 +381,12 @@ namespace q {
         sat::literal_vector lits;
         lits.push_back(~j.m_clause.m_literal);
         for (unsigned i = 0; i < j.m_clause.size(); ++i) 
-            lits.push_back(instantiate(j.m_clause, j.m_binding, j.m_clause[i]));            
+            lits.push_back(instantiate(j.m_clause, j.m_binding, j.m_clause[i])); 
         m_qs.log_instantiation(lits, &j);
-        m_qs.add_clause(lits);               
+        euf::th_proof_hint* ph = nullptr;
+        if (ctx.use_drat()) 
+            ph = q_proof_hint::mk(ctx, j.m_clause.size(), j.m_binding);
+        m_qs.add_clause(lits, ph);               
     }
 
     bool ematch::flush_prop_queue() {
diff --git a/src/sat/smt/q_solver.cpp b/src/sat/smt/q_solver.cpp
index f40aa76c8..565b2536b 100644
--- a/src/sat/smt/q_solver.cpp
+++ b/src/sat/smt/q_solver.cpp
@@ -363,4 +363,28 @@ namespace q {
             ctx.on_instantiation(n, lits, j ? j->m_clause.num_decls() : 0, j ? j->m_binding : nullptr);
         }
     }
+
+    q_proof_hint* q_proof_hint::mk(euf::solver& s, unsigned n, euf::enode* const* bindings) {
+        auto* mem = s.get_region().allocate(q_proof_hint::get_obj_size(n));
+        q_proof_hint* ph = new (mem) q_proof_hint();
+        ph->m_num_bindings = n;
+        for (unsigned i = 0; i < n; ++i)
+            ph->m_bindings[i] = bindings[i];
+        return ph;
+    }
+    
+    expr* q_proof_hint::get_hint(euf::solver& s) const {
+        ast_manager& m = s.get_manager();
+        expr_ref_vector args(m);
+        sort_ref_vector sorts(m);
+        for (unsigned i = 0; i < m_num_bindings; ++i) {
+            args.push_back(m_bindings[i]->get_expr());
+            sorts.push_back(args.back()->get_sort());
+        }
+        sort* range = m.mk_proof_sort();
+        func_decl* d = m.mk_func_decl(symbol("inst"), args.size(), sorts.data(), range);
+        expr* r = m.mk_app(d, args);
+        return r;
+    }
+
 }
diff --git a/src/sat/smt/q_solver.h b/src/sat/smt/q_solver.h
index 755fc6d16..e34da020d 100644
--- a/src/sat/smt/q_solver.h
+++ b/src/sat/smt/q_solver.h
@@ -29,6 +29,15 @@ namespace euf {
 
 namespace q {
 
+    struct q_proof_hint : public euf::th_proof_hint {
+        unsigned     m_num_bindings;
+        euf::enode* m_bindings[0];
+        q_proof_hint() {}
+        static size_t get_obj_size(unsigned num_bindings) { return sizeof(q_proof_hint) + num_bindings*sizeof(euf::enode*); }
+        static q_proof_hint* mk(euf::solver& s, unsigned n, euf::enode* const* bindings);
+        expr* get_hint(euf::solver& s) const override;
+    };
+
     class solver : public euf::th_euf_solver {
 
         typedef obj_map<quantifier, quantifier*> flat_table;