mirror of
https://github.com/Z3Prover/z3
synced 2025-06-28 08:58:44 +00:00
sort nla_expr
Signed-off-by: Lev Nachmanson <levnach@hotmail.com>
This commit is contained in:
parent
00d366e3b8
commit
8eaa2bfb02
5 changed files with 103 additions and 20 deletions
|
@ -124,7 +124,7 @@ interv horner::interval_of_expr(const nex& e) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
interv horner::interval_of_mul(const vector<nex>& es) {
|
interv horner::interval_of_mul(const std::vector<nex>& es) {
|
||||||
interv a = interval_of_expr(es[0]);
|
interv a = interval_of_expr(es[0]);
|
||||||
// std::cout << "a" << std::endl;
|
// std::cout << "a" << std::endl;
|
||||||
TRACE("nla_cn_details", tout << "es[0]= "<< es[0] << std::endl << "a = "; m_intervals.display(tout, a); tout << "\n";);
|
TRACE("nla_cn_details", tout << "es[0]= "<< es[0] << std::endl << "a = "; m_intervals.display(tout, a); tout << "\n";);
|
||||||
|
@ -150,7 +150,7 @@ interv horner::interval_of_mul(const vector<nex>& es) {
|
||||||
return a;
|
return a;
|
||||||
}
|
}
|
||||||
|
|
||||||
interv horner::interval_of_sum(const vector<nex>& es) {
|
interv horner::interval_of_sum(const std::vector<nex>& es) {
|
||||||
interv a = interval_of_expr(es[0]);
|
interv a = interval_of_expr(es[0]);
|
||||||
TRACE("nla_cn_details", tout << "es[0]= " << es[0] << "\n"; m_intervals.display(tout, a) << "\n";);
|
TRACE("nla_cn_details", tout << "es[0]= " << es[0] << "\n"; m_intervals.display(tout, a) << "\n";);
|
||||||
if (m_intervals.is_inf(a)) {
|
if (m_intervals.is_inf(a)) {
|
||||||
|
|
|
@ -41,8 +41,8 @@ public:
|
||||||
intervals::interval interval_of_expr(const nex& e);
|
intervals::interval interval_of_expr(const nex& e);
|
||||||
|
|
||||||
nex nexvar(lpvar j) const;
|
nex nexvar(lpvar j) const;
|
||||||
intervals::interval interval_of_sum(const vector<nex>&);
|
intervals::interval interval_of_sum(const std::vector<nex>&);
|
||||||
intervals::interval interval_of_mul(const vector<nex>&);
|
intervals::interval interval_of_mul(const std::vector<nex>&);
|
||||||
void set_interval_for_scalar(intervals::interval&, const rational&);
|
void set_interval_for_scalar(intervals::interval&, const rational&);
|
||||||
void set_var_interval(lpvar j, intervals::interval&);
|
void set_var_interval(lpvar j, intervals::interval&);
|
||||||
std::set<lpvar> get_vars_of_expr(const nex &) const;
|
std::set<lpvar> get_vars_of_expr(const nex &) const;
|
||||||
|
|
|
@ -44,14 +44,40 @@ inline std::ostream & operator<<(std::ostream& out, expr_type t) {
|
||||||
}
|
}
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// This class is needed in horner calculation with intervals
|
// This class is needed in horner calculation with intervals
|
||||||
template <typename T>
|
template <typename T>
|
||||||
class nla_expr {
|
class nla_expr {
|
||||||
|
|
||||||
|
class sorted_children {
|
||||||
|
std::vector<nla_expr> m_es;
|
||||||
|
// m_order will be sorted according to the non-decreasing order of m_es
|
||||||
|
svector<unsigned> m_order;
|
||||||
|
public:
|
||||||
|
const std::vector<nla_expr>& es() const { return m_es; }
|
||||||
|
std::vector<nla_expr>& es() { return m_es; }
|
||||||
|
void push_back(const nla_expr& e) {
|
||||||
|
SASSERT(m_es.size() == m_order.size());
|
||||||
|
m_order.push_back(m_es.size());
|
||||||
|
m_es.push_back(e);
|
||||||
|
}
|
||||||
|
const svector<unsigned>& order() const { return m_order; }
|
||||||
|
const nla_expr& back() const { return m_es.back(); }
|
||||||
|
nla_expr& back() { return m_es.back(); }
|
||||||
|
const nla_expr* begin() const { return m_es.begin(); }
|
||||||
|
const nla_expr* end() const { return m_es.end(); }
|
||||||
|
unsigned size() const { return m_es.size(); }
|
||||||
|
void sort() {
|
||||||
|
std::sort(m_order.begin(), m_order.end(), [this](unsigned i, unsigned j) { return m_es[i] < m_es[j]; });
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// todo: use union
|
// todo: use union
|
||||||
expr_type m_type;
|
expr_type m_type;
|
||||||
lpvar m_j;
|
lpvar m_j;
|
||||||
T m_v; // for the scalar
|
T m_v; // for the scalar
|
||||||
vector<nla_expr> m_children;
|
sorted_children m_children;
|
||||||
public:
|
public:
|
||||||
bool is_sum() const { return m_type == expr_type::SUM; }
|
bool is_sum() const { return m_type == expr_type::SUM; }
|
||||||
bool is_var() const { return m_type == expr_type::VAR; }
|
bool is_var() const { return m_type == expr_type::VAR; }
|
||||||
|
@ -60,13 +86,16 @@ public:
|
||||||
lpvar var() const { SASSERT(m_type == expr_type::VAR); return m_j; }
|
lpvar var() const { SASSERT(m_type == expr_type::VAR); return m_j; }
|
||||||
expr_type type() const { return m_type; }
|
expr_type type() const { return m_type; }
|
||||||
expr_type& type() { return m_type; }
|
expr_type& type() { return m_type; }
|
||||||
const vector<nla_expr>& children() const { return m_children; }
|
const std::vector<nla_expr>& children() const { return m_children.es(); }
|
||||||
vector<nla_expr>& children() { return m_children; }
|
std::vector<nla_expr>& children() { return m_children.es(); }
|
||||||
|
const sorted_children & s_children() const { return m_children; }
|
||||||
|
sorted_children & s_children() { return m_children; }
|
||||||
const T& value() const { SASSERT(m_type == expr_type::SCALAR); return m_v; }
|
const T& value() const { SASSERT(m_type == expr_type::SCALAR); return m_v; }
|
||||||
std::string str() const { std::stringstream ss; ss << *this; return ss.str(); }
|
std::string str() const { std::stringstream ss; ss << *this; return ss.str(); }
|
||||||
std::ostream & print_sum(std::ostream& out) const {
|
std::ostream & print_sum(std::ostream& out) const {
|
||||||
bool first = true;
|
bool first = true;
|
||||||
for (const nla_expr& v : m_children) {
|
for (unsigned j : m_children.order()) {
|
||||||
|
const nla_expr& v = m_children.es()[j];
|
||||||
std::string s = v.str();
|
std::string s = v.str();
|
||||||
if (first) {
|
if (first) {
|
||||||
first = false;
|
first = false;
|
||||||
|
@ -88,9 +117,18 @@ public:
|
||||||
}
|
}
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
void sort () {
|
||||||
|
if (is_sum() || is_mul()) {
|
||||||
|
for (auto & e : m_children.es()) {
|
||||||
|
e.sort();
|
||||||
|
}
|
||||||
|
m_children.sort();
|
||||||
|
}
|
||||||
|
}
|
||||||
std::ostream & print_mul(std::ostream& out) const {
|
std::ostream & print_mul(std::ostream& out) const {
|
||||||
bool first = true;
|
bool first = true;
|
||||||
for (const nla_expr& v : m_children) {
|
for (unsigned j : m_children.order()) {
|
||||||
|
const nla_expr& v = m_children.es()[j];
|
||||||
std::string s = v.str();
|
std::string s = v.str();
|
||||||
if (first) {
|
if (first) {
|
||||||
first = false;
|
first = false;
|
||||||
|
@ -221,13 +259,35 @@ public:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
bool operator<(const nla_expr& e) const {
|
||||||
|
if (type() != (e.type()))
|
||||||
|
return (int)type() < (int)(e.type());
|
||||||
|
|
||||||
|
SASSERT(type() == (e.type()));
|
||||||
|
|
||||||
|
switch(m_type) {
|
||||||
|
case expr_type::SUM:
|
||||||
|
case expr_type::MUL:
|
||||||
|
return m_children.es() < e.m_children.es();
|
||||||
|
case expr_type::VAR:
|
||||||
|
return m_j < e.m_j;
|
||||||
|
case expr_type::SCALAR:
|
||||||
|
return m_v < e.m_v;
|
||||||
|
default:
|
||||||
|
SASSERT(false);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
};
|
};
|
||||||
template <typename T>
|
template <typename T>
|
||||||
nla_expr<T> operator+(const nla_expr<T>& a, const nla_expr<T>& b) {
|
nla_expr<T> operator+(const nla_expr<T>& a, const nla_expr<T>& b) {
|
||||||
if (a.is_sum()) {
|
if (a.is_sum()) {
|
||||||
nla_expr<T> r(expr_type::SUM);
|
nla_expr<T> r(expr_type::SUM);
|
||||||
r.children() = a.children();
|
r.s_children() = a.s_children();
|
||||||
if (b.is_sum()) {
|
if (b.is_sum()) {
|
||||||
for (auto& e: b.children())
|
for (auto& e: b.children())
|
||||||
r.add_child(e);
|
r.add_child(e);
|
||||||
|
@ -238,7 +298,7 @@ nla_expr<T> operator+(const nla_expr<T>& a, const nla_expr<T>& b) {
|
||||||
}
|
}
|
||||||
if (b.is_sum()) {
|
if (b.is_sum()) {
|
||||||
nla_expr<T> r(expr_type::SUM);
|
nla_expr<T> r(expr_type::SUM);
|
||||||
r.children() = b.children();
|
r.s_children() = b.s_children();
|
||||||
r.add_child(a);
|
r.add_child(a);
|
||||||
return r;
|
return r;
|
||||||
}
|
}
|
||||||
|
@ -249,7 +309,7 @@ template <typename T>
|
||||||
nla_expr<T> operator*(const nla_expr<T>& a, const nla_expr<T>& b) {
|
nla_expr<T> operator*(const nla_expr<T>& a, const nla_expr<T>& b) {
|
||||||
if (a.is_mul()) {
|
if (a.is_mul()) {
|
||||||
nla_expr<T> r(expr_type::MUL);
|
nla_expr<T> r(expr_type::MUL);
|
||||||
r.children() = a.children();
|
r.s_children() = a.s_children();
|
||||||
if (b.is_mul()) {
|
if (b.is_mul()) {
|
||||||
for (auto& e: b.children())
|
for (auto& e: b.children())
|
||||||
r.add_child(e);
|
r.add_child(e);
|
||||||
|
@ -260,7 +320,7 @@ nla_expr<T> operator*(const nla_expr<T>& a, const nla_expr<T>& b) {
|
||||||
}
|
}
|
||||||
if (b.is_mul()) {
|
if (b.is_mul()) {
|
||||||
nla_expr<T> r(expr_type::MUL);
|
nla_expr<T> r(expr_type::MUL);
|
||||||
r.children() = b.children();
|
r.s_children() = b.s_children();
|
||||||
r.add_child(a);
|
r.add_child(a);
|
||||||
return r;
|
return r;
|
||||||
}
|
}
|
||||||
|
|
|
@ -491,7 +491,9 @@ namespace smt {
|
||||||
std::stringstream strm;
|
std::stringstream strm;
|
||||||
strm << "lemma_" << (++m_lemma_id) << ".smt2";
|
strm << "lemma_" << (++m_lemma_id) << ".smt2";
|
||||||
std::ofstream out(strm.str());
|
std::ofstream out(strm.str());
|
||||||
TRACE("lemma", tout << strm.str() << "\n";);
|
TRACE("lemma", tout << strm.str() << "\n";
|
||||||
|
display_lemma_as_smt_problem(tout, num_antecedents, antecedents, num_eq_antecedents, eq_antecedents, consequent, logic);
|
||||||
|
);
|
||||||
display_lemma_as_smt_problem(out, num_antecedents, antecedents, num_eq_antecedents, eq_antecedents, consequent, logic);
|
display_lemma_as_smt_problem(out, num_antecedents, antecedents, num_eq_antecedents, eq_antecedents, consequent, logic);
|
||||||
out.close();
|
out.close();
|
||||||
|
|
||||||
|
|
|
@ -72,11 +72,32 @@ void test_cn() {
|
||||||
enable_trace("nla_cn");
|
enable_trace("nla_cn");
|
||||||
// (a(a+(b+c)c+d)d + e(a(e+c)+d)
|
// (a(a+(b+c)c+d)d + e(a(e+c)+d)
|
||||||
nex a = nex::var(0), b = nex::var(1), c = nex::var(2), d = nex::var(3), e = nex::var(4);
|
nex a = nex::var(0), b = nex::var(1), c = nex::var(2), d = nex::var(3), e = nex::var(4);
|
||||||
nex t = a*a*d + a*b*c*d + a*c*c*d + a*d*d + e*a*e + e*a*c + e*d;
|
{
|
||||||
std::cout << "t = " << t << "\n";
|
nex t = a*a*d + a*b*c*d + a*c*c*d + a*d*d + e*a*e + e*a*c + e*d;
|
||||||
TRACE("nla_cn", tout << "t=" << t << '\n';);
|
std::cout << "t = " << t << "\n";
|
||||||
cross_nested cn(t, [](const nex& n) { std::cout << n << "\n"; } );
|
TRACE("nla_cn", tout << "t=" << t << '\n';);
|
||||||
cn.run();
|
cross_nested cn(t, [](const nex& n) {
|
||||||
|
std::cout << n << "\n";
|
||||||
|
auto nn = n;
|
||||||
|
nn.sort();
|
||||||
|
std::cout << "ordered version\n" << nn << "\n______________________\n";
|
||||||
|
|
||||||
|
} );
|
||||||
|
cn.run();
|
||||||
|
}
|
||||||
|
{
|
||||||
|
nex t = a*b*d + a*b*c;
|
||||||
|
std::cout << "t = " << t << "\n";
|
||||||
|
TRACE("nla_cn", tout << "t=" << t << '\n';);
|
||||||
|
cross_nested cn(t, [](const nex& n) {
|
||||||
|
std::cout << n << "\n";
|
||||||
|
auto nn = n;
|
||||||
|
nn.sort();
|
||||||
|
std::cout << "ordered version\n" << nn << "\n______________________\n";
|
||||||
|
|
||||||
|
} );
|
||||||
|
cn.run();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // end of namespace nla
|
} // end of namespace nla
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue