diff --git a/src/sat/smt/user_solver.cpp b/src/sat/smt/user_solver.cpp index 0204573dc..a4273744b 100644 --- a/src/sat/smt/user_solver.cpp +++ b/src/sat/smt/user_solver.cpp @@ -39,8 +39,10 @@ namespace user_solver { ctx.attach_th_var(n, this, v); expr_ref r(m); sat::literal_vector explain; - if (ctx.is_fixed(n, r, explain)) + if (ctx.is_fixed(n, r, explain)) { m_prop.push_back(prop_info(explain, v, r)); + DEBUG_CODE(for (auto lit : explain) VERIFY(s().value(lit) == l_true);); + } } bool solver::propagate_cb( @@ -90,8 +92,20 @@ namespace user_solver { if (!m_fixed_eh) return; force_push(); + if (m_fixed.contains(v)) + return; + m_fixed.insert(v); + ctx.push(insert_map(m_fixed, v)); m_id2justification.setx(v, sat::literal_vector(num_lits, jlits), sat::literal_vector()); - m_fixed_eh(m_user_context, this, var2expr(v), value); + for (unsigned i = 0; i < num_lits; ++i) + if (s().value(m_id2justification[v][i]) == l_false) + m_id2justification[v][i].neg(); + try { + m_fixed_eh(m_user_context, this, var2expr(v), value); + } + catch (...) { + throw default_exception("Exception thrown in \"fixed\"-callback"); + } } bool solver::decide(sat::bool_var& var, lbool& phase) { @@ -179,9 +193,9 @@ namespace user_solver { sat::literal lit = ctx.internalize(prop.m_conseq, false, false); if (s().value(lit) != l_true) { auto j = mk_justification(m_qhead); + persist_clause(lit, j); s().assign(lit, j); ++m_stats.m_num_propagations; - persist_clause(lit, j); } } @@ -193,7 +207,7 @@ namespace user_solver { if (m_qhead == m_prop.size() && m_replay_qhead == m_clauses_to_replay.size()) return false; force_push(); - + bool replayed = false; if (m_replay_qhead < m_clauses_to_replay.size()) { replayed = true; @@ -235,11 +249,15 @@ namespace user_solver { clause.push_back(ctx.literal2expr(~lit)); for (auto const& [a,b] : prop.m_eqs) clause.push_back(m.mk_not(m.mk_eq(a, b))); + clause.push_back(ctx.literal2expr(lit)); + if (m.is_false(clause.back())) + clause.pop_back(); m_clauses_to_replay.push_back(clause); if (m_replay_qhead + 1 < m_clauses_to_replay.size()) std::swap(m_clauses_to_replay[m_replay_qhead], m_clauses_to_replay[m_clauses_to_replay.size()-1]); + ctx.push(value_trail(m_replay_qhead)); ++m_replay_qhead; } @@ -260,6 +278,7 @@ namespace user_solver { auto const& prop = m_prop[j.m_propagation_index]; for (unsigned id : prop.m_ids) r.append(m_id2justification[id]); + DEBUG_CODE(for (auto lit : r) VERIFY(s().value(lit) == l_true);); for (auto const& p : prop.m_eqs) ctx.add_eq_antecedent(probing, expr2enode(p.first), expr2enode(p.second)); } diff --git a/src/sat/smt/user_solver.h b/src/sat/smt/user_solver.h index 373b046b8..4bdfcf064 100644 --- a/src/sat/smt/user_solver.h +++ b/src/sat/smt/user_solver.h @@ -79,6 +79,7 @@ namespace user_solver { lbool m_next_split_phase = l_undef; vector m_clauses_to_replay; unsigned m_replay_qhead = 0; + uint_set m_fixed; struct justification { unsigned m_propagation_index { 0 };