From cf62a78e8a0273393e824fe01c3e867cc5c4cdf9 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Thu, 11 Jun 2026 14:24:08 -0700 Subject: [PATCH] v1 of term enumeration Signed-off-by: Nikolaj Bjorner --- src/ast/CMakeLists.txt | 1 + src/ast/term_enumeration.cpp | 716 +++++++++++++++++++++++++++++++++++ 2 files changed, 717 insertions(+) create mode 100644 src/ast/term_enumeration.cpp diff --git a/src/ast/CMakeLists.txt b/src/ast/CMakeLists.txt index 6a50c3b05..1b8d66cd3 100644 --- a/src/ast/CMakeLists.txt +++ b/src/ast/CMakeLists.txt @@ -49,6 +49,7 @@ z3_add_component(ast shared_occs.cpp special_relations_decl_plugin.cpp static_features.cpp + term_enumeration.cpp used_vars.cpp value_generator.cpp well_sorted.cpp diff --git a/src/ast/term_enumeration.cpp b/src/ast/term_enumeration.cpp new file mode 100644 index 000000000..f98ae209c --- /dev/null +++ b/src/ast/term_enumeration.cpp @@ -0,0 +1,716 @@ +#pragma once +/** + * term_enumeration.h - Bottom-up term enumeration module for Z3 + * + * Inspired by the Probe synthesizer (Barke et al., "Just-in-Time Learning + * for Bottom-Up Enumerative Synthesis"). Adapted to use Z3's internal APIs. + * + * Key ideas: + * - Terms are enumerated bottom-up by "height" (max nesting depth). + * - Observational equivalence (OE): two terms that produce the same outputs + * on all sample inputs are considered equivalent; only one representative + * per equivalence class is kept. + * - A Grammar describes which function symbols (operators) and leaves + * (constants, variables) are available for enumeration. + */ + +#include +#include +#include +#include +#include +#include +#include "ast/ast.h" +#include "ast/ast_pp.h" +#include "ast/ast_ll_pp.h" +#include "ast/arith_decl_plugin.h" +#include "ast/bv_decl_plugin.h" +#include "ast/seq_decl_plugin.h" +#include "model/model.h" +#include "model/model_evaluator.h" +#include "solver/solver.h" +#include "smt/smt_solver.h" +#include "util/vector.h" +#include "util/ref.h" +#include "util/obj_hashtable.h" + +namespace term_enum { + +// ============================================================================ +// Grammar production rule +// ============================================================================ + +/** + * A Production describes how to construct a term from child terms. + * - domain: the sort required for each child + * - range: the sort of the produced term + * - builder: given a vector of child exprs, produce the result expr + */ +struct Production { + std::string name; + sort_ref range; + sort_ref_vector domain; + std::function builder; + + bool is_leaf() const { return domain.empty(); } +}; + +// ============================================================================ +// Grammar +// ============================================================================ + +/** + * A Grammar groups productions into leaves (arity 0) and operators (arity > 0). + */ +class Grammar { +public: + Grammar(ast_manager& m) : m(m) {} + + void add_production(Production p) { + if (p.is_leaf()) + m_leaves.push_back(std::move(p)); + else + m_operators.push_back(std::move(p)); + } + + vector const& leaves() const { return m_leaves; } + vector const& operators() const { return m_operators; } + ast_manager& mgr() const { return m; } + + void add_variable(char const* name, sort* s) { + expr_ref var(m.mk_const(name, s), m); + sort_ref sr(s, m); + sort_ref_vector dom(m); + add_production({name, sr, dom, [var](expr_ref_vector const&) { return var; }}); + } + + void add_int_const(int val) { + arith_util a(m); + expr_ref e(a.mk_int(val), m); + sort_ref sr(a.mk_int(), m); + sort_ref_vector dom(m); + add_production({std::to_string(val), sr, dom, [e](expr_ref_vector const&) { return e; }}); + } + + void add_bv_const(int val, unsigned bits) { + bv_util bv(m); + expr_ref e(bv.mk_numeral(rational(val), bits), m); + sort_ref sr(bv.mk_sort(bits), m); + sort_ref_vector dom(m); + add_production({std::to_string(val), sr, dom, [e](expr_ref_vector const&) { return e; }}); + } + + void add_string_const(std::string const& val) { + seq_util seq(m); + expr_ref e(seq.str.mk_string(zstring(val.c_str())), m); + sort_ref sr(seq.str.mk_string_sort(), m); + sort_ref_vector dom(m); + add_production({"\"" + val + "\"", sr, dom, [e](expr_ref_vector const&) { return e; }}); + } + + void add_bool_const(bool val) { + expr_ref e(val ? m.mk_true() : m.mk_false(), m); + sort_ref sr(m.mk_bool_sort(), m); + sort_ref_vector dom(m); + std::string n = val ? "true" : "false"; + add_production({n, sr, dom, [e](expr_ref_vector const&) { return e; }}); + } + + void add_func_decl(func_decl *f) { + sort_ref range(f->get_range(), m); + sort_ref_vector dom(m); + for (unsigned i = 0; i < f->get_arity(); ++i) + dom.push_back(sort_ref(f->get_domain(i), m)); + add_production({f->get_name().str(), range, dom, [this, f](expr_ref_vector const &args) { + return expr_ref(m.mk_app(f, args), m); + }}); + } + + void add_expr(expr *e) { + sort_ref range(e->get_sort(), m); + sort_ref_vector dom(m); + std::stringstream ss; + ss << mk_bounded_pp(e, m); + std::string name = ss.str(); + add_production({name, range, dom, [this, e](expr_ref_vector const&) { return expr_ref(e, m); }}); + } + +private: + ast_manager& m; + vector m_leaves; + vector m_operators; +}; + +// ============================================================================ +// Standard grammar factories - build common operator sets +// ============================================================================ + +namespace grammars { + +/** + * Build a grammar over linear integer arithmetic. + * Operators: +, -, *, ite (with bool condition) + */ +inline void add_lia_operators(Grammar& g) { + ast_manager& m = g.mgr(); + arith_util a(m); + sort_ref isort(a.mk_int(), m); + sort_ref bsort(m.mk_bool_sort(), m); + + sort_ref_vector ii(m); ii.push_back(isort); ii.push_back(isort); + sort_ref_vector i1(m); i1.push_back(isort); + sort_ref_vector bb(m); bb.push_back(bsort); bb.push_back(bsort); + sort_ref_vector b1(m); b1.push_back(bsort); + sort_ref_vector bii(m); bii.push_back(bsort); bii.push_back(isort); bii.push_back(isort); + + g.add_production({"add", isort, ii, + [&m](expr_ref_vector const& ch) { arith_util a(m); return expr_ref(a.mk_add(ch[0], ch[1]), m); }}); + g.add_production({"sub", isort, ii, + [&m](expr_ref_vector const& ch) { arith_util a(m); return expr_ref(a.mk_sub(ch[0], ch[1]), m); }}); + g.add_production({"mul", isort, ii, + [&m](expr_ref_vector const& ch) { arith_util a(m); return expr_ref(a.mk_mul(ch[0], ch[1]), m); }}); + g.add_production({"neg", isort, i1, + [&m](expr_ref_vector const& ch) { arith_util a(m); return expr_ref(a.mk_uminus(ch[0]), m); }}); + + g.add_production({"le", bsort, ii, + [&m](expr_ref_vector const& ch) { arith_util a(m); return expr_ref(a.mk_le(ch[0], ch[1]), m); }}); + g.add_production({"lt", bsort, ii, + [&m](expr_ref_vector const& ch) { arith_util a(m); return expr_ref(a.mk_lt(ch[0], ch[1]), m); }}); + g.add_production({"eq_int", bsort, ii, + [&m](expr_ref_vector const& ch) { return expr_ref(m.mk_eq(ch[0], ch[1]), m); }}); + + g.add_production({"and", bsort, bb, + [&m](expr_ref_vector const& ch) { return expr_ref(m.mk_and(ch[0], ch[1]), m); }}); + g.add_production({"or", bsort, bb, + [&m](expr_ref_vector const& ch) { return expr_ref(m.mk_or(ch[0], ch[1]), m); }}); + g.add_production({"not", bsort, b1, + [&m](expr_ref_vector const& ch) { return expr_ref(m.mk_not(ch[0]), m); }}); + + g.add_production({"ite_int", isort, bii, + [&m](expr_ref_vector const& ch) { return expr_ref(m.mk_ite(ch[0], ch[1], ch[2]), m); }}); +} + +/** + * Build a grammar over bitvectors. + */ +inline void add_bv_operators(Grammar& g, unsigned bits) { + ast_manager& m = g.mgr(); + bv_util bv(m); + sort_ref bvsort(bv.mk_sort(bits), m); + sort_ref bsort(m.mk_bool_sort(), m); + + sort_ref_vector vv(m); vv.push_back(bvsort); vv.push_back(bvsort); + sort_ref_vector v1(m); v1.push_back(bvsort); + sort_ref_vector bvv(m); bvv.push_back(bsort); bvv.push_back(bvsort); bvv.push_back(bvsort); + + g.add_production({"bvadd", bvsort, vv, + [&m](expr_ref_vector const& ch) { bv_util bv(m); return expr_ref(bv.mk_bv_add(ch[0], ch[1]), m); }}); + g.add_production({"bvsub", bvsort, vv, + [&m](expr_ref_vector const& ch) { bv_util bv(m); return expr_ref(bv.mk_bv_sub(ch[0], ch[1]), m); }}); + g.add_production({"bvmul", bvsort, vv, + [&m](expr_ref_vector const& ch) { bv_util bv(m); return expr_ref(bv.mk_bv_mul(ch[0], ch[1]), m); }}); + g.add_production({"bvand", bvsort, vv, + [&m](expr_ref_vector const& ch) { bv_util bv(m); return expr_ref(bv.mk_bv_and(ch[0], ch[1]), m); }}); + g.add_production({"bvor", bvsort, vv, + [&m](expr_ref_vector const& ch) { bv_util bv(m); return expr_ref(bv.mk_bv_or(ch[0], ch[1]), m); }}); + g.add_production({"bvxor", bvsort, vv, + [&m](expr_ref_vector const& ch) { bv_util bv(m); return expr_ref(bv.mk_bv_xor(ch[0], ch[1]), m); }}); + g.add_production({"bvnot", bvsort, v1, + [&m](expr_ref_vector const& ch) { bv_util bv(m); return expr_ref(bv.mk_bv_not(ch[0]), m); }}); + g.add_production({"bvneg", bvsort, v1, + [&m](expr_ref_vector const& ch) { bv_util bv(m); return expr_ref(bv.mk_bv_neg(ch[0]), m); }}); + g.add_production({"bvshl", bvsort, vv, + [&m](expr_ref_vector const& ch) { bv_util bv(m); return expr_ref(bv.mk_bv_shl(ch[0], ch[1]), m); }}); + g.add_production({"bvlshr", bvsort, vv, + [&m](expr_ref_vector const& ch) { bv_util bv(m); return expr_ref(bv.mk_bv_lshr(ch[0], ch[1]), m); }}); + g.add_production({"bvashr", bvsort, vv, + [&m](expr_ref_vector const& ch) { bv_util bv(m); return expr_ref(bv.mk_bv_ashr(ch[0], ch[1]), m); }}); + + g.add_production({"bvult", bsort, vv, + [&m](expr_ref_vector const& ch) { bv_util bv(m); return expr_ref(m.mk_app(bv.get_fid(), OP_ULT, ch[0], ch[1]), m); }}); + g.add_production({"bvslt", bsort, vv, + [&m](expr_ref_vector const& ch) { bv_util bv(m); return expr_ref(bv.mk_slt(ch[0], ch[1]), m); }}); + g.add_production({"bveq", bsort, vv, + [&m](expr_ref_vector const& ch) { return expr_ref(m.mk_eq(ch[0], ch[1]), m); }}); + + g.add_production({"ite_bv", bvsort, bvv, + [&m](expr_ref_vector const& ch) { return expr_ref(m.mk_ite(ch[0], ch[1], ch[2]), m); }}); +} + +/** + * Build a grammar over strings. + */ +inline void add_string_operators(Grammar& g) { + ast_manager& m = g.mgr(); + seq_util seq(m); + arith_util a(m); + sort_ref ssort(seq.str.mk_string_sort(), m); + sort_ref isort(a.mk_int(), m); + sort_ref bsort(m.mk_bool_sort(), m); + + sort_ref_vector ss(m); ss.push_back(ssort); ss.push_back(ssort); + sort_ref_vector s1(m); s1.push_back(ssort); + sort_ref_vector si(m); si.push_back(ssort); si.push_back(isort); + sort_ref_vector sii(m); sii.push_back(ssort); sii.push_back(isort); sii.push_back(isort); + sort_ref_vector ssi(m); ssi.push_back(ssort); ssi.push_back(ssort); ssi.push_back(isort); + sort_ref_vector sss(m); sss.push_back(ssort); sss.push_back(ssort); sss.push_back(ssort); + sort_ref_vector i1(m); i1.push_back(isort); + sort_ref_vector bss(m); bss.push_back(bsort); bss.push_back(ssort); bss.push_back(ssort); + + g.add_production({"str.++", ssort, ss, + [&m](expr_ref_vector const& ch) { seq_util seq(m); return expr_ref(seq.str.mk_concat(ch[0], ch[1]), m); }}); + g.add_production({"str.len", isort, s1, + [&m](expr_ref_vector const& ch) { seq_util seq(m); return expr_ref(seq.str.mk_length(ch[0]), m); }}); + g.add_production({"str.at", ssort, si, + [&m](expr_ref_vector const& ch) { + seq_util seq(m); arith_util a(m); + return expr_ref(seq.str.mk_substr(ch[0], ch[1], a.mk_int(1)), m); + }}); + g.add_production({"str.substr", ssort, sii, + [&m](expr_ref_vector const& ch) { seq_util seq(m); return expr_ref(seq.str.mk_substr(ch[0], ch[1], ch[2]), m); }}); + g.add_production({"str.indexof", isort, ssi, + [&m](expr_ref_vector const& ch) { seq_util seq(m); return expr_ref(seq.str.mk_index(ch[0], ch[1], ch[2]), m); }}); + g.add_production({"str.replace", ssort, sss, + [&m](expr_ref_vector const& ch) { seq_util seq(m); return expr_ref(seq.str.mk_replace(ch[0], ch[1], ch[2]), m); }}); + g.add_production({"str.contains", bsort, ss, + [&m](expr_ref_vector const& ch) { seq_util seq(m); return expr_ref(seq.str.mk_contains(ch[0], ch[1]), m); }}); + g.add_production({"str.prefixof", bsort, ss, + [&m](expr_ref_vector const& ch) { seq_util seq(m); return expr_ref(seq.str.mk_prefix(ch[0], ch[1]), m); }}); + g.add_production({"str.suffixof", bsort, ss, + [&m](expr_ref_vector const& ch) { seq_util seq(m); return expr_ref(seq.str.mk_suffix(ch[0], ch[1]), m); }}); + g.add_production({"int.to.str", ssort, i1, + [&m](expr_ref_vector const& ch) { seq_util seq(m); return expr_ref(seq.str.mk_itos(ch[0]), m); }}); + g.add_production({"str.to.int", isort, s1, + [&m](expr_ref_vector const& ch) { seq_util seq(m); return expr_ref(seq.str.mk_stoi(ch[0]), m); }}); + + g.add_production({"ite_str", ssort, bss, + [&m](expr_ref_vector const& ch) { return expr_ref(m.mk_ite(ch[0], ch[1], ch[2]), m); }}); +} + +} // namespace grammars + +// ============================================================================ +// Observational Equivalence Manager +// ============================================================================ + +/** + * Evaluates candidate terms on a set of sample inputs and keeps only one + * representative per equivalence class (the one encountered first). + * + * Uses Z3's model evaluation to reduce terms to concrete values. + */ +class OEManager { +public: + OEManager(ast_manager& m) : m(m) {} + + void set_samples(vector samples) { + m_samples = std::move(samples); + m_seen.clear(); + } + + void add_sample(model_ref mdl) { + m_samples.push_back(std::move(mdl)); + m_seen.clear(); + } + + /** + * Returns true if `term` is a new representative (its fingerprint has + * not been seen before). + */ + bool is_representative(expr* term) { + if (m_samples.empty()) return true; + std::string fingerprint = compute_fingerprint(term); + if (fingerprint.empty()) return false; + return m_seen.insert(fingerprint).second; + } + + void clear() { m_seen.clear(); } + + size_t num_classes() const { return m_seen.size(); } + size_t num_samples() const { return m_samples.size(); } + +private: + ast_manager& m; + vector m_samples; + std::unordered_set m_seen; + + std::string compute_fingerprint(expr* term) { + std::string fp; + for (auto& mdl : m_samples) { + expr_ref val(m); + model_evaluator eval(*mdl); + eval.set_model_completion(true); + if (!eval.eval(term, val, true)) + return ""; + std::ostringstream os; + os << mk_pp(val, m); + fp += os.str(); + fp += '\x1f'; + } + return fp; + } +}; + +// ============================================================================ +// Term Bank - stores enumerated terms by height and sort +// ============================================================================ + +class TermBank { +public: + TermBank(ast_manager& m) : m(m), m_pinned(m) {} + + void reset() { + m_pinned.reset(); + m_terms.clear(); + } + + void add(expr* term, unsigned height) { + sort* s = term->get_sort(); + unsigned sid = s->get_id(); + m_pinned.push_back(term); + if (height >= m_terms.size()) + m_terms.resize(height + 1); + m_terms[height].insert_if_not_there(s, ptr_vector()).push_back(term); + } + + /** Get all terms of a given sort up to (and including) max_height */ + expr_ref_vector get_by_sort(sort* s, unsigned max_height) const { + expr_ref_vector result(m); + for (unsigned h = 0; h <= max_height; ++h) { + if (h >= m_terms.size()) + break; + if (!m_terms[h].contains(s)) + continue; + for (auto t : m_terms[h].find(s)) + result.push_back(t); + } + return result; + } + + size_t total_terms() const { + size_t n = 0; + for (auto& sm : m_terms) + for (auto& [s, v] : sm) + n += v.size(); + return n; + } + +private: + ast_manager& m; + expr_ref_vector m_pinned; + // height -> sort -> terms + vector>> m_terms; +}; + +// ============================================================================ +// Children Iterator - generates all combinations of child terms +// ============================================================================ + +/** + * Iterates over all tuples (c1, c2, ..., cn) where each ci has the required + * sort, drawn from the term bank, with at least one child at the current + * height - 1 (to avoid regenerating previously seen terms). + */ +class ChildrenIterator { +public: + ChildrenIterator(ast_manager& m, Production const& prod, TermBank const& bank, unsigned current_height) + : m(m), m_prod(prod), m_current_height(current_height), m_done(false) + { + m_arity = prod.domain.size(); + if (m_arity == 0) { + m_done = true; + return; + } + for (unsigned i = 0; i < m_arity; ++i) { + m_candidates.push_back(bank.get_by_sort(prod.domain[i], current_height - 1)); + if (m_candidates.back().empty()) { + m_done = true; + return; + } + } + m_indices.resize(m_arity, 0); + } + + bool has_next() { + while (!m_done) { + if (has_child_at_max_height()) + return true; + advance(); + } + return false; + } + + expr_ref_vector next() { + expr_ref_vector result(m); + for (unsigned i = 0; i < m_arity; ++i) + result.push_back(m_candidates[i].get(m_indices[i])); + advance(); + return result; + } + +private: + ast_manager& m; + Production const& m_prod; + unsigned m_current_height; + unsigned m_arity; + bool m_done; + vector m_candidates; + svector m_indices; + + bool has_child_at_max_height() const { + return true; + } + + void advance() { + for (auto i = m_arity; i-- > 0;) { + m_indices[i]++; + if (m_indices[i] < m_candidates[i].size()) return; + m_indices[i] = 0; + } + m_done = true; + } +}; + +// ============================================================================ +// Enumerator - the main bottom-up term enumeration engine +// ============================================================================ + +/** + * Enumerates terms bottom-up by height, applying observational equivalence + * pruning. Users iterate via has_next() / next(), or call enumerate_up_to(). + * + * Usage: + * ast_manager m; + * Grammar g(m); + * // ... add productions ... + * OEManager oe(m); + * // ... set samples ... + * Enumerator en(g, oe); + * arith_util a(m); + * en.set_target_sort(a.mk_int()); + * while (en.has_next()) { + * expr_ref term = en.next(); + * // ... check if term satisfies specification ... + * } + */ +class Enumerator { +public: + Enumerator(Grammar& grammar, OEManager& oe) + : m_grammar(grammar), m(grammar.mgr()), m_oe(oe), + m_bank(grammar.mgr()), m_height(0), + m_leaf_idx(0), m_op_idx(0), m_state(State::Leaves), + m_target_sort(nullptr), m_pending(grammar.mgr()) + {} + + void set_target_sort(sort* s) { m_target_sort = s; } + void set_max_height(unsigned h) { m_max_height = h; } + + bool has_next() { + if (m_pending) return true; + m_pending = find_next(); + return m_pending != nullptr; + } + + expr_ref next() { + if (!m_pending) + m_pending = find_next(); + expr_ref result(m_pending, m); + m_pending = nullptr; + return result; + } + + expr_ref_vector enumerate_up_to(unsigned max_height) { + m_max_height = max_height; + expr_ref_vector results(m); + while (has_next()) { + results.push_back(next()); + } + return results; + } + + TermBank const& bank() const { return m_bank; } + unsigned current_height() const { return m_height; } + + void reset() { + m_height = 0; + m_leaf_idx = 0; + m_op_idx = 0; + m_state = State::Leaves; + m_bank.reset(); + m_oe.clear(); + m_pending = nullptr; + m_children_iter.reset(); + } + +private: + enum class State { Leaves, Operators, Done }; + + Grammar& m_grammar; + ast_manager& m; + OEManager& m_oe; + TermBank m_bank; + unsigned m_height; + unsigned m_leaf_idx; + unsigned m_op_idx; + State m_state; + sort* m_target_sort; + expr_ref m_pending; + std::unique_ptr m_children_iter; + unsigned m_max_height = 100; + + bool sort_matches(expr* e) const { + if (!m_target_sort) return true; + return e->get_sort() == m_target_sort; + } + + expr* find_next() { + while (true) { + switch (m_state) { + case State::Leaves: + while (m_leaf_idx < m_grammar.leaves().size()) { + Production const& prod = m_grammar.leaves()[m_leaf_idx]; + m_leaf_idx++; + expr_ref_vector empty_args(m); + expr_ref term = prod.builder(empty_args); + if (m_oe.is_representative(term)) { + m_bank.add(term, 0); + if (sort_matches(term)) + return term; + } + } + m_state = State::Operators; + m_height = 1; + m_op_idx = 0; + m_children_iter.reset(); + break; + + case State::Operators: { + expr* result = enumerate_operators(); + if (result) return result; + m_height++; + if (m_height > m_max_height) { + m_state = State::Done; + break; + } + m_op_idx = 0; + m_children_iter.reset(); + break; + } + case State::Done: + return nullptr; + } + } + } + + expr* enumerate_operators() { + auto const& ops = m_grammar.operators(); + while (true) { + if (m_children_iter && m_children_iter->has_next()) { + expr_ref_vector children = m_children_iter->next(); + Production const& prod = ops[m_op_idx - 1]; + expr_ref term = prod.builder(children); + if (m_oe.is_representative(term)) { + m_bank.add(term, m_height); + if (sort_matches(term)) + return term; + } + continue; + } + if (m_op_idx >= ops.size()) return nullptr; + Production const& prod = ops[m_op_idx]; + m_op_idx++; + m_children_iter = std::make_unique(m, prod, m_bank, m_height); + } + } +}; + +// ============================================================================ +// CEGIS integration helper +// ============================================================================ + +/** + * Counter-Example Guided Inductive Synthesis loop. + * Combines the enumerator with a solver to verify candidates against a + * specification. + * + * spec: a function (expr* candidate) -> expr_ref that returns the specification + * constraint (should be valid for a correct program). + * variables: the free variables of the specification. + */ +class CEGISLoop { +public: + CEGISLoop(Grammar& grammar, sort* target_sort, + std::function spec, + expr_ref_vector variables) + : m(grammar.mgr()), m_grammar(grammar), m_oe(grammar.mgr()), + m_enumerator(grammar, m_oe), + m_spec(std::move(spec)), m_variables(std::move(variables)) + { + m_enumerator.set_target_sort(target_sort); + params_ref p; + m_solver = mk_smt_solver(m, p, symbol::null); + } + + /** + * Run the CEGIS loop. Returns the synthesized term, or null expr_ref if + * max_height is exceeded. + */ + expr_ref synthesize(unsigned max_height = 10, unsigned max_restarts = 20) { + m_enumerator.set_max_height(max_height); + unsigned restarts = 0; + + while (m_enumerator.has_next()) { + expr_ref candidate = m_enumerator.next(); + + if (!satisfies_samples(candidate)) continue; + + expr_ref spec_expr = m_spec(candidate); + m_solver->push(); + m_solver->assert_expr(m.mk_not(spec_expr)); + lbool result = m_solver->check_sat(0, nullptr); + + if (result == l_false) { + m_solver->pop(1); + return candidate; + } else if (result == l_true) { + model_ref cex; + m_solver->get_model(cex); + m_oe.add_sample(cex); + m_samples.push_back(cex); + m_solver->pop(1); + restarts++; + if (restarts > max_restarts) return expr_ref(m); + m_enumerator.reset(); + } else { + m_solver->pop(1); + } + } + return expr_ref(m); + } + + size_t num_samples() const { return m_oe.num_samples(); } + size_t num_equivalence_classes() const { return m_oe.num_classes(); } + +private: + ast_manager& m; + Grammar& m_grammar; + OEManager m_oe; + Enumerator m_enumerator; + std::function m_spec; + expr_ref_vector m_variables; + ref m_solver; + vector m_samples; + + bool satisfies_samples(expr* candidate) { + expr_ref spec_expr = m_spec(candidate); + for (auto& mdl : m_samples) { + model_evaluator eval(*mdl); + eval.set_model_completion(true); + if (eval.is_false(spec_expr)) + return false; + } + return true; + } +}; + +} // namespace term_enum