From 1974c224ab74fb7957afba04fc3f13649a8929e8 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sun, 4 Dec 2022 09:39:28 -0800 Subject: [PATCH] add demodulator simplifier refactor demodulator-rewriter a bit to separate reusable features. --- scripts/mk_project.py | 2 +- src/CMakeLists.txt | 2 +- src/ast/rewriter/th_rewriter.h | 1 + src/ast/simplifiers/CMakeLists.txt | 2 + .../simplifiers/demodulator_simplifier.cpp | 199 +++++++++ src/ast/simplifiers/demodulator_simplifier.h | 60 +++ src/ast/substitution/demodulator_rewriter.cpp | 377 +++++++++++++----- src/ast/substitution/demodulator_rewriter.h | 136 ++++--- 8 files changed, 624 insertions(+), 155 deletions(-) create mode 100644 src/ast/simplifiers/demodulator_simplifier.cpp create mode 100644 src/ast/simplifiers/demodulator_simplifier.h diff --git a/scripts/mk_project.py b/scripts/mk_project.py index a979359f3..a16913317 100644 --- a/scripts/mk_project.py +++ b/scripts/mk_project.py @@ -39,7 +39,7 @@ def init_project_def(): add_lib('macros', ['rewriter'], 'ast/macros') add_lib('model', ['macros']) add_lib('converters', ['model'], 'ast/converters') - add_lib('simplifiers', ['euf', 'normal_forms', 'bit_blaster', 'converters'], 'ast/simplifiers') + add_lib('simplifiers', ['euf', 'normal_forms', 'bit_blaster', 'converters', 'substitution'], 'ast/simplifiers') add_lib('tactic', ['simplifiers']) add_lib('solver', ['params', 'model', 'tactic', 'proofs']) add_lib('cmd_context', ['solver', 'rewriter', 'params']) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index dadd70bba..652aef4ac 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -52,9 +52,9 @@ add_subdirectory(ast/macros) add_subdirectory(model) add_subdirectory(ast/euf) add_subdirectory(ast/converters) +add_subdirectory(ast/substitution) add_subdirectory(ast/simplifiers) add_subdirectory(tactic) -add_subdirectory(ast/substitution) add_subdirectory(smt/params) add_subdirectory(parsers/util) add_subdirectory(math/grobner) diff --git a/src/ast/rewriter/th_rewriter.h b/src/ast/rewriter/th_rewriter.h index a3f003799..71c39b18e 100644 --- a/src/ast/rewriter/th_rewriter.h +++ b/src/ast/rewriter/th_rewriter.h @@ -52,6 +52,7 @@ public: expr_ref mk_app(func_decl* f, unsigned num_args, expr* const* args); expr_ref mk_app(func_decl* f, ptr_vector const& args) { return mk_app(f, args.size(), args.data()); } + expr_ref mk_app(func_decl* f, expr_ref_vector const& args) { return mk_app(f, args.size(), args.data()); } expr_ref mk_eq(expr* a, expr* b); bool reduce_quantifier(quantifier * old_q, diff --git a/src/ast/simplifiers/CMakeLists.txt b/src/ast/simplifiers/CMakeLists.txt index c6a8469ee..df44427cf 100644 --- a/src/ast/simplifiers/CMakeLists.txt +++ b/src/ast/simplifiers/CMakeLists.txt @@ -3,6 +3,7 @@ z3_add_component(simplifiers bit_blaster.cpp bv_slice.cpp card2bv.cpp + demodulator_simplifier.cpp dependent_expr_state.cpp elim_unconstrained.cpp eliminate_predicates.cpp @@ -18,4 +19,5 @@ z3_add_component(simplifiers rewriter bit_blaster normal_forms + substitution ) diff --git a/src/ast/simplifiers/demodulator_simplifier.cpp b/src/ast/simplifiers/demodulator_simplifier.cpp new file mode 100644 index 000000000..eb3585606 --- /dev/null +++ b/src/ast/simplifiers/demodulator_simplifier.cpp @@ -0,0 +1,199 @@ +/*++ +Copyright (c) 2022 Microsoft Corporation + +Module Name: + + demodulator_simplifier.cpp + +Author: + + Nikolaj Bjorner (nbjorner) 2022-12-4 + +--*/ + +#include "ast/simplifiers/demodulator_simplifier.h" + +demodulator_index::~demodulator_index() { + reset(); +} + +void demodulator_index::reset() { + for (auto& [k, v] : m_fwd_index) + dealloc(v); + for (auto& [k, v] : m_bwd_index) + dealloc(v); + m_fwd_index.reset(); + m_bwd_index.reset(); +} + +void demodulator_index::add(func_decl* f, unsigned i, obj_map& map) { + uint_set* s; + if (!map.find(f, s)) { + s = alloc(uint_set); + map.insert(f, s); + } + s->insert(i); +} + +void demodulator_index::del(func_decl* f, unsigned i, obj_map& map) { + uint_set* s; + if (map.find(f, s)) + s->remove(i); +} + +void demodulator_index::insert_bwd(expr* e, unsigned i) { + struct proc { + unsigned i; + demodulator_index& idx; + proc(unsigned i, demodulator_index& idx) :i(i), idx(idx) {} + void operator()(app* a) { + if (a->get_num_args() > 0 && is_uninterp(a)) + idx.add(a->get_decl(), i, idx.m_bwd_index); + } + void operator()(expr* e) {} + }; + proc p(i, *this); + for_each_expr(p, e); +} + +void demodulator_index::remove_bwd(expr* e, unsigned i) { + struct proc { + unsigned i; + demodulator_index& idx; + proc(unsigned i, demodulator_index& idx) :i(i), idx(idx) {} + void operator()(app* a) { + if (a->get_num_args() > 0 && is_uninterp(a)) + idx.del(a->get_decl(), i, idx.m_bwd_index); + } + void operator()(expr* e) {} + }; + proc p(i, *this); + for_each_expr(p, e); +} + +demodulator_simplifier::demodulator_simplifier(ast_manager& m, params_ref const& p, dependent_expr_state& st): + dependent_expr_simplifier(m, st), + m_util(m), + m_match_subst(m), + m_rewriter(m), + m_pinned(m) +{ + std::function rw = [&](func_decl* f, expr_ref_vector const& args, expr_ref& r) { + return rewrite1(f, args, r); + }; + m_rewriter.set_rewrite1(rw); +} + +void demodulator_simplifier::rewrite(unsigned i) { + if (m_index.empty()) + return; + + m_dependencies.reset(); + expr* f = fml(i); + expr_ref r = m_rewriter.rewrite(f); + if (r == f) + return; + expr_dependency_ref d(dep(i), m); + for (unsigned j : m_dependencies) + d = m.mk_join(d, dep(j)); + m_fmls.update(i, dependent_expr(m, r, d)); +} + +bool demodulator_simplifier::rewrite1(func_decl* f, expr_ref_vector const& args, expr_ref& np) { + uint_set* set; + if (!m_index.find_fwd(f, set)) + return false; + + TRACE("demodulator", tout << "trying to rewrite: " << f->get_name() << " args:\n" << m_new_args << "\n";); + + for (unsigned i : *set) { + + auto const& [lhs, rhs] = m_rewrites[i]; + + if (lhs->get_num_args() != args.size()) + continue; + + SASSERT(lhs->get_decl() == f); + + TRACE("demodulator", tout << "Matching with demodulator: " << mk_pp(d, m) << std::endl; ); + + if (m_match_subst(lhs, rhs, args.data(), np)) { + TRACE("demodulator_bug", tout << "succeeded...\n" << mk_pp(rhs, m) << "\n===>\n" << np << "\n";); + m_dependencies.insert(i); + return true; + } + } + + return false; +} + +void demodulator_simplifier::reschedule_processed(func_decl* f) { + uint_set* set = nullptr; + if (!m_index.find_bwd(f, set)) + return; + uint_set tmp; + for (auto i : *set) + if (m_processed.contains(i)) + tmp.insert(i); + for (auto i : tmp) { + m_processed.remove(i); + m_index.remove_fwd(f, i); + m_index.remove_bwd(fml(i), i); + m_todo.push_back(i); + } +} + +void demodulator_simplifier::reschedule_demodulators(func_decl* f, expr* lhs) { + uint_set* set; + if (!m_index.find_bwd(f, set)) + return; + uint_set all_occurrences(*set); + for (unsigned i : all_occurrences) { + app_expr_pair p; + if (!m_rewrites.find(i, p)) + continue; + if (!m_match_subst.can_rewrite(fml(i), lhs)) + continue; + func_decl* f = p.first->get_decl(); + m_index.remove_fwd(f, i); + m_index.remove_bwd(fml(i), i); + m_todo.push_back(i); + } +} + +void demodulator_simplifier::reset() { + m_pinned.reset(); + m_index.reset(); + m_processed.reset(); + m_todo.reset(); + unsigned max_vid = 1; + for (unsigned i : indices()) + max_vid = std::max(max_vid, m_util.max_var_id(fml(i))); + m_match_subst.reserve(max_vid); +} + +void demodulator_simplifier::reduce() { + reset(); + for (unsigned i : indices()) + m_todo.push_back(i); + + app_ref large(m); + expr_ref small(m); + while (!m_todo.empty()) { + unsigned i = m_todo.back(); + m_todo.pop_back(); + rewrite(i); + if (m_util.is_demodulator(fml(i), large, small)) { + func_decl* f = large->get_decl(); + reschedule_processed(f); + reschedule_demodulators(f, large); + m_index.insert_fwd(f, i); + m_rewrites.insert(i, app_expr_pair(large, small)); + m_pinned.push_back(large); + m_pinned.push_back(small); + } + else + m_processed.insert(i); + m_index.insert_bwd(fml(i), i); + } +} diff --git a/src/ast/simplifiers/demodulator_simplifier.h b/src/ast/simplifiers/demodulator_simplifier.h new file mode 100644 index 000000000..e104c056b --- /dev/null +++ b/src/ast/simplifiers/demodulator_simplifier.h @@ -0,0 +1,60 @@ +/*++ +Copyright (c) 2022 Microsoft Corporation + +Module Name: + + demodulator_simplifier.h + +Author: + + Nikolaj Bjorner (nbjorner) 2022-12-4 + +--*/ + +#pragma once + +#include "ast/substitution/demodulator_rewriter.h" +#include "ast/simplifiers/dependent_expr_state.h" +#include "util/uint_set.h" + +class demodulator_index { + obj_map m_fwd_index, m_bwd_index; + void add(func_decl* f, unsigned i, obj_map& map); + void del(func_decl* f, unsigned i, obj_map& map); + public: + ~demodulator_index(); + void reset(); + void insert_fwd(func_decl* f, unsigned i) { add(f, i, m_fwd_index); } + void remove_fwd(func_decl* f, unsigned i) { del(f, i, m_fwd_index); } + void insert_bwd(expr* e, unsigned i); + void remove_bwd(expr* e, unsigned i); + bool find_fwd(func_decl* f, uint_set*& s) { return m_bwd_index.find(f, s); } + bool find_bwd(func_decl* f, uint_set*& s) { return m_fwd_index.find(f, s); } + bool empty() const { return m_fwd_index.empty(); } +}; + +class demodulator_simplifier : public dependent_expr_simplifier { + typedef std::pair app_expr_pair; + demodulator_index m_index; + demodulator_util m_util; + demodulator_match_subst m_match_subst; + demodulator_rewriter_util m_rewriter; + u_map m_rewrites; + uint_set m_processed, m_dependencies; + unsigned_vector m_todo; + expr_ref_vector m_pinned; + + void rewrite(unsigned i); + bool rewrite1(func_decl* f, expr_ref_vector const& args, expr_ref& np); + expr* fml(unsigned i) { return m_fmls[i].fml(); } + expr_dependency* dep(unsigned i) { return m_fmls[i].dep(); } + void reschedule_processed(func_decl* f); + void reschedule_demodulators(func_decl* f, expr* lhs); + void reset(); + + public: + demodulator_simplifier(ast_manager& m, params_ref const& p, dependent_expr_state& st); + void reduce() override; + + +}; diff --git a/src/ast/substitution/demodulator_rewriter.cpp b/src/ast/substitution/demodulator_rewriter.cpp index 6e79242e0..0b37308c9 100644 --- a/src/ast/substitution/demodulator_rewriter.cpp +++ b/src/ast/substitution/demodulator_rewriter.cpp @@ -27,32 +27,73 @@ Revision History: #include "ast/rewriter/var_subst.h" #include "ast/substitution/demodulator_rewriter.h" -demodulator_rewriter::demodulator_rewriter(ast_manager & m): - m(m), - m_match_subst(m), - m_bsimp(m), - m_todo(m), - m_in_processed(m), - m_new_args(m), - m_rewrite_todo(m), - m_rewrite_cache(m), - m_new_exprs(m) { - params_ref p; - p.set_bool("elim_and", true); - m_bsimp.updt_params(p); + +class var_set_proc { + uint_set & m_set; +public: + var_set_proc(uint_set &s):m_set(s) {} + void operator()(var * n) { m_set.insert(n->get_idx()); } + void operator()(quantifier * n) {} + void operator()(app * n) {} +}; + +int demodulator_util::is_subset(expr * e1, expr * e2) const { + uint_set ev1, ev2; + + if (m.is_value(e1)) + return 1; // values are always a subset! + + var_set_proc proc1(ev1); + for_each_expr(proc1, e1); + var_set_proc proc2(ev2); + for_each_expr(proc2, e2); + + return (ev1==ev2 ) ? +2 : // We return +2 if the sets are equal. + (ev1.subset_of(ev2)) ? +1 : + (ev2.subset_of(ev1)) ? -1 : + 0 ; } -demodulator_rewriter::~demodulator_rewriter() { - reset_dealloc_values(m_fwd_idx); - reset_dealloc_values(m_back_idx); - for (auto & kv : m_demodulator2lhs_rhs) { - m.dec_ref(kv.m_key); - m.dec_ref(kv.m_value.first); - m.dec_ref(kv.m_value.second); +int demodulator_util::is_smaller(expr * e1, expr * e2) const { + unsigned sz1 = 0, sz2 = 0; + + // values are always smaller! + if (m.is_value(e1)) + return +1; + else if (m.is_value(e2)) + return -1; + + // interpreted stuff is always better than uninterpreted. + if (!is_uninterp(e1) && is_uninterp(e2)) + return +1; + else if (is_uninterp(e1) && !is_uninterp(e2)) + return -1; + + // two uninterpreted functions are ordered first by the number of + // arguments, then by their id. + if (is_uninterp(e1) && is_uninterp(e2)) { + if (to_app(e1)->get_num_args() < to_app(e2)->get_num_args()) + return +1; + else if (to_app(e1)->get_num_args() > to_app(e2)->get_num_args()) + return -1; + else { + unsigned a = to_app(e1)->get_decl()->get_id(); + unsigned b = to_app(e2)->get_decl()->get_id(); + if (a < b) + return +1; + else if (a > b) + return -1; + } } + sz1 = get_depth(e1); + sz2 = get_depth(e2); + + return (sz1 == sz2) ? 0 : + (sz1 < sz2) ? +1 : + -1 ; } -bool demodulator_rewriter::is_demodulator(expr * e, app_ref & large, expr_ref & small) const { +bool demodulator_util::is_demodulator(expr * e, app_ref & large, expr_ref & small) const { if (!is_forall(e)) { return false; } @@ -109,71 +150,6 @@ bool demodulator_rewriter::is_demodulator(expr * e, app_ref & large, expr_ref & return false; } -class var_set_proc { - uint_set & m_set; -public: - var_set_proc(uint_set &s):m_set(s) {} - void operator()(var * n) { m_set.insert(n->get_idx()); } - void operator()(quantifier * n) {} - void operator()(app * n) {} -}; - -int demodulator_rewriter::is_subset(expr * e1, expr * e2) const { - uint_set ev1, ev2; - - if (m.is_value(e1)) - return 1; // values are always a subset! - - var_set_proc proc1(ev1); - for_each_expr(proc1, e1); - var_set_proc proc2(ev2); - for_each_expr(proc2, e2); - - return (ev1==ev2 ) ? +2 : // We return +2 if the sets are equal. - (ev1.subset_of(ev2)) ? +1 : - (ev2.subset_of(ev1)) ? -1 : - 0 ; -} - -int demodulator_rewriter::is_smaller(expr * e1, expr * e2) const { - unsigned sz1 = 0, sz2 = 0; - - // values are always smaller! - if (m.is_value(e1)) - return +1; - else if (m.is_value(e2)) - return -1; - - // interpreted stuff is always better than uninterpreted. - if (!is_uninterp(e1) && is_uninterp(e2)) - return +1; - else if (is_uninterp(e1) && !is_uninterp(e2)) - return -1; - - // two uninterpreted functions are ordered first by the number of - // arguments, then by their id. - if (is_uninterp(e1) && is_uninterp(e2)) { - if (to_app(e1)->get_num_args() < to_app(e2)->get_num_args()) - return +1; - else if (to_app(e1)->get_num_args() > to_app(e2)->get_num_args()) - return -1; - else { - unsigned a = to_app(e1)->get_decl()->get_id(); - unsigned b = to_app(e2)->get_decl()->get_id(); - if (a < b) - return +1; - else if (a > b) - return -1; - } - } - sz1 = get_depth(e1); - sz2 = get_depth(e2); - - return (sz1 == sz2) ? 0 : - (sz1 < sz2) ? +1 : - -1 ; -} - class max_var_id_proc { unsigned m_max_var_id; public: @@ -187,13 +163,202 @@ public: unsigned get_max() { return m_max_var_id; } }; -unsigned demodulator_rewriter::max_var_id(expr_ref_vector const& es) { +unsigned demodulator_util::max_var_id(expr* e) { + max_var_id_proc proc; + for_each_expr(proc, e); + return proc.get_max(); +} + +unsigned demodulator_util::max_var_id(expr_ref_vector const& es) { max_var_id_proc proc; for (expr* e : es) for_each_expr(proc, e); return proc.get_max(); } + +// ------------------ + +demodulator_rewriter_util::demodulator_rewriter_util(ast_manager& m): + m(m), + m_th_rewriter(m), + m_rewrite_todo(m), + m_rewrite_cache(m), + m_new_exprs(m), + m_new_args(m) +{} + +expr_ref demodulator_rewriter_util::rewrite(expr * n) { + + TRACE("demodulator", tout << "rewrite: " << mk_pp(n, m) << std::endl; ); + app * a; + + SASSERT(m_rewrite_todo.empty()); + m_new_exprs.reset(); + m_rewrite_cache.reset(); + + m_rewrite_todo.push_back(n); + while (!m_rewrite_todo.empty()) { + TRACE("demodulator_stack", tout << "STACK: " << std::endl; + for (unsigned i = 0; i < m_rewrite_todo.size(); i++) + tout << std::dec << i << ": " << std::hex << (size_t)m_rewrite_todo[i] << + " = " << mk_pp(m_rewrite_todo[i], m) << std::endl; + ); + + expr * e = m_rewrite_todo.back(); + expr_ref actual(e, m); + + if (m_rewrite_cache.contains(e)) { + const expr_bool_pair &ebp = m_rewrite_cache.get(e); + if (ebp.second) { + m_rewrite_todo.pop_back(); + continue; + } + else { + actual = ebp.first; + } + } + + switch (actual->get_kind()) { + case AST_VAR: + rewrite_cache(e, actual, true); + m_rewrite_todo.pop_back(); + break; + case AST_APP: + a = to_app(actual); + if (rewrite_visit_children(a)) { + func_decl * f = a->get_decl(); + m_new_args.reset(); + bool all_untouched = true; + for (expr* o_child : *a) { + expr * n_child; + SASSERT(m_rewrite_cache.contains(o_child) && m_rewrite_cache.get(o_child).second); + expr_bool_pair const & ebp = m_rewrite_cache.get(o_child); + n_child = ebp.first; + if (n_child != o_child) + all_untouched = false; + m_new_args.push_back(n_child); + } + expr_ref np(m); + if (m_rewrite1(f, m_new_args, np)) { + rewrite_cache(e, np, false); + // No pop. + } + else { + if (all_untouched) { + rewrite_cache(e, actual, true); + } + else { + expr_ref na(m); + na = m_th_rewriter.mk_app(f, m_new_args); + TRACE("demodulator_bug", tout << "e:\n" << mk_pp(e, m) << "\nnew_args: \n"; + tout << m_new_args << "\n"; + tout << "=====>\n"; + tout << "na:\n " << na << "\n";); + rewrite_cache(e, na, true); + } + m_rewrite_todo.pop_back(); + } + } + break; + case AST_QUANTIFIER: { + expr * body = to_quantifier(actual)->get_expr(); + if (m_rewrite_cache.contains(body)) { + const expr_bool_pair ebp = m_rewrite_cache.get(body); + SASSERT(ebp.second); + expr * new_body = ebp.first; + quantifier_ref q(m); + q = m.update_quantifier(to_quantifier(actual), new_body); + m_new_exprs.push_back(q); + expr_ref new_q = elim_unused_vars(m, q, params_ref()); + m_new_exprs.push_back(new_q); + rewrite_cache(e, new_q, true); + m_rewrite_todo.pop_back(); + } else { + m_rewrite_todo.push_back(body); + } + break; + } + default: + UNREACHABLE(); + } + } + + SASSERT(m_rewrite_cache.contains(n)); + const expr_bool_pair & ebp = m_rewrite_cache.get(n); + SASSERT(ebp.second); + expr * r = ebp.first; + + TRACE("demodulator", tout << "rewrite result: " << mk_pp(r, m) << std::endl; ); + + return expr_ref(r, m); +} + +bool demodulator_rewriter_util::rewrite_visit_children(app * a) { + bool res = true; + for (expr* e : *a) { + if (m_rewrite_cache.contains(e) && m_rewrite_cache.get(e).second) + continue; + bool recursive = false; + expr * v = e; + if (m_rewrite_cache.contains(e)) { + auto const & [t, marked] = m_rewrite_cache.get(e); + if (marked) + v = t; + } + for (expr* t : m_rewrite_todo) { + if (t == v) { + recursive = true; + TRACE("demodulator", tout << "Detected demodulator cycle: " << + mk_pp(a, m) << " --> " << mk_pp(v, m) << std::endl;); + rewrite_cache(e, v, true); + break; + } + } + if (!recursive) { + m_rewrite_todo.push_back(e); + res = false; + } + } + return res; +} + +void demodulator_rewriter_util::rewrite_cache(expr * e, expr * new_e, bool done) { + m_rewrite_cache.insert(e, expr_bool_pair(new_e, done)); +} + + + +// ------------------ + +demodulator_rewriter::demodulator_rewriter(ast_manager & m): + m(m), + m_match_subst(m), + m_util(m), + m_bsimp(m), + m_todo(m), + m_in_processed(m), + m_new_args(m), + m_rewrite_todo(m), + m_rewrite_cache(m), + m_new_exprs(m) { + params_ref p; + p.set_bool("elim_and", true); + m_bsimp.updt_params(p); +} + +demodulator_rewriter::~demodulator_rewriter() { + reset_dealloc_values(m_fwd_idx); + reset_dealloc_values(m_back_idx); + for (auto & kv : m_demodulator2lhs_rhs) { + m.dec_ref(kv.m_key); + m.dec_ref(kv.m_value.first); + m.dec_ref(kv.m_value.second); + } +} + + + void demodulator_rewriter::insert_fwd_idx(app * large, expr * small, quantifier * demodulator) { SASSERT(demodulator); SASSERT(large && small); @@ -265,17 +430,18 @@ bool demodulator_rewriter::rewrite1(func_decl * f, expr_ref_vector const & args, for (quantifier* d : *set) { - auto const& [large, rhs] = m_demodulator2lhs_rhs[d]; + auto const& [lhs, rhs] = m_demodulator2lhs_rhs[d]; - if (large->get_num_args() != args.size()) + if (lhs->get_num_args() != args.size()) continue; TRACE("demodulator_bug", tout << "Matching with demodulator: " << mk_pp(d, m) << std::endl; ); - SASSERT(large->get_decl() == f); + SASSERT(lhs->get_decl() == f); - if (m_match_subst(large, rhs, args.data(), np)) { + if (m_match_subst(lhs, rhs, args.data(), np)) { TRACE("demodulator_bug", tout << "succeeded...\n" << mk_pp(rhs, m) << "\n===>\n" << mk_pp(np, m) << "\n";); + m_new_exprs.push_back(np); return true; } } @@ -289,15 +455,14 @@ bool demodulator_rewriter::rewrite_visit_children(app * a) { if (m_rewrite_cache.contains(e) && m_rewrite_cache.get(e).second) continue; bool recursive = false; - unsigned sz = m_rewrite_todo.size(); expr * v = e; if (m_rewrite_cache.contains(e)) { - expr_bool_pair const & ebp = m_rewrite_cache.get(e); - if (ebp.second) - v = ebp.first; + auto const & [t, marked] = m_rewrite_cache.get(e); + if (marked) + v = t; } - for (unsigned i = sz; i-- > 0;) { - if (m_rewrite_todo[i] == v) { + for (expr* t : m_rewrite_todo) { + if (t == v) { recursive = true; TRACE("demodulator", tout << "Detected demodulator cycle: " << mk_pp(a, m) << " --> " << mk_pp(v, m) << std::endl;); @@ -504,7 +669,7 @@ void demodulator_rewriter::reschedule_processed(func_decl * f) { } } -bool demodulator_rewriter::can_rewrite(expr * n, expr * lhs) { +bool demodulator_match_subst::can_rewrite(expr * n, expr * lhs) { // this is a quick check, we just traverse d and check if there is an expression in d that is an instance of lhs of n'. // we cannot use the trick used for m_processed, since the main loop would not terminate. @@ -530,7 +695,7 @@ bool demodulator_rewriter::can_rewrite(expr * n, expr * lhs) { case AST_APP: if (for_each_expr_args(stack, visited, to_app(curr)->get_num_args(), to_app(curr)->get_args())) { - if (m_match_subst(lhs, curr)) + if ((*this)(lhs, curr)) return true; visited.mark(curr, true); stack.pop_back(); @@ -582,7 +747,7 @@ void demodulator_rewriter::reschedule_demodulators(func_decl * f, expr * lhs) { func_decl_ref df(l->get_decl(), m); // Now we know there is an occurrence of f in d - if (!can_rewrite(d, lhs)) + if (!m_match_subst.can_rewrite(d, lhs)) continue; TRACE("demodulator", tout << "Rescheduling: " << std::endl << mk_pp(d, m) << std::endl); @@ -602,7 +767,7 @@ void demodulator_rewriter::operator()(expr_ref_vector const& exprs, for (expr* e : exprs) m_todo.push_back(e); - m_match_subst.reserve(max_var_id(exprs)); + m_match_subst.reserve(m_util.max_var_id(exprs)); while (!m_todo.empty()) { // let n be the next formula in m_todo. @@ -618,7 +783,7 @@ void demodulator_rewriter::operator()(expr_ref_vector const& exprs, app_ref large(m); expr_ref small(m); - if (!is_demodulator(np, large, small)) { + if (!m_util.is_demodulator(np, large, small)) { // insert n' into m_processed m_processed.insert(np); m_in_processed.push_back(np); @@ -661,7 +826,7 @@ void demodulator_rewriter::operator()(expr_ref_vector const& exprs, } -demodulator_rewriter::match_subst::match_subst(ast_manager & m): +demodulator_match_subst::demodulator_match_subst(ast_manager & m): m(m), m_subst(m) { } @@ -693,7 +858,7 @@ struct match_args_aux_proc { void operator()(app * n) {} }; -bool demodulator_rewriter::match_subst::match_args(app * lhs, expr * const * args) { +bool demodulator_match_subst::match_args(app * lhs, expr * const * args) { m_cache.reset(); m_todo.reset(); @@ -809,7 +974,7 @@ bool demodulator_rewriter::match_subst::match_args(app * lhs, expr * const * arg } -bool demodulator_rewriter::match_subst::operator()(app * lhs, expr * rhs, expr * const * args, expr_ref & new_rhs) { +bool demodulator_match_subst::operator()(app * lhs, expr * rhs, expr * const * args, expr_ref & new_rhs) { if (match_args(lhs, args)) { if (m_all_args_eq) { @@ -824,7 +989,7 @@ bool demodulator_rewriter::match_subst::operator()(app * lhs, expr * rhs, expr * return false; } -bool demodulator_rewriter::match_subst::operator()(expr * t, expr * i) { +bool demodulator_match_subst::operator()(expr * t, expr * i) { m_cache.reset(); m_todo.reset(); if (is_var(t)) diff --git a/src/ast/substitution/demodulator_rewriter.h b/src/ast/substitution/demodulator_rewriter.h index 2152520ce..18befc198 100644 --- a/src/ast/substitution/demodulator_rewriter.h +++ b/src/ast/substitution/demodulator_rewriter.h @@ -24,6 +24,7 @@ Revision History: #include "ast/ast.h" #include "ast/substitution/substitution.h" #include "ast/rewriter/bool_rewriter.h" +#include "ast/rewriter/th_rewriter.h" #include "util/obj_hashtable.h" #include "util/obj_pair_hashtable.h" #include "util/array_map.h" @@ -92,6 +93,92 @@ The code in spc_rewriter.* does something like that. We cannot reuse this code d for the superposion engine in Z3, but we can adapt it for our needs in the preprocessor. */ +class demodulator_util { + ast_manager& m; + int is_subset(expr*, expr*) const; + int is_smaller(expr*, expr*) const; + public: + demodulator_util(ast_manager& m):m(m) {} + bool is_demodulator(expr* e, app_ref& large, expr_ref & small) const; + unsigned max_var_id(expr* e); + unsigned max_var_id(expr_ref_vector const& e); +}; + +/** + \brief Custom matcher & substitution application +*/ +class demodulator_match_subst { + typedef std::pair expr_pair; + typedef obj_pair_hashtable cache; + + void reset(); + + ast_manager & m; + substitution m_subst; + cache m_cache; + svector m_todo; + bool m_all_args_eq; + + bool match_args(app * t, expr * const * args); + +public: + demodulator_match_subst(ast_manager & m); + + void reserve(unsigned max_vid) { m_subst.reserve(2, max_vid+1); } + /** + \brief Let f be the top symbol of lhs. If (f args) is an + instance of lhs, that is, there is a substitution s + s.t. s[lhs] = (f args), then return true and store s[rhs] + into new_rhs. Where s[t] represents the application of the + substitution s into t. + + Assumptions, the variables in lhs and (f args) are assumed to be distinct. + So, (f x y) matches (f y x). + Moreover, the result should be in terms of the variables in (f args). + */ + bool operator()(app * lhs, expr * rhs, expr * const * args, expr_ref & new_rhs); + + /** + \brief Return true if \c i is an instance of \c t. + */ + bool operator()(expr * t, expr * i); + + bool can_rewrite(expr* n, expr* lhs); +}; + +class demodulator_rewriter_util { + ast_manager& m; + std::function m_rewrite1; + + typedef std::pair expr_bool_pair; + + class plugin { + ast_manager& m; + public: + plugin(ast_manager& m): m(m) { } + void ins_eh(expr* k, expr_bool_pair v) { m.inc_ref(k); m.inc_ref(v.first); } + void del_eh(expr* k, expr_bool_pair v) { m.dec_ref(k); m.dec_ref(v.first); } + static unsigned to_int(expr const * k) { return k->get_id(); } + }; + typedef array_map expr_map; + + typedef expr_map rewrite_cache_map; + + th_rewriter m_th_rewriter; + expr_ref_buffer m_rewrite_todo; + rewrite_cache_map m_rewrite_cache; + expr_ref_buffer m_new_exprs; + expr_ref_vector m_new_args; + + bool rewrite_visit_children(app * a); + void rewrite_cache(expr * e, expr * new_e, bool done); + +public: + demodulator_rewriter_util(ast_manager& m); + void set_rewrite1(std::function& fn) { m_rewrite1 = fn; } + expr_ref rewrite(expr * n); +}; + class demodulator_rewriter final { class rewrite_proc; class add_back_idx_proc; @@ -119,47 +206,10 @@ class demodulator_rewriter final { typedef obj_map demodulator2lhs_rhs; typedef expr_map rewrite_cache_map; - /** - \brief Custom matcher & substitution application - */ - class match_subst { - typedef std::pair expr_pair; - typedef obj_pair_hashtable cache; - - void reset(); - - ast_manager & m; - substitution m_subst; - cache m_cache; - svector m_todo; - bool m_all_args_eq; - - bool match_args(app * t, expr * const * args); - - public: - match_subst(ast_manager & m); - void reserve(unsigned max_vid) { m_subst.reserve(2, max_vid+1); } - /** - \brief Let f be the top symbol of lhs. If (f args) is an - instance of lhs, that is, there is a substitution s - s.t. s[lhs] = (f args), then return true and store s[rhs] - into new_rhs. Where s[t] represents the application of the - substitution s into t. - - Assumptions, the variables in lhs and (f args) are assumed to be distinct. - So, (f x y) matches (f y x). - Moreover, the result should be in terms of the variables in (f args). - */ - bool operator()(app * lhs, expr * rhs, expr * const * args, expr_ref & new_rhs); - - /** - \brief Return true if \c i is an instance of \c t. - */ - bool operator()(expr * t, expr * i); - }; ast_manager & m; - match_subst m_match_subst; + demodulator_match_subst m_match_subst; + demodulator_util m_util; bool_rewriter m_bsimp; fwd_idx_map m_fwd_idx; back_idx_map m_back_idx; @@ -179,7 +229,6 @@ class demodulator_rewriter final { void remove_bwd_idx(expr* q); bool check_fwd_idx_consistency(); void show_fwd_idx(std::ostream & out); - bool is_demodulator(expr * e, app_ref & large, expr_ref & small) const; bool can_rewrite(expr * n, expr * lhs); expr * rewrite(expr * n); @@ -188,13 +237,6 @@ class demodulator_rewriter final { void rewrite_cache(expr * e, expr * new_e, bool done); void reschedule_processed(func_decl * f); void reschedule_demodulators(func_decl * f, expr * np); - unsigned max_var_id(expr_ref_vector const& es); - - // is_smaller returns -1 for e1e2. - int is_smaller(expr * e1, expr * e2) const; - - // is_subset returns -1 for e1 subset e2, +1 for e2 subset e1, 0 else. - int is_subset(expr * e1, expr * e2) const; public: demodulator_rewriter(ast_manager & m);