3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-13 20:38:43 +00:00

make sure that the returned cross nested form is equal to the original

Signed-off-by: Lev Nachmanson <levnach@hotmail.com>
This commit is contained in:
Lev Nachmanson 2019-08-19 16:51:52 -07:00
parent 4e59976c2f
commit a844b88c32
3 changed files with 101 additions and 36 deletions

View file

@ -270,7 +270,7 @@ public:
}
nex* c_over_f = mk_div(*c, f);
to_sum(c_over_f)->simplify();
to_sum(c_over_f)->simplify(&c_over_f);
*c = mk_mul(f, c_over_f);
TRACE("nla_cn", tout << "common factor=" << *f << ", c=" << **c << "\ne = " << *m_e << "\n";);
@ -463,8 +463,7 @@ public:
|| (ce->is_var() && to_var(ce)->var() == j);
}
// all factors of j go to a, the rest to b
void pre_split(nex_sum * e, lpvar j, nex_sum* & a, nex* & b) {
void pre_split(nex_sum * e, lpvar j, nex_sum*& a, nex*& b) {
a = mk_sum();
m_b_split_vec.clear();
for (nex * ce: e->children()) {
@ -478,7 +477,8 @@ public:
}
TRACE("nla_cn_details", tout << "a = " << *a << "\n";);
SASSERT(a->children().size() >= 2 && m_b_split_vec.size());
a->simplify();
nex* f;
a->simplify(&f);
if (m_b_split_vec.size() == 1) {
b = m_b_split_vec[0];
@ -608,11 +608,13 @@ public:
for (unsigned j = 0; j < a->size(); j ++) {
a->children()[j] = normalize(a->children()[j]);
}
a->simplify();
return a;
nex *r;
a->simplify(&r);
return r;
}
nex * normalize_mul(nex_mul* a) {
TRACE("nla_cn", tout << *a << "\n";);
int sum_j = -1;
for (unsigned j = 0; j < a->size(); j ++) {
a->children()[j] = normalize(a->children()[j]);
@ -620,28 +622,36 @@ public:
sum_j = j;
}
if (sum_j == -1)
return a;
if (sum_j == -1) {
nex * r;
a->simplify(&r);
SASSERT(r->is_simplified());
return r;
}
nex_sum *r = mk_sum();
nex_sum *as = to_sum(a->children()[sum_j]);
for (unsigned k = 0; k < as->size(); k++) {
nex_mul *b = mk_mul(as->children()[k]);
r->add_child(b);
for (unsigned j = 0; j < a->size(); j ++) {
if ((int)j != sum_j)
b->add_child(a->children()[j]);
}
b->simplify();
nex *e;
b->simplify(&e);
r->add_child(e);
}
TRACE("nla_cn", tout << *r << "\n";);
return normalize_sum(r);
TRACE("nla_cn", tout << *r << "\n";);
nex *rs = normalize_sum(r);
SASSERT(rs->is_simplified());
return rs;
}
nex * normalize(nex* a) {
if (a->is_simple())
if (a->is_elementary())
return a;
nex *r;
if (a->is_mul()) {

View file

@ -49,7 +49,7 @@ public:
virtual expr_type type() const = 0;
virtual std::ostream& print(std::ostream&) const = 0;
nex() {}
bool is_simple() const {
bool is_elementary() const {
switch(type()) {
case expr_type::SUM:
case expr_type::MUL:
@ -67,7 +67,10 @@ public:
virtual ~nex() {}
virtual bool contains(lpvar j) const { return false; }
virtual int get_degree() const = 0;
virtual void simplify() {}
virtual void simplify(nex** ) = 0;
virtual bool is_simplified() const {
return true;
}
virtual const ptr_vector<nex> * children_ptr() const {
UNREACHABLE();
return nullptr;
@ -103,6 +106,7 @@ public:
bool contains(lpvar j) const { return j == m_j; }
int get_degree() const { return 1; }
virtual void simplify(nex** e) { *e = this; }
};
class nex_scalar : public nex {
@ -119,29 +123,48 @@ public:
}
int get_degree() const { return 0; }
virtual void simplify(nex** e) { *e = this; }
};
const nex_scalar * to_scalar(const nex* a);
static bool ignored_child(nex* e, expr_type t) {
switch(t) {
case expr_type::MUL:
return e->is_scalar() && to_scalar(e)->value().is_one();
case expr_type::SUM:
return e->is_scalar() && to_scalar(e)->value().is_zero();
default: return false;
}
return false;
}
static void promote_children_by_type(ptr_vector<nex> * children, expr_type t) {
ptr_vector<nex> to_promote;
int skipped = 0;
for(unsigned j = 0; j < children->size(); j++) {
nex* e = (*children)[j];
e->simplify();
if (e->type() == t) {
to_promote.push_back(e);
nex** e = &(*children)[j];
(*e)->simplify(e);
if ((*e)->type() == t) {
to_promote.push_back(*e);
} else if (ignored_child(*e, t)) {
skipped ++;
continue;
} else {
unsigned offset = to_promote.size();
unsigned offset = to_promote.size() + skipped;
if (offset) {
(*children)[j - offset] = e;
(*children)[j - offset] = *e;
}
}
}
children->shrink(children->size() - to_promote.size());
children->shrink(children->size() - to_promote.size() - skipped);
for (nex *e : to_promote) {
for (nex *ee : *(e->children_ptr())) {
children->push_back(ee);
if (!ignored_child(ee, t))
children->push_back(ee);
}
}
}
@ -163,12 +186,12 @@ public:
std::string s = v->str();
if (first) {
first = false;
if (v->is_simple())
if (v->is_elementary())
out << s;
else
out << "(" << s << ")";
} else {
if (v->is_simple()) {
if (v->is_elementary()) {
if (s[0] == '-') {
out << "*(" << s << ")";
} else {
@ -222,12 +245,29 @@ public:
return degree;
}
void simplify() {
void simplify(nex **e) {
*e = this;
TRACE("nla_cn_details", tout << *this << "\n";);
promote_children_by_type(&m_children, expr_type::MUL);
if (size() == 1)
*e = m_children[0];
TRACE("nla_cn_details", tout << *this << "\n";);
SASSERT((*e)->is_simplified());
}
#ifdef Z3DEBUG
virtual bool is_simplified() const {
if (size() < 2)
return false;
for (nex * e : children()) {
if (e->is_mul())
return false;
if (e->is_scalar() && to_scalar(e)->value().is_one())
return false;
}
return true;
}
#ifdef Z3DEBUG
virtual void sort() {
for (nex * c : m_children) {
c->sort();
@ -271,12 +311,12 @@ public:
std::string s = v->str();
if (first) {
first = false;
if (v->is_simple())
if (v->is_elementary())
out << s;
else
out << "(" << s << ")";
} else {
if (v->is_simple()) {
if (v->is_elementary()) {
if (s[0] == '-') {
out << s;
} else {
@ -290,8 +330,21 @@ public:
return out;
}
void simplify() {
void simplify(nex **e) {
*e = this;
promote_children_by_type(&m_children, expr_type::SUM);
if (size() == 1)
*e = m_children[0];
}
virtual bool is_simplified() const {
if (size() < 2) return false;
for (nex * e : children()) {
if (e->is_sum())
return false;
if (e->is_scalar() && to_scalar(e)->value().is_zero())
return false;
}
return true;
}
int get_degree() const {
@ -331,6 +384,11 @@ inline const nex_var* to_var(const nex*a) {
return static_cast<const nex_var*>(a);
}
inline const nex_scalar* to_scalar(const nex*a) {
SASSERT(a->is_scalar());
return static_cast<const nex_scalar*>(a);
}
inline const nex_mul* to_mul(const nex*a) {
SASSERT(a->is_mul());
return static_cast<const nex_mul*>(a);
@ -341,11 +399,6 @@ inline nex_mul* to_mul(nex*a) {
return static_cast<nex_mul*>(a);
}
inline const nex_scalar * to_scalar(const nex* a) {
SASSERT(a->is_scalar());
return static_cast<const nex_scalar*>(a);
}
inline std::ostream& operator<<(std::ostream& out, const nex& e ) {
return e.print(out);
}

View file

@ -102,8 +102,10 @@ void test_cn() {
nex* eac = cn.mk_mul(e, a, c);
nex* ed = cn.mk_mul(e, d);
nex* _6aad = cn.mk_mul(cn.mk_scalar(rational(6)), a, a, d);
#ifdef Z3DEBUG
nex * clone = cn.clone(cn.mk_sum(_6aad, abcd, aaccd, add, eae, eac, ed));
TRACE("nla_cn", tout << "clone = " << *clone << "\n";);
#endif
// test_cn_on_expr(cn.mk_sum(aad, abcd, aaccd, add, eae, eac, ed), cn);
test_cn_on_expr(cn.mk_sum(_6aad, abcd, aaccd, add, eae, eac, ed), cn);
// TRACE("nla_cn", tout << "done\n";);