From 49faaaa8f10b4fdac35a5386e4c0beb6b8eb1bc0 Mon Sep 17 00:00:00 2001
From: Nikolaj Bjorner <nbjorner@microsoft.com>
Date: Tue, 23 May 2017 15:01:00 -0700
Subject: [PATCH] allowing non-literal assumptions

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
---
 src/sat/sat_solver/inc_sat_solver.cpp | 35 ++++++++++++++++++++++++---
 src/smt/theory_lra.cpp                |  2 +-
 2 files changed, 32 insertions(+), 5 deletions(-)

diff --git a/src/sat/sat_solver/inc_sat_solver.cpp b/src/sat/sat_solver/inc_sat_solver.cpp
index 83c31715d..4a4e0af38 100644
--- a/src/sat/sat_solver/inc_sat_solver.cpp
+++ b/src/sat/sat_solver/inc_sat_solver.cpp
@@ -130,13 +130,36 @@ public:
         m_solver.display_wcnf(out, m_asms.size(), m_asms.c_ptr(), nweights.c_ptr());
     }
 
+    bool is_literal(expr* e) const {
+        return 
+            is_uninterp_const(e) ||
+            (m.is_not(e, e) && is_uninterp_const(e));
+    }
+
     virtual lbool check_sat(unsigned sz, expr * const * assumptions) {
         m_solver.pop_to_base_level();
+        expr_ref_vector _assumptions(m);
+        obj_map<expr, expr*> asm2fml;
+        for (unsigned i = 0; i < sz; ++i) {
+            if (!is_literal(assumptions[i])) {
+                expr_ref a(m.mk_fresh_const("s", m.mk_bool_sort()), m);
+                expr_ref fml(m.mk_eq(a, assumptions[i]), m);
+                assert_expr(fml);
+                _assumptions.push_back(a);
+                asm2fml.insert(a, assumptions[i]);
+            }
+            else {
+                _assumptions.push_back(assumptions[i]);
+                asm2fml.insert(assumptions[i], assumptions[i]);
+            }
+        }
+        
+        TRACE("sat", tout << _assumptions << "\n";);
         dep2asm_t dep2asm;
         m_model = 0;
         lbool r = internalize_formulas();
         if (r != l_true) return r;
-        r = internalize_assumptions(sz, assumptions, dep2asm);
+        r = internalize_assumptions(sz, _assumptions.c_ptr(), dep2asm);
         if (r != l_true) return r;
 
         r = m_solver.check(m_asms.size(), m_asms.c_ptr());
@@ -150,7 +173,7 @@ public:
         case l_false:
             // TBD: expr_dependency core is not accounted for.
             if (!m_asms.empty()) {
-                extract_core(dep2asm);
+                extract_core(dep2asm, asm2fml);
             }
             break;
         default:
@@ -241,6 +264,7 @@ public:
         sat::bool_var_vector bvars;
         vector<sat::literal_vector> lconseq;
         dep2asm_t dep2asm;
+        obj_map<expr, expr*> asm2fml;
         m_solver.pop_to_base_level();
         lbool r = internalize_formulas();
         if (r != l_true) return r;
@@ -251,7 +275,7 @@ public:
         r = m_solver.get_consequences(m_asms, bvars, lconseq);
         if (r == l_false) {
             if (!m_asms.empty()) {
-                extract_core(dep2asm);
+                extract_core(dep2asm, asm2fml);
             }
             return r;
         }
@@ -569,7 +593,7 @@ private:
         }
     }
 
-    void extract_core(dep2asm_t& dep2asm) {
+    void extract_core(dep2asm_t& dep2asm, obj_map<expr, expr*> const& asm2fml) {
         u_map<expr*> asm2dep;
         extract_asm2dep(dep2asm, asm2dep);
         sat::literal_vector const& core = m_solver.get_core();
@@ -590,6 +614,9 @@ private:
         for (unsigned i = 0; i < core.size(); ++i) {
             expr* e = 0;
             VERIFY(asm2dep.find(core[i].index(), e));
+            if (asm2fml.contains(e)) {
+                e = asm2fml.find(e);
+            }
             m_core.push_back(e);
         }
     }
diff --git a/src/smt/theory_lra.cpp b/src/smt/theory_lra.cpp
index 05aa33d13..a5a34a079 100644
--- a/src/smt/theory_lra.cpp
+++ b/src/smt/theory_lra.cpp
@@ -487,7 +487,7 @@ namespace smt {
                 result = m_theory_var2var_index[v];
             }
             if (result == UINT_MAX) {
-                result = m_solver->add_var(v);
+                result = m_solver->add_var(v); // TBD: is_int(v);
                 m_theory_var2var_index.setx(v, result, UINT_MAX);
                 m_var_index2theory_var.setx(result, v, UINT_MAX);
                 m_var_trail.push_back(v);