3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2026-06-21 08:00:27 +00:00

updated term enumerator

This commit is contained in:
Nikolaj Bjorner 2026-06-19 16:18:07 -07:00
parent c0888b9ecd
commit 04ddb66931
8 changed files with 499 additions and 332 deletions

View file

@ -22,6 +22,7 @@
#include "util/obj_hashtable.h"
#include "ast/ast.h"
#include "ast/ast_ll_pp.h"
#include "ast/ast_pp.h"
#include "ast/arith_decl_plugin.h"
#include "ast/bv_decl_plugin.h"
#include "ast/seq_decl_plugin.h"
@ -60,7 +61,7 @@ struct production {
*/
class grammar {
public:
grammar(ast_manager& m) : m(m) {}
grammar(ast_manager& m) : m(m), m_pinned(m) {}
void add_production(production p) {
if (p.is_leaf())
@ -73,46 +74,8 @@ public:
vector<production> const& operators() const { return m_operators; }
ast_manager& mgr() const { return m; }
void add_variable(char const* name, sort* s) {
expr_ref var(m.mk_const(name, s), m);
sort_ref sr(s, m);
sort_ref_vector dom(m);
add_production({name, sr, dom, [var](expr_ref_vector const&) { return var; }});
}
void add_int_const(int val) {
arith_util a(m);
expr_ref e(a.mk_int(val), m);
sort_ref sr(a.mk_int(), m);
sort_ref_vector dom(m);
add_production({std::to_string(val), sr, dom, [e](expr_ref_vector const&) { return e; }});
}
void add_bv_const(int val, unsigned bits) {
bv_util bv(m);
expr_ref e(bv.mk_numeral(rational(val), bits), m);
sort_ref sr(bv.mk_sort(bits), m);
sort_ref_vector dom(m);
add_production({std::to_string(val), sr, dom, [e](expr_ref_vector const&) { return e; }});
}
void add_string_const(std::string const& val) {
seq_util seq(m);
expr_ref e(seq.str.mk_string(zstring(val.c_str())), m);
sort_ref sr(seq.str.mk_string_sort(), m);
sort_ref_vector dom(m);
add_production({"\"" + val + "\"", sr, dom, [e](expr_ref_vector const&) { return e; }});
}
void add_bool_const(bool val) {
expr_ref e(val ? m.mk_true() : m.mk_false(), m);
sort_ref sr(m.mk_bool_sort(), m);
sort_ref_vector dom(m);
std::string n = val ? "true" : "false";
add_production({n, sr, dom, [e](expr_ref_vector const&) { return e; }});
}
void add_func_decl(func_decl *f) {
m_pinned.push_back(f);
sort_ref range(f->get_range(), m);
sort_ref_vector dom(m);
for (unsigned i = 0; i < f->get_arity(); ++i)
@ -123,6 +86,7 @@ public:
}
void add_expr(expr *e) {
m_pinned.push_back(e);
sort_ref range(e->get_sort(), m);
sort_ref_vector dom(m);
std::stringstream ss;
@ -131,72 +95,31 @@ public:
add_production({name, range, dom, [this, e](expr_ref_vector const&) { return expr_ref(e, m); }});
}
std::ostream& display(std::ostream& out) const {
out << "Leaves:\n";
for (auto const &p : m_leaves) {
out << " " << p.name << " : " << mk_pp(p.range, m) << "\n";
}
out << "Operators:\n";
for (auto const &p : m_operators) {
out << " " << p.name << " : (";
for (unsigned i = 0; i < p.domain.size(); ++i) {
if (i > 0)
out << ", ";
out << mk_pp(p.domain[i], m);
}
out << ") -> " << mk_pp(p.range, m) << "\n";
}
return out;
}
private:
ast_manager& m;
ast_ref_vector m_pinned;
vector<production> m_leaves;
vector<production> m_operators;
};
// ============================================================================
// Observational Equivalence Manager
// ============================================================================
/**
* Evaluates candidate terms on a set of sample inputs and keeps only one
* representative per equivalence class (the one encountered first).
*
* Uses Z3's model evaluation to reduce terms to concrete values.
*/
class oe_manager {
public:
oe_manager(ast_manager& m) : m(m) {}
void set_samples(vector<model_ref> samples) {
m_samples = std::move(samples);
m_seen.clear();
}
void add_sample(model_ref mdl) {
m_samples.push_back(std::move(mdl));
m_seen.clear();
}
/**
* Returns true if `term` is a new representative (its fingerprint has
* not been seen before).
*/
bool is_representative(expr* term) {
auto fingerprint = compute_fingerprint(term);
if (fingerprint == 0)
return false;
return m_seen.insert(fingerprint).second;
}
void clear() { m_seen.clear(); }
size_t num_classes() const { return m_seen.size(); }
unsigned num_samples() const { return m_samples.size(); }
private:
ast_manager& m;
vector<model_ref> m_samples;
std::unordered_set<uint64_t> m_seen;
uint64_t compute_fingerprint(expr* term) {
uint64_t a = 1, b = 2, c = 0;
for (auto& mdl : m_samples) {
expr_ref val(m);
model_evaluator eval(*mdl);
eval.set_model_completion(true);
if (!eval.eval(term, val, true))
continue;
a *= val->hash();
mix(a, b, c);
}
return c;
}
};
// ============================================================================
// Term Bank - stores enumerated terms by cost and sort
// ============================================================================
@ -249,6 +172,36 @@ public:
return m_terms[cost]->find(s);
}
std::ostream& display(std::ostream& out) const {
for (unsigned cost = 0; cost < m_terms.size(); ++cost) {
if (!m_terms[cost])
continue;
out << "cost " << cost << ":\n";
for (auto& [s, terms] : *m_terms[cost]) {
out << " sort " << mk_pp(s, m) << ":\n";
for (expr* e : terms) {
out << " #" << e->get_id() << " ";
if (cost == 0) {
out << mk_bounded_pp(e, m);
}
else if (is_app(e)) {
app* a = to_app(e);
out << a->get_decl()->get_name() << "(";
bool first = true;
for (expr* arg : *a) {
if (!first) out << ", ";
first = false;
out << "#" << arg->get_id();
}
out << ")";
}
out << "\n";
}
}
}
return out;
}
private:
ast_manager& m;
expr_ref_vector m_pinned;
@ -338,33 +291,17 @@ private:
// bottom_up_enumerator - the main bottom-up term enumeration engine
// ============================================================================
/**
* Enumerates terms bottom-up by cost, applying observational equivalence
* pruning. Users iterate via has_next() / next().
*
* Usage:
* ast_manager m;
* grammar g(m);
* // ... add productions ...
* oe_manager oe(m);
* // ... set samples ...
* bottom_up_enumerator en(g, oe);
* arith_util a(m);
* en.set_target_sort(a.mk_int());
* while (en.has_next()) {
* expr_ref term = en.next();
* // ... check if term satisfies specification ...
* }
*/
class bottom_up_enumerator {
public:
bottom_up_enumerator(grammar& grammar, oe_manager& oe)
: m_grammar(grammar), m(grammar.mgr()), m_oe(oe),
bottom_up_enumerator(grammar& grammar)
: m_grammar(grammar), m(grammar.mgr()),
m_bank(grammar.mgr()), m_pending(grammar.mgr())
{}
void set_target_sort(sort* s) { m_target_sort = s; }
void set_target_sort(sort *s) {
m_target_sort = s;
}
bool has_next() {
if (m_pending) return true;
m_pending = find_next();
@ -377,17 +314,21 @@ public:
expr_ref result(m_pending, m);
m_pending = nullptr;
return result;
}
}
term_bank const& bank() const { return m_bank; }
std::ostream& display(std::ostream& out) const {
m_grammar.display(out);
return m_bank.display(out);
}
void reset() {
m_cost = 0;
m_leaf_idx = 0;
m_op_idx = 0;
m_state = State::Leaves;
m_bank.reset();
m_oe.clear();
m_pending = nullptr;
m_children_iter.reset();
}
@ -397,17 +338,17 @@ private:
grammar& m_grammar;
ast_manager& m;
oe_manager& m_oe;
term_bank m_bank;
unsigned m_cost = 0;
unsigned m_leaf_idx = 0;
unsigned m_op_idx = 0;
unsigned m_bank_idx = 0;
unsigned m_bank_size = 0;
bool m_has_range = false;
State m_state = State::Leaves;
sort* m_target_sort = nullptr;
expr_ref m_pending;
std::unique_ptr<children_iterator> m_children_iter;
sort *m_target_sort = nullptr;
bool sort_matches(expr* e) const {
return !m_target_sort || e->get_sort() == m_target_sort;
@ -418,21 +359,20 @@ private:
switch (m_state) {
case State::Leaves:
while (m_leaf_idx < m_grammar.leaves().size()) {
production const& prod = m_grammar.leaves()[m_leaf_idx];
production const &prod = m_grammar.leaves()[m_leaf_idx];
m_leaf_idx++;
expr_ref_vector empty_args(m);
expr_ref term = prod.builder(empty_args);
if (m_oe.is_representative(term)) {
m_bank.add(term, 0);
if (sort_matches(term))
return term;
}
m_bank.add(term, 0);
if (sort_matches(term))
return term;
}
m_state = State::Operators;
m_cost = 1;
m_op_idx = 0;
m_bank_idx = 0;
m_bank_size = get_bank_size();
m_bank_size = get_bank_size();
m_has_range = false;
m_children_iter.reset();
break;
@ -440,6 +380,7 @@ private:
expr* result = enumerate_operators();
if (result)
return result;
m_cost++;
m_op_idx = 0;
m_bank_idx = 0;
@ -477,16 +418,25 @@ private:
expr_ref_vector children = m_children_iter->next(new_cost);
production const &prod = ops[m_op_idx - 1];
expr_ref term = prod.builder(children);
// IF_VERBOSE(0, verbose_stream() << term << "\n");
SASSERT(new_cost >= m_cost);
if (m_oe.is_representative(term)) {
m_bank.add(term, new_cost);
if (sort_matches(term) && new_cost == m_cost)
return term;
}
m_bank.add(term, new_cost);
if (sort_matches(term) && new_cost == m_cost) {
m_has_range = true;
return term;
}
continue;
}
if (m_op_idx >= ops.size())
if (ops.empty()) {
m_state = State::Done;
return nullptr;
}
if (m_op_idx >= ops.size()) {
if (!m_has_range)
m_state = State::Done;
return nullptr;
}
production const &prod = ops[m_op_idx];
m_op_idx++;
m_children_iter = std::make_unique<children_iterator>(m, prod, m_bank, m_cost);
@ -503,12 +453,11 @@ private:
struct term_enumeration::imp {
ast_manager& m;
term_enum::grammar m_grammar;
term_enum::oe_manager m_oe;
term_enum::bottom_up_enumerator m_bottom_up_enumerator;
std::function<unsigned(expr*)> m_cost;
imp(ast_manager& m) :
m(m), m_grammar(m), m_oe(m), m_bottom_up_enumerator(m_grammar, m_oe) {}
m(m), m_grammar(m), m_bottom_up_enumerator(m_grammar) {}
void add_production(func_decl* f) {
m_grammar.add_func_decl(f);
@ -519,71 +468,11 @@ struct term_enumeration::imp {
}
void set_cost(std::function<unsigned(expr*)> const& cost) {
m_cost = cost;
// TODO
}
// Enumerate terms of given sort up to a cost, ordered by cost.
// Returns the next term in cost order, or nullptr if exhausted at current cost.
expr* next_term(sort* s, unsigned& cost_state, unsigned& idx_state,
vector<expr_ref_vector>& levels) {
// Expand levels as needed
while (idx_state >= level_size(levels, cost_state)) {
cost_state++;
expand_level(s, cost_state, levels);
idx_state = 0;
if (level_size(levels, cost_state) > 0)
break;
}
if (cost_state >= levels.size() || idx_state >= levels[cost_state].size())
return nullptr;
return levels[cost_state].get(idx_state++);
}
private:
unsigned level_size(vector<expr_ref_vector> const& levels, unsigned cost) const {
if (cost >= levels.size()) return 0;
return levels[cost].size();
}
void expand_level(sort* s, unsigned cost, vector<expr_ref_vector>& levels) {
if (cost >= levels.size())
levels.resize(cost + 1, expr_ref_vector(m));
// Collect terms at this cost
if (cost == 0) {
// Leaves
for (auto const& prod : m_grammar.leaves()) {
if (prod.range.get() != s) continue;
expr_ref_vector empty_args(m);
expr_ref term = prod.builder(empty_args);
if (m_oe.is_representative(term)) {
m_bottom_up_enumerator.bank(); // just to ensure bank is populated
levels[0].push_back(term);
}
}
}
else {
// operators
for (auto const& prod : m_grammar.operators()) {
if (prod.range.get() != s) continue;
term_enum::children_iterator iter(m, prod, m_bottom_up_enumerator.bank(), cost);
while (iter.has_next(cost)) {
unsigned new_cost = 0;
expr_ref_vector children = iter.next(new_cost);
expr_ref term = prod.builder(children);
levels.reserve(new_cost + 1, expr_ref_vector(m));
if (m_oe.is_representative(term))
levels[new_cost].push_back(term);
}
}
}
// Sort by cost if cost function is set
if (m_cost && !levels[cost].empty()) {
expr_ref_vector& lv = levels[cost];
std::sort(lv.data(), lv.data() + lv.size(),
[&](expr* a, expr* b) { return m_cost(a) < m_cost(b); });
}
std::ostream& display(std::ostream& out) const {
return m_bottom_up_enumerator.display(out);
}
};
@ -591,84 +480,85 @@ private:
struct term_enumeration::iterator::iter_imp {
imp& m_imp;
ast_manager & m;
sort* m_sort;
unsigned m_cost = 0;
unsigned m_idx = 0;
vector<expr_ref_vector> m_levels;
expr* m_current = nullptr;
expr_ref m_current;
bool m_end;
vector<expr_ref_vector> m_vars;
vector<ptr_vector<sort>> m_decls;
vector<vector<symbol>> m_names;
iter_imp(imp& i, sort* s) : m_imp(i), m_sort(s), m_end(false) {
expand_current_level();
advance_to_valid();
iter_imp(imp& i, sort* s) : m_imp(i), m(i.m), m_sort(s), m_current(i.m), m_end(false) {
m_imp.m_bottom_up_enumerator.reset();
init_sort();
advance();
}
// Sentinel constructor
iter_imp(imp& i) :
m_imp(i), m_sort(nullptr), m_end(true) {}
void expand_current_level() {
if (m_cost >= m_levels.size())
m_levels.resize(m_cost + 1, expr_ref_vector(m_imp.m));
if (!m_levels[m_cost].empty())
return;
if (m_cost == 0) {
for (auto const& prod : m_imp.m_grammar.leaves()) {
if (prod.range.get() != m_sort) continue;
expr_ref_vector empty_args(m_imp.m);
expr_ref term = prod.builder(empty_args);
if (m_imp.m_oe.is_representative(term))
m_levels[0].push_back(term);
}
}
else {
for (auto const& prod : m_imp.m_grammar.operators()) {
if (prod.range.get() != m_sort) continue;
term_enum::children_iterator iter(m_imp.m, prod, m_imp.m_bottom_up_enumerator.bank(), m_cost);
while (iter.has_next(m_cost)) {
unsigned new_cost = 0;
expr_ref_vector children = iter.next(new_cost);
expr_ref term = prod.builder(children);
m_levels.reserve(new_cost + 1, expr_ref_vector(m_imp.m));
if (m_imp.m_oe.is_representative(term))
m_levels[new_cost].push_back(term);
}
}
}
// Sort by cost if cost function is set
if (m_imp.m_cost && !m_levels[m_cost].empty()) {
expr_ref_vector& lv = m_levels[m_cost];
std::sort(lv.data(), lv.data() + lv.size(),
[&](expr* a, expr* b) { return m_imp.m_cost(a) < m_imp.m_cost(b); });
}
m_imp(i), m(i.m), m_sort(nullptr), m_current(i.m), m_end(true) {
UNREACHABLE();
}
void advance_to_valid() {
while (true) {
if (m_cost >= m_levels.size())
expand_current_level();
if (m_idx < m_levels[m_cost].size()) {
m_current = m_levels[m_cost].get(m_idx);
return;
void init_sort() {
array_util autil(m);
sort *range = m_sort;
while (autil.is_array(range)) {
m_vars.push_back(expr_ref_vector(m));
m_decls.push_back(ptr_vector<sort>());
m_names.push_back(vector<symbol>());
for (unsigned i = 0; i < get_array_arity(range); ++i) {
m_decls.back().push_back(get_array_domain(range, i));
m_vars.back().push_back(nullptr);
m_names.back().push_back(symbol());
}
m_cost++;
m_idx = 0;
if (m_cost > 100) {
m_end = true;
m_current = nullptr;
return;
#if 0
// TODO: don't enable this until we ensure only generating whnf (beta-redex free) expressions.
expr_ref_vector args(m);
args.push_back(m.mk_const("a", range));
for (unsigned i = 0; i < m_decls.back().size(); ++i) {
args.push_back(m.mk_var(i, m_decls.back().get(i)));
}
expand_current_level();
app_ref sel(autil.mk_select(args), m);
m_imp.m_grammar.add_func_decl(sel->get_decl());
#endif
range = get_array_range(range);
}
unsigned n = 0;
for (unsigned i = m_decls.size(); i-- > 0;) {
for (unsigned j = m_decls[i].size(); j-- > 0;) {
m_vars[i][j] = m.mk_var(n, m_decls[i][j]);
m_names[i][j] = symbol(n);
m_imp.add_production(m_vars[i].get(j));
n++;
}
}
m_sort = range;
m_imp.m_bottom_up_enumerator.set_target_sort(range);
}
void mk_lambda() {
if (!m_current)
return;
for (unsigned i = m_decls.size(); i-- > 0;)
m_current = m.mk_lambda(m_decls[i].size(), m_decls[i].data(), m_names[i].data(), m_current);
}
void advance() {
if (m_end) return;
m_idx++;
advance_to_valid();
if (m_end)
return;
m_current = m_imp.m_bottom_up_enumerator.next();
SASSERT(!m_current || m_current->get_sort() == m_sort);
mk_lambda();
if (!m_current)
m_end = true;
}
};
@ -680,22 +570,6 @@ term_enumeration::iterator::iterator(std::nullptr_t) {
m_imp = nullptr;
}
term_enumeration::iterator::iterator(iterator const& other) {
m_imp = nullptr;
if (other.m_imp)
m_imp = alloc(iter_imp, *other.m_imp);
}
term_enumeration::iterator& term_enumeration::iterator::operator=(iterator const& other) {
if (this != &other) {
dealloc(m_imp);
m_imp = nullptr;
if (other.m_imp)
m_imp = alloc(iter_imp, *other.m_imp);
}
return *this;
}
term_enumeration::iterator::~iterator() {
dealloc(m_imp);
}
@ -715,12 +589,12 @@ term_enumeration::iterator term_enumeration::iterator::operator++(int) {
return tmp;
}
bool term_enumeration::iterator::operator!=(iterator const& other) const {
if (!m_imp && !other.m_imp) return false;
if (!m_imp) return !other.m_imp->m_end;
if (!other.m_imp) return !m_imp->m_end;
return m_imp->m_end != other.m_imp->m_end ||
m_imp->m_current != other.m_imp->m_current;
bool term_enumeration::iterator::operator==(iterator const& other) const {
if (!m_imp && !other.m_imp) return true;
if (!m_imp) return other.m_imp->m_end;
if (!other.m_imp) return m_imp->m_end;
return m_imp->m_end == other.m_imp->m_end &&
m_imp->m_current == other.m_imp->m_current;
}
// -- terms implementation --
@ -760,3 +634,7 @@ void term_enumeration::set_cost(std::function<unsigned(expr*)> const& cost) {
term_enumeration::terms term_enumeration::enum_terms(sort* s) {
return terms(m_imp, s);
}
std::ostream& term_enumeration::display(std::ostream& out) const {
return m_imp->display(out);
}