From 5ebcc3e4473b1b194d354d7829bd91c16db8202e Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Fri, 5 Jul 2024 16:16:01 -0700 Subject: [PATCH] reorg sls --- src/CMakeLists.txt | 2 +- src/ast/sls/CMakeLists.txt | 5 +- src/ast/sls/sls_arith_int.cpp | 808 ++++++++++++++++++++++++++ src/ast/sls/sls_arith_int.h | 193 ++++++ src/ast/sls/sls_cc.cpp | 145 +++++ src/ast/sls/sls_cc.h | 51 ++ src/ast/sls/sls_smt.cpp | 256 ++++++++ src/ast/sls/sls_smt.h | 125 ++++ src/sat/sat_ddfw.cpp | 103 ++-- src/sat/sat_ddfw.h | 73 +-- src/sat/sat_solver/sat_smt_solver.cpp | 5 +- src/sat/smt/CMakeLists.txt | 2 - src/sat/smt/arith_sls.cpp | 642 -------------------- src/sat/smt/arith_sls.h | 170 ------ src/sat/smt/arith_solver.cpp | 1 - src/sat/smt/arith_solver.h | 5 - src/sat/smt/euf_internalize.cpp | 6 +- src/sat/smt/euf_local_search.cpp | 50 -- src/sat/smt/euf_solver.cpp | 1 - src/sat/smt/euf_solver.h | 19 +- src/sat/smt/sls_solver.cpp | 232 +++++--- src/sat/smt/sls_solver.h | 22 +- src/sat/tactic/goal2sat.cpp | 13 +- src/util/checked_int64.h | 64 ++ src/util/sat_sls.h | 37 ++ 25 files changed, 1923 insertions(+), 1107 deletions(-) create mode 100644 src/ast/sls/sls_arith_int.cpp create mode 100644 src/ast/sls/sls_arith_int.h create mode 100644 src/ast/sls/sls_cc.cpp create mode 100644 src/ast/sls/sls_cc.h create mode 100644 src/ast/sls/sls_smt.cpp create mode 100644 src/ast/sls/sls_smt.h delete mode 100644 src/sat/smt/arith_sls.cpp delete mode 100644 src/sat/smt/arith_sls.h delete mode 100644 src/sat/smt/euf_local_search.cpp create mode 100644 src/util/sat_sls.h diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 4c09f31aa..5faede21f 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -54,7 +54,6 @@ add_subdirectory(ast/euf) add_subdirectory(ast/converters) add_subdirectory(ast/substitution) add_subdirectory(ast/simplifiers) -add_subdirectory(ast/sls) add_subdirectory(tactic) add_subdirectory(qe/mbp) add_subdirectory(qe/lite) @@ -74,6 +73,7 @@ add_subdirectory(parsers/smt2) add_subdirectory(solver/assertions) add_subdirectory(ast/pattern) add_subdirectory(math/lp) +add_subdirectory(ast/sls) add_subdirectory(sat/smt) add_subdirectory(sat/tactic) add_subdirectory(nlsat/tactic) diff --git a/src/ast/sls/CMakeLists.txt b/src/ast/sls/CMakeLists.txt index 24eaec4dc..be26d70f0 100644 --- a/src/ast/sls/CMakeLists.txt +++ b/src/ast/sls/CMakeLists.txt @@ -4,8 +4,11 @@ z3_add_component(ast_sls bv_sls.cpp bv_sls_eval.cpp bv_sls_fixed.cpp - bv_sls_terms.cpp + bv_sls_terms.cpp + sls_arith_int.cpp + sls_cc.cpp sls_engine.cpp + sls_smt.cpp sls_valuation.cpp COMPONENT_DEPENDENCIES ast diff --git a/src/ast/sls/sls_arith_int.cpp b/src/ast/sls/sls_arith_int.cpp new file mode 100644 index 000000000..fbf5c4e6a --- /dev/null +++ b/src/ast/sls/sls_arith_int.cpp @@ -0,0 +1,808 @@ +/*++ +Copyright (c) 2023 Microsoft Corporation + +Module Name: + + arith_sls_int.cpp + +Abstract: + + Local search dispatch for NIA + +Author: + + Nikolaj Bjorner (nbjorner) 2023-02-07 + +--*/ + +#include "ast/sls/sls_arith_int.h" +#include "ast/ast_ll_pp.h" + +namespace sls { + + template + arith_plugin::arith_plugin(context& ctx) : + plugin(ctx), + a(m) { + m_fid = a.get_family_id(); + } + + template + void arith_plugin::reset() { + m_bool_vars.reset(); + m_vars.reset(); + m_expr2var.reset(); + } + + template + void arith_plugin::save_best_values() { + for (auto& v : m_vars) + v.m_best_value = v.m_value; + check_ineqs(); + } + + template + void arith_plugin::store_best_values() { + } + + // distance to true + template + int_t arith_plugin::dtt(bool sign, int_t const& args, ineq const& ineq) const { + int_t zero{ 0 }; + switch (ineq.m_op) { + case ineq_kind::LE: + if (sign) { + if (args + ineq.m_coeff <= 0) + return -ineq.m_coeff - args + 1; + return zero; + } + if (args + ineq.m_coeff <= 0) + return zero; + return args + ineq.m_coeff; + case ineq_kind::EQ: + if (sign) { + if (args + ineq.m_coeff == 0) + return int_t(1); + return zero; + } + if (args + ineq.m_coeff == 0) + return zero; + return int_t(1); + case ineq_kind::LT: + if (sign) { + if (args + ineq.m_coeff < 0) + return -ineq.m_coeff - args; + return zero; + } + if (args + ineq.m_coeff < 0) + return zero; + return args + ineq.m_coeff + 1; + default: + UNREACHABLE(); + return zero; + } + } + + // + // dtt is high overhead. It walks ineq.m_args + // m_vars[w].m_value can be computed outside and shared among calls + // different data-structures for storing coefficients + // + template + int_t arith_plugin::dtt(bool sign, ineq const& ineq, var_t v, int_t const& new_value) const { + for (auto const& [coeff, w] : ineq.m_args) + if (w == v) + return dtt(sign, ineq.m_args_value + coeff * (new_value - m_vars[v].m_value), ineq); + return int_t(1); + } + + template + int_t arith_plugin::dtt(bool sign, ineq const& ineq, int_t const& coeff, int_t const& old_value, int_t const& new_value) const { + return dtt(sign, ineq.m_args_value + coeff * (new_value - old_value), ineq); + } + + template + bool arith_plugin::cm(ineq const& ineq, var_t v, int_t& new_value) { + for (auto const& [coeff, w] : ineq.m_args) + if (w == v) + return cm(ineq, v, coeff, new_value); + return false; + } + + template + bool arith_plugin::cm(ineq const& ineq, var_t v, int_t const& coeff, int_t& new_value) { + auto bound = -ineq.m_coeff; + auto argsv = ineq.m_args_value; + bool solved = false; + int_t delta = argsv - bound; + + if (ineq.is_true()) { + switch (ineq.m_op) { + case ineq_kind::LE: + // args <= bound -> args > bound + SASSERT(argsv <= bound); + SASSERT(delta <= 0); + delta -= 1 + (ctx.rand() % 10); + new_value = value(v) + div(abs(delta) + abs(coeff) - 1, coeff); + VERIFY(argsv + coeff * (new_value - value(v)) > bound); + return true; + case ineq_kind::LT: + // args < bound -> args >= bound + SASSERT(argsv <= bound); + SASSERT(delta <= 0); + delta = abs(delta) + ctx.rand() % 10; + new_value = value(v) + div(delta + abs(coeff) - 1, coeff); + VERIFY(argsv + coeff * (new_value - value(v)) >= bound); + return true; + case ineq_kind::EQ: { + delta = abs(delta) + 1 + ctx.rand() % 10; + int sign = ctx.rand() % 2 == 0 ? 1 : -1; + new_value = value(v) + sign * div(abs(delta) + abs(coeff) - 1, coeff); + VERIFY(argsv + coeff * (new_value - value(v)) != bound); + return true; + } + default: + UNREACHABLE(); + break; + } + } + else { + switch (ineq.m_op) { + case ineq_kind::LE: + SASSERT(argsv > bound); + SASSERT(delta > 0); + delta += rand() % 10; + new_value = value(v) - div(delta + abs(coeff) - 1, coeff); + VERIFY(argsv + coeff * (new_value - value(v)) <= bound); + return true; + case ineq_kind::LT: + SASSERT(argsv >= bound); + SASSERT(delta >= 0); + delta += 1 + rand() % 10; + new_value = value(v) - div(delta + abs(coeff) - 1, coeff); + VERIFY(argsv + coeff * (new_value - value(v)) < bound); + return true; + case ineq_kind::EQ: + SASSERT(delta != 0); + if (delta < 0) + new_value = value(v) + div(abs(delta) + abs(coeff) - 1, coeff); + else + new_value = value(v) - div(delta + abs(coeff) - 1, coeff); + solved = argsv + coeff * (new_value - value(v)) == bound; + if (!solved && abs(coeff) == 1) { + verbose_stream() << "did not solve equality " << ineq << " for " << v << "\n"; + verbose_stream() << new_value << " " << value(v) << " delta " << delta << " lhs " << (argsv + coeff * (new_value - value(v))) << " bound " << bound << "\n"; + UNREACHABLE(); + } + return solved; + default: + UNREACHABLE(); + break; + } + } + return false; + } + + // flip on the first positive score + // it could be changed to flip on maximal positive score + // or flip on maximal non-negative score + // or flip on first non-negative score + template + void arith_plugin::repair(sat::literal lit, ineq const& ineq) { + int_t new_value; + if (UINT_MAX == ineq.m_var_to_flip) + dtt_reward(lit); + auto v = ineq.m_var_to_flip; + if (v == UINT_MAX) { + IF_VERBOSE(1, verbose_stream() << "no var to flip\n"); + return; + } + // verbose_stream() << "var to flip " << v << "\n"; + if (!cm(ineq, v, new_value)) { + IF_VERBOSE(1, verbose_stream() << "no critical move for " << v << "\n"); + return; + } + update(v, new_value); + } + + // + // dscore(op) = sum_c (dts(c,alpha) - dts(c,alpha_after)) * weight(c) + // TODO - use cached dts instead of computed dts + // cached dts has to be updated when the score of literals are updated. + // + template + double arith_plugin::dscore(var_t v, int_t const& new_value) const { + double score = 0; + auto const& vi = m_vars[v]; + for (auto const& [coeff, bv] : vi.m_bool_vars) { + sat::literal lit(bv, false); + for (auto cl : ctx.get_use_list(lit)) + score += (compute_dts(cl) - dts(cl, v, new_value)).get_int64() * ctx.get_weight(cl); + for (auto cl : ctx.get_use_list(~lit)) + score += (compute_dts(cl) - dts(cl, v, new_value)).get_int64() * ctx.get_weight(cl); + } + return score; + } + + // + // cm_score is costly. It involves several cache misses. + // Note that + // - get_use_list(lit).size() is "often" 1 or 2 + // - dtt_old can be saved + // + template + int arith_plugin::cm_score(var_t v, int_t const& new_value) { + int score = 0; + auto& vi = m_vars[v]; + int_t old_value = vi.m_value; + for (auto const& [coeff, bv] : vi.m_bool_vars) { + auto const& ineq = *atom(bv); + bool old_sign = sign(bv); + int_t dtt_old = dtt(old_sign, ineq); + int_t dtt_new = dtt(old_sign, ineq, coeff, old_value, new_value); + if ((dtt_old == 0) == (dtt_new == 0)) + continue; + sat::literal lit(bv, old_sign); + if (dtt_old == 0) + // flip from true to false + lit.neg(); + + // lit flips form false to true: + + for (auto cl : ctx.get_use_list(lit)) { + auto const& clause = ctx.get_clause(cl); + if (!clause.is_true()) + ++score; + } + + // ignore the situation where clause contains multiple literals using v + for (auto cl : ctx.get_use_list(~lit)) { + auto const& clause = ctx.get_clause(cl); + if (clause.m_num_trues == 1) + --score; + } + } + return score; + } + + template + int_t arith_plugin::compute_dts(unsigned cl) const { + int_t d(1), d2; + bool first = true; + for (auto a : ctx.get_clause(cl)) { + auto const* ineq = atom(a.var()); + if (!ineq) + continue; + d2 = dtt(a.sign(), *ineq); + if (first) + d = d2, first = false; + else + d = std::min(d, d2); + if (d == 0) + break; + } + return d; + } + + template + int_t arith_plugin::dts(unsigned cl, var_t v, int_t const& new_value) const { + int_t d(1), d2; + bool first = true; + for (auto lit : ctx.get_clause(cl)) { + auto const* ineq = atom(lit.var()); + if (!ineq) + continue; + d2 = dtt(lit.sign(), *ineq, v, new_value); + if (first) + d = d2, first = false; + else + d = std::min(d, d2); + if (d == 0) + break; + } + return d; + } + + template + void arith_plugin::update(var_t v, int_t const& new_value) { + auto& vi = m_vars[v]; + auto old_value = vi.m_value; + if (old_value == new_value) + return; + for (auto const& [coeff, bv] : vi.m_bool_vars) { + auto& ineq = *atom(bv); + bool old_sign = sign(bv); + sat::literal lit(bv, old_sign); + SASSERT(ctx.is_true(lit)); + ineq.m_args_value += coeff * (new_value - old_value); + int_t dtt_new = dtt(old_sign, ineq); + if (dtt_new != 0) + ctx.flip(bv); + SASSERT(dtt(sign(bv), ineq) == 0); + } + vi.m_value = new_value; + for (auto idx : vi.m_muls) { + auto const& [v, monomial] = m_muls[idx]; + + int_t prod(1); + for (auto w : monomial) + prod *= value(w); + if (value(v) != prod) + m_vars_to_update.push_back({ v, prod }); + } + for (auto const& idx : vi.m_adds) { + auto const& ad = m_adds[idx]; + auto const& args = ad.m_args; + auto v = ad.m_var; + int_t sum(ad.m_coeff); + for (auto [c, w] : args) + sum += c * value(w); + if (value(v) != sum) + m_vars_to_update.push_back({ v, sum }); + } + if (vi.m_add_idx != UINT_MAX || vi.m_mul_idx != UINT_MAX) + // add repair actions for additions. + m_defs_to_update.push_back(v); + } + + template + typename arith_plugin::ineq& arith_plugin::new_ineq(ineq_kind op, int_t const& coeff) { + auto* i = alloc(ineq); + i->m_coeff = coeff; + i->m_op = op; + return *i; + } + + template + void arith_plugin::add_arg(linear_term& ineq, int_t const& c, var_t v) { + ineq.m_args.push_back({ c, v }); + } + + template + bool arith_plugin::is_int64(expr* e, int_t& i) { + rational r; + if (a.is_numeral(e, r) && r.is_int64()) { + i = int_t(r.get_int64()); + return true; + } + return false; + } + + bool arith_plugin>::is_int(expr* e, checked_int64& i) { + return is_int64(e, i); + } + + bool arith_plugin::is_int(expr* e, rational& i) { + return a.is_numeral(e, i) && i.is_int(); + } + + template + bool arith_plugin::is_int(expr* e, int_t& i) { + return false; + } + + template + void arith_plugin::add_args(linear_term& term, expr* e, int_t const& coeff) { + auto v = m_expr2var.get(e->get_id(), UINT_MAX); + if (v != UINT_MAX) { + add_arg(term, coeff, v); + return; + } + expr* x, * y; + int_t i; + if (is_int(e, i)) { + term.m_coeff += coeff * i; + return; + } + if (a.is_add(e)) { + for (expr* arg : *to_app(e)) + add_args(term, arg, coeff); + return; + } + if (a.is_sub(e, x, y)) { + add_args(term, x, coeff); + add_args(term, y, -coeff); + return; + } + + if (a.is_mul(e)) { + unsigned_vector m; + int_t c = coeff; + for (expr* arg : *to_app(e)) + if (is_int(x, i)) + c *= i; + else + m.push_back(mk_term(arg)); + switch (m.size()) { + case 0: + term.m_coeff += c; + break; + case 1: + add_arg(term, c, m[0]); + break; + default: { + auto v = mk_var(e); + unsigned idx = m_muls.size(); + m_muls.push_back({ v, m }); + int_t prod(1); + for (auto w : m) + m_vars[w].m_muls.push_back(idx), prod *= value(w); + m_vars[v].m_mul_idx = idx; + m_vars[v].m_value = prod; + add_arg(term, c, v); + break; + } + } + return; + } + if (a.is_uminus(e, x)) { + add_args(term, x, -coeff); + return; + } + + if (is_uninterp(e)) { + auto v = mk_var(e); + add_arg(term, coeff, v); + return; + } + + UNREACHABLE(); + } + + template + typename arith_plugin::var_t arith_plugin::mk_term(expr* e) { + auto v = m_expr2var.get(e->get_id(), UINT_MAX); + if (v != UINT_MAX) + return v; + linear_term t = linear_term({ {}, 0 }); + add_args(t, e, int_t(1)); + if (t.m_coeff == 1 && t.m_args.size() == 1 && t.m_args[0].first == 1) + return t.m_args[0].second; + v = mk_var(e); + auto idx = m_adds.size(); + int_t sum(t.m_coeff); + m_adds.push_back({ t.m_args, t.m_coeff, v }); + for (auto const& [c, w] : t.m_args) + m_vars[w].m_adds.push_back(idx), sum += c * value(w); + m_vars[v].m_add_idx = idx; + m_vars[v].m_value = sum; + return v; + } + + template + unsigned arith_plugin::mk_var(expr* e) { + unsigned v = m_expr2var.get(e->get_id(), UINT_MAX); + if (v == UINT_MAX) { + v = m_vars.size(); + m_expr2var.setx(e->get_id(), v, UINT_MAX); + m_vars.push_back(var_info(e, var_kind::INT)); + } + return v; + } + + template + void arith_plugin::init_bool_var(sat::bool_var bv) { + if (m_bool_vars.get(bv, nullptr)) + return; + expr* e = ctx.atom(bv); + // verbose_stream() << "bool var " << bv << " " << mk_bounded_pp(e, m) << "\n"; + if (!e) + return; + expr* x, * y; + m_bool_vars.reserve(bv + 1); + if (a.is_le(e, x, y) || a.is_ge(e, y, x)) { + auto& ineq = new_ineq(ineq_kind::LE, int_t(0)); + add_args(ineq, x, int_t(1)); + add_args(ineq, y, int_t(-1)); + init_ineq(bv, ineq); + } + else if ((a.is_lt(e, x, y) || a.is_gt(e, y, x)) && a.is_int(x)) { + auto& ineq = new_ineq(ineq_kind::LE, int_t(1)); + add_args(ineq, x, int_t(1)); + add_args(ineq, y, int_t(-1)); + init_ineq(bv, ineq); + } + else if (m.is_eq(e, x, y) && a.is_int_real(x)) { + auto& ineq = new_ineq(ineq_kind::EQ, int_t(0)); + add_args(ineq, x, int_t(1)); + add_args(ineq, y, int_t(-1)); + init_ineq(bv, ineq); + } + else { + SASSERT(!a.is_arith_expr(e)); + } + } + + template + void arith_plugin::init_ineq(sat::bool_var bv, ineq& i) { + i.m_args_value = 0; + for (auto const& [coeff, v] : i.m_args) { + m_vars[v].m_bool_vars.push_back({ coeff, bv }); + i.m_args_value += coeff * value(v); + } + m_bool_vars.set(bv, &i); + } + + template + void arith_plugin::init_bool_var_assignment(sat::bool_var v) { + auto* ineq = m_bool_vars.get(v, nullptr); + if (ineq && ctx.is_true(sat::literal(v, false)) != (dtt(false, *ineq) == 0)) + ctx.flip(v); + } + + template + void arith_plugin::repair(sat::literal lit) { + if (!ctx.is_true(lit)) + return; + auto const* ineq = atom(lit.var()); + if (!ineq) + return; + if (ineq->is_true() != lit.sign()) + return; + TRACE("sls", tout << "repair " << lit << "\n"); + repair(lit, *ineq); + } + + template + void arith_plugin::propagate_updates() { + while (!m_defs_to_update.empty() || !m_vars_to_update.empty()) { + while (!m_vars_to_update.empty()) { + auto [w, new_value1] = m_vars_to_update.back(); + m_vars_to_update.pop_back(); + update(w, new_value1); + } + repair_defs(); + } + } + + template + void arith_plugin::repair_defs() { + while (!m_defs_to_update.empty()) { + auto v = m_defs_to_update.back(); + m_defs_to_update.pop_back(); + auto const& vi = m_vars[v]; + if (vi.m_mul_idx != UINT_MAX) + repair_mul(m_muls[vi.m_mul_idx]); + if (vi.m_add_idx != UINT_MAX) + repair_add(m_adds[vi.m_add_idx]); + } + } + + template + void arith_plugin::repair_add(add_def const& ad) { + auto v = ad.m_var; + auto const& coeffs = ad.m_args; + int_t sum(ad.m_coeff); + int_t val = value(v); + for (auto const& [c, w] : coeffs) + sum += c * value(w); + if (val == sum) + return; + if (rand() % 20 == 0) + update(v, sum); + else { + auto const& [c, w] = coeffs[rand() % coeffs.size()]; + int_t delta = sum - val; + int_t new_value = value(w) + div(delta, c); + update(w, new_value); + } + } + + template + void arith_plugin::repair_mul(mul_def const& md) { + int_t product(1); + int_t val = value(md.m_var); + for (auto v : md.m_monomial) + product *= value(v); + if (product == val) + return; + if (rand() % 20 == 0) { + update(md.m_var, product); + } + else if (val == 0) { + auto v = md.m_monomial[rand() % md.m_monomial.size()]; + int_t zero(0); + update(v, zero); + } + else if (val == 1 || val == -1) { + product = 1; + for (auto v : md.m_monomial) { + int_t new_value(1); + if (rand() % 2 == 0) + new_value = -1; + product *= new_value; + update(v, new_value); + } + if (product != val) { + auto last = md.m_monomial.back(); + update(last, -value(last)); + } + } + else { + product = 1; + for (auto v : md.m_monomial) { + int_t new_value{ 1 }; + if (rand() % 2 == 0) + new_value = -1; + product *= new_value; + update(v, new_value); + } + auto v = md.m_monomial[rand() % md.m_monomial.size()]; + if ((product < 0 && 0 < val) || (val < 0 && 0 < product)) + update(v, -val * value(v)); + else + update(v, val * value(v)); + } + } + + template + double arith_plugin::reward(sat::literal lit) { + if (m_dscore_mode) + return dscore_reward(lit.var()); + else + return dtt_reward(lit); + } + + template + double arith_plugin::dtt_reward(sat::literal lit) { + auto* ineq = atom(lit.var()); + if (!ineq) + return -1; + int_t new_value; + double max_result = -1; + unsigned n = 0; + for (auto const& [coeff, x] : ineq->m_args) { + if (!cm(*ineq, x, coeff, new_value)) + continue; + double result = 0; + auto old_value = m_vars[x].m_value; + for (auto const& [coeff, bv] : m_vars[x].m_bool_vars) { + result += ctx.reward(bv); +#if 0 + bool old_sign = sign(bv); + auto dtt_old = dtt(old_sign, *atom(bv)); + auto dtt_new = dtt(old_sign, *atom(bv), coeff, old_value, new_value); + if ((dtt_new == 0) != (dtt_old == 0)) + result += ctx.reward(bv); +#endif + } + if (result > max_result || max_result == -1 || (result == max_result && (rand() % ++n == 0))) { + max_result = result; + ineq->m_var_to_flip = x; + } + } + return max_result; + } + + template + double arith_plugin::dscore_reward(sat::bool_var bv) { + m_dscore_mode = false; + bool old_sign = sign(bv); + sat::literal litv(bv, old_sign); + auto* ineq = atom(bv); + if (!ineq) + return 0; + SASSERT(ineq->is_true() != old_sign); + int_t new_value; + + for (auto const& [coeff, v] : ineq->m_args) { + double result = 0; + if (cm(*ineq, v, coeff, new_value)) + result = dscore(v, new_value); + // just pick first positive, or pick a max? + if (result > 0) { + ineq->m_var_to_flip = v; + return result; + } + } + return 0; + } + + // switch to dscore mode + template + void arith_plugin::on_rescale() { + m_dscore_mode = true; + } + + template + void arith_plugin::on_restart() { + for (unsigned v = 0; v < ctx.num_bool_vars(); ++v) + init_bool_var_assignment(v); + check_ineqs(); + } + + template + void arith_plugin::check_ineqs() { + auto check_bool_var = [&](sat::bool_var bv) { + auto const* ineq = atom(bv); + if (!ineq) + return; + int_t d = dtt(sign(bv), *ineq); + sat::literal lit(bv, sign(bv)); + if (ctx.is_true(lit) != (d == 0)) { + verbose_stream() << "invalid assignment " << bv << " " << *ineq << "\n"; + } + VERIFY(ctx.is_true(lit) == (d == 0)); + }; + for (unsigned v = 0; v < ctx.num_bool_vars(); ++v) + check_bool_var(v); + } + + template + void arith_plugin::register_term(expr* e) { + } + + template + expr_ref arith_plugin::get_value(expr* e) { + auto v = mk_var(e); + return expr_ref(a.mk_numeral(rational(m_vars[v].m_value.get_int64(), rational::i64()), a.is_int(e)), m); + } + + template + lbool arith_plugin::check() { + // repair each root literal + for (sat::literal lit : ctx.root_literals()) + repair(lit); + + propagate_updates(); + + // update literal assignment based on current model + for (unsigned v = 0; v < ctx.num_bool_vars(); ++v) + init_bool_var_assignment(v); + + return ctx.unsat().empty() ? l_true : l_undef; + } + + template + bool arith_plugin::is_sat() { + for (auto const& clause : ctx.clauses()) { + bool sat = false; + for (auto lit : clause.m_clause) { + if (!ctx.is_true(lit)) + continue; + auto ineq = atom(lit.var()); + if (!ineq) { + sat = true; + break; + } + if (ineq->is_true() != lit.sign()) { + sat = true; + break; + } + } + if (!sat) + return false; + } + return true; + } + + template + std::ostream& arith_plugin::display(std::ostream& out) const { + for (unsigned v = 0; v < ctx.num_bool_vars(); ++v) { + auto ineq = atom(v); + if (ineq) + out << v << ": " << *ineq << "\n"; + } + for (unsigned v = 0; v < m_vars.size(); ++v) { + auto const& vi = m_vars[v]; + out << "v" << v << " := " << vi.m_value << " " << vi.m_best_value << " "; + out << mk_bounded_pp(vi.m_expr, m) << " - "; + for (auto [c, bv] : vi.m_bool_vars) + out << c << "@" << bv << " "; + out << "\n"; + } + return out; + } + + template + void arith_plugin::mk_model(model& mdl) { + for (auto const& v : m_vars) { + expr* e = v.m_expr; + if (is_uninterp_const(e)) + mdl.register_decl(to_app(e)->get_decl(), get_value(e)); + } + } +} + +template class sls::arith_plugin>; +template class sls::arith_plugin; \ No newline at end of file diff --git a/src/ast/sls/sls_arith_int.h b/src/ast/sls/sls_arith_int.h new file mode 100644 index 000000000..5afff74f4 --- /dev/null +++ b/src/ast/sls/sls_arith_int.h @@ -0,0 +1,193 @@ +/*++ +Copyright (c) 2020 Microsoft Corporation + +Module Name: + + arith_local_search.h + +Abstract: + + Theory plugin for arithmetic local search + +Author: + + Nikolaj Bjorner (nbjorner) 2020-09-08 + +--*/ +#pragma once + +#include "util/obj_pair_set.h" +#include "util/checked_int64.h" +#include "ast/ast_trail.h" +#include "ast/arith_decl_plugin.h" +#include "ast/sls/sls_smt.h" + +namespace sls { + + using theory_var = int; + + // local search portion for arithmetic + template + class arith_plugin : public plugin { + enum class ineq_kind { EQ, LE, LT}; + enum class var_kind { INT, REAL }; + typedef unsigned var_t; + typedef unsigned atom_t; + + struct config { + double cb = 0.0; + unsigned L = 20; + unsigned t = 45; + unsigned max_no_improve = 500000; + double sp = 0.0003; + }; + + struct stats { + unsigned m_num_flips = 0; + }; + + // typedef checked_int64 int_t; + + public: + struct linear_term { + vector> m_args; + int_t m_coeff; + }; + // encode args <= bound, args = bound, args < bound + struct ineq : public linear_term { + ineq_kind m_op = ineq_kind::LE; + int_t m_args_value; + unsigned m_var_to_flip = UINT_MAX; + + bool is_true() const { + switch (m_op) { + case ineq_kind::LE: + return m_args_value + m_coeff <= 0; + case ineq_kind::EQ: + return m_args_value + m_coeff == 0; + default: + return m_args_value + m_coeff < 0; + } + } + std::ostream& display(std::ostream& out) const { + bool first = true; + for (auto const& [c, v] : m_args) + out << (first ? "" : " + ") << c << " * v" << v, first = false; + if (m_coeff != 0) + out << " + " << m_coeff; + switch (m_op) { + case ineq_kind::LE: + return out << " <= " << 0 << "(" << m_args_value << ")"; + case ineq_kind::EQ: + return out << " == " << 0 << "(" << m_args_value << ")"; + default: + return out << " < " << 0 << "(" << m_args_value << ")"; + } + } + }; + private: + + struct var_info { + var_info(expr* e, var_kind k): m_expr(e), m_kind(k) {} + expr* m_expr; + int_t m_value{ 0 }; + int_t m_best_value{ 0 }; + var_kind m_kind; + unsigned m_add_idx = UINT_MAX; + unsigned m_mul_idx = UINT_MAX; + vector> m_bool_vars; + unsigned_vector m_muls; + unsigned_vector m_adds; + }; + + struct mul_def { + unsigned m_var; + unsigned_vector m_monomial; + }; + + struct add_def : public linear_term { + unsigned m_var; + }; + + stats m_stats; + config m_config; + scoped_ptr_vector m_bool_vars; + vector m_vars; + vector m_muls; + vector m_adds; + unsigned_vector m_expr2var; + bool m_dscore_mode = false; + arith_util a; + + unsigned get_num_vars() const { return m_vars.size(); } + + void repair_mul(mul_def const& md); + void repair_add(add_def const& ad); + unsigned_vector m_defs_to_update; + vector> m_vars_to_update; + void propagate_updates(); + void repair_defs(); + void repair(sat::literal lit); + void repair(sat::literal lit, ineq const& ineq); + + double reward(sat::literal lit); + + bool sign(sat::bool_var v) const { return !ctx.is_true(sat::literal(v, false)); } + ineq* atom(sat::bool_var bv) const { return m_bool_vars.get(bv, nullptr); } + + + int_t dtt(bool sign, ineq const& ineq) const { return dtt(sign, ineq.m_args_value, ineq); } + int_t dtt(bool sign, int_t const& args_value, ineq const& ineq) const; + int_t dtt(bool sign, ineq const& ineq, var_t v, int_t const& new_value) const; + int_t dtt(bool sign, ineq const& ineq, int_t const& coeff, int_t const& old_value, int_t const& new_value) const; + int_t dts(unsigned cl, var_t v, int_t const& new_value) const; + int_t compute_dts(unsigned cl) const; + bool cm(ineq const& ineq, var_t v, int_t& new_value); + bool cm(ineq const& ineq, var_t v, int_t const& coeff, int_t& new_value); + int cm_score(var_t v, int_t const& new_value); + void update(var_t v, int_t const& new_value); + double dscore_reward(sat::bool_var v); + double dtt_reward(sat::literal lit); + double dscore(var_t v, int_t const& new_value) const; + void save_best_values(); + void store_best_values(); + unsigned mk_var(expr* e); + ineq& new_ineq(ineq_kind op, int_t const& bound); + void add_arg(linear_term& term, int_t const& c, var_t v); + void add_args(linear_term& term, expr* e, int_t const& sign); + var_t mk_term(expr* e); + void init_ineq(sat::bool_var bv, ineq& i); + + void init_bool_var_assignment(sat::bool_var v); + + int_t value(var_t v) const { return m_vars[v].m_value; } + bool is_int64(expr* e, int_t& i); + bool is_int(expr* e, int_t& i); + + void check_ineqs(); + + public: + arith_plugin(context& ctx); + ~arith_plugin() override {} + void init_bool_var(sat::bool_var v) override; + void register_term(expr* e) override; + expr_ref get_value(expr* e) override; + lbool check() override; + bool is_sat() override; + void reset() override; + + void on_rescale() override; + void on_restart() override; + std::ostream& display(std::ostream& out) const override; + void mk_model(model& mdl) override; + }; + + + inline std::ostream& operator<<(std::ostream& out, typename arith_plugin>::ineq const& ineq) { + return ineq.display(out); + } + + inline std::ostream& operator<<(std::ostream& out, typename arith_plugin::ineq const& ineq) { + return ineq.display(out); + } +} diff --git a/src/ast/sls/sls_cc.cpp b/src/ast/sls/sls_cc.cpp new file mode 100644 index 000000000..0d5ebf4c7 --- /dev/null +++ b/src/ast/sls/sls_cc.cpp @@ -0,0 +1,145 @@ +/*++ +Copyright (c) 2024 Microsoft Corporation + +Module Name: + + sls_cc.cpp + +Abstract: + + Congruence Closure for SLS + +Author: + + Nikolaj Bjorner (nbjorner) 2024-06-24 + +--*/ + +#include "ast/sls/sls_cc.h" +#include "ast/ast_ll_pp.h" +#include "ast/ast_pp.h" + + +namespace sls { + + cc_plugin::cc_plugin(context& c): + plugin(c), + m_values(8U, value_hash(*this), value_eq(*this)) { + m_fid = m.mk_family_id("cc"); + } + + cc_plugin::~cc_plugin() {} + + expr_ref cc_plugin::get_value(expr* e) { + UNREACHABLE(); + return expr_ref(m); + } + + void cc_plugin::reset() { + m_app.reset(); + } + + void cc_plugin::register_term(expr* e) { + if (!is_app(e)) + return; + if (!is_uninterp(e)) + return; + app* a = to_app(e); + if (a->get_num_args() == 0) + return; + auto f = a->get_decl(); + if (!m_app.contains(f)) + m_app.insert(f, ptr_vector()); + m_app[f].push_back(a); + } + + unsigned cc_plugin::value_hash::operator()(app* t) const { + unsigned r = 0; + for (auto arg : *t) + r *= 3, r += cc.ctx.get_value(arg)->hash(); + return r; + } + + bool cc_plugin::value_eq::operator()(app* a, app* b) const { + SASSERT(a->get_num_args() == b->get_num_args()); + for (unsigned i = a->get_num_args(); i-- > 0; ) + if (cc.ctx.get_value(a->get_arg(i)) != cc.ctx.get_value(b->get_arg(i))) + return false; + return true; + } + + bool cc_plugin::is_sat() { + for (auto& [f, ts] : m_app) { + if (ts.size() <= 1) + continue; + m_values.reset(); + for (auto* t : ts) { + app* u; + if (!ctx.is_relevant(t)) + continue; + if (m_values.find(t, u)) { + if (ctx.get_value(t) != ctx.get_value(u)) + return false; + } + else + m_values.insert(t); + } + } + return true; + } + + lbool cc_plugin::check() { + bool new_constraint = false; + for (auto & [f, ts] : m_app) { + if (ts.size() <= 1) + continue; + m_values.reset(); + for (auto * t : ts) { + app* u; + if (!ctx.is_relevant(t)) + continue; + if (m_values.find(t, u)) { + if (ctx.get_value(t) == ctx.get_value(u)) + continue; + expr_ref_vector ors(m); + for (unsigned i = t->get_num_args(); i-- > 0; ) + ors.push_back(m.mk_not(m.mk_eq(t->get_arg(i), u->get_arg(i)))); + ors.push_back(m.mk_eq(t, u)); + ctx.add_constraint(m.mk_or(ors)); + new_constraint = true; + } + else + m_values.insert(t); + } + } + return new_constraint ? l_undef : l_true; + } + + std::ostream& cc_plugin::display(std::ostream& out) const { + for (auto& [f, ts] : m_app) { + for (auto* t : ts) + out << mk_bounded_pp(t, m) << "\n"; + out << "\n"; + } + return out; + } + + void cc_plugin::mk_model(model& mdl) { + expr_ref_vector args(m); + for (auto& [f, ts] : m_app) { + func_interp* fi = alloc(func_interp, m, f->get_arity()); + mdl.register_decl(f, fi); + m_values.reset(); + for (auto* t : ts) { + if (m_values.contains(t)) + continue; + args.reset(); + expr_ref val = ctx.get_value(t); + for (auto arg : *t) + args.push_back(ctx.get_value(arg)); + fi->insert_new_entry(args.data(), val); + m_values.insert(t); + } + } + } +} diff --git a/src/ast/sls/sls_cc.h b/src/ast/sls/sls_cc.h new file mode 100644 index 000000000..c68671909 --- /dev/null +++ b/src/ast/sls/sls_cc.h @@ -0,0 +1,51 @@ +/*++ +Copyright (c) 2024 Microsoft Corporation + +Module Name: + + cc_sls.h + +Abstract: + + Congruence Closure for SLS + +Author: + + Nikolaj Bjorner (nbjorner) 2024-06-24 + +--*/ +#pragma once + +#include "util/hashtable.h" +#include "ast/sls/sls_smt.h" + +namespace sls { + + class cc_plugin : public plugin { + obj_map> m_app; + struct value_hash { + cc_plugin& cc; + value_hash(cc_plugin& cc) : cc(cc) {} + unsigned operator()(app* t) const; + }; + struct value_eq { + cc_plugin& cc; + value_eq(cc_plugin& cc) : cc(cc) {} + bool operator()(app* a, app* b) const; + }; + hashtable m_values; + public: + cc_plugin(context& c); + ~cc_plugin() override; + family_id fid() { return m_fid; } + expr_ref get_value(expr* e) override; + lbool check() override; + bool is_sat() override; + void reset() override; + void register_term(expr* e) override; + void init_bool_var(sat::bool_var v) override {} + std::ostream& display(std::ostream& out) const override; + void mk_model(model& mdl) override; + }; + +} diff --git a/src/ast/sls/sls_smt.cpp b/src/ast/sls/sls_smt.cpp new file mode 100644 index 000000000..aed245807 --- /dev/null +++ b/src/ast/sls/sls_smt.cpp @@ -0,0 +1,256 @@ +/*++ +Copyright (c) 2024 Microsoft Corporation + +Module Name: + + smt_sls.cpp + +Abstract: + + A Stochastic Local Search (SLS) Context. + +Author: + + Nikolaj Bjorner (nbjorner) 2024-06-24 + +--*/ +#pragma once + +#include "ast/sls/sls_smt.h" +#include "ast/sls/sls_cc.h" +#include "ast/sls/sls_arith_int.h" + +namespace sls { + + plugin::plugin(context& c): + ctx(c), + m(c.get_manager()) { + reset(); + } + + context::context(ast_manager& m, sat_solver_context& s) : + m(m), s(s), m_atoms(m), m_subterms(m) { + reset(); + } + + void context::register_plugin(plugin* p) { + m_plugins.reserve(p->fid() + 1); + m_plugins.set(p->fid(), p); + } + + void context::register_atom(sat::bool_var v, expr* e) { + m_atoms.setx(v, e); + m_atom2bool_var.setx(e->get_id(), v, UINT_MAX); + } + + typedef arith_plugin> arith64; + + void context::reset() { + m_plugins.reset(); + m_atoms.reset(); + m_atom2bool_var.reset(); + m_initialized = false; + m_parents.reset(); + m_relevant.reset(); + m_visited.reset(); + m_subterms.reset(); + register_plugin(alloc(cc_plugin, *this)); + register_plugin(alloc(arith64, *this)); + } + + lbool context::check() { + // + // initialize data-structures if not done before. + // identify minimal feasible assignment to literals. + // sub-expressions within assignment are relevant. + // Use timestamps to make it incremental. + // + init(); + while (unsat().empty()) { + reinit_relevant(); + for (auto p : m_plugins) { + lbool r; + if (p && (r = p->check()) != l_true) + return r; + } + if (m_new_constraint) + return l_undef; + if (all_of(m_plugins, [&](auto* p) { return !p || p->is_sat(); })) { + model_ref mdl = alloc(model, m); + for (auto p : m_plugins) + if (p) + p->mk_model(*mdl); + s.on_model(mdl); + verbose_stream() << *mdl << "\n"; + return l_true; + } + } + return l_undef; + } + + expr_ref context::get_value(expr* e) { + if (m.is_bool(e)) { + auto v = m_atom2bool_var[e->get_id()]; + return expr_ref(is_true(sat::literal(v, false)) ? m.mk_true() : m.mk_false(), m); + } + sort* s = e->get_sort(); + auto fid = s->get_family_id(); + auto p = m_plugins.get(fid, nullptr); + if (p) + return p->get_value(e); + UNREACHABLE(); + return expr_ref(e, m); + } + + void context::set_value(expr* e, expr* v) { + NOT_IMPLEMENTED_YET(); + } + + bool context::is_relevant(expr* e) { + unsigned id = e->get_id(); + if (m_relevant.contains(id)) + return true; + if (m_visited.contains(id)) + return false; + m_visited.insert(id); + for (auto p : m_parents[id]) { + if (is_relevant(p)) { + m_relevant.insert(id); + return true; + } + } + return false; + } + + void context::add_constraint(expr* e) { + expr_ref _e(e, m); + sat::literal_vector lits; + auto add_literal = [&](expr* e) { + bool is_neg = m.is_not(e, e); + auto v = mk_atom(e); + lits.push_back(sat::literal(v, is_neg)); + }; + if (m.is_or(e)) + for (auto arg : *to_app(e)) + add_literal(arg); + else + add_literal(e); + TRACE("sls", tout << "new clause " << lits << "\n"); + s.add_clause(lits.size(), lits.data()); + m_new_constraint = true; + } + + sat::bool_var context::mk_atom(expr* e) { + auto v = m_atom2bool_var.get(e->get_id(), sat::null_bool_var); + if (v == sat::null_bool_var) { + v = s.add_var(); + register_subterms(e); + register_atom(v, e); + init_bool_var(v); + } + return v; + } + + void context::init_bool_var(sat::bool_var v) { + for (auto p : m_plugins) + if (p) + p->init_bool_var(v); + } + + void context::init() { + m_new_constraint = false; + if (m_initialized) + return; + m_initialized = true; + register_terms(); + for (sat::bool_var v = 0; v < num_bool_vars(); ++v) + init_bool_var(v); + } + + void context::register_terms() { + for (auto a : m_atoms) + if (a) + register_subterms(a); + } + + void context::register_subterms(expr* e) { + auto is_visited = [&](expr* e) { + return nullptr != m_subterms.get(e->get_id(), nullptr); + }; + auto visit = [&](expr* e) { + m_subterms.setx(e->get_id(), e); + }; + if (is_visited(e)) + return; + m_todo.push_back(e); + while (!m_todo.empty()) { + expr* e = m_todo.back(); + if (is_visited(e)) + m_todo.pop_back(); + else if (is_app(e)) { + if (all_of(*to_app(e), [&](expr* arg) { return is_visited(arg); })) { + for (expr* arg : *to_app(e)) { + m_parents.reserve(arg->get_id() + 1); + m_parents[arg->get_id()].push_back(e); + } + register_term(e); + visit(e); + m_todo.pop_back(); + } + else { + for (expr* arg : *to_app(e)) + m_todo.push_back(arg); + } + } + else { + register_term(e); + visit(e); + m_todo.pop_back(); + } + } + } + + void context::register_term(expr* e) { + for (auto p : m_plugins) + if (p) + p->register_term(e); + } + + void context::reinit_relevant() { + m_relevant.reset(); + m_visited.reset(); + m_root_literals.reset(); + for (auto const& clause : s.clauses()) { + bool has_relevant = false; + unsigned n = 0; + sat::literal selected_lit = sat::null_literal; + for (auto lit : clause) { + auto atm = m_atoms.get(lit.var(), nullptr); + if (!atm) + continue; + auto a = atm->get_id(); + if (!is_true(lit)) + continue; + if (m_relevant.contains(a)) { + has_relevant = true; + break; + } + if (m_rand() % ++n == 0) + selected_lit = lit; + } + if (!has_relevant && selected_lit != sat::null_literal) { + m_relevant.insert(m_atoms[selected_lit.var()]->get_id()); + m_root_literals.push_back(selected_lit); + } + } + shuffle(m_root_literals.size(), m_root_literals.data(), m_rand); + } + + std::ostream& context::display(std::ostream& out) const { + for (auto p : m_plugins) { + if (p) + p->display(out); + } + return out; + } +} diff --git a/src/ast/sls/sls_smt.h b/src/ast/sls/sls_smt.h new file mode 100644 index 000000000..23565a1e5 --- /dev/null +++ b/src/ast/sls/sls_smt.h @@ -0,0 +1,125 @@ +/*++ +Copyright (c) 2024 Microsoft Corporation + +Module Name: + + smt_sls.h + +Abstract: + + A Stochastic Local Search (SLS) Context. + +Author: + + Nikolaj Bjorner (nbjorner) 2024-06-24 + +--*/ +#pragma once + +#include "util/sat_literal.h" +#include "util/sat_sls.h" +#include "ast/ast.h" +#include "model/model.h" +#include "util/scoped_ptr_vector.h" +#include "util/obj_hashtable.h" + +namespace sls { + + class context; + + class plugin { + protected: + context& ctx; + ast_manager& m; + family_id m_fid; + public: + plugin(context& c); + virtual ~plugin() {} + virtual family_id fid() { return m_fid; } + virtual void register_term(expr* e) = 0; + virtual expr_ref get_value(expr* e) = 0; + virtual void init_bool_var(sat::bool_var v) = 0; + virtual lbool check() = 0; + virtual bool is_sat() = 0; + virtual void reset() {}; + virtual void on_rescale() {}; + virtual void on_restart() {}; + virtual std::ostream& display(std::ostream& out) const = 0; + virtual void mk_model(model& mdl) = 0; + }; + + using clause = std::initializer_list ; + + class sat_solver_context { + public: + virtual vector const& clauses() const = 0; + virtual sat::clause_info const& get_clause(unsigned idx) const = 0; + virtual std::initializer_list get_use_list(sat::literal lit) = 0; + virtual void flip(sat::bool_var v) = 0; + virtual double reward(sat::bool_var v) = 0; + virtual double get_weigth(unsigned clause_idx) = 0; + virtual bool is_true(sat::literal lit) = 0; + virtual unsigned num_vars() const = 0; + virtual indexed_uint_set const& unsat() const = 0; + virtual void on_model(model_ref& mdl) = 0; + virtual sat::bool_var add_var() = 0; + virtual void add_clause(unsigned n, sat::literal const* lits) = 0; + }; + + class context { + ast_manager& m; + sat_solver_context& s; + scoped_ptr_vector m_plugins; + indexed_uint_set m_relevant, m_visited; + expr_ref_vector m_atoms; + unsigned_vector m_atom2bool_var; + vector> m_parents; + sat::literal_vector m_root_literals; + random_gen m_rand; + bool m_initialized = false; + bool m_new_constraint = false; + expr_ref_vector m_subterms; + + void register_plugin(plugin* p); + + void init(); + void init_bool_var(sat::bool_var v); + void register_terms(); + ptr_vector m_todo; + void register_subterms(expr* e); + void register_term(expr* e); + sat::bool_var mk_atom(expr* e); + + public: + context(ast_manager& m, sat_solver_context& s); + + // Between SAT/SMT solver and context. + void register_atom(sat::bool_var v, expr* e); + void reset(); + lbool check(); + + // expose sat_solver to plugins + vector const& clauses() const { return s.clauses(); } + sat::clause_info const& get_clause(unsigned idx) const { return s.get_clause(idx); } + std::initializer_list get_use_list(sat::literal lit) { return s.get_use_list(lit); } + double get_weight(unsigned clause_idx) { return s.get_weigth(clause_idx); } + unsigned num_bool_vars() const { return s.num_vars(); } + bool is_true(sat::literal lit) { return s.is_true(lit); } + expr* atom(sat::bool_var v) { return m_atoms.get(v, nullptr); } + void flip(sat::bool_var v) { s.flip(v); } + double reward(sat::bool_var v) { return s.reward(v); } + indexed_uint_set const& unsat() const { return s.unsat(); } + unsigned rand() { return m_rand(); } + sat::literal_vector const& root_literals() const { return m_root_literals; } + + void reinit_relevant(); + + // Between plugin solvers + expr_ref get_value(expr* e); + void set_value(expr* e, expr* v); + bool is_relevant(expr* e); + void add_constraint(expr* e); + ast_manager& get_manager() { return m; } + std::ostream& display(std::ostream& out) const; + }; +} diff --git a/src/sat/sat_ddfw.cpp b/src/sat/sat_ddfw.cpp index bd7b0d26c..261e225f1 100644 --- a/src/sat/sat_ddfw.cpp +++ b/src/sat/sat_ddfw.cpp @@ -33,8 +33,6 @@ namespace sat { ddfw::~ddfw() { - for (auto& ci : m_clauses) - m_alloc.del_clause(ci.m_clause); } lbool ddfw::check(unsigned sz, literal const* assumptions, parallel* p) { @@ -63,13 +61,12 @@ namespace sat { m_plugin->init_search(); m_steps_since_progress = 0; unsigned steps = 0; - while (m_min_sz > 0 && m_steps_since_progress++ <= 1500000) { + save_best_values(); + while (m_min_sz != 0 && m_steps_since_progress++ <= 1500000) { if (should_reinit_weights()) do_reinit_weights(); else if (steps % 5000 == 0) shift_weights(), m_plugin->on_rescale(); else if (should_restart()) do_restart(), m_plugin->on_restart(); else if (do_flip()); - else if (do_literal_flip()); - else if (should_parallel_sync()) do_parallel_sync(); else shift_weights(), m_plugin->on_rescale(); ++steps; } @@ -78,7 +75,7 @@ namespace sat { void ddfw::log() { double sec = m_stopwatch.get_current_seconds(); - double kflips_per_sec = (m_flips - m_last_flips) / (1000.0 * sec); + double kflips_per_sec = sec > 0 ? (m_flips - m_last_flips) / (1000.0 * sec) : 0.0; if (m_last_flips == 0) { IF_VERBOSE(1, verbose_stream() << "(sat.ddfw :unsat :models :kflips/sec :flips :restarts :reinits :unsat_vars :shifts"; if (m_par) verbose_stream() << " :par"; @@ -112,10 +109,7 @@ namespace sat { return false; if (reward > 0 || (reward == 0 && m_rand(100) <= m_config.m_use_reward_zero_pct)) { - if (uses_plugin && is_external(v)) - m_plugin->flip(v); - else - flip(v); + flip(v); if (m_unsat.size() <= m_min_sz) save_best_values(); return true; @@ -154,67 +148,36 @@ namespace sat { return m_unsat_vars.elem_at(m_rand(m_unsat_vars.size())); } - template - bool ddfw::do_literal_flip() { - double reward = 1; - return apply_flip(pick_literal_var(), reward); - } - - /* - * Pick a random false literal from a satisfied clause such that - * the literal has zero break count and positive reward. - */ - template - bool_var ddfw::pick_literal_var() { -#if false - unsigned sz = m_clauses.size(); - unsigned start = rand(); - for (unsigned i = 0; i < 100; ++i) { - unsigned cl = (i + start) % sz; - if (m_unsat.contains(cl)) - continue; - for (auto lit : *m_clauses[cl].m_clause) { - if (is_true(lit)) - continue; - double r = uses_plugin ? plugin_reward(lit.var()) : reward(lit.var()); - if (r < 0) - continue; - //verbose_stream() << "false " << r << " " << lit << "\n"; - return lit.var(); - } - } -#endif - return null_bool_var; - } - void ddfw::add(unsigned n, literal const* c) { - clause* cls = m_alloc.mk_clause(n, c, false); unsigned idx = m_clauses.size(); - - m_clauses.push_back(clause_info(cls, m_config.m_init_clause_weight)); - for (literal lit : *cls) { + m_clauses.push_back(clause_info(n, c, m_config.m_init_clause_weight)); + for (literal lit : m_clauses.back().m_clause) { m_use_list.reserve(2*(lit.var()+1)); m_vars.reserve(lit.var()+1); m_use_list[lit.index()].push_back(idx); } } + sat::bool_var ddfw::add_var(bool is_internal) { + auto v = m_vars.size(); + m_vars.reserve(v + 1); + m_vars[v].m_internal = is_internal; + return v; + } + /** * Remove the last clause that was added */ void ddfw::del() { auto& info = m_clauses.back(); - for (literal lit : *info.m_clause) + for (literal lit : info.m_clause) m_use_list[lit.index()].pop_back(); - m_alloc.del_clause(info.m_clause); m_clauses.pop_back(); if (m_unsat.contains(m_clauses.size())) m_unsat.remove(m_clauses.size()); } void ddfw::add(solver const& s) { - for (auto& ci : m_clauses) - m_alloc.del_clause(ci.m_clause); m_clauses.reset(); m_use_list.reset(); m_num_non_binary_clauses = 0; @@ -295,6 +258,12 @@ namespace sat { flatten_use_list(); } + void ddfw::reinit() { + add_assumptions(); + init_clause_data(); + flatten_use_list(); + } + void ddfw::flatten_use_list() { m_use_list_index.reset(); m_flat_use_list.reset(); @@ -310,7 +279,7 @@ namespace sat { literal lit = literal(v, !value(v)); literal nlit = ~lit; SASSERT(is_true(lit)); - for (unsigned cls_idx : use_list(*this, lit)) { + for (unsigned cls_idx : use_list(lit)) { clause_info& ci = m_clauses[cls_idx]; ci.del(lit); double w = ci.m_weight; @@ -318,7 +287,7 @@ namespace sat { switch (ci.m_num_trues) { case 0: { m_unsat.insert_fresh(cls_idx); - clause const& c = get_clause(cls_idx); + auto const& c = get_clause(cls_idx); for (literal l : c) { inc_reward(l, w); inc_make(l); @@ -333,7 +302,7 @@ namespace sat { break; } } - for (unsigned cls_idx : use_list(*this, nlit)) { + for (unsigned cls_idx : use_list(nlit)) { clause_info& ci = m_clauses[cls_idx]; double w = ci.m_weight; // the clause used to have a single true (pivot) literal, now it has two. @@ -341,7 +310,7 @@ namespace sat { switch (ci.m_num_trues) { case 0: { m_unsat.remove(cls_idx); - clause const& c = get_clause(cls_idx); + auto const& c = get_clause(cls_idx); for (literal l : c) { dec_reward(l, w); dec_make(l); @@ -388,13 +357,13 @@ namespace sat { for (unsigned v = 0; v < num_vars(); ++v) { make_count(v) = 0; reward(v) = 0; - } + } m_unsat_vars.reset(); m_unsat.reset(); unsigned sz = m_clauses.size(); for (unsigned i = 0; i < sz; ++i) { auto& ci = m_clauses[i]; - clause const& c = get_clause(i); + auto const& c = get_clause(i); ci.m_trues = 0; ci.m_num_trues = 0; for (literal lit : c) @@ -475,7 +444,7 @@ namespace sat { void ddfw::save_best_values() { - if (m_unsat.size() < m_min_sz) { + if (m_unsat.size() < m_min_sz || m_unsat.empty()) { m_steps_since_progress = 0; if (m_unsat.size() < 50 || m_min_sz * 10 > m_unsat.size() * 11) save_model(); @@ -538,11 +507,11 @@ namespace sat { unsigned ddfw::select_max_same_sign(unsigned cf_idx) { auto& ci = m_clauses[cf_idx]; unsigned cl = UINT_MAX; // clause pointer to same sign, max weight satisfied clause. - clause const& c = *ci.m_clause; + auto const& c = ci.m_clause; double max_weight = m_init_weight; unsigned n = 1; for (literal lit : c) { - for (unsigned cn_idx : use_list(*this, lit)) { + for (unsigned cn_idx : use_list(lit)) { auto& cn = m_clauses[cn_idx]; if (select_clause(max_weight, cn, n)) { cl = cn_idx; @@ -608,17 +577,15 @@ namespace sat { std::ostream& ddfw::display(std::ostream& out) const { unsigned num_cls = m_clauses.size(); for (unsigned i = 0; i < num_cls; ++i) { - out << get_clause(i) << " "; + out << get_clause(i) << " nt: "; auto const& ci = m_clauses[i]; - out << ci.m_num_trues << " " << ci.m_weight << "\n"; - } - for (unsigned v = 0; v < num_vars(); ++v) { - out << v << ": " << reward(v) << "\n"; + out << ci.m_num_trues << " w: " << ci.m_weight << "\n"; } + for (unsigned v = 0; v < num_vars(); ++v) + out << (is_true(literal(v, false)) ? "" : "-") << v << " rw: " << get_reward(v) << "\n"; out << "unsat vars: "; - for (bool_var v : m_unsat_vars) { - out << v << " "; - } + for (bool_var v : m_unsat_vars) + out << v << " "; out << "\n"; return out; } diff --git a/src/sat/sat_ddfw.h b/src/sat/sat_ddfw.h index ff86e9b8c..e2a03159e 100644 --- a/src/sat/sat_ddfw.h +++ b/src/sat/sat_ddfw.h @@ -24,6 +24,7 @@ #include "util/rlimit.h" #include "util/params.h" #include "util/ema.h" +#include "util/sat_sls.h" #include "sat/sat_clause.h" #include "sat/sat_types.h" @@ -40,7 +41,6 @@ namespace sat { virtual ~local_search_plugin() {} virtual void init_search() = 0; virtual void finish_search() = 0; - virtual void flip(bool_var v) = 0; virtual double reward(bool_var v) = 0; virtual void on_rescale() = 0; virtual void on_save_model() = 0; @@ -48,30 +48,6 @@ namespace sat { }; class ddfw : public i_local_search { - friend class arith::sls; - public: - struct clause_info { - clause_info(clause* cl, double init_weight): m_weight(init_weight), m_clause(cl) {} - double m_weight; // weight of clause - unsigned m_trues = 0; // set of literals that are true - unsigned m_num_trues = 0; // size of true set - clause* m_clause; - bool is_true() const { return m_num_trues > 0; } - void add(literal lit) { ++m_num_trues; m_trues += lit.index(); } - void del(literal lit) { SASSERT(m_num_trues > 0); --m_num_trues; m_trues -= lit.index(); } - }; - - class use_list { - ddfw& p; - unsigned i; - public: - use_list(ddfw& p, literal lit) : - p(p), i(lit.index()) {} - unsigned const* begin() { return p.m_flat_use_list.data() + p.m_use_list_index[i]; } - unsigned const* end() { return p.m_flat_use_list.data() + p.m_use_list_index[i + 1]; } - unsigned size() const { return p.m_use_list_index[i + 1] - p.m_use_list_index[i]; } - }; - protected: struct config { @@ -96,6 +72,7 @@ namespace sat { struct var_info { var_info() {} + bool m_internal = false; bool m_value = false; double m_reward = 0; double m_last_reward = 0; @@ -107,8 +84,7 @@ namespace sat { config m_config; reslimit m_limit; - clause_allocator m_alloc; - svector m_clauses; + vector m_clauses; literal_vector m_assumptions; svector m_vars; // var -> info svector m_probs; // var -> probability of flipping @@ -132,7 +108,7 @@ namespace sat { stopwatch m_stopwatch; parallel* m_par; - local_search_plugin* m_plugin = nullptr; + scoped_ptr m_plugin = nullptr; void flatten_use_list(); @@ -142,17 +118,13 @@ namespace sat { */ inline double score(double r) { return r; } - inline unsigned num_vars() const { return m_vars.size(); } - inline unsigned& make_count(bool_var v) { return m_vars[v].m_make_count; } inline bool& value(bool_var v) { return m_vars[v].m_value; } inline bool value(bool_var v) const { return m_vars[v].m_value; } - inline double& reward(bool_var v) { return m_vars[v].m_reward; } - - inline double reward(bool_var v) const { return m_vars[v].m_reward; } + inline double& reward(bool_var v) { return m_vars[v].m_reward; } inline double plugin_reward(bool_var v) { return is_external(v) ? (m_vars[v].m_last_reward = m_plugin->reward(v)) : reward(v); } @@ -166,7 +138,7 @@ namespace sat { inline bool is_true(literal lit) const { return value(lit.var()) != lit.sign(); } - inline clause const& get_clause(unsigned idx) const { return *m_clauses[idx].m_clause; } + inline sat::literal_vector const& get_clause(unsigned idx) const { return m_clauses[idx].m_clause; } inline double get_weight(unsigned idx) const { return m_clauses[idx].m_weight; } @@ -203,11 +175,6 @@ namespace sat { template bool apply_flip(bool_var v, double reward); - template - bool do_literal_flip(); - - template - bool_var pick_literal_var(); void save_best_values(); void save_model(); @@ -241,7 +208,7 @@ namespace sat { void invariant(); - void add(unsigned sz, literal const* c); + void del(); @@ -257,7 +224,7 @@ namespace sat { ~ddfw() override; - void set(local_search_plugin* p) { m_plugin = p; } + void set_plugin(local_search_plugin* p) { m_plugin = p; } lbool check(unsigned sz, literal const* assumptions, parallel* p) override; @@ -286,7 +253,7 @@ namespace sat { // access clause information and state of Boolean search indexed_uint_set& unsat_set() { return m_unsat; } - unsigned num_clauses() const { return m_clauses.size(); } + vector const& clauses() const { return m_clauses; } clause_info& get_clause_info(unsigned idx) { return m_clauses[idx]; } @@ -294,7 +261,27 @@ namespace sat { void flip(bool_var v); - use_list get_use_list(literal lit) { return use_list(*this, lit); } + inline double get_reward(bool_var v) const { return m_vars[v].m_reward; } + + void add(unsigned sz, literal const* c); + + sat::bool_var add_var(bool is_internal = true); + + // is this a variable that was added during initialization? + bool is_initial_var(sat::bool_var v) const { + return m_vars.size() > v && !m_vars[v].m_internal; + } + + void reinit(); + + inline unsigned num_vars() const { return m_vars.size(); } + + std::initializer_list use_list(literal lit) { + unsigned i = lit.index(); + auto const* b = m_flat_use_list.data() + m_use_list_index[i]; + auto const* e = m_flat_use_list.data() + m_use_list_index[i + 1]; + return std::initializer_list(b, e); + } }; } diff --git a/src/sat/sat_solver/sat_smt_solver.cpp b/src/sat/sat_solver/sat_smt_solver.cpp index 19b10eb3e..ee989b0f9 100644 --- a/src/sat/sat_solver/sat_smt_solver.cpp +++ b/src/sat/sat_solver/sat_smt_solver.cpp @@ -581,8 +581,9 @@ private: void add_assumption(expr* a) { init_goal2sat(); - m_dep.insert(a, m_goal2sat.internalize(a)); - get_euf()->add_assertion(a); + auto lit = m_goal2sat.internalize(a); + m_dep.insert(a, lit); + get_euf()->add_clause(1, &lit); } void internalize_assumptions(expr_ref_vector const& asms) { diff --git a/src/sat/smt/CMakeLists.txt b/src/sat/smt/CMakeLists.txt index 7747b65cb..1e2193501 100644 --- a/src/sat/smt/CMakeLists.txt +++ b/src/sat/smt/CMakeLists.txt @@ -3,7 +3,6 @@ z3_add_component(sat_smt arith_axioms.cpp arith_diagnostics.cpp arith_internalize.cpp - arith_sls.cpp arith_solver.cpp arith_value.cpp array_axioms.cpp @@ -22,7 +21,6 @@ z3_add_component(sat_smt euf_ackerman.cpp euf_internalize.cpp euf_invariant.cpp - euf_local_search.cpp euf_model.cpp euf_proof.cpp euf_proof_checker.cpp diff --git a/src/sat/smt/arith_sls.cpp b/src/sat/smt/arith_sls.cpp deleted file mode 100644 index 216829980..000000000 --- a/src/sat/smt/arith_sls.cpp +++ /dev/null @@ -1,642 +0,0 @@ -/*++ -Copyright (c) 2023 Microsoft Corporation - -Module Name: - - arith_local_search.cpp - -Abstract: - - Local search dispatch for SMT - -Author: - - Nikolaj Bjorner (nbjorner) 2023-02-07 - ---*/ -#include "sat/sat_solver.h" -#include "sat/smt/arith_solver.h" - - -namespace arith { - - sls::sls(solver& s): - s(s), m(s.m) {} - - void sls::reset() { - m_bool_vars.reset(); - m_vars.reset(); - m_terms.reset(); - } - - void sls::save_best_values() { - for (unsigned v = 0; v < s.get_num_vars(); ++v) - m_vars[v].m_best_value = m_vars[v].m_value; - check_ineqs(); - if (unsat().size() == 1) { - auto idx = *unsat().begin(); - verbose_stream() << idx << "\n"; - auto const& c = *m_bool_search->m_clauses[idx].m_clause; - verbose_stream() << c << "\n"; - for (auto lit : c) { - bool_var bv = lit.var(); - ineq* i = atom(bv); - if (i) - verbose_stream() << lit << ": " << *i << "\n"; - } - verbose_stream() << "\n"; - } - } - - void sls::store_best_values() { - // first compute assignment to terms - // then update non-basic variables in tableau. - - if (!unsat().empty()) - return; - - for (auto const& [t,v] : m_terms) { - int64_t val = 0; - lp::lar_term const& term = s.lp().get_term(t); - for (lp::lar_term::ival const& arg : term) { - auto t2 = arg.j(); - auto w = s.lp().local_to_external(t2); - val += to_numeral(arg.coeff()) * m_vars[w].m_best_value; - } - m_vars[v].m_best_value = val; - } - - for (unsigned v = 0; v < s.get_num_vars(); ++v) { - if (s.is_bool(v)) - continue; - if (!s.lp().external_is_used(v)) - continue; - int64_t new_value = m_vars[v].m_best_value; - s.ensure_column(v); - lp::lpvar vj = s.lp().external_to_local(v); - SASSERT(vj != lp::null_lpvar); - if (!s.lp().is_base(vj)) { - rational new_value_(new_value, rational::i64()); - lp::impq val(new_value_, rational::zero()); - s.lp().set_value_for_nbasic_column(vj, val); - } - } - - lbool r = s.make_feasible(); - VERIFY (!unsat().empty() || r == l_true); -#if 0 - if (unsat().empty()) - s.m_num_conflicts = s.get_config().m_arith_propagation_threshold; -#endif - - auto check_bool_var = [&](sat::bool_var bv) { - auto* ineq = m_bool_vars.get(bv, nullptr); - if (!ineq) - return; - api_bound* b = nullptr; - s.m_bool_var2bound.find(bv, b); - if (!b) - return; - auto bound = b->get_value(); - theory_var v = b->get_var(); - if (s.get_phase(bv) == m_bool_search->get_model()[bv]) - return; - switch (b->get_bound_kind()) { - case lp_api::lower_t: - verbose_stream() << "v" << v << " " << bound << " <= " << s.get_value(v) << " " << m_vars[v].m_best_value << "\n"; - break; - case lp_api::upper_t: - verbose_stream() << "v" << v << " " << bound << " >= " << s.get_value(v) << " " << m_vars[v].m_best_value << "\n"; - break; - } - int64_t value = 0; - for (auto const& [coeff, v] : ineq->m_args) { - value += coeff * m_vars[v].m_best_value; - } - ineq->m_args_value = value; - verbose_stream() << *ineq << " dtt " << dtt(false, *ineq) << " phase " << s.get_phase(bv) << " model " << m_bool_search->get_model()[bv] << "\n"; - for (auto const& [coeff, v] : ineq->m_args) - verbose_stream() << "v" << v << " := " << m_vars[v].m_best_value << "\n"; - s.display(verbose_stream()); - display(verbose_stream()); - UNREACHABLE(); - exit(0); - }; - - if (unsat().empty()) { - for (bool_var v = 0; v < s.s().num_vars(); ++v) - check_bool_var(v); - } - } - - void sls::set(sat::ddfw* d) { - m_bool_search = d; - reset(); - m_bool_vars.reserve(s.s().num_vars()); - add_vars(); - for (unsigned i = 0; i < d->num_clauses(); ++i) - for (sat::literal lit : *d->get_clause_info(i).m_clause) - init_bool_var(lit.var()); - for (unsigned v = 0; v < s.s().num_vars(); ++v) - init_bool_var_assignment(v); - - d->set(this); - } - - // distance to true - int64_t sls::dtt(bool sign, int64_t args, ineq const& ineq) const { - switch (ineq.m_op) { - case ineq_kind::LE: - if (sign) { - if (args <= ineq.m_bound) - return ineq.m_bound - args + 1; - return 0; - } - if (args <= ineq.m_bound) - return 0; - return args - ineq.m_bound; - case ineq_kind::EQ: - if (sign) { - if (args == ineq.m_bound) - return 1; - return 0; - } - if (args == ineq.m_bound) - return 0; - return 1; - case ineq_kind::NE: - if (sign) { - if (args == ineq.m_bound) - return 0; - return 1; - } - if (args == ineq.m_bound) - return 1; - return 0; - case ineq_kind::LT: - if (sign) { - if (args < ineq.m_bound) - return ineq.m_bound - args; - return 0; - } - if (args < ineq.m_bound) - return 0; - return args - ineq.m_bound + 1; - default: - UNREACHABLE(); - return 0; - } - } - - // - // dtt is high overhead. It walks ineq.m_args - // m_vars[w].m_value can be computed outside and shared among calls - // different data-structures for storing coefficients - // - int64_t sls::dtt(bool sign, ineq const& ineq, var_t v, int64_t new_value) const { - for (auto const& [coeff, w] : ineq.m_args) - if (w == v) - return dtt(sign, ineq.m_args_value + coeff * (new_value - m_vars[v].m_value), ineq); - return 1; - } - - int64_t sls::dtt(bool sign, ineq const& ineq, int64_t coeff, int64_t old_value, int64_t new_value) const { - return dtt(sign, ineq.m_args_value + coeff * (new_value - old_value), ineq); - } - - bool sls::cm(bool old_sign, ineq const& ineq, var_t v, int64_t& new_value) { - for (auto const& [coeff, w] : ineq.m_args) - if (w == v) - return cm(old_sign, ineq, v, coeff, new_value); - return false; - } - - bool sls::cm(bool old_sign, ineq const& ineq, var_t v, int64_t coeff, int64_t& new_value) { - SASSERT(ineq.is_true() != old_sign); - VERIFY(ineq.is_true() != old_sign); - auto bound = ineq.m_bound; - auto argsv = ineq.m_args_value; - bool solved = false; - int64_t delta = argsv - bound; - auto make_eq = [&]() { - SASSERT(delta != 0); - if (delta < 0) - new_value = value(v) + (abs(delta) + abs(coeff) - 1) / coeff; - else - new_value = value(v) - (delta + abs(coeff) - 1) / coeff; - solved = argsv + coeff * (new_value - value(v)) == bound; - if (!solved && abs(coeff) == 1) { - verbose_stream() << "did not solve equality " << ineq << " for " << v << "\n"; - verbose_stream() << new_value << " " << value(v) << " delta " << delta << " lhs " << (argsv + coeff * (new_value - value(v))) << " bound " << bound << "\n"; - UNREACHABLE(); - } - return solved; - }; - - auto make_diseq = [&]() { - if (delta >= 0) - delta++; - else - delta--; - new_value = value(v) + (abs(delta) + abs(coeff) - 1) / coeff; - VERIFY(argsv + coeff * (new_value - value(v)) != bound); - return true; - }; - - if (!old_sign) { - switch (ineq.m_op) { - case ineq_kind::LE: - // args <= bound -> args > bound - SASSERT(argsv <= bound); - SASSERT(delta <= 0); - --delta; - new_value = value(v) + (abs(delta) + abs(coeff) - 1) / coeff; - VERIFY(argsv + coeff * (new_value - value(v)) > bound); - return true; - case ineq_kind::LT: - // args < bound -> args >= bound - SASSERT(argsv <= ineq.m_bound); - SASSERT(delta <= 0); - new_value = value(v) + (abs(delta) + abs(coeff) - 1) / coeff; - VERIFY(argsv + coeff * (new_value - value(v)) >= bound); - return true; - case ineq_kind::EQ: - return make_diseq(); - case ineq_kind::NE: - return make_eq(); - default: - UNREACHABLE(); - break; - } - } - else { - switch (ineq.m_op) { - case ineq_kind::LE: - SASSERT(argsv > ineq.m_bound); - SASSERT(delta > 0); - new_value = value(v) - (delta + abs(coeff) - 1) / coeff; - VERIFY(argsv + coeff * (new_value - value(v)) <= bound); - return true; - case ineq_kind::LT: - SASSERT(argsv >= ineq.m_bound); - SASSERT(delta >= 0); - ++delta; - new_value = value(v) - (abs(delta) + abs(coeff) - 1) / coeff; - VERIFY(argsv + coeff * (new_value - value(v)) < bound); - return true; - case ineq_kind::NE: - return make_diseq(); - case ineq_kind::EQ: - return make_eq(); - default: - UNREACHABLE(); - break; - } - } - return false; - } - - // flip on the first positive score - // it could be changed to flip on maximal positive score - // or flip on maximal non-negative score - // or flip on first non-negative score - bool sls::flip(bool sign, ineq const& ineq) { - int64_t new_value; - auto v = ineq.m_var_to_flip; - if (v == UINT_MAX) { - IF_VERBOSE(1, verbose_stream() << "no var to flip\n"); - return false; - } - if (!cm(sign, ineq, v, new_value)) { - verbose_stream() << "no critical move for " << v << "\n"; - return false; - } - update(v, new_value); - return true; - } - - // - // dscore(op) = sum_c (dts(c,alpha) - dts(c,alpha_after)) * weight(c) - // TODO - use cached dts instead of computed dts - // cached dts has to be updated when the score of literals are updated. - // - double sls::dscore(var_t v, int64_t new_value) const { - double score = 0; - auto const& vi = m_vars[v]; - for (auto const& [coeff, bv] : vi.m_bool_vars) { - sat::literal lit(bv, false); - for (auto cl : m_bool_search->get_use_list(lit)) - score += (compute_dts(cl) - dts(cl, v, new_value)) * m_bool_search->get_weight(cl); - for (auto cl : m_bool_search->get_use_list(~lit)) - score += (compute_dts(cl) - dts(cl, v, new_value)) * m_bool_search->get_weight(cl); - } - return score; - } - - // - // cm_score is costly. It involves several cache misses. - // Note that - // - m_bool_search->get_use_list(lit).size() is "often" 1 or 2 - // - dtt_old can be saved - // - int sls::cm_score(var_t v, int64_t new_value) { - int score = 0; - auto& vi = m_vars[v]; - int64_t old_value = vi.m_value; - for (auto const& [coeff, bv] : vi.m_bool_vars) { - auto const& ineq = *atom(bv); - bool old_sign = sign(bv); - int64_t dtt_old = dtt(old_sign, ineq); - int64_t dtt_new = dtt(old_sign, ineq, coeff, old_value, new_value); - if ((dtt_old == 0) == (dtt_new == 0)) - continue; - sat::literal lit(bv, old_sign); - if (dtt_old == 0) - // flip from true to false - lit.neg(); - - // lit flips form false to true: - for (auto cl : m_bool_search->get_use_list(lit)) { - auto const& clause = get_clause_info(cl); - if (!clause.is_true()) - ++score; - } - // ignore the situation where clause contains multiple literals using v - for (auto cl : m_bool_search->get_use_list(~lit)) { - auto const& clause = get_clause_info(cl); - if (clause.m_num_trues == 1) - --score; - } - } - return score; - } - - int64_t sls::compute_dts(unsigned cl) const { - int64_t d(1), d2; - bool first = true; - for (auto a : get_clause(cl)) { - auto const* ineq = atom(a.var()); - if (!ineq) - continue; - d2 = dtt(a.sign(), *ineq); - if (first) - d = d2, first = false; - else - d = std::min(d, d2); - if (d == 0) - break; - } - return d; - } - - int64_t sls::dts(unsigned cl, var_t v, int64_t new_value) const { - int64_t d(1), d2; - bool first = true; - for (auto lit : get_clause(cl)) { - auto const* ineq = atom(lit.var()); - if (!ineq) - continue; - d2 = dtt(lit.sign(), *ineq, v, new_value); - if (first) - d = d2, first = false; - else - d = std::min(d, d2); - if (d == 0) - break; - } - return d; - } - - void sls::update(var_t v, int64_t new_value) { - auto& vi = m_vars[v]; - auto old_value = vi.m_value; - for (auto const& [coeff, bv] : vi.m_bool_vars) { - auto& ineq = *atom(bv); - bool old_sign = sign(bv); - sat::literal lit(bv, old_sign); - SASSERT(is_true(lit)); - ineq.m_args_value += coeff * (new_value - old_value); - int64_t dtt_new = dtt(old_sign, ineq); - if (dtt_new != 0) - m_bool_search->flip(bv); - SASSERT(dtt(sign(bv), ineq) == 0); - } - vi.m_value = new_value; - } - - void sls::add_vars() { - SASSERT(m_vars.empty()); - for (unsigned v = 0; v < s.get_num_vars(); ++v) { - int64_t value = s.is_registered_var(v) ? to_numeral(s.get_ivalue(v).x) : 0; - auto k = s.is_int(v) ? sls::var_kind::INT : sls::var_kind::REAL; - m_vars.push_back({ value, value, k, {} }); - } - } - - sls::ineq& sls::new_ineq(ineq_kind op, int64_t const& bound) { - auto* i = alloc(ineq); - i->m_bound = bound; - i->m_op = op; - return *i; - } - - void sls::add_arg(sat::bool_var bv, ineq& ineq, int64_t const& c, var_t v) { - ineq.m_args.push_back({ c, v }); - ineq.m_args_value += c * value(v); - m_vars[v].m_bool_vars.push_back({ c, bv}); - } - - int64_t sls::to_numeral(rational const& r) { - if (r.is_int64()) - return r.get_int64(); - return 0; - } - - void sls::add_args(sat::bool_var bv, ineq& ineq, lp::lpvar t, theory_var v, int64_t sign) { - if (s.lp().column_has_term(t)) { - lp::lar_term const& term = s.lp().get_term(t); - m_terms.push_back({t,v}); - for (lp::lar_term::ival arg : term) { - auto t2 = arg.j(); - auto w = s.lp().local_to_external(t2); - add_arg(bv, ineq, sign * to_numeral(arg.coeff()), w); - } - } - else - add_arg(bv, ineq, sign, s.lp().local_to_external(t)); - } - - void sls::init_bool_var(sat::bool_var bv) { - if (m_bool_vars.get(bv, nullptr)) - return; - api_bound* b = nullptr; - s.m_bool_var2bound.find(bv, b); - if (b) { - auto t = b->column_index(); - rational bound = b->get_value(); - bool should_minus = false; - sls::ineq_kind op; - should_minus = b->get_bound_kind() == lp_api::bound_kind::lower_t; - op = sls::ineq_kind::LE; - if (should_minus) - bound.neg(); - - auto& ineq = new_ineq(op, to_numeral(bound)); - - - add_args(bv, ineq, t, b->get_var(), should_minus ? -1 : 1); - m_bool_vars.set(bv, &ineq); - m_bool_search->set_external(bv); - return; - } - - expr* e = s.bool_var2expr(bv); - expr* l = nullptr, * r = nullptr; - if (e && m.is_eq(e, l, r) && s.a.is_int_real(l)) { - theory_var u = s.get_th_var(l); - theory_var v = s.get_th_var(r); - lp::lpvar tu = s.get_column(u); - lp::lpvar tv = s.get_column(v); - auto& ineq = new_ineq(sls::ineq_kind::EQ, 0); - add_args(bv, ineq, tu, u, 1); - add_args(bv, ineq, tv, v, -1); - m_bool_vars.set(bv, &ineq); - m_bool_search->set_external(bv); - return; - } - } - - void sls::init_bool_var_assignment(sat::bool_var v) { - auto* ineq = m_bool_vars.get(v, nullptr); - if (ineq && is_true(sat::literal(v, false)) != (dtt(false, *ineq) == 0)) - m_bool_search->flip(v); - } - - void sls::init_search() { - on_restart(); - } - - void sls::finish_search() { - store_best_values(); - } - - void sls::flip(sat::bool_var v) { - sat::literal lit(v, !sign(v)); - SASSERT(!is_true(lit)); - auto const* ineq = atom(v); - if (!ineq) - IF_VERBOSE(0, verbose_stream() << "no inequality for variable " << v << "\n"); - if (!ineq) - return; - SASSERT(ineq->is_true() == lit.sign()); - flip(sign(v), *ineq); - } - - double sls::reward(sat::bool_var v) { - if (m_dscore_mode) - return dscore_reward(v); - else - return dtt_reward(v); - } - - double sls::dtt_reward(sat::bool_var bv0) { - bool sign0 = sign(bv0); - auto* ineq = atom(bv0); - if (!ineq) - return -1; - int64_t new_value; - double max_result = -1; - for (auto const & [coeff, x] : ineq->m_args) { - if (!cm(sign0, *ineq, x, coeff, new_value)) - continue; - double result = 0; - auto old_value = m_vars[x].m_value; - for (auto const& [coeff, bv] : m_vars[x].m_bool_vars) { - result += m_bool_search->reward(bv); - continue; - bool old_sign = sign(bv); - auto dtt_old = dtt(old_sign, *atom(bv)); - auto dtt_new = dtt(old_sign, *atom(bv), coeff, old_value, new_value); - if ((dtt_new == 0) != (dtt_old == 0)) - result += m_bool_search->reward(bv); - } - if (result > max_result) { - max_result = result; - ineq->m_var_to_flip = x; - } - } - return max_result; - } - - double sls::dscore_reward(sat::bool_var bv) { - m_dscore_mode = false; - bool old_sign = sign(bv); - sat::literal litv(bv, old_sign); - auto* ineq = atom(bv); - if (!ineq) - return 0; - SASSERT(ineq->is_true() != old_sign); - int64_t new_value; - - for (auto const& [coeff, v] : ineq->m_args) { - double result = 0; - if (cm(old_sign, *ineq, v, coeff, new_value)) - result = dscore(v, new_value); - // just pick first positive, or pick a max? - if (result > 0) { - ineq->m_var_to_flip = v; - return result; - } - } - return 0; - } - - // switch to dscore mode - void sls::on_rescale() { - m_dscore_mode = true; - } - - void sls::on_save_model() { - save_best_values(); - } - - void sls::on_restart() { - for (unsigned v = 0; v < s.s().num_vars(); ++v) - init_bool_var_assignment(v); - - check_ineqs(); - } - - void sls::check_ineqs() { - - auto check_bool_var = [&](sat::bool_var bv) { - auto const* ineq = atom(bv); - if (!ineq) - return; - int64_t d = dtt(sign(bv), *ineq); - sat::literal lit(bv, sign(bv)); - if (is_true(lit) != (d == 0)) { - verbose_stream() << "invalid assignment " << bv << " " << *ineq << "\n"; - } - VERIFY(is_true(lit) == (d == 0)); - }; - for (unsigned v = 0; v < s.get_num_vars(); ++v) - check_bool_var(v); - } - - std::ostream& sls::display(std::ostream& out) const { - for (bool_var bv = 0; bv < s.s().num_vars(); ++bv) { - auto const* ineq = atom(bv); - if (!ineq) - continue; - out << bv << " " << *ineq << "\n"; - } - for (unsigned v = 0; v < s.get_num_vars(); ++v) { - if (s.is_bool(v)) - continue; - out << "v" << v << " := " << m_vars[v].m_value << " " << m_vars[v].m_best_value << "\n"; - } - return out; - } - -} diff --git a/src/sat/smt/arith_sls.h b/src/sat/smt/arith_sls.h deleted file mode 100644 index 55d39b252..000000000 --- a/src/sat/smt/arith_sls.h +++ /dev/null @@ -1,170 +0,0 @@ -/*++ -Copyright (c) 2020 Microsoft Corporation - -Module Name: - - arith_local_search.h - -Abstract: - - Theory plugin for arithmetic local search - -Author: - - Nikolaj Bjorner (nbjorner) 2020-09-08 - ---*/ -#pragma once - -#include "util/obj_pair_set.h" -#include "ast/ast_trail.h" -#include "ast/arith_decl_plugin.h" -#include "math/lp/indexed_value.h" -#include "math/lp/lar_solver.h" -#include "math/lp/nla_solver.h" -#include "math/lp/lp_types.h" -#include "math/lp/lp_api.h" -#include "math/polynomial/algebraic_numbers.h" -#include "math/polynomial/polynomial.h" -#include "sat/smt/sat_th.h" -#include "sat/sat_ddfw.h" - -namespace arith { - - class solver; - - // local search portion for arithmetic - class sls : public sat::local_search_plugin { - enum class ineq_kind { EQ, LE, LT, NE }; - enum class var_kind { INT, REAL }; - typedef unsigned var_t; - typedef unsigned atom_t; - - struct config { - double cb = 0.0; - unsigned L = 20; - unsigned t = 45; - unsigned max_no_improve = 500000; - double sp = 0.0003; - }; - - struct stats { - unsigned m_num_flips = 0; - }; - - public: - // encode args <= bound, args = bound, args < bound - struct ineq { - vector> m_args; - ineq_kind m_op = ineq_kind::LE; - int64_t m_bound; - int64_t m_args_value; - unsigned m_var_to_flip = UINT_MAX; - - bool is_true() const { - switch (m_op) { - case ineq_kind::LE: - return m_args_value <= m_bound; - case ineq_kind::EQ: - return m_args_value == m_bound; - case ineq_kind::NE: - return m_args_value != m_bound; - default: - return m_args_value < m_bound; - } - } - std::ostream& display(std::ostream& out) const { - bool first = true; - for (auto const& [c, v] : m_args) - out << (first ? "" : " + ") << c << " * v" << v, first = false; - switch (m_op) { - case ineq_kind::LE: - return out << " <= " << m_bound << "(" << m_args_value << ")"; - case ineq_kind::EQ: - return out << " == " << m_bound << "(" << m_args_value << ")"; - case ineq_kind::NE: - return out << " != " << m_bound << "(" << m_args_value << ")"; - default: - return out << " < " << m_bound << "(" << m_args_value << ")"; - } - } - }; - private: - - struct var_info { - int64_t m_value; - int64_t m_best_value; - var_kind m_kind = var_kind::INT; - svector> m_bool_vars; - }; - - solver& s; - ast_manager& m; - sat::ddfw* m_bool_search = nullptr; - stats m_stats; - config m_config; - scoped_ptr_vector m_bool_vars; - vector m_vars; - svector> m_terms; - bool m_dscore_mode = false; - - - indexed_uint_set& unsat() { return m_bool_search->unsat_set(); } - unsigned num_clauses() const { return m_bool_search->num_clauses(); } - sat::clause& get_clause(unsigned idx) { return *get_clause_info(idx).m_clause; } - sat::clause const& get_clause(unsigned idx) const { return *get_clause_info(idx).m_clause; } - sat::ddfw::clause_info& get_clause_info(unsigned idx) { return m_bool_search->get_clause_info(idx); } - sat::ddfw::clause_info const& get_clause_info(unsigned idx) const { return m_bool_search->get_clause_info(idx); } - bool is_true(sat::literal lit) { return lit.sign() != m_bool_search->get_value(lit.var()); } - bool sign(sat::bool_var v) const { return !m_bool_search->get_value(v); } - - void reset(); - ineq* atom(sat::bool_var bv) const { return m_bool_vars[bv]; } - - bool flip(bool sign, ineq const& ineq); - int64_t dtt(bool sign, ineq const& ineq) const { return dtt(sign, ineq.m_args_value, ineq); } - int64_t dtt(bool sign, int64_t args_value, ineq const& ineq) const; - int64_t dtt(bool sign, ineq const& ineq, var_t v, int64_t new_value) const; - int64_t dtt(bool sign, ineq const& ineq, int64_t coeff, int64_t old_value, int64_t new_value) const; - int64_t dts(unsigned cl, var_t v, int64_t new_value) const; - int64_t compute_dts(unsigned cl) const; - bool cm(bool sign, ineq const& ineq, var_t v, int64_t& new_value); - bool cm(bool sign, ineq const& ineq, var_t v, int64_t coeff, int64_t& new_value); - int cm_score(var_t v, int64_t new_value); - void update(var_t v, int64_t new_value); - double dscore_reward(sat::bool_var v); - double dtt_reward(sat::bool_var v); - double dscore(var_t v, int64_t new_value) const; - void save_best_values(); - void store_best_values(); - void add_vars(); - sls::ineq& new_ineq(ineq_kind op, int64_t const& bound); - void add_arg(sat::bool_var bv, ineq& ineq, int64_t const& c, var_t v); - void add_args(sat::bool_var bv, ineq& ineq, lp::lpvar j, euf::theory_var v, int64_t sign); - void init_bool_var(sat::bool_var v); - void init_bool_var_assignment(sat::bool_var v); - - int64_t value(var_t v) const { return m_vars[v].m_value; } - int64_t to_numeral(rational const& r); - - void check_ineqs(); - - std::ostream& display(std::ostream& out) const; - - public: - sls(solver& s); - ~sls() override {} - void set(sat::ddfw* d); - void init_search() override; - void finish_search() override; - void flip(sat::bool_var v) override; - double reward(sat::bool_var v) override; - void on_rescale() override; - void on_save_model() override; - void on_restart() override; - }; - - inline std::ostream& operator<<(std::ostream& out, sls::ineq const& ineq) { - return ineq.display(out); - } -} diff --git a/src/sat/smt/arith_solver.cpp b/src/sat/smt/arith_solver.cpp index 3086d75f4..3da8fa4a0 100644 --- a/src/sat/smt/arith_solver.cpp +++ b/src/sat/smt/arith_solver.cpp @@ -24,7 +24,6 @@ namespace arith { solver::solver(euf::solver& ctx, theory_id id) : th_euf_solver(ctx, symbol("arith"), id), m_model_eqs(DEFAULT_HASHTABLE_INITIAL_CAPACITY, var_value_hash(*this), var_value_eq(*this)), - m_local_search(*this), m_resource_limit(*this), m_bp(*this, m_implied_bounds), a(m), diff --git a/src/sat/smt/arith_solver.h b/src/sat/smt/arith_solver.h index 755611474..48530bc83 100644 --- a/src/sat/smt/arith_solver.h +++ b/src/sat/smt/arith_solver.h @@ -28,7 +28,6 @@ Author: #include "math/polynomial/algebraic_numbers.h" #include "math/polynomial/polynomial.h" #include "sat/smt/sat_th.h" -#include "sat/smt/arith_sls.h" #include "sat/sat_ddfw.h" namespace euf { @@ -186,8 +185,6 @@ namespace arith { coeffs().pop_back(); } }; - - sls m_local_search; typedef vector> var_coeffs; vector m_columns; @@ -518,8 +515,6 @@ namespace arith { bool enable_ackerman_axioms(euf::enode* n) const override { return !a.is_add(n->get_expr()); } bool has_unhandled() const override { return m_not_handled != nullptr; } - void set_bool_search(sat::ddfw* ddfw) override { m_local_search.set(ddfw); } - // bounds and equality propagation callbacks lp::lar_solver& lp() { return *m_solver; } lp::lar_solver const& lp() const { return *m_solver; } diff --git a/src/sat/smt/euf_internalize.cpp b/src/sat/smt/euf_internalize.cpp index ebb6e4b85..602364e7d 100644 --- a/src/sat/smt/euf_internalize.cpp +++ b/src/sat/smt/euf_internalize.cpp @@ -525,8 +525,8 @@ namespace euf { return n; } - void solver::add_assertion(expr* f) { - m_assertions.push_back(f); - m_trail.push(push_back_vector(m_assertions)); + void solver::add_clause(unsigned n, sat::literal const* lits) { + m_top_level_clauses.push_back(sat::literal_vector(n, lits)); + m_trail.push(push_back_vector(m_top_level_clauses)); } } diff --git a/src/sat/smt/euf_local_search.cpp b/src/sat/smt/euf_local_search.cpp deleted file mode 100644 index ca450e513..000000000 --- a/src/sat/smt/euf_local_search.cpp +++ /dev/null @@ -1,50 +0,0 @@ -/*++ -Copyright (c) 2020 Microsoft Corporation - -Module Name: - - euf_local_search.cpp - -Abstract: - - Local search dispatch for SMT - -Author: - - Nikolaj Bjorner (nbjorner) 2023-02-07 - ---*/ -#include "sat/sat_solver.h" -#include "sat/sat_ddfw.h" -#include "sat/smt/euf_solver.h" - - -namespace euf { - - lbool solver::local_search(bool_vector& phase) { - scoped_limits scoped_rl(m.limit()); - sat::ddfw bool_search; - bool_search.reinit(s(), phase); - bool_search.updt_params(s().params()); - bool_search.set_seed(rand()); - scoped_rl.push_child(&(bool_search.rlimit())); - - for (auto* th : m_solvers) - th->set_bool_search(&bool_search); - - bool_search.check(0, nullptr, nullptr); - - auto const& mdl = bool_search.get_model(); - for (unsigned i = 0; i < mdl.size(); ++i) - phase[i] = mdl[i] == l_true; - - if (bool_search.unsat_set().empty()) { - enable_trace("arith"); - enable_trace("sat"); - enable_trace("euf"); - TRACE("sat", s().display(tout)); - } - - return bool_search.unsat_set().empty() ? l_true : l_undef; - } -} diff --git a/src/sat/smt/euf_solver.cpp b/src/sat/smt/euf_solver.cpp index b866990af..b91051b39 100644 --- a/src/sat/smt/euf_solver.cpp +++ b/src/sat/smt/euf_solver.cpp @@ -55,7 +55,6 @@ namespace euf { m_smt_proof_checker(m, p), m_clause(m), m_expr_args(m), - m_assertions(m), m_values(m) { updt_params(p); diff --git a/src/sat/smt/euf_solver.h b/src/sat/smt/euf_solver.h index 22e68d2da..16d4e22ec 100644 --- a/src/sat/smt/euf_solver.h +++ b/src/sat/smt/euf_solver.h @@ -100,15 +100,6 @@ namespace euf { scope(unsigned l) : m_var_lim(l) {} }; - struct local_search_config { - double cb = 0.0; - unsigned L = 20; - unsigned t = 45; - unsigned max_no_improve = 500000; - double sp = 0.0003; - }; - - size_t* to_ptr(sat::literal l) { return TAG(size_t*, reinterpret_cast((size_t)(l.index() << 4)), 1); } size_t* to_ptr(size_t jst) { return TAG(size_t*, reinterpret_cast(jst), 2); } bool is_literal(size_t* p) const { return GET_TAG(p) == 1; } @@ -127,7 +118,6 @@ namespace euf { sat::sat_internalizer& si; relevancy m_relevancy; smt_params m_config; - local_search_config m_ls_config; euf::egraph m_egraph; trail_stack m_trail; stats m_stats; @@ -174,7 +164,7 @@ namespace euf { symbol m_smt = symbol("smt"); expr_ref_vector m_clause; expr_ref_vector m_expr_args; - expr_ref_vector m_assertions; + vector m_top_level_clauses; // internalization @@ -356,7 +346,6 @@ namespace euf { void add_assumptions(sat::literal_set& assumptions) override; bool tracking_assumptions() override; std::string reason_unknown() override { return m_reason_unknown; } - lbool local_search(bool_vector& phase) override; void propagate(literal lit, ext_justification_idx idx); bool propagate(enode* a, enode* b, ext_justification_idx idx); @@ -485,8 +474,10 @@ namespace euf { bool enable_ackerman_axioms(expr* n) const; bool is_fixed(euf::enode* n, expr_ref& val, sat::literal_vector& explain); - void add_assertion(expr* f); - expr_ref_vector const& get_assertions() { return m_assertions; } + // void add_assertion(expr* f); + // expr_ref_vector const& get_assertions() { return m_assertions; } + void add_clause(unsigned n, sat::literal const* lits); + vector const& top_level_clauses() const { return m_top_level_clauses; } model_ref get_sls_model(); // relevancy diff --git a/src/sat/smt/sls_solver.cpp b/src/sat/smt/sls_solver.cpp index a507619ee..5028bf239 100644 --- a/src/sat/smt/sls_solver.cpp +++ b/src/sat/smt/sls_solver.cpp @@ -17,8 +17,7 @@ Author: #include "sat/smt/sls_solver.h" #include "sat/smt/euf_solver.h" - - +#include "ast/sls/sls_smt.h" namespace sls { @@ -38,14 +37,14 @@ namespace sls { } void solver::finalize() { - if (!m_completed && m_sls) { - m_sls->cancel(); - m_thread.join(); - m_sls->collect_statistics(m_st); - m_sls = nullptr; - m_shared = nullptr; + if (!m_completed && m_ddfw) { + m_ddfw->rlimit().cancel(); + m_thread.join(); + m_ddfw->collect_statistics(m_st); + m_ddfw = nullptr; m_slsm = nullptr; - m_units = nullptr; + m_smt_plugin = nullptr; + m_units.reset(); } } @@ -59,107 +58,162 @@ namespace sls { return false; } - bool solver::is_unit(expr* e) { - if (!e) - return false; - m.is_not(e, e); - if (is_uninterp_const(e)) - return true; - bv_util bu(m); - expr* s; - if (bu.is_bit2bool(e, s)) - return is_uninterp_const(s); - return false; - } - void solver::pop_core(unsigned n) { for (; m_trail_lim < s().init_trail_size(); ++m_trail_lim) { auto lit = s().trail_literal(m_trail_lim); - auto e = ctx.literal2expr(lit); - if (is_unit(e)) { - // IF_VERBOSE(1, verbose_stream() << "add unit " << mk_pp(e, m) << "\n"); - std::lock_guard lock(m_mutex); - ast_translation tr(m, *m_shared); - m_units->push_back(tr(e.get())); - m_has_units = true; - } + std::lock_guard lock(m_mutex); + m_units.push_back(lit); + m_has_units = true; } } + class solver::smt_plugin : public sat::local_search_plugin, public sls::sat_solver_context { + ast_manager& m; + sat::ddfw* m_ddfw; + solver& s; + sls::context m_context; + bool m_new_clause_added = false; + public: + smt_plugin(ast_manager& m, solver& s, sat::ddfw* d) : + m(m), s(s), m_ddfw(d), m_context(m, *this) {} + + void init_search() override {} + + void finish_search() override {} + + void on_rescale() override {} + + void on_restart() override { + if (!s.m_has_units) + return; + { + std::lock_guard lock(s.m_mutex); + for (auto lit : s.m_units) + if (m_ddfw->is_initial_var(lit.var())) + m_ddfw->add(1, &lit); + s.m_has_units = false; + s.m_units.reset(); + } + m_ddfw->reinit(); + } + + void on_save_model() override { + TRACE("sls", display(tout)); + while (unsat().empty()) { + m_context.check(); + if (!m_new_clause_added) + break; + m_ddfw->reinit(); + m_new_clause_added = false; + } + } + + void on_model(model_ref& mdl) override { + IF_VERBOSE(1, verbose_stream() << "on-model " << "\n"); + s.m_sls_model = mdl; + } + + void register_atom(sat::bool_var v, expr* e) { + m_context.register_atom(v, e); + } + + std::ostream& display(std::ostream& out) { + m_ddfw->display(out); + m_context.display(out); + return out; + } + + vector const& clauses() const override { return m_ddfw->clauses(); } + sat::clause_info const& get_clause(unsigned idx) const override { return m_ddfw->get_clause_info(idx); } + std::initializer_list get_use_list(sat::literal lit) override { return m_ddfw->use_list(lit); } + void flip(sat::bool_var v) override { m_ddfw->flip(v); } + double reward(sat::bool_var v) override { return m_ddfw->get_reward(v); } + double get_weigth(unsigned clause_idx) override { return m_ddfw->get_clause_info(clause_idx).m_weight; } + bool is_true(sat::literal lit) override { return m_ddfw->get_value(lit.var()) != lit.sign(); } + unsigned num_vars() const override { return m_ddfw->num_vars(); } + indexed_uint_set const& unsat() const override { return m_ddfw->unsat_set(); } + sat::bool_var add_var() override { return m_ddfw->add_var(); } + void add_clause(unsigned n, sat::literal const* lits) override { + m_ddfw->add(n, lits); + m_new_clause_added = true; + } + }; + void solver::init_search() { - if (m_sls) { - m_sls->cancel(); + if (m_ddfw) { + m_ddfw->rlimit().cancel(); m_thread.join(); - m_result = l_undef; - m_completed = false; - m_has_units = false; - m_model = nullptr; - m_units = nullptr; } // set up state for local search solver here - - m_shared = alloc(ast_manager); - m_slsm = alloc(ast_manager); - m_units = alloc(expr_ref_vector, *m_shared); - ast_translation tr(m, *m_slsm); - - m_completed = false; m_result = l_undef; + m_completed = false; + m_slsm = alloc(ast_manager); + m_units.reset(); + m_has_units = false; m_model = nullptr; - m_sls = alloc(bv::sls, *m_slsm, s().params()); - - for (expr* a : ctx.get_assertions()) - m_sls->assert_expr(tr(a)); + m_sls_model = nullptr; + m_ddfw = alloc(sat::ddfw); + ast_translation tr(m, *m_slsm); + scoped_limits scoped_limits(m.limit()); + scoped_limits.push_child(&m_slsm->limit()); + scoped_limits.push_child(&m_ddfw->rlimit()); + m_smt_plugin = alloc(smt_plugin, *m_slsm, *this, m_ddfw.get()); + m_ddfw->set_plugin(m_smt_plugin); + m_ddfw->updt_params(s().params()); + for (auto const& clause : ctx.top_level_clauses()) + m_ddfw->add(clause.size(), clause.data()); + for (sat::bool_var v = 0; v < s().num_vars(); ++v) { + expr* e = ctx.bool_var2expr(v); + if (e) + m_smt_plugin->register_atom(v, tr(e)); + } - std::function eval = [&](expr* e, unsigned r) { - return false; - }; - - m_sls->init(); - m_sls->init_eval(eval); - m_sls->updt_params(s().params()); - m_sls->init_unit([&]() { - if (!m_has_units) - return expr_ref(*m_slsm); - expr_ref e(*m_slsm); - { - std::lock_guard lock(m_mutex); - if (m_units->empty()) - return expr_ref(*m_slsm); - ast_translation tr(*m_shared, *m_slsm); - e = tr(m_units->back()); - m_units->pop_back(); - } - return e; - }); - m_sls->set_model([&](model& mdl) { - std::lock_guard lock(m_mutex); - ast_translation tr(*m_shared, m); - m_model = mdl.translate(tr); - }); - - m_thread = std::thread([this]() { run_local_search(); }); + run_local_search_sync(); + // m_thread = std::thread([this]() { run_local_search_async(); }); } void solver::sample_local_search() { if (!m_completed) return; m_thread.join(); - m_completed = false; - m_sls->collect_statistics(m_st); - if (m_result == l_true) { - IF_VERBOSE(2, verbose_stream() << "(sat.sls :model-completed)\n";); - auto mdl = m_sls->get_model(); - ast_translation tr(*m_slsm, m); - m_model = mdl->translate(tr); - s().set_canceled(); - } - m_sls = nullptr; + local_search_done(); } - void solver::run_local_search() { - m_result = (*m_sls)(); - m_completed = true; + void solver::local_search_done() { + m_completed = false; + + CTRACE("sls", m_smt_plugin, m_smt_plugin->display(tout)); + if (m_ddfw) + m_ddfw->collect_statistics(m_st); + + TRACE("sls", tout << "result " << m_result << "\n"); + + if (m_result == l_true && m_sls_model) { + ast_translation tr(*m_slsm, m); + m_model = m_sls_model->translate(tr); + TRACE("sls", tout << "model: " << *m_sls_model << "\n";); + s().set_canceled(); + } + m_ddfw = nullptr; + m_smt_plugin = nullptr; + m_sls_model = nullptr; + } + + void solver::run_local_search_async() { + if (m_ddfw) { + m_result = m_ddfw->check(0, nullptr, nullptr); + m_completed = true; + } + } + + void solver::run_local_search_sync() { + m_result = m_ddfw->check(0, nullptr, nullptr); + local_search_done(); + } + + std::ostream& solver::display(std::ostream& out) const { + out << "sls-solver\n"; + return out; } #endif diff --git a/src/sat/smt/sls_solver.h b/src/sat/smt/sls_solver.h index e1d8a95b5..9d009b805 100644 --- a/src/sat/smt/sls_solver.h +++ b/src/sat/smt/sls_solver.h @@ -20,6 +20,7 @@ Author: #include "util/rlimit.h" #include "ast/sls/bv_sls.h" #include "sat/smt/sat_th.h" +#include "sat/sat_ddfw.h" #ifdef SINGLE_THREAD @@ -62,23 +63,28 @@ namespace euf { namespace sls { class solver : public euf::th_euf_solver { + class smt_plugin; + std::atomic m_result; std::atomic m_completed, m_has_units; std::thread m_thread; std::mutex m_mutex; // m is accessed by the main thread // m_slsm is accessed by the sls thread - // m_shared is only accessed at synchronization points - scoped_ptr m_shared, m_slsm; - scoped_ptr m_sls; - scoped_ptr m_units; - model_ref m_model; + scoped_ptr m_slsm; + scoped_ptr m_ddfw; + sat::literal_vector m_units; + smt_plugin* m_smt_plugin = nullptr; + model_ref m_model, m_sls_model; unsigned m_trail_lim = 0; statistics m_st; - void run_local_search(); + + + void run_local_search_async(); + void run_local_search_sync(); void sample_local_search(); - bool is_unit(expr*); + void local_search_done(); public: solver(euf::solver& ctx); @@ -98,7 +104,7 @@ namespace sls { void internalize(expr* e) override { UNREACHABLE(); } void get_antecedents(sat::literal l, sat::ext_justification_idx idx, sat::literal_vector & r, bool probing) override { UNREACHABLE(); } sat::check_result check() override; - std::ostream & display(std::ostream & out) const override { return out; } + std::ostream& display(std::ostream& out) const override; std::ostream & display_justification(std::ostream & out, sat::ext_justification_idx idx) const override { UNREACHABLE(); return out; } std::ostream & display_constraint(std::ostream & out, sat::ext_constraint_idx idx) const override { UNREACHABLE(); return out; } diff --git a/src/sat/tactic/goal2sat.cpp b/src/sat/tactic/goal2sat.cpp index 57e3a89b5..639eeb814 100644 --- a/src/sat/tactic/goal2sat.cpp +++ b/src/sat/tactic/goal2sat.cpp @@ -139,10 +139,6 @@ struct goal2sat::imp : public sat::sat_internalizer { return m_euf && ensure_euf()->relevancy_enabled(); } - bool top_level_relevant() { - return m_top_level && relevancy_enabled(); - } - void mk_clause(sat::literal l1, sat::literal l2, euf::th_proof_hint* ph) { sat::literal lits[2] = { l1, l2 }; mk_clause(2, lits, ph); @@ -158,6 +154,7 @@ struct goal2sat::imp : public sat::sat_internalizer { if (relevancy_enabled()) ensure_euf()->add_aux(n, lits); m_solver.add_clause(n, lits, mk_status(ph)); + add_top_level_clause(n, lits); } void mk_root_clause(sat::literal l) { @@ -179,6 +176,7 @@ struct goal2sat::imp : public sat::sat_internalizer { if (relevancy_enabled()) ensure_euf()->add_root(n, lits); m_solver.add_clause(n, lits, ph ? mk_status(ph) : sat::status::input()); + add_top_level_clause(n, lits); } sat::bool_var add_var(bool is_ext, expr* n) { @@ -895,7 +893,6 @@ struct goal2sat::imp : public sat::sat_internalizer { process(n, true); CTRACE("goal2sat", !m_result_stack.empty(), tout << m_result_stack << "\n";); SASSERT(m_result_stack.empty()); - add_assertion(n); } void insert_dep(expr* dep0, expr* dep, bool sign) { @@ -990,10 +987,12 @@ struct goal2sat::imp : public sat::sat_internalizer { } } - void add_assertion(expr* f) { + void add_top_level_clause(unsigned n, sat::literal const* lits) { + if (!m_top_level) + return; auto* ext = dynamic_cast(m_solver.get_extension()); if (ext) - ext->add_assertion(f); + ext->add_clause(n, lits); } void update_model(model_ref& mdl) { diff --git a/src/util/checked_int64.h b/src/util/checked_int64.h index 06b957fcf..64c4e38cc 100644 --- a/src/util/checked_int64.h +++ b/src/util/checked_int64.h @@ -163,6 +163,11 @@ public: return *this; } + checked_int64& operator/=(checked_int64 const& other) { + m_value /= other.m_value; + return *this; + } + friend inline checked_int64 abs(checked_int64 const& i) { return i.abs(); } @@ -174,21 +179,42 @@ inline bool operator!=(checked_int64 const & i1, checked_int64 con return !operator==(i1, i2); } +template +inline bool operator!=(checked_int64 const& i1, int64_t const& i2) { + return !operator==(i1, i2); +} + template inline bool operator>(checked_int64 const & i1, checked_int64 const & i2) { return operator<(i2, i1); } +template +inline bool operator>(checked_int64 const& i1, int64_t i2) { + return operator<(i2, i1); +} + template inline bool operator<=(checked_int64 const & i1, checked_int64 const & i2) { return !operator>(i1, i2); } +template +inline bool operator<=(checked_int64 const& i1, int64_t const& i2) { + return !operator>(i1, i2); +} + template inline bool operator>=(checked_int64 const & i1, checked_int64 const & i2) { return !operator<(i1, i2); } + +template +inline bool operator>=(checked_int64 const& i1, int64_t const& i2) { + return !operator<(i1, i2); +} + template inline checked_int64 operator-(checked_int64 const& i) { checked_int64 result(i); @@ -202,6 +228,14 @@ inline checked_int64 operator+(checked_int64 const& a, checked_int return result; } +template +inline checked_int64 operator+(checked_int64 const& a, int64_t const& b) { + checked_int64 result(a); + checked_int64 _b(b); + result += _b; + return result; +} + template inline checked_int64 operator-(checked_int64 const& a, checked_int64 const& b) { checked_int64 result(a); @@ -209,9 +243,39 @@ inline checked_int64 operator-(checked_int64 const& a, checked_int return result; } +template +inline checked_int64 operator-(checked_int64 const& a, int64_t const& b) { + checked_int64 result(a); + checked_int64 _b(b); + result -= _b; + return result; +} + template inline checked_int64 operator*(checked_int64 const& a, checked_int64 const& b) { checked_int64 result(a); result *= b; return result; } + +template +inline checked_int64 operator*(int64_t const& a, checked_int64 const& b) { + checked_int64 result(a); + result *= b; + return result; +} + +template +inline checked_int64 operator*(checked_int64 const& a, int64_t const& b) { + checked_int64 result(a); + checked_int64 _b(b); + result *= _b; + return result; +} + +template +inline checked_int64 div(checked_int64 const& a, checked_int64 const& b) { + checked_int64 result(a); + result /= b; + return result; +} diff --git a/src/util/sat_sls.h b/src/util/sat_sls.h new file mode 100644 index 000000000..bcd3a1c74 --- /dev/null +++ b/src/util/sat_sls.h @@ -0,0 +1,37 @@ +/*++ +Copyright (c) 2011 Microsoft Corporation + +Module Name: + + sat_sls.h + +Abstract: + + Base types for SLS. + +Author: + + Nikolaj Bjorner (nbjorner) 2024-06027 + +--*/ +#pragma once + +#include "util/sat_literal.h" + +namespace sat { + + struct clause_info { + clause_info(unsigned n, literal const* lits, double init_weight): m_weight(init_weight), m_clause(n, lits) {} + double m_weight; // weight of clause + unsigned m_trues = 0; // set of literals that are true + unsigned m_num_trues = 0; // size of true set + literal_vector m_clause; + literal const* begin() const { return m_clause.begin(); } + literal const* end() const { return m_clause.end(); } + bool is_true() const { return m_num_trues > 0; } + void add(literal lit) { ++m_num_trues; m_trues += lit.index(); } + void del(literal lit) { SASSERT(m_num_trues > 0); --m_num_trues; m_trues -= lit.index(); } + }; +}; + +