3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2026-07-02 05:16:08 +00:00

integrate size measure for cost and avoid duplicates

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2026-06-18 10:16:54 -06:00
parent dcbbb55b73
commit 8e6ed26fc2
2 changed files with 159 additions and 134 deletions

View file

@ -5,7 +5,7 @@
* for Bottom-Up Enumerative Synthesis"). Adapted to use Z3's internal APIs. * for Bottom-Up Enumerative Synthesis"). Adapted to use Z3's internal APIs.
* *
* Key ideas: * Key ideas:
* - Terms are enumerated bottom-up by "height" (max nesting depth). * - Terms are enumerated bottom-up by "cost" (calculated by tree size).
* - Observational equivalence (OE): two terms that produce the same outputs * - Observational equivalence (OE): two terms that produce the same outputs
* on all sample inputs are considered equivalent; only one representative * on all sample inputs are considered equivalent; only one representative
* per equivalence class is kept. * per equivalence class is kept.
@ -13,23 +13,22 @@
* (constants, variables) are available for enumeration. * (constants, variables) are available for enumeration.
*/ */
#include "ast/term_enumeration.h"
#include <sstream> #include <sstream>
#include <unordered_set> #include <unordered_set>
#include <functional> #include <functional>
#include <string> #include <string>
#include "util/vector.h"
#include "util/ref.h"
#include "util/obj_hashtable.h"
#include "ast/ast.h" #include "ast/ast.h"
#include "ast/ast_pp.h"
#include "ast/ast_ll_pp.h" #include "ast/ast_ll_pp.h"
#include "ast/arith_decl_plugin.h" #include "ast/arith_decl_plugin.h"
#include "ast/bv_decl_plugin.h" #include "ast/bv_decl_plugin.h"
#include "ast/seq_decl_plugin.h" #include "ast/seq_decl_plugin.h"
#include "ast/term_enumeration.h"
#include "model/model.h" #include "model/model.h"
#include "model/model_evaluator.h" #include "model/model_evaluator.h"
#include "util/vector.h"
#include "util/ref.h"
#include "util/obj_hashtable.h"
namespace term_enum { namespace term_enum {
@ -199,9 +198,11 @@ private:
}; };
// ============================================================================ // ============================================================================
// Term Bank - stores enumerated terms by height and sort // Term Bank - stores enumerated terms by cost and sort
// ============================================================================ // ============================================================================
using cost_terms = vector<std::pair<expr*, unsigned>>;
class term_bank { class term_bank {
using sort_term_map = obj_map<sort, ptr_vector<expr>>; using sort_term_map = obj_map<sort, ptr_vector<expr>>;
public: public:
@ -217,35 +218,41 @@ public:
m_terms.clear(); m_terms.clear();
} }
void add(expr* term, unsigned height) { void add(expr* term, unsigned cost) {
sort* s = term->get_sort(); sort* s = term->get_sort();
unsigned sid = s->get_id();
m_pinned.push_back(term); m_pinned.push_back(term);
if (height >= m_terms.size()) if (cost >= m_terms.size())
m_terms.resize(height + 1); m_terms.resize(cost + 1);
if (!m_terms[height]) if (!m_terms[cost])
m_terms[height] = alloc(sort_term_map); m_terms[cost] = alloc(sort_term_map);
m_terms[height]->insert_if_not_there(s, ptr_vector<expr>()).push_back(term); m_terms[cost]->insert_if_not_there(s, ptr_vector<expr>()).push_back(term);
} }
/** Get all terms of a given sort up to (and including) max_height */ /** Get all terms of a given sort up to (and including) max_cost */
expr_ref_vector get_by_sort(sort* s, unsigned max_height) const { cost_terms get_by_sort(sort* s, unsigned max_cost) const {
expr_ref_vector result(m); cost_terms result;
for (unsigned h = 0; h <= max_height; ++h) { for (unsigned c = 0; c <= max_cost; ++c) {
if (h >= m_terms.size()) if (c >= m_terms.size())
break; break;
if (!m_terms[h]->contains(s)) if (!m_terms[c]->contains(s))
continue; continue;
for (auto t : m_terms[h]->find(s)) for (auto t : m_terms[c]->find(s))
result.push_back(t); result.push_back({t, c});
} }
return result; return result;
} }
ptr_vector<expr> null_ptr_vector;
ptr_vector<expr> const &get_by_cost_and_sort(unsigned cost, sort *s) const {
if (cost >= m_terms.size() || !m_terms[cost] || !m_terms[cost]->contains(s))
return null_ptr_vector;
return m_terms[cost]->find(s);
}
private: private:
ast_manager& m; ast_manager& m;
expr_ref_vector m_pinned; expr_ref_vector m_pinned;
// height -> sort -> terms // cost -> sort -> terms
ptr_vector<sort_term_map> m_terms; ptr_vector<sort_term_map> m_terms;
}; };
@ -256,12 +263,12 @@ private:
/** /**
* Iterates over all tuples (c1, c2, ..., cn) where each ci has the required * Iterates over all tuples (c1, c2, ..., cn) where each ci has the required
* sort, drawn from the term bank, with at least one child at the current * sort, drawn from the term bank, with at least one child at the current
* height - 1 (to avoid regenerating previously seen terms). * cost - 1 (to avoid regenerating previously seen terms).
*/ */
class children_iterator { class children_iterator {
public: public:
children_iterator(ast_manager& m, production const& prod, term_bank const& bank, unsigned current_height) children_iterator(ast_manager& m, production const& prod, term_bank const& bank, unsigned current_cost)
: m(m), m_prod(prod), m_current_height(current_height), m_done(false) : m(m), m_prod(prod), m_current_cost(current_cost), m_done(false)
{ {
m_arity = prod.domain.size(); m_arity = prod.domain.size();
if (m_arity == 0) { if (m_arity == 0) {
@ -269,7 +276,7 @@ public:
return; return;
} }
for (unsigned i = 0; i < m_arity; ++i) { for (unsigned i = 0; i < m_arity; ++i) {
m_candidates.push_back(bank.get_by_sort(prod.domain[i], current_height - 1)); m_candidates.push_back(bank.get_by_sort(prod.domain[i], current_cost - 1));
if (m_candidates.back().empty()) { if (m_candidates.back().empty()) {
m_done = true; m_done = true;
return; return;
@ -278,19 +285,23 @@ public:
m_indices.resize(m_arity, 0); m_indices.resize(m_arity, 0);
} }
bool has_next() { bool has_next(unsigned cost) {
while (!m_done) { while (!m_done) {
if (has_child_at_max_height()) if (has_child_at_cost(cost))
return true; return true;
advance(); advance();
} }
return false; return false;
} }
expr_ref_vector next() { expr_ref_vector next(unsigned& cost) {
expr_ref_vector result(m); expr_ref_vector result(m);
for (unsigned i = 0; i < m_arity; ++i) cost = 1;
result.push_back(m_candidates[i].get(m_indices[i])); for (unsigned i = 0; i < m_arity; ++i) {
auto [e, c] = m_candidates[i].get(m_indices[i]);
cost += c;
result.push_back(e);
}
advance(); advance();
return result; return result;
} }
@ -298,18 +309,23 @@ public:
private: private:
ast_manager& m; ast_manager& m;
production const& m_prod; production const& m_prod;
unsigned m_current_height; unsigned m_current_cost;
unsigned m_arity; unsigned m_arity;
bool m_done; bool m_done;
vector<expr_ref_vector> m_candidates; vector<cost_terms> m_candidates;
svector<unsigned> m_indices; svector<unsigned> m_indices;
bool has_child_at_max_height() const { bool has_child_at_cost(unsigned cost) const {
return true; for (unsigned i = 0; i < m_arity; ++i) {
auto [e, c] = m_candidates[i].get(m_indices[i]);
if (c + 1 == cost)
return true;
}
return false;
} }
void advance() { void advance() {
for (auto i = m_arity; i-- > 0;) { for (auto i = m_arity; i-- > 0;) {
m_indices[i]++; m_indices[i]++;
if (m_indices[i] < m_candidates[i].size()) return; if (m_indices[i] < m_candidates[i].size()) return;
m_indices[i] = 0; m_indices[i] = 0;
@ -323,8 +339,8 @@ private:
// ============================================================================ // ============================================================================
/** /**
* Enumerates terms bottom-up by height, applying observational equivalence * Enumerates terms bottom-up by cost, applying observational equivalence
* pruning. Users iterate via has_next() / next(), or call enumerate_up_to(). * pruning. Users iterate via has_next() / next().
* *
* Usage: * Usage:
* ast_manager m; * ast_manager m;
@ -344,13 +360,10 @@ class bottom_up_enumerator {
public: public:
bottom_up_enumerator(grammar& grammar, oe_manager& oe) bottom_up_enumerator(grammar& grammar, oe_manager& oe)
: m_grammar(grammar), m(grammar.mgr()), m_oe(oe), : m_grammar(grammar), m(grammar.mgr()), m_oe(oe),
m_bank(grammar.mgr()), m_height(0), m_bank(grammar.mgr()), m_pending(grammar.mgr())
m_leaf_idx(0), m_op_idx(0), m_state(State::Leaves),
m_target_sort(nullptr), m_pending(grammar.mgr())
{} {}
void set_target_sort(sort* s) { m_target_sort = s; } void set_target_sort(sort* s) { m_target_sort = s; }
void set_max_height(unsigned h) { m_max_height = h; }
bool has_next() { bool has_next() {
if (m_pending) return true; if (m_pending) return true;
@ -366,20 +379,10 @@ public:
return result; return result;
} }
expr_ref_vector enumerate_up_to(unsigned max_height) {
m_max_height = max_height;
expr_ref_vector results(m);
while (has_next()) {
results.push_back(next());
}
return results;
}
term_bank const& bank() const { return m_bank; } term_bank const& bank() const { return m_bank; }
unsigned current_height() const { return m_height; }
void reset() { void reset() {
m_height = 0; m_cost = 0;
m_leaf_idx = 0; m_leaf_idx = 0;
m_op_idx = 0; m_op_idx = 0;
m_state = State::Leaves; m_state = State::Leaves;
@ -390,20 +393,21 @@ public:
} }
private: private:
enum class State { Leaves, operators, Done }; enum class State { Leaves, Operators, Done };
grammar& m_grammar; grammar& m_grammar;
ast_manager& m; ast_manager& m;
oe_manager& m_oe; oe_manager& m_oe;
term_bank m_bank; term_bank m_bank;
unsigned m_height; unsigned m_cost = 0;
unsigned m_leaf_idx; unsigned m_leaf_idx = 0;
unsigned m_op_idx; unsigned m_op_idx = 0;
State m_state; unsigned m_bank_idx = 0;
sort* m_target_sort; unsigned m_bank_size = 0;
State m_state = State::Leaves;
sort* m_target_sort = nullptr;
expr_ref m_pending; expr_ref m_pending;
std::unique_ptr<children_iterator> m_children_iter; std::unique_ptr<children_iterator> m_children_iter;
unsigned m_max_height = 100;
bool sort_matches(expr* e) const { bool sort_matches(expr* e) const {
return !m_target_sort || e->get_sort() == m_target_sort; return !m_target_sort || e->get_sort() == m_target_sort;
@ -424,21 +428,22 @@ private:
return term; return term;
} }
} }
m_state = State::operators; m_state = State::Operators;
m_height = 1; m_cost = 1;
m_op_idx = 0; m_op_idx = 0;
m_bank_idx = 0;
m_bank_size = get_bank_size();
m_children_iter.reset(); m_children_iter.reset();
break; break;
case State::operators: { case State::Operators: {
expr* result = enumerate_operators(); expr* result = enumerate_operators();
if (result) return result; if (result)
m_height++; return result;
if (m_height > m_max_height) { m_cost++;
m_state = State::Done;
break;
}
m_op_idx = 0; m_op_idx = 0;
m_bank_idx = 0;
m_bank_size = get_bank_size();
m_children_iter.reset(); m_children_iter.reset();
break; break;
} }
@ -448,24 +453,43 @@ private:
} }
} }
expr* enumerate_operators() { unsigned get_bank_size() const {
auto const& ops = m_grammar.operators(); auto const &terms = m_bank.get_by_cost_and_sort(m_cost, m_target_sort);
return terms.size();
}
expr *enumerate_operators() {
auto const &ops = m_grammar.operators();
while (true) { while (true) {
if (m_children_iter && m_children_iter->has_next()) {
expr_ref_vector children = m_children_iter->next(); // first find terms at m_cost that were already created
production const& prod = ops[m_op_idx - 1]; if (m_bank_idx < m_bank_size) {
auto const &terms = m_bank.get_by_cost_and_sort(m_cost, m_target_sort);
auto t = terms.get(m_bank_idx);
m_bank_idx++;
SASSERT(sort_matches(t));
return t;
}
// then create new terms using children at cost below current m_cost.
if (m_children_iter && m_children_iter->has_next(m_cost)) {
unsigned new_cost = 0;
expr_ref_vector children = m_children_iter->next(new_cost);
production const &prod = ops[m_op_idx - 1];
expr_ref term = prod.builder(children); expr_ref term = prod.builder(children);
SASSERT(new_cost >= m_cost);
if (m_oe.is_representative(term)) { if (m_oe.is_representative(term)) {
m_bank.add(term, m_height); m_bank.add(term, new_cost);
if (sort_matches(term)) if (sort_matches(term) && new_cost == m_cost)
return term; return term;
} }
continue; continue;
} }
if (m_op_idx >= ops.size()) return nullptr; if (m_op_idx >= ops.size())
production const& prod = ops[m_op_idx]; return nullptr;
production const &prod = ops[m_op_idx];
m_op_idx++; m_op_idx++;
m_children_iter = std::make_unique<children_iterator>(m, prod, m_bank, m_height); m_children_iter = std::make_unique<children_iterator>(m, prod, m_bank, m_cost);
} }
} }
}; };
@ -498,37 +522,35 @@ struct term_enumeration::imp {
m_cost = cost; m_cost = cost;
} }
// Enumerate terms of given sort up to a height, ordered by cost. // Enumerate terms of given sort up to a cost, ordered by cost.
// Returns the next term in cost order, or nullptr if exhausted at current height. // Returns the next term in cost order, or nullptr if exhausted at current cost.
expr* next_term(sort* s, unsigned& height_state, unsigned& idx_state, expr* next_term(sort* s, unsigned& cost_state, unsigned& idx_state,
vector<expr_ref_vector>& levels) { vector<expr_ref_vector>& levels) {
// Expand levels as needed // Expand levels as needed
while (idx_state >= level_size(levels, height_state)) { while (idx_state >= level_size(levels, cost_state)) {
height_state++; cost_state++;
if (height_state > 100) expand_level(s, cost_state, levels);
return nullptr;
expand_level(s, height_state, levels);
idx_state = 0; idx_state = 0;
if (level_size(levels, height_state) > 0) if (level_size(levels, cost_state) > 0)
break; break;
} }
if (height_state >= levels.size() || idx_state >= levels[height_state].size()) if (cost_state >= levels.size() || idx_state >= levels[cost_state].size())
return nullptr; return nullptr;
return levels[height_state].get(idx_state++); return levels[cost_state].get(idx_state++);
} }
private: private:
unsigned level_size(vector<expr_ref_vector> const& levels, unsigned h) const { unsigned level_size(vector<expr_ref_vector> const& levels, unsigned cost) const {
if (h >= levels.size()) return 0; if (cost >= levels.size()) return 0;
return levels[h].size(); return levels[cost].size();
} }
void expand_level(sort* s, unsigned height, vector<expr_ref_vector>& levels) { void expand_level(sort* s, unsigned cost, vector<expr_ref_vector>& levels) {
if (height >= levels.size()) if (cost >= levels.size())
levels.resize(height + 1, expr_ref_vector(m)); levels.resize(cost + 1, expr_ref_vector(m));
// Collect terms at this height // Collect terms at this cost
if (height == 0) { if (cost == 0) {
// Leaves // Leaves
for (auto const& prod : m_grammar.leaves()) { for (auto const& prod : m_grammar.leaves()) {
if (prod.range.get() != s) continue; if (prod.range.get() != s) continue;
@ -544,19 +566,21 @@ private:
// operators // operators
for (auto const& prod : m_grammar.operators()) { for (auto const& prod : m_grammar.operators()) {
if (prod.range.get() != s) continue; if (prod.range.get() != s) continue;
term_enum::children_iterator iter(m, prod, m_bottom_up_enumerator.bank(), height); term_enum::children_iterator iter(m, prod, m_bottom_up_enumerator.bank(), cost);
while (iter.has_next()) { while (iter.has_next(cost)) {
expr_ref_vector children = iter.next(); unsigned new_cost = 0;
expr_ref_vector children = iter.next(new_cost);
expr_ref term = prod.builder(children); expr_ref term = prod.builder(children);
levels.reserve(new_cost + 1, expr_ref_vector(m));
if (m_oe.is_representative(term)) if (m_oe.is_representative(term))
levels[height].push_back(term); levels[new_cost].push_back(term);
} }
} }
} }
// Sort by cost if cost function is set // Sort by cost if cost function is set
if (m_cost && !levels[height].empty()) { if (m_cost && !levels[cost].empty()) {
expr_ref_vector& lv = levels[height]; expr_ref_vector& lv = levels[cost];
std::sort(lv.data(), lv.data() + lv.size(), std::sort(lv.data(), lv.data() + lv.size(),
[&](expr* a, expr* b) { return m_cost(a) < m_cost(b); }); [&](expr* a, expr* b) { return m_cost(a) < m_cost(b); });
} }
@ -568,30 +592,29 @@ private:
struct term_enumeration::iterator::iter_imp { struct term_enumeration::iterator::iter_imp {
imp& m_imp; imp& m_imp;
sort* m_sort; sort* m_sort;
unsigned m_height; unsigned m_cost = 0;
unsigned m_idx; unsigned m_idx = 0;
vector<expr_ref_vector> m_levels; vector<expr_ref_vector> m_levels;
expr* m_current; expr* m_current = nullptr;
bool m_end; bool m_end;
iter_imp(imp& i, sort* s) : iter_imp(imp& i, sort* s) : m_imp(i), m_sort(s), m_end(false) {
m_imp(i), m_sort(s), m_height(0), m_idx(0), m_current(nullptr), m_end(false) {
expand_current_level(); expand_current_level();
advance_to_valid(); advance_to_valid();
} }
// Sentinel constructor // Sentinel constructor
iter_imp(imp& i) : iter_imp(imp& i) :
m_imp(i), m_sort(nullptr), m_height(0), m_idx(0), m_current(nullptr), m_end(true) {} m_imp(i), m_sort(nullptr), m_end(true) {}
void expand_current_level() { void expand_current_level() {
if (m_height >= m_levels.size()) if (m_cost >= m_levels.size())
m_levels.resize(m_height + 1, expr_ref_vector(m_imp.m)); m_levels.resize(m_cost + 1, expr_ref_vector(m_imp.m));
if (!m_levels[m_height].empty()) if (!m_levels[m_cost].empty())
return; return;
if (m_height == 0) { if (m_cost == 0) {
for (auto const& prod : m_imp.m_grammar.leaves()) { for (auto const& prod : m_imp.m_grammar.leaves()) {
if (prod.range.get() != m_sort) continue; if (prod.range.get() != m_sort) continue;
expr_ref_vector empty_args(m_imp.m); expr_ref_vector empty_args(m_imp.m);
@ -603,19 +626,21 @@ struct term_enumeration::iterator::iter_imp {
else { else {
for (auto const& prod : m_imp.m_grammar.operators()) { for (auto const& prod : m_imp.m_grammar.operators()) {
if (prod.range.get() != m_sort) continue; if (prod.range.get() != m_sort) continue;
term_enum::children_iterator iter(m_imp.m, prod, m_imp.m_bottom_up_enumerator.bank(), m_height); term_enum::children_iterator iter(m_imp.m, prod, m_imp.m_bottom_up_enumerator.bank(), m_cost);
while (iter.has_next()) { while (iter.has_next(m_cost)) {
expr_ref_vector children = iter.next(); unsigned new_cost = 0;
expr_ref_vector children = iter.next(new_cost);
expr_ref term = prod.builder(children); expr_ref term = prod.builder(children);
m_levels.reserve(new_cost + 1, expr_ref_vector(m_imp.m));
if (m_imp.m_oe.is_representative(term)) if (m_imp.m_oe.is_representative(term))
m_levels[m_height].push_back(term); m_levels[new_cost].push_back(term);
} }
} }
} }
// Sort by cost if cost function is set // Sort by cost if cost function is set
if (m_imp.m_cost && !m_levels[m_height].empty()) { if (m_imp.m_cost && !m_levels[m_cost].empty()) {
expr_ref_vector& lv = m_levels[m_height]; expr_ref_vector& lv = m_levels[m_cost];
std::sort(lv.data(), lv.data() + lv.size(), std::sort(lv.data(), lv.data() + lv.size(),
[&](expr* a, expr* b) { return m_imp.m_cost(a) < m_imp.m_cost(b); }); [&](expr* a, expr* b) { return m_imp.m_cost(a) < m_imp.m_cost(b); });
} }
@ -623,15 +648,15 @@ struct term_enumeration::iterator::iter_imp {
void advance_to_valid() { void advance_to_valid() {
while (true) { while (true) {
if (m_height >= m_levels.size()) if (m_cost >= m_levels.size())
expand_current_level(); expand_current_level();
if (m_idx < m_levels[m_height].size()) { if (m_idx < m_levels[m_cost].size()) {
m_current = m_levels[m_height].get(m_idx); m_current = m_levels[m_cost].get(m_idx);
return; return;
} }
m_height++; m_cost++;
m_idx = 0; m_idx = 0;
if (m_height > 100) { if (m_cost > 100) {
m_end = true; m_end = true;
m_current = nullptr; m_current = nullptr;
return; return;

View file

@ -187,14 +187,14 @@ namespace smt {
\brief Base class used to solve model construction constraints. \brief Base class used to solve model construction constraints.
*/ */
class node { class node {
unsigned m_id; unsigned m_id = 0;
node* m_find{ nullptr }; node* m_find = nullptr;
unsigned m_eqc_size{ 1 }; unsigned m_eqc_size = 1;
sort* m_sort; // sort of the elements in the instantiation set. sort* m_sort = nullptr; // sort of the elements in the instantiation set.
bool m_mono_proj{ false }; // relevant for integers & reals & bit-vectors bool m_mono_proj = false; // relevant for integers & reals & bit-vectors
bool m_signed_proj{ false }; // relevant for bit-vectors. bool m_signed_proj = false; // relevant for bit-vectors.
ptr_vector<node> m_avoid_set; ptr_vector<node> m_avoid_set;
ptr_vector<expr> m_exceptions; ptr_vector<expr> m_exceptions;
@ -1235,8 +1235,8 @@ namespace smt {
void populate_inst_sets(quantifier* q, func_decl* mhead, ptr_vector<instantiation_set>& uvar_inst_sets, context* ctx) override { void populate_inst_sets(quantifier* q, func_decl* mhead, ptr_vector<instantiation_set>& uvar_inst_sets, context* ctx) override {
if (m_f != mhead) if (m_f != mhead)
return; return;
uvar_inst_sets.reserve(m_var_j + 1, 0); uvar_inst_sets.reserve(m_var_j + 1, nullptr);
if (uvar_inst_sets[m_var_j] == 0) if (uvar_inst_sets[m_var_j] == nullptr)
uvar_inst_sets[m_var_j] = alloc(instantiation_set, ctx->get_manager()); uvar_inst_sets[m_var_j] = alloc(instantiation_set, ctx->get_manager());
instantiation_set* s = uvar_inst_sets[m_var_j]; instantiation_set* s = uvar_inst_sets[m_var_j];
SASSERT(s != nullptr); SASSERT(s != nullptr);