From b0dd83cc600b140b49434a10b325789c005659d0 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Mon, 21 Oct 2024 13:27:01 -0700 Subject: [PATCH] debugging parallel integration --- src/ast/sls/sat_ddfw.cpp | 34 ++++++++++++++++-------- src/ast/sls/sat_ddfw.h | 5 +++- src/ast/sls/sls_context.cpp | 4 --- src/sat/smt/sls_solver.cpp | 53 ++++++++++++++++++++++++------------- 4 files changed, 62 insertions(+), 34 deletions(-) diff --git a/src/ast/sls/sat_ddfw.cpp b/src/ast/sls/sat_ddfw.cpp index 581c1a355..5473271ff 100644 --- a/src/ast/sls/sat_ddfw.cpp +++ b/src/ast/sls/sat_ddfw.cpp @@ -71,7 +71,6 @@ namespace sat { else shift_weights(), m_plugin->on_rescale(); //verbose_stream() << "steps: " << steps << " min_sz: " << m_min_sz << " unsat: " << m_unsat.size() << "\n"; ++steps; - SASSERT(m_unsat.size() >= m_min_sz); } } catch (z3_exception& ex) { @@ -152,7 +151,7 @@ namespace sat { return m_unsat_vars.elem_at(m_rand(m_unsat_vars.size())); } - void ddfw::add(unsigned n, literal const* c) { + void ddfw::add(unsigned n, literal const* c) { unsigned idx = m_clauses.size(); m_clauses.push_back(clause_info(n, c, m_config.m_init_clause_weight)); if (n > 2) @@ -207,8 +206,9 @@ namespace sat { for (unsigned v = 0; v < num_vars(); ++v) { value(v) = (m_rand() % 2) == 0; // m_use_list[lit.index()].size() >= m_use_list[nlit.index()].size(); } - init_clause_data(); - flatten_use_list(); + + if (!flatten_use_list()) + init_clause_data(); m_reinit_count = 0; m_reinit_next = m_config.m_reinit_base; @@ -216,21 +216,23 @@ namespace sat { m_restart_count = 0; m_restart_next = m_config.m_restart_base*2; - m_min_sz = m_unsat.size(); + m_min_sz = m_clauses.size(); m_flips = 0; m_last_flips = 0; m_shifts = 0; m_stopwatch.start(); - verbose_stream() << "unsat " << m_min_sz << "\n"; } void ddfw::reinit() { add_assumptions(); - init_clause_data(); flatten_use_list(); } - void ddfw::flatten_use_list() { + bool ddfw::flatten_use_list() { + if (num_vars() == m_use_list_vars && m_clauses.size() == m_use_list_clauses) + return false; + m_use_list_vars = num_vars(); + m_use_list_clauses = m_clauses.size(); m_use_list_index.reset(); m_flat_use_list.reset(); for (auto const& ul : m_use_list) { @@ -238,6 +240,8 @@ namespace sat { m_flat_use_list.append(ul); } m_use_list_index.push_back(m_flat_use_list.size()); + init_clause_data(); + return true; } void ddfw::flip(bool_var v) { @@ -246,7 +250,7 @@ namespace sat { literal nlit = ~lit; SASSERT(is_true(lit)); for (unsigned cls_idx : use_list(lit)) { - clause_info& ci = m_clauses[cls_idx]; + clause_info& ci = m_clauses[cls_idx]; ci.del(lit); double w = ci.m_weight; // cls becomes false: flip any variable in clause to receive reward w @@ -402,8 +406,16 @@ namespace sat { } void ddfw::save_best_values() { - if ((m_unsat.size() < m_min_sz || m_unsat.empty()) && - ((m_unsat.size() < 50 || m_min_sz * 10 > m_unsat.size() * 11))) + if (m_save_best_values) + return; + if (m_plugin && !m_unsat.empty()) + return; + flet _save_best_values(m_save_best_values, true); + + bool do_save_model = ((m_unsat.size() < m_min_sz || m_unsat.empty()) && + ((m_unsat.size() < 50 || m_min_sz * 10 > m_unsat.size() * 11))); + + if (do_save_model) save_model(); if (m_unsat.size() < m_min_sz) { diff --git a/src/ast/sls/sat_ddfw.h b/src/ast/sls/sat_ddfw.h index 60e9424a7..a00a196c9 100644 --- a/src/ast/sls/sat_ddfw.h +++ b/src/ast/sls/sat_ddfw.h @@ -89,6 +89,7 @@ namespace sat { vector m_use_list; unsigned_vector m_flat_use_list; unsigned_vector m_use_list_index; + unsigned m_use_list_vars = 0, m_use_list_clauses = 0; indexed_uint_set m_unsat; indexed_uint_set m_unsat_vars; // set of variables that are in unsat clauses @@ -102,11 +103,12 @@ namespace sat { u_map m_models; stopwatch m_stopwatch; unsigned_vector m_num_models; + bool m_save_best_values = false; scoped_ptr m_plugin = nullptr; std::function m_parallel_sync; - void flatten_use_list(); + bool flatten_use_list(); /** * TBD: map reward value to a score, possibly through an exponential function, such as @@ -267,6 +269,7 @@ namespace sat { ptr_iterator use_list(literal lit) { + flatten_use_list(); unsigned i = lit.index(); auto const* b = m_flat_use_list.data() + m_use_list_index[i]; auto const* e = m_flat_use_list.data() + m_use_list_index[i + 1]; diff --git a/src/ast/sls/sls_context.cpp b/src/ast/sls/sls_context.cpp index 4dcfa8bbb..507bee701 100644 --- a/src/ast/sls/sls_context.cpp +++ b/src/ast/sls/sls_context.cpp @@ -116,7 +116,6 @@ namespace sls { // Use timestamps to make it incremental. // init(); - //verbose_stream() << "check " << unsat().size() << "\n"; while (unsat().empty() && m.inc()) { propagate_boolean_assignment(); @@ -124,9 +123,6 @@ namespace sls { // verbose_stream() << "propagate " << unsat().size() << " " << m_new_constraint << "\n"; - - // display(verbose_stream()); - if (m_new_constraint || !unsat().empty()) return l_undef; diff --git a/src/sat/smt/sls_solver.cpp b/src/sat/smt/sls_solver.cpp index 1c5dbd79b..cfdf30c64 100644 --- a/src/sat/smt/sls_solver.cpp +++ b/src/sat/smt/sls_solver.cpp @@ -49,10 +49,10 @@ namespace sls { ast_manager m_sync_manager; obj_map m_sls2sync_uninterp; // hashtable from sls-uninterp to sync uninterp obj_map m_smt2sync_uninterp; // hashtable from external uninterp to sync uninterp + ast_translation m_smt2sync_tr, m_smt2sls_tr; expr_ref_vector m_sync_uninterp; expr_ref_vector m_sync_values; - std::atomic m_has_new_sls_values = false; - + std::atomic m_has_new_sls_values = false; // export from SAT to SLS: // - unit literals @@ -64,8 +64,10 @@ namespace sls { std::lock_guard lock(s.m_mutex); IF_VERBOSE(1, verbose_stream() << "SMT -> SLS units " << s.m_units << "\n"); for (auto lit : s.m_units) - if (m_shared_vars.contains(lit.var())) + if (m_shared_vars.contains(lit.var())) { + IF_VERBOSE(1, verbose_stream() << "unit " << lit << "\n"); m_ddfw->add(1, &lit); + } s.m_has_units = false; s.m_units.reset(); return true; @@ -87,7 +89,7 @@ namespace sls { bool export_to_sls() { bool updated = false; - if (false && export_units_to_sls()) + if (export_units_to_sls()) updated = true; if (export_phase_to_sls()) updated = true; @@ -127,7 +129,10 @@ namespace sls { smt_plugin(ast_manager& m, solver& s, sat::ddfw* d) : m(m), s(s), m_ddfw(d), m_context(m, *this), m_sync_uninterp(m_sync_manager), - m_sync_values(m_sync_manager) { + m_sync_values(m_sync_manager), + m_smt2sync_tr(s.ctx.get_manager(), m_sync_manager), + m_smt2sls_tr(s.ctx.get_manager(), m) + { } uint_set m_shared_vars; @@ -140,10 +145,8 @@ namespace sls { svector m_rewards; void add_uninterp(expr* smt_t) { - ast_translation tr1(s.ctx.get_manager(), m_sync_manager); - ast_translation tr2(s.ctx.get_manager(), m); - auto sync_t = tr1(smt_t); - auto sls_t = tr2(smt_t); + auto sync_t = m_smt2sync_tr(smt_t); + auto sls_t = m_smt2sls_tr(smt_t); m_sync_uninterp.push_back(sync_t); m_smt2sync_uninterp.insert(smt_t, sync_t); m_sls2sync_uninterp.insert(sls_t, sync_t); @@ -151,6 +154,7 @@ namespace sls { void add_shared_var(sat::bool_var v) { m_sls_phase.reserve(v + 1); + m_sat_phase.reserve(v + 1); m_rewards.reserve(v + 1); m_shared_vars.insert(v); } @@ -175,11 +179,11 @@ namespace sls { m_ddfw->reinit(); m_new_clause_added = false; } - import_from_sls(); + //import_from_sls(); } void on_model(model_ref& mdl) override { - IF_VERBOSE(1, verbose_stream() << "on-model " << "\n"); + IF_VERBOSE(3, verbose_stream() << "on-model " << "\n"); s.m_sls_model = mdl; } @@ -196,13 +200,19 @@ namespace sls { vector const& clauses() const override { return m_ddfw->clauses(); } sat::clause_info const& get_clause(unsigned idx) const override { return m_ddfw->get_clause_info(idx); } ptr_iterator get_use_list(sat::literal lit) override { return m_ddfw->use_list(lit); } - void flip(sat::bool_var v) override { m_ddfw->flip(v); } + void flip(sat::bool_var v) override { + m_ddfw->flip(v); + } double reward(sat::bool_var v) override { return m_ddfw->get_reward(v); } double get_weigth(unsigned clause_idx) override { return m_ddfw->get_clause_info(clause_idx).m_weight; } - bool is_true(sat::literal lit) override { return m_ddfw->get_value(lit.var()) != lit.sign(); } + bool is_true(sat::literal lit) override { + return m_ddfw->get_value(lit.var()) != lit.sign(); + } unsigned num_vars() const override { return m_ddfw->num_vars(); } indexed_uint_set const& unsat() const override { return m_ddfw->unsat_set(); } - sat::bool_var add_var() override { return m_ddfw->add_var(); } + sat::bool_var add_var() override { + return m_ddfw->add_var(); + } void add_clause(unsigned n, sat::literal const* lits) override { m_ddfw->add(n, lits); m_new_clause_added = true; @@ -238,8 +248,8 @@ namespace sls { m_has_new_sat_phase = true; s.s().set_has_new_best_phase(false); std::lock_guard lock(s.m_mutex); - for (unsigned i = 0; i < m_sat_phase.size(); ++i) - m_sat_phase[i] = s.s().get_best_phase(i); + for (auto v : m_shared_vars) + m_sat_phase[v] = s.s().get_best_phase(v); } void export_activity_to_smt() { @@ -271,13 +281,18 @@ namespace sls { } void solver::pop_core(unsigned n) { + if (!m_smt_plugin) + return; for (; m_trail_lim < s().init_trail_size(); ++m_trail_lim) { auto lit = s().trail_literal(m_trail_lim); + if (!m_smt_plugin->m_shared_vars.contains(lit.var())) + continue; + IF_VERBOSE(10, verbose_stream() << "push unit " << lit << " " << mk_bounded_pp(ctx.literal2expr(lit), m) << "\n"); std::lock_guard lock(m_mutex); m_units.push_back(lit); m_has_units = true; - } - if (s().at_base_lvl() && s().has_new_best_phase()) + } + if (s().has_new_best_phase()) m_smt_plugin->import_phase_from_smt(); m_smt_plugin->export_phase_to_smt(); @@ -334,6 +349,7 @@ namespace sls { } void solver::local_search_done() { + IF_VERBOSE(1, verbose_stream() << "local-search-done\n"); m_completed = false; CTRACE("sls", m_smt_plugin, m_smt_plugin->display(tout)); @@ -356,6 +372,7 @@ namespace sls { void solver::run_local_search_async() { if (m_ddfw) { m_result = m_ddfw->check(0, nullptr); + IF_VERBOSE(1, verbose_stream() << "sls-result " << m_result << "\n"); m_completed = true; } }