From 85f9c7eefaa721ea9eeb6484f310681db257967b Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Mon, 28 Nov 2022 11:45:56 +0700 Subject: [PATCH] replace restore_size_trail by more generic restore_vector other updates: - change signature of advance_qhead to simplify call sites - have model reconstruction replay work on a tail of dependent_expr state, while adding formulas to the tail. --- .../converters/generic_model_converter.cpp | 1 - src/ast/converters/generic_model_converter.h | 3 +- src/ast/converters/model_converter.h | 4 +++ src/ast/simplifiers/bit_blaster.cpp | 2 +- src/ast/simplifiers/bv_slice.cpp | 2 +- src/ast/simplifiers/card2bv.cpp | 2 +- src/ast/simplifiers/dependent_expr_state.h | 3 +- src/ast/simplifiers/elim_unconstrained.cpp | 2 +- src/ast/simplifiers/euf_completion.cpp | 2 +- src/ast/simplifiers/max_bv_sharing.cpp | 2 +- .../model_reconstruction_trail.cpp | 33 +++++++++++-------- .../simplifiers/model_reconstruction_trail.h | 17 +++++++--- src/ast/simplifiers/propagate_values.cpp | 5 +-- src/ast/simplifiers/propagate_values.h | 2 +- src/ast/simplifiers/seq_simplifier.h | 1 + src/ast/simplifiers/solve_eqs.cpp | 2 +- src/sat/sat_solver.h | 1 + src/sat/smt/arith_solver.cpp | 4 +-- src/sat/smt/bv_solver.cpp | 2 +- src/sat/smt/euf_proof.cpp | 16 ++++----- src/sat/smt/euf_solver.cpp | 2 +- src/smt/theory_arith_aux.h | 4 +-- src/smt/theory_lra.cpp | 2 +- src/tactic/core/propagate_values2_tactic.h | 2 +- src/util/trail.h | 23 ++++++------- 25 files changed, 80 insertions(+), 59 deletions(-) diff --git a/src/ast/converters/generic_model_converter.cpp b/src/ast/converters/generic_model_converter.cpp index f805e169b..50c3b071a 100644 --- a/src/ast/converters/generic_model_converter.cpp +++ b/src/ast/converters/generic_model_converter.cpp @@ -36,7 +36,6 @@ generic_model_converter::~generic_model_converter() { void generic_model_converter::add(func_decl * d, expr* e) { VERIFY(e); VERIFY(d->get_range() == e->get_sort()); - m_first_idx.insert_if_not_there(d, m_entries.size()); m_entries.push_back(entry(d, e, m, ADD)); } diff --git a/src/ast/converters/generic_model_converter.h b/src/ast/converters/generic_model_converter.h index 85e8d0390..0706b181f 100644 --- a/src/ast/converters/generic_model_converter.h +++ b/src/ast/converters/generic_model_converter.h @@ -35,7 +35,6 @@ private: ast_manager& m; std::string m_orig; vector m_entries; - obj_map m_first_idx; expr_ref simplify_def(entry const& e); @@ -71,6 +70,8 @@ public: void get_units(obj_map& units) override; vector const& entries() const { return m_entries; } + + void shrink(unsigned j) { m_entries.shrink(j); } }; typedef ref generic_model_converter_ref; diff --git a/src/ast/converters/model_converter.h b/src/ast/converters/model_converter.h index 335e0d276..164becc09 100644 --- a/src/ast/converters/model_converter.h +++ b/src/ast/converters/model_converter.h @@ -101,6 +101,10 @@ typedef sref_buffer model_converter_ref_buffer; model_converter * concat(model_converter * mc1, model_converter * mc2); +inline model_converter * concat(model_converter * mc1, model_converter * mc2, model_converter* mc3) { + return concat(mc1, concat(mc2, mc3)); +} + model_converter * model2model_converter(model * m); model_converter * model_and_labels2model_converter(model * m, labels_vec const &r); diff --git a/src/ast/simplifiers/bit_blaster.cpp b/src/ast/simplifiers/bit_blaster.cpp index ceb3c56a6..218765d28 100644 --- a/src/ast/simplifiers/bit_blaster.cpp +++ b/src/ast/simplifiers/bit_blaster.cpp @@ -61,7 +61,7 @@ void bit_blaster::reduce() { } m_rewriter.cleanup(); - advance_qhead(m_fmls.size()); + advance_qhead(); } diff --git a/src/ast/simplifiers/bv_slice.cpp b/src/ast/simplifiers/bv_slice.cpp index 75e0a890c..995231b34 100644 --- a/src/ast/simplifiers/bv_slice.cpp +++ b/src/ast/simplifiers/bv_slice.cpp @@ -24,7 +24,7 @@ namespace bv { void slice::reduce() { process_eqs(); apply_subst(); - advance_qhead(m_fmls.size()); + advance_qhead(); } void slice::process_eqs() { diff --git a/src/ast/simplifiers/card2bv.cpp b/src/ast/simplifiers/card2bv.cpp index 2da9b6e44..44bc8e589 100644 --- a/src/ast/simplifiers/card2bv.cpp +++ b/src/ast/simplifiers/card2bv.cpp @@ -49,7 +49,7 @@ void card2bv::reduce() { for (func_decl* f : fns) m_fmls.model_trail().hide(f); - advance_qhead(m_fmls.size()); + advance_qhead(); } void card2bv::collect_statistics(statistics& st) const { diff --git a/src/ast/simplifiers/dependent_expr_state.h b/src/ast/simplifiers/dependent_expr_state.h index b94b422ad..90b95ab7d 100644 --- a/src/ast/simplifiers/dependent_expr_state.h +++ b/src/ast/simplifiers/dependent_expr_state.h @@ -69,7 +69,7 @@ protected: unsigned num_scopes() const { return m_trail.get_num_scopes(); } - void advance_qhead(unsigned sz) { if (num_scopes() > 0) m_trail.push(value_trail(m_qhead)); m_qhead = sz; } + void advance_qhead() { if (num_scopes() > 0) m_trail.push(value_trail(m_qhead)); m_qhead = m_fmls.size(); } public: dependent_expr_simplifier(ast_manager& m, dependent_expr_state& s) : m(m), m_fmls(s), m_trail(s.m_trail) {} virtual ~dependent_expr_simplifier() {} @@ -80,6 +80,7 @@ public: virtual void reset_statistics() {} virtual void updt_params(params_ref const& p) {} virtual void collect_param_descrs(param_descrs& r) {} + unsigned qhead() const { return m_qhead; } }; /** diff --git a/src/ast/simplifiers/elim_unconstrained.cpp b/src/ast/simplifiers/elim_unconstrained.cpp index 52dc95f05..ac5cc339d 100644 --- a/src/ast/simplifiers/elim_unconstrained.cpp +++ b/src/ast/simplifiers/elim_unconstrained.cpp @@ -302,5 +302,5 @@ void elim_unconstrained::reduce() { vector old_fmls; assert_normalized(old_fmls); update_model_trail(*mc, old_fmls); - advance_qhead(m_fmls.size()); + advance_qhead(); } diff --git a/src/ast/simplifiers/euf_completion.cpp b/src/ast/simplifiers/euf_completion.cpp index b768180cf..5056d818c 100644 --- a/src/ast/simplifiers/euf_completion.cpp +++ b/src/ast/simplifiers/euf_completion.cpp @@ -128,7 +128,7 @@ namespace euf { CTRACE("euf_completion", g != f, tout << mk_bounded_pp(f, m) << " -> " << mk_bounded_pp(g, m) << "\n"); } if (!m_has_new_eq) - advance_qhead(m_fmls.size()); + advance_qhead(); } bool completion::is_new_eq(expr* a, expr* b) { diff --git a/src/ast/simplifiers/max_bv_sharing.cpp b/src/ast/simplifiers/max_bv_sharing.cpp index 2abacb7f7..3003a47a0 100644 --- a/src/ast/simplifiers/max_bv_sharing.cpp +++ b/src/ast/simplifiers/max_bv_sharing.cpp @@ -261,7 +261,7 @@ public: m_fmls.update(idx, dependent_expr(m, new_curr, d)); } m_rw.cfg().cleanup(); - advance_qhead(m_fmls.size()); + advance_qhead(); } }; diff --git a/src/ast/simplifiers/model_reconstruction_trail.cpp b/src/ast/simplifiers/model_reconstruction_trail.cpp index cfc65ce67..19c7d9381 100644 --- a/src/ast/simplifiers/model_reconstruction_trail.cpp +++ b/src/ast/simplifiers/model_reconstruction_trail.cpp @@ -15,6 +15,7 @@ Author: #include "ast/for_each_expr.h" #include "ast/rewriter/macro_replacer.h" #include "ast/simplifiers/model_reconstruction_trail.h" +#include "ast/simplifiers/dependent_expr_state.h" #include "ast/converters/generic_model_converter.h" @@ -22,13 +23,11 @@ Author: // substitutions that use variables from the dependent expressions. // TODO: add filters to skip sections of the trail that do not touch the current free variables. -void model_reconstruction_trail::replay(dependent_expr const& d, vector& added) { - +void model_reconstruction_trail::replay(unsigned qhead, dependent_expr_state& st) { ast_mark free_vars; scoped_ptr rp = mk_default_expr_replacer(m, false); - add_vars(d, free_vars); - - added.push_back(d); + for (unsigned i = qhead; i < st.size(); ++i) + add_vars(st[i], free_vars); for (auto& t : m_trail) { if (!t->m_active) @@ -44,9 +43,10 @@ void model_reconstruction_trail::replay(dependent_expr const& d, vectoris_loose()) { - added.append(t->m_removed); - for (auto r : t->m_removed) - add_vars(r, free_vars); + for (auto r : t->m_removed) { + add_vars(r, free_vars); + st.add(r); + } m_trail_stack.push(value_trail(t->m_active)); t->m_active = false; continue; @@ -64,12 +64,12 @@ void model_reconstruction_trail::replay(dependent_expr const& d, vectorm_def, t->m_dep); add_vars(de, free_vars); - for (auto& d : added) { - auto [f, dep1] = d(); + for (unsigned i = qhead; i < st.size(); ++i) { + auto [f, dep1] = st[i](); expr_ref g(m); expr_dependency_ref dep2(m); mrp(f, g, dep2); - d = dependent_expr(m, g, m.mk_join(dep1, dep2)); + st.update(i, dependent_expr(m, g, m.mk_join(dep1, dep2))); } continue; } @@ -77,11 +77,12 @@ void model_reconstruction_trail::replay(dependent_expr const& d, vectorset_substitution(t->m_subst.get()); // rigid entries: // apply substitution to added in case of rigid model convertions - for (auto& d : added) { - auto [f, dep1] = d(); + for (unsigned i = qhead; i < st.size(); ++i) { + auto [f, dep1] = st[i](); auto [g, dep2] = rp->replace_with_dep(f); - d = dependent_expr(m, g, m.mk_join(dep1, dep2)); + dependent_expr d(m, g, m.mk_join(dep1, dep2)); add_vars(d, free_vars); + st.update(i, d); } } } @@ -116,3 +117,7 @@ void model_reconstruction_trail::append(generic_model_converter& mc, unsigned& i } +void model_reconstruction_trail::append(generic_model_converter& mc) { + m_trail_stack.push(value_trail(m_trail_index)); + append(mc, m_trail_index); +} diff --git a/src/ast/simplifiers/model_reconstruction_trail.h b/src/ast/simplifiers/model_reconstruction_trail.h index 6aa91e550..5ad204bf7 100644 --- a/src/ast/simplifiers/model_reconstruction_trail.h +++ b/src/ast/simplifiers/model_reconstruction_trail.h @@ -31,6 +31,8 @@ Author: #include "ast/converters/model_converter.h" #include "ast/converters/generic_model_converter.h" +class dependent_expr_state; + class model_reconstruction_trail { struct entry { @@ -41,7 +43,6 @@ class model_reconstruction_trail { expr_dependency_ref m_dep; bool m_active = true; - entry(ast_manager& m, expr_substitution* s, vector const& rem) : m_subst(s), m_removed(rem), m_decl(m), m_def(m), m_dep(m) {} @@ -71,6 +72,7 @@ class model_reconstruction_trail { ast_manager& m; trail_stack& m_trail_stack; scoped_ptr_vector m_trail; + unsigned m_trail_index = 0; void add_vars(dependent_expr const& d, ast_mark& free_vars) { for (expr* t : subterms::all(expr_ref(d.fml(), d.get_manager()))) @@ -87,6 +89,10 @@ class model_reconstruction_trail { return any_of(added, [&](dependent_expr const& d) { return intersects(free_vars, d); }); } + /** + * Append new updates to model converter, update the current index into the trail in the process. + */ + void append(generic_model_converter& mc, unsigned& index); public: model_reconstruction_trail(ast_manager& m, trail_stack& tr): @@ -120,16 +126,17 @@ public: * register a new depedent expression, update the trail * by removing substitutions that are not equivalence preserving. */ - void replay(dependent_expr const& d, vector& added); - + void replay(unsigned qhead, dependent_expr_state& fmls); + /** * retrieve the current model converter corresponding to chaining substitutions from the trail. */ model_converter_ref get_model_converter(); + /** - * Append new updates to model converter, update the current index into the trail in the process. + * Append new updates to model converter, update m_trail_index in the process. */ - void append(generic_model_converter& mc, unsigned& index); + void append(generic_model_converter& mc); }; diff --git a/src/ast/simplifiers/propagate_values.cpp b/src/ast/simplifiers/propagate_values.cpp index 6a179674b..6d6bc976f 100644 --- a/src/ast/simplifiers/propagate_values.cpp +++ b/src/ast/simplifiers/propagate_values.cpp @@ -25,11 +25,12 @@ Notes: #include "ast/shared_occs.h" #include "ast/simplifiers/propagate_values.h" -propagate_values::propagate_values(ast_manager& m, dependent_expr_state& fmls): +propagate_values::propagate_values(ast_manager& m, params_ref const& p, dependent_expr_state& fmls): dependent_expr_simplifier(m, fmls), m_rewriter(m) { m_rewriter.set_order_eq(true); m_rewriter.set_flat_and_or(false); + updt_params(p); } void propagate_values::reduce() { @@ -96,7 +97,7 @@ void propagate_values::reduce() { m_rewriter.set_substitution(nullptr); m_rewriter.reset(); - advance_qhead(m_fmls.size()); + advance_qhead(); } void propagate_values::collect_statistics(statistics& st) const { diff --git a/src/ast/simplifiers/propagate_values.h b/src/ast/simplifiers/propagate_values.h index 3219f9796..d263377b1 100644 --- a/src/ast/simplifiers/propagate_values.h +++ b/src/ast/simplifiers/propagate_values.h @@ -36,7 +36,7 @@ class propagate_values : public dependent_expr_simplifier { unsigned m_max_rounds = 4; public: - propagate_values(ast_manager& m, dependent_expr_state& fmls); + propagate_values(ast_manager& m, params_ref const& p, dependent_expr_state& fmls); void reduce() override; void collect_statistics(statistics& st) const override; void reset_statistics() override { m_stats.reset(); } diff --git a/src/ast/simplifiers/seq_simplifier.h b/src/ast/simplifiers/seq_simplifier.h index 5a3a7da60..4288a3475 100644 --- a/src/ast/simplifiers/seq_simplifier.h +++ b/src/ast/simplifiers/seq_simplifier.h @@ -38,6 +38,7 @@ public: break; s->reduce(); } + advance_qhead(); } void collect_statistics(statistics& st) const override { diff --git a/src/ast/simplifiers/solve_eqs.cpp b/src/ast/simplifiers/solve_eqs.cpp index 5090e1b43..1bea283ac 100644 --- a/src/ast/simplifiers/solve_eqs.cpp +++ b/src/ast/simplifiers/solve_eqs.cpp @@ -237,7 +237,7 @@ namespace euf { save_subst(old_fmls); } - advance_qhead(m_fmls.size()); + advance_qhead(); } void solve_eqs::save_subst(vector const& old_fmls) { diff --git a/src/sat/sat_solver.h b/src/sat/sat_solver.h index 982a84307..227568f3d 100644 --- a/src/sat/sat_solver.h +++ b/src/sat/sat_solver.h @@ -485,6 +485,7 @@ namespace sat { // ----------------------- public: lbool check(unsigned num_lits = 0, literal const* lits = nullptr); + lbool check(literal_vector const& lits) { return check(lits.size(), lits.data()); } // retrieve model if solver return sat model const & get_model() const { return m_model; } diff --git a/src/sat/smt/arith_solver.cpp b/src/sat/smt/arith_solver.cpp index 8be98edfb..a69f3604e 100644 --- a/src/sat/smt/arith_solver.cpp +++ b/src/sat/smt/arith_solver.cpp @@ -893,11 +893,11 @@ namespace arith { theory_var other = m_model_eqs.insert_if_not_there(v); TRACE("arith", tout << "insert: v" << v << " := " << get_value(v) << " found: v" << other << "\n";); if (!is_equal(other, v)) - m_assume_eq_candidates.push_back(std::make_pair(v, other)); + m_assume_eq_candidates.push_back({ v, other }); } if (m_assume_eq_candidates.size() > old_sz) - ctx.push(restore_size_trail, false>(m_assume_eq_candidates, old_sz)); + ctx.push(restore_vector(m_assume_eq_candidates, old_sz)); return delayed_assume_eqs(); } diff --git a/src/sat/smt/bv_solver.cpp b/src/sat/smt/bv_solver.cpp index a5a35878e..a0bcea43b 100644 --- a/src/sat/smt/bv_solver.cpp +++ b/src/sat/smt/bv_solver.cpp @@ -419,7 +419,7 @@ namespace bv { } ctx.push(value_trail(m_lit_tail)); - ctx.push(restore_size_trail(m_proof_literals)); + ctx.push(restore_vector(m_proof_literals)); sat::literal_vector lits; switch (c.m_kind) { diff --git a/src/sat/smt/euf_proof.cpp b/src/sat/smt/euf_proof.cpp index 5fbd632cd..ac9e81311 100644 --- a/src/sat/smt/euf_proof.cpp +++ b/src/sat/smt/euf_proof.cpp @@ -84,7 +84,7 @@ namespace euf { return nullptr; push(value_trail(m_lit_tail)); push(value_trail(m_cc_tail)); - push(restore_size_trail(m_proof_literals)); + push(restore_vector(m_proof_literals)); if (conseq != sat::null_literal) m_proof_literals.push_back(~conseq); m_proof_literals.append(r); @@ -101,8 +101,8 @@ namespace euf { SASSERT(a->get_decl() == b->get_decl()); push(value_trail(m_lit_tail)); push(value_trail(m_cc_tail)); - push(restore_size_trail(m_proof_literals)); - push(restore_size_trail(m_explain_cc, m_explain_cc.size())); + push(restore_vector(m_proof_literals)); + push(restore_vector(m_explain_cc)); for (auto lit : ante) m_proof_literals.push_back(~lit); @@ -121,7 +121,7 @@ namespace euf { return nullptr; push(value_trail(m_lit_tail)); push(value_trail(m_cc_tail)); - push(restore_size_trail(m_proof_literals)); + push(restore_vector(m_proof_literals)); for (unsigned i = 0; i < 3; ++i) m_proof_literals.push_back(~clause[i]); @@ -171,7 +171,7 @@ namespace euf { if (!use_drat()) return nullptr; push(value_trail(m_lit_tail)); - push(restore_size_trail(m_proof_literals)); + push(restore_vector(m_proof_literals)); for (unsigned i = 0; i < nl; ++i) m_proof_literals.push_back(~lits[i]); @@ -190,7 +190,7 @@ namespace euf { if (!use_drat()) return nullptr; push(value_trail(m_lit_tail)); - push(restore_size_trail(m_proof_literals)); + push(restore_vector(m_proof_literals)); for (unsigned i = 0; i < nl; ++i) if (sat::null_literal != lits[i]) { @@ -203,11 +203,11 @@ namespace euf { } push(value_trail(m_eq_tail)); - push(restore_size_trail(m_proof_eqs)); + push(restore_vector(m_proof_eqs)); m_proof_eqs.append(ne, eqs); push(value_trail(m_deq_tail)); - push(restore_size_trail(m_proof_deqs)); + push(restore_vector(m_proof_deqs)); m_proof_deqs.append(nd, deqs); m_lit_head = m_lit_tail; diff --git a/src/sat/smt/euf_solver.cpp b/src/sat/smt/euf_solver.cpp index 7473be61d..865608339 100644 --- a/src/sat/smt/euf_solver.cpp +++ b/src/sat/smt/euf_solver.cpp @@ -225,7 +225,7 @@ namespace euf { m_egraph.begin_explain(); m_explain.reset(); if (use_drat() && !probing) { - push(restore_size_trail(m_explain_cc, m_explain_cc.size())); + push(restore_vector(m_explain_cc)); } auto* ext = sat::constraint_base::to_extension(idx); th_proof_hint* hint = nullptr; diff --git a/src/smt/theory_arith_aux.h b/src/smt/theory_arith_aux.h index 593c32d83..d5eca0bc4 100644 --- a/src/smt/theory_arith_aux.h +++ b/src/smt/theory_arith_aux.h @@ -2223,12 +2223,12 @@ namespace smt { continue; } TRACE("func_interp_bug", tout << "adding to assume_eq queue #" << n->get_owner_id() << " #" << n2->get_owner_id() << "\n";); - m_assume_eq_candidates.push_back(std::make_pair(other, v)); + m_assume_eq_candidates.push_back({ other , v }); result = true; } if (result) - ctx.push_trail(restore_size_trail, false>(m_assume_eq_candidates, old_sz)); + ctx.push_trail(restore_vector(m_assume_eq_candidates, old_sz)); return delayed_assume_eqs(); } diff --git a/src/smt/theory_lra.cpp b/src/smt/theory_lra.cpp index 4075f39dd..1309724d2 100644 --- a/src/smt/theory_lra.cpp +++ b/src/smt/theory_lra.cpp @@ -1515,7 +1515,7 @@ public: } if (num_candidates > 0) { - ctx().push_trail(restore_size_trail, false>(m_assume_eq_candidates, old_sz)); + ctx().push_trail(restore_vector(m_assume_eq_candidates, old_sz)); } return delayed_assume_eqs(); diff --git a/src/tactic/core/propagate_values2_tactic.h b/src/tactic/core/propagate_values2_tactic.h index ab9646128..58e263e80 100644 --- a/src/tactic/core/propagate_values2_tactic.h +++ b/src/tactic/core/propagate_values2_tactic.h @@ -26,7 +26,7 @@ Author: class propagate_values2_tactic_factory : public dependent_expr_simplifier_factory { public: dependent_expr_simplifier* mk(ast_manager& m, params_ref const& p, dependent_expr_state& s) override { - return alloc(propagate_values, m, s); + return alloc(propagate_values, m, p, s); } }; diff --git a/src/util/trail.h b/src/util/trail.h index 20a525cf7..1aa7e4441 100644 --- a/src/util/trail.h +++ b/src/util/trail.h @@ -98,20 +98,21 @@ public: } }; -template -class restore_size_trail : public trail { - vector & m_vector; - unsigned m_old_size; +template +class restore_vector : public trail { + V& m_vector; + unsigned m_old_size; public: - restore_size_trail(vector & v, unsigned sz): + restore_vector(V& v): m_vector(v), - m_old_size(sz) { - } - restore_size_trail(vector & v): + m_old_size(v.size()) + {} + + restore_vector(V& v, unsigned sz): m_vector(v), - m_old_size(v.size()) { - } - + m_old_size(sz) + {} + void undo() override { m_vector.shrink(m_old_size); }