diff --git a/src/ast/for_each_expr.cpp b/src/ast/for_each_expr.cpp index 1e7b6da3b..374d7b496 100644 --- a/src/ast/for_each_expr.cpp +++ b/src/ast/for_each_expr.cpp @@ -44,6 +44,60 @@ unsigned get_num_exprs(expr * n) { return get_num_exprs(n, visited); } + +static void get_num_internal_exprs(unsigned_vector& counts, sbuffer& todo, expr * n) { + counts.reserve(n->get_id() + 1); + unsigned& rc = counts[n->get_id()]; + if (rc > 0) { + --rc; + return; + } + rc = n->get_ref_count() - 1; + unsigned i = todo.size(); + todo.push_back(n); + unsigned count = 0; + for (; i < todo.size(); ++i) { + n = todo[i]; + if (!is_app(n)) + continue; + for (expr* arg : *to_app(n)) { + unsigned id = arg->get_id(); + counts.reserve(id + 1); + unsigned& rc = counts[id]; + if (rc > 0) { + --rc; + continue; + } + rc = arg->get_ref_count() - 1; + todo.push_back(arg); + } + } +} + +unsigned get_num_internal_exprs(expr * n) { + unsigned_vector counts; + sbuffer todo; + unsigned internal_nodes = 0; + get_num_internal_exprs(counts, todo, n); + for (expr* t : todo) + if (counts[t->get_id()] == 0) + ++internal_nodes; + return internal_nodes; +} + +unsigned get_num_internal_exprs(unsigned sz, expr * const * args) { + unsigned_vector counts; + sbuffer todo; + unsigned internal_nodes = 0; + for (unsigned i = 0; i < sz; ++i) + get_num_internal_exprs(counts, todo, args[i]); + for (expr* t : todo) + if (counts[t->get_id()] == 0) + ++internal_nodes; + return internal_nodes; +} + + namespace has_skolem_functions_ns { struct found {}; struct proc { diff --git a/src/ast/for_each_expr.h b/src/ast/for_each_expr.h index 2d5ed05ae..c94348964 100644 --- a/src/ast/for_each_expr.h +++ b/src/ast/for_each_expr.h @@ -163,6 +163,8 @@ struct for_each_expr_proc : public EscapeProc { unsigned get_num_exprs(expr * n); unsigned get_num_exprs(expr * n, expr_mark & visited); unsigned get_num_exprs(expr * n, expr_fast_mark1 & visited); +unsigned get_num_internal_exprs(expr * n); +unsigned get_num_internal_exprs(unsigned sz, expr * const * args); bool has_skolem_functions(expr * n); diff --git a/src/ast/rewriter/bool_rewriter.cpp b/src/ast/rewriter/bool_rewriter.cpp index 378c794cd..9ebdbe7fd 100644 --- a/src/ast/rewriter/bool_rewriter.cpp +++ b/src/ast/rewriter/bool_rewriter.cpp @@ -20,6 +20,7 @@ Notes: #include "params/bool_rewriter_params.hpp" #include "ast/rewriter/rewriter_def.h" #include "ast/ast_lt.h" +#include "ast/for_each_expr.h" #include void bool_rewriter::updt_params(params_ref const & _p) { @@ -268,14 +269,18 @@ br_status bool_rewriter::mk_nflat_or_core(unsigned num_args, expr * const * args return BR_DONE; } -#if 1 br_status st; st = m_hoist.mk_or(buffer.size(), buffer.data(), result); + if (st != BR_FAILED) { + unsigned count1 = get_num_internal_exprs(result); + unsigned count2 = get_num_internal_exprs(buffer.size(), buffer.data()); + if (count1 > count2) + st = BR_FAILED; + } if (st == BR_DONE) return BR_REWRITE1; if (st != BR_FAILED) return st; -#endif if (s) { ast_lt lt;