diff --git a/src/solver/simplifier_solver.cpp b/src/solver/simplifier_solver.cpp index 219b2c46b..29021af76 100644 --- a/src/solver/simplifier_solver.cpp +++ b/src/solver/simplifier_solver.cpp @@ -95,10 +95,21 @@ class simplifier_solver : public solver { expr_ref_vector m_assumptions; model_converter_ref m_mc; bool m_inconsistent = false; + expr_safe_replace m_core_replace; + + void replace(expr_ref_vector& r) { + expr_ref tmp(m); + for (unsigned i = 0; i < r.size(); ++i) { + m_core_replace(r.get(i), tmp); + r[i] = tmp; + } + } void flush(expr_ref_vector& assumptions) { unsigned qhead = m_preprocess_state.qhead(); - if (qhead < m_fmls.size()) { + expr_ref_vector orig_assumptions(assumptions); + m_core_replace.reset(); + if (qhead < m_fmls.size() || !assumptions.empty()) { for (expr* a : assumptions) m_preprocess_state.freeze(a); TRACE("solver", tout << "qhead " << qhead << "\n"); @@ -107,6 +118,8 @@ class simplifier_solver : public solver { if (!m.inc()) return; m_preprocess_state.advance_qhead(); + for (unsigned i = 0; i < assumptions.size(); ++i) + m_core_replace.insert(assumptions.get(i), orig_assumptions.get(i)); } m_mc = m_preprocess_state.model_trail().get_model_converter(); m_cached_mc = nullptr; @@ -148,6 +161,7 @@ public: m_preprocess_state(*this), m_preprocess(m, s->get_params(), m_preprocess_state), m_assumptions(m), + m_core_replace(m), m_proof(m) { if (fac) @@ -189,7 +203,7 @@ public: lbool check_sat_core(unsigned num_assumptions, expr* const* assumptions) override { expr_ref_vector _assumptions(m, num_assumptions, assumptions); flush(_assumptions); - return s->check_sat_core(num_assumptions, assumptions); + return s->check_sat_core(num_assumptions, _assumptions.data()); } void collect_statistics(statistics& st) const override { @@ -258,7 +272,7 @@ public: std::string reason_unknown() const override { return s->reason_unknown(); } void set_reason_unknown(char const* msg) override { s->set_reason_unknown(msg); } void get_labels(svector& r) override { s->get_labels(r); } - void get_unsat_core(expr_ref_vector& r) { s->get_unsat_core(r); } + void get_unsat_core(expr_ref_vector& r) { s->get_unsat_core(r); replace(r); } ast_manager& get_manager() const override { return s->get_manager(); } void reset_params(params_ref const& p) override { s->reset_params(p); } params_ref const& get_params() const override { return s->get_params(); } @@ -276,7 +290,13 @@ public: lbool check_sat_cc(expr_ref_vector const& cube, vector const& clauses) override { return check_sat_cc(cube, clauses); } void set_progress_callback(progress_callback* callback) override { s->set_progress_callback(callback); } lbool get_consequences(expr_ref_vector const& asms, expr_ref_vector const& vars, expr_ref_vector& consequences) override { - return s->get_consequences(asms, vars, consequences); + expr_ref_vector es(m); + es.append(asms); + es.append(vars); + flush(es); + lbool r = s->get_consequences(asms, vars, consequences); + replace(consequences); + return r; } lbool find_mutexes(expr_ref_vector const& vars, vector& mutexes) override { return s->find_mutexes(vars, mutexes); } lbool preferred_sat(expr_ref_vector const& asms, vector& cores) override { return s->preferred_sat(asms, cores); }