mirror of
https://github.com/Z3Prover/z3
synced 2026-06-19 15:16:29 +00:00
updated
Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
parent
cf62a78e8a
commit
19e00e03c1
2 changed files with 313 additions and 2 deletions
|
|
@ -1,6 +1,5 @@
|
|||
#pragma once
|
||||
/**
|
||||
* term_enumeration.h - Bottom-up term enumeration module for Z3
|
||||
* term_enumeration.cpp - Bottom-up term enumeration module for Z3
|
||||
*
|
||||
* Inspired by the Probe synthesizer (Barke et al., "Just-in-Time Learning
|
||||
* for Bottom-Up Enumerative Synthesis"). Adapted to use Z3's internal APIs.
|
||||
|
|
@ -14,12 +13,14 @@
|
|||
* (constants, variables) are available for enumeration.
|
||||
*/
|
||||
|
||||
#include "ast/term_enumeration.h"
|
||||
#include <sstream>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <functional>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
#include "ast/ast.h"
|
||||
#include "ast/ast_pp.h"
|
||||
#include "ast/ast_ll_pp.h"
|
||||
|
|
@ -714,3 +715,267 @@ private:
|
|||
};
|
||||
|
||||
} // namespace term_enum
|
||||
|
||||
// ============================================================================
|
||||
// term_enumeration public interface implementation
|
||||
// ============================================================================
|
||||
|
||||
struct term_enumeration::imp {
|
||||
ast_manager& m;
|
||||
term_enum::Grammar m_grammar;
|
||||
term_enum::OEManager m_oe;
|
||||
term_enum::Enumerator m_enumerator;
|
||||
std::function<unsigned(expr*)> m_cost;
|
||||
|
||||
imp(ast_manager& m) :
|
||||
m(m), m_grammar(m), m_oe(m), m_enumerator(m_grammar, m_oe) {}
|
||||
|
||||
void add_production(func_decl* f) {
|
||||
m_grammar.add_func_decl(f);
|
||||
}
|
||||
|
||||
void add_production(expr* e) {
|
||||
m_grammar.add_expr(e);
|
||||
}
|
||||
|
||||
void set_cost(std::function<unsigned(expr*)> const& cost) {
|
||||
m_cost = cost;
|
||||
}
|
||||
|
||||
// Enumerate terms of given sort up to a height, ordered by cost.
|
||||
// Returns the next term in cost order, or nullptr if exhausted at current height.
|
||||
expr* next_term(sort* s, unsigned& height_state, unsigned& idx_state,
|
||||
vector<expr_ref_vector>& levels) {
|
||||
// Expand levels as needed
|
||||
while (idx_state >= level_size(levels, height_state)) {
|
||||
height_state++;
|
||||
if (height_state > 100)
|
||||
return nullptr;
|
||||
expand_level(s, height_state, levels);
|
||||
idx_state = 0;
|
||||
if (level_size(levels, height_state) > 0)
|
||||
break;
|
||||
}
|
||||
if (height_state >= levels.size() || idx_state >= levels[height_state].size())
|
||||
return nullptr;
|
||||
return levels[height_state][idx_state++];
|
||||
}
|
||||
|
||||
private:
|
||||
unsigned level_size(vector<expr_ref_vector> const& levels, unsigned h) const {
|
||||
if (h >= levels.size()) return 0;
|
||||
return levels[h].size();
|
||||
}
|
||||
|
||||
void expand_level(sort* s, unsigned height, vector<expr_ref_vector>& levels) {
|
||||
if (height >= levels.size())
|
||||
levels.resize(height + 1, expr_ref_vector(m));
|
||||
|
||||
// Collect terms at this height
|
||||
if (height == 0) {
|
||||
// Leaves
|
||||
for (auto const& prod : m_grammar.leaves()) {
|
||||
if (prod.range.get() != s) continue;
|
||||
expr_ref_vector empty_args(m);
|
||||
expr_ref term = prod.builder(empty_args);
|
||||
if (m_oe.is_representative(term)) {
|
||||
m_enumerator.bank(); // just to ensure bank is populated
|
||||
levels[0].push_back(term);
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
// Operators
|
||||
for (auto const& prod : m_grammar.operators()) {
|
||||
if (prod.range.get() != s) continue;
|
||||
term_enum::ChildrenIterator iter(m, prod, m_enumerator.bank(), height);
|
||||
while (iter.has_next()) {
|
||||
expr_ref_vector children = iter.next();
|
||||
expr_ref term = prod.builder(children);
|
||||
if (m_oe.is_representative(term))
|
||||
levels[height].push_back(term);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by cost if cost function is set
|
||||
if (m_cost && !levels[height].empty()) {
|
||||
expr_ref_vector& lv = levels[height];
|
||||
std::sort(lv.data(), lv.data() + lv.size(),
|
||||
[&](expr* a, expr* b) { return m_cost(a) < m_cost(b); });
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// -- iterator implementation --
|
||||
|
||||
struct term_enumeration::iterator::iter_imp {
|
||||
imp& m_imp;
|
||||
sort* m_sort;
|
||||
unsigned m_height;
|
||||
unsigned m_idx;
|
||||
vector<expr_ref_vector> m_levels;
|
||||
expr* m_current;
|
||||
bool m_end;
|
||||
|
||||
iter_imp(imp& i, sort* s) :
|
||||
m_imp(i), m_sort(s), m_height(0), m_idx(0), m_current(nullptr), m_end(false) {
|
||||
expand_current_level();
|
||||
advance_to_valid();
|
||||
}
|
||||
|
||||
// Sentinel constructor
|
||||
iter_imp(imp& i) :
|
||||
m_imp(i), m_sort(nullptr), m_height(0), m_idx(0), m_current(nullptr), m_end(true) {}
|
||||
|
||||
void expand_current_level() {
|
||||
if (m_height >= m_levels.size())
|
||||
m_levels.resize(m_height + 1, expr_ref_vector(m_imp.m));
|
||||
|
||||
if (!m_levels[m_height].empty())
|
||||
return;
|
||||
|
||||
if (m_height == 0) {
|
||||
for (auto const& prod : m_imp.m_grammar.leaves()) {
|
||||
if (prod.range.get() != m_sort) continue;
|
||||
expr_ref_vector empty_args(m_imp.m);
|
||||
expr_ref term = prod.builder(empty_args);
|
||||
if (m_imp.m_oe.is_representative(term))
|
||||
m_levels[0].push_back(term);
|
||||
}
|
||||
}
|
||||
else {
|
||||
for (auto const& prod : m_imp.m_grammar.operators()) {
|
||||
if (prod.range.get() != m_sort) continue;
|
||||
term_enum::ChildrenIterator iter(m_imp.m, prod, m_imp.m_enumerator.bank(), m_height);
|
||||
while (iter.has_next()) {
|
||||
expr_ref_vector children = iter.next();
|
||||
expr_ref term = prod.builder(children);
|
||||
if (m_imp.m_oe.is_representative(term))
|
||||
m_levels[m_height].push_back(term);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by cost if cost function is set
|
||||
if (m_imp.m_cost && !m_levels[m_height].empty()) {
|
||||
expr_ref_vector& lv = m_levels[m_height];
|
||||
std::sort(lv.data(), lv.data() + lv.size(),
|
||||
[&](expr* a, expr* b) { return m_imp.m_cost(a) < m_imp.m_cost(b); });
|
||||
}
|
||||
}
|
||||
|
||||
void advance_to_valid() {
|
||||
while (true) {
|
||||
if (m_height >= m_levels.size())
|
||||
expand_current_level();
|
||||
if (m_idx < m_levels[m_height].size()) {
|
||||
m_current = m_levels[m_height][m_idx];
|
||||
return;
|
||||
}
|
||||
m_height++;
|
||||
m_idx = 0;
|
||||
if (m_height > 100) {
|
||||
m_end = true;
|
||||
m_current = nullptr;
|
||||
return;
|
||||
}
|
||||
expand_current_level();
|
||||
}
|
||||
}
|
||||
|
||||
void advance() {
|
||||
if (m_end) return;
|
||||
m_idx++;
|
||||
advance_to_valid();
|
||||
}
|
||||
};
|
||||
|
||||
term_enumeration::iterator::iterator(imp& i, sort* s) {
|
||||
m_imp = alloc(iter_imp, i, s);
|
||||
}
|
||||
|
||||
term_enumeration::iterator::iterator(std::nullptr_t) {
|
||||
m_imp = nullptr;
|
||||
}
|
||||
|
||||
term_enumeration::iterator::iterator(iterator const& other) {
|
||||
m_imp = nullptr;
|
||||
if (other.m_imp)
|
||||
m_imp = alloc(iter_imp, *other.m_imp);
|
||||
}
|
||||
|
||||
term_enumeration::iterator& term_enumeration::iterator::operator=(iterator const& other) {
|
||||
if (this != &other) {
|
||||
dealloc(m_imp);
|
||||
m_imp = nullptr;
|
||||
if (other.m_imp)
|
||||
m_imp = alloc(iter_imp, *other.m_imp);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
term_enumeration::iterator::~iterator() {
|
||||
dealloc(m_imp);
|
||||
}
|
||||
|
||||
expr* term_enumeration::iterator::operator*() {
|
||||
return m_imp ? m_imp->m_current : nullptr;
|
||||
}
|
||||
|
||||
term_enumeration::iterator& term_enumeration::iterator::operator++() {
|
||||
if (m_imp) m_imp->advance();
|
||||
return *this;
|
||||
}
|
||||
|
||||
term_enumeration::iterator term_enumeration::iterator::operator++(int) {
|
||||
iterator tmp(*this);
|
||||
++(*this);
|
||||
return tmp;
|
||||
}
|
||||
|
||||
bool term_enumeration::iterator::operator!=(iterator const& other) const {
|
||||
if (!m_imp && !other.m_imp) return false;
|
||||
if (!m_imp) return !other.m_imp->m_end;
|
||||
if (!other.m_imp) return !m_imp->m_end;
|
||||
return m_imp->m_end != other.m_imp->m_end ||
|
||||
m_imp->m_current != other.m_imp->m_current;
|
||||
}
|
||||
|
||||
// -- terms implementation --
|
||||
|
||||
term_enumeration::terms::terms(imp* i, sort* s) : m_imp(i), m_sort(s) {}
|
||||
|
||||
term_enumeration::iterator term_enumeration::terms::begin() {
|
||||
return iterator(*m_imp, m_sort);
|
||||
}
|
||||
|
||||
term_enumeration::iterator term_enumeration::terms::end() {
|
||||
return iterator(nullptr);
|
||||
}
|
||||
|
||||
// -- term_enumeration implementation --
|
||||
|
||||
term_enumeration::term_enumeration(ast_manager& m) {
|
||||
m_imp = alloc(imp, m);
|
||||
}
|
||||
|
||||
term_enumeration::~term_enumeration() {
|
||||
dealloc(m_imp);
|
||||
}
|
||||
|
||||
void term_enumeration::add_production(func_decl* f) {
|
||||
m_imp->add_production(f);
|
||||
}
|
||||
|
||||
void term_enumeration::add_production(expr* e) {
|
||||
m_imp->add_production(e);
|
||||
}
|
||||
|
||||
void term_enumeration::set_cost(std::function<unsigned(expr*)> const& cost) {
|
||||
m_imp->set_cost(cost);
|
||||
}
|
||||
|
||||
term_enumeration::terms term_enumeration::enum_terms(sort* s) {
|
||||
return terms(m_imp, s);
|
||||
}
|
||||
|
|
|
|||
46
src/ast/term_enumeration.h
Normal file
46
src/ast/term_enumeration.h
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
#pragma once
|
||||
|
||||
#include "ast/ast.h"
|
||||
#include <functional>
|
||||
|
||||
class term_enumeration {
|
||||
struct imp;
|
||||
imp* m_imp;
|
||||
public:
|
||||
term_enumeration(ast_manager& m);
|
||||
~term_enumeration();
|
||||
|
||||
void add_production(func_decl* f);
|
||||
void add_production(expr* e);
|
||||
|
||||
// cost function associated with expressions.
|
||||
// terms are enumerated with increasing cost.
|
||||
|
||||
void set_cost(std::function<unsigned(expr*)> const& cost);
|
||||
|
||||
class iterator {
|
||||
struct iter_imp;
|
||||
iter_imp* m_imp;
|
||||
public:
|
||||
iterator(imp& i, sort* s);
|
||||
iterator(std::nullptr_t);
|
||||
iterator(iterator const& other);
|
||||
iterator& operator=(iterator const& other);
|
||||
~iterator();
|
||||
expr* operator*();
|
||||
iterator operator++(int);
|
||||
iterator& operator++();
|
||||
bool operator!=(iterator const& other) const;
|
||||
};
|
||||
|
||||
class terms {
|
||||
imp* m_imp;
|
||||
sort* m_sort;
|
||||
public:
|
||||
terms(imp* i, sort* s);
|
||||
iterator begin();
|
||||
iterator end();
|
||||
};
|
||||
|
||||
terms enum_terms(sort* s);
|
||||
};
|
||||
Loading…
Add table
Add a link
Reference in a new issue