mirror of
https://github.com/Z3Prover/z3
synced 2025-08-10 21:20:52 +00:00
adding lookahead mode to arithmetic sls solver
This commit is contained in:
parent
847278fba8
commit
d97bd48669
7 changed files with 575 additions and 207 deletions
File diff suppressed because it is too large
Load diff
|
@ -27,6 +27,14 @@ namespace sls {
|
||||||
|
|
||||||
using theory_var = int;
|
using theory_var = int;
|
||||||
|
|
||||||
|
enum arith_move_type {
|
||||||
|
hillclimb,
|
||||||
|
random_update,
|
||||||
|
random_inc_dec
|
||||||
|
};
|
||||||
|
|
||||||
|
std::ostream& operator<<(std::ostream& out, arith_move_type mt);
|
||||||
|
|
||||||
// local search portion for arithmetic
|
// local search portion for arithmetic
|
||||||
template<typename num_t>
|
template<typename num_t>
|
||||||
class arith_base : public plugin {
|
class arith_base : public plugin {
|
||||||
|
@ -37,6 +45,7 @@ namespace sls {
|
||||||
typedef unsigned atom_t;
|
typedef unsigned atom_t;
|
||||||
|
|
||||||
struct config {
|
struct config {
|
||||||
|
bool config_initialized = false;
|
||||||
double cb = 2.85;
|
double cb = 2.85;
|
||||||
unsigned L = 20;
|
unsigned L = 20;
|
||||||
unsigned t = 45;
|
unsigned t = 45;
|
||||||
|
@ -47,11 +56,22 @@ namespace sls {
|
||||||
bool paws = true;
|
bool paws = true;
|
||||||
unsigned max_moves = 500;
|
unsigned max_moves = 500;
|
||||||
unsigned max_moves_base = 500;
|
unsigned max_moves_base = 500;
|
||||||
|
unsigned wp = 100;
|
||||||
|
bool ucb = true;
|
||||||
|
double ucb_constant = 1.0;
|
||||||
|
double ucb_forget = 0.1;
|
||||||
|
bool ucb_init = false;
|
||||||
|
double ucb_noise = 0.1;
|
||||||
|
unsigned restart_base = 1000;
|
||||||
|
unsigned restart_next = 1000;
|
||||||
|
unsigned restart_init = 1000;
|
||||||
|
bool arith_use_lookahead = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct stats {
|
struct stats {
|
||||||
unsigned m_num_steps = 0;
|
unsigned m_num_steps = 0;
|
||||||
unsigned m_moves = 0;
|
unsigned m_moves = 0;
|
||||||
|
unsigned m_restarts = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
@ -93,10 +113,11 @@ namespace sls {
|
||||||
var_sort m_sort;
|
var_sort m_sort;
|
||||||
arith_op_kind m_op = arith_op_kind::LAST_ARITH_OP;
|
arith_op_kind m_op = arith_op_kind::LAST_ARITH_OP;
|
||||||
unsigned m_def_idx = UINT_MAX;
|
unsigned m_def_idx = UINT_MAX;
|
||||||
vector<std::pair<num_t, sat::bool_var>> m_ineqs;
|
vector<std::pair<num_t, sat::bool_var>> m_linear_occurs;
|
||||||
unsigned_vector m_muls;
|
unsigned_vector m_muls;
|
||||||
unsigned_vector m_adds;
|
unsigned_vector m_adds;
|
||||||
optional<bound> m_lo, m_hi;
|
optional<bound> m_lo, m_hi;
|
||||||
|
vector<num_t> m_finite_domain;
|
||||||
|
|
||||||
num_t const& value() const { return m_value; }
|
num_t const& value() const { return m_value; }
|
||||||
void set_value(num_t const& v) { m_value = v; }
|
void set_value(num_t const& v) { m_value = v; }
|
||||||
|
@ -187,6 +208,7 @@ namespace sls {
|
||||||
|
|
||||||
unsigned get_num_vars() const { return m_vars.size(); }
|
unsigned get_num_vars() const { return m_vars.size(); }
|
||||||
|
|
||||||
|
void updt_params();
|
||||||
bool is_distinct(expr* e);
|
bool is_distinct(expr* e);
|
||||||
bool eval_distinct(expr* e);
|
bool eval_distinct(expr* e);
|
||||||
void repair_distinct(expr* e);
|
void repair_distinct(expr* e);
|
||||||
|
@ -247,7 +269,7 @@ namespace sls {
|
||||||
bool find_lin_moves(sat::literal lit);
|
bool find_lin_moves(sat::literal lit);
|
||||||
bool find_reset_moves(sat::literal lit);
|
bool find_reset_moves(sat::literal lit);
|
||||||
void add_reset_update(var_t v);
|
void add_reset_update(var_t v);
|
||||||
void find_linear_moves(ineq const& i, var_t x, num_t const& coeff, num_t const& sum);
|
void find_linear_moves(ineq const& i, var_t x, num_t const& coeff);
|
||||||
void find_quadratic_moves(ineq const& i, var_t x, num_t const& a, num_t const& b, num_t const& sum);
|
void find_quadratic_moves(ineq const& i, var_t x, num_t const& a, num_t const& b, num_t const& sum);
|
||||||
double compute_score(var_t x, num_t const& delta);
|
double compute_score(var_t x, num_t const& delta);
|
||||||
void save_best_values();
|
void save_best_values();
|
||||||
|
@ -273,6 +295,7 @@ namespace sls {
|
||||||
void check_ineqs();
|
void check_ineqs();
|
||||||
void init_bool_var(sat::bool_var bv);
|
void init_bool_var(sat::bool_var bv);
|
||||||
void initialize_unit(sat::literal lit);
|
void initialize_unit(sat::literal lit);
|
||||||
|
void initialize_input_assertion(expr* f);
|
||||||
void add_le(var_t v, num_t const& n);
|
void add_le(var_t v, num_t const& n);
|
||||||
void add_ge(var_t v, num_t const& n);
|
void add_ge(var_t v, num_t const& n);
|
||||||
void add_lt(var_t v, num_t const& n);
|
void add_lt(var_t v, num_t const& n);
|
||||||
|
@ -288,20 +311,25 @@ namespace sls {
|
||||||
struct bool_info {
|
struct bool_info {
|
||||||
unsigned weight = 0;
|
unsigned weight = 0;
|
||||||
double score = 0;
|
double score = 0;
|
||||||
unsigned touched = 0;
|
unsigned touched = 1;
|
||||||
lbool value = l_undef;
|
lbool value = l_undef;
|
||||||
|
sat::bool_var_set fixable_atoms;
|
||||||
|
uint_set fixable_vars;
|
||||||
|
ptr_vector<expr> fixable_exprs;
|
||||||
|
bool_info(unsigned w) : weight(w) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
vector<ptr_vector<app>> m_update_stack;
|
vector<ptr_vector<app>> m_update_stack;
|
||||||
expr_mark m_in_update_stack;
|
expr_mark m_in_update_stack;
|
||||||
svector<bool_info> m_bool_info;
|
svector<bool_info> m_bool_info;
|
||||||
double m_best_score = 0, m_top_score = 0;
|
double m_best_score = 0, m_top_score = 0;
|
||||||
unsigned m_min_depth = 0, m_max_depth = 0;
|
unsigned m_min_depth = 0, m_max_depth = 0;
|
||||||
num_t m_best_value;
|
num_t m_best_value;
|
||||||
expr* m_best_expr = nullptr, * m_last_atom = nullptr;
|
expr* m_best_expr = nullptr, * m_last_atom = nullptr, * m_last_expr = nullptr;
|
||||||
expr_mark m_is_root;
|
expr_mark m_is_root;
|
||||||
sat::bool_var_set m_fixable_atoms;
|
unsigned m_touched = 1;
|
||||||
uint_set m_fixable_vars;
|
sat::bool_var_set m_fixed_atoms;
|
||||||
ptr_vector<expr> m_fixable_exprs;
|
|
||||||
bool_info& get_bool_info(expr* e);
|
bool_info& get_bool_info(expr* e);
|
||||||
bool get_bool_value(expr* e);
|
bool get_bool_value(expr* e);
|
||||||
bool get_bool_value_rec(expr* e);
|
bool get_bool_value_rec(expr* e);
|
||||||
|
@ -313,29 +341,36 @@ namespace sls {
|
||||||
double new_score(expr* e);
|
double new_score(expr* e);
|
||||||
double new_score(expr* e, bool is_true);
|
double new_score(expr* e, bool is_true);
|
||||||
void set_score(expr* e, double s) { get_bool_info(e).score = s; }
|
void set_score(expr* e, double s) { get_bool_info(e).score = s; }
|
||||||
|
|
||||||
void rescore();
|
void rescore();
|
||||||
void recalibrate_weights();
|
void recalibrate_weights();
|
||||||
void inc_weight(expr* e) { ++get_bool_info(e).weight; }
|
void inc_weight(expr* e) { ++get_bool_info(e).weight; }
|
||||||
void dec_weight(expr* e) { auto& i = get_bool_info(e); i.weight = i.weight > m_config.paws_init ? i.weight - 1 : m_config.paws_init; }
|
void dec_weight(expr* e) { auto& i = get_bool_info(e); i.weight = i.weight > m_config.paws_init ? i.weight - 1 : m_config.paws_init; }
|
||||||
unsigned get_weight(expr* e) { return get_bool_info(e).weight; }
|
unsigned get_weight(expr* e) { return get_bool_info(e).weight; }
|
||||||
|
unsigned get_touched(expr* e) { return get_bool_info(e).touched; }
|
||||||
|
void inc_touched(expr* e) { ++get_bool_info(e).touched; }
|
||||||
|
void set_touched(expr* e, unsigned t) { get_bool_info(e).touched = t; }
|
||||||
void insert_update_stack(expr* t);
|
void insert_update_stack(expr* t);
|
||||||
void insert_update_stack_rec(expr* t);
|
void insert_update_stack_rec(expr* t);
|
||||||
void clear_update_stack();
|
void clear_update_stack();
|
||||||
void lookahead_num(var_t v, num_t const& value);
|
void lookahead_num(var_t v, num_t const& value);
|
||||||
|
bool can_update_num(var_t v, num_t const& delta);
|
||||||
|
bool update_num(var_t v, num_t const& delta);
|
||||||
void lookahead_bool(expr* e);
|
void lookahead_bool(expr* e);
|
||||||
double lookahead(expr* e);
|
double lookahead(expr* e, bool update_score);
|
||||||
void add_lookahead(expr* e);
|
void add_lookahead(bool_info& i, expr* e);
|
||||||
void add_fixable(expr* e);
|
ptr_vector<expr> const& get_fixable_exprs(expr* e);
|
||||||
bool apply_move(expr* f, bool randomize);
|
bool apply_move(expr* f, ptr_vector<expr> const& vars, arith_move_type t);
|
||||||
expr* get_candidate_unsat();
|
expr* get_candidate_unsat();
|
||||||
void check_restart();
|
void check_restart();
|
||||||
|
void ucb_forget();
|
||||||
|
void update_args_value(var_t v, num_t const& new_value);
|
||||||
public:
|
public:
|
||||||
arith_base(context& ctx);
|
arith_base(context& ctx);
|
||||||
~arith_base() override {}
|
~arith_base() override {}
|
||||||
void register_term(expr* e) override;
|
void register_term(expr* e) override;
|
||||||
bool set_value(expr* e, expr* v) override;
|
bool set_value(expr* e, expr* v) override;
|
||||||
expr_ref get_value(expr* e) override;
|
expr_ref get_value(expr* e) override;
|
||||||
|
void start_propagation() override;
|
||||||
bool is_fixed(expr* e, expr_ref& value) override;
|
bool is_fixed(expr* e, expr_ref& value) override;
|
||||||
void initialize() override;
|
void initialize() override;
|
||||||
void propagate_literal(sat::literal lit) override;
|
void propagate_literal(sat::literal lit) override;
|
||||||
|
|
|
@ -72,6 +72,10 @@ namespace sls {
|
||||||
APPLY_BOTH(initialize());
|
APPLY_BOTH(initialize());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void arith_plugin::start_propagation() {
|
||||||
|
WITH_FALLBACK(start_propagation());
|
||||||
|
}
|
||||||
|
|
||||||
void arith_plugin::propagate_literal(sat::literal lit) {
|
void arith_plugin::propagate_literal(sat::literal lit) {
|
||||||
WITH_FALLBACK(propagate_literal(lit));
|
WITH_FALLBACK(propagate_literal(lit));
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,6 +32,7 @@ namespace sls {
|
||||||
~arith_plugin() override {}
|
~arith_plugin() override {}
|
||||||
void register_term(expr* e) override;
|
void register_term(expr* e) override;
|
||||||
expr_ref get_value(expr* e) override;
|
expr_ref get_value(expr* e) override;
|
||||||
|
void start_propagation() override;
|
||||||
bool is_fixed(expr* e, expr_ref& value) override;
|
bool is_fixed(expr* e, expr_ref& value) override;
|
||||||
void initialize() override;
|
void initialize() override;
|
||||||
void propagate_literal(sat::literal lit) override;
|
void propagate_literal(sat::literal lit) override;
|
||||||
|
|
|
@ -304,9 +304,9 @@ namespace sls {
|
||||||
|
|
||||||
void bv_lookahead::updt_params(params_ref const& _p) {
|
void bv_lookahead::updt_params(params_ref const& _p) {
|
||||||
sls_params p(_p);
|
sls_params p(_p);
|
||||||
if (m_config.updated)
|
if (m_config.config_initialized)
|
||||||
return;
|
return;
|
||||||
m_config.updated = true;
|
m_config.config_initialized = true;
|
||||||
m_config.walksat = p.walksat();
|
m_config.walksat = p.walksat();
|
||||||
m_config.walksat_repick = p.walksat_repick();
|
m_config.walksat_repick = p.walksat_repick();
|
||||||
m_config.paws_sp = p.paws_sp();
|
m_config.paws_sp = p.paws_sp();
|
||||||
|
|
|
@ -26,7 +26,7 @@ namespace sls {
|
||||||
class bv_lookahead {
|
class bv_lookahead {
|
||||||
|
|
||||||
struct config {
|
struct config {
|
||||||
bool updated = false;
|
bool config_initialized = false;
|
||||||
double cb = 2.85;
|
double cb = 2.85;
|
||||||
unsigned paws_init = 40;
|
unsigned paws_init = 40;
|
||||||
unsigned paws_sp = 52;
|
unsigned paws_sp = 52;
|
||||||
|
@ -181,11 +181,11 @@ namespace sls {
|
||||||
|
|
||||||
void finalize_bool_values();
|
void finalize_bool_values();
|
||||||
|
|
||||||
|
void updt_params(params_ref const& p);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
bv_lookahead(bv_eval& ev);
|
bv_lookahead(bv_eval& ev);
|
||||||
|
|
||||||
void updt_params(params_ref const& p);
|
|
||||||
|
|
||||||
void start_propagation();
|
void start_propagation();
|
||||||
|
|
||||||
void collect_statistics(statistics& st) const;
|
void collect_statistics(statistics& st) const;
|
||||||
|
|
|
@ -25,6 +25,7 @@ def_module_params('sls',
|
||||||
('dt_axiomatic', BOOL, True, 'use axiomatic mode or model reduction for datatype solver'),
|
('dt_axiomatic', BOOL, True, 'use axiomatic mode or model reduction for datatype solver'),
|
||||||
('track_unsat', BOOL, 0, 'keep a list of unsat assertions as done in SAT - currently disabled internally'),
|
('track_unsat', BOOL, 0, 'keep a list of unsat assertions as done in SAT - currently disabled internally'),
|
||||||
('random_seed', UINT, 0, 'random seed'),
|
('random_seed', UINT, 0, 'random seed'),
|
||||||
|
('arith_use_lookahead', BOOL, False, 'use lookahead solver for NIRA'),
|
||||||
('bv_use_top_level_assertions', BOOL, True, 'use top-level assertions for BV lookahead solver'),
|
('bv_use_top_level_assertions', BOOL, True, 'use top-level assertions for BV lookahead solver'),
|
||||||
('bv_use_lookahead', BOOL, True, 'use lookahead solver for BV'),
|
('bv_use_lookahead', BOOL, True, 'use lookahead solver for BV'),
|
||||||
('bv_allow_rotation', BOOL, True, 'allow model rotation when repairing literal assignment'),
|
('bv_allow_rotation', BOOL, True, 'allow model rotation when repairing literal assignment'),
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue