diff --git a/src/ast/rewriter/bv_rewriter.cpp b/src/ast/rewriter/bv_rewriter.cpp index f77bce5c2..def05f014 100644 --- a/src/ast/rewriter/bv_rewriter.cpp +++ b/src/ast/rewriter/bv_rewriter.cpp @@ -2094,7 +2094,7 @@ br_status bv_rewriter::mk_eq_core(expr * lhs, expr * rhs, expr_ref & result) { if (m_trailing) { st = m_rm_trailing.eq_remove_trailing(lhs, rhs, result); - //m_rm_trailing.reset_cache(); + m_rm_trailing.reset_cache(1 << 12); if (st != BR_FAILED) { TRACE("eq_remove_trailing", tout << mk_ismt2_pp(lhs, m()) << "\n=\n" << mk_ismt2_pp(rhs, m()) << "\n----->\n" << mk_ismt2_pp(result, m()) << "\n";); return st; diff --git a/src/ast/rewriter/bv_trailing.cpp b/src/ast/rewriter/bv_trailing.cpp index ef6a1e577..ed2a17a73 100644 --- a/src/ast/rewriter/bv_trailing.cpp +++ b/src/ast/rewriter/bv_trailing.cpp @@ -18,6 +18,8 @@ #include"bv_decl_plugin.h" #include"ast_smt2_pp.h" +// This is not very elegant at this point, this number shouldn't be too big, +// give up analysis after TRAILING_DEPTH depth. #define TRAILING_DEPTH 4 struct bv_trailing::imp { @@ -35,38 +37,11 @@ struct bv_trailing::imp { { } virtual ~imp() { - reset_cache(); + reset_cache(0); } ast_manager & m() const { return m_util.get_manager(); } - void cache(unsigned depth, expr * e, unsigned min, unsigned max) { - SASSERT(depth <= TRAILING_DEPTH); - m().inc_ref(e); - m_count_cache[depth].insert(e, std::make_pair(min, max)); - TRACE("bv-trailing", tout << "caching@" << depth <<": " << mk_ismt2_pp(e, m()) << '[' << m_util.get_bv_size(e) << "]\n: " << min << '-' << max << "\n";); - } - - bool is_cached(unsigned depth, expr * e, unsigned& min, unsigned& max) { - SASSERT(depth <= TRAILING_DEPTH); - const map::obj_map_entry * const oe = m_count_cache[depth].find_core(e); - if (oe == NULL) return false; - min = oe->get_data().m_value.first; - max = oe->get_data().m_value.second; - TRACE("bv-trailing", tout << "cached@" << depth << ": " << mk_ismt2_pp(e, m()) << '[' << m_util.get_bv_size(e) << "]\n: " << min << '-' << max << "\n";); - return true; - } - - - void reset_cache() { - for (unsigned i = 0; i <= TRAILING_DEPTH; ++i) { - map::iterator it = m_count_cache[i].begin(); - map::iterator end = m_count_cache[i].end(); - for (; it != end; ++it) m().dec_ref(it->m_key); - m_count_cache[i].reset(); - } - } - br_status eq_remove_trailing(expr * e1, expr * e2, expr_ref& result) { TRACE("bv-trailing", tout << mk_ismt2_pp(e1, m()) << "\n=\n" << mk_ismt2_pp(e2, m()) << "\n";); SASSERT(m_util.is_bv(e1) && m_util.is_bv(e2)); @@ -74,28 +49,51 @@ struct bv_trailing::imp { unsigned max1, min1, max2, min2; count_trailing(e1, min1, max1, TRAILING_DEPTH); count_trailing(e2, min2, max2, TRAILING_DEPTH); - if (min1 > max2 || min2 > max1) { + if (min1 > max2 || min2 > max1) { // bounds have empty intersection result = m().mk_false(); return BR_DONE; } - const unsigned min = std::min(min1, min2); - if (min == 0) { + const unsigned min = std::min(min1, min2); // remove the minimum of the two lower bounds + if (min == 0) { // nothing to remove result = m().mk_eq(e1, e2); return BR_FAILED; } const unsigned sz = m_util.get_bv_size(e1); - if (min == sz) { // unlikely but we check anyhow for safety + if (min == sz) { // everything removed, unlikely but we check anyhow for safety result = m().mk_true(); return BR_DONE; } expr_ref out1(m()); expr_ref out2(m()); - remove_trailing(e1, min, out1, TRAILING_DEPTH); - remove_trailing(e2, min, out2, TRAILING_DEPTH); - result = m().mk_eq(out1, out2); - return BR_REWRITE2; + const unsigned rm1 = remove_trailing(e1, min, out1, TRAILING_DEPTH); + const unsigned rm2 = remove_trailing(e2, min, out2, TRAILING_DEPTH); + SASSERT(rm1 == min && rm2 == min); + const bool are_eq = m().are_equal(out1, out2); + result = are_eq ? m().mk_true() : m().mk_eq(out1, out2); + return are_eq ? BR_DONE : BR_REWRITE2; } + // This routine needs to be implemented carefully so that whenever it + // returns a lower bound on trailing zeros min, the routine remove_trailing + // must be capable of removing at least that many zeros from the expression. + void count_trailing(expr * e, unsigned& min, unsigned& max, unsigned depth) { + SASSERT(e && m_util.is_bv(e)); + if (is_cached(depth, e, min, max)) return; + count_trailing_core(e, min, max, depth); + TRACE("bv-trailing", tout << mk_ismt2_pp(e, m()) << "\n:" << min << " - " << max << "\n";); + SASSERT(min <= max); + SASSERT(max <= m_util.get_bv_size(e)); + cache(depth, e, min, max); // store result into the cache + } + + unsigned remove_trailing(expr * e, unsigned n, expr_ref& result, unsigned depth) { + const unsigned retv = remove_trailing_core(e, n, result, depth); + CTRACE("bv-trailing", result.get(), tout << mk_ismt2_pp(e, m()) << "\n--->\n" << mk_ismt2_pp(result.get(), m()) << "\n";); + CTRACE("bv-trailing", !result.get(), tout << mk_ismt2_pp(e, m()) << "\n---> [EMPTY]\n";); + return retv; + } + + // assumes that count_trailing gives me a lower bound, which we can also remove from each summand unsigned remove_trailing_add(app * a, unsigned n, expr_ref& result, unsigned depth) { SASSERT(m_util.is_bv_add(a)); const unsigned num = a->get_num_args(); @@ -104,7 +102,7 @@ struct bv_trailing::imp { return 0; } unsigned min, max; - count_trailing(a, min, max, depth); + count_trailing(a, min, max, depth); // caching is important here const unsigned to_rm = std::min(min, n); if (to_rm == 0) { result = a; @@ -177,7 +175,7 @@ struct bv_trailing::imp { result = a; return 0; } - unsigned num = a->get_num_args(); + const unsigned num = a->get_num_args(); unsigned retv = 0; unsigned i = num; expr_ref new_last(NULL, m()); @@ -221,13 +219,6 @@ struct bv_trailing::imp { return retv; } - unsigned remove_trailing(expr * e, unsigned n, expr_ref& result, unsigned depth) { - const unsigned retv = remove_trailing_core(e, n, result, depth); - CTRACE("bv-trailing", result.get(), tout << mk_ismt2_pp(e, m()) << "\n--->\n" << mk_ismt2_pp(result.get(), m()) << "\n";); - CTRACE("bv-trailing", !result.get(), tout << mk_ismt2_pp(e, m()) << "\n---> [EMPTY]\n";); - return retv; - } - unsigned remove_trailing_core(expr * e, unsigned n, expr_ref& result, unsigned depth) { SASSERT(m_util.is_bv(e)); if (!depth) return 0; @@ -250,16 +241,6 @@ struct bv_trailing::imp { return 0; } - void count_trailing(expr * e, unsigned& min, unsigned& max, unsigned depth) { - if (is_cached(depth, e, min, max)) - return; - SASSERT(e && m_util.is_bv(e)); - count_trailing_core(e, min, max, depth); - TRACE("bv-trailing", tout << mk_ismt2_pp(e, m()) << "\n:" << min << " - " << max << "\n";); - SASSERT(min <= max); - SASSERT(max <= m_util.get_bv_size(e)); - cache(depth, e, min, max); // store result into the cache - } void count_trailing_concat(app * a, unsigned& min, unsigned& max, unsigned depth) { if (depth <= 1) { @@ -352,6 +333,35 @@ struct bv_trailing::imp { max = m_util.get_bv_size(e); } } + + void cache(unsigned depth, expr * e, unsigned min, unsigned max) { + SASSERT(depth <= TRAILING_DEPTH); + m().inc_ref(e); + m_count_cache[depth].insert(e, std::make_pair(min, max)); + TRACE("bv-trailing", tout << "caching@" << depth <<": " << mk_ismt2_pp(e, m()) << '[' << m_util.get_bv_size(e) << "]\n: " << min << '-' << max << "\n";); + } + + bool is_cached(unsigned depth, expr * e, unsigned& min, unsigned& max) { + SASSERT(depth <= TRAILING_DEPTH); + const map::obj_map_entry * const oe = m_count_cache[depth].find_core(e); + if (oe == NULL) return false; + min = oe->get_data().m_value.first; + max = oe->get_data().m_value.second; + TRACE("bv-trailing", tout << "cached@" << depth << ": " << mk_ismt2_pp(e, m()) << '[' << m_util.get_bv_size(e) << "]\n: " << min << '-' << max << "\n";); + return true; + } + + + void reset_cache(unsigned condition) { + for (unsigned i = 0; i <= TRAILING_DEPTH; ++i) { + if (m_count_cache[i].size() < condition) continue; + map::iterator it = m_count_cache[i].begin(); + map::iterator end = m_count_cache[i].end(); + for (; it != end; ++it) m().dec_ref(it->m_key); + m_count_cache[i].reset(); + } + } + }; bv_trailing::bv_trailing(mk_extract_proc& mk_extract) { @@ -366,10 +376,14 @@ br_status bv_trailing::eq_remove_trailing(expr * e1, expr * e2, expr_ref& resul return m_imp->eq_remove_trailing(e1, e2, result); } -unsigned bv_trailing::remove_trailing(expr * e, unsigned n, expr_ref& result, unsigned depth) { - return m_imp->remove_trailing(e, n, result, depth); +void bv_trailing::count_trailing(expr * e, unsigned& min, unsigned& max) { + m_imp->count_trailing(e, min, max, TRAILING_DEPTH); } -void bv_trailing::reset_cache() { - m_imp->reset_cache(); +unsigned bv_trailing::remove_trailing(expr * e, unsigned n, expr_ref& result) { + return m_imp->remove_trailing(e, n, result, TRAILING_DEPTH); +} + +void bv_trailing::reset_cache(unsigned condition) { + m_imp->reset_cache(condition); } diff --git a/src/ast/rewriter/bv_trailing.h b/src/ast/rewriter/bv_trailing.h index 0af909e8f..862a1bea6 100644 --- a/src/ast/rewriter/bv_trailing.h +++ b/src/ast/rewriter/bv_trailing.h @@ -7,6 +7,8 @@ Abstract: + A utility to count trailing zeros of an expression. Treats 2x and x++0 equivalently. + Author: @@ -23,10 +25,20 @@ class bv_trailing { public: bv_trailing(mk_extract_proc& ep); virtual ~bv_trailing(); - void count_trailing(expr * e, unsigned& min, unsigned& max, unsigned depth); + public: + // Remove trailing zeros from both sides of an equality (might give False). br_status eq_remove_trailing(expr * e1, expr * e2, expr_ref& result); - unsigned remove_trailing(expr * e, unsigned n, expr_ref& result, unsigned depth); - void reset_cache(); + + // Gives a lower and upper bound on trailing zeros in e. + void count_trailing(expr * e, unsigned& min, unsigned& max); + + // Attempts removing n trailing zeros from e. Returns how many were successfully removed. + // We're assuming that it can remove at least as many zeros as min returned by count_training. + // Removing the bit-width of e, sets result to NULL. + unsigned remove_trailing(expr * e, unsigned n, expr_ref& result); + + // Reset cache(s) if it exceeded size condition. + void reset_cache(unsigned condition); protected: struct imp; imp * m_imp;