diff --git a/src/math/lp/random_updater.h b/src/math/lp/random_updater.h index 4805cf937..d5cd4928c 100644 --- a/src/math/lp/random_updater.h +++ b/src/math/lp/random_updater.h @@ -35,11 +35,7 @@ class random_updater { u_set m_var_set; lar_solver & m_lar_solver; unsigned m_range; - void add_column_to_sets(unsigned j); - std::unordered_map, unsigned> m_values; // it maps a value to the number of time it occurs bool shift_var(unsigned j); - void add_value(const numeric_pair& v); - void remove_value(const numeric_pair & v); public: random_updater(lar_solver & solver, const vector & column_list); void update(); diff --git a/src/math/lp/random_updater_def.h b/src/math/lp/random_updater_def.h index 6f6bb47fc..711d0d77f 100644 --- a/src/math/lp/random_updater_def.h +++ b/src/math/lp/random_updater_def.h @@ -32,61 +32,50 @@ random_updater::random_updater( m_range(100000) { m_var_set.resize(m_lar_solver.number_of_vars()); for (unsigned j : column_indices) - add_column_to_sets(j); + m_var_set.insert(j); TRACE("lar_solver_rand", tout << "size = " << m_var_set.size() << "\n";); } -bool random_updater::shift_var(unsigned v) { - SASSERT(!m_lar_solver.column_is_fixed(v)); - return m_lar_solver.get_int_solver()->shift_var(v, m_range); +bool random_updater::shift_var(unsigned j) { + SASSERT(!m_lar_solver.column_is_fixed(j) && !m_lar_solver.is_base(j)); + bool ret = m_lar_solver.get_int_solver()->shift_var(j, m_range); + if (ret) { + const auto & A = m_lar_solver.A_r(); + for (const auto& c : A.m_columns[j]) { + unsigned row_index = c.var(); + unsigned changed_basic = m_lar_solver.get_core_solver().m_r_basis[row_index]; + m_var_set.erase(changed_basic); + } + } + return ret; } void random_updater::update() { - for (auto j : m_var_set) { - if (m_var_set.size() <= m_values.size()) { - break; // we are done - } - auto old_x = m_lar_solver.get_column_value(j); - if (shift_var(j)) { - remove_value(old_x); - add_value(m_lar_solver.get_column_value(j)); + auto columns = m_var_set.index(); // m_var_set is going to change during the loop + for (auto j : columns) { + if (!m_var_set.contains(j)) { + TRACE("lar_solver_rand", tout << "skipped " << j << "\n";); + continue; } + if (!m_lar_solver.is_base(j)) { + shift_var(j); + } else { + unsigned row = m_lar_solver.get_core_solver().m_r_heading[j]; + for (auto & row_c : m_lar_solver.get_core_solver().m_r_A.m_rows[row]) { + unsigned cj = row_c.var(); + if (!m_lar_solver.is_base(cj) && + !m_lar_solver.column_is_fixed(cj) + && + shift_var(cj) + ) { + break; // done with the basic var j + } + } + } } TRACE("lar_solver_rand", tout << "m_var_set.size() = " << m_var_set.size() << ", m_values.size() = " << m_values.size() << "\n";); } -void random_updater::add_value(const numeric_pair& v) { - auto it = m_values.find(v); - if (it == m_values.end()) { - m_values[v] = 1; - } else { - it->second++; - } -} - -void random_updater::remove_value(const numeric_pair& v) { - std::unordered_map, unsigned>::iterator it = m_values.find(v); - lp_assert(it != m_values.end()); - it->second--; - if (it->second == 0) - m_values.erase((std::unordered_map, unsigned>::const_iterator)it); -} - -void random_updater::add_column_to_sets(unsigned j) { - if (m_lar_solver.get_core_solver().m_r_heading[j] < 0) { - m_var_set.insert(j); - add_value(m_lar_solver.get_core_solver().m_r_x[j]); - } else { - unsigned row = m_lar_solver.get_core_solver().m_r_heading[j]; - for (auto & row_c : m_lar_solver.get_core_solver().m_r_A.m_rows[row]) { - unsigned cj = row_c.var(); - if (m_lar_solver.get_core_solver().m_r_heading[cj] < 0 && !m_lar_solver.column_is_fixed(cj)) { - m_var_set.insert(cj); - add_value(m_lar_solver.get_core_solver().m_r_x[cj]); - } - } - } -} }