3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-29 20:05:51 +00:00

fixes and tests for arith-sls

This commit is contained in:
Nikolaj Bjorner 2023-02-28 17:40:00 -08:00
parent e87fa1c299
commit 25d45a3500
7 changed files with 182 additions and 136 deletions

View file

@ -29,42 +29,49 @@ namespace arith {
m_terms.reset();
}
void sls::log() {
IF_VERBOSE(2, verbose_stream() << "(sls :flips " << m_stats.m_num_flips << " :unsat " << unsat().size() << ")\n");
}
void sls::save_best_values() {
for (unsigned v = 0; v < s.get_num_vars(); ++v)
m_vars[v].m_best_value = m_vars[v].m_value;
auto check_bool_var = [&](sat::bool_var bv) {
auto const* ineq = atom(bv);
if (!ineq)
return;
sat::literal lit(bv, !m_bool_search->get_value(bv));
int64_t d = dtt(lit.sign(), *ineq);
// verbose_stream() << "check " << lit << " " << *ineq << "\n";
if (is_true(lit) != (d == 0)) {
verbose_stream() << lit << " " << *ineq << "\n";
check_ineqs();
if (unsat().size() == 1) {
auto idx = *unsat().begin();
verbose_stream() << idx << "\n";
auto const& c = *m_bool_search->m_clauses[idx].m_clause;
verbose_stream() << c << "\n";
for (auto lit : c) {
bool_var bv = lit.var();
ineq* i = atom(bv);
if (i)
verbose_stream() << lit << ": " << *i << "\n";
}
VERIFY(is_true(lit) == (d == 0));
};
for (unsigned v = 0; v < s.get_num_vars(); ++v)
check_bool_var(v);
verbose_stream() << "\n";
}
}
void sls::store_best_values() {
// first compute assignment to terms
// then update non-basic variables in tableau.
for (auto const& [t, v] : m_terms) {
if (!unsat().empty())
return;
for (auto const& [t,v] : m_terms) {
int64_t val = 0;
lp::lar_term const& term = s.lp().get_term(t);
for (lp::lar_term::ival arg : term) {
for (lp::lar_term::ival const& arg : term) {
auto t2 = s.lp().column2tv(arg.column());
auto w = s.lp().local_to_external(t2.id());
val += to_numeral(arg.coeff()) * m_vars[w].m_best_value;
}
update(v, val);
if (v == 52) {
verbose_stream() << "update v" << v << " := " << val << "\n";
for (lp::lar_term::ival const& arg : term) {
auto t2 = s.lp().column2tv(arg.column());
auto w = s.lp().local_to_external(t2.id());
verbose_stream() << "v" << w << " := " << m_vars[w].m_best_value << " * " << to_numeral(arg.coeff()) << "\n";
}
}
m_vars[v].m_best_value = val;
}
for (unsigned v = 0; v < s.get_num_vars(); ++v) {
@ -80,16 +87,15 @@ namespace arith {
rational new_value_(new_value, rational::i64());
lp::impq val(new_value_, rational::zero());
s.lp().set_value_for_nbasic_column(vj.index(), val);
// TODO - figure out why this leads to unsound (unsat).
}
}
lbool r = s.make_feasible();
VERIFY (!unsat().empty() || r == l_true);
if (unsat().empty()) {
#if 0
if (unsat().empty())
s.m_num_conflicts = s.get_config().m_arith_propagation_threshold;
}
verbose_stream() << "has changed " << s.m_solver->has_changed_columns() << "\n";
#endif
auto check_bool_var = [&](sat::bool_var bv) {
auto* ineq = m_bool_vars.get(bv, nullptr);
@ -105,10 +111,10 @@ namespace arith {
return;
switch (b->get_bound_kind()) {
case lp_api::lower_t:
verbose_stream() << bv << " " << bound << " <= " << s.get_value(v) << "\n";
verbose_stream() << "v" << v << " " << bound << " <= " << s.get_value(v) << " " << m_vars[v].m_best_value << "\n";
break;
case lp_api::upper_t:
verbose_stream() << bv << " " << bound << " >= " << s.get_value(v) << "\n";
verbose_stream() << "v" << v << " " << bound << " >= " << s.get_value(v) << " " << m_vars[v].m_best_value << "\n";
break;
}
int64_t value = 0;
@ -117,6 +123,12 @@ namespace arith {
}
ineq->m_args_value = value;
verbose_stream() << *ineq << " dtt " << dtt(false, *ineq) << " phase " << s.get_phase(bv) << " model " << m_bool_search->get_model()[bv] << "\n";
for (auto const& [coeff, v] : ineq->m_args)
verbose_stream() << "v" << v << " := " << m_vars[v].m_best_value << "\n";
s.display(verbose_stream());
display(verbose_stream());
UNREACHABLE();
exit(0);
};
if (unsat().empty()) {
@ -200,16 +212,16 @@ namespace arith {
return dtt(sign, ineq.m_args_value + coeff * (new_value - old_value), ineq);
}
bool sls::cm(bool sign, ineq const& ineq, var_t v, int64_t& new_value) {
bool sls::cm(bool old_sign, ineq const& ineq, var_t v, int64_t& new_value) {
for (auto const& [coeff, w] : ineq.m_args)
if (w == v)
return cm(sign, ineq, v, coeff, new_value);
return cm(old_sign, ineq, v, coeff, new_value);
return false;
}
bool sls::cm(bool new_sign, ineq const& ineq, var_t v, int64_t coeff, int64_t& new_value) {
SASSERT(ineq.is_true() == new_sign);
VERIFY(ineq.is_true() == new_sign);
bool sls::cm(bool old_sign, ineq const& ineq, var_t v, int64_t coeff, int64_t& new_value) {
SASSERT(ineq.is_true() != old_sign);
VERIFY(ineq.is_true() != old_sign);
auto bound = ineq.m_bound;
auto argsv = ineq.m_args_value;
bool solved = false;
@ -239,7 +251,7 @@ namespace arith {
return true;
};
if (new_sign) {
if (!old_sign) {
switch (ineq.m_op) {
case ineq_kind::LE:
// args <= bound -> args > bound
@ -300,10 +312,10 @@ namespace arith {
int64_t new_value;
auto v = ineq.m_var_to_flip;
if (v == UINT_MAX) {
// verbose_stream() << "no var to flip\n";
IF_VERBOSE(1, verbose_stream() << "no var to flip\n");
return false;
}
if (!cm(!sign, ineq, v, new_value)) {
if (!cm(sign, ineq, v, new_value)) {
verbose_stream() << "no critical move for " << v << "\n";
return false;
}
@ -316,16 +328,16 @@ namespace arith {
// TODO - use cached dts instead of computed dts
// cached dts has to be updated when the score of literals are updated.
//
double sls::dscore(var_t v, int64_t new_value) const {
verbose_stream() << "dscore\n";
double sls::dscore(var_t v, int64_t new_value) const {
double score = 0;
#if 0
auto const& vi = m_vars[v];
verbose_stream() << "dscore " << v << "\n";
for (auto const& [coeff, lit] : vi.m_literals)
for (auto cl : m_bool_search->get_use_list(lit))
score += (compute_dts(cl) - dts(cl, v, new_value)) * m_bool_search->get_weight(cl);
#endif
for (auto const& [coeff, bv] : vi.m_bool_vars) {
sat::literal lit(bv, false);
for (auto cl : m_bool_search->get_use_list(lit))
score += (compute_dts(cl) - dts(cl, v, new_value)) * m_bool_search->get_weight(cl);
for (auto cl : m_bool_search->get_use_list(~lit))
score += (compute_dts(cl) - dts(cl, v, new_value)) * m_bool_search->get_weight(cl);
}
return score;
}
@ -341,12 +353,12 @@ namespace arith {
int64_t old_value = vi.m_value;
for (auto const& [coeff, bv] : vi.m_bool_vars) {
auto const& ineq = *atom(bv);
bool sign = !m_bool_search->value(bv);
int64_t dtt_old = dtt(sign, ineq);
int64_t dtt_new = dtt(sign, ineq, coeff, old_value, new_value);
bool old_sign = sign(bv);
int64_t dtt_old = dtt(old_sign, ineq);
int64_t dtt_new = dtt(old_sign, ineq, coeff, old_value, new_value);
if ((dtt_old == 0) == (dtt_new == 0))
continue;
sat::literal lit(bv, sign);
sat::literal lit(bv, old_sign);
if (dtt_old == 0)
// flip from true to false
lit.neg();
@ -408,14 +420,14 @@ namespace arith {
auto old_value = vi.m_value;
for (auto const& [coeff, bv] : vi.m_bool_vars) {
auto& ineq = *atom(bv);
bool sign = !m_bool_search->value(bv);
sat::literal lit(bv, sign);
bool old_sign = sign(bv);
sat::literal lit(bv, old_sign);
SASSERT(is_true(lit));
ineq.m_args_value += coeff * (new_value - old_value);
int64_t dtt_new = dtt(sign, ineq);
int64_t dtt_new = dtt(old_sign, ineq);
if (dtt_new != 0)
m_bool_search->flip(bv);
SASSERT(dtt(!m_bool_search->value(bv), ineq) == 0);
SASSERT(dtt(sign(bv), ineq) == 0);
}
vi.m_value = new_value;
}
@ -451,7 +463,7 @@ namespace arith {
void sls::add_args(sat::bool_var bv, ineq& ineq, lp::tv t, theory_var v, int64_t sign) {
if (t.is_term()) {
lp::lar_term const& term = s.lp().get_term(t);
m_terms.push_back({t,v});
for (lp::lar_term::ival arg : term) {
auto t2 = s.lp().column2tv(arg.column());
auto w = s.lp().local_to_external(t2.id());
@ -479,6 +491,7 @@ namespace arith {
auto& ineq = new_ineq(op, to_numeral(bound));
add_args(bv, ineq, t, b->get_var(), should_minus ? -1 : 1);
m_bool_vars.set(bv, &ineq);
m_bool_search->set_external(bv);
@ -516,7 +529,7 @@ namespace arith {
}
void sls::flip(sat::bool_var v) {
sat::literal lit(v, m_bool_search->get_value(v));
sat::literal lit(v, !sign(v));
SASSERT(!is_true(lit));
auto const* ineq = atom(v);
if (!ineq)
@ -524,7 +537,7 @@ namespace arith {
if (!ineq)
return;
SASSERT(ineq->is_true() == lit.sign());
flip(!lit.sign(), *ineq);
flip(sign(v), *ineq);
}
double sls::reward(sat::bool_var v) {
@ -535,21 +548,23 @@ namespace arith {
}
double sls::dtt_reward(sat::bool_var bv0) {
bool sign0 = !m_bool_search->get_value(bv0);
bool sign0 = sign(bv0);
auto* ineq = atom(bv0);
if (!ineq)
return -1;
int64_t new_value;
double max_result = -1;
for (auto const & [coeff, x] : ineq->m_args) {
if (!cm(!sign0, *ineq, x, coeff, new_value))
if (!cm(sign0, *ineq, x, coeff, new_value))
continue;
double result = 0;
auto old_value = m_vars[x].m_value;
for (auto const& [coeff, bv] : m_vars[x].m_bool_vars) {
bool sign = !m_bool_search->value(bv);
auto dtt_old = dtt(sign, *atom(bv));
auto dtt_new = dtt(sign, *atom(bv), coeff, old_value, new_value);
result += m_bool_search->reward(bv);
continue;
bool old_sign = sign(bv);
auto dtt_old = dtt(old_sign, *atom(bv));
auto dtt_new = dtt(old_sign, *atom(bv), coeff, old_value, new_value);
if ((dtt_new == 0) != (dtt_old == 0))
result += m_bool_search->reward(bv);
}
@ -563,17 +578,17 @@ namespace arith {
double sls::dscore_reward(sat::bool_var bv) {
m_dscore_mode = false;
bool sign = !m_bool_search->get_value(bv);
sat::literal litv(bv, sign);
bool old_sign = sign(bv);
sat::literal litv(bv, old_sign);
auto* ineq = atom(bv);
if (!ineq)
return 0;
SASSERT(ineq->is_true() != sign);
SASSERT(ineq->is_true() != old_sign);
int64_t new_value;
for (auto const& [coeff, v] : ineq->m_args) {
double result = 0;
if (cm(!sign, *ineq, v, coeff, new_value))
if (cm(old_sign, *ineq, v, coeff, new_value))
result = dscore(v, new_value);
// just pick first positive, or pick a max?
if (result > 0) {
@ -586,7 +601,7 @@ namespace arith {
// switch to dscore mode
void sls::on_rescale() {
// m_dscore_mode = true;
m_dscore_mode = true;
}
void sls::on_save_model() {
@ -597,23 +612,39 @@ namespace arith {
for (unsigned v = 0; v < s.s().num_vars(); ++v)
init_bool_var_assignment(v);
verbose_stream() << "on-restart\n";
check_ineqs();
}
void sls::check_ineqs() {
auto check_bool_var = [&](sat::bool_var bv) {
auto const* ineq = atom(bv);
if (!ineq)
return;
bool sign = !m_bool_search->get_value(bv);
int64_t d = dtt(sign, *ineq);
sat::literal lit(bv, sign);
// verbose_stream() << "check " << lit << " " << *ineq << "\n";
int64_t d = dtt(sign(bv), *ineq);
sat::literal lit(bv, sign(bv));
if (is_true(lit) != (d == 0)) {
verbose_stream() << "restart " << bv << " " << *ineq << "\n";
verbose_stream() << "invalid assignment " << bv << " " << *ineq << "\n";
}
VERIFY(is_true(lit) == (d == 0));
};
for (unsigned v = 0; v < s.get_num_vars(); ++v)
for (unsigned v = 0; v < s.get_num_vars(); ++v)
check_bool_var(v);
verbose_stream() << "on-restart-done\n";
}
std::ostream& sls::display(std::ostream& out) const {
for (bool_var bv = 0; bv < s.s().num_vars(); ++bv) {
auto const* ineq = atom(bv);
if (!ineq)
continue;
out << bv << " " << *ineq << "\n";
}
for (unsigned v = 0; v < s.get_num_vars(); ++v) {
if (s.is_bool(v))
continue;
out << "v" << v << " := " << m_vars[v].m_value << " " << m_vars[v].m_best_value << "\n";
}
return out;
}
}