diff --git a/src/solver/smtmus.cpp b/src/solver/smtmus.cpp index 9d4f8f392..0c975d12e 100644 --- a/src/solver/smtmus.cpp +++ b/src/solver/smtmus.cpp @@ -204,6 +204,9 @@ struct smtmus::imp { } void init_occurs() { + m_soft_occurs.reset(); + m_hard_occurs.reset(); + m_soft_vars.reset(); unsigned idx = 0; for (auto const& clause : m_soft_clauses) { for (auto* lit : clause) { @@ -216,22 +219,26 @@ struct smtmus::imp { } idx = 0; func_decl_ref_vector vars(m); + for (auto const& [v, w] : m_soft_occurs) + m_hard_occurs.insert(v, unsigned_vector()); for (auto* fml : m_hard) { vars.reset(); extract_vars(fml, vars); for (auto* v : vars) if (m_soft_vars.contains(v)) - init_occurs(idx, v, m_hard_occurs); + m_hard_occurs[v].push_back(idx); ++idx; } } - void init_lit2ineq() { - for (auto const& [lit, vars] : m_lit2vars) - init_lit2ineq(lit); + ineq* lit2ineq(expr* lit) { + ineq* e = nullptr; + if (m_lit2ineq.find(lit, e)) + return e; + return init_lit2ineq(lit); } - void init_lit2ineq(expr* lit) { + ineq* init_lit2ineq(expr* lit) { bool is_not = m.is_not(lit, lit); expr* x, * y; auto mul = [&](rational const& coeff, expr* t) -> expr* { @@ -286,10 +293,12 @@ struct smtmus::imp { else e->m_base = a.mk_add(basis); m_lit2ineq.insert(lit, e); + return e; } else { // literals that don't correspond to inequalities are associated with null. m_lit2ineq.insert(lit, nullptr); + return nullptr; } } @@ -311,7 +320,6 @@ struct smtmus::imp { void init() { init_soft_clauses(); init_occurs(); - init_lit2ineq(); } lbool get_mus(expr_ref_vector& mus) { @@ -458,23 +466,23 @@ struct smtmus::imp { } expr_ref_vector rotate_get_flips(expr* lit, func_decl* v, model& mdl, unsigned limit) { - expr_ref_vector result(m); + expr_ref_vector flips(m); if (m.is_bool(v->get_range())) { expr_ref val(m); expr* lit2 = lit; m.is_not(lit, lit2); if (is_app(lit2) && to_app(lit2)->get_decl() == v && mdl.eval(v, val)) { - result.push_back(m.mk_bool_val(m.is_false(val))); - return result; + flips.push_back(m.mk_bool_val(m.is_false(val))); + return flips; } } - result = rotate_get_eq_flips(lit, v, mdl, limit); - if (!result.empty()) - return result; - result = rotate_get_ineq_flips(lit, v, mdl, limit); - if (!result.empty()) - return result; + flips = rotate_get_eq_flips(lit, v, mdl, limit); + if (!flips.empty()) + return flips; + flips = rotate_get_ineq_flips(lit, v, mdl, limit); + if (!flips.empty()) + return flips; return rotate_get_flips_agnostic(lit, v, mdl, limit); } @@ -495,7 +503,7 @@ struct smtmus::imp { expr_ref_vector rotate_get_ineq_flips(expr* lit, func_decl* v, model& mdl, unsigned limit) { ineq* e = nullptr; expr_ref_vector flips(m); - if (m_lit2ineq.find(lit, e) && e && e->m_coeffs.contains(v)) { + if ((e = lit2ineq(lit)) && e && e->m_coeffs.contains(v)) { rational coeff = e->m_coeffs[v]; rational val = e->get_value(mdl, a, v); bool is_int = a.is_int(v->get_range()); @@ -511,7 +519,7 @@ struct smtmus::imp { s2->assert_expr(lit); auto const& vars = get_vars(lit); expr_ref val(m); - expr_ref_vector result(m); + expr_ref_vector flips(m); for (auto& v2 : vars) { if (v2 == v) continue; @@ -524,10 +532,10 @@ struct smtmus::imp { s2->get_model(m2); if (!m2->eval(v, val)) break; - result.push_back(val); + flips.push_back(val); s2->assert_expr(m.mk_not(m.mk_eq(val, m.mk_const(v)))); } - return result; + return flips; } bool rotate_get_falsified(bool_vector const& formula, model& mdl, func_decl* f, unsigned& falsified) { @@ -595,9 +603,7 @@ struct smtmus::imp { bool arith_are_conflicting(unsigned i, unsigned j, func_decl* v) { auto insert_bounds = [&](vector& bounds, expr_ref_vector const& lits) { for (auto* lit : lits) { - ineq* e = nullptr; - if (!m_lit2ineq.find(lit, e)) - return true; + ineq* e = lit2ineq(lit); if (!e) return true; if (!a.is_numeral(e->m_base))