mirror of
https://github.com/Z3Prover/z3
synced 2026-06-19 15:16:29 +00:00
trim implementation
Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
parent
19e00e03c1
commit
7063ab4646
1 changed files with 73 additions and 317 deletions
|
|
@ -9,18 +9,15 @@
|
|||
* - Observational equivalence (OE): two terms that produce the same outputs
|
||||
* on all sample inputs are considered equivalent; only one representative
|
||||
* per equivalence class is kept.
|
||||
* - A Grammar describes which function symbols (operators) and leaves
|
||||
* - A grammar describes which function symbols (operators) and leaves
|
||||
* (constants, variables) are available for enumeration.
|
||||
*/
|
||||
|
||||
#include "ast/term_enumeration.h"
|
||||
#include <sstream>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <functional>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
#include "ast/ast.h"
|
||||
#include "ast/ast_pp.h"
|
||||
#include "ast/ast_ll_pp.h"
|
||||
|
|
@ -29,8 +26,7 @@
|
|||
#include "ast/seq_decl_plugin.h"
|
||||
#include "model/model.h"
|
||||
#include "model/model_evaluator.h"
|
||||
#include "solver/solver.h"
|
||||
#include "smt/smt_solver.h"
|
||||
|
||||
#include "util/vector.h"
|
||||
#include "util/ref.h"
|
||||
#include "util/obj_hashtable.h"
|
||||
|
|
@ -38,16 +34,16 @@
|
|||
namespace term_enum {
|
||||
|
||||
// ============================================================================
|
||||
// Grammar production rule
|
||||
// grammar production rule
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* A Production describes how to construct a term from child terms.
|
||||
* A production describes how to construct a term from child terms.
|
||||
* - domain: the sort required for each child
|
||||
* - range: the sort of the produced term
|
||||
* - builder: given a vector of child exprs, produce the result expr
|
||||
*/
|
||||
struct Production {
|
||||
struct production {
|
||||
std::string name;
|
||||
sort_ref range;
|
||||
sort_ref_vector domain;
|
||||
|
|
@ -57,25 +53,25 @@ struct Production {
|
|||
};
|
||||
|
||||
// ============================================================================
|
||||
// Grammar
|
||||
// grammar
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* A Grammar groups productions into leaves (arity 0) and operators (arity > 0).
|
||||
* A grammar groups productions into leaves (arity 0) and operators (arity > 0).
|
||||
*/
|
||||
class Grammar {
|
||||
class grammar {
|
||||
public:
|
||||
Grammar(ast_manager& m) : m(m) {}
|
||||
grammar(ast_manager& m) : m(m) {}
|
||||
|
||||
void add_production(Production p) {
|
||||
void add_production(production p) {
|
||||
if (p.is_leaf())
|
||||
m_leaves.push_back(std::move(p));
|
||||
else
|
||||
m_operators.push_back(std::move(p));
|
||||
}
|
||||
|
||||
vector<Production> const& leaves() const { return m_leaves; }
|
||||
vector<Production> const& operators() const { return m_operators; }
|
||||
vector<production> const& leaves() const { return m_leaves; }
|
||||
vector<production> const& operators() const { return m_operators; }
|
||||
ast_manager& mgr() const { return m; }
|
||||
|
||||
void add_variable(char const* name, sort* s) {
|
||||
|
|
@ -138,158 +134,10 @@ public:
|
|||
|
||||
private:
|
||||
ast_manager& m;
|
||||
vector<Production> m_leaves;
|
||||
vector<Production> m_operators;
|
||||
vector<production> m_leaves;
|
||||
vector<production> m_operators;
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// Standard grammar factories - build common operator sets
|
||||
// ============================================================================
|
||||
|
||||
namespace grammars {
|
||||
|
||||
/**
|
||||
* Build a grammar over linear integer arithmetic.
|
||||
* Operators: +, -, *, ite (with bool condition)
|
||||
*/
|
||||
inline void add_lia_operators(Grammar& g) {
|
||||
ast_manager& m = g.mgr();
|
||||
arith_util a(m);
|
||||
sort_ref isort(a.mk_int(), m);
|
||||
sort_ref bsort(m.mk_bool_sort(), m);
|
||||
|
||||
sort_ref_vector ii(m); ii.push_back(isort); ii.push_back(isort);
|
||||
sort_ref_vector i1(m); i1.push_back(isort);
|
||||
sort_ref_vector bb(m); bb.push_back(bsort); bb.push_back(bsort);
|
||||
sort_ref_vector b1(m); b1.push_back(bsort);
|
||||
sort_ref_vector bii(m); bii.push_back(bsort); bii.push_back(isort); bii.push_back(isort);
|
||||
|
||||
g.add_production({"add", isort, ii,
|
||||
[&m](expr_ref_vector const& ch) { arith_util a(m); return expr_ref(a.mk_add(ch[0], ch[1]), m); }});
|
||||
g.add_production({"sub", isort, ii,
|
||||
[&m](expr_ref_vector const& ch) { arith_util a(m); return expr_ref(a.mk_sub(ch[0], ch[1]), m); }});
|
||||
g.add_production({"mul", isort, ii,
|
||||
[&m](expr_ref_vector const& ch) { arith_util a(m); return expr_ref(a.mk_mul(ch[0], ch[1]), m); }});
|
||||
g.add_production({"neg", isort, i1,
|
||||
[&m](expr_ref_vector const& ch) { arith_util a(m); return expr_ref(a.mk_uminus(ch[0]), m); }});
|
||||
|
||||
g.add_production({"le", bsort, ii,
|
||||
[&m](expr_ref_vector const& ch) { arith_util a(m); return expr_ref(a.mk_le(ch[0], ch[1]), m); }});
|
||||
g.add_production({"lt", bsort, ii,
|
||||
[&m](expr_ref_vector const& ch) { arith_util a(m); return expr_ref(a.mk_lt(ch[0], ch[1]), m); }});
|
||||
g.add_production({"eq_int", bsort, ii,
|
||||
[&m](expr_ref_vector const& ch) { return expr_ref(m.mk_eq(ch[0], ch[1]), m); }});
|
||||
|
||||
g.add_production({"and", bsort, bb,
|
||||
[&m](expr_ref_vector const& ch) { return expr_ref(m.mk_and(ch[0], ch[1]), m); }});
|
||||
g.add_production({"or", bsort, bb,
|
||||
[&m](expr_ref_vector const& ch) { return expr_ref(m.mk_or(ch[0], ch[1]), m); }});
|
||||
g.add_production({"not", bsort, b1,
|
||||
[&m](expr_ref_vector const& ch) { return expr_ref(m.mk_not(ch[0]), m); }});
|
||||
|
||||
g.add_production({"ite_int", isort, bii,
|
||||
[&m](expr_ref_vector const& ch) { return expr_ref(m.mk_ite(ch[0], ch[1], ch[2]), m); }});
|
||||
}
|
||||
|
||||
/**
|
||||
* Build a grammar over bitvectors.
|
||||
*/
|
||||
inline void add_bv_operators(Grammar& g, unsigned bits) {
|
||||
ast_manager& m = g.mgr();
|
||||
bv_util bv(m);
|
||||
sort_ref bvsort(bv.mk_sort(bits), m);
|
||||
sort_ref bsort(m.mk_bool_sort(), m);
|
||||
|
||||
sort_ref_vector vv(m); vv.push_back(bvsort); vv.push_back(bvsort);
|
||||
sort_ref_vector v1(m); v1.push_back(bvsort);
|
||||
sort_ref_vector bvv(m); bvv.push_back(bsort); bvv.push_back(bvsort); bvv.push_back(bvsort);
|
||||
|
||||
g.add_production({"bvadd", bvsort, vv,
|
||||
[&m](expr_ref_vector const& ch) { bv_util bv(m); return expr_ref(bv.mk_bv_add(ch[0], ch[1]), m); }});
|
||||
g.add_production({"bvsub", bvsort, vv,
|
||||
[&m](expr_ref_vector const& ch) { bv_util bv(m); return expr_ref(bv.mk_bv_sub(ch[0], ch[1]), m); }});
|
||||
g.add_production({"bvmul", bvsort, vv,
|
||||
[&m](expr_ref_vector const& ch) { bv_util bv(m); return expr_ref(bv.mk_bv_mul(ch[0], ch[1]), m); }});
|
||||
g.add_production({"bvand", bvsort, vv,
|
||||
[&m](expr_ref_vector const& ch) { bv_util bv(m); return expr_ref(bv.mk_bv_and(ch[0], ch[1]), m); }});
|
||||
g.add_production({"bvor", bvsort, vv,
|
||||
[&m](expr_ref_vector const& ch) { bv_util bv(m); return expr_ref(bv.mk_bv_or(ch[0], ch[1]), m); }});
|
||||
g.add_production({"bvxor", bvsort, vv,
|
||||
[&m](expr_ref_vector const& ch) { bv_util bv(m); return expr_ref(bv.mk_bv_xor(ch[0], ch[1]), m); }});
|
||||
g.add_production({"bvnot", bvsort, v1,
|
||||
[&m](expr_ref_vector const& ch) { bv_util bv(m); return expr_ref(bv.mk_bv_not(ch[0]), m); }});
|
||||
g.add_production({"bvneg", bvsort, v1,
|
||||
[&m](expr_ref_vector const& ch) { bv_util bv(m); return expr_ref(bv.mk_bv_neg(ch[0]), m); }});
|
||||
g.add_production({"bvshl", bvsort, vv,
|
||||
[&m](expr_ref_vector const& ch) { bv_util bv(m); return expr_ref(bv.mk_bv_shl(ch[0], ch[1]), m); }});
|
||||
g.add_production({"bvlshr", bvsort, vv,
|
||||
[&m](expr_ref_vector const& ch) { bv_util bv(m); return expr_ref(bv.mk_bv_lshr(ch[0], ch[1]), m); }});
|
||||
g.add_production({"bvashr", bvsort, vv,
|
||||
[&m](expr_ref_vector const& ch) { bv_util bv(m); return expr_ref(bv.mk_bv_ashr(ch[0], ch[1]), m); }});
|
||||
|
||||
g.add_production({"bvult", bsort, vv,
|
||||
[&m](expr_ref_vector const& ch) { bv_util bv(m); return expr_ref(m.mk_app(bv.get_fid(), OP_ULT, ch[0], ch[1]), m); }});
|
||||
g.add_production({"bvslt", bsort, vv,
|
||||
[&m](expr_ref_vector const& ch) { bv_util bv(m); return expr_ref(bv.mk_slt(ch[0], ch[1]), m); }});
|
||||
g.add_production({"bveq", bsort, vv,
|
||||
[&m](expr_ref_vector const& ch) { return expr_ref(m.mk_eq(ch[0], ch[1]), m); }});
|
||||
|
||||
g.add_production({"ite_bv", bvsort, bvv,
|
||||
[&m](expr_ref_vector const& ch) { return expr_ref(m.mk_ite(ch[0], ch[1], ch[2]), m); }});
|
||||
}
|
||||
|
||||
/**
|
||||
* Build a grammar over strings.
|
||||
*/
|
||||
inline void add_string_operators(Grammar& g) {
|
||||
ast_manager& m = g.mgr();
|
||||
seq_util seq(m);
|
||||
arith_util a(m);
|
||||
sort_ref ssort(seq.str.mk_string_sort(), m);
|
||||
sort_ref isort(a.mk_int(), m);
|
||||
sort_ref bsort(m.mk_bool_sort(), m);
|
||||
|
||||
sort_ref_vector ss(m); ss.push_back(ssort); ss.push_back(ssort);
|
||||
sort_ref_vector s1(m); s1.push_back(ssort);
|
||||
sort_ref_vector si(m); si.push_back(ssort); si.push_back(isort);
|
||||
sort_ref_vector sii(m); sii.push_back(ssort); sii.push_back(isort); sii.push_back(isort);
|
||||
sort_ref_vector ssi(m); ssi.push_back(ssort); ssi.push_back(ssort); ssi.push_back(isort);
|
||||
sort_ref_vector sss(m); sss.push_back(ssort); sss.push_back(ssort); sss.push_back(ssort);
|
||||
sort_ref_vector i1(m); i1.push_back(isort);
|
||||
sort_ref_vector bss(m); bss.push_back(bsort); bss.push_back(ssort); bss.push_back(ssort);
|
||||
|
||||
g.add_production({"str.++", ssort, ss,
|
||||
[&m](expr_ref_vector const& ch) { seq_util seq(m); return expr_ref(seq.str.mk_concat(ch[0], ch[1]), m); }});
|
||||
g.add_production({"str.len", isort, s1,
|
||||
[&m](expr_ref_vector const& ch) { seq_util seq(m); return expr_ref(seq.str.mk_length(ch[0]), m); }});
|
||||
g.add_production({"str.at", ssort, si,
|
||||
[&m](expr_ref_vector const& ch) {
|
||||
seq_util seq(m); arith_util a(m);
|
||||
return expr_ref(seq.str.mk_substr(ch[0], ch[1], a.mk_int(1)), m);
|
||||
}});
|
||||
g.add_production({"str.substr", ssort, sii,
|
||||
[&m](expr_ref_vector const& ch) { seq_util seq(m); return expr_ref(seq.str.mk_substr(ch[0], ch[1], ch[2]), m); }});
|
||||
g.add_production({"str.indexof", isort, ssi,
|
||||
[&m](expr_ref_vector const& ch) { seq_util seq(m); return expr_ref(seq.str.mk_index(ch[0], ch[1], ch[2]), m); }});
|
||||
g.add_production({"str.replace", ssort, sss,
|
||||
[&m](expr_ref_vector const& ch) { seq_util seq(m); return expr_ref(seq.str.mk_replace(ch[0], ch[1], ch[2]), m); }});
|
||||
g.add_production({"str.contains", bsort, ss,
|
||||
[&m](expr_ref_vector const& ch) { seq_util seq(m); return expr_ref(seq.str.mk_contains(ch[0], ch[1]), m); }});
|
||||
g.add_production({"str.prefixof", bsort, ss,
|
||||
[&m](expr_ref_vector const& ch) { seq_util seq(m); return expr_ref(seq.str.mk_prefix(ch[0], ch[1]), m); }});
|
||||
g.add_production({"str.suffixof", bsort, ss,
|
||||
[&m](expr_ref_vector const& ch) { seq_util seq(m); return expr_ref(seq.str.mk_suffix(ch[0], ch[1]), m); }});
|
||||
g.add_production({"int.to.str", ssort, i1,
|
||||
[&m](expr_ref_vector const& ch) { seq_util seq(m); return expr_ref(seq.str.mk_itos(ch[0]), m); }});
|
||||
g.add_production({"str.to.int", isort, s1,
|
||||
[&m](expr_ref_vector const& ch) { seq_util seq(m); return expr_ref(seq.str.mk_stoi(ch[0]), m); }});
|
||||
|
||||
g.add_production({"ite_str", ssort, bss,
|
||||
[&m](expr_ref_vector const& ch) { return expr_ref(m.mk_ite(ch[0], ch[1], ch[2]), m); }});
|
||||
}
|
||||
|
||||
} // namespace grammars
|
||||
|
||||
// ============================================================================
|
||||
// Observational Equivalence Manager
|
||||
// ============================================================================
|
||||
|
|
@ -300,9 +148,9 @@ inline void add_string_operators(Grammar& g) {
|
|||
*
|
||||
* Uses Z3's model evaluation to reduce terms to concrete values.
|
||||
*/
|
||||
class OEManager {
|
||||
class oe_manager {
|
||||
public:
|
||||
OEManager(ast_manager& m) : m(m) {}
|
||||
oe_manager(ast_manager& m) : m(m) {}
|
||||
|
||||
void set_samples(vector<model_ref> samples) {
|
||||
m_samples = std::move(samples);
|
||||
|
|
@ -319,36 +167,34 @@ public:
|
|||
* not been seen before).
|
||||
*/
|
||||
bool is_representative(expr* term) {
|
||||
if (m_samples.empty()) return true;
|
||||
std::string fingerprint = compute_fingerprint(term);
|
||||
if (fingerprint.empty()) return false;
|
||||
auto fingerprint = compute_fingerprint(term);
|
||||
if (fingerprint == 0)
|
||||
return false;
|
||||
return m_seen.insert(fingerprint).second;
|
||||
}
|
||||
|
||||
void clear() { m_seen.clear(); }
|
||||
|
||||
size_t num_classes() const { return m_seen.size(); }
|
||||
size_t num_samples() const { return m_samples.size(); }
|
||||
unsigned num_samples() const { return m_samples.size(); }
|
||||
|
||||
private:
|
||||
ast_manager& m;
|
||||
vector<model_ref> m_samples;
|
||||
std::unordered_set<std::string> m_seen;
|
||||
std::unordered_set<uint64_t> m_seen;
|
||||
|
||||
std::string compute_fingerprint(expr* term) {
|
||||
std::string fp;
|
||||
uint64_t compute_fingerprint(expr* term) {
|
||||
uint64_t a = 0, b = 1, c = 2;
|
||||
for (auto& mdl : m_samples) {
|
||||
expr_ref val(m);
|
||||
model_evaluator eval(*mdl);
|
||||
eval.set_model_completion(true);
|
||||
if (!eval.eval(term, val, true))
|
||||
return "";
|
||||
std::ostringstream os;
|
||||
os << mk_pp(val, m);
|
||||
fp += os.str();
|
||||
fp += '\x1f';
|
||||
continue;
|
||||
a *= val->hash();
|
||||
mix(a, b, c);
|
||||
}
|
||||
return fp;
|
||||
return a;
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -356,9 +202,15 @@ private:
|
|||
// Term Bank - stores enumerated terms by height and sort
|
||||
// ============================================================================
|
||||
|
||||
class TermBank {
|
||||
class term_bank {
|
||||
using sort_term_map = obj_map<sort, ptr_vector<expr>>;
|
||||
public:
|
||||
TermBank(ast_manager& m) : m(m), m_pinned(m) {}
|
||||
term_bank(ast_manager& m) : m(m), m_pinned(m) {}
|
||||
|
||||
~term_bank() {
|
||||
for (auto s : m_terms)
|
||||
dealloc(s);
|
||||
}
|
||||
|
||||
void reset() {
|
||||
m_pinned.reset();
|
||||
|
|
@ -371,7 +223,9 @@ public:
|
|||
m_pinned.push_back(term);
|
||||
if (height >= m_terms.size())
|
||||
m_terms.resize(height + 1);
|
||||
m_terms[height].insert_if_not_there(s, ptr_vector<expr>()).push_back(term);
|
||||
if (!m_terms[height])
|
||||
m_terms[height] = alloc(sort_term_map);
|
||||
m_terms[height]->insert_if_not_there(s, ptr_vector<expr>()).push_back(term);
|
||||
}
|
||||
|
||||
/** Get all terms of a given sort up to (and including) max_height */
|
||||
|
|
@ -380,27 +234,19 @@ public:
|
|||
for (unsigned h = 0; h <= max_height; ++h) {
|
||||
if (h >= m_terms.size())
|
||||
break;
|
||||
if (!m_terms[h].contains(s))
|
||||
if (!m_terms[h]->contains(s))
|
||||
continue;
|
||||
for (auto t : m_terms[h].find(s))
|
||||
for (auto t : m_terms[h]->find(s))
|
||||
result.push_back(t);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
size_t total_terms() const {
|
||||
size_t n = 0;
|
||||
for (auto& sm : m_terms)
|
||||
for (auto& [s, v] : sm)
|
||||
n += v.size();
|
||||
return n;
|
||||
}
|
||||
|
||||
private:
|
||||
ast_manager& m;
|
||||
expr_ref_vector m_pinned;
|
||||
// height -> sort -> terms
|
||||
vector<obj_map<sort, ptr_vector<expr>>> m_terms;
|
||||
ptr_vector<sort_term_map> m_terms;
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
|
|
@ -412,9 +258,9 @@ private:
|
|||
* sort, drawn from the term bank, with at least one child at the current
|
||||
* height - 1 (to avoid regenerating previously seen terms).
|
||||
*/
|
||||
class ChildrenIterator {
|
||||
class children_iterator {
|
||||
public:
|
||||
ChildrenIterator(ast_manager& m, Production const& prod, TermBank const& bank, unsigned current_height)
|
||||
children_iterator(ast_manager& m, production const& prod, term_bank const& bank, unsigned current_height)
|
||||
: m(m), m_prod(prod), m_current_height(current_height), m_done(false)
|
||||
{
|
||||
m_arity = prod.domain.size();
|
||||
|
|
@ -451,7 +297,7 @@ public:
|
|||
|
||||
private:
|
||||
ast_manager& m;
|
||||
Production const& m_prod;
|
||||
production const& m_prod;
|
||||
unsigned m_current_height;
|
||||
unsigned m_arity;
|
||||
bool m_done;
|
||||
|
|
@ -473,7 +319,7 @@ private:
|
|||
};
|
||||
|
||||
// ============================================================================
|
||||
// Enumerator - the main bottom-up term enumeration engine
|
||||
// bottom_up_enumerator - the main bottom-up term enumeration engine
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
|
|
@ -482,11 +328,11 @@ private:
|
|||
*
|
||||
* Usage:
|
||||
* ast_manager m;
|
||||
* Grammar g(m);
|
||||
* grammar g(m);
|
||||
* // ... add productions ...
|
||||
* OEManager oe(m);
|
||||
* oe_manager oe(m);
|
||||
* // ... set samples ...
|
||||
* Enumerator en(g, oe);
|
||||
* bottom_up_enumerator en(g, oe);
|
||||
* arith_util a(m);
|
||||
* en.set_target_sort(a.mk_int());
|
||||
* while (en.has_next()) {
|
||||
|
|
@ -494,9 +340,9 @@ private:
|
|||
* // ... check if term satisfies specification ...
|
||||
* }
|
||||
*/
|
||||
class Enumerator {
|
||||
class bottom_up_enumerator {
|
||||
public:
|
||||
Enumerator(Grammar& grammar, OEManager& oe)
|
||||
bottom_up_enumerator(grammar& grammar, oe_manager& oe)
|
||||
: m_grammar(grammar), m(grammar.mgr()), m_oe(oe),
|
||||
m_bank(grammar.mgr()), m_height(0),
|
||||
m_leaf_idx(0), m_op_idx(0), m_state(State::Leaves),
|
||||
|
|
@ -529,7 +375,7 @@ public:
|
|||
return results;
|
||||
}
|
||||
|
||||
TermBank const& bank() const { return m_bank; }
|
||||
term_bank const& bank() const { return m_bank; }
|
||||
unsigned current_height() const { return m_height; }
|
||||
|
||||
void reset() {
|
||||
|
|
@ -544,24 +390,23 @@ public:
|
|||
}
|
||||
|
||||
private:
|
||||
enum class State { Leaves, Operators, Done };
|
||||
enum class State { Leaves, operators, Done };
|
||||
|
||||
Grammar& m_grammar;
|
||||
grammar& m_grammar;
|
||||
ast_manager& m;
|
||||
OEManager& m_oe;
|
||||
TermBank m_bank;
|
||||
oe_manager& m_oe;
|
||||
term_bank m_bank;
|
||||
unsigned m_height;
|
||||
unsigned m_leaf_idx;
|
||||
unsigned m_op_idx;
|
||||
State m_state;
|
||||
sort* m_target_sort;
|
||||
expr_ref m_pending;
|
||||
std::unique_ptr<ChildrenIterator> m_children_iter;
|
||||
std::unique_ptr<children_iterator> m_children_iter;
|
||||
unsigned m_max_height = 100;
|
||||
|
||||
bool sort_matches(expr* e) const {
|
||||
if (!m_target_sort) return true;
|
||||
return e->get_sort() == m_target_sort;
|
||||
return !m_target_sort || e->get_sort() == m_target_sort;
|
||||
}
|
||||
|
||||
expr* find_next() {
|
||||
|
|
@ -569,7 +414,7 @@ private:
|
|||
switch (m_state) {
|
||||
case State::Leaves:
|
||||
while (m_leaf_idx < m_grammar.leaves().size()) {
|
||||
Production const& prod = m_grammar.leaves()[m_leaf_idx];
|
||||
production const& prod = m_grammar.leaves()[m_leaf_idx];
|
||||
m_leaf_idx++;
|
||||
expr_ref_vector empty_args(m);
|
||||
expr_ref term = prod.builder(empty_args);
|
||||
|
|
@ -579,13 +424,13 @@ private:
|
|||
return term;
|
||||
}
|
||||
}
|
||||
m_state = State::Operators;
|
||||
m_state = State::operators;
|
||||
m_height = 1;
|
||||
m_op_idx = 0;
|
||||
m_children_iter.reset();
|
||||
break;
|
||||
|
||||
case State::Operators: {
|
||||
case State::operators: {
|
||||
expr* result = enumerate_operators();
|
||||
if (result) return result;
|
||||
m_height++;
|
||||
|
|
@ -608,7 +453,7 @@ private:
|
|||
while (true) {
|
||||
if (m_children_iter && m_children_iter->has_next()) {
|
||||
expr_ref_vector children = m_children_iter->next();
|
||||
Production const& prod = ops[m_op_idx - 1];
|
||||
production const& prod = ops[m_op_idx - 1];
|
||||
expr_ref term = prod.builder(children);
|
||||
if (m_oe.is_representative(term)) {
|
||||
m_bank.add(term, m_height);
|
||||
|
|
@ -618,102 +463,13 @@ private:
|
|||
continue;
|
||||
}
|
||||
if (m_op_idx >= ops.size()) return nullptr;
|
||||
Production const& prod = ops[m_op_idx];
|
||||
production const& prod = ops[m_op_idx];
|
||||
m_op_idx++;
|
||||
m_children_iter = std::make_unique<ChildrenIterator>(m, prod, m_bank, m_height);
|
||||
m_children_iter = std::make_unique<children_iterator>(m, prod, m_bank, m_height);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// CEGIS integration helper
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Counter-Example Guided Inductive Synthesis loop.
|
||||
* Combines the enumerator with a solver to verify candidates against a
|
||||
* specification.
|
||||
*
|
||||
* spec: a function (expr* candidate) -> expr_ref that returns the specification
|
||||
* constraint (should be valid for a correct program).
|
||||
* variables: the free variables of the specification.
|
||||
*/
|
||||
class CEGISLoop {
|
||||
public:
|
||||
CEGISLoop(Grammar& grammar, sort* target_sort,
|
||||
std::function<expr_ref(expr*)> spec,
|
||||
expr_ref_vector variables)
|
||||
: m(grammar.mgr()), m_grammar(grammar), m_oe(grammar.mgr()),
|
||||
m_enumerator(grammar, m_oe),
|
||||
m_spec(std::move(spec)), m_variables(std::move(variables))
|
||||
{
|
||||
m_enumerator.set_target_sort(target_sort);
|
||||
params_ref p;
|
||||
m_solver = mk_smt_solver(m, p, symbol::null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Run the CEGIS loop. Returns the synthesized term, or null expr_ref if
|
||||
* max_height is exceeded.
|
||||
*/
|
||||
expr_ref synthesize(unsigned max_height = 10, unsigned max_restarts = 20) {
|
||||
m_enumerator.set_max_height(max_height);
|
||||
unsigned restarts = 0;
|
||||
|
||||
while (m_enumerator.has_next()) {
|
||||
expr_ref candidate = m_enumerator.next();
|
||||
|
||||
if (!satisfies_samples(candidate)) continue;
|
||||
|
||||
expr_ref spec_expr = m_spec(candidate);
|
||||
m_solver->push();
|
||||
m_solver->assert_expr(m.mk_not(spec_expr));
|
||||
lbool result = m_solver->check_sat(0, nullptr);
|
||||
|
||||
if (result == l_false) {
|
||||
m_solver->pop(1);
|
||||
return candidate;
|
||||
} else if (result == l_true) {
|
||||
model_ref cex;
|
||||
m_solver->get_model(cex);
|
||||
m_oe.add_sample(cex);
|
||||
m_samples.push_back(cex);
|
||||
m_solver->pop(1);
|
||||
restarts++;
|
||||
if (restarts > max_restarts) return expr_ref(m);
|
||||
m_enumerator.reset();
|
||||
} else {
|
||||
m_solver->pop(1);
|
||||
}
|
||||
}
|
||||
return expr_ref(m);
|
||||
}
|
||||
|
||||
size_t num_samples() const { return m_oe.num_samples(); }
|
||||
size_t num_equivalence_classes() const { return m_oe.num_classes(); }
|
||||
|
||||
private:
|
||||
ast_manager& m;
|
||||
Grammar& m_grammar;
|
||||
OEManager m_oe;
|
||||
Enumerator m_enumerator;
|
||||
std::function<expr_ref(expr*)> m_spec;
|
||||
expr_ref_vector m_variables;
|
||||
ref<solver> m_solver;
|
||||
vector<model_ref> m_samples;
|
||||
|
||||
bool satisfies_samples(expr* candidate) {
|
||||
expr_ref spec_expr = m_spec(candidate);
|
||||
for (auto& mdl : m_samples) {
|
||||
model_evaluator eval(*mdl);
|
||||
eval.set_model_completion(true);
|
||||
if (eval.is_false(spec_expr))
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace term_enum
|
||||
|
||||
// ============================================================================
|
||||
|
|
@ -722,13 +478,13 @@ private:
|
|||
|
||||
struct term_enumeration::imp {
|
||||
ast_manager& m;
|
||||
term_enum::Grammar m_grammar;
|
||||
term_enum::OEManager m_oe;
|
||||
term_enum::Enumerator m_enumerator;
|
||||
term_enum::grammar m_grammar;
|
||||
term_enum::oe_manager m_oe;
|
||||
term_enum::bottom_up_enumerator m_bottom_up_enumerator;
|
||||
std::function<unsigned(expr*)> m_cost;
|
||||
|
||||
imp(ast_manager& m) :
|
||||
m(m), m_grammar(m), m_oe(m), m_enumerator(m_grammar, m_oe) {}
|
||||
m(m), m_grammar(m), m_oe(m), m_bottom_up_enumerator(m_grammar, m_oe) {}
|
||||
|
||||
void add_production(func_decl* f) {
|
||||
m_grammar.add_func_decl(f);
|
||||
|
|
@ -758,7 +514,7 @@ struct term_enumeration::imp {
|
|||
}
|
||||
if (height_state >= levels.size() || idx_state >= levels[height_state].size())
|
||||
return nullptr;
|
||||
return levels[height_state][idx_state++];
|
||||
return levels[height_state].get(idx_state++);
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
@ -779,16 +535,16 @@ private:
|
|||
expr_ref_vector empty_args(m);
|
||||
expr_ref term = prod.builder(empty_args);
|
||||
if (m_oe.is_representative(term)) {
|
||||
m_enumerator.bank(); // just to ensure bank is populated
|
||||
m_bottom_up_enumerator.bank(); // just to ensure bank is populated
|
||||
levels[0].push_back(term);
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
// Operators
|
||||
// operators
|
||||
for (auto const& prod : m_grammar.operators()) {
|
||||
if (prod.range.get() != s) continue;
|
||||
term_enum::ChildrenIterator iter(m, prod, m_enumerator.bank(), height);
|
||||
term_enum::children_iterator iter(m, prod, m_bottom_up_enumerator.bank(), height);
|
||||
while (iter.has_next()) {
|
||||
expr_ref_vector children = iter.next();
|
||||
expr_ref term = prod.builder(children);
|
||||
|
|
@ -847,7 +603,7 @@ struct term_enumeration::iterator::iter_imp {
|
|||
else {
|
||||
for (auto const& prod : m_imp.m_grammar.operators()) {
|
||||
if (prod.range.get() != m_sort) continue;
|
||||
term_enum::ChildrenIterator iter(m_imp.m, prod, m_imp.m_enumerator.bank(), m_height);
|
||||
term_enum::children_iterator iter(m_imp.m, prod, m_imp.m_bottom_up_enumerator.bank(), m_height);
|
||||
while (iter.has_next()) {
|
||||
expr_ref_vector children = iter.next();
|
||||
expr_ref term = prod.builder(children);
|
||||
|
|
@ -870,7 +626,7 @@ struct term_enumeration::iterator::iter_imp {
|
|||
if (m_height >= m_levels.size())
|
||||
expand_current_level();
|
||||
if (m_idx < m_levels[m_height].size()) {
|
||||
m_current = m_levels[m_height][m_idx];
|
||||
m_current = m_levels[m_height].get(m_idx);
|
||||
return;
|
||||
}
|
||||
m_height++;
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue