From 5669cf65bc99863f21db6fbe0d49df613c11474c Mon Sep 17 00:00:00 2001
From: Nikolaj Bjorner <nbjorner@microsoft.com>
Date: Sat, 13 Aug 2022 06:18:13 -0700
Subject: [PATCH] bug fixes to mod/div quantifier elimination features

---
 .gitignore                           |  3 ++-
 src/math/simplex/model_based_opt.cpp | 23 ++++++++++++++---------
 2 files changed, 16 insertions(+), 10 deletions(-)

diff --git a/.gitignore b/.gitignore
index 3fe3a3110..ffc50c1ba 100644
--- a/.gitignore
+++ b/.gitignore
@@ -91,4 +91,5 @@ examples/**/obj
 CMakeSettings.json
 # Editor temp files
 *.swp
-.DS_Store
\ No newline at end of file
+.DS_Store
+dbg/**
diff --git a/src/math/simplex/model_based_opt.cpp b/src/math/simplex/model_based_opt.cpp
index d0b4a03ea..7e654a3f6 100644
--- a/src/math/simplex/model_based_opt.cpp
+++ b/src/math/simplex/model_based_opt.cpp
@@ -871,6 +871,7 @@ namespace opt {
 
     unsigned model_based_opt::add_var(rational const& value, bool is_int) {
         unsigned v = m_var2value.size();
+        verbose_stream() << "add var " << v << "\n";
         m_var2value.push_back(value);
         m_var2is_int.push_back(is_int);
         SASSERT(value.is_int() || !is_int);
@@ -1208,7 +1209,7 @@ namespace opt {
 
         // compute a_inv 
         rational a_inv, m_inv;
-        rational g = gcd(a, m, a_inv, m_inv);
+        rational g = mod(gcd(a, m, a_inv, m_inv), m);
         if (a_inv.is_neg())
             a_inv = mod(a_inv, m);
         SASSERT(mod(a_inv * a, m) == g);
@@ -1252,7 +1253,7 @@ namespace opt {
         // add g*z + w - v - k*m = 0, for k = (g*z_value + w_value) div m
         rational km = div(g*z_value + w_value, m)*m;
         vector<var> mod_coeffs;
-        mod_coeffs.push_back(var(z, g));
+        if (g != 0) mod_coeffs.push_back(var(z, g));
         mod_coeffs.push_back(var(w, rational::one()));
         mod_coeffs.push_back(var(v, rational::minus_one()));
         add_constraint(mod_coeffs, km, t_eq);
@@ -1270,7 +1271,7 @@ namespace opt {
             result = (y_def * m) + z_def * a_inv;
             m_var2value[x] = eval(result);
         }
-
+        TRACE("opt", display(tout << "solve_mod\n"));
         return result;
     }
 
@@ -1308,11 +1309,14 @@ namespace opt {
         replace_var(row_index, x, rational::zero());
         rational b_value = m_rows[row_index].m_value;
 
+        TRACE("opt", display(tout << "solve_div\n"));
+
         // compute a_inv 
         rational a_inv, m_inv;
-        rational g = gcd(a, m, a_inv, m_inv);
+        rational g = mod(gcd(a, m, a_inv, m_inv), m);
         if (a_inv.is_neg())
             a_inv = mod(a_inv, m);
+
         SASSERT(mod(a_inv * a, m) == g);
 
         // solve for x_value = m*y_value + a_inv*z_value, 0 <= z_value < m.
@@ -1366,7 +1370,7 @@ namespace opt {
 
         // add g*z + (b mod m) < (k + 1)*m
         vector<var> bound_coeffs;
-        bound_coeffs.push_back(var(z, g));
+        if (g != 0) bound_coeffs.push_back(var(z, g));
         bound_coeffs.push_back(var(u, rational::one()));       
         add_constraint(bound_coeffs, 1 - m * (k + 1), t_le);
 
@@ -1387,6 +1391,7 @@ namespace opt {
             result = (y_def * m) + (z_def * a_inv);
             m_var2value[x] = eval(result);
         }
+        TRACE("opt", display(tout << "solve_div\n"));
         return result;
     }
 
@@ -1505,13 +1510,13 @@ namespace opt {
         if (coeff.is_zero() || !r.m_alive)
             return;
         replace_var(row_id, x, rational::zero());        
-        r.m_vars.push_back(var(y, coeff*A));
-        r.m_vars.push_back(var(z, coeff*B));
+        if (A != 0) r.m_vars.push_back(var(y, coeff*A));
+        if (B != 0) r.m_vars.push_back(var(z, coeff*B));
         r.m_value += coeff*A*m_var2value[y];
         r.m_value += coeff*B*m_var2value[z];
         std::sort(r.m_vars.begin(), r.m_vars.end(), var::compare());
-        m_var2row_ids[y].push_back(row_id);
-        m_var2row_ids[z].push_back(row_id);
+        if (A != 0) m_var2row_ids[y].push_back(row_id);
+        if (B != 0) m_var2row_ids[z].push_back(row_id);
         SASSERT(invariant(row_id, r));
     }