From 554b325124058d823313ade1aafbce8ff5e07914 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sun, 8 Sep 2024 17:56:57 -0700 Subject: [PATCH] replace user plugin by euf plugin Signed-off-by: Nikolaj Bjorner --- src/ast/sls/.#sls_user_sort_plugin.cpp | 1 + src/ast/sls/CMakeLists.txt | 1 - src/ast/sls/sls_context.cpp | 37 +++++-- src/ast/sls/sls_context.h | 1 + src/ast/sls/sls_euf_plugin.cpp | 135 +++++++++++++++++++++---- src/ast/sls/sls_euf_plugin.h | 14 +++ 6 files changed, 161 insertions(+), 28 deletions(-) create mode 100644 src/ast/sls/.#sls_user_sort_plugin.cpp diff --git a/src/ast/sls/.#sls_user_sort_plugin.cpp b/src/ast/sls/.#sls_user_sort_plugin.cpp new file mode 100644 index 000000000..48891894f --- /dev/null +++ b/src/ast/sls/.#sls_user_sort_plugin.cpp @@ -0,0 +1 @@ +nbjorner@LAPTOP-04AEAFKH.15652:1725565681 \ No newline at end of file diff --git a/src/ast/sls/CMakeLists.txt b/src/ast/sls/CMakeLists.txt index c2a1376ee..519497e9f 100644 --- a/src/ast/sls/CMakeLists.txt +++ b/src/ast/sls/CMakeLists.txt @@ -15,7 +15,6 @@ z3_add_component(ast_sls sls_context.cpp sls_euf_plugin.cpp sls_smt_solver.cpp - sls_user_sort_plugin.cpp COMPONENT_DEPENDENCIES ast euf diff --git a/src/ast/sls/sls_context.cpp b/src/ast/sls/sls_context.cpp index 2ed898add..610841087 100644 --- a/src/ast/sls/sls_context.cpp +++ b/src/ast/sls/sls_context.cpp @@ -21,8 +21,6 @@ Author: #include "ast/sls/sls_array_plugin.h" #include "ast/sls/sls_bv_plugin.h" #include "ast/sls/sls_basic_plugin.h" -#include "ast/sls/sls_model_value_plugin.h" -#include "ast/sls/sls_user_sort_plugin.h" #include "ast/ast_ll_pp.h" #include "ast/ast_pp.h" #include "smt/params/smt_params_helper.hpp" @@ -41,13 +39,6 @@ namespace sls { m_repair_down(m.get_num_asts(), m_gd), m_repair_up(m.get_num_asts(), m_ld), m_todo(m) { - register_plugin(alloc(euf_plugin, *this)); - register_plugin(alloc(arith_plugin, *this)); - register_plugin(alloc(bv_plugin, *this)); - register_plugin(alloc(basic_plugin, *this)); - register_plugin(alloc(array_plugin, *this)); - register_plugin(alloc(user_sort_plugin, *this)); - register_plugin(alloc(model_value_plugin, *this)); } void context::updt_params(params_ref const& p) { @@ -60,6 +51,27 @@ namespace sls { m_plugins.set(p->fid(), p); } + void context::ensure_plugin(expr* e) { + auto fid = get_fid(e); + if (m_plugins.get(fid, nullptr)) + return; + else if (fid == arith_family_id) + register_plugin(alloc(arith_plugin, *this)); + else if (fid == user_sort_family_id) + register_plugin(alloc(euf_plugin, *this)); + else if (fid == basic_family_id) + register_plugin(alloc(basic_plugin, *this)); + else if (fid == bv_util(m).get_family_id()) + register_plugin(alloc(bv_plugin, *this)); + else if (fid == array_util(m).get_family_id()) + register_plugin(alloc(array_plugin, *this)); + else + verbose_stream() << "did not find plugin for " << mk_bounded_pp(e, m) << "\n"; + + // add arrays and bv dynamically too. + } + + void context::register_atom(sat::bool_var v, expr* e) { m_atoms.setx(v, e); m_atom2bool_var.setx(e->get_id(), v, sat::null_bool_var); @@ -177,12 +189,14 @@ namespace sls { family_id context::get_fid(expr* e) const { if (!is_app(e)) - return null_family_id; + return user_sort_family_id; family_id fid = to_app(e)->get_family_id(); if (m.is_eq(e)) fid = to_app(e)->get_arg(0)->get_sort()->get_family_id(); if (m.is_distinct(e)) fid = to_app(e)->get_arg(0)->get_sort()->get_family_id(); + if (fid == null_family_id || fid == model_value_family_id) + fid = user_sort_family_id; return fid; } @@ -196,6 +210,8 @@ namespace sls { auto p = m_plugins.get(fid, nullptr); if (p) p->propagate_literal(lit); + if (!is_true(lit)) + m_new_constraint = true; } bool context::is_true(expr* e) { @@ -429,6 +445,7 @@ namespace sls { auto visit = [&](expr* e) { m_allterms.setx(e->get_id(), e); + ensure_plugin(e); }; if (is_visited(e)) return; diff --git a/src/ast/sls/sls_context.h b/src/ast/sls/sls_context.h index 05b19bbd3..5cdbdd3f7 100644 --- a/src/ast/sls/sls_context.h +++ b/src/ast/sls/sls_context.h @@ -134,6 +134,7 @@ namespace sls { void propagate_literal(sat::literal lit); void repair_literals(); + void ensure_plugin(expr* e); family_id get_fid(expr* e) const; diff --git a/src/ast/sls/sls_euf_plugin.cpp b/src/ast/sls/sls_euf_plugin.cpp index 84b87de4d..2d36611ab 100644 --- a/src/ast/sls/sls_euf_plugin.cpp +++ b/src/ast/sls/sls_euf_plugin.cpp @@ -25,16 +25,17 @@ namespace sls { euf_plugin::euf_plugin(context& c): plugin(c), m_values(8U, value_hash(*this), value_eq(*this)) { - m_fid = m.mk_family_id("cc"); + m_fid = user_sort_family_id; } euf_plugin::~euf_plugin() {} - - expr_ref euf_plugin::get_value(expr* e) { - UNREACHABLE(); - return expr_ref(m); + + void euf_plugin::start_propagation() { + m_g = alloc(euf::egraph, m); + init_egraph(*m_g); } + void euf_plugin::register_term(expr* e) { if (!is_app(e)) return; @@ -64,22 +65,120 @@ namespace sls { return true; } - void euf_plugin::propagate_literal(sat::literal lit) { - if (!ctx.is_true(lit)) - return; + void euf_plugin::propagate_literal(sat::literal lit) { + SASSERT(ctx.is_true(lit)); auto e = ctx.atom(lit.var()); expr* x, * y; - if (e && m.is_eq(e, x, y) && m.is_uninterp(x->get_sort())) { - auto vx = ctx.get_value(x); - auto vy = ctx.get_value(y); - verbose_stream() << "check " << mk_bounded_pp(x, m) << " == " << mk_bounded_pp(y, m) << "\n"; - if (lit.sign() && vx == vy) - ctx.flip(lit.var()); - else if (!lit.sign() && vx != vy) - ctx.flip(lit.var()); + + if (!e) + return; + + auto block = [&](euf::enode* a, euf::enode* b) { + if (a->get_root() != b->get_root()) + return; + ptr_vector explain; + m_g->explain_eq(explain, nullptr, a, b); + m_g->end_explain(); + unsigned n = 1; + sat::literal_vector lits; + sat::literal flit = sat::null_literal; + if (!ctx.is_unit(lit)) { + flit = lit; + lits.push_back(~lit); + } + for (auto p : explain) { + sat::literal l = to_literal(p); + if (!ctx.is_true(l)) + return; + if (ctx.is_unit(l)) + continue; + lits.push_back(~l); + if (ctx.rand(++n) == 0) + flit = l; + } + ctx.add_clause(lits); + if (flit != sat::null_literal) + ctx.flip(flit.var()); + }; + + if (lit.sign() && m.is_eq(e, x, y)) + block(m_g->find(x), m_g->find(y)); + else if (!lit.sign() && m.is_distinct(e)) { + auto n = to_app(e)->get_num_args(); + for (unsigned i = 0; i < n; ++i) { + auto a = m_g->find(to_app(e)->get_arg(i)); + for (unsigned j = i + 1; j < n; ++j) { + auto b = m_g->find(to_app(e)->get_arg(j)); + block(a, b); + } + } + } + else if (lit.sign()) { + auto a = m_g->find(e); + auto b = m_g->find(m.mk_true()); + block(a, b); } } + void euf_plugin::init_egraph(euf::egraph& g) { + ptr_vector args; + for (auto t : ctx.subterms()) { + args.reset(); + if (is_app(t)) + for (auto* arg : *to_app(t)) + args.push_back(g.find(arg)); + g.mk(t, 0, args.size(), args.data()); + } + if (!g.find(m.mk_true())) + g.mk(m.mk_true(), 0, 0, nullptr); + if (!g.find(m.mk_false())) + g.mk(m.mk_false(), 0, 0, nullptr); + for (auto lit : ctx.root_literals()) { + if (!ctx.is_true(lit)) + lit.neg(); + auto e = ctx.atom(lit.var()); + expr* x, * y; + if (e && m.is_eq(e, x, y) && !lit.sign()) + g.merge(g.find(x), g.find(y), to_ptr(lit)); + else if (!lit.sign()) + g.merge(g.find(e), g.find(m.mk_true()), to_ptr(lit)); + } + g.propagate(); + + typedef obj_map map1; + typedef obj_map map2; + + m_num_elems = alloc(map1); + m_root2value = alloc(map2); + m_pinned = alloc(expr_ref_vector, m); + + for (auto n : g.nodes()) { + if (n->is_root() && is_user_sort(n->get_sort())) { + // verbose_stream() << "init root " << g.pp(n) << "\n"; + unsigned num = 0; + m_num_elems->find(n->get_sort(), num); + expr* v = m.mk_model_value(num, n->get_sort()); + m_pinned->push_back(v); + m_root2value->insert(n, v); + m_num_elems->insert(n->get_sort(), num + 1); + } + } + } + + expr_ref euf_plugin::get_value(expr* e) { + if (m.is_model_value(e)) + return expr_ref(e, m); + + if (!m_g) { + m_g = alloc(euf::egraph, m); + init_egraph(*m_g); + } + auto n = m_g->find(e)->get_root(); + VERIFY(m_root2value->find(n, e)); + return expr_ref(e, m); + } + + bool euf_plugin::is_sat() { for (auto& [f, ts] : m_app) { if (ts.size() <= 1) @@ -101,7 +200,6 @@ namespace sls { } bool euf_plugin::propagate() { - return false; bool new_constraint = false; for (auto & [f, ts] : m_app) { if (ts.size() <= 1) @@ -135,6 +233,9 @@ namespace sls { } std::ostream& euf_plugin::display(std::ostream& out) const { + if (m_g) + m_g->display(out); + for (auto& [f, ts] : m_app) { for (auto* t : ts) out << mk_bounded_pp(t, m) << "\n"; diff --git a/src/ast/sls/sls_euf_plugin.h b/src/ast/sls/sls_euf_plugin.h index 03a3166aa..39a6dd754 100644 --- a/src/ast/sls/sls_euf_plugin.h +++ b/src/ast/sls/sls_euf_plugin.h @@ -18,6 +18,7 @@ Author: #include "util/hashtable.h" #include "ast/sls/sls_context.h" +#include "ast/euf/euf_egraph.h" namespace sls { @@ -34,12 +35,25 @@ namespace sls { bool operator()(app* a, app* b) const; }; hashtable m_values; + + scoped_ptr m_g; + scoped_ptr> m_num_elems; + scoped_ptr> m_root2value; + scoped_ptr m_pinned; + + void init_egraph(euf::egraph& g); + bool is_user_sort(sort* s) { return s->get_family_id() == user_sort_family_id; } + + size_t* to_ptr(sat::literal l) { return reinterpret_cast((size_t)(l.index() << 4)); }; + sat::literal to_literal(size_t* p) { return sat::to_literal(static_cast(reinterpret_cast(p) >> 4)); }; + public: euf_plugin(context& c); ~euf_plugin() override; family_id fid() { return m_fid; } expr_ref get_value(expr* e) override; void initialize() override {} + void start_propagation() override; void propagate_literal(sat::literal lit) override; bool propagate() override; bool is_sat() override;