3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2026-06-22 16:40:29 +00:00

Term enumeration (#9908)

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
Signed-off-by: dependabot[bot] <support@github.com>
Signed-off-by: Lev Nachmanson <levnach@hotmail.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com>
Co-authored-by: davedets <daviddetlefs@gmail.com>
Co-authored-by: Lev Nachmanson <5377127+levnach@users.noreply.github.com>
Co-authored-by: Claude Fable 5 <noreply@anthropic.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Co-authored-by: Margus Veanes <veanes@users.noreply.github.com>
Co-authored-by: Nuno Lopes <nuno.lopes@tecnico.ulisboa.pt>
Co-authored-by: Shantanu Gontia <gontia.shantanu@gmail.com>
Co-authored-by: Peter Chen J. <34339487+peter941221@users.noreply.github.com>
Co-authored-by: Alcides Fonseca <me@alcidesfonseca.com>
Co-authored-by: Can Cebeci <can.cebeci99@gmail.com>
Co-authored-by: Can Cebeci <t-cancebeci@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2026-06-20 18:14:44 -06:00 committed by GitHub
parent b9cc87ae4b
commit 5699142f5b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 1156 additions and 33 deletions

View file

@ -43,6 +43,7 @@ z3_add_component(rewriter
seq_rewriter.cpp
seq_regex_bisim.cpp
seq_skolem.cpp
term_enumeration.cpp
th_rewriter.cpp
value_sweep.cpp
var_subst.cpp

View file

@ -0,0 +1,674 @@
/**
* term_enumeration.cpp - Bottom-up term enumeration module for Z3
*
* Inspired by the Probe synthesizer (Barke et al., "Just-in-Time Learning
* for Bottom-Up Enumerative Synthesis"). Adapted to use Z3's internal APIs.
*
* Key ideas:
* - Terms are enumerated bottom-up by "cost" (calculated by tree size).
* - A grammar describes which function symbols (operators) and leaves
* (constants, variables) are available for enumeration.
*/
#include <sstream>
#include <functional>
#include <string>
#include "util/vector.h"
#include "util/scoped_ptr_vector.h"
#include "util/obj_hashtable.h"
#include "util/uint_set.h"
#include "ast/ast.h"
#include "ast/ast_ll_pp.h"
#include "ast/ast_pp.h"
#include "ast/rewriter/th_rewriter.h"
#include "ast/rewriter/term_enumeration.h"
namespace term_enum {
// ============================================================================
// grammar production rule
// ============================================================================
/**
* A production describes how to construct a term from child terms.
* - domain: the sort required for each child
* - range: the sort of the produced term
* - builder: given a vector of child exprs, produce the result expr
*/
struct production {
std::string name;
sort_ref range;
sort_ref_vector domain;
std::function<expr_ref(expr_ref_vector const&)> builder;
bool is_leaf() const { return domain.empty(); }
};
// ============================================================================
// grammar
// ============================================================================
/**
* A grammar groups productions into leaves (arity 0) and operators (arity > 0).
*/
class grammar {
public:
grammar(ast_manager& m) : m(m), m_pinned(m) {}
void add_production(production* p) {
if (p->is_leaf())
m_leaves.push_back(p);
else
m_operators.push_back(p);
}
scoped_ptr_vector<production> const& leaves() const { return m_leaves; }
scoped_ptr_vector<production> const& operators() const { return m_operators; }
ast_manager& mgr() const { return m; }
void add_func_decl(func_decl *f) {
if (m_seen.contains(f))
return;
m_pinned.push_back(f);
m_seen.insert(f);
sort_ref range(f->get_range(), m);
sort_ref_vector dom(m);
for (unsigned i = 0; i < f->get_arity(); ++i)
dom.push_back(sort_ref(f->get_domain(i), m));
add_production(alloc(production, {f->get_name().str(), range, dom, [this, f](expr_ref_vector const &args) {
return expr_ref(m.mk_app(f, args), m);
}}));
}
void add_expr(expr *e) {
if (m_seen.contains(e))
return;
m_pinned.push_back(e);
m_seen.insert(e);
sort_ref range(e->get_sort(), m);
sort_ref_vector dom(m);
std::stringstream ss;
ss << mk_bounded_pp(e, m);
std::string name = ss.str();
add_production(alloc(production, {name, range, dom, [this, e](expr_ref_vector const&) { return expr_ref(e, m); }}));
}
std::ostream& display(std::ostream& out) const {
out << "Leaves:\n";
for (auto const *p : m_leaves) {
out << " " << p->name << " : " << mk_pp(p->range, m) << "\n";
}
out << "Operators:\n";
for (auto const *p : m_operators) {
out << " " << p->name << " : (";
for (unsigned i = 0; i < p->domain.size(); ++i) {
if (i > 0)
out << ", ";
out << mk_pp(p->domain[i], m);
}
out << ") -> " << mk_pp(p->range, m) << "\n";
}
return out;
}
private:
ast_manager& m;
ast_ref_vector m_pinned;
scoped_ptr_vector<production> m_leaves;
scoped_ptr_vector<production> m_operators;
obj_hashtable<ast> m_seen;
};
// ============================================================================
// Term Bank - stores enumerated terms by cost and sort
// ============================================================================
using cost_terms = vector<std::pair<expr*, unsigned>>;
class term_bank {
using sort_term_map = obj_map<sort, ptr_vector<expr>>;
public:
term_bank(ast_manager& m) : m(m), m_pinned(m) {}
~term_bank() {
for (auto s : m_terms)
dealloc(s);
}
void reset() {
m_pinned.reset();
m_terms.clear();
}
void add(expr* term, unsigned cost) {
sort* s = term->get_sort();
m_pinned.push_back(term);
if (cost >= m_terms.size())
m_terms.resize(cost + 1);
if (!m_terms[cost])
m_terms[cost] = alloc(sort_term_map);
m_terms[cost]->insert_if_not_there(s, ptr_vector<expr>()).push_back(term);
}
/** Get all terms of a given sort up to (and including) max_cost */
cost_terms get_by_sort(sort* s, unsigned max_cost) const {
cost_terms result;
for (unsigned c = 0; c <= max_cost; ++c) {
if (c >= m_terms.size())
break;
if (!m_terms[c]->contains(s))
continue;
for (auto t : m_terms[c]->find(s))
result.push_back({t, c});
}
return result;
}
// Return true if there is at least one term at/above `cost` whose sort is
// not in `sorts` (i.e., enumeration can still produce a new requested sort).
bool is_productive(unsigned cost, uint_set const& sorts) {
for (unsigned i = cost; i < m_terms.size(); ++i) {
if (!m_terms[i])
continue;
for (auto const& entry : *m_terms[i]) {
sort* term_sort = entry.m_key;
if (!sorts.contains(term_sort->get_small_id()))
return true;
}
}
return false;
}
ptr_vector<expr> null_ptr_vector;
ptr_vector<expr> const &get_by_cost_and_sort(unsigned cost, sort *s) const {
if (cost >= m_terms.size() || !m_terms[cost] || !m_terms[cost]->contains(s))
return null_ptr_vector;
return m_terms[cost]->find(s);
}
std::ostream& display(std::ostream& out) const {
for (unsigned cost = 0; cost < m_terms.size(); ++cost) {
if (!m_terms[cost])
continue;
out << "cost " << cost << ":\n";
for (auto& [s, terms] : *m_terms[cost]) {
out << " sort " << mk_pp(s, m) << ":\n";
for (expr* e : terms) {
out << " #" << e->get_id() << " ";
if (cost == 0) {
out << mk_bounded_pp(e, m);
}
else if (is_app(e)) {
app* a = to_app(e);
out << a->get_decl()->get_name() << "(";
bool first = true;
for (expr* arg : *a) {
if (!first) out << ", ";
first = false;
out << "#" << arg->get_id();
}
out << ")";
}
out << "\n";
}
}
}
return out;
}
private:
ast_manager& m;
expr_ref_vector m_pinned;
// cost -> sort -> terms
ptr_vector<sort_term_map> m_terms;
};
// ============================================================================
// Children Iterator - generates all combinations of child terms
// ============================================================================
/**
* Iterates over all tuples (c1, c2, ..., cn) where each ci has the required
* sort, drawn from the term bank, with at least one child at the current
* cost - 1 (to avoid regenerating previously seen terms).
*/
class children_iterator {
public:
children_iterator(ast_manager& m, production const& prod, term_bank const& bank, unsigned current_cost)
: m(m), m_prod(prod), m_current_cost(current_cost), m_done(false)
{
m_arity = prod.domain.size();
if (m_arity == 0) {
m_done = true;
return;
}
for (unsigned i = 0; i < m_arity; ++i) {
m_candidates.push_back(bank.get_by_sort(prod.domain[i], current_cost - 1));
if (m_candidates.back().empty()) {
m_done = true;
return;
}
}
m_indices.resize(m_arity, 0);
}
bool has_next(unsigned cost) {
while (!m_done) {
if (has_child_at_cost(cost))
return true;
advance();
}
return false;
}
expr_ref_vector next(unsigned& cost) {
expr_ref_vector result(m);
cost = 1;
for (unsigned i = 0; i < m_arity; ++i) {
auto [e, c] = m_candidates[i].get(m_indices[i]);
cost += c;
result.push_back(e);
}
advance();
return result;
}
private:
ast_manager& m;
production const& m_prod;
unsigned m_current_cost;
unsigned m_arity;
bool m_done;
vector<cost_terms> m_candidates;
svector<unsigned> m_indices;
bool has_child_at_cost(unsigned cost) const {
for (unsigned i = 0; i < m_arity; ++i) {
auto [e, c] = m_candidates[i].get(m_indices[i]);
if (c + 1 == cost)
return true;
}
return false;
}
void advance() {
for (auto i = m_arity; i-- > 0;) {
m_indices[i]++;
if (m_indices[i] < m_candidates[i].size()) return;
m_indices[i] = 0;
}
m_done = true;
}
};
// ============================================================================
// bottom_up_enumerator - the main bottom-up term enumeration engine
// ============================================================================
class bottom_up_enumerator {
public:
bottom_up_enumerator(grammar& grammar)
: m_grammar(grammar), m(grammar.mgr()),
m_bank(grammar.mgr()), m_pending(grammar.mgr()), m_rewriter(grammar.mgr())
{}
void set_target_sort(sort *s) {
m_target_sort = s;
}
bool has_next() {
if (m_pending) return true;
m_pending = find_next();
return m_pending != nullptr;
}
expr_ref next() {
if (!m_pending)
m_pending = find_next();
expr_ref result(m_pending, m);
m_pending = nullptr;
return result;
}
term_bank const& bank() const { return m_bank; }
std::ostream& display(std::ostream& out) const {
m_grammar.display(out);
return m_bank.display(out);
}
void reset() {
m_cost = 0;
m_leaf_idx = 0;
m_op_idx = 0;
m_state = State::Leaves;
m_bank.reset();
m_pending = nullptr;
m_rewriter.reset();
m_seen_terms.reset();
m_children_iter.reset();
}
expr* add_term(expr_ref const& term, unsigned cost) {
expr_ref simplified(m);
m_rewriter(term, simplified);
if (m_seen_terms.contains(simplified))
return nullptr;
m_seen_terms.insert(simplified);
m_bank.add(simplified, cost);
return simplified;
}
private:
enum class State { Leaves, Operators, Done };
grammar& m_grammar;
ast_manager& m;
term_bank m_bank;
unsigned m_cost = 0;
unsigned m_leaf_idx = 0;
unsigned m_op_idx = 0;
unsigned m_bank_idx = 0;
unsigned m_bank_size = 0;
bool m_made_progress = false;
uint_set m_sorts_produced;
State m_state = State::Leaves;
expr_ref m_pending;
th_rewriter m_rewriter;
obj_hashtable<expr> m_seen_terms;
std::unique_ptr<children_iterator> m_children_iter;
sort *m_target_sort = nullptr;
bool sort_matches(expr* e) const {
return !m_target_sort || e->get_sort() == m_target_sort;
}
expr* find_next() {
while (true) {
switch (m_state) {
case State::Leaves:
while (m_leaf_idx < m_grammar.leaves().size()) {
production const &prod = *m_grammar.leaves()[m_leaf_idx];
m_leaf_idx++;
expr_ref_vector empty_args(m);
expr_ref term = prod.builder(empty_args);
expr* r = add_term(term, 0);
if (r && sort_matches(r))
return r;
}
m_state = State::Operators;
m_cost = 1;
m_op_idx = 0;
m_bank_idx = 0;
m_bank_size = get_bank_size();
m_made_progress = false;
m_sorts_produced.reset();
m_children_iter.reset();
break;
case State::Operators: {
expr* result = enumerate_operators();
if (result)
return result;
m_cost++;
m_op_idx = 0;
m_bank_idx = 0;
m_bank_size = get_bank_size();
m_children_iter.reset();
if (!m_made_progress && !m_bank.is_productive(m_cost, m_sorts_produced)) {
m_state = State::Done;
return nullptr;
}
if (m_sorts_produced.contains(m_target_sort->get_small_id()))
m_sorts_produced.reset();
m_made_progress = false;
break;
}
case State::Done:
return nullptr;
}
}
}
unsigned get_bank_size() const {
auto const &terms = m_bank.get_by_cost_and_sort(m_cost, m_target_sort);
return terms.size();
}
expr *enumerate_operators() {
auto const &ops = m_grammar.operators();
while (true) {
// first find terms at m_cost that were already created
if (m_bank_idx < m_bank_size) {
auto const &terms = m_bank.get_by_cost_and_sort(m_cost, m_target_sort);
auto t = terms.get(m_bank_idx);
m_bank_idx++;
SASSERT(sort_matches(t));
return t;
}
// then create new terms using children at cost below current m_cost.
if (m_children_iter && m_children_iter->has_next(m_cost)) {
unsigned new_cost = 0;
expr_ref_vector children = m_children_iter->next(new_cost);
production const &prod = *ops[m_op_idx - 1];
expr_ref term = prod.builder(children);
// IF_VERBOSE(0, verbose_stream() << term << "\n");
SASSERT(new_cost >= m_cost);
expr* r = add_term(term, new_cost);
if (!r)
continue;
unsigned sort_id = r->get_sort()->get_small_id();
if (!m_sorts_produced.contains(sort_id))
m_made_progress = true;
m_sorts_produced.insert(sort_id);
if (sort_matches(r) && new_cost == m_cost) {
return r;
}
continue;
}
if (m_op_idx >= ops.size())
return nullptr;
production const &prod = *ops[m_op_idx];
m_op_idx++;
m_children_iter = std::make_unique<children_iterator>(m, prod, m_bank, m_cost);
}
}
};
} // namespace term_enum
// ============================================================================
// term_enumeration public interface implementation
// ============================================================================
struct term_enumeration::imp {
ast_manager& m;
term_enum::grammar m_grammar;
term_enum::bottom_up_enumerator m_bottom_up_enumerator;
std::function<unsigned(expr*)> m_cost;
imp(ast_manager& m) :
m(m), m_grammar(m), m_bottom_up_enumerator(m_grammar) {}
void add_production(func_decl* f) {
m_grammar.add_func_decl(f);
}
void add_production(expr* e) {
m_grammar.add_expr(e);
}
void set_cost(std::function<unsigned(expr*)> const& cost) {
// TODO
}
std::ostream& display(std::ostream& out) const {
return m_bottom_up_enumerator.display(out);
}
};
// -- iterator implementation --
struct term_enumeration::iterator::iter_imp {
imp& m_imp;
ast_manager & m;
sort* m_sort;
unsigned m_cost = 0;
unsigned m_idx = 0;
vector<expr_ref_vector> m_levels;
expr_ref m_current;
bool m_end;
vector<expr_ref_vector> m_vars;
vector<ptr_vector<sort>> m_decls;
vector<vector<symbol>> m_names;
iter_imp(imp& i, sort* s) : m_imp(i), m(i.m), m_sort(s), m_current(i.m), m_end(false) {
m_imp.m_bottom_up_enumerator.reset();
init_sort();
advance();
}
// Sentinel constructor
iter_imp(imp& i) :
m_imp(i), m(i.m), m_sort(nullptr), m_current(i.m), m_end(true) {
UNREACHABLE();
}
void init_sort() {
array_util autil(m);
sort *range = m_sort;
while (autil.is_array(range)) {
m_vars.push_back(expr_ref_vector(m));
m_decls.push_back(ptr_vector<sort>());
m_names.push_back(vector<symbol>());
for (unsigned i = 0; i < get_array_arity(range); ++i) {
m_decls.back().push_back(get_array_domain(range, i));
m_vars.back().push_back(nullptr);
m_names.back().push_back(symbol());
}
expr_ref_vector args(m);
args.push_back(m.mk_const("a", range));
for (unsigned i = 0; i < m_decls.back().size(); ++i) {
args.push_back(m.mk_var(i, m_decls.back().get(i)));
}
app_ref sel(autil.mk_select(args), m);
m_imp.m_grammar.add_func_decl(sel->get_decl());
range = get_array_range(range);
}
unsigned n = 0;
for (unsigned i = m_decls.size(); i-- > 0;) {
for (unsigned j = m_decls[i].size(); j-- > 0;) {
m_vars[i][j] = m.mk_var(n, m_decls[i][j]);
m_names[i][j] = symbol(n);
m_imp.add_production(m_vars[i].get(j));
n++;
}
}
m_sort = range;
m_imp.m_bottom_up_enumerator.set_target_sort(range);
}
void mk_lambda() {
if (!m_current)
return;
for (unsigned i = m_decls.size(); i-- > 0;)
m_current = m.mk_lambda(m_decls[i].size(), m_decls[i].data(), m_names[i].data(), m_current);
}
void advance() {
if (m_end)
return;
m_current = m_imp.m_bottom_up_enumerator.next();
SASSERT(!m_current || m_current->get_sort() == m_sort);
mk_lambda();
if (!m_current)
m_end = true;
}
};
term_enumeration::iterator::iterator(imp& i, sort* s) {
m_imp = alloc(iter_imp, i, s);
}
term_enumeration::iterator::iterator(std::nullptr_t) {
m_imp = nullptr;
}
term_enumeration::iterator::~iterator() {
dealloc(m_imp);
}
expr* term_enumeration::iterator::operator*() {
return m_imp ? m_imp->m_current.get() : nullptr;
}
term_enumeration::iterator& term_enumeration::iterator::operator++() {
if (m_imp) m_imp->advance();
return *this;
}
term_enumeration::iterator term_enumeration::iterator::operator++(int) {
iterator tmp(*this);
++(*this);
return tmp;
}
bool term_enumeration::iterator::operator==(iterator const& other) const {
if (!m_imp && !other.m_imp) return true;
if (!m_imp) return other.m_imp->m_end;
if (!other.m_imp) return m_imp->m_end;
return m_imp->m_end == other.m_imp->m_end &&
m_imp->m_current == other.m_imp->m_current;
}
// -- terms implementation --
term_enumeration::terms::terms(imp* i, sort* s) : m_imp(i), m_sort(s) {}
term_enumeration::iterator term_enumeration::terms::begin() {
return iterator(*m_imp, m_sort);
}
term_enumeration::iterator term_enumeration::terms::end() {
return iterator(nullptr);
}
// -- term_enumeration implementation --
term_enumeration::term_enumeration(ast_manager& m) {
m_imp = alloc(imp, m);
}
term_enumeration::~term_enumeration() {
dealloc(m_imp);
}
void term_enumeration::add_production(func_decl* f) {
m_imp->add_production(f);
}
void term_enumeration::add_production(expr* e) {
m_imp->add_production(e);
}
void term_enumeration::set_cost(std::function<unsigned(expr*)> const& cost) {
m_imp->set_cost(cost);
}
term_enumeration::terms term_enumeration::enum_terms(sort* s) {
return terms(m_imp, s);
}
std::ostream& term_enumeration::display(std::ostream& out) const {
return m_imp->display(out);
}

View file

@ -0,0 +1,50 @@
#pragma once
#include "ast/ast.h"
#include <functional>
class term_enumeration {
struct imp;
imp* m_imp;
public:
term_enumeration(ast_manager& m);
~term_enumeration();
void add_production(func_decl* f);
void add_production(expr* e);
// void add_production(sort *s, std::function<expr *()> g);
// cost function associated with expressions.
// terms are enumerated with increasing cost.
void set_cost(std::function<unsigned(expr*)> const& cost);
class iterator {
struct iter_imp;
iter_imp* m_imp;
public:
iterator(imp& i, sort* s);
iterator(std::nullptr_t);
~iterator();
expr* operator*();
iterator operator++(int);
iterator& operator++();
bool operator!=(iterator const& other) const {
return !(*this == other);
}
bool operator==(iterator const &other) const;
};
class terms {
imp* m_imp;
sort* m_sort;
public:
terms(imp* i, sort* s);
iterator begin();
iterator end();
};
terms enum_terms(sort* s);
std::ostream& display(std::ostream& out) const;
};

View file

@ -513,7 +513,7 @@ void non_auf_macro_solver::collect_candidates(ptr_vector<quantifier> const& qs,
TRACE(model_finder, tout << "considering macro for: " << f->get_name() << "\n";
m->display(tout); tout << "\n";);
if (m->is_unconditional() && (!qi->is_auf() || m->get_weight() >= m_mbqi_force_template)) {
full_macros.insert(f, std::make_pair(m, q));
full_macros.insert(f, {m, q});
cond_macros.erase(f);
}
else if (!full_macros.contains(f) && !qi->is_auf())
@ -524,10 +524,8 @@ void non_auf_macro_solver::collect_candidates(ptr_vector<quantifier> const& qs,
}
void non_auf_macro_solver::process_full_macros(obj_map<func_decl, mq_pair> const& full_macros, obj_hashtable<quantifier>& removed) {
for (auto const& kv : full_macros) {
func_decl* f = kv.m_key;
cond_macro* m = kv.m_value.first;
quantifier* q = kv.m_value.second;
for (auto const &[f, v] : full_macros) {
auto [m, q] = v;
SASSERT(m->is_unconditional());
if (add_macro(f, m->get_def())) {
get_qinfo(q)->set_the_one(f);

View file

@ -219,7 +219,7 @@ namespace smt {
if (use_inv) {
unsigned sk_term_gen = 0;
expr * sk_term = m_model_finder.get_inv(q, i, sk_value, sk_term_gen);
expr * sk_term = m_model_finder.get_inv(q, i, sk_value, *cex, sk_term_gen);
if (sk_term != nullptr) {
TRACE(model_checker, tout << "Found inverse " << mk_pp(sk_term, m) << "\n";);
SASSERT(!m.is_model_value(sk_term));
@ -238,10 +238,6 @@ namespace smt {
TRACE(model_checker, tout << "sk term " << mk_pp(sk_term, m) << "\n");
sk_value = sk_term;
}
// last ditch: am I an array?
else if (false && autil.is_as_array(sk_value, f) && cex->get_func_interp(f) && cex->get_func_interp(f)->get_array_interp(f)) {
sk_value = cex->get_func_interp(f)->get_array_interp(f);
}
}
if (contains_model_value(sk_value)) {

View file

@ -18,6 +18,7 @@ Revision History:
--*/
#include "util/backtrackable_set.h"
#include "ast/ast_util.h"
#include "ast/has_free_vars.h"
#include "ast/macros/macro_util.h"
#include "ast/arith_decl_plugin.h"
#include "ast/bv_decl_plugin.h"
@ -31,6 +32,7 @@ Revision History:
#include "ast/ast_ll_pp.h"
#include "ast/well_sorted.h"
#include "ast/ast_smt2_pp.h"
#include "ast/rewriter/term_enumeration.h"
#include "model/model_pp.h"
#include "model/model_macro_solver.h"
#include "smt/smt_model_finder.h"
@ -107,9 +109,15 @@ namespace smt {
}
}
expr* get_inv(expr* v) const {
expr* get_inv(expr* v, model& mdl) const {
expr* t = nullptr;
m_inv.find(v, t);
if (!t) {
for (auto [k, term] : m_inv) {
if (mdl.are_equal(k, v))
return term;
}
}
return t;
}
@ -120,14 +128,11 @@ namespace smt {
}
void mk_inverse(evaluator& ev) {
for (auto const& kv : m_elems) {
expr* t = kv.m_key;
for (auto const &[t, gen] : m_elems) {
SASSERT(!contains_model_value(t));
unsigned gen = kv.m_value;
expr* t_val = ev.eval(t, true);
if (!t_val) break;
TRACE(model_finder, tout << mk_pp(t, m) << " " << mk_pp(t_val, m) << "\n";);
expr* old_t = nullptr;
if (m_inv.find(t_val, old_t)) {
unsigned old_t_gen = 0;
@ -187,14 +192,14 @@ namespace smt {
\brief Base class used to solve model construction constraints.
*/
class node {
unsigned m_id;
node* m_find{ nullptr };
unsigned m_eqc_size{ 1 };
unsigned m_id = 0;
node* m_find = nullptr;
unsigned m_eqc_size = 1;
sort* m_sort; // sort of the elements in the instantiation set.
sort* m_sort = nullptr; // sort of the elements in the instantiation set.
bool m_mono_proj{ false }; // relevant for integers & reals & bit-vectors
bool m_signed_proj{ false }; // relevant for bit-vectors.
bool m_mono_proj = false; // relevant for integers & reals & bit-vectors
bool m_signed_proj = false; // relevant for bit-vectors.
ptr_vector<node> m_avoid_set;
ptr_vector<expr> m_exceptions;
@ -291,7 +296,7 @@ namespace smt {
}
void insert(expr* n, unsigned generation) {
if (is_ground(n))
if (is_ground(n) || (has_quantifiers(n) && !has_free_vars(n))) // this is a closed term
get_root()->m_set->insert(n, generation);
}
@ -599,7 +604,10 @@ namespace smt {
}
else {
r = tmp;
TRACE(model_finder, tout << "eval\n" << mk_pp(n, m) << "\n----->\n" << mk_pp(r, m) << "\n";);
TRACE(model_finder, tout << "eval-failed\n" << mk_pp(n, m) << "\n----->\n" << mk_pp(r, m) << "\n";);
if (is_lambda(tmp)) {
r = m.mk_fresh_const("lambda", tmp->get_sort());
}
}
m_eval_cache[model_completion].insert(n, r);
m_eval_cache_range.push_back(r);
@ -1235,8 +1243,8 @@ namespace smt {
void populate_inst_sets(quantifier* q, func_decl* mhead, ptr_vector<instantiation_set>& uvar_inst_sets, context* ctx) override {
if (m_f != mhead)
return;
uvar_inst_sets.reserve(m_var_j + 1, 0);
if (uvar_inst_sets[m_var_j] == 0)
uvar_inst_sets.reserve(m_var_j + 1, nullptr);
if (uvar_inst_sets[m_var_j] == nullptr)
uvar_inst_sets[m_var_j] = alloc(instantiation_set, ctx->get_manager());
instantiation_set* s = uvar_inst_sets[m_var_j];
SASSERT(s != nullptr);
@ -1369,6 +1377,74 @@ namespace smt {
};
class ho_var : public qinfo {
unsigned m_var_i;
public:
ho_var(ast_manager& m, unsigned i) : qinfo(m), m_var_i(i) {
}
char const *get_kind() const override {
return "ho_var";
}
bool is_equal(qinfo const *qi) const override {
if (qi->get_kind() != get_kind())
return false;
ho_var const *other = static_cast<ho_var const *>(qi);
return m_var_i == other->m_var_i;
}
void display(std::ostream &out) const override {
out << "(" << "ho-var: " << m_var_i << ")";
}
void process_auf(quantifier *q, auf_solver &s, context *ctx) override {
/* node * S_i = */ s.get_uvar(q, m_var_i);
}
void populate_inst_sets(quantifier *q, auf_solver &s, context *ctx) override {
node *S = s.get_uvar(q, m_var_i);
sort *srt = S->get_sort();
IF_VERBOSE(3, verbose_stream() << "ho_var::populate_inst_sets: " << q->get_id() << " " << mk_pp(srt, m) << "\n";);
term_enumeration tn(m);
// Add ground terms of type S.
// Add productions for functions in E-graph
// add other possible relevant functions such as equality over srt, Boolean operators
ast_mark visited;
for (enode *n : ctx->enodes()) {
if (!ctx->is_relevant(n))
continue;
auto e = n->get_expr();
if (srt == n->get_sort()) {
TRACE(model_finder, tout << "inserting " << mk_pp(e, m) << " into inst set\n");
S->insert(e, n->get_generation());
}
else if (is_uninterp_const(e)) {
TRACE(model_finder, tout << "add production " << mk_pp(e, m) << "\n");
tn.add_production(e);
}
else if (is_uninterp(e)) {
auto f = to_app(e)->get_decl();
if (visited.is_marked(f))
continue;
visited.mark(f, true);
TRACE(model_finder, tout << "add function " << mk_pp(f, m) << "\n");
tn.add_production(f);
}
}
unsigned max_count = 20;
for (auto t : tn.enum_terms(srt)) {
unsigned generation = 0; // todo - inherited from sub-term of t?
TRACE(model_finder, tout << "ho_var: adding term " << mk_ismt2_pp(t, m)
<< " to instantiation set of S" << std::endl;);
S->insert(t, generation);
}
}
};
/**
\brief auf_arr is a term (pattern) of the form:
@ -2105,7 +2181,12 @@ namespace smt {
process_app(to_app(curr));
}
else if (is_var(curr)) {
m_info->m_is_auf = false; // unexpected occurrence of variable.
if (m_array_util.is_array(curr)) {
insert_qinfo(alloc(ho_var, m, to_var(curr)->get_idx()));
}
else {
m_info->m_is_auf = false; // unexpected occurrence of variable.
}
}
else {
SASSERT(is_lambda(curr));
@ -2520,11 +2601,12 @@ namespace smt {
Store in generation the generation of the result
*/
expr* model_finder::get_inv(quantifier* q, unsigned i, expr* val, unsigned& generation) {
expr* model_finder::get_inv(quantifier* q, unsigned i, expr* val, model& mdl,unsigned& generation) {
instantiation_set const* s = get_uvar_inst_set(q, i);
if (s == nullptr)
return nullptr;
expr* t = s->get_inv(val);
expr* t = s->get_inv(val, mdl);
if (m_auf_solver->is_default_representative(t))
return val;
if (t != nullptr) {
@ -2560,16 +2642,27 @@ namespace smt {
obj_map<expr, expr*> const& inv = s->get_inv_map();
if (inv.empty())
continue; // nothing to do
ptr_buffer<expr> eqs;
for (auto const& [val, _] : inv) {
if (val->get_sort() == sk->get_sort())
eqs.push_back(m.mk_eq(sk, val));
expr_ref_vector eqs(m), defs(m);
for (auto const& [val, term] : inv) {
if (val->get_sort() == sk->get_sort()) {
if (is_lambda(term)) {
eqs.push_back(m.mk_eq(sk, val));
defs.push_back(m.mk_eq(val, term));
}
else
eqs.push_back(m.mk_eq(sk, val));
}
}
if (!eqs.empty()) {
expr_ref new_cnstr(m);
new_cnstr = m.mk_or(eqs);
TRACE(model_finder, tout << "assert_restriction:\n" << mk_pp(new_cnstr, m) << "\n";);
aux_ctx->assert_expr(new_cnstr);
for (auto def : defs) {
TRACE(model_finder, tout << "assert_def:\n" << mk_pp(def, m) << "\n";);
aux_ctx->assert_expr(def);
}
asserted_something = true;
}
}

View file

@ -113,7 +113,7 @@ namespace smt {
void fix_model(proto_model * m);
quantifier * get_flat_quantifier(quantifier * q);
expr * get_inv(quantifier * q, unsigned i, expr * val, unsigned & generation);
expr * get_inv(quantifier * q, unsigned i, expr * val, model& m, unsigned & generation);
bool restrict_sks_to_inst_set(context * aux_ctx, quantifier * q, expr_ref_vector const & sks);
void restart_eh();

View file

@ -145,6 +145,7 @@ add_executable(test-z3
symbol.cpp
symbol_table.cpp
tbv.cpp
term_enumeration.cpp
theory_dl.cpp
theory_pb.cpp
timeout.cpp

View file

@ -195,6 +195,7 @@
X(finite_set) \
X(finite_set_rewriter) \
X(fpa) \
X(term_enumeration) \
X(lcube)
#define FOR_EACH_TEST(X, X_ARGV) \

View file

@ -0,0 +1,309 @@
/*++
Copyright (c) 2024 Microsoft Corporation
Module Name:
tst_term_enumeration.cpp
Abstract:
Test term enumeration module
--*/
#include "ast/rewriter/term_enumeration.h"
#include "ast/ast_pp.h"
#include "ast/arith_decl_plugin.h"
#include "ast/bv_decl_plugin.h"
#include "ast/array_decl_plugin.h"
#include "ast/reg_decl_plugins.h"
#include "ast/rewriter/th_rewriter.h"
#include "util/obj_hashtable.h"
#include <iostream>
#include <sstream>
static void tst_basic_enumeration() {
std::cout << "=== test basic enumeration ===\n";
ast_manager m;
reg_decl_plugins(m);
arith_util a(m);
term_enumeration te(m);
// Add some leaf productions (constants)
expr_ref zero(a.mk_int(0), m);
expr_ref one(a.mk_int(1), m);
te.add_production(zero);
te.add_production(one);
// Enumerate terms of Int sort
sort* int_sort = a.mk_int();
unsigned count = 0;
for (expr* e : te.enum_terms(int_sort)) {
std::cout << "Term: " << mk_pp(e, m) << "\n";
count++;
if (count >= 5) break; // Limit output
}
ENSURE(count >= 2); // At least 0 and 1
std::cout << "Enumerated " << count << " terms\n";
}
static void tst_enumeration_with_operators() {
std::cout << "=== test enumeration with operators ===\n";
ast_manager m;
reg_decl_plugins(m);
arith_util a(m);
term_enumeration te(m);
// Add leaf productions
expr_ref zero(a.mk_int(0), m);
expr_ref one(a.mk_int(1), m);
te.add_production(zero);
te.add_production(one);
// Add operator productions (+ and *)
// Get func_decl by creating an app and extracting the decl
app_ref tmp_add(a.mk_add(zero, one), m);
app_ref tmp_mul(a.mk_mul(zero, one), m);
func_decl* add_decl = tmp_add->get_decl();
func_decl* mul_decl = tmp_mul->get_decl();
te.add_production(add_decl);
te.add_production(mul_decl);
sort* int_sort = a.mk_int();
unsigned count = 0;
for (expr* e : te.enum_terms(int_sort)) {
std::cout << "Term: " << mk_pp(e, m) << "\n";
count++;
if (count >= 20) break; // Limit output
}
ENSURE(count >= 2); // At least the leaves
std::cout << "Enumerated " << count << " terms with operators\n";
}
static void tst_observational_equivalence_filter() {
std::cout << "=== test observational equivalence filter ===\n";
ast_manager m;
reg_decl_plugins(m);
arith_util a(m);
th_rewriter rw(m);
term_enumeration te(m);
expr_ref zero(a.mk_int(0), m);
expr_ref one(a.mk_int(1), m);
te.add_production(zero);
te.add_production(one);
app_ref tmp_add(a.mk_add(zero, one), m);
te.add_production(tmp_add->get_decl());
sort* int_sort = a.mk_int();
obj_hashtable<expr> seen;
unsigned count = 0;
for (expr* e : te.enum_terms(int_sort)) {
expr_ref r(m);
rw(e, r);
ENSURE(r == e);
ENSURE(!seen.contains(r));
seen.insert(r);
count++;
if (count >= 20) break;
}
ENSURE(count >= 2);
}
static void tst_display() {
std::cout << "=== test display ===\n";
ast_manager m;
reg_decl_plugins(m);
arith_util a(m);
term_enumeration te(m);
// Add leaf productions
expr_ref zero(a.mk_int(0), m);
expr_ref one(a.mk_int(1), m);
te.add_production(zero);
te.add_production(one);
// Add operator productions
app_ref tmp_add(a.mk_add(zero, one), m);
func_decl* add_decl = tmp_add->get_decl();
te.add_production(add_decl);
sort* int_sort = a.mk_int();
unsigned count = 0;
for (expr* e : te.enum_terms(int_sort)) {
(void)e;
count++;
if (count >= 10) break;
}
std::cout << "Internal state after enumeration:\n";
std::ostringstream oss;
te.display(oss);
std::cout << oss.str();
// Verify display produced some output
ENSURE(!oss.str().empty());
}
static void tst_bitvector_enumeration() {
std::cout << "=== test bitvector enumeration ===\n";
ast_manager m;
reg_decl_plugins(m);
bv_util bv(m);
term_enumeration te(m);
// Add bitvector constants
unsigned bv_size = 8;
expr_ref bv_zero(bv.mk_numeral(0, bv_size), m);
expr_ref bv_one(bv.mk_numeral(1, bv_size), m);
te.add_production(bv_zero);
te.add_production(bv_one);
// Add bvadd operator
app_ref tmp_add(bv.mk_bv_add(bv_zero, bv_one), m);
func_decl* bvadd = tmp_add->get_decl();
te.add_production(bvadd);
sort* bv8 = bv.mk_sort(bv_size);
unsigned count = 0;
for (expr* e : te.enum_terms(bv8)) {
std::cout << "BV Term: " << mk_pp(e, m) << "\n";
count++;
if (count >= 10) break;
}
ENSURE(count >= 2);
std::cout << "Enumerated " << count << " bitvector terms\n";
}
static void tst_multiple_sorts() {
std::cout << "=== test multiple sorts ===\n";
ast_manager m;
reg_decl_plugins(m);
arith_util a(m);
term_enumeration te(m);
// Add Int constants
expr_ref i_zero(a.mk_int(0), m);
expr_ref i_one(a.mk_int(1), m);
te.add_production(i_zero);
te.add_production(i_one);
// Add Real constants
expr_ref r_zero(a.mk_real(0), m);
expr_ref r_one(a.mk_real(1), m);
te.add_production(r_zero);
te.add_production(r_one);
// Enumerate Int terms
sort* int_sort = a.mk_int();
unsigned int_count = 0;
for (expr* e : te.enum_terms(int_sort)) {
std::cout << "Int Term: " << mk_pp(e, m) << "\n";
int_count++;
if (int_count >= 5) break;
}
ENSURE(int_count >= 2);
std::cout << "Enumerated " << int_count << " Int terms\n";
}
static void tst_nested_array_enumeration() {
std::cout << "=== test nested array enumeration (Array(A, Array(B, A))) ===\n";
ast_manager m;
reg_decl_plugins(m);
array_util arr(m);
term_enumeration te(m);
// Create uninterpreted sorts A and B
sort_ref sort_A(m.mk_uninterpreted_sort(symbol("A")), m);
sort_ref sort_B(m.mk_uninterpreted_sort(symbol("B")), m);
// Create nested array sort: Array(B, A) - arrays indexed by B returning A
sort_ref array_B_A(arr.mk_array_sort(sort_B, sort_A), m);
// Create outer array sort: Array(A, Array(B, A)) - arrays indexed by A returning Array(B,A)
sort_ref array_A_arrayBA(arr.mk_array_sort(sort_A, array_B_A), m);
std::cout << "Sort A: " << mk_pp(sort_A.get(), m) << "\n";
std::cout << "Sort B: " << mk_pp(sort_B.get(), m) << "\n";
std::cout << "Sort Array(B, A): " << mk_pp(array_B_A.get(), m) << "\n";
std::cout << "Sort Array(A, Array(B, A)): " << mk_pp(array_A_arrayBA.get(), m) << "\n";
// Add constants of sort A
app_ref a0(m.mk_const(symbol("a0"), sort_A), m);
app_ref a1(m.mk_const(symbol("a1"), sort_A), m);
te.add_production(a0);
te.add_production(a1);
// Add constants of sort B
app_ref b0(m.mk_const(symbol("b0"), sort_B), m);
app_ref b1(m.mk_const(symbol("b1"), sort_B), m);
te.add_production(b0);
te.add_production(b1);
// Add a constant array of inner type Array(B, A) - const_array(a0) : Array(B, A)
app_ref const_inner(arr.mk_const_array(array_B_A, a0), m);
te.add_production(const_inner);
// Add a constant array of outer type Array(A, Array(B, A))
app_ref const_outer(arr.mk_const_array(array_A_arrayBA, const_inner), m);
te.add_production(const_outer);
// Add store operator for the inner array type Array(B, A)
// store(array, index, value) : store(Array(B,A), B, A) -> Array(B,A)
expr* store_inner_args[3] = { const_inner.get(), b0.get(), a0.get() };
app_ref tmp_store_inner(arr.mk_store(3, store_inner_args), m);
func_decl* store_inner_decl = tmp_store_inner->get_decl();
te.add_production(store_inner_decl);
// Add store operator for the outer array type Array(A, Array(B, A))
// store(array, index, value) : store(Array(A, Array(B,A)), A, Array(B,A)) -> Array(A, Array(B,A))
expr* store_outer_args[3] = { const_outer.get(), a0.get(), const_inner.get() };
app_ref tmp_store_outer(arr.mk_store(3, store_outer_args), m);
func_decl* store_outer_decl = tmp_store_outer->get_decl();
te.add_production(store_outer_decl);
// Add select operator for the outer array (returns Array(B, A))
// select(Array(A, Array(B,A)), A) -> Array(B, A)
app_ref tmp_select_outer(arr.mk_select(const_outer.get(), a0.get()), m);
func_decl* select_outer_decl = tmp_select_outer->get_decl();
te.add_production(select_outer_decl);
// Enumerate terms of the nested array sort Array(A, Array(B, A))
std::cout << "\nEnumerating terms of sort Array(A, Array(B, A)):\n";
unsigned count = 0;
for (expr* e : te.enum_terms(array_A_arrayBA)) {
std::cout << " Term " << count << ": " << mk_pp(e, m) << "\n";
count++;
if (count >= 15) break; // Limit output
}
ENSURE(count >= 1); // At least the constant array
std::cout << "Enumerated " << count << " terms of sort Array(A, Array(B, A))\n";
te.display(std::cout);
}
void tst_term_enumeration() {
tst_basic_enumeration();
tst_enumeration_with_operators();
tst_observational_equivalence_filter();
tst_display();
tst_bitvector_enumeration();
tst_multiple_sorts();
tst_nested_array_enumeration();
std::cout << "All term_enumeration tests passed!\n";
}