3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-06 01:24:08 +00:00

updates to sls

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2024-03-30 12:59:02 -07:00
parent 111fcb9366
commit 51f1e2655c
13 changed files with 234 additions and 105 deletions

View file

@ -329,6 +329,7 @@ namespace euf {
m_config.m_context_solve = p.get_bool("context_solve", tp.solve_eqs_context_solve());
for (auto* ex : m_extract_plugins)
ex->updt_params(p);
m_rewriter.updt_params(p);
}
void solve_eqs::collect_param_descrs(param_descrs& r) {

View file

@ -25,12 +25,15 @@ Author:
namespace bv {
sls::sls(ast_manager& m):
sls::sls(ast_manager& m, params_ref const& p):
m(m),
bv(m),
m_terms(m),
m_eval(m)
{}
m_eval(m),
m_engine(m, p)
{
updt_params(p);
}
void sls::init() {
m_terms.init();
@ -67,20 +70,48 @@ namespace bv {
}
}
void sls::init_repair_candidates() {
m_to_repair.reset();
ptr_vector<expr> todo;
expr_fast_mark1 mark;
for (auto index : m_repair_roots)
todo.push_back(m_terms.term(index));
for (unsigned i = 0; i < todo.size(); ++i) {
expr* e = todo[i];
if (mark.is_marked(e))
continue;
mark.mark(e);
if (!is_app(e))
continue;
for (expr* arg : *to_app(e))
todo.push_back(arg);
if (is_uninterp_const(e))
m_to_repair.insert(e->get_id());
}
}
void sls::reinit_eval() {
init_repair_candidates();
if (m_to_repair.empty())
return;
std::function<bool(expr*, unsigned)> eval = [&](expr* e, unsigned i) {
auto should_keep = [&]() {
return m_rand() % 100 <= 92;
};
if (m.is_bool(e)) {
if (m_eval.is_fixed0(e) || should_keep())
return m_eval.bval0(e);
}
unsigned id = e->get_id();
bool keep = (m_rand() % 100 <= 50) || !m_to_repair.contains(id);
if (m.is_bool(e) && (m_eval.is_fixed0(e) || keep))
return m_eval.bval0(e);
else if (bv.is_bv(e)) {
auto& w = m_eval.wval(e);
if (w.fixed.get(i) || should_keep())
return w.get_bit(i);
}
if (w.fixed.get(i) || keep)
return w.get_bit(i);
//auto const& z = m_engine.get_value(e);
//return rational(z).get_bit(i);
}
return m_rand() % 2 == 0;
};
m_eval.init_eval(m_terms.assertions(), eval);
@ -119,7 +150,7 @@ namespace bv {
return { false, nullptr };
}
lbool sls::search() {
lbool sls::search1() {
// init and init_eval were invoked
unsigned n = 0;
for (; n++ < m_config.m_max_repairs && m.inc(); ) {
@ -127,7 +158,6 @@ namespace bv {
if (!e)
return l_true;
trace_repair(down, e);
++m_stats.m_moves;
@ -140,16 +170,32 @@ namespace bv {
return l_undef;
}
lbool sls::search2() {
lbool res = l_undef;
if (m_stats.m_restarts == 0)
res = m_engine();
else if (m_stats.m_restarts % 1000 == 0)
res = m_engine.search_loop();
if (res != l_undef)
m_engine_model = true;
return res;
}
lbool sls::operator()() {
lbool res = l_undef;
m_stats.reset();
m_stats.m_restarts = 0;
m_engine_model = false;
do {
res = search();
res = search1();
if (res != l_undef)
break;
trace();
res = search2();
if (res != l_undef)
break;
reinit_eval();
}
while (m.inc() && m_stats.m_restarts++ < m_config.m_max_restarts);
@ -158,34 +204,60 @@ namespace bv {
}
void sls::try_repair_down(app* e) {
unsigned n = e->get_num_args();
if (n == 0) {
if (m.is_bool(e))
m_eval.set(e, m_eval.bval1(e));
else
if (m.is_bool(e)) {
m_eval.set(e, m_eval.bval1(e));
}
else {
VERIFY(m_eval.wval(e).commit_eval());
}
for (auto p : m_terms.parents(e))
m_repair_up.insert(p->get_id());
return;
}
unsigned s = m_rand(n);
for (unsigned i = 0; i < n; ++i) {
auto j = (i + s) % n;
if (m_eval.try_repair(e, j)) {
set_repair_down(e->get_arg(j));
if (n == 2) {
auto d1 = get_depth(e->get_arg(0));
auto d2 = get_depth(e->get_arg(1));
unsigned s = m_rand(d1 + d2 + 2);
if (s <= d1 && m_eval.try_repair(e, 0)) {
set_repair_down(e->get_arg(0));
return;
}
if (m_eval.try_repair(e, 1)) {
set_repair_down(e->get_arg(1));
return;
}
if (m_eval.try_repair(e, 0)) {
set_repair_down(e->get_arg(0));
return;
}
}
// search a new root / random walk to repair
else {
unsigned s = m_rand(n);
for (unsigned i = 0; i < n; ++i) {
auto j = (i + s) % n;
if (m_eval.try_repair(e, j)) {
set_repair_down(e->get_arg(j));
return;
}
}
}
// repair was not successful, so reset the state to find a different way to repair
init_repair();
}
void sls::try_repair_up(app* e) {
if (m_terms.is_assertion(e) || !m_eval.repair_up(e))
m_repair_roots.insert(e->get_id());
if (m_terms.is_assertion(e))
m_repair_roots.insert(e->get_id());
else if (!m_eval.repair_up(e)) {
//m_repair_roots.insert(e->get_id());
init_repair();
}
else {
if (!eval_is_correct(e)) {
verbose_stream() << "incorrect eval #" << e->get_id() << " " << mk_bounded_pp(e, m) << "\n";
@ -224,7 +296,10 @@ namespace bv {
}
model_ref sls::get_model() {
model_ref mdl = alloc(model, m);
if (m_engine_model)
return m_engine.get_model();
model_ref mdl = alloc(model, m);
auto& terms = m_eval.sort_assertions(m_terms.assertions());
for (expr* e : terms) {
if (!re_eval_is_correct(to_app(e))) {
@ -273,7 +348,12 @@ namespace bv {
void sls::updt_params(params_ref const& _p) {
sls_params p(_p);
m_config.m_max_restarts = p.max_restarts();
m_config.m_max_repairs = p.max_repairs();
m_rand.set_seed(p.random_seed());
m_terms.updt_params(_p);
params_ref q = _p;
q.set_uint("max_restarts", 10);
m_engine.updt_params(q);
}
void sls::trace_repair(bool down, expr* e) {

View file

@ -26,6 +26,7 @@ Author:
#include "ast/sls/sls_valuation.h"
#include "ast/sls/bv_sls_terms.h"
#include "ast/sls/bv_sls_eval.h"
#include "ast/sls/sls_engine.h"
#include "ast/bv_decl_plugin.h"
#include "model/model.h"
@ -49,6 +50,8 @@ namespace bv {
ptr_vector<expr> m_todo;
random_gen m_rand;
config m_config;
sls_engine m_engine;
bool m_engine_model = false;
std::pair<bool, app*> next_to_repair();
@ -59,19 +62,23 @@ namespace bv {
void try_repair_up(app* e);
void set_repair_down(expr* e) { m_repair_down = e->get_id(); }
lbool search();
lbool search1();
lbool search2();
void reinit_eval();
void init_repair();
void trace();
void trace_repair(bool down, expr* e);
indexed_uint_set m_to_repair;
void init_repair_candidates();
public:
sls(ast_manager& m);
sls(ast_manager& m, params_ref const& p);
/**
* Add constraints
*/
void assert_expr(expr* e) { m_terms.assert_expr(e); }
void assert_expr(expr* e) { m_terms.assert_expr(e); m_engine.assert_expr(e); }
/*
* Invoke init after all expressions are asserted.
@ -91,10 +98,10 @@ namespace bv {
lbool operator()();
void updt_params(params_ref const& p);
void collect_statistics(statistics & st) const { m_stats.collect_statistics(st); }
void reset_statistics() { m_stats.reset(); }
void collect_statistics(statistics& st) const { m_stats.collect_statistics(st); m_engine.collect_statistics(st); }
void reset_statistics() { m_stats.reset(); m_engine.reset_statistics(); }
sls_stats const& get_stats() const { return m_stats; }
unsigned get_num_moves() { return m_stats.m_moves + m_engine.get_stats().m_moves; }
std::ostream& display(std::ostream& out);

View file

@ -24,8 +24,8 @@ namespace bv {
{}
void sls_eval::init_eval(expr_ref_vector const& es, std::function<bool(expr*, unsigned)> const& eval) {
sort_assertions(es);
for (expr* e : m_todo) {
auto& terms = sort_assertions(es);
for (expr* e : terms) {
if (!is_app(e))
continue;
app* a = to_app(e);
@ -49,7 +49,7 @@ namespace bv {
TRACE("sls", tout << "Unhandled expression " << mk_pp(e, m) << "\n");
}
}
m_todo.reset();
terms.reset();
}
/**
@ -1698,7 +1698,7 @@ namespace bv {
}
if (bv.is_bv(e)) {
auto& v = eval(to_app(e));
// verbose_stream() << "committing: " << v << "\n";
for (unsigned i = 0; i < v.nw; ++i)
if (0 != (v.fixed[i] & (v.bits()[i] ^ v.eval[i]))) {
v.bits().copy_to(v.nw, v.eval);

View file

@ -108,6 +108,13 @@ namespace bv {
else if (bv.is_numeral(t, a))
init_range(s, -a, nullptr, rational(0), false);
}
else if (sign && m.is_eq(e, s, t)) {
if (bv.is_numeral(s, a))
// 1 <= t - a
init_range(nullptr, rational(1), t, -a, false);
else if (bv.is_numeral(t, a))
init_range(nullptr, rational(1), s, -a, false);
}
else if (bv.is_bit2bool(e, s, idx)) {
auto& val = wval(s);
val.try_set_bit(idx, !sign);
@ -157,7 +164,6 @@ namespace bv {
else
v.add_range(-b, -a);
}
}
void sls_fixed::get_offset(expr* e, expr*& x, rational& offset) {

View file

@ -20,12 +20,14 @@ Author:
#include "ast/ast_ll_pp.h"
#include "ast/sls/bv_sls.h"
#include "ast/rewriter/th_rewriter.h"
namespace bv {
sls_terms::sls_terms(ast_manager& m):
m(m),
bv(m),
m_rewriter(m),
m_assertions(m),
m_pinned(m),
m_translated(m),
@ -40,18 +42,20 @@ namespace bv {
expr* top = e;
m_pinned.push_back(e);
m_todo.push_back(e);
expr_fast_mark1 mark;
for (unsigned i = 0; i < m_todo.size(); ++i) {
expr* e = m_todo[i];
if (!is_app(e))
continue;
if (m_translated.get(e->get_id(), nullptr))
continue;
if (mark.is_marked(e))
continue;
mark.mark(e);
for (auto arg : *to_app(e))
m_todo.push_back(arg);
{
expr_fast_mark1 mark;
for (unsigned i = 0; i < m_todo.size(); ++i) {
expr* e = m_todo[i];
if (!is_app(e))
continue;
if (m_translated.get(e->get_id(), nullptr))
continue;
if (mark.is_marked(e))
continue;
mark.mark(e);
for (auto arg : *to_app(e))
m_todo.push_back(arg);
}
}
std::stable_sort(m_todo.begin(), m_todo.end(), [&](expr* a, expr* b) { return get_depth(a) < get_depth(b); });
for (expr* e : m_todo)
@ -127,7 +131,7 @@ namespace bv {
m_translated.setx(e->get_id(), r);
}
expr* sls_terms::mk_sdiv(expr* x, expr* y) {
expr_ref sls_terms::mk_sdiv(expr* x, expr* y) {
// d = udiv(abs(x), abs(y))
// y = 0, x >= 0 -> -1
// y = 0, x < 0 -> 1
@ -141,17 +145,18 @@ namespace bv {
expr_ref z(bv.mk_zero(sz), m);
expr* signx = bv.mk_ule(bv.mk_numeral(N / 2, sz), x);
expr* signy = bv.mk_ule(bv.mk_numeral(N / 2, sz), y);
expr* absx = m.mk_ite(signx, bv.mk_bv_sub(bv.mk_numeral(N - 1, sz), x), x);
expr* absy = m.mk_ite(signy, bv.mk_bv_sub(bv.mk_numeral(N - 1, sz), y), y);
expr* absx = m.mk_ite(signx, bv.mk_bv_neg(x), x);
expr* absy = m.mk_ite(signy, bv.mk_bv_neg(y), y);
expr* d = bv.mk_bv_udiv(absx, absy);
expr* r = m.mk_ite(m.mk_eq(signx, signy), d, bv.mk_bv_neg(d));
expr_ref r(m.mk_ite(m.mk_eq(signx, signy), d, bv.mk_bv_neg(d)), m);
r = m.mk_ite(m.mk_eq(z, y),
m.mk_ite(signx, bv.mk_one(sz), bv.mk_numeral(N - 1, sz)),
m.mk_ite(m.mk_eq(x, z), z, r));
m.mk_ite(signx, bv.mk_one(sz), bv.mk_numeral(N - 1, sz)),
m.mk_ite(m.mk_eq(x, z), z, r));
m_rewriter(r);
return r;
}
expr* sls_terms::mk_smod(expr* x, expr* y) {
expr_ref sls_terms::mk_smod(expr* x, expr* y) {
// u := umod(abs(x), abs(y))
// u = 0 -> 0
// y = 0 -> x
@ -164,21 +169,24 @@ namespace bv {
expr_ref abs_x(m.mk_ite(bv.mk_sle(z, x), x, bv.mk_bv_neg(x)), m);
expr_ref abs_y(m.mk_ite(bv.mk_sle(z, y), y, bv.mk_bv_neg(y)), m);
expr_ref u(bv.mk_bv_urem(abs_x, abs_y), m);
return
m.mk_ite(m.mk_eq(u, z), z,
expr_ref r(m);
r = m.mk_ite(m.mk_eq(u, z), z,
m.mk_ite(m.mk_eq(y, z), x,
m.mk_ite(m.mk_and(bv.mk_sle(z, x), bv.mk_sle(z, x)), u,
m.mk_ite(bv.mk_sle(z, x), bv.mk_bv_add(y, u),
m.mk_ite(bv.mk_sle(z, y), bv.mk_bv_sub(y, u), bv.mk_bv_neg(u))))));
m_rewriter(r);
return r;
}
expr* sls_terms::mk_srem(expr* x, expr* y) {
expr_ref sls_terms::mk_srem(expr* x, expr* y) {
// y = 0 -> x
// else x - sdiv(x, y) * y
return
m.mk_ite(m.mk_eq(y, bv.mk_zero(bv.get_bv_size(x))),
expr_ref r(m);
r = m.mk_ite(m.mk_eq(y, bv.mk_zero(bv.get_bv_size(x))),
x, bv.mk_bv_sub(x, bv.mk_bv_mul(y, mk_sdiv(x, y))));
m_rewriter(r);
return r;
}
@ -209,4 +217,11 @@ namespace bv {
m_assertion_set.insert(a->get_id());
}
void sls_terms::updt_params(params_ref const& p) {
params_ref q = p;
q.set_bool("flat", false);
m_rewriter.updt_params(q);
}
}

View file

@ -21,6 +21,7 @@ Author:
#include "util/scoped_ptr_vector.h"
#include "util/uint_set.h"
#include "ast/ast.h"
#include "ast/rewriter/th_rewriter.h"
#include "ast/sls/sls_stats.h"
#include "ast/sls/sls_powers.h"
#include "ast/sls/sls_valuation.h"
@ -31,6 +32,7 @@ namespace bv {
class sls_terms {
ast_manager& m;
bv_util bv;
th_rewriter m_rewriter;
ptr_vector<expr> m_todo, m_args;
expr_ref_vector m_assertions, m_pinned, m_translated;
app_ref_vector m_terms;
@ -40,12 +42,14 @@ namespace bv {
expr* ensure_binary(expr* e);
void ensure_binary_core(expr* e);
expr* mk_sdiv(expr* x, expr* y);
expr* mk_smod(expr* x, expr* y);
expr* mk_srem(expr* x, expr* y);
expr_ref mk_sdiv(expr* x, expr* y);
expr_ref mk_smod(expr* x, expr* y);
expr_ref mk_srem(expr* x, expr* y);
public:
sls_terms(ast_manager& m);
void updt_params(params_ref const& p);
/**
* Add constraints

View file

@ -421,6 +421,7 @@ lbool sls_engine::search() {
// get candidate variables
ptr_vector<func_decl> & to_evaluate = m_tracker.get_unsat_constants(m_assertions);
if (to_evaluate.empty())
{
res = l_true;
@ -514,6 +515,12 @@ lbool sls_engine::operator()() {
if (m_restart_init)
m_tracker.randomize(m_assertions);
return search_loop();
}
lbool sls_engine::search_loop() {
lbool res = l_undef;
do {
@ -533,7 +540,6 @@ lbool sls_engine::operator()() {
} while (res != l_true && m_stats.m_restarts++ < m_max_restarts);
verbose_stream() << "(restarts: " << m_stats.m_restarts << " flips: " << m_stats.m_moves << " fps: " << (m_stats.m_moves / m_stats.m_stopwatch.get_current_seconds()) << ")" << std::endl;
return res;
}

View file

@ -79,7 +79,11 @@ public:
void mk_inv(unsigned bv_sz, const mpz & old_value, mpz & inverted);
void mk_flip(sort * s, const mpz & old_value, unsigned bit, mpz & flipped);
lbool search();
lbool search();
lbool search_loop();
lbool operator()();

View file

@ -106,10 +106,12 @@ namespace bv {
bool sls_valuation::commit_eval() {
for (unsigned i = 0; i < nw; ++i)
if (0 != (fixed[i] & (m_bits[i] ^ eval[i])))
return false;
if (!in_range(eval))
if (0 != (fixed[i] & (m_bits[i] ^ eval[i])))
return false;
if (!in_range(eval))
return false;
for (unsigned i = 0; i < nw; ++i)
m_bits[i] = eval[i];
SASSERT(well_formed());
@ -491,8 +493,8 @@ namespace bv {
SASSERT(well_formed());
}
void sls_valuation::add_range(rational l, rational h) {
void sls_valuation::add_range(rational l, rational h) {
l = mod(l, rational::power_of_two(bw));
h = mod(h, rational::power_of_two(bw));
if (h == l)
@ -509,21 +511,28 @@ namespace bv {
auto old_lo = lo();
auto old_hi = hi();
if (old_lo < old_hi) {
if (old_lo < l && l < old_hi)
if (old_lo < l && l < old_hi && old_hi <= h)
set_value(m_lo, l),
old_lo = l;
if (old_hi < h && h < old_hi)
if (l <= old_lo && old_lo < h && h < old_hi)
set_value(m_hi, h);
}
else {
SASSERT(old_hi < old_lo);
if (old_lo < l || l < old_hi)
set_value(m_lo, l),
old_lo = l;
if (old_lo < h && h < old_hi)
if (h <= old_hi && old_lo <= l) {
set_value(m_lo, l);
set_value(m_hi, h);
else if (old_hi < old_lo && (h < old_hi || old_lo < h))
}
else if (old_lo <= l && l <= h) {
set_value(m_lo, l);
set_value(m_hi, h);
}
else if (old_lo + 1 == l) {
set_value(m_lo, l);
}
else if (old_hi == h + 1) {
set_value(m_hi, h);
}
}
}
@ -552,8 +561,7 @@ namespace bv {
// lo < hi, set most significant bits based on hi
//
void sls_valuation::tighten_range() {
// verbose_stream() << "tighten " << *this << "\n";
if (m_lo == m_hi)
return;
@ -613,6 +621,9 @@ namespace bv {
break;
}
if (has_range() && !in_range(m_bits))
m_bits = m_lo;
SASSERT(well_formed());
}

View file

@ -3,6 +3,7 @@ def_module_params('sls',
description='Experimental Stochastic Local Search Solver (for QFBV only).',
params=(max_memory_param(),
('max_restarts', UINT, UINT_MAX, 'maximum number of restarts'),
('max_repairs', UINT, 1000, 'maximum number of repairs before restart'),
('walksat', BOOL, 1, 'use walksat assertion selection (instead of gsat)'),
('walksat_ucb', BOOL, 1, 'use bandit heuristic for walksat assertion selection (instead of random)'),
('walksat_ucb_constant', DOUBLE, 20.0, 'the ucb constant c in the term score + c * f(touched)'),

View file

@ -61,9 +61,10 @@ namespace sls {
m_m = alloc(ast_manager, m);
ast_translation tr(m, *m_m);
params_ref p;
m_completed = false;
m_result = l_undef;
m_bvsls = alloc(bv::sls, *m_m);
m_bvsls = alloc(bv::sls, *m_m, p);
// walk clauses, add them
// walk trail stack until search level, add units
// encapsulate bvsls within the arguments of run-local-search.

View file

@ -134,7 +134,7 @@ public:
bv_sls_tactic(ast_manager& _m, params_ref const& p) :
m(_m),
m_params(p) {
m_sls = alloc(bv::sls, m);
m_sls = alloc(bv::sls, m, p);
}
tactic* translate(ast_manager& m) override {
@ -172,12 +172,12 @@ public:
m_sls->init_eval(false_eval);
lbool res = m_sls->operator()();
auto const& stats = m_sls->get_stats();
report_tactic_progress("Number of flips:", stats.m_moves);
IF_VERBOSE(20, verbose_stream() << res << "\n");
IF_VERBOSE(20, m_sls->display(verbose_stream()));
m_st.reset();
m_sls->collect_statistics(m_st);
report_tactic_progress("Number of flips:", m_sls->get_num_moves());
IF_VERBOSE(20, verbose_stream() << res << "\n");
IF_VERBOSE(20, m_sls->display(verbose_stream()));
if (res == l_true) {
if (g->models_enabled()) {
model_ref mdl = m_sls->get_model();
@ -207,7 +207,7 @@ public:
void cleanup() override {
auto* d = alloc(bv::sls, m);
auto* d = alloc(bv::sls, m, m_params);
std::swap(d, m_sls);
dealloc(d);
}
@ -235,12 +235,6 @@ tactic* mk_bv_sls_tactic(ast_manager& m, params_ref const& p) {
static tactic * mk_preamble(ast_manager & m, params_ref const & p) {
params_ref main_p;
main_p.set_bool("elim_and", true);
// main_p.set_bool("pull_cheap_ite", true);
main_p.set_bool("push_ite_bv", true);
main_p.set_bool("blast_distinct", true);
main_p.set_bool("hi_div0", true);
params_ref simp2_p = p;
simp2_p.set_bool("som", true);
@ -249,18 +243,15 @@ static tactic * mk_preamble(ast_manager & m, params_ref const & p) {
simp2_p.set_bool("local_ctx", true);
simp2_p.set_uint("local_ctx_limit", 10000000);
params_ref hoist_p;
params_ref hoist_p = p;
hoist_p.set_bool("hoist_mul", true);
hoist_p.set_bool("som", false);
params_ref gaussian_p;
params_ref gaussian_p = p;
// conservative gaussian elimination.
gaussian_p.set_uint("gaussian_max_occs", 2);
params_ref ctx_p;
ctx_p.set_uint("max_depth", 32);
ctx_p.set_uint("max_steps", 5000000);
return and_then(and_then(mk_simplify_tactic(m),
return and_then(and_then(mk_simplify_tactic(m, p),
mk_propagate_values_tactic(m),
using_params(mk_solve_eqs_tactic(m), gaussian_p),
mk_elim_uncnstr_tactic(m),
@ -278,7 +269,9 @@ tactic * mk_qfbv_sls_tactic(ast_manager & m, params_ref const & p) {
}
tactic* mk_qfbv_new_sls_tactic(ast_manager& m, params_ref const& p) {
tactic* t = and_then(mk_preamble(m, p), mk_bv_sls_tactic(m, p));
t->updt_params(p);
params_ref q = p;
q.set_bool("elim_sign_ext", false);
tactic* t = and_then(mk_preamble(m, q), mk_bv_sls_tactic(m, q));
t->updt_params(q);
return t;
}