3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2026-06-19 15:16:29 +00:00
Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2026-06-11 18:01:36 -07:00
parent cf62a78e8a
commit 19e00e03c1
2 changed files with 313 additions and 2 deletions

View file

@ -1,6 +1,5 @@
#pragma once
/**
* term_enumeration.h - Bottom-up term enumeration module for Z3
* term_enumeration.cpp - Bottom-up term enumeration module for Z3
*
* Inspired by the Probe synthesizer (Barke et al., "Just-in-Time Learning
* for Bottom-Up Enumerative Synthesis"). Adapted to use Z3's internal APIs.
@ -14,12 +13,14 @@
* (constants, variables) are available for enumeration.
*/
#include "ast/term_enumeration.h"
#include <sstream>
#include <unordered_map>
#include <unordered_set>
#include <functional>
#include <string>
#include <memory>
#include <queue>
#include "ast/ast.h"
#include "ast/ast_pp.h"
#include "ast/ast_ll_pp.h"
@ -714,3 +715,267 @@ private:
};
} // namespace term_enum
// ============================================================================
// term_enumeration public interface implementation
// ============================================================================
struct term_enumeration::imp {
ast_manager& m;
term_enum::Grammar m_grammar;
term_enum::OEManager m_oe;
term_enum::Enumerator m_enumerator;
std::function<unsigned(expr*)> m_cost;
imp(ast_manager& m) :
m(m), m_grammar(m), m_oe(m), m_enumerator(m_grammar, m_oe) {}
void add_production(func_decl* f) {
m_grammar.add_func_decl(f);
}
void add_production(expr* e) {
m_grammar.add_expr(e);
}
void set_cost(std::function<unsigned(expr*)> const& cost) {
m_cost = cost;
}
// Enumerate terms of given sort up to a height, ordered by cost.
// Returns the next term in cost order, or nullptr if exhausted at current height.
expr* next_term(sort* s, unsigned& height_state, unsigned& idx_state,
vector<expr_ref_vector>& levels) {
// Expand levels as needed
while (idx_state >= level_size(levels, height_state)) {
height_state++;
if (height_state > 100)
return nullptr;
expand_level(s, height_state, levels);
idx_state = 0;
if (level_size(levels, height_state) > 0)
break;
}
if (height_state >= levels.size() || idx_state >= levels[height_state].size())
return nullptr;
return levels[height_state][idx_state++];
}
private:
unsigned level_size(vector<expr_ref_vector> const& levels, unsigned h) const {
if (h >= levels.size()) return 0;
return levels[h].size();
}
void expand_level(sort* s, unsigned height, vector<expr_ref_vector>& levels) {
if (height >= levels.size())
levels.resize(height + 1, expr_ref_vector(m));
// Collect terms at this height
if (height == 0) {
// Leaves
for (auto const& prod : m_grammar.leaves()) {
if (prod.range.get() != s) continue;
expr_ref_vector empty_args(m);
expr_ref term = prod.builder(empty_args);
if (m_oe.is_representative(term)) {
m_enumerator.bank(); // just to ensure bank is populated
levels[0].push_back(term);
}
}
}
else {
// Operators
for (auto const& prod : m_grammar.operators()) {
if (prod.range.get() != s) continue;
term_enum::ChildrenIterator iter(m, prod, m_enumerator.bank(), height);
while (iter.has_next()) {
expr_ref_vector children = iter.next();
expr_ref term = prod.builder(children);
if (m_oe.is_representative(term))
levels[height].push_back(term);
}
}
}
// Sort by cost if cost function is set
if (m_cost && !levels[height].empty()) {
expr_ref_vector& lv = levels[height];
std::sort(lv.data(), lv.data() + lv.size(),
[&](expr* a, expr* b) { return m_cost(a) < m_cost(b); });
}
}
};
// -- iterator implementation --
struct term_enumeration::iterator::iter_imp {
imp& m_imp;
sort* m_sort;
unsigned m_height;
unsigned m_idx;
vector<expr_ref_vector> m_levels;
expr* m_current;
bool m_end;
iter_imp(imp& i, sort* s) :
m_imp(i), m_sort(s), m_height(0), m_idx(0), m_current(nullptr), m_end(false) {
expand_current_level();
advance_to_valid();
}
// Sentinel constructor
iter_imp(imp& i) :
m_imp(i), m_sort(nullptr), m_height(0), m_idx(0), m_current(nullptr), m_end(true) {}
void expand_current_level() {
if (m_height >= m_levels.size())
m_levels.resize(m_height + 1, expr_ref_vector(m_imp.m));
if (!m_levels[m_height].empty())
return;
if (m_height == 0) {
for (auto const& prod : m_imp.m_grammar.leaves()) {
if (prod.range.get() != m_sort) continue;
expr_ref_vector empty_args(m_imp.m);
expr_ref term = prod.builder(empty_args);
if (m_imp.m_oe.is_representative(term))
m_levels[0].push_back(term);
}
}
else {
for (auto const& prod : m_imp.m_grammar.operators()) {
if (prod.range.get() != m_sort) continue;
term_enum::ChildrenIterator iter(m_imp.m, prod, m_imp.m_enumerator.bank(), m_height);
while (iter.has_next()) {
expr_ref_vector children = iter.next();
expr_ref term = prod.builder(children);
if (m_imp.m_oe.is_representative(term))
m_levels[m_height].push_back(term);
}
}
}
// Sort by cost if cost function is set
if (m_imp.m_cost && !m_levels[m_height].empty()) {
expr_ref_vector& lv = m_levels[m_height];
std::sort(lv.data(), lv.data() + lv.size(),
[&](expr* a, expr* b) { return m_imp.m_cost(a) < m_imp.m_cost(b); });
}
}
void advance_to_valid() {
while (true) {
if (m_height >= m_levels.size())
expand_current_level();
if (m_idx < m_levels[m_height].size()) {
m_current = m_levels[m_height][m_idx];
return;
}
m_height++;
m_idx = 0;
if (m_height > 100) {
m_end = true;
m_current = nullptr;
return;
}
expand_current_level();
}
}
void advance() {
if (m_end) return;
m_idx++;
advance_to_valid();
}
};
term_enumeration::iterator::iterator(imp& i, sort* s) {
m_imp = alloc(iter_imp, i, s);
}
term_enumeration::iterator::iterator(std::nullptr_t) {
m_imp = nullptr;
}
term_enumeration::iterator::iterator(iterator const& other) {
m_imp = nullptr;
if (other.m_imp)
m_imp = alloc(iter_imp, *other.m_imp);
}
term_enumeration::iterator& term_enumeration::iterator::operator=(iterator const& other) {
if (this != &other) {
dealloc(m_imp);
m_imp = nullptr;
if (other.m_imp)
m_imp = alloc(iter_imp, *other.m_imp);
}
return *this;
}
term_enumeration::iterator::~iterator() {
dealloc(m_imp);
}
expr* term_enumeration::iterator::operator*() {
return m_imp ? m_imp->m_current : nullptr;
}
term_enumeration::iterator& term_enumeration::iterator::operator++() {
if (m_imp) m_imp->advance();
return *this;
}
term_enumeration::iterator term_enumeration::iterator::operator++(int) {
iterator tmp(*this);
++(*this);
return tmp;
}
bool term_enumeration::iterator::operator!=(iterator const& other) const {
if (!m_imp && !other.m_imp) return false;
if (!m_imp) return !other.m_imp->m_end;
if (!other.m_imp) return !m_imp->m_end;
return m_imp->m_end != other.m_imp->m_end ||
m_imp->m_current != other.m_imp->m_current;
}
// -- terms implementation --
term_enumeration::terms::terms(imp* i, sort* s) : m_imp(i), m_sort(s) {}
term_enumeration::iterator term_enumeration::terms::begin() {
return iterator(*m_imp, m_sort);
}
term_enumeration::iterator term_enumeration::terms::end() {
return iterator(nullptr);
}
// -- term_enumeration implementation --
term_enumeration::term_enumeration(ast_manager& m) {
m_imp = alloc(imp, m);
}
term_enumeration::~term_enumeration() {
dealloc(m_imp);
}
void term_enumeration::add_production(func_decl* f) {
m_imp->add_production(f);
}
void term_enumeration::add_production(expr* e) {
m_imp->add_production(e);
}
void term_enumeration::set_cost(std::function<unsigned(expr*)> const& cost) {
m_imp->set_cost(cost);
}
term_enumeration::terms term_enumeration::enum_terms(sort* s) {
return terms(m_imp, s);
}

View file

@ -0,0 +1,46 @@
#pragma once
#include "ast/ast.h"
#include <functional>
class term_enumeration {
struct imp;
imp* m_imp;
public:
term_enumeration(ast_manager& m);
~term_enumeration();
void add_production(func_decl* f);
void add_production(expr* e);
// cost function associated with expressions.
// terms are enumerated with increasing cost.
void set_cost(std::function<unsigned(expr*)> const& cost);
class iterator {
struct iter_imp;
iter_imp* m_imp;
public:
iterator(imp& i, sort* s);
iterator(std::nullptr_t);
iterator(iterator const& other);
iterator& operator=(iterator const& other);
~iterator();
expr* operator*();
iterator operator++(int);
iterator& operator++();
bool operator!=(iterator const& other) const;
};
class terms {
imp* m_imp;
sort* m_sort;
public:
terms(imp* i, sort* s);
iterator begin();
iterator end();
};
terms enum_terms(sort* s);
};