diff --git a/src/smt/smt_solver.cpp b/src/smt/smt_solver.cpp index f9ab54327..1f90030d7 100644 --- a/src/smt/smt_solver.cpp +++ b/src/smt/smt_solver.cpp @@ -27,6 +27,7 @@ Notes: #include "smt/params/smt_params_helper.hpp" #include "solver/solver_na2as.h" #include "solver/mus.h" +#include "solver/smtmus.h" namespace { @@ -266,12 +267,28 @@ namespace { if (!m_minimizing_core && smt_params_helper(get_params()).core_minimize()) { scoped_minimize_core scm(*this); - mus mus(*this); - mus.add_soft(r.size(), r.data()); - expr_ref_vector r2(m); - if (l_true == mus.get_mus(r2)) { - r.reset(); - r.append(r2); + bool use_smtmus = false; + if (solver::get_params().get_bool("solver.smtmus", false)) + use_smtmus = true; + if (gparams().get_value("solver.smtmus") == "true") + use_smtmus = true; + if (use_smtmus) { + smtmus mus(*this); + mus.add_soft(r.size(), r.data()); + expr_ref_vector r2(m); + if (l_true == mus.get_mus(r2)) { + r.reset(); + r.append(r2); + } + } + else { + mus mus(*this); + mus.add_soft(r.size(), r.data()); + expr_ref_vector r2(m); + if (l_true == mus.get_mus(r2)) { + r.reset(); + r.append(r2); + } } } diff --git a/src/solver/smtmus.cpp b/src/solver/smtmus.cpp index 960e6d529..c9b8e4ef9 100644 --- a/src/solver/smtmus.cpp +++ b/src/solver/smtmus.cpp @@ -19,6 +19,7 @@ Author: #include "solver/smtmus.h" #include "ast/ast_pp.h" #include "ast/ast_util.h" +#include "ast/for_each_expr.h" #include "model/model_evaluator.h" #include "model/model.h" #include "ast/arith_decl_plugin.h" @@ -31,57 +32,152 @@ struct smtmus::imp { solver& m_solver; ast_manager& m; arith_util a; - expr_ref_vector m_lit2expr; obj_map m_expr2lit; model_ref m_model; expr_ref_vector m_soft; vector m_soft_clauses; obj_map m_lit2vars; + obj_map m_occurs; unsigned m_rotated = 0; unsigned p_max_cores = 30; bool p_crit_ext = false; + unsigned p_limit = 1; imp(solver& s) : - m_solver(s), m(s.get_manager()), a(m), m_lit2expr(m), m_soft(m) + m_solver(s), m(s.get_manager()), a(m), m_soft(m) {} ~imp() { - for (auto& [k, v] : m_lit2vars) - dealloc(v); + reset(); } - unsigned add_soft(expr* lit) { - unsigned idx = m_lit2expr.size(); + unsigned idx = m_soft.size(); m_expr2lit.insert(lit, idx); - m_lit2expr.push_back(lit); - TRACE("mus", tout << idx << ": " << mk_pp(lit, m) << "\n" << m_lit2expr << "\n";); + m_soft.push_back(lit); + TRACE("mus", tout << idx << ": " << mk_pp(lit, m) << "\n" << m_soft << "\n";); return idx; } - void init() { - // initialize soft_clauses based on control variables in mus, or if mus already is a clause. + // initialize soft_clauses based on control variables in mus, or if mus already is a clause. + void init_soft_clauses() { + obj_map lit2clause; + vector clauses; + obj_hashtable softs; + bool initialized = false; + auto init_lit2clause = [&]() { + if (initialized) + return; + initialized = true; + for (expr* s : m_soft) + softs.insert(s); + for (auto* f : m_solver.get_assertions()) { + expr_ref_vector ors(m); + flatten_or(f, ors); + unsigned idx = 0; + for (expr* e : ors) { + expr* s = nullptr; + if (m.is_not(e, s) && softs.contains(s)) { + ors[idx] = ors.back(); + ors.pop_back(); + if (lit2clause.find(s, idx)) { + expr_ref cl(m.mk_and(mk_or(clauses[idx]), mk_or(ors)), m); + ors.reset(); + ors.push_back(cl); + clauses[idx].reset(); + clauses[idx].append(ors); + } + else { + lit2clause.insert(s, clauses.size()); + clauses.push_back(ors); + } + break; + } + ++idx; + } + } + }; + unsigned cl; + for (expr* s : m_soft) { + expr_ref_vector clause(m); + if (m.is_or(s)) + clause.append(to_app(s)->get_num_args(), to_app(s)->get_args()); + else if (is_uninterp_const(s)) { + init_lit2clause(); + if (lit2clause.find(s, cl)) + clause.append(clauses[cl]); + else + clause.push_back(s); + } + else + clause.push_back(s); + m_soft_clauses.push_back(clause); + } + + TRACE("satmus", + for (expr* s : m_soft) + tout << "soft " << mk_pp(s, m) << "\n"; + for (auto const& clause : m_soft_clauses) + tout << "clause " << clause << "\n";); } + void init_occurs() { + unsigned idx = 0; + for (auto const& clause : m_soft_clauses) { + for (auto* lit : clause) { + auto const& vars = get_vars(lit); + for (auto* v : vars) { + if (!m_occurs.contains(v)) + m_occurs.insert(v, alloc(unsigned_vector)); + auto& occ = *m_occurs[v]; + if (!occ.empty() && occ.back() == idx) + continue; + occ.push_back(idx); + } + } + ++idx; + } + } + + void reset() { + m_model.reset(); + for (auto& [k, v] : m_lit2vars) + dealloc(v); + m_lit2vars.reset(); + for (auto& [k, v] : m_occurs) + dealloc(v); + m_occurs.reset(); + m_soft_clauses.reset(); + } + + void init() { + init_soft_clauses(); + init_occurs(); + } lbool get_mus(expr_ref_vector& mus) { - m_model.reset(); mus.reset(); - if (m_lit2expr.size() == 1) { - mus.push_back(m_lit2expr.back()); + if (m_soft.size() == 1) { + mus.push_back(m_soft.back()); return l_true; } - return l_undef; + init(); + + bool_vector shrunk(m_soft_clauses.size(), true); + + if (!shrink(shrunk)) + return l_undef; + + for (unsigned i = 0; i < shrunk.size(); ++i) + if (shrunk[i]) + mus.push_back(m_soft.get(i)); + return l_true; } - // extract clauses corresponding to added soft constraints. - // extract integer, real variables from clauses - // - - bool_vector shrink() { + bool shrink(bool_vector& shrunk) { bool_vector crits(m_soft_clauses.size(), false); - bool_vector shrunk(m_soft_clauses.size(), true); + unsigned max_cores = p_max_cores; for (unsigned i = 0; i < m_soft_clauses.size(); ++i) { if (!shrunk[i] || crits[i]) @@ -98,9 +194,10 @@ struct smtmus::imp { --max_cores; break; default: - break; + return false; } } + return true; } unsigned count_ones(bool_vector const& v) { @@ -150,7 +247,7 @@ struct smtmus::imp { for (auto const& lit : m_soft_clauses[i]) { auto const& vars = get_vars(lit); for (auto v : vars) { - expr_ref_vector flips = rotate_get_flips(lit, v, mdl, 1); + expr_ref_vector flips = rotate_get_flips(lit, v, mdl, p_limit); for (auto& flip : flips) { if (!mdl.eval(v, prev_value)) continue; @@ -177,7 +274,9 @@ struct smtmus::imp { } void extract_vars(expr* e, func_decl_ref_vector& vars) { - NOT_IMPLEMENTED_YET(); + for (expr* t : subterms::ground(expr_ref(e, m))) + if (is_uninterp_const(t)) + vars.push_back(to_app(t)->get_decl()); } expr_ref_vector rotate_get_flips(expr* lit, func_decl* v, model& mdl, unsigned limit) { @@ -194,6 +293,12 @@ struct smtmus::imp { } result = rotate_get_eq_flips(lit, v, mdl, limit); + if (!result.empty()) + return result; + result = rotate_get_lia_flips(lit, v, mdl, limit); + if (!result.empty()) + return result; + result = rotate_get_lra_flips(lit, v, mdl, limit); if (!result.empty()) return result; return rotate_get_flips_agnostic(lit, v, mdl, limit); @@ -212,6 +317,16 @@ struct smtmus::imp { return flips; } + expr_ref_vector rotate_get_lia_flips(expr* lit, func_decl* v, model& mdl, unsigned limit) { + expr_ref_vector flips(m); + return flips; + } + + expr_ref_vector rotate_get_lra_flips(expr* lit, func_decl* v, model& mdl, unsigned limit) { + expr_ref_vector flips(m); + return flips; + } + expr_ref_vector rotate_get_flips_agnostic(expr* lit, func_decl* v, model& mdl, unsigned limit) { solver_ref s2 = mk_smt2_solver(m, m_solver.get_params()); s2->assert_expr(lit); @@ -238,7 +353,20 @@ struct smtmus::imp { bool rotate_get_falsified(bool_vector const& formula, model& mdl, func_decl* f, unsigned& falsified) { falsified = UINT_MAX; - NOT_IMPLEMENTED_YET(); + for (auto i : *m_occurs[f]) { + if (formula[i] && !is_true(mdl, m_soft_clauses.get(i))) { + if (falsified != UINT_MAX) + return false; + falsified = i; + } + } + return falsified != UINT_MAX; + } + + bool is_true(model& mdl, expr_ref_vector const& clause) { + for (expr* lit : clause) + if (m.is_true(lit)) + return true; return false; } @@ -250,10 +378,47 @@ struct smtmus::imp { critical_extension(formula, crits, i); } - void critical_extension(bool_vector const& formula, bool_vector& crits, unsigned i) { - NOT_IMPLEMENTED_YET(); + unsigned critical_extension(bool_vector const& formula, bool_vector& crits, unsigned i) { + unsigned unused_vars = 0; + ast_mark checked_vars; + for (auto* lit : m_soft_clauses[i]) { + auto const& vars = get_vars(lit); + for (auto* v : vars) { + if (checked_vars.is_marked(v)) + continue; + checked_vars.mark(v, true); + unsigned_vector hits; + for (auto j : *m_occurs[v]) { + if (formula[j] && j != i && (are_conflicting(i, j, v) || m.is_bool(v->get_range()))) + hits.push_back(j); + } + if (hits.size() == 1) + mark_critical(formula, crits, hits[0]); + else + ++unused_vars; + } + } + return unused_vars; } + bool are_conflicting(unsigned i, unsigned j, func_decl* v) { + if (!lia_are_conflicting(i, j, v)) + return false; + if (!lra_are_conflicting(i, j, v)) + return false; + // TBD: what is the right default value? + return true; + } + + bool lia_are_conflicting(unsigned i, unsigned j, func_decl* v) { + NOT_IMPLEMENTED_YET(); + return true; + } + + bool lra_are_conflicting(unsigned i, unsigned j, func_decl* v) { + NOT_IMPLEMENTED_YET(); + return true; + } }; smtmus::smtmus(solver& s) { diff --git a/src/solver/solver_params.pyg b/src/solver/solver_params.pyg index 21d0ab530..0d4d29329 100644 --- a/src/solver/solver_params.pyg +++ b/src/solver/solver_params.pyg @@ -4,6 +4,7 @@ def_module_params('solver', export=True, params=(('smtlib2_log', SYMBOL, '', "file to save solver interaction"), ('cancel_backup_file', SYMBOL, '', "file to save partial search state if search is canceled"), + ('smtmus', BOOL, False, "use smt mus extractor instead of default mus"), ('timeout', UINT, UINT_MAX, "timeout on the solver object; overwrites a global timeout"), ))