diff --git a/src/ast/rewriter/bool_rewriter.cpp b/src/ast/rewriter/bool_rewriter.cpp index 0bfdfa01a..1e3449ef3 100644 --- a/src/ast/rewriter/bool_rewriter.cpp +++ b/src/ast/rewriter/bool_rewriter.cpp @@ -34,6 +34,7 @@ void bool_rewriter::updt_params(params_ref const & _p) { m_blast_distinct = p.blast_distinct(); m_blast_distinct_threshold = p.blast_distinct_threshold(); m_ite_extra_rules = p.ite_extra_rules(); + m_elim_ite_value_tree = p.elim_ite_value_tree(); } void bool_rewriter::get_param_descrs(param_descrs & r) { @@ -680,9 +681,84 @@ br_status bool_rewriter::try_ite_value(app * ite, app * val, expr_ref & result) return BR_REWRITE2; } + if (m_elim_ite_value_tree) { + result = simplify_eq_ite(val, ite); + if (result) + return BR_REWRITE_FULL; + } + return BR_FAILED; } +expr_ref bool_rewriter::simplify_eq_ite(expr* value, expr* ite) { + SASSERT(m().is_value(value)); + SASSERT(m().is_ite(ite)); + expr* c = nullptr, * t = nullptr, * e = nullptr; + ptr_buffer todo; + ptr_vector values; + expr_ref_vector pinned(m()); + expr_ref r(m()); + todo.push_back(ite); + while (!todo.empty()) { + expr* arg = todo.back(); + if (m().is_value(arg)) { + todo.pop_back(); + if (m().are_equal(arg, value)) { + values.setx(arg->get_id(), m().mk_true(), nullptr); + continue; + } + if (m().are_distinct(arg, value)) { + values.setx(arg->get_id(), m().mk_false(), nullptr); + continue; + } + return expr_ref(nullptr, m()); + } + if (m().is_ite(arg, c, t, e)) { + unsigned sz = todo.size(); + if (!values.get(t->get_id(), nullptr)) + todo.push_back(t); + + if (!values.get(e->get_id(), nullptr)) + todo.push_back(e); + + if (sz < todo.size()) + continue; + todo.pop_back(); + if (m().is_true(values[t->get_id()])) { + r = m().mk_or(c, values[e->get_id()]); + values.setx(arg->get_id(), r, nullptr); + pinned.push_back(r); + continue; + } + if (m().is_false(values[t->get_id()])) { + r = m().mk_and(m().mk_not(c), values[e->get_id()]); + values.setx(arg->get_id(), r, nullptr); + pinned.push_back(r); + continue; + } + if (m().is_false(values[e->get_id()])) { + r = m().mk_and(c, values[t->get_id()]); + values.setx(arg->get_id(), r, nullptr); + pinned.push_back(r); + continue; + } + if (m().is_true(values[e->get_id()])) { + r = m().mk_or(m().mk_not(c), values[t->get_id()]); + values.setx(arg->get_id(), r, nullptr); + pinned.push_back(r); + continue; + } + r = m().mk_ite(c, values[t->get_id()], values[e->get_id()]); + values.setx(arg->get_id(), r, nullptr); + pinned.push_back(r); + continue; + } + IF_VERBOSE(10, verbose_stream() << "bail " << mk_bounded_pp(arg, m()) << "\n"); + return expr_ref(nullptr, m()); + } + return expr_ref(values[ite->get_id()], m()); +} + app* bool_rewriter::mk_eq_plain(expr* lhs, expr* rhs) { if (m().are_equal(lhs, rhs)) diff --git a/src/ast/rewriter/bool_rewriter.h b/src/ast/rewriter/bool_rewriter.h index aec8e0700..44aed881f 100644 --- a/src/ast/rewriter/bool_rewriter.h +++ b/src/ast/rewriter/bool_rewriter.h @@ -61,6 +61,7 @@ class bool_rewriter { unsigned m_local_ctx_limit; unsigned m_local_ctx_cost; bool m_elim_ite; + bool m_elim_ite_value_tree; ptr_vector m_todo1, m_todo2; unsigned_vector m_counts1, m_counts2; @@ -83,6 +84,8 @@ class bool_rewriter { void push_new_arg(expr* arg, expr_ref_vector& new_args, expr_fast_mark1& neg_lits, expr_fast_mark2& pos_lits); + expr_ref simplify_eq_ite(expr* value, expr* ite); + public: bool_rewriter(ast_manager & m, params_ref const & p = params_ref()):m_manager(m), m_local_ctx_cost(0) { updt_params(p); diff --git a/src/params/bool_rewriter_params.pyg b/src/params/bool_rewriter_params.pyg index 87578470e..e10e08d46 100644 --- a/src/params/bool_rewriter_params.pyg +++ b/src/params/bool_rewriter_params.pyg @@ -7,6 +7,7 @@ def_module_params(module_name='rewriter', ("sort_disjunctions", BOOL, True, "sort subterms in disjunctions"), ("elim_and", BOOL, False, "conjunctions are rewritten using negation and disjunctions"), ('elim_ite', BOOL, True, "eliminate ite in favor of and/or"), + ('elim_ite_value_tree', BOOL, False, "eliminate equations 'v = ite(...)' where v is a value and each leaf in the ite tree is a value"), ("local_ctx", BOOL, False, "perform local (i.e., cheap) context simplifications"), ("local_ctx_limit", UINT, UINT_MAX, "limit for applying local context simplifier"), ("blast_distinct", BOOL, False, "expand a distinct predicate into a quadratic number of disequalities"),