From f8fa2de35b3bb79249c2de4ca17d248a212fdbb7 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Fri, 11 Oct 2024 09:54:46 -0700 Subject: [PATCH] add incremental mode --- src/ast/sls/sls_euf_plugin.cpp | 136 +++++++++++++++++++++++++++++---- src/ast/sls/sls_euf_plugin.h | 13 +++- src/params/sls_params.pyg | 1 + 3 files changed, 133 insertions(+), 17 deletions(-) diff --git a/src/ast/sls/sls_euf_plugin.cpp b/src/ast/sls/sls_euf_plugin.cpp index a85ef5e69..5114e3829 100644 --- a/src/ast/sls/sls_euf_plugin.cpp +++ b/src/ast/sls/sls_euf_plugin.cpp @@ -36,11 +36,15 @@ namespace sls { euf_plugin::~euf_plugin() {} - void euf_plugin::start_propagation() { - m_g = alloc(euf::egraph, m); - init_egraph(*m_g); + void euf_plugin::initialize() { + m_incremental = ctx.get_params().get_bool("euf_incremental", m_incremental); + IF_VERBOSE(2, verbose_stream() << "sls.euf: incremental " << m_incremental << < "\n"); } + void euf_plugin::start_propagation() { + m_g = alloc(euf::egraph, m); + init_egraph(*m_g, !m_incremental); + } void euf_plugin::register_term(expr* e) { if (!is_app(e)) @@ -71,7 +75,103 @@ namespace sls { return true; } + void euf_plugin::propagate_literal_incremental(sat::literal lit) { + m_replay_stack.push_back(lit); + replay(); + } + void euf_plugin::resolve() { + if (!g.inconsistent()) + return; + + unsigned n = 1; + sat::literal_vector lits; + sat::literal flit = sat::null_literal, slit; + ptr_vector explain; + g.begin_explain(); + g.explain(explain, nullptr); + g.end_explain(); + for (auto p : explain) { + sat::literal l = to_literal(p); + SASSERT(ctx.is_true(l)); + lits.push_back(~l); + if (ctx.rand(++n) == 0) + flit = l; + } + ctx.add_clause(lits); + if (flit == sat::null_literal) + return; + do { + slit = m_stack.back(); + g.pop(1); + m_replay_stack.push_back(slit); + m_stack.pop_back(); + } + while (slit != flit); + // flip the last literal on the replay stack + IF_VERBOSE(2, verbose_stream() << "sls.euf - flip " << flit << "\n"); + ctx.flip(flit.var()); + m_replay_stack.back().neg(); + } + + void euf_plugin::replay() { + while (!m_replay_stack.empty()) { + auto l = m_replay_stack.back(); + m_replay_stack.pop_back(); + propagate_literal_incremental_step(l); + if (g.inconsistent()) + resolve(); + } + } + + + void euf_plugin::propagate_literal_incremental_step(sat::literal lit) { + SASSERT(ctx.is_true(lit)); + auto e = ctx.atom(lit.var()); + expr* x, * y; + auto& g = *m_g; + + if (!e) + return; + + m_stack.push_back(lit); + g.push(); + if (!lit.sign() && m.is_eq(e, x, y)) { + auto a = g.find(x); + auto b = g.find(y); + g.merge(a, b, to_ptr(lit)); + } + else if (!lit.sign() && m.is_distinct(e)) { + auto n = to_app(e)->get_num_args(); + for (unsigned i = 0; i < n; ++i) { + expr* a = to_app(e)->get_arg(i); + for (unsigned j = i + 1; j < n; ++j) { + auto b = to_app(e)->get_arg(j); + expr_ref eq(m.mk_eq(a, b), m); + auto c = g.find(eq); + if (!g.find(eq)) { + enode* args[2] = { g.find(a), g.find(b) }; + c = g.mk(eq, 2, args, nullptr); + } + g.merge(c, g.find(m.mk_false()), to_ptr(lit)); + } + } + } + else { + auto a = g.find(e); + auto b = g.find(m.mk_bool_val(!lit.sign())); + g.merge(a, b, to_ptr(lit)); + } + g.propagate(); + } + void euf_plugin::propagate_literal(sat::literal lit) { + if (m_incremental) + propagate_literal_incremental(lit); + else + propagate_literal_non_incremental(lit); + } + + void euf_plugin::propagate_literal_non_incremental(sat::literal lit) { SASSERT(ctx.is_true(lit)); auto e = ctx.atom(lit.var()); expr* x, * y; @@ -126,8 +226,9 @@ namespace sls { } } - void euf_plugin::init_egraph(euf::egraph& g) { + void euf_plugin::init_egraph(euf::egraph& g, bool merge_eqs) { ptr_vector args; + m_stack.reset(); for (auto t : ctx.subterms()) { args.reset(); if (is_app(t)) @@ -139,17 +240,22 @@ namespace sls { g.mk(m.mk_true(), 0, 0, nullptr); if (!g.find(m.mk_false())) g.mk(m.mk_false(), 0, 0, nullptr); - for (auto lit : ctx.root_literals()) { - if (!ctx.is_true(lit)) - lit.neg(); - auto e = ctx.atom(lit.var()); - expr* x, * y; - if (e && m.is_eq(e, x, y) && !lit.sign()) - g.merge(g.find(x), g.find(y), to_ptr(lit)); - else if (!lit.sign()) - g.merge(g.find(e), g.find(m.mk_true()), to_ptr(lit)); + + // merge all equalities + // check for conflict with disequalities during propagation + if (merge_eqs) { + for (auto lit : ctx.root_literals()) { + if (!ctx.is_true(lit)) + lit.neg(); + auto e = ctx.atom(lit.var()); + expr* x, * y; + if (e && m.is_eq(e, x, y) && !lit.sign()) + g.merge(g.find(x), g.find(y), to_ptr(lit)); + else if (!lit.sign()) + g.merge(g.find(e), g.find(m.mk_true()), to_ptr(lit)); + } + g.propagate(); } - g.propagate(); typedef obj_map map1; typedef obj_map map2; @@ -177,7 +283,7 @@ namespace sls { if (!m_g) { m_g = alloc(euf::egraph, m); - init_egraph(*m_g); + init_egraph(*m_g, true); } auto n = m_g->find(e)->get_root(); VERIFY(m_root2value->find(n, e)); diff --git a/src/ast/sls/sls_euf_plugin.h b/src/ast/sls/sls_euf_plugin.h index 39a6dd754..406243ea4 100644 --- a/src/ast/sls/sls_euf_plugin.h +++ b/src/ast/sls/sls_euf_plugin.h @@ -36,12 +36,21 @@ namespace sls { }; hashtable m_values; + bool m_incremental = false; + scoped_ptr m_g; scoped_ptr> m_num_elems; scoped_ptr> m_root2value; scoped_ptr m_pinned; - void init_egraph(euf::egraph& g); + void init_egraph(euf::egraph& g, bool merge_eqs); + sat::literal_vector m_stack, m_replay_stack; + void propagate_literal_incremental(sat::literal lit); + void propagate_literal_incremental_step(sat::literal lit); + void resolve(); + void replay(); + + void propagate_literal_non_incremental(sat::literal lit); bool is_user_sort(sort* s) { return s->get_family_id() == user_sort_family_id; } size_t* to_ptr(sat::literal l) { return reinterpret_cast((size_t)(l.index() << 4)); }; @@ -52,7 +61,7 @@ namespace sls { ~euf_plugin() override; family_id fid() { return m_fid; } expr_ref get_value(expr* e) override; - void initialize() override {} + void initialize() override; void start_propagation() override; void propagate_literal(sat::literal lit) override; bool propagate() override; diff --git a/src/params/sls_params.pyg b/src/params/sls_params.pyg index 18b8d3371..708041db0 100644 --- a/src/params/sls_params.pyg +++ b/src/params/sls_params.pyg @@ -22,6 +22,7 @@ def_module_params('sls', ('early_prune', BOOL, 1, 'use early pruning for score prediction'), ('random_offset', BOOL, 1, 'use random offset for candidate evaluation'), ('rescore', BOOL, 1, 'rescore/normalize top-level score every base restart interval'), + ('euf_incremental', BOOL, False, 'use incremental EUF resolver'), ('track_unsat', BOOL, 0, 'keep a list of unsat assertions as done in SAT - currently disabled internally'), ('random_seed', UINT, 0, 'random seed') ))