diff --git a/contrib/cmake/src/opt/CMakeLists.txt b/contrib/cmake/src/opt/CMakeLists.txt index b8d17ec89..05a62b6c2 100644 --- a/contrib/cmake/src/opt/CMakeLists.txt +++ b/contrib/cmake/src/opt/CMakeLists.txt @@ -9,6 +9,7 @@ z3_add_component(opt optsmt.cpp opt_solver.cpp pb_sls.cpp + sortmax.cpp wmax.cpp COMPONENT_DEPENDENCIES sat_solver diff --git a/src/muz/pdr/pdr_context.cpp b/src/muz/pdr/pdr_context.cpp index 6d3a57581..01a9a4416 100644 --- a/src/muz/pdr/pdr_context.cpp +++ b/src/muz/pdr/pdr_context.cpp @@ -584,6 +584,7 @@ namespace pdr { init_atom(pts, rule.get_head(), var_reprs, conj, UINT_MAX); for (unsigned i = 0; i < ut_size; ++i) { if (rule.is_neg_tail(i)) { + dealloc(&var_reprs); throw default_exception("PDR does not support negated predicates in rule tails"); } init_atom(pts, rule.get_tail(i), var_reprs, conj, i); @@ -602,7 +603,13 @@ namespace pdr { var_subst(m, false)(tail[i].get(), var_reprs.size(), (expr*const*)var_reprs.c_ptr(), tmp); conj.push_back(tmp); TRACE("pdr", tout << mk_pp(tail[i].get(), m) << "\n" << mk_pp(tmp, m) << "\n";); - SASSERT(is_ground(tmp)); + if (!is_ground(tmp)) { + std::stringstream msg; + msg << "PDR cannot solve non-ground tails: " << tmp; + IF_VERBOSE(0, verbose_stream() << msg.str() << "\n";); + dealloc(&var_reprs); + throw default_exception(msg.str()); + } } expr_ref fml = pm.mk_and(conj); th_rewriter rw(m); diff --git a/src/opt/maxres.cpp b/src/opt/maxres.cpp index 016f1ee0e..7f9183d9a 100644 --- a/src/opt/maxres.cpp +++ b/src/opt/maxres.cpp @@ -162,7 +162,6 @@ public: if (m_asm2weight.find(e, weight)) { weight += w; m_asm2weight.insert(e, weight); - m_upper += w; return; } if (is_literal(e)) { @@ -174,7 +173,6 @@ public: s().assert_expr(fml); } new_assumption(asum, w); - m_upper += w; } void new_assumption(expr* e, rational const& w) { @@ -805,7 +803,6 @@ public: } lbool init_local() { - m_upper.reset(); m_lower.reset(); m_trail.reset(); obj_map new_soft; diff --git a/src/opt/maxsmt.cpp b/src/opt/maxsmt.cpp index 97ce28166..85b19e427 100644 --- a/src/opt/maxsmt.cpp +++ b/src/opt/maxsmt.cpp @@ -235,6 +235,9 @@ namespace opt { else if (maxsat_engine == symbol("wmax")) { m_msolver = mk_wmax(m_c, m_weights, m_soft_constraints); } + else if (maxsat_engine == symbol("sortmax")) { + m_msolver = mk_sortmax(m_c, m_weights, m_soft_constraints); + } else { warning_msg("solver %s is not recognized, using default 'maxres'", maxsat_engine.str().c_str()); m_msolver = mk_maxres(m_c, m_index, m_weights, m_soft_constraints); diff --git a/src/opt/sortmax.cpp b/src/opt/sortmax.cpp new file mode 100644 index 000000000..6df827896 --- /dev/null +++ b/src/opt/sortmax.cpp @@ -0,0 +1,141 @@ +/*++ +Copyright (c) 2014 Microsoft Corporation + +Module Name: + + sortmax.cpp + +Abstract: + + Theory based MaxSAT. + +Author: + + Nikolaj Bjorner (nbjorner) 2016-11-18 + +Notes: + +--*/ +#include "maxsmt.h" +#include "uint_set.h" +#include "ast_pp.h" +#include "model_smt2_pp.h" +#include "smt_theory.h" +#include "smt_context.h" +#include "opt_context.h" +#include "sorting_network.h" +#include "filter_model_converter.h" + +namespace opt { + + class sortmax : public maxsmt_solver_base { + public: + typedef expr* literal; + typedef ptr_vector literal_vector; + psort_nw m_sort; + expr_ref_vector m_trail; + func_decl_ref_vector m_fresh; + ref m_filter; + sortmax(maxsat_context& c, weights_t& ws, expr_ref_vector const& soft): + maxsmt_solver_base(c, ws, soft), m_sort(*this), m_trail(m), m_fresh(m) {} + + virtual ~sortmax() {} + + lbool operator()() { + obj_map soft; + if (!init()) { + return l_false; + } + lbool is_sat = find_mutexes(soft); + if (is_sat != l_true) { + return is_sat; + } + m_filter = alloc(filter_model_converter, m); + rational offset = m_lower; + m_upper = offset; + expr_ref_vector in(m); + expr_ref tmp(m); + ptr_vector out; + obj_map::iterator it = soft.begin(), end = soft.end(); + for (; it != end; ++it) { + unsigned n = it->m_value.get_unsigned(); + while (n > 0) { + in.push_back(it->m_key); + --n; + } + m_upper += it->m_value; + } + m_sort.sorting(in.size(), in.c_ptr(), out); + unsigned first = 0; + while (l_true == is_sat && first < out.size() && m_lower < m_upper) { + trace_bounds("sortmax"); + s().assert_expr(out[first]); + is_sat = s().check_sat(0, 0); + TRACE("opt", tout << is_sat << "\n"; s().display(tout); tout << "\n";); + if (m.canceled()) { + is_sat = l_undef; + } + if (is_sat == l_true) { + ++first; + s().get_model(m_model); + update_assignment(); + for (; first < out.size() && is_true(out[first]); ++first) { + s().assert_expr(out[first]); + } + TRACE("opt", model_smt2_pp(tout, m, *m_model.get(), 0);); + m_upper = m_lower + rational(out.size() - first); + (*m_filter)(m_model); + } + } + if (is_sat == l_false) { + is_sat = l_true; + m_lower = m_upper; + } + TRACE("opt", tout << "min cost: " << m_upper << "\n";); + return is_sat; + } + + void update_assignment() { + for (unsigned i = 0; i < m_soft.size(); ++i) { + m_assignment[i] = is_true(m_soft[i]); + } + } + + bool is_true(expr* e) { + expr_ref tmp(m); + return m_model->eval(e, tmp) && m.is_true(tmp); + } + + // definitions used for sorting network + literal mk_false() { return m.mk_false(); } + literal mk_true() { return m.mk_true(); } + literal mk_max(literal a, literal b) { return trail(m.mk_or(a, b)); } + literal mk_min(literal a, literal b) { return trail(m.mk_and(a, b)); } + literal mk_not(literal a) { if (m.is_not(a,a)) return a; return trail(m.mk_not(a)); } + + std::ostream& pp(std::ostream& out, literal lit) { return out << mk_pp(lit, m); } + + literal trail(literal l) { + m_trail.push_back(l); + return l; + } + literal fresh() { + expr_ref fr(m.mk_fresh_const("sn", m.mk_bool_sort()), m); + func_decl* f = to_app(fr)->get_decl(); + m_fresh.push_back(f); + m_filter->insert(f); + return trail(fr); + } + + void mk_clause(unsigned n, literal const* lits) { + s().assert_expr(mk_or(m, n, lits)); + } + + }; + + + maxsmt_solver_base* mk_sortmax(maxsat_context& c, weights_t& ws, expr_ref_vector const& soft) { + return alloc(sortmax, c, ws, soft); + } + +} diff --git a/src/opt/wmax.cpp b/src/opt/wmax.cpp index d3ceb9a3b..7e0e796ca 100644 --- a/src/opt/wmax.cpp +++ b/src/opt/wmax.cpp @@ -51,7 +51,10 @@ namespace opt { obj_map::iterator it = soft.begin(), end = soft.end(); for (; it != end; ++it) { wth().assert_weighted(it->m_key, it->m_value); - m_upper += it->m_value; + expr_ref tmp(m); + if (!m_model->eval(it->m_key, tmp) || !m.is_true(tmp)) { + m_upper += it->m_value; + } } trace_bounds("wmax"); while (l_true == is_sat && m_lower < m_upper) { diff --git a/src/opt/wmax.h b/src/opt/wmax.h index 094bb8644..3d9d206ad 100644 --- a/src/opt/wmax.h +++ b/src/opt/wmax.h @@ -25,5 +25,7 @@ Notes: namespace opt { maxsmt_solver_base* mk_wmax(maxsat_context& c, weights_t & ws, expr_ref_vector const& soft); + maxsmt_solver_base* mk_sortmax(maxsat_context& c, weights_t & ws, expr_ref_vector const& soft); + } #endif diff --git a/src/sat/sat_solver.cpp b/src/sat/sat_solver.cpp index 7acfda822..9cf13afe4 100644 --- a/src/sat/sat_solver.cpp +++ b/src/sat/sat_solver.cpp @@ -21,6 +21,7 @@ Revision History: #include"luby.h" #include"trace.h" #include"sat_bceq.h" +#include"max_cliques.h" // define to update glue during propagation #define UPDATE_GLUE @@ -3062,74 +3063,44 @@ namespace sat { // // ----------------------- + struct neg_literal { + unsigned negate(unsigned idx) { + return (~to_literal(idx)).index(); + } + }; + lbool solver::find_mutexes(literal_vector const& lits, vector & mutexes) { - literal_vector ps(lits); + max_cliques mc; m_user_bin_clauses.reset(); m_binary_clause_graph.reset(); collect_bin_clauses(m_user_bin_clauses, true); collect_bin_clauses(m_user_bin_clauses, false); + hashtable, default_eq > seen_bc; for (unsigned i = 0; i < m_user_bin_clauses.size(); ++i) { literal l1 = m_user_bin_clauses[i].first; literal l2 = m_user_bin_clauses[i].second; - m_binary_clause_graph.reserve(l1.index() + 1); - m_binary_clause_graph.reserve(l2.index() + 1); - m_binary_clause_graph.reserve((~l1).index() + 1); - m_binary_clause_graph.reserve((~l2).index() + 1); - m_binary_clause_graph[l1.index()].push_back(l2); - m_binary_clause_graph[l2.index()].push_back(l1); + literal_pair p(l1, l2); + if (!seen_bc.contains(p)) { + seen_bc.insert(p); + mc.add_edge(l1.index(), l2.index()); + } } + vector _mutexes; + unsigned_vector ps; for (unsigned i = 0; i < lits.size(); ++i) { - m_binary_clause_graph.reserve(lits[i].index() + 1); - m_binary_clause_graph.reserve((~lits[i]).index() + 1); + ps.push_back(lits[i].index()); } - bool non_empty = true; - m_seen[0].reset(); - while (non_empty) { - literal_vector mutex; - bool turn = false; - m_reachable[turn] = ps; - while (!m_reachable[turn].empty()) { - literal p = m_reachable[turn].pop(); - if (m_seen[0].contains(p)) { - continue; - } - m_reachable[turn].remove(p); - m_seen[0].insert(p); - mutex.push_back(p); - if (m_reachable[turn].empty()) { - break; - } - m_reachable[!turn].reset(); - get_reachable(p, m_reachable[turn], m_reachable[!turn]); - turn = !turn; + mc.cliques(ps, _mutexes); + for (unsigned i = 0; i < _mutexes.size(); ++i) { + literal_vector lits; + for (unsigned j = 0; j < _mutexes[i].size(); ++j) { + lits.push_back(to_literal(_mutexes[i][j])); } - if (mutex.size() > 1) { - mutexes.push_back(mutex); - } - non_empty = !mutex.empty(); + mutexes.push_back(lits); } return l_true; } - void solver::get_reachable(literal p, literal_set const& goal, literal_set& reachable) { - m_seen[1].reset(); - m_todo.reset(); - m_todo.push_back(p); - while (!m_todo.empty()) { - p = m_todo.back(); - m_todo.pop_back(); - if (m_seen[1].contains(p)) { - continue; - } - m_seen[1].insert(p); - literal np = ~p; - if (goal.contains(np)) { - reachable.insert(np); - } - m_todo.append(m_binary_clause_graph[np.index()]); - } - } - // ----------------------- // // Consequence generation. diff --git a/src/sat/sat_solver.h b/src/sat/sat_solver.h index a44f07a23..bcd9a66d2 100644 --- a/src/sat/sat_solver.h +++ b/src/sat/sat_solver.h @@ -451,14 +451,9 @@ namespace sat { u_map m_antecedents; vector m_binary_clause_graph; - literal_set m_reachable[2]; - literal_set m_seen[2]; - literal_vector m_todo; void extract_assumptions(literal lit, index_set& s); - void get_reachable(literal p, literal_set const& goal, literal_set& reachable); - lbool get_consequences(literal_vector const& assms, literal_vector const& lits, vector& conseq); void delete_unfixed(literal_set& unfixed); diff --git a/src/sat/sat_types.h b/src/sat/sat_types.h index 509cc58ba..28d8d761a 100644 --- a/src/sat/sat_types.h +++ b/src/sat/sat_types.h @@ -96,6 +96,7 @@ namespace sat { }; const literal null_literal; + struct literal_hash : obj_hash {}; inline literal to_literal(unsigned x) { return literal(x); } inline bool operator<(literal const & l1, literal const & l2) { return l1.m_val < l2.m_val; } diff --git a/src/smt/params/qi_params.cpp b/src/smt/params/qi_params.cpp index a341040ce..8182222e4 100644 --- a/src/smt/params/qi_params.cpp +++ b/src/smt/params/qi_params.cpp @@ -62,4 +62,4 @@ void qi_params::display(std::ostream & out) const { DISPLAY_PARAM(m_mbqi_trace); DISPLAY_PARAM(m_mbqi_force_template); DISPLAY_PARAM(m_mbqi_id); -} \ No newline at end of file +} diff --git a/src/smt/smt_consequences.cpp b/src/smt/smt_consequences.cpp index 667411be5..e6edbf87e 100644 --- a/src/smt/smt_consequences.cpp +++ b/src/smt/smt_consequences.cpp @@ -20,6 +20,8 @@ Revision History: #include "ast_util.h" #include "datatype_decl_plugin.h" #include "model_pp.h" +#include "max_cliques.h" +#include "stopwatch.h" namespace smt { @@ -367,67 +369,46 @@ namespace smt { << ")\n"; } + struct neg_literal { + unsigned negate(unsigned i) { + return (~to_literal(i)).index(); + } + }; lbool context::find_mutexes(expr_ref_vector const& vars, vector& mutexes) { - uint_set lits; + unsigned_vector ps; + max_cliques mc; + expr_ref lit(m_manager); for (unsigned i = 0; i < vars.size(); ++i) { expr* n = vars[i]; bool neg = m_manager.is_not(n, n); if (b_internalized(n)) { - lits.insert(literal(get_bool_var(n), neg).index()); + ps.push_back(literal(get_bool_var(n), neg).index()); } } - while (!lits.empty()) { - literal_vector mutex; - uint_set other(lits); - while (!other.empty()) { - uint_set conseq; - literal p = to_literal(*other.begin()); - other.remove(p.index()); - mutex.push_back(p); - if (other.empty()) { - break; + for (unsigned i = 0; i < m_watches.size(); ++i) { + watch_list & w = m_watches[i]; + for (literal const* it = w.begin_literals(), *end = w.end_literals(); it != end; ++it) { + unsigned idx1 = (~to_literal(i)).index(); + unsigned idx2 = it->index(); + if (idx1 < idx2) { + mc.add_edge(idx1, idx2); } - get_reachable(p, other, conseq); - other = conseq; - } - if (mutex.size() > 1) { - expr_ref_vector mux(m_manager); - for (unsigned i = 0; i < mutex.size(); ++i) { - expr_ref e(m_manager); - literal2expr(mutex[i], e); - mux.push_back(e); - } - mutexes.push_back(mux); - } - for (unsigned i = 0; i < mutex.size(); ++i) { - lits.remove(mutex[i].index()); } } + vector _mutexes; + mc.cliques(ps, _mutexes); + for (unsigned i = 0; i < _mutexes.size(); ++i) { + expr_ref_vector lits(m_manager); + for (unsigned j = 0; j < _mutexes[i].size(); ++j) { + literal2expr(to_literal(_mutexes[i][j]), lit); + lits.push_back(lit); + } + mutexes.push_back(lits); + } return l_true; } - void context::get_reachable(literal p, uint_set& goal, uint_set& reachable) { - uint_set seen; - literal_vector todo; - todo.push_back(p); - while (!todo.empty()) { - // std::cout << "todo: " << todo.size() << "\n"; - p = todo.back(); - todo.pop_back(); - if (seen.contains(p.index())) { - continue; - } - seen.insert(p.index()); - literal np = ~p; - if (goal.contains(np.index())) { - reachable.insert(np.index()); - } - watch_list & w = m_watches[np.index()]; - todo.append(static_cast(w.end_literals() - w.begin_literals()), w.begin_literals()); - } - } - // // Validate, in a slow pass, that the current consequences are correctly // extracted. diff --git a/src/smt/smt_context.h b/src/smt/smt_context.h index 84459d272..21d628e1a 100644 --- a/src/smt/smt_context.h +++ b/src/smt/smt_context.h @@ -1367,11 +1367,6 @@ namespace smt { void validate_consequences(expr_ref_vector const& assumptions, expr_ref_vector const& vars, expr_ref_vector const& conseq, expr_ref_vector const& unfixed); - /* - \brief Auxiliry function for mutex finding. - */ - - void get_reachable(literal p, uint_set& goal, uint_set& reached); public: context(ast_manager & m, smt_params & fp, params_ref const & p = params_ref()); diff --git a/src/smt/smt_model_checker.cpp b/src/smt/smt_model_checker.cpp index c17211664..a7f415aad 100644 --- a/src/smt/smt_model_checker.cpp +++ b/src/smt/smt_model_checker.cpp @@ -194,6 +194,7 @@ namespace smt { } tout << "\n";); + max_generation = std::max(m_qm->get_generation(q), max_generation); add_instance(q, bindings, max_generation); return true; } diff --git a/src/util/max_cliques.h b/src/util/max_cliques.h new file mode 100644 index 000000000..8668b9931 --- /dev/null +++ b/src/util/max_cliques.h @@ -0,0 +1,140 @@ +/*++ +Copyright (c) 2016 Microsoft Corporation + +Module Name: + + max_cliques.h + +Abstract: + + Utility for enumerating locally maximal sub cliques. + +Author: + + Nikolaj Bjorner (nbjorner) 2016-11-18 + +Notes: + + +--*/ + +#include "vector.h" +#include "uint_set.h" + +class max_cliques_plugin { +public: + virtual unsigned operator()(unsigned i) = 0; +}; + +template +class max_cliques : public T { + vector m_next, m_tc; + uint_set m_reachable[2]; + uint_set m_seen1, m_seen2; + unsigned_vector m_todo; + + void get_reachable(unsigned p, uint_set const& goal, uint_set& reachable) { + m_seen1.reset(); + m_todo.reset(); + m_todo.push_back(p); + for (unsigned i = 0; i < m_todo.size(); ++i) { + p = m_todo[i]; + if (m_seen1.contains(p)) { + continue; + } + m_seen1.insert(p); + if (m_seen2.contains(p)) { + unsigned_vector const& tc = m_tc[p]; + for (unsigned j = 0; j < tc.size(); ++j) { + unsigned np = tc[j]; + if (goal.contains(np)) { + reachable.insert(np); + } + } + } + else { + unsigned np = negate(p); + if (goal.contains(np)) { + reachable.insert(np); + } + m_todo.append(next(np)); + } + } + for (unsigned i = m_todo.size(); i > 0; ) { + --i; + p = m_todo[i]; + if (m_seen2.contains(p)) { + continue; + } + m_seen2.insert(p); + unsigned np = negate(p); + unsigned_vector& tc = m_tc[p]; + if (goal.contains(np)) { + tc.push_back(np); + } + else { + unsigned_vector const& succ = next(np); + for (unsigned j = 0; j < succ.size(); ++j) { + tc.append(m_tc[succ[j]]); + } + } + } + } + + + + + + unsigned_vector const& next(unsigned vertex) const { return m_next[vertex]; } + +public: + max_cliques() {} + + void add_edge(unsigned src, unsigned dst) { + m_next.reserve(std::max(src, dst) + 1); + m_next.reserve(std::max(negate(src), negate(dst)) + 1); + m_next[src].push_back(dst); + m_next[dst].push_back(src); + } + + void cliques(unsigned_vector const& ps, vector& cliques) { + unsigned max = 0; + unsigned num_ps = ps.size(); + for (unsigned i = 0; i < num_ps; ++i) { + unsigned p = ps[i]; + unsigned np = negate(p); + max = std::max(max, std::max(np, p) + 1); + } + m_next.reserve(max); + m_tc.reserve(max); + unsigned_vector clique; + uint_set vars; + for (unsigned i = 0; i < num_ps; ++i) { + vars.insert(ps[i]); + } + + while (!vars.empty()) { + clique.reset(); + bool turn = false; + m_reachable[turn] = vars; + while (!m_reachable[turn].empty()) { + unsigned p = *m_reachable[turn].begin(); + m_reachable[turn].remove(p); + vars.remove(p); + clique.push_back(p); + if (m_reachable[turn].empty()) { + break; + } + m_reachable[!turn].reset(); + get_reachable(p, m_reachable[turn], m_reachable[!turn]); + turn = !turn; + } + if (clique.size() > 1) { + std::cout << clique.size() << "\n"; + cliques.push_back(clique); + } + } + } + + +}; diff --git a/src/util/sorting_network.h b/src/util/sorting_network.h index 87d8bbf3f..2c819db04 100644 --- a/src/util/sorting_network.h +++ b/src/util/sorting_network.h @@ -744,6 +744,7 @@ Notes: return vc_cmp()*std::min(a-1,b); } + public: void sorting(unsigned n, literal const* xs, literal_vector& out) { TRACE("pb", tout << "sorting: " << n << "\n";); switch(n) { @@ -773,8 +774,9 @@ Notes: TRACE("pb", tout << "sorting: " << n << "\n"; pp(tout << "in:", n, xs) << "\n"; pp(tout << "out:", out) << "\n";); - } + + private: vc vc_sorting(unsigned n) { switch(n) { case 0: return vc(0,0); diff --git a/src/util/uint_set.h b/src/util/uint_set.h index 525638f4f..e78a4b0b6 100644 --- a/src/util/uint_set.h +++ b/src/util/uint_set.h @@ -163,10 +163,11 @@ public: class iterator { uint_set const* m_set; unsigned m_index; + unsigned m_last; - bool invariant() const { return m_index <= m_set->get_max_elem(); } + bool invariant() const { return m_index <= m_last; } - bool at_end() const { return m_index == m_set->get_max_elem(); } + bool at_end() const { return m_index == m_last; } void scan_idx() { SASSERT(invariant()); @@ -200,7 +201,7 @@ public: } public: iterator(uint_set const& s, bool at_end): - m_set(&s), m_index(at_end?s.get_max_elem():0) { + m_set(&s), m_index(at_end?s.get_max_elem():0), m_last(s.get_max_elem()) { scan(); SASSERT(invariant()); } @@ -212,6 +213,7 @@ public: iterator & operator=(iterator const& other) { m_set = other.m_set; m_index = other.m_index; + m_last = other.m_last; return *this; } };