diff --git a/src/smt/theory_datatype.cpp b/src/smt/theory_datatype.cpp index e21067a56..f17bdc3d6 100644 --- a/src/smt/theory_datatype.cpp +++ b/src/smt/theory_datatype.cpp @@ -411,19 +411,41 @@ namespace smt { return r; } + // Assuming `app` is equal to a constructor term, return the constructor enode + inline enode * theory_datatype::oc_get_cstor(enode * app) { + theory_var v = app->get_root()->get_th_var(get_id()); + SASSERT(v != null_theory_var); + v = m_find.find(v); + var_data * d = m_var_data[v]; + SASSERT(d->m_constructor); + return d->m_constructor; + } + // explain the cycle root -> … -> app -> root void theory_datatype::occurs_check_explain(enode * app, enode * root) { TRACE("datatype", tout << "occurs_check_explain " << mk_bounded_pp(app->get_owner(), get_manager()) << " <-> " << mk_bounded_pp(root->get_owner(), get_manager()) << "\n";); enode* app_parent = nullptr; + // first: explain that root=v, given that app=cstor(…,v,…) + { + enode * app_cstor = oc_get_cstor(app); + unsigned n_args_app = app_cstor->get_num_args(); + for (unsigned i=0; i < n_args_app; ++i) { + enode * arg = app_cstor->get_arg(i); + // found an argument which is equal to root + if (arg->get_root() == root->get_root()) { + if (arg != root) + m_used_eqs.push_back(enode_pair(arg, root)); + break; + } + } + } + + // now explain app=cstor(…,v,…) where v=root, and recurse with parent of app while (app->get_root() != root->get_root()) { - theory_var v = app->get_root()->get_th_var(get_id()); - SASSERT(v != null_theory_var); - v = m_find.find(v); - var_data * d = m_var_data[v]; - SASSERT(d->m_constructor); - if (app != d->m_constructor) - m_used_eqs.push_back(enode_pair(app, d->m_constructor)); + enode * app_cstor = oc_get_cstor(app); + if (app != app_cstor) + m_used_eqs.push_back(enode_pair(app, app_cstor)); app_parent = m_parent[app->get_root()]; app = app_parent; } diff --git a/src/smt/theory_datatype.h b/src/smt/theory_datatype.h index 77da52360..020949296 100644 --- a/src/smt/theory_datatype.h +++ b/src/smt/theory_datatype.h @@ -121,6 +121,7 @@ namespace smt { } }; + enode * oc_get_cstor(enode * n); bool occurs_check(enode * n); bool occurs_check_enter(enode * n); void occurs_check_explain(enode * top, enode * root);