3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-12 04:03:39 +00:00

consolidate functionality

This commit is contained in:
Nikolaj Bjorner 2025-01-25 22:34:58 -08:00
parent a7010574c8
commit 12e8082d86
3 changed files with 56 additions and 77 deletions

View file

@ -489,11 +489,11 @@ namespace sls {
if (vi.m_op == arith_op_kind::OP_NUM) if (vi.m_op == arith_op_kind::OP_NUM)
return; return;
if (is_add(v) && m_allow_recursive_delta) if (is_add(v) && m_allow_recursive_delta)
add_update_add(m_adds[vi.m_def_idx], delta_out); add_update_add(get_add(v), delta_out);
else if (is_mul(v) && m_allow_recursive_delta) else if (is_mul(v) && m_allow_recursive_delta)
add_update_mul(m_muls[vi.m_def_idx], delta_out); add_update_mul(get_mul(v), delta_out);
else if (is_op(v) && m_allow_recursive_delta) else if (is_op(v) && m_allow_recursive_delta)
add_update(m_ops[vi.m_def_idx], delta_out); add_update(get_op(v), delta_out);
else if (vi.is_if_op() && m_allow_recursive_delta) { else if (vi.is_if_op() && m_allow_recursive_delta) {
expr* c, * t, * e; expr* c, * t, * e;
VERIFY(m.is_ite(vi.m_expr, c, t, e)); VERIFY(m.is_ite(vi.m_expr, c, t, e));
@ -1283,7 +1283,7 @@ namespace sls {
case LAST_ARITH_OP: case LAST_ARITH_OP:
break; break;
case OP_ADD: { case OP_ADD: {
auto const& ad = m_adds[vi.m_def_idx]; auto const& ad = get_add(v);
auto const& args = ad.m_args; auto const& args = ad.m_args;
result = ad.m_coeff; result = ad.m_coeff;
for (auto [c, w] : args) for (auto [c, w] : args)
@ -1291,40 +1291,40 @@ namespace sls {
break; break;
} }
case OP_MUL: { case OP_MUL: {
auto const& [w, monomial] = m_muls[vi.m_def_idx]; auto const& [w, monomial] = get_mul(v);
result = num_t(1); result = num_t(1);
for (auto [w, p] : monomial) for (auto [w, p] : monomial)
result *= power_of(value(w), p); result *= power_of(value(w), p);
break; break;
} }
case OP_MOD: case OP_MOD:
v1 = value(m_ops[vi.m_def_idx].m_arg1); v1 = value(get_op(v).m_arg1);
v2 = value(m_ops[vi.m_def_idx].m_arg2); v2 = value(get_op(v).m_arg2);
result = v2 == 0 ? num_t(0) : mod(v1, v2); result = v2 == 0 ? num_t(0) : mod(v1, v2);
break; break;
case OP_DIV: case OP_DIV:
v1 = value(m_ops[vi.m_def_idx].m_arg1); v1 = value(get_op(v).m_arg1);
v2 = value(m_ops[vi.m_def_idx].m_arg2); v2 = value(get_op(v).m_arg2);
result = v2 == 0 ? num_t(0) : v1 / v2; result = v2 == 0 ? num_t(0) : v1 / v2;
break; break;
case OP_IDIV: case OP_IDIV:
v1 = value(m_ops[vi.m_def_idx].m_arg1); v1 = value(get_op(v).m_arg1);
v2 = value(m_ops[vi.m_def_idx].m_arg2); v2 = value(get_op(v).m_arg2);
result = v2 == 0 ? num_t(0) : div(v1, v2); result = v2 == 0 ? num_t(0) : div(v1, v2);
break; break;
case OP_REM: case OP_REM:
v1 = value(m_ops[vi.m_def_idx].m_arg1); v1 = value(get_op(v).m_arg1);
v2 = value(m_ops[vi.m_def_idx].m_arg2); v2 = value(get_op(v).m_arg2);
result = v2 == 0 ? num_t(0) : v1 %= v2; result = v2 == 0 ? num_t(0) : v1 %= v2;
break; break;
case OP_ABS: case OP_ABS:
result = abs(value(m_ops[vi.m_def_idx].m_arg1)); result = abs(value(get_op(v).m_arg1));
break; break;
case OP_TO_REAL: case OP_TO_REAL:
result = value(m_ops[vi.m_def_idx].m_arg1); result = value(get_op(v).m_arg1);
break; break;
case OP_TO_INT: { case OP_TO_INT: {
rational r = value(m_ops[vi.m_def_idx].m_arg1).to_rational(); rational r = value(get_op(v).m_arg1).to_rational();
result = to_num(floor(r)); result = to_num(floor(r));
break; break;
} }
@ -1368,25 +1368,25 @@ namespace sls {
case arith_op_kind::LAST_ARITH_OP: case arith_op_kind::LAST_ARITH_OP:
break; break;
case arith_op_kind::OP_ADD: case arith_op_kind::OP_ADD:
return repair_add(m_adds[vi.m_def_idx]); return repair_add(get_add(v));
case arith_op_kind::OP_MUL: case arith_op_kind::OP_MUL:
return repair_mul(m_muls[vi.m_def_idx]); return repair_mul(get_mul(v));
case arith_op_kind::OP_MOD: case arith_op_kind::OP_MOD:
return repair_mod(m_ops[vi.m_def_idx]); return repair_mod(get_op(v));
case arith_op_kind::OP_REM: case arith_op_kind::OP_REM:
return repair_rem(m_ops[vi.m_def_idx]); return repair_rem(get_op(v));
case arith_op_kind::OP_POWER: case arith_op_kind::OP_POWER:
return repair_power(m_ops[vi.m_def_idx]); return repair_power(get_op(v));
case arith_op_kind::OP_IDIV: case arith_op_kind::OP_IDIV:
return repair_idiv(m_ops[vi.m_def_idx]); return repair_idiv(get_op(v));
case arith_op_kind::OP_DIV: case arith_op_kind::OP_DIV:
return repair_div(m_ops[vi.m_def_idx]); return repair_div(get_op(v));
case arith_op_kind::OP_ABS: case arith_op_kind::OP_ABS:
return repair_abs(m_ops[vi.m_def_idx]); return repair_abs(get_op(v));
case arith_op_kind::OP_TO_INT: case arith_op_kind::OP_TO_INT:
return repair_to_int(m_ops[vi.m_def_idx]); return repair_to_int(get_op(v));
case arith_op_kind::OP_TO_REAL: case arith_op_kind::OP_TO_REAL:
return repair_to_real(m_ops[vi.m_def_idx]); return repair_to_real(get_op(v));
default: default:
throw default_exception("no repair " + mk_pp(e, m)); throw default_exception("no repair " + mk_pp(e, m));
} }
@ -1514,7 +1514,7 @@ namespace sls {
case OP_REM: case OP_REM:
break; break;
case OP_MOD: { case OP_MOD: {
auto v2 = m_ops[vi.m_def_idx].m_arg2; auto v2 = get_op(v).m_arg2;
auto const& vi2 = m_vars[v2]; auto const& vi2 = m_vars[v2];
if (vi2.m_lo && vi2.m_hi && vi2.m_lo->value == vi2.m_hi->value && vi2.m_lo->value > 0) { if (vi2.m_lo && vi2.m_hi && vi2.m_lo->value == vi2.m_hi->value && vi2.m_lo->value > 0) {
add_le(v, vi2.m_lo->value - 1); add_le(v, vi2.m_lo->value - 1);
@ -1532,7 +1532,7 @@ namespace sls {
} }
template<typename num_t> template<typename num_t>
void arith_base<num_t>::initialize_of_bool_var(sat::bool_var bv) { void arith_base<num_t>::initialize_vars_of(sat::bool_var bv) {
auto* ineq = get_ineq(bv); auto* ineq = get_ineq(bv);
if (!ineq) if (!ineq)
return; return;
@ -1542,11 +1542,9 @@ namespace sls {
m_tmp_set.reset(); m_tmp_set.reset();
for (unsigned i = 0; i < todo.size(); ++i) { for (unsigned i = 0; i < todo.size(); ++i) {
var_t u = todo[i]; var_t u = todo[i];
auto& ui = m_vars[u];
if (m_tmp_set.contains(u)) if (m_tmp_set.contains(u))
continue; continue;
m_tmp_set.insert(u); m_tmp_set.insert(u);
ui.m_bool_vars_of.push_back(bv);
if (is_add(u)) { if (is_add(u)) {
auto const& ad = get_add(u); auto const& ad = get_add(u);
for (auto const& [c, w] : ad.m_args) for (auto const& [c, w] : ad.m_args)
@ -1558,45 +1556,25 @@ namespace sls {
todo.push_back(w); todo.push_back(w);
} }
if (is_op(u)) { if (is_op(u)) {
auto const& op = m_ops[ui.m_def_idx]; auto const& op = get_op(u);
todo.push_back(op.m_arg1); todo.push_back(op.m_arg1);
todo.push_back(op.m_arg2); todo.push_back(op.m_arg2);
} }
} }
} }
template<typename num_t>
void arith_base<num_t>::initialize_of_bool_var(sat::bool_var bv) {
initialize_vars_of(bv);
for (auto v : m_tmp_set)
m_vars[v].m_bool_vars_of.push_back(bv);
}
template<typename num_t> template<typename num_t>
void arith_base<num_t>::initialize_clauses_of(sat::bool_var bv, unsigned ci) { void arith_base<num_t>::initialize_clauses_of(sat::bool_var bv, unsigned ci) {
auto* ineq = get_ineq(bv); initialize_vars_of(bv);
if (!ineq) for (auto v : m_tmp_set)
return; m_vars[v].m_clauses_of.push_back(ci);
buffer<var_t> todo;
for (auto const& [coeff, v] : ineq->m_args)
todo.push_back(v);
m_tmp_set.reset();
for (unsigned i = 0; i < todo.size(); ++i) {
var_t u = todo[i];
auto& ui = m_vars[u];
if (m_tmp_set.contains(u))
continue;
m_tmp_set.insert(u);
ui.m_clauses_of.push_back(ci);
if (is_add(u)) {
auto const& ad = get_add(u);
for (auto const& [c, w] : ad.m_args)
todo.push_back(w);
}
if (is_mul(u)) {
auto const& [w, monomial] = get_mul(u);
for (auto [w, p] : monomial)
todo.push_back(w);
}
if (is_op(u)) {
auto const& op = m_ops[ui.m_def_idx];
todo.push_back(op.m_arg1);
todo.push_back(op.m_arg2);
}
}
} }
template<typename num_t> template<typename num_t>
@ -1942,8 +1920,7 @@ namespace sls {
template<typename num_t> template<typename num_t>
num_t arith_base<num_t>::mul_value_without(var_t m, var_t x) { num_t arith_base<num_t>::mul_value_without(var_t m, var_t x) {
auto const& vi = m_vars[m]; auto const& [w, monomial] = get_mul(m);
auto const& [w, monomial] = m_muls[vi.m_def_idx];
SASSERT(m == w); SASSERT(m == w);
num_t r(1); num_t r(1);
for (auto [y, p] : monomial) for (auto [y, p] : monomial)
@ -2477,52 +2454,52 @@ namespace sls {
case arith_op_kind::LAST_ARITH_OP: case arith_op_kind::LAST_ARITH_OP:
break; break;
case arith_op_kind::OP_ADD: { case arith_op_kind::OP_ADD: {
auto ad = m_adds[vi.m_def_idx]; auto ad = get_add(v);
num_t sum(ad.m_coeff); num_t sum(ad.m_coeff);
for (auto [c, w] : ad.m_args) for (auto [c, w] : ad.m_args)
sum += c * value(w); sum += c * value(w);
return sum == value(v); return sum == value(v);
} }
case arith_op_kind::OP_MUL: { case arith_op_kind::OP_MUL: {
auto md = m_muls[vi.m_def_idx]; auto md = get_mul(v);
num_t prod(1); num_t prod(1);
for (auto [w, p] : md.m_monomial) for (auto [w, p] : md.m_monomial)
prod *= power_of(value(w), p); prod *= power_of(value(w), p);
return prod == value(v); return prod == value(v);
} }
case arith_op_kind::OP_MOD: { case arith_op_kind::OP_MOD: {
auto od = m_ops[vi.m_def_idx]; auto od = get_op(v);
return value(v) == (value(od.m_arg2) == 0 ? num_t(0) : mod(value(od.m_arg1), value(od.m_arg2))); return value(v) == (value(od.m_arg2) == 0 ? num_t(0) : mod(value(od.m_arg1), value(od.m_arg2)));
} }
case arith_op_kind::OP_REM: { case arith_op_kind::OP_REM: {
auto od = m_ops[vi.m_def_idx]; auto od = get_op(v);
return value(v) == (value(od.m_arg2) == 0 ? num_t(0) : mod(value(od.m_arg1), value(od.m_arg2))); return value(v) == (value(od.m_arg2) == 0 ? num_t(0) : mod(value(od.m_arg1), value(od.m_arg2)));
} }
case arith_op_kind::OP_POWER: { case arith_op_kind::OP_POWER: {
//auto od = m_ops[vi.m_def_idx]; //auto od = get_op(v);
throw default_exception("unsupported " + mk_pp(vi.m_expr, m)); throw default_exception("unsupported " + mk_pp(vi.m_expr, m));
break; break;
} }
case arith_op_kind::OP_IDIV: { case arith_op_kind::OP_IDIV: {
auto od = m_ops[vi.m_def_idx]; auto od = get_op(v);
return value(v) == (value(od.m_arg2) == 0 ? num_t(0) : div(value(od.m_arg1), value(od.m_arg2))); return value(v) == (value(od.m_arg2) == 0 ? num_t(0) : div(value(od.m_arg1), value(od.m_arg2)));
} }
case arith_op_kind::OP_DIV: { case arith_op_kind::OP_DIV: {
auto od = m_ops[vi.m_def_idx]; auto od = get_op(v);
return value(v) == (value(od.m_arg2) == 0 ? num_t(0) : value(od.m_arg1) / value(od.m_arg2)); return value(v) == (value(od.m_arg2) == 0 ? num_t(0) : value(od.m_arg1) / value(od.m_arg2));
} }
case arith_op_kind::OP_ABS: { case arith_op_kind::OP_ABS: {
auto od = m_ops[vi.m_def_idx]; auto od = get_op(v);
return value(v) == abs(value(od.m_arg1)); return value(v) == abs(value(od.m_arg1));
} }
case arith_op_kind::OP_TO_INT: { case arith_op_kind::OP_TO_INT: {
auto od = m_ops[vi.m_def_idx]; auto od = get_op(v);
auto val = value(od.m_var); auto val = value(od.m_var);
auto v1 = value(od.m_arg1); auto v1 = value(od.m_arg1);
return val - 1 < v1 && v1 <= val; return val - 1 < v1 && v1 <= val;
} }
case arith_op_kind::OP_TO_REAL: { case arith_op_kind::OP_TO_REAL: {
auto od = m_ops[vi.m_def_idx]; auto od = get_op(v);
auto val = value(od.m_var); auto val = value(od.m_var);
auto v1 = value(od.m_arg1); auto v1 = value(od.m_arg1);
return val == v1; return val == v1;

View file

@ -283,6 +283,7 @@ namespace sls {
bool is_if(var_t v) const { return m.is_ite(m_vars[v].m_expr); } bool is_if(var_t v) const { return m.is_ite(m_vars[v].m_expr); }
mul_def const& get_mul(var_t v) const { SASSERT(is_mul(v)); return m_muls[m_vars[v].m_def_idx]; } mul_def const& get_mul(var_t v) const { SASSERT(is_mul(v)); return m_muls[m_vars[v].m_def_idx]; }
add_def const& get_add(var_t v) const { SASSERT(is_add(v)); return m_adds[m_vars[v].m_def_idx]; } add_def const& get_add(var_t v) const { SASSERT(is_add(v)); return m_adds[m_vars[v].m_def_idx]; }
op_def const& get_op(var_t v) const { SASSERT(is_op(v)); return m_ops[m_vars[v].m_def_idx]; }
bool update(var_t v, num_t const& new_value); bool update(var_t v, num_t const& new_value);
bool apply_update(); bool apply_update();
@ -295,8 +296,9 @@ namespace sls {
double compute_score(var_t x, num_t const& delta); double compute_score(var_t x, num_t const& delta);
void save_best_values(); void save_best_values();
void initialize_of_bool_var(sat::bool_var v); void initialize_vars_of(sat::bool_var bv);
void initialize_clauses_of(sat::bool_var v, unsigned cl); void initialize_of_bool_var(sat::bool_var bv);
void initialize_clauses_of(sat::bool_var bv, unsigned cl);
var_t mk_var(expr* e); var_t mk_var(expr* e);
var_t mk_term(expr* e); var_t mk_term(expr* e);
var_t mk_op(arith_op_kind k, expr* e, expr* x, expr* y); var_t mk_op(arith_op_kind k, expr* e, expr* x, expr* y);

View file

@ -73,8 +73,8 @@ namespace sls {
if (bv != sat::null_bool_var) tout << "bool flip " << bv << "\n"; if (bv != sat::null_bool_var) tout << "bool flip " << bv << "\n";
else if (v != null_arith_var) tout << "arith flip v" << v << "\n"; else if (v != null_arith_var) tout << "arith flip v" << v << "\n";
else tout << "no flip\n"; else tout << "no flip\n";
tout << "unsat-vars " << vars_in_unsat << "\n"; tout << "unsat-vars " << ctx.unsat_vars().size() << "\n";
tout << "bools: " << bool_in_unsat << " timeup-bool " << time_up_bool << "\n"; tout << "bools: " << (ctx.unsat_vars().size() - ctx.num_external_in_unsat_vars()) << " timeup-bool " << time_up_bool << "\n";
tout << "no-improve bool: " << m_no_improve_bool << "\n"; tout << "no-improve bool: " << m_no_improve_bool << "\n";
tout << "no-improve arith: " << m_no_improve_arith << "\n"; tout << "no-improve arith: " << m_no_improve_arith << "\n";
tout << "ext: " << ext_in_unsat << " timeup-arith " << time_up_arith << "\n"; tout << "ext: " << ext_in_unsat << " timeup-arith " << time_up_arith << "\n";