3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-15 13:28:47 +00:00

update hitting set implementation

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2014-06-23 11:28:38 -07:00
parent 04407938be
commit 519c9dba25
2 changed files with 213 additions and 100 deletions

View file

@ -130,6 +130,7 @@ namespace simplex {
void simplex<Ext>::add_patch(var_t v) {
SASSERT(is_base(v));
if (outside_bounds(v)) {
TRACE("simplex", tout << "Add patch: v" << v << "\n";);
m_to_patch.insert(v);
}
}
@ -200,12 +201,18 @@ namespace simplex {
var_info& vi = m_vars[var];
em.set(vi.m_lower, b);
vi.m_lower_valid = true;
TRACE("simplex", em.display(tout << "v" << var << " lower: ", b);
em.display(tout << " value: ", vi.m_value););
SASSERT(!vi.m_upper_valid || em.le(b, vi.m_upper));
if (!vi.m_is_base && em.lt(vi.m_value, b)) {
scoped_eps_numeral delta(em);
em.sub(b, vi.m_value, delta);
update_value(var, delta);
}
else if (vi.m_is_base && em.lt(vi.m_value, b)) {
SASSERT(outside_bounds(var));
add_patch(var);
}
SASSERT(well_formed());
}
@ -220,6 +227,10 @@ namespace simplex {
em.sub(b, vi.m_value, delta);
update_value(var, delta);
}
else if (vi.m_is_base && em.lt(b, vi.m_value)) {
SASSERT(outside_bounds(var));
add_patch(var);
}
SASSERT(well_formed());
}

View file

@ -30,7 +30,6 @@ typedef simplex::sparse_matrix<simplex::mpz_ext> sparse_matrix;
namespace opt {
struct hitting_sets::imp {
typedef unsigned_vector set;
class justification {
public:
enum kind_t { AXIOM, DECISION, CLAUSE };
@ -55,14 +54,49 @@ namespace opt {
kind_t kind() const { return m_kind; }
bool pos() const { return m_pos; }
};
class set {
unsigned m_num_elems;
unsigned m_elems[0];
set(): m_num_elems(0) {}
public:
static set* mk(small_object_allocator& alloc, unsigned sz, unsigned const* elems) {
unsigned size = (sz+1)*sizeof(unsigned);
void * mem = alloc.allocate(size);
set* result = new (mem) set();
result->m_num_elems = sz;
memcpy(result->m_elems, elems, sizeof(unsigned)*sz);
return result;
}
inline unsigned operator[](unsigned idx) const {
SASSERT(idx < m_num_elems);
return m_elems[idx];
}
inline unsigned& operator[](unsigned idx) {
SASSERT(idx < m_num_elems);
return m_elems[idx];
}
unsigned size() const { return m_num_elems; }
unsigned alloc_size() const { return (m_num_elems + 1)*sizeof(unsigned); }
bool empty() const { return 0 == size(); }
};
volatile bool m_cancel;
rational m_lower;
rational m_upper;
vector<rational> m_weights;
vector<rational> m_weights_inv;
rational m_max_weight;
rational m_denominator;
vector<set> m_T;
vector<set> m_F;
small_object_allocator m_alloc;
ptr_vector<set> m_T;
ptr_vector<set> m_F;
svector<lbool> m_value;
svector<lbool> m_model;
vector<unsigned_vector> m_tuse_list;
@ -82,6 +116,18 @@ namespace opt {
rational m_weight; // current weight of assignment.
unsigned_vector m_indices;
unsigned_vector m_scores;
vector<rational> m_scored_weights;
svector<bool> m_score_updated;
bool m_enable_simplex;
struct compare_scores {
imp* m_imp;
compare_scores(imp* i):m_imp(i) {}
bool operator()(int v1, int v2) const {
return m_imp->m_scored_weights[v1] > m_imp->m_scored_weights[v2];
}
};
compare_scores m_compare_scores;
heap<compare_scores> m_heap;
svector<bool> m_mark;
struct scope {
unsigned m_trail_lim;
@ -105,7 +151,11 @@ namespace opt {
m_qhead(0),
m_scope_lvl(0),
m_conflict_j(justification(justification::AXIOM)),
m_inconsistent(false) {
m_inconsistent(false),
m_compare_scores(this),
m_heap(0, m_compare_scores) {
m_enable_simplex = true;
}
~imp() {}
@ -116,6 +166,7 @@ namespace opt {
m_simplex.set_lower(var, mpq_inf(mpq(0),mpq(0)));
m_simplex.set_upper(var, mpq_inf(mpq(1),mpq(0)));
m_weights.push_back(w);
m_weights_inv.push_back(rational::one());
m_value.push_back(l_undef);
m_justification.push_back(justification(justification::DECISION));
m_tuse_list.push_back(unsigned_vector());
@ -127,6 +178,8 @@ namespace opt {
m_model.push_back(l_undef);
m_mark.push_back(false);
m_scores.push_back(0);
m_scored_weights.push_back(rational(0));
m_score_updated.push_back(true);
m_max_weight += w;
}
@ -142,7 +195,7 @@ namespace opt {
vector<unsigned_vector>& use_list = sign?m_fuse_list:m_tuse_list;
lbool val = sign?l_false:l_true;
justification j(justification::AXIOM);
vector<set>& Sets = sign?m_F:m_T;
ptr_vector<set>& Sets = sign?m_F:m_T;
vector<unsigned_vector>& watch = sign?m_fwatch:m_twatch;
init_weights();
if (sz == 0) {
@ -162,26 +215,28 @@ namespace opt {
j = justification(clause_id, !sign);
watch[S[0]].push_back(clause_id);
watch[S[1]].push_back(clause_id);
Sets.push_back(unsigned_vector(sz, S));
Sets.push_back(set::mk(m_alloc, sz, S));
if (!sign) {
pop(scope_lvl());
inc_score(clause_id);
}
TRACE("opt", display(tout, j););
// add_simplex_row(!sign, sz, S);
if (!sign && m_enable_simplex) {
add_simplex_row(!sign, sz, S);
}
}
return j;
}
lbool compute_lower() {
m_lower.reset();
// L3() disabled: mostly a waste of time.
if (L1() && L2()) {
return l_true;
}
else {
return l_undef;
}
rational w1 = L1();
rational w2 = L2();
rational w3 = L3();
if (w1 > m_lower) m_lower = w1;
if (w2 > m_lower) m_lower = w2;
if (w3 > m_lower) m_lower = w3;
return l_true;
}
lbool compute_upper() {
@ -190,11 +245,14 @@ namespace opt {
lbool r = search();
pop(scope_lvl());
std::cout << m_T.size() << " " << m_F.size() << "\n";
#if 0
// garbage collect agressively on exit.
// all learned clases for negative branches are
// pruned.
for (unsigned i = fsz; i < m_F.size(); ++i) {
m_alloc.deallocate(m_F[i]->alloc_size(), m_F[i]);
}
m_F.resize(fsz);
for (unsigned i = 0; i < m_fuse_list.size(); ++i) {
unsigned_vector & uses = m_fuse_list[i];
@ -209,6 +267,7 @@ namespace opt {
}
watch.resize(k);
}
#endif
return r;
}
@ -263,6 +322,20 @@ namespace opt {
m_weights[i] *= d;
}
}
rational lc(1);
for (unsigned i = 0; i < m_weights.size(); ++i) {
lc = lcm(lc, m_weights[i]);
}
for (unsigned i = 0; i < m_weights.size(); ++i) {
m_weights_inv[i] = lc/m_weights[i];
}
m_heap.set_bounds(m_weights.size());
for (unsigned i = 0; i < m_weights.size(); ++i) {
m_heap.insert(i);
}
update_heap();
// set up Simplex objective function.
for (unsigned i = 0; i < m_weights.size(); ++i) {
vars.push_back(i);
@ -272,6 +345,7 @@ namespace opt {
vars.push_back(m_weights_var);
coeffs.push_back(mpz(-1));
m_simplex.add_row(m_weights_var, coeffs.size(), vars.c_ptr(), coeffs.c_ptr());
}
void display(std::ostream& out) const {
@ -281,10 +355,10 @@ namespace opt {
out << i << ": " << value(i) << " w: " << m_weights[i] << " s: " << m_scores[i] << "\n";
}
for (unsigned i = 0; i < m_T.size(); ++i) {
display(out << "+" << i << ": ", m_T[i]);
display(out << "+" << i << ": ", *m_T[i]);
}
for (unsigned i = 0; i < m_F.size(); ++i) {
display(out << "-" << i << ": ", m_F[i]);
display(out << "-" << i << ": ", *m_F[i]);
}
out << "watch lists:\n";
for (unsigned i = 0; i < m_fwatch.size(); ++i) {
@ -321,7 +395,7 @@ namespace opt {
break;
case justification::CLAUSE: {
out << "clause: ";
set const& S = j.pos()?m_T[j.clause()]:m_F[j.clause()];
set const& S = j.pos()?(*m_T[j.clause()]):(*m_F[j.clause()]);
for (unsigned i = 0; i < S.size(); ++i) {
out << S[i] << " ";
}
@ -330,50 +404,38 @@ namespace opt {
}
}
struct scoped_select {
struct scoped_push {
imp& s;
unsigned sz;
scoped_select(imp& s):s(s), sz(s.m_trail.size()) {
}
~scoped_select() {
s.unassign(sz);
}
scoped_push(imp& s):s(s) { s.push(); }
~scoped_push() { s.pop(1); }
};
struct value_lt {
vector<rational> const& weights;
unsigned_vector const& scores;
value_lt(vector<rational> const& weights, unsigned_vector const& scores):
weights(weights), scores(scores) {}
value_lt(vector<rational> const& weights):
weights(weights) {}
bool operator()(int v1, int v2) const {
// - score1 / w1 < - score2 / w2
// <=>
// score1 / w1 > score2 / w2
// <=>
// score1*w2 > score2*w1
unsigned score1 = scores[v1];
unsigned score2 = scores[v2];
rational w1 = weights[v1];
rational w2 = weights[v2];
return rational(score1)*w2 > rational(score2)*w1;
return weights[v1] > weights[v2];
}
};
void inc_score(unsigned clause_id) {
set const& S = m_T[clause_id];
set const& S = *m_T[clause_id];
if (!has_selected(S)) {
for (unsigned j = 0; j < S.size(); ++j) {
++m_scores[S[j]];
m_score_updated[S[j]] = true;
}
}
}
void dec_score(unsigned clause_id) {
set const& S = m_T[clause_id];
set const& S = *m_T[clause_id];
if (!has_selected(S)) {
for (unsigned j = 0; j < S.size(); ++j) {
SASSERT(m_scores[S[j]] > 0);
--m_scores[S[j]];
m_score_updated[S[j]] = true;
}
}
}
@ -390,11 +452,11 @@ namespace opt {
}
}
bool L1() {
rational w(0);
scoped_select _sc(*this);
rational L1() {
rational w(m_weight);
scoped_push _sc(*this);
for (unsigned i = 0; !canceled() && i < m_T.size(); ++i) {
set const& S = m_T[i];
set const& S = *m_T[i];
SASSERT(!S.empty());
if (!has_selected(S)) {
w += m_weights[select_min(S)];
@ -403,25 +465,42 @@ namespace opt {
}
}
}
if (m_lower < w) {
m_lower = w;
}
return !canceled();
return w;
}
bool L2() {
rational w(0);
scoped_select _sc(*this);
void update_heap() {
for (unsigned i = 0; i < m_scored_weights.size(); ++i) {
if (m_score_updated[i]) {
rational const& old_w = m_scored_weights[i];
rational new_w = rational(m_scores[i])*m_weights_inv[i];
if (new_w > old_w) {
m_scored_weights[i] = new_w;
//m_heap.decreased(i);
}
else if (new_w < old_w) {
m_scored_weights[i] = new_w;
//m_heap.increased(i);
}
m_score_updated[i] = false;
}
}
}
rational L2() {
rational w(m_weight);
scoped_push _sc(*this);
int n = 0;
for (unsigned i = 0; i < m_T.size(); ++i) {
if (!has_selected(m_T[i])) ++n;
if (!has_selected(*m_T[i])) ++n;
}
value_lt lt(m_weights, m_scores);
update_heap();
value_lt lt(m_scored_weights);
std::sort(m_indices.begin(), m_indices.end(), lt);
for(unsigned i = 0; i < m_indices.size() && n > 0; ++i) {
// deg(c) = score(c)
// wt(c) = m_weights[c]
unsigned idx = m_indices[i];
if (m_scores[idx] == 0) {
break;
@ -434,13 +513,10 @@ namespace opt {
}
n -= m_scores[idx];
}
if (m_lower < w) {
m_lower = w;
}
return !canceled();
return w;
}
bool L3() {
rational L3() {
TRACE("simplex", m_simplex.display(tout););
VERIFY(l_true == m_simplex.make_feasible());
TRACE("simplex", m_simplex.display(tout););
@ -450,11 +526,12 @@ namespace opt {
unsynch_mpq_manager& mq = mg.mpq_manager();
scoped_mpq c(mq);
mg.ceil(val, c);
rational w = rational(c);
if (w > m_lower) {
m_lower = w;
}
return true;
rational w(c);
CTRACE("simplex",
w >= m_weight, tout << w << " " << m_weight << " !!!!\n";
display(tout););
SASSERT(w >= m_weight);
return w;
}
void add_simplex_row(bool is_some_true, unsigned sz, unsigned const* S) {
@ -490,9 +567,9 @@ namespace opt {
return result;
}
bool have_selected(lbool val, vector<set> const& Sets, unsigned& i) {
bool have_selected(lbool val, ptr_vector<set> const& Sets, unsigned& i) {
for (i = 0; i < Sets.size(); ++i) {
if (!has_selected(val, Sets[i])) return false;
if (!has_selected(val, *Sets[i])) return false;
}
return true;
}
@ -508,7 +585,7 @@ namespace opt {
bool values_satisfy_Fs(unsigned& i) {
unsigned j = 0;
for (i = 0; i < m_F.size(); ++i) {
set const& F = m_F[i];
set const& F = *m_F[i];
for (j = 0; j < F.size(); ++j) {
if (m_model[F[j]] == l_false) {
break;
@ -575,6 +652,9 @@ namespace opt {
if (val == l_true) {
m_weight += m_weights[idx];
update_score(idx, false);
if (m_enable_simplex) {
m_simplex.set_lower(idx, mpq_inf(mpq(1),mpq(0)));
}
}
SASSERT(val != l_true || m_scores[idx] == 0);
m_value[idx] = val;
@ -594,7 +674,10 @@ namespace opt {
m_value[idx] = l_undef;
if (val == l_true) {
m_weight -= m_weights[idx];
update_score(idx, true);
update_score(idx, true);
if (m_enable_simplex) {
m_simplex.set_lower(idx, mpq_inf(mpq(0),mpq(0)));
}
}
if (m_justification[idx].is_axiom()) {
m_replay_idx.push_back(idx);
@ -639,7 +722,7 @@ namespace opt {
bool validate_model() {
for (unsigned i = 0; i < m_T.size(); ++i) {
set const& S = m_T[i];
set const& S = *m_T[i];
bool found = false;
for (unsigned j = 0; !found && j < S.size(); ++j) {
found = value(S[j]) == l_true;
@ -650,7 +733,7 @@ namespace opt {
SASSERT(found);
}
for (unsigned i = 0; i < m_F.size(); ++i) {
set const& S = m_F[i];
set const& S = *m_F[i];
bool found = false;
for (unsigned j = 0; !found && j < S.size(); ++j) {
found = value(S[j]) != l_true;
@ -667,13 +750,13 @@ namespace opt {
bool invariant() {
for (unsigned i = 0; i < m_fwatch.size(); ++i) {
for (unsigned j = 0; j < m_fwatch[i].size(); ++j) {
set const& S = m_F[m_fwatch[i][j]];
set const& S = *m_F[m_fwatch[i][j]];
SASSERT(S[0] == i || S[1] == i);
}
}
for (unsigned i = 0; i < m_twatch.size(); ++i) {
for (unsigned j = 0; j < m_twatch[i].size(); ++j) {
set const& S = m_T[m_twatch[i][j]];
set const& S = *m_T[m_twatch[i][j]];
SASSERT(S[0] == i || S[1] == i);
}
}
@ -692,9 +775,9 @@ namespace opt {
unsigned r = lvl(conflict_l);
if (conflict_j.is_clause()) {
unsigned clause = conflict_j.clause();
vector<unsigned_vector> const& S = conflict_j.pos()?m_T:m_F;
r = std::max(r, lvl(S[clause][0]));
r = std::max(r, lvl(S[clause][1]));
ptr_vector<set> const& S = conflict_j.pos()?m_T:m_F;
r = std::max(r, lvl((*S[clause])[0]));
r = std::max(r, lvl((*S[clause])[1]));
}
return r;
}
@ -723,7 +806,7 @@ namespace opt {
unsigned cl = conflict_j.clause();
unsigned i = 0;
SASSERT(value(conflict_l) != l_undef);
set const& T = conflict_j.pos()?m_T[cl]:m_F[cl];
set const& T = conflict_j.pos()?(*m_T[cl]):(*m_F[cl]);
if (T[0] == conflict_l) {
i = 1;
}
@ -823,13 +906,24 @@ namespace opt {
}
unsigned next_var() {
value_lt lt(m_weights, m_scores);
update_heap();
value_lt lt(m_scored_weights);
std::sort(m_indices.begin(), m_indices.end(), lt);
unsigned idx = m_indices[0];
if (m_scores[idx] == 0) {
idx = UINT_MAX;
if (m_scores[idx] == 0) return UINT_MAX;
return idx;
#if 0
int min_val = m_heap.min_value();
if (min_val == -1) {
return UINT_MAX;
}
return idx;
SASSERT(0 <= min_val && static_cast<unsigned>(min_val) < m_weights.size());
if (m_scores[min_val] == 0) {
return UINT_MAX;
}
return static_cast<unsigned>(min_val);
#endif
}
bool decide() {
@ -866,7 +960,7 @@ namespace opt {
prune_branch();
}
void propagate(unsigned idx, lbool good_val, vector<unsigned_vector>& watch, vector<set>& Fs)
void propagate(unsigned idx, lbool good_val, vector<unsigned_vector>& watch, ptr_vector<set>& Fs)
{
TRACE("opt", tout << idx << " " << value(idx) << "\n";);
unsigned_vector& w = watch[idx];
@ -876,10 +970,10 @@ namespace opt {
unsigned l = 0;
for (unsigned i = 0; i < sz && !canceled(); ++i, ++l) {
unsigned clause_id = w[i];
set& F = Fs[clause_id];
set& F = *Fs[clause_id];
SASSERT(F.size() >= 2);
unsigned k1 = (F[0] == idx)?0:1;
unsigned k2 = 1 - k1;
bool k1 = (F[0] != idx);
bool k2 = !k1;
SASSERT(F[k1] == idx);
SASSERT(value(F[k1]) == bad_val);
if (value(F[k2]) == good_val) {
@ -924,27 +1018,35 @@ namespace opt {
}
bool infeasible_lookahead() {
// TBD: make this more powerful
// by using L1, L2, L3 pruning criteria.
return (m_weight >= m_max_weight);
if (m_enable_simplex && L3() >= m_max_weight) {
return true;
}
return
(L1() >= m_max_weight) ||
(L2() >= m_max_weight);
}
void prune_branch() {
if (!inconsistent() && infeasible_lookahead()) {
IF_VERBOSE(4, verbose_stream() << "(hs.prune-branch " << m_weight << ")\n";);
m_lemma.reset();
for (unsigned i = m_trail.size(); i > 0; ) {
--i;
unsigned idx = m_trail[i];
if (m_justification[idx].is_decision()) {
SASSERT(value(idx) == l_true);
m_lemma.push_back(idx);
}
}
justification j = add_exists_false(m_lemma.size(), m_lemma.c_ptr());
TRACE("opt", display(tout, j););
set_conflict(m_lemma[0], j);
if (inconsistent() || !infeasible_lookahead()) {
return;
}
IF_VERBOSE(4, verbose_stream() << "(hs.prune-branch " << m_weight << ")\n";);
m_lemma.reset();
unsigned i = 0;
rational w(0);
for (; i < m_trail.size() && w < m_max_weight; ++i) {
unsigned idx = m_trail[i];
if (m_justification[idx].is_decision()) {
SASSERT(value(idx) == l_true);
m_lemma.push_back(idx);
w += m_weights[idx];
}
}
// undo the lower bounds.
justification j = add_exists_false(m_lemma.size(), m_lemma.c_ptr());
TRACE("opt", display(tout, j););
set_conflict(m_lemma[0], j);
}
// TBD: derive strong inequalities and add them to Simplex.