3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-15 13:28:47 +00:00

sort nla_expr

Signed-off-by: Lev Nachmanson <levnach@hotmail.com>
This commit is contained in:
Lev Nachmanson 2019-07-15 15:07:59 -07:00
parent 00d366e3b8
commit 8eaa2bfb02
5 changed files with 103 additions and 20 deletions

View file

@ -124,7 +124,7 @@ interv horner::interval_of_expr(const nex& e) {
}
}
interv horner::interval_of_mul(const vector<nex>& es) {
interv horner::interval_of_mul(const std::vector<nex>& es) {
interv a = interval_of_expr(es[0]);
// std::cout << "a" << std::endl;
TRACE("nla_cn_details", tout << "es[0]= "<< es[0] << std::endl << "a = "; m_intervals.display(tout, a); tout << "\n";);
@ -150,7 +150,7 @@ interv horner::interval_of_mul(const vector<nex>& es) {
return a;
}
interv horner::interval_of_sum(const vector<nex>& es) {
interv horner::interval_of_sum(const std::vector<nex>& es) {
interv a = interval_of_expr(es[0]);
TRACE("nla_cn_details", tout << "es[0]= " << es[0] << "\n"; m_intervals.display(tout, a) << "\n";);
if (m_intervals.is_inf(a)) {

View file

@ -41,8 +41,8 @@ public:
intervals::interval interval_of_expr(const nex& e);
nex nexvar(lpvar j) const;
intervals::interval interval_of_sum(const vector<nex>&);
intervals::interval interval_of_mul(const vector<nex>&);
intervals::interval interval_of_sum(const std::vector<nex>&);
intervals::interval interval_of_mul(const std::vector<nex>&);
void set_interval_for_scalar(intervals::interval&, const rational&);
void set_var_interval(lpvar j, intervals::interval&);
std::set<lpvar> get_vars_of_expr(const nex &) const;

View file

@ -44,14 +44,40 @@ inline std::ostream & operator<<(std::ostream& out, expr_type t) {
}
return out;
}
// This class is needed in horner calculation with intervals
template <typename T>
class nla_expr {
class sorted_children {
std::vector<nla_expr> m_es;
// m_order will be sorted according to the non-decreasing order of m_es
svector<unsigned> m_order;
public:
const std::vector<nla_expr>& es() const { return m_es; }
std::vector<nla_expr>& es() { return m_es; }
void push_back(const nla_expr& e) {
SASSERT(m_es.size() == m_order.size());
m_order.push_back(m_es.size());
m_es.push_back(e);
}
const svector<unsigned>& order() const { return m_order; }
const nla_expr& back() const { return m_es.back(); }
nla_expr& back() { return m_es.back(); }
const nla_expr* begin() const { return m_es.begin(); }
const nla_expr* end() const { return m_es.end(); }
unsigned size() const { return m_es.size(); }
void sort() {
std::sort(m_order.begin(), m_order.end(), [this](unsigned i, unsigned j) { return m_es[i] < m_es[j]; });
}
};
// todo: use union
expr_type m_type;
lpvar m_j;
T m_v; // for the scalar
vector<nla_expr> m_children;
sorted_children m_children;
public:
bool is_sum() const { return m_type == expr_type::SUM; }
bool is_var() const { return m_type == expr_type::VAR; }
@ -60,13 +86,16 @@ public:
lpvar var() const { SASSERT(m_type == expr_type::VAR); return m_j; }
expr_type type() const { return m_type; }
expr_type& type() { return m_type; }
const vector<nla_expr>& children() const { return m_children; }
vector<nla_expr>& children() { return m_children; }
const std::vector<nla_expr>& children() const { return m_children.es(); }
std::vector<nla_expr>& children() { return m_children.es(); }
const sorted_children & s_children() const { return m_children; }
sorted_children & s_children() { return m_children; }
const T& value() const { SASSERT(m_type == expr_type::SCALAR); return m_v; }
std::string str() const { std::stringstream ss; ss << *this; return ss.str(); }
std::ostream & print_sum(std::ostream& out) const {
bool first = true;
for (const nla_expr& v : m_children) {
for (unsigned j : m_children.order()) {
const nla_expr& v = m_children.es()[j];
std::string s = v.str();
if (first) {
first = false;
@ -88,9 +117,18 @@ public:
}
return out;
}
void sort () {
if (is_sum() || is_mul()) {
for (auto & e : m_children.es()) {
e.sort();
}
m_children.sort();
}
}
std::ostream & print_mul(std::ostream& out) const {
bool first = true;
for (const nla_expr& v : m_children) {
for (unsigned j : m_children.order()) {
const nla_expr& v = m_children.es()[j];
std::string s = v.str();
if (first) {
first = false;
@ -221,13 +259,35 @@ public:
}
}
return false;
}
}
bool operator<(const nla_expr& e) const {
if (type() != (e.type()))
return (int)type() < (int)(e.type());
SASSERT(type() == (e.type()));
switch(m_type) {
case expr_type::SUM:
case expr_type::MUL:
return m_children.es() < e.m_children.es();
case expr_type::VAR:
return m_j < e.m_j;
case expr_type::SCALAR:
return m_v < e.m_v;
default:
SASSERT(false);
return false;
}
}
};
template <typename T>
nla_expr<T> operator+(const nla_expr<T>& a, const nla_expr<T>& b) {
if (a.is_sum()) {
nla_expr<T> r(expr_type::SUM);
r.children() = a.children();
r.s_children() = a.s_children();
if (b.is_sum()) {
for (auto& e: b.children())
r.add_child(e);
@ -238,7 +298,7 @@ nla_expr<T> operator+(const nla_expr<T>& a, const nla_expr<T>& b) {
}
if (b.is_sum()) {
nla_expr<T> r(expr_type::SUM);
r.children() = b.children();
r.s_children() = b.s_children();
r.add_child(a);
return r;
}
@ -249,7 +309,7 @@ template <typename T>
nla_expr<T> operator*(const nla_expr<T>& a, const nla_expr<T>& b) {
if (a.is_mul()) {
nla_expr<T> r(expr_type::MUL);
r.children() = a.children();
r.s_children() = a.s_children();
if (b.is_mul()) {
for (auto& e: b.children())
r.add_child(e);
@ -260,7 +320,7 @@ nla_expr<T> operator*(const nla_expr<T>& a, const nla_expr<T>& b) {
}
if (b.is_mul()) {
nla_expr<T> r(expr_type::MUL);
r.children() = b.children();
r.s_children() = b.s_children();
r.add_child(a);
return r;
}

View file

@ -491,7 +491,9 @@ namespace smt {
std::stringstream strm;
strm << "lemma_" << (++m_lemma_id) << ".smt2";
std::ofstream out(strm.str());
TRACE("lemma", tout << strm.str() << "\n";);
TRACE("lemma", tout << strm.str() << "\n";
display_lemma_as_smt_problem(tout, num_antecedents, antecedents, num_eq_antecedents, eq_antecedents, consequent, logic);
);
display_lemma_as_smt_problem(out, num_antecedents, antecedents, num_eq_antecedents, eq_antecedents, consequent, logic);
out.close();

View file

@ -72,11 +72,32 @@ void test_cn() {
enable_trace("nla_cn");
// (a(a+(b+c)c+d)d + e(a(e+c)+d)
nex a = nex::var(0), b = nex::var(1), c = nex::var(2), d = nex::var(3), e = nex::var(4);
nex t = a*a*d + a*b*c*d + a*c*c*d + a*d*d + e*a*e + e*a*c + e*d;
std::cout << "t = " << t << "\n";
TRACE("nla_cn", tout << "t=" << t << '\n';);
cross_nested cn(t, [](const nex& n) { std::cout << n << "\n"; } );
cn.run();
{
nex t = a*a*d + a*b*c*d + a*c*c*d + a*d*d + e*a*e + e*a*c + e*d;
std::cout << "t = " << t << "\n";
TRACE("nla_cn", tout << "t=" << t << '\n';);
cross_nested cn(t, [](const nex& n) {
std::cout << n << "\n";
auto nn = n;
nn.sort();
std::cout << "ordered version\n" << nn << "\n______________________\n";
} );
cn.run();
}
{
nex t = a*b*d + a*b*c;
std::cout << "t = " << t << "\n";
TRACE("nla_cn", tout << "t=" << t << '\n';);
cross_nested cn(t, [](const nex& n) {
std::cout << n << "\n";
auto nn = n;
nn.sort();
std::cout << "ordered version\n" << nn << "\n______________________\n";
} );
cn.run();
}
}
} // end of namespace nla