From 847278fba8d376d6190c11442fccf4ca06a44a5f Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Thu, 9 Jan 2025 16:47:33 -0800 Subject: [PATCH] adding global lookahead variant to sls arith solver --- src/ast/sls/sls_arith_base.cpp | 485 +++++++++++++++++++++++++++++++ src/ast/sls/sls_arith_base.h | 56 ++++ src/ast/sls/sls_bv_eval.cpp | 8 +- src/ast/sls/sls_bv_lookahead.cpp | 3 - src/util/checked_int64.h | 1 + src/util/util.h | 8 + 6 files changed, 552 insertions(+), 9 deletions(-) diff --git a/src/ast/sls/sls_arith_base.cpp b/src/ast/sls/sls_arith_base.cpp index 7515c58de..da97ca823 100644 --- a/src/ast/sls/sls_arith_base.cpp +++ b/src/ast/sls/sls_arith_base.cpp @@ -2316,6 +2316,491 @@ namespace sls { void arith_base::reset_statistics() { m_stats.m_num_steps = 0; } + + // global lookahead mode + // + + template + arith_base::bool_info& arith_base::get_bool_info(expr* e) { + m_bool_info.reserve(e->get_id() + 1, { m_config.paws_init, 0, 1, l_undef }); + return m_bool_info[e->get_id()]; + } + + template + bool arith_base::get_bool_value_rec(expr* e) { + if (!is_app(e)) + return ctx.get_value(e) == l_true; + + if (is_uninterp(e)) + return ctx.get_value(e) == l_true; + + app* a = to_app(e); + if (a->get_family_id() == basic_family_id) { + bool r = get_basic_bool_value(a); + get_bool_info(e).value = to_lbool(r); + return r; + } + + auto v = ctx.atom2bool_var(e); + if (v == sat::null_bool_var) + return false; + auto const* ineq = get_ineq(v); + if (!ineq) + return false; + auto r = ineq->is_true() == ctx.is_true(v); + get_bool_info(e).value = to_lbool(r); + return r; + } + + template + bool arith_base::get_bool_value(expr* e) { + auto& info = get_bool_info(e); + if (info.value != l_undef) + return info.value == l_true; + + auto r = get_bool_value_rec(e); + info.value = to_lbool(r); + return r; + } + + + template + bool arith_base::get_basic_bool_value(app* e) { + switch (e->get_decl_kind()) { + case OP_TRUE: + return true; + case OP_FALSE: + return false; + case OP_NOT: + return !get_bool_value(e->get_arg(0)); + case OP_AND: + return all_of(*e, [&](expr* arg) { return get_bool_value(arg); }); + case OP_OR: + return any_of(*e, [&](expr* arg) { return get_bool_value(arg); }); + case OP_XOR: + return xor_of(*e, [&](expr* arg) { return get_bool_value(arg); }); + case OP_IMPLIES: + return !get_bool_value(e->get_arg(0)) || get_bool_value(e->get_arg(1)); + case OP_EQ: + if (m.is_bool(e->get_arg(0))) + return get_bool_value(e->get_arg(0)) == get_bool_value(e->get_arg(1)); + NOT_IMPLEMENTED_YET(); + case OP_DISTINCT: + NOT_IMPLEMENTED_YET(); + default: + NOT_IMPLEMENTED_YET(); + } + return false; + } + + template + void arith_base::initialize_bool_assignment() { + for (auto t : ctx.subterms()) + if (m.is_bool(t)) + get_bool_value(t); + } + + template + void arith_base::finalize_bool_assignment() { + for (unsigned v = ctx.num_bool_vars(); v-- > 0; ) { + auto a = ctx.atom(v); + if (!a) + continue; + if (get_bool_value(a) != ctx.is_true(v)) + ctx.flip(v); + } + } + + template + double arith_base::new_score(expr* e) { + return new_score(e, true); + } + + template + double arith_base::new_score(expr* a, bool is_true) { + bool is_true_new = get_bool_value(a); + + //verbose_stream() << "compute score " << mk_bounded_pp(a, m) << " is-true " << is_true << " is-true-new " << is_true_new << "\n"; + if (is_true == is_true_new) + return 1; + if (is_uninterp(a)) + return 0; + if (m.is_true(a)) + return is_true ? 1 : 0; + if (m.is_false(a)) + return is_true ? 0 : 1; + expr* x, * y, * z; + if (m.is_not(a, x)) + return new_score(x, !is_true); + if ((m.is_and(a) && is_true) || (m.is_or(a) && !is_true)) { + double score = 1; + for (auto arg : *to_app(a)) + score = std::min(score, new_score(arg, is_true)); + return score; + } + if ((m.is_and(a) && !is_true) || (m.is_or(a) && is_true)) { + double score = 0; + for (auto arg : *to_app(a)) + score = std::max(score, new_score(arg, is_true)); + return score; + } + if (m.is_iff(a, x, y)) { + auto v0 = get_bool_value(x); + auto v1 = get_bool_value(y); + return (is_true == (v0 == v1)) ? 1 : 0; + } + if (m.is_ite(a, x, y, z)) + return get_bool_value(x) ? new_score(y, is_true) : new_score(z, is_true); + + + auto v = ctx.atom2bool_var(a); + if (v == sat::null_bool_var) + return 0; + auto const* ineq = get_ineq(v); + if (!ineq) + return 0; + + auto const& args = ineq->m_args_value; + auto const& coeff = ineq->m_coeff; + auto value = args + coeff; + + switch (ineq->m_op) { + case ineq_kind::LE: + if (is_true) { + if (value <= 0) + return 1.0; + } + else { + if (value > 0) + return 1.0; + value = -value + 1; + } + break; + case ineq_kind::LT: + if (is_true) { + if (value < 0) + return 1.0; + } + else { + if (value >= 0) + return 1.0; + value = -value; + } + break; + case ineq_kind::EQ: + if (is_true) { + if (value == 0) + return 1.0; + if (value < 0) + value = -value; + } + else { + if (value != 0) + return 1.0; + return 0.0; + } + break; + } + + SASSERT(value > 0); + unsigned max_value = 10000; + if (value > max_value) + return 1.0; + auto d = value.get_double(); + return 1.0 - ((d * d) / ((double)max_value * (double)max_value)); + } + + template + void arith_base::rescore() { + m_top_score = 0; + m_is_root.reset(); + for (auto a : ctx.input_assertions()) { + double score = new_score(a); + set_score(a, score); + m_top_score += score; + m_is_root.mark(a); + } + } + + template + void arith_base::recalibrate_weights() { + for (auto a : ctx.input_assertions()) { + if (ctx.rand(2047) < m_config.paws_sp) { + if (get_bool_value(a)) + dec_weight(a); + } + else if (!get_bool_value(a)) + inc_weight(a); + } + } + + template + void arith_base::insert_update_stack_rec(expr* t) { + m_min_depth = m_max_depth = get_depth(t); + insert_update_stack(t); + for (unsigned depth = m_max_depth; depth <= m_max_depth; ++depth) { + for (unsigned i = 0; i < m_update_stack[depth].size(); ++i) { + auto a = m_update_stack[depth][i]; + for (auto p : ctx.parents(a)) { + insert_update_stack(p); + m_max_depth = std::max(m_max_depth, get_depth(p)); + } + } + } + } + template + double arith_base::lookahead(expr* t) { + SASSERT(a.is_int_real(t) || m.is_bool(t)); + double score = m_top_score; + for (unsigned depth = m_min_depth; depth <= m_max_depth; ++depth) { + for (unsigned i = 0; i < m_update_stack[depth].size(); ++i) { + auto* a = m_update_stack[depth][i]; + TRACE("bv_verbose", tout << "update " << mk_bounded_pp(a, m) << " depth: " << depth << "\n";); + if (t != a) + set_bool_value(a, get_bool_value_rec(a)); + if (m_is_root.is_marked(a)) + score += get_weight(a) * (new_score(a) - old_score(a)); + } + } + return score; + } + + template + void arith_base::insert_update_stack(expr* t) { + if (!m.is_bool(t)) + return; + unsigned depth = get_depth(t); + m_update_stack.reserve(depth + 1); + if (!m_in_update_stack.is_marked(t) && is_app(t)) { + m_in_update_stack.mark(t); + m_update_stack[depth].push_back(to_app(t)); + } + } + + template + void arith_base::clear_update_stack() { + lookahead(nullptr); + m_in_update_stack.reset(); + for (unsigned i = m_min_depth; i <= m_max_depth; ++i) + m_update_stack[i].reset(); + } + + template + void arith_base::lookahead_num(var_t v, num_t const& new_value) { + num_t old_value = value(v); + if (!update(v, new_value)) + return; + + expr* e = m_vars[v].m_expr; + auto score = lookahead(e); + if (score > m_best_score) { + m_best_score = score; + m_best_value = new_value; + m_best_expr = e; + } + VERIFY(update(v, old_value)); + lookahead(e); + } + + template + void arith_base::lookahead_bool(expr* e) { + bool b = get_bool_value(e); + set_bool_value(e, !b); + auto score = lookahead(e); + if (score > m_best_score) { + m_best_score = score; + m_best_expr = e; + } + set_bool_value(e, b); + } + + // for every variable e, for every atom containing e + // add lookahead for e. + // m_fixable_atoms contains atoms that can be fixed. + // m_fixable_vars contains variables that can be updated. + template + void arith_base::add_lookahead(expr* e) { + if (m.is_bool(e)) { + auto bv = ctx.atom2bool_var(e); + if (m_fixable_atoms.contains(bv)) + lookahead_bool(e); + } + else if (a.is_int_real(e)) { + auto v = mk_term(e); + auto& vi = m_vars[v]; + for (auto [coeff, bv] : vi.m_ineqs) { + if (!m_fixable_atoms.contains(bv)) + continue; + auto a = ctx.atom(bv); + if (!a) + continue; + auto* ineq = get_ineq(bv); + if (!ineq) + continue; + num_t na, nb; + for (auto const& [x, nl] : ineq->m_nonlinear) { + if (!m_fixable_vars.contains(x)) + continue; + if (is_fixed(x)) + continue; + if (is_linear(x, nl, nb)) + find_linear_moves(*ineq, x, nb, ineq->m_args_value); + else if (is_quadratic(x, nl, na, nb)) + find_quadratic_moves(*ineq, x, na, nb, ineq->m_args_value); + else + ; + } + m_fixable_atoms.remove(bv); + } + } + } + + // + // e is a formula that is false, + // assemble candidates that can flip the formula to true. + // candidate expressions may be either numeric or boolean variables. + // + template + void arith_base::add_fixable(expr* e) { + m_fixable_exprs.reset(); + m_fixable_atoms.reset(); + m_fixable_vars.reset(); + expr_mark visited; + buffer> todo; + expr* x, * y, * z; + todo.push_back({ e, l_true }); + while (!todo.empty()) { + auto [e, is_true] = todo.back(); + todo.pop_back(); + if (visited.is_marked(e)) + continue; + visited.mark(e); + if (is_true == l_true && get_bool_value(e)) + continue; + if (is_true == l_false && !get_bool_value(e)) + continue; + if (m.is_not(e, e)) + todo.push_back({ e, ~is_true }); + else if (m.is_and(e) || m.is_or(e)) { + for (auto arg : *to_app(e)) + todo.push_back({ arg, is_true }); + } + else if (m.is_implies(e, x, y)) { + todo.push_back({ x, ~is_true }); + todo.push_back({ y, is_true }); + } + else if (m.is_iff(e, x, y)) { + todo.push_back({ x, l_undef }); + todo.push_back({ y, l_undef }); + } + else if (m.is_ite(e, x, y, z)) { + todo.push_back({ x, l_undef }); + todo.push_back({ y, is_true }); + todo.push_back({ z, ~is_true }); + } + else { + auto bv = ctx.atom2bool_var(e); + if (bv == sat::null_bool_var) + continue; + if (is_uninterp(e)) { + if (!m_fixable_atoms.contains(bv)) { + m_fixable_atoms.insert(bv); + m_fixable_exprs.push_back(e); + } + continue; + } + auto* ineq = get_ineq(bv); + if (!ineq) + continue; + m_fixable_atoms.insert(bv); + for (auto& [v, occ] : ineq->m_nonlinear) { + if (m_fixable_vars.contains(v)) + continue; + m_fixable_vars.insert(v); + m_fixable_exprs.push_back(m_vars[v].m_expr); + } + } + } + } + + template + bool arith_base::apply_move(expr* t, bool randomize) { + add_fixable(t); + auto& vars = m_fixable_exprs; + if (vars.empty()) + return false; + m_best_expr = nullptr; + m_best_score = m_top_score; + unsigned sz = vars.size(); + unsigned start = ctx.rand(); + m_updates.reset(); + insert_update_stack_rec(t); + for (unsigned i = 0; i < sz; ++i) + add_lookahead(vars[(start + i) % sz]); + + if (randomize) { + if (m_updates.empty()) + return false; + auto& [v, new_value, score] = m_updates[ctx.rand() % m_updates.size()]; + m_best_expr = m_vars[v].m_expr; + } + else { + for (auto const& [v, new_value, score] : m_updates) + lookahead_num(v, new_value); + } + if (m_best_expr) + m_top_score = lookahead(m_best_expr); + + clear_update_stack(); + + CTRACE("bv", !m_best_expr, tout << "no guided move\n";); + return !!m_best_expr; + } + + template + void arith_base::global_search() { + initialize_bool_assignment(); + rescore(); + m_config.max_moves = m_stats.m_moves + m_config.max_moves_base; + TRACE("bv", tout << "search " << m_stats.m_moves << " " << m_config.max_moves << "\n";); + IF_VERBOSE(1, verbose_stream() << "lookahead-search moves:" << m_stats.m_moves << " max-moves:" << m_config.max_moves << "\n"); + + while (m.inc() && m_stats.m_moves < m_config.max_moves) { + m_stats.m_moves++; + check_restart(); + + auto t = get_candidate_unsat(); + + if (!t) + break; + + if (apply_move(t, false)) + continue; + + if (apply_move(t, true)) + recalibrate_weights(); + } + m_config.max_moves_base += 100; + finalize_bool_assignment(); + } + + template + expr* arith_base::get_candidate_unsat() { + unsigned n = 0; + expr* r = nullptr; + for (auto a : ctx.input_assertions()) { + if (!get_bool_value(a) && (ctx.rand() % (++n)) == 0) + r = a; + } + return r; + } + + template + void arith_base::check_restart() { + + + } + } template class sls::arith_base>; diff --git a/src/ast/sls/sls_arith_base.h b/src/ast/sls/sls_arith_base.h index df2f95ee4..07b952d14 100644 --- a/src/ast/sls/sls_arith_base.h +++ b/src/ast/sls/sls_arith_base.h @@ -42,10 +42,16 @@ namespace sls { unsigned t = 45; unsigned max_no_improve = 500000; double sp = 0.0003; + unsigned paws_init = 40; + unsigned paws_sp = 52; + bool paws = true; + unsigned max_moves = 500; + unsigned max_moves_base = 500; }; struct stats { unsigned m_num_steps = 0; + unsigned m_moves = 0; }; public: @@ -274,6 +280,56 @@ namespace sls { std::ostream& display(std::ostream& out, var_t v) const; std::ostream& display(std::ostream& out, add_def const& ad) const; std::ostream& display(std::ostream& out, mul_def const& md) const; + + + + // for global lookahead search mode + void global_search(); + struct bool_info { + unsigned weight = 0; + double score = 0; + unsigned touched = 0; + lbool value = l_undef; + }; + vector> m_update_stack; + expr_mark m_in_update_stack; + svector m_bool_info; + double m_best_score = 0, m_top_score = 0; + unsigned m_min_depth = 0, m_max_depth = 0; + num_t m_best_value; + expr* m_best_expr = nullptr, * m_last_atom = nullptr; + expr_mark m_is_root; + sat::bool_var_set m_fixable_atoms; + uint_set m_fixable_vars; + ptr_vector m_fixable_exprs; + bool_info& get_bool_info(expr* e); + bool get_bool_value(expr* e); + bool get_bool_value_rec(expr* e); + void set_bool_value(expr* e, bool v) { get_bool_info(e).value = to_lbool(v); } + bool get_basic_bool_value(app* e); + void initialize_bool_assignment(); + void finalize_bool_assignment(); + double old_score(expr* e) { return get_bool_info(e).score; } + double new_score(expr* e); + double new_score(expr* e, bool is_true); + void set_score(expr* e, double s) { get_bool_info(e).score = s; } + + void rescore(); + void recalibrate_weights(); + void inc_weight(expr* e) { ++get_bool_info(e).weight; } + void dec_weight(expr* e) { auto& i = get_bool_info(e); i.weight = i.weight > m_config.paws_init ? i.weight - 1 : m_config.paws_init; } + unsigned get_weight(expr* e) { return get_bool_info(e).weight; } + void insert_update_stack(expr* t); + void insert_update_stack_rec(expr* t); + void clear_update_stack(); + void lookahead_num(var_t v, num_t const& value); + void lookahead_bool(expr* e); + double lookahead(expr* e); + void add_lookahead(expr* e); + void add_fixable(expr* e); + bool apply_move(expr* f, bool randomize); + expr* get_candidate_unsat(); + void check_restart(); public: arith_base(context& ctx); ~arith_base() override {} diff --git a/src/ast/sls/sls_bv_eval.cpp b/src/ast/sls/sls_bv_eval.cpp index 7908781e4..afaf71a89 100644 --- a/src/ast/sls/sls_bv_eval.cpp +++ b/src/ast/sls/sls_bv_eval.cpp @@ -257,12 +257,8 @@ namespace sls { return !get_bool_value(e->get_arg(0)) || get_bool_value(e->get_arg(1)); case OP_ITE: return get_bool_value(e->get_arg(0)) ? get_bool_value(e->get_arg(1)) : get_bool_value(e->get_arg(2)); - case OP_XOR: { - bool r = false; - for (expr* arg : *e) - r ^= get_bool_value(arg); - return r; - } + case OP_XOR: + return xor_of(*e, [&](expr* arg) { return get_bool_value(arg); }); case OP_TRUE: return true; case OP_FALSE: diff --git a/src/ast/sls/sls_bv_lookahead.cpp b/src/ast/sls/sls_bv_lookahead.cpp index 93cd6adf0..630642f1d 100644 --- a/src/ast/sls/sls_bv_lookahead.cpp +++ b/src/ast/sls/sls_bv_lookahead.cpp @@ -496,9 +496,6 @@ namespace sls { for (unsigned i = 0; i < m_update_stack[depth].size(); ++i) { auto const& [a, is_bv] = m_update_stack[depth][i]; TRACE("bv_verbose", tout << "update " << mk_bounded_pp(a, m) << " depth: " << depth << "\n";); - bool before; - if (m.is_bool(a)) - before = m_ev.get_bool_value(a); if (t != a) { if (is_bv) diff --git a/src/util/checked_int64.h b/src/util/checked_int64.h index 65a80de6a..ec71bef13 100644 --- a/src/util/checked_int64.h +++ b/src/util/checked_int64.h @@ -58,6 +58,7 @@ public: static checked_int64 minus_one() { return ci(-1);} int64_t get_int64() const { return m_value; } + double get_double() const { return (double)m_value; } rational to_rational() const { return r64(m_value); } checked_int64 abs() const { diff --git a/src/util/util.h b/src/util/util.h index 7d1265b33..104617e4f 100644 --- a/src/util/util.h +++ b/src/util/util.h @@ -389,6 +389,14 @@ bool all_of(S const& set, T const& p) { return true; } +template +bool xor_of(S const& set, T const& p) { + bool r = false; + for (auto const& s : set) + r ^= p(s); + return r; +} + template R find(S const& set, std::function p) { for (auto const& s : set)