diff --git a/src/math/lp/cross_nested.h b/src/math/lp/cross_nested.h index 50d01a2ee..d09c92653 100644 --- a/src/math/lp/cross_nested.h +++ b/src/math/lp/cross_nested.h @@ -78,7 +78,7 @@ public: nex* extract_common_factor(nex* e) { nex_sum* c = to_sum(e); TRACE("nla_cn", tout << "c=" << *c << "\n"; tout << "occs:"; dump_occurences(tout, m_nex_creator.occurences_map()) << "\n";); - unsigned size = c->children().size(); + unsigned size = c->size(); bool have_factor = false; for(const auto & p : m_nex_creator.occurences_map()) { if (p.second.m_occs == size) { @@ -131,7 +131,7 @@ public: nex_mul* cm; *c = cm = m_nex_creator.mk_mul(f, c_over_f); TRACE("nla_cn", tout << "common factor=" << *f << ", c=" << **c << "\ne = " << *m_e << "\n";); - explore_expr_on_front_elem(cm->children()[1].ee(), front); + explore_expr_on_front_elem((*cm)[1].ee(), front); return true; } @@ -193,7 +193,7 @@ public: void calc_occurences(nex_sum* e) { clear_maps(); - for (const auto * ce : e->children()) { + for (const auto * ce : *e) { if (ce->is_mul()) { to_mul(ce)->get_powers_from_mul(m_nex_creator.powers()); update_occurences_with_powers(); @@ -338,7 +338,7 @@ public: // The result is sorted by large number of occurences first vector> get_mult_occurences(const nex_sum* e) { clear_maps(); - for (const auto * ce : e->children()) { + for (const auto * ce : *e) { if (ce->is_mul()) { to_mul(ce)->get_powers_from_mul(m_nex_creator.powers()); update_occurences_with_powers(); @@ -375,7 +375,7 @@ public: TRACE("nla_cn_details", tout << "e = " << * e << ", j = " << m_nex_creator.ch(j) << std::endl;); a = m_nex_creator.mk_sum(); m_b_split_vec.clear(); - for (nex * ce: e->children()) { + for (nex * ce: *e) { if (is_divisible_by_var(ce, j)) { a->add_child(m_nex_creator.mk_div(ce , j)); } else { @@ -385,7 +385,7 @@ public: } } TRACE("nla_cn_details", tout << "a = " << *a << "\n";); - SASSERT(a->children().size() >= 2 && m_b_split_vec.size()); + SASSERT(a->size() >= 2 && m_b_split_vec.size()); a = to_sum(m_nex_creator.simplify_sum(a)); if (m_b_split_vec.size() == 1) { @@ -402,12 +402,12 @@ public: TRACE("nla_cn_details", tout << "b = " << *b << "\n";); e = m_nex_creator.mk_sum(m_nex_creator.mk_mul(m_nex_creator.mk_var(j), a), b); // e = j*a + b if (!a->is_linear()) { - nex **ptr_to_a = (to_mul(to_sum(e)->children()[0]))->children()[1].ee(); + nex **ptr_to_a = ((*to_mul((*to_sum(e))[0])))[1].ee(); push_to_front(front, ptr_to_a); } if (b->is_sum() && !to_sum(b)->is_linear()) { - nex **ptr_to_a = &(to_sum(e)->children()[1]); + nex **ptr_to_a = &((*to_sum(e))[1]); push_to_front(front, ptr_to_a); } } @@ -416,7 +416,7 @@ public: if (b == nullptr) { e = m_nex_creator.mk_mul(m_nex_creator.mk_var(j), a); if (!to_sum(a)->is_linear()) - push_to_front(front, to_mul(e)->children()[1].ee()); + push_to_front(front, (*to_mul(e))[1].ee()); } else { update_front_with_split_with_non_empty_b(e, j, front, a, b); } @@ -442,33 +442,6 @@ public: return true; } - static std::unordered_set get_vars_of_expr(const nex *e ) { - std::unordered_set r; - switch (e->type()) { - case expr_type::SCALAR: - return r; - case expr_type::SUM: - { - for (auto c: to_sum(e)->children()) - for ( lpvar j : get_vars_of_expr(c)) - r.insert(j); - } - case expr_type::MUL: - { - for (auto &c: to_mul(e)->children()) - for ( lpvar j : get_vars_of_expr(c.e())) - r.insert(j); - } - return r; - case expr_type::VAR: - r.insert(to_var(e)->var()); - return r; - default: - TRACE("nla_cn_details", tout << e->type() << "\n";); - SASSERT(false); - return r; - } - } ~cross_nested() { m_nex_creator.clear(); diff --git a/src/math/lp/nex.h b/src/math/lp/nex.h index 881d2d71b..397242a62 100644 --- a/src/math/lp/nex.h +++ b/src/math/lp/nex.h @@ -398,5 +398,33 @@ inline bool less_than_nex_standard(const nex* a, const nex* b) { lt_on_vars lt = [](lpvar j, lpvar k) { return j < k; }; return less_than_nex(a, b, lt); } + +inline std::unordered_set get_vars_of_expr(const nex *e ) { + std::unordered_set r; + switch (e->type()) { + case expr_type::SCALAR: + return r; + case expr_type::SUM: + { + for (auto c: *to_sum(e)) + for ( lpvar j : get_vars_of_expr(c)) + r.insert(j); + } + case expr_type::MUL: + { + for (auto &c: *to_mul(e)) + for ( lpvar j : get_vars_of_expr(c.e())) + r.insert(j); + } + return r; + case expr_type::VAR: + r.insert(to_var(e)->var()); + return r; + default: + TRACE("nla_cn_details", tout << e->type() << "\n";); + SASSERT(false); + return r; + } + } } diff --git a/src/math/lp/nex_creator.cpp b/src/math/lp/nex_creator.cpp index 078ab6b6f..5916a9dae 100644 --- a/src/math/lp/nex_creator.cpp +++ b/src/math/lp/nex_creator.cpp @@ -329,13 +329,22 @@ nex* nex_creator::simplify_sum(nex_sum *e) { } bool nex_creator::sum_is_simplified(const nex_sum* e) const { - TRACE("nla_cn_details", tout << ++ lp::lp_settings::ddd << std::endl;); if (e->size() < 2) return false; + bool scalar = false; for (nex * ee : *e) { if (ee->is_sum()) return false; - if (ee->is_scalar() && to_scalar(ee)->value().is_zero()) + if (ee->is_scalar()) { + if (scalar) { + return false; + } + if (to_scalar(ee)->value().is_zero()) { + return false; + } + scalar = true; + } + if (!is_simplified(ee)) return false; } return true; @@ -550,7 +559,7 @@ nex * nex_creator::mk_div_by_mul(const nex* a, const nex_mul* b) { return mk_div_sum_by_mul(to_sum(a), b); } if (a->is_var() || (a->is_mul() && to_mul(a)->size() == 1)) { - SASSERT(b->get_degree() == 1 && !b->has_a_coeff() && b->contains(to_var(a)->var())); + SASSERT(b->get_degree() == 1 && !b->has_a_coeff() && get_vars_of_expr(a) == get_vars_of_expr(b)); return mk_scalar(rational(1)); } const nex_mul* am = to_mul(a); diff --git a/src/test/lp/lp.cpp b/src/test/lp/lp.cpp index 50794e1cd..5fc25ebf7 100644 --- a/src/test/lp/lp.cpp +++ b/src/test/lp/lp.cpp @@ -69,7 +69,8 @@ void test_basic_lemma_for_mon_neutral_from_monomial_to_factors(); void test_basic_lemma_for_mon_neutral_from_factors_to_monomial(); void test_cn_on_expr(nex_sum *t, cross_nested& cn) { - TRACE("nla_cn", tout << "t=" << *t << '\n';); + t = to_sum(cn.get_nex_creator().simplify(t)); + TRACE("nla_test", tout << "t=" << *t << '\n';); cn.run(t); } @@ -146,13 +147,14 @@ void test_simplify() { void test_cn() { cross_nested cn( [](const nex* n) { - TRACE("nla_test", tout << *n << "\n";); - return false; - } , + TRACE("nla_test", tout <<"cn form = " << *n << "\n";); + return false; + } , [](unsigned) { return false; }, []{ return 1; }); enable_trace("nla_test"); - enable_trace("nla_test_details"); + // enable_trace("nla_cn"); + // enable_trace("nla_test_details"); auto & cr = cn.get_nex_creator(); cr.active_vars_weights().resize(20); for (unsigned j = 0; j < cr.active_vars_weights().size(); j++) @@ -164,6 +166,10 @@ void test_cn() { nex_var* d = cr.mk_var(3); nex_var* e = cr.mk_var(4); nex_var* g = cr.mk_var(6); + nex_sum * a_p_ae_sq = cr.mk_sum(a, cr.mk_mul(a, e, e)); + a_p_ae_sq = to_sum(cr.simplify(a_p_ae_sq)); + test_cn_on_expr(a_p_ae_sq, cn); + nex* min_1 = cr.mk_scalar(rational(-1)); // test_cn_on_expr(min_1*c*e + min_1*b*d + min_1*a*b + a*c); nex* bcd = cr.mk_mul(b, c, d); @@ -171,7 +177,6 @@ void test_cn() { bcg->add_child(min_1); nex_sum* t = cr.mk_sum(bcd, bcg); test_cn_on_expr(t, cn); - nex* aad = cr.mk_mul(a, a, d); nex* abcd = cr.mk_mul(a, b, c, d); nex* aaccd = cr.mk_mul(a, a, c, c, d); nex* add = cr.mk_mul(a, d, d);