From a24b94828c314ce8b4d1d06add5eb7cdcb3be047 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sun, 8 Sep 2024 13:31:02 -0700 Subject: [PATCH] Enhance array plugin with early termination and propagation verification, and improve euf and user sort plugins with propagation adjustments and debugging enhancements --- src/ast/sls/sls_array_plugin.cpp | 14 ++- src/ast/sls/sls_array_plugin.h | 3 +- src/ast/sls/sls_context.cpp | 29 ++++-- src/ast/sls/sls_context.h | 3 + src/ast/sls/sls_euf_plugin.cpp | 19 +++- src/ast/sls/sls_euf_plugin.h | 2 +- src/ast/sls/sls_user_sort_plugin.cpp | 136 ++++++++++++++++++++++++--- src/ast/sls/sls_user_sort_plugin.h | 8 +- 8 files changed, 189 insertions(+), 25 deletions(-) diff --git a/src/ast/sls/sls_array_plugin.cpp b/src/ast/sls/sls_array_plugin.cpp index 89590590b..930e8f9d8 100644 --- a/src/ast/sls/sls_array_plugin.cpp +++ b/src/ast/sls/sls_array_plugin.cpp @@ -30,6 +30,8 @@ namespace sls { } bool array_plugin::is_sat() { + if (!m_has_arrays) + return true; m_g = alloc(euf::egraph, m); m_kv = nullptr; init_egraph(*m_g); @@ -88,8 +90,10 @@ namespace sls { auto nsel = mk_select(g, n, n); if (are_distinct(nsel, val)) add_store_axiom1(n->get_app()); - else + else { g.merge(nsel, val, nullptr); + VERIFY(g.propagate()); + } } // i /~ j, b ~ a[i->v], b[j] occurs -> a[j] = b[j] @@ -103,8 +107,10 @@ namespace sls { auto nsel = mk_select(g, sto->get_arg(0), sel); if (are_distinct(nsel, sel)) add_store_axiom2(sto->get_app(), sel->get_app()); - else + else { g.merge(nsel, sel, nullptr); + VERIFY(g.propagate()); + } } // a ~ b, i /~ j, b[j] occurs -> a[i -> v][j] = b[j] @@ -118,8 +124,10 @@ namespace sls { auto nsel = mk_select(g, sto, sel); if (are_distinct(nsel, sel)) add_store_axiom2(sto->get_app(), sel->get_app()); - else + else { g.merge(nsel, sel, nullptr); + VERIFY(g.propagate()); + } } bool array_plugin::are_distinct(euf::enode* a, euf::enode* b) { diff --git a/src/ast/sls/sls_array_plugin.h b/src/ast/sls/sls_array_plugin.h index ac04889ca..d6557f245 100644 --- a/src/ast/sls/sls_array_plugin.h +++ b/src/ast/sls/sls_array_plugin.h @@ -52,6 +52,7 @@ namespace sls { scoped_ptr m_g; scoped_ptr m_kv; bool m_add_conflicts = true; + bool m_has_arrays = false; void init_egraph(euf::egraph& g); void init_kv(euf::egraph& g, kv& kv); @@ -68,7 +69,7 @@ namespace sls { public: array_plugin(context& ctx); ~array_plugin() override {} - void register_term(expr* e) override { } + void register_term(expr* e) override { if (a.is_array(e->get_sort())) m_has_arrays = true; } expr_ref get_value(expr* e) override; void initialize() override { m_g = nullptr; } void propagate_literal(sat::literal lit) override { m_g = nullptr; } diff --git a/src/ast/sls/sls_context.cpp b/src/ast/sls/sls_context.cpp index 3c2c0ce6e..90755652a 100644 --- a/src/ast/sls/sls_context.cpp +++ b/src/ast/sls/sls_context.cpp @@ -85,7 +85,7 @@ namespace sls { propagate_boolean_assignment(); - // verbose_stream() << "propagate " << unsat().size() << " " << m_new_constraint << "\n"; + // verbose_stream() << "propagate " << unsat().size() << " " << m_new_constraint << "\n"; // display(verbose_stream()); @@ -113,11 +113,16 @@ namespace sls { void context::propagate_boolean_assignment() { reinit_relevant(); - for (sat::literal lit : root_literals()) { + for (auto p : m_plugins) + if (p) + p->start_propagation(); + + for (sat::literal lit : root_literals()) propagate_literal(lit); - if (m_new_constraint) - return; - } + + if (m_new_constraint) + return; + while (!m_new_constraint && m.inc() && (!m_repair_up.empty() || !m_repair_down.empty())) { while (!m_repair_down.empty() && !m_new_constraint && m.inc()) { @@ -230,6 +235,8 @@ namespace sls { if (m_visited.contains(id)) return false; m_visited.insert(id); + if (m_parents.size() <= id) + verbose_stream() << "not in map " << mk_bounded_pp(e, m) << "\n"; for (auto p : m_parents[id]) { if (is_relevant(p)) { m_relevant.insert(id); @@ -242,6 +249,7 @@ namespace sls { void context::add_constraint(expr* e) { add_clause(e); m_new_constraint = true; + ++m_stats.m_num_constraints; } void context::add_clause(expr* f) { @@ -303,6 +311,13 @@ namespace sls { } } + void context::add_clause(sat::literal_vector const& lits) { + //verbose_stream() << lits << "\n"; + s.add_clause(lits.size(), lits.data()); + m_new_constraint = true; + ++m_stats.m_num_constraints; + } + sat::literal context::mk_literal() { sat::bool_var v = s.add_var(); return sat::literal(v, false); @@ -427,6 +442,7 @@ namespace sls { if (all_of(*to_app(e), [&](expr* arg) { return is_visited(arg); })) { expr_ref _e(e, m); m_todo.pop_back(); + m_parents.reserve(to_app(e)->get_id() + 1); for (expr* arg : *to_app(e)) { m_parents.reserve(arg->get_id() + 1); m_parents[arg->get_id()].push_back(e); @@ -488,7 +504,7 @@ namespace sls { if (e) m_subterms.push_back(e); std::stable_sort(m_subterms.begin(), m_subterms.end(), - [](expr* a, expr* b) { return a->get_id() < b->get_id(); }); + [](expr* a, expr* b) { return get_depth(a) < get_depth(b); }); return m_subterms; } @@ -547,6 +563,7 @@ namespace sls { p->collect_statistics(st); st.update("sls-repair-down", m_stats.m_num_repair_down); st.update("sls-repair-up", m_stats.m_num_repair_up); + st.update("sls-constraints", m_stats.m_num_constraints); } void context::reset_statistics() { diff --git a/src/ast/sls/sls_context.h b/src/ast/sls/sls_context.h index 4df41b315..05b19bbd3 100644 --- a/src/ast/sls/sls_context.h +++ b/src/ast/sls/sls_context.h @@ -41,6 +41,7 @@ namespace sls { virtual void register_term(expr* e) = 0; virtual expr_ref get_value(expr* e) = 0; virtual void initialize() = 0; + virtual void start_propagation() {}; virtual bool propagate() = 0; virtual void propagate_literal(sat::literal lit) = 0; virtual void repair_literal(sat::literal lit) = 0; @@ -97,6 +98,7 @@ namespace sls { struct stats { unsigned m_num_repair_down = 0; unsigned m_num_repair_up = 0; + unsigned m_num_constraints = 0; void reset() { memset(this, 0, sizeof(*this)); } }; @@ -160,6 +162,7 @@ namespace sls { sat::bool_var atom2bool_var(expr* e) const { return m_atom2bool_var.get(e->get_id(), sat::null_bool_var); } sat::literal mk_literal(expr* e); void add_clause(expr* f); + void add_clause(sat::literal_vector const& lits); void flip(sat::bool_var v) { s.flip(v); } double reward(sat::bool_var v) { return s.reward(v); } indexed_uint_set const& unsat() const { return s.unsat(); } diff --git a/src/ast/sls/sls_euf_plugin.cpp b/src/ast/sls/sls_euf_plugin.cpp index 61e2d747b..84b87de4d 100644 --- a/src/ast/sls/sls_euf_plugin.cpp +++ b/src/ast/sls/sls_euf_plugin.cpp @@ -64,6 +64,22 @@ namespace sls { return true; } + void euf_plugin::propagate_literal(sat::literal lit) { + if (!ctx.is_true(lit)) + return; + auto e = ctx.atom(lit.var()); + expr* x, * y; + if (e && m.is_eq(e, x, y) && m.is_uninterp(x->get_sort())) { + auto vx = ctx.get_value(x); + auto vy = ctx.get_value(y); + verbose_stream() << "check " << mk_bounded_pp(x, m) << " == " << mk_bounded_pp(y, m) << "\n"; + if (lit.sign() && vx == vy) + ctx.flip(lit.var()); + else if (!lit.sign() && vx != vy) + ctx.flip(lit.var()); + } + } + bool euf_plugin::is_sat() { for (auto& [f, ts] : m_app) { if (ts.size() <= 1) @@ -84,7 +100,8 @@ namespace sls { return true; } - bool euf_plugin::propagate() { + bool euf_plugin::propagate() { + return false; bool new_constraint = false; for (auto & [f, ts] : m_app) { if (ts.size() <= 1) diff --git a/src/ast/sls/sls_euf_plugin.h b/src/ast/sls/sls_euf_plugin.h index fd0aa7266..03a3166aa 100644 --- a/src/ast/sls/sls_euf_plugin.h +++ b/src/ast/sls/sls_euf_plugin.h @@ -40,7 +40,7 @@ namespace sls { family_id fid() { return m_fid; } expr_ref get_value(expr* e) override; void initialize() override {} - void propagate_literal(sat::literal lit) override {} + void propagate_literal(sat::literal lit) override; bool propagate() override; bool is_sat() override; void register_term(expr* e) override; diff --git a/src/ast/sls/sls_user_sort_plugin.cpp b/src/ast/sls/sls_user_sort_plugin.cpp index f731eb007..018a5822a 100644 --- a/src/ast/sls/sls_user_sort_plugin.cpp +++ b/src/ast/sls/sls_user_sort_plugin.cpp @@ -28,28 +28,93 @@ namespace sls { m_fid = user_sort_family_id; } + void user_sort_plugin::start_propagation() { + m_g = alloc(euf::egraph, m); + init_egraph(*m_g); + } + + void user_sort_plugin::propagate_literal(sat::literal lit) { + SASSERT(ctx.is_true(lit)); + auto e = ctx.atom(lit.var()); + expr* x, * y; + + auto block = [&](euf::enode* a, euf::enode* b) { + ptr_vector explain; + m_g->explain_eq(explain, nullptr, a, b); + m_g->end_explain(); + unsigned n = 1; + sat::literal_vector lits; + lits.push_back(~lit); + sat::literal flit = lit; + for (auto p : explain) { + sat::literal l = to_literal(p); + if (!ctx.is_true(l)) + return sat::null_literal; + if (ctx.is_unit(l)) + continue; + lits.push_back(~l); + if (ctx.rand(++n) == 0) + flit = l; + } + ctx.add_clause(lits); + return flit; + }; + + if (e && m.is_eq(e, x, y) && m.is_uninterp(x->get_sort())) { + auto vx = get_value(x); + auto vy = get_value(y); + bool should_flip = lit.sign() ? vx == vy : vx != vy; + if (should_flip) { + sat::literal flit = lit; + + if (lit.sign()) { + auto a = m_g->find(x); + auto b = m_g->find(y); + flit = block(a, b); + } + + if (flit != sat::null_literal) + ctx.flip(flit.var()); + } + } + else if (e && lit.sign()) { + auto a = m_g->find(e); + auto b = m_g->find(m.mk_true()); + + if (a->get_root() == b->get_root()) { + verbose_stream() << "block " << lit << "\n"; + auto flit = block(a, b); + if (flit != sat::null_literal) + ctx.flip(flit.var()); + } + } + } + void user_sort_plugin::init_egraph(euf::egraph& g) { ptr_vector args; for (auto t : ctx.subterms()) { args.reset(); - if (is_app(t)) { - for (auto* arg : *to_app(t)) { - args.push_back(g.find(arg)); - } - } + if (is_app(t)) + for (auto* arg : *to_app(t)) + args.push_back(g.find(arg)); g.mk(t, 0, args.size(), args.data()); } + if (!g.find(m.mk_true())) + g.mk(m.mk_true(), 0, 0, nullptr); + if (!g.find(m.mk_false())) + g.mk(m.mk_false(), 0, 0, nullptr); for (auto lit : ctx.root_literals()) { - if (!ctx.is_true(lit) || lit.sign()) - continue; + if (!ctx.is_true(lit)) + lit.neg(); auto e = ctx.atom(lit.var()); expr* x, * y; - if (e && m.is_eq(e, x, y)) - g.merge(g.find(x), g.find(y), nullptr); + if (e && m.is_eq(e, x, y) && !lit.sign()) + g.merge(g.find(x), g.find(y), to_ptr(lit)); + else if (!lit.sign()) + g.merge(g.find(e), g.find(m.mk_true()), to_ptr(lit)); } - display(verbose_stream()); - + g.propagate(); typedef obj_map map1; typedef obj_map map2; @@ -60,6 +125,7 @@ namespace sls { for (auto n : g.nodes()) { if (n->is_root() && is_user_sort(n->get_sort())) { + // verbose_stream() << "init root " << g.pp(n) << "\n"; unsigned num = 0; m_num_elems->find(n->get_sort(), num); expr* v = m.mk_model_value(num, n->get_sort()); @@ -89,4 +155,52 @@ namespace sls { m_g->display(out); return out; } + + bool user_sort_plugin::is_sat() { + return true; + bool flipped = false; + ptr_vector args; + euf::egraph g(m); + auto assert_lit = [&](auto lit) { + auto e = ctx.atom(lit.var()); + expr* x, * y; + if (e && m.is_eq(e, x, y) && !lit.sign()) + g.merge(g.find(x), g.find(y), nullptr); + else if (e && m.is_eq(e) && lit.sign()) + g.merge(g.find(e), g.find(m.mk_false()), nullptr); + else + g.merge(g.find(e), g.find(m.mk_bool_val(!lit.sign())), nullptr); + g.propagate(); + }; + g.mk(m.mk_false(), 0, 0, nullptr); + g.mk(m.mk_true(), 0, 0, nullptr); + for (auto t : ctx.subterms()) { + if (g.find(t)) + continue; + args.reset(); + if (is_app(t)) { + for (auto* arg : *to_app(t)) { + args.push_back(g.find(arg)); + } + } + g.mk(t, 0, args.size(), args.data()); + } + for (auto lit : ctx.root_literals()) { + if (!ctx.is_true(lit)) + lit.neg(); + + g.push(); + assert_lit(lit); + bool is_unsat = g.inconsistent(); + g.pop(1); + if (is_unsat) { + ctx.flip(lit.var()); + lit.neg(); + flipped = true; + } + assert_lit(lit); + + } + return !flipped; + } } diff --git a/src/ast/sls/sls_user_sort_plugin.h b/src/ast/sls/sls_user_sort_plugin.h index 121798bd2..da452c133 100644 --- a/src/ast/sls/sls_user_sort_plugin.h +++ b/src/ast/sls/sls_user_sort_plugin.h @@ -29,6 +29,9 @@ namespace sls { void init_egraph(euf::egraph& g); bool is_user_sort(sort* s) { return s->get_family_id() == user_sort_family_id; } + + size_t* to_ptr(sat::literal l) { return reinterpret_cast((size_t)(l.index() << 4)); }; + sat::literal to_literal(size_t* p) { return sat::to_literal(static_cast(reinterpret_cast(p) >> 4)); }; public: user_sort_plugin(context& ctx); @@ -36,12 +39,13 @@ namespace sls { void register_term(expr* e) override { } expr_ref get_value(expr* e) override; void initialize() override { m_g = nullptr; } - void propagate_literal(sat::literal lit) override { m_g = nullptr; } + void start_propagation() override; + void propagate_literal(sat::literal lit) override; bool propagate() override { return false; } bool repair_down(app* e) override { return true; } void repair_up(app* e) override {} void repair_literal(sat::literal lit) override { m_g = nullptr; } - bool is_sat() override { return true; } + bool is_sat() override; void on_rescale() override {} void on_restart() override {}