mirror of
https://github.com/Z3Prover/z3
synced 2025-04-15 13:28:47 +00:00
sort nla_expr
Signed-off-by: Lev Nachmanson <levnach@hotmail.com>
This commit is contained in:
parent
00d366e3b8
commit
8eaa2bfb02
|
@ -124,7 +124,7 @@ interv horner::interval_of_expr(const nex& e) {
|
|||
}
|
||||
}
|
||||
|
||||
interv horner::interval_of_mul(const vector<nex>& es) {
|
||||
interv horner::interval_of_mul(const std::vector<nex>& es) {
|
||||
interv a = interval_of_expr(es[0]);
|
||||
// std::cout << "a" << std::endl;
|
||||
TRACE("nla_cn_details", tout << "es[0]= "<< es[0] << std::endl << "a = "; m_intervals.display(tout, a); tout << "\n";);
|
||||
|
@ -150,7 +150,7 @@ interv horner::interval_of_mul(const vector<nex>& es) {
|
|||
return a;
|
||||
}
|
||||
|
||||
interv horner::interval_of_sum(const vector<nex>& es) {
|
||||
interv horner::interval_of_sum(const std::vector<nex>& es) {
|
||||
interv a = interval_of_expr(es[0]);
|
||||
TRACE("nla_cn_details", tout << "es[0]= " << es[0] << "\n"; m_intervals.display(tout, a) << "\n";);
|
||||
if (m_intervals.is_inf(a)) {
|
||||
|
|
|
@ -41,8 +41,8 @@ public:
|
|||
intervals::interval interval_of_expr(const nex& e);
|
||||
|
||||
nex nexvar(lpvar j) const;
|
||||
intervals::interval interval_of_sum(const vector<nex>&);
|
||||
intervals::interval interval_of_mul(const vector<nex>&);
|
||||
intervals::interval interval_of_sum(const std::vector<nex>&);
|
||||
intervals::interval interval_of_mul(const std::vector<nex>&);
|
||||
void set_interval_for_scalar(intervals::interval&, const rational&);
|
||||
void set_var_interval(lpvar j, intervals::interval&);
|
||||
std::set<lpvar> get_vars_of_expr(const nex &) const;
|
||||
|
|
|
@ -44,14 +44,40 @@ inline std::ostream & operator<<(std::ostream& out, expr_type t) {
|
|||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
// This class is needed in horner calculation with intervals
|
||||
template <typename T>
|
||||
class nla_expr {
|
||||
|
||||
class sorted_children {
|
||||
std::vector<nla_expr> m_es;
|
||||
// m_order will be sorted according to the non-decreasing order of m_es
|
||||
svector<unsigned> m_order;
|
||||
public:
|
||||
const std::vector<nla_expr>& es() const { return m_es; }
|
||||
std::vector<nla_expr>& es() { return m_es; }
|
||||
void push_back(const nla_expr& e) {
|
||||
SASSERT(m_es.size() == m_order.size());
|
||||
m_order.push_back(m_es.size());
|
||||
m_es.push_back(e);
|
||||
}
|
||||
const svector<unsigned>& order() const { return m_order; }
|
||||
const nla_expr& back() const { return m_es.back(); }
|
||||
nla_expr& back() { return m_es.back(); }
|
||||
const nla_expr* begin() const { return m_es.begin(); }
|
||||
const nla_expr* end() const { return m_es.end(); }
|
||||
unsigned size() const { return m_es.size(); }
|
||||
void sort() {
|
||||
std::sort(m_order.begin(), m_order.end(), [this](unsigned i, unsigned j) { return m_es[i] < m_es[j]; });
|
||||
}
|
||||
};
|
||||
|
||||
// todo: use union
|
||||
expr_type m_type;
|
||||
lpvar m_j;
|
||||
T m_v; // for the scalar
|
||||
vector<nla_expr> m_children;
|
||||
sorted_children m_children;
|
||||
public:
|
||||
bool is_sum() const { return m_type == expr_type::SUM; }
|
||||
bool is_var() const { return m_type == expr_type::VAR; }
|
||||
|
@ -60,13 +86,16 @@ public:
|
|||
lpvar var() const { SASSERT(m_type == expr_type::VAR); return m_j; }
|
||||
expr_type type() const { return m_type; }
|
||||
expr_type& type() { return m_type; }
|
||||
const vector<nla_expr>& children() const { return m_children; }
|
||||
vector<nla_expr>& children() { return m_children; }
|
||||
const std::vector<nla_expr>& children() const { return m_children.es(); }
|
||||
std::vector<nla_expr>& children() { return m_children.es(); }
|
||||
const sorted_children & s_children() const { return m_children; }
|
||||
sorted_children & s_children() { return m_children; }
|
||||
const T& value() const { SASSERT(m_type == expr_type::SCALAR); return m_v; }
|
||||
std::string str() const { std::stringstream ss; ss << *this; return ss.str(); }
|
||||
std::ostream & print_sum(std::ostream& out) const {
|
||||
bool first = true;
|
||||
for (const nla_expr& v : m_children) {
|
||||
for (unsigned j : m_children.order()) {
|
||||
const nla_expr& v = m_children.es()[j];
|
||||
std::string s = v.str();
|
||||
if (first) {
|
||||
first = false;
|
||||
|
@ -88,9 +117,18 @@ public:
|
|||
}
|
||||
return out;
|
||||
}
|
||||
void sort () {
|
||||
if (is_sum() || is_mul()) {
|
||||
for (auto & e : m_children.es()) {
|
||||
e.sort();
|
||||
}
|
||||
m_children.sort();
|
||||
}
|
||||
}
|
||||
std::ostream & print_mul(std::ostream& out) const {
|
||||
bool first = true;
|
||||
for (const nla_expr& v : m_children) {
|
||||
for (unsigned j : m_children.order()) {
|
||||
const nla_expr& v = m_children.es()[j];
|
||||
std::string s = v.str();
|
||||
if (first) {
|
||||
first = false;
|
||||
|
@ -221,13 +259,35 @@ public:
|
|||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
bool operator<(const nla_expr& e) const {
|
||||
if (type() != (e.type()))
|
||||
return (int)type() < (int)(e.type());
|
||||
|
||||
SASSERT(type() == (e.type()));
|
||||
|
||||
switch(m_type) {
|
||||
case expr_type::SUM:
|
||||
case expr_type::MUL:
|
||||
return m_children.es() < e.m_children.es();
|
||||
case expr_type::VAR:
|
||||
return m_j < e.m_j;
|
||||
case expr_type::SCALAR:
|
||||
return m_v < e.m_v;
|
||||
default:
|
||||
SASSERT(false);
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
};
|
||||
template <typename T>
|
||||
nla_expr<T> operator+(const nla_expr<T>& a, const nla_expr<T>& b) {
|
||||
if (a.is_sum()) {
|
||||
nla_expr<T> r(expr_type::SUM);
|
||||
r.children() = a.children();
|
||||
r.s_children() = a.s_children();
|
||||
if (b.is_sum()) {
|
||||
for (auto& e: b.children())
|
||||
r.add_child(e);
|
||||
|
@ -238,7 +298,7 @@ nla_expr<T> operator+(const nla_expr<T>& a, const nla_expr<T>& b) {
|
|||
}
|
||||
if (b.is_sum()) {
|
||||
nla_expr<T> r(expr_type::SUM);
|
||||
r.children() = b.children();
|
||||
r.s_children() = b.s_children();
|
||||
r.add_child(a);
|
||||
return r;
|
||||
}
|
||||
|
@ -249,7 +309,7 @@ template <typename T>
|
|||
nla_expr<T> operator*(const nla_expr<T>& a, const nla_expr<T>& b) {
|
||||
if (a.is_mul()) {
|
||||
nla_expr<T> r(expr_type::MUL);
|
||||
r.children() = a.children();
|
||||
r.s_children() = a.s_children();
|
||||
if (b.is_mul()) {
|
||||
for (auto& e: b.children())
|
||||
r.add_child(e);
|
||||
|
@ -260,7 +320,7 @@ nla_expr<T> operator*(const nla_expr<T>& a, const nla_expr<T>& b) {
|
|||
}
|
||||
if (b.is_mul()) {
|
||||
nla_expr<T> r(expr_type::MUL);
|
||||
r.children() = b.children();
|
||||
r.s_children() = b.s_children();
|
||||
r.add_child(a);
|
||||
return r;
|
||||
}
|
||||
|
|
|
@ -491,7 +491,9 @@ namespace smt {
|
|||
std::stringstream strm;
|
||||
strm << "lemma_" << (++m_lemma_id) << ".smt2";
|
||||
std::ofstream out(strm.str());
|
||||
TRACE("lemma", tout << strm.str() << "\n";);
|
||||
TRACE("lemma", tout << strm.str() << "\n";
|
||||
display_lemma_as_smt_problem(tout, num_antecedents, antecedents, num_eq_antecedents, eq_antecedents, consequent, logic);
|
||||
);
|
||||
display_lemma_as_smt_problem(out, num_antecedents, antecedents, num_eq_antecedents, eq_antecedents, consequent, logic);
|
||||
out.close();
|
||||
|
||||
|
|
|
@ -72,11 +72,32 @@ void test_cn() {
|
|||
enable_trace("nla_cn");
|
||||
// (a(a+(b+c)c+d)d + e(a(e+c)+d)
|
||||
nex a = nex::var(0), b = nex::var(1), c = nex::var(2), d = nex::var(3), e = nex::var(4);
|
||||
nex t = a*a*d + a*b*c*d + a*c*c*d + a*d*d + e*a*e + e*a*c + e*d;
|
||||
std::cout << "t = " << t << "\n";
|
||||
TRACE("nla_cn", tout << "t=" << t << '\n';);
|
||||
cross_nested cn(t, [](const nex& n) { std::cout << n << "\n"; } );
|
||||
cn.run();
|
||||
{
|
||||
nex t = a*a*d + a*b*c*d + a*c*c*d + a*d*d + e*a*e + e*a*c + e*d;
|
||||
std::cout << "t = " << t << "\n";
|
||||
TRACE("nla_cn", tout << "t=" << t << '\n';);
|
||||
cross_nested cn(t, [](const nex& n) {
|
||||
std::cout << n << "\n";
|
||||
auto nn = n;
|
||||
nn.sort();
|
||||
std::cout << "ordered version\n" << nn << "\n______________________\n";
|
||||
|
||||
} );
|
||||
cn.run();
|
||||
}
|
||||
{
|
||||
nex t = a*b*d + a*b*c;
|
||||
std::cout << "t = " << t << "\n";
|
||||
TRACE("nla_cn", tout << "t=" << t << '\n';);
|
||||
cross_nested cn(t, [](const nex& n) {
|
||||
std::cout << n << "\n";
|
||||
auto nn = n;
|
||||
nn.sort();
|
||||
std::cout << "ordered version\n" << nn << "\n______________________\n";
|
||||
|
||||
} );
|
||||
cn.run();
|
||||
}
|
||||
}
|
||||
|
||||
} // end of namespace nla
|
||||
|
|
Loading…
Reference in a new issue