diff --git a/src/sat/smt/user_solver.cpp b/src/sat/smt/user_solver.cpp index 01c1786ec..0204573dc 100644 --- a/src/sat/smt/user_solver.cpp +++ b/src/sat/smt/user_solver.cpp @@ -178,8 +178,10 @@ namespace user_solver { void solver::propagate_consequence(prop_info const& prop) { sat::literal lit = ctx.internalize(prop.m_conseq, false, false); if (s().value(lit) != l_true) { - s().assign(lit, mk_justification(m_qhead)); + auto j = mk_justification(m_qhead); + s().assign(lit, j); ++m_stats.m_num_propagations; + persist_clause(lit, j); } } @@ -188,9 +190,17 @@ namespace user_solver { } bool solver::unit_propagate() { - if (m_qhead == m_prop.size()) + if (m_qhead == m_prop.size() && m_replay_qhead == m_clauses_to_replay.size()) return false; force_push(); + + bool replayed = false; + if (m_replay_qhead < m_clauses_to_replay.size()) { + replayed = true; + ctx.push(value_trail(m_replay_qhead)); + for (; m_replay_qhead < m_clauses_to_replay.size(); ++m_replay_qhead) + replay_clause(m_clauses_to_replay.get(m_replay_qhead)); + } ctx.push(value_trail(m_qhead)); unsigned np = m_stats.m_num_propagations; for (; m_qhead < m_prop.size() && !s().inconsistent(); ++m_qhead) { @@ -200,7 +210,37 @@ namespace user_solver { else propagate_new_fixed(prop); } - return np < m_stats.m_num_propagations; + return np < m_stats.m_num_propagations || replayed; + } + + void solver::replay_clause(expr_ref_vector const& clause) { + sat::literal_vector lits; + for (expr* e : clause) + lits.push_back(ctx.mk_literal(e)); + add_clause(lits); + } + + void solver::persist_clause(sat::literal lit, sat::justification const& sj) { + if (!ctx.get_config().m_up_persist_clauses) + return; + + expr_ref_vector clause(m); + auto idx = sj.get_ext_justification_idx(); + auto& j = justification::from_index(idx); + auto const& prop = m_prop[j.m_propagation_index]; + sat::literal_vector r; + for (unsigned id : prop.m_ids) + r.append(m_id2justification[id]); + for (auto lit : r) + clause.push_back(ctx.literal2expr(~lit)); + for (auto const& [a,b] : prop.m_eqs) + clause.push_back(m.mk_not(m.mk_eq(a, b))); + clause.push_back(ctx.literal2expr(lit)); + + m_clauses_to_replay.push_back(clause); + if (m_replay_qhead + 1 < m_clauses_to_replay.size()) + std::swap(m_clauses_to_replay[m_replay_qhead], m_clauses_to_replay[m_clauses_to_replay.size()-1]); + ++m_replay_qhead; } void solver::collect_statistics(::statistics& st) const { diff --git a/src/sat/smt/user_solver.h b/src/sat/smt/user_solver.h index cd94441ea..373b046b8 100644 --- a/src/sat/smt/user_solver.h +++ b/src/sat/smt/user_solver.h @@ -77,6 +77,8 @@ namespace user_solver { stats m_stats; sat::bool_var m_next_split_var = sat::null_bool_var; lbool m_next_split_phase = l_undef; + vector m_clauses_to_replay; + unsigned m_replay_qhead = 0; struct justification { unsigned m_propagation_index { 0 }; @@ -105,6 +107,9 @@ namespace user_solver { sat::bool_var enode_to_bool(euf::enode* n, unsigned idx); + void replay_clause(expr_ref_vector const& clause); + void persist_clause(sat::literal lit, sat::justification const& j); + public: solver(euf::solver& ctx);