From bef64961ae985969ca4a02e9d8e11f0aecb49a26 Mon Sep 17 00:00:00 2001
From: Murphy Berzish <murphy.berzish@gmail.com>
Date: Tue, 18 Apr 2017 13:12:03 -0400
Subject: [PATCH] add pre-init assumptions for smt theories

---
 src/smt/smt_context.cpp | 17 ++++++++++++++++-
 src/smt/smt_context.h   | 15 ---------------
 src/smt/smt_setup.cpp   |  2 --
 src/smt/smt_theory.h    |  7 +++++++
 src/smt/theory_str.cpp  | 31 +++++++++++++++++++------------
 src/smt/theory_str.h    |  2 ++
 6 files changed, 44 insertions(+), 30 deletions(-)

diff --git a/src/smt/smt_context.cpp b/src/smt/smt_context.cpp
index dfe396f2b..db09552ef 100644
--- a/src/smt/smt_context.cpp
+++ b/src/smt/smt_context.cpp
@@ -77,7 +77,6 @@ namespace smt {
         m_unknown("unknown"),
         m_unsat_core(m),
         m_use_theory_str_overlap_assumption(false),
-        m_theoryStrOverlapAssumption_term(m_manager),
 #ifdef Z3DEBUG
         m_trail_enabled(true),
 #endif
@@ -3269,6 +3268,7 @@ namespace smt {
             r = l_undef;
         }
 
+        /*
         // PATCH for theory_str:
         // UNSAT + overlapping variables => UNKNOWN
         if (r == l_false && use_theory_str_overlap_assumption()) {
@@ -3304,6 +3304,7 @@ namespace smt {
                 TRACE("t_str", tout << "no overlaps detected in unsat core, answering UNSAT" << std::endl;);
             }
         }
+        */
 
         return r;
     }
@@ -3322,6 +3323,7 @@ namespace smt {
         SASSERT(!m_setup.already_configured());
         setup_context(m_fparams.m_auto_config);
 
+        /*
         // theory_str requires the context to be set up with a special assumption.
         // we need to wait until after setup_context() to know whether this is the case
         if (m_use_theory_str_overlap_assumption) {
@@ -3336,6 +3338,19 @@ namespace smt {
             // this might work, even though we already did a bit of setup
             return check(assumption.size(), assumption.c_ptr(), reset_cancel);
         }
+        */
+
+        expr_ref_vector theory_assumptions(m_manager);
+        ptr_vector<theory>::iterator it  = m_theory_set.begin();
+        ptr_vector<theory>::iterator end = m_theory_set.end();
+        for (; it != end; ++it) {
+            (*it)->add_theory_assumptions(theory_assumptions);
+        }
+        if (!theory_assumptions.empty()) {
+            TRACE("search", tout << "Adding theory assumptions to context" << std::endl;);
+            // this works even though we already did part of setup
+            return check(theory_assumptions.size(), theory_assumptions.c_ptr(), reset_cancel);
+        }
 
         internalize_assertions();
         lbool r = l_undef;
diff --git a/src/smt/smt_context.h b/src/smt/smt_context.h
index 0cf3f8d68..0667f622e 100644
--- a/src/smt/smt_context.h
+++ b/src/smt/smt_context.h
@@ -849,21 +849,6 @@ namespace smt {
          */
         void add_theory_aware_branching_info(bool_var v, double priority, lbool phase);
 
-        // unsat core assumption hint for theory_str
-        void set_use_theory_str_overlap_assumption(bool f) {
-            m_use_theory_str_overlap_assumption = f;
-        }
-
-        bool use_theory_str_overlap_assumption() const {
-            return m_use_theory_str_overlap_assumption;
-        }
-
-        expr_ref get_theory_str_overlap_assumption_term() {
-            return m_theoryStrOverlapAssumption_term;
-        }
-
-    protected:
-        expr_ref m_theoryStrOverlapAssumption_term;
     public:
 
         // helper function for trail
diff --git a/src/smt/smt_setup.cpp b/src/smt/smt_setup.cpp
index fdcf33c0e..78a295e27 100644
--- a/src/smt/smt_setup.cpp
+++ b/src/smt/smt_setup.cpp
@@ -706,7 +706,6 @@ namespace smt {
     }
 
     void setup::setup_QF_S() {
-        m_context.set_use_theory_str_overlap_assumption(true);
         m_context.register_plugin(alloc(smt::theory_mi_arith, m_manager, m_params));
         m_context.register_plugin(alloc(smt::theory_str, m_manager, m_params));
     }
@@ -842,7 +841,6 @@ namespace smt {
 
     void setup::setup_str() {
         setup_arith();
-        m_context.set_use_theory_str_overlap_assumption(true);
         m_context.register_plugin(alloc(theory_str, m_manager, m_params));
     }
 
diff --git a/src/smt/smt_theory.h b/src/smt/smt_theory.h
index cee36535f..e412f2f1b 100644
--- a/src/smt/smt_theory.h
+++ b/src/smt/smt_theory.h
@@ -177,6 +177,13 @@ namespace smt {
         virtual void restart_eh() {
         }
 
+        /**
+           \brief This method is called by smt_context before the search starts to get any
+           extra assumptions the theory wants to use. (see theory_str for an example)
+        */
+        virtual void add_theory_assumptions(expr_ref_vector & assumptions) {
+        }
+
         /**
            \brief This method is invoked before the search starts.
         */
diff --git a/src/smt/theory_str.cpp b/src/smt/theory_str.cpp
index 9d3fef6d7..354589318 100644
--- a/src/smt/theory_str.cpp
+++ b/src/smt/theory_str.cpp
@@ -56,6 +56,7 @@ theory_str::theory_str(ast_manager & m, theory_str_params const & params):
         tmpValTestVarCount(0),
         avoidLoopCut(true),
         loopDetected(false),
+        m_theoryStrOverlapAssumption_term(m),
         contains_map(m),
         string_int_conversion_terms(m),
         totalCacheAccessCount(0),
@@ -3080,7 +3081,7 @@ void theory_str::process_concat_eq_type1(expr * concatAst1, expr * concatAst2) {
 
                 if (!overlapAssumptionUsed) {
                 	overlapAssumptionUsed = true;
-                	assert_implication(ax_l, ctx.get_theory_str_overlap_assumption_term());
+                	assert_implication(ax_l, m_theoryStrOverlapAssumption_term);
                 }
             }
         }
@@ -3143,7 +3144,7 @@ void theory_str::process_concat_eq_type1(expr * concatAst1, expr * concatAst2) {
 
                 if (!overlapAssumptionUsed) {
                 	overlapAssumptionUsed = true;
-                	assert_implication(ax_l, ctx.get_theory_str_overlap_assumption_term());
+                	assert_implication(ax_l, m_theoryStrOverlapAssumption_term);
                 }
             }
         }
@@ -3199,7 +3200,7 @@ void theory_str::process_concat_eq_type1(expr * concatAst1, expr * concatAst2) {
 
                 if (!overlapAssumptionUsed) {
                 	overlapAssumptionUsed = true;
-                	arrangement_disjunction.push_back(ctx.get_theory_str_overlap_assumption_term());
+                	arrangement_disjunction.push_back(m_theoryStrOverlapAssumption_term);
                 }
             }
         }
@@ -3248,7 +3249,7 @@ void theory_str::process_concat_eq_type1(expr * concatAst1, expr * concatAst2) {
 
                 if (!overlapAssumptionUsed) {
                 	overlapAssumptionUsed = true;
-                	arrangement_disjunction.push_back(ctx.get_theory_str_overlap_assumption_term());
+                	arrangement_disjunction.push_back(m_theoryStrOverlapAssumption_term);
                 }
             }
         }
@@ -3495,7 +3496,7 @@ void theory_str::process_concat_eq_type2(expr * concatAst1, expr * concatAst2) {
 
 	                if (!overlapAssumptionUsed) {
 	                	overlapAssumptionUsed = true;
-	                	assert_implication(ax_l, ctx.get_theory_str_overlap_assumption_term());
+	                	assert_implication(ax_l, m_theoryStrOverlapAssumption_term);
 	                }
 	            }
 	        }
@@ -3601,7 +3602,7 @@ void theory_str::process_concat_eq_type2(expr * concatAst1, expr * concatAst2) {
 
 				    if (!overlapAssumptionUsed) {
 				    	overlapAssumptionUsed = true;
-				    	arrangement_disjunction.push_back(ctx.get_theory_str_overlap_assumption_term());
+				    	arrangement_disjunction.push_back(m_theoryStrOverlapAssumption_term);
 				    }
 				}
 			}
@@ -3903,7 +3904,7 @@ void theory_str::process_concat_eq_type3(expr * concatAst1, expr * concatAst2) {
 
                     if (!overlapAssumptionUsed) {
                     	overlapAssumptionUsed = true;
-                    	assert_implication(ax_l, ctx.get_theory_str_overlap_assumption_term());
+                    	assert_implication(ax_l, m_theoryStrOverlapAssumption_term);
                     }
                 }
             }
@@ -3987,7 +3988,7 @@ void theory_str::process_concat_eq_type3(expr * concatAst1, expr * concatAst2) {
 
                     if (!overlapAssumptionUsed) {
                     	overlapAssumptionUsed = true;
-                    	arrangement_disjunction.push_back(ctx.get_theory_str_overlap_assumption_term());
+                    	arrangement_disjunction.push_back(m_theoryStrOverlapAssumption_term);
                     }
                 }
             }
@@ -4393,7 +4394,7 @@ void theory_str::process_concat_eq_type6(expr * concatAst1, expr * concatAst2) {
 
             // only add the overlap assumption one time
             if (!overlapAssumptionUsed) {
-                arrangement_disjunction.push_back(ctx.get_theory_str_overlap_assumption_term());
+                arrangement_disjunction.push_back(m_theoryStrOverlapAssumption_term);
                 overlapAssumptionUsed = true;
             }
         }
@@ -7292,13 +7293,19 @@ void theory_str::set_up_axioms(expr * ex) {
     }
 }
 
+void theory_str::add_theory_assumptions(expr_ref_vector & assumptions) {
+    TRACE("t_str", tout << "add overlap assumption for theory_str" << std::endl;);
+    symbol strOverlap("!!TheoryStrOverlapAssumption!!");
+    seq_util m_sequtil(get_manager());
+    sort * s = get_manager().mk_bool_sort();
+    m_theoryStrOverlapAssumption_term = expr_ref(get_manager().mk_const(strOverlap, s), get_manager());
+    assumptions.push_back(get_manager().mk_not(m_theoryStrOverlapAssumption_term));
+}
+
 void theory_str::init_search_eh() {
     ast_manager & m = get_manager();
     context & ctx = get_context();
 
-    // safety
-    SASSERT(ctx.use_theory_str_overlap_assumption());
-
     TRACE("t_str_detail",
         tout << "dumping all asserted formulas:" << std::endl;
         unsigned nFormulas = ctx.get_num_asserted_formulas();
diff --git a/src/smt/theory_str.h b/src/smt/theory_str.h
index a8857de24..3c273d4e2 100644
--- a/src/smt/theory_str.h
+++ b/src/smt/theory_str.h
@@ -291,6 +291,7 @@ namespace smt {
         bool avoidLoopCut;
         bool loopDetected;
         obj_map<expr, std::stack<T_cut*> > cut_var_map;
+        expr_ref m_theoryStrOverlapAssumption_term;
 
         obj_hashtable<expr> variable_set;
         obj_hashtable<expr> internal_variable_set;
@@ -627,6 +628,7 @@ namespace smt {
 
         virtual theory* mk_fresh(context*) { return alloc(theory_str, get_manager(), m_params); }
         virtual void init_search_eh();
+        virtual void add_theory_assumptions(expr_ref_vector & assumptions);
         virtual void relevant_eh(app * n);
         virtual void assign_eh(bool_var v, bool is_true);
         virtual void push_scope_eh();