3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-08 10:25:18 +00:00

adding lookahead mode to arithmetic sls solver

This commit is contained in:
Nikolaj Bjorner 2025-01-11 15:47:17 -08:00
parent 847278fba8
commit d97bd48669
7 changed files with 575 additions and 207 deletions

File diff suppressed because it is too large Load diff

View file

@ -27,6 +27,14 @@ namespace sls {
using theory_var = int;
enum arith_move_type {
hillclimb,
random_update,
random_inc_dec
};
std::ostream& operator<<(std::ostream& out, arith_move_type mt);
// local search portion for arithmetic
template<typename num_t>
class arith_base : public plugin {
@ -37,6 +45,7 @@ namespace sls {
typedef unsigned atom_t;
struct config {
bool config_initialized = false;
double cb = 2.85;
unsigned L = 20;
unsigned t = 45;
@ -47,11 +56,22 @@ namespace sls {
bool paws = true;
unsigned max_moves = 500;
unsigned max_moves_base = 500;
unsigned wp = 100;
bool ucb = true;
double ucb_constant = 1.0;
double ucb_forget = 0.1;
bool ucb_init = false;
double ucb_noise = 0.1;
unsigned restart_base = 1000;
unsigned restart_next = 1000;
unsigned restart_init = 1000;
bool arith_use_lookahead = false;
};
struct stats {
unsigned m_num_steps = 0;
unsigned m_moves = 0;
unsigned m_restarts = 0;
};
public:
@ -93,10 +113,11 @@ namespace sls {
var_sort m_sort;
arith_op_kind m_op = arith_op_kind::LAST_ARITH_OP;
unsigned m_def_idx = UINT_MAX;
vector<std::pair<num_t, sat::bool_var>> m_ineqs;
vector<std::pair<num_t, sat::bool_var>> m_linear_occurs;
unsigned_vector m_muls;
unsigned_vector m_adds;
optional<bound> m_lo, m_hi;
vector<num_t> m_finite_domain;
num_t const& value() const { return m_value; }
void set_value(num_t const& v) { m_value = v; }
@ -187,6 +208,7 @@ namespace sls {
unsigned get_num_vars() const { return m_vars.size(); }
void updt_params();
bool is_distinct(expr* e);
bool eval_distinct(expr* e);
void repair_distinct(expr* e);
@ -247,7 +269,7 @@ namespace sls {
bool find_lin_moves(sat::literal lit);
bool find_reset_moves(sat::literal lit);
void add_reset_update(var_t v);
void find_linear_moves(ineq const& i, var_t x, num_t const& coeff, num_t const& sum);
void find_linear_moves(ineq const& i, var_t x, num_t const& coeff);
void find_quadratic_moves(ineq const& i, var_t x, num_t const& a, num_t const& b, num_t const& sum);
double compute_score(var_t x, num_t const& delta);
void save_best_values();
@ -273,6 +295,7 @@ namespace sls {
void check_ineqs();
void init_bool_var(sat::bool_var bv);
void initialize_unit(sat::literal lit);
void initialize_input_assertion(expr* f);
void add_le(var_t v, num_t const& n);
void add_ge(var_t v, num_t const& n);
void add_lt(var_t v, num_t const& n);
@ -288,20 +311,25 @@ namespace sls {
struct bool_info {
unsigned weight = 0;
double score = 0;
unsigned touched = 0;
unsigned touched = 1;
lbool value = l_undef;
sat::bool_var_set fixable_atoms;
uint_set fixable_vars;
ptr_vector<expr> fixable_exprs;
bool_info(unsigned w) : weight(w) {}
};
vector<ptr_vector<app>> m_update_stack;
expr_mark m_in_update_stack;
svector<bool_info> m_bool_info;
double m_best_score = 0, m_top_score = 0;
unsigned m_min_depth = 0, m_max_depth = 0;
num_t m_best_value;
expr* m_best_expr = nullptr, * m_last_atom = nullptr;
expr* m_best_expr = nullptr, * m_last_atom = nullptr, * m_last_expr = nullptr;
expr_mark m_is_root;
sat::bool_var_set m_fixable_atoms;
uint_set m_fixable_vars;
ptr_vector<expr> m_fixable_exprs;
unsigned m_touched = 1;
sat::bool_var_set m_fixed_atoms;
bool_info& get_bool_info(expr* e);
bool get_bool_value(expr* e);
bool get_bool_value_rec(expr* e);
@ -313,29 +341,36 @@ namespace sls {
double new_score(expr* e);
double new_score(expr* e, bool is_true);
void set_score(expr* e, double s) { get_bool_info(e).score = s; }
void rescore();
void recalibrate_weights();
void inc_weight(expr* e) { ++get_bool_info(e).weight; }
void dec_weight(expr* e) { auto& i = get_bool_info(e); i.weight = i.weight > m_config.paws_init ? i.weight - 1 : m_config.paws_init; }
unsigned get_weight(expr* e) { return get_bool_info(e).weight; }
unsigned get_touched(expr* e) { return get_bool_info(e).touched; }
void inc_touched(expr* e) { ++get_bool_info(e).touched; }
void set_touched(expr* e, unsigned t) { get_bool_info(e).touched = t; }
void insert_update_stack(expr* t);
void insert_update_stack_rec(expr* t);
void clear_update_stack();
void lookahead_num(var_t v, num_t const& value);
bool can_update_num(var_t v, num_t const& delta);
bool update_num(var_t v, num_t const& delta);
void lookahead_bool(expr* e);
double lookahead(expr* e);
void add_lookahead(expr* e);
void add_fixable(expr* e);
bool apply_move(expr* f, bool randomize);
double lookahead(expr* e, bool update_score);
void add_lookahead(bool_info& i, expr* e);
ptr_vector<expr> const& get_fixable_exprs(expr* e);
bool apply_move(expr* f, ptr_vector<expr> const& vars, arith_move_type t);
expr* get_candidate_unsat();
void check_restart();
void ucb_forget();
void update_args_value(var_t v, num_t const& new_value);
public:
arith_base(context& ctx);
~arith_base() override {}
void register_term(expr* e) override;
bool set_value(expr* e, expr* v) override;
expr_ref get_value(expr* e) override;
void start_propagation() override;
bool is_fixed(expr* e, expr_ref& value) override;
void initialize() override;
void propagate_literal(sat::literal lit) override;

View file

@ -72,6 +72,10 @@ namespace sls {
APPLY_BOTH(initialize());
}
void arith_plugin::start_propagation() {
WITH_FALLBACK(start_propagation());
}
void arith_plugin::propagate_literal(sat::literal lit) {
WITH_FALLBACK(propagate_literal(lit));
}

View file

@ -32,6 +32,7 @@ namespace sls {
~arith_plugin() override {}
void register_term(expr* e) override;
expr_ref get_value(expr* e) override;
void start_propagation() override;
bool is_fixed(expr* e, expr_ref& value) override;
void initialize() override;
void propagate_literal(sat::literal lit) override;

View file

@ -304,9 +304,9 @@ namespace sls {
void bv_lookahead::updt_params(params_ref const& _p) {
sls_params p(_p);
if (m_config.updated)
if (m_config.config_initialized)
return;
m_config.updated = true;
m_config.config_initialized = true;
m_config.walksat = p.walksat();
m_config.walksat_repick = p.walksat_repick();
m_config.paws_sp = p.paws_sp();

View file

@ -26,7 +26,7 @@ namespace sls {
class bv_lookahead {
struct config {
bool updated = false;
bool config_initialized = false;
double cb = 2.85;
unsigned paws_init = 40;
unsigned paws_sp = 52;
@ -181,11 +181,11 @@ namespace sls {
void finalize_bool_values();
void updt_params(params_ref const& p);
public:
bv_lookahead(bv_eval& ev);
void updt_params(params_ref const& p);
void start_propagation();
void collect_statistics(statistics& st) const;

View file

@ -25,6 +25,7 @@ def_module_params('sls',
('dt_axiomatic', BOOL, True, 'use axiomatic mode or model reduction for datatype solver'),
('track_unsat', BOOL, 0, 'keep a list of unsat assertions as done in SAT - currently disabled internally'),
('random_seed', UINT, 0, 'random seed'),
('arith_use_lookahead', BOOL, False, 'use lookahead solver for NIRA'),
('bv_use_top_level_assertions', BOOL, True, 'use top-level assertions for BV lookahead solver'),
('bv_use_lookahead', BOOL, True, 'use lookahead solver for BV'),
('bv_allow_rotation', BOOL, True, 'allow model rotation when repairing literal assignment'),