From 8a6cbec4f088ff0a85df85b34a43874ccf1af989 Mon Sep 17 00:00:00 2001 From: Ilana Shapiro Date: Wed, 30 Jul 2025 15:55:03 -0700 Subject: [PATCH] fix bug in parallel solving batch setup --- src/smt/smt_parallel.cpp | 33 +++++++++++++++++++++++---------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/src/smt/smt_parallel.cpp b/src/smt/smt_parallel.cpp index b45d82389..4d55029af 100644 --- a/src/smt/smt_parallel.cpp +++ b/src/smt/smt_parallel.cpp @@ -304,10 +304,10 @@ namespace smt { struct BatchManager { std::mutex mtx; std::vector batches; - size_t batch_idx = 0; - size_t batch_size = 1; // num batches + unsigned batch_idx = 0; + unsigned batch_size = 1; - BatchManager(size_t batch_size) : batch_size(batch_size) {} + BatchManager(unsigned batch_size) : batch_size(batch_size) {} // translate the next SINGLE batch of batch_size cubes to the thread expr_ref_vector get_next_batch( @@ -316,6 +316,7 @@ namespace smt { ) { std::lock_guard lock(mtx); expr_ref_vector cube_batch(thread_m); // ensure bound to thread manager + if (batch_idx >= batches.size()) return cube_batch; for (expr* cube : batches[batch_idx]) { cube_batch.push_back( @@ -324,6 +325,7 @@ namespace smt { } ++batch_idx; + std::cout << "Thread batch " << batch_idx - 1 << " size: " << cube_batch.size() << "\n"; return cube_batch; } @@ -347,7 +349,7 @@ namespace smt { } std::cout << "Top lits:\n"; - for (size_t j = 0; j < top_lits.size(); ++j) { + for (unsigned j = 0; j < top_lits.size(); ++j) { std::cout << " [" << j << "] " << top_lits[j].get() << "\n"; } @@ -380,12 +382,23 @@ namespace smt { std::lock_guard lock(mtx); std::vector cube_batches; - size_t num_batches = 0; - while (num_batches < batch_size) { - expr_ref_vector cube_batch = cube_batch_pq(main_ctx); - cube_batches.push_back(cube_batch); - num_batches += cube_batch.size(); + // Get all cubes in the main context's manager + expr_ref_vector all_cubes = cube_batch_pq(main_ctx); + + ast_manager &m = main_ctx.get_manager(); + + // Partition into batches + for (unsigned start = 0; start < all_cubes.size(); start += batch_size) { + expr_ref_vector batch(m); + + unsigned end = std::min(start + batch_size, all_cubes.size()); + for (unsigned j = start; j < end; ++j) { + batch.push_back(all_cubes[j].get()); + } + + cube_batches.push_back(std::move(batch)); } + return cube_batches; } }; @@ -405,7 +418,7 @@ namespace smt { auto next_batch = batch_manager.get_next_batch(ctx.m, *pms[i]); if (next_batch.empty()) break; // No more work - lbool r = worker_thread(i, next_batch); + worker_thread(i, next_batch); } });