diff --git a/src/math/lp/cross_nested.h b/src/math/lp/cross_nested.h index e82de1683..c5694955c 100644 --- a/src/math/lp/cross_nested.h +++ b/src/math/lp/cross_nested.h @@ -60,8 +60,7 @@ public: SASSERT(m_nex_creator.is_simplified(e)); m_e = e; #ifdef Z3DEBUG - // m_e_clone = clone(m_e); - // m_e_clone = normalize(m_e_clone); + m_e_clone = m_nex_creator.clone(m_e); #endif vector front; explore_expr_on_front_elem(&m_e, front); @@ -255,13 +254,9 @@ public: if(front.empty()) { TRACE("nla_cn", tout << "got the cn form: =" << *m_e << "\n";); m_done = m_call_on_result(m_e) || ++m_reported > 100; -// #ifdef Z3DEBUG -// nex *ce = clone(m_e); -// TRACE("nla_cn", tout << "ce = " << *ce << "\n";); -// nex *n = normalize(ce); -// TRACE("nla_cn", tout << "n = " << *n << "\nm_e_clone=" << * m_e_clone << "\n";); -// SASSERT(*n == *m_e_clone); -// #endif + #ifdef Z3DEBUG + SASSERT(nex_creator::equal(m_e, m_e_clone)); + #endif } else { nex** f = pop_front(front); explore_expr_on_front_elem(f, front); diff --git a/src/math/lp/nex.h b/src/math/lp/nex.h index 397242a62..f7979e6a8 100644 --- a/src/math/lp/nex.h +++ b/src/math/lp/nex.h @@ -410,6 +410,7 @@ inline std::unordered_set get_vars_of_expr(const nex *e ) { for ( lpvar j : get_vars_of_expr(c)) r.insert(j); } + return r; case expr_type::MUL: { for (auto &c: *to_mul(e)) diff --git a/src/math/lp/nex_creator.cpp b/src/math/lp/nex_creator.cpp index 1570cbde9..5e1fefe8f 100644 --- a/src/math/lp/nex_creator.cpp +++ b/src/math/lp/nex_creator.cpp @@ -65,6 +65,8 @@ bool nex_creator::eat_scalar_pow(nex_scalar *& r, nex_pow& p, unsigned pow) { if (!p.e()->is_scalar()) return false; nex_scalar *pe = to_scalar(p.e()); + if (pe->value().is_one()) + return true; // but do not change r here if (r == nullptr) { r = pe; r->value() = r->value().expt(p.pow()*pow); @@ -207,9 +209,7 @@ bool nex_creator::less_than_on_var_nex(const nex_var* a, const nex* b) const { case expr_type::SUM: { - nex_sum m; - m.add_child(const_cast(a)); - return lt(&m, to_sum(b)); + return !lt((*to_sum(b))[0], a); } default: UNREACHABLE(); @@ -243,7 +243,20 @@ bool nex_creator::less_than_on_mul_nex(const nex_mul* a, const nex* b) const { } } +bool nex_creator::less_than_on_sum_sum(const nex_sum* a, const nex_sum* b) const { + unsigned size = std::min(a->size(), b->size()); + for (unsigned j = 0; j < size; j++) { + if (lt((*a)[j], (*b)[j])) + return true; + if (lt((*b)[j], (*a)[j])) + return false; + } + return size < b->size(); + +} + bool nex_creator::lt(const nex* a, const nex* b) const { + TRACE("nla_cn_details", tout << *a << " ^ " << *b << "\n";); bool ret; switch (a->type()) { case expr_type::VAR: @@ -261,8 +274,9 @@ bool nex_creator::lt(const nex* a, const nex* b) const { break; } case expr_type::SUM: { - UNREACHABLE(); - return false; + if (b->is_sum()) + return less_than_on_sum_sum(to_sum(a), to_sum(b)); + return lt((*to_sum(a))[0], b); } default: UNREACHABLE(); @@ -479,9 +493,7 @@ void nex_creator::sort_join_sum(ptr_vector & children) { { return lt(a, b); }); std::unordered_set existing_nex; // handling (nex*) as numbers nex_scalar * common_scalar; - bool simplified = fill_join_map_for_sum(children, map, existing_nex, common_scalar); - if (!simplified) - return; + fill_join_map_for_sum(children, map, existing_nex, common_scalar); TRACE("nla_cn_details", for (auto & p : map ) { tout << "(" << *p.first << ", " << p.second << ") ";}); children.clear(); @@ -676,4 +688,78 @@ bool nex_creator::is_simplified(const nex *e) const return sum_is_simplified(to_sum(e)); return true; } + +#ifdef Z3DEBUG +unsigned nex_creator::find_sum_in_mul(const nex_mul* a) const { + for (unsigned j = 0; j < a->size(); j++) + if ((*a)[j].e()->is_sum()) + return j; + + return -1; +} +nex* nex_creator::canonize_mul(nex_mul *a) { + unsigned j = find_sum_in_mul(a); + if (j + 1 == 0) + return a; + nex_pow& np = (*a)[j]; + SASSERT(np.pow()); + unsigned power = np.pow(); + nex_sum * s = to_sum(np.e()); // s is going to explode + nex_sum * r = mk_sum(); + nex *sclone = power > 1? clone(s) : nullptr; + for (nex *e : *s) { + nex_mul *m = mk_mul(); + if (power > 1) + m->add_child_in_power(sclone, power - 1); + m->add_child(e); + for (unsigned k = 0; k < a->size(); k++) { + if (k == j) + continue; + m->add_child_in_power(clone((*a)[k].e()), (*a)[k].pow()); + } + r->add_child(m); + } + TRACE("nla_cn_details", tout << *r << "\n";); + return canonize(r); +} + + +nex* nex_creator::canonize(const nex *a) { + if (a->is_elementary()) + return clone(a); + + nex *t = simplify(clone(a)); + if (t->is_sum()) { + nex_sum * s = to_sum(t); + for (unsigned j = 0; j < s->size(); j++) { + (*s)[j] = canonize((*s)[j]); + } + t = simplify(s); + TRACE("nla_cn_details", tout << *t << "\n";); + return t; + } + return canonize_mul(to_mul(t)); +} + +bool nex_creator::equal(const nex* a, const nex* b) { + nex_creator cn; + unsigned n = 0; + for (lpvar j : get_vars_of_expr(a)) { + n = std::max(j + 1, n); + } + for (lpvar j : get_vars_of_expr(b)) { + n = std::max(j + 1, n); + } + cn.set_number_of_vars(n); + for (lpvar j = 0; j < n; j++) { + cn.set_var_weight(j, j); + } + nex * ca = cn.canonize(a); + nex * cb = cn.canonize(b); + TRACE("nla_cn_test", tout << "a = " << *a << ", canonized a = " << *ca << "\n";); + TRACE("nla_cn_test", tout << "b = " << *b << ", canonized b = " << *cb << "\n";); + return !(cn.lt(ca, cb) || cn.lt(cb, ca)); +} +#endif + } diff --git a/src/math/lp/nex_creator.h b/src/math/lp/nex_creator.h index 6fbf5e073..598bf5b26 100644 --- a/src/math/lp/nex_creator.h +++ b/src/math/lp/nex_creator.h @@ -54,7 +54,7 @@ class nex_creator { ptr_vector m_allocated; std::unordered_map m_occurences_map; std::unordered_map m_powers; - svector m_active_vars_weights; + svector m_active_vars_weights; public: static char ch(unsigned j) { @@ -64,9 +64,23 @@ public: return (char)('a'+j); } - svector& active_vars_weights() { return m_active_vars_weights;} - const svector& active_vars_weights() const { return m_active_vars_weights;} + // assuming that every lpvar is less than this number + void set_number_of_vars(unsigned k) { + m_active_vars_weights.resize(k); + } + unsigned get_number_of_vars() const { + return m_active_vars_weights.size(); + } + + + void set_var_weight(unsigned j, unsigned weight) { + m_active_vars_weights[j] = weight; + } +private: + svector& active_vars_weights() { return m_active_vars_weights;} + const svector& active_vars_weights() const { return m_active_vars_weights;} +public: nex* simplify(nex* e); bool less_than(lpvar j, lpvar k) const{ @@ -237,7 +251,15 @@ public: bool less_than_on_mul(const nex_mul* a, const nex_mul* b) const; bool less_than_on_var_nex(const nex_var* a, const nex* b) const; bool less_than_on_mul_nex(const nex_mul* a, const nex* b) const; + bool less_than_on_sum_sum(const nex_sum* a, const nex_sum* b) const; void fill_map_with_children(std::map & m, ptr_vector & children); void process_map_pair(nex *e, const rational& coeff, ptr_vector & children, std::unordered_set&); +#ifdef Z3DEBUG + static + bool equal(const nex*, const nex* ); + nex* canonize(const nex*); + nex* canonize_mul(nex_mul*); + unsigned find_sum_in_mul(const nex_mul* a) const; +#endif }; } diff --git a/src/math/lp/nla_grobner.cpp b/src/math/lp/nla_grobner.cpp index 3533865bd..d2df9c66c 100644 --- a/src/math/lp/nla_grobner.cpp +++ b/src/math/lp/nla_grobner.cpp @@ -102,9 +102,9 @@ var_weight nla_grobner::get_var_weight(lpvar j) const { } void nla_grobner::set_active_vars_weights() { - m_nex_creator.active_vars_weights().resize(c().m_lar_solver.column_count()); + m_nex_creator.set_number_of_vars(c().m_lar_solver.column_count()); for (lpvar j : m_active_vars) { - m_nex_creator.active_vars_weights()[j] = get_var_weight(j); + m_nex_creator.set_var_weight(j, static_cast(get_var_weight(j))); } } diff --git a/src/test/lp/lp.cpp b/src/test/lp/lp.cpp index aa00adc64..11177e619 100644 --- a/src/test/lp/lp.cpp +++ b/src/test/lp/lp.cpp @@ -108,27 +108,6 @@ bool mul_has_var_in_power(lpvar j, unsigned k, const nex_mul* e) { return false; } -bool has_var_in_power(lpvar j, unsigned k, const nex* e) { - TRACE("nla_cn", tout << "j = " << nex_creator::ch(j) << ", e = " << *e << ", k = " << k << "\n";); - if (k == 0) - return true; - if (e->is_scalar()) - return false; - if (e->is_var()) { - return k == 1 && to_var(e)->var() == j; - } - if (e->is_sum()) { - for (auto ee : *to_sum(e)) { - if (has_var_in_power(j, k, ee)) - return true; - } - return false; - } - if (e->is_mul()) { - return mul_has_var_in_power(j, k, to_mul(e)); - } -} - void test_simplify() { cross_nested cn( [](const nex* n) { @@ -144,9 +123,9 @@ void test_simplify() { enable_trace("nla_test"); nex_creator & r = cn.get_nex_creator(); - r.active_vars_weights().resize(3); - for (unsigned j = 0; j < r.active_vars_weights().size(); j++) - r.active_vars_weights()[j] = static_cast(5 - j); + r.set_number_of_vars(3); + for (unsigned j = 0; j < r.get_number_of_vars(); j++) + r.set_var_weight(j, j); nex_var* a = r.mk_var(0); nex_var* b = r.mk_var(1); nex_var* c = r.mk_var(2); @@ -199,24 +178,25 @@ void test_simplify() { } void test_cn_shorter() { + nex_sum *clone; cross_nested cn( [](const nex* n) { TRACE("nla_test", tout <<"cn form = " << *n << "\n"; - SASSERT(has_var_in_power(4, // stands for e - 2, n)); + ); return false; } , [](unsigned) { return false; }, []{ return 1; }); enable_trace("nla_test"); - // enable_trace("nla_cn"); - // enable_trace("nla_cn_details"); - // enable_trace("nla_test_details"); + enable_trace("nla_cn"); + enable_trace("nla_cn_test"); + enable_trace("nla_cn_details"); + enable_trace("nla_test_details"); auto & cr = cn.get_nex_creator(); - cr.active_vars_weights().resize(20); - for (unsigned j = 0; j < cr.active_vars_weights().size(); j++) - cr.active_vars_weights()[j] = static_cast(1); + cr.set_number_of_vars(20); + for (unsigned j = 0; j < cr.get_number_of_vars(); j++) + cr.set_var_weight(j,j); nex_var* a = cr.mk_var(0); nex_var* b = cr.mk_var(1); @@ -238,20 +218,11 @@ void test_cn_shorter() { nex* eac = cr.mk_mul(e, a, c); nex* ed = cr.mk_mul(e, d); nex* _6aad = cr.mk_mul(cr.mk_scalar(rational(6)), a, a, d); -#ifdef Z3DEBUG - nex * clone = cr.clone(cr.mk_sum(_6aad, abcd, eae, eac)); - clone = cr.simplify(clone); - SASSERT(cr.is_simplified(clone)); + clone = to_sum(cr.clone(cr.mk_sum(_6aad, abcd, eae, eac))); + clone = to_sum(cr.simplify(clone)); TRACE("nla_test", tout << "clone = " << *clone << "\n";); -#endif // test_cn_on_expr(cr.mk_sum(aad, abcd, aaccd, add, eae, eac, ed), cn); - test_cn_on_expr(to_sum(clone), cn); - // TRACE("nla_test", tout << "done\n";); - // test_cn_on_expr(a*b*d + a*b*c + c*b*d + a*c*d); - // TRACE("nla_test", tout << "done\n";); - // test_cn_on_expr(a*b*b*d*d + a*b*b*c*d + c*b*b*d); - // TRACE("nla_test", tout << "done\n";); - // test_cn_on_expr(a*b*d + a*b*c + c*b*d); + test_cn_on_expr(clone, cn); } void test_cn() { @@ -264,12 +235,13 @@ void test_cn() { [](unsigned) { return false; }, []{ return 1; }); enable_trace("nla_test"); + enable_trace("nla_cn_test"); // enable_trace("nla_cn"); // enable_trace("nla_test_details"); auto & cr = cn.get_nex_creator(); - cr.active_vars_weights().resize(20); - for (unsigned j = 0; j < cr.active_vars_weights().size(); j++) - cr.active_vars_weights()[j] = static_cast(1); + cr.set_number_of_vars(20); + for (unsigned j = 0; j < cr.get_number_of_vars(); j++) + cr.set_var_weight(j, j); nex_var* a = cr.mk_var(0); nex_var* b = cr.mk_var(1);