diff --git a/src/math/simplex/sparse_matrix.h b/src/math/simplex/sparse_matrix.h index ffbe009e0..669b8c2b9 100644 --- a/src/math/simplex/sparse_matrix.h +++ b/src/math/simplex/sparse_matrix.h @@ -221,13 +221,13 @@ namespace simplex { friend class sparse_matrix; unsigned m_curr; column const& m_col; - vector<_row> const& m_rows; + vector<_row>& m_rows; void move_to_used() { while (m_curr < m_col.num_entries() && m_col.m_entries[m_curr].is_dead()) { ++m_curr; } } - col_iterator(column const& c, vector<_row> const& r, bool begin): + col_iterator(column const& c, vector<_row>& r, bool begin): m_curr(0), m_col(c), m_rows(r) { ++m_col.m_refs; if (begin) { @@ -245,21 +245,21 @@ namespace simplex { row get_row() const { return row(m_col.m_entries[m_curr].m_row_id); } - row_entry const& get_row_entry() const { + row_entry& get_row_entry() { col_entry const& c = m_col.m_entries[m_curr]; int row_id = c.m_row_id; return m_rows[row_id].m_entries[c.m_row_idx]; } - std::pair operator*() { return std::make_pair(get_row(), &get_row_entry()); } + std::pair operator*() { return std::make_pair(get_row(), &get_row_entry()); } col_iterator & operator++() { ++m_curr; move_to_used(); return *this; } col_iterator operator++(int) { col_iterator tmp = *this; ++*this; return tmp; } bool operator==(col_iterator const & it) const { return m_curr == it.m_curr; } bool operator!=(col_iterator const & it) const { return m_curr != it.m_curr; } }; - col_iterator col_begin(int v) const { return col_iterator(m_columns[v], m_rows, true); } - col_iterator col_end(int v) const { return col_iterator(m_columns[v], m_rows, false); } + col_iterator col_begin(int v) { return col_iterator(m_columns[v], m_rows, true); } + col_iterator col_end(int v) { return col_iterator(m_columns[v], m_rows, false); } class var_rows { friend class sparse_matrix; @@ -305,6 +305,13 @@ namespace simplex { all_rows get_rows() { return all_rows(*this); } + + numeral& get_coeff(row r, unsigned v) { + for (auto & [coeff, u] : get_row(r)) + if (u == v) + return coeff; + throw default_exception("variable not in row"); + } void display(std::ostream& out); diff --git a/src/math/simplex/sparse_matrix_ops.h b/src/math/simplex/sparse_matrix_ops.h index e12f421a1..1827f9b12 100644 --- a/src/math/simplex/sparse_matrix_ops.h +++ b/src/math/simplex/sparse_matrix_ops.h @@ -27,7 +27,7 @@ namespace simplex { public: static void kernel(sparse_matrix& M, vector>& K) { mpq_ext::numeral coeff; - rational D1, D2; + rational D; vector d, c; unsigned m = M.num_vars(); auto& mgr = M.get_manager(); @@ -45,27 +45,26 @@ namespace simplex { continue; d.back() = v + 1; c[v] = row.id() + 1; - D1 = rational(-1) / coeff1; + D = rational(-1) / coeff1; mgr.set(coeff1, mpq(-1)); // eliminate v from other rows. - for (auto const& [row2, row_entry2] : M.get_rows(v)) { - if (row.id() >= row2.id() || row_entry2->m_coeff == 0) + for (auto& [row2, row_entry2] : M.get_rows(v)) { + if (row.id() >= row2.id()) continue; - for (auto& [coeff2, w] : M.get_row(row2)) { - if (v == w) - mgr.set(coeff2, (D1*coeff2).to_mpq()); - } + mpq & m_js = row_entry2->m_coeff; + mgr.set(m_js, (D * m_js).to_mpq()); } - - for (auto& [coeff2, w] : M.get_row(row)) { + for (auto& [m_ik, w] : M.get_row(row)) { if (v == w) continue; - D2 = coeff2; - mgr.set(coeff2, mpq(0)); - for (auto const& [row2, row_entry2] : M.get_rows(w)) { - if (row.id() >= row2.id() || row_entry2->m_coeff == 0 || row_entry2->m_var == v) + D = m_ik; + mgr.set(m_ik, mpq(0)); + for (auto& [row2, row_entry2] : M.get_rows(w)) { + if (row.id() >= row2.id()) continue; - // mgr.set(row_entry2->m_coeff, row_entry2->m_coeff + D2*row2[v]); + auto& m_js = M.get_coeff(row2, v); + auto & m_is = row_entry2->m_coeff; + mgr.set(m_is, (m_is + D * m_js).to_mpq()); } } break; @@ -79,8 +78,15 @@ namespace simplex { if (d[k] != 0) continue; K.push_back(vector()); - for (unsigned i = 0; i < d.size(); ++i) { - // K.back().push_back(d[i] > 0 ? M[d[i]-1][k] : (i == k) ? 1 : 0); + for (unsigned i = 0; i < d.size(); ++i) { + if (d[i] > 0) { + // row r = row(i); + // K.back().push_back(M[d[i]-1][k]); + } + else if (i == k) + K.back().push_back(rational(1)); + else + K.back().push_back(rational(0)); } }