diff --git a/src/math/simplex/model_based_opt.cpp b/src/math/simplex/model_based_opt.cpp index 577ce4633..d0b4a03ea 100644 --- a/src/math/simplex/model_based_opt.cpp +++ b/src/math/simplex/model_based_opt.cpp @@ -1206,7 +1206,6 @@ namespace opt { m_rows[row_index].m_alive = false; replace_var(row_index, x, rational::zero()); - // compute a_inv rational a_inv, m_inv; rational g = gcd(a, m, a_inv, m_inv); @@ -1268,7 +1267,7 @@ namespace opt { def z_def = project(z, compute_def); if (compute_def) { - result = (y_def * m) + z_def; + result = (y_def * m) + z_def * a_inv; m_var2value[x] = eval(result); } @@ -1286,6 +1285,16 @@ namespace opt { // where k := (b.value mod m + a*z.value) div m // k is between 0 and a // + // - k*m <= b mod m + a*z < (k+1)*m + // + // A better version using a^-1 + // - v = (a*m*y + a^-1*a*z + b) div m + // = a*y + ((m*A + g)*z + b) div m where we write a*a^-1 = m*A + g + // = a*y + A + (g*z + b) div m + // - k*m <= b mod m + gz < (k+1)*m + // where k is between 0 and g + // when gcd(a, m) = 1, then there are only two cases. + // model_based_opt::def model_based_opt::solve_div(unsigned x, unsigned_vector const& div_rows, bool compute_def) { def result; SASSERT(!div_rows.empty()); @@ -1299,24 +1308,31 @@ namespace opt { replace_var(row_index, x, rational::zero()); rational b_value = m_rows[row_index].m_value; - // solve for x_value = m*y_value + z_value, 0 <= z_value < m. + // compute a_inv + rational a_inv, m_inv; + rational g = gcd(a, m, a_inv, m_inv); + if (a_inv.is_neg()) + a_inv = mod(a_inv, m); + SASSERT(mod(a_inv * a, m) == g); + + // solve for x_value = m*y_value + a_inv*z_value, 0 <= z_value < m. rational z_value = mod(x_value, m); - rational y_value = div(x_value, m); - SASSERT(x_value == m*y_value + z_value); + rational y_value = div(x_value, m) - div(z_value*a_inv, m); + SASSERT(x_value == m*y_value + a_inv*z_value); SASSERT(0 <= z_value && z_value < m); // add new variables unsigned y = add_var(y_value, true); unsigned z = add_var(z_value, true); - // replace x by m*y + z in other rows. + // replace x by m*y + a^-1*z in other rows. unsigned_vector const& row_ids = m_var2row_ids[x]; uint_set visited; visited.insert(row_index); for (unsigned row_id : row_ids) { if (visited.contains(row_id)) continue; - replace_var(row_id, x, m, y, rational::one(), z); + replace_var(row_id, x, m, y, a_inv, z); visited.insert(row_id); normalize(row_id); } @@ -1330,28 +1346,45 @@ namespace opt { rational coeff = m_rows[row_index].m_coeff; unsigned w = add_div(coeffs, coeff, m); - // // w = b div m - // v = a*y + w + k - // k = (a*z_value + (b_value mod m)) div m + // v = a*y + w + (a*a_inv div m) + k + // k = (g*z_value + (b_value mod m)) div m + // k*m <= g*z + b mod m < (k+1)*m // rational k = div(a*z_value + mod(b_value, m), m); + rational n = div(a_inv * a, m); vector div_coeffs; div_coeffs.push_back(var(v, rational::minus_one())); div_coeffs.push_back(var(y, a)); div_coeffs.push_back(var(w, rational::one())); + if (n != 0) div_coeffs.push_back(var(z, n)); add_constraint(div_coeffs, k, t_eq); + unsigned u = add_mod(coeffs, coeff, m); + + // add g*z + (b mod m) < (k + 1)*m + vector bound_coeffs; + bound_coeffs.push_back(var(z, g)); + bound_coeffs.push_back(var(u, rational::one())); + add_constraint(bound_coeffs, 1 - m * (k + 1), t_le); + + // add k*m <= gz + (b mod m) + for (auto& c : bound_coeffs) + c.m_coeff.neg(); + add_constraint(bound_coeffs, k * m, t_le); + // allow to recycle row. m_retired_rows.push_back(row_index); + + // project internal variables. project(v, false); def y_def = project(y, compute_def); def z_def = project(z, compute_def); if (compute_def) { - result = (y_def * m) + z_def; + result = (y_def * m) + (z_def * a_inv); m_var2value[x] = eval(result); } return result; diff --git a/src/qe/mbp/mbp_arith.cpp b/src/qe/mbp/mbp_arith.cpp index 7674a1577..11db720ef 100644 --- a/src/qe/mbp/mbp_arith.cpp +++ b/src/qe/mbp/mbp_arith.cpp @@ -225,19 +225,17 @@ namespace mbp { rational c0 = add_def(t1, mul1, coeffs); mbo.add_divides(coeffs, c0 - r, mul1); } - else if (false && a.is_mod(t, t1, t2) && is_numeral(t2, mul1) && !mul1.is_zero()) { + else if (false && a.is_mod(t, t1, t2) && is_numeral(t2, mul1) && mul1 > 0) { // v = t1 mod mul1 vars coeffs; rational c0 = add_def(t1, mul1, coeffs); - unsigned v = mbo.add_mod(coeffs, c0, mul1); - tids.insert(t, v); + tids.insert(t, mbo.add_mod(coeffs, c0, mul1)); } else if (false && a.is_idiv(t, t1, t2) && is_numeral(t2, mul1) && mul1 > 0) { // v = t1 div mul1 vars coeffs; rational c0 = add_def(t1, mul1, coeffs); - unsigned v = mbo.add_div(coeffs, c0, mul1); - tids.insert(t, v); + tids.insert(t, mbo.add_div(coeffs, c0, mul1)); } else insert_mul(t, mul, ts);