diff --git a/src/math/lp/nex.cpp b/src/math/lp/nex.cpp index 8481f5ddc..6d5abe4f9 100644 --- a/src/math/lp/nex.cpp +++ b/src/math/lp/nex.cpp @@ -83,18 +83,33 @@ void promote_children_of_sum(ptr_vector & children, nex_lt lt ) { } } -void promote_children_of_mul(vector & children, nex_lt lt) { +bool eat_scalar(nex_scalar *& r, nex_pow& p) { + if (!p.e()->is_scalar()) + return false; + nex_scalar *pe = to_scalar(p.e()); + if (r == nullptr) { + r = pe; + r->value() = r->value().expt(p.pow()); + } else { + r->value() *= pe->value().expt(p.pow()); + } + return true; +} + +void simplify_children_of_mul(vector & children, nex_lt lt) { + nex_scalar* r = nullptr; TRACE("nla_cn_details", print_vector(children, tout);); vector to_promote; int skipped = 0; - for(unsigned j = 0; j < children.size(); j++) { + for(unsigned j = 0; j < children.size(); j++) { nex_pow& p = children[j]; + if (eat_scalar(r, p)) { + skipped++; + continue; + } (p.e())->simplify(p.ee(), lt); if ((p.e())->is_mul()) { to_promote.push_back(p); - } else if (ignored_child(p.e(), expr_type::MUL)) { - skipped ++; - continue; } else { unsigned offset = to_promote.size() + skipped; if (offset) { @@ -104,14 +119,18 @@ void promote_children_of_mul(vector & children, nex_lt lt) { } children.shrink(children.size() - to_promote.size() - skipped); - + for (nex_pow & p : to_promote) { for (nex_pow& pp : to_mul(p.e())->children()) { - SASSERT(!ignored_child(pp.e(), expr_type::MUL)); - children.push_back(nex_pow(pp.e(), pp.pow() * p.pow())); + if (!eat_scalar(r, pp)) + children.push_back(nex_pow(pp.e(), pp.pow() * p.pow())); } } + if (r != nullptr) { + children.push_back(nex_pow(r)); + } + mul_to_powers(children, lt); TRACE("nla_cn_details", print_vector(children, tout);); diff --git a/src/math/lp/nex.h b/src/math/lp/nex.h index bb9a256f5..c874c7627 100644 --- a/src/math/lp/nex.h +++ b/src/math/lp/nex.h @@ -108,7 +108,7 @@ public: lpvar var() const { return m_j; } lpvar& var() { return m_j; } // the setter std::ostream & print(std::ostream& out) const { - out << 'v' << m_j; + out << (char)('a' + m_j); return out; } @@ -141,7 +141,7 @@ const nex_sum* to_sum(const nex*a); void promote_children_of_sum(ptr_vector & children, nex_lt); class nex_pow; -void promote_children_of_mul(vector & children, nex_lt); +void simplify_children_of_mul(vector & children, nex_lt); class nex_pow { nex* m_e; @@ -243,7 +243,7 @@ public: TRACE("nla_cn_details", tout << "**e = " << **e << "\n";); *e = this; TRACE("nla_cn_details", tout << *this << "\n";); - promote_children_of_mul(m_children, lt); + simplify_children_of_mul(m_children, lt); if (size() == 1 && m_children[0].pow() == 1) *e = m_children[0].e(); TRACE("nla_cn_details", tout << *this << "\n";); @@ -424,6 +424,11 @@ inline const nex_scalar* to_scalar(const nex*a) { return static_cast(a); } +inline nex_scalar* to_scalar(nex*a) { + SASSERT(a->is_scalar()); + return static_cast(a); +} + inline const nex_mul* to_mul(const nex*a) { SASSERT(a->is_mul()); return static_cast(a); diff --git a/src/test/lp/lp.cpp b/src/test/lp/lp.cpp index 129918020..1c5ab4a89 100644 --- a/src/test/lp/lp.cpp +++ b/src/test/lp/lp.cpp @@ -85,19 +85,23 @@ void test_simplify() { ); enable_trace("nla_cn"); enable_trace("nla_cn_details"); - nex_var* a = cn.get_nex_creator().mk_var(0); - nex_var* b = cn.get_nex_creator().mk_var(1); - nex_var* c = cn.get_nex_creator().mk_var(2); - auto & r = cn.get_nex_creator(); - auto m = r.mk_mul(); m->add_child_in_power(c, 2); + auto & creator = cn.get_nex_creator(); + nex_var* a = creator.mk_var(0); + nex_var* b = creator.mk_var(1); + nex_var* c = creator.mk_var(2); + auto m = creator.mk_mul(); m->add_child_in_power(c, 2); TRACE("nla_cn", tout << "m = " << *m << "\n";); - auto n = r.mk_mul(a); + auto n = creator.mk_mul(a); n->add_child_in_power(b, 7); + n->add_child(creator.mk_scalar(rational(3))); + n->add_child_in_power(creator.mk_scalar(rational(4)), 2); + n->add_child(creator.mk_scalar(rational(1))); TRACE("nla_cn", tout << "n = " << *n << "\n";); m->add_child_in_power(n, 3); + n->add_child_in_power(creator.mk_scalar(rational(1, 3)), 2); TRACE("nla_cn", tout << "m = " << *m << "\n";); - nex * e = r.mk_sum(a, r.mk_sum(b, m)); + nex * e = creator.mk_sum(a, creator.mk_sum(b, m)); TRACE("nla_cn", tout << "e = " << *e << "\n";); e->simplify(&e); TRACE("nla_cn", tout << "simplified e = " << *e << "\n";);