From bbf97c5e212a1ca36b89a239d02b1583a5a1acd8 Mon Sep 17 00:00:00 2001 From: Ilana Shapiro Date: Wed, 29 Oct 2025 18:42:23 -0700 Subject: [PATCH] fix some things for clause replay --- src/smt/smt_context.cpp | 3 +-- src/smt/smt_context.h | 4 ++-- src/smt/smt_internalizer.cpp | 23 ++++++++--------------- src/smt/smt_parallel.cpp | 19 +++++++++++++++++++ 4 files changed, 30 insertions(+), 19 deletions(-) diff --git a/src/smt/smt_context.cpp b/src/smt/smt_context.cpp index 46b630231..c0a2f3a4d 100644 --- a/src/smt/smt_context.cpp +++ b/src/smt/smt_context.cpp @@ -82,8 +82,7 @@ namespace smt { m_mk_bool_var_trail(*this), m_mk_enode_trail(*this), m_mk_lambda_trail(*this), - m_lemma_visitor(m), - m_recorded_clauses(m) { + m_lemma_visitor(m) { SASSERT(m_scope_lvl == 0); SASSERT(m_base_lvl == 0); diff --git a/src/smt/smt_context.h b/src/smt/smt_context.h index 328e7dc25..198c099a0 100644 --- a/src/smt/smt_context.h +++ b/src/smt/smt_context.h @@ -137,7 +137,7 @@ namespace smt { scoped_ptr m_fmls; svector m_lit_scores[2]; - expr_ref_vector m_recorded_clauses; + vector m_recorded_clauses; // ----------------------------------- @@ -1302,7 +1302,7 @@ namespace smt { void add_scores(unsigned n, literal const *lits); - void record_clause(clause const* cls); + void record_clause(unsigned n, literal const * lits); // ----------------------------------- diff --git a/src/smt/smt_internalizer.cpp b/src/smt/smt_internalizer.cpp index 90b49dcfa..be884fd18 100644 --- a/src/smt/smt_internalizer.cpp +++ b/src/smt/smt_internalizer.cpp @@ -936,7 +936,6 @@ namespace smt { m_lit_scores[0].reserve(v + 1); m_lit_scores[1].reserve(v + 1); m_lit_scores[0][v] = m_lit_scores[1][v] = 0.0; - m_recorded_clauses.reserve(v + 1); literal l(v, false); literal not_l(v, true); @@ -967,29 +966,23 @@ namespace smt { } // following the pattern of solver::persist_clause in src/sat/smt/user_solver.cpp - void context::record_clause(clause const* cls) { - expr_ref_vector clause(m); - for (unsigned i = 0; i < cls->get_num_literals(); ++i) { - literal lit = cls->get_literal(i); - clause.push_back(literal2expr(~lit)); - } - if (!clause.empty() && m.is_false(clause.back())) - clause.pop_back(); - expr_ref disj(m.mk_or(clause.size(), clause.data()), m); - m_recorded_clauses.push_back(disj); + void context::record_clause(unsigned num_lits, literal const *lits) { + literal_vector clause; + clause.append(num_lits, lits); + m_recorded_clauses.push_back(clause); } void context::add_scores(unsigned n, literal const *lits) { for (unsigned i = 0; i < n; ++i) { auto lit = lits[i]; - unsigned v = lit.var(); // unique key per literal - m_lit_scores[lit.sign()][v] += 1.0 / n; + unsigned v = lit.var(); // uniq0 / n; } } void context::undo_mk_bool_var() { - SASSERT(!m_b_internalized_stack.empty()); + SASSERT(!m_b_internalized_stack.empty(ue key per literal + m_lit_scores[lit.sign()][v] += 1.)); m_stats.m_num_del_bool_var++; expr * n = m_b_internalized_stack.back(); unsigned n_id = n->get_id(); @@ -1447,6 +1440,7 @@ namespace smt { case CLS_LEARNED: dump_lemma(num_lits, lits); add_scores(num_lits, lits); + record_clause(num_lits, lits); break; default: break; @@ -1506,7 +1500,6 @@ namespace smt { if (k == CLS_LEARNED) { int w2_idx = select_learned_watch_lit(cls); cls->swap_lits(1, w2_idx); - record_clause(cls); } else { SASSERT(k == CLS_TH_LEMMA); diff --git a/src/smt/smt_parallel.cpp b/src/smt/smt_parallel.cpp index e619a4107..a31927934 100644 --- a/src/smt/smt_parallel.cpp +++ b/src/smt/smt_parallel.cpp @@ -82,7 +82,26 @@ namespace smt { void parallel::param_generator::replay_proof_prefixes(unsigned max_conflicts_epsilon=200) { unsigned conflict_budget = m_max_prefix_conflicts + max_conflicts_epsilon; + // loop through m_param_probe_contexts + for (unsigned i = 0; i < m_param_probe_contexts.size(); ++i) { + IF_VERBOSE(1, verbose_stream() << " PARAM TUNER: replaying proof prefix in param probe context " << i << "\n"); + context *probe_ctx = m_param_probe_contexts[i]; + probe_ctx->get_fparams().m_max_conflicts = conflict_budget; + + for (auto const& clause : probe_ctx->m_recorded_clauses) { + expr_ref_vector negated_lits(probe_ctx->m); + for (literal lit : clause) { + expr* e = probe_ctx->bool_var2expr(lit.var()); + if (!e) continue; // skip if var not yet mapped + if (!lit.sign()) + e = probe_ctx->m.mk_not(e); // since bool_var2expr discards sign + negated_lits.push_back(e); + } + // Replay the negated clause + lbool r = probe_ctx->check(negated_lits.size(), negated_lits.data()); + } + } } void parallel::param_generator::protocol_iteration() {