diff --git a/src/ast/simplifiers/solve_eqs.cpp b/src/ast/simplifiers/solve_eqs.cpp index 77808c142..a47a05cee 100644 --- a/src/ast/simplifiers/solve_eqs.cpp +++ b/src/ast/simplifiers/solve_eqs.cpp @@ -16,16 +16,69 @@ Author: --*/ - +#include "util/trace.h" +#include "ast/ast_util.h" +#include "ast/for_each_expr.h" +#include "ast/ast_pp.h" +#include "ast/arith_decl_plugin.h" +#include "ast/rewriter/expr_replacer.h" #include "ast/simplifiers/solve_eqs.h" namespace euf { + class basic_extract_eq : public extract_eq { + ast_manager& m; + public: + basic_extract_eq(ast_manager& m) : m(m) {} + void get_eqs(dependent_expr const& e, dep_eq_vector& eqs) { + auto [f, d] = e(); + expr* x, * y; + if (m.is_eq(f, x, y)) { + if (is_uninterp_const(x)) + eqs.push_back(dependent_eq(to_app(x), expr_ref(y, m), d)); + if (is_uninterp_const(y)) + eqs.push_back(dependent_eq(to_app(y), expr_ref(x, m), d)); + } + expr* c, * th, * el, * x1, * y1, * x2, * y2; + if (m.is_ite(f, c, th, el)) { + if (m.is_eq(th, x1, y1) && m.is_eq(el, x2, y2)) { + if (x1 == y2 && is_uninterp_const(x1)) + std::swap(x2, y2); + if (x2 == y2 && is_uninterp_const(x2)) + std::swap(x2, y2), std::swap(x1, y1); + if (x2 == y1 && is_uninterp_const(x2)) + std::swap(x1, y1); + if (x1 == x2 && is_uninterp_const(x1)) + eqs.push_back(dependent_eq(to_app(x1), expr_ref(m.mk_ite(c, y1, y2), m), d)); + } + } + if (is_uninterp_const(f)) + eqs.push_back(dependent_eq(to_app(f), expr_ref(m.mk_true(), m), d)); + if (m.is_not(f, x) && is_uninterp_const(x)) + eqs.push_back(dependent_eq(to_app(x), expr_ref(m.mk_false(), m), d)); + } + }; - void solve_eqs::init() { + class arith_extract_eq : public extract_eq { + ast_manager& m; + arith_util a; +#if 0 + void solve_eq(expr* f, expr_depedency* d) { - } + } +#endif + public: + arith_extract_eq(ast_manager& m) : m(m), a(m) {} + void get_eqs(dependent_expr const& e, dep_eq_vector& eqs) { +#if 0 + auto [f, d] = e(); + expr* x, * y; + if (m.is_eq(f, x, y) && a.is_int_real(x)) + ; +#endif + } + }; // initialize graph that maps variable ids to next ids void solve_eqs::extract_dep_graph(dep_eq_vector& eqs) { @@ -44,12 +97,8 @@ namespace euf { } m_next.resize(m_id2var.size()); - for (auto const& [v, t, d] : eqs) - m_next[var2id(v)].push_back(t); - } - - void solve_eqs::add_subst(app* v, expr* term) { - + for (auto const& eq : eqs) + m_next[var2id(eq.var)].push_back(eq); } /** @@ -58,22 +107,30 @@ namespace euf { * Free variables have higher levels. */ void solve_eqs::extract_subst() { - m_var2level.reset(); - m_var2level.resize(m_id2var.size(), UINT_MAX); + m_id2level.reset(); + m_id2level.resize(m_id2var.size(), UINT_MAX); + m_subst_ids.reset(); + m_subst = alloc(expr_substitution, m, false, false); + auto is_explored = [&](unsigned id) { - return m_var2level[id] != UINT_MAX; + return m_id2level[id] != UINT_MAX; }; + auto is_safe = [&](unsigned lvl, expr* t) { for (auto* e : subterms::all(expr_ref(t, m))) - if (is_var(e) && m_var2level[var2id(e)] < lvl) + if (is_var(e) && m_id2level[var2id(e)] < lvl) return false; + return true; }; unsigned init_level = UINT_MAX; + unsigned_vector todo; for (unsigned id = 0; id < m_id2var.size(); ++id) { if (is_explored(id)) continue; // initialize current level to have enough room to assign different levels to all variables. + if (init_level < m_id2var.size() + 1) + return; init_level -= m_id2var.size() + 1; unsigned curr_level = init_level; todo.push_back(id); @@ -82,12 +139,14 @@ namespace euf { todo.pop_back(); if (is_explored(j)) continue; - m_var2level[id] = curr_level++; - for (expr* t : m_next[j]) { + m_id2level[id] = curr_level++; + for (auto const& eq : m_next[j]) { + auto const& [v, t, d] = eq; if (!is_safe(curr_level, t)) continue; - add_subst(m_id2var[j], t); - for (auto* e : subterms::all(expr_ref(t, m))) + m_next[j][0] = eq; + m_subst_ids.push_back(id); + for (expr* e : subterms::all(expr_ref(t, m))) if (is_var(e) && !is_explored(var2id(e))) todo.push_back(var2id(e)); break; @@ -96,19 +155,65 @@ namespace euf { } } - void solve_eqs::extract_subst(dep_eq_vector& eqs, dep_eq_vector& subst) { + void solve_eqs::add_subst(dependent_eq const& eq) { + m_subst->insert(eq.var, eq.term, nullptr, eq.dep); + } + void solve_eqs::normalize() { + scoped_ptr rp = mk_default_expr_replacer(m, true); + m_subst->reset(); + rp->set_substitution(m_subst.get()); + + std::sort(m_subst_ids.begin(), m_subst_ids.end(), [&](unsigned u, unsigned v) { return m_id2level[u] > m_id2level[v]; }); + + expr_dependency_ref new_dep(m); + expr_ref new_def(m); + proof_ref new_pr(m); + + for (unsigned id : m_subst_ids) { + // checkpoint(); + auto const& [v, def, dep] = m_next[id][0]; + rp->operator()(def, new_def, new_pr, new_dep); + // m_num_steps += rp->get_num_steps() + 1; + new_dep = m.mk_join(dep, new_dep); + m_subst->insert(v, new_def, new_pr, new_dep); + // we updated the substitution, but we don't need to reset rp + // because all cached values there do not depend on v. + } + + TRACE("solve_eqs", + tout << "after normalizing variables\n"; + for (unsigned id : m_subst_ids) { + auto const& eq = m_next[id][0]; + expr* def = nullptr; + proof* pr = nullptr; + expr_dependency* dep = nullptr; + m_subst->find(eq.var, def, pr, dep); + tout << mk_pp(eq.var, m) << "\n----->\n" << mk_pp(def, m) << "\n\n"; + }); } void solve_eqs::apply_subst() { - + scoped_ptr rp = mk_default_expr_replacer(m, true); + rp->set_substitution(m_subst.get()); + expr_ref new_f(m); + proof_ref new_pr(m); + expr_dependency_ref new_dep(m); + for (unsigned i = m_qhead; i < m_fmls.size() && !m_fmls.inconsistent(); ++i) { + auto [f, d] = m_fmls[i](); + rp->operator()(f, new_f, new_pr, new_dep); + if (new_f == f) + continue; + new_dep = m.mk_join(d, new_dep); + m_fmls.update(i, dependent_expr(m, new_f, new_dep)); + } } void solve_eqs::reduce() { - init(); - dep_eq_vector eqs, subst; + dep_eq_vector eqs; get_eqs(eqs); - extract_subst(eqs, subst); + extract_dep_graph(eqs); + extract_subst(); apply_subst(); advance_qhead(m_fmls.size()); } diff --git a/src/ast/simplifiers/solve_eqs.h b/src/ast/simplifiers/solve_eqs.h index 936816bc4..55cad7e67 100644 --- a/src/ast/simplifiers/solve_eqs.h +++ b/src/ast/simplifiers/solve_eqs.h @@ -20,6 +20,8 @@ Author: #include "ast/simplifiers/dependent_expr_state.h" #include "ast/rewriter/th_rewriter.h" +#include "ast/expr_substitution.h" +#include "util/scoped_ptr_vector.h" namespace euf { @@ -34,34 +36,37 @@ namespace euf { typedef vector dep_eq_vector; class extract_eq { - pulic: + public: virtual ~extract_eq() {} - virtual void get_eqs(depdendent_expr const& e, dep_eq_vector& eqs) = 0; + virtual void get_eqs(dependent_expr const& e, dep_eq_vector& eqs) = 0; }; class solve_eqs : public dependent_expr_simplifier { th_rewriter m_rewriter; scoped_ptr_vector m_extract_plugins; - unsigned_vector m_var2id; + unsigned_vector m_var2id, m_id2level, m_subst_ids; ptr_vector m_id2var; - vector m_next; + vector m_next; + scoped_ptr m_subst; - void init(); + void add_subst(dependent_eq const& eq); - bool is_var(expr* v) const; + bool is_var(expr* e) const { return e->get_id() < m_var2id.size() && m_var2id[e->get_id()] != UINT_MAX; } unsigned var2id(expr* v) const { return m_var2id[v->get_id()]; } void get_eqs(dep_eq_vector& eqs) { for (unsigned i = m_qhead; i < m_fmls.size(); ++i) - get_eqs(m_fmls[i](), eqs); + get_eqs(m_fmls[i], eqs); } void get_eqs(dependent_expr const& f, dep_eq_vector& eqs) { - for (auto* ex : m_extract_plugins) + for (extract_eq* ex : m_extract_plugins) ex->get_eqs(f, eqs); } - void extract_subst(dep_eq_vector& eqs, dep_eq_vector& subst); + void extract_subst(); + void extract_dep_graph(dep_eq_vector& eqs); + void normalize(); void apply_subst(); public: