3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2026-06-21 08:00:27 +00:00

updated term enumerator

This commit is contained in:
Nikolaj Bjorner 2026-06-19 16:18:07 -07:00
parent c0888b9ecd
commit 04ddb66931
8 changed files with 499 additions and 332 deletions

View file

@ -22,6 +22,7 @@
#include "util/obj_hashtable.h"
#include "ast/ast.h"
#include "ast/ast_ll_pp.h"
#include "ast/ast_pp.h"
#include "ast/arith_decl_plugin.h"
#include "ast/bv_decl_plugin.h"
#include "ast/seq_decl_plugin.h"
@ -60,7 +61,7 @@ struct production {
*/
class grammar {
public:
grammar(ast_manager& m) : m(m) {}
grammar(ast_manager& m) : m(m), m_pinned(m) {}
void add_production(production p) {
if (p.is_leaf())
@ -73,46 +74,8 @@ public:
vector<production> const& operators() const { return m_operators; }
ast_manager& mgr() const { return m; }
void add_variable(char const* name, sort* s) {
expr_ref var(m.mk_const(name, s), m);
sort_ref sr(s, m);
sort_ref_vector dom(m);
add_production({name, sr, dom, [var](expr_ref_vector const&) { return var; }});
}
void add_int_const(int val) {
arith_util a(m);
expr_ref e(a.mk_int(val), m);
sort_ref sr(a.mk_int(), m);
sort_ref_vector dom(m);
add_production({std::to_string(val), sr, dom, [e](expr_ref_vector const&) { return e; }});
}
void add_bv_const(int val, unsigned bits) {
bv_util bv(m);
expr_ref e(bv.mk_numeral(rational(val), bits), m);
sort_ref sr(bv.mk_sort(bits), m);
sort_ref_vector dom(m);
add_production({std::to_string(val), sr, dom, [e](expr_ref_vector const&) { return e; }});
}
void add_string_const(std::string const& val) {
seq_util seq(m);
expr_ref e(seq.str.mk_string(zstring(val.c_str())), m);
sort_ref sr(seq.str.mk_string_sort(), m);
sort_ref_vector dom(m);
add_production({"\"" + val + "\"", sr, dom, [e](expr_ref_vector const&) { return e; }});
}
void add_bool_const(bool val) {
expr_ref e(val ? m.mk_true() : m.mk_false(), m);
sort_ref sr(m.mk_bool_sort(), m);
sort_ref_vector dom(m);
std::string n = val ? "true" : "false";
add_production({n, sr, dom, [e](expr_ref_vector const&) { return e; }});
}
void add_func_decl(func_decl *f) {
m_pinned.push_back(f);
sort_ref range(f->get_range(), m);
sort_ref_vector dom(m);
for (unsigned i = 0; i < f->get_arity(); ++i)
@ -123,6 +86,7 @@ public:
}
void add_expr(expr *e) {
m_pinned.push_back(e);
sort_ref range(e->get_sort(), m);
sort_ref_vector dom(m);
std::stringstream ss;
@ -131,72 +95,31 @@ public:
add_production({name, range, dom, [this, e](expr_ref_vector const&) { return expr_ref(e, m); }});
}
std::ostream& display(std::ostream& out) const {
out << "Leaves:\n";
for (auto const &p : m_leaves) {
out << " " << p.name << " : " << mk_pp(p.range, m) << "\n";
}
out << "Operators:\n";
for (auto const &p : m_operators) {
out << " " << p.name << " : (";
for (unsigned i = 0; i < p.domain.size(); ++i) {
if (i > 0)
out << ", ";
out << mk_pp(p.domain[i], m);
}
out << ") -> " << mk_pp(p.range, m) << "\n";
}
return out;
}
private:
ast_manager& m;
ast_ref_vector m_pinned;
vector<production> m_leaves;
vector<production> m_operators;
};
// ============================================================================
// Observational Equivalence Manager
// ============================================================================
/**
* Evaluates candidate terms on a set of sample inputs and keeps only one
* representative per equivalence class (the one encountered first).
*
* Uses Z3's model evaluation to reduce terms to concrete values.
*/
class oe_manager {
public:
oe_manager(ast_manager& m) : m(m) {}
void set_samples(vector<model_ref> samples) {
m_samples = std::move(samples);
m_seen.clear();
}
void add_sample(model_ref mdl) {
m_samples.push_back(std::move(mdl));
m_seen.clear();
}
/**
* Returns true if `term` is a new representative (its fingerprint has
* not been seen before).
*/
bool is_representative(expr* term) {
auto fingerprint = compute_fingerprint(term);
if (fingerprint == 0)
return false;
return m_seen.insert(fingerprint).second;
}
void clear() { m_seen.clear(); }
size_t num_classes() const { return m_seen.size(); }
unsigned num_samples() const { return m_samples.size(); }
private:
ast_manager& m;
vector<model_ref> m_samples;
std::unordered_set<uint64_t> m_seen;
uint64_t compute_fingerprint(expr* term) {
uint64_t a = 1, b = 2, c = 0;
for (auto& mdl : m_samples) {
expr_ref val(m);
model_evaluator eval(*mdl);
eval.set_model_completion(true);
if (!eval.eval(term, val, true))
continue;
a *= val->hash();
mix(a, b, c);
}
return c;
}
};
// ============================================================================
// Term Bank - stores enumerated terms by cost and sort
// ============================================================================
@ -249,6 +172,36 @@ public:
return m_terms[cost]->find(s);
}
std::ostream& display(std::ostream& out) const {
for (unsigned cost = 0; cost < m_terms.size(); ++cost) {
if (!m_terms[cost])
continue;
out << "cost " << cost << ":\n";
for (auto& [s, terms] : *m_terms[cost]) {
out << " sort " << mk_pp(s, m) << ":\n";
for (expr* e : terms) {
out << " #" << e->get_id() << " ";
if (cost == 0) {
out << mk_bounded_pp(e, m);
}
else if (is_app(e)) {
app* a = to_app(e);
out << a->get_decl()->get_name() << "(";
bool first = true;
for (expr* arg : *a) {
if (!first) out << ", ";
first = false;
out << "#" << arg->get_id();
}
out << ")";
}
out << "\n";
}
}
}
return out;
}
private:
ast_manager& m;
expr_ref_vector m_pinned;
@ -338,33 +291,17 @@ private:
// bottom_up_enumerator - the main bottom-up term enumeration engine
// ============================================================================
/**
* Enumerates terms bottom-up by cost, applying observational equivalence
* pruning. Users iterate via has_next() / next().
*
* Usage:
* ast_manager m;
* grammar g(m);
* // ... add productions ...
* oe_manager oe(m);
* // ... set samples ...
* bottom_up_enumerator en(g, oe);
* arith_util a(m);
* en.set_target_sort(a.mk_int());
* while (en.has_next()) {
* expr_ref term = en.next();
* // ... check if term satisfies specification ...
* }
*/
class bottom_up_enumerator {
public:
bottom_up_enumerator(grammar& grammar, oe_manager& oe)
: m_grammar(grammar), m(grammar.mgr()), m_oe(oe),
bottom_up_enumerator(grammar& grammar)
: m_grammar(grammar), m(grammar.mgr()),
m_bank(grammar.mgr()), m_pending(grammar.mgr())
{}
void set_target_sort(sort* s) { m_target_sort = s; }
void set_target_sort(sort *s) {
m_target_sort = s;
}
bool has_next() {
if (m_pending) return true;
m_pending = find_next();
@ -377,17 +314,21 @@ public:
expr_ref result(m_pending, m);
m_pending = nullptr;
return result;
}
}
term_bank const& bank() const { return m_bank; }
std::ostream& display(std::ostream& out) const {
m_grammar.display(out);
return m_bank.display(out);
}
void reset() {
m_cost = 0;
m_leaf_idx = 0;
m_op_idx = 0;
m_state = State::Leaves;
m_bank.reset();
m_oe.clear();
m_pending = nullptr;
m_children_iter.reset();
}
@ -397,17 +338,17 @@ private:
grammar& m_grammar;
ast_manager& m;
oe_manager& m_oe;
term_bank m_bank;
unsigned m_cost = 0;
unsigned m_leaf_idx = 0;
unsigned m_op_idx = 0;
unsigned m_bank_idx = 0;
unsigned m_bank_size = 0;
bool m_has_range = false;
State m_state = State::Leaves;
sort* m_target_sort = nullptr;
expr_ref m_pending;
std::unique_ptr<children_iterator> m_children_iter;
sort *m_target_sort = nullptr;
bool sort_matches(expr* e) const {
return !m_target_sort || e->get_sort() == m_target_sort;
@ -418,21 +359,20 @@ private:
switch (m_state) {
case State::Leaves:
while (m_leaf_idx < m_grammar.leaves().size()) {
production const& prod = m_grammar.leaves()[m_leaf_idx];
production const &prod = m_grammar.leaves()[m_leaf_idx];
m_leaf_idx++;
expr_ref_vector empty_args(m);
expr_ref term = prod.builder(empty_args);
if (m_oe.is_representative(term)) {
m_bank.add(term, 0);
if (sort_matches(term))
return term;
}
m_bank.add(term, 0);
if (sort_matches(term))
return term;
}
m_state = State::Operators;
m_cost = 1;
m_op_idx = 0;
m_bank_idx = 0;
m_bank_size = get_bank_size();
m_bank_size = get_bank_size();
m_has_range = false;
m_children_iter.reset();
break;
@ -440,6 +380,7 @@ private:
expr* result = enumerate_operators();
if (result)
return result;
m_cost++;
m_op_idx = 0;
m_bank_idx = 0;
@ -477,16 +418,25 @@ private:
expr_ref_vector children = m_children_iter->next(new_cost);
production const &prod = ops[m_op_idx - 1];
expr_ref term = prod.builder(children);
// IF_VERBOSE(0, verbose_stream() << term << "\n");
SASSERT(new_cost >= m_cost);
if (m_oe.is_representative(term)) {
m_bank.add(term, new_cost);
if (sort_matches(term) && new_cost == m_cost)
return term;
}
m_bank.add(term, new_cost);
if (sort_matches(term) && new_cost == m_cost) {
m_has_range = true;
return term;
}
continue;
}
if (m_op_idx >= ops.size())
if (ops.empty()) {
m_state = State::Done;
return nullptr;
}
if (m_op_idx >= ops.size()) {
if (!m_has_range)
m_state = State::Done;
return nullptr;
}
production const &prod = ops[m_op_idx];
m_op_idx++;
m_children_iter = std::make_unique<children_iterator>(m, prod, m_bank, m_cost);
@ -503,12 +453,11 @@ private:
struct term_enumeration::imp {
ast_manager& m;
term_enum::grammar m_grammar;
term_enum::oe_manager m_oe;
term_enum::bottom_up_enumerator m_bottom_up_enumerator;
std::function<unsigned(expr*)> m_cost;
imp(ast_manager& m) :
m(m), m_grammar(m), m_oe(m), m_bottom_up_enumerator(m_grammar, m_oe) {}
m(m), m_grammar(m), m_bottom_up_enumerator(m_grammar) {}
void add_production(func_decl* f) {
m_grammar.add_func_decl(f);
@ -519,71 +468,11 @@ struct term_enumeration::imp {
}
void set_cost(std::function<unsigned(expr*)> const& cost) {
m_cost = cost;
// TODO
}
// Enumerate terms of given sort up to a cost, ordered by cost.
// Returns the next term in cost order, or nullptr if exhausted at current cost.
expr* next_term(sort* s, unsigned& cost_state, unsigned& idx_state,
vector<expr_ref_vector>& levels) {
// Expand levels as needed
while (idx_state >= level_size(levels, cost_state)) {
cost_state++;
expand_level(s, cost_state, levels);
idx_state = 0;
if (level_size(levels, cost_state) > 0)
break;
}
if (cost_state >= levels.size() || idx_state >= levels[cost_state].size())
return nullptr;
return levels[cost_state].get(idx_state++);
}
private:
unsigned level_size(vector<expr_ref_vector> const& levels, unsigned cost) const {
if (cost >= levels.size()) return 0;
return levels[cost].size();
}
void expand_level(sort* s, unsigned cost, vector<expr_ref_vector>& levels) {
if (cost >= levels.size())
levels.resize(cost + 1, expr_ref_vector(m));
// Collect terms at this cost
if (cost == 0) {
// Leaves
for (auto const& prod : m_grammar.leaves()) {
if (prod.range.get() != s) continue;
expr_ref_vector empty_args(m);
expr_ref term = prod.builder(empty_args);
if (m_oe.is_representative(term)) {
m_bottom_up_enumerator.bank(); // just to ensure bank is populated
levels[0].push_back(term);
}
}
}
else {
// operators
for (auto const& prod : m_grammar.operators()) {
if (prod.range.get() != s) continue;
term_enum::children_iterator iter(m, prod, m_bottom_up_enumerator.bank(), cost);
while (iter.has_next(cost)) {
unsigned new_cost = 0;
expr_ref_vector children = iter.next(new_cost);
expr_ref term = prod.builder(children);
levels.reserve(new_cost + 1, expr_ref_vector(m));
if (m_oe.is_representative(term))
levels[new_cost].push_back(term);
}
}
}
// Sort by cost if cost function is set
if (m_cost && !levels[cost].empty()) {
expr_ref_vector& lv = levels[cost];
std::sort(lv.data(), lv.data() + lv.size(),
[&](expr* a, expr* b) { return m_cost(a) < m_cost(b); });
}
std::ostream& display(std::ostream& out) const {
return m_bottom_up_enumerator.display(out);
}
};
@ -591,84 +480,85 @@ private:
struct term_enumeration::iterator::iter_imp {
imp& m_imp;
ast_manager & m;
sort* m_sort;
unsigned m_cost = 0;
unsigned m_idx = 0;
vector<expr_ref_vector> m_levels;
expr* m_current = nullptr;
expr_ref m_current;
bool m_end;
vector<expr_ref_vector> m_vars;
vector<ptr_vector<sort>> m_decls;
vector<vector<symbol>> m_names;
iter_imp(imp& i, sort* s) : m_imp(i), m_sort(s), m_end(false) {
expand_current_level();
advance_to_valid();
iter_imp(imp& i, sort* s) : m_imp(i), m(i.m), m_sort(s), m_current(i.m), m_end(false) {
m_imp.m_bottom_up_enumerator.reset();
init_sort();
advance();
}
// Sentinel constructor
iter_imp(imp& i) :
m_imp(i), m_sort(nullptr), m_end(true) {}
void expand_current_level() {
if (m_cost >= m_levels.size())
m_levels.resize(m_cost + 1, expr_ref_vector(m_imp.m));
if (!m_levels[m_cost].empty())
return;
if (m_cost == 0) {
for (auto const& prod : m_imp.m_grammar.leaves()) {
if (prod.range.get() != m_sort) continue;
expr_ref_vector empty_args(m_imp.m);
expr_ref term = prod.builder(empty_args);
if (m_imp.m_oe.is_representative(term))
m_levels[0].push_back(term);
}
}
else {
for (auto const& prod : m_imp.m_grammar.operators()) {
if (prod.range.get() != m_sort) continue;
term_enum::children_iterator iter(m_imp.m, prod, m_imp.m_bottom_up_enumerator.bank(), m_cost);
while (iter.has_next(m_cost)) {
unsigned new_cost = 0;
expr_ref_vector children = iter.next(new_cost);
expr_ref term = prod.builder(children);
m_levels.reserve(new_cost + 1, expr_ref_vector(m_imp.m));
if (m_imp.m_oe.is_representative(term))
m_levels[new_cost].push_back(term);
}
}
}
// Sort by cost if cost function is set
if (m_imp.m_cost && !m_levels[m_cost].empty()) {
expr_ref_vector& lv = m_levels[m_cost];
std::sort(lv.data(), lv.data() + lv.size(),
[&](expr* a, expr* b) { return m_imp.m_cost(a) < m_imp.m_cost(b); });
}
m_imp(i), m(i.m), m_sort(nullptr), m_current(i.m), m_end(true) {
UNREACHABLE();
}
void advance_to_valid() {
while (true) {
if (m_cost >= m_levels.size())
expand_current_level();
if (m_idx < m_levels[m_cost].size()) {
m_current = m_levels[m_cost].get(m_idx);
return;
void init_sort() {
array_util autil(m);
sort *range = m_sort;
while (autil.is_array(range)) {
m_vars.push_back(expr_ref_vector(m));
m_decls.push_back(ptr_vector<sort>());
m_names.push_back(vector<symbol>());
for (unsigned i = 0; i < get_array_arity(range); ++i) {
m_decls.back().push_back(get_array_domain(range, i));
m_vars.back().push_back(nullptr);
m_names.back().push_back(symbol());
}
m_cost++;
m_idx = 0;
if (m_cost > 100) {
m_end = true;
m_current = nullptr;
return;
#if 0
// TODO: don't enable this until we ensure only generating whnf (beta-redex free) expressions.
expr_ref_vector args(m);
args.push_back(m.mk_const("a", range));
for (unsigned i = 0; i < m_decls.back().size(); ++i) {
args.push_back(m.mk_var(i, m_decls.back().get(i)));
}
expand_current_level();
app_ref sel(autil.mk_select(args), m);
m_imp.m_grammar.add_func_decl(sel->get_decl());
#endif
range = get_array_range(range);
}
unsigned n = 0;
for (unsigned i = m_decls.size(); i-- > 0;) {
for (unsigned j = m_decls[i].size(); j-- > 0;) {
m_vars[i][j] = m.mk_var(n, m_decls[i][j]);
m_names[i][j] = symbol(n);
m_imp.add_production(m_vars[i].get(j));
n++;
}
}
m_sort = range;
m_imp.m_bottom_up_enumerator.set_target_sort(range);
}
void mk_lambda() {
if (!m_current)
return;
for (unsigned i = m_decls.size(); i-- > 0;)
m_current = m.mk_lambda(m_decls[i].size(), m_decls[i].data(), m_names[i].data(), m_current);
}
void advance() {
if (m_end) return;
m_idx++;
advance_to_valid();
if (m_end)
return;
m_current = m_imp.m_bottom_up_enumerator.next();
SASSERT(!m_current || m_current->get_sort() == m_sort);
mk_lambda();
if (!m_current)
m_end = true;
}
};
@ -680,22 +570,6 @@ term_enumeration::iterator::iterator(std::nullptr_t) {
m_imp = nullptr;
}
term_enumeration::iterator::iterator(iterator const& other) {
m_imp = nullptr;
if (other.m_imp)
m_imp = alloc(iter_imp, *other.m_imp);
}
term_enumeration::iterator& term_enumeration::iterator::operator=(iterator const& other) {
if (this != &other) {
dealloc(m_imp);
m_imp = nullptr;
if (other.m_imp)
m_imp = alloc(iter_imp, *other.m_imp);
}
return *this;
}
term_enumeration::iterator::~iterator() {
dealloc(m_imp);
}
@ -715,12 +589,12 @@ term_enumeration::iterator term_enumeration::iterator::operator++(int) {
return tmp;
}
bool term_enumeration::iterator::operator!=(iterator const& other) const {
if (!m_imp && !other.m_imp) return false;
if (!m_imp) return !other.m_imp->m_end;
if (!other.m_imp) return !m_imp->m_end;
return m_imp->m_end != other.m_imp->m_end ||
m_imp->m_current != other.m_imp->m_current;
bool term_enumeration::iterator::operator==(iterator const& other) const {
if (!m_imp && !other.m_imp) return true;
if (!m_imp) return other.m_imp->m_end;
if (!other.m_imp) return m_imp->m_end;
return m_imp->m_end == other.m_imp->m_end &&
m_imp->m_current == other.m_imp->m_current;
}
// -- terms implementation --
@ -760,3 +634,7 @@ void term_enumeration::set_cost(std::function<unsigned(expr*)> const& cost) {
term_enumeration::terms term_enumeration::enum_terms(sort* s) {
return terms(m_imp, s);
}
std::ostream& term_enumeration::display(std::ostream& out) const {
return m_imp->display(out);
}

View file

@ -12,6 +12,7 @@ public:
void add_production(func_decl* f);
void add_production(expr* e);
// void add_production(sort *s, std::function<expr *()> g);
// cost function associated with expressions.
// terms are enumerated with increasing cost.
@ -24,13 +25,14 @@ public:
public:
iterator(imp& i, sort* s);
iterator(std::nullptr_t);
iterator(iterator const& other);
iterator& operator=(iterator const& other);
~iterator();
expr* operator*();
iterator operator++(int);
iterator& operator++();
bool operator!=(iterator const& other) const;
bool operator!=(iterator const& other) const {
return !(*this == other);
}
bool operator==(iterator const &other) const;
};
class terms {
@ -43,4 +45,6 @@ public:
};
terms enum_terms(sort* s);
std::ostream& display(std::ostream& out) const;
};

View file

@ -219,7 +219,7 @@ namespace smt {
if (use_inv) {
unsigned sk_term_gen = 0;
expr * sk_term = m_model_finder.get_inv(q, i, sk_value, sk_term_gen);
expr * sk_term = m_model_finder.get_inv(q, i, sk_value, *cex, sk_term_gen);
if (sk_term != nullptr) {
TRACE(model_checker, tout << "Found inverse " << mk_pp(sk_term, m) << "\n";);
SASSERT(!m.is_model_value(sk_term));
@ -238,10 +238,6 @@ namespace smt {
TRACE(model_checker, tout << "sk term " << mk_pp(sk_term, m) << "\n");
sk_value = sk_term;
}
// last ditch: am I an array?
else if (false && autil.is_as_array(sk_value, f) && cex->get_func_interp(f) && cex->get_func_interp(f)->get_array_interp(f)) {
sk_value = cex->get_func_interp(f)->get_array_interp(f);
}
}
if (contains_model_value(sk_value)) {

View file

@ -18,6 +18,7 @@ Revision History:
--*/
#include "util/backtrackable_set.h"
#include "ast/ast_util.h"
#include "ast/has_free_vars.h"
#include "ast/macros/macro_util.h"
#include "ast/arith_decl_plugin.h"
#include "ast/bv_decl_plugin.h"
@ -108,9 +109,15 @@ namespace smt {
}
}
expr* get_inv(expr* v) const {
expr* get_inv(expr* v, model& mdl) const {
expr* t = nullptr;
m_inv.find(v, t);
if (!t) {
for (auto [k, term] : m_inv) {
if (mdl.are_equal(k, v))
return term;
}
}
return t;
}
@ -121,14 +128,11 @@ namespace smt {
}
void mk_inverse(evaluator& ev) {
for (auto const& kv : m_elems) {
expr* t = kv.m_key;
for (auto const &[t, gen] : m_elems) {
SASSERT(!contains_model_value(t));
unsigned gen = kv.m_value;
expr* t_val = ev.eval(t, true);
if (!t_val) break;
TRACE(model_finder, tout << mk_pp(t, m) << " " << mk_pp(t_val, m) << "\n";);
expr* old_t = nullptr;
if (m_inv.find(t_val, old_t)) {
unsigned old_t_gen = 0;
@ -292,7 +296,7 @@ namespace smt {
}
void insert(expr* n, unsigned generation) {
if (is_ground(n))
if (is_ground(n) || (has_quantifiers(n) && !has_free_vars(n))) // this is a closed term
get_root()->m_set->insert(n, generation);
}
@ -600,7 +604,10 @@ namespace smt {
}
else {
r = tmp;
TRACE(model_finder, tout << "eval\n" << mk_pp(n, m) << "\n----->\n" << mk_pp(r, m) << "\n";);
TRACE(model_finder, tout << "eval-failed\n" << mk_pp(n, m) << "\n----->\n" << mk_pp(r, m) << "\n";);
if (is_lambda(tmp)) {
r = m.mk_fresh_const("lambda", tmp->get_sort());
}
}
m_eval_cache[model_completion].insert(n, r);
m_eval_cache_range.push_back(r);
@ -1388,43 +1395,30 @@ namespace smt {
}
void display(std::ostream &out) const override {
out << "(" << "ho-var " << ":" << m_var_i << ")";
out << "(" << "ho-var: " << m_var_i << ")";
}
void process_auf(quantifier *q, auf_solver &s, context *ctx) override {
/* node * S_i = */ s.get_uvar(q, m_var_i);
}
void populate_inst_sets(quantifier *q, auf_solver &s, context *ctx) override {
node *S = s.get_uvar(q, m_var_i);
sort *srt = S->get_sort();
sort* range = get_array_range(srt);
unsigned arity = get_array_arity(srt);
IF_VERBOSE(0, verbose_stream() << "ho_var::populate_inst_sets: " << q->get_id() << " " << mk_pp(srt, m) << "\n";);
term_enumeration tn(m);
// Add ground terms of type S.
// Add productions for functions in E-graph
// add other possible relevant functions such as equality over srt, Boolean operators
// TODO: use term_enumerator to produce instances int the instantiation set of S.
expr_ref_vector vars(m);
ptr_vector<sort> sorts;
vector<symbol> names;
for (unsigned i = 0; i < arity; ++i) {
vars.push_back(m.mk_var(i, get_array_domain(srt, i)));
auto v = vars.back();
tn.add_production(v);
sorts.push_back(v->get_sort());
names.push_back(symbol(i));
}
auto mk_lambda = [&](expr* body) {
return m.mk_lambda(vars.size(), sorts.data(), names.data(), body);
};
ast_mark visited;
for (enode *n : ctx->enodes()) {
if (false && !ctx->is_relevant(n))
if (!ctx->is_relevant(n))
continue;
auto e = n->get_expr();
if (srt == n->get_sort()) {
IF_VERBOSE(0, verbose_stream() << "inserting " << mk_pp(e, m) << " into inst set\n");
S->insert(e, n->get_generation());
}
else if (is_uninterp_const(e)) {
@ -1443,11 +1437,10 @@ namespace smt {
unsigned max_count = 20;
for (auto t : tn.enum_terms(srt)) {
auto lam = mk_lambda(t);
unsigned generation = 0; // todo - inherited from sub-term of t?
IF_VERBOSE(0, verbose_stream() << "ho_var: adding term " << mk_ismt2_pp(t, m)
<< " to instantiation set of S" << std::endl;);
S->insert(lam, generation);
S->insert(t, generation);
}
}
};
@ -2251,7 +2244,6 @@ namespace smt {
}
SASSERT(is_quantifier(atom));
UNREACHABLE();
}
void process_literal(expr* atom, polarity pol) {
@ -2603,11 +2595,12 @@ namespace smt {
Store in generation the generation of the result
*/
expr* model_finder::get_inv(quantifier* q, unsigned i, expr* val, unsigned& generation) {
expr* model_finder::get_inv(quantifier* q, unsigned i, expr* val, model& mdl,unsigned& generation) {
instantiation_set const* s = get_uvar_inst_set(q, i);
if (s == nullptr)
return nullptr;
expr* t = s->get_inv(val);
expr* t = s->get_inv(val, mdl);
if (m_auf_solver->is_default_representative(t))
return val;
if (t != nullptr) {
@ -2643,16 +2636,27 @@ namespace smt {
obj_map<expr, expr*> const& inv = s->get_inv_map();
if (inv.empty())
continue; // nothing to do
ptr_buffer<expr> eqs;
for (auto const& [val, _] : inv) {
if (val->get_sort() == sk->get_sort())
eqs.push_back(m.mk_eq(sk, val));
expr_ref_vector eqs(m), defs(m);
for (auto const& [val, term] : inv) {
if (val->get_sort() == sk->get_sort()) {
if (is_lambda(term)) {
eqs.push_back(m.mk_eq(sk, val));
defs.push_back(m.mk_eq(val, term));
}
else
eqs.push_back(m.mk_eq(sk, val));
}
}
if (!eqs.empty()) {
expr_ref new_cnstr(m);
new_cnstr = m.mk_or(eqs);
TRACE(model_finder, tout << "assert_restriction:\n" << mk_pp(new_cnstr, m) << "\n";);
aux_ctx->assert_expr(new_cnstr);
for (auto def : defs) {
TRACE(model_finder, tout << "assert_def:\n" << mk_pp(def, m) << "\n";);
aux_ctx->assert_expr(def);
}
asserted_something = true;
}
}

View file

@ -113,7 +113,7 @@ namespace smt {
void fix_model(proto_model * m);
quantifier * get_flat_quantifier(quantifier * q);
expr * get_inv(quantifier * q, unsigned i, expr * val, unsigned & generation);
expr * get_inv(quantifier * q, unsigned i, expr * val, model& m, unsigned & generation);
bool restrict_sks_to_inst_set(context * aux_ctx, quantifier * q, expr_ref_vector const & sks);
void restart_eh();

View file

@ -143,6 +143,7 @@ add_executable(test-z3
symbol.cpp
symbol_table.cpp
tbv.cpp
term_enumeration.cpp
theory_dl.cpp
theory_pb.cpp
timeout.cpp

View file

@ -193,7 +193,8 @@
X(ho_matcher) \
X(finite_set) \
X(finite_set_rewriter) \
X(fpa)
X(fpa) \
X(term_enumeration)
#define FOR_EACH_TEST(X, X_ARGV) \
FOR_EACH_ALL_TEST(X, X_ARGV) \

View file

@ -0,0 +1,283 @@
/*++
Copyright (c) 2024 Microsoft Corporation
Module Name:
tst_term_enumeration.cpp
Abstract:
Test term enumeration module
--*/
#include "ast/term_enumeration.h"
#include "ast/ast_pp.h"
#include "ast/arith_decl_plugin.h"
#include "ast/bv_decl_plugin.h"
#include "ast/array_decl_plugin.h"
#include "ast/reg_decl_plugins.h"
#include <iostream>
#include <sstream>
static void tst_basic_enumeration() {
std::cout << "=== test basic enumeration ===\n";
ast_manager m;
reg_decl_plugins(m);
arith_util a(m);
term_enumeration te(m);
// Add some leaf productions (constants)
expr_ref zero(a.mk_int(0), m);
expr_ref one(a.mk_int(1), m);
te.add_production(zero);
te.add_production(one);
// Enumerate terms of Int sort
sort* int_sort = a.mk_int();
unsigned count = 0;
for (expr* e : te.enum_terms(int_sort)) {
std::cout << "Term: " << mk_pp(e, m) << "\n";
count++;
if (count >= 5) break; // Limit output
}
ENSURE(count >= 2); // At least 0 and 1
std::cout << "Enumerated " << count << " terms\n";
}
static void tst_enumeration_with_operators() {
std::cout << "=== test enumeration with operators ===\n";
ast_manager m;
reg_decl_plugins(m);
arith_util a(m);
term_enumeration te(m);
// Add leaf productions
expr_ref zero(a.mk_int(0), m);
expr_ref one(a.mk_int(1), m);
te.add_production(zero);
te.add_production(one);
// Add operator productions (+ and *)
// Get func_decl by creating an app and extracting the decl
app_ref tmp_add(a.mk_add(zero, one), m);
app_ref tmp_mul(a.mk_mul(zero, one), m);
func_decl* add_decl = tmp_add->get_decl();
func_decl* mul_decl = tmp_mul->get_decl();
te.add_production(add_decl);
te.add_production(mul_decl);
sort* int_sort = a.mk_int();
unsigned count = 0;
for (expr* e : te.enum_terms(int_sort)) {
std::cout << "Term: " << mk_pp(e, m) << "\n";
count++;
if (count >= 20) break; // Limit output
}
ENSURE(count >= 2); // At least the leaves
std::cout << "Enumerated " << count << " terms with operators\n";
}
static void tst_display() {
std::cout << "=== test display ===\n";
ast_manager m;
reg_decl_plugins(m);
arith_util a(m);
term_enumeration te(m);
// Add leaf productions
expr_ref zero(a.mk_int(0), m);
expr_ref one(a.mk_int(1), m);
te.add_production(zero);
te.add_production(one);
// Add operator productions
app_ref tmp_add(a.mk_add(zero, one), m);
func_decl* add_decl = tmp_add->get_decl();
te.add_production(add_decl);
sort* int_sort = a.mk_int();
unsigned count = 0;
for (expr* e : te.enum_terms(int_sort)) {
(void)e;
count++;
if (count >= 10) break;
}
std::cout << "Internal state after enumeration:\n";
std::ostringstream oss;
te.display(oss);
std::cout << oss.str();
// Verify display produced some output
ENSURE(!oss.str().empty());
}
static void tst_bitvector_enumeration() {
std::cout << "=== test bitvector enumeration ===\n";
ast_manager m;
reg_decl_plugins(m);
bv_util bv(m);
term_enumeration te(m);
// Add bitvector constants
unsigned bv_size = 8;
expr_ref bv_zero(bv.mk_numeral(0, bv_size), m);
expr_ref bv_one(bv.mk_numeral(1, bv_size), m);
te.add_production(bv_zero);
te.add_production(bv_one);
// Add bvadd operator
app_ref tmp_add(bv.mk_bv_add(bv_zero, bv_one), m);
func_decl* bvadd = tmp_add->get_decl();
te.add_production(bvadd);
sort* bv8 = bv.mk_sort(bv_size);
unsigned count = 0;
for (expr* e : te.enum_terms(bv8)) {
std::cout << "BV Term: " << mk_pp(e, m) << "\n";
count++;
if (count >= 10) break;
}
ENSURE(count >= 2);
std::cout << "Enumerated " << count << " bitvector terms\n";
}
static void tst_multiple_sorts() {
std::cout << "=== test multiple sorts ===\n";
ast_manager m;
reg_decl_plugins(m);
arith_util a(m);
term_enumeration te(m);
// Add Int constants
expr_ref i_zero(a.mk_int(0), m);
expr_ref i_one(a.mk_int(1), m);
te.add_production(i_zero);
te.add_production(i_one);
// Add Real constants
expr_ref r_zero(a.mk_real(0), m);
expr_ref r_one(a.mk_real(1), m);
te.add_production(r_zero);
te.add_production(r_one);
// Enumerate Int terms
sort* int_sort = a.mk_int();
unsigned int_count = 0;
for (expr* e : te.enum_terms(int_sort)) {
std::cout << "Int Term: " << mk_pp(e, m) << "\n";
int_count++;
if (int_count >= 5) break;
}
ENSURE(int_count >= 2);
std::cout << "Enumerated " << int_count << " Int terms\n";
}
static void tst_nested_array_enumeration() {
std::cout << "=== test nested array enumeration (Array(A, Array(B, A))) ===\n";
ast_manager m;
reg_decl_plugins(m);
array_util arr(m);
term_enumeration te(m);
// Create uninterpreted sorts A and B
sort_ref sort_A(m.mk_uninterpreted_sort(symbol("A")), m);
sort_ref sort_B(m.mk_uninterpreted_sort(symbol("B")), m);
// Create nested array sort: Array(B, A) - arrays indexed by B returning A
sort_ref array_B_A(arr.mk_array_sort(sort_B, sort_A), m);
// Create outer array sort: Array(A, Array(B, A)) - arrays indexed by A returning Array(B,A)
sort_ref array_A_arrayBA(arr.mk_array_sort(sort_A, array_B_A), m);
std::cout << "Sort A: " << mk_pp(sort_A.get(), m) << "\n";
std::cout << "Sort B: " << mk_pp(sort_B.get(), m) << "\n";
std::cout << "Sort Array(B, A): " << mk_pp(array_B_A.get(), m) << "\n";
std::cout << "Sort Array(A, Array(B, A)): " << mk_pp(array_A_arrayBA.get(), m) << "\n";
// Add constants of sort A
app_ref a0(m.mk_const(symbol("a0"), sort_A), m);
app_ref a1(m.mk_const(symbol("a1"), sort_A), m);
te.add_production(a0);
te.add_production(a1);
// Add constants of sort B
app_ref b0(m.mk_const(symbol("b0"), sort_B), m);
app_ref b1(m.mk_const(symbol("b1"), sort_B), m);
te.add_production(b0);
te.add_production(b1);
// Add a constant array of inner type Array(B, A) - const_array(a0) : Array(B, A)
app_ref const_inner(arr.mk_const_array(array_B_A, a0), m);
te.add_production(const_inner);
// Add a constant array of outer type Array(A, Array(B, A))
app_ref const_outer(arr.mk_const_array(array_A_arrayBA, const_inner), m);
te.add_production(const_outer);
// Add store operator for the inner array type Array(B, A)
// store(array, index, value) : store(Array(B,A), B, A) -> Array(B,A)
expr* store_inner_args[3] = { const_inner.get(), b0.get(), a0.get() };
app_ref tmp_store_inner(arr.mk_store(3, store_inner_args), m);
func_decl* store_inner_decl = tmp_store_inner->get_decl();
te.add_production(store_inner_decl);
// Add store operator for the outer array type Array(A, Array(B, A))
// store(array, index, value) : store(Array(A, Array(B,A)), A, Array(B,A)) -> Array(A, Array(B,A))
expr* store_outer_args[3] = { const_outer.get(), a0.get(), const_inner.get() };
app_ref tmp_store_outer(arr.mk_store(3, store_outer_args), m);
func_decl* store_outer_decl = tmp_store_outer->get_decl();
te.add_production(store_outer_decl);
// Add select operator for the outer array (returns Array(B, A))
// select(Array(A, Array(B,A)), A) -> Array(B, A)
app_ref tmp_select_outer(arr.mk_select(const_outer.get(), a0.get()), m);
func_decl* select_outer_decl = tmp_select_outer->get_decl();
te.add_production(select_outer_decl);
// Enumerate terms of the nested array sort Array(A, Array(B, A))
std::cout << "\nEnumerating terms of sort Array(A, Array(B, A)):\n";
unsigned count = 0;
for (expr* e : te.enum_terms(array_A_arrayBA)) {
std::cout << " Term " << count << ": " << mk_pp(e, m) << "\n";
count++;
if (count >= 15) break; // Limit output
}
ENSURE(count >= 1); // At least the constant array
std::cout << "Enumerated " << count << " terms of sort Array(A, Array(B, A))\n";
// Also enumerate terms of the inner array sort Array(B, A)
std::cout << "\nEnumerating terms of sort Array(B, A):\n";
unsigned inner_count = 0;
for (expr* e : te.enum_terms(array_B_A)) {
std::cout << " Term " << inner_count << ": " << mk_pp(e, m) << "\n";
inner_count++;
if (inner_count >= 10) break;
}
// ENSURE(inner_count >= 1);
std::cout << "Enumerated " << inner_count << " terms of sort Array(B, A)\n";
te.display(std::cout);
}
void tst_term_enumeration() {
tst_basic_enumeration();
tst_enumeration_with_operators();
tst_display();
tst_bitvector_enumeration();
tst_multiple_sorts();
tst_nested_array_enumeration();
std::cout << "All term_enumeration tests passed!\n";
}