mirror of
https://github.com/Z3Prover/z3
synced 2025-04-08 10:25:18 +00:00
bug fixes to sls
This commit is contained in:
parent
e380903d61
commit
c7ea4964f2
|
@ -59,8 +59,6 @@ namespace sls {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
template<typename num_t>
|
||||
std::ostream& arith_base<num_t>::ineq::display(std::ostream& out) const {
|
||||
bool first = true;
|
||||
|
@ -118,7 +116,7 @@ namespace sls {
|
|||
template<typename num_t>
|
||||
void arith_base<num_t>::save_best_values() {
|
||||
for (auto& v : m_vars)
|
||||
v.m_best_value = v.m_value;
|
||||
v.set_best_value(v.value());
|
||||
check_ineqs();
|
||||
}
|
||||
|
||||
|
@ -168,8 +166,8 @@ namespace sls {
|
|||
template<typename num_t>
|
||||
num_t arith_base<num_t>::dtt(bool sign, ineq const& ineq, var_t v, num_t const& new_value) const {
|
||||
for (auto const& [coeff, w] : ineq.m_args)
|
||||
if (w == v)
|
||||
return dtt(sign, ineq.m_args_value + coeff * (new_value - m_vars[v].m_value), ineq);
|
||||
if (w == v)
|
||||
return dtt(sign, ineq.m_args_value + coeff * (new_value - m_vars[v].value()), ineq);
|
||||
return num_t(1);
|
||||
}
|
||||
|
||||
|
@ -444,17 +442,19 @@ namespace sls {
|
|||
|
||||
delta_out = delta;
|
||||
|
||||
if (m_last_var == v && m_last_delta == -delta)
|
||||
return false;
|
||||
if (m_last_var == v && m_last_delta == -delta)
|
||||
return false;
|
||||
|
||||
if (m_use_tabu && vi.is_tabu(m_stats.m_num_steps, delta))
|
||||
if (m_use_tabu && vi.is_tabu(m_stats.m_num_steps, delta))
|
||||
return false;
|
||||
|
||||
|
||||
auto old_value = value(v);
|
||||
auto new_value = old_value + delta;
|
||||
if (!vi.in_range(new_value))
|
||||
return false;
|
||||
|
||||
|
||||
if (m_use_tabu && !in_bounds(v, new_value) && in_bounds(v, old_value)) {
|
||||
auto const& lo = m_vars[v].m_lo;
|
||||
auto const& hi = m_vars[v].m_hi;
|
||||
|
@ -490,9 +490,7 @@ namespace sls {
|
|||
void arith_base<num_t>::add_update(var_t v, num_t delta) {
|
||||
num_t delta_out;
|
||||
if (!is_permitted_update(v, delta, delta_out))
|
||||
return;
|
||||
|
||||
|
||||
return;
|
||||
m_updates.push_back({ v, delta_out, 0 });
|
||||
}
|
||||
|
||||
|
@ -647,7 +645,7 @@ namespace sls {
|
|||
bool arith_base<num_t>::update(var_t v, num_t const& new_value) {
|
||||
auto& vi = m_vars[v];
|
||||
expr* e = vi.m_expr;
|
||||
auto old_value = vi.m_value;
|
||||
auto old_value = vi.value();
|
||||
if (old_value == new_value)
|
||||
return true;
|
||||
if (!vi.in_range(new_value))
|
||||
|
@ -665,15 +663,10 @@ namespace sls {
|
|||
}
|
||||
}
|
||||
catch (overflow_exception const&) {
|
||||
verbose_stream() << "overflow1\n";
|
||||
return false;
|
||||
}
|
||||
|
||||
#if 0
|
||||
if (!check_update(v, new_value))
|
||||
return false;
|
||||
apply_checked_update();
|
||||
#else
|
||||
|
||||
buffer<sat::bool_var> to_flip;
|
||||
for (auto const& [coeff, bv] : vi.m_bool_vars) {
|
||||
auto& ineq = *atom(bv);
|
||||
|
@ -687,12 +680,13 @@ namespace sls {
|
|||
|
||||
}
|
||||
IF_VERBOSE(5, verbose_stream() << "repair: v" << v << " := " << old_value << " -> " << new_value << "\n");
|
||||
vi.m_value = new_value;
|
||||
vi.set_value(new_value);
|
||||
ctx.new_value_eh(e);
|
||||
m_last_var = v;
|
||||
|
||||
for (auto bv : to_flip) {
|
||||
ctx.flip(bv);
|
||||
if (dtt(sign(bv), *atom(bv)) != 0)
|
||||
ctx.flip(bv);
|
||||
SASSERT(dtt(sign(bv), *atom(bv)) == 0);
|
||||
}
|
||||
|
||||
|
@ -711,6 +705,7 @@ namespace sls {
|
|||
prod *= power_of(value(w), p);
|
||||
}
|
||||
catch (overflow_exception const&) {
|
||||
verbose_stream() << "overflow\n";
|
||||
return false;
|
||||
}
|
||||
if (value(w) != prod && !update(w, prod))
|
||||
|
@ -727,82 +722,10 @@ namespace sls {
|
|||
if (!update(ad.m_var, sum))
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template<typename num_t>
|
||||
bool arith_base<num_t>::check_update(var_t v, num_t new_value) {
|
||||
|
||||
++m_update_timestamp;
|
||||
if (m_update_timestamp == 0) {
|
||||
for (auto& vi : m_vars)
|
||||
vi.set_update_value(num_t(0), 0);
|
||||
++m_update_timestamp;
|
||||
}
|
||||
auto& vi = m_vars[v];
|
||||
m_update_trail.reset();
|
||||
m_update_trail.push_back(v);
|
||||
vi.set_update_value(new_value, m_update_timestamp);
|
||||
|
||||
num_t delta;
|
||||
for (unsigned i = 0; i < m_update_trail.size(); ++i) {
|
||||
auto v = m_update_trail[i];
|
||||
auto& vi = m_vars[v];
|
||||
for (auto idx : vi.m_muls) {
|
||||
auto const& [w, monomial] = m_muls[idx];
|
||||
num_t prod(1);
|
||||
try {
|
||||
for (auto [w, p] : monomial)
|
||||
prod *= power_of(get_update_value(w), p);
|
||||
}
|
||||
catch (overflow_exception const&) {
|
||||
return false;
|
||||
}
|
||||
if (get_update_value(w) != prod && (!is_permitted_update(w, prod - value(w), delta) || prod - value(w) != delta))
|
||||
return false;
|
||||
m_update_trail.push_back(w);
|
||||
m_vars[w].set_update_value(prod, m_update_timestamp);
|
||||
}
|
||||
|
||||
for (auto idx : vi.m_adds) {
|
||||
auto const& ad = m_adds[idx];
|
||||
auto w = ad.m_var;
|
||||
num_t sum(ad.m_coeff);
|
||||
for (auto const& [coeff, w] : ad.m_args)
|
||||
sum += coeff * get_update_value(w);
|
||||
if (get_update_value(v) != sum && !(is_permitted_update(w, sum - value(w), delta) || sum - value(w) != delta))
|
||||
return false;
|
||||
m_update_trail.push_back(w);
|
||||
m_vars[w].set_update_value(sum, m_update_timestamp);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template<typename num_t>
|
||||
void arith_base<num_t>::apply_checked_update() {
|
||||
for (auto v : m_update_trail) {
|
||||
auto & vi = m_vars[v];
|
||||
auto old_value = vi.m_value;
|
||||
vi.m_value = vi.get_update_value(m_update_timestamp);
|
||||
auto new_value = vi.m_value;
|
||||
ctx.new_value_eh(vi.m_expr);
|
||||
for (auto const& [coeff, bv] : vi.m_bool_vars) {
|
||||
auto& ineq = *atom(bv);
|
||||
bool old_sign = sign(bv);
|
||||
sat::literal lit(bv, old_sign);
|
||||
SASSERT(ctx.is_true(lit));
|
||||
ineq.m_args_value += coeff * (new_value - old_value);
|
||||
num_t dtt_new = dtt(old_sign, ineq);
|
||||
if (dtt_new != 0)
|
||||
ctx.flip(bv);
|
||||
SASSERT(dtt(sign(bv), ineq) == 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template<typename num_t>
|
||||
typename arith_base<num_t>::ineq& arith_base<num_t>::new_ineq(ineq_kind op, num_t const& coeff) {
|
||||
auto* i = alloc(ineq);
|
||||
|
@ -906,7 +829,7 @@ namespace sls {
|
|||
m_vars[w].m_muls.push_back(idx), prod *= power_of(value(w), p);
|
||||
m_vars[v].m_def_idx = idx;
|
||||
m_vars[v].m_op = arith_op_kind::OP_MUL;
|
||||
m_vars[v].m_value = prod;
|
||||
m_vars[v].set_value(prod);
|
||||
add_arg(term, coeff, v);
|
||||
break;
|
||||
}
|
||||
|
@ -972,7 +895,7 @@ namespace sls {
|
|||
m_ops.push_back({v, k, v, w});
|
||||
m_vars[v].m_def_idx = idx;
|
||||
m_vars[v].m_op = k;
|
||||
m_vars[v].m_value = val;
|
||||
m_vars[v].set_value(val);
|
||||
return v;
|
||||
}
|
||||
|
||||
|
@ -993,7 +916,7 @@ namespace sls {
|
|||
m_vars[w].m_adds.push_back(idx), sum += c * value(w);
|
||||
m_vars[v].m_def_idx = idx;
|
||||
m_vars[v].m_op = arith_op_kind::OP_ADD;
|
||||
m_vars[v].m_value = sum;
|
||||
m_vars[v].set_value(sum);
|
||||
return v;
|
||||
}
|
||||
|
||||
|
@ -1055,6 +978,7 @@ namespace sls {
|
|||
else {
|
||||
SASSERT(!a.is_arith_expr(e));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
template<typename num_t>
|
||||
|
@ -1345,6 +1269,7 @@ namespace sls {
|
|||
hi_valid = false;
|
||||
}
|
||||
catch (overflow_exception&) {
|
||||
verbose_stream() << "overflow3\n";
|
||||
hi_valid = false;
|
||||
}
|
||||
}
|
||||
|
@ -2021,7 +1946,7 @@ namespace sls {
|
|||
if (is_num(e, n))
|
||||
return expr_ref(a.mk_numeral(n.to_rational(), a.is_int(e)), m);
|
||||
auto v = mk_term(e);
|
||||
return expr_ref(a.mk_numeral(m_vars[v].m_value.to_rational(), a.is_int(e)), m);
|
||||
return expr_ref(a.mk_numeral(m_vars[v].value().to_rational(), a.is_int(e)), m);
|
||||
}
|
||||
|
||||
template<typename num_t>
|
||||
|
@ -2112,7 +2037,7 @@ namespace sls {
|
|||
auto const& vi = m_vars[v];
|
||||
auto const& lo = vi.m_lo;
|
||||
auto const& hi = vi.m_hi;
|
||||
out << "v" << v << " := " << vi.m_value << " ";
|
||||
out << "v" << v << " := " << vi.value() << " ";
|
||||
if (lo || hi) {
|
||||
if (lo)
|
||||
out << (lo->is_strict ? "(": "[") << lo->value;
|
||||
|
|
|
@ -76,13 +76,14 @@ namespace sls {
|
|||
|
||||
class var_info {
|
||||
num_t m_range{ 100000000 };
|
||||
num_t m_update_value{ 0 };
|
||||
unsigned m_update_timestamp = 0;
|
||||
unsigned m_num_out_of_range = 0;
|
||||
unsigned m_num_in_range = 0;
|
||||
num_t m_value{ 0 };
|
||||
num_t m_best_value{ 0 };
|
||||
public:
|
||||
var_info(expr* e, var_sort k): m_expr(e), m_sort(k) {}
|
||||
expr* m_expr;
|
||||
num_t m_value{ 0 };
|
||||
num_t m_best_value{ 0 };
|
||||
|
||||
var_sort m_sort;
|
||||
arith_op_kind m_op = arith_op_kind::LAST_ARITH_OP;
|
||||
unsigned m_def_idx = UINT_MAX;
|
||||
|
@ -91,23 +92,27 @@ namespace sls {
|
|||
unsigned_vector m_adds;
|
||||
optional<bound> m_lo, m_hi;
|
||||
|
||||
// retrieve temporary value during an update.
|
||||
void set_update_value(num_t const& v, unsigned timestamp) {
|
||||
m_update_value = v;
|
||||
m_update_timestamp = timestamp;
|
||||
}
|
||||
num_t const& get_update_value(unsigned ts) const {
|
||||
return ts == m_update_timestamp ? m_update_value : m_value;
|
||||
}
|
||||
num_t const& value() const { return m_value; }
|
||||
void set_value(num_t const& v) { m_value = v; }
|
||||
|
||||
bool in_range(num_t const& n) const {
|
||||
num_t const& best_value() const { return m_best_value; }
|
||||
void set_best_value(num_t const& v) { m_best_value = v; }
|
||||
|
||||
bool in_range(num_t const& n) {
|
||||
if (-m_range < n && n < m_range)
|
||||
return true;
|
||||
bool result = false;
|
||||
if (m_lo && !m_hi)
|
||||
return n < m_lo->value + m_range;
|
||||
if (!m_lo && m_hi)
|
||||
return n > m_hi->value - m_range;
|
||||
return false;
|
||||
result = n < m_lo->value + m_range;
|
||||
else if (!m_lo && m_hi)
|
||||
result = n > m_hi->value - m_range;
|
||||
#if 0
|
||||
if (!result)
|
||||
out_of_range();
|
||||
else
|
||||
++m_num_in_range;
|
||||
#endif
|
||||
return result;
|
||||
}
|
||||
unsigned m_tabu_pos = 0, m_tabu_neg = 0;
|
||||
unsigned m_last_pos = 0, m_last_neg = 0;
|
||||
|
@ -120,6 +125,15 @@ namespace sls {
|
|||
else
|
||||
m_tabu_neg = tabu_step, m_last_neg = step;
|
||||
}
|
||||
void out_of_range() {
|
||||
++m_num_out_of_range;
|
||||
if (m_num_out_of_range < 1000 * (1 + m_num_in_range))
|
||||
return;
|
||||
IF_VERBOSE(2, verbose_stream() << "increase range " << m_range << "\n");
|
||||
m_range *= 2;
|
||||
m_num_out_of_range = 0;
|
||||
m_num_in_range = 0;
|
||||
}
|
||||
};
|
||||
|
||||
struct mul_def {
|
||||
|
@ -187,10 +201,7 @@ namespace sls {
|
|||
|
||||
void add_update(var_t v, num_t delta);
|
||||
bool is_permitted_update(var_t v, num_t const& delta, num_t& delta_out);
|
||||
unsigned m_update_timestamp = 0;
|
||||
svector<var_t> m_update_trail;
|
||||
bool check_update(var_t v, num_t new_value);
|
||||
void apply_checked_update();
|
||||
|
||||
|
||||
num_t value1(var_t v);
|
||||
|
||||
|
@ -247,8 +258,7 @@ namespace sls {
|
|||
|
||||
bool is_int(var_t v) const { return m_vars[v].m_sort == var_sort::INT; }
|
||||
|
||||
num_t value(var_t v) const { return m_vars[v].m_value; }
|
||||
num_t const& get_update_value(var_t v) const { return m_vars[v].get_update_value(m_update_timestamp); }
|
||||
num_t value(var_t v) const { return m_vars[v].value(); }
|
||||
bool is_num(expr* e, num_t& i);
|
||||
expr_ref from_num(sort* s, num_t const& n);
|
||||
void check_ineqs();
|
||||
|
|
|
@ -27,7 +27,7 @@ namespace sls {
|
|||
return m_arith64->_fn_;\
|
||||
}\
|
||||
catch (overflow_exception&) {\
|
||||
throw;\
|
||||
IF_VERBOSE(1, verbose_stream() << "revert to bignum solver " << #_fn_ << "\n");\
|
||||
init_backup();\
|
||||
}\
|
||||
}\
|
||||
|
@ -39,7 +39,7 @@ namespace sls {
|
|||
m_arith64->_fn_;\
|
||||
}\
|
||||
catch (overflow_exception&) {\
|
||||
throw;\
|
||||
IF_VERBOSE(1, verbose_stream() << "revert to bignum solver " << #_fn_ << "\n");\
|
||||
init_backup();\
|
||||
}\
|
||||
}\
|
||||
|
@ -49,11 +49,7 @@ namespace sls {
|
|||
plugin(ctx), m_shared(ctx.get_manager()) {
|
||||
m_arith64 = alloc(arith_base<checked_int64<true>>, ctx);
|
||||
m_arith = alloc(arith_base<rational>, ctx);
|
||||
m_arith64 = nullptr;
|
||||
if (m_arith)
|
||||
m_fid = m_arith->fid();
|
||||
else
|
||||
m_fid = m_arith64->fid();
|
||||
m_fid = m_arith->fid();
|
||||
}
|
||||
|
||||
void arith_plugin::init_backup() {
|
||||
|
|
|
@ -115,7 +115,7 @@ namespace sls {
|
|||
m_ddfw->rlimit().pop();
|
||||
}
|
||||
|
||||
void smt_plugin::finalize(model_ref& mdl, ::statistics& st) {
|
||||
void smt_plugin::finalize(model_ref& mdl) {
|
||||
auto* d = m_ddfw;
|
||||
if (!d)
|
||||
return;
|
||||
|
@ -126,7 +126,6 @@ namespace sls {
|
|||
if (m_thread.joinable())
|
||||
m_thread.join();
|
||||
SASSERT(m_completed);
|
||||
st.copy(m_st);
|
||||
mdl = nullptr;
|
||||
if (m_result == l_true && m_sls_model) {
|
||||
ast_translation tr(m_sls, m);
|
||||
|
@ -140,6 +139,10 @@ namespace sls {
|
|||
dealloc(d);
|
||||
}
|
||||
|
||||
void smt_plugin::collect_statistics(::statistics& st) const {
|
||||
st.copy(m_st);
|
||||
}
|
||||
|
||||
void smt_plugin::get_shared_clauses(vector<sat::literal_vector>& _clauses) {
|
||||
_clauses.reset();
|
||||
for (auto const& clause : clauses()) {
|
||||
|
@ -257,7 +260,7 @@ namespace sls {
|
|||
void smt_plugin::sls_phase_to_smt() {
|
||||
if (!m_has_new_sls_phase)
|
||||
return;
|
||||
IF_VERBOSE(2, verbose_stream() << "SLS -> SMT phase\n");
|
||||
IF_VERBOSE(2, verbose_stream() << "SLS -> SMT phase " << m_min_unsat_size << "\n");
|
||||
for (auto v : m_shared_bool_vars)
|
||||
ctx.force_phase(sat::literal(v, !m_sls_phase[v]));
|
||||
m_has_new_sls_phase = false;
|
||||
|
@ -290,7 +293,7 @@ namespace sls {
|
|||
}
|
||||
|
||||
void smt_plugin::export_from_sls() {
|
||||
if (unsat().size() > m_min_unsat_size)
|
||||
if (unsat().size() >= m_min_unsat_size)
|
||||
return;
|
||||
m_min_unsat_size = unsat().size();
|
||||
export_phase_from_sls();
|
||||
|
|
|
@ -106,7 +106,8 @@ namespace sls {
|
|||
|
||||
// interface to calling solver:
|
||||
void check(expr_ref_vector const& fmls, vector <sat::literal_vector> const& clauses);
|
||||
void finalize(model_ref& md, ::statistics& st);
|
||||
void collect_statistics(::statistics& st) const;
|
||||
void finalize(model_ref& md);
|
||||
void get_shared_clauses(vector<sat::literal_vector>& clauses);
|
||||
void updt_params(params_ref& p) {}
|
||||
std::ostream& display(std::ostream& out) override;
|
||||
|
|
|
@ -72,7 +72,8 @@ namespace sls {
|
|||
if (!m_smt_plugin)
|
||||
return;
|
||||
|
||||
m_smt_plugin->finalize(m_model, m_st);
|
||||
m_smt_plugin->collect_statistics(m_st);
|
||||
m_smt_plugin->finalize(m_model);
|
||||
m_model = nullptr;
|
||||
m_smt_plugin = nullptr;
|
||||
}
|
||||
|
@ -89,7 +90,8 @@ namespace sls {
|
|||
return false;
|
||||
if (!m_smt_plugin->completed())
|
||||
return false;
|
||||
m_smt_plugin->finalize(m_model, m_st);
|
||||
m_smt_plugin->collect_statistics(m_st);
|
||||
m_smt_plugin->finalize(m_model);
|
||||
m_smt_plugin = nullptr;
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -62,6 +62,8 @@ namespace smt {
|
|||
}
|
||||
|
||||
bool theory_sls::get_smt_value(expr* v, expr_ref& value) {
|
||||
if (!ctx.e_internalized(v))
|
||||
return false;
|
||||
auto* n = ctx.get_enode(v);
|
||||
return n && ctx.get_value(n, value);
|
||||
}
|
||||
|
@ -78,7 +80,8 @@ namespace smt {
|
|||
if (!m_smt_plugin)
|
||||
return;
|
||||
|
||||
m_smt_plugin->finalize(m_model, m_st);
|
||||
m_smt_plugin->collect_statistics(m_st);
|
||||
m_smt_plugin->finalize(m_model);
|
||||
m_model = nullptr;
|
||||
m_smt_plugin = nullptr;
|
||||
}
|
||||
|
@ -98,7 +101,8 @@ namespace smt {
|
|||
else if (!m_parallel_mode)
|
||||
propagate_local_search();
|
||||
else if (m_smt_plugin->completed()) {
|
||||
m_smt_plugin->finalize(m_model, m_st);
|
||||
m_smt_plugin->collect_statistics(m_st);
|
||||
m_smt_plugin->finalize(m_model);
|
||||
m_smt_plugin = nullptr;
|
||||
}
|
||||
}
|
||||
|
@ -184,7 +188,10 @@ namespace smt {
|
|||
}
|
||||
|
||||
void theory_sls::collect_statistics(::statistics& st) const {
|
||||
st.copy(m_st);
|
||||
if (m_smt_plugin)
|
||||
m_smt_plugin->collect_statistics(st);
|
||||
else
|
||||
st.copy(m_st);
|
||||
}
|
||||
|
||||
void theory_sls::restart_eh() {
|
||||
|
@ -205,7 +212,8 @@ namespace smt {
|
|||
void theory_sls::bounded_run(unsigned num_steps) {
|
||||
m_smt_plugin->bounded_run(num_steps);
|
||||
if (m_smt_plugin->result() == l_true) {
|
||||
m_smt_plugin->finalize(m_model, m_st);
|
||||
m_smt_plugin->collect_statistics(m_st);
|
||||
m_smt_plugin->finalize(m_model);
|
||||
m_smt_plugin = nullptr;
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue