diff --git a/src/math/lp/horner.cpp b/src/math/lp/horner.cpp index 133642cdc..615b43905 100644 --- a/src/math/lp/horner.cpp +++ b/src/math/lp/horner.cpp @@ -311,10 +311,7 @@ lp::lar_term horner::expression_to_normalized_term(const nex_sum* e, rational& a // we should have in the case of found a*m_terms[k] + b = e, // where m_terms[k] corresponds to the returned lpvar -lpvar horner::find_term_column(const nex* e, rational& a, rational& b) const { - if (!e->is_sum()) - return -1; - lp::lar_term norm_t = expression_to_normalized_term(to_sum(e), a, b); +lpvar horner::find_term_column(const lp::lar_term & norm_t, rational& a) const { std::pair a_j; if (c().m_lar_solver.fetch_normalized_term_column(norm_t, a_j)) { a /= a_j.first; @@ -359,7 +356,14 @@ interv horner::interval_of_sum_no_terms(const nex_sum* e) { bool horner::interval_from_term(const nex* e, interv & i) const { rational a, b; - lpvar j = find_term_column(e, a, b); + lp::lar_term norm_t = expression_to_normalized_term(to_sum(e), a, b); + lp::explanation exp; + if (c().explain_by_equiv(norm_t, exp)) { + m_intervals.set_zero_interval_with_explanation(i, exp); + TRACE("nla_horner", tout << "explain_by_equiv\n"); + return true; + } + lpvar j = find_term_column(norm_t, a); if (j + 1 == 0) return false; @@ -381,6 +385,7 @@ interv horner::interval_of_sum(const nex_sum* e) { TRACE("nla_horner_details", tout << "e=" << e << "\n";); interv i_e = interval_of_sum_no_terms(e); if (e->is_a_linear_term()) { + SASSERT(e->is_sum() && e->size() > 1); interv i_from_term ; if (interval_from_term(e, i_from_term)) { interv r = m_intervals.intersect(i_e, i_from_term); diff --git a/src/math/lp/horner.h b/src/math/lp/horner.h index 2d8732e69..f62078b44 100644 --- a/src/math/lp/horner.h +++ b/src/math/lp/horner.h @@ -52,7 +52,7 @@ public: template // T has an iterator of (coeff(), var()) bool row_has_monomial_to_refine(const T&) const; - lpvar find_term_column(const nex* e, rational& a, rational& b) const; + lpvar find_term_column(const lp::lar_term &, rational & a) const; static lp::lar_term expression_to_normalized_term(const nex_sum*, rational& a, rational & b); static void add_linear_to_vector(const nex*, vector> &); static void add_mul_to_vector(const nex_mul*, vector> &); diff --git a/src/math/lp/nla_core.cpp b/src/math/lp/nla_core.cpp index 73e9e67a6..9b3d73788 100644 --- a/src/math/lp/nla_core.cpp +++ b/src/math/lp/nla_core.cpp @@ -409,10 +409,10 @@ bool core:: explain_ineq(const lp::lar_term& t, llc cmp, const rational& rs) { /** * \brief - if t is an octagon term -+x -+ y try to explain why the term always + if t is an octagon term -+x -+ y try to explain why the term always is equal zero */ -bool core:: explain_by_equiv(const lp::lar_term& t, lp::explanation& e) { +bool core:: explain_by_equiv(const lp::lar_term& t, lp::explanation& e) const { lpvar i,j; bool sign; if (!is_octagon_term(t, sign, i, j)) diff --git a/src/math/lp/nla_core.h b/src/math/lp/nla_core.h index 5685aab1d..bc7fd68ad 100644 --- a/src/math/lp/nla_core.h +++ b/src/math/lp/nla_core.h @@ -280,7 +280,7 @@ public: bool explain_ineq(const lp::lar_term& t, llc cmp, const rational& rs); - bool explain_by_equiv(const lp::lar_term& t, lp::explanation& e); + bool explain_by_equiv(const lp::lar_term& t, lp::explanation& e) const; bool has_zero_factor(const factorization& factorization) const; diff --git a/src/math/lp/nla_intervals.cpp b/src/math/lp/nla_intervals.cpp index debdeb55e..b9e32944d 100644 --- a/src/math/lp/nla_intervals.cpp +++ b/src/math/lp/nla_intervals.cpp @@ -32,6 +32,17 @@ void intervals::set_var_interval_with_deps(lpvar v, interval& b) const { } } +void intervals::set_zero_interval_with_explanation(interval& i, const lp::explanation& exp) const { + auto val = rational(0); + m_config.set_lower(i, val); + m_config.set_lower_is_open(i, false); + m_config.set_lower_is_inf(i, false); + m_config.set_upper(i, val); + m_config.set_upper_is_open(i, false); + m_config.set_upper_is_inf(i, false); + i.m_lower_dep = i.m_upper_dep = mk_dep(exp); +} + void intervals::set_zero_interval_deps_for_mult(interval& a) { a.m_lower_dep = m_dep_manager.mk_join(a.m_lower_dep, a.m_upper_dep); a.m_upper_dep = a.m_lower_dep; @@ -74,7 +85,19 @@ bool intervals::check_interval_for_conflict_on_zero_lower(const interval & i) { intervals::ci_dependency *intervals::mk_dep(lp::constraint_index ci) const { return m_dep_manager.mk_leaf(ci); } - + +intervals::ci_dependency *intervals::mk_dep(const lp::explanation& expl) const { + intervals::ci_dependency * r = nullptr; + for (auto p : expl) { + if (r == nullptr) { + r = m_dep_manager.mk_leaf(p.second); + } else { + r = m_dep_manager.mk_join(r, m_dep_manager.mk_leaf(p.second)); + } + } + return r; +} + std::ostream& intervals::display(std::ostream& out, const interval& i) const { if (m_imanager.lower_is_inf(i)) { diff --git a/src/math/lp/nla_intervals.h b/src/math/lp/nla_intervals.h index 008c6b2bc..962653451 100644 --- a/src/math/lp/nla_intervals.h +++ b/src/math/lp/nla_intervals.h @@ -147,6 +147,7 @@ public: private: void set_var_interval(lpvar v, interval & b) const; ci_dependency* mk_dep(lp::constraint_index ci) const; + ci_dependency* mk_dep(lp::explanation const &) const; lp::lar_solver& ls(); const lp::lar_solver& ls() const; public: @@ -300,6 +301,7 @@ public: bool lower_is_inf(const interval& a) const { return m_config.lower_is_inf(a); } void set_var_interval_with_deps(lpvar, interval &) const; void set_zero_interval_deps_for_mult(interval&); + void set_zero_interval_with_explanation(interval& , const lp::explanation& exp) const; bool is_inf(const interval& i) const { return m_config.is_inf(i); } bool check_interval_for_conflict_on_zero(const interval & i); bool check_interval_for_conflict_on_zero_lower(const interval & i);