diff --git a/src/ast/sls/sls_datatype_plugin.cpp b/src/ast/sls/sls_datatype_plugin.cpp index 3a30c6a19..92f1347cf 100644 --- a/src/ast/sls/sls_datatype_plugin.cpp +++ b/src/ast/sls/sls_datatype_plugin.cpp @@ -51,6 +51,7 @@ Axioms: --*/ #include "ast/sls/sls_datatype_plugin.h" +#include "ast/ast_pp.h" namespace sls { @@ -257,13 +258,7 @@ namespace sls { m_values.reserve(id + 1); if (!dt.is_datatype(e)) continue; - euf::enode* con = nullptr; - for (auto sib : euf::enode_class(n)) { - if (dt.is_constructor(sib->get_expr())) { - con = sib; - break; - } - } + euf::enode* con = get_constructor(n); if (con) { auto f = con->get_decl(); args.reset(); @@ -284,13 +279,7 @@ namespace sls { void datatype_plugin::add_dep(euf::enode* n, top_sort& dep) { if (!dt.is_datatype(n->get_expr())) return; - euf::enode* con = nullptr; - for (auto sib : euf::enode_class(n)) { - if (dt.is_constructor(sib->get_expr())) { - con = sib; - break; - } - } + euf::enode* con = get_constructor(n); TRACE("dt", display(tout) << g->bpp(n) << " con: " << g->bpp(con) << "\n";); if (!con) dep.insert(n, nullptr); @@ -304,16 +293,101 @@ namespace sls { void datatype_plugin::start_propagation() { m_values.reset(); } - - void datatype_plugin::propagate_literal(sat::literal lit) {} - bool datatype_plugin::propagate() { return false; } - bool datatype_plugin::is_sat() { return true; } - void datatype_plugin::register_term(expr* e) {} + + euf::enode* datatype_plugin::get_constructor(euf::enode* n) { + euf::enode* con = nullptr; + for (auto sib : euf::enode_class(n)) + if (dt.is_constructor(sib->get_expr())) + return sib; + return nullptr; + } + + bool datatype_plugin::propagate() { + enum color_t { white, grey, black }; + svector color; + unsigned_vector dfsnum; + svector> todo; + for (auto n : g->nodes()) { + if (!n->is_root()) + continue; + expr* e = n->get_expr(); + if (!dt.is_datatype(e)) + continue; + + auto c = color.get(e->get_id(), white); + SASSERT(c != grey); + if (c == black) + continue; + + dfsnum.setx(e->get_id(), 0, UINT_MAX); + + // dfs traversal of enodes, starting with n, + // with outgoing edges the arguments of con, where con + // is a node in the same congruence class as n that is a constructor. + // For every cycle accumulate a conflict. + + todo.push_back({ 0, n, 0 }); + while (!todo.empty()) { + auto [depth, n, parent_idx] = todo.back(); + unsigned id = n->get_root_id(); + c = color[id]; + euf::enode* con; + unsigned idx; + + switch (c) { + case black: + todo.pop_back(); + break; + case grey: + if (dfsnum.get(id, UINT_MAX) < depth) { + expr_ref_vector diseqs(m); + while (true) { + auto [depth2, n2, parent_idx2] = todo[parent_idx]; + auto con2 = get_constructor(n2); + if (n2 != con2) + diseqs.push_back(m.mk_not(m.mk_eq(n2->get_expr(), con2->get_expr()))); + parent_idx = parent_idx2; + if (n2->get_root() == n->get_root()) { + diseqs.push_back(m.mk_not(m.mk_eq(n->get_expr(), n2->get_expr()))); + break; + } + } + verbose_stream() << "cycle\n"; + for (auto e : diseqs) + verbose_stream() << mk_pp(e, m) << "\n"; + ctx.add_clause(m.mk_or(diseqs)); + return true; + } + color[id] = black; + todo.pop_back(); + break; + case white: + color[id] = grey; + dfsnum.setx(id, depth, UINT_MAX); + con = get_constructor(n); + idx = todo.size() - 1; + if (con) + for (auto child : euf::enode_args(con)) + if (color.get(child->get_root_id(), white) == white && dt.is_datatype(child->get_expr())) + todo.push_back({ depth + 1, child, idx }); + break; + } + } + } + return false; + } + std::ostream& datatype_plugin::display(std::ostream& out) const { for (auto a : m_axioms) out << mk_bounded_pp(a, m, 3) << "\n"; return out; } + + void datatype_plugin::propagate_literal(sat::literal lit) {} + + bool datatype_plugin::is_sat() { return true; } + void datatype_plugin::register_term(expr* e) {} + void datatype_plugin::mk_model(model& mdl) { } diff --git a/src/ast/sls/sls_datatype_plugin.h b/src/ast/sls/sls_datatype_plugin.h index 61cc2b938..bacfa9d63 100644 --- a/src/ast/sls/sls_datatype_plugin.h +++ b/src/ast/sls/sls_datatype_plugin.h @@ -48,6 +48,8 @@ namespace sls { void init_values(); void add_dep(euf::enode* n, top_sort& dep); + euf::enode* get_constructor(euf::enode* n); + public: datatype_plugin(context& c); ~datatype_plugin() override;