diff --git a/src/math/lp/cross_nested.h b/src/math/lp/cross_nested.h index dadede624..50d01a2ee 100644 --- a/src/math/lp/cross_nested.h +++ b/src/math/lp/cross_nested.h @@ -127,7 +127,7 @@ public: nex* c_over_f = m_nex_creator.mk_div(*c, f); c_over_f = m_nex_creator.simplify(c_over_f); - TRACE("nla_cn", tout << "c_over_f =" << *c_over_f << std::endl;); + TRACE("nla_cn", tout << "c_over_f = " << *c_over_f << std::endl;); nex_mul* cm; *c = cm = m_nex_creator.mk_mul(f, c_over_f); TRACE("nla_cn", tout << "common factor=" << *f << ", c=" << **c << "\ne = " << *m_e << "\n";); diff --git a/src/math/lp/nex_creator.cpp b/src/math/lp/nex_creator.cpp index f8ffb7369..961c3dcee 100644 --- a/src/math/lp/nex_creator.cpp +++ b/src/math/lp/nex_creator.cpp @@ -19,6 +19,8 @@ --*/ #include "math/lp/nex_creator.h" #include +#include + namespace nla { nex * nex_creator::mk_div(const nex* a, lpvar j) { @@ -114,10 +116,11 @@ void nex_creator::simplify_children_of_mul(vector & children) { TRACE("nla_cn_details", print_vector(children, tout);); } -bool nex_creator::less_than_on_mul(const nex_mul* a, const nex_mul* b, bool skip_scalar) const { +bool nex_creator::less_than_on_mul(const nex_mul* a, const nex_mul* b) const { // the scalar, if it is there, is at the beginning of the children() - TRACE("nla_cn_details", tout << "a = " << *a << ", b = " << *b << ", skip_scalar = " << skip_scalar << "\n";); - SASSERT(is_simplified(a) && is_simplified(b)); + TRACE("nla_cn_details", tout << "a = " << *a << ", b = " << *b << "\n";); + SASSERT(is_simplified(a)); + SASSERT(is_simplified(b)); unsigned a_deg = a->get_degree(); unsigned b_deg = b->get_degree(); if (a_deg > b_deg) @@ -125,11 +128,7 @@ bool nex_creator::less_than_on_mul(const nex_mul* a, const nex_mul* b, bool skip if (a_deg < b_deg) return false; auto it_a = a->children().begin(); - if (skip_scalar && it_a->e()->is_scalar()) - it_a ++; auto it_b = b->children().begin(); - if (skip_scalar && it_b->e()->is_scalar()) - it_b ++; auto a_end = a->children().end(); auto b_end = b->children().end(); unsigned a_pow, b_pow; @@ -152,9 +151,9 @@ bool nex_creator::less_than_on_mul(const nex_mul* a, const nex_mul* b, bool skip b_pow = it_b->pow(); } - if (lt(ae, be, skip_scalar)) + if (lt(ae, be)) return true; - if (lt(be, ae, skip_scalar)) + if (lt(be, ae)) return false; if (a_pow == b_pow) { inside_a_p = inside_b_p = false; @@ -189,23 +188,26 @@ bool nex_creator::less_than_on_mul(const nex_mul* a, const nex_mul* b, bool skip } -bool nex_creator::less_than_on_var_nex(const nex_var* a, const nex* b, bool skip_scalar) const { +bool nex_creator::less_than_on_var_nex(const nex_var* a, const nex* b) const { switch(b->type()) { case expr_type::SCALAR: return false; case expr_type::VAR: return less_than(a->var() , to_var(b)->var()); case expr_type::MUL: { - nex_mul m; - m.add_child(const_cast(a)); - return less_than_on_mul(&m, to_mul(b), skip_scalar); + if (b->get_degree() > 1) + return false; + auto it = to_mul(b)->children().begin(); + const nex_pow & c = *it; + const nex * f = c.e(); + return less_than_on_var_nex(a, f); } case expr_type::SUM: { nex_sum m; m.add_child(const_cast(a)); - return lt(&m, to_sum(b), skip_scalar); + return lt(&m, to_sum(b)); } default: UNREACHABLE(); @@ -214,21 +216,24 @@ bool nex_creator::less_than_on_var_nex(const nex_var* a, const nex* b, bool skip } -bool nex_creator::less_than_on_mul_nex(const nex_mul* a, const nex* b, bool skip_scalar) const { +bool nex_creator::less_than_on_mul_nex(const nex_mul* a, const nex* b) const { switch(b->type()) { case expr_type::SCALAR: return false; case expr_type::VAR: { - nex_mul m; - m.add_child(const_cast(b)); - return less_than_on_mul(a, &m, skip_scalar); + if (a->get_degree() > 1) + return false; + auto it = a->children().begin(); + const nex_pow & c = *it; + const nex * f = c.e(); + return lt(f, a); } case expr_type::MUL: - return less_than_on_mul(a, to_mul(b), skip_scalar); + return less_than_on_mul(a, to_mul(b)); case expr_type::SUM: { const nex* fc = *(to_sum(b)->children().begin()); - return lt(a, fc, skip_scalar); + return lt(a, fc); } default: UNREACHABLE(); @@ -236,32 +241,35 @@ bool nex_creator::less_than_on_mul_nex(const nex_mul* a, const nex* b, bool skip } } -bool nex_creator::lt(const nex* a, const nex* b, bool skip_scalar) const { - TRACE("nla_cn_details", tout << "a = " << *a << ", b = " << *b << ", skip_scalar = " << skip_scalar << "\n";); - +bool nex_creator::lt(const nex* a, const nex* b) const { + bool ret; switch (a->type()) { case expr_type::VAR: - return less_than_on_var_nex(to_var(a), b, skip_scalar); - + ret = less_than_on_var_nex(to_var(a), b); + break; case expr_type::SCALAR: { if (b->is_scalar()) - return - to_scalar(a)->value() < to_scalar(b)->value(); - return true; // the scalars are the smallest + ret = to_scalar(a)->value() < to_scalar(b)->value(); + else + ret = true; // the scalars are the smallest + break; } case expr_type::MUL: { - return less_than_on_mul_nex(to_mul(a), b, skip_scalar); + ret = less_than_on_mul_nex(to_mul(a), b); + break; } case expr_type::SUM: { UNREACHABLE(); return false; } default: - SASSERT(false); + UNREACHABLE(); return false; } - return false; + TRACE("nla_cn_details_", tout << *a << (ret?" < ":" >= ") << *b << "\n";); + + return ret; } @@ -274,40 +282,12 @@ bool nex_creator::is_sorted(const nex_mul* e) const { } -bool nex_creator::less_than_nex(const nex* a, const nex* b) const { - int r = (int)(a->type()) - (int)(b->type()); - if (r) { - return r < 0; - } - SASSERT(a->type() == b->type()); - switch (a->type()) { - case expr_type::VAR: { - return less_than(to_var(a)->var() , to_var(b)->var()); - } - case expr_type::SCALAR: { - return to_scalar(a)->value() < to_scalar(b)->value(); - } - case expr_type::MUL: { - NOT_IMPLEMENTED_YET(); - return false; // to_mul(a)->children() < to_mul(b)->children(); - } - case expr_type::SUM: { - NOT_IMPLEMENTED_YET(); - return false; //to_sum(a)->children() < to_sum(b)->children(); - } - default: - SASSERT(false); - return false; - } - - return false; -} bool nex_creator::mul_is_simplified(const nex_mul* e) const { - if (size() == 1 && e->children().begin()->pow() == 1) + if (e->size() == 1 && e->children().begin()->pow() == 1) return false; - std::set s([this](const nex* a, const nex* b) {return less_than_nex(a, b); }); + std::set s([this](const nex* a, const nex* b) {return lt(a, b); }); for (const auto &p : e->children()) { const nex* ee = p.e(); if (p.pow() == 0) @@ -331,7 +311,7 @@ bool nex_creator::mul_is_simplified(const nex_mul* e) const { nex * nex_creator::simplify_mul(nex_mul *e) { TRACE("nla_cn_details", tout << *e << "\n";); simplify_children_of_mul(e->children()); - if (size() == 1 && e->children()[0].pow() == 1) + if (e->size() == 1 && e->children()[0].pow() == 1) return e->children()[0].e(); TRACE("nla_cn_details", tout << *e << "\n";); SASSERT(is_simplified(e)); @@ -357,7 +337,7 @@ bool nex_creator::sum_is_simplified(const nex_sum* e) const { } void nex_creator::mul_to_powers(vector& children) { - std::map m([this](const nex* a, const nex* b) {return less_than_nex(a, b); }); + std::map m([this](const nex* a, const nex* b) {return lt(a, b); }); for (auto & p : children) { auto it = m.find(p.e()); @@ -378,95 +358,116 @@ void nex_creator::mul_to_powers(vector& children) { } nex* nex_creator::create_child_from_nex_and_coeff(nex *e, - const rational& coeff) { + const rational& coeff) { + TRACE("nla_cn_details", tout << *e << ", coeff = " << coeff << "\n";); if (coeff.is_one()) return e; SASSERT(is_simplified(e)); - switch (e->type()) { - case expr_type::VAR: { - if (coeff.is_one()) - return e; - return mk_mul(mk_scalar(coeff), e); - } - case expr_type::SCALAR: { - return mk_scalar(coeff); - } - case expr_type::MUL: { - nex_mul * em = to_mul(e); - nex_pow *np = em->children().begin(); - if (np->e()->is_scalar()) { - SASSERT(np->pow() == 1); - to_scalar(np->e())->value() = coeff; - return e; - } - em->add_child(mk_scalar(coeff)); - std::sort(em->children().begin(), em->children().end(), [this](const nex_pow& a, - const nex_pow& b) {return less_than_on_nex_pow(a, b);}); - return em; - } - case expr_type::SUM: { - return mk_mul(mk_scalar(coeff), e); - } - default: - UNREACHABLE(); - return nullptr; - } - + switch (e->type()) { + case expr_type::VAR: { + if (coeff.is_one()) + return e; + return mk_mul(mk_scalar(coeff), e); } + case expr_type::SCALAR: { + return mk_scalar(coeff); + } + case expr_type::MUL: { + nex_mul * em = to_mul(e); + nex_pow *np = em->children().begin(); + if (np->e()->is_scalar()) { + SASSERT(np->pow() == 1); + to_scalar(np->e())->value() = coeff; + return e; + } + em->add_child(mk_scalar(coeff)); + std::sort(em->children().begin(), em->children().end(), [this](const nex_pow& a, + const nex_pow& b) {return less_than_on_nex_pow(a, b);}); + return em; + } + case expr_type::SUM: { + return mk_mul(mk_scalar(coeff), e); + } + default: + UNREACHABLE(); + return nullptr; + } + +} +// returns true if new +bool nex_creator::register_in_join_map(std::map& map, nex* e, const rational& r) const{ + TRACE("nla_cn_details", tout << *e << ", r = " << r << std::endl;); + auto map_it = map.find(e); + if (map_it == map.end()) { + map[e] = r; + return true; + } else { + map_it->second += r; + return false; + } +} + +void nex_creator::process_mul_in_simplify_sum(nex_mul* em, std::map &map, vector & tmp) { + auto it = em->children().begin(); + if (it->e()->is_scalar()) { + rational r = to_scalar(it->e())->value(); + auto end = em->children().end(); + if (em->children().size() == 2 && em->children()[1].pow() == 1) { + register_in_join_map(map, em->children()[1].e(), r); + } + SASSERT(it->pow() == 1); + tmp.push_back(nex_mul()); + nex_mul * m = &tmp[tmp.size()-1]; + for (it++; it != end; it++) { + m->add_child_in_power(it->e(), it->pow()); + } + if (!register_in_join_map(map, m, r)) + tmp.pop_back(); + } else { + register_in_join_map(map, em, rational(1)); + } +} // a + 3bc + 2bc => a + 5bc void nex_creator::sort_join_sum(ptr_vector & children) { - std::map m([this](const nex *a , const nex *b) - { return lt(a, b, true); }); + std::map map([this](const nex *a , const nex *b) + { return lt(a, b); }); TRACE("nla_cn_details", print_vector_of_ptrs(children, tout);); - fill_map_with_children(m, children); - TRACE("nla_cn_details", for (auto & p : m ) { tout << "(" << *p.first << ", " << p.second << ") ";}); - children.clear(); - for (auto& p : m) { - children.push_back(create_child_from_nex_and_coeff(p.first, p.second)); - } -} - -rational nex_creator::extract_coeff_from_mul(const nex_mul* m) { - const nex* e = m->children().begin()->e(); - if (e->is_scalar()) { - SASSERT(m->children().begin()->pow() == 1); - return to_scalar(e)->value(); - } - return rational(1); -} - -rational nex_creator::extract_coeff(const nex* m) { - if (!m->is_mul()) - return rational(1); - return extract_coeff_from_mul(to_mul(m)); -} - - -void nex_creator::fill_map_with_children(std::map & m, ptr_vector & children) { - nex_scalar * scalar = nullptr; - TRACE("nla_cn_details", print_vector_of_ptrs(children, tout);); - for (nex* e : children) { - if (e->is_scalar()) { - if (scalar == nullptr) { - scalar = to_scalar(e); - } else { - scalar->value() += to_scalar(e)->value(); - } - } else { - rational r = extract_coeff(e); - auto it = m.find(e); - if (it == m.end()) { - m[e] = r; - } else { - it->second += r; - } + vector tmp; + nex_scalar * s = nullptr; + for (auto e : children) { + if (e->is_mul()) { + process_mul_in_simplify_sum(to_mul(e), map, tmp); + } else if (e->is_scalar()) { + nex_scalar * es = to_scalar(e); + if (s == nullptr) + s = es; + else + s->value() += es->value(); + } + else { + register_in_join_map(map, e, rational(1)); } } - if (scalar && scalar->value().is_zero() == false) { - m[scalar] = rational(scalar->value()); + + bool simplified; + for (auto& p : map) { + if (!p.second.is_one()) { + simplified = true; + break; + } + } + if (!simplified) + return; + TRACE("nla_cn_details", for (auto & p : map ) { tout << "(" << *p.first << ", " << p.second << ") ";}); + children.clear(); + if (s) { + children.push_back(s); + } + for (auto& p : map) { + if (p.second.is_zero() == false) + children.push_back(create_child_from_nex_and_coeff(p.first, p.second)); } - } bool is_zero_scalar(nex *e) { diff --git a/src/math/lp/nex_creator.h b/src/math/lp/nex_creator.h index 5845b1314..f5a89fef3 100644 --- a/src/math/lp/nex_creator.h +++ b/src/math/lp/nex_creator.h @@ -68,8 +68,6 @@ public: const svector& active_vars_weights() const { return m_active_vars_weights;} nex* simplify(nex* e); - rational extract_coeff_from_mul(const nex_mul* m); - rational extract_coeff(const nex*); bool less_than(lpvar j, lpvar k) const{ unsigned wj = (unsigned)m_active_vars_weights[j]; @@ -77,10 +75,8 @@ public: return wj != wk ? wj < wk : j < k; } - bool less_than_nex(const nex* a, const nex* b) const; - bool less_than_on_nex_pow(const nex_pow & a, const nex_pow& b) const { - return (a.pow() < b.pow()) || (a.pow() == b.pow() && less_than_nex(a.e(), b.e())); + return (a.pow() < b.pow()) || (a.pow() == b.pow() && lt(a.e(), b.e())); } void simplify_children_of_mul(vector & children); @@ -211,7 +207,8 @@ public: nex * simplify_mul(nex_mul *e); bool is_sorted(const nex_mul * e) const; - nex* simplify_sum(nex_sum *e); + nex* simplify_sum(nex_sum *e); + void process_mul_in_simplify_sum(nex_mul* e, std::map &, vector &); bool is_simplified(const nex *e) const; bool sum_is_simplified(const nex_sum* e) const; @@ -223,17 +220,17 @@ public: const rational& coeff) ; void sort_join_sum(ptr_vector & children); + bool register_in_join_map(std::map&, nex*, const rational&) const; void simplify_children_of_sum(ptr_vector & children); bool eat_scalar_pow(nex_scalar *& r, nex_pow& p); void simplify_children_of_mul(vector & children, lt_on_vars lt, std::function mk_scalar); - bool lt(const nex* a, const nex* b, bool skip_scalar) const; - - bool less_than_on_mul(const nex_mul* a, const nex_mul* b, bool skip_scalar) const; - bool less_than_on_var_nex(const nex_var* a, const nex* b, bool skip_scalar) const; - bool less_than_on_mul_nex(const nex_mul* a, const nex* b, bool skip_scalar) const; + bool lt(const nex* a, const nex* b) const; + bool less_than_on_mul(const nex_mul* a, const nex_mul* b) const; + bool less_than_on_var_nex(const nex_var* a, const nex* b) const; + bool less_than_on_mul_nex(const nex_mul* a, const nex* b) const; void fill_map_with_children(std::map & m, ptr_vector & children); }; }