From 43294cea1637231ecc87cca2ce981af182b9b752 Mon Sep 17 00:00:00 2001 From: Lev Nachmanson Date: Wed, 2 Oct 2019 16:41:43 -0700 Subject: [PATCH] fix nex simplification Signed-off-by: Lev Nachmanson --- src/math/lp/nex_creator.cpp | 100 ++++++++++++++++++++++++------------ src/math/lp/nex_creator.h | 2 +- src/test/lp/lp.cpp | 7 ++- 3 files changed, 70 insertions(+), 39 deletions(-) diff --git a/src/math/lp/nex_creator.cpp b/src/math/lp/nex_creator.cpp index 8cf7fa2e9..3d729a41a 100644 --- a/src/math/lp/nex_creator.cpp +++ b/src/math/lp/nex_creator.cpp @@ -327,6 +327,8 @@ nex* nex_creator::simplify_sum(nex_sum *e) { } bool nex_creator::sum_is_simplified(const nex_sum* e) const { + TRACE("nla_cn_details", tout << ++ lp::lp_settings::ddd << std::endl;); + if (e->size() < 2) return false; for (nex * ee : e->children()) { if (ee->is_sum()) @@ -395,80 +397,110 @@ nex* nex_creator::create_child_from_nex_and_coeff(nex *e, } } -// returns true if new +// returns true if the key exists already bool nex_creator::register_in_join_map(std::map& map, nex* e, const rational& r) const{ TRACE("nla_cn_details", tout << *e << ", r = " << r << std::endl;); auto map_it = map.find(e); if (map_it == map.end()) { map[e] = r; - return true; + return false; } else { map_it->second += r; - return false; + return true; } } -void nex_creator::process_mul_in_simplify_sum(nex_mul* em, std::map &map, vector & tmp) { +// returns true if a simplificatian happens +bool nex_creator::process_mul_in_simplify_sum(nex_mul* em, std::map &map) { + bool found = false; auto it = em->children().begin(); if (it->e()->is_scalar()) { SASSERT(it->pow() == 1); rational r = to_scalar(it->e())->value(); auto end = em->children().end(); if (em->children().size() == 2 && em->children()[1].pow() == 1) { - register_in_join_map(map, em->children()[1].e(), r); + found = register_in_join_map(map, em->children()[1].e(), r); } else { - tmp.push_back(nex_mul()); - nex_mul * m = &tmp[tmp.size()-1]; + nex_mul * m = new nex_mul(); for (it++; it != end; it++) { m->add_child_in_power(it->e(), it->pow()); } - if (!register_in_join_map(map, m, r)) - tmp.pop_back(); + found = register_in_join_map(map, m, r); } } else { - register_in_join_map(map, em, rational(1)); + found = register_in_join_map(map, em, rational(1)); } + return found; } // a + 3bc + 2bc => a + 5bc void nex_creator::sort_join_sum(ptr_vector & children) { + TRACE("nla_cn_details", print_vector_of_ptrs(children, tout);); std::map map([this](const nex *a , const nex *b) { return lt(a, b); }); - TRACE("nla_cn_details", print_vector_of_ptrs(children, tout);); - vector tmp; - nex_scalar * s = nullptr; - for (auto e : children) { - if (e->is_mul()) { - process_mul_in_simplify_sum(to_mul(e), map, tmp); - } else if (e->is_scalar()) { + std::unordered_set existing_nex; // handling (nex*) as numbers + nex_scalar * common_scalar = nullptr; + bool simplified = false; + for (auto e : children) { + if (e->is_scalar()) { nex_scalar * es = to_scalar(e); - if (s == nullptr) - s = es; - else - s->value() += es->value(); + if (common_scalar == nullptr) { + common_scalar = es; + } else { + simplified = true; + common_scalar->value() += es->value(); + } + continue; } - else { - register_in_join_map(map, e, rational(1)); + existing_nex.insert(e); + if (e->is_mul()) { + simplified |= process_mul_in_simplify_sum(to_mul(e), map); + } else { + SASSERT(e->is_var()); + simplified |= register_in_join_map(map, e, rational(1)); } } - bool simplified; - for (auto& p : map) { - if (!p.second.is_one()) { - simplified = true; - break; - } - } if (!simplified) return; + TRACE("nla_cn_details", for (auto & p : map ) { tout << "(" << *p.first << ", " << p.second << ") ";}); children.clear(); - if (s) { - children.push_back(s); + if (common_scalar) { + children.push_back(common_scalar); } for (auto& p : map) { - if (p.second.is_zero() == false) - children.push_back(create_child_from_nex_and_coeff(p.first, p.second)); + nex *e = p.first; + const rational & coeff = p.second; + if (coeff.is_zero()) + continue; + bool e_is_old = existing_nex.find(e) != existing_nex.end(); + if (e_is_old) { + if (coeff.is_one()) { + children.push_back(e); + } else { + if (e->is_var()) { + children.push_back(mk_mul(mk_scalar(coeff), e)); + } else { + SASSERT(e->is_mul()); + nex* first = to_mul(e)->children()[0].e(); + if (first->is_scalar()) { + to_scalar(first)->value() = coeff; + children.push_back(e); + } else { + e = simplify(mk_mul(mk_scalar(coeff), e)); + children.push_back(e); + } + } + } + } else { // e is new + if (coeff.is_one()) { + m_allocated.push_back(e); + children.push_back(e); + } else { + children.push_back(simplify(mk_mul(mk_scalar(coeff), e))); + } + } } } diff --git a/src/math/lp/nex_creator.h b/src/math/lp/nex_creator.h index f5a89fef3..997457578 100644 --- a/src/math/lp/nex_creator.h +++ b/src/math/lp/nex_creator.h @@ -208,7 +208,7 @@ public: bool is_sorted(const nex_mul * e) const; nex* simplify_sum(nex_sum *e); - void process_mul_in_simplify_sum(nex_mul* e, std::map &, vector &); + bool process_mul_in_simplify_sum(nex_mul* e, std::map &); bool is_simplified(const nex *e) const; bool sum_is_simplified(const nex_sum* e) const; diff --git a/src/test/lp/lp.cpp b/src/test/lp/lp.cpp index 4a578a9ed..824e802da 100644 --- a/src/test/lp/lp.cpp +++ b/src/test/lp/lp.cpp @@ -83,9 +83,9 @@ void test_simplify() { [](unsigned) { return false; }, []() { return 1; } // for random ); - enable_trace("nla_cn"); - enable_trace("nla_cn_details"); - enable_trace("nla_cn_details_"); + // enable_trace("nla_cn"); + // enable_trace("nla_cn_details"); + // enable_trace("nla_cn_details_"); enable_trace("nla_test"); nex_creator & r = cn.get_nex_creator(); @@ -99,7 +99,6 @@ void test_simplify() { auto a_plus_bc = r.mk_sum(a, bc); auto simp_a_plus_bc = r.simplify(a_plus_bc); SASSERT(to_sum(simp_a_plus_bc)->size() > 1); - return; auto m = r.mk_mul(); m->add_child_in_power(c, 2); TRACE("nla_test_", tout << "m = " << *m << "\n";); auto n = r.mk_mul(a);