diff --git a/src/sat/smt/dt_solver.cpp b/src/sat/smt/dt_solver.cpp index 22e446a13..daecb7325 100644 --- a/src/sat/smt/dt_solver.cpp +++ b/src/sat/smt/dt_solver.cpp @@ -554,12 +554,12 @@ namespace dt { } // Assuming `app` is equal to a constructor term, return the constructor enode - inline euf::enode* solver::oc_get_cstor(enode* app) { + inline euf::enode* solver::oc_get_cstor(enode* app) const { theory_var v = app->get_root()->get_th_var(get_id()); - SASSERT(v != euf::null_theory_var); + if (v == euf::null_theory_var) + return nullptr; v = m_find.find(v); var_data* d = m_var_data[v]; - SASSERT(d->m_constructor); return d->m_constructor; } @@ -783,7 +783,7 @@ namespace dt { if (v == euf::null_theory_var) return false; euf::enode* con = m_var_data[m_find.find(v)]->m_constructor; - CTRACE("dt", !con, display(tout) << ctx.bpp(n) << "\n";); + TRACE("dt", display(tout) << ctx.bpp(n) << " con: " << ctx.bpp(con) << "\n";); if (con->num_args() == 0) dep.insert(n, nullptr); for (enode* arg : euf::enode_args(con)) @@ -794,16 +794,15 @@ namespace dt { bool solver::include_func_interp(func_decl* f) const { if (!dt.is_accessor(f)) return false; - func_decl* con = dt.get_accessor_constructor(f); - for (enode* app : ctx.get_egraph().enodes_of(f)) { - enode* arg = app->get_arg(0)->get_root(); - if (is_constructor(arg) && arg->get_decl() != con) + func_decl* con_decl = dt.get_accessor_constructor(f); + for (enode* app : ctx.get_egraph().enodes_of(f)) { + enode* con = oc_get_cstor(app->get_arg(0)); + if (con && is_constructor(con) && con->get_decl() != con_decl) return true; } return false; } - sat::literal solver::internalize(expr* e, bool sign, bool root) { if (!visit_rec(m, e, sign, root)) return sat::null_literal; diff --git a/src/sat/smt/dt_solver.h b/src/sat/smt/dt_solver.h index 51a7679fd..b2cbba63b 100644 --- a/src/sat/smt/dt_solver.h +++ b/src/sat/smt/dt_solver.h @@ -116,7 +116,7 @@ namespace dt { void pop_core(unsigned n) override; - enode * oc_get_cstor(enode * n); + enode * oc_get_cstor(enode * n) const; bool occurs_check(enode * n); bool occurs_check_enter(enode * n); void occurs_check_explain(enode * top, enode * root);