From df0e3a100ca46cd5a2990b3cb94a47928c20130d Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sat, 19 Nov 2016 08:04:06 -0800 Subject: [PATCH] tune initialization for wmax and sortmax Signed-off-by: Nikolaj Bjorner --- src/opt/sortmax.cpp | 21 ++++++++++++++++++++- src/opt/wmax.cpp | 8 ++++++-- src/smt/theory_wmaxsat.cpp | 8 +++++--- src/smt/theory_wmaxsat.h | 2 +- src/util/max_cliques.h | 5 ----- 5 files changed, 32 insertions(+), 12 deletions(-) diff --git a/src/opt/sortmax.cpp b/src/opt/sortmax.cpp index 6df827896..e3cf59de4 100644 --- a/src/opt/sortmax.cpp +++ b/src/opt/sortmax.cpp @@ -58,15 +58,34 @@ namespace opt { ptr_vector out; obj_map::iterator it = soft.begin(), end = soft.end(); for (; it != end; ++it) { + if (!it->m_value.is_unsigned()) { + throw default_exception("sortmax can only handle unsigned weights. Use a different heuristic."); + } unsigned n = it->m_value.get_unsigned(); while (n > 0) { in.push_back(it->m_key); --n; } - m_upper += it->m_value; } m_sort.sorting(in.size(), in.c_ptr(), out); + + // initialize sorting network outputs using the initial assignment. unsigned first = 0; + it = soft.begin(); + for (; it != end; ++it) { + expr_ref tmp(m); + if (m_model->eval(it->m_key, tmp) && m.is_true(tmp)) { + unsigned n = it->m_value.get_unsigned(); + while (n > 0) { + s().assert_expr(out[first]); + ++first; + --n; + } + } + else { + m_upper += it->m_value; + } + } while (l_true == is_sat && first < out.size() && m_lower < m_upper) { trace_bounds("sortmax"); s().assert_expr(out[first]); diff --git a/src/opt/wmax.cpp b/src/opt/wmax.cpp index 7e0e796ca..d8ffe9f66 100644 --- a/src/opt/wmax.cpp +++ b/src/opt/wmax.cpp @@ -48,14 +48,18 @@ namespace opt { rational offset = m_lower; m_upper = offset; bool was_sat = false; + expr_ref_vector disj(m); obj_map::iterator it = soft.begin(), end = soft.end(); for (; it != end; ++it) { - wth().assert_weighted(it->m_key, it->m_value); expr_ref tmp(m); - if (!m_model->eval(it->m_key, tmp) || !m.is_true(tmp)) { + bool is_true = m_model->eval(it->m_key, tmp) && m.is_true(tmp); + expr* c = wth().assert_weighted(it->m_key, it->m_value, is_true); + if (!is_true) { m_upper += it->m_value; + disj.push_back(c); } } + s().assert_expr(mk_or(disj)); trace_bounds("wmax"); while (l_true == is_sat && m_lower < m_upper) { is_sat = s().check_sat(0, 0); diff --git a/src/smt/theory_wmaxsat.cpp b/src/smt/theory_wmaxsat.cpp index b572d8e69..fd07188e5 100644 --- a/src/smt/theory_wmaxsat.cpp +++ b/src/smt/theory_wmaxsat.cpp @@ -87,7 +87,7 @@ void theory_wmaxsat::init_search_eh() { m_propagate = true; } -bool_var theory_wmaxsat::assert_weighted(expr* fml, rational const& w) { +expr* theory_wmaxsat::assert_weighted(expr* fml, rational const& w, bool is_true) { context & ctx = get_context(); ast_manager& m = get_manager(); app_ref var(m), wfml(m); @@ -99,9 +99,11 @@ bool_var theory_wmaxsat::assert_weighted(expr* fml, rational const& w) { m_vars.push_back(var); m_fmls.push_back(fml); m_assigned.push_back(false); - m_rmin_cost += w; + if (!is_true) { + m_rmin_cost += w; + } m_normalize = true; - return register_var(var, true); + return ctx.bool_var2expr(register_var(var, true)); } bool_var theory_wmaxsat::register_var(app* var, bool attach) { diff --git a/src/smt/theory_wmaxsat.h b/src/smt/theory_wmaxsat.h index b0c556c0e..0804c4b68 100644 --- a/src/smt/theory_wmaxsat.h +++ b/src/smt/theory_wmaxsat.h @@ -57,7 +57,7 @@ namespace smt { virtual ~theory_wmaxsat(); void get_assignment(svector& result); virtual void init_search_eh(); - bool_var assert_weighted(expr* fml, rational const& w); + expr* assert_weighted(expr* fml, rational const& w, bool is_true); bool_var register_var(app* var, bool attach); rational const& get_min_cost(); class numeral_trail : public trail { diff --git a/src/util/max_cliques.h b/src/util/max_cliques.h index 8668b9931..e8493a9b9 100644 --- a/src/util/max_cliques.h +++ b/src/util/max_cliques.h @@ -21,10 +21,6 @@ Notes: #include "vector.h" #include "uint_set.h" -class max_cliques_plugin { -public: - virtual unsigned operator()(unsigned i) = 0; -}; template class max_cliques : public T { @@ -130,7 +126,6 @@ public: turn = !turn; } if (clique.size() > 1) { - std::cout << clique.size() << "\n"; cliques.push_back(clique); } }