3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-23 09:05:31 +00:00

bug fixes

This commit is contained in:
Nikolaj Bjorner 2024-08-13 14:50:17 -07:00
parent 920c207a27
commit afef727b88
4 changed files with 154 additions and 45 deletions

View file

@ -171,8 +171,10 @@ namespace sls {
num_t arith_base<num_t>::divide_floor(var_t v, num_t const& a, num_t const& b) {
if (!is_int(v))
return a / b;
if (b > 0)
if (b > 0 && a >= 0)
return div(a, b);
else if (b > 0)
return -div(-a + b - 1, b);
else if (a > 0)
return -div(a - b - 1, -b);
else
@ -183,8 +185,10 @@ namespace sls {
num_t arith_base<num_t>::divide_ceil(var_t v, num_t const& a, num_t const& b) {
if (!is_int(v))
return a / b;
if (b > 0)
if (b > 0 && a >= 0)
return div(a + b - 1, b);
else if (b > 0)
return -div(-a, b);
else if (a > 0)
return -div(a, -b);
else
@ -256,19 +260,6 @@ namespace sls {
lh += eps;
if (a * rl * rl + b * rl + c <= 0)
rl -= eps;
if (is_square && a * lh * lh + b * lh + c <= 0) {
num_t ll = divide_floor(x, -b - root, 2 * a);
num_t lh = divide_ceil(x, -b - root, 2 * a);
num_t rl = divide_floor(x, -b + root, 2 * a);
num_t rh = divide_ceil(x, -b + root, 2 * a);
verbose_stream() << a << " " << b << " " << c << "\n";
verbose_stream() << (-b - root) << " " << (2 * a) << " " << ll << " " << lh << "\n";
verbose_stream() << (-b + root) << " " << (2 * a) << " " << rl << " " << rh << "\n";
verbose_stream() << "root " << root << "\n";
UNREACHABLE();
}
SASSERT(!is_square || a * lh * lh + b * lh + c > 0);
SASSERT(!is_square || a * rl * rl + b * rl + c > 0);
add_update(x, lh - value(x));
@ -420,9 +411,7 @@ namespace sls {
num_t delta = sum;
SASSERT(sum != 0);
delta = sum < 0 ? divide(v, abs(sum), coeff) : -divide(v, sum, coeff);
if (sum + coeff * delta != 0)
solve_eq_pairs(v, ineq);
else
if (sum + coeff * delta == 0)
add_update(v, delta);
break;
}
@ -440,7 +429,7 @@ namespace sls {
if (m_last_var == v && m_last_delta == -delta)
return false;
if (m_tabu && vi.is_tabu(m_stats.m_num_steps, delta))
if (false && m_use_tabu && vi.is_tabu(m_stats.m_num_steps, delta))
return false;
auto old_value = value(v);
@ -448,7 +437,7 @@ namespace sls {
if (!vi.in_range(new_value))
return false;
if (!in_bounds(v, new_value) && in_bounds(v, old_value)) {
if (m_use_tabu && !in_bounds(v, new_value) && in_bounds(v, old_value)) {
auto const& lo = m_vars[v].m_lo;
auto const& hi = m_vars[v].m_hi;
if (lo && (lo->is_strict ? lo->value >= new_value : lo->value > new_value)) {
@ -492,7 +481,7 @@ namespace sls {
if (is_fixed(v))
return false;
auto argsv = ineq.m_args_value;
num_t a;
num_t a(0);
for (auto const& [c, w] : ineq.m_args)
if (v == w) {
a = c;
@ -501,6 +490,7 @@ namespace sls {
if (abs(a) == 1)
return false;
IF_VERBOSE(3, verbose_stream() << "solve_eq_pairs " << ineq << " for v" << v << "\n");
SASSERT(a != 0);
unsigned start = ctx.rand();
for (unsigned i = 0; i < ineq.m_args.size(); ++i) {
unsigned j = (start + i) % ineq.m_args.size();
@ -525,6 +515,7 @@ namespace sls {
if (is_fixed(y))
return false;
num_t x0, y0;
std::cout << "solve_eq_pairs " << _a << " v" << x << " " << _b << " v" << y << " " << r << "\n";
num_t a = _a, b = _b;
num_t g = gcd(a, b, x0, y0);
SASSERT(g >= 1);
@ -752,6 +743,14 @@ namespace sls {
IF_VERBOSE(10, display(verbose_stream(), v) << " := " << new_value << "\n");
#if 0
if (!check_update(v, new_value))
return false;
apply_checked_update();
#else
for (auto const& [coeff, bv] : vi.m_bool_vars) {
auto& ineq = *atom(bv);
bool old_sign = sign(bv);
@ -771,25 +770,106 @@ namespace sls {
for (auto idx : vi.m_muls) {
auto const& [w, coeff, monomial] = m_muls[idx];
ctx.new_value_eh(m_vars[w].m_expr);
num_t prod(coeff);
for (auto [w, p] : monomial)
prod *= power_of(value(w), p);
try {
for (auto [w, p] : monomial)
prod *= power_of(value(w), p);
}
catch (overflow_exception const&) {
return false;
}
if (value(w) != prod && !update(w, prod))
return false;
}
for (auto idx : vi.m_adds) {
auto const& ad = m_adds[idx];
auto w = ad.m_var;
ctx.new_value_eh(m_vars[w].m_expr);
num_t sum(ad.m_coeff);
for (auto const& [coeff, w] : ad.m_args)
sum += coeff * value(w);
if (!update(ad.m_var, sum))
return false;
}
#endif
return true;
}
template<typename num_t>
bool arith_base<num_t>::check_update(var_t v, num_t new_value) {
++m_update_timestamp;
if (m_update_timestamp == 0) {
for (auto& vi : m_vars)
vi.set_update_value(num_t(0), 0);
++m_update_timestamp;
}
auto& vi = m_vars[v];
m_update_trail.reset();
m_update_trail.push_back(v);
vi.set_update_value(new_value, m_update_timestamp);
for (unsigned i = 0; i < m_update_trail.size(); ++i) {
auto v = m_update_trail[i];
auto& vi = m_vars[v];
for (auto idx : vi.m_muls) {
auto const& [w, coeff, monomial] = m_muls[idx];
num_t prod(coeff);
try {
for (auto [w, p] : monomial)
prod *= power_of(get_update_value(w), p);
}
catch (overflow_exception const&) {
return false;
}
if (get_update_value(w) != prod && !is_permitted_update(w, prod - value(w)))
return false;
m_update_trail.push_back(w);
m_vars[w].set_update_value(prod, m_update_timestamp);
}
for (auto idx : vi.m_adds) {
auto const& ad = m_adds[idx];
auto w = ad.m_var;
num_t sum(ad.m_coeff);
for (auto const& [coeff, w] : ad.m_args)
sum += coeff * get_update_value(w);
if (get_update_value(v) != sum && !is_permitted_update(w, sum - value(w)))
return false;
m_update_trail.push_back(w);
m_vars[w].set_update_value(sum, m_update_timestamp);
}
}
return true;
}
template<typename num_t>
void arith_base<num_t>::apply_checked_update() {
for (auto v : m_update_trail) {
auto & vi = m_vars[v];
auto old_value = vi.m_value;
vi.m_value = vi.get_update_value(m_update_timestamp);
auto new_value = vi.m_value;
ctx.new_value_eh(vi.m_expr);
for (auto const& [coeff, bv] : vi.m_bool_vars) {
auto& ineq = *atom(bv);
bool old_sign = sign(bv);
sat::literal lit(bv, old_sign);
SASSERT(ctx.is_true(lit));
ineq.m_args_value += coeff * (new_value - old_value);
num_t dtt_new = dtt(old_sign, ineq);
// verbose_stream() << "dtt " << lit << " " << ineq << " " << dtt_new << "\n";
if (dtt_new != 0)
ctx.flip(bv);
SASSERT(dtt(sign(bv), ineq) == 0);
}
}
}
template<typename num_t>
typename arith_base<num_t>::ineq& arith_base<num_t>::new_ineq(ineq_kind op, num_t const& coeff) {
auto* i = alloc(ineq);
@ -1144,53 +1224,55 @@ namespace sls {
auto const& vi = m_vars[v];
if (vi.m_def_idx == UINT_MAX)
return;
num_t v1, v2;
num_t new_value, v1, v2;
switch (vi.m_op) {
case LAST_ARITH_OP:
break;
case OP_ADD: {
auto const& ad = m_adds[vi.m_def_idx];
auto const& args = ad.m_args;
num_t sum(ad.m_coeff);
new_value = ad.m_coeff;
for (auto [c, w] : args)
sum += c * value(w);
update(v, sum);
new_value += c * value(w);
break;
}
case OP_MUL: {
auto const& [w, coeff, monomial] = m_muls[vi.m_def_idx];
num_t prod(coeff);
new_value = coeff;
for (auto [w, p] : monomial)
prod *= power_of(value(w), p);
update(v, prod);
new_value *= power_of(value(w), p);
break;
}
case OP_MOD:
v1 = value(m_ops[vi.m_def_idx].m_arg1);
v2 = value(m_ops[vi.m_def_idx].m_arg2);
update(v, v2 == 0 ? num_t(0) : mod(v1, v2));
new_value = v2 == 0 ? num_t(0) : mod(v1, v2);
break;
case OP_DIV:
v1 = value(m_ops[vi.m_def_idx].m_arg1);
v2 = value(m_ops[vi.m_def_idx].m_arg2);
update(v, v2 == 0 ? num_t(0) : v1 / v2);
new_value = v2 == 0 ? num_t(0) : v1 / v2;
break;
case OP_IDIV:
v1 = value(m_ops[vi.m_def_idx].m_arg1);
v2 = value(m_ops[vi.m_def_idx].m_arg2);
update(v, v2 == 0 ? num_t(0) : div(v1, v2));
new_value = v2 == 0 ? num_t(0) : div(v1, v2);
break;
case OP_REM:
v1 = value(m_ops[vi.m_def_idx].m_arg1);
v2 = value(m_ops[vi.m_def_idx].m_arg2);
update(v, v2 == 0 ? num_t(0) : v1 %= v2);
new_value = v2 == 0 ? num_t(0) : v1 %= v2;
break;
case OP_ABS:
update(v, abs(value(m_ops[vi.m_def_idx].m_arg1)));
new_value = abs(value(m_ops[vi.m_def_idx].m_arg1));
break;
default:
NOT_IMPLEMENTED_YET();
}
if (!update(v, new_value))
ctx.new_value_eh(e);
}
template<typename num_t>
@ -1201,6 +1283,7 @@ namespace sls {
auto const& vi = m_vars[v];
if (vi.m_def_idx == UINT_MAX)
return false;
flet<bool> _tabu(m_use_tabu, false);
TRACE("sls", tout << "repair def " << mk_bounded_pp(vi.m_expr, m) << "\n");
switch (vi.m_op) {
case arith_op_kind::LAST_ARITH_OP:
@ -1234,11 +1317,11 @@ namespace sls {
template<typename num_t>
void arith_base<num_t>::initialize() {
for (auto lit : ctx.unit_literals())
initialize(lit);
initialize_unit(lit);
}
template<typename num_t>
void arith_base<num_t>::initialize(sat::literal lit) {
void arith_base<num_t>::initialize_unit(sat::literal lit) {
init_bool_var(lit.var());
auto* ineq = atom(lit.var());
if (!ineq)

View file

@ -69,7 +69,11 @@ namespace sls {
};
private:
struct var_info {
class var_info {
num_t m_range{ 100000000 };
num_t m_update_value{ 0 };
unsigned m_update_timestamp = 0;
public:
var_info(expr* e, var_sort k): m_expr(e), m_sort(k) {}
expr* m_expr;
num_t m_value{ 0 };
@ -81,7 +85,16 @@ namespace sls {
unsigned_vector m_muls;
unsigned_vector m_adds;
optional<bound> m_lo, m_hi;
num_t m_range{ 100000000 };
// retrieve temporary value during an update.
void set_update_value(num_t const& v, unsigned timestamp) {
m_update_value = v;
m_update_timestamp = timestamp;
}
num_t const& get_update_value(unsigned ts) const {
return ts == m_update_timestamp ? m_update_value : m_value;
}
bool in_range(num_t const& n) const {
if (-m_range < n && n < m_range)
return true;
@ -139,7 +152,7 @@ namespace sls {
vector<var_change> m_updates;
var_t m_last_var = 0;
num_t m_last_delta { 0 };
bool m_tabu = false;
bool m_use_tabu = true;
arith_util a;
void invariant();
@ -172,6 +185,10 @@ namespace sls {
void add_update(var_t v, num_t delta);
bool is_permitted_update(var_t v, num_t& delta);
unsigned m_update_timestamp = 0;
svector<var_t> m_update_trail;
bool check_update(var_t v, num_t new_value);
void apply_checked_update();
vector<num_t> m_factors;
vector<num_t> const& factor(num_t n);
@ -260,11 +277,12 @@ namespace sls {
bool is_int(var_t v) const { return m_vars[v].m_sort == var_sort::INT; }
num_t value(var_t v) const { return m_vars[v].m_value; }
num_t const& get_update_value(var_t v) const { return m_vars[v].get_update_value(m_update_timestamp); }
bool is_num(expr* e, num_t& i);
expr_ref from_num(sort* s, num_t const& n);
void check_ineqs();
void init_bool_var(sat::bool_var bv);
void initialize(sat::literal lit);
void initialize_unit(sat::literal lit);
void add_le(var_t v, num_t const& n);
void add_ge(var_t v, num_t const& n);
void add_lt(var_t v, num_t const& n);

View file

@ -101,8 +101,8 @@ namespace sls {
propagate_literal(lit);
}
while (!m_new_constraint && (!m_repair_up.empty() || !m_repair_down.empty())) {
while (!m_repair_down.empty() && !m_new_constraint) {
while (!m_new_constraint && m.inc() && (!m_repair_up.empty() || !m_repair_down.empty())) {
while (!m_repair_down.empty() && !m_new_constraint && m.inc()) {
auto id = m_repair_down.erase_min();
expr* e = term(id);
TRACE("sls", tout << "repair down " << mk_bounded_pp(e, m) << "\n");
@ -114,7 +114,7 @@ namespace sls {
}
}
}
while (!m_repair_up.empty() && !m_new_constraint) {
while (!m_repair_up.empty() && !m_new_constraint && m.inc()) {
auto id = m_repair_up.erase_min();
expr* e = term(id);
TRACE("sls", tout << "repair up " << mk_bounded_pp(e, m) << "\n");
@ -308,12 +308,14 @@ namespace sls {
}
}
);
// verbose_stream() << "new value " << mk_bounded_pp(e, m) << " " << mk_bounded_pp(get_value(e), m) << "\n";
m_repair_down.reserve(e->get_id() + 1);
m_repair_up.reserve(e->get_id() + 1);
if (!m_repair_down.contains(e->get_id()))
m_repair_down.insert(e->get_id());
for (auto p : parents(e)) {
m_repair_up.reserve(p->get_id() + 1);
m_repair_down.reserve(p->get_id() + 1);
if (!m_repair_up.contains(p->get_id()))
m_repair_up.insert(p->get_id());
}

View file

@ -46,7 +46,11 @@ namespace sls {
void on_restart() override {}
bool m_on_save_model = false;
void on_save_model() override {
if (m_on_save_model)
return;
flet<bool> _on_save_model(m_on_save_model, true);
TRACE("sls", display(tout));
while (unsat().empty()) {
m_context.check();
@ -185,6 +189,8 @@ namespace sls {
}
else {
sat::literal lit = mk_literal(f);
if (sign)
lit.neg();
m_solver_ctx->add_clause(1, &lit);
}
}