From 9fbd0da93102fe42ce8a59e45e2a232bf74c3750 Mon Sep 17 00:00:00 2001 From: Lev Nachmanson <levnach@hotmail.com> Date: Thu, 15 Aug 2019 17:15:45 -0700 Subject: [PATCH] rewrite horner scheme on top of nex_expr as a pointer Signed-off-by: Lev Nachmanson <levnach@hotmail.com> --- src/math/lp/cross_nested.h | 337 +++++++++++++------ src/math/lp/horner.cpp | 122 +++---- src/math/lp/horner.h | 28 +- src/math/lp/nla_core.cpp | 66 ++-- src/math/lp/nla_core.h | 5 - src/math/lp/nla_expr.h | 665 +++++++++++-------------------------- src/test/lp/lp.cpp | 35 +- 7 files changed, 563 insertions(+), 695 deletions(-) diff --git a/src/math/lp/cross_nested.h b/src/math/lp/cross_nested.h index f1de18ff4..cda1c98cb 100644 --- a/src/math/lp/cross_nested.h +++ b/src/math/lp/cross_nested.h @@ -22,7 +22,6 @@ #include "math/lp/nla_expr.h" namespace nla { class cross_nested { - typedef nla_expr<rational> nex; struct occ { unsigned m_occs; unsigned m_power; @@ -36,61 +35,179 @@ class cross_nested { }; // fields - nex& m_e; - std::function<bool (const nex&)> m_call_on_result; + nex_sum * m_e; + std::function<bool (const nex*)> m_call_on_result; std::function<bool (unsigned)> m_var_is_fixed; bool m_done; std::unordered_map<lpvar, occ> m_occurences_map; std::unordered_map<lpvar, unsigned> m_powers; - + vector<nex*> m_allocated; + vector<nex*> m_b_vec; public: - cross_nested(nex &e, - std::function<bool (const nex&)> call_on_result, + cross_nested(std::function<bool (const nex*)> call_on_result, std::function<bool (unsigned)> var_is_fixed): - m_e(e), m_call_on_result(call_on_result), m_var_is_fixed(var_is_fixed), m_done(false) {} - void run() { - vector<nex*> front; - explore_expr_on_front_elem(&m_e, front); // true for trivial form - no change + void run(nex_sum *e) { + m_e = e; + + vector<nex_sum*> front; + explore_expr_on_front_elem(m_e, front); } - static nex* pop_back(vector<nex*>& front) { - nex* c = front.back(); + static nex_sum* pop_back(vector<nex_sum*>& front) { + nex_sum* c = front.back(); TRACE("nla_cn", tout << *c << "\n";); front.pop_back(); return c; } - static bool extract_common_factor(nex* c, nex& f, const vector<std::pair<lpvar, occ>> & occurences) { + nex_sum* mk_sum() { + auto r = new nex_sum(); + m_allocated.push_back(r); + return r; + } + + nex_sum* mk_sum(const vector<nex*>& v) { + auto r = new nex_sum(); + m_allocated.push_back(r); + r->children() = v; + return r; + } + + nex_sum* mk_sum(nex *a, nex* b) { + auto r = new nex_sum(); + m_allocated.push_back(r); + r->children().push_back(a); + r->children().push_back(b); + return r; + } + + nex_var* mk_var(lpvar j) { + auto r = new nex_var(j); + m_allocated.push_back(r); + return r; + } + + nex_mul* mk_mul() { + auto r = new nex_mul(); + m_allocated.push_back(r); + return r; + } + + nex_mul* mk_mul(nex * a, nex * b) { + auto r = new nex_mul(); + m_allocated.push_back(r); + r->add_child(a); r->add_child(b); + return r; + } + + nex_mul* mk_mul(nex * a, nex * b, nex *c) { + auto r = new nex_mul(); + m_allocated.push_back(r); + r->add_child(a); r->add_child(b); r->add_child(c); + return r; + } + + nex_scalar* mk_scalar(const rational& v) { + auto r = new nex_scalar(v); + m_allocated.push_back(r); + return r; + } + + + nex * mk_div(const nex* a, lpvar j) { + SASSERT(false); + return nullptr; + } + + nex * mk_div(const nex* a, const nex* b) { + TRACE("nla_cn_details", tout << *a <<" / " << *b << "\n";); + if (b->is_var()) { + return mk_div(a, to_var(b)->var()); + } + SASSERT(b->is_mul()); + const nex_mul *bm = to_mul(b); + if (a->is_sum()) { + nex_sum * r = mk_sum(); + const nex_sum * m = to_sum(a); + for (auto e : m->children()) { + r->add_child(mk_div(e, bm)); + } + TRACE("nla_cn_details", tout << *r << "\n";); + return r; + } + if (a->is_var() || (a->is_mul() && to_mul(a)->children().size() == 1)) { + return mk_scalar(rational(1)); + } + SASSERT(a->is_mul()); + const nex_mul* am = to_mul(a); + bm->get_powers_from_mul(m_powers); + nex_mul* ret = new nex_mul(); + for (auto e : am->children()) { + TRACE("nla_cn_details", tout << "e=" << *e << "\n";); + + if (!e->is_var()) { + SASSERT(e->is_scalar()); + ret->add_child(e); + TRACE("nla_cn_details", tout << "continue\n";); + continue; + } + SASSERT(e->is_var()); + lpvar j = to_var(e)->var(); + auto it = m_powers.find(j); + if (it == m_powers.end()) { + ret->add_child(e); + } else { + it->second --; + if (it->second == 0) + m_powers.erase(it); + } + TRACE("nla_cn_details", tout << *ret << "\n";); + } + SASSERT(m_powers.size() == 0); + if (ret->children().size() == 0) { + delete ret; + TRACE("nla_cn_details", tout << "return 1\n";); + return mk_scalar(rational(1)); + } + m_allocated.push_back(ret); + TRACE("nla_cn_details", tout << *ret << "\n";); + return ret; + } + + nex* extract_common_factor(nex* e, const vector<std::pair<lpvar, occ>> & occurences) { + nex_sum* c = to_sum(e); TRACE("nla_cn", tout << "c=" << *c << "\n";); - SASSERT(c->is_sum()); - f.type() = expr_type::MUL; - SASSERT(f.children().empty()); unsigned size = c->children().size(); for(const auto & p : occurences) { + if (p.second.m_occs < size) { + return nullptr; + } + } + nex_mul* f = mk_mul(); + for(const auto & p : occurences) { // randomize here: todo if (p.second.m_occs == size) { unsigned pow = p.second.m_power; while (pow --) { - f *= nex::var(p.first); + f->add_child(mk_var(p.first)); } } } - return !f.children().empty(); + return f; } - static bool has_common_factor(const nex& c) { - TRACE("nla_cn", tout << "c=" << c << "\n";); - SASSERT(c.is_sum()); - auto & ch = c.children(); + static bool has_common_factor(const nex_sum* c) { + TRACE("nla_cn", tout << "c=" << *c << "\n";); + auto & ch = c->children(); auto common_vars = get_vars_of_expr(ch[0]); for (lpvar j : common_vars) { bool divides_the_rest = true; for(unsigned i = 1; i < ch.size() && divides_the_rest; i++) { - if (!ch[i].contains(j)) + if (!ch[i]->contains(j)) divides_the_rest = false; } if (divides_the_rest) { @@ -101,45 +218,45 @@ public: return false; } - bool proceed_with_common_factor(nex* c, vector<nex*>& front, const vector<std::pair<lpvar, occ>> & occurences) { + bool proceed_with_common_factor(nex*& c, vector<nex_sum*>& front, const vector<std::pair<lpvar, occ>> & occurences) { TRACE("nla_cn", tout << "c=" << *c << "\n";); - SASSERT(c->is_sum()); - nex f; - if (!extract_common_factor(c, f, occurences)) + nex* f = extract_common_factor(c, occurences); + if (f == nullptr) return false; - *c /= f; - f.simplify(); - * c = nex::mul(f, *c); - TRACE("nla_cn", tout << "common factor=" << f << ", c=" << *c << "\n";); - explore_expr_on_front_elem(&(c->children()[1]), front); + nex_sum* c_over_f = to_sum(mk_div(c, f)); + c_over_f->simplify(); + c = mk_mul(f, c_over_f); + TRACE("nla_cn", tout << "common factor=" << *f << ", c=" << *c << "\ne = " << *m_e << "\n";); + + explore_expr_on_front_elem(c_over_f, front); return true; } - static void push(vector<nex*>& front, nex* e) { + static void push(vector<nex_sum*>& front, nex_sum* e) { TRACE("nla_cn", tout << *e << "\n";); front.push_back(e); } - static vector<nex> copy_front(const vector<nex*>& front) { - vector<nex> v; - for (nex* n: front) - v.push_back(*n); + static vector<nex_sum*> copy_front(const vector<nex_sum*>& front) { + vector<nex_sum*> v; + for (nex_sum* n: front) + v.push_back(n); return v; } - static void restore_front(const vector<nex> ©, vector<nex*>& front) { + static void restore_front(const vector<nex_sum*> ©, vector<nex_sum*>& front) { SASSERT(copy.size() == front.size()); for (unsigned i = 0; i < front.size(); i++) - *(front[i]) = copy[i]; + front[i] = copy[i]; } - void explore_expr_on_front_elem_occs(nex* c, vector<nex*>& front, const vector<std::pair<lpvar, occ>> & occurences) { + void explore_expr_on_front_elem_occs(nex* c, vector<nex_sum*>& front, const vector<std::pair<lpvar, occ>> & occurences) { if (proceed_with_common_factor(c, front, occurences)) return; TRACE("nla_cn", tout << "save c=" << *c << "; front:"; print_vector_of_ptrs(front, tout) << "\n";); - nex copy_of_c = *c; - vector<nex> copy_of_front = copy_front(front); + nex* copy_of_c = c; + auto copy_of_front = copy_front(front); for(auto& p : occurences) { SASSERT(p.second.m_occs > 1); lpvar j = p.first; @@ -152,7 +269,7 @@ public: explore_of_expr_on_sum_and_var(c, j, front); if (m_done) return; - *c = copy_of_c; + c = copy_of_c; TRACE("nla_cn", tout << "restore c=" << *c << ", m_e=" << m_e << "\n";); restore_front(copy_of_front, front); TRACE("nla_cn", tout << "restore c=" << *c << "\n";); @@ -171,9 +288,8 @@ public: return out; } - void explore_expr_on_front_elem(nex* c, vector<nex*>& front) { - SASSERT(c->is_sum()); - auto occurences = get_mult_occurences(*c); + void explore_expr_on_front_elem(nex_sum* c, vector<nex_sum*>& front) { + auto occurences = get_mult_occurences(c); TRACE("nla_cn", tout << "m_e=" << m_e << "\nc=" << *c << ", c occurences="; dump_occurences(tout, occurences) << "; front:"; print_vector_of_ptrs(front, tout) << "\n";); @@ -182,7 +298,7 @@ public: TRACE("nla_cn", tout << "got the cn form: =" << m_e << "\n";); m_done = m_call_on_result(m_e); } else { - nex* c = pop_back(front); + auto c = pop_back(front); explore_expr_on_front_elem(c, front); } } else { @@ -196,17 +312,17 @@ public: // return (char)('a'+j); } // e is the global expression, c is the sub expressiond which is going to changed from sum to the cross nested form - void explore_of_expr_on_sum_and_var(nex* c, lpvar j, vector<nex*> front) { + void explore_of_expr_on_sum_and_var(nex* & c, lpvar j, vector<nex_sum*> front) { TRACE("nla_cn", tout << "m_e=" << m_e << "\nc=" << *c << "\nj = " << ch(j) << "\nfront="; print_vector_of_ptrs(front, tout) << "\n";); - if (!split_with_var(*c, j, front)) + if (!split_with_var(c, j, front)) return; TRACE("nla_cn", tout << "after split c=" << *c << "\nfront="; print_vector_of_ptrs(front, tout) << "\n";); SASSERT(front.size()); - nex* n = pop_back(front); TRACE("nla_cn", tout << "n=" << *n <<"\n";); + auto n = pop_back(front); TRACE("nla_cn", tout << "n=" << *n <<"\n";); explore_expr_on_front_elem(n, front); } - void process_var_occurences(lpvar j) { + void add_var_occs(lpvar j) { auto it = m_occurences_map.find(j); if (it != m_occurences_map.end()) { it->second.m_occs++; @@ -251,15 +367,14 @@ public: // j -> the number of expressions j appears in as a multiplier // The result is sorted by large number of occurences first - vector<std::pair<lpvar, occ>> get_mult_occurences(const nex& e) { + vector<std::pair<lpvar, occ>> get_mult_occurences(const nex_sum* e) { clear_maps(); - SASSERT(e.type() == expr_type::SUM); - for (const auto & ce : e.children()) { - if (ce.is_mul()) { - auto powers = ce.get_powers_from_mul(); + for (const auto * ce : e->children()) { + if (ce->is_mul()) { + to_mul(ce)->get_powers_from_mul(m_powers); update_occurences_with_powers(); - } else if (ce.type() == expr_type::VAR) { - process_var_occurences(ce.var()); + } else if (ce->is_var()) { + add_var_occs(to_var(ce)->var()); } } remove_singular_occurences(); @@ -281,63 +396,65 @@ public: }); return ret; } + + static bool is_divisible_by_var(nex* ce, lpvar j) { + return (ce->is_mul() && to_mul(ce)->contains(j)) + || (ce->is_var() && to_var(ce)->var() == j); + } // all factors of j go to a, the rest to b - static void pre_split(nex &e, lpvar j, nex &a, nex&b) { - for (const nex & ce: e.children()) { - if ((ce.is_mul() && ce.contains(j)) || (ce.is_var() && ce.var() == j)) { - a.add_child(ce / j); + void pre_split(nex_sum * e, lpvar j, nex_sum* & a, nex* & b) { + + a = mk_sum(); + m_b_vec.clear(); + for (nex * ce: e->children()) { + if (is_divisible_by_var(ce, j)) { + a->add_child(mk_div(ce , j)); } else { - b.add_child(ce); + m_b_vec.push_back(ce); } } - a.type() = expr_type::SUM; - TRACE("nla_cn_details", tout << "a = " << a << "\n";); - SASSERT(a.children().size() >= 2); - a.simplify(); + TRACE("nla_cn_details", tout << "a = " << *a << "\n";); + SASSERT(a->children().size() >= 2 && m_b_vec.size()); + a->simplify(); - if (b.children().size() == 1) { - nex t = b.children()[0]; - b = t; - } else if (b.children().size() > 1) { - b.type() = expr_type::SUM; - } + if (m_b_vec.size() == 1) { + b = m_b_vec[0]; + } else { + SASSERT(m_b_vec.size() > 1); + b = mk_sum(m_b_vec); + } } - // returns true if the recursion is done inside - void update_front_with_split_with_non_empty_b(nex& e, lpvar j, vector<nex*> & front, nex& a, nex& b) { - nex f; - SASSERT(a.is_sum()); + void update_front_with_split_with_non_empty_b(nex* &e, lpvar j, vector<nex_sum*> & front, nex_sum* a, nex* b) { + + SASSERT(a->is_sum()); TRACE("nla_cn_details", tout << "b = " << b << "\n";); - e = nex::sum(nex::mul(nex::var(j), a), b); - push(front, &(e.children()[0].children()[1])); // pushing 'a' - TRACE("nla_cn", tout << "push to front " << e.children()[0].children()[1] << "\n";); + e = mk_sum(mk_mul(mk_var(j), a), b); // e = j*a + b + push(front, a); // pushing 'a' + TRACE("nla_cn", tout << "push to front " << *a << "\n";); - if (b.is_sum()) { - push(front, &(e.children()[1])); - TRACE("nla_cn", tout << "push to front " << e.children()[1] << "\n";); + if (b->is_sum()) { + push(front, to_sum(b)); + TRACE("nla_cn", tout << "push to front " << *b << "\n";); } } - void update_front_with_split(nex& e, lpvar j, vector<nex*> & front, nex& a, nex& b) { - if (b.is_undef()) { - SASSERT(b.children().size() == 0); - e = nex(expr_type::MUL); - e.add_child(nex::var(j)); - e.add_child(a); - if (a.size() > 1) { - push(front, &e.children().back()); - TRACE("nla_cn_details", tout << "push to front " << e.children().back() << "\n";); - } + void update_front_with_split(nex* & e, lpvar j, vector<nex_sum*> & front, nex_sum* a, nex* b) { + if (b == nullptr) { + e = mk_mul(mk_var(j), a); + push(front, a); + TRACE("nla_cn_details", tout << "push to front " << *a << "\n";); + } else { + update_front_with_split_with_non_empty_b(e, j, front, a, b); } - update_front_with_split_with_non_empty_b(e, j, front, a, b); } // it returns true if the recursion brings a cross-nested form - bool split_with_var(nex& e, lpvar j, vector<nex*> & front) { + bool split_with_var(nex*& e, lpvar j, vector<nex_sum*> & front) { + SASSERT(e->is_sum()); TRACE("nla_cn", tout << "e = " << e << ", j=" << ch(j) << "\n";); - if (!e.is_sum()) return false; - nex a, b; - pre_split(e, j, a, b); + nex_sum* a; nex * b; + pre_split(to_sum(e), j, a, b); /* When we have e without a non-trivial common factor then there is a variable j such that e = jP + Q, where Q has all members @@ -352,28 +469,42 @@ public: update_front_with_split(e, j, front, a, b); return true; } - static std::unordered_set<lpvar> get_vars_of_expr(const nex &e ) { + + static std::unordered_set<lpvar> get_vars_of_expr(const nex *e ) { std::unordered_set<lpvar> r; - switch (e.type()) { + switch (e->type()) { case expr_type::SCALAR: return r; case expr_type::SUM: + { + for (auto c: to_sum(e)->children()) + for ( lpvar j : get_vars_of_expr(c)) + r.insert(j); + } case expr_type::MUL: { - for (const auto & c: e.children()) + for (auto c: to_mul(e)->children()) for ( lpvar j : get_vars_of_expr(c)) r.insert(j); } return r; case expr_type::VAR: - r.insert(e.var()); + r.insert(to_var(e)->var()); return r; default: - TRACE("nla_cn_details", tout << e.type() << "\n";); + TRACE("nla_cn_details", tout << e->type() << "\n";); SASSERT(false); return r; } } + + ~cross_nested() { + for (auto e: m_allocated) + delete e; + m_allocated.clear(); + } + bool done() const { return m_done; } + }; } diff --git a/src/math/lp/horner.cpp b/src/math/lp/horner.cpp index 0d87a728b..5c842dd78 100644 --- a/src/math/lp/horner.cpp +++ b/src/math/lp/horner.cpp @@ -63,31 +63,31 @@ bool horner::row_is_interesting(const T& row) const { return false; } -bool horner::lemmas_on_expr(nex& e) { +bool horner::lemmas_on_expr(nex_sum* e, cross_nested& cn) { TRACE("nla_horner", tout << "e = " << e << "\n";); - bool conflict = false; - cross_nested cn(e, [this, & conflict](const nex& n) { + cn.run(e); + return cn.done(); +} + + +template <typename T> +bool horner::lemmas_on_row(const T& row) { + cross_nested cn([this](const nex* n) { TRACE("nla_horner", tout << "cross-nested n = " << n << "\n";); auto i = interval_of_expr(n); TRACE("nla_horner", tout << "callback n = " << n << "\ni="; m_intervals.display(tout, i) << "\n";); - conflict = m_intervals.check_interval_for_conflict_on_zero(i); + bool conflict = m_intervals.check_interval_for_conflict_on_zero(i); c().lp_settings().st().m_cross_nested_forms++; m_intervals.reset(); // clean the memory allocated by the interval bound dependencies return conflict; }, [this](unsigned j) { return c().var_is_fixed(j); } ); - cn.run(); - return conflict; -} - -template <typename T> -bool horner::lemmas_on_row(const T& row) { SASSERT (row_is_interesting(row)); - nex e = create_sum_from_row(row); - return lemmas_on_expr(e); + nex_sum* e = create_sum_from_row(row, cn); + return lemmas_on_expr(e, cn); } void horner::horner_lemmas() { @@ -120,27 +120,28 @@ void horner::horner_lemmas() { } } -typedef nla_expr<rational> nex; - -nex horner::nexvar(lpvar j) const { +nex * horner::nexvar(lpvar j, cross_nested& cn) const { // todo: consider deepen the recursion if (!c().is_monomial_var(j)) - return nex::var(j); + return cn.mk_var(j); const monomial& m = c().emons()[j]; - nex e(expr_type::MUL); + nex_mul * e = cn.mk_mul(); for (lpvar k : m.vars()) { - e.add_child(nex::var(k)); + e->add_child(cn.mk_var(k)); CTRACE("nla_horner", c().is_monomial_var(k), c().print_var(k, tout) << "\n";); } return e; } -template <typename T> nex horner::create_sum_from_row(const T& row) { +template <typename T> nex_sum* horner::create_sum_from_row(const T& row, cross_nested& cn) { TRACE("nla_horner", tout << "row="; m_core->print_term(row, tout) << "\n";); SASSERT(row.size() > 1); - nex e(expr_type::SUM); - for (const auto &p : row) { - e.add_child(nex::scalar(p.coeff())* nexvar(p.var())); + nex_sum *e = cn.mk_sum(); + for (const auto &p : row) { + if (p.coeff().is_one()) + e->add_child(nexvar(p.var(), cn)); + else + e->add_child(cn.mk_mul(cn.mk_scalar(p.coeff()), nexvar(p.var(), cn))); } return e; } @@ -155,28 +156,28 @@ void horner::set_interval_for_scalar(interv& a, const rational& v) { m_intervals.set_upper_is_inf(a, false); } -interv horner::interval_of_expr(const nex& e) { +interv horner::interval_of_expr(const nex* e) { interv a; - switch (e.type()) { + switch (e->type()) { case expr_type::SCALAR: - set_interval_for_scalar(a, e.value()); + set_interval_for_scalar(a, to_scalar(e)->value()); return a; case expr_type::SUM: - return interval_of_sum(e); + return interval_of_sum(to_sum(e)); case expr_type::MUL: - return interval_of_mul(e); + return interval_of_mul(to_mul(e)); case expr_type::VAR: - set_var_interval(e.var(), a); + set_var_interval(to_var(e)->var(), a); return a; default: - TRACE("nla_horner_details", tout << e.type() << "\n";); + TRACE("nla_horner_details", tout << e->type() << "\n";); SASSERT(false); return interv(); } } -interv horner::interval_of_mul(const nex& e) { - SASSERT(e.is_mul()); - auto & es = e.children(); +interv horner::interval_of_mul(const nex_mul* e) { + SASSERT(e->is_mul()); + auto & es = to_mul(e)->children(); interv a = interval_of_expr(es[0]); if (m_intervals.is_zero(a)) { m_intervals.set_zero_interval_deps_for_mult(a); @@ -208,25 +209,25 @@ interv horner::interval_of_mul(const nex& e) { return a; } -void horner::add_mul_to_vector(const nex& e, vector<std::pair<rational, lpvar>> &v) { +void horner::add_mul_to_vector(const nex_mul* e, vector<std::pair<rational, lpvar>> &v) { TRACE("nla_horner_details", tout << e << "\n";); - SASSERT(e.is_mul() && e.size() > 0); - if (e.size() == 1) { - add_linear_to_vector(*(e.children().begin()), v); + SASSERT(e->size() > 0); + if (e->size() == 1) { + add_linear_to_vector(*(e->children().begin()), v); return; } rational r; lpvar j = -1; - for (const nex & c : e.children()) { - switch (c.type()) { + for (const nex * c : e->children()) { + switch (c->type()) { case expr_type::SCALAR: - r = c.value(); + r = to_scalar(c)->value(); break; case expr_type::VAR: - j = c.var(); + j = to_var(c)->var(); break; default: - TRACE("nla_horner_details", tout << e.type() << "\n";); + TRACE("nla_horner_details", tout << e->type() << "\n";); SASSERT(false); } } @@ -234,30 +235,30 @@ void horner::add_mul_to_vector(const nex& e, vector<std::pair<rational, lpvar>> v.push_back(std::make_pair(r, j)); } -void horner::add_linear_to_vector(const nex& e, vector<std::pair<rational, lpvar>> &v) { +void horner::add_linear_to_vector(const nex* e, vector<std::pair<rational, lpvar>> &v) { TRACE("nla_horner_details", tout << e << "\n";); - switch (e.type()) { + switch (e->type()) { case expr_type::MUL: - add_mul_to_vector(e, v); + add_mul_to_vector(to_mul(e), v); break; case expr_type::VAR: - v.push_back(std::make_pair(rational(1), e.var())); + v.push_back(std::make_pair(rational(1), to_var(e)->var())); break; default: - SASSERT(!e.is_sum()); + SASSERT(!e->is_sum()); // noop } } // e = a * can_t + b -lp::lar_term horner::expression_to_normalized_term(nex& e, rational& a, rational& b) { +lp::lar_term horner::expression_to_normalized_term(const nex_sum* e, rational& a, rational& b) { TRACE("nla_horner_details", tout << e << "\n";); lpvar smallest_j; vector<std::pair<rational, lpvar>> v; b = rational(0); unsigned a_index; - for (const nex& c : e.children()) { - if (c.is_scalar()) { - b += c.value(); + for (const nex* c : e->children()) { + if (c->is_scalar()) { + b += to_scalar(c)->value(); } else { add_linear_to_vector(c, v); if (v.empty()) @@ -295,9 +296,10 @@ lp::lar_term horner::expression_to_normalized_term(nex& e, rational& a, rational // we should have in the case of found a*m_terms[k] + b = e, // where m_terms[k] corresponds to the returned lpvar -lpvar horner::find_term_column(const nex& e, rational& a, rational& b) const { - nex n = e; - lp::lar_term norm_t = expression_to_normalized_term(n, a, b); +lpvar horner::find_term_column(const nex* e, rational& a, rational& b) const { + if (!e->is_sum()) + return -1; + lp::lar_term norm_t = expression_to_normalized_term(to_sum(e), a, b); std::pair<rational, lpvar> a_j; if (c().m_lar_solver.fetch_normalized_term_column(norm_t, a_j)) { a /= a_j.first; @@ -306,8 +308,8 @@ lpvar horner::find_term_column(const nex& e, rational& a, rational& b) const { return -1; } -interv horner::interval_of_sum_no_terms(const nex& e) { - auto & es = e.children(); +interv horner::interval_of_sum_no_terms(const nex_sum* e) { + auto & es = e->children(); interv a = interval_of_expr(es[0]); if (m_intervals.is_inf(a)) { TRACE("nla_horner_details", tout << "e=" << e << "\n"; @@ -340,10 +342,9 @@ interv horner::interval_of_sum_no_terms(const nex& e) { return a; } -bool horner::interval_from_term(const nex& e, interv & i) const { +bool horner::interval_from_term(const nex* e, interv & i) const { rational a, b; - nex n = e; - lpvar j = find_term_column(n, a, b); + lpvar j = find_term_column(e, a, b); if (j + 1 == 0) return false; @@ -361,11 +362,10 @@ bool horner::interval_from_term(const nex& e, interv & i) const { } -interv horner::interval_of_sum(const nex& e) { +interv horner::interval_of_sum(const nex_sum* e) { TRACE("nla_horner_details", tout << "e=" << e << "\n";); - SASSERT(e.is_sum()); interv i_e = interval_of_sum_no_terms(e); - if (e.sum_is_a_linear_term()) { + if (e->is_a_linear_term()) { interv i_from_term ; if (interval_from_term(e, i_from_term)) { interv r = m_intervals.intersect(i_e, i_from_term); diff --git a/src/math/lp/horner.h b/src/math/lp/horner.h index f5944b24c..cd9d60c20 100644 --- a/src/math/lp/horner.h +++ b/src/math/lp/horner.h @@ -22,6 +22,7 @@ #include "math/lp/nla_common.h" #include "math/lp/nla_intervals.h" #include "math/lp/nla_expr.h" +#include "math/lp/cross_nested.h" namespace nla { class core; @@ -30,31 +31,30 @@ class core; class horner : common { intervals m_intervals; public: - typedef nla_expr<rational> nex; typedef intervals::interval interv; horner(core *core); void horner_lemmas(); template <typename T> // T has an iterator of (coeff(), var()) bool lemmas_on_row(const T&); template <typename T> bool row_is_interesting(const T&) const; - template <typename T> nex create_sum_from_row(const T&); - intervals::interval interval_of_expr(const nex& e); + template <typename T> + nex_sum* create_sum_from_row(const T&, cross_nested&); + intervals::interval interval_of_expr(const nex* e); - nex nexvar(lpvar j) const; - intervals::interval interval_of_sum(const nex&); - intervals::interval interval_of_sum_no_terms(const nex&); - intervals::interval interval_of_mul(const nex&); + nex* nexvar(lpvar j, cross_nested& cn) const; + intervals::interval interval_of_sum(const nex_sum*); + intervals::interval interval_of_sum_no_terms(const nex_sum*); + intervals::interval interval_of_mul(const nex_mul*); void set_interval_for_scalar(intervals::interval&, const rational&); void set_var_interval(lpvar j, intervals::interval&) const; - bool lemmas_on_expr(nex &); + bool lemmas_on_expr(nex_sum* , cross_nested&); template <typename T> // T has an iterator of (coeff(), var()) bool row_has_monomial_to_refine(const T&) const; - lpvar find_term_column(const nex& e, rational& a, rational& b) const; - static lp::lar_term expression_to_normalized_term(nex&, rational& a, rational & b); - static void add_linear_to_vector(const nex&, vector<std::pair<rational, lpvar>> &); - static void add_mul_to_vector(const nex&, vector<std::pair<rational, lpvar>> &); - bool is_tighter(const interv&, const interv&) const; - bool interval_from_term(const nex& e, interv&) const; + lpvar find_term_column(const nex* e, rational& a, rational& b) const; + static lp::lar_term expression_to_normalized_term(const nex_sum*, rational& a, rational & b); + static void add_linear_to_vector(const nex*, vector<std::pair<rational, lpvar>> &); + static void add_mul_to_vector(const nex_mul*, vector<std::pair<rational, lpvar>> &); + bool interval_from_term(const nex* e, interv&) const; }; // end of horner } diff --git a/src/math/lp/nla_core.cpp b/src/math/lp/nla_core.cpp index 5a69b0cae..70883c5c1 100644 --- a/src/math/lp/nla_core.cpp +++ b/src/math/lp/nla_core.cpp @@ -1346,41 +1346,41 @@ lbool core::test_check( return check(l); } -nla_expr<rational> core::mk_expr(lpvar j) const { - return nla_expr<rational>::var(j); -} +// nla_expr<rational> core::mk_expr(lpvar j) const { +// return nla_expr<rational>::var(j); +// } -nla_expr<rational> core::mk_expr(const rational &a, lpvar j) const { - if (a == 1) - return mk_expr(j); - nla_expr<rational> r(expr_type::MUL); - r.add_child(nla_expr<rational>::scalar(a)); - r.add_child(nla_expr<rational>::var(j)); - return r; -} +// nla_expr<rational> core::mk_expr(const rational &a, lpvar j) const { +// if (a == 1) +// return mk_expr(j); +// nla_expr<rational> r(expr_type::MUL); +// r.add_child(nla_expr<rational>::scalar(a)); +// r.add_child(nla_expr<rational>::var(j)); +// return r; +// } -nla_expr<rational> core::mk_expr(const rational &a, const svector<lpvar>& vs) const { - nla_expr<rational> r(expr_type::MUL); - r.add_child(nla_expr<rational>::scalar(a)); - for (lpvar j : vs) - r.add_child(nla_expr<rational>::var(j)); - return r; -} -nla_expr<rational> core::mk_expr(const lp::lar_term& t) const { - auto coeffs = t.coeffs_as_vector(); - if (coeffs.size() == 1) { - return mk_expr(coeffs[0].first, coeffs[0].second); - } - nla_expr<rational> r(expr_type::SUM); - for (const auto & p : coeffs) { - lpvar j = p.second; - if (is_monomial_var(j)) - r.add_child(mk_expr(p.first, m_emons[j].vars())); - else - r.add_child(mk_expr(p.first, j)); - } - return r; -} +// nla_expr<rational> core::mk_expr(const rational &a, const svector<lpvar>& vs) const { +// nla_expr<rational> r(expr_type::MUL); +// r.add_child(nla_expr<rational>::scalar(a)); +// for (lpvar j : vs) +// r.add_child(nla_expr<rational>::var(j)); +// return r; +// } +// nla_expr<rational> core::mk_expr(const lp::lar_term& t) const { +// auto coeffs = t.coeffs_as_vector(); +// if (coeffs.size() == 1) { +// return mk_expr(coeffs[0].first, coeffs[0].second); +// } +// nla_expr<rational> r(expr_type::SUM); +// for (const auto & p : coeffs) { +// lpvar j = p.second; +// if (is_monomial_var(j)) +// r.add_child(mk_expr(p.first, m_emons[j].vars())); +// else +// r.add_child(mk_expr(p.first, j)); +// } +// return r; +// } std::ostream& core::print_terms(std::ostream& out) const { for (unsigned i=0; i< m_lar_solver.m_terms.size(); i++) { diff --git a/src/math/lp/nla_core.h b/src/math/lp/nla_core.h index 5fee3abc0..e865f11cb 100644 --- a/src/math/lp/nla_core.h +++ b/src/math/lp/nla_core.h @@ -350,11 +350,6 @@ public: lpvar map_to_root(lpvar) const; std::ostream& print_terms(std::ostream&) const; std::ostream& print_term( const lp::lar_term&, std::ostream&) const; - nla_expr<rational> mk_expr(lpvar j) const; - nla_expr<rational> mk_expr(const rational &a, lpvar j) const; - - nla_expr<rational> mk_expr(const rational &a, const svector<lpvar>& vs) const; - nla_expr<rational> mk_expr(const lp::lar_term& t) const; }; // end of core struct pp_mon { diff --git a/src/math/lp/nla_expr.h b/src/math/lp/nla_expr.h index 83be7f5eb..f35a4fc9a 100644 --- a/src/math/lp/nla_expr.h +++ b/src/math/lp/nla_expr.h @@ -47,38 +47,178 @@ inline std::ostream & operator<<(std::ostream& out, expr_type t) { // This class is needed in horner calculation with intervals -template <typename T> -class nla_expr { - // todo: use union - expr_type m_type; - lpvar m_j; - T m_v; // for the scalar - vector<nla_expr> m_children; +class nex { public: - bool is_sum() const { return m_type == expr_type::SUM; } - bool is_var() const { return m_type == expr_type::VAR; } - bool is_mul() const { return m_type == expr_type::MUL; } - bool is_undef() const { return m_type == expr_type::UNDEF; } - bool is_scalar() const { return m_type == expr_type::SCALAR; } - lpvar var() const { SASSERT(m_type == expr_type::VAR); return m_j; } - expr_type type() const { return m_type; } - expr_type& type() { return m_type; } - const vector<nla_expr>& children() const { return m_children; } - vector<nla_expr>& children() { return m_children; } - const T& value() const { SASSERT(m_type == expr_type::SCALAR); return m_v; } - std::string str() const { std::stringstream ss; ss << *this; return ss.str(); } - std::ostream & print_sum(std::ostream& out) const { + virtual expr_type type() const = 0; + virtual std::ostream& print(std::ostream&) const = 0; + nex() {} + bool is_simple() const { + switch(type()) { + case expr_type::SUM: + case expr_type::MUL: + return false; + default: + return true; + } + } + + bool is_sum() const { return type() == expr_type::SUM; } + bool is_mul() const { return type() == expr_type::MUL; } + bool is_var() const { return type() == expr_type::VAR; } + bool is_scalar() const { return type() == expr_type::SCALAR; } + std::string str() const { std::stringstream ss; print(ss); return ss.str(); } + virtual ~nex() {} + virtual bool contains(lpvar j) const { return false; } + virtual int get_degree() const = 0; +}; +std::ostream& operator<<(std::ostream& out, const nex&); + +class nex_var : public nex { + lpvar m_j; +public: + nex_var(lpvar j) : m_j(j) {} + nex_var() {} + expr_type type() const { return expr_type::VAR; } + lpvar var() const { return m_j; } + lpvar& var() { return m_j; } // the setter + std::ostream & print(std::ostream& out) const { + out << 'v' << m_j; + return out; + } + + bool contains(lpvar j) const { return j == m_j; } + int get_degree() const { return 1; } +}; + +class nex_scalar : public nex { + rational m_v; +public: + nex_scalar(const rational& v) : m_v(v) {} + nex_scalar() {} + expr_type type() const { return expr_type::SCALAR; } + const rational& value() const { return m_v; } + rational& value() { return m_v; } // the setter + std::ostream& print(std::ostream& out) const { + out << m_v; + return out; + } + + int get_degree() const { return 0; } + +}; + +class nex_mul : public nex { + vector<nex*> m_children; +public: + nex_mul() {} + unsigned size() const { return m_children.size(); } + expr_type type() const { return expr_type::MUL; } + vector<nex*>& children() { return m_children;} + const vector<nex*>& children() const { return m_children;} + std::ostream & print(std::ostream& out) const { bool first = true; - for (const nla_expr& v : m_children) { - std::string s = v.str(); + for (const nex* v : m_children) { + std::string s = v->str(); if (first) { first = false; - if (v.is_simple()) - out << v; + if (v->is_simple()) + out << s; else out << "(" << s << ")"; } else { - if (v.is_simple()) { + if (v->is_simple()) { + if (s[0] == '-') { + out << "*(" << s << ")"; + } else { + out << "*" << s; + } + } else { + out << "*(" << s << ")"; + } + } + } + return out; + } + + void add_child(nex* e) { m_children.push_back(e); } + + bool contains(lpvar j) const { + for (const nex* c : children()) { + if (c->contains(j)) + return true; + } + return false; + } + + static const nex_var* to_var(const nex*a) { + SASSERT(a->is_var()); + return static_cast<const nex_var*>(a); + } + + void get_powers_from_mul(std::unordered_map<lpvar, unsigned> & r) const { + r.clear(); + for (const auto & c : children()) { + if (!c->is_var()) { + continue; + } + lpvar j = to_var(c)->var(); + auto it = r.find(j); + if (it == r.end()) { + r[j] = 1; + } else { + it->second++; + } + } + TRACE("nla_cn_details", tout << "powers of " << *this << "\n"; print_vector(r, tout)<< "\n";); + } + + int get_degree() const { + int degree = 0; + for (auto e : children()) { + degree += e->get_degree(); + } + return degree; + } + + +}; + +class nex_sum : public nex { + vector<nex*> m_children; +public: + nex_sum() {} + expr_type type() const { return expr_type::SUM; } + vector<nex*>& children() { return m_children;} + const vector<nex*>& children() const { return m_children;} + unsigned size() const { return m_children.size(); } + + // we need a linear combination of at least two variables + bool is_a_linear_term() const { + TRACE("nex_details", tout << *this << "\n";); + unsigned number_of_non_scalars = 0; + for (auto e : children()) { + int d = e->get_degree(); + if (d == 0) continue; + if (d > 1) return false; + + number_of_non_scalars++; + } + TRACE("nex_details", tout << (number_of_non_scalars > 1?"linear":"non-linear") << "\n";); + return number_of_non_scalars > 1; + } + + std::ostream & print(std::ostream& out) const { + bool first = true; + for (const nex* v : m_children) { + std::string s = v->str(); + if (first) { + first = false; + if (v->is_simple()) + out << s; + else + out << "(" << s << ")"; + } else { + if (v->is_simple()) { if (s[0] == '-') { out << s; } else { @@ -93,457 +233,52 @@ public: } void simplify() { - if (is_simple()) return; - bool has_sum = false; - if (is_sum()) { - for (auto & e : m_children) { - e.simplify(); - has_sum |= e.is_sum(); - } - if (has_sum) { - nla_expr n(expr_type::SUM); - for (auto &e : m_children) { - n += e; - } - m_children = n.m_children; - } - } else if (is_mul()) { - bool has_mul = false; - for (auto & e : m_children) { - e.simplify(); - has_mul |= e.is_mul(); - } - if (has_mul) { - nla_expr n(expr_type::MUL); - for (auto &e : m_children) { - n *= e; - } - m_children = n.m_children; - } - TRACE("nla_cn_details", tout << "simplified " << *this << "\n";); - } - } - - std::ostream & print_mul(std::ostream& out) const { - bool first = true; - for (const nla_expr& v : m_children) { - std::string s = v.str(); - if (first) { - first = false; - if (v.is_simple()) - out << s; - else - out << "(" << s << ")"; - } else { - if (v.is_simple()) { - if (s[0] == '-') { - out << "*(" << s << ")"; - } else { - out << "*" << s; - } - } else { - out << "*(" << s << ")"; - } - } - } - return out; - } - std::ostream & print(std::ostream& out) const { - switch(m_type) { - case expr_type::SUM: - return print_sum(out); - case expr_type::MUL: - return print_mul(out); - case expr_type::VAR: - out << 'v' << m_j; - return out; - case expr_type::SCALAR: - out << m_v; - return out; - default: - out << "undef"; - return out; - } - } - - bool is_simple() const { - switch(m_type) { - case expr_type::SUM: - case expr_type::MUL: - return false; - - default: - return true; - } - } - - unsigned size() const { - switch(m_type) { - case expr_type::SUM: - case expr_type::MUL: - return m_children.size(); - - default: - return 1; - } - } - nla_expr(expr_type t): m_type(t) {} - nla_expr(): m_type(expr_type::UNDEF) {} - - void add_child(const nla_expr& e) { - m_children.push_back(e); - } - - void add_child(const T& k) { - m_children.push_back(scalar(k)); - } - - void add_children() { } - - template <typename K, typename ...Args> - void add_children(K e, Args ... es) { - add_child(e); - add_children(es ...); - } - - template <typename K, typename ... Args> - static nla_expr sum(K e, Args ... es) { - nla_expr r(expr_type::SUM); - r.add_children(e, es...); - return r; - } - - template <typename K, typename ... Args> - static nla_expr mul(K e, Args ... es) { - nla_expr r(expr_type::MUL); - r.add_children(e, es...); - return r; - } - - static nla_expr mul(const T& v, nla_expr & w) { - if (v == 1) - return w; - nla_expr r(expr_type::MUL); - r.add_child(scalar(v)); - r.add_child(w); - return r; - } - - static nla_expr mul() { - return nla_expr(expr_type::MUL); - } - - static nla_expr mul(const T& v, lpvar j) { - if (v == 1) - return var(j); - return mul(scalar(v), var(j)); - } - - static nla_expr scalar(const T& v) { - nla_expr r(expr_type::SCALAR); - r.m_v = v; - return r; - } - - static nla_expr var(lpvar j) { - nla_expr r(expr_type::VAR); - r.m_j = j; - return r; - } - - bool contains(lpvar j) const { - if (is_var()) - return m_j == j; - if (is_mul()) { - for (const nla_expr<T>& c : children()) { - if (c.contains(j)) - return true; - } - } - return false; - } - - nla_expr& operator*=(const nla_expr& b) { - if (is_mul()) { - if (b.is_mul()) { - for (auto& e: b.children()) - add_child(e); - } else { - add_child(b); - } - return *this; - } - SASSERT(false); // not impl - return *this; - } - - std::unordered_map<lpvar, int> get_powers_from_mul() const { - SASSERT(is_mul()); - std::unordered_map<lpvar, int> r; - for (const auto & c : children()) { - if (!c.is_var()) { - continue; - } - lpvar j = c.var(); - auto it = r.find(j); - if (it == r.end()) { - r[j] = 1; - } else { - it->second++; - } - } - TRACE("nla_cn_details", tout << "powers of " << *this << "\n"; print_vector(r, tout)<< "\n";); - return r; - } - - friend nla_expr operator-(const nla_expr& a, const nla_expr&b) { - return a + scalar(T(-1))*b; + SASSERT(false); } - nla_expr& operator/=(const nla_expr& b) { - TRACE("nla_cn_details", tout << *this <<" / " << b << "\n";); - if (b.is_var()) { - *this = (*this) / b.var(); - TRACE("nla_cn_details", tout << *this << "\n";); - return *this; - } - SASSERT(b.is_mul()); - if (is_sum()) { - for (auto & e : children()) { - e /= b; - } - TRACE("nla_cn_details", tout << *this << "\n";); - return *this; - } - if (is_var() || children().size() == 1) { - *this = scalar(T(1)); - TRACE("nla_cn_details", tout << *this << "\n";); - return *this; - } - SASSERT(is_mul()); - auto powers = b.get_powers_from_mul(); - unsigned i = 0, k = 0; - for (; i < children().size(); i++, k++) { - auto & e = children()[i]; - TRACE("nla_cn_details", tout << "e=" << e << ",i=" <<i<< ",k=" << k<< "\n";); - - if (!e.is_var()) { - SASSERT(e.is_scalar()); - if (i != k) - children()[k] = children()[i]; - - TRACE("nla_cn_details", tout << "continue\n";); - continue; - - } - lpvar j = e.var(); - auto it = powers.find(j); - if (it == powers.end()) { - if (i != k) - children()[k] = children()[i]; - } else { - it->second --; - if (it->second == 0) - powers.erase(it); - k--; - } - TRACE("nla_cn_details", tout << *this << "\n";); - - } - SASSERT(powers.size() == 0); - while(k ++ < i) - children().pop_back(); - - if (children().size() == 0) - *this = scalar(T(1)); - TRACE("nla_cn_details", tout << *this << "\n";); - - return *this; - } - - - nla_expr& operator+=(const nla_expr& b) { - if (is_sum()) { - if (b.is_sum()) { - for (auto& e: b.children()) - add_child(e); - } else { - add_child(b); - } - return *this; - } - SASSERT(false); // not impl - return *this; - } - - // we need a linear combination of at least two variables - bool sum_is_a_linear_term() const { - SASSERT(is_sum()); - TRACE("nla_expr_details", tout << *this << "\n";); - unsigned number_of_non_scalars = 0; - for (auto & e : children()) { - int d = e.get_degree(); - if (d == 0) continue; - if (d > 1) return false; - - number_of_non_scalars++; - } - TRACE("nla_expr_details", tout << (number_of_non_scalars > 1?"linear":"non-linear") << "\n";); - return number_of_non_scalars > 1; - } - int get_degree() const { - switch (type()) { - case expr_type::SUM: { - int degree = 0; - for (auto & e : children()) { - degree = std::max(degree, e.get_degree()); - } - return degree; + int degree = 0; + for (auto e : children()) { + degree = std::max(degree, e->get_degree()); } - - case expr_type::MUL: { - int degree = 0; - for (auto & e : children()) { - degree += e.get_degree(); - } - return degree; - } - case expr_type::VAR: - return 1; - case expr_type::SCALAR: - return 0; - case expr_type::UNDEF: - default: - UNREACHABLE(); - break; - } - return 0; - } + return degree; + } + + void add_child(nex* e) { m_children.push_back(e); } }; -/* -nla_expr operator/=(const nla_expr &a, const nla_expr& b) { - TRACE("nla_cn_details", tout << a <<" / " << b << "\n";); - if (b.is_var()) { - return a / b.var(); - } - SASSERT(b.is_mul()); - if (a.is_sum()) { - auto r = nex::sum(); - for (auto & e : a.children()) { - r.add_child(e/b); - } - return r; - } - if (is_var()) { - return scalar(T(1)); - return *this; - } - SASSERT(a.is_mul()); - auto powers = b.get_powers_from_mul(); - auto r=nex::mul(); - for (unsigned i = 0; i < a.children().size(); i++, k++) { - auto & e = children()[i]; - if (!e.is_var()) { - SASSERT(e.is_scalar()); - r.add_child(e); - continue; - } - lpvar j = e.var(); - auto it = powers.find(j); - if (it == powers.end()) { - r.add_child(e); - } else { - it->second --; // finish h - if (it->second == 0) - powers.erase(it); - } - } - - return r; - } -*/ -template <typename T> -nla_expr<T> operator+(const nla_expr<T>& a, const nla_expr<T>& b) { - if (a.is_sum()) { - nla_expr<T> r(expr_type::SUM); - r.children() = a.children(); - if (b.is_sum()) { - for (auto& e: b.children()) - r.add_child(e); - } else { - r.add_child(b); - } - return r; - } - if (b.is_sum()) { - nla_expr<T> r(expr_type::SUM); - r.children() = b.children(); - r.add_child(a); - return r; - } - return nla_expr<T>::sum(a, b); +inline const nex_sum* to_sum(const nex*a) { + SASSERT(a->is_sum()); + return static_cast<const nex_sum*>(a); } -template <typename T> -nla_expr<T> operator*(const nla_expr<T>& a, const nla_expr<T>& b) { - if (a.is_scalar() && a.value() == T(1)) - return b; - if (b.is_scalar() && b.value() == T(1)) - return a; - if (a.is_mul()) { - nla_expr<T> r(expr_type::MUL); - r.children() = a.children(); - if (b.is_mul()) { - for (auto& e: b.children()) - r.add_child(e); - } else { - r.add_child(b); - } - return r; - } - if (b.is_mul()) { - nla_expr<T> r(expr_type::MUL); - r.children() = b.children(); - r.add_child(a); - return r; - } - return nla_expr<T>::mul(a, b); +inline nex_sum* to_sum(nex * a) { + SASSERT(a->is_sum()); + return static_cast<nex_sum*>(a); } - -template <typename T> -nla_expr<T> operator/(const nla_expr<T>& a, lpvar j) { - TRACE("nla_cn_details", tout << "a=" << a << ", v" << j << "\n";); - SASSERT((a.is_mul() && a.contains(j)) || (a.is_var() && a.var() == j)); - if (a.is_var()) - return nla_expr<T>::scalar(T(1)); - nla_expr<T> b; - bool seenj = false; - for (const nla_expr<T>& c : a.children()) { - if (!seenj) { - if (c.contains(j)) { - if (!c.is_var()) - b.add_child(c / j); - seenj = true; - continue; - } - } - b.add_child(c); - } - if (b.children().size() > 1) { - b.type() = expr_type::MUL; - } else if (b.children().size() == 1) { - auto t = b.children()[0]; - b = t; - } else { - b = nla_expr<T>::scalar(T(1)); - } - return b; + +inline const nex_var* to_var(const nex*a) { + SASSERT(a->is_var()); + return static_cast<const nex_var*>(a); } -template <typename T> -std::ostream& operator<<(std::ostream& out, const nla_expr<T>& e ) { + +inline const nex_mul* to_mul(const nex*a) { + SASSERT(a->is_mul()); + return static_cast<const nex_mul*>(a); +} + +inline nex_mul* to_mul(nex*a) { + SASSERT(a->is_mul()); + return static_cast<nex_mul*>(a); +} + +inline const nex_scalar * to_scalar(const nex* a) { + SASSERT(a->is_scalar()); + return static_cast<const nex_scalar*>(a); +} + +inline std::ostream& operator<<(std::ostream& out, const nex& e ) { return e.print(out); } diff --git a/src/test/lp/lp.cpp b/src/test/lp/lp.cpp index e6c65a2d1..a0ff48209 100644 --- a/src/test/lp/lp.cpp +++ b/src/test/lp/lp.cpp @@ -68,26 +68,33 @@ void test_basic_lemma_for_mon_zero_from_factors_to_monomial(); void test_basic_lemma_for_mon_neutral_from_monomial_to_factors(); void test_basic_lemma_for_mon_neutral_from_factors_to_monomial(); -void test_cn_on_expr(horner::nex t) { - TRACE("nla_cn", tout << "t=" << t << '\n';); - cross_nested cn(t, [](const horner::nex& n) { - TRACE("nla_cn_test", tout << n << "\n";); - return false; - } , - [](unsigned) { return false; }); - cn.run(); +void test_cn_on_expr(nex_sum *t, cross_nested& cn) { + TRACE("nla_cn", tout << "t=" << *t << '\n';); + cn.run(t); } void test_cn() { - typedef horner::nex nex; + cross_nested cn([](const nex* n) { + TRACE("nla_cn_test", tout << *n << "\n";); + return false; + } , + [](unsigned) { return false; }); enable_trace("nla_cn"); enable_trace("nla_cn_details"); - nex a = nex::var(0), b = nex::var(1), c = nex::var(2), d = nex::var(3), e = nex::var(4), f = nex::var(5), g = nex::var(6); - nex min_1 = nex::scalar(rational(-1)); + nex_var* a = cn.mk_var(0); + nex_var* b = cn.mk_var(1); + nex_var* c = cn.mk_var(2); + nex_var* d = cn.mk_var(3); + nex_var* e = cn.mk_var(4); + nex_var* f = cn.mk_var(5); + nex_var* g = cn.mk_var(6); + nex* min_1 = cn.mk_scalar(rational(-1)); // test_cn_on_expr(min_1*c*e + min_1*b*d + min_1*a*b + a*c); - TRACE("nla_cn", tout << "done\n";); - - test_cn_on_expr(b*c*d - b*c*g); + nex* bcd = cn.mk_mul(b, c, d); + nex_mul* bcg = cn.mk_mul(b, c, g); + bcg->add_child(min_1); + nex_sum* t = cn.mk_sum(bcd, bcg); + test_cn_on_expr(t, cn); // test_cn_on_expr(a*a*d + a*b*c*d + a*a*c*c*d + a*d*d + e*a*e + e*a*c + e*d); // TRACE("nla_cn", tout << "done\n";); // test_cn_on_expr(a*b*d + a*b*c + c*b*d + a*c*d);