diff --git a/src/math/simplex/model_based_opt.cpp b/src/math/simplex/model_based_opt.cpp index 267171d22..5cd4e24aa 100644 --- a/src/math/simplex/model_based_opt.cpp +++ b/src/math/simplex/model_based_opt.cpp @@ -736,6 +736,8 @@ namespace opt { void model_based_opt::normalize(unsigned row_id) { row& r = m_rows[row_id]; + if (!r.m_alive) + return; if (r.m_vars.empty()) { retire_row(row_id); return; @@ -934,6 +936,7 @@ namespace opt { else { row_id = m_retired_rows.back(); m_retired_rows.pop_back(); + SASSERT(!m_rows[row_id].m_alive); m_rows[row_id].reset(); m_rows[row_id].m_alive = true; } @@ -995,10 +998,10 @@ namespace opt { return v; } - void model_based_opt::add_constraint(vector const& coeffs, rational const& c, rational const& m, ineq_type rel, unsigned id) { + unsigned model_based_opt::add_constraint(vector const& coeffs, rational const& c, rational const& m, ineq_type rel, unsigned id) { auto const& r = m_rows.back(); if (r.m_vars == coeffs && r.m_coeff == c && r.m_mod == m && r.m_type == rel && r.m_id == id && r.m_alive) - return; + return m_rows.size() - 1; unsigned row_id = new_row(); set_row(row_id, coeffs, c, m, rel); m_rows[row_id].m_id = id; @@ -1006,6 +1009,7 @@ namespace opt { m_var2row_ids[coeff.m_id].push_back(row_id); SASSERT(invariant(row_id, m_rows[row_id])); normalize(row_id); + return row_id; } void model_based_opt::set_objective(vector const& coeffs, rational const& c) { @@ -1236,10 +1240,15 @@ namespace opt { unsigned y = add_var(y_value, true); uint_set visited; + unsigned j = 0; for (unsigned ri : mod_rows) { + if (visited.contains(ri)) + continue; m_rows[ri].m_alive = false; visited.insert(ri); + mod_rows[j++] = ri; } + mod_rows.shrink(j); // replace x by K*y + z in other rows. for (unsigned ri : m_var2row_ids[x]) { @@ -1266,11 +1275,12 @@ namespace opt { unsigned w = UINT_MAX; rational offset(0); - if (coeffs.empty()) + if (coeffs.empty() || K == 1) offset = mod(coeff, K); else w = add_mod(coeffs, coeff, K); + rational w_value = w == UINT_MAX ? offset : m_var2value[w]; // add v = a*z + w - V, for k = (a*z_value + w_value) div K @@ -1284,8 +1294,7 @@ namespace opt { add_lower_bound(v, rational::zero()); add_upper_bound(v, K - 1); - // allow to recycle row. - m_retired_rows.push_back(ri); + retire_row(ri); project(v, false); } @@ -1341,12 +1350,17 @@ namespace opt { unsigned y = add_var(y_value, true); uint_set visited; + unsigned j = 0; for (unsigned ri : div_rows) { + if (visited.contains(ri)) + continue; row& r = m_rows[ri]; mul(ri, K / r.m_mod); r.m_alive = false; visited.insert(ri); + div_rows[j++] = ri; } + div_rows.shrink(j); // replace x by K*y + z in other rows. for (unsigned ri : m_var2row_ids[x]) { @@ -1375,9 +1389,11 @@ namespace opt { rational coeff = m_rows[ri].m_coeff; unsigned w = UINT_MAX; rational offset(0); - if (coeffs.empty()) + if (K == 1) + offset = coeff; + else if (coeffs.empty()) offset = div(coeff, K); - else + else w = add_div(coeffs, coeff, K); // @@ -1412,20 +1428,27 @@ namespace opt { vector div_coeffs; div_coeffs.push_back(var(v, rational::minus_one())); div_coeffs.push_back(var(y, a)); - if (w != UINT_MAX) div_coeffs.push_back(var(w, rational::one())); + if (w != UINT_MAX) + div_coeffs.push_back(var(w, rational::one())); + else if (K == 1) + div_coeffs.append(coeffs); add_constraint(div_coeffs, k + offset, t_eq); unsigned u = UINT_MAX; offset = 0; - if (coeffs.empty()) + if (K == 1) + offset = 0; + else if (coeffs.empty()) offset = mod(coeff, K); else u = add_mod(coeffs, coeff, K); + // add a*z + (b mod K) < (k + 1)*K vector bound_coeffs; bound_coeffs.push_back(var(z, a)); - if (u != UINT_MAX) bound_coeffs.push_back(var(u, rational::one())); + if (u != UINT_MAX) + bound_coeffs.push_back(var(u, rational::one())); add_constraint(bound_coeffs, 1 - K * (k + 1) + offset, t_le); // add k*K <= az + (b mod K) @@ -1433,7 +1456,7 @@ namespace opt { c.m_coeff.neg(); add_constraint(bound_coeffs, k * K - offset, t_le); // allow to recycle row. - m_retired_rows.push_back(ri); + retire_row(ri); project(v, false); } diff --git a/src/math/simplex/model_based_opt.h b/src/math/simplex/model_based_opt.h index a62c265a9..96b6508de 100644 --- a/src/math/simplex/model_based_opt.h +++ b/src/math/simplex/model_based_opt.h @@ -60,14 +60,13 @@ namespace opt { } }; struct row { - row(): m_type(t_le), m_value(0), m_alive(false) {} - vector m_vars; // variables with coefficients - rational m_coeff; // constant in inequality - rational m_mod; // value the term divide - ineq_type m_type; // inequality type - rational m_value; // value of m_vars + m_coeff under interpretation of m_var2value. - bool m_alive; // rows can be marked dead if they have been processed. - unsigned m_id; // variable defined by row (used for mod_t and div_t) + vector m_vars; // variables with coefficients + rational m_coeff = rational::zero(); // constant in inequality + rational m_mod = rational::zero(); // value the term divide + ineq_type m_type = t_le; // inequality type + rational m_value = rational::zero(); // value of m_vars + m_coeff under interpretation of m_var2value. + bool m_alive = false; // rows can be marked dead if they have been processed. + unsigned m_id = UINT_MAX; // variable defined by row (used for mod_t and div_t) void reset() { m_vars.reset(); m_coeff.reset(); m_value.reset(); } row& normalize(); @@ -139,7 +138,7 @@ namespace opt { void add_upper_bound(unsigned x, rational const& hi); - void add_constraint(vector const& coeffs, rational const& c, rational const& m, ineq_type r, unsigned id); + unsigned add_constraint(vector const& coeffs, rational const& c, rational const& m, ineq_type r, unsigned id); void replace_var(unsigned row_id, unsigned x, rational const& A, unsigned y, rational const& B);