From 94b3a4681114b04701e954d7604b7763a6037a2d Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Tue, 18 Mar 2014 16:06:04 -0700 Subject: [PATCH] working on pb sls Signed-off-by: Nikolaj Bjorner --- src/opt/pb_sls.cpp | 186 ++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 177 insertions(+), 9 deletions(-) diff --git a/src/opt/pb_sls.cpp b/src/opt/pb_sls.cpp index 0aee3cbd5..3ac91b33a 100644 --- a/src/opt/pb_sls.cpp +++ b/src/opt/pb_sls.cpp @@ -1,6 +1,7 @@ #include "pb_sls.h" #include "smt_literal.h" #include "ast_pp.h" +#include "uint_set.h" namespace smt { struct pb_sls::imp { @@ -41,21 +42,29 @@ namespace smt { vector m_clauses; // clauses to be satisfied vector m_soft; // soft constraints vector m_weights; // weights of soft constraints - rational m_value; // current value of soft constraints - vector m_pos, m_neg; // positive and negative occurs. + rational m_penalty; // current penalty of soft constraints + vector m_hard_occ, m_soft_occ; // variable occurrence svector m_assignment; // current assignment. obj_map m_expr2var; // map expressions to Boolean variables. ptr_vector m_var2expr; // reverse map - + uint_set m_hard_false; // list of hard clauses that are false. + uint_set m_soft_false; // list of soft clauses that are false. + unsigned m_max_flips; imp(ast_manager& m): m(m), pb(m), m_cancel(false) - {} + { + m_max_flips = 100; + } ~imp() { } + unsigned max_flips() { + return m_max_flips; + } + void add(expr* f) { clause cls(mgr); if (compile_clause(f, cls)) { @@ -72,12 +81,17 @@ namespace smt { void init_value(expr* f, bool b) { literal lit = mk_literal(f); SASSERT(!lit.sign()); - //if (m_assignment[lit.var()] != b) { m_assignment[lit.var()] = b; - //} } lbool operator()() { + init(); + for (unsigned i = 0; i < max_flips(); ++i) { + flip(); + if (m_cancel) { + return l_undef; + } + } return l_undef; } @@ -102,6 +116,157 @@ namespace smt { void updt_params(params_ref& p) { } + bool eval(clause& cls) { + unsigned sz = cls.m_lits.size(); + cls.m_value.reset(); + for (unsigned i = 0; i < sz; ++i) { + if (get_value(cls.m_lits[i])) { + cls.m_value += cls.m_weights[i]; + } + } + if (cls.m_eq) { + return cls.m_value == cls.m_k; + } + else { + return cls.m_value >= cls.m_k; + } + } + + void init_occ(vector const& clauses, vector & occ) { + for (unsigned i = 0; i < clauses.size(); ++i) { + clause const& cls = clauses[i]; + for (unsigned j = 0; j < cls.m_lits.size(); ++j) { + literal lit = cls.m_lits[j]; + occ[lit.var()].push_back(i); + } + } + } + + void init() { + // initialize the occurs vectors. + init_occ(m_clauses, m_hard_occ); + init_occ(m_soft, m_soft_occ); + // add clauses that are false. + for (unsigned i = 0; i < m_clauses.size(); ++i) { + if (!eval(m_clauses[i])) { + m_hard_false.insert(i); + } + } + m_penalty.reset(); + for (unsigned i = 0; i < m_soft.size(); ++i) { + if (!eval(m_soft[i])) { + m_soft_false.insert(i); + m_penalty += m_weights[i]; + } + } + } + + void flip() { + if (m_hard_false.empty()) { + flip_soft(); + } + else { + flip_hard(); + } + } + + void flip_hard() { + SASSERT(!m_hard_false.empty()); + clause const& cls = pick_hard_clause(); + int break_count; + int min_bc = INT_MAX; + unsigned min_bc_index = 0; + for (unsigned i = 0; i < cls.m_lits.size(); ++i) { + literal lit = cls.m_lits[i]; + break_count = flip(lit); + if (break_count <= 0) { + return; + } + if (break_count < min_bc) { + min_bc = break_count; + min_bc_index = i; + } + VERIFY(-break_count == flip(~lit)); + } + // just do a greedy move: + flip(cls.m_lits[min_bc_index]); + } + + void flip_soft() { + NOT_IMPLEMENTED_YET(); + } + + // crude selection strategy. + clause const& pick_hard_clause() { + SASSERT(!m_hard_false.empty()); + uint_set::iterator it = m_hard_false.begin(); + uint_set::iterator end = m_hard_false.end(); + SASSERT(it != end); + return m_clauses[*it]; + } + + clause const& pick_soft_clause() { + SASSERT(!m_soft_false.empty()); + uint_set::iterator it = m_soft_false.begin(); + uint_set::iterator end = m_soft_false.end(); + SASSERT(it != end); + unsigned index = *it; + rational penalty = m_weights[index]; + ++it; + for (; it != end; ++it) { + if (m_weights[*it] > penalty) { + index = *it; + penalty = m_weights[*it]; + } + } + return m_soft[index]; + } + + int flip(literal l) { + SASSERT(get_value(l)); + m_assignment[l.var()] = !m_assignment[l.var()]; + int break_count = 0; + { + unsigned_vector const& occ = m_hard_occ[l.var()]; + for (unsigned i = 0; i < occ.size(); ++i) { + unsigned j = occ[i]; + if (eval(m_clauses[j])) { + if (m_hard_false.contains(j)) { + break_count--; + m_hard_false.remove(j); + } + } + else { + if (!m_hard_false.contains(j)) { + break_count++; + m_hard_false.insert(j); + } + } + } + } + { + unsigned_vector const& occ = m_soft_occ[l.var()]; + for (unsigned i = 0; i < occ.size(); ++i) { + unsigned j = occ[i]; + if (eval(m_soft[j])) { + if (m_soft_false.contains(j)) { + m_penalty -= m_weights[j]; + m_soft_false.remove(j); + } + } + else { + if (!m_soft_false.contains(j)) { + m_penalty += m_weights[j]; + m_soft_false.insert(j); + } + } + } + } + + SASSERT(get_value(~l)); + return break_count; + } + literal mk_literal(expr* f) { literal result; bool sign = false; @@ -117,10 +282,10 @@ namespace smt { else { unsigned var; if (!m_expr2var.find(f, var)) { - var = m_pos.size(); + var = m_hard_occ.size(); SASSERT(m_expr2var.size() == var); - m_pos.push_back(unsigned_vector()); - m_neg.push_back(unsigned_vector()); + m_hard_occ.push_back(unsigned_vector()); + m_soft_occ.push_back(unsigned_vector()); m_assignment.push_back(false); m_expr2var.insert(f, var); m_var2expr.push_back(f); @@ -146,6 +311,7 @@ namespace smt { SASSERT(coeff.is_int()); lit = mk_literal(args[i]); if (lit == null_literal) return false; + SASSERT(lit != false_literal && lit != true_literal); cls.m_lits.push_back(lit); cls.m_weights.push_back(coeff.to_mpq().numerator()); if (get_value(lit)) { @@ -161,6 +327,7 @@ namespace smt { for (unsigned i = 0; i < sz; ++i) { lit = mk_literal(args[i]); if (lit == null_literal) return false; + SASSERT(lit != false_literal && lit != true_literal); cls.m_lits.push_back(lit); cls.m_weights.push_back(mpz(1)); if (get_value(lit)) { @@ -173,6 +340,7 @@ namespace smt { else { lit = mk_literal(f); if (lit == null_literal) return false; + SASSERT(lit != false_literal && lit != true_literal); cls.m_lits.push_back(lit); cls.m_weights.push_back(mpz(1)); cls.m_eq = true;