From deaced1711db7253e00aba56c1f69877146ac8e5 Mon Sep 17 00:00:00 2001 From: Simon Jeanteur Date: Tue, 13 Jan 2026 19:53:17 +0100 Subject: [PATCH] Subterms Theory (#8115) * somewhaat failed attempt at declaring subterm predicate I can't really figure out how to link the smt parser to the rest of the machinenery, so I will stop here and try from the other side. I'll start implmenting the logic and see if it brings me back to the parser. * initial logic implmentation Very primitive, but I don't like have that much work uncommitted. * parser implementation * more theory * Working base * subterm reflexivity * a few optimization Skip adding obvious equalities or disequality * removed some optimisations * better handling of backtracking * stupid segfault Add m_subterm to the trail * Update src/smt/theory_datatype.h Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update src/ast/rewriter/datatype_rewriter.cpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update src/smt/theory_datatype.cpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update src/smt/theory_datatype.cpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update src/smt/theory_datatype.cpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * review * forgot to update `iterate_subterm`'s signature * fix iterator segfault * Remove duplicate include statement Removed duplicate include of 'theory_datatype.h'. * Replace 'optional' with 'std::option' in datatype_decl_plugin.h * Add is_subterm_predicate matcher to datatype_decl_plugin * Change std::option to std::optional for m_subterm * Update pdecl.h * Change has_subterm to use has_value method * Update pdecl.cpp --------- Co-authored-by: Nikolaj Bjorner Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/ast/datatype_decl_plugin.cpp | 55 +++- src/ast/datatype_decl_plugin.h | 40 ++- src/ast/rewriter/datatype_rewriter.cpp | 3 + src/cmd_context/cmd_context.cpp | 3 + src/cmd_context/pdecl.cpp | 15 +- src/cmd_context/pdecl.h | 13 + src/parsers/smt2/smt2parser.cpp | 23 +- src/smt/theory_datatype.cpp | 368 +++++++++++++++++++++++-- src/smt/theory_datatype.h | 75 ++++- 9 files changed, 563 insertions(+), 32 deletions(-) diff --git a/src/ast/datatype_decl_plugin.cpp b/src/ast/datatype_decl_plugin.cpp index d0c74bd50..fc3ddfcab 100644 --- a/src/ast/datatype_decl_plugin.cpp +++ b/src/ast/datatype_decl_plugin.cpp @@ -57,6 +57,23 @@ namespace datatype { return alloc(accessor, tr.to(), name(), to_sort(tr(m_range.get()))); } + def const& subterm::get_def() const { return *m_def; } + util& subterm::u() const { return m_def->u(); } + + func_decl_ref subterm::instantiate(sort_ref_vector const& ps) const { + ast_manager& m = ps.get_manager(); + sort_ref dt_sort = get_def().instantiate(ps); + sort* domain[2] = { dt_sort, dt_sort }; + sort_ref range(m.mk_bool_sort(), m); + parameter p(name()); + return func_decl_ref(m.mk_func_decl(u().get_family_id(), OP_DT_SUBTERM, 1, &p, 2, domain, range), m); + } + + func_decl_ref subterm::instantiate(sort* dt) const { + sort_ref_vector sorts = get_def().u().datatype_params(dt); + return instantiate(sorts); + } + constructor::~constructor() { for (accessor* a : m_accessors) dealloc(a); m_accessors.reset(); @@ -235,6 +252,7 @@ namespace datatype { void plugin::reset() { m_datatype2constructors.reset(); + m_datatype2subterm.reset(); m_datatype2nonrec_constructor.reset(); m_constructor2accessors.reset(); m_constructor2recognizer.reset(); @@ -443,6 +461,18 @@ namespace datatype { return m.mk_func_decl(name, arity, domain, range, info); } + func_decl * decl::plugin::mk_subterm(unsigned num_parameters, parameter const * parameters, + unsigned arity, sort * const * domain, sort* range) + { + ast_manager& m = *m_manager; + VALIDATE_PARAM(num_parameters == 1 && parameters[0].is_symbol()); + VALIDATE_PARAM(arity == 2 && u().is_datatype(domain[0]) && domain[0] == domain[1] && m.is_bool(range)); + func_decl_info info(m_family_id, OP_DT_SUBTERM, num_parameters, parameters); + info.m_private_parameters = true; + symbol name = parameters[0].get_symbol(); + return m.mk_func_decl(name, arity, domain, range, info); + } + func_decl * decl::plugin::mk_func_decl(decl_kind k, unsigned num_parameters, parameter const * parameters, unsigned arity, sort * const * domain, sort * range) { switch (k) { @@ -453,7 +483,9 @@ namespace datatype { case OP_DT_IS: return mk_is(num_parameters, parameters, arity, domain, range); case OP_DT_ACCESSOR: - return mk_accessor(num_parameters, parameters, arity, domain, range); + return mk_accessor(num_parameters, parameters, arity, domain, range); + case OP_DT_SUBTERM: + return mk_subterm(num_parameters, parameters, arity, domain, range); case OP_DT_UPDATE_FIELD: return mk_update_field(num_parameters, parameters, arity, domain, range); default: @@ -1040,6 +1072,22 @@ namespace datatype { return m_family_id; } + func_decl * util::get_datatype_subterm(sort * ty) { + SASSERT(is_datatype(ty)); + func_decl * r = nullptr; + if (plugin().m_datatype2subterm.find(ty, r)) + return r; + + def const& d = get_def(ty); + if (d.has_subterm()) { + func_decl_ref f = d.get_subterm().instantiate(ty); + r = f; + plugin().add_ast(r); + plugin().m_datatype2subterm.insert(ty, r); + } + return r; + } + ptr_vector const * util::get_datatype_constructors(sort * ty) { SASSERT(is_datatype(ty)); ptr_vector * r = nullptr; @@ -1482,11 +1530,14 @@ namespace datatype { } -datatype_decl * mk_datatype_decl(datatype_util& u, symbol const & n, unsigned num_params, sort*const* params, unsigned num_constructors, constructor_decl * const * cs) { +datatype_decl * mk_datatype_decl(datatype_util& u, symbol const & n, unsigned num_params, sort*const* params, unsigned num_constructors, constructor_decl * const * cs, symbol const& subterm_name) { datatype::decl::plugin& p = u.plugin(); datatype::def* d = p.mk(n, num_params, params); for (unsigned i = 0; i < num_constructors; ++i) { d->add(cs[i]); } + if (subterm_name != symbol::null) { + d->attach_subterm(subterm_name, u.get_manager().mk_bool_sort()); + } return d; } diff --git a/src/ast/datatype_decl_plugin.h b/src/ast/datatype_decl_plugin.h index 7876f10c6..41ed2036b 100644 --- a/src/ast/datatype_decl_plugin.h +++ b/src/ast/datatype_decl_plugin.h @@ -39,6 +39,7 @@ enum op_kind { OP_DT_IS, OP_DT_ACCESSOR, OP_DT_UPDATE_FIELD, + OP_DT_SUBTERM, LAST_DT_OP }; @@ -48,6 +49,22 @@ namespace datatype { class def; class accessor; class constructor; + class subterm; + + class subterm { + symbol m_name; + sort_ref m_range; + def* m_def = nullptr; + public: + subterm(ast_manager& m, symbol const& n, sort* r) : m_name(n), m_range(r, m) {} + sort* range() const { return m_range; } + symbol const& name() const { return m_name; } + func_decl_ref instantiate(sort_ref_vector const& ps) const; + func_decl_ref instantiate(sort* dt) const; + util& u() const; + void attach(def* d) { m_def = d; } + def const& get_def() const; + }; class accessor { @@ -166,6 +183,7 @@ namespace datatype { mutable sort_ref m_sort; ptr_vector m_constructors; mutable dictionary m_name2constructor; + std::optional m_subterm; public: def(ast_manager& m, util& u, symbol const& n, unsigned class_id, unsigned num_params, sort * const* params): m(m), @@ -185,6 +203,10 @@ namespace datatype { m_constructors.push_back(c); c->attach(this); } + void attach_subterm(symbol const& n, sort* range) { + m_subterm = subterm(m, n, range); + m_subterm->attach(this); + } symbol const& name() const { return m_name; } unsigned id() const { return m_class_id; } sort_ref instantiate(sort_ref_vector const& ps) const; @@ -222,6 +244,8 @@ namespace datatype { SASSERT(result); // Post-condition: get_constructor_by_name returns a non-null result return result; } + bool has_subterm() const { return m_subterm.has_value(); } + subterm const& get_subterm() const { return *m_subterm; } def* translate(ast_translation& tr, util& u); }; @@ -293,6 +317,7 @@ namespace datatype { obj_map*> m_datatype2constructors; + obj_map m_datatype2subterm; obj_map m_datatype2nonrec_constructor; obj_map*> m_constructor2accessors; obj_map m_constructor2recognizer; @@ -324,6 +349,16 @@ namespace datatype { unsigned num_parameters, parameter const * parameters, unsigned arity, sort * const * domain, sort * range); + /** + * \brief declares a subterm predicate + * + * Subterms have the signature `sort -> sort -> bool` and are only + * supported for non-mutually recursive datatypes + */ + func_decl * mk_subterm( + unsigned num_parameters, parameter const * parameters, + unsigned arity, sort * const * domain, sort * range); + func_decl * mk_recognizer( unsigned num_parameters, parameter const * parameters, unsigned arity, sort * const * domain, sort * range); @@ -379,6 +414,8 @@ namespace datatype { bool is_is(func_decl * f) const { return is_decl_of(f, fid(), OP_DT_IS); } bool is_accessor(func_decl * f) const { return is_decl_of(f, fid(), OP_DT_ACCESSOR); } bool is_update_field(func_decl * f) const { return is_decl_of(f, fid(), OP_DT_UPDATE_FIELD); } + bool is_subterm_predicate(func_decl * f) const { return is_decl_of(f, fid(), OP_DT_SUBTERM); } + bool is_subterm_predicate(expr* e) const { return is_app(e) && is_subterm_predicate(to_app(e)->get_decl()); } bool is_constructor(app const * f) const { return is_app_of(f, fid(), OP_DT_CONSTRUCTOR); } bool is_constructor(expr const * e) const { return is_app(e) && is_constructor(to_app(e)); } bool is_recognizer0(app const* f) const { return is_app_of(f, fid(), OP_DT_RECOGNISER);} @@ -393,6 +430,7 @@ namespace datatype { bool is_update_field(expr * f) const { return is_app(f) && is_app_of(to_app(f), fid(), OP_DT_UPDATE_FIELD); } app* mk_is(func_decl * c, expr *f); ptr_vector const * get_datatype_constructors(sort * ty); + func_decl * get_datatype_subterm(sort * ty); unsigned get_datatype_num_constructors(sort * ty); unsigned get_datatype_num_parameter_sorts(sort * ty); sort* get_datatype_parameter_sort(sort * ty, unsigned idx); @@ -468,7 +506,7 @@ inline constructor_decl * mk_constructor_decl(symbol const & n, symbol const & r // Remark: the datatype becomes the owner of the constructor_decls -datatype_decl * mk_datatype_decl(datatype_util& u, symbol const & n, unsigned num_params, sort*const* params, unsigned num_constructors, constructor_decl * const * cs); +datatype_decl * mk_datatype_decl(datatype_util& u, symbol const & n, unsigned num_params, sort*const* params, unsigned num_constructors, constructor_decl * const * cs, symbol const& subterm_name = symbol::null); inline void del_datatype_decl(datatype_decl * d) {} inline void del_datatype_decls(unsigned num, datatype_decl * const * ds) {} diff --git a/src/ast/rewriter/datatype_rewriter.cpp b/src/ast/rewriter/datatype_rewriter.cpp index 770aaba4b..dcdc25517 100644 --- a/src/ast/rewriter/datatype_rewriter.cpp +++ b/src/ast/rewriter/datatype_rewriter.cpp @@ -121,6 +121,9 @@ br_status datatype_rewriter::mk_app_core(func_decl * f, unsigned num_args, expr result = m().mk_app(c_decl, num, new_args.data()); return BR_DONE; } + case OP_DT_SUBTERM: + // No rewrite yet for subterms + return BR_FAILED; default: UNREACHABLE(); } diff --git a/src/cmd_context/cmd_context.cpp b/src/cmd_context/cmd_context.cpp index aab16efde..8b5d126ec 100644 --- a/src/cmd_context/cmd_context.cpp +++ b/src/cmd_context/cmd_context.cpp @@ -2546,6 +2546,9 @@ void cmd_context::dt_eh::operator()(sort * dt, pdecl* pd) { m_owner.insert(a); } } + if (func_decl * sub = m_dt_util.get_datatype_subterm(dt)) { + m_owner.insert(sub); + } if (!m_owner.m_scopes.empty() && !m_owner.m_global_decls) { m_owner.pm().inc_ref(pd); m_owner.m_psort_inst_stack.push_back(pd); diff --git a/src/cmd_context/pdecl.cpp b/src/cmd_context/pdecl.cpp index c0c63befb..722a66fff 100644 --- a/src/cmd_context/pdecl.cpp +++ b/src/cmd_context/pdecl.cpp @@ -541,6 +541,12 @@ void pconstructor_decl::display(std::ostream & out, pdatatype_decl const * const out << ")"; } +// ~~~~~~~~~~~~ psubterm_decl ~~~~~~~~~~~~ // +std::ostream& psubterm_decl::display(std::ostream & out) const { + return out << ":subterm " << m_name; +} + + pdatatype_decl::pdatatype_decl(unsigned id, unsigned num_params, pdecl_manager & m, symbol const & n, unsigned num_constructors, pconstructor_decl * const * constructors): psort_decl(id, num_params, m, n), @@ -589,7 +595,11 @@ datatype_decl * pdatatype_decl::instantiate_decl(pdecl_manager & m, unsigned n, for (auto c : m_constructors) cs.push_back(c->instantiate_decl(m, n, s)); datatype_util util(m.m()); - return mk_datatype_decl(util, m_name, m_num_params, s, cs.size(), cs.data()); + symbol subterm_name = symbol::null; + if (m_subterm.has_value()) { + subterm_name = m_subterm->get_name(); + } + return mk_datatype_decl(util, m_name, m_num_params, s, cs.size(), cs.data(), subterm_name); } struct datatype_decl_buffer { @@ -647,6 +657,9 @@ std::ostream& pdatatype_decl::display(std::ostream & out) const { } first = false; } + if (m_subterm.has_value()) { + m_subterm->display(out); + } return out << ")"; } diff --git a/src/cmd_context/pdecl.h b/src/cmd_context/pdecl.h index a3005f182..409172671 100644 --- a/src/cmd_context/pdecl.h +++ b/src/cmd_context/pdecl.h @@ -229,11 +229,23 @@ public: void display(std::ostream & out, pdatatype_decl const * const * dts) const; }; +class psubterm_decl: public pdecl { + friend class pdecl_manager; + friend class pdatatype_decl; + symbol m_name; + ptype m_type; + symbol const & get_name() const { return m_name; } +public: + psubterm_decl(symbol const& n) : pdecl(0, 0), m_name(n) {} + std::ostream& display(std::ostream & out) const override; +}; + class pdatatype_decl : public psort_decl { friend class pdecl_manager; friend class pdatatypes_decl; ptr_vector m_constructors; pdatatypes_decl * m_parent; + std::optional m_subterm; pdatatype_decl(unsigned id, unsigned num_params, pdecl_manager & m, symbol const & n, unsigned num_constructors, pconstructor_decl * const * constructors); void finalize(pdecl_manager & m) override; @@ -246,6 +258,7 @@ public: bool has_missing_refs(symbol & missing) const; bool has_duplicate_accessors(symbol & repeated) const; bool commit(pdecl_manager& m); + void set_subterm(symbol const& n) { m_subterm = psubterm_decl(n); } }; /** diff --git a/src/parsers/smt2/smt2parser.cpp b/src/parsers/smt2/smt2parser.cpp index 3ce1ece4a..5d0ce85ff 100644 --- a/src/parsers/smt2/smt2parser.cpp +++ b/src/parsers/smt2/smt2parser.cpp @@ -105,6 +105,7 @@ namespace smt2 { symbol m_declare_type_var; symbol m_declare_datatypes; symbol m_declare_datatype; + symbol m_subterm_keyword; symbol m_par; symbol m_push; symbol m_pop; @@ -955,7 +956,7 @@ namespace smt2 { next(); } - // ( declare-datatype symbol datatype_dec) + // ( declare-datatype symbol datatype_dec [:subterm ]) void parse_declare_datatype() { SASSERT(curr_is_identifier()); SASSERT(curr_id() == m_declare_datatype); @@ -974,8 +975,15 @@ namespace smt2 { pdatatype_decl_ref d(pm()); pconstructor_decl_ref_buffer new_ct_decls(pm()); parse_datatype_dec(&dt_name, new_ct_decls); + + symbol subterm_name = parse_subterm_decl(); + d = pm().mk_pdatatype_decl(m_sort_id2param_idx.size(), dt_name, new_ct_decls.size(), new_ct_decls.data()); + if (subterm_name != symbol::null) { + d->set_subterm(subterm_name); + } + check_missing(d, line, pos); check_duplicate(d, line, pos); @@ -985,6 +993,18 @@ namespace smt2 { next(); } + // [:subterm ] + symbol parse_subterm_decl() { + symbol predicate_name = symbol::null; + if ((curr_is_identifier() || curr() == scanner::KEYWORD_TOKEN) && curr_id() == m_subterm_keyword) { + next(); // consume :subterm keyword + check_identifier("expected name for subterm predicate"); + predicate_name = curr_id(); + next(); + } + return predicate_name; + } + // datatype_dec ::= ( constructor_dec+ ) | ( par ( symbol+ ) ( constructor_dec+ ) ) @@ -3088,6 +3108,7 @@ namespace smt2 { m_declare_type_var("declare-type-var"), m_declare_datatypes("declare-datatypes"), m_declare_datatype("declare-datatype"), + m_subterm_keyword(":subterm"), m_par("par"), m_push("push"), m_pop("pop"), diff --git a/src/smt/theory_datatype.cpp b/src/smt/theory_datatype.cpp index b4a3ed4db..00dae2233 100644 --- a/src/smt/theory_datatype.cpp +++ b/src/smt/theory_datatype.cpp @@ -28,6 +28,7 @@ Revision History: #include namespace smt { + class dt_eq_justification : public ext_theory_eq_propagation_justification { public: @@ -260,6 +261,21 @@ namespace smt { ctx.mk_th_axiom(get_id(), 2, lits); } + void theory_datatype::assert_subterm_axioms(enode * n) { + sort * s = n->get_sort(); + if (m_util.is_datatype(s)) { + func_decl * sub_decl = m_util.get_datatype_subterm(s); + if (sub_decl) { + TRACE(datatype, tout << "asserting reflexivity for #" << n->get_owner_id() << " " << mk_pp(n->get_expr(), m) << "\n";); + app_ref reflex(m.mk_app(sub_decl, n->get_expr(), n->get_expr()), m); + ctx.internalize(reflex, false); + literal l(ctx.get_bool_var(reflex)); + ctx.mark_as_relevant(l); + ctx.mk_th_axiom(get_id(), 1, &l); + } + } + } + theory_var theory_datatype::mk_var(enode * n) { theory_var r = theory::mk_var(n); VERIFY(r == static_cast(m_find.mk_var())); @@ -267,6 +283,9 @@ namespace smt { m_var_data.push_back(alloc(var_data)); var_data * d = m_var_data[r]; ctx.attach_th_var(n, this, r); + + assert_subterm_axioms(n); + if (is_constructor(n)) { d->m_constructor = n; assert_accessor_axioms(n); @@ -327,7 +346,7 @@ namespace smt { // it. // Moreover, fresh variables of sort S can only be created after the // interpretation for each (relevant) expression of sort S in the - // logical context is created. Returning to the example, + // logical context is created. Returning to the example, // to create the interpretation of x1 we need the // interpretation for x2. So, x2 cannot be a fresh value, // since it would have to be created after x1. @@ -350,6 +369,18 @@ namespace smt { } mk_var(e); } + else if (m_util.is_subterm_predicate(term)) { + SASSERT(term->get_num_args() == 2); + enode * arg1 = e->get_arg(0); + if (!is_attached_to_var(arg1)) + mk_var(arg1); + enode * arg2 = e->get_arg(1); + if (!is_attached_to_var(arg2)) + mk_var(arg2); + SASSERT(is_attached_to_var(arg1)); + SASSERT(is_attached_to_var(arg2)); + // Axiom generation logic for subterm can be added here. + } else { SASSERT(is_accessor(term) || is_recognizer(term)); SASSERT(term->get_num_args() == 1); @@ -413,35 +444,282 @@ namespace smt { void theory_datatype::assign_eh(bool_var v, bool is_true) { force_push(); - enode * n = ctx.bool_var2enode(v); - if (!is_recognizer(n)) - return; - TRACE(datatype, tout << "assigning recognizer: #" << n->get_owner_id() << " is_true: " << is_true << "\n" - << enode_pp(n, ctx) << "\n";); - SASSERT(n->get_num_args() == 1); - enode * arg = n->get_arg(0); - theory_var tv = arg->get_th_var(get_id()); - tv = m_find.find(tv); - var_data * d = m_var_data[tv]; - func_decl * r = n->get_decl(); - func_decl * c = m_util.get_recognizer_constructor(r); - if (is_true) { - SASSERT(tv != null_theory_var); - if (d->m_constructor != nullptr && d->m_constructor->get_decl() == c) - return; // do nothing - assert_is_constructor_axiom(arg, c, literal(v)); - } - else { - if (d->m_constructor != nullptr) { - if (d->m_constructor->get_decl() == c) { - // conflict - sign_recognizer_conflict(d->m_constructor, n); - } + enode *n = ctx.bool_var2enode(v); + if (is_recognizer(n)) { + TRACE(datatype, tout << "assigning recognizer: #" << n->get_owner_id() << " is_true: " << is_true << "\n" + << enode_pp(n, ctx) << "\n";); + SASSERT(n->get_num_args() == 1); + enode *arg = n->get_arg(0); + theory_var tv = arg->get_th_var(get_id()); + tv = m_find.find(tv); + var_data *d = m_var_data[tv]; + func_decl *r = n->get_decl(); + func_decl *c = m_util.get_recognizer_constructor(r); + if (is_true) { + SASSERT(tv != null_theory_var); + if (d->m_constructor != nullptr && d->m_constructor->get_decl() == c) + return; // do nothing + assert_is_constructor_axiom(arg, c, literal(v)); + propagate_subterm_with_constructor(tv); } else { - propagate_recognizer(tv, n); + if (d->m_constructor != nullptr) { + if (d->m_constructor->get_decl() == c) { + // conflict + sign_recognizer_conflict(d->m_constructor, n); + } + } + else { + propagate_recognizer(tv, n); + } } } + else if (is_subterm_predicate(n)) { + TRACE(datatype, tout << "assigning subterm: #" << n->get_owner_id() << " is_true: " << is_true << "\n" + << enode_pp(n, ctx) << "\n";); + SASSERT(n->get_num_args() == 2); + + propagate_subterm(n, is_true); + } + } + + void theory_datatype::propagate_subterm_with_constructor(theory_var v) { + v = m_find.find(v); + var_data *d = m_var_data[v]; + if (!d->m_constructor) + return; + + ptr_vector subs(d->m_subterms); + for (enode *n : subs) { + lbool val = ctx.get_assignment(n); + switch (val) { + case l_undef: continue; + case l_true: propagate_subterm(n, true); break; + case l_false: propagate_subterm(n, false); break; + } + } + } + + void theory_datatype::propagate_subterm(enode *n, bool is_true) { + force_push(); // I am fairly sure I need that here + if (is_true) { + propagate_is_subterm(n); + } + else { + propagate_not_is_subterm(n); + } + } + + void theory_datatype::propagate_is_subterm(enode *n) { + SASSERT(is_subterm_predicate(n)); + enode *arg1 = n->get_arg(0); + enode *arg2 = n->get_arg(1); + + // If we are here, n is assigned true. + SASSERT(ctx.get_assignment(n) == l_true); + + TRACE(datatype, tout << "propagate_is_subterm: " << enode_pp(n, ctx) << "\n";); + + if (arg1->get_root() == arg2->get_root()) { + TRACE(datatype, tout << "subterm reflexivity, skipping " << "\n";); + return; + } + + literal_vector lits; + lits.push_back(literal(ctx.enode2bool_var(n), true)); // antecedent: ~n + + bool found_possible = false; + bool has_leaf_root = false; + + ptr_vector candidates = list_subterms(arg2); + + for (enode *s : candidates) { + bool is_leaf = !m_util.is_constructor(s->get_expr()); + + // Case 1: Equality check (arg1 == s) + // Valid if sorts are compatible. + if (s->get_sort() == arg1->get_sort()) { + // trying to be smarter about this causes other problems + TRACE(datatype, tout << "adding equality case: " << mk_pp(arg1->get_expr(), m) + << " == " << mk_pp(s->get_expr(), m) << "\n";); + lits.push_back(mk_eq(arg1->get_expr(), s->get_expr(), false)); + found_possible = true; + } + + // Case 2: Recursive subterm check (arg1 ⊑ s) + // Only if s is a leaf (unexpanded) and not the root itself (to avoid tautology). + if (is_leaf) { + if (s->get_root() == arg2->get_root()) { + // If arg2 is a leaf, we haven't explored its possibilities yet. + has_leaf_root = true; + found_possible = true; + continue; + } + + if (m_util.is_datatype(s->get_sort())) { + // arg1 ⊑ s + func_decl *sub_decl = m_util.get_datatype_subterm(s->get_sort()); + if (sub_decl) { + TRACE(datatype, tout << "adding recursive case: " << mk_pp(arg1->get_expr(), m) << " ⊑ " + << mk_pp(s->get_expr(), m) << "\n";); + auto tmp = m.mk_not( m.mk_app(sub_decl, arg1->get_expr(), s->get_expr())); + lits.push_back(mk_literal(app_ref(tmp, m))); + found_possible = true; + } + } + } + } + + if (has_leaf_root) { + split_leaf_root(arg2); + } + + if (lits.size() > 1) { + if (!has_leaf_root) { + ctx.mk_th_axiom(get_id(), lits.size(), lits.data()); + } + } + else if (!found_possible) { + // Conflict: arg1 cannot be subterm of arg2 (no path matches) + TRACE(datatype, tout << "conflict: no path matches\n";); + ctx.mk_th_axiom(get_id(), lits.size(), lits.data()); + } + } + + void theory_datatype::propagate_not_is_subterm(enode *n) { + SASSERT(is_subterm_predicate(n)); + enode *arg1 = n->get_arg(0); + enode *arg2 = n->get_arg(1); + + // If we are here, n is assigned false. + SASSERT(ctx.get_assignment(n) == l_false); + + if (arg1->get_root() == arg2->get_root()) { + // ~ (a ⊑ a) is a conflict + literal l(ctx.enode2bool_var(n)); + ctx.set_conflict(ctx.mk_justification(ext_theory_conflict_justification(get_id(), ctx, 1, &l, 0, nullptr))); + return; + } + + TRACE(datatype, tout << "propagate_not_is_subterm: " << enode_pp(n, ctx) << "\n";); + + literal antecedent = literal(ctx.enode2bool_var(n), false); + bool has_leaf_root = false; + + ptr_vector candidates = list_subterms(arg2); + + for (enode *s : candidates) { + bool is_leaf = !m_util.is_constructor(s->get_expr()); + + if (s->get_sort() == arg1->get_sort()) { + TRACE(datatype, + tout << "asserting " << mk_pp(arg1->get_expr(), m) << " != " << mk_pp(s->get_expr(), m) << "\n";); + literal eq = mk_eq(arg1->get_expr(), s->get_expr(), true); + literal lits[2] = {antecedent, ~eq}; + ctx.mk_th_axiom(get_id(), 2, lits); + } + + if (is_leaf) { + if (s->get_root() == arg2->get_root()) { + has_leaf_root = true; + continue; + } + + if (m_util.is_datatype(s->get_sort())) { + func_decl *sub_decl = m_util.get_datatype_subterm(s->get_sort()); + if (sub_decl) { + TRACE(datatype, tout << "asserting NOT " << mk_pp(arg1->get_expr(), m) << " subterm " + << mk_pp(s->get_expr(), m) << "\n";); + app_ref sub_app(m.mk_app(sub_decl, arg1->get_expr(), s->get_expr()), m); + ctx.internalize(sub_app, false); + literal sub_lit = literal(ctx.get_bool_var(sub_app)); + literal lits[2] = {antecedent, ~sub_lit}; + ctx.mk_th_axiom(get_id(), 2, lits); + } + } + } + } + + if (has_leaf_root) { + split_leaf_root(arg2); + } + } + + // requesting to split on arg2 + void theory_datatype::split_leaf_root(smt::enode *arg2) { + TRACE(datatype, tout << "arg is a leaf: " << enode_pp(arg2, ctx) << "\n";); + theory_var v = arg2->get_th_var(get_id()); + if (v != null_theory_var) { + v = m_find.find(v); + if (m_var_data[v]->m_constructor == nullptr) { + mk_split(v); + } + } + } + + void subterm_iterator::next() { + m_current = nullptr; + if (!m_manager) + return; + + while (!m_todo.empty()) { + enode *curr = m_todo.back(); + m_todo.pop_back(); + enode *root = curr->get_root(); + + if (root->is_marked()) + continue; + root->set_mark(); + m_marked.push_back(root); + + enode *ctor = nullptr; + enode *iter = root; + do { + if (m_util->is_constructor(iter->get_expr())) { + ctor = iter; + break; + } + iter = iter->get_next(); + } while (iter != root); + + if (ctor) { + m_current = ctor; + for (enode *child : enode::args(ctor)) { + m_todo.push_back(child); + } + return; + } + else { + m_current = root; + return; + } + } + } + + subterm_iterator::subterm_iterator(ast_manager &m, datatype_util& m_util, enode *start) : m_manager(&m), m_current(nullptr), m_util(&m_util) { + m_todo.push_back(start); + next(); + } + + subterm_iterator::subterm_iterator(subterm_iterator &&other) : m_manager(nullptr), m_current(nullptr), m_util(nullptr) { + m_todo.swap(other.m_todo); + m_marked.swap(other.m_marked); + std::swap(m_manager, other.m_manager); + std::swap(m_current, other.m_current); + std::swap(m_util, other.m_util); + } + + subterm_iterator::~subterm_iterator() { + for (enode *n : m_marked) + n->unset_mark(); + } + + ptr_vector theory_datatype::list_subterms(enode* arg) { + ptr_vector result; + for (enode* n : iterate_subterms(get_manager(), m_util, arg)) { + result.push_back(n); + } + return result; } void theory_datatype::relevant_eh(app * n) { @@ -455,6 +733,19 @@ namespace smt { SASSERT(v != null_theory_var); add_recognizer(v, e); } + else if (is_subterm_predicate(n)) { + SASSERT(ctx.e_internalized(n)); + + enode * e = ctx.get_enode(n); + theory_var a = e->get_arg(0)->get_th_var(get_id()); // e is 'a ⊑ b' + theory_var b = e->get_arg(1)->get_th_var(get_id()); // e is 'a ⊑ b' + SASSERT(a != null_theory_var && b != null_theory_var); + + add_subterm_predicate(a, e); + add_subterm_predicate(b, e); + + // propagating potentially adds a lot of literals, avoid it if we can + } } void theory_datatype::push_scope_eh() { @@ -872,11 +1163,16 @@ namespace smt { } } d1->m_constructor = d2->m_constructor; + propagate_subterm_with_constructor(v1); } } for (enode* e : d2->m_recognizers) if (e) add_recognizer(v1, e); + + for (enode* e : d2->m_subterms) { + add_subterm_predicate(v1, e); + } } void theory_datatype::unmerge_eh(theory_var v1, theory_var v2) { @@ -921,6 +1217,26 @@ namespace smt { } } + /** \brief register `predicate` to `v`'s `var_data` + * + * With `predicate:='a ⊑ b'` this should be called with `v:='a'` and `v:='b'`. + * + * This doesn't handle potential propagation. The responsibility for it + * falls on the caller. + */ + void theory_datatype::add_subterm_predicate(theory_var v, enode * predicate) { + SASSERT(is_subterm_predicate(predicate)); + v = m_find.find(v); + var_data * d = m_var_data[v]; + + if (d->m_subterms.contains(predicate)) return; + + TRACE(datatype, tout << "add subterm predicate\n" << enode_pp(predicate, ctx) << "\n";); + + m_trail_stack.push(restore_vector(d->m_subterms)); + d->m_subterms.push_back(predicate); + } + /** \brief Propagate a recognizer assigned to false. */ diff --git a/src/smt/theory_datatype.h b/src/smt/theory_datatype.h index dfc06ae69..b52cae1b3 100644 --- a/src/smt/theory_datatype.h +++ b/src/smt/theory_datatype.h @@ -33,6 +33,18 @@ namespace smt { struct var_data { ptr_vector m_recognizers; //!< recognizers of this equivalence class that are being watched. enode * m_constructor; //!< constructor of this equivalence class, 0 if there is no constructor in the eqc. + + /** + * \brief subterm predicates that involve this equivalence class + * + * So all terms of the shape `a ⊑ b` where `var_data` represents either `a` or `b`. + * + * This is more a set than a vector, but I'll use `ptr_vector` + * because I know the API better, it's easier to backtrack on it and + * it should be small enough to outperform a hasmap anyway + */ + ptr_vector m_subterms; + var_data(): m_constructor(nullptr) { } @@ -56,11 +68,13 @@ namespace smt { bool is_constructor(app * f) const { return m_util.is_constructor(f); } bool is_recognizer(app * f) const { return m_util.is_recognizer(f); } + bool is_subterm_predicate(app * f) const { return m_util.is_subterm_predicate(f); } bool is_accessor(app * f) const { return m_util.is_accessor(f); } bool is_update_field(app * f) const { return m_util.is_update_field(f); } bool is_constructor(enode * n) const { return is_constructor(n->get_expr()); } bool is_recognizer(enode * n) const { return is_recognizer(n->get_expr()); } + bool is_subterm_predicate(enode * n) const { return is_subterm_predicate(n->get_expr()); } bool is_accessor(enode * n) const { return is_accessor(n->get_expr()); } bool is_update_field(enode * n) const { return m_util.is_update_field(n->get_expr()); } @@ -68,8 +82,15 @@ namespace smt { void assert_is_constructor_axiom(enode * n, func_decl * c, literal antecedent); void assert_accessor_axioms(enode * n); void assert_update_field_axioms(enode * n); + void assert_subterm_axioms(enode * n); void add_recognizer(theory_var v, enode * recognizer); - void propagate_recognizer(theory_var v, enode * r); + void add_subterm_predicate(theory_var v, enode *predicate); + void propagate_subterm(enode * n, bool is_true); + void propagate_is_subterm(enode * n); + void propagate_not_is_subterm(enode *n); + void split_leaf_root(smt::enode *arg2); + void propagate_subterm_with_constructor(theory_var v); + void propagate_recognizer(theory_var v, enode *r); void sign_recognizer_conflict(enode * c, enode * r); typedef enum { ENTER, EXIT } stack_op; @@ -113,6 +134,7 @@ namespace smt { void mk_split(theory_var v); void display_var(std::ostream & out, theory_var v) const; + ptr_vector list_subterms(enode* arg); protected: theory_var mk_var(enode * n) override; @@ -148,6 +170,57 @@ namespace smt { }; + /** + * Iterator over the subterms of an enode. + * + * It only takes into account datatype terms when looking for subterms. + * + * It uses the `mark` field of the `enode` struct to mark the node visited. + * It will clean afterwards. *Implementation invariant*: the destructor + * *must* be run *exactly* once otherwise the marks might not be clean or + * might be clean more than once and mid search + */ + class subterm_iterator { + ptr_vector m_todo; + ptr_vector m_marked; + ast_manager* m_manager; + enode* m_current; + datatype_util* m_util; + + void next(); + subterm_iterator() : m_manager(nullptr), m_current(nullptr), m_util(nullptr) {} + + public: + // subterm_iterator(); + subterm_iterator(ast_manager& m, datatype_util& m_util, enode *start); + ~subterm_iterator(); + subterm_iterator(subterm_iterator &&other); + // need to delete this function otherwise the destructor could be ran + // more than once, invalidating the marks used in the dfs. + subterm_iterator(const subterm_iterator& other) = delete; + + subterm_iterator begin() { + return std::move(*this); + } + subterm_iterator end() { + return subterm_iterator(); + } + + bool operator!=(const subterm_iterator &other) const { + return m_current != other.m_current; + } + + enode *operator*() const { + return m_current; + } + + void operator++() { next(); } + subterm_iterator& operator=(const subterm_iterator&) = delete; + }; + + inline subterm_iterator iterate_subterms(ast_manager& m, datatype_util& m_util, enode *arg) { + return subterm_iterator(m, m_util, arg); + } };