diff --git a/src/ast/term_enumeration.cpp b/src/ast/term_enumeration.cpp index 84b868220..458bc4f61 100644 --- a/src/ast/term_enumeration.cpp +++ b/src/ast/term_enumeration.cpp @@ -5,7 +5,7 @@ * for Bottom-Up Enumerative Synthesis"). Adapted to use Z3's internal APIs. * * Key ideas: - * - Terms are enumerated bottom-up by "height" (max nesting depth). + * - Terms are enumerated bottom-up by "cost" (calculated by tree size). * - Observational equivalence (OE): two terms that produce the same outputs * on all sample inputs are considered equivalent; only one representative * per equivalence class is kept. @@ -13,23 +13,22 @@ * (constants, variables) are available for enumeration. */ -#include "ast/term_enumeration.h" #include #include #include #include +#include "util/vector.h" +#include "util/ref.h" +#include "util/obj_hashtable.h" #include "ast/ast.h" -#include "ast/ast_pp.h" #include "ast/ast_ll_pp.h" #include "ast/arith_decl_plugin.h" #include "ast/bv_decl_plugin.h" #include "ast/seq_decl_plugin.h" +#include "ast/term_enumeration.h" #include "model/model.h" #include "model/model_evaluator.h" -#include "util/vector.h" -#include "util/ref.h" -#include "util/obj_hashtable.h" namespace term_enum { @@ -199,9 +198,11 @@ private: }; // ============================================================================ -// Term Bank - stores enumerated terms by height and sort +// Term Bank - stores enumerated terms by cost and sort // ============================================================================ +using cost_terms = vector>; + class term_bank { using sort_term_map = obj_map>; public: @@ -217,35 +218,41 @@ public: m_terms.clear(); } - void add(expr* term, unsigned height) { + void add(expr* term, unsigned cost) { sort* s = term->get_sort(); - unsigned sid = s->get_id(); m_pinned.push_back(term); - if (height >= m_terms.size()) - m_terms.resize(height + 1); - if (!m_terms[height]) - m_terms[height] = alloc(sort_term_map); - m_terms[height]->insert_if_not_there(s, ptr_vector()).push_back(term); + if (cost >= m_terms.size()) + m_terms.resize(cost + 1); + if (!m_terms[cost]) + m_terms[cost] = alloc(sort_term_map); + m_terms[cost]->insert_if_not_there(s, ptr_vector()).push_back(term); } - /** Get all terms of a given sort up to (and including) max_height */ - expr_ref_vector get_by_sort(sort* s, unsigned max_height) const { - expr_ref_vector result(m); - for (unsigned h = 0; h <= max_height; ++h) { - if (h >= m_terms.size()) + /** Get all terms of a given sort up to (and including) max_cost */ + cost_terms get_by_sort(sort* s, unsigned max_cost) const { + cost_terms result; + for (unsigned c = 0; c <= max_cost; ++c) { + if (c >= m_terms.size()) break; - if (!m_terms[h]->contains(s)) + if (!m_terms[c]->contains(s)) continue; - for (auto t : m_terms[h]->find(s)) - result.push_back(t); + for (auto t : m_terms[c]->find(s)) + result.push_back({t, c}); } return result; } + ptr_vector null_ptr_vector; + ptr_vector const &get_by_cost_and_sort(unsigned cost, sort *s) const { + if (cost >= m_terms.size() || !m_terms[cost] || !m_terms[cost]->contains(s)) + return null_ptr_vector; + return m_terms[cost]->find(s); + } + private: ast_manager& m; expr_ref_vector m_pinned; - // height -> sort -> terms + // cost -> sort -> terms ptr_vector m_terms; }; @@ -256,12 +263,12 @@ private: /** * Iterates over all tuples (c1, c2, ..., cn) where each ci has the required * sort, drawn from the term bank, with at least one child at the current - * height - 1 (to avoid regenerating previously seen terms). + * cost - 1 (to avoid regenerating previously seen terms). */ class children_iterator { public: - children_iterator(ast_manager& m, production const& prod, term_bank const& bank, unsigned current_height) - : m(m), m_prod(prod), m_current_height(current_height), m_done(false) + children_iterator(ast_manager& m, production const& prod, term_bank const& bank, unsigned current_cost) + : m(m), m_prod(prod), m_current_cost(current_cost), m_done(false) { m_arity = prod.domain.size(); if (m_arity == 0) { @@ -269,7 +276,7 @@ public: return; } for (unsigned i = 0; i < m_arity; ++i) { - m_candidates.push_back(bank.get_by_sort(prod.domain[i], current_height - 1)); + m_candidates.push_back(bank.get_by_sort(prod.domain[i], current_cost - 1)); if (m_candidates.back().empty()) { m_done = true; return; @@ -278,19 +285,23 @@ public: m_indices.resize(m_arity, 0); } - bool has_next() { + bool has_next(unsigned cost) { while (!m_done) { - if (has_child_at_max_height()) + if (has_child_at_cost(cost)) return true; advance(); } return false; } - expr_ref_vector next() { + expr_ref_vector next(unsigned& cost) { expr_ref_vector result(m); - for (unsigned i = 0; i < m_arity; ++i) - result.push_back(m_candidates[i].get(m_indices[i])); + cost = 1; + for (unsigned i = 0; i < m_arity; ++i) { + auto [e, c] = m_candidates[i].get(m_indices[i]); + cost += c; + result.push_back(e); + } advance(); return result; } @@ -298,18 +309,23 @@ public: private: ast_manager& m; production const& m_prod; - unsigned m_current_height; + unsigned m_current_cost; unsigned m_arity; bool m_done; - vector m_candidates; + vector m_candidates; svector m_indices; - bool has_child_at_max_height() const { - return true; + bool has_child_at_cost(unsigned cost) const { + for (unsigned i = 0; i < m_arity; ++i) { + auto [e, c] = m_candidates[i].get(m_indices[i]); + if (c + 1 == cost) + return true; + } + return false; } void advance() { - for (auto i = m_arity; i-- > 0;) { + for (auto i = m_arity; i-- > 0;) { m_indices[i]++; if (m_indices[i] < m_candidates[i].size()) return; m_indices[i] = 0; @@ -323,8 +339,8 @@ private: // ============================================================================ /** - * Enumerates terms bottom-up by height, applying observational equivalence - * pruning. Users iterate via has_next() / next(), or call enumerate_up_to(). + * Enumerates terms bottom-up by cost, applying observational equivalence + * pruning. Users iterate via has_next() / next(). * * Usage: * ast_manager m; @@ -344,13 +360,10 @@ class bottom_up_enumerator { public: bottom_up_enumerator(grammar& grammar, oe_manager& oe) : m_grammar(grammar), m(grammar.mgr()), m_oe(oe), - m_bank(grammar.mgr()), m_height(0), - m_leaf_idx(0), m_op_idx(0), m_state(State::Leaves), - m_target_sort(nullptr), m_pending(grammar.mgr()) + m_bank(grammar.mgr()), m_pending(grammar.mgr()) {} void set_target_sort(sort* s) { m_target_sort = s; } - void set_max_height(unsigned h) { m_max_height = h; } bool has_next() { if (m_pending) return true; @@ -366,20 +379,10 @@ public: return result; } - expr_ref_vector enumerate_up_to(unsigned max_height) { - m_max_height = max_height; - expr_ref_vector results(m); - while (has_next()) { - results.push_back(next()); - } - return results; - } - term_bank const& bank() const { return m_bank; } - unsigned current_height() const { return m_height; } void reset() { - m_height = 0; + m_cost = 0; m_leaf_idx = 0; m_op_idx = 0; m_state = State::Leaves; @@ -390,20 +393,21 @@ public: } private: - enum class State { Leaves, operators, Done }; + enum class State { Leaves, Operators, Done }; grammar& m_grammar; ast_manager& m; oe_manager& m_oe; term_bank m_bank; - unsigned m_height; - unsigned m_leaf_idx; - unsigned m_op_idx; - State m_state; - sort* m_target_sort; + unsigned m_cost = 0; + unsigned m_leaf_idx = 0; + unsigned m_op_idx = 0; + unsigned m_bank_idx = 0; + unsigned m_bank_size = 0; + State m_state = State::Leaves; + sort* m_target_sort = nullptr; expr_ref m_pending; std::unique_ptr m_children_iter; - unsigned m_max_height = 100; bool sort_matches(expr* e) const { return !m_target_sort || e->get_sort() == m_target_sort; @@ -424,21 +428,22 @@ private: return term; } } - m_state = State::operators; - m_height = 1; + m_state = State::Operators; + m_cost = 1; m_op_idx = 0; + m_bank_idx = 0; + m_bank_size = get_bank_size(); m_children_iter.reset(); break; - case State::operators: { + case State::Operators: { expr* result = enumerate_operators(); - if (result) return result; - m_height++; - if (m_height > m_max_height) { - m_state = State::Done; - break; - } + if (result) + return result; + m_cost++; m_op_idx = 0; + m_bank_idx = 0; + m_bank_size = get_bank_size(); m_children_iter.reset(); break; } @@ -448,24 +453,43 @@ private: } } - expr* enumerate_operators() { - auto const& ops = m_grammar.operators(); + unsigned get_bank_size() const { + auto const &terms = m_bank.get_by_cost_and_sort(m_cost, m_target_sort); + return terms.size(); + } + + expr *enumerate_operators() { + auto const &ops = m_grammar.operators(); while (true) { - if (m_children_iter && m_children_iter->has_next()) { - expr_ref_vector children = m_children_iter->next(); - production const& prod = ops[m_op_idx - 1]; + + // first find terms at m_cost that were already created + if (m_bank_idx < m_bank_size) { + auto const &terms = m_bank.get_by_cost_and_sort(m_cost, m_target_sort); + auto t = terms.get(m_bank_idx); + m_bank_idx++; + SASSERT(sort_matches(t)); + return t; + } + + // then create new terms using children at cost below current m_cost. + if (m_children_iter && m_children_iter->has_next(m_cost)) { + unsigned new_cost = 0; + expr_ref_vector children = m_children_iter->next(new_cost); + production const &prod = ops[m_op_idx - 1]; expr_ref term = prod.builder(children); + SASSERT(new_cost >= m_cost); if (m_oe.is_representative(term)) { - m_bank.add(term, m_height); - if (sort_matches(term)) + m_bank.add(term, new_cost); + if (sort_matches(term) && new_cost == m_cost) return term; } continue; } - if (m_op_idx >= ops.size()) return nullptr; - production const& prod = ops[m_op_idx]; + if (m_op_idx >= ops.size()) + return nullptr; + production const &prod = ops[m_op_idx]; m_op_idx++; - m_children_iter = std::make_unique(m, prod, m_bank, m_height); + m_children_iter = std::make_unique(m, prod, m_bank, m_cost); } } }; @@ -498,37 +522,35 @@ struct term_enumeration::imp { m_cost = cost; } - // Enumerate terms of given sort up to a height, ordered by cost. - // Returns the next term in cost order, or nullptr if exhausted at current height. - expr* next_term(sort* s, unsigned& height_state, unsigned& idx_state, + // Enumerate terms of given sort up to a cost, ordered by cost. + // Returns the next term in cost order, or nullptr if exhausted at current cost. + expr* next_term(sort* s, unsigned& cost_state, unsigned& idx_state, vector& levels) { // Expand levels as needed - while (idx_state >= level_size(levels, height_state)) { - height_state++; - if (height_state > 100) - return nullptr; - expand_level(s, height_state, levels); + while (idx_state >= level_size(levels, cost_state)) { + cost_state++; + expand_level(s, cost_state, levels); idx_state = 0; - if (level_size(levels, height_state) > 0) + if (level_size(levels, cost_state) > 0) break; } - if (height_state >= levels.size() || idx_state >= levels[height_state].size()) + if (cost_state >= levels.size() || idx_state >= levels[cost_state].size()) return nullptr; - return levels[height_state].get(idx_state++); + return levels[cost_state].get(idx_state++); } private: - unsigned level_size(vector const& levels, unsigned h) const { - if (h >= levels.size()) return 0; - return levels[h].size(); + unsigned level_size(vector const& levels, unsigned cost) const { + if (cost >= levels.size()) return 0; + return levels[cost].size(); } - void expand_level(sort* s, unsigned height, vector& levels) { - if (height >= levels.size()) - levels.resize(height + 1, expr_ref_vector(m)); + void expand_level(sort* s, unsigned cost, vector& levels) { + if (cost >= levels.size()) + levels.resize(cost + 1, expr_ref_vector(m)); - // Collect terms at this height - if (height == 0) { + // Collect terms at this cost + if (cost == 0) { // Leaves for (auto const& prod : m_grammar.leaves()) { if (prod.range.get() != s) continue; @@ -544,19 +566,21 @@ private: // operators for (auto const& prod : m_grammar.operators()) { if (prod.range.get() != s) continue; - term_enum::children_iterator iter(m, prod, m_bottom_up_enumerator.bank(), height); - while (iter.has_next()) { - expr_ref_vector children = iter.next(); + term_enum::children_iterator iter(m, prod, m_bottom_up_enumerator.bank(), cost); + while (iter.has_next(cost)) { + unsigned new_cost = 0; + expr_ref_vector children = iter.next(new_cost); expr_ref term = prod.builder(children); + levels.reserve(new_cost + 1, expr_ref_vector(m)); if (m_oe.is_representative(term)) - levels[height].push_back(term); + levels[new_cost].push_back(term); } } } // Sort by cost if cost function is set - if (m_cost && !levels[height].empty()) { - expr_ref_vector& lv = levels[height]; + if (m_cost && !levels[cost].empty()) { + expr_ref_vector& lv = levels[cost]; std::sort(lv.data(), lv.data() + lv.size(), [&](expr* a, expr* b) { return m_cost(a) < m_cost(b); }); } @@ -568,30 +592,29 @@ private: struct term_enumeration::iterator::iter_imp { imp& m_imp; sort* m_sort; - unsigned m_height; - unsigned m_idx; + unsigned m_cost = 0; + unsigned m_idx = 0; vector m_levels; - expr* m_current; + expr* m_current = nullptr; bool m_end; - iter_imp(imp& i, sort* s) : - m_imp(i), m_sort(s), m_height(0), m_idx(0), m_current(nullptr), m_end(false) { + iter_imp(imp& i, sort* s) : m_imp(i), m_sort(s), m_end(false) { expand_current_level(); advance_to_valid(); } // Sentinel constructor iter_imp(imp& i) : - m_imp(i), m_sort(nullptr), m_height(0), m_idx(0), m_current(nullptr), m_end(true) {} + m_imp(i), m_sort(nullptr), m_end(true) {} void expand_current_level() { - if (m_height >= m_levels.size()) - m_levels.resize(m_height + 1, expr_ref_vector(m_imp.m)); + if (m_cost >= m_levels.size()) + m_levels.resize(m_cost + 1, expr_ref_vector(m_imp.m)); - if (!m_levels[m_height].empty()) + if (!m_levels[m_cost].empty()) return; - if (m_height == 0) { + if (m_cost == 0) { for (auto const& prod : m_imp.m_grammar.leaves()) { if (prod.range.get() != m_sort) continue; expr_ref_vector empty_args(m_imp.m); @@ -603,19 +626,21 @@ struct term_enumeration::iterator::iter_imp { else { for (auto const& prod : m_imp.m_grammar.operators()) { if (prod.range.get() != m_sort) continue; - term_enum::children_iterator iter(m_imp.m, prod, m_imp.m_bottom_up_enumerator.bank(), m_height); - while (iter.has_next()) { - expr_ref_vector children = iter.next(); + term_enum::children_iterator iter(m_imp.m, prod, m_imp.m_bottom_up_enumerator.bank(), m_cost); + while (iter.has_next(m_cost)) { + unsigned new_cost = 0; + expr_ref_vector children = iter.next(new_cost); expr_ref term = prod.builder(children); + m_levels.reserve(new_cost + 1, expr_ref_vector(m_imp.m)); if (m_imp.m_oe.is_representative(term)) - m_levels[m_height].push_back(term); + m_levels[new_cost].push_back(term); } } } // Sort by cost if cost function is set - if (m_imp.m_cost && !m_levels[m_height].empty()) { - expr_ref_vector& lv = m_levels[m_height]; + if (m_imp.m_cost && !m_levels[m_cost].empty()) { + expr_ref_vector& lv = m_levels[m_cost]; std::sort(lv.data(), lv.data() + lv.size(), [&](expr* a, expr* b) { return m_imp.m_cost(a) < m_imp.m_cost(b); }); } @@ -623,15 +648,15 @@ struct term_enumeration::iterator::iter_imp { void advance_to_valid() { while (true) { - if (m_height >= m_levels.size()) + if (m_cost >= m_levels.size()) expand_current_level(); - if (m_idx < m_levels[m_height].size()) { - m_current = m_levels[m_height].get(m_idx); + if (m_idx < m_levels[m_cost].size()) { + m_current = m_levels[m_cost].get(m_idx); return; } - m_height++; + m_cost++; m_idx = 0; - if (m_height > 100) { + if (m_cost > 100) { m_end = true; m_current = nullptr; return; diff --git a/src/smt/smt_model_finder.cpp b/src/smt/smt_model_finder.cpp index 27516b3dc..1af371e8b 100644 --- a/src/smt/smt_model_finder.cpp +++ b/src/smt/smt_model_finder.cpp @@ -187,14 +187,14 @@ namespace smt { \brief Base class used to solve model construction constraints. */ class node { - unsigned m_id; - node* m_find{ nullptr }; - unsigned m_eqc_size{ 1 }; + unsigned m_id = 0; + node* m_find = nullptr; + unsigned m_eqc_size = 1; - sort* m_sort; // sort of the elements in the instantiation set. + sort* m_sort = nullptr; // sort of the elements in the instantiation set. - bool m_mono_proj{ false }; // relevant for integers & reals & bit-vectors - bool m_signed_proj{ false }; // relevant for bit-vectors. + bool m_mono_proj = false; // relevant for integers & reals & bit-vectors + bool m_signed_proj = false; // relevant for bit-vectors. ptr_vector m_avoid_set; ptr_vector m_exceptions; @@ -1235,8 +1235,8 @@ namespace smt { void populate_inst_sets(quantifier* q, func_decl* mhead, ptr_vector& uvar_inst_sets, context* ctx) override { if (m_f != mhead) return; - uvar_inst_sets.reserve(m_var_j + 1, 0); - if (uvar_inst_sets[m_var_j] == 0) + uvar_inst_sets.reserve(m_var_j + 1, nullptr); + if (uvar_inst_sets[m_var_j] == nullptr) uvar_inst_sets[m_var_j] = alloc(instantiation_set, ctx->get_manager()); instantiation_set* s = uvar_inst_sets[m_var_j]; SASSERT(s != nullptr);