From 9cba261a9c50d4bb3dce647b8edea76d447a97c1 Mon Sep 17 00:00:00 2001 From: Lev Nachmanson Date: Mon, 15 Jul 2019 20:27:30 -0700 Subject: [PATCH] simplify and order nla_expr Signed-off-by: Lev Nachmanson --- src/math/lp/nla_expr.h | 131 ++++++++++++++++++++++++++++++----------- src/test/lp/lp.cpp | 45 ++++++-------- 2 files changed, 115 insertions(+), 61 deletions(-) diff --git a/src/math/lp/nla_expr.h b/src/math/lp/nla_expr.h index 83a800249..bdb7a5d07 100644 --- a/src/math/lp/nla_expr.h +++ b/src/math/lp/nla_expr.h @@ -1,21 +1,21 @@ /*++ -Copyright (c) 2017 Microsoft Corporation + Copyright (c) 2017 Microsoft Corporation -Module Name: + Module Name: - + -Abstract: + Abstract: - + -Author: - Lev Nachmanson (levnach) + Author: + Lev Nachmanson (levnach) -Revision History: + Revision History: ---*/ + --*/ #pragma once #include #include "math/lp/nla_defs.h" @@ -50,28 +50,28 @@ inline std::ostream & operator<<(std::ostream& out, expr_type t) { template class nla_expr { -class sorted_children { - std::vector m_es; - // m_order will be sorted according to the non-decreasing order of m_es - svector m_order; -public: - const std::vector& es() const { return m_es; } - std::vector& es() { return m_es; } - void push_back(const nla_expr& e) { - SASSERT(m_es.size() == m_order.size()); - m_order.push_back(m_es.size()); - m_es.push_back(e); - } - const svector& order() const { return m_order; } - const nla_expr& back() const { return m_es.back(); } - nla_expr& back() { return m_es.back(); } - const nla_expr* begin() const { return m_es.begin(); } - const nla_expr* end() const { return m_es.end(); } - unsigned size() const { return m_es.size(); } - void sort() { - std::sort(m_order.begin(), m_order.end(), [this](unsigned i, unsigned j) { return m_es[i] < m_es[j]; }); - } -}; + class sorted_children { + std::vector m_es; + // m_order will be sorted according to the non-decreasing order of m_es + svector m_order; + public: + const std::vector& es() const { return m_es; } + std::vector& es() { return m_es; } + void push_back(const nla_expr& e) { + SASSERT(m_es.size() == m_order.size()); + m_order.push_back(m_es.size()); + m_es.push_back(e); + } + const svector& order() const { return m_order; } + const nla_expr& back() const { return m_es.back(); } + nla_expr& back() { return m_es.back(); } + const nla_expr* begin() const { return m_es.begin(); } + const nla_expr* end() const { return m_es.end(); } + unsigned size() const { return m_es.size(); } + void sort() { + std::sort(m_order.begin(), m_order.end(), [this](unsigned i, unsigned j) { return m_es[i] < m_es[j]; }); + } + }; // todo: use union expr_type m_type; @@ -125,6 +125,43 @@ public: m_children.sort(); } } + + void simplify() { + if (is_sum()) { + bool has_sum = false; + for (unsigned j = 0; j < m_children.es().size(); j++) { + auto& e = m_children.es()[j]; + e.simplify(); + if (e.is_sum()) + has_sum = true; + } + if (has_sum) { + nla_expr n(expr_type::SUM); + for (auto &e : m_children.es()) { + n += e; + } + *this = n; + } + + } else if (is_mul()) { + bool has_mul = false; + for (unsigned j = 0; j < m_children.es().size(); j++) { + auto& e = m_children.es()[j]; + e.simplify(); + if (e.is_mul()) + has_mul = true; + } + if (has_mul) { + nla_expr n(expr_type::MUL); + for (auto &e : m_children.es()) { + n *= e; + } + *this = n; + } + TRACE("nla_cn", tout << "simplified " << *this << "\n";); + } + } + std::ostream & print_mul(std::ostream& out) const { bool first = true; for (unsigned j : m_children.order()) { @@ -278,10 +315,36 @@ public: SASSERT(false); return false; } - - - } + + nla_expr& operator*=(const nla_expr& b) { + if (is_mul()) { + if (b.is_mul()) { + for (auto& e: b.children()) + add_child(e); + } else { + add_child(b); + } + return *this; + } + SASSERT(false); // not impl + return *this; + } + + nla_expr& operator+=(const nla_expr& b) { + if (is_sum()) { + if (b.is_sum()) { + for (auto& e: b.children()) + add_child(e); + } else { + add_child(b); + } + return *this; + } + SASSERT(false); // not impl + return *this; + } + }; template nla_expr operator+(const nla_expr& a, const nla_expr& b) { diff --git a/src/test/lp/lp.cpp b/src/test/lp/lp.cpp index 010e570a9..12c7f61ce 100644 --- a/src/test/lp/lp.cpp +++ b/src/test/lp/lp.cpp @@ -67,37 +67,28 @@ void test_basic_lemma_for_mon_zero_from_factors_to_monomial(); void test_basic_lemma_for_mon_neutral_from_monomial_to_factors(); void test_basic_lemma_for_mon_neutral_from_factors_to_monomial(); +void test_cn_on_expr(horner::nex t) { + TRACE("nla_cn", tout << "t=" << t << '\n';); + cross_nested cn(t, [](const horner::nex& n) { + TRACE("nla_cn", tout << n << "\n";); + auto nn = n; + nn.simplify(); + nn.sort(); + TRACE("nla_cn", tout << "ordered version\n" << nn << "\n______________________\n";); + + } ); + cn.run(); +} + void test_cn() { typedef horner::nex nex; enable_trace("nla_cn"); - // (a(a+(b+c)c+d)d + e(a(e+c)+d) + enable_trace("nla_cn_details"); nex a = nex::var(0), b = nex::var(1), c = nex::var(2), d = nex::var(3), e = nex::var(4); - { - nex t = a*a*d + a*b*c*d + a*c*c*d + a*d*d + e*a*e + e*a*c + e*d; - std::cout << "t = " << t << "\n"; - TRACE("nla_cn", tout << "t=" << t << '\n';); - cross_nested cn(t, [](const nex& n) { - std::cout << n << "\n"; - auto nn = n; - nn.sort(); - std::cout << "ordered version\n" << nn << "\n______________________\n"; - - } ); - cn.run(); - } - { - nex t = a*b*d + a*b*c; - std::cout << "t = " << t << "\n"; - TRACE("nla_cn", tout << "t=" << t << '\n';); - cross_nested cn(t, [](const nex& n) { - std::cout << n << "\n"; - auto nn = n; - nn.sort(); - std::cout << "ordered version\n" << nn << "\n______________________\n"; - - } ); - cn.run(); - } + + nex t = a*a*d + a*b*c*d + a*a*c*c*d + a*d*d + e*a*e + e*a*c + e*d; + test_cn_on_expr(t); + test_cn_on_expr(a*b*d + a*b*c); } } // end of namespace nla