diff --git a/src/ast/term_enumeration.cpp b/src/ast/term_enumeration.cpp index f98ae209c..27c2b9823 100644 --- a/src/ast/term_enumeration.cpp +++ b/src/ast/term_enumeration.cpp @@ -1,6 +1,5 @@ -#pragma once /** - * term_enumeration.h - Bottom-up term enumeration module for Z3 + * term_enumeration.cpp - Bottom-up term enumeration module for Z3 * * Inspired by the Probe synthesizer (Barke et al., "Just-in-Time Learning * for Bottom-Up Enumerative Synthesis"). Adapted to use Z3's internal APIs. @@ -14,12 +13,14 @@ * (constants, variables) are available for enumeration. */ +#include "ast/term_enumeration.h" #include #include #include #include #include #include +#include #include "ast/ast.h" #include "ast/ast_pp.h" #include "ast/ast_ll_pp.h" @@ -714,3 +715,267 @@ private: }; } // namespace term_enum + +// ============================================================================ +// term_enumeration public interface implementation +// ============================================================================ + +struct term_enumeration::imp { + ast_manager& m; + term_enum::Grammar m_grammar; + term_enum::OEManager m_oe; + term_enum::Enumerator m_enumerator; + std::function m_cost; + + imp(ast_manager& m) : + m(m), m_grammar(m), m_oe(m), m_enumerator(m_grammar, m_oe) {} + + void add_production(func_decl* f) { + m_grammar.add_func_decl(f); + } + + void add_production(expr* e) { + m_grammar.add_expr(e); + } + + void set_cost(std::function const& cost) { + m_cost = cost; + } + + // Enumerate terms of given sort up to a height, ordered by cost. + // Returns the next term in cost order, or nullptr if exhausted at current height. + expr* next_term(sort* s, unsigned& height_state, unsigned& idx_state, + vector& levels) { + // Expand levels as needed + while (idx_state >= level_size(levels, height_state)) { + height_state++; + if (height_state > 100) + return nullptr; + expand_level(s, height_state, levels); + idx_state = 0; + if (level_size(levels, height_state) > 0) + break; + } + if (height_state >= levels.size() || idx_state >= levels[height_state].size()) + return nullptr; + return levels[height_state][idx_state++]; + } + +private: + unsigned level_size(vector const& levels, unsigned h) const { + if (h >= levels.size()) return 0; + return levels[h].size(); + } + + void expand_level(sort* s, unsigned height, vector& levels) { + if (height >= levels.size()) + levels.resize(height + 1, expr_ref_vector(m)); + + // Collect terms at this height + if (height == 0) { + // Leaves + for (auto const& prod : m_grammar.leaves()) { + if (prod.range.get() != s) continue; + expr_ref_vector empty_args(m); + expr_ref term = prod.builder(empty_args); + if (m_oe.is_representative(term)) { + m_enumerator.bank(); // just to ensure bank is populated + levels[0].push_back(term); + } + } + } + else { + // Operators + for (auto const& prod : m_grammar.operators()) { + if (prod.range.get() != s) continue; + term_enum::ChildrenIterator iter(m, prod, m_enumerator.bank(), height); + while (iter.has_next()) { + expr_ref_vector children = iter.next(); + expr_ref term = prod.builder(children); + if (m_oe.is_representative(term)) + levels[height].push_back(term); + } + } + } + + // Sort by cost if cost function is set + if (m_cost && !levels[height].empty()) { + expr_ref_vector& lv = levels[height]; + std::sort(lv.data(), lv.data() + lv.size(), + [&](expr* a, expr* b) { return m_cost(a) < m_cost(b); }); + } + } +}; + +// -- iterator implementation -- + +struct term_enumeration::iterator::iter_imp { + imp& m_imp; + sort* m_sort; + unsigned m_height; + unsigned m_idx; + vector m_levels; + expr* m_current; + bool m_end; + + iter_imp(imp& i, sort* s) : + m_imp(i), m_sort(s), m_height(0), m_idx(0), m_current(nullptr), m_end(false) { + expand_current_level(); + advance_to_valid(); + } + + // Sentinel constructor + iter_imp(imp& i) : + m_imp(i), m_sort(nullptr), m_height(0), m_idx(0), m_current(nullptr), m_end(true) {} + + void expand_current_level() { + if (m_height >= m_levels.size()) + m_levels.resize(m_height + 1, expr_ref_vector(m_imp.m)); + + if (!m_levels[m_height].empty()) + return; + + if (m_height == 0) { + for (auto const& prod : m_imp.m_grammar.leaves()) { + if (prod.range.get() != m_sort) continue; + expr_ref_vector empty_args(m_imp.m); + expr_ref term = prod.builder(empty_args); + if (m_imp.m_oe.is_representative(term)) + m_levels[0].push_back(term); + } + } + else { + for (auto const& prod : m_imp.m_grammar.operators()) { + if (prod.range.get() != m_sort) continue; + term_enum::ChildrenIterator iter(m_imp.m, prod, m_imp.m_enumerator.bank(), m_height); + while (iter.has_next()) { + expr_ref_vector children = iter.next(); + expr_ref term = prod.builder(children); + if (m_imp.m_oe.is_representative(term)) + m_levels[m_height].push_back(term); + } + } + } + + // Sort by cost if cost function is set + if (m_imp.m_cost && !m_levels[m_height].empty()) { + expr_ref_vector& lv = m_levels[m_height]; + std::sort(lv.data(), lv.data() + lv.size(), + [&](expr* a, expr* b) { return m_imp.m_cost(a) < m_imp.m_cost(b); }); + } + } + + void advance_to_valid() { + while (true) { + if (m_height >= m_levels.size()) + expand_current_level(); + if (m_idx < m_levels[m_height].size()) { + m_current = m_levels[m_height][m_idx]; + return; + } + m_height++; + m_idx = 0; + if (m_height > 100) { + m_end = true; + m_current = nullptr; + return; + } + expand_current_level(); + } + } + + void advance() { + if (m_end) return; + m_idx++; + advance_to_valid(); + } +}; + +term_enumeration::iterator::iterator(imp& i, sort* s) { + m_imp = alloc(iter_imp, i, s); +} + +term_enumeration::iterator::iterator(std::nullptr_t) { + m_imp = nullptr; +} + +term_enumeration::iterator::iterator(iterator const& other) { + m_imp = nullptr; + if (other.m_imp) + m_imp = alloc(iter_imp, *other.m_imp); +} + +term_enumeration::iterator& term_enumeration::iterator::operator=(iterator const& other) { + if (this != &other) { + dealloc(m_imp); + m_imp = nullptr; + if (other.m_imp) + m_imp = alloc(iter_imp, *other.m_imp); + } + return *this; +} + +term_enumeration::iterator::~iterator() { + dealloc(m_imp); +} + +expr* term_enumeration::iterator::operator*() { + return m_imp ? m_imp->m_current : nullptr; +} + +term_enumeration::iterator& term_enumeration::iterator::operator++() { + if (m_imp) m_imp->advance(); + return *this; +} + +term_enumeration::iterator term_enumeration::iterator::operator++(int) { + iterator tmp(*this); + ++(*this); + return tmp; +} + +bool term_enumeration::iterator::operator!=(iterator const& other) const { + if (!m_imp && !other.m_imp) return false; + if (!m_imp) return !other.m_imp->m_end; + if (!other.m_imp) return !m_imp->m_end; + return m_imp->m_end != other.m_imp->m_end || + m_imp->m_current != other.m_imp->m_current; +} + +// -- terms implementation -- + +term_enumeration::terms::terms(imp* i, sort* s) : m_imp(i), m_sort(s) {} + +term_enumeration::iterator term_enumeration::terms::begin() { + return iterator(*m_imp, m_sort); +} + +term_enumeration::iterator term_enumeration::terms::end() { + return iterator(nullptr); +} + +// -- term_enumeration implementation -- + +term_enumeration::term_enumeration(ast_manager& m) { + m_imp = alloc(imp, m); +} + +term_enumeration::~term_enumeration() { + dealloc(m_imp); +} + +void term_enumeration::add_production(func_decl* f) { + m_imp->add_production(f); +} + +void term_enumeration::add_production(expr* e) { + m_imp->add_production(e); +} + +void term_enumeration::set_cost(std::function const& cost) { + m_imp->set_cost(cost); +} + +term_enumeration::terms term_enumeration::enum_terms(sort* s) { + return terms(m_imp, s); +} diff --git a/src/ast/term_enumeration.h b/src/ast/term_enumeration.h new file mode 100644 index 000000000..7934c0f4e --- /dev/null +++ b/src/ast/term_enumeration.h @@ -0,0 +1,46 @@ +#pragma once + +#include "ast/ast.h" +#include + +class term_enumeration { + struct imp; + imp* m_imp; +public: + term_enumeration(ast_manager& m); + ~term_enumeration(); + + void add_production(func_decl* f); + void add_production(expr* e); + + // cost function associated with expressions. + // terms are enumerated with increasing cost. + + void set_cost(std::function const& cost); + + class iterator { + struct iter_imp; + iter_imp* m_imp; + public: + iterator(imp& i, sort* s); + iterator(std::nullptr_t); + iterator(iterator const& other); + iterator& operator=(iterator const& other); + ~iterator(); + expr* operator*(); + iterator operator++(int); + iterator& operator++(); + bool operator!=(iterator const& other) const; + }; + + class terms { + imp* m_imp; + sort* m_sort; + public: + terms(imp* i, sort* s); + iterator begin(); + iterator end(); + }; + + terms enum_terms(sort* s); +}; \ No newline at end of file