From 1b8b09cddbe102036af40a9e5850374a6c6853e4 Mon Sep 17 00:00:00 2001 From: Lev Nachmanson Date: Wed, 7 Aug 2019 11:40:39 -0700 Subject: [PATCH] fixes in horner's heuristic Signed-off-by: Lev Nachmanson --- src/math/lp/horner.cpp | 28 +++++--- src/math/lp/lar_solver.cpp | 19 +++-- src/math/lp/lar_solver.h | 10 +-- src/math/lp/lar_term.h | 17 ++--- src/math/lp/nla_core.cpp | 5 +- src/math/lp/nla_intervals.h | 135 ++++++++++++++++++++++++++---------- 6 files changed, 145 insertions(+), 69 deletions(-) diff --git a/src/math/lp/horner.cpp b/src/math/lp/horner.cpp index 3bb84f9c7..f951af5de 100644 --- a/src/math/lp/horner.cpp +++ b/src/math/lp/horner.cpp @@ -283,10 +283,18 @@ lp::lar_term horner::expression_to_normalized_term(nex& e, rational& a, rational return t; } + +// we should have in the case of found a*m_terms[k] + b = e, +// where m_terms[k] corresponds to the returned lpvar lpvar horner::find_term_column(const nex& e, rational& a, rational& b) const { nex n = e; lp::lar_term norm_t = expression_to_normalized_term(n, a, b); - return c().m_lar_solver.fetch_normalized_term_column(norm_t); + std::pair a_j; + if (c().m_lar_solver.fetch_normalized_term_column(norm_t, a_j)) { + a /= a_j.first; + return a_j.second; + } + return -1; } interv horner::interval_of_sum_no_terms(const nex& e) { @@ -350,12 +358,16 @@ interv horner::interval_of_sum(const nex& e) { interv i_e = interval_of_sum_no_terms(e); if (e.sum_is_a_linear_term()) { interv i_from_term ; - if (interval_from_term(e, i_from_term) - && - is_tighter(i_from_term, i_e)) - return i_from_term; + if (interval_from_term(e, i_from_term)) { + interv r = m_intervals.intersect(i_e, i_from_term); + TRACE("nla_horner_details", tout << "intersection="; m_intervals.display(tout, r) << "\n";); + if (m_intervals.is_empty(r)) { + SASSERT(false); // not implemented + } + return r; + + } } - return i_e; } @@ -365,10 +377,6 @@ void horner::set_var_interval(lpvar v, interv& b) const{ TRACE("nla_horner_details_var", tout << "v = "; print_var(v, tout) << "\n"; m_intervals.display(tout, b);); } -bool horner::is_tighter(const interv& a, const interv& b) const { - return m_intervals.is_tighter(a, b); -} - } diff --git a/src/math/lp/lar_solver.cpp b/src/math/lp/lar_solver.cpp index c489aed15..d6971f184 100644 --- a/src/math/lp/lar_solver.cpp +++ b/src/math/lp/lar_solver.cpp @@ -2372,12 +2372,14 @@ void lar_solver::set_cut_strategy(unsigned cut_frequency) { } } + void lar_solver::register_normalized_term(const lar_term& t, lpvar j) { - lar_term normalized_t = t.get_normalized_by_min_var(); + mpq a; + lar_term normalized_t = t.get_normalized_by_min_var(a); TRACE("lar_solver_terms", tout << "t="; print_term_as_indices(t, tout); tout << ", normalized_t="; print_term_as_indices(normalized_t, tout) << "\n";); if (m_normalized_terms_to_columns.find(normalized_t) == m_normalized_terms_to_columns.end()) { - m_normalized_terms_to_columns[normalized_t] = j; + m_normalized_terms_to_columns[normalized_t] = std::make_pair(a, j); } else { TRACE("lar_solver_terms", tout << "the term has been seen already\n";); } @@ -2386,7 +2388,8 @@ void lar_solver::register_normalized_term(const lar_term& t, lpvar j) { void lar_solver::deregister_normalized_term(const lar_term& t) { TRACE("lar_solver_terms", tout << "deregister term "; print_term_as_indices(t, tout) << "\n";); - lar_term normalized_t = t.get_normalized_by_min_var(); + mpq a; + lar_term normalized_t = t.get_normalized_by_min_var(a); m_normalized_terms_to_columns.erase(normalized_t); } @@ -2397,18 +2400,20 @@ void lar_solver::register_existing_terms() { register_normalized_term(*m_terms[k], j); } } - -unsigned lar_solver::fetch_normalized_term_column(const lar_term& c) const { +// a_j.first gives the normalised coefficient, +// a_j.second givis the column +bool lar_solver::fetch_normalized_term_column(const lar_term& c, std::pair & a_j) const { TRACE("lar_solver_terms", tout << "looking for term "; print_term_as_indices(c, tout) << "\n";); lp_assert(c.is_normalized()); auto it = m_normalized_terms_to_columns.find(c); if (it != m_normalized_terms_to_columns.end()) { TRACE("lar_solver_terms", tout << "got " << it->second << "\n" ;); - return it->second; + a_j = it->second; + return true; } TRACE("lar_solver_terms", tout << "have not found\n";); - return -1; + return false; } } // namespace lp diff --git a/src/math/lp/lar_solver.h b/src/math/lp/lar_solver.h index 1782b02f8..56e995871 100644 --- a/src/math/lp/lar_solver.h +++ b/src/math/lp/lar_solver.h @@ -57,8 +57,8 @@ class lar_solver : public column_namer { size_t seed = 0; int i = 0; for (const auto& p : t.coeffs()) { - hash_combine(seed, p.m_key); - hash_combine(seed, p.m_value); + hash_combine(seed, p.first); + hash_combine(seed, p.second); if (i++ > 10) break; } @@ -106,7 +106,9 @@ public: vector m_terms; indexed_vector m_column_buffer; bool m_need_register_terms; - std::unordered_map m_normalized_terms_to_columns; // end of fields + std::unordered_map, term_hasher, term_comparer> + m_normalized_terms_to_columns; + // end of fields unsigned terms_start_index() const { return m_terms_start_index; } const vector & terms() const { return m_terms; } @@ -646,6 +648,6 @@ public: void register_existing_terms(); void register_normalized_term(const lar_term&, lpvar); void deregister_normalized_term(const lar_term&); - lpvar fetch_normalized_term_column(const lar_term& t) const; + bool fetch_normalized_term_column(const lar_term& t, std::pair& ) const; }; } diff --git a/src/math/lp/lar_term.h b/src/math/lp/lar_term.h index 108325045..2330a647c 100644 --- a/src/math/lp/lar_term.h +++ b/src/math/lp/lar_term.h @@ -19,7 +19,8 @@ --*/ #pragma once #include "math/lp/indexed_vector.h" -#include "util/map.h" +#include + namespace lp { class lar_term { // the term evaluates to sum of m_coeffs @@ -69,7 +70,7 @@ public: vector> coeffs_as_vector() const { vector> ret; for (const auto & p : m_coeffs) { - ret.push_back(std::make_pair(p.m_value, p.m_key)); + ret.push_back(std::make_pair(p.second, p.first)); } return ret; } @@ -95,25 +96,25 @@ public: } bool contains(lpvar j) const { - return m_coeffs.contains(j); + return m_coeffs.find(j) != m_coeffs.end(); } void negate() { for (auto & t : m_coeffs) - t.m_value.neg(); + t.second.neg(); } template T apply(const vector& x) const { T ret(0); for (const auto & t : m_coeffs) { - ret += t.m_value * x[t.m_key]; + ret += t.second * x[t.first]; } return ret; } void clear() { - m_coeffs.reset(); + m_coeffs.clear(); } struct ival { @@ -137,13 +138,13 @@ public: typedef std::forward_iterator_tag iterator_category; reference operator*() const { - return ival(m_it->m_key, m_it->m_value); + return ival(m_it->first, m_it->second); } self_type operator++() { self_type i = *this; m_it++; return i; } self_type operator++(int) { m_it++; return *this; } - const_iterator(u_map::iterator it) : m_it(it) {} + const_iterator(std::map::const_iterator it) : m_it(it) {} bool operator==(const self_type &other) const { return m_it == other.m_it; } diff --git a/src/math/lp/nla_core.cpp b/src/math/lp/nla_core.cpp index cff739a99..5a69b0cae 100644 --- a/src/math/lp/nla_core.cpp +++ b/src/math/lp/nla_core.cpp @@ -1391,11 +1391,8 @@ std::ostream& core::print_terms(std::ostream& out) const { } const lp::lar_term & t = *m_lar_solver.m_terms[i]; - print_term(t, out) << std::endl; + out << "term:"; print_term(t, out) << std::endl; lpvar j = m_lar_solver.external_to_local(ext); - SASSERT(j + 1); - auto e = mk_expr(t); - out << "e= " << e << "\n"; print_var(j, out); } return out; diff --git a/src/math/lp/nla_intervals.h b/src/math/lp/nla_intervals.h index cbc5a35e8..0e34371a3 100644 --- a/src/math/lp/nla_intervals.h +++ b/src/math/lp/nla_intervals.h @@ -101,8 +101,7 @@ class intervals : common { (!lower_is_open(a)) && (!upper_is_open(a)) && unsynch_mpq_manager::is_zero(a.m_lower) && unsynch_mpq_manager::is_zero(a.m_upper); } - - + // Setters void set_lower(interval & a, mpq const & n) const { m_manager.set(a.m_lower, n); } void set_upper(interval & a, mpq const & n) const { m_manager.set(a.m_upper, n); } @@ -138,7 +137,7 @@ class intervals : common { small_object_allocator m_alloc; ci_value_manager m_val_manager; - unsynch_mpq_manager m_num_manager; + mutable unsynch_mpq_manager m_num_manager; mutable ci_dependency_manager m_dep_manager; im_config m_config; mutable interval_manager m_imanager; @@ -181,7 +180,7 @@ public: b.m_upper_dep = a.m_lower_dep; b.m_lower_dep = a.m_upper_dep; } - } + } void add(const rational& r, interval& a) const { if (!a.m_lower_inf) { @@ -206,45 +205,96 @@ public: m_config.add_deps(a, b, deps, i); } - bool is_tighter_on_lower(const interval& a, const interval& b) const { - if (lower_is_inf(a)) - return false; - if (lower_is_inf(b)) - return true; - if (rational(lower(a)) < rational(lower(b))) - return true; - if (lower(a) > lower(b)) - return false; + void update_lower_for_intersection(const interval& a, const interval& b, interval & i) const { + if (a.m_lower_inf) { + if (b.m_lower_inf) + return; + copy_lower_bound(b, i); + return; + } + + if (b.m_lower_inf) { + SASSERT(!a.m_lower_inf); + copy_lower_bound(a, i); + return; + } - if (!a.m_lower_open) - return false; - if (b.m_lower_open) - return false; + if (m_num_manager.lt(a.m_lower, b.m_lower)) { + copy_lower_bound(b, i); + return; + } - return true; + if (m_num_manager.gt(a.m_lower, b.m_lower)) { + copy_lower_bound(a, i); + return; + } + + SASSERT(m_num_manager.eq(a.m_lower, b.m_lower)); + if (a.m_lower_open) { // we might consider to look at b.m_lower_open too here + copy_lower_bound(a, i); + return; + } + + copy_lower_bound(b, i); } - bool is_tighter_on_upper(const interval& a, const interval& b) const { - if (upper_is_inf(a)) - return false; - if (upper_is_inf(b)) - return true; - if (rational(upper(a)) > rational(upper(b))) - return true; - if (rational(upper(a)) < rational(upper(b))) - return false; + void copy_upper_bound(const interval& a, interval & i) const { + SASSERT(a.m_upper_inf == false); + i.m_upper_inf = false; + m_config.set_upper(i, a.m_upper); + i.m_upper_dep = a.m_upper_dep; + i.m_upper_open = a.m_upper_open; + } - if (!a.m_upper_open) - return false; - if (b.m_upper_open) - return false; - - return true; + void copy_lower_bound(const interval& a, interval & i) const { + SASSERT(a.m_lower_open == false); + i.m_lower_inf = false; + m_config.set_lower(i, a.m_lower); + i.m_lower_dep = a.m_lower_dep; + i.m_lower_open = a.m_lower_open; } - bool is_tighter(const interval& a, const interval& b) const { - return (is_tighter_on_lower(a, b) && !is_tighter_on_upper(b, a)) || - (is_tighter_on_upper(a, b) && is_tighter_on_lower(b, a)); + void update_upper_for_intersection(const interval& a, const interval& b, interval & i) const { + if (a.m_upper_inf) { + if (b.m_upper_inf) + return; + copy_upper_bound(b, i); + return; + } + + if (b.m_upper_inf) { + SASSERT(!a.m_upper_inf); + copy_upper_bound(a, i); + return; + } + + if (m_num_manager.gt(a.m_upper, b.m_upper)) { + copy_upper_bound(b, i); + return; + } + + if (m_num_manager.lt(a.m_upper, b.m_upper)) { + copy_upper_bound(a, i); + return; + } + + SASSERT(m_num_manager.eq(a.m_upper, b.m_upper)); + if (a.m_upper_open) { // we might consider to look at b.m_upper_open too here + copy_upper_bound(a, i); + return; + } + + copy_upper_bound(b, i); + } + + interval intersect(const interval& a, const interval& b) const { + interval i; + TRACE("nla_interval_compare", tout << "a="; display(tout, a) << "\nb="; display(tout, b);); + update_lower_for_intersection(a, b, i); + TRACE("nla_interval_compare", tout << "i="; display(tout, i) << "\n";); + update_upper_for_intersection(a, b, i); + TRACE("nla_interval_compare", tout << "i="; display(tout, i) << "\n";); + return i; } bool upper_is_inf(const interval& a) const { return m_config.upper_is_inf(a); } @@ -257,5 +307,18 @@ public: bool check_interval_for_conflict_on_zero_upper(const interval & i); mpq const & lower(interval const & a) const { return m_config.lower(a); } mpq const & upper(interval const & a) const { return m_config.upper(a); } + bool is_empty(interval const & a) const { + if (a.m_lower_inf || a.m_upper_inf) + return false; + + if (m_num_manager.gt(a.m_lower, a.m_upper)) + return true; + if (m_num_manager.lt(a.m_lower, a.m_upper)) + return false; + if (a.m_lower_open || a.m_upper_open) + return true; + return false; + } + }; // end of intervals } // end of namespace nla