diff --git a/src/math/lp/nex_creator.cpp b/src/math/lp/nex_creator.cpp index 21151abfc..066654695 100644 --- a/src/math/lp/nex_creator.cpp +++ b/src/math/lp/nex_creator.cpp @@ -435,13 +435,12 @@ bool nex_creator::process_mul_in_simplify_sum(nex_mul* em, std::map a + 5bc -void nex_creator::sort_join_sum(ptr_vector & children) { - TRACE("nla_cn_details", print_vector_of_ptrs(children, tout);); - std::map map([this](const nex *a , const nex *b) - { return lt(a, b); }); - std::unordered_set existing_nex; // handling (nex*) as numbers - nex_scalar * common_scalar = nullptr; +bool nex_creator::fill_join_map_for_sum(ptr_vector & children, + std::map& map, + std::unordered_set& existing_nex, + nex_scalar*& common_scalar) { + + common_scalar = nullptr; bool simplified = false; for (auto e : children) { if (e->is_scalar()) { @@ -462,7 +461,16 @@ void nex_creator::sort_join_sum(ptr_vector & children) { simplified |= register_in_join_map(map, e, rational(1)); } } - + return simplified; +} + // a + 3bc + 2bc => a + 5bc +void nex_creator::sort_join_sum(ptr_vector & children) { + TRACE("nla_cn_details", print_vector_of_ptrs(children, tout);); + std::map map([this](const nex *a , const nex *b) + { return lt(a, b); }); + std::unordered_set existing_nex; // handling (nex*) as numbers + nex_scalar * common_scalar; + bool simplified = fill_join_map_for_sum(children, map, existing_nex, common_scalar); if (!simplified) return; @@ -472,37 +480,7 @@ void nex_creator::sort_join_sum(ptr_vector & children) { children.push_back(common_scalar); } for (auto& p : map) { - nex *e = p.first; - const rational & coeff = p.second; - if (coeff.is_zero()) - continue; - bool e_is_old = existing_nex.find(e) != existing_nex.end(); - if (e_is_old) { - if (coeff.is_one()) { - children.push_back(e); - } else { - if (e->is_var()) { - children.push_back(mk_mul(mk_scalar(coeff), e)); - } else { - SASSERT(e->is_mul()); - nex* first = to_mul(e)->children()[0].e(); - if (first->is_scalar()) { - to_scalar(first)->value() = coeff; - children.push_back(e); - } else { - e = simplify(mk_mul(mk_scalar(coeff), e)); - children.push_back(e); - } - } - } - } else { // e is new - if (coeff.is_one()) { - m_allocated.push_back(e); - children.push_back(e); - } else { - children.push_back(simplify(mk_mul(mk_scalar(coeff), e))); - } - } + process_map_pair(p.first, p.second, children, existing_nex); } } @@ -560,6 +538,8 @@ bool have_no_scalars(const nex_mul* a) { nex * nex_creator::mk_div_by_mul(const nex* a, const nex_mul* b) { + + // todo: break on shorter funcitons if (a->is_sum()) { nex_sum * r = mk_sum(); const nex_sum * m = to_sum(a); @@ -640,6 +620,39 @@ nex* nex_creator::simplify(nex* e) { return es; } +void nex_creator::process_map_pair(nex *e, const rational& coeff, ptr_vector & children, std::unordered_set& existing_nex) { + // todo : break on shorter functions + if (coeff.is_zero()) + return; + bool e_is_old = existing_nex.find(e) != existing_nex.end(); + if (e_is_old) { + if (coeff.is_one()) { + children.push_back(e); + } else { + if (e->is_var()) { + children.push_back(mk_mul(mk_scalar(coeff), e)); + } else { + SASSERT(e->is_mul()); + nex* first = to_mul(e)->children()[0].e(); + if (first->is_scalar()) { + to_scalar(first)->value() = coeff; + children.push_back(e); + } else { + e = simplify(mk_mul(mk_scalar(coeff), e)); + children.push_back(e); + } + } + } + } else { // e is new + if (coeff.is_one()) { + m_allocated.push_back(e); + children.push_back(e); + } else { + children.push_back(simplify(mk_mul(mk_scalar(coeff), e))); + } + } +} + bool nex_creator::is_simplified(const nex *e) const { if (e->is_mul()) diff --git a/src/math/lp/nex_creator.h b/src/math/lp/nex_creator.h index fee027e8c..308c86a06 100644 --- a/src/math/lp/nex_creator.h +++ b/src/math/lp/nex_creator.h @@ -220,6 +220,10 @@ public: const rational& coeff) ; void sort_join_sum(ptr_vector & children); + bool fill_join_map_for_sum(ptr_vector & children, + std::map& map, + std::unordered_set& existing_nex, + nex_scalar*& common_scalar); bool register_in_join_map(std::map&, nex*, const rational&) const; void simplify_children_of_sum(ptr_vector & children); @@ -231,6 +235,7 @@ public: bool less_than_on_mul(const nex_mul* a, const nex_mul* b) const; bool less_than_on_var_nex(const nex_var* a, const nex* b) const; bool less_than_on_mul_nex(const nex_mul* a, const nex* b) const; - void fill_map_with_children(std::map & m, ptr_vector & children); + void fill_map_with_children(std::map & m, ptr_vector & children); + void process_map_pair(nex *e, const rational& coeff, ptr_vector & children, std::unordered_set&); }; }