mirror of
https://github.com/Z3Prover/z3
synced 2025-04-22 16:45:31 +00:00
process with nex simplifications
Signed-off-by: Lev Nachmanson <levnach@hotmail.com>
This commit is contained in:
parent
c076c17df9
commit
8cd9989dcf
5 changed files with 122 additions and 78 deletions
|
@ -36,12 +36,12 @@ class cross_nested {
|
|||
bool m_random_bit;
|
||||
nex_creator m_nex_creator;
|
||||
nex_lt m_lt;
|
||||
|
||||
std::function<nex_scalar*()> m_mk_scalar;
|
||||
#ifdef Z3DEBUG
|
||||
nex* m_e_clone;
|
||||
#endif
|
||||
public:
|
||||
|
||||
|
||||
nex_creator& get_nex_creator() { return m_nex_creator; }
|
||||
|
||||
cross_nested(std::function<bool (const nex*)> call_on_result,
|
||||
|
@ -54,7 +54,9 @@ public:
|
|||
m_done(false),
|
||||
m_reported(0),
|
||||
m_nex_creator(lt),
|
||||
m_lt(lt) {}
|
||||
m_lt(lt),
|
||||
m_mk_scalar([this]{return m_nex_creator.mk_scalar(rational(1));})
|
||||
{}
|
||||
|
||||
|
||||
void run(nex *e) {
|
||||
|
@ -128,7 +130,7 @@ public:
|
|||
}
|
||||
|
||||
nex* c_over_f = m_nex_creator.mk_div(*c, f);
|
||||
to_sum(c_over_f)->simplify(&c_over_f, m_lt);
|
||||
to_sum(c_over_f)->simplify(&c_over_f, m_lt, m_mk_scalar);
|
||||
nex_mul* cm;
|
||||
*c = cm = m_nex_creator.mk_mul(f, c_over_f);
|
||||
TRACE("nla_cn", tout << "common factor=" << *f << ", c=" << **c << "\ne = " << *m_e << "\n";);
|
||||
|
@ -393,7 +395,7 @@ public:
|
|||
TRACE("nla_cn_details", tout << "a = " << *a << "\n";);
|
||||
SASSERT(a->children().size() >= 2 && m_b_split_vec.size());
|
||||
nex* f;
|
||||
a->simplify(&f, m_lt);
|
||||
a->simplify(&f, m_lt, m_mk_scalar);
|
||||
|
||||
if (m_b_split_vec.size() == 1) {
|
||||
b = m_b_split_vec[0];
|
||||
|
@ -488,7 +490,7 @@ public:
|
|||
a->children()[j] = normalize(a->children()[j]);
|
||||
}
|
||||
nex *r;
|
||||
a->simplify(&r, m_lt);
|
||||
a->simplify(&r, m_lt, m_mk_scalar);
|
||||
return r;
|
||||
}
|
||||
|
||||
|
|
|
@ -22,15 +22,8 @@
|
|||
namespace nla {
|
||||
|
||||
|
||||
bool ignored_child(nex* e, expr_type t) {
|
||||
switch(t) {
|
||||
case expr_type::MUL:
|
||||
return e->is_scalar() && to_scalar(e)->value().is_one();
|
||||
case expr_type::SUM:
|
||||
return e->is_scalar() && to_scalar(e)->value().is_zero();
|
||||
default: return false;
|
||||
}
|
||||
return false;
|
||||
bool is_zero_scalar(nex* e) {
|
||||
return e->is_scalar() && to_scalar(e)->value().is_zero();
|
||||
}
|
||||
|
||||
void mul_to_powers(vector<nex_pow>& children, nex_lt lt) {
|
||||
|
@ -54,15 +47,50 @@ void mul_to_powers(vector<nex_pow>& children, nex_lt lt) {
|
|||
});
|
||||
}
|
||||
|
||||
void promote_children_of_sum(ptr_vector<nex> & children, nex_lt lt ) {
|
||||
rational extract_coeff(const nex_mul* m) {
|
||||
const nex* e = m->children().begin()->e();
|
||||
if (e->is_scalar())
|
||||
return to_scalar(e)->value();
|
||||
return rational(1);
|
||||
}
|
||||
|
||||
|
||||
bool sum_simplify_lt(const nex_mul* a, const nex_mul* b, const nex_lt& lt) {
|
||||
NOT_IMPLEMENTED_YET();
|
||||
}
|
||||
|
||||
// a + 3bc + 2bc => a + 5bc
|
||||
void sort_join_sum(ptr_vector<nex> & children, nex_lt& lt, std::function<nex_scalar*()> mk_scalar) {
|
||||
ptr_vector<nex> non_muls;
|
||||
std::map<nex_mul*, rational, std::function<bool(const nex_mul *a , const nex_mul *b)>>
|
||||
m([lt](const nex_mul *a , const nex_mul *b) { return sum_simplify_lt(a, b, lt); });
|
||||
for (nex* e : children) {
|
||||
SASSERT(e->is_simplified(lt));
|
||||
if (!e->is_mul()) {
|
||||
non_muls.push_back(e);
|
||||
} else {
|
||||
nex_mul * em = to_mul(e);
|
||||
rational r = extract_coeff(em);
|
||||
auto it = m.find(em);
|
||||
if (it == m.end()) {
|
||||
m[em] = r;
|
||||
} else {
|
||||
it->second += r;
|
||||
}
|
||||
}
|
||||
}
|
||||
NOT_IMPLEMENTED_YET();
|
||||
}
|
||||
|
||||
void simplify_children_of_sum(ptr_vector<nex> & children, nex_lt lt, std::function<nex_scalar*()> mk_scalar ) {
|
||||
ptr_vector<nex> to_promote;
|
||||
int skipped = 0;
|
||||
for(unsigned j = 0; j < children.size(); j++) {
|
||||
nex** e = &(children[j]);
|
||||
(*e)->simplify(e, lt);
|
||||
(*e)->simplify(e, lt, mk_scalar);
|
||||
if ((*e)->is_sum()) {
|
||||
to_promote.push_back(*e);
|
||||
} else if (ignored_child(*e, expr_type::SUM)) {
|
||||
} else if (is_zero_scalar(*e)) {
|
||||
skipped ++;
|
||||
continue;
|
||||
} else {
|
||||
|
@ -77,13 +105,15 @@ void promote_children_of_sum(ptr_vector<nex> & children, nex_lt lt ) {
|
|||
|
||||
for (nex *e : to_promote) {
|
||||
for (nex *ee : *(to_sum(e)->children_ptr())) {
|
||||
if (!ignored_child(ee, expr_type::SUM))
|
||||
if (!is_zero_scalar(ee))
|
||||
children.push_back(ee);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sort_join_sum(children, lt, mk_scalar);
|
||||
}
|
||||
|
||||
bool eat_scalar(nex_scalar *& r, nex_pow& p) {
|
||||
bool eat_scalar_pow(nex_scalar *& r, nex_pow& p) {
|
||||
if (!p.e()->is_scalar())
|
||||
return false;
|
||||
nex_scalar *pe = to_scalar(p.e());
|
||||
|
@ -96,18 +126,18 @@ bool eat_scalar(nex_scalar *& r, nex_pow& p) {
|
|||
return true;
|
||||
}
|
||||
|
||||
void simplify_children_of_mul(vector<nex_pow> & children, nex_lt lt) {
|
||||
void simplify_children_of_mul(vector<nex_pow> & children, nex_lt lt, std::function<nex_scalar*()> mk_scalar) {
|
||||
nex_scalar* r = nullptr;
|
||||
TRACE("nla_cn_details", print_vector(children, tout););
|
||||
vector<nex_pow> to_promote;
|
||||
int skipped = 0;
|
||||
for(unsigned j = 0; j < children.size(); j++) {
|
||||
nex_pow& p = children[j];
|
||||
if (eat_scalar(r, p)) {
|
||||
if (eat_scalar_pow(r, p)) {
|
||||
skipped++;
|
||||
continue;
|
||||
}
|
||||
(p.e())->simplify(p.ee(), lt);
|
||||
(p.e())->simplify(p.ee(), lt, mk_scalar );
|
||||
if ((p.e())->is_mul()) {
|
||||
to_promote.push_back(p);
|
||||
} else {
|
||||
|
@ -122,7 +152,7 @@ void simplify_children_of_mul(vector<nex_pow> & children, nex_lt lt) {
|
|||
|
||||
for (nex_pow & p : to_promote) {
|
||||
for (nex_pow& pp : to_mul(p.e())->children()) {
|
||||
if (!eat_scalar(r, pp))
|
||||
if (!eat_scalar_pow(r, pp))
|
||||
children.push_back(nex_pow(pp.e(), pp.pow() * p.pow()));
|
||||
}
|
||||
}
|
||||
|
@ -133,7 +163,36 @@ void simplify_children_of_mul(vector<nex_pow> & children, nex_lt lt) {
|
|||
|
||||
mul_to_powers(children, lt);
|
||||
|
||||
TRACE("nla_cn_details", print_vector(children, tout););
|
||||
|
||||
TRACE("nla_cn_details", print_vector(children, tout););
|
||||
}
|
||||
|
||||
bool less_than_nex(const nex* a, const nex* b, lt_on_vars lt) {
|
||||
int r = (int)(a->type()) - (int)(b->type());
|
||||
if (r) {
|
||||
return r < 0;
|
||||
}
|
||||
SASSERT(a->type() == b->type());
|
||||
switch (a->type()) {
|
||||
case expr_type::VAR: {
|
||||
return lt(to_var(a)->var() , to_var(b)->var());
|
||||
}
|
||||
case expr_type::SCALAR: {
|
||||
return to_scalar(a)->value() < to_scalar(b)->value();
|
||||
}
|
||||
case expr_type::MUL: {
|
||||
NOT_IMPLEMENTED_YET();
|
||||
return false; // to_mul(a)->children() < to_mul(b)->children();
|
||||
}
|
||||
case expr_type::SUM: {
|
||||
NOT_IMPLEMENTED_YET();
|
||||
return false; //to_sum(a)->children() < to_sum(b)->children();
|
||||
}
|
||||
default:
|
||||
SASSERT(false);
|
||||
return false;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -52,6 +52,7 @@ inline std::ostream & operator<<(std::ostream& out, expr_type t) {
|
|||
class nex;
|
||||
bool less_than_nex_standard(const nex* a, const nex* b);
|
||||
|
||||
class nex_scalar;
|
||||
// This is the class of non-linear expressions
|
||||
class nex {
|
||||
public:
|
||||
|
@ -78,8 +79,8 @@ public:
|
|||
virtual bool contains(lpvar j) const { return false; }
|
||||
virtual int get_degree() const = 0;
|
||||
// simplifies the expression and also assigns the address of "this" to *e
|
||||
virtual void simplify(nex** e, nex_lt) { *e = this; }
|
||||
void simplify(nex** e) { return simplify(e, less_than_nex_standard); }
|
||||
virtual void simplify(nex** e, nex_lt, std::function<nex_scalar*()>) = 0;
|
||||
void simplify(nex** e, std::function<nex_scalar*()> mk_scalar) { return simplify(e, less_than_nex_standard, mk_scalar); }
|
||||
virtual bool is_simplified(nex_lt) const {
|
||||
return true;
|
||||
}
|
||||
|
@ -115,6 +116,7 @@ public:
|
|||
bool contains(lpvar j) const { return j == m_j; }
|
||||
int get_degree() const { return 1; }
|
||||
bool virtual is_linear() const { return true; }
|
||||
void simplify(nex** e, nex_lt, std::function<nex_scalar*()>) {*e = this;}
|
||||
};
|
||||
|
||||
class nex_scalar : public nex {
|
||||
|
@ -132,6 +134,7 @@ public:
|
|||
|
||||
int get_degree() const { return 0; }
|
||||
bool is_linear() const { return true; }
|
||||
void simplify(nex** e, nex_lt, std::function<nex_scalar*()>) {*e = this;}
|
||||
|
||||
};
|
||||
|
||||
|
@ -139,9 +142,9 @@ const nex_scalar * to_scalar(const nex* a);
|
|||
class nex_sum;
|
||||
const nex_sum* to_sum(const nex*a);
|
||||
|
||||
void promote_children_of_sum(ptr_vector<nex> & children, nex_lt);
|
||||
void simplify_children_of_sum(ptr_vector<nex> & children, nex_lt, std::function<nex_scalar*()>);
|
||||
class nex_pow;
|
||||
void simplify_children_of_mul(vector<nex_pow> & children, nex_lt);
|
||||
void simplify_children_of_mul(vector<nex_pow> & children, nex_lt, std::function<nex_scalar*()>);
|
||||
|
||||
class nex_pow {
|
||||
nex* m_e;
|
||||
|
@ -238,12 +241,12 @@ public:
|
|||
return degree;
|
||||
}
|
||||
// the second argument is the comparison less than operator
|
||||
void simplify(nex **e, nex_lt lt) {
|
||||
void simplify(nex **e, nex_lt lt, std::function<nex_scalar*()> mk_scalar) {
|
||||
TRACE("nla_cn_details", tout << *this << "\n";);
|
||||
TRACE("nla_cn_details", tout << "**e = " << **e << "\n";);
|
||||
*e = this;
|
||||
TRACE("nla_cn_details", tout << *this << "\n";);
|
||||
simplify_children_of_mul(m_children, lt);
|
||||
simplify_children_of_mul(m_children, lt, mk_scalar);
|
||||
if (size() == 1 && m_children[0].pow() == 1)
|
||||
*e = m_children[0].e();
|
||||
TRACE("nla_cn_details", tout << *this << "\n";);
|
||||
|
@ -361,9 +364,9 @@ public:
|
|||
return out;
|
||||
}
|
||||
|
||||
void simplify(nex **e, nex_lt lt ) {
|
||||
void simplify(nex **e, nex_lt lt, std::function<nex_scalar*()> mk_scalar) {
|
||||
*e = this;
|
||||
promote_children_of_sum(m_children, lt);
|
||||
simplify_children_of_sum(m_children, lt, mk_scalar);
|
||||
if (size() == 1)
|
||||
*e = m_children[0];
|
||||
}
|
||||
|
@ -444,37 +447,11 @@ inline std::ostream& operator<<(std::ostream& out, const nex& e ) {
|
|||
}
|
||||
|
||||
|
||||
inline bool less_than_nex(const nex* a, const nex* b, lt_on_vars lt) {
|
||||
int r = (int)(a->type()) - (int)(b->type());
|
||||
if (r) {
|
||||
return r < 0;
|
||||
}
|
||||
// here a and b have the same type
|
||||
switch (a->type()) {
|
||||
case expr_type::VAR: {
|
||||
return lt(to_var(a)->var() , to_var(b)->var());
|
||||
}
|
||||
case expr_type::SCALAR: {
|
||||
return to_scalar(a)->value() < to_scalar(b)->value();
|
||||
}
|
||||
case expr_type::MUL: {
|
||||
NOT_IMPLEMENTED_YET();
|
||||
return false; // to_mul(a)->children() < to_mul(b)->children();
|
||||
}
|
||||
case expr_type::SUM: {
|
||||
NOT_IMPLEMENTED_YET();
|
||||
return false; //to_sum(a)->children() < to_sum(b)->children();
|
||||
}
|
||||
default:
|
||||
SASSERT(false);
|
||||
return false;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
bool less_than_nex(const nex* a, const nex* b, lt_on_vars lt);
|
||||
|
||||
inline bool less_than_nex_standard(const nex* a, const nex* b) {
|
||||
return less_than_nex(a, b, [](lpvar j, lpvar k) { return j < k; });
|
||||
lt_on_vars lt = [](lpvar j, lpvar k) { return j < k; };
|
||||
return less_than_nex(a, b, lt);
|
||||
}
|
||||
|
||||
#if Z3DEBUG
|
||||
|
|
|
@ -170,7 +170,8 @@ private:
|
|||
}
|
||||
|
||||
bool less_than_on_expr(const nex* a, const nex* b) const {
|
||||
return less_than_nex(a, b, [this](lpvar j, lpvar k) {return less_than_on_vars(j, k);});
|
||||
lt_on_vars lt = [this](lpvar j, lpvar k) {return less_than_on_vars(j, k);};
|
||||
return less_than_nex(a, b, lt);
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -85,26 +85,31 @@ void test_simplify() {
|
|||
);
|
||||
enable_trace("nla_cn");
|
||||
enable_trace("nla_cn_details");
|
||||
auto & creator = cn.get_nex_creator();
|
||||
nex_var* a = creator.mk_var(0);
|
||||
nex_var* b = creator.mk_var(1);
|
||||
nex_var* c = creator.mk_var(2);
|
||||
auto m = creator.mk_mul(); m->add_child_in_power(c, 2);
|
||||
nex_creator & r = cn.get_nex_creator();
|
||||
nex_var* a = r.mk_var(0);
|
||||
nex_var* b = r.mk_var(1);
|
||||
nex_var* c = r.mk_var(2);
|
||||
auto m = r.mk_mul(); m->add_child_in_power(c, 2);
|
||||
TRACE("nla_cn", tout << "m = " << *m << "\n";);
|
||||
auto n = creator.mk_mul(a);
|
||||
auto n = r.mk_mul(a);
|
||||
n->add_child_in_power(b, 7);
|
||||
n->add_child(creator.mk_scalar(rational(3)));
|
||||
n->add_child_in_power(creator.mk_scalar(rational(4)), 2);
|
||||
n->add_child(creator.mk_scalar(rational(1)));
|
||||
n->add_child(r.mk_scalar(rational(3)));
|
||||
n->add_child_in_power(r.mk_scalar(rational(4)), 2);
|
||||
n->add_child(r.mk_scalar(rational(1)));
|
||||
TRACE("nla_cn", tout << "n = " << *n << "\n";);
|
||||
m->add_child_in_power(n, 3);
|
||||
n->add_child_in_power(creator.mk_scalar(rational(1, 3)), 2);
|
||||
n->add_child_in_power(r.mk_scalar(rational(1, 3)), 2);
|
||||
TRACE("nla_cn", tout << "m = " << *m << "\n";);
|
||||
|
||||
nex * e = creator.mk_sum(a, creator.mk_sum(b, m));
|
||||
nex * e = r.mk_sum(a, r.mk_sum(b, m));
|
||||
TRACE("nla_cn", tout << "e = " << *e << "\n";);
|
||||
e->simplify(&e);
|
||||
std::function<nex_scalar*()> mks = [&r] {return r.mk_scalar(rational(1)); };
|
||||
e->simplify(&e, mks);
|
||||
TRACE("nla_cn", tout << "simplified e = " << *e << "\n";);
|
||||
nex * l = r.mk_sum(e, r.mk_mul(r.mk_scalar(rational(3)), r.clone(e)));
|
||||
TRACE("nla_cn", tout << "sum l = " << *l << "\n";);
|
||||
l->simplify(&l, mks);
|
||||
TRACE("nla_cn", tout << "simplified sum l = " << *l << "\n";);
|
||||
}
|
||||
|
||||
void test_cn() {
|
||||
|
@ -142,7 +147,7 @@ void test_cn() {
|
|||
nex* _6aad = cn.get_nex_creator().mk_mul(cn.get_nex_creator().mk_scalar(rational(6)), a, a, d);
|
||||
#ifdef Z3DEBUG
|
||||
nex * clone = cn.get_nex_creator().clone(cn.get_nex_creator().mk_sum(_6aad, abcd, aaccd, add, eae, eac, ed));
|
||||
clone->simplify(&clone);
|
||||
clone->simplify(&clone,[&cn] {return cn.get_nex_creator().mk_scalar(rational(1));});
|
||||
SASSERT(clone->is_simplified());
|
||||
TRACE("nla_cn", tout << "clone = " << *clone << "\n";);
|
||||
#endif
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue