3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-23 09:05:31 +00:00

custom HS solver

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2014-06-18 17:31:00 -07:00
parent d7d85aa18a
commit 04407938be

View file

@ -22,7 +22,6 @@ Notes:
#include "simplex.h"
#include "sparse_matrix_def.h"
#include "simplex_def.h"
#include "sat_solver.h"
typedef simplex::simplex<simplex::mpz_ext> Simplex;
typedef simplex::sparse_matrix<simplex::mpz_ext> sparse_matrix;
@ -42,7 +41,7 @@ namespace opt {
public:
explicit justification(kind_t k):m_kind(k), m_value(0), m_pos(false) {}
explicit justification(unsigned v, bool pos):m_kind(CLAUSE), m_value(v), m_pos(pos) {}
explicit justification(justification const& other): m_kind(other.m_kind), m_value(other.m_value), m_pos(other.m_pos) {}
justification(justification const& other): m_kind(other.m_kind), m_value(other.m_value), m_pos(other.m_pos) {}
justification& operator=(justification const& other) {
m_kind = other.m_kind;
m_value = other.m_value;
@ -96,11 +95,6 @@ namespace opt {
Simplex m_simplex;
unsigned m_weights_var;
// sat solver
params_ref m_params;
sat::solver m_solver;
svector<sat::bool_var> m_vars;
static unsigned const null_idx = UINT_MAX;
imp():
@ -111,10 +105,7 @@ namespace opt {
m_qhead(0),
m_scope_lvl(0),
m_conflict_j(justification(justification::AXIOM)),
m_inconsistent(false),
m_solver(m_params,0) {
m_params.set_bool("elim_vars", false);
m_solver.updt_params(m_params);
m_inconsistent(false) {
}
~imp() {}
@ -137,46 +128,49 @@ namespace opt {
m_mark.push_back(false);
m_scores.push_back(0);
m_max_weight += w;
m_vars.push_back(m_solver.mk_var());
}
unsigned add_exists_false(unsigned sz, unsigned const* S) {
justification add_exists_false(unsigned sz, unsigned const* S) {
return add_exists(sz, S, true);
}
unsigned add_exists_true(unsigned sz, unsigned const* S) {
justification add_exists_true(unsigned sz, unsigned const* S) {
return add_exists(sz, S, false);
}
unsigned add_exists(unsigned sz, unsigned const* S, bool sign) {
justification add_exists(unsigned sz, unsigned const* S, bool sign) {
vector<unsigned_vector>& use_list = sign?m_fuse_list:m_tuse_list;
lbool val = sign?l_false:l_true;
unsigned clause_id;
justification j(justification::AXIOM);
vector<set>& Sets = sign?m_F:m_T;
vector<unsigned_vector>& watch = sign?m_fwatch:m_twatch;
SASSERT(sz > 0);
for (unsigned i = 0; i < sz; ++i) {
use_list[S[i]].push_back(Sets.size());
}
init_weights();
if (sz == 1) {
clause_id = UINT_MAX;
if (sz == 0) {
// TBD
IF_VERBOSE(0, verbose_stream() << "empty clause\n";);
set_conflict(0, justification(justification::AXIOM));
}
else if (sz == 1) {
IF_VERBOSE(1, verbose_stream() << "unit literal : " << S[0] << " " << val << "\n";);
assign(S[0], val, justification(justification::AXIOM));
}
else {
clause_id = Sets.size();
unsigned clause_id = Sets.size();
for (unsigned i = 0; i < sz; ++i) {
use_list[S[i]].push_back(clause_id);
}
j = justification(clause_id, !sign);
watch[S[0]].push_back(clause_id);
watch[S[1]].push_back(clause_id);
Sets.push_back(unsigned_vector(sz, S));
if (!sign) {
pop(scope_lvl());
inc_score(clause_id);
}
TRACE("opt", display(tout, j););
// add_simplex_row(!sign, sz, S);
}
add_simplex_row(!sign, sz, S);
// Add clause to SAT solver:
svector<sat::literal> lits;
for (unsigned i = 0; i < sz; ++i) {
lits.push_back(sat::literal(m_vars[S[i]], sign));
}
m_solver.mk_clause(lits.size(), lits.c_ptr());
return clause_id;
return j;
}
lbool compute_lower() {
@ -192,8 +186,30 @@ namespace opt {
lbool compute_upper() {
m_upper = m_max_weight;
return search();
// return U1();
unsigned fsz = m_F.size();
lbool r = search();
pop(scope_lvl());
std::cout << m_T.size() << " " << m_F.size() << "\n";
// garbage collect agressively on exit.
// all learned clases for negative branches are
// pruned.
m_F.resize(fsz);
for (unsigned i = 0; i < m_fuse_list.size(); ++i) {
unsigned_vector & uses = m_fuse_list[i];
while (!uses.empty() && uses.back() >= fsz) uses.pop_back();
unsigned_vector & watch = m_fwatch[i];
unsigned j = 0, k = 0;
for (; j < watch.size(); ++j) {
if (watch[j] < fsz) {
watch[k] = watch[j];
++k;
}
}
watch.resize(k);
}
return r;
}
rational get_lower() {
@ -217,12 +233,10 @@ namespace opt {
void set_cancel(bool f) {
m_cancel = f;
m_simplex.set_cancel(f);
m_solver.set_cancel(f);
}
void collect_statistics(::statistics& st) const {
m_simplex.collect_statistics(st);
m_solver.collect_statistics(st);
}
void reset() {
@ -264,7 +278,7 @@ namespace opt {
out << "inconsistent: " << m_inconsistent << "\n";
out << "weight: " << m_weight << "\n";
for (unsigned i = 0; i < m_weights.size(); ++i) {
out << i << ": " << value(i) << " " << m_weights[i] << "\n";
out << i << ": " << value(i) << " w: " << m_weights[i] << " s: " << m_scores[i] << "\n";
}
for (unsigned i = 0; i < m_T.size(); ++i) {
display(out << "+" << i << ": ", m_T[i]);
@ -345,139 +359,33 @@ namespace opt {
}
};
lbool U1() {
scoped_select _sc(*this);
while (true) {
lbool is_sat = compute_U1();
if (is_sat != l_true) {
return is_sat;
void inc_score(unsigned clause_id) {
set const& S = m_T[clause_id];
if (!has_selected(S)) {
for (unsigned j = 0; j < S.size(); ++j) {
++m_scores[S[j]];
}
unsigned i = 0, j = 0;
set_undef_to_false();
if (values_satisfy_Fs(i)) {
if (m_upper > m_max_weight) {
IF_VERBOSE(1, verbose_stream() << "(hs.bound_degradation " << m_upper << " )\n";);
}
return l_true;
}
//
// pick some unsatisfied clause from m_F,
// and set the value of the most expensive
// literal to true.
//
IF_VERBOSE(1, verbose_stream() << "(hs.refining exclusion set " << i << ")\n";);
set const& F = m_F[i];
rational max_value(0);
j = 0;
for (i = 0; i < F.size(); ++i) {
SASSERT(m_model[F[i]] == l_true);
if (max_value < m_weights[F[i]]) {
max_value = m_weights[F[i]];
j = F[i];
}
}
IF_VERBOSE(1, verbose_stream() << "(hs.unselect " << j << ")\n";);
assign(j, l_false, justification(justification::DECISION));
for (i = 0; i < m_T.size(); ++i) {
set const& S = m_T[i];
for (j = 0; j < S.size(); ++j) {
if (l_false != value(S[j])) break;
}
if (j == S.size()) {
IF_VERBOSE(1, verbose_stream() << "(hs.fallback-to-SAT)\n";);
return compute_U2();
}
}
TRACE("opt", display(tout););
}
}
lbool compute_U2() {
lbool is_sat = l_true;
while (true) {
is_sat = m_solver.check();
if (is_sat == l_true) {
sat::model const& model = m_solver.get_model();
m_model.reset();
m_upper.reset();
for (unsigned i = 0; i < m_vars.size(); ++i) {
m_model.push_back(model[m_vars[i]]);
if (model[m_vars[i]] == l_true) {
m_upper += m_weights[i];
}
}
IF_VERBOSE(1, verbose_stream() << "(hs.upper " << m_upper << ")\n";);
m_solver.pop(m_solver.scope_lvl());
void dec_score(unsigned clause_id) {
set const& S = m_T[clause_id];
if (!has_selected(S)) {
for (unsigned j = 0; j < S.size(); ++j) {
SASSERT(m_scores[S[j]] > 0);
--m_scores[S[j]];
}
break;
}
return is_sat;
}
bool block_model(sat::model const& model) {
rational value(0);
svector<sat::literal> lits;
for (unsigned i = 0; i < m_vars.size(); ++i) {
if (value >= m_max_weight) {
m_solver.mk_clause(lits.size(), lits.c_ptr());
return true;
void update_score(unsigned idx, bool inc) {
unsigned_vector const& uses = m_tuse_list[idx];
for (unsigned i = 0; i < uses.size(); ++i) {
if (inc) {
inc_score(uses[i]);
}
if (model[m_vars[i]] == l_true) {
value += m_weights[i];
lits.push_back(sat::literal(m_vars[i], true));
}
}
return false;
}
// compute upper bound for hitting set.
lbool compute_U1() {
rational w(0);
scoped_select _sc(*this);
// score each variable by the number of
// unassigned sets they occur in.
//
// Sort indices.
// The least literals are those where score/w is maximized.
//
value_lt lt(m_weights, m_scores);
while (true) {
if (canceled()) {
return l_undef;
}
init_scores();
std::sort(m_indices.begin(), m_indices.end(), lt);
unsigned idx = m_indices[0];
if (m_scores[idx] == 0) {
break;
}
assign(idx, l_true, justification(justification::DECISION));
}
m_upper = m_weight;
m_model.reset();
m_model.append(m_value);
return l_true;
}
void init_scores() {
unsigned_vector & scores = m_scores;
scores.reset();
for (unsigned i = 0; i < m_value.size(); ++i) {
scores.push_back(0);
}
for (unsigned i = 0; i < m_T.size(); ++i) {
set const& S = m_T[i];
if (!has_selected(S)) {
for (unsigned j = 0; j < S.size(); ++j) {
if (value(S[j]) != l_false) {
++scores[S[j]];
}
}
else {
dec_score(uses[i]);
}
}
}
@ -508,7 +416,6 @@ namespace opt {
for (unsigned i = 0; i < m_T.size(); ++i) {
if (!has_selected(m_T[i])) ++n;
}
init_scores();
value_lt lt(m_weights, m_scores);
std::sort(m_indices.begin(), m_indices.end(), lt);
@ -664,28 +571,47 @@ namespace opt {
}
}
void assign(unsigned j, lbool val, justification const& justification) {
void assign(unsigned idx, lbool val, justification const& justification) {
if (val == l_true) {
m_weight += m_weights[j];
m_weight += m_weights[idx];
update_score(idx, false);
}
m_value[j] = val;
m_justification[j] = justification;
m_trail.push_back(j);
m_level[j] = scope_lvl();
TRACE("opt", tout << j << " := " << val << " scope: " << scope_lvl() << " w: " << m_weight << "\n";);
SASSERT(val != l_true || m_scores[idx] == 0);
m_value[idx] = val;
m_justification[idx] = justification;
m_trail.push_back(idx);
m_level[idx] = scope_lvl();
TRACE("opt", tout << idx << " := " << val << " scope: " << scope_lvl() << " w: " << m_weight << "\n";);
}
svector<unsigned> m_replay_idx;
svector<lbool> m_replay_val;
void unassign(unsigned sz) {
for (unsigned j = sz; j < m_trail.size(); ++j) {
unsigned idx = m_trail[j];
if (value(idx) == l_true) {
m_weight -= m_weights[idx];
}
unsigned idx = m_trail[j];
lbool val = value(idx);
m_value[idx] = l_undef;
if (val == l_true) {
m_weight -= m_weights[idx];
update_score(idx, true);
}
if (m_justification[idx].is_axiom()) {
m_replay_idx.push_back(idx);
m_replay_val.push_back(val);
}
}
TRACE("opt", tout << m_weight << "\n";);
m_trail.shrink(sz);
m_qhead = sz;
for (unsigned i = m_replay_idx.size(); i > 0; ) {
--i;
unsigned idx = m_replay_idx[i];
lbool val = m_replay_val[i];
assign(idx, val, justification(justification::AXIOM));
}
m_replay_idx.reset();
m_replay_val.reset();
}
@ -778,6 +704,9 @@ namespace opt {
TRACE("opt", display(tout););
unsigned conflict_l = m_conflict_l;
justification conflict_j(m_conflict_j);
if (conflict_j.is_axiom()) {
return false;
}
m_conflict_lvl = get_max_lvl(conflict_l, conflict_j);
if (m_conflict_lvl == 0) {
return false;
@ -808,6 +737,17 @@ namespace opt {
process_antecedent(T[i], num_marks);
}
}
else if (conflict_j.is_decision()) {
--num_marks;
SASSERT(num_marks == 0);
break;
}
else if (conflict_j.is_axiom()) {
IF_VERBOSE(0, verbose_stream() << "axiom " << conflict_l << " " << value(conflict_l) << " " << num_marks << "\n";);
--num_marks;
SASSERT(num_marks == 0);
break;
}
while (true) {
unsigned l = m_trail[idx];
if (is_marked(l)) break;
@ -818,16 +758,19 @@ namespace opt {
conflict_j = m_justification[conflict_l];
--idx;
--num_marks;
if (num_marks == 0 && value(conflict_l) == l_false) {
++num_marks;
}
reset_mark(conflict_l);
}
while (num_marks > 0);
m_lemma[0] = conflict_l;
TRACE("opt",
for (unsigned i = 0; i < m_lemma.size(); ++i) {
tout << m_lemma[i] << " " << value(m_lemma[i]) << " ";
tout << m_lemma[i] << " ";
}
tout << "\n";);
SASSERT(value(conflict_l) == l_true);
unsigned new_scope_lvl = 0;
for (unsigned i = 1; i < m_lemma.size(); ++i) {
SASSERT(l_true == value(m_lemma[i]));
@ -836,17 +779,16 @@ namespace opt {
}
pop(scope_lvl() - new_scope_lvl);
SASSERT(l_undef == value(conflict_l));
unsigned clause_id = add_exists_false(m_lemma.size(), m_lemma.c_ptr());
if (clause_id != UINT_MAX) {
assign(conflict_l, l_false, justification(clause_id, false));
}
justification j = add_exists_false(m_lemma.size(), m_lemma.c_ptr());
if (!j.is_axiom()) assign(conflict_l, l_false, j);
return true;
}
void process_antecedent(unsigned antecedent, unsigned& num_marks) {
unsigned alvl = lvl(antecedent);
SASSERT(alvl <= m_conflict_lvl);
if (!is_marked(antecedent) && alvl > 0) {
if (!is_marked(antecedent) && alvl > 0 && !m_justification[antecedent].is_axiom()) {
mark(antecedent);
if (alvl == m_conflict_lvl || value(antecedent) == l_false) {
++num_marks;
@ -882,7 +824,6 @@ namespace opt {
unsigned next_var() {
value_lt lt(m_weights, m_scores);
init_scores();
std::sort(m_indices.begin(), m_indices.end(), lt);
unsigned idx = m_indices[0];
if (m_scores[idx] == 0) {
@ -922,18 +863,19 @@ namespace opt {
break;
}
}
//prune_branch();
prune_branch();
}
void propagate(unsigned idx, lbool good_val, vector<unsigned_vector>& watch, vector<set>& Fs)
{
TRACE("opt", tout << idx << " " << value(idx) << "\n";);
unsigned sz = watch[idx].size();
unsigned_vector& w = watch[idx];
unsigned sz = w.size();
lbool bad_val = ~good_val;
SASSERT(value(idx) == bad_val);
unsigned l = 0;
for (unsigned i = 0; i < sz && !canceled(); ++i, ++l) {
unsigned clause_id = watch[idx][i];
unsigned clause_id = w[i];
set& F = Fs[clause_id];
SASSERT(F.size() >= 2);
unsigned k1 = (F[0] == idx)?0:1;
@ -941,11 +883,12 @@ namespace opt {
SASSERT(F[k1] == idx);
SASSERT(value(F[k1]) == bad_val);
if (value(F[k2]) == good_val) {
watch[idx][l] = watch[idx][i];
w[l] = w[i];
continue;
}
bool found = false;
for (unsigned j = 2; !found && j < F.size(); ++j) {
unsigned sz2 = F.size();
for (unsigned j = 2; !found && j < sz2; ++j) {
unsigned idx2 = F[j];
if (value(idx2) != bad_val) {
found = true;
@ -957,15 +900,20 @@ namespace opt {
if (!found) {
if (value(F[k2]) == bad_val) {
set_conflict(F[k2], justification(clause_id, good_val == l_true));
for (; l <= i && i < sz; ++i, ++l) {
watch[idx][l] = watch[idx][i];
if (i == l) {
l = sz;
}
else {
for (; i < sz; ++i, ++l) {
w[l] = w[i];
}
}
break;
}
else {
SASSERT(value(F[k2]) == l_undef);
assign(F[k2], good_val, justification(clause_id, good_val == l_true));
watch[idx][l] = watch[idx][i];
w[l] = w[i];
}
}
}
@ -982,16 +930,20 @@ namespace opt {
}
void prune_branch() {
if (infeasible_lookahead()) {
m_inconsistent = true;
if (!inconsistent() && infeasible_lookahead()) {
IF_VERBOSE(4, verbose_stream() << "(hs.prune-branch " << m_weight << ")\n";);
m_lemma.reset();
for (unsigned i = m_trail.size(); i > 0; ) {
--i;
if (value(m_trail[i]) == l_true) {
m_conflict_l = m_trail[i];
m_conflict_j = m_justification[m_conflict_l];
break;
}
unsigned idx = m_trail[i];
if (m_justification[idx].is_decision()) {
SASSERT(value(idx) == l_true);
m_lemma.push_back(idx);
}
}
justification j = add_exists_false(m_lemma.size(), m_lemma.c_ptr());
TRACE("opt", display(tout, j););
set_conflict(m_lemma[0], j);
}
}