diff --git a/src/smt/smt_solver.cpp b/src/smt/smt_solver.cpp index 1f90030d7..4f917ba00 100644 --- a/src/smt/smt_solver.cpp +++ b/src/smt/smt_solver.cpp @@ -196,6 +196,7 @@ namespace { m_context.pop(n); } + lbool check_sat_core2(unsigned num_assumptions, expr * const * assumptions) override { TRACE("solver_na2as", tout << "smt_solver::check_sat_core: " << num_assumptions << "\n";); return m_context.check(num_assumptions, assumptions); @@ -275,6 +276,7 @@ namespace { if (use_smtmus) { smtmus mus(*this); mus.add_soft(r.size(), r.data()); + mus.set_assumptions(m_last_assumptions); expr_ref_vector r2(m); if (l_true == mus.get_mus(r2)) { r.reset(); diff --git a/src/solver/smtmus.cpp b/src/solver/smtmus.cpp index 8f6e8ca5f..f0947927f 100644 --- a/src/solver/smtmus.cpp +++ b/src/solver/smtmus.cpp @@ -125,6 +125,7 @@ struct smtmus::imp { obj_map m_soft_occurs; // map from variables to soft clause occurrences obj_map m_hard_occurs; // map from variables to hard clause occurrences obj_map m_lit2ineq; // map from literals to inequality abstraction + expr_ref_vector m_assumptions; // set of assumptions used. unsigned m_rotated = 0; unsigned p_max_cores = 30; @@ -134,7 +135,7 @@ struct smtmus::imp { bool p_use_reset = false; imp(solver& s) : - m_solver(s), m_main_solver(&s), m(s.get_manager()), a(m), m_soft(m), m_hard(m) + m_solver(s), m_main_solver(&s), m(s.get_manager()), a(m), m_soft(m), m_hard(m), m_assumptions(m) {} ~imp() { @@ -145,23 +146,16 @@ struct smtmus::imp { m_soft.push_back(lit); } - void init_soft_clauses() { + void set_assumptions(expr_ref_vector const& assumptions) { + m_assumptions.reset(); + m_assumptions.append(assumptions); + } + + void init_softs(expr_ref_vector const& soft, obj_map& softs) { obj_hashtable dups; - obj_map soft2hard; - obj_map softs; - u_map hard2soft; - unsigned idx = 0; - - // initialize hard clauses - m_hard.reset(); - m_hard.append(m_solver.get_assertions()); - // initialize soft clauses. - m_soft_clauses.reset(); - for (expr* s : m_soft) - m_soft_clauses.push_back(expr_ref_vector(m, 1, &s)); - // collect indicator variable candidates - for (expr* s : m_soft) { + unsigned idx = 0; + for (expr* s : soft) { if (is_uninterp_const(s)) { if (softs.contains(s)) dups.insert(s); @@ -172,11 +166,11 @@ struct smtmus::imp { } for (auto* s : dups) softs.remove(s); - if (softs.empty()) - return; + } + void init_soft2hard(obj_map& soft2hard, u_map& hard2soft, obj_mapconst & softs) { // find all clauses where soft indicators are used. - idx = 0; + unsigned idx = 0; for (expr* f : m_hard) { expr_ref_vector ors(m); flatten_or(f, ors); @@ -190,7 +184,6 @@ struct smtmus::imp { } ++idx; } - // remove hard2soft associations if soft clauses don't occur uniquely. idx = 0; unsigned_vector to_remove; @@ -206,6 +199,32 @@ struct smtmus::imp { } for (auto i : to_remove) hard2soft.remove(i); + } + + void simplify_hard(u_map const& hard2soft) { + for (auto const& [i, s] : hard2soft) + m_hard[i] = m.mk_true(); + } + + void init_soft_clauses() { + obj_map soft2hard; + obj_map softs; + u_map hard2soft; + unsigned idx = 0; + + // initialize hard clauses + m_hard.reset(); + m_hard.append(m_solver.get_assertions()); + // initialize soft clauses. + m_soft_clauses.reset(); + for (expr* s : m_soft) + m_soft_clauses.push_back(expr_ref_vector(m, 1, &s)); + + init_softs(m_soft, softs); + if (softs.empty()) + return; + + init_soft2hard(soft2hard, hard2soft, softs); // // update soft clauses using hard clauses. @@ -231,16 +250,27 @@ struct smtmus::imp { ++idx; } SASSERT(idx <= ors.size()); - m_hard[i] = m.mk_true(); + } + simplify_hard(hard2soft); + softs.reset(); + hard2soft.reset(); + init_softs(m_assumptions, softs); + if (!softs.empty()) { + init_soft2hard(soft2hard, hard2soft, softs); + simplify_hard(hard2soft); + } TRACE("satmus", - for (expr* s : m_soft) - tout << "soft " << mk_pp(s, m) << "\n"; - for (auto const& clause : m_soft_clauses) - tout << "clause " << clause << "\n"; - for (expr* h : m_hard) - tout << "hard " << mk_pp(h, m) << "\n";); + for (expr* s : m_soft) + tout << "soft " << mk_pp(s, m) << "\n"; + for (auto const& clause : m_soft_clauses) + tout << "clause " << clause << "\n"; + for (expr* h : m_hard) + tout << "hard " << mk_pp(h, m) << "\n"; + for (expr* a : m_assumptions) + tout << "assumption " << mk_pp(a, m) << "\n"; + ); } void init_occurs(unsigned idx, func_decl* v, obj_map& occurs) { @@ -287,12 +317,16 @@ struct smtmus::imp { return init_lit2ineq(lit); } - ineq* init_lit2ineq(expr* lit) { + ineq* init_lit2ineq(expr* _lit) { + expr* lit = _lit; bool is_not = m.is_not(lit, lit); expr* x, * y; auto mul = [&](rational const& coeff, expr* t) -> expr* { + rational coeff2; if (coeff == 1) return t; + else if (a.is_numeral(t, coeff2)) + return a.mk_numeral(coeff*coeff2, t->get_sort()); return a.mk_mul(a.mk_numeral(coeff, a.is_int(t)), t); }; if (a.is_le(lit, x, y) || a.is_lt(lit, x, y) || a.is_ge(lit, y, x) || a.is_gt(lit, y, x)) { @@ -342,12 +376,12 @@ struct smtmus::imp { e->m_base = a.mk_numeral(rational::zero(), a.is_int(x)); else e->m_base = a.mk_add(basis); - m_lit2ineq.insert(lit, e); + m_lit2ineq.insert(_lit, e); return e; } else { // literals that don't correspond to inequalities are associated with null. - m_lit2ineq.insert(lit, nullptr); + m_lit2ineq.insert(_lit, nullptr); return nullptr; } } @@ -735,3 +769,6 @@ lbool smtmus::get_mus(expr_ref_vector& mus) { return m_imp->get_mus(mus); } +void smtmus::set_assumptions(expr_ref_vector const& assumptions) { + m_imp->set_assumptions(assumptions); +} diff --git a/src/solver/smtmus.h b/src/solver/smtmus.h index 7445d37e4..3af9b1c5e 100644 --- a/src/solver/smtmus.h +++ b/src/solver/smtmus.h @@ -35,6 +35,8 @@ class smtmus { add_soft(clss[i]); } + void set_assumptions(expr_ref_vector const& assumptions); + /** Retrieve mus over soft constraints */ diff --git a/src/solver/solver_na2as.cpp b/src/solver/solver_na2as.cpp index 4951f8833..9420b30e4 100644 --- a/src/solver/solver_na2as.cpp +++ b/src/solver/solver_na2as.cpp @@ -25,7 +25,8 @@ Notes: solver_na2as::solver_na2as(ast_manager & m): m(m), - m_assumptions(m) { + m_assumptions(m), + m_last_assumptions(m) { } solver_na2as::~solver_na2as() {} @@ -45,36 +46,37 @@ void solver_na2as::assert_expr_core2(expr * t, expr * a) { } } -struct append_assumptions { - expr_ref_vector & m_assumptions; +struct solver_na2as::append_assumptions { + solver_na2as& s; unsigned m_old_sz; - append_assumptions(expr_ref_vector & _m_assumptions, + append_assumptions(solver_na2as& s, unsigned num_assumptions, expr * const * assumptions): - m_assumptions(_m_assumptions) { - m_old_sz = m_assumptions.size(); - m_assumptions.append(num_assumptions, assumptions); + s(s), m_old_sz(s.m_assumptions.size()) { + s.m_assumptions.append(num_assumptions, assumptions); + s.m_last_assumptions.reset(); + s.m_last_assumptions.append(s.m_assumptions); } ~append_assumptions() { - m_assumptions.shrink(m_old_sz); + s.m_assumptions.shrink(m_old_sz); } }; lbool solver_na2as::check_sat_core(unsigned num_assumptions, expr * const * assumptions) { - append_assumptions app(m_assumptions, num_assumptions, assumptions); + append_assumptions app(*this, num_assumptions, assumptions); TRACE("solver_na2as", display(tout);); return check_sat_core2(m_assumptions.size(), m_assumptions.data()); } lbool solver_na2as::check_sat_cc(const expr_ref_vector &assumptions, vector const &clauses) { if (clauses.empty()) return check_sat(assumptions.size(), assumptions.data()); - append_assumptions app(m_assumptions, assumptions.size(), assumptions.data()); + append_assumptions app(*this, assumptions.size(), assumptions.data()); return check_sat_cc_core(m_assumptions, clauses); } lbool solver_na2as::get_consequences(expr_ref_vector const& asms, expr_ref_vector const& vars, expr_ref_vector& consequences) { - append_assumptions app(m_assumptions, asms.size(), asms.data()); + append_assumptions app(*this, asms.size(), asms.data()); return get_consequences_core(m_assumptions, vars, consequences); } diff --git a/src/solver/solver_na2as.h b/src/solver/solver_na2as.h index c8340bd6e..c835a77da 100644 --- a/src/solver/solver_na2as.h +++ b/src/solver/solver_na2as.h @@ -26,9 +26,10 @@ Notes: class solver_na2as : public solver { protected: ast_manager & m; - expr_ref_vector m_assumptions; + expr_ref_vector m_assumptions, m_last_assumptions; unsigned_vector m_scopes; void restore_assumptions(unsigned old_sz); + struct append_assumptions; public: solver_na2as(ast_manager & m); ~solver_na2as() override;