diff --git a/src/math/lp/lar_solver.cpp b/src/math/lp/lar_solver.cpp index 4c26aaff6..92f6b01ec 100644 --- a/src/math/lp/lar_solver.cpp +++ b/src/math/lp/lar_solver.cpp @@ -230,8 +230,6 @@ namespace lp { m_crossed_bounds_column = null_lpvar; m_crossed_bounds_deps = nullptr; m_mpq_lar_core_solver.push(); - m_term_count = m_terms.size(); - m_term_count.push(); m_constraints.push(); m_usage_in_terms.push(); m_dependencies.push_scope(); @@ -267,14 +265,11 @@ namespace lp { lp_assert(m_mpq_lar_core_solver.m_r_solver.m_costs.size() == A_r().column_count()); lp_assert(m_mpq_lar_core_solver.m_r_solver.m_basis.size() == A_r().row_count()); lp_assert(m_mpq_lar_core_solver.m_r_solver.basis_heading_is_correct()); - lp_assert(A_r().column_count() == n); TRACE("lar_solver_details", for (unsigned j = 0; j < n; j++) print_column_info(j, tout) << "\n";); m_mpq_lar_core_solver.pop(k); remove_non_fixed_from_fixed_var_table(); - clean_popped_elements(n, m_columns_with_changed_bounds); - clean_popped_elements(n, m_incorrect_columns); for (auto rid : m_row_bounds_to_replay) add_touched_row(rid); @@ -288,14 +283,6 @@ namespace lp { m_mpq_lar_core_solver.m_r_solver.reduced_costs_are_correct_tableau()); m_constraints.pop(k); - m_term_count.pop(k); - for (unsigned i = m_term_count; i < m_terms.size(); i++) { - if (m_need_register_terms) - deregister_normalized_term(*m_terms[i]); - delete m_terms[i]; - } - m_term_register.shrink(m_term_count); - m_terms.resize(m_term_count); m_simplex_strategy.pop(k); m_settings.set_simplex_strategy(m_simplex_strategy); lp_assert(sizes_are_correct()); @@ -1473,12 +1460,30 @@ namespace lp { return j; } - struct lar_solver::add_column : public trail { + struct lar_solver::undo_add_column : public trail { lar_solver& s; - add_column(lar_solver& s) : s(s) {} + undo_add_column(lar_solver& s) : s(s) {} virtual void undo() { s.remove_last_column_from_tableau(); s.m_columns_to_ul_pairs.pop_back(); + unsigned j = s.m_columns_to_ul_pairs.size(); + if (s.m_columns_with_changed_bounds.contains(j)) + s.m_columns_with_changed_bounds.remove(j); + if (s.m_incorrect_columns.contains(j)) + s.m_incorrect_columns.remove(j); + } + }; + + struct lar_solver::undo_add_term : public trail { + lar_solver& s; + undo_add_term(lar_solver& s):s(s) {} + void undo() override { + auto* t = s.m_terms.back(); + if (s.m_need_register_terms) + s.deregister_normalized_term(*t); + delete t; + s.m_terms.pop_back(); + s.m_term_register.shrink(s.m_terms.size()); } }; @@ -1492,7 +1497,7 @@ namespace lp { lp_assert(m_columns_to_ul_pairs.size() == A_r().column_count()); local_j = A_r().column_count(); m_columns_to_ul_pairs.push_back(ul_pair(false)); // not associated with a row - m_trail.push(add_column(*this)); + m_trail.push(undo_add_column(*this)); while (m_usage_in_terms.size() <= ext_j) m_usage_in_terms.push_back(0); add_non_basic_var_to_core_fields(ext_j, is_int); @@ -1575,8 +1580,10 @@ namespace lp { return false; } #endif + void lar_solver::push_term(lar_term* t) { m_terms.push_back(t); + m_trail.push(undo_add_term(*this)); } // terms @@ -1645,7 +1652,7 @@ namespace lp { unsigned j = A_r().column_count(); ul_pair ul(true); // to mark this column as associated_with_row m_columns_to_ul_pairs.push_back(ul); - m_trail.push(add_column(*this)); + m_trail.push(undo_add_column(*this)); add_basic_var_to_core_fields(); A_r().fill_last_row_with_pivoting(*term, diff --git a/src/math/lp/lar_solver.h b/src/math/lp/lar_solver.h index 6c7b30664..06d7da5b6 100644 --- a/src/math/lp/lar_solver.h +++ b/src/math/lp/lar_solver.h @@ -87,7 +87,6 @@ class lar_solver : public column_namer { bool m_need_register_terms = false; var_register m_var_register; var_register m_term_register; - struct add_column; svector m_columns_to_ul_pairs; constraint_set m_constraints; // the set of column indices j such that bounds have changed for j @@ -103,7 +102,6 @@ class lar_solver : public column_namer { indexed_uint_set m_incorrect_columns; // copy of m_r_solver.inf_heap() unsigned_vector m_inf_index_copy; - stacked_value m_term_count; vector m_terms; indexed_vector m_column_buffer; std::unordered_map, term_hasher, term_comparer> @@ -119,6 +117,10 @@ class lar_solver : public column_namer { indexed_uint_set m_fixed_base_var_set; // end of fields + ////////////////// nested structs ///////////////////////// + struct undo_add_column; + struct undo_add_term; + ////////////////// methods //////////////////////////////// static bool valid_index(unsigned j) { return static_cast(j) >= 0; } @@ -395,7 +397,7 @@ class lar_solver : public column_namer { inline column_index to_column_index(unsigned v) const { return column_index(external_to_column_index(v)); } bool external_is_used(unsigned) const; void pop(unsigned k); - unsigned num_scopes() const { return m_term_count.stack_size(); } + unsigned num_scopes() const { return m_trail.get_num_scopes(); } bool compare_values(var_index j, lconstraint_kind kind, const mpq& right_side); var_index add_term(const vector>& coeffs, unsigned ext_i); void register_existing_terms(); diff --git a/src/math/lp/monomial_bounds.cpp b/src/math/lp/monomial_bounds.cpp index 2f2e524aa..446bf4698 100644 --- a/src/math/lp/monomial_bounds.cpp +++ b/src/math/lp/monomial_bounds.cpp @@ -52,7 +52,7 @@ namespace nla { * a bounds axiom. */ bool monomial_bounds::propagate_value(dep_interval& range, lpvar v) { - // auto val = c().val(v); + bool propagated = false; if (should_propagate_upper(range, v, 1)) { auto const& upper = dep.upper(range); @@ -88,37 +88,21 @@ namespace nla { bool monomial_bounds::should_propagate_lower(dep_interval const& range, lpvar v, unsigned p) { if (dep.lower_is_inf(range)) return false; - u_dependency* d = nullptr; - rational bound; - bool is_strict; - if (!c().has_lower_bound(v, d, bound, is_strict)) - return true; + auto bound = c().val(v); auto const& lower = dep.lower(range); if (p > 1) bound = power(bound, p); - if (bound < lower) - return true; - if (bound > lower) - return false; - return !is_strict && dep.lower_is_open(range); + return bound < lower; } bool monomial_bounds::should_propagate_upper(dep_interval const& range, lpvar v, unsigned p) { if (dep.upper_is_inf(range)) return false; - u_dependency* d = nullptr; - rational bound; - bool is_strict; - if (!c().has_upper_bound(v, d, bound, is_strict)) - return true; + auto bound = c().val(v); auto const& upper = dep.upper(range); if (p > 1) bound = power(bound, p); - if (bound > upper) - return true; - if (bound < upper) - return false; - return !is_strict && dep.upper_is_open(range); + return bound > upper; } /**