diff --git a/src/math/lp/lar_solver.cpp b/src/math/lp/lar_solver.cpp index 589aad85a..b3fc36a65 100644 --- a/src/math/lp/lar_solver.cpp +++ b/src/math/lp/lar_solver.cpp @@ -8,7 +8,46 @@ namespace lp { + struct column_update { + bool m_is_upper; + unsigned m_j; + impq m_bound; + column m_column; + }; + + struct imp { + lar_solver &lra; + vector m_column_updates; + void set_r_upper_bound(unsigned j, const impq& b) { + lra.m_mpq_lar_core_solver.m_r_upper_bounds[j] = b; + } + void set_r_lower_bound(unsigned j, const impq& b) { + lra.m_mpq_lar_core_solver.m_r_lower_bounds[j] = b; + } + + imp(lar_solver& s) : lra(s) {} + + void set_column(unsigned j, const column& c) { + lra.m_columns[j] = c; + } + + struct column_update_trail : public trail { + imp& m_imp; + column_update_trail(imp & i) : m_imp(i) {} + void undo() override { + auto& [is_upper, j, bound, column] = m_imp.m_column_updates.back(); + if (is_upper) + m_imp.set_r_upper_bound(j, bound); + else + m_imp.set_r_lower_bound(j, bound); + m_imp.set_column(j, column); + m_imp.m_column_updates.pop_back(); + } + }; + }; + + imp* m_imp; lp_settings& lar_solver::settings() { return m_settings; } lp_settings const& lar_solver::settings() const { return m_settings; } @@ -25,7 +64,8 @@ namespace lp { lar_solver::lar_solver() : m_mpq_lar_core_solver(m_settings, *this), m_var_register(), - m_constraints(m_dependencies, *this) {} + m_constraints(m_dependencies, *this), m_imp(alloc(imp, *this)) { + } // start or ends tracking the rows that were changed by solve() void lar_solver::track_touched_rows(bool v) { @@ -574,38 +614,25 @@ namespace lp { A_r().pop(k); } - struct lar_solver::column_update_trail : public trail { - lar_solver& s; - column_update_trail(lar_solver& s) : s(s) {} - void undo() override { - auto& [is_upper, j, bound, column] = s.m_column_updates.back(); - if (is_upper) - s.m_mpq_lar_core_solver.m_r_upper_bounds[j] = bound; - else - s.m_mpq_lar_core_solver.m_r_lower_bounds[j] = bound; - s.m_columns[j] = column; - s.m_column_updates.pop_back(); - } - }; - + void lar_solver::set_upper_bound_witness(lpvar j, u_dependency* dep, impq const& high) { bool has_upper = m_columns[j].upper_bound_witness() != nullptr; - m_column_updates.push_back({true, j, get_upper_bound(j), m_columns[j]}); - m_trail.push(column_update_trail(*this)); + m_imp->m_column_updates.push_back({true, j, get_upper_bound(j), m_columns[j]}); + m_trail.push(imp::column_update_trail(*this->m_imp)); m_columns[j].set_upper_bound_witness(dep); if (has_upper) - m_columns[j].set_previous_upper(m_column_updates.size() - 1); + m_columns[j].set_previous_upper(m_imp->m_column_updates.size() - 1); m_mpq_lar_core_solver.m_r_upper_bounds[j] = high; insert_to_columns_with_changed_bounds(j); } void lar_solver::set_lower_bound_witness(lpvar j, u_dependency* dep, impq const& low) { bool has_lower = m_columns[j].lower_bound_witness() != nullptr; - m_column_updates.push_back({false, j, get_lower_bound(j), m_columns[j]}); - m_trail.push(column_update_trail(*this)); + m_imp->m_column_updates.push_back({false, j, get_lower_bound(j), m_columns[j]}); + m_trail.push(imp::column_update_trail(*this->m_imp)); m_columns[j].set_lower_bound_witness(dep); if (has_lower) - m_columns[j].set_previous_lower(m_column_updates.size() - 1); + m_columns[j].set_previous_lower(m_imp->m_column_updates.size() - 1); m_mpq_lar_core_solver.m_r_lower_bounds[j] = low; insert_to_columns_with_changed_bounds(j); } @@ -1196,26 +1223,34 @@ namespace lp { #if 1 if(is_upper) { - if (ul.previous_upper() != UINT_MAX) { - auto const& [_is_upper, _j, _bound, _column] = m_column_updates[ul.previous_upper()]; + unsigned current_update_index = ul.previous_upper(); + while (current_update_index != UINT_MAX) { + auto const& [_is_upper, _j, _bound, _column] = m_imp->m_column_updates[current_update_index]; auto new_slack = slack + coeff * (_bound - get_upper_bound(j)); if (sign == get_sign(new_slack)) { - // verbose_stream() << "can weaken j" << j << " " << coeff << " " << get_upper_bound(j) << " " << _bound << "\n"; + //verbose_stream() << "can weaken upper j" << j << " " << coeff << " " << get_upper_bound(j) << " " << _bound << "\n"; slack = new_slack; bound_constr_i = _column.upper_bound_witness(); - } + current_update_index = _column.previous_upper(); // Move to the next previous bound in the list + } else + break; // Stop if weakening is not possible with this previous bound + } - } + } else { - if (ul.previous_lower() != UINT_MAX) { - auto const& [_is_upper, _j, _bound, _column] = m_column_updates[ul.previous_lower()]; + unsigned prev_l = ul.previous_lower(); + while (prev_l != UINT_MAX) { + auto const& [_is_upper, _j, _bound, _column] = m_imp->m_column_updates[prev_l]; auto new_slack = slack + coeff * (_bound - get_lower_bound(j)); if (sign == get_sign(new_slack)) { - // verbose_stream() << "can weaken j" << j << " " << coeff << " " << get_lower_bound(j) << " " << _bound << "\n"; + //verbose_stream() << "can weaken lower j" << j << " " << coeff << " " << get_lower_bound(j) << " " << _bound << "\n"; slack = new_slack; bound_constr_i = _column.lower_bound_witness(); + prev_l = _column.previous_lower(); } - } + else + break; + } } #endif diff --git a/src/math/lp/lar_solver.h b/src/math/lp/lar_solver.h index cdccbe52d..1ea22ebd6 100644 --- a/src/math/lp/lar_solver.h +++ b/src/math/lp/lar_solver.h @@ -47,7 +47,7 @@ namespace lp { class int_branch; class int_solver; - +struct imp; class lar_solver : public column_namer { struct term_hasher { @@ -73,13 +73,7 @@ class lar_solver : public column_namer { } }; - struct column_update { - bool m_is_upper; - unsigned m_j; - impq m_bound; - column m_column; - }; - struct column_update_trail; + //////////////////// fields ////////////////////////// trail_stack m_trail; @@ -94,7 +88,6 @@ class lar_solver : public column_namer { bool m_need_register_terms = false; var_register m_var_register; svector m_columns; - vector m_column_updates; constraint_set m_constraints; // the set of column indices j such that bounds have changed for j @@ -123,6 +116,7 @@ class lar_solver : public column_namer { map, default_eq> m_fixed_var_table_real; // the set of fixed variables which are also base variables indexed_uint_set m_fixed_base_var_set; + imp* m_imp; // end of fields ////////////////// nested structs ///////////////////////// @@ -742,5 +736,7 @@ public: std::function m_find_monics_with_changed_bounds_func = nullptr; friend int_solver; friend int_branch; + friend imp; + }; } // namespace lp