diff --git a/src/ast/term_enumeration.cpp b/src/ast/term_enumeration.cpp index 27c2b9823..815ef2fdb 100644 --- a/src/ast/term_enumeration.cpp +++ b/src/ast/term_enumeration.cpp @@ -9,18 +9,15 @@ * - Observational equivalence (OE): two terms that produce the same outputs * on all sample inputs are considered equivalent; only one representative * per equivalence class is kept. - * - A Grammar describes which function symbols (operators) and leaves + * - A grammar describes which function symbols (operators) and leaves * (constants, variables) are available for enumeration. */ #include "ast/term_enumeration.h" #include -#include #include #include #include -#include -#include #include "ast/ast.h" #include "ast/ast_pp.h" #include "ast/ast_ll_pp.h" @@ -29,8 +26,7 @@ #include "ast/seq_decl_plugin.h" #include "model/model.h" #include "model/model_evaluator.h" -#include "solver/solver.h" -#include "smt/smt_solver.h" + #include "util/vector.h" #include "util/ref.h" #include "util/obj_hashtable.h" @@ -38,16 +34,16 @@ namespace term_enum { // ============================================================================ -// Grammar production rule +// grammar production rule // ============================================================================ /** - * A Production describes how to construct a term from child terms. + * A production describes how to construct a term from child terms. * - domain: the sort required for each child * - range: the sort of the produced term * - builder: given a vector of child exprs, produce the result expr */ -struct Production { +struct production { std::string name; sort_ref range; sort_ref_vector domain; @@ -57,25 +53,25 @@ struct Production { }; // ============================================================================ -// Grammar +// grammar // ============================================================================ /** - * A Grammar groups productions into leaves (arity 0) and operators (arity > 0). + * A grammar groups productions into leaves (arity 0) and operators (arity > 0). */ -class Grammar { +class grammar { public: - Grammar(ast_manager& m) : m(m) {} + grammar(ast_manager& m) : m(m) {} - void add_production(Production p) { + void add_production(production p) { if (p.is_leaf()) m_leaves.push_back(std::move(p)); else m_operators.push_back(std::move(p)); } - vector const& leaves() const { return m_leaves; } - vector const& operators() const { return m_operators; } + vector const& leaves() const { return m_leaves; } + vector const& operators() const { return m_operators; } ast_manager& mgr() const { return m; } void add_variable(char const* name, sort* s) { @@ -138,158 +134,10 @@ public: private: ast_manager& m; - vector m_leaves; - vector m_operators; + vector m_leaves; + vector m_operators; }; -// ============================================================================ -// Standard grammar factories - build common operator sets -// ============================================================================ - -namespace grammars { - -/** - * Build a grammar over linear integer arithmetic. - * Operators: +, -, *, ite (with bool condition) - */ -inline void add_lia_operators(Grammar& g) { - ast_manager& m = g.mgr(); - arith_util a(m); - sort_ref isort(a.mk_int(), m); - sort_ref bsort(m.mk_bool_sort(), m); - - sort_ref_vector ii(m); ii.push_back(isort); ii.push_back(isort); - sort_ref_vector i1(m); i1.push_back(isort); - sort_ref_vector bb(m); bb.push_back(bsort); bb.push_back(bsort); - sort_ref_vector b1(m); b1.push_back(bsort); - sort_ref_vector bii(m); bii.push_back(bsort); bii.push_back(isort); bii.push_back(isort); - - g.add_production({"add", isort, ii, - [&m](expr_ref_vector const& ch) { arith_util a(m); return expr_ref(a.mk_add(ch[0], ch[1]), m); }}); - g.add_production({"sub", isort, ii, - [&m](expr_ref_vector const& ch) { arith_util a(m); return expr_ref(a.mk_sub(ch[0], ch[1]), m); }}); - g.add_production({"mul", isort, ii, - [&m](expr_ref_vector const& ch) { arith_util a(m); return expr_ref(a.mk_mul(ch[0], ch[1]), m); }}); - g.add_production({"neg", isort, i1, - [&m](expr_ref_vector const& ch) { arith_util a(m); return expr_ref(a.mk_uminus(ch[0]), m); }}); - - g.add_production({"le", bsort, ii, - [&m](expr_ref_vector const& ch) { arith_util a(m); return expr_ref(a.mk_le(ch[0], ch[1]), m); }}); - g.add_production({"lt", bsort, ii, - [&m](expr_ref_vector const& ch) { arith_util a(m); return expr_ref(a.mk_lt(ch[0], ch[1]), m); }}); - g.add_production({"eq_int", bsort, ii, - [&m](expr_ref_vector const& ch) { return expr_ref(m.mk_eq(ch[0], ch[1]), m); }}); - - g.add_production({"and", bsort, bb, - [&m](expr_ref_vector const& ch) { return expr_ref(m.mk_and(ch[0], ch[1]), m); }}); - g.add_production({"or", bsort, bb, - [&m](expr_ref_vector const& ch) { return expr_ref(m.mk_or(ch[0], ch[1]), m); }}); - g.add_production({"not", bsort, b1, - [&m](expr_ref_vector const& ch) { return expr_ref(m.mk_not(ch[0]), m); }}); - - g.add_production({"ite_int", isort, bii, - [&m](expr_ref_vector const& ch) { return expr_ref(m.mk_ite(ch[0], ch[1], ch[2]), m); }}); -} - -/** - * Build a grammar over bitvectors. - */ -inline void add_bv_operators(Grammar& g, unsigned bits) { - ast_manager& m = g.mgr(); - bv_util bv(m); - sort_ref bvsort(bv.mk_sort(bits), m); - sort_ref bsort(m.mk_bool_sort(), m); - - sort_ref_vector vv(m); vv.push_back(bvsort); vv.push_back(bvsort); - sort_ref_vector v1(m); v1.push_back(bvsort); - sort_ref_vector bvv(m); bvv.push_back(bsort); bvv.push_back(bvsort); bvv.push_back(bvsort); - - g.add_production({"bvadd", bvsort, vv, - [&m](expr_ref_vector const& ch) { bv_util bv(m); return expr_ref(bv.mk_bv_add(ch[0], ch[1]), m); }}); - g.add_production({"bvsub", bvsort, vv, - [&m](expr_ref_vector const& ch) { bv_util bv(m); return expr_ref(bv.mk_bv_sub(ch[0], ch[1]), m); }}); - g.add_production({"bvmul", bvsort, vv, - [&m](expr_ref_vector const& ch) { bv_util bv(m); return expr_ref(bv.mk_bv_mul(ch[0], ch[1]), m); }}); - g.add_production({"bvand", bvsort, vv, - [&m](expr_ref_vector const& ch) { bv_util bv(m); return expr_ref(bv.mk_bv_and(ch[0], ch[1]), m); }}); - g.add_production({"bvor", bvsort, vv, - [&m](expr_ref_vector const& ch) { bv_util bv(m); return expr_ref(bv.mk_bv_or(ch[0], ch[1]), m); }}); - g.add_production({"bvxor", bvsort, vv, - [&m](expr_ref_vector const& ch) { bv_util bv(m); return expr_ref(bv.mk_bv_xor(ch[0], ch[1]), m); }}); - g.add_production({"bvnot", bvsort, v1, - [&m](expr_ref_vector const& ch) { bv_util bv(m); return expr_ref(bv.mk_bv_not(ch[0]), m); }}); - g.add_production({"bvneg", bvsort, v1, - [&m](expr_ref_vector const& ch) { bv_util bv(m); return expr_ref(bv.mk_bv_neg(ch[0]), m); }}); - g.add_production({"bvshl", bvsort, vv, - [&m](expr_ref_vector const& ch) { bv_util bv(m); return expr_ref(bv.mk_bv_shl(ch[0], ch[1]), m); }}); - g.add_production({"bvlshr", bvsort, vv, - [&m](expr_ref_vector const& ch) { bv_util bv(m); return expr_ref(bv.mk_bv_lshr(ch[0], ch[1]), m); }}); - g.add_production({"bvashr", bvsort, vv, - [&m](expr_ref_vector const& ch) { bv_util bv(m); return expr_ref(bv.mk_bv_ashr(ch[0], ch[1]), m); }}); - - g.add_production({"bvult", bsort, vv, - [&m](expr_ref_vector const& ch) { bv_util bv(m); return expr_ref(m.mk_app(bv.get_fid(), OP_ULT, ch[0], ch[1]), m); }}); - g.add_production({"bvslt", bsort, vv, - [&m](expr_ref_vector const& ch) { bv_util bv(m); return expr_ref(bv.mk_slt(ch[0], ch[1]), m); }}); - g.add_production({"bveq", bsort, vv, - [&m](expr_ref_vector const& ch) { return expr_ref(m.mk_eq(ch[0], ch[1]), m); }}); - - g.add_production({"ite_bv", bvsort, bvv, - [&m](expr_ref_vector const& ch) { return expr_ref(m.mk_ite(ch[0], ch[1], ch[2]), m); }}); -} - -/** - * Build a grammar over strings. - */ -inline void add_string_operators(Grammar& g) { - ast_manager& m = g.mgr(); - seq_util seq(m); - arith_util a(m); - sort_ref ssort(seq.str.mk_string_sort(), m); - sort_ref isort(a.mk_int(), m); - sort_ref bsort(m.mk_bool_sort(), m); - - sort_ref_vector ss(m); ss.push_back(ssort); ss.push_back(ssort); - sort_ref_vector s1(m); s1.push_back(ssort); - sort_ref_vector si(m); si.push_back(ssort); si.push_back(isort); - sort_ref_vector sii(m); sii.push_back(ssort); sii.push_back(isort); sii.push_back(isort); - sort_ref_vector ssi(m); ssi.push_back(ssort); ssi.push_back(ssort); ssi.push_back(isort); - sort_ref_vector sss(m); sss.push_back(ssort); sss.push_back(ssort); sss.push_back(ssort); - sort_ref_vector i1(m); i1.push_back(isort); - sort_ref_vector bss(m); bss.push_back(bsort); bss.push_back(ssort); bss.push_back(ssort); - - g.add_production({"str.++", ssort, ss, - [&m](expr_ref_vector const& ch) { seq_util seq(m); return expr_ref(seq.str.mk_concat(ch[0], ch[1]), m); }}); - g.add_production({"str.len", isort, s1, - [&m](expr_ref_vector const& ch) { seq_util seq(m); return expr_ref(seq.str.mk_length(ch[0]), m); }}); - g.add_production({"str.at", ssort, si, - [&m](expr_ref_vector const& ch) { - seq_util seq(m); arith_util a(m); - return expr_ref(seq.str.mk_substr(ch[0], ch[1], a.mk_int(1)), m); - }}); - g.add_production({"str.substr", ssort, sii, - [&m](expr_ref_vector const& ch) { seq_util seq(m); return expr_ref(seq.str.mk_substr(ch[0], ch[1], ch[2]), m); }}); - g.add_production({"str.indexof", isort, ssi, - [&m](expr_ref_vector const& ch) { seq_util seq(m); return expr_ref(seq.str.mk_index(ch[0], ch[1], ch[2]), m); }}); - g.add_production({"str.replace", ssort, sss, - [&m](expr_ref_vector const& ch) { seq_util seq(m); return expr_ref(seq.str.mk_replace(ch[0], ch[1], ch[2]), m); }}); - g.add_production({"str.contains", bsort, ss, - [&m](expr_ref_vector const& ch) { seq_util seq(m); return expr_ref(seq.str.mk_contains(ch[0], ch[1]), m); }}); - g.add_production({"str.prefixof", bsort, ss, - [&m](expr_ref_vector const& ch) { seq_util seq(m); return expr_ref(seq.str.mk_prefix(ch[0], ch[1]), m); }}); - g.add_production({"str.suffixof", bsort, ss, - [&m](expr_ref_vector const& ch) { seq_util seq(m); return expr_ref(seq.str.mk_suffix(ch[0], ch[1]), m); }}); - g.add_production({"int.to.str", ssort, i1, - [&m](expr_ref_vector const& ch) { seq_util seq(m); return expr_ref(seq.str.mk_itos(ch[0]), m); }}); - g.add_production({"str.to.int", isort, s1, - [&m](expr_ref_vector const& ch) { seq_util seq(m); return expr_ref(seq.str.mk_stoi(ch[0]), m); }}); - - g.add_production({"ite_str", ssort, bss, - [&m](expr_ref_vector const& ch) { return expr_ref(m.mk_ite(ch[0], ch[1], ch[2]), m); }}); -} - -} // namespace grammars - // ============================================================================ // Observational Equivalence Manager // ============================================================================ @@ -300,9 +148,9 @@ inline void add_string_operators(Grammar& g) { * * Uses Z3's model evaluation to reduce terms to concrete values. */ -class OEManager { +class oe_manager { public: - OEManager(ast_manager& m) : m(m) {} + oe_manager(ast_manager& m) : m(m) {} void set_samples(vector samples) { m_samples = std::move(samples); @@ -319,36 +167,34 @@ public: * not been seen before). */ bool is_representative(expr* term) { - if (m_samples.empty()) return true; - std::string fingerprint = compute_fingerprint(term); - if (fingerprint.empty()) return false; + auto fingerprint = compute_fingerprint(term); + if (fingerprint == 0) + return false; return m_seen.insert(fingerprint).second; } void clear() { m_seen.clear(); } size_t num_classes() const { return m_seen.size(); } - size_t num_samples() const { return m_samples.size(); } + unsigned num_samples() const { return m_samples.size(); } private: ast_manager& m; vector m_samples; - std::unordered_set m_seen; + std::unordered_set m_seen; - std::string compute_fingerprint(expr* term) { - std::string fp; + uint64_t compute_fingerprint(expr* term) { + uint64_t a = 0, b = 1, c = 2; for (auto& mdl : m_samples) { expr_ref val(m); model_evaluator eval(*mdl); eval.set_model_completion(true); if (!eval.eval(term, val, true)) - return ""; - std::ostringstream os; - os << mk_pp(val, m); - fp += os.str(); - fp += '\x1f'; + continue; + a *= val->hash(); + mix(a, b, c); } - return fp; + return a; } }; @@ -356,9 +202,15 @@ private: // Term Bank - stores enumerated terms by height and sort // ============================================================================ -class TermBank { +class term_bank { + using sort_term_map = obj_map>; public: - TermBank(ast_manager& m) : m(m), m_pinned(m) {} + term_bank(ast_manager& m) : m(m), m_pinned(m) {} + + ~term_bank() { + for (auto s : m_terms) + dealloc(s); + } void reset() { m_pinned.reset(); @@ -371,7 +223,9 @@ public: m_pinned.push_back(term); if (height >= m_terms.size()) m_terms.resize(height + 1); - m_terms[height].insert_if_not_there(s, ptr_vector()).push_back(term); + if (!m_terms[height]) + m_terms[height] = alloc(sort_term_map); + m_terms[height]->insert_if_not_there(s, ptr_vector()).push_back(term); } /** Get all terms of a given sort up to (and including) max_height */ @@ -380,27 +234,19 @@ public: for (unsigned h = 0; h <= max_height; ++h) { if (h >= m_terms.size()) break; - if (!m_terms[h].contains(s)) + if (!m_terms[h]->contains(s)) continue; - for (auto t : m_terms[h].find(s)) + for (auto t : m_terms[h]->find(s)) result.push_back(t); } return result; } - size_t total_terms() const { - size_t n = 0; - for (auto& sm : m_terms) - for (auto& [s, v] : sm) - n += v.size(); - return n; - } - private: ast_manager& m; expr_ref_vector m_pinned; // height -> sort -> terms - vector>> m_terms; + ptr_vector m_terms; }; // ============================================================================ @@ -412,9 +258,9 @@ private: * sort, drawn from the term bank, with at least one child at the current * height - 1 (to avoid regenerating previously seen terms). */ -class ChildrenIterator { +class children_iterator { public: - ChildrenIterator(ast_manager& m, Production const& prod, TermBank const& bank, unsigned current_height) + children_iterator(ast_manager& m, production const& prod, term_bank const& bank, unsigned current_height) : m(m), m_prod(prod), m_current_height(current_height), m_done(false) { m_arity = prod.domain.size(); @@ -451,7 +297,7 @@ public: private: ast_manager& m; - Production const& m_prod; + production const& m_prod; unsigned m_current_height; unsigned m_arity; bool m_done; @@ -473,7 +319,7 @@ private: }; // ============================================================================ -// Enumerator - the main bottom-up term enumeration engine +// bottom_up_enumerator - the main bottom-up term enumeration engine // ============================================================================ /** @@ -482,11 +328,11 @@ private: * * Usage: * ast_manager m; - * Grammar g(m); + * grammar g(m); * // ... add productions ... - * OEManager oe(m); + * oe_manager oe(m); * // ... set samples ... - * Enumerator en(g, oe); + * bottom_up_enumerator en(g, oe); * arith_util a(m); * en.set_target_sort(a.mk_int()); * while (en.has_next()) { @@ -494,9 +340,9 @@ private: * // ... check if term satisfies specification ... * } */ -class Enumerator { +class bottom_up_enumerator { public: - Enumerator(Grammar& grammar, OEManager& oe) + bottom_up_enumerator(grammar& grammar, oe_manager& oe) : m_grammar(grammar), m(grammar.mgr()), m_oe(oe), m_bank(grammar.mgr()), m_height(0), m_leaf_idx(0), m_op_idx(0), m_state(State::Leaves), @@ -529,7 +375,7 @@ public: return results; } - TermBank const& bank() const { return m_bank; } + term_bank const& bank() const { return m_bank; } unsigned current_height() const { return m_height; } void reset() { @@ -544,24 +390,23 @@ public: } private: - enum class State { Leaves, Operators, Done }; + enum class State { Leaves, operators, Done }; - Grammar& m_grammar; + grammar& m_grammar; ast_manager& m; - OEManager& m_oe; - TermBank m_bank; + oe_manager& m_oe; + term_bank m_bank; unsigned m_height; unsigned m_leaf_idx; unsigned m_op_idx; State m_state; sort* m_target_sort; expr_ref m_pending; - std::unique_ptr m_children_iter; + std::unique_ptr m_children_iter; unsigned m_max_height = 100; bool sort_matches(expr* e) const { - if (!m_target_sort) return true; - return e->get_sort() == m_target_sort; + return !m_target_sort || e->get_sort() == m_target_sort; } expr* find_next() { @@ -569,7 +414,7 @@ private: switch (m_state) { case State::Leaves: while (m_leaf_idx < m_grammar.leaves().size()) { - Production const& prod = m_grammar.leaves()[m_leaf_idx]; + production const& prod = m_grammar.leaves()[m_leaf_idx]; m_leaf_idx++; expr_ref_vector empty_args(m); expr_ref term = prod.builder(empty_args); @@ -579,13 +424,13 @@ private: return term; } } - m_state = State::Operators; + m_state = State::operators; m_height = 1; m_op_idx = 0; m_children_iter.reset(); break; - case State::Operators: { + case State::operators: { expr* result = enumerate_operators(); if (result) return result; m_height++; @@ -608,7 +453,7 @@ private: while (true) { if (m_children_iter && m_children_iter->has_next()) { expr_ref_vector children = m_children_iter->next(); - Production const& prod = ops[m_op_idx - 1]; + production const& prod = ops[m_op_idx - 1]; expr_ref term = prod.builder(children); if (m_oe.is_representative(term)) { m_bank.add(term, m_height); @@ -618,102 +463,13 @@ private: continue; } if (m_op_idx >= ops.size()) return nullptr; - Production const& prod = ops[m_op_idx]; + production const& prod = ops[m_op_idx]; m_op_idx++; - m_children_iter = std::make_unique(m, prod, m_bank, m_height); + m_children_iter = std::make_unique(m, prod, m_bank, m_height); } } }; -// ============================================================================ -// CEGIS integration helper -// ============================================================================ - -/** - * Counter-Example Guided Inductive Synthesis loop. - * Combines the enumerator with a solver to verify candidates against a - * specification. - * - * spec: a function (expr* candidate) -> expr_ref that returns the specification - * constraint (should be valid for a correct program). - * variables: the free variables of the specification. - */ -class CEGISLoop { -public: - CEGISLoop(Grammar& grammar, sort* target_sort, - std::function spec, - expr_ref_vector variables) - : m(grammar.mgr()), m_grammar(grammar), m_oe(grammar.mgr()), - m_enumerator(grammar, m_oe), - m_spec(std::move(spec)), m_variables(std::move(variables)) - { - m_enumerator.set_target_sort(target_sort); - params_ref p; - m_solver = mk_smt_solver(m, p, symbol::null); - } - - /** - * Run the CEGIS loop. Returns the synthesized term, or null expr_ref if - * max_height is exceeded. - */ - expr_ref synthesize(unsigned max_height = 10, unsigned max_restarts = 20) { - m_enumerator.set_max_height(max_height); - unsigned restarts = 0; - - while (m_enumerator.has_next()) { - expr_ref candidate = m_enumerator.next(); - - if (!satisfies_samples(candidate)) continue; - - expr_ref spec_expr = m_spec(candidate); - m_solver->push(); - m_solver->assert_expr(m.mk_not(spec_expr)); - lbool result = m_solver->check_sat(0, nullptr); - - if (result == l_false) { - m_solver->pop(1); - return candidate; - } else if (result == l_true) { - model_ref cex; - m_solver->get_model(cex); - m_oe.add_sample(cex); - m_samples.push_back(cex); - m_solver->pop(1); - restarts++; - if (restarts > max_restarts) return expr_ref(m); - m_enumerator.reset(); - } else { - m_solver->pop(1); - } - } - return expr_ref(m); - } - - size_t num_samples() const { return m_oe.num_samples(); } - size_t num_equivalence_classes() const { return m_oe.num_classes(); } - -private: - ast_manager& m; - Grammar& m_grammar; - OEManager m_oe; - Enumerator m_enumerator; - std::function m_spec; - expr_ref_vector m_variables; - ref m_solver; - vector m_samples; - - bool satisfies_samples(expr* candidate) { - expr_ref spec_expr = m_spec(candidate); - for (auto& mdl : m_samples) { - model_evaluator eval(*mdl); - eval.set_model_completion(true); - if (eval.is_false(spec_expr)) - return false; - } - return true; - } -}; - } // namespace term_enum // ============================================================================ @@ -722,13 +478,13 @@ private: struct term_enumeration::imp { ast_manager& m; - term_enum::Grammar m_grammar; - term_enum::OEManager m_oe; - term_enum::Enumerator m_enumerator; + term_enum::grammar m_grammar; + term_enum::oe_manager m_oe; + term_enum::bottom_up_enumerator m_bottom_up_enumerator; std::function m_cost; imp(ast_manager& m) : - m(m), m_grammar(m), m_oe(m), m_enumerator(m_grammar, m_oe) {} + m(m), m_grammar(m), m_oe(m), m_bottom_up_enumerator(m_grammar, m_oe) {} void add_production(func_decl* f) { m_grammar.add_func_decl(f); @@ -758,7 +514,7 @@ struct term_enumeration::imp { } if (height_state >= levels.size() || idx_state >= levels[height_state].size()) return nullptr; - return levels[height_state][idx_state++]; + return levels[height_state].get(idx_state++); } private: @@ -779,16 +535,16 @@ private: expr_ref_vector empty_args(m); expr_ref term = prod.builder(empty_args); if (m_oe.is_representative(term)) { - m_enumerator.bank(); // just to ensure bank is populated + m_bottom_up_enumerator.bank(); // just to ensure bank is populated levels[0].push_back(term); } } } else { - // Operators + // operators for (auto const& prod : m_grammar.operators()) { if (prod.range.get() != s) continue; - term_enum::ChildrenIterator iter(m, prod, m_enumerator.bank(), height); + term_enum::children_iterator iter(m, prod, m_bottom_up_enumerator.bank(), height); while (iter.has_next()) { expr_ref_vector children = iter.next(); expr_ref term = prod.builder(children); @@ -847,7 +603,7 @@ struct term_enumeration::iterator::iter_imp { else { for (auto const& prod : m_imp.m_grammar.operators()) { if (prod.range.get() != m_sort) continue; - term_enum::ChildrenIterator iter(m_imp.m, prod, m_imp.m_enumerator.bank(), m_height); + term_enum::children_iterator iter(m_imp.m, prod, m_imp.m_bottom_up_enumerator.bank(), m_height); while (iter.has_next()) { expr_ref_vector children = iter.next(); expr_ref term = prod.builder(children); @@ -870,7 +626,7 @@ struct term_enumeration::iterator::iter_imp { if (m_height >= m_levels.size()) expand_current_level(); if (m_idx < m_levels[m_height].size()) { - m_current = m_levels[m_height][m_idx]; + m_current = m_levels[m_height].get(m_idx); return; } m_height++;