From 81d97a81af7aeaf1773d5c9569567e8b6dd8549f Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Wed, 27 Apr 2022 09:58:38 +0100 Subject: [PATCH] enable nested ADT and sequences add API to define forward reference to recursively defined datatype. The forward reference should be used only when passed to constructor declarations that are used in a datatype definition (Z3_mk_datatypes). The call to Z3_mk_datatypes ensures that the forward reference can be resolved with respect to constructors. --- src/api/api_datatype.cpp | 14 ++++ src/api/z3_api.h | 13 ++++ src/ast/datatype_decl_plugin.cpp | 65 +++++++++++------- src/ast/datatype_decl_plugin.h | 10 +-- src/cmd_context/pdecl.cpp | 16 ++--- src/model/datatype_factory.cpp | 2 +- src/sat/smt/dt_solver.cpp | 80 +++++++++++++++++----- src/sat/smt/dt_solver.h | 5 +- src/smt/theory_datatype.cpp | 114 ++++++++++++++++++++++--------- src/smt/theory_datatype.h | 5 +- 10 files changed, 232 insertions(+), 92 deletions(-) diff --git a/src/api/api_datatype.cpp b/src/api/api_datatype.cpp index 673ceb3c1..23ff21575 100644 --- a/src/api/api_datatype.cpp +++ b/src/api/api_datatype.cpp @@ -365,6 +365,20 @@ extern "C" { Z3_CATCH; } + Z3_sort Z3_API Z3_mk_datatype_sort(Z3_context c, Z3_symbol name) { + Z3_TRY; + LOG_Z3_mk_datatype_sort(c, name); + RESET_ERROR_CODE(); + ast_manager& m = mk_c(c)->m(); + datatype_util data_util(m); + parameter param(name); + sort * s = m.mk_sort(util.get_family_id(), DATATYPE_SORT, 1, ¶m); + mk_c(c)->save_ast_trail(s); + RETURN_Z3(of_sort(s)); + Z3_CATCH_RETURN(nullptr); + } + + void Z3_API Z3_mk_datatypes(Z3_context c, unsigned num_sorts, Z3_symbol const sort_names[], diff --git a/src/api/z3_api.h b/src/api/z3_api.h index a9e0c6b7c..7ca693984 100644 --- a/src/api/z3_api.h +++ b/src/api/z3_api.h @@ -2094,6 +2094,19 @@ extern "C" { unsigned num_constructors, Z3_constructor constructors[]); + /** + \brief create a forward reference to a recursive datatype being declared. + The forward reference can be used in a nested occurrence: the range of an array + or as element sort of a sequence. The forward reference should only be used when + used in an accessor for a recursive datatype that gets declared. + + Forward references can replace the use sort references, that are unsigned integers + in the \c Z3_mk_constructor call + + def_API('Z3_mk_datatype_sort', SORT, (_in(CONTEXT), _in(SYMBOL))) + */ + Z3_sort Z3_API Z3_mk_datatype_sort(Z3_context c, Z3_symbol name); + /** \brief Create list of constructors. diff --git a/src/ast/datatype_decl_plugin.cpp b/src/ast/datatype_decl_plugin.cpp index bc574fc1c..11c33d695 100644 --- a/src/ast/datatype_decl_plugin.cpp +++ b/src/ast/datatype_decl_plugin.cpp @@ -19,6 +19,7 @@ Revision History: #include "util/warning.h" #include "ast/array_decl_plugin.h" +#include "ast/seq_decl_plugin.h" #include "ast/datatype_decl_plugin.h" #include "ast/ast_smt2_pp.h" #include "ast/ast_pp.h" @@ -462,28 +463,27 @@ namespace datatype { } for (symbol const& s : m_def_block) { def& d = *m_defs[s]; - for (constructor* c : d) { - for (accessor* a : *c) { + for (constructor* c : d) + for (accessor* a : *c) a->fix_range(sorts); - } - } } - if (!u().is_well_founded(sorts.size(), sorts.data())) { + if (!u().is_well_founded(sorts.size(), sorts.data())) m_manager->raise_exception("datatype is not well-founded"); - } - if (!u().is_covariant(sorts.size(), sorts.data())) { + if (!u().is_covariant(sorts.size(), sorts.data())) m_manager->raise_exception("datatype is not co-variant"); - } - + array_util autil(m); + seq_util sutil(m); + sort* sr; for (sort* s : sorts) { for (constructor const* c : get_def(s)) { for (accessor const* a : *c) { - if (autil.is_array(a->range())) { - if (sorts.contains(get_array_range(a->range()))) { - m_has_nested_arrays = true; - } - } + if (autil.is_array(a->range()) && sorts.contains(get_array_range(a->range()))) + m_has_nested_rec = true; + else if (sutil.is_seq(a->range(), sr) && sorts.contains(sr)) + m_has_nested_rec = true; + else if (sutil.is_re(a->range(), sr) && sorts.contains(sr)) + m_has_nested_rec = true; } } } @@ -1103,12 +1103,19 @@ namespace datatype { return r; } - bool util::is_recursive_array(sort* a) { + bool util::is_recursive_nested(sort* a) { array_util autil(m); - if (!autil.is_array(a)) - return false; - a = autil.get_array_range_rec(a); - return is_datatype(a) && is_recursive(a); + seq_util sutil(m); + sort* sr; + if (autil.is_array(a)) { + a = autil.get_array_range_rec(a); + return is_datatype(a) && is_recursive(a); + } + if (sutil.is_seq(a, sr)) + return is_datatype(sr) && is_recursive(sr); + if (sutil.is_re(a, sr)) + return is_datatype(sr) && is_recursive(sr); + return false; } bool util::is_enum_sort(sort* s) { @@ -1273,14 +1280,22 @@ namespace datatype { */ bool util::are_siblings(sort * s1, sort * s2) { array_util autil(m); - s1 = autil.get_array_range_rec(s1); - s2 = autil.get_array_range_rec(s2); - if (!is_datatype(s1) || !is_datatype(s2)) { + seq_util sutil(m); + auto get_nested = [&](sort* s) { + while (true) { + if (autil.is_array(s)) + s = get_array_range(s); + else if (!sutil.is_seq(s, s)) + break; + } + return s; + }; + s1 = get_nested(s1); + s2 = get_nested(s2); + if (!is_datatype(s1) || !is_datatype(s2)) return s1 == s2; - } - else { + else return get_def(s1).id() == get_def(s2).id(); - } } unsigned util::get_datatype_num_constructors(sort * ty) { diff --git a/src/ast/datatype_decl_plugin.h b/src/ast/datatype_decl_plugin.h index 0172d0b1d..0698bf821 100644 --- a/src/ast/datatype_decl_plugin.h +++ b/src/ast/datatype_decl_plugin.h @@ -207,14 +207,14 @@ namespace datatype { unsigned m_id_counter; svector m_def_block; unsigned m_class_id; - mutable bool m_has_nested_arrays; + mutable bool m_has_nested_rec; void inherit(decl_plugin* other_p, ast_translation& tr) override; void log_axiom_definitions(symbol const& s, sort * new_sort); public: - plugin(): m_id_counter(0), m_class_id(0), m_has_nested_arrays(false) {} + plugin(): m_id_counter(0), m_class_id(0), m_has_nested_rec(false) {} ~plugin() override; void finalize() override; @@ -254,7 +254,7 @@ namespace datatype { unsigned get_axiom_base_id(symbol const& s) { return m_axiom_bases[s]; } util & u() const; - bool has_nested_arrays() const { return m_has_nested_arrays; } + bool has_nested_rec() const { return m_has_nested_rec; } private: bool is_value_visit(bool unique, expr * arg, ptr_buffer & todo) const; @@ -334,7 +334,7 @@ namespace datatype { bool is_datatype(sort const* s) const { return is_sort_of(s, fid(), DATATYPE_SORT); } bool is_enum_sort(sort* s); bool is_recursive(sort * ty); - bool is_recursive_array(sort * ty); + bool is_recursive_nested(sort * ty); bool is_constructor(func_decl * f) const { return is_decl_of(f, fid(), OP_DT_CONSTRUCTOR); } bool is_recognizer(func_decl * f) const { return is_recognizer0(f) || is_is(f); } bool is_recognizer0(func_decl * f) const { return is_decl_of(f, fid(), OP_DT_RECOGNISER); } @@ -365,7 +365,7 @@ namespace datatype { func_decl * get_accessor_constructor(func_decl * accessor); func_decl * get_recognizer_constructor(func_decl * recognizer) const; func_decl * get_update_accessor(func_decl * update) const; - bool has_nested_arrays() const { return plugin().has_nested_arrays(); } + bool has_nested_rec() const { return plugin().has_nested_rec(); } family_id get_family_id() const { return fid(); } decl::plugin& plugin() const; bool are_siblings(sort * s1, sort * s2); diff --git a/src/cmd_context/pdecl.cpp b/src/cmd_context/pdecl.cpp index 2bf21de3a..1545487c7 100644 --- a/src/cmd_context/pdecl.cpp +++ b/src/cmd_context/pdecl.cpp @@ -493,18 +493,16 @@ void pconstructor_decl::finalize(pdecl_manager & m) { } bool pconstructor_decl::has_missing_refs(symbol & missing) const { - for (paccessor_decl* a : m_accessors) { + for (paccessor_decl* a : m_accessors) if (a->has_missing_refs(missing)) return true; - } return false; } bool pconstructor_decl::fix_missing_refs(dictionary const & symbol2idx, symbol & missing) { - for (paccessor_decl* a : m_accessors) { + for (paccessor_decl* a : m_accessors) if (!a->fix_missing_refs(symbol2idx, missing)) return false; - } return true; } @@ -561,18 +559,16 @@ bool pdatatype_decl::has_duplicate_accessors(symbol & duplicated) const { bool pdatatype_decl::fix_missing_refs(dictionary const & symbol2idx, symbol & missing) { - for (auto c : m_constructors) { + for (auto c : m_constructors) if (!c->fix_missing_refs(symbol2idx, missing)) return false; - } return true; } datatype_decl * pdatatype_decl::instantiate_decl(pdecl_manager & m, unsigned n, sort * const * s) { ptr_buffer cs; - for (auto c : m_constructors) { + for (auto c : m_constructors) cs.push_back(c->instantiate_decl(m, n, s)); - } datatype_util util(m.m()); return mk_datatype_decl(util, m_name, m_num_params, s, cs.size(), cs.data()); } @@ -647,10 +643,8 @@ bool pdatatype_decl::commit(pdecl_manager& m) { sort_ref_vector sorts(m.m()); bool is_ok = m.get_dt_plugin()->mk_datatypes(1, &d_ptr, m_num_params, ps.data(), sorts); m.notify_mk_datatype(m_name); - if (is_ok && m_num_params == 0) { + if (is_ok && m_num_params == 0) m.notify_new_dt(sorts.get(0), this); - } - return is_ok; } diff --git a/src/model/datatype_factory.cpp b/src/model/datatype_factory.cpp index e58812a1f..56312839a 100644 --- a/src/model/datatype_factory.cpp +++ b/src/model/datatype_factory.cpp @@ -166,7 +166,7 @@ expr * datatype_factory::get_fresh_value(sort * s) { for (unsigned i = 0; i < num; i++) { sort * s_arg = constructor->get_domain(i); if (!found_fresh_arg && - !m_util.is_recursive_array(s_arg) && + !m_util.is_recursive_nested(s_arg) && (!m_util.is_recursive(s) || !m_util.is_datatype(s_arg) || !m_util.are_siblings(s, s_arg))) { expr * new_arg = m_model.get_fresh_value(s_arg); if (new_arg != nullptr) { diff --git a/src/sat/smt/dt_solver.cpp b/src/sat/smt/dt_solver.cpp index 76c154a4a..13d9768e2 100644 --- a/src/sat/smt/dt_solver.cpp +++ b/src/sat/smt/dt_solver.cpp @@ -29,6 +29,7 @@ namespace dt { th_euf_solver(ctx, ctx.get_manager().get_family_name(id), id), dt(m), m_autil(m), + m_sutil(m), m_find(*this), m_args(m) {} @@ -496,13 +497,41 @@ namespace dt { } ptr_vector const& solver::get_array_args(enode* n) { - m_array_args.reset(); + m_nodes.reset(); array::solver* th = dynamic_cast(ctx.fid2solver(m_autil.get_family_id())); for (enode* p : th->parent_selects(n)) - m_array_args.push_back(p); + m_nodes.push_back(p); app_ref def(m_autil.mk_default(n->get_expr()), m); - m_array_args.push_back(ctx.get_enode(def)); - return m_array_args; + m_nodes.push_back(ctx.get_enode(def)); + return m_nodes; + } + + ptr_vector const& solver::get_seq_args(enode* n) { + m_nodes.reset(); + m_todo.reset(); + auto add_todo = [&](enode* n) { + if (!n->is_marked1()) { + n->mark1(); + m_todo.push_back(n); + } + }; + + for (enode* sib : euf::enode_class(n)) + add_todo(sib); + + for (unsigned i = 0; i < m_todo.size(); ++i) { + enode* n = m_todo[i]; + expr* e = n->get_expr(); + if (m_sutil.str.is_unit(e)) + m_nodes.push_back(n->get_arg(0)); + else if (m_sutil.str.is_concat(e)) + for (expr* arg : *to_app(e)) + add_todo(ctx.get_enode(arg)); + } + for (enode* n : m_todo) + n->unmark1(); + + return m_nodes; } // Assuming `app` is equal to a constructor term, return the constructor enode @@ -536,6 +565,12 @@ namespace dt { for (enode* aarg : get_array_args(arg)) add(aarg); } + sort* se; + if (m_sutil.is_seq(child->get_sort(), se) && dt.is_datatype(se)) { + for (enode* aarg : get_seq_args(child)) + add(aarg); + } + VERIFY(found); } @@ -575,6 +610,21 @@ namespace dt { return false; enode* parent = d->m_constructor; oc_mark_on_stack(parent); + + auto process_arg = [&](enode* aarg) { + if (oc_cycle_free(aarg)) + return false; + if (oc_on_stack(aarg)) { + occurs_check_explain(parent, aarg); + return true; + } + if (dt.is_datatype(aarg->get_sort())) { + m_parent.insert(aarg->get_root(), parent); + oc_push_stack(aarg); + } + return false; + }; + for (enode* arg : euf::enode_args(parent)) { if (oc_cycle_free(arg)) continue; @@ -585,24 +635,20 @@ namespace dt { } // explore `arg` (with parent) expr* earg = arg->get_expr(); - sort* s = earg->get_sort(); + sort* s = earg->get_sort(), *se; if (dt.is_datatype(s)) { m_parent.insert(arg->get_root(), parent); oc_push_stack(arg); } - else if (m_autil.is_array(s) && dt.is_datatype(get_array_range(s))) { - for (enode* aarg : get_array_args(arg)) { - if (oc_cycle_free(aarg)) - continue; - if (oc_on_stack(aarg)) { - occurs_check_explain(parent, aarg); + else if (m_sutil.is_seq(s, se) && dt.is_datatype(se)) { + for (enode* sarg : get_seq_args(arg)) + if (process_arg(sarg)) + return true; + } + else if (m_autil.is_array(s) && dt.is_datatype(get_array_range(s))) { + for (enode* sarg : get_array_args(arg)) + if (process_arg(sarg)) return true; - } - if (is_datatype(aarg)) { - m_parent.insert(aarg->get_root(), parent); - oc_push_stack(aarg); - } - } } } return false; diff --git a/src/sat/smt/dt_solver.h b/src/sat/smt/dt_solver.h index cd5529075..e0a076a2d 100644 --- a/src/sat/smt/dt_solver.h +++ b/src/sat/smt/dt_solver.h @@ -19,6 +19,7 @@ Author: #include "sat/smt/sat_th.h" #include "ast/datatype_decl_plugin.h" #include "ast/array_decl_plugin.h" +#include "ast/seq_decl_plugin.h" namespace euf { class solver; @@ -62,6 +63,7 @@ namespace dt { mutable datatype_util dt; array_util m_autil; + seq_util m_sutil; stats m_stats; ptr_vector m_var_data; dt_union_find m_find; @@ -108,8 +110,9 @@ namespace dt { bool oc_cycle_free(enode * n) const { return n->get_root()->is_marked2(); } void oc_push_stack(enode * n); - ptr_vector m_array_args; + ptr_vector m_nodes, m_todo; ptr_vector const& get_array_args(enode* n); + ptr_vector const& get_seq_args(enode* n); void pop_core(unsigned n) override; diff --git a/src/smt/theory_datatype.cpp b/src/smt/theory_datatype.cpp index 035b647dc..27b922dba 100644 --- a/src/smt/theory_datatype.cpp +++ b/src/smt/theory_datatype.cpp @@ -275,7 +275,7 @@ namespace smt { else if (is_update_field(n)) { assert_update_field_axioms(n); } - else { + else if (m_util.is_datatype(n->get_sort())) { sort * s = n->get_sort(); if (m_util.get_datatype_num_constructors(s) == 1) { func_decl * c = m_util.get_datatype_constructors(s)->get(0); @@ -343,7 +343,7 @@ namespace smt { } arg = ctx.get_enode(def); } - if (!m_util.is_datatype(s)) + if (!m_util.is_datatype(s) && !m_sutil.is_seq(s)) continue; if (is_attached_to_var(arg)) continue; @@ -393,7 +393,7 @@ namespace smt { if (!is_attached_to_var(n) && (ctx.has_quantifiers() || - (m_util.is_datatype(s) && m_util.has_nested_arrays()) || + (m_util.is_datatype(s) && m_util.has_nested_rec()) || (m_util.is_datatype(s) && !s->is_infinite()))) { mk_var(n); } @@ -485,7 +485,10 @@ namespace smt { for (int v = 0; v < num_vars; v++) { if (v == static_cast(m_find.find(v))) { enode * node = get_enode(v); - if (m_util.is_recursive(node->get_sort()) && !oc_cycle_free(node) && occurs_check(node)) { + sort* s = node->get_sort(); + if (!m_util.is_datatype(s)) + continue; + if (m_util.is_recursive(s) && !oc_cycle_free(node) && occurs_check(node)) { // conflict was detected... // return... return FC_CONTINUE; @@ -541,6 +544,17 @@ namespace smt { } } } + sort* se = nullptr; + if (m_sutil.is_seq(s, se) && m_util.is_datatype(se)) { + for (enode* aarg : get_seq_args(arg)) { + if (aarg->get_root() == child->get_root()) { + if (aarg != child) { + m_used_eqs.push_back(enode_pair(aarg, child)); + } + found = true; + } + } + } } VERIFY(found); } @@ -587,6 +601,20 @@ namespace smt { } enode * parent = d->m_constructor; oc_mark_on_stack(parent); + auto process_arg = [&](enode* aarg) { + if (oc_cycle_free(aarg)) + return false; + if (oc_on_stack(aarg)) { + occurs_check_explain(parent, aarg); + return true; + } + if (m_util.is_datatype(aarg->get_sort())) { + m_parent.insert(aarg->get_root(), parent); + oc_push_stack(aarg); + } + return false; + }; + for (enode * arg : enode::args(parent)) { if (oc_cycle_free(arg)) { continue; @@ -598,39 +626,61 @@ namespace smt { } // explore `arg` (with parent) expr* earg = arg->get_expr(); - sort* s = earg->get_sort(); + sort* s = earg->get_sort(), *se = nullptr; if (m_util.is_datatype(s)) { m_parent.insert(arg->get_root(), parent); oc_push_stack(arg); } - else if (m_autil.is_array(s) && m_util.is_datatype(get_array_range(s))) { - for (enode* aarg : get_array_args(arg)) { - if (oc_cycle_free(aarg)) { - continue; - } - if (oc_on_stack(aarg)) { - occurs_check_explain(parent, aarg); + else if (m_sutil.is_seq(s, se) && m_util.is_datatype(se)) { + for (enode* sarg : get_seq_args(arg)) + if (process_arg(sarg)) return true; - } - if (m_util.is_datatype(aarg->get_sort())) { - m_parent.insert(aarg->get_root(), parent); - oc_push_stack(aarg); - } - } - } + } + else if (m_autil.is_array(s) && m_util.is_datatype(get_array_range(s))) { + for (enode* aarg : get_array_args(arg)) + if (process_arg(aarg)) + return true; + } } return false; } - ptr_vector const& theory_datatype::get_array_args(enode* n) { - m_array_args.reset(); - theory_array* th = dynamic_cast(ctx.get_theory(m_autil.get_family_id())); - for (enode* p : th->parent_selects(n)) { - m_array_args.push_back(p); + ptr_vector const& theory_datatype::get_seq_args(enode* n) { + m_args.reset(); + m_todo.reset(); + auto add_todo = [&](enode* n) { + if (!n->is_marked()) { + n->set_mark(); + m_todo.push_back(n); + } + }; + + for (enode* sib : *n) + add_todo(sib); + + for (unsigned i = 0; i < m_todo.size(); ++i) { + enode* n = m_todo[i]; + expr* e = n->get_expr(); + if (m_sutil.str.is_unit(e)) + m_args.push_back(n->get_arg(0)); + else if (m_sutil.str.is_concat(e)) + for (expr* arg : *to_app(e)) + add_todo(ctx.get_enode(arg)); } + for (enode* n : m_todo) + n->unset_mark(); + + return m_args; + } + + ptr_vector const& theory_datatype::get_array_args(enode* n) { + m_args.reset(); + theory_array* th = dynamic_cast(ctx.get_theory(m_autil.get_family_id())); + for (enode* p : th->parent_selects(n)) + m_args.push_back(p); app_ref def(m_autil.mk_default(n->get_expr()), m); - m_array_args.push_back(ctx.get_enode(def)); - return m_array_args; + m_args.push_back(ctx.get_enode(def)); + return m_args; } /** @@ -653,18 +703,19 @@ namespace smt { enode * app = m_stack.back().second; m_stack.pop_back(); - if (oc_cycle_free(app)) continue; + if (oc_cycle_free(app)) + continue; TRACE("datatype", tout << "occurs check loop: " << enode_pp(app, ctx) << (op==ENTER?" enter":" exit")<< "\n";); switch (op) { case ENTER: - res = occurs_check_enter(app); - break; + res = occurs_check_enter(app); + break; case EXIT: - oc_mark_cycle_free(app); - break; + oc_mark_cycle_free(app); + break; } } @@ -702,6 +753,7 @@ namespace smt { theory(ctx, ctx.get_manager().mk_family_id("datatype")), m_util(m), m_autil(m), + m_sutil(m), m_find(*this), m_trail_stack() { } diff --git a/src/smt/theory_datatype.h b/src/smt/theory_datatype.h index d219b1f9d..c0e06b58d 100644 --- a/src/smt/theory_datatype.h +++ b/src/smt/theory_datatype.h @@ -20,6 +20,7 @@ Revision History: #include "util/union_find.h" #include "ast/array_decl_plugin.h" +#include "ast/seq_decl_plugin.h" #include "ast/datatype_decl_plugin.h" #include "model/datatype_factory.h" #include "smt/smt_theory.h" @@ -46,6 +47,7 @@ namespace smt { datatype_util m_util; array_util m_autil; + seq_util m_sutil; ptr_vector m_var_data; th_union_find m_find; trail_stack m_trail_stack; @@ -90,8 +92,9 @@ namespace smt { bool oc_cycle_free(enode * n) const { return n->get_root()->is_marked2(); } void oc_push_stack(enode * n); - ptr_vector m_array_args; + ptr_vector m_args, m_todo; ptr_vector const& get_array_args(enode* n); + ptr_vector const& get_seq_args(enode* n); // class for managing state of final_check class final_check_st {