3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-23 09:05:31 +00:00

updates to repair logic, mainly arithmetic

This commit is contained in:
Nikolaj Bjorner 2024-07-21 21:03:14 -07:00
parent 5b0d49cd76
commit 5e62984178
13 changed files with 532 additions and 191 deletions

View file

@ -129,9 +129,10 @@ namespace sls {
template<typename num_t>
num_t arith_base<num_t>::divide(var_t v, num_t const& delta, num_t const& coeff) {
if (m_vars[v].m_sort == var_sort::REAL)
return delta / coeff;
return div(delta + abs(coeff) - 1, coeff);
if (is_int(v))
return div(delta + abs(coeff) - 1, coeff);
else
return delta / coeff;
}
template<typename num_t>
@ -147,7 +148,7 @@ namespace sls {
// args <= bound -> args > bound
SASSERT(argsv <= bound);
SASSERT(delta <= 0);
delta -= 1 + (ctx.rand() % 10);
delta -= 1 + (ctx.rand(10));
new_value = value(v) + divide(v, abs(delta), coeff);
VERIFY(argsv + coeff * (new_value - value(v)) > bound);
return true;
@ -155,13 +156,13 @@ namespace sls {
// args < bound -> args >= bound
SASSERT(argsv <= bound);
SASSERT(delta <= 0);
delta = abs(delta) + ctx.rand() % 10;
delta = abs(delta) + ctx.rand(10);
new_value = value(v) + divide(v, delta, coeff);
VERIFY(argsv + coeff * (new_value - value(v)) >= bound);
return true;
case ineq_kind::EQ: {
delta = abs(delta) + 1 + ctx.rand() % 10;
int sign = ctx.rand() % 2 == 0 ? 1 : -1;
delta = abs(delta) + 1 + ctx.rand(10);
int sign = ctx.rand(2) == 0 ? 1 : -1;
new_value = value(v) + sign * divide(v, abs(delta), coeff);
VERIFY(argsv + coeff * (new_value - value(v)) != bound);
return true;
@ -176,14 +177,14 @@ namespace sls {
case ineq_kind::LE:
SASSERT(argsv > bound);
SASSERT(delta > 0);
delta += rand() % 10;
delta += ctx.rand(10);
new_value = value(v) - divide(v, delta, coeff);
VERIFY(argsv + coeff * (new_value - value(v)) <= bound);
return true;
case ineq_kind::LT:
SASSERT(argsv >= bound);
SASSERT(delta >= 0);
delta += 1 + rand() % 10;
delta += 1 + ctx.rand(10);
new_value = value(v) - divide(v, delta, coeff);
VERIFY(argsv + coeff * (new_value - value(v)) < bound);
return true;
@ -229,6 +230,8 @@ namespace sls {
}
verbose_stream() << "repair " << lit << ": " << ineq << " var: v" << v << " := " << value(v) << " -> " << new_value << "\n";
update(v, new_value);
if (dtt(lit.sign(), ineq) != 0)
ctx.flip(lit.var());
}
//
@ -329,14 +332,55 @@ namespace sls {
return d;
}
template<typename num_t>
void arith_base<num_t>::update(var_t v, num_t const& new_value) {
bool arith_base<num_t>::in_bounds(var_t v, num_t const& value) {
auto const& vi = m_vars[v];
auto const& lo = vi.m_lo;
auto const& hi = vi.m_hi;
if (lo && value < lo->value)
return false;
if (lo && lo->is_strict && value <= lo->value)
return false;
if (hi && value > hi->value)
return false;
if (hi && hi->is_strict && value >= hi->value)
return false;
return true;
}
template<typename num_t>
bool arith_base<num_t>::update(var_t v, num_t const& new_value) {
auto& vi = m_vars[v];
expr* e = vi.m_expr;
auto old_value = vi.m_value;
if (old_value == new_value)
return;
verbose_stream() << mk_bounded_pp(e, m) << " := " << new_value << "\n";
return true;
display(verbose_stream(), v) << " := " << new_value << "\n";
if (!in_bounds(v, new_value)) {
auto const& lo = vi.m_lo;
auto const& hi = vi.m_hi;
if (is_int(v) && lo && !lo->is_strict && new_value < lo->value) {
if (lo->value != old_value)
return update(v, lo->value);
if (in_bounds(v, old_value + 1))
return update(v, old_value + 1);
else
return false;
}
if (is_int(v) && hi && !hi->is_strict && new_value > hi->value) {
if (hi->value != old_value)
return update(v, hi->value);
else if (in_bounds(v, old_value - 1))
return update(v, old_value - 1);
else
return false;
}
verbose_stream() << "out of bounds old value " << old_value << "\n";
display(verbose_stream(), v) << "\n";
SASSERT(false);
return false;
}
for (auto const& [coeff, bv] : vi.m_bool_vars) {
auto& ineq = *atom(bv);
bool old_sign = sign(bv);
@ -344,6 +388,7 @@ namespace sls {
SASSERT(ctx.is_true(lit));
ineq.m_args_value += coeff * (new_value - old_value);
num_t dtt_new = dtt(old_sign, ineq);
// verbose_stream() << "dtt " << lit << " " << ineq << " " << dtt_new << "\n";
if (dtt_new != 0)
ctx.flip(bv);
SASSERT(dtt(sign(bv), ineq) == 0);
@ -358,8 +403,12 @@ namespace sls {
ctx.new_value_eh(m_vars[ad.m_var].m_expr);
}
if (m.is_value(e)) {
display(verbose_stream());
}
SASSERT(!m.is_value(e));
ctx.new_value_eh(e);
return true;
}
template<typename num_t>
@ -433,7 +482,7 @@ namespace sls {
}
else if (a.is_mul(e)) {
unsigned_vector m;
num_t c = coeff;
num_t c(1);
for (expr* arg : *to_app(e))
if (is_num(arg, i))
c *= i;
@ -441,10 +490,10 @@ namespace sls {
m.push_back(mk_term(arg));
switch (m.size()) {
case 0:
term.m_coeff += c;
term.m_coeff += c*coeff;
break;
case 1:
add_arg(term, c, m[0]);
add_arg(term, c*coeff, m[0]);
break;
default: {
v = mk_var(e);
@ -456,7 +505,7 @@ namespace sls {
m_vars[v].m_def_idx = idx;
m_vars[v].m_op = arith_op_kind::OP_MUL;
m_vars[v].m_value = prod;
add_arg(term, num_t(1), v);
add_arg(term, coeff, v);
break;
}
}
@ -517,6 +566,7 @@ namespace sls {
NOT_IMPLEMENTED_YET();
break;
}
verbose_stream() << "mk-op " << mk_bounded_pp(e, m) << "\n";
m_ops.push_back({v, k, v, w});
m_vars[v].m_def_idx = idx;
m_vars[v].m_op = k;
@ -547,6 +597,7 @@ namespace sls {
template<typename num_t>
typename arith_base<num_t>::var_t arith_base<num_t>::mk_var(expr* e) {
SASSERT(!m.is_value(e));
var_t v = m_expr2var.get(e->get_id(), UINT_MAX);
if (v == UINT_MAX) {
v = m_vars.size();
@ -648,8 +699,6 @@ namespace sls {
auto const& vi = m_vars[v];
if (vi.m_def_idx == UINT_MAX)
return;
m_ops.reserve(vi.m_def_idx + 1);
auto const& od = m_ops[vi.m_def_idx];
num_t v1, v2;
switch (vi.m_op) {
case LAST_ARITH_OP:
@ -672,27 +721,27 @@ namespace sls {
break;
}
case OP_MOD:
v1 = value(od.m_arg1);
v2 = value(od.m_arg2);
v1 = value(m_ops[vi.m_def_idx].m_arg1);
v2 = value(m_ops[vi.m_def_idx].m_arg2);
update(v, v2 == 0 ? num_t(0) : mod(v1, v2));
break;
case OP_DIV:
v1 = value(od.m_arg1);
v2 = value(od.m_arg2);
v1 = value(m_ops[vi.m_def_idx].m_arg1);
v2 = value(m_ops[vi.m_def_idx].m_arg2);
update(v, v2 == 0 ? num_t(0) : v1 / v2);
break;
case OP_IDIV:
v1 = value(od.m_arg1);
v2 = value(od.m_arg2);
v1 = value(m_ops[vi.m_def_idx].m_arg1);
v2 = value(m_ops[vi.m_def_idx].m_arg2);
update(v, v2 == 0 ? num_t(0) : div(v1, v2));
break;
case OP_REM:
v1 = value(od.m_arg1);
v2 = value(od.m_arg2);
v1 = value(m_ops[vi.m_def_idx].m_arg1);
v2 = value(m_ops[vi.m_def_idx].m_arg2);
update(v, v2 == 0 ? num_t(0) : v1 %= v2);
break;
case OP_ABS:
update(v, abs(value(od.m_arg1)));
update(v, abs(value(m_ops[vi.m_def_idx].m_arg1)));
break;
default:
NOT_IMPLEMENTED_YET();
@ -700,54 +749,150 @@ namespace sls {
}
template<typename num_t>
void arith_base<num_t>::repair_down(app* e) {
bool arith_base<num_t>::repair_down(app* e) {
auto v = m_expr2var.get(e->get_id(), UINT_MAX);
if (v == UINT_MAX)
return;
return false;
auto const& vi = m_vars[v];
if (vi.m_def_idx == UINT_MAX)
return;
return false;
TRACE("sls", tout << "repair def " << mk_bounded_pp(vi.m_expr, m) << "\n");
switch (vi.m_op) {
case arith_op_kind::LAST_ARITH_OP:
break;
case arith_op_kind::OP_ADD:
repair_add(m_adds[vi.m_def_idx]);
break;
return repair_add(m_adds[vi.m_def_idx]);
case arith_op_kind::OP_MUL:
repair_mul(m_muls[vi.m_def_idx]);
break;
return repair_mul(m_muls[vi.m_def_idx]);
case arith_op_kind::OP_MOD:
repair_mod(m_ops[vi.m_def_idx]);
break;
return repair_mod(m_ops[vi.m_def_idx]);
case arith_op_kind::OP_REM:
repair_rem(m_ops[vi.m_def_idx]);
break;
return repair_rem(m_ops[vi.m_def_idx]);
case arith_op_kind::OP_POWER:
repair_power(m_ops[vi.m_def_idx]);
break;
return repair_power(m_ops[vi.m_def_idx]);
case arith_op_kind::OP_IDIV:
repair_idiv(m_ops[vi.m_def_idx]);
break;
return repair_idiv(m_ops[vi.m_def_idx]);
case arith_op_kind::OP_DIV:
repair_div(m_ops[vi.m_def_idx]);
break;
return repair_div(m_ops[vi.m_def_idx]);
case arith_op_kind::OP_ABS:
repair_abs(m_ops[vi.m_def_idx]);
break;
return repair_abs(m_ops[vi.m_def_idx]);
case arith_op_kind::OP_TO_INT:
repair_to_int(m_ops[vi.m_def_idx]);
break;
return repair_to_int(m_ops[vi.m_def_idx]);
case arith_op_kind::OP_TO_REAL:
repair_to_real(m_ops[vi.m_def_idx]);
break;
return repair_to_real(m_ops[vi.m_def_idx]);
default:
NOT_IMPLEMENTED_YET();
}
return true;
}
template<typename num_t>
void arith_base<num_t>::initialize() {
for (auto lit : ctx.unit_literals())
initialize(lit);
}
template<typename num_t>
void arith_base<num_t>::initialize(sat::literal lit) {
init_bool_var(lit.var());
auto* ineq = atom(lit.var());
if (!ineq)
return;
if (ineq->m_args.size() != 1)
return;
auto [c, v] = ineq->m_args[0];
switch (ineq->m_op) {
case ineq_kind::LE:
if (lit.sign()) {
if (c == -1) // -x + c >= 0 <=> c >= x
add_le(v, ineq->m_coeff);
else if (c == 1) // x + c >= 0 <=> x >= -c
add_ge(v, -ineq->m_coeff);
else
verbose_stream() << "INITIALIZE " << lit << " " << *ineq << "\n";
}
else {
if (c == -1)
add_ge(v, ineq->m_coeff);
else if (c == 1)
add_le(v, -ineq->m_coeff);
else
verbose_stream() << "INITIALIZE " << lit << " " << *ineq << "\n";
}
break;
case ineq_kind::EQ:
if (lit.sign()) {
verbose_stream() << "INITIALIZE " << lit << " " << *ineq << "\n";
}
else {
if (c == -1) {
add_ge(v, ineq->m_coeff);
add_le(v, ineq->m_coeff);
}
else if (c == 1) {
add_ge(v, -ineq->m_coeff);
add_le(v, -ineq->m_coeff);
}
else
verbose_stream() << "INITIALIZE " << lit << " " << *ineq << "\n";
}
break;
case ineq_kind::LT:
if (lit.sign()) {
if (c == -1) // -x + c >= 0 <=> c >= x
add_le(v, ineq->m_coeff);
else if (c == 1) // x + c >= 0 <=> x >= -c
add_ge(v, -ineq->m_coeff);
else
verbose_stream() << "INITIALIZE " << lit << " " << *ineq << "\n";
}
else {
if (c == -1)
add_gt(v, ineq->m_coeff);
else if (c == 1)
add_lt(v, -ineq->m_coeff);
else
verbose_stream() << "INITIALIZE " << lit << " " << *ineq << "\n";
}
break;
}
}
template<typename num_t>
void arith_base<num_t>::repair_add(add_def const& ad) {
void arith_base<num_t>::add_le(var_t v, num_t const& n) {
if (m_vars[v].m_hi && m_vars[v].m_hi->value <= n)
return;
m_vars[v].m_hi = { false, n };
}
template<typename num_t>
void arith_base<num_t>::add_ge(var_t v, num_t const& n) {
if (m_vars[v].m_lo && m_vars[v].m_lo->value >= n)
return;
m_vars[v].m_lo = { false, n };
}
template<typename num_t>
void arith_base<num_t>::add_lt(var_t v, num_t const& n) {
if (is_int(v))
add_le(v, n - 1);
else
m_vars[v].m_hi = { true, n };
}
template<typename num_t>
void arith_base<num_t>::add_gt(var_t v, num_t const& n) {
if (is_int(v))
add_ge(v, n + 1);
else
m_vars[v].m_lo = { true, n };
}
template<typename num_t>
bool arith_base<num_t>::repair_add(add_def const& ad) {
auto v = ad.m_var;
auto const& coeffs = ad.m_args;
num_t sum(ad.m_coeff);
@ -758,21 +903,71 @@ namespace sls {
for (auto const& [c, w] : coeffs)
sum += c * value(w);
if (val == sum)
return;
if (rand() % 20 == 0)
update(v, sum);
return true;
if (ctx.rand(20) == 0)
return update(v, sum);
else {
auto const& [c, w] = coeffs[rand() % coeffs.size()];
auto const& [c, w] = coeffs[ctx.rand(coeffs.size())];
num_t delta = sum - val;
bool is_real = m_vars[w].m_sort == var_sort::REAL;
bool round_down = rand() % 2 == 0;
bool round_down = ctx.rand(2) == 0;
num_t new_value = value(w) + (is_real ? delta / c : round_down ? div(delta, c) : div(delta + c - 1, c));
update(w, new_value);
return update(w, new_value);
}
}
template<typename num_t>
void arith_base<num_t>::repair_mul(mul_def const& md) {
bool arith_base<num_t>::repair_square(mul_def const& md) {
auto const& [v, coeff, monomial] = md;
if (!is_int(v) || monomial.size() != 2 || monomial[0] != monomial[1])
return false;
num_t val = value(v);
val = div(val, coeff);
var_t w = monomial[0];
if (val < 0)
update(w, num_t(ctx.rand(10)));
else {
num_t root = sqrt(val);
if (ctx.rand(3) == 0)
root = -root;
if (root * root == val)
update(w, root);
else
update(w, root + num_t(ctx.rand(3)) - 1);
}
verbose_stream() << "ROOT " << val << " v" << w << " := " << value(w) << "\n";
return true;
}
template<typename num_t>
bool arith_base<num_t>::repair_mul1(mul_def const& md) {
auto const& [v, coeff, monomial] = md;
if (!is_int(v))
return false;
num_t val = value(v);
val = div(val, coeff);
if (val == 0)
return false;
unsigned sz = monomial.size();
unsigned start = ctx.rand(sz);
for (unsigned i = 0; i < sz; ++i) {
unsigned j = (start + i) % sz;
auto w = monomial[j];
num_t product(1);
for (auto v : monomial)
if (v != w)
product *= value(v);
if (product == 0 || !divides(product, val))
continue;
update(w, div(val, product));
return true;
}
return false;
}
template<typename num_t>
bool arith_base<num_t>::repair_mul(mul_def const& md) {
auto const& [v, coeff, monomial] = md;
num_t product(coeff);
num_t val = value(v);
@ -780,118 +975,124 @@ namespace sls {
for (auto v : monomial)
product *= value(v);
if (product == val)
return;
verbose_stream() << "repair mul " << mk_bounded_pp(m_vars[v].m_expr, m) << " := " << val << "(" << product << ")\n";
if (rand() % 20 == 0)
update(v, product);
return true;
// verbose_stream() << "repair mul " << mk_bounded_pp(m_vars[v].m_expr, m) << " := " << val << "(product: " << product << ")\n";
unsigned sz = monomial.size();
if (ctx.rand(20) == 0)
return update(v, product);
else if (val == 0) {
auto v = monomial[ctx.rand(monomial.size())];
auto v = monomial[ctx.rand(sz)];
num_t zero(0);
update(v, zero);
return update(v, zero);
}
else if (val == 1 || val == -1) {
product = coeff;
for (auto v : monomial) {
num_t new_value(1);
if (rand() % 2 == 0)
new_value = -1;
product *= new_value;
update(v, new_value);
}
if (product != val) {
auto last = monomial.back();
update(last, -value(last));
}
else if (repair_square(md))
return true;
else if (ctx.rand(4) != 0 && repair_mul1(md)) {
#if 0
verbose_stream() << "mul1 " << val << " " << coeff << " ";
for (auto v : monomial)
verbose_stream() << "v" << v << " = " << value(v) << " ";
verbose_stream() << "\n";
#endif
return true;
}
else if (rand() % 2 == 0 && product != 0) {
// value1(v) * product / value(v) = val
// value1(v) = value(v) * val / product
auto w = monomial[ctx.rand(monomial.size())];
auto old_value = value(w);
new_value = divide(w, old_value * val, product);
update(w, new_value);
else if (is_int(v)) {
#if 0
verbose_stream() << "repair mul2 - ";
for (auto v : monomial)
verbose_stream() << "v" << v << " = " << value(v) << " ";
#endif
num_t n = div(val, coeff);
if (!divides(coeff, val) && ctx.rand(2) == 0)
n = div(val + coeff - 1, coeff);
auto const& fs = factor(abs(n));
vector<num_t> coeffs(sz, num_t(ctx.rand(2) == 0 ? 1 : -1));
vector<num_t> gcds(sz, num_t(0));
num_t sign(1);
for (auto c : coeffs)
sign *= c;
unsigned i = 0;
for (auto w : monomial) {
for (auto idx : m_vars[w].m_muls) {
auto const& [w1, coeff1, monomial1] = m_muls[idx];
gcds[i] = gcd(gcds[i], abs(value(w1)));
}
++i;
}
for (auto f : fs)
coeffs[ctx.rand(sz)] *= f;
if ((sign == 0) != (n == 0))
coeffs[ctx.rand(sz)] *= -1;
// verbose_stream() << "value " << val << " coeff: " << coeff << " coeffs: " << coeffs << " factors: " << fs << "\n";
i = 0;
for (auto w : monomial)
if (!update(w, coeffs[i++]))
return false;
return true;
}
else {
auto w = monomial[ctx.rand(monomial.size())];
num_t prod(coeff);
for (auto v : monomial) {
if (v == w)
continue;
num_t new_value(1);
if (rand() % 2 == 0)
new_value = -1;
prod *= new_value;
update(v, new_value);
}
verbose_stream() << "select random " << coeff << " " << val << " v" << w << "\n";
new_value = divide(w, val * value(w), coeff);
if ((product < 0 && 0 < new_value) || (new_value < 0 && 0 < product))
update(w, -new_value);
else
update(w, new_value);
NOT_IMPLEMENTED_YET();
}
return false;
}
template<typename num_t>
void arith_base<num_t>::repair_rem(op_def const& od) {
bool arith_base<num_t>::repair_rem(op_def const& od) {
auto v1 = value(od.m_arg1);
auto v2 = value(od.m_arg2);
if (v2 == 0) {
update(od.m_var, num_t(0));
return;
}
if (v2 == 0)
return update(od.m_var, num_t(0));
IF_VERBOSE(0, verbose_stream() << "todo repair rem");
// bail
v1 %= v2;
update(od.m_var, v1);
return update(od.m_var, v1);
}
template<typename num_t>
void arith_base<num_t>::repair_abs(op_def const& od) {
bool arith_base<num_t>::repair_abs(op_def const& od) {
auto val = value(od.m_var);
auto v1 = value(od.m_arg1);
if (val < 0)
update(od.m_var, abs(v1));
else if (rand() % 2 == 0)
update(od.m_arg1, val);
return update(od.m_var, abs(v1));
else if (ctx.rand(2) == 0)
return update(od.m_arg1, val);
else
update(od.m_arg1, -val);
return update(od.m_arg1, -val);
}
template<typename num_t>
void arith_base<num_t>::repair_to_int(op_def const& od) {
bool arith_base<num_t>::repair_to_int(op_def const& od) {
auto val = value(od.m_var);
auto v1 = value(od.m_arg1);
if (val - 1 < v1 && v1 <= val)
return;
update(od.m_arg1, val);
return true;
return update(od.m_arg1, val);
}
template<typename num_t>
void arith_base<num_t>::repair_to_real(op_def const& od) {
if (rand() % 20 == 0)
update(od.m_var, value(od.m_arg1));
bool arith_base<num_t>::repair_to_real(op_def const& od) {
if (ctx.rand(20) == 0)
return update(od.m_var, value(od.m_arg1));
else
update(od.m_arg1, value(od.m_arg1));
return update(od.m_arg1, value(od.m_arg1));
}
template<typename num_t>
void arith_base<num_t>::repair_power(op_def const& od) {
bool arith_base<num_t>::repair_power(op_def const& od) {
auto v1 = value(od.m_arg1);
auto v2 = value(od.m_arg2);
if (v1 == 0 && v2 == 0) {
update(od.m_var, num_t(0));
return;
return update(od.m_var, num_t(0));
}
IF_VERBOSE(0, verbose_stream() << "todo repair ^");
NOT_IMPLEMENTED_YET();
return false;
}
template<typename num_t>
void arith_base<num_t>::repair_mod(op_def const& od) {
bool arith_base<num_t>::repair_mod(op_def const& od) {
auto val = value(od.m_var);
auto v1 = value(od.m_arg1);
auto v2 = value(od.m_arg2);
@ -899,11 +1100,11 @@ namespace sls {
if (val >= 0 && val < v2) {
auto v3 = mod(v1, v2);
if (v3 == val)
return;
return true;
// find r, such that mod(v1 + r, v2) = val
// v1 := v1 + val - v3 (+/- v2)
v1 += val - v3;
switch (rand() % 6) {
switch (ctx.rand(6)) {
case 0:
v1 += v2;
break;
@ -913,28 +1114,27 @@ namespace sls {
default:
break;
}
update(od.m_arg1, v1);
return;
return update(od.m_arg1, v1);
}
update(od.m_var, v2 == 0 ? num_t(0) : mod(v1, v2));
return update(od.m_var, v2 == 0 ? num_t(0) : mod(v1, v2));
}
template<typename num_t>
void arith_base<num_t>::repair_idiv(op_def const& od) {
bool arith_base<num_t>::repair_idiv(op_def const& od) {
auto v1 = value(od.m_arg1);
auto v2 = value(od.m_arg2);
IF_VERBOSE(0, verbose_stream() << "todo repair div");
// bail
update(od.m_var, v2 == 0 ? num_t(0) : div(v1, v2));
return update(od.m_var, v2 == 0 ? num_t(0) : div(v1, v2));
}
template<typename num_t>
void arith_base<num_t>::repair_div(op_def const& od) {
bool arith_base<num_t>::repair_div(op_def const& od) {
auto v1 = value(od.m_arg1);
auto v2 = value(od.m_arg2);
IF_VERBOSE(0, verbose_stream() << "todo repair /");
// bail
update(od.m_var, v2 == 0 ? num_t(0) : v1 / v2);
return update(od.m_var, v2 == 0 ? num_t(0) : v1 / v2);
}
template<typename num_t>
@ -968,7 +1168,7 @@ namespace sls {
result += ctx.reward(bv);
#endif
}
if (result > max_result || max_result == -1 || (result == max_result && (rand() % ++n == 0))) {
if (result > max_result || max_result == -1 || (result == max_result && (ctx.rand(++n) == 0))) {
max_result = result;
ineq->m_var_to_flip = x;
}
@ -976,6 +1176,48 @@ namespace sls {
return max_result;
}
// Newton function for integer square root.
template<typename num_t>
num_t arith_base<num_t>::sqrt(num_t n) {
if (n <= 1)
return n;
auto x0 = div(n, num_t(2));
auto x1 = div(x0 + div(n, x0), num_t(2));
while (x1 < x0) {
x0 = x1;
x1 = div(x0 + div(n, x0), num_t(2));
}
return x0;
}
template<typename num_t>
vector<num_t> const& arith_base<num_t>::factor(num_t n) {
m_factors.reset();
for (auto d : { 2, 3, 5 }) {
while (mod(n, num_t(d)) == 0) {
m_factors.push_back(num_t(d));
n = div(n, num_t(d));
}
}
static int increments[8] = { 4, 2, 4, 2, 4, 6, 2, 6 };
unsigned i = 0;
for (auto d = num_t(7); d * d <= n; d += num_t(increments[i++])) {
while (mod(n, d) == 0) {
m_factors.push_back(d);
n = div(n, d);
}
if (i == 8)
i = 0;
}
if (n > 1)
m_factors.push_back(n);
return m_factors;
}
template<typename num_t>
double arith_base<num_t>::dscore_reward(sat::bool_var bv) {
m_dscore_mode = false;
@ -1063,8 +1305,11 @@ namespace sls {
template<typename num_t>
expr_ref arith_base<num_t>::get_value(expr* e) {
auto v = mk_var(e);
return expr_ref(a.mk_numeral(rational(m_vars[v].m_value.get_int64(), rational::i64()), a.is_int(e)), m);
num_t n;
if (is_num(e, n))
return expr_ref(a.mk_numeral(n.to_rational(), a.is_int(e)), m);
auto v = mk_term(e);
return expr_ref(a.mk_numeral(m_vars[v].m_value.to_rational(), a.is_int(e)), m);
}
template<typename num_t>
@ -1086,6 +1331,7 @@ namespace sls {
}
if (sat)
continue;
verbose_stream() << "not sat:\n";
verbose_stream() << clause << "\n";
for (auto lit : clause.m_clause) {
verbose_stream() << lit << " (" << ctx.is_true(lit) << ") ";
@ -1103,6 +1349,30 @@ namespace sls {
return true;
}
template<typename num_t>
std::ostream& arith_base<num_t>::display(std::ostream& out, var_t v) const {
auto const& vi = m_vars[v];
auto const& lo = vi.m_lo;
auto const& hi = vi.m_hi;
out << "v" << v << " := " << vi.m_value << " ";
if (lo || hi) {
if (lo)
out << (lo->is_strict ? "(": "[") << lo->value;
else
out << "(";
out << " ";
if (hi)
out << hi->value << (hi->is_strict ? ")" : "]");
else
out << ")";
out << " ";
}
out << mk_bounded_pp(vi.m_expr, m) << " : ";
for (auto [c, bv] : vi.m_bool_vars)
out << c << "@" << bv << " ";
return out;
}
template<typename num_t>
std::ostream& arith_base<num_t>::display(std::ostream& out) const {
for (unsigned v = 0; v < ctx.num_bool_vars(); ++v) {
@ -1110,14 +1380,9 @@ namespace sls {
if (ineq)
out << v << ": " << *ineq << "\n";
}
for (unsigned v = 0; v < m_vars.size(); ++v) {
auto const& vi = m_vars[v];
out << "v" << v << " := " << vi.m_value << " (best " << vi.m_best_value << ") ";
out << mk_bounded_pp(vi.m_expr, m) << " : ";
for (auto [c, bv] : vi.m_bool_vars)
out << c << "@" << bv << " ";
out << "\n";
}
for (unsigned v = 0; v < m_vars.size(); ++v)
display(out, v) << "\n";
for (auto md : m_muls) {
out << "v" << md.m_var << " := ";
for (auto w : md.m_monomial)

View file

@ -18,6 +18,7 @@ Author:
#include "util/obj_pair_set.h"
#include "util/checked_int64.h"
#include "util/optional.h"
#include "ast/ast_trail.h"
#include "ast/arith_decl_plugin.h"
#include "ast/sls/sls_context.h"
@ -31,6 +32,7 @@ namespace sls {
class arith_base : public plugin {
enum class ineq_kind { EQ, LE, LT};
enum class var_sort { INT, REAL };
struct bound { bool is_strict = false; num_t value; };
typedef unsigned var_t;
typedef unsigned atom_t;
@ -73,6 +75,7 @@ namespace sls {
vector<std::pair<num_t, sat::bool_var>> m_bool_vars;
unsigned_vector m_muls;
unsigned_vector m_adds;
optional<bound> m_lo, m_hi;
};
struct mul_def {
@ -104,17 +107,24 @@ namespace sls {
unsigned get_num_vars() const { return m_vars.size(); }
void repair_mul(mul_def const& md);
void repair_add(add_def const& ad);
void repair_mod(op_def const& od);
void repair_idiv(op_def const& od);
void repair_div(op_def const& od);
void repair_rem(op_def const& od);
void repair_power(op_def const& od);
void repair_abs(op_def const& od);
void repair_to_int(op_def const& od);
void repair_to_real(op_def const& od);
bool repair_mul1(mul_def const& md);
bool repair_square(mul_def const& md);
bool repair_mul(mul_def const& md);
bool repair_add(add_def const& ad);
bool repair_mod(op_def const& od);
bool repair_idiv(op_def const& od);
bool repair_div(op_def const& od);
bool repair_rem(op_def const& od);
bool repair_power(op_def const& od);
bool repair_abs(op_def const& od);
bool repair_to_int(op_def const& od);
bool repair_to_real(op_def const& od);
void repair(sat::literal lit, ineq const& ineq);
bool in_bounds(var_t v, num_t const& value);
vector<num_t> m_factors;
vector<num_t> const& factor(num_t n);
num_t sqrt(num_t n);
double reward(sat::literal lit);
@ -129,7 +139,7 @@ namespace sls {
bool cm(ineq const& ineq, var_t v, num_t& new_value);
bool cm(ineq const& ineq, var_t v, num_t const& coeff, num_t& new_value);
int cm_score(var_t v, num_t const& new_value);
void update(var_t v, num_t const& new_value);
bool update(var_t v, num_t const& new_value);
double dscore_reward(sat::bool_var v);
double dtt_reward(sat::literal lit);
double dscore(var_t v, num_t const& new_value) const;
@ -146,22 +156,30 @@ namespace sls {
void init_bool_var_assignment(sat::bool_var v);
bool is_int(var_t v) const { return m_vars[v].m_sort == var_sort::INT; }
num_t value(var_t v) const { return m_vars[v].m_value; }
bool is_num(expr* e, num_t& i);
expr_ref from_num(sort* s, num_t const& n);
void check_ineqs();
void init_bool_var(sat::bool_var bv);
void initialize(sat::literal lit);
void add_le(var_t v, num_t const& n);
void add_ge(var_t v, num_t const& n);
void add_lt(var_t v, num_t const& n);
void add_gt(var_t v, num_t const& n);
std::ostream& display(std::ostream& out, var_t v) const;
public:
arith_base(context& ctx);
~arith_base() override {}
void register_term(expr* e) override;
void set_value(expr* e, expr* v) override;
expr_ref get_value(expr* e) override;
void initialize() override {}
void initialize() override;
void propagate_literal(sat::literal lit) override;
bool propagate() override;
void repair_up(app* e) override;
void repair_down(app* e) override;
bool repair_down(app* e) override;
bool is_sat() override;
void on_rescale() override;
void on_restart() override;

View file

@ -102,7 +102,7 @@ namespace sls {
m_arith64->mk_model(mdl);
}
void arith_plugin::repair_down(app* e) {
bool arith_plugin::repair_down(app* e) {
WITH_FALLBACK(repair_down(e));
}

View file

@ -35,7 +35,7 @@ namespace sls {
void initialize() override;
void propagate_literal(sat::literal lit) override;
bool propagate() override;
void repair_down(app* e) override;
bool repair_down(app* e) override;
void repair_up(app* e) override;
bool is_sat() override;

View file

@ -335,21 +335,23 @@ namespace sls {
set_value(e, b);
}
void basic_plugin::repair_down(app* e) {
bool basic_plugin::repair_down(app* e) {
SASSERT(m.is_bool(e));
unsigned n = e->get_num_args();
if (n == 0 || !is_basic(e))
return;
if (!is_basic(e))
return false;
if (n == 0)
return true;
if (bval0(e) == bval1(e))
return;
return true;
unsigned s = ctx.rand(n);
for (unsigned i = 0; i < n; ++i) {
auto j = (i + s) % n;
if (try_repair(e, j))
return;
return true;
}
repair_up(e);
return false;
}
bool basic_plugin::try_repair_distinct(app* e, unsigned i) {

View file

@ -46,7 +46,7 @@ namespace sls {
void initialize() override;
void propagate_literal(sat::literal lit) override;
bool propagate() override;
void repair_down(app* e) override;
bool repair_down(app* e) override;
void repair_up(app* e) override;
bool is_sat() override;

View file

@ -99,34 +99,33 @@ namespace sls {
w.commit_eval();
}
void bv_plugin::repair_down(app* e) {
bool bv_plugin::repair_down(app* e) {
unsigned n = e->get_num_args();
if (n == 0 || m_eval.eval_is_correct(e))
return;
return true;
if (n == 2) {
auto d1 = get_depth(e->get_arg(0));
auto d2 = get_depth(e->get_arg(1));
unsigned s = ctx.rand(d1 + d2 + 2);
if (s <= d1 && m_eval.repair_down(e, 0))
return;
return true;
if (m_eval.repair_down(e, 1))
return;
return true;
if (m_eval.repair_down(e, 0))
return;
return true;
}
else {
unsigned s = ctx.rand(n);
for (unsigned i = 0; i < n; ++i) {
auto j = (i + s) % n;
if (m_eval.repair_down(e, j))
return;
return true;
}
}
IF_VERBOSE(0, verbose_stream() << "revert repair: " << mk_bounded_pp(e, m) << "\n");
repair_up(e);
return false;
}
void bv_plugin::repair_up(app* e) {

View file

@ -43,7 +43,7 @@ namespace sls {
void initialize() override;
void propagate_literal(sat::literal lit) override;
bool propagate() override;
void repair_down(app* e) override;
bool repair_down(app* e) override;
void repair_up(app* e) override;
bool is_sat() override;

View file

@ -104,9 +104,11 @@ namespace sls {
expr* e = term(id);
TRACE("sls", tout << "repair down " << mk_bounded_pp(e, m) << "\n");
if (is_app(e)) {
auto p = m_plugins.get(to_app(e)->get_family_id(), nullptr);
if (p)
p->repair_down(to_app(e));
auto p = m_plugins.get(get_fid(e), nullptr);
if (p && !p->repair_down(to_app(e)) && !m_repair_up.contains(e->get_id())) {
IF_VERBOSE(0, verbose_stream() << "revert repair: " << mk_bounded_pp(e, m) << "\n");
m_repair_up.insert(e->get_id());
}
}
}
while (!m_repair_up.empty() && !m_new_constraint) {
@ -114,7 +116,7 @@ namespace sls {
expr* e = term(id);
TRACE("sls", tout << "repair up " << mk_bounded_pp(e, m) << "\n");
if (is_app(e)) {
auto p = m_plugins.get(to_app(e)->get_family_id(), nullptr);
auto p = m_plugins.get(get_fid(e), nullptr);
if (p)
p->repair_up(to_app(e));
}
@ -129,15 +131,24 @@ namespace sls {
}
}
family_id context::get_fid(expr* e) const {
if (!is_app(e))
return null_family_id;
family_id fid = to_app(e)->get_family_id();
if (m.is_eq(e) || m.is_distinct(e))
fid = to_app(e)->get_arg(0)->get_sort()->get_family_id();
else if (m.is_ite(e))
fid = to_app(e)->get_arg(1)->get_sort()->get_family_id();
return fid;
}
void context::propagate_literal(sat::literal lit) {
if (!is_true(lit))
return;
auto a = atom(lit.var());
if (!a || !is_app(a))
if (!a)
return;
family_id fid = to_app(a)->get_family_id();
if (m.is_eq(a) || m.is_distinct(a))
fid = to_app(a)->get_arg(0)->get_sort()->get_family_id();
family_id fid = get_fid(a);
auto p = m_plugins.get(fid, nullptr);
if (p)
p->propagate_literal(lit);
@ -223,6 +234,11 @@ namespace sls {
if (m_initialized)
return;
m_initialized = true;
m_unit_literals.reset();
for (auto const& clause : s.clauses())
if (clause.m_clause.size() == 1)
m_unit_literals.push_back(clause.m_clause[0]);
verbose_stream() << "UNITS " << m_unit_literals << "\n";
for (auto a : m_atoms)
if (a)
register_terms(a);
@ -310,7 +326,7 @@ namespace sls {
m_relevant.reset();
m_visited.reset();
m_root_literals.reset();
m_unit_literals.reset();
for (auto const& clause : s.clauses()) {
bool has_relevant = false;
unsigned n = 0;
@ -329,8 +345,6 @@ namespace sls {
if (m_rand() % ++n == 0)
selected_lit = lit;
}
if (clause.m_clause.size() == 1)
m_unit_literals.push_back(clause.m_clause[0]);
if (!has_relevant && selected_lit != sat::null_literal) {
m_relevant.insert(m_atoms[selected_lit.var()]->get_id());
m_root_literals.push_back(selected_lit);

View file

@ -42,7 +42,7 @@ namespace sls {
virtual void initialize() = 0;
virtual bool propagate() = 0;
virtual void propagate_literal(sat::literal lit) = 0;
virtual void repair_down(app* e) = 0;
virtual bool repair_down(app* e) = 0;
virtual void repair_up(app* e) = 0;
virtual bool is_sat() = 0;
virtual void on_rescale() {};
@ -116,6 +116,8 @@ namespace sls {
void propagate_boolean_assignment();
void propagate_literal(sat::literal lit);
family_id get_fid(expr* e) const;
public:
context(ast_manager& m, sat_solver_context& s);

View file

@ -49,7 +49,7 @@ namespace sls {
void set_value(expr* e, expr* v) override {}
void repair_up(app* e) override {}
void repair_down(app* e) override {}
bool repair_down(app* e) override { return false; }
};
}

View file

@ -18,6 +18,7 @@ Author:
#include "ast/sls/sls_context.h"
#include "ast/sls/sat_ddfw.h"
#include "ast/sls/sls_smt_solver.h"
#include "ast/ast_ll_pp.h"
namespace sls {
@ -101,7 +102,12 @@ namespace sls {
}
void smt_solver::assert_expr(expr* e) {
m_assertions.push_back(e);
if (m.is_and(e)) {
for (expr* arg : *to_app(e))
assert_expr(arg);
}
else
m_assertions.push_back(e);
}
lbool smt_solver::check() {
@ -116,7 +122,11 @@ namespace sls {
}
void smt_solver::add_clause(expr* f) {
expr* g;
sat::literal_vector clause;
if (m.is_not(f, g) && m.is_not(g, g)) {
add_clause(g);
}
if (m.is_or(f)) {
clause.reset();
for (auto arg : *to_app(f))
@ -127,6 +137,18 @@ namespace sls {
for (auto arg : *to_app(f))
add_clause(arg);
}
else if (m.is_not(f, g) && m.is_or(g)) {
for (auto arg : *to_app(g)) {
expr_ref fml(m.mk_not(arg), m);;
add_clause(fml);
}
}
else if (m.is_not(f, g) && m.is_and(g)) {
clause.reset();
for (auto arg : *to_app(g))
clause.push_back(~mk_literal(arg));
m_solver_ctx->add_clause(clause.size(), clause.data());
}
else {
sat::literal lit = mk_literal(f);
m_solver_ctx->add_clause(1, &lit);