diff --git a/src/ast/rewriter/expr_replacer.h b/src/ast/rewriter/expr_replacer.h index c1bfabd12..96418f00b 100644 --- a/src/ast/rewriter/expr_replacer.h +++ b/src/ast/rewriter/expr_replacer.h @@ -39,6 +39,7 @@ public: void operator()(expr * t, expr_ref & result); void operator()(expr_ref & t) { expr_ref s(t, m()); (*this)(s, t); } void operator()(expr_ref_vector& v) { expr_ref t(m()); for (unsigned i = 0; i < v.size(); ++i) (*this)(v.get(i), t), v[i] = t; } + std::pair replace_with_dep(expr* t) { expr_ref r(m()); expr_dependency_ref d(m()); (*this)(t, r, d); return { r, d }; } virtual unsigned get_num_steps() const { return 0; } virtual void reset() = 0; diff --git a/src/ast/simplifiers/bv_slice.cpp b/src/ast/simplifiers/bv_slice.cpp index f39fa932e..75e0a890c 100644 --- a/src/ast/simplifiers/bv_slice.cpp +++ b/src/ast/simplifiers/bv_slice.cpp @@ -109,12 +109,12 @@ namespace bv { }; if (lo > 0 && !b.contains(lo)) { b.insert(lo); - if (m_num_scopes > 0) + if (num_scopes() > 0) m_trail.push(remove_set(b, lo)); } if (hi + 1 < sz && !b.contains(hi + 1)) { b.insert(hi + 1); - if (m_num_scopes > 0) + if (num_scopes() > 0) m_trail.push(remove_set(b, hi+ 1)); } } diff --git a/src/ast/simplifiers/dependent_expr_state.h b/src/ast/simplifiers/dependent_expr_state.h index 32ad59681..803c58510 100644 --- a/src/ast/simplifiers/dependent_expr_state.h +++ b/src/ast/simplifiers/dependent_expr_state.h @@ -34,6 +34,7 @@ Author: #include "util/params.h" #include "ast/converters/model_converter.h" #include "ast/simplifiers/dependent_expr.h" +#include "ast/simplifiers/model_reconstruction_trail.h" /** @@ -46,6 +47,12 @@ public: virtual dependent_expr const& operator[](unsigned i) = 0; virtual void update(unsigned i, dependent_expr const& j) = 0; virtual bool inconsistent() = 0; + + trail_stack m_trail; + void push() { m_trail.push_scope(); } + void pop(unsigned n) { m_trail.pop_scope(n); } + + virtual model_reconstruction_trail* model_trail() { return nullptr; } }; /** @@ -55,20 +62,21 @@ class dependent_expr_simplifier { protected: ast_manager& m; dependent_expr_state& m_fmls; - unsigned m_qhead = 0; // pointer into last processed formula in m_fmls - unsigned m_num_scopes = 0; - trail_stack m_trail; - void advance_qhead(unsigned sz) { if (m_num_scopes > 0) m_trail.push(value_trail(m_qhead)); m_qhead = sz; } + trail_stack& m_trail; + unsigned m_qhead = 0; // pointer into last processed formula in m_fmls + + unsigned num_scopes() const { return m_trail.get_num_scopes(); } + + void advance_qhead(unsigned sz) { if (num_scopes() > 0) m_trail.push(value_trail(m_qhead)); m_qhead = sz; } public: - dependent_expr_simplifier(ast_manager& m, dependent_expr_state& s) : m(m), m_fmls(s) {} + dependent_expr_simplifier(ast_manager& m, dependent_expr_state& s) : m(m), m_fmls(s), m_trail(s.m_trail) {} virtual ~dependent_expr_simplifier() {} - virtual void push() { m_num_scopes++; m_trail.push_scope(); } - virtual void pop(unsigned n) { m_num_scopes -= n; m_trail.pop_scope(n); } + virtual void push() { } + virtual void pop(unsigned n) { } virtual void reduce() = 0; virtual void collect_statistics(statistics& st) const {} virtual void reset_statistics() {} virtual void updt_params(params_ref const& p) {} - virtual model_converter_ref get_model_converter() { return model_converter_ref(); } }; /** diff --git a/src/ast/simplifiers/euf_completion.cpp b/src/ast/simplifiers/euf_completion.cpp index e5b328d7f..e1360fa0b 100644 --- a/src/ast/simplifiers/euf_completion.cpp +++ b/src/ast/simplifiers/euf_completion.cpp @@ -213,7 +213,7 @@ namespace euf { old_value = nullptr; } }; - if (m_num_scopes > 0) + if (num_scopes() > 0) m_trail.push(vtrail(m_canonical, n->get_id())); m_canonical.setx(n->get_id(), e); m_epochs.setx(n->get_id(), m_epoch, 0); diff --git a/src/ast/simplifiers/model_reconstruction_trail.cpp b/src/ast/simplifiers/model_reconstruction_trail.cpp index 077443e7e..a8e75bfa3 100644 --- a/src/ast/simplifiers/model_reconstruction_trail.cpp +++ b/src/ast/simplifiers/model_reconstruction_trail.cpp @@ -21,11 +21,12 @@ void model_reconstruction_trail::replay(dependent_expr const& d, vector rp = mk_default_expr_replacer(m, false); add_vars(d, free_vars); added.push_back(d); + for (auto& t : m_trail) { if (!t->m_active) continue; @@ -45,13 +46,12 @@ void model_reconstruction_trail::replay(dependent_expr const& d, vectorset_substitution(t->m_subst.get()); // rigid entries: // apply substitution to added in case of rigid model convertions for (auto& d : added) { auto [f, dep1] = d(); - expr_ref g(m); - expr_dependency_ref dep2(m); - (*t->m_replace)(f, g, dep2); + auto [g, dep2] = rp->replace_with_dep(f); d = dependent_expr(m, g, m.mk_join(dep1, dep2)); } } @@ -69,9 +69,9 @@ model_converter_ref model_reconstruction_trail::get_model_converter() { // substituted variables by their terms. // - scoped_ptr rp = mk_default_expr_replacer(m, true); - scoped_ptr subst = alloc(expr_substitution, m, true, false); - rp->set_substitution(subst.get()); + scoped_ptr rp = mk_default_expr_replacer(m, false); + expr_substitution subst(m, true, false); + rp->set_substitution(&subst); generic_model_converter_ref mc = alloc(generic_model_converter, m, "dependent-expr-model"); bool first = true; for (unsigned i = m_trail.size(); i-- > 0; ) { @@ -83,19 +83,17 @@ model_converter_ref model_reconstruction_trail::get_model_converter() { first = false; for (auto const& [v, def] : t->m_subst->sub()) { expr_dependency* dep = t->m_subst->dep(v); - subst->insert(v, def, dep); + subst.insert(v, def, dep); mc->add(v, def); } continue; } - expr_dependency_ref new_dep(m); - expr_ref new_def(m); for (auto const& [v, def] : t->m_subst->sub()) { - rp->operator()(def, new_def, new_dep); + auto [new_def, new_dep] = rp->replace_with_dep(def); expr_dependency* dep = t->m_subst->dep(v); new_dep = m.mk_join(dep, new_dep); - subst->insert(v, new_def, new_dep); + subst.insert(v, new_def, new_dep); mc->add(v, new_def); } diff --git a/src/ast/simplifiers/model_reconstruction_trail.h b/src/ast/simplifiers/model_reconstruction_trail.h index 4aa8a54fd..c9b42bc92 100644 --- a/src/ast/simplifiers/model_reconstruction_trail.h +++ b/src/ast/simplifiers/model_reconstruction_trail.h @@ -33,13 +33,12 @@ Author: class model_reconstruction_trail { struct entry { - scoped_ptr m_replace; scoped_ptr m_subst; vector m_removed; bool m_active = true; - entry(expr_replacer* r, expr_substitution* s, vector const& rem) : - m_replace(r), m_subst(s), m_removed(rem) {} + entry(expr_substitution* s, vector const& rem) : + m_subst(s), m_removed(rem) {} bool is_loose() const { return !m_removed.empty(); } @@ -64,7 +63,8 @@ class model_reconstruction_trail { bool intersects(ast_mark const& free_vars, dependent_expr const& d) { expr_ref term(d.fml(), d.get_manager()); - return any_of(subterms::all(term), [&](expr* t) { return free_vars.is_marked(t); }); + auto iter = subterms::all(term); + return any_of(iter, [&](expr* t) { return free_vars.is_marked(t); }); } bool intersects(ast_mark const& free_vars, vector const& added) { @@ -77,10 +77,10 @@ public: m(m), m_trail_stack(tr) {} /** - * add a new substitution to the stack + * add a new substitution to the trail */ - void push(expr_replacer* r, vector const& removed) { - m_trail.push_back(alloc(entry, r, nullptr, removed)); + void push(expr_substitution* s, vector const& removed) { + m_trail.push_back(alloc(entry, s, removed)); m_trail_stack.push(push_back_vector(m_trail)); } diff --git a/src/ast/simplifiers/solve_eqs.cpp b/src/ast/simplifiers/solve_eqs.cpp index e4361ad97..b5b500a96 100644 --- a/src/ast/simplifiers/solve_eqs.cpp +++ b/src/ast/simplifiers/solve_eqs.cpp @@ -60,7 +60,7 @@ namespace euf { m_id2level.reset(); m_id2level.resize(m_id2var.size(), UINT_MAX); m_subst_ids.reset(); - m_subst = alloc(expr_substitution, m, true, false); + m_subst = alloc(expr_substitution, m, true, false); auto is_explored = [&](unsigned id) { return m_id2level[id] != UINT_MAX; @@ -105,30 +105,22 @@ namespace euf { } } - void solve_eqs::add_subst(dependent_eq const& eq) { - SASSERT(can_be_var(eq.var)); - m_subst->insert(eq.var, eq.term, nullptr, eq.dep); - ++m_stats.m_num_elim_vars; - } - void solve_eqs::normalize() { scoped_ptr rp = mk_default_expr_replacer(m, true); - m_subst->reset(); rp->set_substitution(m_subst.get()); std::sort(m_subst_ids.begin(), m_subst_ids.end(), [&](unsigned u, unsigned v) { return m_id2level[u] > m_id2level[v]; }); - expr_dependency_ref new_dep(m); - expr_ref new_def(m); - for (unsigned id : m_subst_ids) { if (!m.inc()) break; auto const& [v, def, dep] = m_next[id][0]; - rp->operator()(def, new_def, new_dep); + auto [new_def, new_dep] = rp->replace_with_dep(def); m_stats.m_num_steps += rp->get_num_steps() + 1; + ++m_stats.m_num_elim_vars; new_dep = m.mk_join(dep, new_dep); - m_subst->insert(v, new_def, nullptr, new_dep); + m_subst->insert(v, new_def, new_dep); + SASSERT(can_be_var(v)); // we updated the substitution, but we don't need to reset rp // because all cached values there do not depend on v. } @@ -147,11 +139,10 @@ namespace euf { return; scoped_ptr rp = mk_default_expr_replacer(m, true); rp->set_substitution(m_subst.get()); - expr_ref new_f(m); - expr_dependency_ref new_dep(m); + for (unsigned i = m_qhead; i < m_fmls.size() && !m_fmls.inconsistent(); ++i) { auto [f, d] = m_fmls[i](); - rp->operator()(f, new_f, new_dep); + auto [new_f, new_dep] = rp->replace_with_dep(f); if (new_f == f) continue; new_dep = m.mk_join(d, new_dep); @@ -164,13 +155,27 @@ namespace euf { for (extract_eq* ex : m_extract_plugins) ex->pre_process(m_fmls); - // TODO add a loop. - dep_eq_vector eqs; - get_eqs(eqs); - extract_dep_graph(eqs); - extract_subst(); - apply_subst(); + unsigned count = 0; + do { + m_subst_ids.reset(); + if (!m.inc()) + return; + dep_eq_vector eqs; + get_eqs(eqs); + extract_dep_graph(eqs); + extract_subst(); + apply_subst(); + ++count; + } + while (!m_subst_ids.empty() && count < 20); + advance_qhead(m_fmls.size()); + save_subst(); + } + + void solve_eqs::save_subst() { + if (!m_subst->empty()) + m_fmls.model_trail()->push(m_subst.detach(), {}); } void solve_eqs::filter_unsafe_vars() { @@ -181,16 +186,7 @@ namespace euf { m_unsafe_vars.mark(term); } - typedef generic_model_converter gmc; - model_converter_ref solve_eqs::get_model_converter() { - model_converter_ref mc = alloc(gmc, m, "solve-eqs"); - for (unsigned id : m_subst_ids) { - auto* v = m_id2var[id]; - static_cast(mc.get())->add(v, m_subst->find(v)); - } - return mc; - } solve_eqs::solve_eqs(ast_manager& m, dependent_expr_state& fmls) : dependent_expr_simplifier(m, fmls), m_rewriter(m) { diff --git a/src/ast/simplifiers/solve_eqs.h b/src/ast/simplifiers/solve_eqs.h index 49cd90ca2..db7a1323b 100644 --- a/src/ast/simplifiers/solve_eqs.h +++ b/src/ast/simplifiers/solve_eqs.h @@ -67,6 +67,7 @@ namespace euf { void extract_dep_graph(dep_eq_vector& eqs); void normalize(); void apply_subst(); + void save_subst(); public: @@ -78,7 +79,5 @@ namespace euf { void updt_params(params_ref const& p) override; void collect_statistics(statistics& st) const override; - - model_converter_ref get_model_converter() override; }; } diff --git a/src/tactic/dependent_expr_state_tactic.h b/src/tactic/dependent_expr_state_tactic.h index ae3635c77..719a29eea 100644 --- a/src/tactic/dependent_expr_state_tactic.h +++ b/src/tactic/dependent_expr_state_tactic.h @@ -24,12 +24,16 @@ class dependent_expr_state_tactic : public tactic, public dependent_expr_state { std::string m_name; ref m_factory; scoped_ptr m_simp; + trail_stack m_trail; + scoped_ptr m_model_trail; goal_ref m_goal; dependent_expr m_dep; void init() { if (!m_simp) m_simp = m_factory->mk(m, m_params, *this); + if (!m_model_trail) + m_model_trail = alloc(model_reconstruction_trail, m, m_trail); } public: @@ -60,6 +64,10 @@ public: bool inconsistent() override { return m_goal->inconsistent(); } + + model_reconstruction_trail* model_trail() override { + return m_model_trail.get(); + } char const* name() const override { return m_name.c_str(); } @@ -83,7 +91,7 @@ public: m_simp->reduce(); m_goal->inc_depth(); if (in->models_enabled()) - in->set(m_simp->get_model_converter().get()); + in->set(m_model_trail->get_model_converter().get()); result.push_back(in.get()); } diff --git a/src/util/util.h b/src/util/util.h index c3f06d8d3..925e20186 100644 --- a/src/util/util.h +++ b/src/util/util.h @@ -367,7 +367,6 @@ bool any_of(S& set, T& p) { return true; return false; } -// #define any_of(S, p) { for (auto const& s : S) if (p(s)) return true; return false; } /** \brief Iterator for the [0..sz[0]) X [0..sz[1]) X ... X [0..sz[n-1]).