diff --git a/src/util/lp/lar_solver.cpp b/src/util/lp/lar_solver.cpp index 4c7cd2d94..ac2b36e0b 100644 --- a/src/util/lp/lar_solver.cpp +++ b/src/util/lp/lar_solver.cpp @@ -397,7 +397,7 @@ void lar_solver::pop(unsigned k) { m_settings.simplex_strategy() = m_simplex_strategy; lp_assert(sizes_are_correct()); lp_assert((!m_settings.use_tableau()) || m_mpq_lar_core_solver.m_r_solver.reduced_costs_are_correct_tableau()); - lp_assert(m_cube_rounded_rows.size() != 0 || ax_is_correct()); + lp_assert(ax_is_correct()); set_status(lp_status::UNKNOWN); } @@ -875,8 +875,8 @@ void lar_solver::update_x_and_inf_costs_for_columns_with_changed_bounds_tableau( } void lar_solver::fix_Ax_b_on_rounded_row(unsigned i) { - if (A_r().m_rows[i].size() <= i) - return; + if (A_r().m_rows.size() <= i) + return; unsigned bj = m_mpq_lar_core_solver.m_r_basis[i]; auto& v = m_mpq_lar_core_solver.m_r_x[bj]; v = zero_of_type>(); @@ -885,11 +885,28 @@ void lar_solver::fix_Ax_b_on_rounded_row(unsigned i) { v -= c.coeff() * m_mpq_lar_core_solver.m_r_x[c.var()]; } } +void lar_solver::collect_rounded_rows_to_fix() { + lp_assert(m_cube_rounded_rows.size() == 0); + for (unsigned j : m_cube_rounded_columns) { + if (j >= A_r().m_columns.size()) + continue; + int j_raw = m_mpq_lar_core_solver.m_r_solver.m_basis_heading[j]; + if (j_raw >= 0) { + m_cube_rounded_rows.insert(j_raw); + } else { + for (const auto & c : A_r().m_columns[j]) { + m_cube_rounded_rows.insert(c.var()); + } + } + } +} void lar_solver::fix_Ax_b_on_rounded_rows() { + collect_rounded_rows_to_fix(); for (unsigned i : m_cube_rounded_rows) { fix_Ax_b_on_rounded_row(i); } m_cube_rounded_rows.clear(); + m_cube_rounded_columns.clear(); lp_assert(ax_is_correct()); } diff --git a/src/util/lp/lar_solver.h b/src/util/lp/lar_solver.h index 96e0795d8..75ca8d5fe 100644 --- a/src/util/lp/lar_solver.h +++ b/src/util/lp/lar_solver.h @@ -88,7 +88,8 @@ class lar_solver : public column_namer { #endif //////////////////// fields ////////////////////////// - vector m_cube_rounded_rows; + std::unordered_set m_cube_rounded_columns; + std::unordered_set m_cube_rounded_rows; lp_settings m_settings; lp_status m_status; stacked_value m_simplex_strategy; @@ -639,5 +640,6 @@ public: bool sum_first_coords(const lar_term& t, mpq & val) const; void fix_Ax_b_on_rounded_rows(); void fix_Ax_b_on_rounded_row(unsigned); + void collect_rounded_rows_to_fix(); }; }