diff --git a/src/ast/expr_substitution.h b/src/ast/expr_substitution.h index 03e96db42..f71e7d232 100644 --- a/src/ast/expr_substitution.h +++ b/src/ast/expr_substitution.h @@ -51,6 +51,8 @@ public: void reset(); void cleanup(); + obj_map const sub() const { return m_subst; } + std::ostream& display(std::ostream& out); }; diff --git a/src/ast/rewriter/th_rewriter.cpp b/src/ast/rewriter/th_rewriter.cpp index 4937e480c..b4521a216 100644 --- a/src/ast/rewriter/th_rewriter.cpp +++ b/src/ast/rewriter/th_rewriter.cpp @@ -30,6 +30,7 @@ Notes: #include "ast/rewriter/seq_rewriter.h" #include "ast/rewriter/rewriter_def.h" #include "ast/rewriter/var_subst.h" +#include "ast/rewriter/expr_safe_replace.h" #include "ast/expr_substitution.h" #include "ast/ast_smt2_pp.h" #include "ast/ast_pp.h" @@ -50,6 +51,7 @@ struct th_rewriter_cfg : public default_rewriter_cfg { recfun_rewriter m_rec_rw; arith_util m_a_util; bv_util m_bv_util; + expr_ref_vector m_pinned; unsigned long long m_max_memory; // in bytes unsigned m_max_steps; bool m_pull_cheap_ite; @@ -663,6 +665,20 @@ struct th_rewriter_cfg : public default_rewriter_cfg { return result; } + void apply_subst(ptr_buffer& patterns) { + if (!m_subst) + return; + expr_ref tmp(m()); + expr_safe_replace rep(m()); + for (auto kv : m_subst->sub()) + rep.insert(kv.m_key, kv.m_value); + for (unsigned i = 0; i < patterns.size(); ++i) { + rep(patterns[i], tmp); + m_pinned.push_back(tmp); + patterns[i] = tmp; + } + } + bool reduce_quantifier(quantifier * old_q, expr * new_body, @@ -721,9 +737,12 @@ struct th_rewriter_cfg : public default_rewriter_cfg { remove_duplicates(new_patterns_buf); remove_duplicates(new_no_patterns_buf); + apply_subst(new_patterns_buf); + q1 = m().update_quantifier(old_q, new_patterns_buf.size(), new_patterns_buf.c_ptr(), new_no_patterns_buf.size(), new_no_patterns_buf.c_ptr(), new_body); + m_pinned.reset(); TRACE("reduce_quantifier", tout << mk_ismt2_pp(old_q, m()) << "\n----->\n" << mk_ismt2_pp(q1, m()) << "\n";); SASSERT(is_well_sorted(m(), q1)); if (m().proofs_enabled() && q1 != old_q) { @@ -760,6 +779,7 @@ struct th_rewriter_cfg : public default_rewriter_cfg { m_rec_rw(m), m_a_util(m), m_bv_util(m), + m_pinned(m), m_used_dependencies(m), m_subst(nullptr) { updt_local_params(p);