From 4d8860c0bc4f4a587e9e66b3494df446a52c3112 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sat, 5 Nov 2022 10:34:57 -0700 Subject: [PATCH] wip - adding context equation solver the solve_eqs_tactic is to be replaced by a re-implementation that uses solve_eqs in the simplifiers directory. The re-implementation should address efficiency issues with the previous code. At this point it punts on low level proofs. The plan is to use coarser dependency tracking instead of low level proofs for pre-processing. Dependencies can be converted into a proof hint representation that can be checked using a stronger checker. --- src/ast/occurs.cpp | 43 +++++ src/ast/occurs.h | 9 +- src/ast/simplifiers/CMakeLists.txt | 1 + src/ast/simplifiers/dependent_expr.h | 2 + src/ast/simplifiers/dependent_expr_state.h | 3 +- src/ast/simplifiers/extract_eqs.cpp | 52 +++--- src/ast/simplifiers/extract_eqs.h | 7 +- src/ast/simplifiers/solve_context_eqs.cpp | 203 +++++++++++++++++++++ src/ast/simplifiers/solve_context_eqs.h | 58 ++++++ src/ast/simplifiers/solve_eqs.cpp | 39 +++- src/ast/simplifiers/solve_eqs.h | 45 ++--- src/tactic/core/solve_eqs_tactic.cpp | 51 +----- src/tactic/dependent_expr_state_tactic.h | 4 +- src/util/mpn.cpp | 8 +- src/util/util.h | 8 + 15 files changed, 416 insertions(+), 117 deletions(-) create mode 100644 src/ast/simplifiers/solve_context_eqs.cpp create mode 100644 src/ast/simplifiers/solve_context_eqs.h diff --git a/src/ast/occurs.cpp b/src/ast/occurs.cpp index 21e7f5906..2bcd98396 100644 --- a/src/ast/occurs.cpp +++ b/src/ast/occurs.cpp @@ -74,3 +74,46 @@ bool occurs(func_decl * d, expr * n) { return false; } +void mark_occurs(ptr_vector& to_check, expr* v, expr_mark& occ) { + expr_fast_mark2 visited; + occ.mark(v, true); + visited.mark(v, true); + while (!to_check.empty()) { + expr* e = to_check.back(); + if (visited.is_marked(e)) { + to_check.pop_back(); + continue; + } + if (is_app(e)) { + bool does_occur = false; + bool all_visited = true; + for (expr* arg : *to_app(e)) { + if (!visited.is_marked(arg)) { + to_check.push_back(arg); + all_visited = false; + } + else + does_occur |= occ.is_marked(arg); + } + if (all_visited) { + occ.mark(e, does_occur); + visited.mark(e, true); + to_check.pop_back(); + } + } + else if (is_quantifier(e)) { + expr* body = to_quantifier(e)->get_expr(); + if (visited.is_marked(body)) { + visited.mark(e, true); + occ.mark(e, occ.is_marked(body)); + to_check.pop_back(); + } + else + to_check.push_back(body); + } + else { + visited.mark(e, true); + to_check.pop_back(); + } + } +} \ No newline at end of file diff --git a/src/ast/occurs.h b/src/ast/occurs.h index 15a33ddf5..7475a292c 100644 --- a/src/ast/occurs.h +++ b/src/ast/occurs.h @@ -18,8 +18,8 @@ Revision History: --*/ #pragma once -class expr; -class func_decl; +#include "util/vector.h" +#include "ast/ast.h" /** \brief Return true if n1 occurs in n2 @@ -31,4 +31,9 @@ bool occurs(expr * n1, expr * n2); */ bool occurs(func_decl * d, expr * n); +/** +* \brief Mark sub-expressions of to_check by whether v occurs in these. +*/ +void mark_occurs(ptr_vector& to_check, expr* v, expr_mark& occurs); + diff --git a/src/ast/simplifiers/CMakeLists.txt b/src/ast/simplifiers/CMakeLists.txt index dc7aa6fb6..ef04cc433 100644 --- a/src/ast/simplifiers/CMakeLists.txt +++ b/src/ast/simplifiers/CMakeLists.txt @@ -4,6 +4,7 @@ z3_add_component(simplifiers euf_completion.cpp extract_eqs.cpp model_reconstruction_trail.cpp + solve_context_eqs.cpp solve_eqs.cpp COMPONENT_DEPENDENCIES euf diff --git a/src/ast/simplifiers/dependent_expr.h b/src/ast/simplifiers/dependent_expr.h index 53f9cb9d8..f789bf332 100644 --- a/src/ast/simplifiers/dependent_expr.h +++ b/src/ast/simplifiers/dependent_expr.h @@ -72,6 +72,8 @@ public: ast_manager& get_manager() const { return m; } expr* fml() const { return m_fml; } + + expr_dependency* dep() const { return m_dep; } std::tuple operator()() const { return { m_fml, m_dep }; diff --git a/src/ast/simplifiers/dependent_expr_state.h b/src/ast/simplifiers/dependent_expr_state.h index 803c58510..6bdd34626 100644 --- a/src/ast/simplifiers/dependent_expr_state.h +++ b/src/ast/simplifiers/dependent_expr_state.h @@ -47,12 +47,13 @@ public: virtual dependent_expr const& operator[](unsigned i) = 0; virtual void update(unsigned i, dependent_expr const& j) = 0; virtual bool inconsistent() = 0; + virtual model_reconstruction_trail& model_trail() = 0; trail_stack m_trail; void push() { m_trail.push_scope(); } void pop(unsigned n) { m_trail.pop_scope(n); } - virtual model_reconstruction_trail* model_trail() { return nullptr; } + }; /** diff --git a/src/ast/simplifiers/extract_eqs.cpp b/src/ast/simplifiers/extract_eqs.cpp index 1e9b576e1..e77ce9e06 100644 --- a/src/ast/simplifiers/extract_eqs.cpp +++ b/src/ast/simplifiers/extract_eqs.cpp @@ -38,9 +38,9 @@ namespace euf { expr* x, * y; if (m.is_eq(f, x, y)) { if (is_uninterp_const(x)) - eqs.push_back(dependent_eq(to_app(x), expr_ref(y, m), d)); + eqs.push_back(dependent_eq(e.fml(), to_app(x), expr_ref(y, m), d)); if (is_uninterp_const(y)) - eqs.push_back(dependent_eq(to_app(y), expr_ref(x, m), d)); + eqs.push_back(dependent_eq(e.fml(), to_app(y), expr_ref(x, m), d)); } expr* c, * th, * el, * x1, * y1, * x2, * y2; if (m_ite_solver && m.is_ite(f, c, th, el)) { @@ -52,13 +52,13 @@ namespace euf { if (x2 == y1 && is_uninterp_const(x2)) std::swap(x1, y1); if (x1 == x2 && is_uninterp_const(x1)) - eqs.push_back(dependent_eq(to_app(x1), expr_ref(m.mk_ite(c, y1, y2), m), d)); + eqs.push_back(dependent_eq(e.fml(), to_app(x1), expr_ref(m.mk_ite(c, y1, y2), m), d)); } } if (is_uninterp_const(f)) - eqs.push_back(dependent_eq(to_app(f), expr_ref(m.mk_true(), m), d)); + eqs.push_back(dependent_eq(e.fml(), to_app(f), expr_ref(m.mk_true(), m), d)); if (m.is_not(f, x) && is_uninterp_const(x)) - eqs.push_back(dependent_eq(to_app(x), expr_ref(m.mk_false(), m), d)); + eqs.push_back(dependent_eq(e.fml(), to_app(x), expr_ref(m.mk_false(), m), d)); } void updt_params(params_ref const& p) { @@ -76,7 +76,7 @@ namespace euf { // solve u mod r1 = y -> u = r1*mod!1 + y - void solve_mod(expr* x, expr* y, expr_dependency* d, dep_eq_vector& eqs) { + void solve_mod(expr* orig, expr* x, expr* y, expr_dependency* d, dep_eq_vector& eqs) { expr* u, * z; rational r1, r2; if (!a.is_mod(x, u, z)) @@ -87,7 +87,11 @@ namespace euf { return; expr_ref term(m); term = a.mk_add(a.mk_mul(z, m.mk_fresh_const("mod", a.mk_int())), y); - solve_eq(u, term, d, eqs); + + if (is_uninterp_const(u)) + eqs.push_back(dependent_eq(orig, to_app(u), term, d)); + else + solve_eq(orig, u, term, d, eqs); } /*** @@ -96,7 +100,7 @@ namespace euf { * -1*x + Y = Z -> x = Y - Z * a*x + Y = Z -> x = (Z - Y)/a for is-real(x) */ - void solve_add(expr* x, expr* y, expr_dependency* d, dep_eq_vector& eqs) { + void solve_add(expr* orig, expr* x, expr* y, expr_dependency* d, dep_eq_vector& eqs) { if (!a.is_add(x)) return; expr* u, * z; @@ -115,18 +119,18 @@ namespace euf { for (expr* arg : *to_app(x)) { if (is_uninterp_const(arg)) { mk_term(i); - eqs.push_back(dependent_eq(to_app(arg), term, d)); + eqs.push_back(dependent_eq(orig, to_app(arg), term, d)); } else if (a.is_mul(arg, u, z) && a.is_numeral(u, r) && is_uninterp_const(z)) { if (r == -1) { mk_term(i); term = a.mk_uminus(term); - eqs.push_back(dependent_eq(to_app(z), term, d)); + eqs.push_back(dependent_eq(orig, to_app(z), term, d)); } else if (a.is_real(arg) && r != 0) { mk_term(i); term = a.mk_div(term, u); - eqs.push_back(dependent_eq(to_app(z), term, d)); + eqs.push_back(dependent_eq(orig, to_app(z), term, d)); } } else if (a.is_real(arg) && a.is_mul(arg)) { @@ -155,7 +159,7 @@ namespace euf { } mk_term(i); term = a.mk_div(term, a.mk_mul(args.size(), args.data())); - eqs.push_back(dependent_eq(to_app(xarg), term, d)); + eqs.push_back(dependent_eq(orig, to_app(xarg), term, d)); } } ++i; @@ -165,7 +169,7 @@ namespace euf { /*** * Solve for x * Y = Z, where Y != 0 -> x = Z / Y */ - void solve_mul(expr* x, expr* y, expr_dependency* d, dep_eq_vector& eqs) { + void solve_mul(expr* orig, expr* x, expr* y, expr_dependency* d, dep_eq_vector& eqs) { if (!a.is_mul(x)) return; rational r; @@ -193,7 +197,7 @@ namespace euf { args.push_back(arg2); } term = a.mk_div(y, a.mk_mul(args)); - eqs.push_back(dependent_eq(to_app(arg), term, d)); + eqs.push_back(dependent_eq(orig, to_app(arg), term, d)); } } @@ -214,22 +218,24 @@ namespace euf { } } - void solve_eq(expr* x, expr* y, expr_dependency* d, dep_eq_vector& eqs) { - solve_add(x, y, d, eqs); - solve_mod(x, y, d, eqs); - solve_mul(x, y, d, eqs); + void solve_eq(expr* orig, expr* x, expr* y, expr_dependency* d, dep_eq_vector& eqs) { + solve_add(orig, x, y, d, eqs); + solve_mod(orig, x, y, d, eqs); + solve_mul(orig, x, y, d, eqs); } public: + arith_extract_eq(ast_manager& m) : m(m), a(m), m_args(m) {} + void get_eqs(dependent_expr const& e, dep_eq_vector& eqs) override { if (!m_enabled) return; auto [f, d] = e(); expr* x, * y; if (m.is_eq(f, x, y) && a.is_int_real(x)) { - solve_eq(x, y, d, eqs); - solve_eq(y, x, d, eqs); + solve_eq(f, x, y, d, eqs); + solve_eq(f, y, x, d, eqs); } } @@ -237,10 +243,8 @@ namespace euf { if (!m_enabled) return; m_nonzero.reset(); - for (unsigned i = 0; i < fmls.size(); ++i) { - auto [f, d] = fmls[i](); - add_pos(f); - } + for (unsigned i = 0; i < fmls.size(); ++i) + add_pos(fmls[i].fml()); } diff --git a/src/ast/simplifiers/extract_eqs.h b/src/ast/simplifiers/extract_eqs.h index 00f96f59b..f38829dfc 100644 --- a/src/ast/simplifiers/extract_eqs.h +++ b/src/ast/simplifiers/extract_eqs.h @@ -27,10 +27,11 @@ Author: namespace euf { struct dependent_eq { - app* var; - expr_ref term; + expr* orig; // original expression that encoded equation + app* var; // isolated variable + expr_ref term; // defined term expr_dependency* dep; - dependent_eq(app* var, expr_ref const& term, expr_dependency* d) : var(var), term(term), dep(d) {} + dependent_eq(expr* orig, app* var, expr_ref const& term, expr_dependency* d) : orig(orig), var(var), term(term), dep(d) {} }; typedef vector dep_eq_vector; diff --git a/src/ast/simplifiers/solve_context_eqs.cpp b/src/ast/simplifiers/solve_context_eqs.cpp new file mode 100644 index 000000000..766c18535 --- /dev/null +++ b/src/ast/simplifiers/solve_context_eqs.cpp @@ -0,0 +1,203 @@ +/*++ +Copyright (c) 2022 Microsoft Corporation + +Module Name: + + solve_context_eqs.cpp + +Abstract: + + simplifier for solving equations within a context + +Author: + + Nikolaj Bjorner (nbjorner) 2022-11-2. + +Notes: + +The variable v is solved based on expression e. +Check that every occurrence of v uses e in conjunctive context. + +Walk formulas containing v in as and-or. +Equalities that occur within at least one alternation of or are +considered as candidates. + +To constrain how formulas are traversed, first +label sub-expressions that contain v. An equality eq is safe for v +if every occurrence of v occurs in the same conjunctive context as eq. + +--*/ + +#include "ast/ast.h" +#include "ast/ast_pp.h" +#include "ast/occurs.h" +#include "ast/simplifiers/solve_context_eqs.h" +#include "ast/simplifiers/solve_eqs.h" + +namespace euf { + + + solve_context_eqs::solve_context_eqs(solve_eqs& s): m(s.m), m_fmls(s.m_fmls), m_solve_eqs(s) {} + + bool solve_context_eqs::is_safe_eq(expr* e) { + m_and_pos.reset(); m_and_neg.reset(); m_or_pos.reset(); m_or_neg.reset(); + for (unsigned i = 0; i < m_fmls.size(); ++i) + if (!is_safe_eq(m_fmls[i].fml(), e)) + return false; + return true; + } + + /** + * Check if some conjunction of f contains equality 'e'. + * If this is not the case, then check that every conjunct that contains v + * recursively contains a disjunction that contains 'e'. + */ + bool solve_context_eqs::is_safe_eq(unsigned recursion_depth, expr* f, bool sign, expr* e) { + if (!contains_v(f)) + return true; + signed_expressions conjuncts; + if (contains_conjunctively(f, sign, e, conjuncts)) + return true; + if (recursion_depth > 3) + return false; + return all_of(conjuncts, [&](std::pair const& p) { return is_disjunctively_safe(recursion_depth, p.second, p.first, e); }); + } + + /* + * Every disjunction in f that contains v also contains the equation e. + */ + bool solve_context_eqs::is_disjunctively_safe(unsigned recursion_depth, expr* f, bool sign, expr* e) { + signed_expressions todo; + todo.push_back({sign, f}); + while (!todo.empty()) { + auto [s, f] = todo.back(); + todo.pop_back(); + if (s && m_or_neg.is_marked(f)) + continue; + if (!s && m_or_pos.is_marked(f)) + continue; + if (s) + m_or_neg.mark(f, true); + else + m_or_pos.mark(f, true); + if (!s && f == e) + continue; + else if (!contains_v(f)) + continue; + else if (s && m.is_and(f)) + for (auto* arg : *to_app(f)) + todo.push_back({s, arg}); + else if (!s && m.is_or(f)) + for (auto* arg : *to_app(f)) + todo.push_back({s, arg}); + else if (m.is_not(f, f)) + todo.push_back({!s, f}); + else if (!is_safe_eq(recursion_depth + 1, f, s, e)) + return false; + } + return true; + } + + /** + * Determine whether some conjunction in f contains e. + * If no conjunction contains e, then return the set of conjunctions that contain v. + */ + bool solve_context_eqs::contains_conjunctively(expr* f, bool sign, expr* e, signed_expressions& conjuncts) { + signed_expressions todo; + todo.push_back({sign, f}); + while (!todo.empty()) { + auto [s, f] = todo.back(); + todo.pop_back(); + if (!s && f == e) + return true; + if (!s && m_and_pos.is_marked(f)) + continue; + if (s && m_and_neg.is_marked(f)) + continue; + if (s) + m_and_neg.mark(f, true); + else + m_and_pos.mark(f, true); + if (!contains_v(f)) + continue; + if (!s && m.is_and(f)) + for (auto* arg : *to_app(f)) + todo.push_back({false, arg}); + else if (s && m.is_or(f)) + for (auto* arg : *to_app(f)) + todo.push_back({true, arg}); + else if (m.is_not(f, f)) + todo.push_back({!s, f}); + else + conjuncts.push_back({s, f}); + } + return false; + } + + void solve_context_eqs::init_contains(expr* v) { + m_contains_v.reset(); + for (unsigned i = 0; i < m_fmls.size(); ++i) + m_todo.push_back(m_fmls[i].fml()); + mark_occurs(m_todo, v, m_contains_v); + SASSERT(m_todo.empty()); + } + + void solve_context_eqs::collect_nested_equalities(dep_eq_vector& eqs) { + expr_mark visited; + for (unsigned i = m_solve_eqs.m_qhead; i < m_fmls.size(); ++i) + collect_nested_equalities(m_fmls[i], visited, eqs); + + unsigned j = 0; + for (auto const& eq : eqs) { + init_contains(eq.var); + if (is_safe_eq(eq.orig)) + eqs[j++] = eq; + } + eqs.shrink(j); + } + + void solve_context_eqs::collect_nested_equalities(dependent_expr const& df, expr_mark& visited, dep_eq_vector& eqs) { + + svector> todo; + todo.push_back({ false, 0, df.fml()}); + + // even depth is conjunctive context, odd is disjunctive + // when alternating between conjunctive and disjunctive context, increment depth. + auto inc_or = [](unsigned depth) { + return (0 == depth % 2) ? depth + 1 : depth; + }; + auto inc_and = [](unsigned depth) { + return (0 == depth % 2) ? depth : depth + 1; + }; + + while (!todo.empty()) { + auto [s, depth, f] = todo.back(); + todo.pop_back(); + if (visited.is_marked(f)) + continue; + visited.mark(f, true); + if (s && m.is_and(f)) { + for (auto* arg : *to_app(f)) + todo.push_back({ s, inc_or(depth), arg }); + } + else if (!s && m.is_or(f)) { + for (auto* arg : *to_app(f)) + todo.push_back({ s, inc_or(depth), arg }); + } + if (!s && m.is_and(f)) { + for (auto* arg : *to_app(f)) + todo.push_back({ s, inc_and(depth), arg }); + } + else if (s && m.is_or(f)) { + for (auto* arg : *to_app(f)) + todo.push_back({ s, inc_and(depth), arg }); + } + else if (m.is_not(f, f)) + todo.push_back({ !s, depth, f }); + else if (!s && 1 == depth % 2) { + for (extract_eq* ex : m_solve_eqs.m_extract_plugins) + ex->get_eqs(dependent_expr(m, f, df.dep()), eqs); + } + } + } +} diff --git a/src/ast/simplifiers/solve_context_eqs.h b/src/ast/simplifiers/solve_context_eqs.h new file mode 100644 index 000000000..b3db74127 --- /dev/null +++ b/src/ast/simplifiers/solve_context_eqs.h @@ -0,0 +1,58 @@ +/*++ +Copyright (c) 2022 Microsoft Corporation + +Module Name: + + solve_context_eqs.h + +Abstract: + + simplifier for solving equations within a context + +Author: + + Nikolaj Bjorner (nbjorner) 2022-11-2. + +--*/ + + +#pragma once + +#include "ast/simplifiers/dependent_expr_state.h" +#include "ast/simplifiers/extract_eqs.h" + +namespace euf { + + class solve_eqs; + + + class solve_context_eqs { + + ast_manager& m; + dependent_expr_state& m_fmls; + solve_eqs& m_solve_eqs; + expr_mark m_and_pos, m_and_neg, m_or_pos, m_or_neg; + expr_mark m_contains_v; + ptr_vector m_todo; + + typedef svector> signed_expressions; + + bool contains_v(expr* f) const { return m_contains_v.is_marked(f); } + bool is_safe_eq(expr* e); + bool is_safe_eq(unsigned recursion_depth, expr* f, bool sign, expr* e); + bool is_safe_eq(expr* f, expr* e) { return is_safe_eq(0, f, false, e); } + bool is_disjunctively_safe(unsigned recursion_depth, expr* f, bool sign, expr* e); + bool contains_conjunctively(expr* f, bool sign, expr* e, signed_expressions& conjuncts); + + void collect_nested_equalities(dependent_expr const& f, expr_mark& visited, dep_eq_vector& eqs); + void init_contains(expr* v); + + + public: + + solve_context_eqs(solve_eqs& s); + + void collect_nested_equalities(dep_eq_vector& eqs); + + }; +} diff --git a/src/ast/simplifiers/solve_eqs.cpp b/src/ast/simplifiers/solve_eqs.cpp index b5b500a96..31e063119 100644 --- a/src/ast/simplifiers/solve_eqs.cpp +++ b/src/ast/simplifiers/solve_eqs.cpp @@ -23,22 +23,29 @@ Author: #include "ast/recfun_decl_plugin.h" #include "ast/rewriter/expr_replacer.h" #include "ast/simplifiers/solve_eqs.h" +#include "ast/simplifiers/solve_context_eqs.h" #include "ast/converters/generic_model_converter.h" #include "params/tactic_params.hpp" namespace euf { + void solve_eqs::get_eqs(dep_eq_vector& eqs) { + for (extract_eq* ex : m_extract_plugins) + for (unsigned i = m_qhead; i < m_fmls.size(); ++i) + ex->get_eqs(m_fmls[i], eqs); + } + // initialize graph that maps variable ids to next ids void solve_eqs::extract_dep_graph(dep_eq_vector& eqs) { m_var2id.reset(); m_id2var.reset(); m_next.reset(); unsigned sz = 0; - for (auto const& [v, t, d] : eqs) + for (auto const& [orig, v, t, d] : eqs) sz = std::max(sz, v->get_id()); m_var2id.resize(sz + 1, UINT_MAX); - for (auto const& [v, t, d] : eqs) { + for (auto const& [orig, v, t, d] : eqs) { if (is_var(v) || !can_be_var(v)) continue; m_var2id[v->get_id()] = m_id2var.size(); @@ -91,7 +98,7 @@ namespace euf { continue; m_id2level[id] = curr_level++; for (auto const& eq : m_next[j]) { - auto const& [v, t, d] = eq; + auto const& [orig, v, t, d] = eq; if (!is_safe(curr_level, t)) continue; m_next[j][0] = eq; @@ -114,7 +121,7 @@ namespace euf { for (unsigned id : m_subst_ids) { if (!m.inc()) break; - auto const& [v, def, dep] = m_next[id][0]; + auto const& [orig, v, def, dep] = m_next[id][0]; auto [new_def, new_dep] = rp->replace_with_dep(def); m_stats.m_num_steps += rp->get_num_steps() + 1; ++m_stats.m_num_elim_vars; @@ -134,7 +141,7 @@ namespace euf { }); } - void solve_eqs::apply_subst() { + void solve_eqs::apply_subst(vector& old_fmls) { if (!m.inc()) return; scoped_ptr rp = mk_default_expr_replacer(m, true); @@ -146,6 +153,7 @@ namespace euf { if (new_f == f) continue; new_dep = m.mk_join(d, new_dep); + old_fmls.push_back(m_fmls[i]); m_fmls.update(i, dependent_expr(m, new_f, new_dep)); } } @@ -157,6 +165,7 @@ namespace euf { unsigned count = 0; do { + vector old_fmls; m_subst_ids.reset(); if (!m.inc()) return; @@ -164,18 +173,30 @@ namespace euf { get_eqs(eqs); extract_dep_graph(eqs); extract_subst(); - apply_subst(); + apply_subst(old_fmls); ++count; } while (!m_subst_ids.empty() && count < 20); + save_subst({}); + + if (m_config.m_context_solve) { + vector old_fmls; + dep_eq_vector eqs; + m_subst_ids.reset(); + solve_context_eqs context_solve(*this); + context_solve.collect_nested_equalities(eqs); + extract_dep_graph(eqs); + extract_subst(); + apply_subst(old_fmls); + save_subst(old_fmls); + } advance_qhead(m_fmls.size()); - save_subst(); } - void solve_eqs::save_subst() { + void solve_eqs::save_subst(vector const& old_fmls) { if (!m_subst->empty()) - m_fmls.model_trail()->push(m_subst.detach(), {}); + m_fmls.model_trail().push(m_subst.detach(), old_fmls); } void solve_eqs::filter_unsafe_vars() { diff --git a/src/ast/simplifiers/solve_eqs.h b/src/ast/simplifiers/solve_eqs.h index db7a1323b..35044b373 100644 --- a/src/ast/simplifiers/solve_eqs.h +++ b/src/ast/simplifiers/solve_eqs.h @@ -26,58 +26,51 @@ Author: namespace euf { class solve_eqs : public dependent_expr_simplifier { + + friend class solve_context_eqs; + struct stats { unsigned m_num_steps = 0; unsigned m_num_elim_vars = 0; }; + struct config { bool m_context_solve = true; unsigned m_max_occs = UINT_MAX; }; - th_rewriter m_rewriter; - scoped_ptr_vector m_extract_plugins; - unsigned_vector m_var2id, m_id2level, m_subst_ids; - ptr_vector m_id2var; - vector m_next; - scoped_ptr m_subst; - - expr_mark m_unsafe_vars; // expressions that cannot be replaced stats m_stats; config m_config; - - void add_subst(dependent_eq const& eq); + th_rewriter m_rewriter; + scoped_ptr_vector m_extract_plugins; + unsigned_vector m_var2id; // app->get_id() |-> small numeral + ptr_vector m_id2var; // small numeral |-> app + unsigned_vector m_id2level; // small numeral |-> level in substitution ordering + unsigned_vector m_subst_ids; // sorted list of small numeral by level + vector m_next; // adjacency list for solved equations + scoped_ptr m_subst; // current substitution + expr_mark m_unsafe_vars; // expressions that cannot be replaced bool is_var(expr* e) const { return e->get_id() < m_var2id.size() && m_var2id[e->get_id()] != UINT_MAX; } unsigned var2id(expr* v) const { return m_var2id[v->get_id()]; } - - void get_eqs(dep_eq_vector& eqs) { - for (unsigned i = m_qhead; i < m_fmls.size(); ++i) - get_eqs(m_fmls[i], eqs); - } - - void get_eqs(dependent_expr const& f, dep_eq_vector& eqs) { - for (extract_eq* ex : m_extract_plugins) - ex->get_eqs(f, eqs); - } - - void filter_unsafe_vars(); bool can_be_var(expr* e) const { return is_uninterp_const(e) && !m_unsafe_vars.is_marked(e); } + void get_eqs(dep_eq_vector& eqs); + void filter_unsafe_vars(); void extract_subst(); void extract_dep_graph(dep_eq_vector& eqs); void normalize(); - void apply_subst(); - void save_subst(); + void apply_subst(vector& old_fmls); + void save_subst(vector const& old_fmls); public: solve_eqs(ast_manager& m, dependent_expr_state& fmls); - void push() override { dependent_expr_simplifier::push(); } - void pop(unsigned n) override { dependent_expr_simplifier::pop(n); } void reduce() override; void updt_params(params_ref const& p) override; + void collect_statistics(statistics& st) const override; + }; } diff --git a/src/tactic/core/solve_eqs_tactic.cpp b/src/tactic/core/solve_eqs_tactic.cpp index 3e338b57e..cfc4d8eeb 100644 --- a/src/tactic/core/solve_eqs_tactic.cpp +++ b/src/tactic/core/solve_eqs_tactic.cpp @@ -479,52 +479,11 @@ class solve_eqs_tactic : public tactic { ptr_vector m_todo; void mark_occurs(expr_mark& occ, goal const& g, expr* v) { - expr_fast_mark2 visited; - occ.mark(v, true); - visited.mark(v, true); - for (unsigned j = 0; j < g.size(); ++j) { - m_todo.push_back(g.form(j)); - } - while (!m_todo.empty()) { - expr* e = m_todo.back(); - if (visited.is_marked(e)) { - m_todo.pop_back(); - continue; - } - if (is_app(e)) { - bool does_occur = false; - bool all_visited = true; - for (expr* arg : *to_app(e)) { - if (!visited.is_marked(arg)) { - m_todo.push_back(arg); - all_visited = false; - } - else { - does_occur |= occ.is_marked(arg); - } - } - if (all_visited) { - occ.mark(e, does_occur); - visited.mark(e, true); - m_todo.pop_back(); - } - } - else if (is_quantifier(e)) { - expr* body = to_quantifier(e)->get_expr(); - if (visited.is_marked(body)) { - visited.mark(e, true); - occ.mark(e, occ.is_marked(body)); - m_todo.pop_back(); - } - else { - m_todo.push_back(body); - } - } - else { - visited.mark(e, true); - m_todo.pop_back(); - } - } + SASSERT(m_todo.empty()); + for (unsigned j = 0; j < g.size(); ++j) + m_todo.push_back(g.form(j)); + ::mark_occurs(m_todo, v, occ); + SASSERT(m_todo.empty()); } expr_mark m_compatible_tried; diff --git a/src/tactic/dependent_expr_state_tactic.h b/src/tactic/dependent_expr_state_tactic.h index 719a29eea..41b32baac 100644 --- a/src/tactic/dependent_expr_state_tactic.h +++ b/src/tactic/dependent_expr_state_tactic.h @@ -65,8 +65,8 @@ public: return m_goal->inconsistent(); } - model_reconstruction_trail* model_trail() override { - return m_model_trail.get(); + model_reconstruction_trail& model_trail() override { + return *m_model_trail; } char const* name() const override { return m_name.c_str(); } diff --git a/src/util/mpn.cpp b/src/util/mpn.cpp index 0cbe9e9f8..bc9017726 100644 --- a/src/util/mpn.cpp +++ b/src/util/mpn.cpp @@ -34,8 +34,8 @@ int mpn_manager::compare(mpn_digit const * a, unsigned lnga, trace(a, lnga); - unsigned j = max(lnga, lngb) - 1; - for (; j != -1u && res == 0; j--) { + unsigned j = max(lnga, lngb); + for (; j-- > 0 && res == 0;) { mpn_digit const & u_j = (j < lnga) ? a[j] : zero; mpn_digit const & v_j = (j < lngb) ? b[j] : zero; if (u_j > v_j) @@ -310,7 +310,7 @@ bool mpn_manager::div_n(mpn_sbuffer & numer, mpn_sbuffer const & denom, mpn_double_digit q_hat, temp, r_hat; mpn_digit borrow; - for (unsigned j = m-1; j != -1u; j--) { + for (unsigned j = m; j-- > 0; ) { temp = (((mpn_double_digit)numer[j+n]) << DIGIT_BITS) | ((mpn_double_digit)numer[j+n-1]); q_hat = temp / (mpn_double_digit) denom[n-1]; r_hat = temp % (mpn_double_digit) denom[n-1]; @@ -388,7 +388,7 @@ char * mpn_manager::to_string(mpn_digit const * a, unsigned lng, char * buf, uns void mpn_manager::display_raw(std::ostream & out, mpn_digit const * a, unsigned lng) const { out << "["; - for (unsigned i = lng-1; i != -1u; i-- ) { out << a[i]; if (i != 0) out << "|"; } + for (unsigned i = lng; i-- > 0; ) { out << a[i]; if (i != 0) out << "|"; } out << "]"; } diff --git a/src/util/util.h b/src/util/util.h index 2a037770d..121031492 100644 --- a/src/util/util.h +++ b/src/util/util.h @@ -368,6 +368,14 @@ bool any_of(S& set, T const& p) { return false; } +template +bool all_of(S& set, T const& p) { + for (auto const& s : set) + if (!p(s)) + return false; + return true; +} + /** \brief Iterator for the [0..sz[0]) X [0..sz[1]) X ... X [0..sz[n-1]). it contains the current value.