3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-23 09:05:31 +00:00

fixes to sls

This commit is contained in:
Nikolaj Bjorner 2024-07-27 03:29:54 +02:00
parent 5e62984178
commit fce21981c6
16 changed files with 521 additions and 80 deletions

View file

@ -68,8 +68,9 @@ namespace sat {
else if (should_restart()) do_restart(), m_plugin->on_restart();
else if (do_flip<true>());
else shift_weights(), m_plugin->on_rescale();
verbose_stream() << "steps: " << steps << " min_sz: " << m_min_sz << " unsat: " << m_unsat.size() << "\n";
//verbose_stream() << "steps: " << steps << " min_sz: " << m_min_sz << " unsat: " << m_unsat.size() << "\n";
++steps;
SASSERT(m_unsat.size() >= m_min_sz);
}
}
catch (z3_exception& ex) {
@ -114,7 +115,7 @@ namespace sat {
if (reward > 0 || (reward == 0 && m_rand(100) <= m_config.m_use_reward_zero_pct)) {
flip(v);
if (m_unsat.size() <= m_min_sz)
if (m_unsat.size() <= m_min_sz)
save_best_values();
return true;
}
@ -124,32 +125,46 @@ namespace sat {
template<bool uses_plugin>
bool_var ddfw::pick_var(double& r) {
double sum_pos = 0;
unsigned n = 1;
unsigned n = 1, m = 1;
bool_var v0 = null_bool_var;
bool_var v1 = null_bool_var;
if (m_unsat_vars.empty())
return null_bool_var;
for (bool_var v : m_unsat_vars) {
r = uses_plugin ? plugin_reward(v) : reward(v);
r = reward(v);
if (r > 0.0)
sum_pos += score(r);
else if (r == 0.0 && sum_pos == 0 && (m_rand() % (n++)) == 0)
v0 = v;
else if (m_rand(m++) == 0)
v1 = v;
}
if (v0 != null_bool_var && m_rand(20) == 0)
return v0;
if (v1 != null_bool_var && m_rand(20) == 0)
return v1;
if (sum_pos > 0) {
double lim_pos = ((double) m_rand() / (1.0 + m_rand.max_value())) * sum_pos;
double lim_pos = ((double) m_rand() / (1.0 + m_rand.max_value())) * sum_pos;
for (bool_var v : m_unsat_vars) {
r = uses_plugin && is_external(v) ? m_vars[v].m_last_reward : reward(v);
r = reward(v);
if (r > 0) {
lim_pos -= score(r);
if (lim_pos <= 0)
return v;
if (lim_pos <= 0) {
return v;
}
}
}
}
r = 0;
if (v0 != null_bool_var)
return v0;
if (m_unsat_vars.empty())
return null_bool_var;
return m_unsat_vars.elem_at(m_rand(m_unsat_vars.size()));
if (v0 == null_bool_var)
v0 = m_unsat_vars.elem_at(m_rand(m_unsat_vars.size()));
return v0;
}
void ddfw::add(unsigned n, literal const* c) {
@ -351,13 +366,15 @@ namespace sat {
break;
}
}
save_best_values();
}
bool ddfw::should_restart() {
return m_flips >= m_restart_next;
}
void ddfw::do_restart() {
void ddfw::do_restart() {
verbose_stream() << "restart\n";
reinit_values();
init_clause_data();
m_restart_next += m_config.m_restart_base*get_luby(++m_restart_count);
@ -403,6 +420,11 @@ namespace sat {
if (m_unsat.size() < 50 || m_min_sz * 10 > m_unsat.size() * 11)
save_model();
}
#if 0
if (m_unsat.size() <= m_min_sz) {
verbose_stream() << "unsat " << m_clauses[m_unsat[0]] << "\n";
}
#endif
if (m_unsat.size() < m_min_sz) {
m_models.reset();
// skip saving the first model.

View file

@ -38,7 +38,6 @@ namespace sat {
virtual ~local_search_plugin() {}
virtual void init_search() = 0;
virtual void finish_search() = 0;
virtual double reward(bool_var v) = 0;
virtual void on_rescale() = 0;
virtual void on_save_model() = 0;
virtual void on_restart() = 0;
@ -124,8 +123,6 @@ namespace sat {
inline double& reward(bool_var v) { return m_vars[v].m_reward; }
inline double plugin_reward(bool_var v) { return is_external(v) ? (m_vars[v].m_last_reward = m_plugin->reward(v)) : reward(v); }
void set_external(bool_var v) { m_vars[v].m_external = true; }
inline bool is_external(bool_var v) const { return m_vars[v].m_external; }

View file

@ -32,6 +32,8 @@ namespace sls {
}
}
template<typename num_t>
std::ostream& arith_base<num_t>::ineq::display(std::ostream& out) const {
bool first = true;
@ -141,6 +143,49 @@ namespace sls {
auto argsv = ineq.m_args_value;
bool solved = false;
num_t delta = argsv - bound;
auto const& lo = m_vars[v].m_lo;
auto const& hi = m_vars[v].m_hi;
if (is_fixed(v))
return false;
auto well_formed = [&]() {
num_t new_args = argsv + coeff * (new_value - value(v));
if (ineq.is_true()) {
switch (ineq.m_op) {
case ineq_kind::LE: return new_args > bound;
case ineq_kind::LT: return new_args >= bound;
case ineq_kind::EQ: return new_args != bound;
}
}
else {
switch (ineq.m_op) {
case ineq_kind::LE: return new_args <= bound;
case ineq_kind::LT: return new_args < bound;
case ineq_kind::EQ: return new_args == bound;
}
}
return false;
};
auto move_to_bounds = [&]() {
VERIFY(well_formed());
if (!in_bounds(v, value(v)))
return true;
if (in_bounds(v, new_value))
return true;
if (lo && lo->value > new_value) {
new_value = lo->value;
if (!well_formed())
new_value += 1;
}
if (hi && hi->value < new_value) {
new_value = hi->value;
if (!well_formed())
new_value -= 1;
}
return well_formed() && in_bounds(v, new_value);
};
if (ineq.is_true()) {
switch (ineq.m_op) {
@ -148,24 +193,22 @@ namespace sls {
// args <= bound -> args > bound
SASSERT(argsv <= bound);
SASSERT(delta <= 0);
delta -= 1 + (ctx.rand(10));
new_value = value(v) + divide(v, abs(delta), coeff);
VERIFY(argsv + coeff * (new_value - value(v)) > bound);
return true;
delta -= 1;
new_value = value(v) + divide(v, abs(delta - ctx.rand(3)), coeff);
return move_to_bounds();
case ineq_kind::LT:
// args < bound -> args >= bound
SASSERT(argsv <= bound);
SASSERT(delta <= 0);
delta = abs(delta) + ctx.rand(10);
new_value = value(v) + divide(v, delta, coeff);
delta = abs(delta);
new_value = value(v) + divide(v, delta + ctx.rand(3), coeff);
VERIFY(argsv + coeff * (new_value - value(v)) >= bound);
return true;
return move_to_bounds();
case ineq_kind::EQ: {
delta = abs(delta) + 1 + ctx.rand(10);
int sign = ctx.rand(2) == 0 ? 1 : -1;
new_value = value(v) + sign * divide(v, abs(delta), coeff);
VERIFY(argsv + coeff * (new_value - value(v)) != bound);
return true;
return move_to_bounds();
}
default:
UNREACHABLE();
@ -178,16 +221,14 @@ namespace sls {
SASSERT(argsv > bound);
SASSERT(delta > 0);
delta += ctx.rand(10);
new_value = value(v) - divide(v, delta, coeff);
VERIFY(argsv + coeff * (new_value - value(v)) <= bound);
return true;
new_value = value(v) - divide(v, delta + ctx.rand(3), coeff);
return move_to_bounds();
case ineq_kind::LT:
SASSERT(argsv >= bound);
SASSERT(delta >= 0);
delta += 1 + ctx.rand(10);
new_value = value(v) - divide(v, delta, coeff);
VERIFY(argsv + coeff * (new_value - value(v)) < bound);
return true;
new_value = value(v) - divide(v, delta + ctx.rand(3), coeff);
return move_to_bounds();
case ineq_kind::EQ:
SASSERT(delta != 0);
if (delta < 0)
@ -195,12 +236,7 @@ namespace sls {
else
new_value = value(v) - divide(v, delta, coeff);
solved = argsv + coeff * (new_value - value(v)) == bound;
if (!solved && abs(coeff) == 1) {
verbose_stream() << "did not solve equality " << ineq << " for " << v << "\n";
verbose_stream() << new_value << " " << value(v) << " delta " << delta << " lhs " << (argsv + coeff * (new_value - value(v))) << " bound " << bound << "\n";
UNREACHABLE();
}
return solved;
return solved && move_to_bounds();
default:
UNREACHABLE();
break;
@ -209,6 +245,130 @@ namespace sls {
return false;
}
template<typename num_t>
bool arith_base<num_t>::solve_eq_pairs(ineq const& ineq) {
SASSERT(ineq.m_op == ineq_kind::EQ);
auto v = ineq.m_var_to_flip;
if (is_fixed(v))
return false;
auto bound = -ineq.m_coeff;
auto argsv = ineq.m_args_value;
num_t a;
for (auto const& [c, w] : ineq.m_args)
if (v == w) {
a = c;
argsv -= value(v) * c;
}
if (abs(a) == 1)
return false;
verbose_stream() << "solve_eq_pairs " << ineq << " for v" << v << "\n";
unsigned start = ctx.rand();
for (unsigned i = 0; i < ineq.m_args.size(); ++i) {
unsigned j = (start + i) % ineq.m_args.size();
auto const& [b, w] = ineq.m_args[j];
if (w == v)
continue;
if (b == 1 || b == -1)
continue;
argsv -= value(w) * b;
if (solve_eq_pairs(a, v, b, w, bound - argsv))
return true;
argsv += value(w) * b;
}
return false;
}
// ax0 + by0 = r
// (x, y) = (x0 - k*b/g, y0 + k*a/g)
// find the min x1 >= x0 satisfying progression and where x1 >= lo(x)
// k*ab/g - k*ab/g = 0
template<typename num_t>
bool arith_base<num_t>::solve_eq_pairs(num_t const& _a, var_t x, num_t const& _b, var_t y, num_t const& r) {
if (is_fixed(y))
return false;
num_t x0, y0;
num_t a = _a, b = _b;
num_t g = gcd(a, b, x0, y0);
SASSERT(g >= 1);
SASSERT(g == a * x0 + b * y0);
if (!divides(g, r))
return false;
//verbose_stream() << g << " == " << a << "*" << x0 << " + " << b << "*" << y0 << "\n";
x0 *= div(r, g);
y0 *= div(r, g);
//verbose_stream() << r << " == " << a << "*" << x0 << " + " << b << "*" << y0 << "\n";
auto adjust_lo = [&](num_t& x0, num_t& y0, num_t a, num_t b, optional<bound> const& lo, optional<bound> const& hi) {
if (!lo || lo->value <= x0)
return true;
// x0 + k*b/g >= lo
// k*(b/g) >= lo - x0
// k >= (lo - x0)/(b/g)
// x1 := x0 + k*b/g
auto delta = lo->value - x0;
auto bg = abs(div(b, g));
verbose_stream() << g << " " << bg << " " << " " << delta << "\n";
auto k = divide(x, delta, bg);
auto x1 = x0 + k * bg;
if (hi && hi->value < x1)
return false;
x0 = x1;
y0 = y0 + k * (div(b, g) > 0 ? -div(a, g) : div(a, g));
SASSERT(r == a * x0 + b * y0);
return true;
};
auto adjust_hi = [&](num_t& x0, num_t& y0, num_t a, num_t b, optional<bound> const& lo, optional<bound> const& hi) {
if (!hi || hi->value >= x0)
return true;
// x0 + k*b/g <= hi
// k <= (x0 - hi)/(b/g)
auto delta = x0 - hi->value;
auto bg = abs(div(b, g));
auto k = div(delta, bg);
auto x1 = x0 - k * bg;
if (lo && lo->value < x1)
return false;
x0 = x1;
y0 = y0 - k * (div(b, g) > 0 ? -div(a, g) : div(a, g));
SASSERT(r == a * x0 + b * y0);
return true;
};
auto const& lo_x = m_vars[x].m_lo;
auto const& hi_x = m_vars[x].m_hi;
if (!adjust_lo(x0, y0, a, b, lo_x, hi_x))
return false;
if (!adjust_hi(x0, y0, a, b, lo_x, hi_x))
return false;
auto const& lo_y = m_vars[y].m_lo;
auto const& hi_y = m_vars[y].m_hi;
if (!adjust_lo(y0, x0, b, a, lo_y, hi_y))
return false;
if (!adjust_hi(y0, x0, b, a, lo_y, hi_y))
return false;
if (lo_x && lo_x->value > x0)
return false;
if (hi_x && hi_x->value < x0)
return false;
if (x0 == value(x))
return false;
if (abs(value(x)) * 2 < abs(x0))
return false;
if (abs(value(y)) * 2 < abs(y0))
return false;
update(x, x0);
update(y, y0);
return true;
}
// flip on the first positive score
// it could be changed to flip on maximal positive score
// or flip on maximal non-negative score
@ -216,24 +376,63 @@ namespace sls {
template<typename num_t>
void arith_base<num_t>::repair(sat::literal lit, ineq const& ineq) {
num_t new_value, old_value;
if (UINT_MAX == ineq.m_var_to_flip)
dtt_reward(lit);
dtt_reward(lit);
auto v = ineq.m_var_to_flip;
if (v == UINT_MAX) {
IF_VERBOSE(0, verbose_stream() << "no var to flip\n");
return;
}
if (repair_eq(lit, ineq))
return;
if (!cm(ineq, v, new_value)) {
display(verbose_stream(), v) << "\n";
IF_VERBOSE(0, verbose_stream() << "no critical move for " << v << "\n");
if (dtt(!ctx.is_true(lit), ineq) != 0)
ctx.flip(lit.var());
return;
}
verbose_stream() << "repair " << lit << ": " << ineq << " var: v" << v << " := " << value(v) << " -> " << new_value << "\n";
//for (auto const& [coeff, w] : ineq.m_args)
// display(verbose_stream(), w) << "\n";
update(v, new_value);
if (dtt(lit.sign(), ineq) != 0)
invariant(ineq);
if (dtt(!ctx.is_true(lit), ineq) != 0)
ctx.flip(lit.var());
}
template<typename num_t>
bool arith_base<num_t>::repair_eq(sat::literal lit, ineq const& ineq) {
if (lit.sign() || ineq.m_op != ineq_kind::EQ)
return false;
auto v = ineq.m_var_to_flip;
num_t new_value;
verbose_stream() << ineq << "\n";
for (auto const& [coeff, w] : ineq.m_args)
display(verbose_stream(), w) << "\n";
if (ctx.rand(10) == 0 && solve_eq_pairs(ineq)) {
verbose_stream() << ineq << "\n";
for (auto const& [coeff, w] : ineq.m_args)
display(verbose_stream(), w) << "\n";
}
else if (cm(ineq, v, new_value) && update(v, new_value))
;
else if (solve_eq_pairs(ineq)) {
verbose_stream() << ineq << "\n";
for (auto const& [coeff, w] : ineq.m_args)
display(verbose_stream(), w) << "\n";
}
else
return false;
SASSERT(dtt(!ctx.is_true(lit), ineq) == 0);
if (dtt(!ctx.is_true(lit), ineq) != 0)
ctx.flip(lit.var());
return true;
}
//
// dscore(op) = sum_c (dts(c,alpha) - dts(c,alpha_after)) * weight(c)
// TODO - use cached dts instead of computed dts
@ -349,6 +548,14 @@ namespace sls {
return true;
}
template<typename num_t>
bool arith_base<num_t>::is_fixed(var_t v) {
auto const& vi = m_vars[v];
auto const& lo = vi.m_lo;
auto const& hi = vi.m_hi;
return lo && hi && lo->value == hi->value && lo->value == value(v);
}
template<typename num_t>
bool arith_base<num_t>::update(var_t v, num_t const& new_value) {
auto& vi = m_vars[v];
@ -394,20 +601,29 @@ namespace sls {
SASSERT(dtt(sign(bv), ineq) == 0);
}
vi.m_value = new_value;
SASSERT(!m.is_value(e));
verbose_stream() << "new value eh " << mk_bounded_pp(e, m) << "\n";
ctx.new_value_eh(e);
for (auto idx : vi.m_muls) {
auto const& [w, coeff, monomial] = m_muls[idx];
ctx.new_value_eh(m_vars[w].m_expr);
num_t prod(coeff);
for (auto w : monomial)
prod *= value(w);
if (value(w) != prod)
update(w, prod);
}
for (auto idx : vi.m_adds) {
auto const& ad = m_adds[idx];
ctx.new_value_eh(m_vars[ad.m_var].m_expr);
num_t sum(ad.m_coeff);
for (auto const& [coeff, w] : ad.m_args)
sum += coeff * value(w);
if (sum != ad.m_coeff)
update(ad.m_var, sum);
}
if (m.is_value(e)) {
display(verbose_stream());
}
SASSERT(!m.is_value(e));
ctx.new_value_eh(e);
return true;
}
@ -421,7 +637,8 @@ namespace sls {
template<typename num_t>
void arith_base<num_t>::add_arg(linear_term& ineq, num_t const& c, var_t v) {
ineq.m_args.push_back({ c, v });
if (c != 0)
ineq.m_args.push_back({ c, v });
}
template<>
@ -686,6 +903,14 @@ namespace sls {
repair(lit, *ineq);
}
template<typename num_t>
void arith_base<num_t>::repair_literal(sat::literal lit) {
auto v = lit.var();
auto const* ineq = atom(v);
if (ineq && ineq->is_true() != ctx.is_true(v))
ctx.flip(v);
}
template<typename num_t>
bool arith_base<num_t>::propagate() {
return false;
@ -960,8 +1185,8 @@ namespace sls {
product *= value(v);
if (product == 0 || !divides(product, val))
continue;
update(w, div(val, product));
return true;
if (update(w, div(val, product)))
return true;
}
return false;
}
@ -976,7 +1201,7 @@ namespace sls {
product *= value(v);
if (product == val)
return true;
// verbose_stream() << "repair mul " << mk_bounded_pp(m_vars[v].m_expr, m) << " := " << val << "(product: " << product << ")\n";
verbose_stream() << "repair mul " << mk_bounded_pp(m_vars[v].m_expr, m) << " := " << val << "(product: " << product << ")\n";
unsigned sz = monomial.size();
if (ctx.rand(20) == 0)
return update(v, product);
@ -1006,28 +1231,39 @@ namespace sls {
if (!divides(coeff, val) && ctx.rand(2) == 0)
n = div(val + coeff - 1, coeff);
auto const& fs = factor(abs(n));
vector<num_t> coeffs(sz, num_t(ctx.rand(2) == 0 ? 1 : -1));
vector<num_t> coeffs(sz, num_t(1));
vector<num_t> gcds(sz, num_t(0));
num_t sign(1);
for (auto c : coeffs)
sign *= c;
unsigned i = 0;
unsigned i = 0;
for (auto w : monomial) {
for (auto idx : m_vars[w].m_muls) {
auto const& [w1, coeff1, monomial1] = m_muls[idx];
gcds[i] = gcd(gcds[i], abs(value(w1)));
}
auto const& vi = m_vars[w];
if (vi.m_lo && vi.m_lo->value >= 0)
coeffs[i] = 1;
else if (vi.m_hi && vi.m_hi->value < 0)
coeffs[i] = -1;
else
coeffs[i] = num_t(ctx.rand(2) == 0 ? 1 : -1);
++i;
}
for (auto f : fs)
coeffs[ctx.rand(sz)] *= f;
if ((sign == 0) != (n == 0))
coeffs[ctx.rand(sz)] *= -1;
// verbose_stream() << "value " << val << " coeff: " << coeff << " coeffs: " << coeffs << " factors: " << fs << "\n";
verbose_stream() << "value " << val << " coeff: " << coeff << " coeffs: " << coeffs << " factors: " << fs << "\n";
i = 0;
for (auto w : monomial)
if (!update(w, coeffs[i++]))
for (auto w : monomial) {
if (!update(w, coeffs[i++])) {
verbose_stream() << "failed to update v" << w << " to " << coeffs[i - 1] << "\n";
return false;
}
}
verbose_stream() << "all updated for v" << v << " := " << value(v) << "\n";
return true;
}
else {
@ -1151,31 +1387,72 @@ namespace sls {
if (!ineq)
return -1;
num_t new_value;
double max_result = -1;
unsigned n = 0;
double max_result = -100;
unsigned n = 0, mult = 2;
double sum_prob = 0;
unsigned i = 0;
m_probs.reserve(ineq->m_args.size());
for (auto const& [coeff, x] : ineq->m_args) {
if (!cm(*ineq, x, coeff, new_value))
continue;
double result = 0;
// auto old_value = m_vars[x].m_value;
for (auto const& [coeff, bv] : m_vars[x].m_bool_vars) {
result += ctx.reward(bv);
#if 0
bool old_sign = sign(bv);
auto dtt_old = dtt(old_sign, *atom(bv));
auto dtt_new = dtt(old_sign, *atom(bv), coeff, old_value, new_value);
if ((dtt_new == 0) != (dtt_old == 0))
result += ctx.reward(bv);
#endif
}
if (result > max_result || max_result == -1 || (result == max_result && (ctx.rand(++n) == 0))) {
max_result = result;
ineq->m_var_to_flip = x;
double prob = 0;
if (is_fixed(x))
prob = 0;
else if (!cm(*ineq, x, coeff, new_value))
prob = 0.5;
else {
auto old_value = m_vars[x].m_value;
for (auto const& [coeff, bv] : m_vars[x].m_bool_vars) {
bool old_sign = sign(bv);
auto dtt_old = dtt(old_sign, *atom(bv));
auto dtt_new = dtt(old_sign, *atom(bv), coeff, old_value, new_value);
if (dtt_new == 0 && dtt_old != 0)
result += 1;
if (dtt_new != 0 && dtt_old == 0)
result -= 1;
}
if (result > max_result || max_result == -100 || (result == max_result && (ctx.rand(++n) == 0)))
max_result = result;
if (result < 0)
prob = 0.1;
else if (result == 0)
prob = 0.2;
else
prob = result;
}
// verbose_stream() << "prob v" << x << " " << prob << "\n";
m_probs[i++] = prob;
sum_prob += prob;
}
double lim = sum_prob * ((double)ctx.rand() / random_gen().max_value());
do {
lim -= m_probs[--i];
}
while (lim >= 0 && i > 0);
ineq->m_var_to_flip = ineq->m_args[i].second;
return max_result;
}
#if 0
double sum_prob = 0;
unsigned i = 0;
clause const& c = get_clause(cls_idx);
for (literal lit : c) {
double prob = m_prob_break[m_breaks[lit.var()]];
m_probs[i++] = prob;
sum_prob += prob;
}
double lim = sum_prob * ((double)m_rand() / m_rand.max_value());
do {
lim -= m_probs[--i];
} while (lim >= 0 && i > 0);
#endif
// Newton function for integer square root.
template<typename num_t>
num_t arith_base<num_t>::sqrt(num_t n) {
@ -1203,14 +1480,14 @@ namespace sls {
}
}
static int increments[8] = { 4, 2, 4, 2, 4, 6, 2, 6 };
unsigned i = 0;
for (auto d = num_t(7); d * d <= n; d += num_t(increments[i++])) {
unsigned i = 0, j = 0;
for (auto d = num_t(7); d * d <= n && j < 3; d += num_t(increments[i++]), ++j) {
while (mod(n, d) == 0) {
m_factors.push_back(d);
n = div(n, d);
}
if (i == 8)
i = 0;
i = 0;
}
if (n > 1)
m_factors.push_back(n);
@ -1314,6 +1591,7 @@ namespace sls {
template<typename num_t>
bool arith_base<num_t>::is_sat() {
invariant();
for (auto const& clause : ctx.clauses()) {
bool sat = false;
for (auto lit : clause.m_clause) {
@ -1405,6 +1683,61 @@ namespace sls {
return out;
}
template<typename num_t>
void arith_base<num_t>::invariant() {
for (unsigned v = 0; v < ctx.num_bool_vars(); ++v) {
auto ineq = atom(v);
if (ineq)
invariant(*ineq);
}
auto& out = verbose_stream();
for (auto md : m_muls) {
auto const& [w, coeff, monomial] = md;
num_t prod(coeff);
for (auto v : monomial)
prod *= value(v);
//verbose_stream() << "check " << w << " " << monomial << "\n";
if (prod != value(w)) {
out << prod << " " << value(w) << "\n";
out << "v" << w << " := ";
for (auto w : monomial)
out << "v" << w << " ";
out << "\n";
}
SASSERT(prod == value(w));
}
for (auto ad : m_adds) {
//out << "check add " << ad.m_var << "\n";
num_t sum(ad.m_coeff);
for (auto [c, w] : ad.m_args)
sum += c * value(w);
if (sum != value(ad.m_var)) {
out << "v" << ad.m_var << " := ";
bool first = true;
for (auto [c, w] : ad.m_args)
out << (first ? "" : " + ") << c << "* v" << w;
if (ad.m_coeff != 0)
out << " + " << ad.m_coeff;
out << "\n";
}
SASSERT(sum == value(ad.m_var));
}
}
template<typename num_t>
void arith_base<num_t>::invariant(ineq const& i) {
num_t val(0);
for (auto const& [c, v] : i.m_args)
val += c * value(v);
//verbose_stream() << "invariant " << i << "\n";
if (val != i.m_args_value)
verbose_stream() << i << "\n";
SASSERT(val == i.m_args_value);
}
template<typename num_t>
void arith_base<num_t>::mk_model(model& mdl) {
}

View file

@ -102,9 +102,13 @@ namespace sls {
vector<add_def> m_adds;
vector<op_def> m_ops;
unsigned_vector m_expr2var;
svector<double> m_probs;
bool m_dscore_mode = false;
arith_util a;
void invariant();
void invariant(ineq const& i);
unsigned get_num_vars() const { return m_vars.size(); }
bool repair_mul1(mul_def const& md);
@ -120,7 +124,9 @@ namespace sls {
bool repair_to_int(op_def const& od);
bool repair_to_real(op_def const& od);
void repair(sat::literal lit, ineq const& ineq);
bool repair_eq(sat::literal lit, ineq const& ineq);
bool in_bounds(var_t v, num_t const& value);
bool is_fixed(var_t v);
vector<num_t> m_factors;
vector<num_t> const& factor(num_t n);
@ -144,6 +150,8 @@ namespace sls {
double dtt_reward(sat::literal lit);
double dscore(var_t v, num_t const& new_value) const;
void save_best_values();
bool solve_eq_pairs(ineq const& ineq);
bool solve_eq_pairs(num_t const& a, var_t x, num_t const& b, var_t y, num_t const& r);
var_t mk_var(expr* e);
var_t mk_term(expr* e);
@ -180,6 +188,7 @@ namespace sls {
bool propagate() override;
void repair_up(app* e) override;
bool repair_down(app* e) override;
void repair_literal(sat::literal lit) override;
bool is_sat() override;
void on_rescale() override;
void on_restart() override;

View file

@ -110,6 +110,10 @@ namespace sls {
WITH_FALLBACK(repair_up(e));
}
void arith_plugin::repair_literal(sat::literal lit) {
WITH_FALLBACK(repair_literal(lit));
}
void arith_plugin::set_value(expr* e, expr* v) {
WITH_FALLBACK(set_value(e, v));
}

View file

@ -37,6 +37,7 @@ namespace sls {
bool propagate() override;
bool repair_down(app* e) override;
void repair_up(app* e) override;
void repair_literal(sat::literal lit) override;
bool is_sat() override;
void on_rescale() override;

View file

@ -335,8 +335,17 @@ namespace sls {
set_value(e, b);
}
void basic_plugin::repair_literal(sat::literal lit) {
auto a = ctx.atom(lit.var());
if (!is_basic(a))
return;
if (bval1(to_app(a)) != bval0(to_app(a)))
ctx.flip(lit.var());
}
bool basic_plugin::repair_down(app* e) {
SASSERT(m.is_bool(e));
unsigned n = e->get_num_args();
if (!is_basic(e))
return false;
@ -345,6 +354,7 @@ namespace sls {
if (bval0(e) == bval1(e))
return true;
verbose_stream() << "basic repair down " << mk_bounded_pp(e, m) << "\n";
unsigned s = ctx.rand(n);
for (unsigned i = 0; i < n; ++i) {
auto j = (i + s) % n;

View file

@ -48,6 +48,7 @@ namespace sls {
bool propagate() override;
bool repair_down(app* e) override;
void repair_up(app* e) override;
void repair_literal(sat::literal lit) override;
bool is_sat() override;
void on_rescale() override {}

View file

@ -143,6 +143,15 @@ namespace sls {
}
}
void bv_plugin::repair_literal(sat::literal lit) {
SASSERT(ctx.is_true(lit));
auto a = ctx.atom(lit.var());
if (!a || !is_app(a))
return;
if (!m_eval.eval_is_correct(to_app(a)))
ctx.flip(lit.var());
}
std::ostream& bv_plugin::trace_repair(bool down, expr* e) {
verbose_stream() << (down ? "d #" : "u #")
<< e->get_id() << ": "

View file

@ -45,6 +45,7 @@ namespace sls {
bool propagate() override;
bool repair_down(app* e) override;
void repair_up(app* e) override;
void repair_literal(sat::literal lit) override;
bool is_sat() override;
void on_rescale() override {}

View file

@ -64,7 +64,10 @@ namespace sls {
propagate_boolean_assignment();
verbose_stream() << "propagate " << unsat().size() << " " << m_new_constraint << "\n";
// display(verbose_stream());
if (m_new_constraint || !unsat().empty())
@ -129,6 +132,16 @@ namespace sls {
for (auto p : m_plugins)
propagated |= p && !m_new_constraint && p->propagate();
}
for (sat::bool_var v = 0; v < s.num_vars(); ++v) {
auto a = atom(v);
if (!a)
continue;
sat::literal lit(v, !is_true(v));
auto p = m_plugins.get(get_fid(a), nullptr);
if (p)
p->repair_literal(lit);
}
}
family_id context::get_fid(expr* e) const {

View file

@ -42,6 +42,7 @@ namespace sls {
virtual void initialize() = 0;
virtual bool propagate() = 0;
virtual void propagate_literal(sat::literal lit) = 0;
virtual void repair_literal(sat::literal lit) = 0;
virtual bool repair_down(app* e) = 0;
virtual void repair_up(app* e) = 0;
virtual bool is_sat() = 0;
@ -69,6 +70,7 @@ namespace sls {
virtual void on_model(model_ref& mdl) = 0;
virtual sat::bool_var add_var() = 0;
virtual void add_clause(unsigned n, sat::literal const* lits) = 0;
virtual std::ostream& display(std::ostream& out) = 0;
};
class context {

View file

@ -41,7 +41,7 @@ namespace sls {
expr_ref get_value(expr* e) override;
void initialize() override {}
void propagate_literal(sat::literal lit) override {}
bool propagate() override;
bool propagate() override;
bool is_sat() override;
void register_term(expr* e) override;
std::ostream& display(std::ostream& out) const override;
@ -50,6 +50,7 @@ namespace sls {
void repair_up(app* e) override {}
bool repair_down(app* e) override { return false; }
void repair_literal(sat::literal lit) override {}
};
}

View file

@ -122,7 +122,7 @@ namespace sls {
}
void smt_solver::add_clause(expr* f) {
expr* g;
expr* g, * h;
sat::literal_vector clause;
if (m.is_not(f, g) && m.is_not(g, g)) {
add_clause(g);
@ -149,6 +149,18 @@ namespace sls {
clause.push_back(~mk_literal(arg));
m_solver_ctx->add_clause(clause.size(), clause.data());
}
else if (m.is_eq(f, g, h) && m.is_bool(g)) {
auto lit1 = mk_literal(g);
auto lit2 = mk_literal(h);
clause.reset();
clause.push_back(~lit1);
clause.push_back(lit2);
m_solver_ctx->add_clause(clause.size(), clause.data());
clause.reset();
clause.push_back(lit1);
clause.push_back(~lit2);
m_solver_ctx->add_clause(clause.size(), clause.data());
}
else {
sat::literal lit = mk_literal(f);
m_solver_ctx->add_clause(1, &lit);