mirror of
				https://github.com/Z3Prover/z3
				synced 2025-10-31 03:32:28 +00:00 
			
		
		
		
	add clausal lookahead to arithmetic solver as part of portfolio
have legacy qfbv-sls solver use nnf pre-processing. It relies on it for correctness of the score updates.
This commit is contained in:
		
							parent
							
								
									a941f5ae84
								
							
						
					
					
						commit
						22e4054674
					
				
					 13 changed files with 678 additions and 92 deletions
				
			
		|  | @ -3,6 +3,7 @@ z3_add_component(ast_sls | |||
|     bvsls_opt_engine.cpp | ||||
|     sat_ddfw.cpp | ||||
|     sls_arith_base.cpp | ||||
|     sls_arith_clausal.cpp | ||||
|     sls_arith_plugin.cpp | ||||
|     sls_array_plugin.cpp | ||||
|     sls_basic_plugin.cpp | ||||
|  |  | |||
|  | @ -99,6 +99,16 @@ namespace sat { | |||
|         m_last_flips = m_flips; | ||||
|     } | ||||
| 
 | ||||
|     sat::bool_var ddfw::bool_flip() { | ||||
|         flet<bool> _in_bool_flip(m_in_bool_flip, true); | ||||
|         double reward = 0; | ||||
|         bool_var v = pick_var(reward); | ||||
|         if (apply_flip(v, reward)) | ||||
|             return v; | ||||
|         shift_weights(); | ||||
|         return sat::null_bool_var; | ||||
|     } | ||||
| 
 | ||||
|     bool ddfw::do_flip() { | ||||
|         double reward = 0; | ||||
|         bool_var v = pick_var(reward); | ||||
|  | @ -125,7 +135,9 @@ namespace sat { | |||
|         bool_var v0 = null_bool_var; | ||||
|         for (bool_var v : m_unsat_vars) { | ||||
|             r = reward(v); | ||||
|             if (r > 0.0)     | ||||
|             if (m_in_bool_flip && m_plugin->is_external(v)) | ||||
|                 ; | ||||
|             else if (r > 0.0)     | ||||
|                 sum_pos += score(r);             | ||||
|             else if (r == 0.0 && sum_pos == 0 && (m_rand() % (n++)) == 0)  | ||||
|                 v0 = v;             | ||||
|  | @ -134,6 +146,8 @@ namespace sat { | |||
|             double lim_pos = ((double) m_rand() / (1.0 + m_rand.max_value())) * sum_pos;                 | ||||
|             for (bool_var v : m_unsat_vars) { | ||||
|                 r = reward(v); | ||||
|                 if (m_in_bool_flip && m_plugin->is_external(v)) | ||||
|                     continue; | ||||
|                 if (r > 0) { | ||||
|                     lim_pos -= score(r); | ||||
|                     if (lim_pos <= 0)  | ||||
|  | @ -146,6 +160,8 @@ namespace sat { | |||
|             return v0; | ||||
|         if (m_unsat_vars.empty()) | ||||
|             return null_bool_var; | ||||
|         if (m_in_bool_flip) | ||||
|             return false; | ||||
|         return m_unsat_vars.elem_at(m_rand(m_unsat_vars.size())); | ||||
|     } | ||||
| 
 | ||||
|  | @ -332,6 +348,7 @@ namespace sat { | |||
|             m_vars[v].m_reward = 0; | ||||
|         }         | ||||
|         m_unsat_vars.reset(); | ||||
|         m_num_external_in_unsat_vars = 0; | ||||
|         m_unsat.reset(); | ||||
|         unsigned sz = m_clauses.size(); | ||||
|         for (unsigned i = 0; i < sz; ++i) { | ||||
|  | @ -400,7 +417,7 @@ namespace sat { | |||
|         for (unsigned i = 0; i < num_vars(); ++i)  | ||||
|             m_model[i] = to_lbool(value(i)); | ||||
|         save_priorities(); | ||||
|         if (m_plugin) | ||||
|         if (m_plugin && !m_in_bool_flip) | ||||
|             m_last_result = m_plugin->on_save_model();    | ||||
|     } | ||||
| 
 | ||||
|  |  | |||
|  | @ -36,11 +36,10 @@ namespace sat { | |||
|     class local_search_plugin { | ||||
|     public: | ||||
|         virtual ~local_search_plugin() {} | ||||
|         //virtual void init_search() = 0;
 | ||||
|         //virtual void finish_search() = 0;
 | ||||
|         virtual void on_rescale() = 0; | ||||
|         virtual lbool on_save_model() = 0; | ||||
|         virtual void on_restart() = 0; | ||||
|         virtual bool is_external(sat::bool_var v) = 0; | ||||
|     }; | ||||
|      | ||||
|     class ddfw { | ||||
|  | @ -140,14 +139,26 @@ namespace sat { | |||
| 
 | ||||
|         unsigned select_max_same_sign(unsigned cf_idx); | ||||
| 
 | ||||
|         unsigned m_num_external_in_unsat_vars = 0; | ||||
| 
 | ||||
|         inline void inc_make(literal lit) {  | ||||
|             bool_var v = lit.var();  | ||||
|             if (make_count(v)++ == 0) m_unsat_vars.insert_fresh(v);  | ||||
|             if (make_count(v)++ == 0) { | ||||
|                 m_unsat_vars.insert_fresh(v); | ||||
|                 if (m_plugin && m_plugin->is_external(v)) | ||||
|                     ++m_num_external_in_unsat_vars; | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         inline void dec_make(literal lit) {  | ||||
|             bool_var v = lit.var();  | ||||
|             if (--make_count(v) == 0) m_unsat_vars.remove(v);  | ||||
|             if (--make_count(v) == 0) { | ||||
|                 if (m_unsat_vars.contains(v)) { | ||||
|                     m_unsat_vars.remove(v); | ||||
|                     if (m_plugin && m_plugin->is_external(v)) | ||||
|                         --m_num_external_in_unsat_vars; | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         inline void inc_reward(literal lit, double w) { m_vars[lit.var()].m_reward += w; } | ||||
|  | @ -164,13 +175,12 @@ namespace sat { | |||
| 
 | ||||
|         bool apply_flip(bool_var v, double reward); | ||||
| 
 | ||||
| 
 | ||||
|         void save_best_values(); | ||||
|         void save_model(); | ||||
|         void save_priorities(); | ||||
| 
 | ||||
|         // shift activity
 | ||||
|         void shift_weights(); | ||||
| 
 | ||||
|         inline double calculate_transfer_weight(double w); | ||||
| 
 | ||||
|         // reinitialize weights activity
 | ||||
|  | @ -204,6 +214,8 @@ namespace sat { | |||
|         bool_var_set m_rotate_tabu; | ||||
|         bool_var_vector m_new_tabu_vars; | ||||
| 
 | ||||
|         bool m_in_bool_flip = false; | ||||
| 
 | ||||
|     public: | ||||
| 
 | ||||
|         ddfw() {} | ||||
|  | @ -241,6 +253,10 @@ namespace sat { | |||
| 
 | ||||
|         indexed_uint_set const& unsat_set() const { return m_unsat; } | ||||
| 
 | ||||
|         indexed_uint_set const& unsat_vars() const { return m_unsat_vars; } | ||||
| 
 | ||||
|         unsigned num_external_in_unsat_vars() const { return m_num_external_in_unsat_vars; } | ||||
| 
 | ||||
|         vector<clause_info> const& clauses() const { return m_clauses; } | ||||
| 
 | ||||
|         clause_info& get_clause_info(unsigned idx) { return m_clauses[idx]; } | ||||
|  | @ -251,6 +267,10 @@ namespace sat { | |||
| 
 | ||||
|         void flip(bool_var v); | ||||
| 
 | ||||
|         sat::bool_var bool_flip(); | ||||
| 
 | ||||
|         void shift_weights(); | ||||
| 
 | ||||
|         inline double reward(bool_var v) const { return m_vars[v].m_reward; } | ||||
| 
 | ||||
|         void set_reward(bool_var v, double r) { m_vars[v].m_reward = r; } | ||||
|  |  | |||
|  | @ -111,7 +111,8 @@ namespace sls { | |||
|     arith_base<num_t>::arith_base(context& ctx) : | ||||
|         plugin(ctx), | ||||
|         a(m), | ||||
|         m_new_terms(m) { | ||||
|         m_new_terms(m), | ||||
|         m_clausal_sls(*this) { | ||||
|         m_fid = a.get_family_id(); | ||||
|     } | ||||
| 
 | ||||
|  | @ -447,12 +448,12 @@ namespace sls { | |||
|         delta_out = delta; | ||||
| 
 | ||||
|         if (m_last_var == v && m_last_delta == -delta) { | ||||
|             TRACE("arith", tout << "flip back " << v << " " << delta << "\n";); | ||||
|             TRACE("arith_verbose", tout << "flip back " << v << " " << delta << "\n";); | ||||
|             return false; | ||||
|         } | ||||
| 
 | ||||
|         if (m_use_tabu && vi.is_tabu(m_stats.m_num_steps, delta)) { | ||||
|             TRACE("arith", tout << "tabu\n"); | ||||
|         if (m_use_tabu && vi.is_tabu(m_stats.m_steps, delta)) { | ||||
|             TRACE("arith_verbose", tout << "tabu v" << v << " delta:" << delta << "\n"); | ||||
|             return false; | ||||
|         } | ||||
|          | ||||
|  | @ -545,8 +546,8 @@ namespace sls { | |||
| 
 | ||||
|             if (update(v, new_value)) { | ||||
|                 m_last_delta = delta; | ||||
|                 m_stats.m_num_steps++; | ||||
|                 m_vars[v].set_step(m_stats.m_num_steps, m_stats.m_num_steps + 3 + ctx.rand(10), delta); | ||||
|                 m_stats.m_steps++; | ||||
|                 m_vars[v].set_step(m_stats.m_steps, m_stats.m_steps + 3 + ctx.rand(10), delta); | ||||
|                 return true; | ||||
|             } | ||||
|             sum_score -= score; | ||||
|  | @ -1106,6 +1107,7 @@ namespace sls { | |||
| 
 | ||||
|         // attach i to bv
 | ||||
|         m_ineqs.set(bv, &i); | ||||
|         m_bool_var_atoms.insert(bv); | ||||
|      } | ||||
| 
 | ||||
|     template<typename num_t> | ||||
|  | @ -1403,6 +1405,40 @@ namespace sls { | |||
|                 throw default_exception("repair is not supported for " + mk_pp(e, m)); | ||||
|             } | ||||
|         } | ||||
|         for (unsigned v = 0; v < m_vars.size(); ++v) | ||||
|             initialize_bool_vars_of(v); | ||||
|     } | ||||
| 
 | ||||
|     template<typename num_t> | ||||
|     void arith_base<num_t>::initialize_bool_vars_of(var_t v) { | ||||
|         if (!m_vars[v].m_bool_vars_of.empty()) | ||||
|             return; | ||||
|         buffer<var_t> todo; | ||||
|         todo.push_back(v); | ||||
|         auto& vi = m_vars[v]; | ||||
|         for (unsigned i = 0; i < todo.size(); ++i) { | ||||
|             var_t u = todo[i]; | ||||
|             auto& ui = m_vars[u]; | ||||
|             for (auto const& idx : ui.m_muls) { | ||||
|                 auto& [x, monomial] = m_muls[idx]; | ||||
|                 if (all_of(todo, [x](var_t v) { return x != v; })) | ||||
|                     todo.push_back(x); | ||||
|             } | ||||
|             for (auto const& idx : ui.m_adds) { | ||||
|                 auto x = m_adds[idx].m_var; | ||||
|                 if (all_of(todo, [x](var_t v) { return x != v; })) | ||||
|                     todo.push_back(x); | ||||
|             } | ||||
|             for (auto const& [coeff, bv] : ui.m_linear_occurs) | ||||
|                 vi.m_bool_vars_of.insert(bv); | ||||
|         } | ||||
|         ; | ||||
|         for (auto bv : vi.m_bool_vars_of) { | ||||
|             for (auto i : ctx.get_use_list(sat::literal(bv, true))) | ||||
|                 vi.m_clauses_of.insert(i); | ||||
|             for (auto i : ctx.get_use_list(sat::literal(bv, false))) | ||||
|                 vi.m_clauses_of.insert(i); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     template<typename num_t> | ||||
|  | @ -2274,7 +2310,7 @@ namespace sls { | |||
|         auto const& vi = m_vars[v]; | ||||
|         if (vi.m_def_idx == UINT_MAX) | ||||
|             return true; | ||||
|         IF_VERBOSE(4, verbose_stream() << vi.m_op << " repair def " << mk_bounded_pp(vi.m_expr, m) << "\n"); | ||||
|         IF_VERBOSE(10, verbose_stream() << vi.m_op << " repair def " << mk_bounded_pp(vi.m_expr, m) << "\n"); | ||||
|         TRACE("sls", tout << "repair def " << mk_bounded_pp(vi.m_expr, m) << "\n"); | ||||
|         switch (vi.m_op) { | ||||
|         case arith_op_kind::LAST_ARITH_OP: | ||||
|  | @ -2398,13 +2434,12 @@ namespace sls { | |||
| 
 | ||||
|     template<typename num_t> | ||||
|     void arith_base<num_t>::collect_statistics(statistics& st) const { | ||||
|         st.update("sls-arith-flips", m_stats.m_num_steps); | ||||
|         st.update("sls-arith-moves", m_stats.m_moves); | ||||
|         st.update("sls-arith-steps", m_stats.m_steps); | ||||
|     } | ||||
| 
 | ||||
|     template<typename num_t> | ||||
|     void arith_base<num_t>::reset_statistics() { | ||||
|         m_stats.m_num_steps = 0; | ||||
|         m_stats.m_steps = 0; | ||||
|     } | ||||
| 
 | ||||
|     // global lookahead mode
 | ||||
|  | @ -2708,7 +2743,6 @@ namespace sls { | |||
|     template<typename num_t> | ||||
|     void arith_base<num_t>::lookahead_num(var_t v, num_t const& delta) {        | ||||
|         num_t old_value = value(v); | ||||
| 
 | ||||
|         expr* e = m_vars[v].m_expr; | ||||
|         if (m_last_expr != e) { | ||||
|             if (m_last_expr) | ||||
|  | @ -2779,6 +2813,31 @@ namespace sls { | |||
|         m_last_expr = nullptr; | ||||
|     } | ||||
| 
 | ||||
|     template<typename num_t> | ||||
|     void arith_base<num_t>::add_lookahead(bool_info& i, sat::bool_var bv) { | ||||
|         if (!i.fixable_atoms.contains(bv)) | ||||
|             return; | ||||
|         if (m_fixed_atoms.contains(bv)) | ||||
|             return; | ||||
|         auto* ineq = get_ineq(bv); | ||||
|         if (!ineq) | ||||
|             return; | ||||
|         num_t na, nb; | ||||
|         for (auto const& [x, nl] : ineq->m_nonlinear) { | ||||
|             if (!i.fixable_vars.contains(x)) | ||||
|                 continue; | ||||
|             if (is_fixed(x)) | ||||
|                 continue; | ||||
|             if (is_linear(x, nl, nb)) | ||||
|                 find_linear_moves(*ineq, x, nb); | ||||
|             else if (is_quadratic(x, nl, na, nb)) | ||||
|                 find_quadratic_moves(*ineq, x, na, nb, ineq->m_args_value); | ||||
|             else | ||||
|                 ; | ||||
|         } | ||||
|         m_fixed_atoms.insert(bv); | ||||
|     } | ||||
| 
 | ||||
|     // for every variable e, for every atom containing e
 | ||||
|     // add lookahead for e.
 | ||||
|     // m_fixable_atoms contains atoms that can be fixed.
 | ||||
|  | @ -2786,33 +2845,6 @@ namespace sls { | |||
|     template<typename num_t> | ||||
|     void arith_base<num_t>::add_lookahead(bool_info& i, expr* e) { | ||||
| 
 | ||||
|         auto add_atom = [&](sat::bool_var bv) { | ||||
|             if (!i.fixable_atoms.contains(bv)) | ||||
|                 return; | ||||
|             if (m_fixed_atoms.contains(bv)) | ||||
|                 return; | ||||
|             auto a = ctx.atom(bv); | ||||
|             if (!a) | ||||
|                 return; | ||||
|             auto* ineq = get_ineq(bv); | ||||
|             if (!ineq) | ||||
|                 return; | ||||
|             num_t na, nb; | ||||
|             for (auto const& [x, nl] : ineq->m_nonlinear) { | ||||
|                 if (!i.fixable_vars.contains(x)) | ||||
|                     continue; | ||||
|                 if (is_fixed(x)) | ||||
|                     continue; | ||||
|                 if (is_linear(x, nl, nb)) | ||||
|                     find_linear_moves(*ineq, x, nb); | ||||
|                 else if (is_quadratic(x, nl, na, nb)) | ||||
|                     find_quadratic_moves(*ineq, x, na, nb, ineq->m_args_value); | ||||
|                 else | ||||
|                     ; | ||||
|             } | ||||
|             m_fixed_atoms.insert(bv); | ||||
|         }; | ||||
| 
 | ||||
|         auto add_finite_domain = [&](var_t v) { | ||||
|             auto old_value = value(v); | ||||
|             for (auto const& n : m_vars[v].m_finite_domain)  | ||||
|  | @ -2832,13 +2864,8 @@ namespace sls { | |||
|                 add_finite_domain(v); | ||||
|                 return; | ||||
|             } | ||||
|             for (auto const& [coeff, bv] : vi.m_linear_occurs)  | ||||
|                 add_atom(bv); | ||||
|             for (auto const& idx : vi.m_muls) { | ||||
|                 auto const& [x, monomial] = m_muls[idx]; | ||||
|                 for (auto [coeff, bv] : m_vars[x].m_linear_occurs) | ||||
|                     add_atom(bv); | ||||
|             } | ||||
|             for (auto bv : vi.m_bool_vars_of) | ||||
|                 add_lookahead(i, bv);             | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|  | @ -2927,7 +2954,7 @@ namespace sls { | |||
|                 add_lookahead(info, vars[(start + i) % sz]); | ||||
|             if (m_updates.empty()) | ||||
|                 return false; | ||||
|             unsigned idx = ctx.rand() % m_updates.size(); | ||||
|             unsigned idx = ctx.rand(m_updates.size()); | ||||
|             auto& [v, delta, score] = m_updates[idx]; | ||||
|             m_best_expr = m_vars[v].m_expr; | ||||
|             if (false && !m_vars[v].m_finite_domain.empty()) | ||||
|  | @ -3015,13 +3042,13 @@ namespace sls { | |||
|     void  arith_base<num_t>::global_search() { | ||||
|         initialize_bool_assignment(); | ||||
|         rescore(); | ||||
|         m_config.max_moves = m_stats.m_moves + m_config.max_moves_base; | ||||
|         TRACE("arith", tout << "search " << m_stats.m_moves << " " << m_config.max_moves << "\n";); | ||||
|         IF_VERBOSE(3, verbose_stream() << "lookahead-search moves:" << m_stats.m_moves << " max-moves:" << m_config.max_moves << "\n"); | ||||
|         m_config.max_moves = m_stats.m_steps + m_config.max_moves_base; | ||||
|         TRACE("arith", tout << "search " << m_stats.m_steps << " " << m_config.max_moves << "\n";); | ||||
|         IF_VERBOSE(3, verbose_stream() << "lookahead-search steps:" << m_stats.m_steps << " max-moves:" << m_config.max_moves << "\n"); | ||||
|         TRACE("arith", display(tout)); | ||||
| 
 | ||||
|         while (ctx.rlimit().inc() && m_stats.m_moves < m_config.max_moves) { | ||||
|             m_stats.m_moves++; | ||||
|         while (ctx.rlimit().inc() && m_stats.m_steps < m_config.max_moves) { | ||||
|             m_stats.m_steps++; | ||||
|             check_restart(); | ||||
| 
 | ||||
|             auto t = get_candidate_unsat(); | ||||
|  | @ -3043,7 +3070,7 @@ namespace sls { | |||
|             if (apply_move(t, vars, arith_move_type::random_update)) | ||||
|                 recalibrate_weights(); | ||||
|         } | ||||
|         if (m_stats.m_moves >= m_config.max_moves) | ||||
|         if (m_stats.m_steps >= m_config.max_moves) | ||||
|             m_config.max_moves_base += 100; | ||||
|         finalize_bool_assignment(); | ||||
|     } | ||||
|  | @ -3098,11 +3125,11 @@ namespace sls { | |||
|         if (old_value == new_value) | ||||
|             return true; | ||||
|         if (!vi.in_range(new_value)) { | ||||
|             TRACE("arith", tout << "Not in range v" << v << " " << new_value << "\n"); | ||||
|             TRACE("arith_verbose", tout << "Not in range v" << v << " " << new_value << "\n"); | ||||
|             return false; | ||||
|         } | ||||
|         if (!in_bounds(v, new_value) && in_bounds(v, old_value)) { | ||||
|             TRACE("arith", tout << "out of bounds v" << v << " " << new_value << "\n"); | ||||
|             TRACE("arith_verbose", tout << "out of bounds v" << v << " " << new_value << "\n"); | ||||
|             //verbose_stream() << "out of bounds v" << v << " " << new_value << "\n";
 | ||||
|             return false; | ||||
|         } | ||||
|  | @ -3166,16 +3193,16 @@ namespace sls { | |||
| 
 | ||||
|     template<typename num_t> | ||||
|     void arith_base<num_t>::check_restart() { | ||||
|         if (m_stats.m_moves % m_config.restart_base == 0) { | ||||
|         if (m_stats.m_steps % m_config.restart_base == 0) { | ||||
|             ucb_forget(); | ||||
|             rescore(); | ||||
|         } | ||||
| 
 | ||||
|         if (m_stats.m_moves < m_config.restart_next) | ||||
|         if (m_stats.m_steps < m_config.restart_next) | ||||
|             return; | ||||
| 
 | ||||
|         ++m_stats.m_restarts; | ||||
|         m_config.restart_next = std::max(m_config.restart_next, m_stats.m_moves); | ||||
|         m_config.restart_next = std::max(m_config.restart_next, m_stats.m_steps); | ||||
| 
 | ||||
|         if (0x1 == (m_stats.m_restarts & 0x1)) | ||||
|             m_config.restart_next += m_config.restart_base; | ||||
|  | @ -3184,10 +3211,8 @@ namespace sls { | |||
| 
 | ||||
|         // reset_uninterp_in_false_literals
 | ||||
|         rescore(); | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     template<typename num_t> | ||||
|     void arith_base<num_t>::ucb_forget() { | ||||
|         if (m_config.ucb_forget >= 1.0) | ||||
|  | @ -3214,18 +3239,21 @@ namespace sls { | |||
|         //m_config.ucb_forget = p.ucb_forget();
 | ||||
|         m_config.wp = p.wp(); | ||||
|         m_config.restart_base = p.restart_base(); | ||||
|         //m_config.restart_next = p.restart_next();
 | ||||
|         m_config.restart_next = p.restart_base(); | ||||
|         //m_config.max_moves_base = p.max_moves_base();
 | ||||
|         //m_config.max_moves = p.max_moves();
 | ||||
|         m_config.arith_use_lookahead = p.arith_use_lookahead(); | ||||
|         m_config.use_lookahead = p.arith_use_lookahead(); | ||||
|         m_config.use_clausal_lookahead = p.arith_use_clausal_lookahead(); | ||||
|         m_config.allow_plateau = p.arith_allow_plateau(); | ||||
|         m_config.config_initialized = true; | ||||
|     } | ||||
| 
 | ||||
|     template<typename num_t> | ||||
|     void arith_base<num_t>::start_propagation() { | ||||
|         updt_params();         | ||||
|         if (m_config.arith_use_lookahead) | ||||
|         updt_params();     | ||||
|         if (m_config.use_clausal_lookahead) | ||||
|             m_clausal_sls.search(); | ||||
|         else if (m_config.use_lookahead) | ||||
|             global_search(); | ||||
|     } | ||||
| 
 | ||||
|  |  | |||
|  | @ -22,6 +22,7 @@ Author: | |||
| #include "ast/ast_trail.h" | ||||
| #include "ast/arith_decl_plugin.h" | ||||
| #include "ast/sls/sls_context.h" | ||||
| #include "ast/sls/sls_arith_clausal.h" | ||||
| 
 | ||||
| namespace sls { | ||||
| 
 | ||||
|  | @ -36,6 +37,8 @@ namespace sls { | |||
| 
 | ||||
|     std::ostream& operator<<(std::ostream& out, arith_move_type mt); | ||||
| 
 | ||||
|     static const unsigned null_arith_var = UINT_MAX; | ||||
| 
 | ||||
|     // local search portion for arithmetic
 | ||||
|     template<typename num_t> | ||||
|     class arith_base : public plugin { | ||||
|  | @ -66,13 +69,13 @@ namespace sls { | |||
|             unsigned restart_base = 1000; | ||||
|             unsigned restart_next = 1000; | ||||
|             unsigned restart_init = 1000; | ||||
|             bool     arith_use_lookahead = false; | ||||
|             bool     use_lookahead = false; | ||||
|             bool     use_clausal_lookahead = false; | ||||
|             bool     allow_plateau = false; | ||||
|         }; | ||||
| 
 | ||||
|         struct stats { | ||||
|             unsigned m_num_steps = 0; | ||||
|             unsigned m_moves = 0; | ||||
|             unsigned m_steps = 0; | ||||
|             unsigned m_restarts = 0; | ||||
|         }; | ||||
| 
 | ||||
|  | @ -116,6 +119,8 @@ namespace sls { | |||
|             arith_op_kind m_op = arith_op_kind::LAST_ARITH_OP; | ||||
|             unsigned     m_def_idx = UINT_MAX; | ||||
|             vector<std::pair<num_t, sat::bool_var>> m_linear_occurs; | ||||
|             indexed_uint_set m_bool_vars_of; | ||||
|             indexed_uint_set m_clauses_of; | ||||
|             unsigned_vector m_muls; | ||||
|             unsigned_vector m_adds; | ||||
|             optional<bound> m_lo, m_hi; | ||||
|  | @ -154,6 +159,9 @@ namespace sls { | |||
|                 else | ||||
|                     m_tabu_neg = tabu_step, m_last_neg = step; | ||||
|             } | ||||
|             unsigned last_step(num_t const& delta) const { | ||||
|                 return delta > 0 ? m_last_pos : m_last_neg; | ||||
|             } | ||||
|             void out_of_range() { | ||||
|                 ++m_num_out_of_range; | ||||
|                 if (m_num_out_of_range < 1000 * (1 + m_num_in_range)) | ||||
|  | @ -204,7 +212,10 @@ namespace sls { | |||
|         bool                         m_use_tabu = true; | ||||
|         unsigned                     m_updates_max_size = 45; | ||||
|         arith_util                   a; | ||||
|         friend class arith_clausal<num_t>; | ||||
|         arith_clausal<num_t>         m_clausal_sls; | ||||
|         svector<double>              m_prob_break; | ||||
|         indexed_uint_set             m_bool_var_atoms; | ||||
| 
 | ||||
|         void invariant(); | ||||
|         void invariant(ineq const& i); | ||||
|  | @ -277,6 +288,7 @@ namespace sls { | |||
|         double compute_score(var_t x, num_t const& delta); | ||||
|         void save_best_values(); | ||||
| 
 | ||||
|         void initialize_bool_vars_of(var_t v); | ||||
|         var_t mk_var(expr* e); | ||||
|         var_t mk_term(expr* e); | ||||
|         var_t mk_op(arith_op_kind k, expr* e, expr* x, expr* y); | ||||
|  | @ -318,7 +330,7 @@ namespace sls { | |||
|             double   score = 0; | ||||
|             unsigned touched = 1; | ||||
|             lbool    value = l_undef; | ||||
|             sat::bool_var_set fixable_atoms; | ||||
|             indexed_uint_set fixable_atoms; | ||||
|             uint_set          fixable_vars; | ||||
|             ptr_vector<expr>  fixable_exprs; | ||||
|             bool_info(unsigned w) : weight(w) {} | ||||
|  | @ -335,6 +347,7 @@ namespace sls { | |||
|         unsigned m_touched = 1; | ||||
|         sat::bool_var_set m_fixed_atoms; | ||||
|         uint64_t m_tabu_set = 0; | ||||
|         unsigned m_global_search_count = 0; | ||||
| 
 | ||||
|         bool in_tabu_set(expr* e, num_t const& n); | ||||
|         void insert_tabu_set(expr* e, num_t const& n); | ||||
|  | @ -344,6 +357,7 @@ namespace sls { | |||
|         void set_bool_value(expr* e, bool v) { get_bool_info(e).value = to_lbool(v); } | ||||
|         bool get_basic_bool_value(app* e); | ||||
|         void initialize_bool_assignment(); | ||||
| 
 | ||||
|         void finalize_bool_assignment(); | ||||
|         double old_score(expr* e) { return get_bool_info(e).score; } | ||||
|         double new_score(expr* e); | ||||
|  | @ -366,6 +380,7 @@ namespace sls { | |||
|         void lookahead_bool(expr* e); | ||||
|         double lookahead(expr* e, bool update_score); | ||||
|         void add_lookahead(bool_info& i, expr* e); | ||||
|         void add_lookahead(bool_info& i, sat::bool_var bv); | ||||
|         ptr_vector<expr> const& get_fixable_exprs(expr* e); | ||||
|         bool apply_move(expr* f, ptr_vector<expr> const& vars, arith_move_type t); | ||||
|         expr* get_candidate_unsat(); | ||||
|  |  | |||
							
								
								
									
										368
									
								
								src/ast/sls/sls_arith_clausal.cpp
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										368
									
								
								src/ast/sls/sls_arith_clausal.cpp
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,368 @@ | |||
| /*++
 | ||||
| Copyright (c) 2025 Microsoft Corporation | ||||
| 
 | ||||
| Module Name: | ||||
| 
 | ||||
|     sls_arith_clausal | ||||
| 
 | ||||
| Abstract: | ||||
| 
 | ||||
|     Theory plugin for arithmetic local search | ||||
|     based on clausal search as used in HybridSMT (nia_ls) | ||||
| 
 | ||||
|     In contrast to HybridSMT/nia_ls we reuse ddfw  | ||||
|     for everything Boolean. It requiers exposing the following: | ||||
| 
 | ||||
|     - unsat_vars - Boolean variables that are in unsat clauses. | ||||
|     - num_external_vars_in_unsat - External variables in unsat clauses | ||||
|     - shift_weights - allow plugin to invoke shift-weights | ||||
| 
 | ||||
|      | ||||
| Author: | ||||
| 
 | ||||
|     Nikolaj Bjorner (nbjorner) 2025-01-16 | ||||
| 
 | ||||
| --*/ | ||||
| 
 | ||||
| #include "ast/sls/sls_arith_clausal.h" | ||||
| #include "ast/sls/sls_arith_base.h" | ||||
| 
 | ||||
| namespace sls { | ||||
|     template<typename num_t> | ||||
|     arith_clausal<num_t>::arith_clausal(arith_base<num_t>& a) : | ||||
|         ctx(a.ctx), | ||||
|         a(a) {  | ||||
|     }             | ||||
| 
 | ||||
|     template<typename num_t> | ||||
|     void arith_clausal<num_t>::search() { | ||||
|         num_t delta; | ||||
| 
 | ||||
|         initialize(); | ||||
| 
 | ||||
|         TRACE("arith", ctx.display_all(tout)); | ||||
| 
 | ||||
|         a.m_config.max_moves = a.m_stats.m_steps + a.m_config.max_moves_base; | ||||
| 
 | ||||
|         while (ctx.rlimit().inc() && a.m_stats.m_steps < a.m_config.max_moves && !ctx.unsat().empty()) { | ||||
|             a.m_stats.m_steps++; | ||||
| 
 | ||||
|             check_restart(); | ||||
|              | ||||
|             unsigned vars_in_unsat = ctx.unsat_vars().size(); | ||||
|             unsigned ext_in_unsat = ctx.num_external_in_unsat_vars(); | ||||
|             unsigned bool_in_unsat =  vars_in_unsat - ext_in_unsat; | ||||
|             bool time_up_bool  = m_no_improve_bool  * vars_in_unsat >  5 * bool_in_unsat; | ||||
|             bool time_up_arith = m_no_improve_arith * vars_in_unsat > 20 * ext_in_unsat; | ||||
|             if ((m_bool_mode && bool_in_unsat < vars_in_unsat && time_up_bool) || bool_in_unsat == 0) | ||||
|                 enter_arith_mode(); | ||||
|             else if ((!m_bool_mode && bool_in_unsat > 0 && time_up_arith) || vars_in_unsat == bool_in_unsat) | ||||
|                 enter_bool_mode(); | ||||
|             if (m_bool_mode) { | ||||
|                 sat::bool_var v = ctx.bool_flip();  | ||||
|                 TRACE("arith", tout << "bool flip v:" << v << "\n"; | ||||
|                 tout << "unsat-vars " << vars_in_unsat << "\n"; | ||||
|                 tout << "bools: " << bool_in_unsat << " timeup-bool " << time_up_bool << "\n"; | ||||
|                 tout << "no-improve bool: " << m_no_improve_bool << "\n"; | ||||
|                 tout << "ext: " << ext_in_unsat << " timeup-arith " << time_up_arith << "\n";); | ||||
|                  | ||||
|                 m_no_improve_bool = update_outer_best_solution() ? 0 : m_no_improve_bool + 1; | ||||
|             } | ||||
|             else { | ||||
|                 move_arith_variable();                 | ||||
|                 m_no_improve_arith = update_inner_best_solution() ? 0 : m_no_improve_arith + 1; | ||||
|             } | ||||
|             m_no_improve = update_best_solution() ? 0 : m_no_improve + 1; | ||||
|         } | ||||
|         if (a.m_stats.m_steps >= a.m_config.max_moves) | ||||
|             a.m_config.max_moves_base += 100; | ||||
|     } | ||||
| 
 | ||||
|     template<typename num_t> | ||||
|     void arith_clausal<num_t>::move_arith_variable() { | ||||
| 
 | ||||
|         var_t v = null_arith_var; | ||||
| 
 | ||||
|         { | ||||
|             a.m_best_score = 1; | ||||
|             flet<bool> _use_tabu(a.m_use_tabu, true); | ||||
|             if (v == null_arith_var) { | ||||
|                 add_lookahead_on_unsat_vars(); | ||||
|                 v = critical_move_on_updates(unsat_var_move); | ||||
|             } | ||||
|             if (v == null_arith_var) { | ||||
|                 add_lookahead_on_false_literals(); | ||||
|                 v = critical_move_on_updates(false_literal_move); | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         // tabu flips were not possible
 | ||||
| 
 | ||||
|         if (v == null_arith_var) | ||||
|             ctx.shift_weights(); | ||||
| 
 | ||||
|         if (v == null_arith_var) { | ||||
|             a.m_best_score = -1; | ||||
|             flet<bool> _use_tabu(a.m_use_tabu, false); | ||||
|             add_lookahead_on_unsat_vars(); | ||||
|             v = random_move_on_updates(); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     template<typename num_t> | ||||
|     void arith_clausal<num_t>::add_lookahead_on_unsat_vars() { | ||||
|         a.m_updates.reset(); | ||||
|         a.m_fixed_atoms.reset(); | ||||
|         TRACE("arith_verbose", tout << "unsat-vars "; | ||||
|         for (auto v : ctx.unsat_vars()) | ||||
|             if (a.get_ineq(v)) tout << mk_bounded_pp(ctx.atom(v), a.m) << " "; | ||||
|         tout << "\n";); | ||||
| 
 | ||||
|         for (auto v : ctx.unsat_vars()) { | ||||
| 
 | ||||
|             auto* ineq = a.get_ineq(v); | ||||
|             if (!ineq) | ||||
|                 continue; | ||||
|             auto e = ctx.atom(v); | ||||
|             auto& i = a.get_bool_info(e);   | ||||
|             auto const& vars = a.get_fixable_exprs(e);             | ||||
|             for (auto v : vars) | ||||
|                 a.add_lookahead(i, v); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     /**
 | ||||
|     * \brief walk over literals that are false in some clause. | ||||
|     * Try to determine if flipping them to true improves the overall score. | ||||
|     */ | ||||
|     template<typename num_t> | ||||
|     void arith_clausal<num_t>::add_lookahead_on_false_literals() { | ||||
|         a.m_updates.reset(); | ||||
|         a.m_fixed_atoms.reset(); | ||||
| 
 | ||||
|         for (auto bv : a.m_bool_var_atoms) { | ||||
|             if (ctx.unsat_vars().contains(bv)) | ||||
|                 continue; | ||||
|             auto* ineq = a.get_ineq(bv);             | ||||
|             if (!ineq) | ||||
|                 continue; | ||||
|             sat::literal lit(bv, !ineq->is_true()); | ||||
|             auto const& ul = ctx.get_use_list(~lit); | ||||
|             if (ul.begin() == ul.end()) | ||||
|                 continue; | ||||
|             auto v = lit.var(); | ||||
|             // literal is false in some clause but none of the clauses where it occurs false are unsat.
 | ||||
| 
 | ||||
|             auto e = ctx.atom(v); | ||||
|             auto& i = a.get_bool_info(e); | ||||
|             a.add_lookahead(i, v); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     template<typename num_t> | ||||
|     var_t arith_clausal<num_t>::critical_move_on_updates(move_t mt) { | ||||
|         if (a.m_updates.empty()) | ||||
|             return null_arith_var; | ||||
|         std::stable_sort(a.m_updates.begin(), a.m_updates.end(), [](auto const& a, auto const& b) { return a.m_var < b.m_var || (a.m_var == b.m_var && a.m_delta < b.m_delta); }); | ||||
|         m_last_var = null_arith_var; | ||||
|         m_last_delta = 0; | ||||
|         m_best_var = null_arith_var; | ||||
|         m_best_delta = 0; | ||||
|         m_best_abs_value = num_t(-1); | ||||
|         m_best_last_step = UINT_MAX; | ||||
|         for (auto const& u : a.m_updates) | ||||
|             lookahead(u.m_var, u.m_delta); | ||||
|         critical_move(m_best_var, m_best_delta, mt);         | ||||
|         return m_best_var; | ||||
|     } | ||||
| 
 | ||||
|     template<typename num_t> | ||||
|     var_t arith_clausal<num_t>::random_move_on_updates() { | ||||
|         if (a.m_updates.empty()) | ||||
|             return null_arith_var; | ||||
|         unsigned idx = ctx.rand(a.m_updates.size()); | ||||
|         auto& [v, delta, score] = a.m_updates[idx]; | ||||
|         if (!a.can_update_num(v, delta)) | ||||
|             return null_arith_var; | ||||
|         critical_move(v, delta, random_move); | ||||
|         return v; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     template<typename num_t> | ||||
|     void arith_clausal<num_t>::lookahead(var_t v, num_t const& delta) { | ||||
|         if (v == m_last_var && delta == m_last_delta) | ||||
|             return; | ||||
|         if (delta == 0) | ||||
|             return; | ||||
|         m_last_var = v; | ||||
|         m_last_delta = delta; | ||||
|         if (!a.can_update_num(v, delta)) | ||||
|             return; | ||||
|         auto score = get_score(v, delta); | ||||
|         auto& vi = a.m_vars[v];         | ||||
|         num_t abs_value = abs(vi.value() + delta); | ||||
|         unsigned last_step = vi.last_step(delta); | ||||
|         if (score < a.m_best_score) | ||||
|             return; | ||||
|         if (score > a.m_best_score || | ||||
|             (m_best_abs_value == -1) || | ||||
|             (abs_value < m_best_abs_value) || | ||||
|             (abs_value == m_best_abs_value && last_step < m_best_last_step)) { | ||||
|             a.m_best_score = score; | ||||
|             m_best_var = v; | ||||
|             m_best_delta = delta; | ||||
|             m_best_last_step = last_step; | ||||
|             m_best_abs_value = abs_value; | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     template<typename num_t> | ||||
|     void arith_clausal<num_t>::critical_move(var_t v, num_t const& delta, move_t mt) { | ||||
|         if (v == null_arith_var) | ||||
|             return; | ||||
|         a.m_last_delta = delta; | ||||
|         a.m_last_var = v; | ||||
|         TRACE("arith", tout << mt << " v" << v << " " << mk_bounded_pp(a.m_vars[v].m_expr, a.m)  | ||||
|                             << " += " << delta << " score:" << a.m_best_score << "\n"); | ||||
|         a.m_vars[v].set_step(a.m_stats.m_steps, a.m_stats.m_steps + 3 + ctx.rand(10), delta); | ||||
|         VERIFY(a.update_num(v, delta)); | ||||
|         for (auto bv : a.m_vars[v].m_bool_vars_of)  | ||||
|             if (a.get_ineq(bv) && a.get_ineq(bv)->is_true() != ctx.is_true(bv))  | ||||
|                 ctx.flip(bv);    | ||||
| 
 | ||||
|         DEBUG_CODE( | ||||
|             for (sat::bool_var bv = 0; bv < ctx.num_bool_vars(); ++bv) { | ||||
|                 if (a.get_ineq(bv) && a.get_ineq(bv)->is_true() != ctx.is_true(bv)) { | ||||
|                     TRACE("arith", tout << bv << " " << *a.get_ineq(bv) << "\n"; | ||||
|                     tout << a.m_vars[v].m_bool_vars_of << "\n"); | ||||
|                 } | ||||
|                 VERIFY(!a.get_ineq(bv) || a.get_ineq(bv)->is_true() == ctx.is_true(bv)); | ||||
|             }); | ||||
|     } | ||||
| 
 | ||||
|     template<typename num_t> | ||||
|     double arith_clausal<num_t>::get_score(var_t v, num_t const& delta) { | ||||
|         auto& vi = a.m_vars[v]; | ||||
|         VERIFY(a.update_num(v, delta)); | ||||
|         double score = 0; | ||||
|         for (auto ci : vi.m_clauses_of) { | ||||
|             auto const& c = ctx.get_clause(ci); | ||||
|             unsigned num_true = 0; | ||||
|             for (auto lit : c) { | ||||
|                 auto bv = lit.var(); | ||||
|                 auto ineq = a.get_ineq(bv); | ||||
|                 if (ineq) { | ||||
|                     if (ineq->is_true() != lit.sign()) | ||||
|                         ++num_true; | ||||
|                 } | ||||
|                 else if (ctx.is_true(lit)) | ||||
|                     ++num_true; | ||||
|             } | ||||
|             CTRACE("arith_verbose", c.m_num_trues != num_true && (c.m_num_trues == 0 || num_true == 0), | ||||
|                 tout << "clause: " << c | ||||
|                 << " v" << v << " += " << delta | ||||
|                 << " new-true lits: " << num_true | ||||
|                 << " old-true lits: " << c.m_num_trues | ||||
|                 << " w: " << c.m_weight << "\n"; | ||||
|             for (auto lit : c)  | ||||
|                 if (a.get_ineq(lit.var())) | ||||
|                     tout << lit << " " << *a.get_ineq(lit.var()) << "\n";); | ||||
|             if (c.m_num_trues > 0 && num_true == 0)  | ||||
|                 score -= c.m_weight;             | ||||
|             else if (c.m_num_trues == 0 && num_true > 0)  | ||||
|                 score += c.m_weight;             | ||||
|         } | ||||
|         // revert the update
 | ||||
|         VERIFY(a.update_num(v, -delta)); | ||||
|         return score; | ||||
|     } | ||||
| 
 | ||||
|     template<typename num_t> | ||||
|     void arith_clausal<num_t>::check_restart() { | ||||
|         if (m_no_improve <= 500000) | ||||
|             return; | ||||
|          | ||||
|         IF_VERBOSE(2, verbose_stream() << "restart sls-arith\n"); | ||||
|         TRACE("arith", tout << "restart\n";); | ||||
|         // reset values of (arithmetical) variables at bounds.
 | ||||
|         for (auto& vi : a.m_vars) { | ||||
|             if (vi.m_lo && !vi.m_lo->is_strict && vi.m_lo->value > 0) | ||||
|                 vi.set_value(vi.m_lo->value); | ||||
|             else if (vi.m_hi && !vi.m_hi->is_strict && vi.m_hi->value < 0) | ||||
|                 vi.set_value(vi.m_hi->value); | ||||
|             else | ||||
|                 vi.set_value(num_t(0)); | ||||
|             vi.m_bool_vars_of.reset(); | ||||
|             vi.m_clauses_of.reset(); | ||||
|         } | ||||
|         initialize(); | ||||
|     } | ||||
| 
 | ||||
|     template<typename num_t> | ||||
|     void arith_clausal<num_t>::initialize() { | ||||
|         a.initialize_bool_assignment(); | ||||
|         for (sat::bool_var v = 0; v < ctx.num_bool_vars(); ++v) | ||||
|             a.init_bool_var_assignment(v);    | ||||
|          | ||||
|         m_best_found_cost_bool = ctx.unsat().size(); | ||||
|         m_best_found_cost_arith = ctx.unsat().size(); | ||||
|         m_best_found_cost_restart = ctx.unsat().size(); | ||||
|         m_no_improve = 0; | ||||
|         m_no_improve_bool = 0; | ||||
|         m_no_improve_arith = 0; | ||||
|     }     | ||||
|     | ||||
| 
 | ||||
|     template<typename num_t> | ||||
|     bool arith_clausal<num_t>::update_outer_best_solution() { | ||||
|         if (ctx.unsat().size() >= m_best_found_cost_bool) | ||||
|             return false; | ||||
|         m_best_found_cost_bool = ctx.unsat().size(); | ||||
|         return true; | ||||
|     } | ||||
| 
 | ||||
|     template<typename num_t> | ||||
|     void arith_clausal<num_t>::enter_bool_mode() { | ||||
|         CTRACE("arith", !m_bool_mode, tout << "enter bool mode\n";); | ||||
|         m_best_found_cost_bool = ctx.unsat().size(); | ||||
|         if (!m_bool_mode)  | ||||
|             m_no_improve_bool = 0;  | ||||
|         m_bool_mode = true; | ||||
|     } | ||||
| 
 | ||||
|     template<typename num_t> | ||||
|     bool arith_clausal<num_t>::update_inner_best_solution() { | ||||
|         if (ctx.unsat().size() >= m_best_found_cost_arith) | ||||
|             return false; | ||||
|         m_best_found_cost_arith = ctx.unsat().size(); | ||||
|         return true; | ||||
|     } | ||||
| 
 | ||||
|     template<typename num_t> | ||||
|     void arith_clausal<num_t>::enter_arith_mode() { | ||||
|         CTRACE("arith", m_bool_mode, tout << "enter arith mode\n";); | ||||
|         m_best_found_cost_arith = ctx.unsat().size(); | ||||
|         if (m_bool_mode) | ||||
|             m_no_improve_arith = 0; | ||||
|         m_bool_mode = false; | ||||
|     } | ||||
| 
 | ||||
|     template<typename num_t> | ||||
|     bool arith_clausal<num_t>::update_best_solution() { | ||||
|         bool improved = false; | ||||
|         if (ctx.unsat().size() < m_best_found_cost_restart) { | ||||
|             improved = true; | ||||
|             m_best_found_cost_restart = ctx.unsat().size(); | ||||
|         } | ||||
|         if (ctx.unsat().size() < m_best_found_cost) { | ||||
|             improved = true; | ||||
|             m_best_found_cost = ctx.unsat().size(); | ||||
|         } | ||||
|         return improved; | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| template class sls::arith_clausal<checked_int64<true>>; | ||||
| template class sls::arith_clausal<rational>; | ||||
| 
 | ||||
							
								
								
									
										90
									
								
								src/ast/sls/sls_arith_clausal.h
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										90
									
								
								src/ast/sls/sls_arith_clausal.h
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,90 @@ | |||
| /*++
 | ||||
| Copyright (c) 2025 Microsoft Corporation | ||||
| 
 | ||||
| Module Name: | ||||
| 
 | ||||
|     sls_arith_clausal | ||||
| 
 | ||||
| Abstract: | ||||
| 
 | ||||
|     Theory plugin for arithmetic local search  | ||||
|     based on clausal search as used in HybridSMT | ||||
| 
 | ||||
| Author: | ||||
| 
 | ||||
|     Nikolaj Bjorner (nbjorner) 2025-01-16 | ||||
| 
 | ||||
| --*/ | ||||
| #pragma once | ||||
| 
 | ||||
| #include "util/checked_int64.h" | ||||
| #include "util/optional.h" | ||||
| #include "ast/ast_trail.h" | ||||
| #include "ast/arith_decl_plugin.h" | ||||
| #include "ast/sls/sls_context.h" | ||||
| 
 | ||||
| namespace sls { | ||||
|      | ||||
|     template<typename num_t> | ||||
|     class arith_base; | ||||
| 
 | ||||
|     using var_t = unsigned; | ||||
| 
 | ||||
|     template<typename num_t> | ||||
|     class arith_clausal { | ||||
|         context& ctx; | ||||
|         class arith_base<num_t>& a; | ||||
| 
 | ||||
|         void check_restart(); | ||||
|         void initialize(); | ||||
| 
 | ||||
|         enum move_t { | ||||
|             unsat_var_move, | ||||
|             false_literal_move, | ||||
|             random_move | ||||
|         }; | ||||
|         friend std::ostream& operator<<(std::ostream& out, move_t mt) {  | ||||
|             return out << (mt == unsat_var_move ?  | ||||
|                 "unsat-var" : mt == false_literal_move ?  | ||||
|                 "false-literal" : "random");  | ||||
|         } | ||||
|         void enter_arith_mode(); | ||||
|         void enter_bool_mode(); | ||||
| 
 | ||||
|         bool update_outer_best_solution(); | ||||
|         bool update_inner_best_solution(); | ||||
|         bool update_best_solution(); | ||||
|         void move_arith_variable(); | ||||
|         var_t critical_move_on_updates(move_t mt); | ||||
|         var_t random_move_on_updates(); | ||||
|         void add_lookahead_on_unsat_vars(); | ||||
|         void add_lookahead_on_false_literals(); | ||||
|         void critical_move(var_t v, num_t const& delta, move_t mt); | ||||
|         void lookahead(var_t v, num_t const& delta); | ||||
|         double get_score(var_t v, num_t const& delta); | ||||
|          | ||||
| 
 | ||||
|         unsigned m_no_improve_bool = 0; | ||||
|         unsigned m_no_improve_arith = 0; | ||||
|         unsigned m_no_improve = 0; | ||||
|         bool     m_bool_mode = true; | ||||
|         unsigned m_best_found_cost_bool = 0; | ||||
|         unsigned m_best_found_cost_arith = 0; | ||||
|         unsigned m_best_found_cost_restart = 0; | ||||
|         unsigned m_best_found_cost = 0; | ||||
|         num_t    m_best_abs_value; | ||||
|         num_t    m_best_delta; | ||||
|         var_t    m_best_var = UINT_MAX; | ||||
|         unsigned m_best_last_step = 0; | ||||
| 
 | ||||
|         // avoid checking the same updates twice
 | ||||
|         var_t m_last_var = UINT_MAX; | ||||
|         num_t m_last_delta; | ||||
| 
 | ||||
|     public: | ||||
|         arith_clausal(arith_base<num_t>& a); | ||||
|         void search(); | ||||
|     }; | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
|  | @ -112,6 +112,18 @@ namespace sls { | |||
|             if (p) | ||||
|                 p->on_restart(); | ||||
|     } | ||||
| 
 | ||||
|     bool context::is_external(sat::bool_var v) { | ||||
|         auto a = atom(v); | ||||
|         if (!a) | ||||
|             return false; | ||||
|         family_id fid = get_fid(a); | ||||
|         if (fid == basic_family_id) | ||||
|             return false; | ||||
|         auto p = m_plugins.get(fid, nullptr); | ||||
|         CTRACE("sls_verbose", p != nullptr, tout << "external " << mk_bounded_pp(a, m) << "\n"); | ||||
|         return p != nullptr;      | ||||
|     } | ||||
|      | ||||
|     lbool context::check() { | ||||
|         //
 | ||||
|  | @ -438,6 +450,7 @@ namespace sls { | |||
| 
 | ||||
|     sat::literal context::mk_literal(expr* e) { | ||||
|         expr_ref _e(e, m); | ||||
|         SASSERT(!m_input_assertions.contains(e)); | ||||
|         sat::literal lit; | ||||
|         bool neg = false; | ||||
|         expr* a, * b, * c; | ||||
|  | @ -528,8 +541,11 @@ namespace sls { | |||
|         for (unsigned i = 0; i < m_atoms.size(); ++i) | ||||
|             if (m_atoms.get(i)) | ||||
|                 register_terms(m_atoms.get(i)); | ||||
|         for (auto e : m_input_assertions) | ||||
|             register_terms(e); | ||||
|         { | ||||
|             flet<bool> _is_input_assertion(m_is_input_assertion, true); | ||||
|             for (auto e : m_input_assertions) | ||||
|                 register_terms(e); | ||||
|         } | ||||
|         for (auto p : m_plugins) | ||||
|             if (p) | ||||
|                 p->initialize(); | ||||
|  | @ -564,7 +580,7 @@ namespace sls { | |||
|                         m_parents.reserve(arg->get_id() + 1); | ||||
|                         m_parents[arg->get_id()].push_back(e); | ||||
|                     } | ||||
|                     if (m.is_bool(e)) | ||||
|                     if (m.is_bool(e) && !m_is_input_assertion) | ||||
|                         mk_literal(e); | ||||
|                     visit(e); | ||||
|                 } | ||||
|  | @ -629,7 +645,6 @@ namespace sls { | |||
|         m_visited.reset(); | ||||
|         m_root_literals.reset(); | ||||
| 
 | ||||
| 
 | ||||
|         for (auto const& clause : s.clauses()) { | ||||
|             bool has_relevant = false; | ||||
|             unsigned n = 0; | ||||
|  |  | |||
|  | @ -69,12 +69,16 @@ namespace sls { | |||
|         virtual sat::clause_info const& get_clause(unsigned idx) const = 0; | ||||
|         virtual ptr_iterator<unsigned> get_use_list(sat::literal lit) = 0; | ||||
|         virtual void flip(sat::bool_var v) = 0; | ||||
|         virtual sat::bool_var bool_flip() = 0; | ||||
|         virtual bool try_rotate(sat::bool_var v, sat::bool_var_set& rotated, unsigned& budget) = 0; | ||||
|         virtual double reward(sat::bool_var v) = 0; | ||||
|         virtual double get_weigth(unsigned clause_idx) = 0; | ||||
|         virtual bool is_true(sat::literal lit) = 0; | ||||
|         virtual unsigned num_vars() const = 0; | ||||
|         virtual indexed_uint_set const& unsat() const = 0; | ||||
|         virtual indexed_uint_set const& unsat_vars() const = 0; | ||||
|         virtual void shift_weights() = 0; | ||||
|         virtual unsigned num_external_in_unsat_vars() const = 0; | ||||
|         virtual void on_model(model_ref& mdl) = 0; | ||||
|         virtual sat::bool_var add_var() = 0; | ||||
|         virtual void add_clause(unsigned n, sat::literal const* lits) = 0; | ||||
|  | @ -136,6 +140,7 @@ namespace sls { | |||
| 
 | ||||
|         void init(); | ||||
|         expr_ref_vector m_todo; | ||||
|         bool m_is_input_assertion = false; | ||||
|         void register_terms(expr* e); | ||||
|         void register_term(expr* e); | ||||
| 
 | ||||
|  | @ -162,6 +167,7 @@ namespace sls { | |||
|         void register_atom(sat::bool_var v, expr* e); | ||||
|         lbool check();        | ||||
| 
 | ||||
|         bool is_external(sat::bool_var v); | ||||
|         void on_restart(); | ||||
|         void updt_params(params_ref const& p); | ||||
|         params_ref const& get_params() const { return m_params;  } | ||||
|  | @ -183,9 +189,13 @@ namespace sls { | |||
|         void add_theory_axiom(expr* f) { add_assertion(f, false); } | ||||
|         void add_clause(sat::literal_vector const& lits); | ||||
|         void flip(sat::bool_var v) { s.flip(v); } | ||||
|         sat::bool_var bool_flip() { return s.bool_flip(); } | ||||
|         void shift_weights() { s.shift_weights(); } | ||||
|         bool try_rotate(sat::bool_var v, sat::bool_var_set& rotated, unsigned& budget) { return s.try_rotate(v, rotated, budget); } | ||||
|         double reward(sat::bool_var v) { return s.reward(v); } | ||||
|         indexed_uint_set const& unsat() const { return s.unsat(); } | ||||
|         indexed_uint_set const& unsat_vars() const { return s.unsat_vars(); } | ||||
|         unsigned num_external_in_unsat_vars() const { return s.num_external_in_unsat_vars(); } | ||||
|         unsigned rand() { return m_rand(); } | ||||
|         unsigned rand(unsigned n) { return m_rand(n); } | ||||
|         reslimit& rlimit() { return s.rlimit(); } | ||||
|  |  | |||
|  | @ -124,6 +124,8 @@ namespace sls { | |||
|                 m_ddfw->reinit(); | ||||
|         } | ||||
| 
 | ||||
|         void shift_weights() override { m_ddfw->shift_weights(); } | ||||
| 
 | ||||
|         lbool on_save_model() override; | ||||
| 
 | ||||
|         void on_model(model_ref& mdl) override { | ||||
|  | @ -131,6 +133,14 @@ namespace sls { | |||
|             m_sls_model = mdl; | ||||
|         } | ||||
| 
 | ||||
|         sat::bool_var bool_flip() override { | ||||
|             return m_ddfw->bool_flip(); | ||||
|         } | ||||
| 
 | ||||
|         bool is_external(sat::bool_var v) override { | ||||
|             return m_context.is_external(v); | ||||
|         } | ||||
| 
 | ||||
|         void on_rescale() override {} | ||||
| 
 | ||||
|         reslimit& rlimit() override { return m_ddfw->rlimit(); } | ||||
|  | @ -160,6 +170,8 @@ namespace sls { | |||
|         } | ||||
|         unsigned num_vars() const override { return m_ddfw->num_vars(); } | ||||
|         indexed_uint_set const& unsat() const override { return m_ddfw->unsat_set(); } | ||||
|         indexed_uint_set const& unsat_vars() const override { return m_ddfw->unsat_vars(); } | ||||
|         unsigned num_external_in_unsat_vars() const override { return m_ddfw->num_external_in_unsat_vars(); } | ||||
|         sat::bool_var add_var() override {  | ||||
|             return m_ddfw->add_var();  | ||||
|         } | ||||
|  |  | |||
|  | @ -67,10 +67,14 @@ namespace sls { | |||
|             return r; | ||||
|         } | ||||
| 
 | ||||
|         void on_model(model_ref& mdl) override {            | ||||
|         void on_model(model_ref& mdl) override { | ||||
|             m_model = mdl; | ||||
|         } | ||||
| 
 | ||||
|         bool is_external(sat::bool_var v) override { | ||||
|             return m_context.is_external(v); | ||||
|         } | ||||
| 
 | ||||
|         void register_atom(sat::bool_var v, expr* e) { | ||||
|             m_context.register_atom(v, e); | ||||
|         } | ||||
|  | @ -85,15 +89,19 @@ namespace sls { | |||
|         sat::clause_info const& get_clause(unsigned idx) const override { return m_ddfw.get_clause_info(idx); } | ||||
|         ptr_iterator<unsigned> get_use_list(sat::literal lit) override { return m_ddfw.use_list(lit); } | ||||
|         void flip(sat::bool_var v) override { if (m_dirty) m_ddfw.reinit(), m_dirty = false;  m_ddfw.flip(v); } | ||||
|         sat::bool_var bool_flip() override { if (m_dirty) m_ddfw.reinit(), m_dirty = false; return m_ddfw.bool_flip(); } | ||||
|         bool try_rotate(sat::bool_var v, sat::bool_var_set& rotated, unsigned& budget) override { if (m_dirty) m_ddfw.reinit(), m_dirty = false; return m_ddfw.try_rotate(v, rotated, budget); } | ||||
|         double reward(sat::bool_var v) override { return m_ddfw.reward(v); } | ||||
|         double get_weigth(unsigned clause_idx) override { return m_ddfw.get_clause_info(clause_idx).m_weight; } | ||||
|         bool is_true(sat::literal lit) override { return m_ddfw.get_value(lit.var()) != lit.sign(); } | ||||
|         unsigned num_vars() const override { return m_ddfw.num_vars(); } | ||||
|         indexed_uint_set const& unsat() const override { return m_ddfw.unsat_set(); } | ||||
|         indexed_uint_set const& unsat_vars() const override { return m_ddfw.unsat_vars(); } | ||||
|         unsigned num_external_in_unsat_vars() const override { return m_ddfw.num_external_in_unsat_vars(); } | ||||
|         sat::bool_var add_var() override { m_dirty = true;  return m_ddfw.add_var(); }   | ||||
|         void add_input_assertion(expr* f) { m_context.add_input_assertion(f); } | ||||
|         reslimit& rlimit() { return m_ddfw.rlimit(); } | ||||
|         void shift_weights() override { m_ddfw.shift_weights(); } | ||||
| 
 | ||||
|         void force_restart() override { m_ddfw.force_restart(); } | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue