diff --git a/src/ast/sls/sls_arith_base.cpp b/src/ast/sls/sls_arith_base.cpp index 5abb17734..60f765386 100644 --- a/src/ast/sls/sls_arith_base.cpp +++ b/src/ast/sls/sls_arith_base.cpp @@ -129,9 +129,10 @@ namespace sls { template num_t arith_base::divide(var_t v, num_t const& delta, num_t const& coeff) { - if (m_vars[v].m_sort == var_sort::REAL) - return delta / coeff; - return div(delta + abs(coeff) - 1, coeff); + if (is_int(v)) + return div(delta + abs(coeff) - 1, coeff); + else + return delta / coeff; } template @@ -147,7 +148,7 @@ namespace sls { // args <= bound -> args > bound SASSERT(argsv <= bound); SASSERT(delta <= 0); - delta -= 1 + (ctx.rand() % 10); + delta -= 1 + (ctx.rand(10)); new_value = value(v) + divide(v, abs(delta), coeff); VERIFY(argsv + coeff * (new_value - value(v)) > bound); return true; @@ -155,13 +156,13 @@ namespace sls { // args < bound -> args >= bound SASSERT(argsv <= bound); SASSERT(delta <= 0); - delta = abs(delta) + ctx.rand() % 10; + delta = abs(delta) + ctx.rand(10); new_value = value(v) + divide(v, delta, coeff); VERIFY(argsv + coeff * (new_value - value(v)) >= bound); return true; case ineq_kind::EQ: { - delta = abs(delta) + 1 + ctx.rand() % 10; - int sign = ctx.rand() % 2 == 0 ? 1 : -1; + delta = abs(delta) + 1 + ctx.rand(10); + int sign = ctx.rand(2) == 0 ? 1 : -1; new_value = value(v) + sign * divide(v, abs(delta), coeff); VERIFY(argsv + coeff * (new_value - value(v)) != bound); return true; @@ -176,14 +177,14 @@ namespace sls { case ineq_kind::LE: SASSERT(argsv > bound); SASSERT(delta > 0); - delta += rand() % 10; + delta += ctx.rand(10); new_value = value(v) - divide(v, delta, coeff); VERIFY(argsv + coeff * (new_value - value(v)) <= bound); return true; case ineq_kind::LT: SASSERT(argsv >= bound); SASSERT(delta >= 0); - delta += 1 + rand() % 10; + delta += 1 + ctx.rand(10); new_value = value(v) - divide(v, delta, coeff); VERIFY(argsv + coeff * (new_value - value(v)) < bound); return true; @@ -229,6 +230,8 @@ namespace sls { } verbose_stream() << "repair " << lit << ": " << ineq << " var: v" << v << " := " << value(v) << " -> " << new_value << "\n"; update(v, new_value); + if (dtt(lit.sign(), ineq) != 0) + ctx.flip(lit.var()); } // @@ -329,14 +332,55 @@ namespace sls { return d; } + template - void arith_base::update(var_t v, num_t const& new_value) { + bool arith_base::in_bounds(var_t v, num_t const& value) { + auto const& vi = m_vars[v]; + auto const& lo = vi.m_lo; + auto const& hi = vi.m_hi; + if (lo && value < lo->value) + return false; + if (lo && lo->is_strict && value <= lo->value) + return false; + if (hi && value > hi->value) + return false; + if (hi && hi->is_strict && value >= hi->value) + return false; + return true; + } + + template + bool arith_base::update(var_t v, num_t const& new_value) { auto& vi = m_vars[v]; expr* e = vi.m_expr; auto old_value = vi.m_value; if (old_value == new_value) - return; - verbose_stream() << mk_bounded_pp(e, m) << " := " << new_value << "\n"; + return true; + display(verbose_stream(), v) << " := " << new_value << "\n"; + if (!in_bounds(v, new_value)) { + auto const& lo = vi.m_lo; + auto const& hi = vi.m_hi; + if (is_int(v) && lo && !lo->is_strict && new_value < lo->value) { + if (lo->value != old_value) + return update(v, lo->value); + if (in_bounds(v, old_value + 1)) + return update(v, old_value + 1); + else + return false; + } + if (is_int(v) && hi && !hi->is_strict && new_value > hi->value) { + if (hi->value != old_value) + return update(v, hi->value); + else if (in_bounds(v, old_value - 1)) + return update(v, old_value - 1); + else + return false; + } + verbose_stream() << "out of bounds old value " << old_value << "\n"; + display(verbose_stream(), v) << "\n"; + SASSERT(false); + return false; + } for (auto const& [coeff, bv] : vi.m_bool_vars) { auto& ineq = *atom(bv); bool old_sign = sign(bv); @@ -344,6 +388,7 @@ namespace sls { SASSERT(ctx.is_true(lit)); ineq.m_args_value += coeff * (new_value - old_value); num_t dtt_new = dtt(old_sign, ineq); + // verbose_stream() << "dtt " << lit << " " << ineq << " " << dtt_new << "\n"; if (dtt_new != 0) ctx.flip(bv); SASSERT(dtt(sign(bv), ineq) == 0); @@ -358,8 +403,12 @@ namespace sls { ctx.new_value_eh(m_vars[ad.m_var].m_expr); } + if (m.is_value(e)) { + display(verbose_stream()); + } SASSERT(!m.is_value(e)); ctx.new_value_eh(e); + return true; } template @@ -433,7 +482,7 @@ namespace sls { } else if (a.is_mul(e)) { unsigned_vector m; - num_t c = coeff; + num_t c(1); for (expr* arg : *to_app(e)) if (is_num(arg, i)) c *= i; @@ -441,10 +490,10 @@ namespace sls { m.push_back(mk_term(arg)); switch (m.size()) { case 0: - term.m_coeff += c; + term.m_coeff += c*coeff; break; case 1: - add_arg(term, c, m[0]); + add_arg(term, c*coeff, m[0]); break; default: { v = mk_var(e); @@ -456,7 +505,7 @@ namespace sls { m_vars[v].m_def_idx = idx; m_vars[v].m_op = arith_op_kind::OP_MUL; m_vars[v].m_value = prod; - add_arg(term, num_t(1), v); + add_arg(term, coeff, v); break; } } @@ -517,6 +566,7 @@ namespace sls { NOT_IMPLEMENTED_YET(); break; } + verbose_stream() << "mk-op " << mk_bounded_pp(e, m) << "\n"; m_ops.push_back({v, k, v, w}); m_vars[v].m_def_idx = idx; m_vars[v].m_op = k; @@ -547,6 +597,7 @@ namespace sls { template typename arith_base::var_t arith_base::mk_var(expr* e) { + SASSERT(!m.is_value(e)); var_t v = m_expr2var.get(e->get_id(), UINT_MAX); if (v == UINT_MAX) { v = m_vars.size(); @@ -648,8 +699,6 @@ namespace sls { auto const& vi = m_vars[v]; if (vi.m_def_idx == UINT_MAX) return; - m_ops.reserve(vi.m_def_idx + 1); - auto const& od = m_ops[vi.m_def_idx]; num_t v1, v2; switch (vi.m_op) { case LAST_ARITH_OP: @@ -672,27 +721,27 @@ namespace sls { break; } case OP_MOD: - v1 = value(od.m_arg1); - v2 = value(od.m_arg2); + v1 = value(m_ops[vi.m_def_idx].m_arg1); + v2 = value(m_ops[vi.m_def_idx].m_arg2); update(v, v2 == 0 ? num_t(0) : mod(v1, v2)); break; case OP_DIV: - v1 = value(od.m_arg1); - v2 = value(od.m_arg2); + v1 = value(m_ops[vi.m_def_idx].m_arg1); + v2 = value(m_ops[vi.m_def_idx].m_arg2); update(v, v2 == 0 ? num_t(0) : v1 / v2); break; case OP_IDIV: - v1 = value(od.m_arg1); - v2 = value(od.m_arg2); + v1 = value(m_ops[vi.m_def_idx].m_arg1); + v2 = value(m_ops[vi.m_def_idx].m_arg2); update(v, v2 == 0 ? num_t(0) : div(v1, v2)); break; case OP_REM: - v1 = value(od.m_arg1); - v2 = value(od.m_arg2); + v1 = value(m_ops[vi.m_def_idx].m_arg1); + v2 = value(m_ops[vi.m_def_idx].m_arg2); update(v, v2 == 0 ? num_t(0) : v1 %= v2); break; case OP_ABS: - update(v, abs(value(od.m_arg1))); + update(v, abs(value(m_ops[vi.m_def_idx].m_arg1))); break; default: NOT_IMPLEMENTED_YET(); @@ -700,54 +749,150 @@ namespace sls { } template - void arith_base::repair_down(app* e) { + bool arith_base::repair_down(app* e) { auto v = m_expr2var.get(e->get_id(), UINT_MAX); if (v == UINT_MAX) - return; + return false; auto const& vi = m_vars[v]; if (vi.m_def_idx == UINT_MAX) - return; + return false; TRACE("sls", tout << "repair def " << mk_bounded_pp(vi.m_expr, m) << "\n"); switch (vi.m_op) { case arith_op_kind::LAST_ARITH_OP: break; case arith_op_kind::OP_ADD: - repair_add(m_adds[vi.m_def_idx]); - break; + return repair_add(m_adds[vi.m_def_idx]); case arith_op_kind::OP_MUL: - repair_mul(m_muls[vi.m_def_idx]); - break; + return repair_mul(m_muls[vi.m_def_idx]); case arith_op_kind::OP_MOD: - repair_mod(m_ops[vi.m_def_idx]); - break; + return repair_mod(m_ops[vi.m_def_idx]); case arith_op_kind::OP_REM: - repair_rem(m_ops[vi.m_def_idx]); - break; + return repair_rem(m_ops[vi.m_def_idx]); case arith_op_kind::OP_POWER: - repair_power(m_ops[vi.m_def_idx]); - break; + return repair_power(m_ops[vi.m_def_idx]); case arith_op_kind::OP_IDIV: - repair_idiv(m_ops[vi.m_def_idx]); - break; + return repair_idiv(m_ops[vi.m_def_idx]); case arith_op_kind::OP_DIV: - repair_div(m_ops[vi.m_def_idx]); - break; + return repair_div(m_ops[vi.m_def_idx]); case arith_op_kind::OP_ABS: - repair_abs(m_ops[vi.m_def_idx]); - break; + return repair_abs(m_ops[vi.m_def_idx]); case arith_op_kind::OP_TO_INT: - repair_to_int(m_ops[vi.m_def_idx]); - break; + return repair_to_int(m_ops[vi.m_def_idx]); case arith_op_kind::OP_TO_REAL: - repair_to_real(m_ops[vi.m_def_idx]); - break; + return repair_to_real(m_ops[vi.m_def_idx]); default: NOT_IMPLEMENTED_YET(); } + return true; + } + + template + void arith_base::initialize() { + for (auto lit : ctx.unit_literals()) + initialize(lit); + } + + template + void arith_base::initialize(sat::literal lit) { + init_bool_var(lit.var()); + auto* ineq = atom(lit.var()); + if (!ineq) + return; + + if (ineq->m_args.size() != 1) + return; + auto [c, v] = ineq->m_args[0]; + + switch (ineq->m_op) { + case ineq_kind::LE: + if (lit.sign()) { + if (c == -1) // -x + c >= 0 <=> c >= x + add_le(v, ineq->m_coeff); + else if (c == 1) // x + c >= 0 <=> x >= -c + add_ge(v, -ineq->m_coeff); + else + verbose_stream() << "INITIALIZE " << lit << " " << *ineq << "\n"; + } + else { + if (c == -1) + add_ge(v, ineq->m_coeff); + else if (c == 1) + add_le(v, -ineq->m_coeff); + else + verbose_stream() << "INITIALIZE " << lit << " " << *ineq << "\n"; + } + break; + case ineq_kind::EQ: + if (lit.sign()) { + verbose_stream() << "INITIALIZE " << lit << " " << *ineq << "\n"; + } + else { + if (c == -1) { + add_ge(v, ineq->m_coeff); + add_le(v, ineq->m_coeff); + } + else if (c == 1) { + add_ge(v, -ineq->m_coeff); + add_le(v, -ineq->m_coeff); + } + else + verbose_stream() << "INITIALIZE " << lit << " " << *ineq << "\n"; + } + break; + case ineq_kind::LT: + + if (lit.sign()) { + if (c == -1) // -x + c >= 0 <=> c >= x + add_le(v, ineq->m_coeff); + else if (c == 1) // x + c >= 0 <=> x >= -c + add_ge(v, -ineq->m_coeff); + else + verbose_stream() << "INITIALIZE " << lit << " " << *ineq << "\n"; + } + else { + if (c == -1) + add_gt(v, ineq->m_coeff); + else if (c == 1) + add_lt(v, -ineq->m_coeff); + else + verbose_stream() << "INITIALIZE " << lit << " " << *ineq << "\n"; + } + break; + } } template - void arith_base::repair_add(add_def const& ad) { + void arith_base::add_le(var_t v, num_t const& n) { + if (m_vars[v].m_hi && m_vars[v].m_hi->value <= n) + return; + m_vars[v].m_hi = { false, n }; + } + + template + void arith_base::add_ge(var_t v, num_t const& n) { + if (m_vars[v].m_lo && m_vars[v].m_lo->value >= n) + return; + m_vars[v].m_lo = { false, n }; + } + + template + void arith_base::add_lt(var_t v, num_t const& n) { + if (is_int(v)) + add_le(v, n - 1); + else + m_vars[v].m_hi = { true, n }; + } + + template + void arith_base::add_gt(var_t v, num_t const& n) { + if (is_int(v)) + add_ge(v, n + 1); + else + m_vars[v].m_lo = { true, n }; + } + + template + bool arith_base::repair_add(add_def const& ad) { auto v = ad.m_var; auto const& coeffs = ad.m_args; num_t sum(ad.m_coeff); @@ -758,21 +903,71 @@ namespace sls { for (auto const& [c, w] : coeffs) sum += c * value(w); if (val == sum) - return; - if (rand() % 20 == 0) - update(v, sum); + return true; + if (ctx.rand(20) == 0) + return update(v, sum); else { - auto const& [c, w] = coeffs[rand() % coeffs.size()]; + auto const& [c, w] = coeffs[ctx.rand(coeffs.size())]; num_t delta = sum - val; bool is_real = m_vars[w].m_sort == var_sort::REAL; - bool round_down = rand() % 2 == 0; + bool round_down = ctx.rand(2) == 0; num_t new_value = value(w) + (is_real ? delta / c : round_down ? div(delta, c) : div(delta + c - 1, c)); - update(w, new_value); + return update(w, new_value); } } template - void arith_base::repair_mul(mul_def const& md) { + bool arith_base::repair_square(mul_def const& md) { + auto const& [v, coeff, monomial] = md; + if (!is_int(v) || monomial.size() != 2 || monomial[0] != monomial[1]) + return false; + + num_t val = value(v); + val = div(val, coeff); + var_t w = monomial[0]; + if (val < 0) + update(w, num_t(ctx.rand(10))); + else { + num_t root = sqrt(val); + if (ctx.rand(3) == 0) + root = -root; + if (root * root == val) + update(w, root); + else + update(w, root + num_t(ctx.rand(3)) - 1); + } + verbose_stream() << "ROOT " << val << " v" << w << " := " << value(w) << "\n"; + return true; + } + + template + bool arith_base::repair_mul1(mul_def const& md) { + auto const& [v, coeff, monomial] = md; + if (!is_int(v)) + return false; + num_t val = value(v); + val = div(val, coeff); + if (val == 0) + return false; + unsigned sz = monomial.size(); + unsigned start = ctx.rand(sz); + for (unsigned i = 0; i < sz; ++i) { + unsigned j = (start + i) % sz; + auto w = monomial[j]; + num_t product(1); + for (auto v : monomial) + if (v != w) + product *= value(v); + if (product == 0 || !divides(product, val)) + continue; + update(w, div(val, product)); + return true; + } + return false; + } + + template + bool arith_base::repair_mul(mul_def const& md) { auto const& [v, coeff, monomial] = md; num_t product(coeff); num_t val = value(v); @@ -780,118 +975,124 @@ namespace sls { for (auto v : monomial) product *= value(v); if (product == val) - return; - verbose_stream() << "repair mul " << mk_bounded_pp(m_vars[v].m_expr, m) << " := " << val << "(" << product << ")\n"; - if (rand() % 20 == 0) - update(v, product); + return true; +// verbose_stream() << "repair mul " << mk_bounded_pp(m_vars[v].m_expr, m) << " := " << val << "(product: " << product << ")\n"; + unsigned sz = monomial.size(); + if (ctx.rand(20) == 0) + return update(v, product); else if (val == 0) { - auto v = monomial[ctx.rand(monomial.size())]; + auto v = monomial[ctx.rand(sz)]; num_t zero(0); - update(v, zero); + return update(v, zero); } - else if (val == 1 || val == -1) { - product = coeff; - for (auto v : monomial) { - num_t new_value(1); - if (rand() % 2 == 0) - new_value = -1; - product *= new_value; - update(v, new_value); - } - if (product != val) { - auto last = monomial.back(); - update(last, -value(last)); - } + else if (repair_square(md)) + return true; + else if (ctx.rand(4) != 0 && repair_mul1(md)) { +#if 0 + verbose_stream() << "mul1 " << val << " " << coeff << " "; + for (auto v : monomial) + verbose_stream() << "v" << v << " = " << value(v) << " "; + verbose_stream() << "\n"; +#endif + return true; } - else if (rand() % 2 == 0 && product != 0) { - // value1(v) * product / value(v) = val - // value1(v) = value(v) * val / product - auto w = monomial[ctx.rand(monomial.size())]; - auto old_value = value(w); - new_value = divide(w, old_value * val, product); - update(w, new_value); + else if (is_int(v)) { +#if 0 + verbose_stream() << "repair mul2 - "; + for (auto v : monomial) + verbose_stream() << "v" << v << " = " << value(v) << " "; +#endif + num_t n = div(val, coeff); + if (!divides(coeff, val) && ctx.rand(2) == 0) + n = div(val + coeff - 1, coeff); + auto const& fs = factor(abs(n)); + vector coeffs(sz, num_t(ctx.rand(2) == 0 ? 1 : -1)); + vector gcds(sz, num_t(0)); + num_t sign(1); + for (auto c : coeffs) + sign *= c; + unsigned i = 0; + for (auto w : monomial) { + for (auto idx : m_vars[w].m_muls) { + auto const& [w1, coeff1, monomial1] = m_muls[idx]; + gcds[i] = gcd(gcds[i], abs(value(w1))); + } + ++i; + } + for (auto f : fs) + coeffs[ctx.rand(sz)] *= f; + if ((sign == 0) != (n == 0)) + coeffs[ctx.rand(sz)] *= -1; +// verbose_stream() << "value " << val << " coeff: " << coeff << " coeffs: " << coeffs << " factors: " << fs << "\n"; + i = 0; + for (auto w : monomial) + if (!update(w, coeffs[i++])) + return false; + return true; } else { - auto w = monomial[ctx.rand(monomial.size())]; - num_t prod(coeff); - for (auto v : monomial) { - if (v == w) - continue; - num_t new_value(1); - if (rand() % 2 == 0) - new_value = -1; - prod *= new_value; - update(v, new_value); - } - - verbose_stream() << "select random " << coeff << " " << val << " v" << w << "\n"; - new_value = divide(w, val * value(w), coeff); - - if ((product < 0 && 0 < new_value) || (new_value < 0 && 0 < product)) - update(w, -new_value); - else - update(w, new_value); + NOT_IMPLEMENTED_YET(); } + return false; } template - void arith_base::repair_rem(op_def const& od) { + bool arith_base::repair_rem(op_def const& od) { auto v1 = value(od.m_arg1); auto v2 = value(od.m_arg2); - if (v2 == 0) { - update(od.m_var, num_t(0)); - return; - } + if (v2 == 0) + return update(od.m_var, num_t(0)); + IF_VERBOSE(0, verbose_stream() << "todo repair rem"); // bail v1 %= v2; - update(od.m_var, v1); + return update(od.m_var, v1); } template - void arith_base::repair_abs(op_def const& od) { + bool arith_base::repair_abs(op_def const& od) { auto val = value(od.m_var); auto v1 = value(od.m_arg1); if (val < 0) - update(od.m_var, abs(v1)); - else if (rand() % 2 == 0) - update(od.m_arg1, val); + return update(od.m_var, abs(v1)); + else if (ctx.rand(2) == 0) + return update(od.m_arg1, val); else - update(od.m_arg1, -val); + return update(od.m_arg1, -val); } template - void arith_base::repair_to_int(op_def const& od) { + bool arith_base::repair_to_int(op_def const& od) { auto val = value(od.m_var); auto v1 = value(od.m_arg1); if (val - 1 < v1 && v1 <= val) - return; - update(od.m_arg1, val); + return true; + return update(od.m_arg1, val); } template - void arith_base::repair_to_real(op_def const& od) { - if (rand() % 20 == 0) - update(od.m_var, value(od.m_arg1)); + bool arith_base::repair_to_real(op_def const& od) { + if (ctx.rand(20) == 0) + return update(od.m_var, value(od.m_arg1)); else - update(od.m_arg1, value(od.m_arg1)); + return update(od.m_arg1, value(od.m_arg1)); } template - void arith_base::repair_power(op_def const& od) { + bool arith_base::repair_power(op_def const& od) { auto v1 = value(od.m_arg1); auto v2 = value(od.m_arg2); if (v1 == 0 && v2 == 0) { - update(od.m_var, num_t(0)); - return; + return update(od.m_var, num_t(0)); } IF_VERBOSE(0, verbose_stream() << "todo repair ^"); NOT_IMPLEMENTED_YET(); + return false; } template - void arith_base::repair_mod(op_def const& od) { + bool arith_base::repair_mod(op_def const& od) { auto val = value(od.m_var); auto v1 = value(od.m_arg1); auto v2 = value(od.m_arg2); @@ -899,11 +1100,11 @@ namespace sls { if (val >= 0 && val < v2) { auto v3 = mod(v1, v2); if (v3 == val) - return; + return true; // find r, such that mod(v1 + r, v2) = val // v1 := v1 + val - v3 (+/- v2) v1 += val - v3; - switch (rand() % 6) { + switch (ctx.rand(6)) { case 0: v1 += v2; break; @@ -913,28 +1114,27 @@ namespace sls { default: break; } - update(od.m_arg1, v1); - return; + return update(od.m_arg1, v1); } - update(od.m_var, v2 == 0 ? num_t(0) : mod(v1, v2)); + return update(od.m_var, v2 == 0 ? num_t(0) : mod(v1, v2)); } template - void arith_base::repair_idiv(op_def const& od) { + bool arith_base::repair_idiv(op_def const& od) { auto v1 = value(od.m_arg1); auto v2 = value(od.m_arg2); IF_VERBOSE(0, verbose_stream() << "todo repair div"); // bail - update(od.m_var, v2 == 0 ? num_t(0) : div(v1, v2)); + return update(od.m_var, v2 == 0 ? num_t(0) : div(v1, v2)); } template - void arith_base::repair_div(op_def const& od) { + bool arith_base::repair_div(op_def const& od) { auto v1 = value(od.m_arg1); auto v2 = value(od.m_arg2); IF_VERBOSE(0, verbose_stream() << "todo repair /"); // bail - update(od.m_var, v2 == 0 ? num_t(0) : v1 / v2); + return update(od.m_var, v2 == 0 ? num_t(0) : v1 / v2); } template @@ -968,7 +1168,7 @@ namespace sls { result += ctx.reward(bv); #endif } - if (result > max_result || max_result == -1 || (result == max_result && (rand() % ++n == 0))) { + if (result > max_result || max_result == -1 || (result == max_result && (ctx.rand(++n) == 0))) { max_result = result; ineq->m_var_to_flip = x; } @@ -976,6 +1176,48 @@ namespace sls { return max_result; } + // Newton function for integer square root. + template + num_t arith_base::sqrt(num_t n) { + if (n <= 1) + return n; + + auto x0 = div(n, num_t(2)); + + auto x1 = div(x0 + div(n, x0), num_t(2)); + + while (x1 < x0) { + x0 = x1; + x1 = div(x0 + div(n, x0), num_t(2)); + } + return x0; + } + + template + vector const& arith_base::factor(num_t n) { + m_factors.reset(); + for (auto d : { 2, 3, 5 }) { + while (mod(n, num_t(d)) == 0) { + m_factors.push_back(num_t(d)); + n = div(n, num_t(d)); + } + } + static int increments[8] = { 4, 2, 4, 2, 4, 6, 2, 6 }; + unsigned i = 0; + for (auto d = num_t(7); d * d <= n; d += num_t(increments[i++])) { + while (mod(n, d) == 0) { + m_factors.push_back(d); + n = div(n, d); + } + if (i == 8) + i = 0; + } + if (n > 1) + m_factors.push_back(n); + return m_factors; + } + + template double arith_base::dscore_reward(sat::bool_var bv) { m_dscore_mode = false; @@ -1063,8 +1305,11 @@ namespace sls { template expr_ref arith_base::get_value(expr* e) { - auto v = mk_var(e); - return expr_ref(a.mk_numeral(rational(m_vars[v].m_value.get_int64(), rational::i64()), a.is_int(e)), m); + num_t n; + if (is_num(e, n)) + return expr_ref(a.mk_numeral(n.to_rational(), a.is_int(e)), m); + auto v = mk_term(e); + return expr_ref(a.mk_numeral(m_vars[v].m_value.to_rational(), a.is_int(e)), m); } template @@ -1086,6 +1331,7 @@ namespace sls { } if (sat) continue; + verbose_stream() << "not sat:\n"; verbose_stream() << clause << "\n"; for (auto lit : clause.m_clause) { verbose_stream() << lit << " (" << ctx.is_true(lit) << ") "; @@ -1103,6 +1349,30 @@ namespace sls { return true; } + template + std::ostream& arith_base::display(std::ostream& out, var_t v) const { + auto const& vi = m_vars[v]; + auto const& lo = vi.m_lo; + auto const& hi = vi.m_hi; + out << "v" << v << " := " << vi.m_value << " "; + if (lo || hi) { + if (lo) + out << (lo->is_strict ? "(": "[") << lo->value; + else + out << "("; + out << " "; + if (hi) + out << hi->value << (hi->is_strict ? ")" : "]"); + else + out << ")"; + out << " "; + } + out << mk_bounded_pp(vi.m_expr, m) << " : "; + for (auto [c, bv] : vi.m_bool_vars) + out << c << "@" << bv << " "; + return out; + } + template std::ostream& arith_base::display(std::ostream& out) const { for (unsigned v = 0; v < ctx.num_bool_vars(); ++v) { @@ -1110,14 +1380,9 @@ namespace sls { if (ineq) out << v << ": " << *ineq << "\n"; } - for (unsigned v = 0; v < m_vars.size(); ++v) { - auto const& vi = m_vars[v]; - out << "v" << v << " := " << vi.m_value << " (best " << vi.m_best_value << ") "; - out << mk_bounded_pp(vi.m_expr, m) << " : "; - for (auto [c, bv] : vi.m_bool_vars) - out << c << "@" << bv << " "; - out << "\n"; - } + for (unsigned v = 0; v < m_vars.size(); ++v) + display(out, v) << "\n"; + for (auto md : m_muls) { out << "v" << md.m_var << " := "; for (auto w : md.m_monomial) diff --git a/src/ast/sls/sls_arith_base.h b/src/ast/sls/sls_arith_base.h index 93b5b717d..e24059135 100644 --- a/src/ast/sls/sls_arith_base.h +++ b/src/ast/sls/sls_arith_base.h @@ -18,6 +18,7 @@ Author: #include "util/obj_pair_set.h" #include "util/checked_int64.h" +#include "util/optional.h" #include "ast/ast_trail.h" #include "ast/arith_decl_plugin.h" #include "ast/sls/sls_context.h" @@ -31,6 +32,7 @@ namespace sls { class arith_base : public plugin { enum class ineq_kind { EQ, LE, LT}; enum class var_sort { INT, REAL }; + struct bound { bool is_strict = false; num_t value; }; typedef unsigned var_t; typedef unsigned atom_t; @@ -73,6 +75,7 @@ namespace sls { vector> m_bool_vars; unsigned_vector m_muls; unsigned_vector m_adds; + optional m_lo, m_hi; }; struct mul_def { @@ -104,17 +107,24 @@ namespace sls { unsigned get_num_vars() const { return m_vars.size(); } - void repair_mul(mul_def const& md); - void repair_add(add_def const& ad); - void repair_mod(op_def const& od); - void repair_idiv(op_def const& od); - void repair_div(op_def const& od); - void repair_rem(op_def const& od); - void repair_power(op_def const& od); - void repair_abs(op_def const& od); - void repair_to_int(op_def const& od); - void repair_to_real(op_def const& od); + bool repair_mul1(mul_def const& md); + bool repair_square(mul_def const& md); + bool repair_mul(mul_def const& md); + bool repair_add(add_def const& ad); + bool repair_mod(op_def const& od); + bool repair_idiv(op_def const& od); + bool repair_div(op_def const& od); + bool repair_rem(op_def const& od); + bool repair_power(op_def const& od); + bool repair_abs(op_def const& od); + bool repair_to_int(op_def const& od); + bool repair_to_real(op_def const& od); void repair(sat::literal lit, ineq const& ineq); + bool in_bounds(var_t v, num_t const& value); + + vector m_factors; + vector const& factor(num_t n); + num_t sqrt(num_t n); double reward(sat::literal lit); @@ -129,7 +139,7 @@ namespace sls { bool cm(ineq const& ineq, var_t v, num_t& new_value); bool cm(ineq const& ineq, var_t v, num_t const& coeff, num_t& new_value); int cm_score(var_t v, num_t const& new_value); - void update(var_t v, num_t const& new_value); + bool update(var_t v, num_t const& new_value); double dscore_reward(sat::bool_var v); double dtt_reward(sat::literal lit); double dscore(var_t v, num_t const& new_value) const; @@ -146,22 +156,30 @@ namespace sls { void init_bool_var_assignment(sat::bool_var v); + bool is_int(var_t v) const { return m_vars[v].m_sort == var_sort::INT; } + num_t value(var_t v) const { return m_vars[v].m_value; } bool is_num(expr* e, num_t& i); expr_ref from_num(sort* s, num_t const& n); void check_ineqs(); void init_bool_var(sat::bool_var bv); + void initialize(sat::literal lit); + void add_le(var_t v, num_t const& n); + void add_ge(var_t v, num_t const& n); + void add_lt(var_t v, num_t const& n); + void add_gt(var_t v, num_t const& n); + std::ostream& display(std::ostream& out, var_t v) const; public: arith_base(context& ctx); ~arith_base() override {} void register_term(expr* e) override; void set_value(expr* e, expr* v) override; expr_ref get_value(expr* e) override; - void initialize() override {} + void initialize() override; void propagate_literal(sat::literal lit) override; bool propagate() override; void repair_up(app* e) override; - void repair_down(app* e) override; + bool repair_down(app* e) override; bool is_sat() override; void on_rescale() override; void on_restart() override; diff --git a/src/ast/sls/sls_arith_plugin.cpp b/src/ast/sls/sls_arith_plugin.cpp index 55c041df9..aefcbb0ee 100644 --- a/src/ast/sls/sls_arith_plugin.cpp +++ b/src/ast/sls/sls_arith_plugin.cpp @@ -102,7 +102,7 @@ namespace sls { m_arith64->mk_model(mdl); } - void arith_plugin::repair_down(app* e) { + bool arith_plugin::repair_down(app* e) { WITH_FALLBACK(repair_down(e)); } diff --git a/src/ast/sls/sls_arith_plugin.h b/src/ast/sls/sls_arith_plugin.h index 1dbc6f027..70ec38565 100644 --- a/src/ast/sls/sls_arith_plugin.h +++ b/src/ast/sls/sls_arith_plugin.h @@ -35,7 +35,7 @@ namespace sls { void initialize() override; void propagate_literal(sat::literal lit) override; bool propagate() override; - void repair_down(app* e) override; + bool repair_down(app* e) override; void repair_up(app* e) override; bool is_sat() override; diff --git a/src/ast/sls/sls_basic_plugin.cpp b/src/ast/sls/sls_basic_plugin.cpp index eb825db81..0600e1df5 100644 --- a/src/ast/sls/sls_basic_plugin.cpp +++ b/src/ast/sls/sls_basic_plugin.cpp @@ -335,21 +335,23 @@ namespace sls { set_value(e, b); } - void basic_plugin::repair_down(app* e) { + bool basic_plugin::repair_down(app* e) { SASSERT(m.is_bool(e)); unsigned n = e->get_num_args(); - if (n == 0 || !is_basic(e)) - return; + if (!is_basic(e)) + return false; + if (n == 0) + return true; if (bval0(e) == bval1(e)) - return; + return true; unsigned s = ctx.rand(n); for (unsigned i = 0; i < n; ++i) { auto j = (i + s) % n; if (try_repair(e, j)) - return; + return true; } - repair_up(e); + return false; } bool basic_plugin::try_repair_distinct(app* e, unsigned i) { diff --git a/src/ast/sls/sls_basic_plugin.h b/src/ast/sls/sls_basic_plugin.h index 8f9f6b621..64890f81e 100644 --- a/src/ast/sls/sls_basic_plugin.h +++ b/src/ast/sls/sls_basic_plugin.h @@ -46,7 +46,7 @@ namespace sls { void initialize() override; void propagate_literal(sat::literal lit) override; bool propagate() override; - void repair_down(app* e) override; + bool repair_down(app* e) override; void repair_up(app* e) override; bool is_sat() override; diff --git a/src/ast/sls/sls_bv_plugin.cpp b/src/ast/sls/sls_bv_plugin.cpp index ad0f3f99d..032ff397d 100644 --- a/src/ast/sls/sls_bv_plugin.cpp +++ b/src/ast/sls/sls_bv_plugin.cpp @@ -99,34 +99,33 @@ namespace sls { w.commit_eval(); } - void bv_plugin::repair_down(app* e) { + bool bv_plugin::repair_down(app* e) { unsigned n = e->get_num_args(); if (n == 0 || m_eval.eval_is_correct(e)) - return; + return true; if (n == 2) { auto d1 = get_depth(e->get_arg(0)); auto d2 = get_depth(e->get_arg(1)); unsigned s = ctx.rand(d1 + d2 + 2); if (s <= d1 && m_eval.repair_down(e, 0)) - return; + return true; if (m_eval.repair_down(e, 1)) - return; + return true; if (m_eval.repair_down(e, 0)) - return; + return true; } else { unsigned s = ctx.rand(n); for (unsigned i = 0; i < n; ++i) { auto j = (i + s) % n; if (m_eval.repair_down(e, j)) - return; + return true; } } - IF_VERBOSE(0, verbose_stream() << "revert repair: " << mk_bounded_pp(e, m) << "\n"); - repair_up(e); + return false; } void bv_plugin::repair_up(app* e) { diff --git a/src/ast/sls/sls_bv_plugin.h b/src/ast/sls/sls_bv_plugin.h index 9657b2e98..e675f04cc 100644 --- a/src/ast/sls/sls_bv_plugin.h +++ b/src/ast/sls/sls_bv_plugin.h @@ -43,7 +43,7 @@ namespace sls { void initialize() override; void propagate_literal(sat::literal lit) override; bool propagate() override; - void repair_down(app* e) override; + bool repair_down(app* e) override; void repair_up(app* e) override; bool is_sat() override; diff --git a/src/ast/sls/sls_context.cpp b/src/ast/sls/sls_context.cpp index 2a2d59573..2b802b90f 100644 --- a/src/ast/sls/sls_context.cpp +++ b/src/ast/sls/sls_context.cpp @@ -104,9 +104,11 @@ namespace sls { expr* e = term(id); TRACE("sls", tout << "repair down " << mk_bounded_pp(e, m) << "\n"); if (is_app(e)) { - auto p = m_plugins.get(to_app(e)->get_family_id(), nullptr); - if (p) - p->repair_down(to_app(e)); + auto p = m_plugins.get(get_fid(e), nullptr); + if (p && !p->repair_down(to_app(e)) && !m_repair_up.contains(e->get_id())) { + IF_VERBOSE(0, verbose_stream() << "revert repair: " << mk_bounded_pp(e, m) << "\n"); + m_repair_up.insert(e->get_id()); + } } } while (!m_repair_up.empty() && !m_new_constraint) { @@ -114,7 +116,7 @@ namespace sls { expr* e = term(id); TRACE("sls", tout << "repair up " << mk_bounded_pp(e, m) << "\n"); if (is_app(e)) { - auto p = m_plugins.get(to_app(e)->get_family_id(), nullptr); + auto p = m_plugins.get(get_fid(e), nullptr); if (p) p->repair_up(to_app(e)); } @@ -129,15 +131,24 @@ namespace sls { } } + family_id context::get_fid(expr* e) const { + if (!is_app(e)) + return null_family_id; + family_id fid = to_app(e)->get_family_id(); + if (m.is_eq(e) || m.is_distinct(e)) + fid = to_app(e)->get_arg(0)->get_sort()->get_family_id(); + else if (m.is_ite(e)) + fid = to_app(e)->get_arg(1)->get_sort()->get_family_id(); + return fid; + } + void context::propagate_literal(sat::literal lit) { if (!is_true(lit)) return; auto a = atom(lit.var()); - if (!a || !is_app(a)) + if (!a) return; - family_id fid = to_app(a)->get_family_id(); - if (m.is_eq(a) || m.is_distinct(a)) - fid = to_app(a)->get_arg(0)->get_sort()->get_family_id(); + family_id fid = get_fid(a); auto p = m_plugins.get(fid, nullptr); if (p) p->propagate_literal(lit); @@ -223,6 +234,11 @@ namespace sls { if (m_initialized) return; m_initialized = true; + m_unit_literals.reset(); + for (auto const& clause : s.clauses()) + if (clause.m_clause.size() == 1) + m_unit_literals.push_back(clause.m_clause[0]); + verbose_stream() << "UNITS " << m_unit_literals << "\n"; for (auto a : m_atoms) if (a) register_terms(a); @@ -310,7 +326,7 @@ namespace sls { m_relevant.reset(); m_visited.reset(); m_root_literals.reset(); - m_unit_literals.reset(); + for (auto const& clause : s.clauses()) { bool has_relevant = false; unsigned n = 0; @@ -329,8 +345,6 @@ namespace sls { if (m_rand() % ++n == 0) selected_lit = lit; } - if (clause.m_clause.size() == 1) - m_unit_literals.push_back(clause.m_clause[0]); if (!has_relevant && selected_lit != sat::null_literal) { m_relevant.insert(m_atoms[selected_lit.var()]->get_id()); m_root_literals.push_back(selected_lit); diff --git a/src/ast/sls/sls_context.h b/src/ast/sls/sls_context.h index 3d91255d7..7b481d3a8 100644 --- a/src/ast/sls/sls_context.h +++ b/src/ast/sls/sls_context.h @@ -42,7 +42,7 @@ namespace sls { virtual void initialize() = 0; virtual bool propagate() = 0; virtual void propagate_literal(sat::literal lit) = 0; - virtual void repair_down(app* e) = 0; + virtual bool repair_down(app* e) = 0; virtual void repair_up(app* e) = 0; virtual bool is_sat() = 0; virtual void on_rescale() {}; @@ -116,6 +116,8 @@ namespace sls { void propagate_boolean_assignment(); void propagate_literal(sat::literal lit); + + family_id get_fid(expr* e) const; public: context(ast_manager& m, sat_solver_context& s); diff --git a/src/ast/sls/sls_euf_plugin.h b/src/ast/sls/sls_euf_plugin.h index 60504212f..6578e0a3b 100644 --- a/src/ast/sls/sls_euf_plugin.h +++ b/src/ast/sls/sls_euf_plugin.h @@ -49,7 +49,7 @@ namespace sls { void set_value(expr* e, expr* v) override {} void repair_up(app* e) override {} - void repair_down(app* e) override {} + bool repair_down(app* e) override { return false; } }; } diff --git a/src/ast/sls/sls_smt_solver.cpp b/src/ast/sls/sls_smt_solver.cpp index 22d4e8d69..511682350 100644 --- a/src/ast/sls/sls_smt_solver.cpp +++ b/src/ast/sls/sls_smt_solver.cpp @@ -18,6 +18,7 @@ Author: #include "ast/sls/sls_context.h" #include "ast/sls/sat_ddfw.h" #include "ast/sls/sls_smt_solver.h" +#include "ast/ast_ll_pp.h" namespace sls { @@ -101,7 +102,12 @@ namespace sls { } void smt_solver::assert_expr(expr* e) { - m_assertions.push_back(e); + if (m.is_and(e)) { + for (expr* arg : *to_app(e)) + assert_expr(arg); + } + else + m_assertions.push_back(e); } lbool smt_solver::check() { @@ -116,7 +122,11 @@ namespace sls { } void smt_solver::add_clause(expr* f) { + expr* g; sat::literal_vector clause; + if (m.is_not(f, g) && m.is_not(g, g)) { + add_clause(g); + } if (m.is_or(f)) { clause.reset(); for (auto arg : *to_app(f)) @@ -127,6 +137,18 @@ namespace sls { for (auto arg : *to_app(f)) add_clause(arg); } + else if (m.is_not(f, g) && m.is_or(g)) { + for (auto arg : *to_app(g)) { + expr_ref fml(m.mk_not(arg), m);; + add_clause(fml); + } + } + else if (m.is_not(f, g) && m.is_and(g)) { + clause.reset(); + for (auto arg : *to_app(g)) + clause.push_back(~mk_literal(arg)); + m_solver_ctx->add_clause(clause.size(), clause.data()); + } else { sat::literal lit = mk_literal(f); m_solver_ctx->add_clause(1, &lit); diff --git a/src/util/checked_int64.h b/src/util/checked_int64.h index c2a9121fd..174b7b10a 100644 --- a/src/util/checked_int64.h +++ b/src/util/checked_int64.h @@ -35,7 +35,7 @@ class checked_int64 { int64_t m_value; typedef checked_int64 ci; - rational r64(int64_t i) { return rational(i, rational::i64()); } + rational r64(int64_t i) const { return rational(i, rational::i64()); } public: @@ -56,6 +56,7 @@ public: static checked_int64 minus_one() { return ci(-1);} int64_t get_int64() const { return m_value; } + rational to_rational() const { return r64(m_value); } checked_int64 abs() const { if (m_value >= 0) { @@ -305,3 +306,21 @@ inline checked_int64 mod(checked_int64 const& a, checked_int64 +inline bool divides(checked_int64 const& a, checked_int64 const& b) { + return mod(b, a) == 0; +} + +template +inline checked_int64 gcd(checked_int64 const& a, checked_int64 const& b) { + checked_int64 _a = abs(a); + checked_int64 _b = abs(b); + if (_a == 0) + return _b; + while (_b != 0) { + checked_int64 r = mod(_a, _b); + _a = _b; + _b = r; + } + return _a; +} \ No newline at end of file