3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-13 12:28:44 +00:00

implement canonization of nex expressions

Signed-off-by: Lev Nachmanson <levnach@hotmail.com>
This commit is contained in:
Lev Nachmanson 2019-10-04 16:28:24 -07:00
parent 08de9ecbd1
commit 3e009a237f
6 changed files with 145 additions and 69 deletions

View file

@ -60,8 +60,7 @@ public:
SASSERT(m_nex_creator.is_simplified(e));
m_e = e;
#ifdef Z3DEBUG
// m_e_clone = clone(m_e);
// m_e_clone = normalize(m_e_clone);
m_e_clone = m_nex_creator.clone(m_e);
#endif
vector<nex**> front;
explore_expr_on_front_elem(&m_e, front);
@ -255,13 +254,9 @@ public:
if(front.empty()) {
TRACE("nla_cn", tout << "got the cn form: =" << *m_e << "\n";);
m_done = m_call_on_result(m_e) || ++m_reported > 100;
// #ifdef Z3DEBUG
// nex *ce = clone(m_e);
// TRACE("nla_cn", tout << "ce = " << *ce << "\n";);
// nex *n = normalize(ce);
// TRACE("nla_cn", tout << "n = " << *n << "\nm_e_clone=" << * m_e_clone << "\n";);
// SASSERT(*n == *m_e_clone);
// #endif
#ifdef Z3DEBUG
SASSERT(nex_creator::equal(m_e, m_e_clone));
#endif
} else {
nex** f = pop_front(front);
explore_expr_on_front_elem(f, front);

View file

@ -410,6 +410,7 @@ inline std::unordered_set<lpvar> get_vars_of_expr(const nex *e ) {
for ( lpvar j : get_vars_of_expr(c))
r.insert(j);
}
return r;
case expr_type::MUL:
{
for (auto &c: *to_mul(e))

View file

@ -65,6 +65,8 @@ bool nex_creator::eat_scalar_pow(nex_scalar *& r, nex_pow& p, unsigned pow) {
if (!p.e()->is_scalar())
return false;
nex_scalar *pe = to_scalar(p.e());
if (pe->value().is_one())
return true; // but do not change r here
if (r == nullptr) {
r = pe;
r->value() = r->value().expt(p.pow()*pow);
@ -207,9 +209,7 @@ bool nex_creator::less_than_on_var_nex(const nex_var* a, const nex* b) const {
case expr_type::SUM:
{
nex_sum m;
m.add_child(const_cast<nex_var*>(a));
return lt(&m, to_sum(b));
return !lt((*to_sum(b))[0], a);
}
default:
UNREACHABLE();
@ -243,7 +243,20 @@ bool nex_creator::less_than_on_mul_nex(const nex_mul* a, const nex* b) const {
}
}
bool nex_creator::less_than_on_sum_sum(const nex_sum* a, const nex_sum* b) const {
unsigned size = std::min(a->size(), b->size());
for (unsigned j = 0; j < size; j++) {
if (lt((*a)[j], (*b)[j]))
return true;
if (lt((*b)[j], (*a)[j]))
return false;
}
return size < b->size();
}
bool nex_creator::lt(const nex* a, const nex* b) const {
TRACE("nla_cn_details", tout << *a << " ^ " << *b << "\n";);
bool ret;
switch (a->type()) {
case expr_type::VAR:
@ -261,8 +274,9 @@ bool nex_creator::lt(const nex* a, const nex* b) const {
break;
}
case expr_type::SUM: {
UNREACHABLE();
return false;
if (b->is_sum())
return less_than_on_sum_sum(to_sum(a), to_sum(b));
return lt((*to_sum(a))[0], b);
}
default:
UNREACHABLE();
@ -479,9 +493,7 @@ void nex_creator::sort_join_sum(ptr_vector<nex> & children) {
{ return lt(a, b); });
std::unordered_set<nex*> existing_nex; // handling (nex*) as numbers
nex_scalar * common_scalar;
bool simplified = fill_join_map_for_sum(children, map, existing_nex, common_scalar);
if (!simplified)
return;
fill_join_map_for_sum(children, map, existing_nex, common_scalar);
TRACE("nla_cn_details", for (auto & p : map ) { tout << "(" << *p.first << ", " << p.second << ") ";});
children.clear();
@ -676,4 +688,78 @@ bool nex_creator::is_simplified(const nex *e) const
return sum_is_simplified(to_sum(e));
return true;
}
#ifdef Z3DEBUG
unsigned nex_creator::find_sum_in_mul(const nex_mul* a) const {
for (unsigned j = 0; j < a->size(); j++)
if ((*a)[j].e()->is_sum())
return j;
return -1;
}
nex* nex_creator::canonize_mul(nex_mul *a) {
unsigned j = find_sum_in_mul(a);
if (j + 1 == 0)
return a;
nex_pow& np = (*a)[j];
SASSERT(np.pow());
unsigned power = np.pow();
nex_sum * s = to_sum(np.e()); // s is going to explode
nex_sum * r = mk_sum();
nex *sclone = power > 1? clone(s) : nullptr;
for (nex *e : *s) {
nex_mul *m = mk_mul();
if (power > 1)
m->add_child_in_power(sclone, power - 1);
m->add_child(e);
for (unsigned k = 0; k < a->size(); k++) {
if (k == j)
continue;
m->add_child_in_power(clone((*a)[k].e()), (*a)[k].pow());
}
r->add_child(m);
}
TRACE("nla_cn_details", tout << *r << "\n";);
return canonize(r);
}
nex* nex_creator::canonize(const nex *a) {
if (a->is_elementary())
return clone(a);
nex *t = simplify(clone(a));
if (t->is_sum()) {
nex_sum * s = to_sum(t);
for (unsigned j = 0; j < s->size(); j++) {
(*s)[j] = canonize((*s)[j]);
}
t = simplify(s);
TRACE("nla_cn_details", tout << *t << "\n";);
return t;
}
return canonize_mul(to_mul(t));
}
bool nex_creator::equal(const nex* a, const nex* b) {
nex_creator cn;
unsigned n = 0;
for (lpvar j : get_vars_of_expr(a)) {
n = std::max(j + 1, n);
}
for (lpvar j : get_vars_of_expr(b)) {
n = std::max(j + 1, n);
}
cn.set_number_of_vars(n);
for (lpvar j = 0; j < n; j++) {
cn.set_var_weight(j, j);
}
nex * ca = cn.canonize(a);
nex * cb = cn.canonize(b);
TRACE("nla_cn_test", tout << "a = " << *a << ", canonized a = " << *ca << "\n";);
TRACE("nla_cn_test", tout << "b = " << *b << ", canonized b = " << *cb << "\n";);
return !(cn.lt(ca, cb) || cn.lt(cb, ca));
}
#endif
}

View file

@ -54,7 +54,7 @@ class nex_creator {
ptr_vector<nex> m_allocated;
std::unordered_map<lpvar, occ> m_occurences_map;
std::unordered_map<lpvar, unsigned> m_powers;
svector<var_weight> m_active_vars_weights;
svector<unsigned> m_active_vars_weights;
public:
static char ch(unsigned j) {
@ -64,9 +64,23 @@ public:
return (char)('a'+j);
}
svector<var_weight>& active_vars_weights() { return m_active_vars_weights;}
const svector<var_weight>& active_vars_weights() const { return m_active_vars_weights;}
// assuming that every lpvar is less than this number
void set_number_of_vars(unsigned k) {
m_active_vars_weights.resize(k);
}
unsigned get_number_of_vars() const {
return m_active_vars_weights.size();
}
void set_var_weight(unsigned j, unsigned weight) {
m_active_vars_weights[j] = weight;
}
private:
svector<unsigned>& active_vars_weights() { return m_active_vars_weights;}
const svector<unsigned>& active_vars_weights() const { return m_active_vars_weights;}
public:
nex* simplify(nex* e);
bool less_than(lpvar j, lpvar k) const{
@ -237,7 +251,15 @@ public:
bool less_than_on_mul(const nex_mul* a, const nex_mul* b) const;
bool less_than_on_var_nex(const nex_var* a, const nex* b) const;
bool less_than_on_mul_nex(const nex_mul* a, const nex* b) const;
bool less_than_on_sum_sum(const nex_sum* a, const nex_sum* b) const;
void fill_map_with_children(std::map<nex*, rational, nex_lt> & m, ptr_vector<nex> & children);
void process_map_pair(nex *e, const rational& coeff, ptr_vector<nex> & children, std::unordered_set<nex*>&);
#ifdef Z3DEBUG
static
bool equal(const nex*, const nex* );
nex* canonize(const nex*);
nex* canonize_mul(nex_mul*);
unsigned find_sum_in_mul(const nex_mul* a) const;
#endif
};
}

View file

@ -102,9 +102,9 @@ var_weight nla_grobner::get_var_weight(lpvar j) const {
}
void nla_grobner::set_active_vars_weights() {
m_nex_creator.active_vars_weights().resize(c().m_lar_solver.column_count());
m_nex_creator.set_number_of_vars(c().m_lar_solver.column_count());
for (lpvar j : m_active_vars) {
m_nex_creator.active_vars_weights()[j] = get_var_weight(j);
m_nex_creator.set_var_weight(j, static_cast<unsigned>(get_var_weight(j)));
}
}

View file

@ -108,27 +108,6 @@ bool mul_has_var_in_power(lpvar j, unsigned k, const nex_mul* e) {
return false;
}
bool has_var_in_power(lpvar j, unsigned k, const nex* e) {
TRACE("nla_cn", tout << "j = " << nex_creator::ch(j) << ", e = " << *e << ", k = " << k << "\n";);
if (k == 0)
return true;
if (e->is_scalar())
return false;
if (e->is_var()) {
return k == 1 && to_var(e)->var() == j;
}
if (e->is_sum()) {
for (auto ee : *to_sum(e)) {
if (has_var_in_power(j, k, ee))
return true;
}
return false;
}
if (e->is_mul()) {
return mul_has_var_in_power(j, k, to_mul(e));
}
}
void test_simplify() {
cross_nested cn(
[](const nex* n) {
@ -144,9 +123,9 @@ void test_simplify() {
enable_trace("nla_test");
nex_creator & r = cn.get_nex_creator();
r.active_vars_weights().resize(3);
for (unsigned j = 0; j < r.active_vars_weights().size(); j++)
r.active_vars_weights()[j] = static_cast<var_weight>(5 - j);
r.set_number_of_vars(3);
for (unsigned j = 0; j < r.get_number_of_vars(); j++)
r.set_var_weight(j, j);
nex_var* a = r.mk_var(0);
nex_var* b = r.mk_var(1);
nex_var* c = r.mk_var(2);
@ -199,24 +178,25 @@ void test_simplify() {
}
void test_cn_shorter() {
nex_sum *clone;
cross_nested cn(
[](const nex* n) {
TRACE("nla_test", tout <<"cn form = " << *n << "\n";
SASSERT(has_var_in_power(4, // stands for e
2, n));
);
return false;
} ,
[](unsigned) { return false; },
[]{ return 1; });
enable_trace("nla_test");
// enable_trace("nla_cn");
// enable_trace("nla_cn_details");
// enable_trace("nla_test_details");
enable_trace("nla_cn");
enable_trace("nla_cn_test");
enable_trace("nla_cn_details");
enable_trace("nla_test_details");
auto & cr = cn.get_nex_creator();
cr.active_vars_weights().resize(20);
for (unsigned j = 0; j < cr.active_vars_weights().size(); j++)
cr.active_vars_weights()[j] = static_cast<var_weight>(1);
cr.set_number_of_vars(20);
for (unsigned j = 0; j < cr.get_number_of_vars(); j++)
cr.set_var_weight(j,j);
nex_var* a = cr.mk_var(0);
nex_var* b = cr.mk_var(1);
@ -238,20 +218,11 @@ void test_cn_shorter() {
nex* eac = cr.mk_mul(e, a, c);
nex* ed = cr.mk_mul(e, d);
nex* _6aad = cr.mk_mul(cr.mk_scalar(rational(6)), a, a, d);
#ifdef Z3DEBUG
nex * clone = cr.clone(cr.mk_sum(_6aad, abcd, eae, eac));
clone = cr.simplify(clone);
SASSERT(cr.is_simplified(clone));
clone = to_sum(cr.clone(cr.mk_sum(_6aad, abcd, eae, eac)));
clone = to_sum(cr.simplify(clone));
TRACE("nla_test", tout << "clone = " << *clone << "\n";);
#endif
// test_cn_on_expr(cr.mk_sum(aad, abcd, aaccd, add, eae, eac, ed), cn);
test_cn_on_expr(to_sum(clone), cn);
// TRACE("nla_test", tout << "done\n";);
// test_cn_on_expr(a*b*d + a*b*c + c*b*d + a*c*d);
// TRACE("nla_test", tout << "done\n";);
// test_cn_on_expr(a*b*b*d*d + a*b*b*c*d + c*b*b*d);
// TRACE("nla_test", tout << "done\n";);
// test_cn_on_expr(a*b*d + a*b*c + c*b*d);
test_cn_on_expr(clone, cn);
}
void test_cn() {
@ -264,12 +235,13 @@ void test_cn() {
[](unsigned) { return false; },
[]{ return 1; });
enable_trace("nla_test");
enable_trace("nla_cn_test");
// enable_trace("nla_cn");
// enable_trace("nla_test_details");
auto & cr = cn.get_nex_creator();
cr.active_vars_weights().resize(20);
for (unsigned j = 0; j < cr.active_vars_weights().size(); j++)
cr.active_vars_weights()[j] = static_cast<var_weight>(1);
cr.set_number_of_vars(20);
for (unsigned j = 0; j < cr.get_number_of_vars(); j++)
cr.set_var_weight(j, j);
nex_var* a = cr.mk_var(0);
nex_var* b = cr.mk_var(1);