diff --git a/src/math/lp/nex.h b/src/math/lp/nex.h index 1ba5d1ea9..774795cab 100644 --- a/src/math/lp/nex.h +++ b/src/math/lp/nex.h @@ -191,6 +191,21 @@ public: void add_child(nex* e) { add_child_in_power(e, 1); } + + // returns true if the product of scalars gives a number different from 1 + bool has_a_coeff() const { + rational r(1); + for (auto & p : *this) { + if (p.e()->is_scalar()) + r *= to_scalar(p.e())->value(); + } + return !r.is_one(); + } + + const nex_pow* begin() const { return m_children.begin(); } + const nex_pow* end() const { return m_children.end(); } + nex_pow* begin() { return m_children.begin(); } + nex_pow* end() { return m_children.end(); } void add_child_in_power(nex* e, int power) { m_children.push_back(nex_pow(e, power)); } diff --git a/src/math/lp/nex_creator.cpp b/src/math/lp/nex_creator.cpp index 066654695..9b8c799b4 100644 --- a/src/math/lp/nex_creator.cpp +++ b/src/math/lp/nex_creator.cpp @@ -379,14 +379,14 @@ nex* nex_creator::create_child_from_nex_and_coeff(nex *e, } case expr_type::MUL: { nex_mul * em = to_mul(e); - nex_pow *np = em->children().begin(); + nex_pow *np = em->begin(); if (np->e()->is_scalar()) { SASSERT(np->pow() == 1); to_scalar(np->e())->value() = coeff; return e; } em->add_child(mk_scalar(coeff)); - std::sort(em->children().begin(), em->children().end(), [this](const nex_pow& a, + std::sort(em->begin(), em->end(), [this](const nex_pow& a, const nex_pow& b) {return less_than_on_nex_pow(a, b);}); return em; } @@ -415,11 +415,11 @@ bool nex_creator::register_in_join_map(std::map& map, ne // returns true if a simplificatian happens bool nex_creator::process_mul_in_simplify_sum(nex_mul* em, std::map &map) { bool found = false; - auto it = em->children().begin(); + auto it = em->begin(); if (it->e()->is_scalar()) { SASSERT(it->pow() == 1); rational r = to_scalar(it->e())->value(); - auto end = em->children().end(); + auto end = em->end(); if (em->children().size() == 2 && em->children()[1].pow() == 1) { found = register_in_join_map(map, em->children()[1].e(), r); } else { @@ -536,20 +536,21 @@ bool have_no_scalars(const nex_mul* a) { return true; } +nex * nex_creator::mk_div_sum_by_mul(const nex_sum* m, const nex_mul* b) { + nex_sum * r = mk_sum(); + for (auto e : m->children()) { + r->add_child(mk_div_by_mul(e, b)); + } + TRACE("nla_cn_details", tout << *r << "\n";); + return r; +} nex * nex_creator::mk_div_by_mul(const nex* a, const nex_mul* b) { - - // todo: break on shorter funcitons if (a->is_sum()) { - nex_sum * r = mk_sum(); - const nex_sum * m = to_sum(a); - for (auto e : m->children()) { - r->add_child(mk_div_by_mul(e, b)); - } - TRACE("nla_cn_details", tout << *r << "\n";); - return r; + return mk_div_sum_by_mul(to_sum(a), b); } if (a->is_var() || (a->is_mul() && to_mul(a)->children().size() == 1)) { + SASSERT(b->get_degree() == 1 && !b->has_a_coeff() && b->contains(to_var(a)->var())); return mk_scalar(rational(1)); } const nex_mul* am = to_mul(a); diff --git a/src/math/lp/nex_creator.h b/src/math/lp/nex_creator.h index 308c86a06..17912e25b 100644 --- a/src/math/lp/nex_creator.h +++ b/src/math/lp/nex_creator.h @@ -203,6 +203,7 @@ public: nex * mk_div(const nex* a, lpvar j); nex * mk_div(const nex* a, const nex* b); nex * mk_div_by_mul(const nex* a, const nex_mul* b); + nex * mk_div_sum_by_mul(const nex_sum* a, const nex_mul* b); nex * simplify_mul(nex_mul *e); bool is_sorted(const nex_mul * e) const; @@ -236,6 +237,6 @@ public: bool less_than_on_var_nex(const nex_var* a, const nex* b) const; bool less_than_on_mul_nex(const nex_mul* a, const nex* b) const; void fill_map_with_children(std::map & m, ptr_vector & children); - void process_map_pair(nex *e, const rational& coeff, ptr_vector & children, std::unordered_set&); + void process_map_pair(nex *e, const rational& coeff, ptr_vector & children, std::unordered_set&); }; }