From 3d7bd40a87cee19d8b1135b3352a86b91f006a98 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sun, 4 Dec 2022 06:07:45 -0800 Subject: [PATCH] a round of cleanup --- src/ast/rewriter/demodulator_rewriter.cpp | 319 ++++++++++------------ src/ast/rewriter/demodulator_rewriter.h | 13 +- src/tactic/ufbv/ufbv_rewriter_tactic.cpp | 7 +- 3 files changed, 150 insertions(+), 189 deletions(-) diff --git a/src/ast/rewriter/demodulator_rewriter.cpp b/src/ast/rewriter/demodulator_rewriter.cpp index 56be99529..f1fc6f969 100644 --- a/src/ast/rewriter/demodulator_rewriter.cpp +++ b/src/ast/rewriter/demodulator_rewriter.cpp @@ -3,7 +3,7 @@ Copyright (c) 2006 Microsoft Corporation Module Name: - demodulator.cpp + demodulator_rewriter.cpp Abstract: @@ -17,6 +17,7 @@ Revision History: Christoph M. Wintersteiger (cwinter) 2010-04-21: Implementation Christoph M. Wintersteiger (cwinter) 2012-10-24: Moved from demodulator.h to ufbv_rewriter.h + Nikolaj Bjorner (nbjorner) 2022-12-4: Moved to demodulator_rewriter.h --*/ @@ -186,48 +187,45 @@ public: unsigned get_max() { return m_max_var_id; } }; -unsigned demodulator_rewriter::max_var_id(expr * e) -{ +unsigned demodulator_rewriter::max_var_id(expr_ref_vector const& es) { max_var_id_proc proc; - for_each_expr(proc, e); + for (expr* e : es) + for_each_expr(proc, e); return proc.get_max(); } -void demodulator_rewriter::insert_fwd_idx(expr * large, expr * small, quantifier * demodulator) { - SASSERT(large->get_kind() == AST_APP); +void demodulator_rewriter::insert_fwd_idx(app * large, expr * small, quantifier * demodulator) { SASSERT(demodulator); SASSERT(large && small); TRACE("demodulator_fwd", tout << "INSERT: " << mk_pp(demodulator, m) << std::endl; ); func_decl * fd = to_app(large)->get_decl(); - fwd_idx_map::iterator it = m_fwd_idx.find_iterator(fd); - if (it == m_fwd_idx.end()) { - quantifier_set * qs = alloc(quantifier_set, 1); + quantifier_set * qs; + if (!m_fwd_idx.find(fd, qs)) { + qs = alloc(quantifier_set, 1); m_fwd_idx.insert(fd, qs); - it = m_fwd_idx.find_iterator(fd); } - SASSERT(it->m_value); - it->m_value->insert(demodulator); + SASSERT(qs); + qs->insert(demodulator); m.inc_ref(demodulator); m.inc_ref(large); m.inc_ref(small); - m_demodulator2lhs_rhs.insert(demodulator, expr_pair(large, small)); + m_demodulator2lhs_rhs.insert(demodulator, app_expr_pair(large, small)); } void demodulator_rewriter::remove_fwd_idx(func_decl * f, quantifier * demodulator) { TRACE("demodulator_fwd", tout << "REMOVE: " << std::hex << (size_t)demodulator << std::endl; ); - fwd_idx_map::iterator it = m_fwd_idx.find_iterator(f); - if (it != m_fwd_idx.end()) { - demodulator2lhs_rhs::iterator fit = m_demodulator2lhs_rhs.find_iterator(demodulator); - expr_pair p = fit->m_value; + quantifier_set* qs; + if (m_fwd_idx.find(f, qs)) { + auto [lhs, rhs] = m_demodulator2lhs_rhs[demodulator]; m_demodulator2lhs_rhs.erase(demodulator); - it->m_value->erase(demodulator); - m.dec_ref(p.first); - m.dec_ref(p.second); + qs->erase(demodulator); + m.dec_ref(lhs); + m.dec_ref(rhs); m.dec_ref(demodulator); } else { SASSERT(m_demodulator2lhs_rhs.contains(demodulator)); @@ -235,59 +233,50 @@ void demodulator_rewriter::remove_fwd_idx(func_decl * f, quantifier * demodulato } bool demodulator_rewriter::check_fwd_idx_consistency() { - for (auto & kv : m_fwd_idx) { - quantifier_set * set = kv.m_value; + for (auto & [k, set] : m_fwd_idx) { SASSERT(set); - for (auto e : *set) { + for (auto e : *set) if (!m_demodulator2lhs_rhs.contains(e)) return false; - } } - return true; } void demodulator_rewriter::show_fwd_idx(std::ostream & out) { - for (auto & kv : m_fwd_idx) { - quantifier_set * set = kv.m_value; - SASSERT(!set); - - out << kv.m_key->get_name() << ": " << std::endl; - - for (auto e : *set) { - out << std::hex << (size_t)e << std::endl; - } + for (auto & [k, set] : m_fwd_idx) { + out << k->get_name() << ": " << std::endl; + if (set) + for (auto e : *set) + out << std::hex << (size_t)e << std::endl; } out << "D2LR: " << std::endl; - for (auto & kv : m_demodulator2lhs_rhs) { - out << (size_t) kv.m_key << std::endl; + for (auto & [k, v] : m_demodulator2lhs_rhs) { + out << (size_t) k << std::endl; } } -bool demodulator_rewriter::rewrite1(func_decl * f, expr_ref_vector & m_new_args, expr_ref & np) { - fwd_idx_map::iterator it = m_fwd_idx.find_iterator(f); - if (it != m_fwd_idx.end()) { - TRACE("demodulator_bug", tout << "trying to rewrite: " << f->get_name() << " args:\n"; - tout << m_new_args << "\n";); - for (quantifier* d : *it->m_value) { +bool demodulator_rewriter::rewrite1(func_decl * f, expr_ref_vector const & args, expr_ref & np) { + quantifier_set* set; + if (!m_fwd_idx.find(f, set)) + return false; + TRACE("demodulator_bug", tout << "trying to rewrite: " << f->get_name() << " args:\n"; + tout << m_new_args << "\n";); - SASSERT(m_demodulator2lhs_rhs.contains(d)); - expr_pair l_s; - m_demodulator2lhs_rhs.find(d, l_s); - app * large = to_app(l_s.first); + for (quantifier* d : *set) { - if (large->get_num_args() != m_new_args.size()) - continue; - - TRACE("demodulator_bug", tout << "Matching with demodulator: " << mk_pp(d, m) << std::endl; ); - - SASSERT(large->get_decl() == f); - - if (m_match_subst(large, l_s.second, m_new_args.data(), np)) { - TRACE("demodulator_bug", tout << "succeeded...\n" << mk_pp(l_s.second, m) << "\n===>\n" << mk_pp(np, m) << "\n";); - return true; - } + auto const& [large, rhs] = m_demodulator2lhs_rhs[d]; + + if (large->get_num_args() != args.size()) + continue; + + TRACE("demodulator_bug", tout << "Matching with demodulator: " << mk_pp(d, m) << std::endl; ); + + SASSERT(large->get_decl() == f); + + if (m_match_subst(large, rhs, args.data(), np)) { + TRACE("demodulator_bug", tout << "succeeded...\n" << mk_pp(rhs, m) << "\n===>\n" << mk_pp(np, m) << "\n";); + return true; } } @@ -449,21 +438,19 @@ public: void operator()(var * n) {} void operator()(quantifier * n) {} void operator()(app * n) { - // We track only uninterpreted and constant functions. - if (n->get_num_args()==0) return; + // We track only uninterpreted functions. + if (n->get_num_args() == 0) + return; SASSERT(m_expr && m_expr != (expr*) 0x00000003); - func_decl * d=n->get_decl(); - if (d->get_family_id() == null_family_id) { - back_idx_map::iterator it = m_back_idx.find_iterator(d); - if (it != m_back_idx.end()) { - SASSERT(it->m_value); - it->m_value->insert(m_expr); - } else { - expr_set * e = alloc(expr_set); - e->insert(m_expr); - m_back_idx.insert(d, e); - } + func_decl * d = n->get_decl(); + if (d->get_family_id() != null_family_id) + return; + expr_set* set = nullptr; + if (!m_back_idx.find(d, set)) { + set = alloc(expr_set); + m_back_idx.insert(d, set); } + set->insert(m_expr); } }; @@ -475,39 +462,48 @@ public: void operator()(var * n) {} void operator()(quantifier * n) {} void operator()(app * n) { - // We track only uninterpreted and constant functions. - if (n->get_num_args()==0) return; - func_decl * d=n->get_decl(); - if (d->get_family_id() == null_family_id) { - back_idx_map::iterator it = m_back_idx.find_iterator(d); - if (it != m_back_idx.end()) { - SASSERT(it->m_value); - it->m_value->remove(m_expr); - } - } + // We track only uninterpreted functions. + if (n->get_num_args() == 0) + return; + func_decl * d = n->get_decl(); + if (d->get_family_id() != null_family_id) + return; + expr_set* set = nullptr; + if (m_back_idx.find(d, set)) + set->remove(m_expr); } }; + +void demodulator_rewriter::insert_bwd_idx(expr* e) { + add_back_idx_proc proc(m_back_idx, e); + for_each_expr(proc, e); +} + +void demodulator_rewriter::remove_bwd_idx(expr* e) { + remove_back_idx_proc proc(m_back_idx, e); + for_each_expr(proc, e); +} + void demodulator_rewriter::reschedule_processed(func_decl * f) { //use m_back_idx to find all formulas p in m_processed that contains f { - back_idx_map::iterator it = m_back_idx.find_iterator(f); - if (it != m_back_idx.end()) { - SASSERT(it->m_value); - expr_set temp; + expr_set* set = nullptr; + if (!m_back_idx.find(f, set)) + return; + SASSERT(set); + expr_set temp; - for (expr* p : *it->m_value) { - if (m_processed.contains(p)) + for (expr* p : *set) + if (m_processed.contains(p)) temp.insert(p); - } - for (expr * p : temp) { - // remove p from m_processed and m_back_idx - m_processed.remove(p); - remove_back_idx_proc proc(m_back_idx, p); // this could change it->m_value, thus we need the `temp' set. - for_each_expr(proc, p); - // insert p into m_todo - m_todo.push_back(p); - } + for (expr * p : temp) { + // remove p from m_processed and m_back_idx + m_processed.remove(p); + // this could change `set', thus we need the `temp' set. + remove_bwd_idx(p); + // insert p into m_todo + m_todo.push_back(p); } } @@ -545,20 +541,10 @@ bool demodulator_rewriter::can_rewrite(expr * n, expr * lhs) { break; case AST_QUANTIFIER: - if (!for_each_expr_args(stack, visited, to_quantifier(curr)->get_num_patterns(), - to_quantifier(curr)->get_patterns())) { - break; - } - if (!for_each_expr_args(stack, visited, to_quantifier(curr)->get_num_no_patterns(), - to_quantifier(curr)->get_no_patterns())) { - break; - } - if (!visited.is_marked(to_quantifier(curr)->get_expr())) { + if (visited.is_marked(to_quantifier(curr)->get_expr())) + stack.pop_back(); + else stack.push_back(to_quantifier(curr)->get_expr()); - break; - } - - stack.pop_back(); break; default: UNREACHABLE(); @@ -571,66 +557,55 @@ bool demodulator_rewriter::can_rewrite(expr * n, expr * lhs) { void demodulator_rewriter::reschedule_demodulators(func_decl * f, expr * lhs) { // use m_back_idx to find all demodulators d in m_fwd_idx that contains f { - //ptr_vector to_remove; - back_idx_map::iterator it = m_back_idx.find_iterator(f); - if (it != m_back_idx.end()) { - SASSERT(it->m_value); - expr_set all_occurrences; - expr_ref l(m); + expr_set* set = nullptr; + if (!m_back_idx.find(f, set)) + return; + SASSERT(set); + expr_set all_occurrences; + app_ref l(m); - for (auto s : *it->m_value) - all_occurrences.insert(s); + for (auto s : *set) + all_occurrences.insert(s); + + // Run over all f-demodulators + for (expr* occ : all_occurrences) { + + if (!is_quantifier(occ)) + continue; + quantifier* qe = to_quantifier(occ); + + // Use the fwd idx to find out whether this is a demodulator. + app_expr_pair p; + if (!m_demodulator2lhs_rhs.find(qe, p)) + continue; - // Run over all f-demodulators - for (expr* occ : all_occurrences) { + l = p.first; + quantifier_ref d(qe, m); + func_decl_ref df(l->get_decl(), m); + + // Now we know there is an occurrence of f in d + if (!can_rewrite(d, lhs)) + continue; - if (!is_quantifier(occ)) - continue; + TRACE("demodulator", tout << "Rescheduling: " << std::endl << mk_pp(d, m) << std::endl); - // Use the fwd idx to find out whether this is a demodulator. - demodulator2lhs_rhs::iterator d2lr_it = m_demodulator2lhs_rhs.find_iterator(to_quantifier(occ)); - if (d2lr_it != m_demodulator2lhs_rhs.end()) { - l = d2lr_it->m_value.first; - quantifier_ref d(m); - func_decl_ref df(m); - d = to_quantifier(occ); - df = to_app(l)->get_decl(); - - // Now we know there is an occurrence of f in d - // if n' can rewrite d { - if (can_rewrite(d, lhs)) { - TRACE("demodulator", tout << "Rescheduling: " << std::endl << mk_pp(d, m) << std::endl; ); - // remove d from m_fwd_idx - remove_fwd_idx(df, d); - // remove d from m_back_idx - // just remember it here, because otherwise it and/or esit might become invalid? - // to_remove.insert(d); - remove_back_idx_proc proc(m_back_idx, d); - for_each_expr(proc, d); - // insert d into m_todo - m_todo.push_back(d); - } - } - } + remove_fwd_idx(df, d); + remove_bwd_idx(d); + m_todo.push_back(d); } } -void demodulator_rewriter::operator()(unsigned n, expr * const * exprs, - expr_ref_vector & new_exprs) { +void demodulator_rewriter::operator()(expr_ref_vector const& exprs, + expr_ref_vector & new_exprs) { - TRACE("demodulator", tout << "before demodulator:\n"; - for ( unsigned i = 0 ; i < n ; i++ ) - tout << mk_pp(exprs[i], m) << std::endl; ); + TRACE("demodulator", tout << "before demodulator:\n" << exprs); // Initially, m_todo contains all formulas. That is, it contains the argument exprs. m_fwd_idx, m_processed, m_back_idx are empty. - unsigned max_vid = 0; - for ( unsigned i = 0 ; i < n ; i++ ) { - m_todo.push_back(exprs[i]); - max_vid = std::max(max_vid, max_var_id(exprs[i])); - } + for (expr* e : exprs) + m_todo.push_back(e); - m_match_subst.reserve(max_vid); + m_match_subst.reserve(max_var_id(exprs)); while (!m_todo.empty()) { // let n be the next formula in m_todo. @@ -644,7 +619,6 @@ void demodulator_rewriter::operator()(unsigned n, expr * const * exprs, // unless there is a demodulator cycle // SASSERT(rewrite(np)==np); - // if (n' is not a demodulator) { app_ref large(m); expr_ref small(m); if (!is_demodulator(np, large, small)) { @@ -652,22 +626,11 @@ void demodulator_rewriter::operator()(unsigned n, expr * const * exprs, m_processed.insert(np); m_in_processed.push_back(np); // update m_back_idx (traverse n' and for each uninterpreted function declaration f in n' add the entry f->n' to m_back_idx) - add_back_idx_proc proc(m_back_idx, np); - for_each_expr(proc, np); - } else { + insert_bwd_idx(np); + } + else { // np is a demodulator that allows us to replace 'large' with 'small'. - TRACE("demodulator", tout << "Found demodulator: " << std::endl; - tout << mk_pp(large.get(), m) << std::endl << " ---> " << - std::endl << mk_pp(small.get(), m) << std::endl; ); - - TRACE("demodulator_s", tout << "Found demodulator: " << std::endl; - tout << to_app(large)->get_decl()->get_name() << - "[" << to_app(large)->get_depth() << "]" << " ---> "; - if (is_app(small)) - tout << to_app(small)->get_decl()->get_name() << - "[" << to_app(small)->get_depth() << "]" << std::endl; - else - tout << mk_pp(small.get(), m) << std::endl; ); + TRACE("demodulator", tout << "Found demodulator:\n" << large << "\n ---> " << small << "\n"); // let f be the top symbol of n' func_decl * f = large->get_decl(); @@ -679,8 +642,7 @@ void demodulator_rewriter::operator()(unsigned n, expr * const * exprs, insert_fwd_idx(large, small, to_quantifier(np)); // update m_back_idx - add_back_idx_proc proc(m_back_idx, np); - for_each_expr(proc, np); + insert_bwd_idx(np); } } @@ -690,12 +652,11 @@ void demodulator_rewriter::operator()(unsigned n, expr * const * exprs, TRACE("demodulator", tout << mk_pp(e, m) << std::endl; ); } - for (auto const& kv : m_fwd_idx) { - if (kv.m_value) { - for (expr* e : *kv.m_value) { + for (auto const& [k, set] : m_fwd_idx) { + if (set) { + for (expr* e : *set) new_exprs.push_back(e); - TRACE("demodulator", tout << mk_pp(e, m) << std::endl; ); - } + TRACE("demodulator", for (expr* e : *set) tout << mk_pp(e, m) << std::endl; ); } } diff --git a/src/ast/rewriter/demodulator_rewriter.h b/src/ast/rewriter/demodulator_rewriter.h index 90d15bcec..2152520ce 100644 --- a/src/ast/rewriter/demodulator_rewriter.h +++ b/src/ast/rewriter/demodulator_rewriter.h @@ -111,11 +111,12 @@ class demodulator_rewriter final { typedef array_map expr_map; typedef std::pair expr_pair; + typedef std::pair app_expr_pair; typedef obj_hashtable expr_set; typedef obj_map back_idx_map; typedef obj_hashtable quantifier_set; typedef obj_map fwd_idx_map; - typedef obj_map demodulator2lhs_rhs; + typedef obj_map demodulator2lhs_rhs; typedef expr_map rewrite_cache_map; /** @@ -172,20 +173,22 @@ class demodulator_rewriter final { rewrite_cache_map m_rewrite_cache; expr_ref_buffer m_new_exprs; - void insert_fwd_idx(expr * large, expr * small, quantifier * demodulator); + void insert_fwd_idx(app * large, expr * small, quantifier * demodulator); void remove_fwd_idx(func_decl * f, quantifier * demodulator); + void insert_bwd_idx(expr* q); + void remove_bwd_idx(expr* q); bool check_fwd_idx_consistency(); void show_fwd_idx(std::ostream & out); bool is_demodulator(expr * e, app_ref & large, expr_ref & small) const; bool can_rewrite(expr * n, expr * lhs); expr * rewrite(expr * n); - bool rewrite1(func_decl * f, expr_ref_vector & m_new_args, expr_ref & np); + bool rewrite1(func_decl * f, expr_ref_vector const & args, expr_ref & np); bool rewrite_visit_children(app * a); void rewrite_cache(expr * e, expr * new_e, bool done); void reschedule_processed(func_decl * f); void reschedule_demodulators(func_decl * f, expr * np); - unsigned max_var_id(expr * e); + unsigned max_var_id(expr_ref_vector const& es); // is_smaller returns -1 for e1e2. int is_smaller(expr * e1, expr * e2) const; @@ -197,7 +200,7 @@ public: demodulator_rewriter(ast_manager & m); ~demodulator_rewriter(); - void operator()(unsigned n, expr * const * exprs, expr_ref_vector & new_exprs); + void operator()(expr_ref_vector const& exprs, expr_ref_vector & new_exprs); /** Given a demodulator (aka rewrite rule) of the form diff --git a/src/tactic/ufbv/ufbv_rewriter_tactic.cpp b/src/tactic/ufbv/ufbv_rewriter_tactic.cpp index f8f3153c5..f374ea114 100644 --- a/src/tactic/ufbv/ufbv_rewriter_tactic.cpp +++ b/src/tactic/ufbv/ufbv_rewriter_tactic.cpp @@ -55,15 +55,12 @@ public: demodulator_rewriter dem(m_manager); expr_ref_vector forms(m_manager), new_forms(m_manager); - proof_ref_vector proofs(m_manager), new_proofs(m_manager); unsigned size = g->size(); - for (unsigned i = 0; i < size; i++) { + for (unsigned i = 0; i < size; i++) forms.push_back(g->form(i)); - proofs.push_back(g->pr(i)); - } - dem(forms.size(), forms.data(), new_forms); + dem(forms, new_forms); g->reset(); for (unsigned i = 0; i < new_forms.size(); i++)