3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-22 16:45:31 +00:00
This commit is contained in:
Nikolaj Bjorner 2024-07-05 17:03:00 -07:00
parent 5ebcc3e447
commit 3e57a9ce1e
4 changed files with 205 additions and 175 deletions

View file

@ -20,35 +20,35 @@ Author:
namespace sls {
template<typename int_t>
arith_plugin<int_t>::arith_plugin(context& ctx) :
template<typename num_t>
arith_plugin<num_t>::arith_plugin(context& ctx) :
plugin(ctx),
a(m) {
m_fid = a.get_family_id();
}
template<typename int_t>
void arith_plugin<int_t>::reset() {
template<typename num_t>
void arith_plugin<num_t>::reset() {
m_bool_vars.reset();
m_vars.reset();
m_expr2var.reset();
}
template<typename int_t>
void arith_plugin<int_t>::save_best_values() {
template<typename num_t>
void arith_plugin<num_t>::save_best_values() {
for (auto& v : m_vars)
v.m_best_value = v.m_value;
check_ineqs();
}
template<typename int_t>
void arith_plugin<int_t>::store_best_values() {
template<typename num_t>
void arith_plugin<num_t>::store_best_values() {
}
// distance to true
template<typename int_t>
int_t arith_plugin<int_t>::dtt(bool sign, int_t const& args, ineq const& ineq) const {
int_t zero{ 0 };
template<typename num_t>
num_t arith_plugin<num_t>::dtt(bool sign, num_t const& args, ineq const& ineq) const {
num_t zero{ 0 };
switch (ineq.m_op) {
case ineq_kind::LE:
if (sign) {
@ -62,12 +62,12 @@ namespace sls {
case ineq_kind::EQ:
if (sign) {
if (args + ineq.m_coeff == 0)
return int_t(1);
return num_t(1);
return zero;
}
if (args + ineq.m_coeff == 0)
return zero;
return int_t(1);
return num_t(1);
case ineq_kind::LT:
if (sign) {
if (args + ineq.m_coeff < 0)
@ -88,33 +88,40 @@ namespace sls {
// m_vars[w].m_value can be computed outside and shared among calls
// different data-structures for storing coefficients
//
template<typename int_t>
int_t arith_plugin<int_t>::dtt(bool sign, ineq const& ineq, var_t v, int_t const& new_value) const {
template<typename num_t>
num_t arith_plugin<num_t>::dtt(bool sign, ineq const& ineq, var_t v, num_t const& new_value) const {
for (auto const& [coeff, w] : ineq.m_args)
if (w == v)
return dtt(sign, ineq.m_args_value + coeff * (new_value - m_vars[v].m_value), ineq);
return int_t(1);
return num_t(1);
}
template<typename int_t>
int_t arith_plugin<int_t>::dtt(bool sign, ineq const& ineq, int_t const& coeff, int_t const& old_value, int_t const& new_value) const {
template<typename num_t>
num_t arith_plugin<num_t>::dtt(bool sign, ineq const& ineq, num_t const& coeff, num_t const& old_value, num_t const& new_value) const {
return dtt(sign, ineq.m_args_value + coeff * (new_value - old_value), ineq);
}
template<typename int_t>
bool arith_plugin<int_t>::cm(ineq const& ineq, var_t v, int_t& new_value) {
template<typename num_t>
bool arith_plugin<num_t>::cm(ineq const& ineq, var_t v, num_t& new_value) {
for (auto const& [coeff, w] : ineq.m_args)
if (w == v)
return cm(ineq, v, coeff, new_value);
return false;
}
template<typename int_t>
bool arith_plugin<int_t>::cm(ineq const& ineq, var_t v, int_t const& coeff, int_t& new_value) {
template<typename num_t>
num_t arith_plugin<num_t>::divide(var_t v, num_t const& delta, num_t const& coeff) {
if (m_vars[v].m_kind == var_kind::REAL)
return delta / coeff;
return div(delta + abs(coeff) - 1, coeff);
}
template<typename num_t>
bool arith_plugin<num_t>::cm(ineq const& ineq, var_t v, num_t const& coeff, num_t& new_value) {
auto bound = -ineq.m_coeff;
auto argsv = ineq.m_args_value;
bool solved = false;
int_t delta = argsv - bound;
num_t delta = argsv - bound;
if (ineq.is_true()) {
switch (ineq.m_op) {
@ -123,7 +130,7 @@ namespace sls {
SASSERT(argsv <= bound);
SASSERT(delta <= 0);
delta -= 1 + (ctx.rand() % 10);
new_value = value(v) + div(abs(delta) + abs(coeff) - 1, coeff);
new_value = value(v) + divide(v, abs(delta), coeff);
VERIFY(argsv + coeff * (new_value - value(v)) > bound);
return true;
case ineq_kind::LT:
@ -131,13 +138,13 @@ namespace sls {
SASSERT(argsv <= bound);
SASSERT(delta <= 0);
delta = abs(delta) + ctx.rand() % 10;
new_value = value(v) + div(delta + abs(coeff) - 1, coeff);
new_value = value(v) + divide(v, delta, coeff);
VERIFY(argsv + coeff * (new_value - value(v)) >= bound);
return true;
case ineq_kind::EQ: {
delta = abs(delta) + 1 + ctx.rand() % 10;
int sign = ctx.rand() % 2 == 0 ? 1 : -1;
new_value = value(v) + sign * div(abs(delta) + abs(coeff) - 1, coeff);
new_value = value(v) + sign * divide(v, abs(delta), coeff);
VERIFY(argsv + coeff * (new_value - value(v)) != bound);
return true;
}
@ -152,22 +159,22 @@ namespace sls {
SASSERT(argsv > bound);
SASSERT(delta > 0);
delta += rand() % 10;
new_value = value(v) - div(delta + abs(coeff) - 1, coeff);
new_value = value(v) - divide(v, delta, coeff);
VERIFY(argsv + coeff * (new_value - value(v)) <= bound);
return true;
case ineq_kind::LT:
SASSERT(argsv >= bound);
SASSERT(delta >= 0);
delta += 1 + rand() % 10;
new_value = value(v) - div(delta + abs(coeff) - 1, coeff);
new_value = value(v) - divide(v, delta, coeff);
VERIFY(argsv + coeff * (new_value - value(v)) < bound);
return true;
case ineq_kind::EQ:
SASSERT(delta != 0);
if (delta < 0)
new_value = value(v) + div(abs(delta) + abs(coeff) - 1, coeff);
new_value = value(v) + divide(v, abs(delta), coeff);
else
new_value = value(v) - div(delta + abs(coeff) - 1, coeff);
new_value = value(v) - divide(v, delta, coeff);
solved = argsv + coeff * (new_value - value(v)) == bound;
if (!solved && abs(coeff) == 1) {
verbose_stream() << "did not solve equality " << ineq << " for " << v << "\n";
@ -187,9 +194,9 @@ namespace sls {
// it could be changed to flip on maximal positive score
// or flip on maximal non-negative score
// or flip on first non-negative score
template<typename int_t>
void arith_plugin<int_t>::repair(sat::literal lit, ineq const& ineq) {
int_t new_value;
template<typename num_t>
void arith_plugin<num_t>::repair(sat::literal lit, ineq const& ineq) {
num_t new_value;
if (UINT_MAX == ineq.m_var_to_flip)
dtt_reward(lit);
auto v = ineq.m_var_to_flip;
@ -210,8 +217,8 @@ namespace sls {
// TODO - use cached dts instead of computed dts
// cached dts has to be updated when the score of literals are updated.
//
template<typename int_t>
double arith_plugin<int_t>::dscore(var_t v, int_t const& new_value) const {
template<typename num_t>
double arith_plugin<num_t>::dscore(var_t v, num_t const& new_value) const {
double score = 0;
auto const& vi = m_vars[v];
for (auto const& [coeff, bv] : vi.m_bool_vars) {
@ -230,16 +237,16 @@ namespace sls {
// - get_use_list(lit).size() is "often" 1 or 2
// - dtt_old can be saved
//
template<typename int_t>
int arith_plugin<int_t>::cm_score(var_t v, int_t const& new_value) {
template<typename num_t>
int arith_plugin<num_t>::cm_score(var_t v, num_t const& new_value) {
int score = 0;
auto& vi = m_vars[v];
int_t old_value = vi.m_value;
num_t old_value = vi.m_value;
for (auto const& [coeff, bv] : vi.m_bool_vars) {
auto const& ineq = *atom(bv);
bool old_sign = sign(bv);
int_t dtt_old = dtt(old_sign, ineq);
int_t dtt_new = dtt(old_sign, ineq, coeff, old_value, new_value);
num_t dtt_old = dtt(old_sign, ineq);
num_t dtt_new = dtt(old_sign, ineq, coeff, old_value, new_value);
if ((dtt_old == 0) == (dtt_new == 0))
continue;
sat::literal lit(bv, old_sign);
@ -265,9 +272,9 @@ namespace sls {
return score;
}
template<typename int_t>
int_t arith_plugin<int_t>::compute_dts(unsigned cl) const {
int_t d(1), d2;
template<typename num_t>
num_t arith_plugin<num_t>::compute_dts(unsigned cl) const {
num_t d(1), d2;
bool first = true;
for (auto a : ctx.get_clause(cl)) {
auto const* ineq = atom(a.var());
@ -284,9 +291,9 @@ namespace sls {
return d;
}
template<typename int_t>
int_t arith_plugin<int_t>::dts(unsigned cl, var_t v, int_t const& new_value) const {
int_t d(1), d2;
template<typename num_t>
num_t arith_plugin<num_t>::dts(unsigned cl, var_t v, num_t const& new_value) const {
num_t d(1), d2;
bool first = true;
for (auto lit : ctx.get_clause(cl)) {
auto const* ineq = atom(lit.var());
@ -303,8 +310,8 @@ namespace sls {
return d;
}
template<typename int_t>
void arith_plugin<int_t>::update(var_t v, int_t const& new_value) {
template<typename num_t>
void arith_plugin<num_t>::update(var_t v, num_t const& new_value) {
auto& vi = m_vars[v];
auto old_value = vi.m_value;
if (old_value == new_value)
@ -315,7 +322,7 @@ namespace sls {
sat::literal lit(bv, old_sign);
SASSERT(ctx.is_true(lit));
ineq.m_args_value += coeff * (new_value - old_value);
int_t dtt_new = dtt(old_sign, ineq);
num_t dtt_new = dtt(old_sign, ineq);
if (dtt_new != 0)
ctx.flip(bv);
SASSERT(dtt(sign(bv), ineq) == 0);
@ -324,7 +331,7 @@ namespace sls {
for (auto idx : vi.m_muls) {
auto const& [v, monomial] = m_muls[idx];
int_t prod(1);
num_t prod(1);
for (auto w : monomial)
prod *= value(w);
if (value(v) != prod)
@ -334,7 +341,7 @@ namespace sls {
auto const& ad = m_adds[idx];
auto const& args = ad.m_args;
auto v = ad.m_var;
int_t sum(ad.m_coeff);
num_t sum(ad.m_coeff);
for (auto [c, w] : args)
sum += c * value(w);
if (value(v) != sum)
@ -345,52 +352,50 @@ namespace sls {
m_defs_to_update.push_back(v);
}
template<typename int_t>
typename arith_plugin<int_t>::ineq& arith_plugin<int_t>::new_ineq(ineq_kind op, int_t const& coeff) {
template<typename num_t>
typename arith_plugin<num_t>::ineq& arith_plugin<num_t>::new_ineq(ineq_kind op, num_t const& coeff) {
auto* i = alloc(ineq);
i->m_coeff = coeff;
i->m_op = op;
return *i;
}
template<typename int_t>
void arith_plugin<int_t>::add_arg(linear_term& ineq, int_t const& c, var_t v) {
template<typename num_t>
void arith_plugin<num_t>::add_arg(linear_term& ineq, num_t const& c, var_t v) {
ineq.m_args.push_back({ c, v });
}
template<typename int_t>
bool arith_plugin<int_t>::is_int64(expr* e, int_t& i) {
bool arith_plugin<checked_int64<true>>::is_num(expr* e, checked_int64<true>& i) {
rational r;
if (a.is_numeral(e, r) && r.is_int64()) {
i = int_t(r.get_int64());
if (a.is_numeral(e, r)) {
if (!r.is_int64())
throw overflow_exception();
i = r.get_int64();
return true;
}
return false;
}
bool arith_plugin<checked_int64<true>>::is_int(expr* e, checked_int64<true>& i) {
return is_int64(e, i);
bool arith_plugin<rational>::is_num(expr* e, rational& i) {
return a.is_numeral(e, i);
}
bool arith_plugin<rational>::is_int(expr* e, rational& i) {
return a.is_numeral(e, i) && i.is_int();
}
template<typename int_t>
bool arith_plugin<int_t>::is_int(expr* e, int_t& i) {
template<typename num_t>
bool arith_plugin<num_t>::is_num(expr* e, num_t& i) {
return false;
}
template<typename int_t>
void arith_plugin<int_t>::add_args(linear_term& term, expr* e, int_t const& coeff) {
template<typename num_t>
void arith_plugin<num_t>::add_args(linear_term& term, expr* e, num_t const& coeff) {
auto v = m_expr2var.get(e->get_id(), UINT_MAX);
if (v != UINT_MAX) {
add_arg(term, coeff, v);
return;
}
expr* x, * y;
int_t i;
if (is_int(e, i)) {
num_t i;
if (is_num(e, i)) {
term.m_coeff += coeff * i;
return;
}
@ -407,9 +412,9 @@ namespace sls {
if (a.is_mul(e)) {
unsigned_vector m;
int_t c = coeff;
num_t c = coeff;
for (expr* arg : *to_app(e))
if (is_int(x, i))
if (is_num(x, i))
c *= i;
else
m.push_back(mk_term(arg));
@ -424,7 +429,7 @@ namespace sls {
auto v = mk_var(e);
unsigned idx = m_muls.size();
m_muls.push_back({ v, m });
int_t prod(1);
num_t prod(1);
for (auto w : m)
m_vars[w].m_muls.push_back(idx), prod *= value(w);
m_vars[v].m_mul_idx = idx;
@ -449,18 +454,18 @@ namespace sls {
UNREACHABLE();
}
template<typename int_t>
typename arith_plugin<int_t>::var_t arith_plugin<int_t>::mk_term(expr* e) {
template<typename num_t>
typename arith_plugin<num_t>::var_t arith_plugin<num_t>::mk_term(expr* e) {
auto v = m_expr2var.get(e->get_id(), UINT_MAX);
if (v != UINT_MAX)
return v;
linear_term t = linear_term({ {}, 0 });
add_args(t, e, int_t(1));
linear_term t;
add_args(t, e, num_t(1));
if (t.m_coeff == 1 && t.m_args.size() == 1 && t.m_args[0].first == 1)
return t.m_args[0].second;
v = mk_var(e);
auto idx = m_adds.size();
int_t sum(t.m_coeff);
num_t sum(t.m_coeff);
m_adds.push_back({ t.m_args, t.m_coeff, v });
for (auto const& [c, w] : t.m_args)
m_vars[w].m_adds.push_back(idx), sum += c * value(w);
@ -469,19 +474,19 @@ namespace sls {
return v;
}
template<typename int_t>
unsigned arith_plugin<int_t>::mk_var(expr* e) {
template<typename num_t>
unsigned arith_plugin<num_t>::mk_var(expr* e) {
unsigned v = m_expr2var.get(e->get_id(), UINT_MAX);
if (v == UINT_MAX) {
v = m_vars.size();
m_expr2var.setx(e->get_id(), v, UINT_MAX);
m_vars.push_back(var_info(e, var_kind::INT));
m_vars.push_back(var_info(e, a.is_int(e) ? var_kind::INT : var_kind::REAL));
}
return v;
}
template<typename int_t>
void arith_plugin<int_t>::init_bool_var(sat::bool_var bv) {
template<typename num_t>
void arith_plugin<num_t>::init_bool_var(sat::bool_var bv) {
if (m_bool_vars.get(bv, nullptr))
return;
expr* e = ctx.atom(bv);
@ -491,21 +496,27 @@ namespace sls {
expr* x, * y;
m_bool_vars.reserve(bv + 1);
if (a.is_le(e, x, y) || a.is_ge(e, y, x)) {
auto& ineq = new_ineq(ineq_kind::LE, int_t(0));
add_args(ineq, x, int_t(1));
add_args(ineq, y, int_t(-1));
auto& ineq = new_ineq(ineq_kind::LE, num_t(0));
add_args(ineq, x, num_t(1));
add_args(ineq, y, num_t(-1));
init_ineq(bv, ineq);
}
else if ((a.is_lt(e, x, y) || a.is_gt(e, y, x)) && a.is_int(x)) {
auto& ineq = new_ineq(ineq_kind::LE, int_t(1));
add_args(ineq, x, int_t(1));
add_args(ineq, y, int_t(-1));
auto& ineq = new_ineq(ineq_kind::LE, num_t(1));
add_args(ineq, x, num_t(1));
add_args(ineq, y, num_t(-1));
init_ineq(bv, ineq);
}
else if ((a.is_lt(e, x, y) || a.is_gt(e, y, x)) && a.is_real(x)) {
auto& ineq = new_ineq(ineq_kind::LT, num_t(0));
add_args(ineq, x, num_t(1));
add_args(ineq, y, num_t(-1));
init_ineq(bv, ineq);
}
else if (m.is_eq(e, x, y) && a.is_int_real(x)) {
auto& ineq = new_ineq(ineq_kind::EQ, int_t(0));
add_args(ineq, x, int_t(1));
add_args(ineq, y, int_t(-1));
auto& ineq = new_ineq(ineq_kind::EQ, num_t(0));
add_args(ineq, x, num_t(1));
add_args(ineq, y, num_t(-1));
init_ineq(bv, ineq);
}
else {
@ -513,8 +524,8 @@ namespace sls {
}
}
template<typename int_t>
void arith_plugin<int_t>::init_ineq(sat::bool_var bv, ineq& i) {
template<typename num_t>
void arith_plugin<num_t>::init_ineq(sat::bool_var bv, ineq& i) {
i.m_args_value = 0;
for (auto const& [coeff, v] : i.m_args) {
m_vars[v].m_bool_vars.push_back({ coeff, bv });
@ -523,15 +534,15 @@ namespace sls {
m_bool_vars.set(bv, &i);
}
template<typename int_t>
void arith_plugin<int_t>::init_bool_var_assignment(sat::bool_var v) {
template<typename num_t>
void arith_plugin<num_t>::init_bool_var_assignment(sat::bool_var v) {
auto* ineq = m_bool_vars.get(v, nullptr);
if (ineq && ctx.is_true(sat::literal(v, false)) != (dtt(false, *ineq) == 0))
ctx.flip(v);
}
template<typename int_t>
void arith_plugin<int_t>::repair(sat::literal lit) {
template<typename num_t>
void arith_plugin<num_t>::repair(sat::literal lit) {
if (!ctx.is_true(lit))
return;
auto const* ineq = atom(lit.var());
@ -543,8 +554,8 @@ namespace sls {
repair(lit, *ineq);
}
template<typename int_t>
void arith_plugin<int_t>::propagate_updates() {
template<typename num_t>
void arith_plugin<num_t>::propagate_updates() {
while (!m_defs_to_update.empty() || !m_vars_to_update.empty()) {
while (!m_vars_to_update.empty()) {
auto [w, new_value1] = m_vars_to_update.back();
@ -555,8 +566,8 @@ namespace sls {
}
}
template<typename int_t>
void arith_plugin<int_t>::repair_defs() {
template<typename num_t>
void arith_plugin<num_t>::repair_defs() {
while (!m_defs_to_update.empty()) {
auto v = m_defs_to_update.back();
m_defs_to_update.pop_back();
@ -568,12 +579,12 @@ namespace sls {
}
}
template<typename int_t>
void arith_plugin<int_t>::repair_add(add_def const& ad) {
template<typename num_t>
void arith_plugin<num_t>::repair_add(add_def const& ad) {
auto v = ad.m_var;
auto const& coeffs = ad.m_args;
int_t sum(ad.m_coeff);
int_t val = value(v);
num_t sum(ad.m_coeff);
num_t val = value(v);
for (auto const& [c, w] : coeffs)
sum += c * value(w);
if (val == sum)
@ -582,16 +593,18 @@ namespace sls {
update(v, sum);
else {
auto const& [c, w] = coeffs[rand() % coeffs.size()];
int_t delta = sum - val;
int_t new_value = value(w) + div(delta, c);
num_t delta = sum - val;
bool is_real = m_vars[w].m_kind == var_kind::REAL;
bool round_down = rand() % 2 == 0;
num_t new_value = value(w) + (is_real ? delta / c : round_down ? div(delta, c) : div(delta + c - 1, c));
update(w, new_value);
}
}
template<typename int_t>
void arith_plugin<int_t>::repair_mul(mul_def const& md) {
int_t product(1);
int_t val = value(md.m_var);
template<typename num_t>
void arith_plugin<num_t>::repair_mul(mul_def const& md) {
num_t product(1);
num_t val = value(md.m_var);
for (auto v : md.m_monomial)
product *= value(v);
if (product == val)
@ -601,13 +614,13 @@ namespace sls {
}
else if (val == 0) {
auto v = md.m_monomial[rand() % md.m_monomial.size()];
int_t zero(0);
num_t zero(0);
update(v, zero);
}
else if (val == 1 || val == -1) {
product = 1;
for (auto v : md.m_monomial) {
int_t new_value(1);
num_t new_value(1);
if (rand() % 2 == 0)
new_value = -1;
product *= new_value;
@ -618,10 +631,22 @@ namespace sls {
update(last, -value(last));
}
}
else if (rand() % 2 == 0 && product != 0) {
// value1(v) * product / value(v) = val
// value1(v) = value(v) * val / product
auto w = md.m_monomial[rand() % md.m_monomial.size()];
auto old_value = value(w);
num_t new_value;
if (m_vars[w].m_kind == var_kind::REAL)
new_value = old_value * val / product;
else
new_value = divide(w, old_value * val, product);
update(w, new_value);
}
else {
product = 1;
for (auto v : md.m_monomial) {
int_t new_value{ 1 };
num_t new_value{ 1 };
if (rand() % 2 == 0)
new_value = -1;
product *= new_value;
@ -635,20 +660,20 @@ namespace sls {
}
}
template<typename int_t>
double arith_plugin<int_t>::reward(sat::literal lit) {
template<typename num_t>
double arith_plugin<num_t>::reward(sat::literal lit) {
if (m_dscore_mode)
return dscore_reward(lit.var());
else
return dtt_reward(lit);
}
template<typename int_t>
double arith_plugin<int_t>::dtt_reward(sat::literal lit) {
template<typename num_t>
double arith_plugin<num_t>::dtt_reward(sat::literal lit) {
auto* ineq = atom(lit.var());
if (!ineq)
return -1;
int_t new_value;
num_t new_value;
double max_result = -1;
unsigned n = 0;
for (auto const& [coeff, x] : ineq->m_args) {
@ -674,8 +699,8 @@ namespace sls {
return max_result;
}
template<typename int_t>
double arith_plugin<int_t>::dscore_reward(sat::bool_var bv) {
template<typename num_t>
double arith_plugin<num_t>::dscore_reward(sat::bool_var bv) {
m_dscore_mode = false;
bool old_sign = sign(bv);
sat::literal litv(bv, old_sign);
@ -683,7 +708,7 @@ namespace sls {
if (!ineq)
return 0;
SASSERT(ineq->is_true() != old_sign);
int_t new_value;
num_t new_value;
for (auto const& [coeff, v] : ineq->m_args) {
double result = 0;
@ -699,25 +724,25 @@ namespace sls {
}
// switch to dscore mode
template<typename int_t>
void arith_plugin<int_t>::on_rescale() {
template<typename num_t>
void arith_plugin<num_t>::on_rescale() {
m_dscore_mode = true;
}
template<typename int_t>
void arith_plugin<int_t>::on_restart() {
template<typename num_t>
void arith_plugin<num_t>::on_restart() {
for (unsigned v = 0; v < ctx.num_bool_vars(); ++v)
init_bool_var_assignment(v);
check_ineqs();
}
template<typename int_t>
void arith_plugin<int_t>::check_ineqs() {
template<typename num_t>
void arith_plugin<num_t>::check_ineqs() {
auto check_bool_var = [&](sat::bool_var bv) {
auto const* ineq = atom(bv);
if (!ineq)
return;
int_t d = dtt(sign(bv), *ineq);
num_t d = dtt(sign(bv), *ineq);
sat::literal lit(bv, sign(bv));
if (ctx.is_true(lit) != (d == 0)) {
verbose_stream() << "invalid assignment " << bv << " " << *ineq << "\n";
@ -728,18 +753,18 @@ namespace sls {
check_bool_var(v);
}
template<typename int_t>
void arith_plugin<int_t>::register_term(expr* e) {
template<typename num_t>
void arith_plugin<num_t>::register_term(expr* e) {
}
template<typename int_t>
expr_ref arith_plugin<int_t>::get_value(expr* e) {
template<typename num_t>
expr_ref arith_plugin<num_t>::get_value(expr* e) {
auto v = mk_var(e);
return expr_ref(a.mk_numeral(rational(m_vars[v].m_value.get_int64(), rational::i64()), a.is_int(e)), m);
}
template<typename int_t>
lbool arith_plugin<int_t>::check() {
template<typename num_t>
lbool arith_plugin<num_t>::check() {
// repair each root literal
for (sat::literal lit : ctx.root_literals())
repair(lit);
@ -753,8 +778,8 @@ namespace sls {
return ctx.unsat().empty() ? l_true : l_undef;
}
template<typename int_t>
bool arith_plugin<int_t>::is_sat() {
template<typename num_t>
bool arith_plugin<num_t>::is_sat() {
for (auto const& clause : ctx.clauses()) {
bool sat = false;
for (auto lit : clause.m_clause) {
@ -776,8 +801,8 @@ namespace sls {
return true;
}
template<typename int_t>
std::ostream& arith_plugin<int_t>::display(std::ostream& out) const {
template<typename num_t>
std::ostream& arith_plugin<num_t>::display(std::ostream& out) const {
for (unsigned v = 0; v < ctx.num_bool_vars(); ++v) {
auto ineq = atom(v);
if (ineq)
@ -794,8 +819,8 @@ namespace sls {
return out;
}
template<typename int_t>
void arith_plugin<int_t>::mk_model(model& mdl) {
template<typename num_t>
void arith_plugin<num_t>::mk_model(model& mdl) {
for (auto const& v : m_vars) {
expr* e = v.m_expr;
if (is_uninterp_const(e))

View file

@ -27,7 +27,7 @@ namespace sls {
using theory_var = int;
// local search portion for arithmetic
template<typename int_t>
template<typename num_t>
class arith_plugin : public plugin {
enum class ineq_kind { EQ, LE, LT};
enum class var_kind { INT, REAL };
@ -46,17 +46,15 @@ namespace sls {
unsigned m_num_flips = 0;
};
// typedef checked_int64<true> int_t;
public:
struct linear_term {
vector<std::pair<int_t, var_t>> m_args;
int_t m_coeff;
vector<std::pair<num_t, var_t>> m_args;
num_t m_coeff{ 0 };
};
// encode args <= bound, args = bound, args < bound
struct ineq : public linear_term {
ineq_kind m_op = ineq_kind::LE;
int_t m_args_value;
num_t m_args_value;
unsigned m_var_to_flip = UINT_MAX;
bool is_true() const {
@ -90,12 +88,12 @@ namespace sls {
struct var_info {
var_info(expr* e, var_kind k): m_expr(e), m_kind(k) {}
expr* m_expr;
int_t m_value{ 0 };
int_t m_best_value{ 0 };
num_t m_value{ 0 };
num_t m_best_value{ 0 };
var_kind m_kind;
unsigned m_add_idx = UINT_MAX;
unsigned m_mul_idx = UINT_MAX;
vector<std::pair<int_t, sat::bool_var>> m_bool_vars;
vector<std::pair<num_t, sat::bool_var>> m_bool_vars;
unsigned_vector m_muls;
unsigned_vector m_adds;
};
@ -124,7 +122,7 @@ namespace sls {
void repair_mul(mul_def const& md);
void repair_add(add_def const& ad);
unsigned_vector m_defs_to_update;
vector<std::pair<var_t, int_t>> m_vars_to_update;
vector<std::pair<var_t, num_t>> m_vars_to_update;
void propagate_updates();
void repair_defs();
void repair(sat::literal lit);
@ -136,33 +134,33 @@ namespace sls {
ineq* atom(sat::bool_var bv) const { return m_bool_vars.get(bv, nullptr); }
int_t dtt(bool sign, ineq const& ineq) const { return dtt(sign, ineq.m_args_value, ineq); }
int_t dtt(bool sign, int_t const& args_value, ineq const& ineq) const;
int_t dtt(bool sign, ineq const& ineq, var_t v, int_t const& new_value) const;
int_t dtt(bool sign, ineq const& ineq, int_t const& coeff, int_t const& old_value, int_t const& new_value) const;
int_t dts(unsigned cl, var_t v, int_t const& new_value) const;
int_t compute_dts(unsigned cl) const;
bool cm(ineq const& ineq, var_t v, int_t& new_value);
bool cm(ineq const& ineq, var_t v, int_t const& coeff, int_t& new_value);
int cm_score(var_t v, int_t const& new_value);
void update(var_t v, int_t const& new_value);
num_t dtt(bool sign, ineq const& ineq) const { return dtt(sign, ineq.m_args_value, ineq); }
num_t dtt(bool sign, num_t const& args_value, ineq const& ineq) const;
num_t dtt(bool sign, ineq const& ineq, var_t v, num_t const& new_value) const;
num_t dtt(bool sign, ineq const& ineq, num_t const& coeff, num_t const& old_value, num_t const& new_value) const;
num_t dts(unsigned cl, var_t v, num_t const& new_value) const;
num_t compute_dts(unsigned cl) const;
bool cm(ineq const& ineq, var_t v, num_t& new_value);
bool cm(ineq const& ineq, var_t v, num_t const& coeff, num_t& new_value);
int cm_score(var_t v, num_t const& new_value);
void update(var_t v, num_t const& new_value);
double dscore_reward(sat::bool_var v);
double dtt_reward(sat::literal lit);
double dscore(var_t v, int_t const& new_value) const;
double dscore(var_t v, num_t const& new_value) const;
void save_best_values();
void store_best_values();
unsigned mk_var(expr* e);
ineq& new_ineq(ineq_kind op, int_t const& bound);
void add_arg(linear_term& term, int_t const& c, var_t v);
void add_args(linear_term& term, expr* e, int_t const& sign);
ineq& new_ineq(ineq_kind op, num_t const& bound);
void add_arg(linear_term& term, num_t const& c, var_t v);
void add_args(linear_term& term, expr* e, num_t const& sign);
var_t mk_term(expr* e);
void init_ineq(sat::bool_var bv, ineq& i);
num_t divide(var_t v, num_t const& delta, num_t const& coeff);
void init_bool_var_assignment(sat::bool_var v);
int_t value(var_t v) const { return m_vars[v].m_value; }
bool is_int64(expr* e, int_t& i);
bool is_int(expr* e, int_t& i);
num_t value(var_t v) const { return m_vars[v].m_value; }
bool is_num(expr* e, num_t& i);
void check_ineqs();

View file

@ -43,7 +43,7 @@ class hilbert_basis {
typedef vector<numeral> num_vector;
static checked_int64<check> to_numeral(rational const& r) {
if (!r.is_int64()) {
throw checked_int64<check>::overflow_exception();
throw overflow_exception();
}
return checked_int64<check>(r.get_int64());
}

View file

@ -26,6 +26,10 @@ Revision History:
#include "util/z3_exception.h"
#include "util/rational.h"
class overflow_exception : public z3_exception {
char const* msg() const override { return "checked_int64 overflow/underflow"; }
};
template<bool CHECK>
class checked_int64 {
int64_t m_value;
@ -38,10 +42,6 @@ public:
checked_int64(): m_value(0) {}
checked_int64(int64_t v): m_value(v) {}
class overflow_exception : public z3_exception {
char const * msg() const override { return "checked_int64 overflow/underflow";}
};
bool is_zero() const { return m_value == 0; }
bool is_pos() const { return m_value > 0; }
bool is_neg() const { return m_value < 0; }
@ -279,3 +279,10 @@ inline checked_int64<CHECK> div(checked_int64<CHECK> const& a, checked_int64<CHE
result /= b;
return result;
}
template<bool CHECK>
inline checked_int64<CHECK> operator/(checked_int64<CHECK> const& a, checked_int64<CHECK> const& b) {
checked_int64<CHECK> result(a);
result /= b;
return result;
}