diff --git a/src/math/lp/cross_nested.h b/src/math/lp/cross_nested.h index 6f8139ac9..4dfd1034f 100644 --- a/src/math/lp/cross_nested.h +++ b/src/math/lp/cross_nested.h @@ -35,7 +35,6 @@ class cross_nested { int m_reported; bool m_random_bit; nex_creator m_nex_creator; - const lt_on_vars& m_lt; std::function m_mk_scalar; #ifdef Z3DEBUG nex* m_e_clone; @@ -46,8 +45,7 @@ public: cross_nested(std::function call_on_result, std::function var_is_fixed, - std::function random, - lt_on_vars lt): + std::function random) : m_call_on_result(call_on_result), m_var_is_fixed(var_is_fixed), m_random(random), @@ -59,7 +57,7 @@ public: void run(nex *e) { TRACE("nla_cn", tout << *e << "\n";); - SASSERT(e->is_simplified(m_lt)); + SASSERT(m_nex_creator.is_simplified(e)); m_e = e; #ifdef Z3DEBUG // m_e_clone = clone(m_e); @@ -128,7 +126,7 @@ public: } nex* c_over_f = m_nex_creator.mk_div(*c, f); - to_sum(c_over_f)->simplify(&c_over_f, m_lt, m_mk_scalar); + c_over_f = m_nex_creator.simplify(c_over_f); nex_mul* cm; *c = cm = m_nex_creator.mk_mul(f, c_over_f); TRACE("nla_cn", tout << "common factor=" << *f << ", c=" << **c << "\ne = " << *m_e << "\n";); @@ -392,8 +390,7 @@ public: } TRACE("nla_cn_details", tout << "a = " << *a << "\n";); SASSERT(a->children().size() >= 2 && m_b_split_vec.size()); - nex* f; - a->simplify(&f, m_lt, m_mk_scalar); + a = to_sum(m_nex_creator.simplify_sum(a)); if (m_b_split_vec.size() == 1) { b = m_b_split_vec[0]; @@ -484,12 +481,8 @@ public: bool done() const { return m_done; } #if Z3DEBUG nex * normalize_sum(nex_sum* a) { - for (unsigned j = 0; j < a->size(); j ++) { - a->children()[j] = normalize(a->children()[j]); - } - nex *r; - a->simplify(&r, m_lt, m_mk_scalar); - return r; + NOT_IMPLEMENTED_YET(); + return nullptr; } nex * normalize_mul(nex_mul* a) { diff --git a/src/math/lp/horner.cpp b/src/math/lp/horner.cpp index f72206a9c..bab09e9b2 100644 --- a/src/math/lp/horner.cpp +++ b/src/math/lp/horner.cpp @@ -92,8 +92,7 @@ bool horner::lemmas_on_row(const T& row) { cross_nested cn( [this](const nex* n) { return check_cross_nested_expr(n); }, [this](unsigned j) { return c().var_is_fixed(j); }, - [this]() { return c().random(); }, - [](lpvar j, lpvar k) { return j < k;}); // todo : consider using weights here - the same way they are used in Grobner basis + [this]() { return c().random(); }); SASSERT (row_is_interesting(row)); create_sum_from_row(row, cn.get_nex_creator(), m_row_sum); diff --git a/src/math/lp/nex_creator.cpp b/src/math/lp/nex_creator.cpp index 9779a0644..fd5040b24 100644 --- a/src/math/lp/nex_creator.cpp +++ b/src/math/lp/nex_creator.cpp @@ -158,6 +158,36 @@ bool nex_creator::is_sorted(const nex_mul* e) const { return true; } +bool nex_creator::less_than_nex(const nex* a, const nex* b) const { + int r = (int)(a->type()) - (int)(b->type()); + if (r) { + return r < 0; + } + SASSERT(a->type() == b->type()); + switch (a->type()) { + case expr_type::VAR: { + return less_than(to_var(a)->var() , to_var(b)->var()); + } + case expr_type::SCALAR: { + return to_scalar(a)->value() < to_scalar(b)->value(); + } + case expr_type::MUL: { + NOT_IMPLEMENTED_YET(); + return false; // to_mul(a)->children() < to_mul(b)->children(); + } + case expr_type::SUM: { + NOT_IMPLEMENTED_YET(); + return false; //to_sum(a)->children() < to_sum(b)->children(); + } + default: + SASSERT(false); + return false; + } + + return false; +} + + bool nex_creator::mul_is_simplified(const nex_mul* e) const { if (size() == 1 && e->children().begin()->pow() == 1) return false; @@ -282,11 +312,20 @@ void nex_creator::sort_join_sum(ptr_vector & children) { rational nex_creator::extract_coeff_from_mul(const nex_mul* m) { const nex* e = m->children().begin()->e(); - if (e->is_scalar()) + if (e->is_scalar()) { + SASSERT(m->children().begin()->pow() == 1); return to_scalar(e)->value(); + } return rational(1); } +rational nex_creator::extract_coeff(const nex* m) { + if (!m->is_mul()) + return rational(1); + return extract_coeff_from_mul(to_mul(m)); +} + + void nex_creator::fill_map_with_children(std::map & m, ptr_vector & children) { nex_scalar * scalar = nullptr; TRACE("nla_cn_details", print_vector_of_ptrs(children, tout);); @@ -347,4 +386,88 @@ void nex_creator::simplify_children_of_sum(ptr_vector & children) { sort_join_sum(children); } +bool all_factors_are_elementary(const nex_mul* a) { + for (auto & p : a->children()) + if (!p.e()->is_elementary()) + return false; + + return true; +} + +bool have_no_scalars(const nex_mul* a) { + for (auto & p : a->children()) + if (p.e()->is_scalar() && !to_scalar(p.e())->value().is_one()) + return false; + + return true; +} + + +nex * nex_creator::mk_div_by_mul(const nex* a, const nex_mul* b) { + if (a->is_sum()) { + nex_sum * r = mk_sum(); + const nex_sum * m = to_sum(a); + for (auto e : m->children()) { + r->add_child(mk_div_by_mul(e, b)); + } + TRACE("nla_cn_details", tout << *r << "\n";); + return r; + } + if (a->is_var() || (a->is_mul() && to_mul(a)->children().size() == 1)) { + return mk_scalar(rational(1)); + } + const nex_mul* am = to_mul(a); + SASSERT(all_factors_are_elementary(am) && all_factors_are_elementary(b) && have_no_scalars(b)); + b->get_powers_from_mul(m_powers); + nex_mul* ret = new nex_mul(); + for (auto& p : am->children()) { + TRACE("nla_cn_details", tout << "p = " << p << "\n";); + const nex* e = p.e(); + if (!e->is_var()) { + SASSERT(e->is_scalar()); + ret->add_child_in_power(clone(e), p.pow()); + TRACE("nla_cn_details", tout << "processed scalar\n";); + continue; + } + SASSERT(e->is_var()); + lpvar j = to_var(e)->var(); + auto it = m_powers.find(j); + if (it == m_powers.end()) { + ret->add_child_in_power(clone(e), p.pow()); + } else { + unsigned pw = p.pow(); + SASSERT(pw); + while (pw--) { + SASSERT(it->second); + it->second --; + if (it->second == 0) { + m_powers.erase(it); + break; + } + } + if (pw) { + ret->add_child_in_power(clone(e), pw); + } + } + TRACE("nla_cn_details", tout << *ret << "\n";); + } + SASSERT(m_powers.size() == 0); + if (ret->children().size() == 0) { + delete ret; + TRACE("nla_cn_details", tout << "return 1\n";); + return mk_scalar(rational(1)); + } + add_to_allocated(ret); + TRACE("nla_cn_details", tout << *ret << "\n";); + return ret; +} + +nex * nex_creator::mk_div(const nex* a, const nex* b) { + TRACE("nla_cn_details", tout << *a <<" / " << *b << "\n";); + if (b->is_var()) { + return mk_div(a, to_var(b)->var()); + } + return mk_div_by_mul(a, to_mul(b)); +} + } diff --git a/src/math/lp/nex_creator.h b/src/math/lp/nex_creator.h index b2e9adec3..61d6d617f 100644 --- a/src/math/lp/nex_creator.h +++ b/src/math/lp/nex_creator.h @@ -57,14 +57,15 @@ class nex_creator { svector m_active_vars_weights; public: + svector& active_vars_weights() { return m_active_vars_weights;} + const svector& active_vars_weights() const { return m_active_vars_weights;} nex* simplify(nex* e) { NOT_IMPLEMENTED_YET(); } rational extract_coeff_from_mul(const nex_mul* m); - - rational extract_coeff(const nex* ); + rational extract_coeff(const nex*); bool is_simplified(const nex *e) { NOT_IMPLEMENTED_YET(); @@ -205,6 +206,7 @@ public: nex * mk_div(const nex* a, lpvar j); nex * mk_div(const nex* a, const nex* b); + nex * mk_div_by_mul(const nex* a, const nex_mul* b); nex * simplify_mul(nex_mul *e); bool is_sorted(const nex_mul * e) const; @@ -226,36 +228,8 @@ public: bool eat_scalar_pow(nex_scalar *& r, nex_pow& p); void simplify_children_of_mul(vector & children, lt_on_vars lt, std::function mk_scalar); -bool sum_simplify_lt(const nex* a, const nex* b); + bool sum_simplify_lt(const nex* a, const nex* b); -bool less_than_nex(const nex* a, const nex* b, const lt_on_vars& lt) { - int r = (int)(a->type()) - (int)(b->type()); - if (r) { - return r < 0; - } - SASSERT(a->type() == b->type()); - switch (a->type()) { - case expr_type::VAR: { - return lt(to_var(a)->var() , to_var(b)->var()); - } - case expr_type::SCALAR: { - return to_scalar(a)->value() < to_scalar(b)->value(); - } - case expr_type::MUL: { - NOT_IMPLEMENTED_YET(); - return false; // to_mul(a)->children() < to_mul(b)->children(); - } - case expr_type::SUM: { - NOT_IMPLEMENTED_YET(); - return false; //to_sum(a)->children() < to_sum(b)->children(); - } - default: - SASSERT(false); - return false; - } - - return false; -} bool mul_simplify_lt(const nex_mul* a, const nex_mul* b); void fill_map_with_children(std::map & m, ptr_vector & children); }; diff --git a/src/math/lp/nla_grobner.cpp b/src/math/lp/nla_grobner.cpp index a9b62bfe9..3533865bd 100644 --- a/src/math/lp/nla_grobner.cpp +++ b/src/math/lp/nla_grobner.cpp @@ -25,12 +25,7 @@ nla_grobner::nla_grobner(core *c ) : common(c), m_nl_gb_exhausted(false), - m_dep_manager(m_val_manager, m_alloc), - m_nex_creator([this](lpvar a, lpvar b) { - if (m_active_vars_weights[a] != m_active_vars_weights[b]) - return m_active_vars_weights[a] < m_active_vars_weights[b]; - return a < b; - }) {} + m_dep_manager(m_val_manager, m_alloc) {} // Scan the grobner basis eqs for equations of the form x - k = 0 or x = 0 is found, and x is not fixed, // then assert bounds for x, and continue @@ -107,9 +102,9 @@ var_weight nla_grobner::get_var_weight(lpvar j) const { } void nla_grobner::set_active_vars_weights() { - m_active_vars_weights.resize(c().m_lar_solver.column_count()); + m_nex_creator.active_vars_weights().resize(c().m_lar_solver.column_count()); for (lpvar j : m_active_vars) { - m_active_vars_weights[j] = get_var_weight(j); + m_nex_creator.active_vars_weights()[j] = get_var_weight(j); } } diff --git a/src/test/lp/lp.cpp b/src/test/lp/lp.cpp index 21e419eeb..cf31713b7 100644 --- a/src/test/lp/lp.cpp +++ b/src/test/lp/lp.cpp @@ -73,7 +73,6 @@ void test_cn_on_expr(nex_sum *t, cross_nested& cn) { cn.run(t); } -lt_on_vars lpvar_lt() { return [](lpvar a, lpvar b) { return a < b; };} void test_simplify() { cross_nested cn( @@ -82,7 +81,8 @@ void test_simplify() { return false; } , [](unsigned) { return false; }, - []{ return 1; }, lpvar_lt()); + []() { return 1; } // for random + ); enable_trace("nla_cn"); enable_trace("nla_cn_details"); nex_creator & r = cn.get_nex_creator(); @@ -103,12 +103,11 @@ void test_simplify() { nex * e = r.mk_sum(a, r.mk_sum(b, m)); TRACE("nla_cn", tout << "e = " << *e << "\n";); - std::function mks = [&r] {return r.mk_scalar(rational(1)); }; - e->simplify(&e, lpvar_lt(), mks); + e = r.simplify(e); TRACE("nla_cn", tout << "simplified e = " << *e << "\n";); nex * l = r.mk_sum(e, r.mk_mul(r.mk_scalar(rational(3)), r.clone(e))); TRACE("nla_cn", tout << "sum l = " << *l << "\n";); - l->simplify(&l, lpvar_lt(), mks); + l = r.simplify(l); TRACE("nla_cn", tout << "simplified sum l = " << *l << "\n";); } @@ -119,7 +118,7 @@ void test_cn() { return false; } , [](unsigned) { return false; }, - []{ return 1; }, lpvar_lt()); + []{ return 1; }); enable_trace("nla_cn"); enable_trace("nla_cn_details"); nex_var* a = cn.get_nex_creator().mk_var(0); @@ -145,8 +144,8 @@ void test_cn() { nex* _6aad = cn.get_nex_creator().mk_mul(cn.get_nex_creator().mk_scalar(rational(6)), a, a, d); #ifdef Z3DEBUG nex * clone = cn.get_nex_creator().clone(cn.get_nex_creator().mk_sum(_6aad, abcd, aaccd, add, eae, eac, ed)); - clone->simplify(&clone, lpvar_lt(), [&cn] {return cn.get_nex_creator().mk_scalar(rational(1));}); - SASSERT(clone->is_simplified(lpvar_lt())); + clone = cn.get_nex_creator().simplify(clone); + SASSERT(cn.get_nex_creator().is_simplified(clone)); TRACE("nla_cn", tout << "clone = " << *clone << "\n";); #endif // test_cn_on_expr(cn.get_nex_creator().mk_sum(aad, abcd, aaccd, add, eae, eac, ed), cn);