3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2026-06-20 15:40:37 +00:00
z3/src/ast/term_enumeration.cpp
Nikolaj Bjorner be26deb365 use scoped_ptr_vector instead of vector of structs
Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
2026-06-19 19:27:30 -07:00

629 lines
19 KiB
C++

/**
* term_enumeration.cpp - Bottom-up term enumeration module for Z3
*
* Inspired by the Probe synthesizer (Barke et al., "Just-in-Time Learning
* for Bottom-Up Enumerative Synthesis"). Adapted to use Z3's internal APIs.
*
* Key ideas:
* - Terms are enumerated bottom-up by "cost" (calculated by tree size).
* - A grammar describes which function symbols (operators) and leaves
* (constants, variables) are available for enumeration.
*/
#include <sstream>
#include <functional>
#include <string>
#include "util/vector.h"
#include "util/scoped_ptr_vector.h"
#include "util/obj_hashtable.h"
#include "ast/ast.h"
#include "ast/ast_ll_pp.h"
#include "ast/ast_pp.h"
#include "ast/term_enumeration.h"
namespace term_enum {
// ============================================================================
// grammar production rule
// ============================================================================
/**
* A production describes how to construct a term from child terms.
* - domain: the sort required for each child
* - range: the sort of the produced term
* - builder: given a vector of child exprs, produce the result expr
*/
struct production {
std::string name;
sort_ref range;
sort_ref_vector domain;
std::function<expr_ref(expr_ref_vector const&)> builder;
bool is_leaf() const { return domain.empty(); }
};
// ============================================================================
// grammar
// ============================================================================
/**
* A grammar groups productions into leaves (arity 0) and operators (arity > 0).
*/
class grammar {
public:
grammar(ast_manager& m) : m(m), m_pinned(m) {}
void add_production(production* p) {
if (p->is_leaf())
m_leaves.push_back(p);
else
m_operators.push_back(p);
}
scoped_ptr_vector<production> const& leaves() const { return m_leaves; }
scoped_ptr_vector<production> const& operators() const { return m_operators; }
ast_manager& mgr() const { return m; }
void add_func_decl(func_decl *f) {
m_pinned.push_back(f);
sort_ref range(f->get_range(), m);
sort_ref_vector dom(m);
for (unsigned i = 0; i < f->get_arity(); ++i)
dom.push_back(sort_ref(f->get_domain(i), m));
add_production(alloc(production, {f->get_name().str(), range, dom, [this, f](expr_ref_vector const &args) {
return expr_ref(m.mk_app(f, args), m);
}}));
}
void add_expr(expr *e) {
m_pinned.push_back(e);
sort_ref range(e->get_sort(), m);
sort_ref_vector dom(m);
std::stringstream ss;
ss << mk_bounded_pp(e, m);
std::string name = ss.str();
add_production(alloc(production, {name, range, dom, [this, e](expr_ref_vector const&) { return expr_ref(e, m); }}));
}
std::ostream& display(std::ostream& out) const {
out << "Leaves:\n";
for (auto const *p : m_leaves) {
out << " " << p->name << " : " << mk_pp(p->range, m) << "\n";
}
out << "Operators:\n";
for (auto const *p : m_operators) {
out << " " << p->name << " : (";
for (unsigned i = 0; i < p->domain.size(); ++i) {
if (i > 0)
out << ", ";
out << mk_pp(p->domain[i], m);
}
out << ") -> " << mk_pp(p->range, m) << "\n";
}
return out;
}
private:
ast_manager& m;
ast_ref_vector m_pinned;
scoped_ptr_vector<production> m_leaves;
scoped_ptr_vector<production> m_operators;
};
// ============================================================================
// Term Bank - stores enumerated terms by cost and sort
// ============================================================================
using cost_terms = vector<std::pair<expr*, unsigned>>;
class term_bank {
using sort_term_map = obj_map<sort, ptr_vector<expr>>;
public:
term_bank(ast_manager& m) : m(m), m_pinned(m) {}
~term_bank() {
for (auto s : m_terms)
dealloc(s);
}
void reset() {
m_pinned.reset();
m_terms.clear();
}
void add(expr* term, unsigned cost) {
sort* s = term->get_sort();
m_pinned.push_back(term);
if (cost >= m_terms.size())
m_terms.resize(cost + 1);
if (!m_terms[cost])
m_terms[cost] = alloc(sort_term_map);
m_terms[cost]->insert_if_not_there(s, ptr_vector<expr>()).push_back(term);
}
/** Get all terms of a given sort up to (and including) max_cost */
cost_terms get_by_sort(sort* s, unsigned max_cost) const {
cost_terms result;
for (unsigned c = 0; c <= max_cost; ++c) {
if (c >= m_terms.size())
break;
if (!m_terms[c]->contains(s))
continue;
for (auto t : m_terms[c]->find(s))
result.push_back({t, c});
}
return result;
}
ptr_vector<expr> null_ptr_vector;
ptr_vector<expr> const &get_by_cost_and_sort(unsigned cost, sort *s) const {
if (cost >= m_terms.size() || !m_terms[cost] || !m_terms[cost]->contains(s))
return null_ptr_vector;
return m_terms[cost]->find(s);
}
std::ostream& display(std::ostream& out) const {
for (unsigned cost = 0; cost < m_terms.size(); ++cost) {
if (!m_terms[cost])
continue;
out << "cost " << cost << ":\n";
for (auto& [s, terms] : *m_terms[cost]) {
out << " sort " << mk_pp(s, m) << ":\n";
for (expr* e : terms) {
out << " #" << e->get_id() << " ";
if (cost == 0) {
out << mk_bounded_pp(e, m);
}
else if (is_app(e)) {
app* a = to_app(e);
out << a->get_decl()->get_name() << "(";
bool first = true;
for (expr* arg : *a) {
if (!first) out << ", ";
first = false;
out << "#" << arg->get_id();
}
out << ")";
}
out << "\n";
}
}
}
return out;
}
private:
ast_manager& m;
expr_ref_vector m_pinned;
// cost -> sort -> terms
ptr_vector<sort_term_map> m_terms;
};
// ============================================================================
// Children Iterator - generates all combinations of child terms
// ============================================================================
/**
* Iterates over all tuples (c1, c2, ..., cn) where each ci has the required
* sort, drawn from the term bank, with at least one child at the current
* cost - 1 (to avoid regenerating previously seen terms).
*/
class children_iterator {
public:
children_iterator(ast_manager& m, production const& prod, term_bank const& bank, unsigned current_cost)
: m(m), m_prod(prod), m_current_cost(current_cost), m_done(false)
{
m_arity = prod.domain.size();
if (m_arity == 0) {
m_done = true;
return;
}
for (unsigned i = 0; i < m_arity; ++i) {
m_candidates.push_back(bank.get_by_sort(prod.domain[i], current_cost - 1));
if (m_candidates.back().empty()) {
m_done = true;
return;
}
}
m_indices.resize(m_arity, 0);
}
bool has_next(unsigned cost) {
while (!m_done) {
if (has_child_at_cost(cost))
return true;
advance();
}
return false;
}
expr_ref_vector next(unsigned& cost) {
expr_ref_vector result(m);
cost = 1;
for (unsigned i = 0; i < m_arity; ++i) {
auto [e, c] = m_candidates[i].get(m_indices[i]);
cost += c;
result.push_back(e);
}
advance();
return result;
}
private:
ast_manager& m;
production const& m_prod;
unsigned m_current_cost;
unsigned m_arity;
bool m_done;
vector<cost_terms> m_candidates;
svector<unsigned> m_indices;
bool has_child_at_cost(unsigned cost) const {
for (unsigned i = 0; i < m_arity; ++i) {
auto [e, c] = m_candidates[i].get(m_indices[i]);
if (c + 1 == cost)
return true;
}
return false;
}
void advance() {
for (auto i = m_arity; i-- > 0;) {
m_indices[i]++;
if (m_indices[i] < m_candidates[i].size()) return;
m_indices[i] = 0;
}
m_done = true;
}
};
// ============================================================================
// bottom_up_enumerator - the main bottom-up term enumeration engine
// ============================================================================
class bottom_up_enumerator {
public:
bottom_up_enumerator(grammar& grammar)
: m_grammar(grammar), m(grammar.mgr()),
m_bank(grammar.mgr()), m_pending(grammar.mgr())
{}
void set_target_sort(sort *s) {
m_target_sort = s;
}
bool has_next() {
if (m_pending) return true;
m_pending = find_next();
return m_pending != nullptr;
}
expr_ref next() {
if (!m_pending)
m_pending = find_next();
expr_ref result(m_pending, m);
m_pending = nullptr;
return result;
}
term_bank const& bank() const { return m_bank; }
std::ostream& display(std::ostream& out) const {
m_grammar.display(out);
return m_bank.display(out);
}
void reset() {
m_cost = 0;
m_leaf_idx = 0;
m_op_idx = 0;
m_state = State::Leaves;
m_bank.reset();
m_pending = nullptr;
m_children_iter.reset();
}
private:
enum class State { Leaves, Operators, Done };
grammar& m_grammar;
ast_manager& m;
term_bank m_bank;
unsigned m_cost = 0;
unsigned m_leaf_idx = 0;
unsigned m_op_idx = 0;
unsigned m_bank_idx = 0;
unsigned m_bank_size = 0;
bool m_has_range = false;
State m_state = State::Leaves;
expr_ref m_pending;
std::unique_ptr<children_iterator> m_children_iter;
sort *m_target_sort = nullptr;
bool sort_matches(expr* e) const {
return !m_target_sort || e->get_sort() == m_target_sort;
}
expr* find_next() {
while (true) {
switch (m_state) {
case State::Leaves:
while (m_leaf_idx < m_grammar.leaves().size()) {
production const &prod = *m_grammar.leaves()[m_leaf_idx];
m_leaf_idx++;
expr_ref_vector empty_args(m);
expr_ref term = prod.builder(empty_args);
m_bank.add(term, 0);
if (sort_matches(term))
return term;
}
m_state = State::Operators;
m_cost = 1;
m_op_idx = 0;
m_bank_idx = 0;
m_bank_size = get_bank_size();
m_has_range = false;
m_children_iter.reset();
break;
case State::Operators: {
expr* result = enumerate_operators();
if (result)
return result;
m_cost++;
m_op_idx = 0;
m_bank_idx = 0;
m_bank_size = get_bank_size();
m_children_iter.reset();
break;
}
case State::Done:
return nullptr;
}
}
}
unsigned get_bank_size() const {
auto const &terms = m_bank.get_by_cost_and_sort(m_cost, m_target_sort);
return terms.size();
}
expr *enumerate_operators() {
auto const &ops = m_grammar.operators();
while (true) {
// first find terms at m_cost that were already created
if (m_bank_idx < m_bank_size) {
auto const &terms = m_bank.get_by_cost_and_sort(m_cost, m_target_sort);
auto t = terms.get(m_bank_idx);
m_bank_idx++;
SASSERT(sort_matches(t));
return t;
}
// then create new terms using children at cost below current m_cost.
if (m_children_iter && m_children_iter->has_next(m_cost)) {
unsigned new_cost = 0;
expr_ref_vector children = m_children_iter->next(new_cost);
production const &prod = *ops[m_op_idx - 1];
expr_ref term = prod.builder(children);
// IF_VERBOSE(0, verbose_stream() << term << "\n");
SASSERT(new_cost >= m_cost);
m_bank.add(term, new_cost);
if (sort_matches(term) && new_cost == m_cost) {
m_has_range = true;
return term;
}
continue;
}
if (ops.empty()) {
m_state = State::Done;
return nullptr;
}
if (m_op_idx >= ops.size()) {
if (!m_has_range)
m_state = State::Done;
return nullptr;
}
production const &prod = *ops[m_op_idx];
m_op_idx++;
m_children_iter = std::make_unique<children_iterator>(m, prod, m_bank, m_cost);
}
}
};
} // namespace term_enum
// ============================================================================
// term_enumeration public interface implementation
// ============================================================================
struct term_enumeration::imp {
ast_manager& m;
term_enum::grammar m_grammar;
term_enum::bottom_up_enumerator m_bottom_up_enumerator;
std::function<unsigned(expr*)> m_cost;
imp(ast_manager& m) :
m(m), m_grammar(m), m_bottom_up_enumerator(m_grammar) {}
void add_production(func_decl* f) {
m_grammar.add_func_decl(f);
}
void add_production(expr* e) {
m_grammar.add_expr(e);
}
void set_cost(std::function<unsigned(expr*)> const& cost) {
// TODO
}
std::ostream& display(std::ostream& out) const {
return m_bottom_up_enumerator.display(out);
}
};
// -- iterator implementation --
struct term_enumeration::iterator::iter_imp {
imp& m_imp;
ast_manager & m;
sort* m_sort;
unsigned m_cost = 0;
unsigned m_idx = 0;
vector<expr_ref_vector> m_levels;
expr_ref m_current;
bool m_end;
vector<expr_ref_vector> m_vars;
vector<ptr_vector<sort>> m_decls;
vector<vector<symbol>> m_names;
iter_imp(imp& i, sort* s) : m_imp(i), m(i.m), m_sort(s), m_current(i.m), m_end(false) {
m_imp.m_bottom_up_enumerator.reset();
init_sort();
advance();
}
// Sentinel constructor
iter_imp(imp& i) :
m_imp(i), m(i.m), m_sort(nullptr), m_current(i.m), m_end(true) {
UNREACHABLE();
}
void init_sort() {
array_util autil(m);
sort *range = m_sort;
while (autil.is_array(range)) {
m_vars.push_back(expr_ref_vector(m));
m_decls.push_back(ptr_vector<sort>());
m_names.push_back(vector<symbol>());
for (unsigned i = 0; i < get_array_arity(range); ++i) {
m_decls.back().push_back(get_array_domain(range, i));
m_vars.back().push_back(nullptr);
m_names.back().push_back(symbol());
}
expr_ref_vector args(m);
args.push_back(m.mk_const("a", range));
for (unsigned i = 0; i < m_decls.back().size(); ++i) {
args.push_back(m.mk_var(i, m_decls.back().get(i)));
}
app_ref sel(autil.mk_select(args), m);
m_imp.m_grammar.add_func_decl(sel->get_decl());
range = get_array_range(range);
}
unsigned n = 0;
for (unsigned i = m_decls.size(); i-- > 0;) {
for (unsigned j = m_decls[i].size(); j-- > 0;) {
m_vars[i][j] = m.mk_var(n, m_decls[i][j]);
m_names[i][j] = symbol(n);
m_imp.add_production(m_vars[i].get(j));
n++;
}
}
m_sort = range;
m_imp.m_bottom_up_enumerator.set_target_sort(range);
}
void mk_lambda() {
if (!m_current)
return;
for (unsigned i = m_decls.size(); i-- > 0;)
m_current = m.mk_lambda(m_decls[i].size(), m_decls[i].data(), m_names[i].data(), m_current);
}
void advance() {
if (m_end)
return;
m_current = m_imp.m_bottom_up_enumerator.next();
SASSERT(!m_current || m_current->get_sort() == m_sort);
mk_lambda();
if (!m_current)
m_end = true;
}
};
term_enumeration::iterator::iterator(imp& i, sort* s) {
m_imp = alloc(iter_imp, i, s);
}
term_enumeration::iterator::iterator(std::nullptr_t) {
m_imp = nullptr;
}
term_enumeration::iterator::~iterator() {
dealloc(m_imp);
}
expr* term_enumeration::iterator::operator*() {
return m_imp ? m_imp->m_current.get() : nullptr;
}
term_enumeration::iterator& term_enumeration::iterator::operator++() {
if (m_imp) m_imp->advance();
return *this;
}
term_enumeration::iterator term_enumeration::iterator::operator++(int) {
iterator tmp(*this);
++(*this);
return tmp;
}
bool term_enumeration::iterator::operator==(iterator const& other) const {
if (!m_imp && !other.m_imp) return true;
if (!m_imp) return other.m_imp->m_end;
if (!other.m_imp) return m_imp->m_end;
return m_imp->m_end == other.m_imp->m_end &&
m_imp->m_current == other.m_imp->m_current;
}
// -- terms implementation --
term_enumeration::terms::terms(imp* i, sort* s) : m_imp(i), m_sort(s) {}
term_enumeration::iterator term_enumeration::terms::begin() {
return iterator(*m_imp, m_sort);
}
term_enumeration::iterator term_enumeration::terms::end() {
return iterator(nullptr);
}
// -- term_enumeration implementation --
term_enumeration::term_enumeration(ast_manager& m) {
m_imp = alloc(imp, m);
}
term_enumeration::~term_enumeration() {
dealloc(m_imp);
}
void term_enumeration::add_production(func_decl* f) {
m_imp->add_production(f);
}
void term_enumeration::add_production(expr* e) {
m_imp->add_production(e);
}
void term_enumeration::set_cost(std::function<unsigned(expr*)> const& cost) {
m_imp->set_cost(cost);
}
term_enumeration::terms term_enumeration::enum_terms(sort* s) {
return terms(m_imp, s);
}
std::ostream& term_enumeration::display(std::ostream& out) const {
return m_imp->display(out);
}