From f4c5e14b6b13d5412a33c1f68839857797fd7fda Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Fri, 7 Nov 2025 08:04:41 -0800 Subject: [PATCH] add fold-unfold simplification this substitutes solve-eqs by including fold-unfold reductions. --- src/ast/simplifiers/CMakeLists.txt | 1 + src/ast/simplifiers/euf_completion.cpp | 4 + src/ast/simplifiers/fold_unfold.cpp | 396 +++++++++++++++++++++++++ src/ast/simplifiers/fold_unfold.h | 108 +++++++ src/tactic/core/CMakeLists.txt | 1 + src/tactic/core/fold_unfold_tactic.h | 43 +++ 6 files changed, 553 insertions(+) create mode 100644 src/ast/simplifiers/fold_unfold.cpp create mode 100644 src/ast/simplifiers/fold_unfold.h create mode 100644 src/tactic/core/fold_unfold_tactic.h diff --git a/src/ast/simplifiers/CMakeLists.txt b/src/ast/simplifiers/CMakeLists.txt index d947011ae..a59550c17 100644 --- a/src/ast/simplifiers/CMakeLists.txt +++ b/src/ast/simplifiers/CMakeLists.txt @@ -15,6 +15,7 @@ z3_add_component(simplifiers eliminate_predicates.cpp euf_completion.cpp extract_eqs.cpp + fold_unfold.cpp linear_equation.cpp max_bv_sharing.cpp model_reconstruction_trail.cpp diff --git a/src/ast/simplifiers/euf_completion.cpp b/src/ast/simplifiers/euf_completion.cpp index a78338226..cd56aea60 100644 --- a/src/ast/simplifiers/euf_completion.cpp +++ b/src/ast/simplifiers/euf_completion.cpp @@ -806,6 +806,7 @@ namespace euf { // callback when mam finds a binding void completion::on_binding(quantifier* q, app* pat, enode* const* binding, unsigned max_global, unsigned min_top, unsigned max_top) { + verbose_stream() << "on-binding\n"; if (should_stop()) return; if (max_top >= m_max_generation) @@ -863,6 +864,8 @@ namespace euf { pr = m.mk_quant_inst(m.mk_or(m.mk_not(q), r), s.size(), s.data()); m_consequences.push_back(r); TRACE(euf_completion, tout << "new instantiation: " << r << " q: " << mk_pp(q, m) << "\n"); + verbose_stream() << mk_pp(q, m) << " " << r + << "\n "; add_constraint(r, pr, d); propagate_rules(); m_egraph.propagate(); @@ -1131,6 +1134,7 @@ namespace euf { } enode* n = m_egraph.find(f); if (!n) n = mk_enode(f); + enode* r = n->get_root(); d = m.mk_join(d, explain_eq(n, r)); d = m.mk_join(d, m_deps.get(r->get_id(), nullptr)); diff --git a/src/ast/simplifiers/fold_unfold.cpp b/src/ast/simplifiers/fold_unfold.cpp new file mode 100644 index 000000000..346f96683 --- /dev/null +++ b/src/ast/simplifiers/fold_unfold.cpp @@ -0,0 +1,396 @@ +/*++ +Copyright (c) 2022 Microsoft Corporation + +Module Name: + + fold_unfold.h + +Abstract: + + fold-unfold simplifier + +Author: + + Nikolaj Bjorner (nbjorner) 2025-11-5. + +- remove alias x = y +- remove alias with const x = k +- fold-unfold simplification x = f(y), y = g(z), f(g(z)) = u -> x |-> u + +- assign levels to E-nodes: + - dfs over roots. + - visit children, assign level + - +- remove alias with linear x = f(y) -> x |-> f(y) if level y < level x +--*/ + +#include "ast/ast_pp.h" +#include "ast/simplifiers/fold_unfold.h" +#include "ast/rewriter/expr_replacer.h" +#include "util/union_find.h" +#include "params/smt_params_helper.hpp" + +namespace euf { + + fold_unfold::fold_unfold(ast_manager& m, dependent_expr_state& fmls) + : dependent_expr_simplifier(m, fmls), + m_rewriter(m), + m_egraph(m) { + register_extract_eqs(m, m_extract_plugins); + m_rewriter.set_flat_and_or(false); + // flat sum/prod := false + } + + void fold_unfold::reduce() { + if (!m_config.m_enabled) + return; + + m_fmls.freeze_suffix(); + + for (extract_eq* ex : m_extract_plugins) + ex->pre_process(m_fmls); + + reduce_alias(true); + reduce_linear(); + reduce_alias(false); + } + + void fold_unfold::reduce_alias(bool fuf) { + m_subst = nullptr; + dep_eq_vector eqs; + get_eqs(eqs); + extract_subst(fuf, eqs); + vector old_fmls; + apply_subst(old_fmls); + } + + void fold_unfold::get_eqs(dep_eq_vector& eqs) { + for (extract_eq* ex : m_extract_plugins) + for (unsigned i : indices()) + ex->get_eqs(m_fmls[i], eqs); + } + + void fold_unfold::extract_subst(bool fuf, dep_eq_vector const& eqs) { + m_find.reset(); + for (auto const& [orig, v, t, d] : eqs) { + auto a = mk_enode(v); + auto b = mk_enode(t); + // verbose_stream() << mk_bounded_pp(v, m) << " == " << mk_bounded_pp(t, m) << "\n"; + proof_ref pr(m); + auto j = to_ptr(push_pr_dep(pr, d)); + m_egraph.merge(a, b, j); + } + + // choose uninterpreted or value representative + auto find_rep = [&](enode *a, ptr_buffer& vars) { + enode *rep = nullptr; + for (auto b : euf::enode_class(a)) { + expr *t = b->get_expr(); + if (is_uninterp_const(t)) + vars.push_back(b); + if (m.is_value(t)) + rep = b; + } + if (!rep) { + for (auto v : vars) + if (!rep || v->get_id() < rep->get_id()) + rep = v; + } + return rep; + }; + + for (auto a : m_egraph.nodes()) { + if (!a->is_root()) + continue; + ptr_buffer vars; + enode *rep = find_rep(a, vars); + if (!rep) + continue; + for (auto w : vars) { + if (w != rep) + m_find.setx(w->get_id(), rep, nullptr); + } + } + if (fuf) { + // find new equalities by performing fold-unfold + vector> new_eqs; + for (auto n : m_egraph.nodes()) { + if (!n->is_root()) + continue; + auto ne = n->get_expr(); + unsigned depth = 3; + vector> es; + unfold(depth, n, nullptr, es); + // verbose_stream() << "unfolds " << es.size() << "\n"; + for (auto [e, d] : es) { + expr_ref r(m); + proof_ref pr(m); + fold(e, r, pr); + if (ne == r) + continue; + new_eqs.push_back({n, r, pr, d}); + } + } + for (auto const &[a, t, pr, d] : new_eqs) { + auto b = mk_enode(t); + auto j = to_ptr(push_pr_dep(pr, d)); + m_egraph.merge(a, b, j); + } + } + + for (auto a : m_egraph.nodes()) { + if (!a->is_root()) + continue; + ptr_buffer vars; + enode *rep = find_rep(a, vars); + if (!rep) + continue; + for (auto v : vars) { + if (v == rep) + continue; + m_find.setx(v->get_id(), rep, nullptr); + // verbose_stream() << "insert " << mk_pp(v->get_expr(), m) << " " << mk_pp(rep->get_expr(), m) << "\n"; + insert_subst(v->get_expr(), rep->get_expr(), explain_eq(v, rep)); + m_stats.m_num_elim_vars++; + } + } + } + + expr_dependency *fold_unfold::explain_eq(enode *a, enode *b) { + if (a == b) + return nullptr; + ptr_vector just; + m_egraph.begin_explain(); + m_egraph.explain_eq(just, nullptr, a, b); + m_egraph.end_explain(); + expr_dependency *d = nullptr; + for (size_t *j : just) + d = m.mk_join(d, m_pr_dep[from_ptr(j)].second); + return d; + } + + unsigned fold_unfold::push_pr_dep(proof *pr, expr_dependency *d) { + unsigned sz = m_pr_dep.size(); + SASSERT(!m.proofs_enabled() || pr); + m_pr_dep.push_back({proof_ref(pr, m), d}); + m_trail.push(push_back_vector(m_pr_dep)); + return sz; + } + + enode *fold_unfold::mk_enode(expr *e) { + m_todo.push_back(e); + enode *n; + while (!m_todo.empty()) { + e = m_todo.back(); + if (m_egraph.find(e)) { + m_todo.pop_back(); + continue; + } + if (!is_app(e)) { + m_egraph.mk(e, m_generation, 0, nullptr); + m_todo.pop_back(); + continue; + } + m_args.reset(); + unsigned sz = m_todo.size(); + for (expr *arg : *to_app(e)) { + n = m_egraph.find(arg); + if (n) + m_args.push_back(n); + else + m_todo.push_back(arg); + } + if (sz == m_todo.size()) { + n = m_egraph.mk(e, m_generation, m_args.size(), m_args.data()); + if (m_egraph.get_plugin(e->get_sort()->get_family_id())) + m_egraph.add_th_var(n, m_th_var++, e->get_sort()->get_family_id()); + if (!m.is_eq(e)) { + for (auto ch : m_args) + for (auto idv : euf::enode_th_vars(*ch)) + m_egraph.register_shared(n, idv.get_id()); + } + m_todo.pop_back(); + } + } + return m_egraph.find(e); + } + + + void fold_unfold::fold(expr *e, expr_ref &result, proof_ref &pr) { + m_rewriter(e, result, pr); + } + + void fold_unfold::unfold(unsigned n, enode *e, expr_dependency* d, vector>& es) { + if (n == 0) { + es.push_back({expr_ref(e->get_expr(), m), d}); + return; + } + if (es.size() > 10) + return; + unsigned count = 0; + for (auto sib : euf::enode_class(e)) { + auto sib_e = sib->get_expr(); + if (!is_app(sib_e)) + continue; + if (is_uninterp_const(sib_e)) { + auto f = m_find.get(sib->get_id(), nullptr); + if (f && f != sib) + continue; + } + ++count; + expr_ref_vector args(m); + expr_dependency *d1 = m.mk_join(d, explain_eq(sib, e)); + unfold_arg(n, 0, sib, args, d1, es); + if (count > 2) + break; + } + // verbose_stream() << "count " << count << "\n"; + } + + void fold_unfold::unfold_arg(unsigned n, unsigned i, enode* e, expr_ref_vector& args, expr_dependency* d, + vector>& es) { + if (i == e->num_args()) { + es.push_back({expr_ref(m.mk_app(e->get_decl(), args), m), d}); + return; + } + vector> es_arg; + unfold(n - 1, e->get_arg(i), d, es_arg); + for (auto [arg, dep] : es_arg) { + args.push_back(arg); + unfold_arg(n, i + 1, e, args, dep, es); + args.pop_back(); + if (es.size() > 10) + return; + } + } + + void fold_unfold::insert_subst(expr * v, expr * t, expr_dependency* d) { + if (!m_subst) + m_subst = alloc(expr_substitution, m, true, false); + m_subst->insert(v, t, d); + } + + void fold_unfold::apply_subst(vector &old_fmls) { + if (!m.inc()) + return; + if (!m_subst) + return; + + scoped_ptr rp = mk_default_expr_replacer(m, false); + rp->set_substitution(m_subst.get()); + + for (unsigned i : indices()) { + auto [f, p, d] = m_fmls[i](); + auto [new_f, new_dep] = rp->replace_with_dep(f); + proof_ref new_pr(m); + expr_ref tmp(m); + m_rewriter(new_f, tmp, new_pr); + if (tmp == f) + continue; + new_dep = m.mk_join(d, new_dep); + old_fmls.push_back(m_fmls[i]); + m_fmls.update(i, dependent_expr(m, tmp, mp(p, new_pr), new_dep)); + } + m_fmls.model_trail().push(m_subst.detach(), old_fmls, false); + } + + void fold_unfold::set_levels() { + m_node2level.reset(); + m_level2node.reset(); + m_level_count = 0; + for (auto n : m_egraph.nodes()) + if (n->is_root()) + set_level(n); + for (auto n : m_egraph.nodes()) + if (n->is_root()) + n->unmark1(); + } + + void fold_unfold::set_level(enode* n) { + SASSERT(n->is_root()); + + if (m_node2level.get(n->get_id(), UINT_MAX) != UINT_MAX) + return; + + if (!n->is_marked1()) { + n->mark1(); + for (auto b : enode_class(n)) { + for (auto arg : enode_args(b)) + set_level(arg->get_root()); + } + } + if (m_node2level.get(n->get_id(), UINT_MAX) != UINT_MAX) + return; + for (auto a : enode_class(n)) { + m_node2level.setx(a->get_id(), m_level_count, UINT_MAX); + m_level2node.setx(m_level_count, a, nullptr); + } + ++m_level_count; + } + + void fold_unfold::reduce_linear() { + set_levels(); + m_subst = alloc(expr_substitution, m, true, false); + scoped_ptr rp = mk_default_expr_replacer(m, false); + rp->set_substitution(m_subst.get()); + for (auto n : m_level2node) { + SASSERT(n); + SASSERT(n->is_root()); + // if a is uninterpreted and is not eliminated, + // n is equal to a linear term with lower level argument + // back-substitute the linear term using existing subst. + // update subst with a -> linear term + enode *var = nullptr; + enode *term = nullptr; + for (auto a : enode_class(n)) { + if (m_find.get(a->get_id(), nullptr) != nullptr) // already substituted + continue; + if (is_uninterp_const(a->get_expr())) + var = a; + else if (is_linear_term(a)) + term = a; + } + if (var && term) { + m_find.setx(var->get_id(), term, nullptr); // record that var was replaced + auto dep = explain_eq(var, term); + auto [new_term, new_dep] = rp->replace_with_dep(term->get_expr()); + expr_ref r(m); + proof_ref pr(m); + m_rewriter(new_term, r, pr); + m_subst->insert(var->get_expr(), r, m.mk_join(dep, new_dep)); + } + } + vector old_fmls; + apply_subst(old_fmls); + } + + bool fold_unfold::is_linear_term(enode *n) { + unsigned num_vars = 0; + unsigned level = m_node2level[n->get_root_id()]; + for (auto arg : enode_args(n)) + if (!m.is_value(arg->get_expr())) { + if (m_node2level[arg->get_root_id()] >= level) + return false; + ++num_vars; + } + return num_vars <= 1; + } + + void fold_unfold::updt_params(params_ref const &p) { + m_config.m_enabled = true; + params_ref p1; + p1.set_bool("eliminate_mod", false); + for (auto ex : m_extract_plugins) { + ex->updt_params(p); + ex->updt_params(p1); + } + } + + void fold_unfold::collect_param_descrs(param_descrs &r) {} + + void fold_unfold::collect_statistics(statistics &st) const { + st.update("fold-unfold-steps", m_stats.m_num_steps); + st.update("fold-unfold-elim-vars", m_stats.m_num_elim_vars); + } + +} diff --git a/src/ast/simplifiers/fold_unfold.h b/src/ast/simplifiers/fold_unfold.h new file mode 100644 index 000000000..577801f2d --- /dev/null +++ b/src/ast/simplifiers/fold_unfold.h @@ -0,0 +1,108 @@ + +/*++ +Copyright (c) 2022 Microsoft Corporation + +Module Name: + + fold_unfold.h + +Abstract: + + fold-unfold simplifier + +Author: + + Nikolaj Bjorner (nbjorner) 2025-11-5. + +--*/ + +#pragma once + +#include "util/scoped_ptr_vector.h" +#include "ast/expr_substitution.h" +#include "ast/rewriter/th_rewriter.h" +#include "ast/simplifiers/extract_eqs.h" +#include "ast/euf/euf_egraph.h" + +namespace euf { + + class fold_unfold : public dependent_expr_simplifier { + friend class solve_context_eqs; + + struct stats { + unsigned m_num_steps = 0; + unsigned m_num_elim_vars = 0; + void reset() { + m_num_steps = 0; + m_num_elim_vars = 0; + } + }; + + struct config { + bool m_enabled = true; + }; + + stats m_stats; + config m_config; + th_rewriter m_rewriter; + egraph m_egraph; + scoped_ptr_vector m_extract_plugins; + unsigned_vector m_var2id; // app->get_id() |-> small numeral + scoped_ptr m_subst; // current substitution + vector> m_pr_dep; + + void get_eqs(dep_eq_vector &eqs); + void extract_subst(bool fuf, dep_eq_vector const &eqs); + void insert_subst(expr *v, expr *t, expr_dependency* d); + void apply_subst(vector &old_fmls); + void reduce_alias(bool fuf); + void reduce_linear(); + + size_t *to_ptr(size_t i) const { + return reinterpret_cast(i); + } + unsigned from_ptr(size_t *s) const { + return (unsigned)reinterpret_cast(s); + } + unsigned push_pr_dep(proof *pr, expr_dependency *d); + expr_dependency *explain_eq(enode *a, enode *b); + + ptr_vector m_todo; + enode_vector m_args, m_find; + unsigned_vector m_node2level; + enode_vector m_level2node; + unsigned m_level_count = 0; + + void set_levels(); + void set_level(enode *n); + bool is_linear_term(enode *n); + + unsigned m_generation = 0; + unsigned m_th_var = 0; + enode *mk_enode(expr *e); + + void fold(expr *e, expr_ref &result, proof_ref &pr); + void unfold(unsigned n, enode *e, expr_dependency* d, vector> &es); + void unfold_arg(unsigned n, unsigned i, enode *e, expr_ref_vector &args, expr_dependency *d, + vector> &es); + + public: + fold_unfold(ast_manager &m, dependent_expr_state &fmls); + + char const *name() const override { + return "fold-unfold"; + } + + void reduce() override; + + void updt_params(params_ref const &p) override; + + void collect_param_descrs(param_descrs &r) override; + + void collect_statistics(statistics &st) const override; + + void reset_statistics() override { + m_stats.reset(); + } + }; +} // namespace euf diff --git a/src/tactic/core/CMakeLists.txt b/src/tactic/core/CMakeLists.txt index a191b6251..e65b53078 100644 --- a/src/tactic/core/CMakeLists.txt +++ b/src/tactic/core/CMakeLists.txt @@ -37,6 +37,7 @@ z3_add_component(core_tactics elim_uncnstr_tactic.h elim_uncnstr2_tactic.h eliminate_predicates_tactic.h + fold_unfold_tactic.h injectivity_tactic.h nnf_tactic.h occf_tactic.h diff --git a/src/tactic/core/fold_unfold_tactic.h b/src/tactic/core/fold_unfold_tactic.h new file mode 100644 index 000000000..5963ad5ac --- /dev/null +++ b/src/tactic/core/fold_unfold_tactic.h @@ -0,0 +1,43 @@ +/*++ +Copyright (c) 2025 Microsoft Corporation + +Module Name: + + fold_unfold_tactic.h + +Abstract: + + Tactic for solving variables using fold/unfold transformations. + +Author: + + Nikolaj Bjorner (nbjorner) 2025-11-05 + +Tactic Documentation: + +## Tactic fold-unfold + +### Short Description + +Apply fold-unfold simplifications to solve for equalities + + +--*/ + +#pragma once +#include "util/params.h" +#include "tactic/tactic.h" +#include "tactic/dependent_expr_state_tactic.h" +#include "ast/simplifiers/fold_unfold.h" + +inline tactic *mk_fold_unfold_tactic(ast_manager &m, params_ref const &p = params_ref()) { + return alloc(dependent_expr_state_tactic, m, p, + [](auto &m, auto &p, auto &s) -> dependent_expr_simplifier * { return alloc(euf::fold_unfold, m, s); }); +} + +/* + ADD_TACTIC("fold-unfold", "solve for variables.", "mk_fold_unfold_tactic(m, p)") + ADD_SIMPLIFIER("fold-unfold", "solve for variables.", "alloc(euf::fold_unfold, m, s)") +*/ + +