mirror of
https://github.com/Z3Prover/z3
synced 2026-06-22 16:40:29 +00:00
Term enumeration (#9908)
Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com> Signed-off-by: dependabot[bot] <support@github.com> Signed-off-by: Lev Nachmanson <levnach@hotmail.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com> Co-authored-by: davedets <daviddetlefs@gmail.com> Co-authored-by: Lev Nachmanson <5377127+levnach@users.noreply.github.com> Co-authored-by: Claude Fable 5 <noreply@anthropic.com> Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Co-authored-by: Margus Veanes <veanes@users.noreply.github.com> Co-authored-by: Nuno Lopes <nuno.lopes@tecnico.ulisboa.pt> Co-authored-by: Shantanu Gontia <gontia.shantanu@gmail.com> Co-authored-by: Peter Chen J. <34339487+peter941221@users.noreply.github.com> Co-authored-by: Alcides Fonseca <me@alcidesfonseca.com> Co-authored-by: Can Cebeci <can.cebeci99@gmail.com> Co-authored-by: Can Cebeci <t-cancebeci@microsoft.com>
This commit is contained in:
parent
b9cc87ae4b
commit
5699142f5b
10 changed files with 1156 additions and 33 deletions
|
|
@ -43,6 +43,7 @@ z3_add_component(rewriter
|
|||
seq_rewriter.cpp
|
||||
seq_regex_bisim.cpp
|
||||
seq_skolem.cpp
|
||||
term_enumeration.cpp
|
||||
th_rewriter.cpp
|
||||
value_sweep.cpp
|
||||
var_subst.cpp
|
||||
|
|
|
|||
674
src/ast/rewriter/term_enumeration.cpp
Normal file
674
src/ast/rewriter/term_enumeration.cpp
Normal file
|
|
@ -0,0 +1,674 @@
|
|||
/**
|
||||
* term_enumeration.cpp - Bottom-up term enumeration module for Z3
|
||||
*
|
||||
* Inspired by the Probe synthesizer (Barke et al., "Just-in-Time Learning
|
||||
* for Bottom-Up Enumerative Synthesis"). Adapted to use Z3's internal APIs.
|
||||
*
|
||||
* Key ideas:
|
||||
* - Terms are enumerated bottom-up by "cost" (calculated by tree size).
|
||||
* - A grammar describes which function symbols (operators) and leaves
|
||||
* (constants, variables) are available for enumeration.
|
||||
*/
|
||||
|
||||
#include <sstream>
|
||||
#include <functional>
|
||||
#include <string>
|
||||
#include "util/vector.h"
|
||||
#include "util/scoped_ptr_vector.h"
|
||||
#include "util/obj_hashtable.h"
|
||||
#include "util/uint_set.h"
|
||||
#include "ast/ast.h"
|
||||
#include "ast/ast_ll_pp.h"
|
||||
#include "ast/ast_pp.h"
|
||||
#include "ast/rewriter/th_rewriter.h"
|
||||
#include "ast/rewriter/term_enumeration.h"
|
||||
|
||||
|
||||
namespace term_enum {
|
||||
|
||||
// ============================================================================
|
||||
// grammar production rule
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* A production describes how to construct a term from child terms.
|
||||
* - domain: the sort required for each child
|
||||
* - range: the sort of the produced term
|
||||
* - builder: given a vector of child exprs, produce the result expr
|
||||
*/
|
||||
struct production {
|
||||
std::string name;
|
||||
sort_ref range;
|
||||
sort_ref_vector domain;
|
||||
std::function<expr_ref(expr_ref_vector const&)> builder;
|
||||
|
||||
bool is_leaf() const { return domain.empty(); }
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// grammar
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* A grammar groups productions into leaves (arity 0) and operators (arity > 0).
|
||||
*/
|
||||
class grammar {
|
||||
public:
|
||||
grammar(ast_manager& m) : m(m), m_pinned(m) {}
|
||||
|
||||
void add_production(production* p) {
|
||||
if (p->is_leaf())
|
||||
m_leaves.push_back(p);
|
||||
else
|
||||
m_operators.push_back(p);
|
||||
}
|
||||
|
||||
scoped_ptr_vector<production> const& leaves() const { return m_leaves; }
|
||||
scoped_ptr_vector<production> const& operators() const { return m_operators; }
|
||||
ast_manager& mgr() const { return m; }
|
||||
|
||||
void add_func_decl(func_decl *f) {
|
||||
if (m_seen.contains(f))
|
||||
return;
|
||||
m_pinned.push_back(f);
|
||||
m_seen.insert(f);
|
||||
sort_ref range(f->get_range(), m);
|
||||
sort_ref_vector dom(m);
|
||||
for (unsigned i = 0; i < f->get_arity(); ++i)
|
||||
dom.push_back(sort_ref(f->get_domain(i), m));
|
||||
add_production(alloc(production, {f->get_name().str(), range, dom, [this, f](expr_ref_vector const &args) {
|
||||
return expr_ref(m.mk_app(f, args), m);
|
||||
}}));
|
||||
}
|
||||
|
||||
void add_expr(expr *e) {
|
||||
if (m_seen.contains(e))
|
||||
return;
|
||||
m_pinned.push_back(e);
|
||||
m_seen.insert(e);
|
||||
sort_ref range(e->get_sort(), m);
|
||||
sort_ref_vector dom(m);
|
||||
std::stringstream ss;
|
||||
ss << mk_bounded_pp(e, m);
|
||||
std::string name = ss.str();
|
||||
add_production(alloc(production, {name, range, dom, [this, e](expr_ref_vector const&) { return expr_ref(e, m); }}));
|
||||
}
|
||||
|
||||
std::ostream& display(std::ostream& out) const {
|
||||
out << "Leaves:\n";
|
||||
for (auto const *p : m_leaves) {
|
||||
out << " " << p->name << " : " << mk_pp(p->range, m) << "\n";
|
||||
}
|
||||
out << "Operators:\n";
|
||||
for (auto const *p : m_operators) {
|
||||
out << " " << p->name << " : (";
|
||||
for (unsigned i = 0; i < p->domain.size(); ++i) {
|
||||
if (i > 0)
|
||||
out << ", ";
|
||||
out << mk_pp(p->domain[i], m);
|
||||
}
|
||||
out << ") -> " << mk_pp(p->range, m) << "\n";
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
private:
|
||||
ast_manager& m;
|
||||
ast_ref_vector m_pinned;
|
||||
scoped_ptr_vector<production> m_leaves;
|
||||
scoped_ptr_vector<production> m_operators;
|
||||
obj_hashtable<ast> m_seen;
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// Term Bank - stores enumerated terms by cost and sort
|
||||
// ============================================================================
|
||||
|
||||
using cost_terms = vector<std::pair<expr*, unsigned>>;
|
||||
|
||||
class term_bank {
|
||||
using sort_term_map = obj_map<sort, ptr_vector<expr>>;
|
||||
public:
|
||||
term_bank(ast_manager& m) : m(m), m_pinned(m) {}
|
||||
|
||||
~term_bank() {
|
||||
for (auto s : m_terms)
|
||||
dealloc(s);
|
||||
}
|
||||
|
||||
void reset() {
|
||||
m_pinned.reset();
|
||||
m_terms.clear();
|
||||
}
|
||||
|
||||
void add(expr* term, unsigned cost) {
|
||||
sort* s = term->get_sort();
|
||||
m_pinned.push_back(term);
|
||||
if (cost >= m_terms.size())
|
||||
m_terms.resize(cost + 1);
|
||||
if (!m_terms[cost])
|
||||
m_terms[cost] = alloc(sort_term_map);
|
||||
m_terms[cost]->insert_if_not_there(s, ptr_vector<expr>()).push_back(term);
|
||||
}
|
||||
|
||||
/** Get all terms of a given sort up to (and including) max_cost */
|
||||
cost_terms get_by_sort(sort* s, unsigned max_cost) const {
|
||||
cost_terms result;
|
||||
for (unsigned c = 0; c <= max_cost; ++c) {
|
||||
if (c >= m_terms.size())
|
||||
break;
|
||||
if (!m_terms[c]->contains(s))
|
||||
continue;
|
||||
for (auto t : m_terms[c]->find(s))
|
||||
result.push_back({t, c});
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Return true if there is at least one term at/above `cost` whose sort is
|
||||
// not in `sorts` (i.e., enumeration can still produce a new requested sort).
|
||||
bool is_productive(unsigned cost, uint_set const& sorts) {
|
||||
for (unsigned i = cost; i < m_terms.size(); ++i) {
|
||||
if (!m_terms[i])
|
||||
continue;
|
||||
for (auto const& entry : *m_terms[i]) {
|
||||
sort* term_sort = entry.m_key;
|
||||
if (!sorts.contains(term_sort->get_small_id()))
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
ptr_vector<expr> null_ptr_vector;
|
||||
ptr_vector<expr> const &get_by_cost_and_sort(unsigned cost, sort *s) const {
|
||||
if (cost >= m_terms.size() || !m_terms[cost] || !m_terms[cost]->contains(s))
|
||||
return null_ptr_vector;
|
||||
return m_terms[cost]->find(s);
|
||||
}
|
||||
|
||||
std::ostream& display(std::ostream& out) const {
|
||||
for (unsigned cost = 0; cost < m_terms.size(); ++cost) {
|
||||
if (!m_terms[cost])
|
||||
continue;
|
||||
out << "cost " << cost << ":\n";
|
||||
for (auto& [s, terms] : *m_terms[cost]) {
|
||||
out << " sort " << mk_pp(s, m) << ":\n";
|
||||
for (expr* e : terms) {
|
||||
out << " #" << e->get_id() << " ";
|
||||
if (cost == 0) {
|
||||
out << mk_bounded_pp(e, m);
|
||||
}
|
||||
else if (is_app(e)) {
|
||||
app* a = to_app(e);
|
||||
out << a->get_decl()->get_name() << "(";
|
||||
bool first = true;
|
||||
for (expr* arg : *a) {
|
||||
if (!first) out << ", ";
|
||||
first = false;
|
||||
out << "#" << arg->get_id();
|
||||
}
|
||||
out << ")";
|
||||
}
|
||||
out << "\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
private:
|
||||
ast_manager& m;
|
||||
expr_ref_vector m_pinned;
|
||||
// cost -> sort -> terms
|
||||
ptr_vector<sort_term_map> m_terms;
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// Children Iterator - generates all combinations of child terms
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Iterates over all tuples (c1, c2, ..., cn) where each ci has the required
|
||||
* sort, drawn from the term bank, with at least one child at the current
|
||||
* cost - 1 (to avoid regenerating previously seen terms).
|
||||
*/
|
||||
class children_iterator {
|
||||
public:
|
||||
children_iterator(ast_manager& m, production const& prod, term_bank const& bank, unsigned current_cost)
|
||||
: m(m), m_prod(prod), m_current_cost(current_cost), m_done(false)
|
||||
{
|
||||
m_arity = prod.domain.size();
|
||||
if (m_arity == 0) {
|
||||
m_done = true;
|
||||
return;
|
||||
}
|
||||
for (unsigned i = 0; i < m_arity; ++i) {
|
||||
m_candidates.push_back(bank.get_by_sort(prod.domain[i], current_cost - 1));
|
||||
if (m_candidates.back().empty()) {
|
||||
m_done = true;
|
||||
return;
|
||||
}
|
||||
}
|
||||
m_indices.resize(m_arity, 0);
|
||||
}
|
||||
|
||||
bool has_next(unsigned cost) {
|
||||
while (!m_done) {
|
||||
if (has_child_at_cost(cost))
|
||||
return true;
|
||||
advance();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
expr_ref_vector next(unsigned& cost) {
|
||||
expr_ref_vector result(m);
|
||||
cost = 1;
|
||||
for (unsigned i = 0; i < m_arity; ++i) {
|
||||
auto [e, c] = m_candidates[i].get(m_indices[i]);
|
||||
cost += c;
|
||||
result.push_back(e);
|
||||
}
|
||||
advance();
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
ast_manager& m;
|
||||
production const& m_prod;
|
||||
unsigned m_current_cost;
|
||||
unsigned m_arity;
|
||||
bool m_done;
|
||||
vector<cost_terms> m_candidates;
|
||||
svector<unsigned> m_indices;
|
||||
|
||||
bool has_child_at_cost(unsigned cost) const {
|
||||
for (unsigned i = 0; i < m_arity; ++i) {
|
||||
auto [e, c] = m_candidates[i].get(m_indices[i]);
|
||||
if (c + 1 == cost)
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void advance() {
|
||||
for (auto i = m_arity; i-- > 0;) {
|
||||
m_indices[i]++;
|
||||
if (m_indices[i] < m_candidates[i].size()) return;
|
||||
m_indices[i] = 0;
|
||||
}
|
||||
m_done = true;
|
||||
}
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// bottom_up_enumerator - the main bottom-up term enumeration engine
|
||||
// ============================================================================
|
||||
|
||||
|
||||
class bottom_up_enumerator {
|
||||
public:
|
||||
bottom_up_enumerator(grammar& grammar)
|
||||
: m_grammar(grammar), m(grammar.mgr()),
|
||||
m_bank(grammar.mgr()), m_pending(grammar.mgr()), m_rewriter(grammar.mgr())
|
||||
{}
|
||||
|
||||
void set_target_sort(sort *s) {
|
||||
m_target_sort = s;
|
||||
}
|
||||
bool has_next() {
|
||||
if (m_pending) return true;
|
||||
m_pending = find_next();
|
||||
return m_pending != nullptr;
|
||||
}
|
||||
|
||||
expr_ref next() {
|
||||
if (!m_pending)
|
||||
m_pending = find_next();
|
||||
expr_ref result(m_pending, m);
|
||||
m_pending = nullptr;
|
||||
return result;
|
||||
}
|
||||
|
||||
term_bank const& bank() const { return m_bank; }
|
||||
|
||||
std::ostream& display(std::ostream& out) const {
|
||||
m_grammar.display(out);
|
||||
return m_bank.display(out);
|
||||
}
|
||||
|
||||
void reset() {
|
||||
m_cost = 0;
|
||||
m_leaf_idx = 0;
|
||||
m_op_idx = 0;
|
||||
m_state = State::Leaves;
|
||||
m_bank.reset();
|
||||
m_pending = nullptr;
|
||||
m_rewriter.reset();
|
||||
m_seen_terms.reset();
|
||||
m_children_iter.reset();
|
||||
}
|
||||
|
||||
expr* add_term(expr_ref const& term, unsigned cost) {
|
||||
expr_ref simplified(m);
|
||||
m_rewriter(term, simplified);
|
||||
if (m_seen_terms.contains(simplified))
|
||||
return nullptr;
|
||||
m_seen_terms.insert(simplified);
|
||||
m_bank.add(simplified, cost);
|
||||
return simplified;
|
||||
}
|
||||
|
||||
private:
|
||||
enum class State { Leaves, Operators, Done };
|
||||
|
||||
grammar& m_grammar;
|
||||
ast_manager& m;
|
||||
term_bank m_bank;
|
||||
unsigned m_cost = 0;
|
||||
unsigned m_leaf_idx = 0;
|
||||
unsigned m_op_idx = 0;
|
||||
unsigned m_bank_idx = 0;
|
||||
unsigned m_bank_size = 0;
|
||||
bool m_made_progress = false;
|
||||
uint_set m_sorts_produced;
|
||||
State m_state = State::Leaves;
|
||||
expr_ref m_pending;
|
||||
th_rewriter m_rewriter;
|
||||
obj_hashtable<expr> m_seen_terms;
|
||||
std::unique_ptr<children_iterator> m_children_iter;
|
||||
sort *m_target_sort = nullptr;
|
||||
|
||||
bool sort_matches(expr* e) const {
|
||||
return !m_target_sort || e->get_sort() == m_target_sort;
|
||||
}
|
||||
|
||||
expr* find_next() {
|
||||
while (true) {
|
||||
switch (m_state) {
|
||||
case State::Leaves:
|
||||
while (m_leaf_idx < m_grammar.leaves().size()) {
|
||||
production const &prod = *m_grammar.leaves()[m_leaf_idx];
|
||||
m_leaf_idx++;
|
||||
expr_ref_vector empty_args(m);
|
||||
expr_ref term = prod.builder(empty_args);
|
||||
expr* r = add_term(term, 0);
|
||||
if (r && sort_matches(r))
|
||||
return r;
|
||||
}
|
||||
m_state = State::Operators;
|
||||
m_cost = 1;
|
||||
m_op_idx = 0;
|
||||
m_bank_idx = 0;
|
||||
m_bank_size = get_bank_size();
|
||||
m_made_progress = false;
|
||||
m_sorts_produced.reset();
|
||||
m_children_iter.reset();
|
||||
break;
|
||||
|
||||
case State::Operators: {
|
||||
expr* result = enumerate_operators();
|
||||
if (result)
|
||||
return result;
|
||||
|
||||
m_cost++;
|
||||
m_op_idx = 0;
|
||||
m_bank_idx = 0;
|
||||
m_bank_size = get_bank_size();
|
||||
m_children_iter.reset();
|
||||
if (!m_made_progress && !m_bank.is_productive(m_cost, m_sorts_produced)) {
|
||||
m_state = State::Done;
|
||||
return nullptr;
|
||||
}
|
||||
if (m_sorts_produced.contains(m_target_sort->get_small_id()))
|
||||
m_sorts_produced.reset();
|
||||
m_made_progress = false;
|
||||
break;
|
||||
}
|
||||
case State::Done:
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsigned get_bank_size() const {
|
||||
auto const &terms = m_bank.get_by_cost_and_sort(m_cost, m_target_sort);
|
||||
return terms.size();
|
||||
}
|
||||
|
||||
expr *enumerate_operators() {
|
||||
auto const &ops = m_grammar.operators();
|
||||
while (true) {
|
||||
|
||||
// first find terms at m_cost that were already created
|
||||
if (m_bank_idx < m_bank_size) {
|
||||
auto const &terms = m_bank.get_by_cost_and_sort(m_cost, m_target_sort);
|
||||
auto t = terms.get(m_bank_idx);
|
||||
m_bank_idx++;
|
||||
SASSERT(sort_matches(t));
|
||||
return t;
|
||||
}
|
||||
|
||||
// then create new terms using children at cost below current m_cost.
|
||||
if (m_children_iter && m_children_iter->has_next(m_cost)) {
|
||||
unsigned new_cost = 0;
|
||||
expr_ref_vector children = m_children_iter->next(new_cost);
|
||||
production const &prod = *ops[m_op_idx - 1];
|
||||
expr_ref term = prod.builder(children);
|
||||
// IF_VERBOSE(0, verbose_stream() << term << "\n");
|
||||
SASSERT(new_cost >= m_cost);
|
||||
expr* r = add_term(term, new_cost);
|
||||
if (!r)
|
||||
continue;
|
||||
unsigned sort_id = r->get_sort()->get_small_id();
|
||||
if (!m_sorts_produced.contains(sort_id))
|
||||
m_made_progress = true;
|
||||
m_sorts_produced.insert(sort_id);
|
||||
if (sort_matches(r) && new_cost == m_cost) {
|
||||
return r;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (m_op_idx >= ops.size())
|
||||
return nullptr;
|
||||
|
||||
production const &prod = *ops[m_op_idx];
|
||||
m_op_idx++;
|
||||
m_children_iter = std::make_unique<children_iterator>(m, prod, m_bank, m_cost);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace term_enum
|
||||
|
||||
// ============================================================================
|
||||
// term_enumeration public interface implementation
|
||||
// ============================================================================
|
||||
|
||||
struct term_enumeration::imp {
|
||||
ast_manager& m;
|
||||
term_enum::grammar m_grammar;
|
||||
term_enum::bottom_up_enumerator m_bottom_up_enumerator;
|
||||
std::function<unsigned(expr*)> m_cost;
|
||||
|
||||
imp(ast_manager& m) :
|
||||
m(m), m_grammar(m), m_bottom_up_enumerator(m_grammar) {}
|
||||
|
||||
void add_production(func_decl* f) {
|
||||
m_grammar.add_func_decl(f);
|
||||
}
|
||||
|
||||
void add_production(expr* e) {
|
||||
m_grammar.add_expr(e);
|
||||
}
|
||||
|
||||
void set_cost(std::function<unsigned(expr*)> const& cost) {
|
||||
// TODO
|
||||
}
|
||||
|
||||
std::ostream& display(std::ostream& out) const {
|
||||
return m_bottom_up_enumerator.display(out);
|
||||
}
|
||||
};
|
||||
|
||||
// -- iterator implementation --
|
||||
|
||||
struct term_enumeration::iterator::iter_imp {
|
||||
imp& m_imp;
|
||||
ast_manager & m;
|
||||
sort* m_sort;
|
||||
unsigned m_cost = 0;
|
||||
unsigned m_idx = 0;
|
||||
vector<expr_ref_vector> m_levels;
|
||||
expr_ref m_current;
|
||||
bool m_end;
|
||||
vector<expr_ref_vector> m_vars;
|
||||
vector<ptr_vector<sort>> m_decls;
|
||||
vector<vector<symbol>> m_names;
|
||||
|
||||
iter_imp(imp& i, sort* s) : m_imp(i), m(i.m), m_sort(s), m_current(i.m), m_end(false) {
|
||||
m_imp.m_bottom_up_enumerator.reset();
|
||||
init_sort();
|
||||
advance();
|
||||
}
|
||||
|
||||
// Sentinel constructor
|
||||
iter_imp(imp& i) :
|
||||
m_imp(i), m(i.m), m_sort(nullptr), m_current(i.m), m_end(true) {
|
||||
UNREACHABLE();
|
||||
}
|
||||
|
||||
|
||||
void init_sort() {
|
||||
array_util autil(m);
|
||||
sort *range = m_sort;
|
||||
|
||||
while (autil.is_array(range)) {
|
||||
m_vars.push_back(expr_ref_vector(m));
|
||||
m_decls.push_back(ptr_vector<sort>());
|
||||
m_names.push_back(vector<symbol>());
|
||||
for (unsigned i = 0; i < get_array_arity(range); ++i) {
|
||||
m_decls.back().push_back(get_array_domain(range, i));
|
||||
m_vars.back().push_back(nullptr);
|
||||
m_names.back().push_back(symbol());
|
||||
}
|
||||
|
||||
expr_ref_vector args(m);
|
||||
args.push_back(m.mk_const("a", range));
|
||||
for (unsigned i = 0; i < m_decls.back().size(); ++i) {
|
||||
args.push_back(m.mk_var(i, m_decls.back().get(i)));
|
||||
}
|
||||
app_ref sel(autil.mk_select(args), m);
|
||||
m_imp.m_grammar.add_func_decl(sel->get_decl());
|
||||
|
||||
range = get_array_range(range);
|
||||
}
|
||||
unsigned n = 0;
|
||||
for (unsigned i = m_decls.size(); i-- > 0;) {
|
||||
for (unsigned j = m_decls[i].size(); j-- > 0;) {
|
||||
m_vars[i][j] = m.mk_var(n, m_decls[i][j]);
|
||||
m_names[i][j] = symbol(n);
|
||||
m_imp.add_production(m_vars[i].get(j));
|
||||
n++;
|
||||
}
|
||||
}
|
||||
m_sort = range;
|
||||
m_imp.m_bottom_up_enumerator.set_target_sort(range);
|
||||
}
|
||||
|
||||
void mk_lambda() {
|
||||
if (!m_current)
|
||||
return;
|
||||
for (unsigned i = m_decls.size(); i-- > 0;)
|
||||
m_current = m.mk_lambda(m_decls[i].size(), m_decls[i].data(), m_names[i].data(), m_current);
|
||||
}
|
||||
|
||||
void advance() {
|
||||
if (m_end)
|
||||
return;
|
||||
m_current = m_imp.m_bottom_up_enumerator.next();
|
||||
SASSERT(!m_current || m_current->get_sort() == m_sort);
|
||||
mk_lambda();
|
||||
if (!m_current)
|
||||
m_end = true;
|
||||
}
|
||||
};
|
||||
|
||||
term_enumeration::iterator::iterator(imp& i, sort* s) {
|
||||
m_imp = alloc(iter_imp, i, s);
|
||||
}
|
||||
|
||||
term_enumeration::iterator::iterator(std::nullptr_t) {
|
||||
m_imp = nullptr;
|
||||
}
|
||||
|
||||
term_enumeration::iterator::~iterator() {
|
||||
dealloc(m_imp);
|
||||
}
|
||||
|
||||
expr* term_enumeration::iterator::operator*() {
|
||||
return m_imp ? m_imp->m_current.get() : nullptr;
|
||||
}
|
||||
|
||||
term_enumeration::iterator& term_enumeration::iterator::operator++() {
|
||||
if (m_imp) m_imp->advance();
|
||||
return *this;
|
||||
}
|
||||
|
||||
term_enumeration::iterator term_enumeration::iterator::operator++(int) {
|
||||
iterator tmp(*this);
|
||||
++(*this);
|
||||
return tmp;
|
||||
}
|
||||
|
||||
bool term_enumeration::iterator::operator==(iterator const& other) const {
|
||||
if (!m_imp && !other.m_imp) return true;
|
||||
if (!m_imp) return other.m_imp->m_end;
|
||||
if (!other.m_imp) return m_imp->m_end;
|
||||
return m_imp->m_end == other.m_imp->m_end &&
|
||||
m_imp->m_current == other.m_imp->m_current;
|
||||
}
|
||||
|
||||
// -- terms implementation --
|
||||
|
||||
term_enumeration::terms::terms(imp* i, sort* s) : m_imp(i), m_sort(s) {}
|
||||
|
||||
term_enumeration::iterator term_enumeration::terms::begin() {
|
||||
return iterator(*m_imp, m_sort);
|
||||
}
|
||||
|
||||
term_enumeration::iterator term_enumeration::terms::end() {
|
||||
return iterator(nullptr);
|
||||
}
|
||||
|
||||
// -- term_enumeration implementation --
|
||||
|
||||
term_enumeration::term_enumeration(ast_manager& m) {
|
||||
m_imp = alloc(imp, m);
|
||||
}
|
||||
|
||||
term_enumeration::~term_enumeration() {
|
||||
dealloc(m_imp);
|
||||
}
|
||||
|
||||
void term_enumeration::add_production(func_decl* f) {
|
||||
m_imp->add_production(f);
|
||||
}
|
||||
|
||||
void term_enumeration::add_production(expr* e) {
|
||||
m_imp->add_production(e);
|
||||
}
|
||||
|
||||
void term_enumeration::set_cost(std::function<unsigned(expr*)> const& cost) {
|
||||
m_imp->set_cost(cost);
|
||||
}
|
||||
|
||||
term_enumeration::terms term_enumeration::enum_terms(sort* s) {
|
||||
return terms(m_imp, s);
|
||||
}
|
||||
|
||||
std::ostream& term_enumeration::display(std::ostream& out) const {
|
||||
return m_imp->display(out);
|
||||
}
|
||||
50
src/ast/rewriter/term_enumeration.h
Normal file
50
src/ast/rewriter/term_enumeration.h
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
#pragma once
|
||||
|
||||
#include "ast/ast.h"
|
||||
#include <functional>
|
||||
|
||||
class term_enumeration {
|
||||
struct imp;
|
||||
imp* m_imp;
|
||||
public:
|
||||
term_enumeration(ast_manager& m);
|
||||
~term_enumeration();
|
||||
|
||||
void add_production(func_decl* f);
|
||||
void add_production(expr* e);
|
||||
// void add_production(sort *s, std::function<expr *()> g);
|
||||
|
||||
// cost function associated with expressions.
|
||||
// terms are enumerated with increasing cost.
|
||||
|
||||
void set_cost(std::function<unsigned(expr*)> const& cost);
|
||||
|
||||
class iterator {
|
||||
struct iter_imp;
|
||||
iter_imp* m_imp;
|
||||
public:
|
||||
iterator(imp& i, sort* s);
|
||||
iterator(std::nullptr_t);
|
||||
~iterator();
|
||||
expr* operator*();
|
||||
iterator operator++(int);
|
||||
iterator& operator++();
|
||||
bool operator!=(iterator const& other) const {
|
||||
return !(*this == other);
|
||||
}
|
||||
bool operator==(iterator const &other) const;
|
||||
};
|
||||
|
||||
class terms {
|
||||
imp* m_imp;
|
||||
sort* m_sort;
|
||||
public:
|
||||
terms(imp* i, sort* s);
|
||||
iterator begin();
|
||||
iterator end();
|
||||
};
|
||||
|
||||
terms enum_terms(sort* s);
|
||||
|
||||
std::ostream& display(std::ostream& out) const;
|
||||
};
|
||||
|
|
@ -513,7 +513,7 @@ void non_auf_macro_solver::collect_candidates(ptr_vector<quantifier> const& qs,
|
|||
TRACE(model_finder, tout << "considering macro for: " << f->get_name() << "\n";
|
||||
m->display(tout); tout << "\n";);
|
||||
if (m->is_unconditional() && (!qi->is_auf() || m->get_weight() >= m_mbqi_force_template)) {
|
||||
full_macros.insert(f, std::make_pair(m, q));
|
||||
full_macros.insert(f, {m, q});
|
||||
cond_macros.erase(f);
|
||||
}
|
||||
else if (!full_macros.contains(f) && !qi->is_auf())
|
||||
|
|
@ -524,10 +524,8 @@ void non_auf_macro_solver::collect_candidates(ptr_vector<quantifier> const& qs,
|
|||
}
|
||||
|
||||
void non_auf_macro_solver::process_full_macros(obj_map<func_decl, mq_pair> const& full_macros, obj_hashtable<quantifier>& removed) {
|
||||
for (auto const& kv : full_macros) {
|
||||
func_decl* f = kv.m_key;
|
||||
cond_macro* m = kv.m_value.first;
|
||||
quantifier* q = kv.m_value.second;
|
||||
for (auto const &[f, v] : full_macros) {
|
||||
auto [m, q] = v;
|
||||
SASSERT(m->is_unconditional());
|
||||
if (add_macro(f, m->get_def())) {
|
||||
get_qinfo(q)->set_the_one(f);
|
||||
|
|
|
|||
|
|
@ -219,7 +219,7 @@ namespace smt {
|
|||
|
||||
if (use_inv) {
|
||||
unsigned sk_term_gen = 0;
|
||||
expr * sk_term = m_model_finder.get_inv(q, i, sk_value, sk_term_gen);
|
||||
expr * sk_term = m_model_finder.get_inv(q, i, sk_value, *cex, sk_term_gen);
|
||||
if (sk_term != nullptr) {
|
||||
TRACE(model_checker, tout << "Found inverse " << mk_pp(sk_term, m) << "\n";);
|
||||
SASSERT(!m.is_model_value(sk_term));
|
||||
|
|
@ -238,10 +238,6 @@ namespace smt {
|
|||
TRACE(model_checker, tout << "sk term " << mk_pp(sk_term, m) << "\n");
|
||||
sk_value = sk_term;
|
||||
}
|
||||
// last ditch: am I an array?
|
||||
else if (false && autil.is_as_array(sk_value, f) && cex->get_func_interp(f) && cex->get_func_interp(f)->get_array_interp(f)) {
|
||||
sk_value = cex->get_func_interp(f)->get_array_interp(f);
|
||||
}
|
||||
|
||||
}
|
||||
if (contains_model_value(sk_value)) {
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ Revision History:
|
|||
--*/
|
||||
#include "util/backtrackable_set.h"
|
||||
#include "ast/ast_util.h"
|
||||
#include "ast/has_free_vars.h"
|
||||
#include "ast/macros/macro_util.h"
|
||||
#include "ast/arith_decl_plugin.h"
|
||||
#include "ast/bv_decl_plugin.h"
|
||||
|
|
@ -31,6 +32,7 @@ Revision History:
|
|||
#include "ast/ast_ll_pp.h"
|
||||
#include "ast/well_sorted.h"
|
||||
#include "ast/ast_smt2_pp.h"
|
||||
#include "ast/rewriter/term_enumeration.h"
|
||||
#include "model/model_pp.h"
|
||||
#include "model/model_macro_solver.h"
|
||||
#include "smt/smt_model_finder.h"
|
||||
|
|
@ -107,9 +109,15 @@ namespace smt {
|
|||
}
|
||||
}
|
||||
|
||||
expr* get_inv(expr* v) const {
|
||||
expr* get_inv(expr* v, model& mdl) const {
|
||||
expr* t = nullptr;
|
||||
m_inv.find(v, t);
|
||||
if (!t) {
|
||||
for (auto [k, term] : m_inv) {
|
||||
if (mdl.are_equal(k, v))
|
||||
return term;
|
||||
}
|
||||
}
|
||||
return t;
|
||||
}
|
||||
|
||||
|
|
@ -120,14 +128,11 @@ namespace smt {
|
|||
}
|
||||
|
||||
void mk_inverse(evaluator& ev) {
|
||||
for (auto const& kv : m_elems) {
|
||||
expr* t = kv.m_key;
|
||||
for (auto const &[t, gen] : m_elems) {
|
||||
SASSERT(!contains_model_value(t));
|
||||
unsigned gen = kv.m_value;
|
||||
expr* t_val = ev.eval(t, true);
|
||||
if (!t_val) break;
|
||||
TRACE(model_finder, tout << mk_pp(t, m) << " " << mk_pp(t_val, m) << "\n";);
|
||||
|
||||
expr* old_t = nullptr;
|
||||
if (m_inv.find(t_val, old_t)) {
|
||||
unsigned old_t_gen = 0;
|
||||
|
|
@ -187,14 +192,14 @@ namespace smt {
|
|||
\brief Base class used to solve model construction constraints.
|
||||
*/
|
||||
class node {
|
||||
unsigned m_id;
|
||||
node* m_find{ nullptr };
|
||||
unsigned m_eqc_size{ 1 };
|
||||
unsigned m_id = 0;
|
||||
node* m_find = nullptr;
|
||||
unsigned m_eqc_size = 1;
|
||||
|
||||
sort* m_sort; // sort of the elements in the instantiation set.
|
||||
sort* m_sort = nullptr; // sort of the elements in the instantiation set.
|
||||
|
||||
bool m_mono_proj{ false }; // relevant for integers & reals & bit-vectors
|
||||
bool m_signed_proj{ false }; // relevant for bit-vectors.
|
||||
bool m_mono_proj = false; // relevant for integers & reals & bit-vectors
|
||||
bool m_signed_proj = false; // relevant for bit-vectors.
|
||||
ptr_vector<node> m_avoid_set;
|
||||
ptr_vector<expr> m_exceptions;
|
||||
|
||||
|
|
@ -291,7 +296,7 @@ namespace smt {
|
|||
}
|
||||
|
||||
void insert(expr* n, unsigned generation) {
|
||||
if (is_ground(n))
|
||||
if (is_ground(n) || (has_quantifiers(n) && !has_free_vars(n))) // this is a closed term
|
||||
get_root()->m_set->insert(n, generation);
|
||||
}
|
||||
|
||||
|
|
@ -599,7 +604,10 @@ namespace smt {
|
|||
}
|
||||
else {
|
||||
r = tmp;
|
||||
TRACE(model_finder, tout << "eval\n" << mk_pp(n, m) << "\n----->\n" << mk_pp(r, m) << "\n";);
|
||||
TRACE(model_finder, tout << "eval-failed\n" << mk_pp(n, m) << "\n----->\n" << mk_pp(r, m) << "\n";);
|
||||
if (is_lambda(tmp)) {
|
||||
r = m.mk_fresh_const("lambda", tmp->get_sort());
|
||||
}
|
||||
}
|
||||
m_eval_cache[model_completion].insert(n, r);
|
||||
m_eval_cache_range.push_back(r);
|
||||
|
|
@ -1235,8 +1243,8 @@ namespace smt {
|
|||
void populate_inst_sets(quantifier* q, func_decl* mhead, ptr_vector<instantiation_set>& uvar_inst_sets, context* ctx) override {
|
||||
if (m_f != mhead)
|
||||
return;
|
||||
uvar_inst_sets.reserve(m_var_j + 1, 0);
|
||||
if (uvar_inst_sets[m_var_j] == 0)
|
||||
uvar_inst_sets.reserve(m_var_j + 1, nullptr);
|
||||
if (uvar_inst_sets[m_var_j] == nullptr)
|
||||
uvar_inst_sets[m_var_j] = alloc(instantiation_set, ctx->get_manager());
|
||||
instantiation_set* s = uvar_inst_sets[m_var_j];
|
||||
SASSERT(s != nullptr);
|
||||
|
|
@ -1369,6 +1377,74 @@ namespace smt {
|
|||
|
||||
};
|
||||
|
||||
class ho_var : public qinfo {
|
||||
unsigned m_var_i;
|
||||
public:
|
||||
ho_var(ast_manager& m, unsigned i) : qinfo(m), m_var_i(i) {
|
||||
}
|
||||
|
||||
char const *get_kind() const override {
|
||||
return "ho_var";
|
||||
}
|
||||
|
||||
bool is_equal(qinfo const *qi) const override {
|
||||
if (qi->get_kind() != get_kind())
|
||||
return false;
|
||||
ho_var const *other = static_cast<ho_var const *>(qi);
|
||||
return m_var_i == other->m_var_i;
|
||||
}
|
||||
|
||||
void display(std::ostream &out) const override {
|
||||
out << "(" << "ho-var: " << m_var_i << ")";
|
||||
}
|
||||
|
||||
void process_auf(quantifier *q, auf_solver &s, context *ctx) override {
|
||||
/* node * S_i = */ s.get_uvar(q, m_var_i);
|
||||
}
|
||||
|
||||
void populate_inst_sets(quantifier *q, auf_solver &s, context *ctx) override {
|
||||
node *S = s.get_uvar(q, m_var_i);
|
||||
sort *srt = S->get_sort();
|
||||
|
||||
IF_VERBOSE(3, verbose_stream() << "ho_var::populate_inst_sets: " << q->get_id() << " " << mk_pp(srt, m) << "\n";);
|
||||
term_enumeration tn(m);
|
||||
// Add ground terms of type S.
|
||||
// Add productions for functions in E-graph
|
||||
// add other possible relevant functions such as equality over srt, Boolean operators
|
||||
|
||||
ast_mark visited;
|
||||
for (enode *n : ctx->enodes()) {
|
||||
if (!ctx->is_relevant(n))
|
||||
continue;
|
||||
auto e = n->get_expr();
|
||||
if (srt == n->get_sort()) {
|
||||
TRACE(model_finder, tout << "inserting " << mk_pp(e, m) << " into inst set\n");
|
||||
S->insert(e, n->get_generation());
|
||||
}
|
||||
else if (is_uninterp_const(e)) {
|
||||
TRACE(model_finder, tout << "add production " << mk_pp(e, m) << "\n");
|
||||
tn.add_production(e);
|
||||
}
|
||||
else if (is_uninterp(e)) {
|
||||
auto f = to_app(e)->get_decl();
|
||||
if (visited.is_marked(f))
|
||||
continue;
|
||||
visited.mark(f, true);
|
||||
TRACE(model_finder, tout << "add function " << mk_pp(f, m) << "\n");
|
||||
tn.add_production(f);
|
||||
}
|
||||
}
|
||||
|
||||
unsigned max_count = 20;
|
||||
for (auto t : tn.enum_terms(srt)) {
|
||||
unsigned generation = 0; // todo - inherited from sub-term of t?
|
||||
TRACE(model_finder, tout << "ho_var: adding term " << mk_ismt2_pp(t, m)
|
||||
<< " to instantiation set of S" << std::endl;);
|
||||
S->insert(t, generation);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/**
|
||||
\brief auf_arr is a term (pattern) of the form:
|
||||
|
|
@ -2105,7 +2181,12 @@ namespace smt {
|
|||
process_app(to_app(curr));
|
||||
}
|
||||
else if (is_var(curr)) {
|
||||
m_info->m_is_auf = false; // unexpected occurrence of variable.
|
||||
if (m_array_util.is_array(curr)) {
|
||||
insert_qinfo(alloc(ho_var, m, to_var(curr)->get_idx()));
|
||||
}
|
||||
else {
|
||||
m_info->m_is_auf = false; // unexpected occurrence of variable.
|
||||
}
|
||||
}
|
||||
else {
|
||||
SASSERT(is_lambda(curr));
|
||||
|
|
@ -2520,11 +2601,12 @@ namespace smt {
|
|||
|
||||
Store in generation the generation of the result
|
||||
*/
|
||||
expr* model_finder::get_inv(quantifier* q, unsigned i, expr* val, unsigned& generation) {
|
||||
expr* model_finder::get_inv(quantifier* q, unsigned i, expr* val, model& mdl,unsigned& generation) {
|
||||
instantiation_set const* s = get_uvar_inst_set(q, i);
|
||||
if (s == nullptr)
|
||||
return nullptr;
|
||||
expr* t = s->get_inv(val);
|
||||
|
||||
expr* t = s->get_inv(val, mdl);
|
||||
if (m_auf_solver->is_default_representative(t))
|
||||
return val;
|
||||
if (t != nullptr) {
|
||||
|
|
@ -2560,16 +2642,27 @@ namespace smt {
|
|||
obj_map<expr, expr*> const& inv = s->get_inv_map();
|
||||
if (inv.empty())
|
||||
continue; // nothing to do
|
||||
ptr_buffer<expr> eqs;
|
||||
for (auto const& [val, _] : inv) {
|
||||
if (val->get_sort() == sk->get_sort())
|
||||
eqs.push_back(m.mk_eq(sk, val));
|
||||
expr_ref_vector eqs(m), defs(m);
|
||||
|
||||
for (auto const& [val, term] : inv) {
|
||||
if (val->get_sort() == sk->get_sort()) {
|
||||
if (is_lambda(term)) {
|
||||
eqs.push_back(m.mk_eq(sk, val));
|
||||
defs.push_back(m.mk_eq(val, term));
|
||||
}
|
||||
else
|
||||
eqs.push_back(m.mk_eq(sk, val));
|
||||
}
|
||||
}
|
||||
if (!eqs.empty()) {
|
||||
expr_ref new_cnstr(m);
|
||||
new_cnstr = m.mk_or(eqs);
|
||||
TRACE(model_finder, tout << "assert_restriction:\n" << mk_pp(new_cnstr, m) << "\n";);
|
||||
aux_ctx->assert_expr(new_cnstr);
|
||||
for (auto def : defs) {
|
||||
TRACE(model_finder, tout << "assert_def:\n" << mk_pp(def, m) << "\n";);
|
||||
aux_ctx->assert_expr(def);
|
||||
}
|
||||
asserted_something = true;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -113,7 +113,7 @@ namespace smt {
|
|||
void fix_model(proto_model * m);
|
||||
|
||||
quantifier * get_flat_quantifier(quantifier * q);
|
||||
expr * get_inv(quantifier * q, unsigned i, expr * val, unsigned & generation);
|
||||
expr * get_inv(quantifier * q, unsigned i, expr * val, model& m, unsigned & generation);
|
||||
bool restrict_sks_to_inst_set(context * aux_ctx, quantifier * q, expr_ref_vector const & sks);
|
||||
|
||||
void restart_eh();
|
||||
|
|
|
|||
|
|
@ -145,6 +145,7 @@ add_executable(test-z3
|
|||
symbol.cpp
|
||||
symbol_table.cpp
|
||||
tbv.cpp
|
||||
term_enumeration.cpp
|
||||
theory_dl.cpp
|
||||
theory_pb.cpp
|
||||
timeout.cpp
|
||||
|
|
|
|||
|
|
@ -195,6 +195,7 @@
|
|||
X(finite_set) \
|
||||
X(finite_set_rewriter) \
|
||||
X(fpa) \
|
||||
X(term_enumeration) \
|
||||
X(lcube)
|
||||
|
||||
#define FOR_EACH_TEST(X, X_ARGV) \
|
||||
|
|
|
|||
309
src/test/term_enumeration.cpp
Normal file
309
src/test/term_enumeration.cpp
Normal file
|
|
@ -0,0 +1,309 @@
|
|||
/*++
|
||||
Copyright (c) 2024 Microsoft Corporation
|
||||
|
||||
Module Name:
|
||||
|
||||
tst_term_enumeration.cpp
|
||||
|
||||
Abstract:
|
||||
|
||||
Test term enumeration module
|
||||
|
||||
--*/
|
||||
|
||||
|
||||
#include "ast/rewriter/term_enumeration.h"
|
||||
#include "ast/ast_pp.h"
|
||||
#include "ast/arith_decl_plugin.h"
|
||||
#include "ast/bv_decl_plugin.h"
|
||||
#include "ast/array_decl_plugin.h"
|
||||
#include "ast/reg_decl_plugins.h"
|
||||
#include "ast/rewriter/th_rewriter.h"
|
||||
#include "util/obj_hashtable.h"
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
static void tst_basic_enumeration() {
|
||||
std::cout << "=== test basic enumeration ===\n";
|
||||
ast_manager m;
|
||||
reg_decl_plugins(m);
|
||||
arith_util a(m);
|
||||
|
||||
term_enumeration te(m);
|
||||
|
||||
// Add some leaf productions (constants)
|
||||
expr_ref zero(a.mk_int(0), m);
|
||||
expr_ref one(a.mk_int(1), m);
|
||||
te.add_production(zero);
|
||||
te.add_production(one);
|
||||
|
||||
// Enumerate terms of Int sort
|
||||
sort* int_sort = a.mk_int();
|
||||
unsigned count = 0;
|
||||
for (expr* e : te.enum_terms(int_sort)) {
|
||||
std::cout << "Term: " << mk_pp(e, m) << "\n";
|
||||
count++;
|
||||
if (count >= 5) break; // Limit output
|
||||
}
|
||||
|
||||
ENSURE(count >= 2); // At least 0 and 1
|
||||
std::cout << "Enumerated " << count << " terms\n";
|
||||
}
|
||||
|
||||
static void tst_enumeration_with_operators() {
|
||||
std::cout << "=== test enumeration with operators ===\n";
|
||||
ast_manager m;
|
||||
reg_decl_plugins(m);
|
||||
arith_util a(m);
|
||||
|
||||
term_enumeration te(m);
|
||||
|
||||
// Add leaf productions
|
||||
expr_ref zero(a.mk_int(0), m);
|
||||
expr_ref one(a.mk_int(1), m);
|
||||
te.add_production(zero);
|
||||
te.add_production(one);
|
||||
|
||||
// Add operator productions (+ and *)
|
||||
// Get func_decl by creating an app and extracting the decl
|
||||
app_ref tmp_add(a.mk_add(zero, one), m);
|
||||
app_ref tmp_mul(a.mk_mul(zero, one), m);
|
||||
func_decl* add_decl = tmp_add->get_decl();
|
||||
func_decl* mul_decl = tmp_mul->get_decl();
|
||||
te.add_production(add_decl);
|
||||
te.add_production(mul_decl);
|
||||
|
||||
sort* int_sort = a.mk_int();
|
||||
unsigned count = 0;
|
||||
for (expr* e : te.enum_terms(int_sort)) {
|
||||
std::cout << "Term: " << mk_pp(e, m) << "\n";
|
||||
count++;
|
||||
if (count >= 20) break; // Limit output
|
||||
}
|
||||
|
||||
ENSURE(count >= 2); // At least the leaves
|
||||
std::cout << "Enumerated " << count << " terms with operators\n";
|
||||
}
|
||||
|
||||
static void tst_observational_equivalence_filter() {
|
||||
std::cout << "=== test observational equivalence filter ===\n";
|
||||
ast_manager m;
|
||||
reg_decl_plugins(m);
|
||||
arith_util a(m);
|
||||
th_rewriter rw(m);
|
||||
|
||||
term_enumeration te(m);
|
||||
|
||||
expr_ref zero(a.mk_int(0), m);
|
||||
expr_ref one(a.mk_int(1), m);
|
||||
te.add_production(zero);
|
||||
te.add_production(one);
|
||||
|
||||
app_ref tmp_add(a.mk_add(zero, one), m);
|
||||
te.add_production(tmp_add->get_decl());
|
||||
|
||||
sort* int_sort = a.mk_int();
|
||||
obj_hashtable<expr> seen;
|
||||
unsigned count = 0;
|
||||
for (expr* e : te.enum_terms(int_sort)) {
|
||||
expr_ref r(m);
|
||||
rw(e, r);
|
||||
ENSURE(r == e);
|
||||
ENSURE(!seen.contains(r));
|
||||
seen.insert(r);
|
||||
count++;
|
||||
if (count >= 20) break;
|
||||
}
|
||||
|
||||
ENSURE(count >= 2);
|
||||
}
|
||||
|
||||
static void tst_display() {
|
||||
std::cout << "=== test display ===\n";
|
||||
ast_manager m;
|
||||
reg_decl_plugins(m);
|
||||
arith_util a(m);
|
||||
|
||||
term_enumeration te(m);
|
||||
|
||||
// Add leaf productions
|
||||
expr_ref zero(a.mk_int(0), m);
|
||||
expr_ref one(a.mk_int(1), m);
|
||||
te.add_production(zero);
|
||||
te.add_production(one);
|
||||
|
||||
// Add operator productions
|
||||
app_ref tmp_add(a.mk_add(zero, one), m);
|
||||
func_decl* add_decl = tmp_add->get_decl();
|
||||
te.add_production(add_decl);
|
||||
|
||||
sort* int_sort = a.mk_int();
|
||||
unsigned count = 0;
|
||||
for (expr* e : te.enum_terms(int_sort)) {
|
||||
(void)e;
|
||||
count++;
|
||||
if (count >= 10) break;
|
||||
}
|
||||
|
||||
std::cout << "Internal state after enumeration:\n";
|
||||
std::ostringstream oss;
|
||||
te.display(oss);
|
||||
std::cout << oss.str();
|
||||
|
||||
// Verify display produced some output
|
||||
ENSURE(!oss.str().empty());
|
||||
}
|
||||
|
||||
static void tst_bitvector_enumeration() {
|
||||
std::cout << "=== test bitvector enumeration ===\n";
|
||||
ast_manager m;
|
||||
reg_decl_plugins(m);
|
||||
bv_util bv(m);
|
||||
|
||||
term_enumeration te(m);
|
||||
|
||||
// Add bitvector constants
|
||||
unsigned bv_size = 8;
|
||||
expr_ref bv_zero(bv.mk_numeral(0, bv_size), m);
|
||||
expr_ref bv_one(bv.mk_numeral(1, bv_size), m);
|
||||
te.add_production(bv_zero);
|
||||
te.add_production(bv_one);
|
||||
|
||||
// Add bvadd operator
|
||||
app_ref tmp_add(bv.mk_bv_add(bv_zero, bv_one), m);
|
||||
func_decl* bvadd = tmp_add->get_decl();
|
||||
te.add_production(bvadd);
|
||||
|
||||
sort* bv8 = bv.mk_sort(bv_size);
|
||||
unsigned count = 0;
|
||||
for (expr* e : te.enum_terms(bv8)) {
|
||||
std::cout << "BV Term: " << mk_pp(e, m) << "\n";
|
||||
count++;
|
||||
if (count >= 10) break;
|
||||
}
|
||||
|
||||
ENSURE(count >= 2);
|
||||
std::cout << "Enumerated " << count << " bitvector terms\n";
|
||||
}
|
||||
|
||||
static void tst_multiple_sorts() {
|
||||
std::cout << "=== test multiple sorts ===\n";
|
||||
ast_manager m;
|
||||
reg_decl_plugins(m);
|
||||
arith_util a(m);
|
||||
|
||||
term_enumeration te(m);
|
||||
|
||||
// Add Int constants
|
||||
expr_ref i_zero(a.mk_int(0), m);
|
||||
expr_ref i_one(a.mk_int(1), m);
|
||||
te.add_production(i_zero);
|
||||
te.add_production(i_one);
|
||||
|
||||
// Add Real constants
|
||||
expr_ref r_zero(a.mk_real(0), m);
|
||||
expr_ref r_one(a.mk_real(1), m);
|
||||
te.add_production(r_zero);
|
||||
te.add_production(r_one);
|
||||
|
||||
// Enumerate Int terms
|
||||
sort* int_sort = a.mk_int();
|
||||
unsigned int_count = 0;
|
||||
for (expr* e : te.enum_terms(int_sort)) {
|
||||
std::cout << "Int Term: " << mk_pp(e, m) << "\n";
|
||||
int_count++;
|
||||
if (int_count >= 5) break;
|
||||
}
|
||||
|
||||
ENSURE(int_count >= 2);
|
||||
std::cout << "Enumerated " << int_count << " Int terms\n";
|
||||
}
|
||||
|
||||
static void tst_nested_array_enumeration() {
|
||||
std::cout << "=== test nested array enumeration (Array(A, Array(B, A))) ===\n";
|
||||
ast_manager m;
|
||||
reg_decl_plugins(m);
|
||||
array_util arr(m);
|
||||
|
||||
term_enumeration te(m);
|
||||
|
||||
// Create uninterpreted sorts A and B
|
||||
sort_ref sort_A(m.mk_uninterpreted_sort(symbol("A")), m);
|
||||
sort_ref sort_B(m.mk_uninterpreted_sort(symbol("B")), m);
|
||||
|
||||
// Create nested array sort: Array(B, A) - arrays indexed by B returning A
|
||||
sort_ref array_B_A(arr.mk_array_sort(sort_B, sort_A), m);
|
||||
|
||||
// Create outer array sort: Array(A, Array(B, A)) - arrays indexed by A returning Array(B,A)
|
||||
sort_ref array_A_arrayBA(arr.mk_array_sort(sort_A, array_B_A), m);
|
||||
|
||||
std::cout << "Sort A: " << mk_pp(sort_A.get(), m) << "\n";
|
||||
std::cout << "Sort B: " << mk_pp(sort_B.get(), m) << "\n";
|
||||
std::cout << "Sort Array(B, A): " << mk_pp(array_B_A.get(), m) << "\n";
|
||||
std::cout << "Sort Array(A, Array(B, A)): " << mk_pp(array_A_arrayBA.get(), m) << "\n";
|
||||
|
||||
// Add constants of sort A
|
||||
app_ref a0(m.mk_const(symbol("a0"), sort_A), m);
|
||||
app_ref a1(m.mk_const(symbol("a1"), sort_A), m);
|
||||
te.add_production(a0);
|
||||
te.add_production(a1);
|
||||
|
||||
// Add constants of sort B
|
||||
app_ref b0(m.mk_const(symbol("b0"), sort_B), m);
|
||||
app_ref b1(m.mk_const(symbol("b1"), sort_B), m);
|
||||
te.add_production(b0);
|
||||
te.add_production(b1);
|
||||
|
||||
// Add a constant array of inner type Array(B, A) - const_array(a0) : Array(B, A)
|
||||
app_ref const_inner(arr.mk_const_array(array_B_A, a0), m);
|
||||
te.add_production(const_inner);
|
||||
|
||||
// Add a constant array of outer type Array(A, Array(B, A))
|
||||
app_ref const_outer(arr.mk_const_array(array_A_arrayBA, const_inner), m);
|
||||
te.add_production(const_outer);
|
||||
|
||||
// Add store operator for the inner array type Array(B, A)
|
||||
// store(array, index, value) : store(Array(B,A), B, A) -> Array(B,A)
|
||||
expr* store_inner_args[3] = { const_inner.get(), b0.get(), a0.get() };
|
||||
app_ref tmp_store_inner(arr.mk_store(3, store_inner_args), m);
|
||||
func_decl* store_inner_decl = tmp_store_inner->get_decl();
|
||||
te.add_production(store_inner_decl);
|
||||
|
||||
// Add store operator for the outer array type Array(A, Array(B, A))
|
||||
// store(array, index, value) : store(Array(A, Array(B,A)), A, Array(B,A)) -> Array(A, Array(B,A))
|
||||
expr* store_outer_args[3] = { const_outer.get(), a0.get(), const_inner.get() };
|
||||
app_ref tmp_store_outer(arr.mk_store(3, store_outer_args), m);
|
||||
func_decl* store_outer_decl = tmp_store_outer->get_decl();
|
||||
te.add_production(store_outer_decl);
|
||||
|
||||
// Add select operator for the outer array (returns Array(B, A))
|
||||
// select(Array(A, Array(B,A)), A) -> Array(B, A)
|
||||
app_ref tmp_select_outer(arr.mk_select(const_outer.get(), a0.get()), m);
|
||||
func_decl* select_outer_decl = tmp_select_outer->get_decl();
|
||||
te.add_production(select_outer_decl);
|
||||
|
||||
// Enumerate terms of the nested array sort Array(A, Array(B, A))
|
||||
std::cout << "\nEnumerating terms of sort Array(A, Array(B, A)):\n";
|
||||
unsigned count = 0;
|
||||
for (expr* e : te.enum_terms(array_A_arrayBA)) {
|
||||
std::cout << " Term " << count << ": " << mk_pp(e, m) << "\n";
|
||||
count++;
|
||||
if (count >= 15) break; // Limit output
|
||||
}
|
||||
|
||||
ENSURE(count >= 1); // At least the constant array
|
||||
std::cout << "Enumerated " << count << " terms of sort Array(A, Array(B, A))\n";
|
||||
|
||||
te.display(std::cout);
|
||||
}
|
||||
|
||||
void tst_term_enumeration() {
|
||||
tst_basic_enumeration();
|
||||
tst_enumeration_with_operators();
|
||||
tst_observational_equivalence_filter();
|
||||
tst_display();
|
||||
tst_bitvector_enumeration();
|
||||
tst_multiple_sorts();
|
||||
tst_nested_array_enumeration();
|
||||
std::cout << "All term_enumeration tests passed!\n";
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue