3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-08-05 02:40:24 +00:00
z3/src/smt/smt_parallel.cpp
2025-07-23 15:26:02 -07:00

315 lines
12 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/*++
Copyright (c) 2020 Microsoft Corporation
Module Name:
smt_parallel.cpp
Abstract:
Parallel SMT, portfolio loop specialized to SMT core.
Author:
nbjorner 2020-01-31
--*/
#include "util/scoped_ptr_vector.h"
#include "ast/ast_util.h"
#include "ast/ast_pp.h"
#include "ast/ast_ll_pp.h"
#include "ast/ast_translation.h"
#include "smt/smt_parallel.h"
#include "smt/smt_lookahead.h"
#ifdef SINGLE_THREAD
namespace smt {
lbool parallel::operator()(expr_ref_vector const& asms) {
return l_undef;
}
}
#else
#include <thread>
namespace smt {
lbool parallel::operator()(expr_ref_vector const& asms) {
lbool result = l_undef;
unsigned num_threads = std::min((unsigned) std::thread::hardware_concurrency(), ctx.get_fparams().m_threads);
flet<unsigned> _nt(ctx.m_fparams.m_threads, 1);
unsigned thread_max_conflicts = ctx.get_fparams().m_threads_max_conflicts;
unsigned max_conflicts = ctx.get_fparams().m_max_conflicts;
// try first sequential with a low conflict budget to make super easy problems cheap
unsigned max_c = std::min(thread_max_conflicts, 40u);
flet<unsigned> _mc(ctx.get_fparams().m_max_conflicts, max_c);
result = ctx.check(asms.size(), asms.data());
if (result != l_undef || ctx.m_num_conflicts < max_c) {
return result;
}
enum par_exception_kind {
DEFAULT_EX,
ERROR_EX
};
vector<smt_params> smt_params;
scoped_ptr_vector<ast_manager> pms;
scoped_ptr_vector<context> pctxs;
vector<expr_ref_vector> pasms;
ast_manager& m = ctx.m;
scoped_limits sl(m.limit());
unsigned finished_id = UINT_MAX;
std::string ex_msg;
par_exception_kind ex_kind = DEFAULT_EX;
unsigned error_code = 0;
bool done = false;
unsigned num_rounds = 0;
if (m.has_trace_stream())
throw default_exception("trace streams have to be off in parallel mode");
for (unsigned i = 0; i < num_threads; ++i) {
smt_params.push_back(ctx.get_fparams());
}
for (unsigned i = 0; i < num_threads; ++i) {
ast_manager* new_m = alloc(ast_manager, m, true);
pms.push_back(new_m);
pctxs.push_back(alloc(context, *new_m, smt_params[i], ctx.get_params()));
context& new_ctx = *pctxs.back();
context::copy(ctx, new_ctx, true);
new_ctx.set_random_seed(i + ctx.get_fparams().m_random_seed);
ast_translation tr(m, *new_m);
pasms.push_back(tr(asms));
sl.push_child(&(new_m->limit()));
}
// auto cube = [](context& ctx, expr_ref_vector& lasms, expr_ref& c) {
// lookahead lh(ctx);
// c = lh.choose();
// if (c) {
// if ((ctx.get_random_value() % 2) == 0)
// c = c.get_manager().mk_not(c);
// lasms.push_back(c);
// }
// };
auto cube = [&](context& ctx, expr_ref_vector& lasms, expr_ref& c) {
lookahead lh(ctx); // Create lookahead object to use get_score for evaluation
std::vector<std::pair<expr_ref, double>> candidates; // List of candidate literals and their scores
unsigned budget = 10; // Maximum number of variables to sample for building the cubes
// Loop through all Boolean variables in the context
for (bool_var v = 0; v < ctx.m_bool_var2expr.size(); ++v) {
if (ctx.get_assignment(v) != l_undef) continue; // Skip already assigned variables
expr* e = ctx.bool_var2expr(v); // Get expression associated with variable
if (!e) continue; // Skip if not a valid variable
literal lit(v, false); // Create literal for v = true
ctx.push_scope(); // Save solver state
ctx.assign(lit, b_justification::mk_axiom(), true); // Assign v = true with axiom justification
ctx.propagate(); // Propagate consequences of assignment
if (!ctx.inconsistent()) { // Only keep variable if assignment didnt lead to conflict
double score = lh.get_score(); // Evaluate current state using lookahead scoring
candidates.emplace_back(expr_ref(e, ctx.get_manager()), score); // Store (expr, score) pair
}
ctx.pop_scope(1); // Restore solver state
if (candidates.size() >= budget) break; // Stop early if sample budget is exhausted
}
// Sort candidates in descending order by score (higher score = better)
std::sort(candidates.begin(), candidates.end(),
[](auto& a, auto& b) { return a.second > b.second; });
unsigned cube_size = 2; // compute_cube_size_from_feedback(); // NEED TO IMPLEMENT: Decide how many literals to include (adaptive)
// Select top-scoring literals to form the cube
for (unsigned i = 0; i < std::min(cube_size, (unsigned)candidates.size()); ++i) {
expr_ref lit = candidates[i].first;
// Randomly flip polarity with 50% chance (introduces polarity diversity)
if ((ctx.get_random_value() % 2) == 0)
lit = ctx.get_manager().mk_not(lit);
lasms.push_back(lit); // Add literal as thread-local assumption
}
};
obj_hashtable<expr> unit_set;
expr_ref_vector unit_trail(ctx.m);
unsigned_vector unit_lim;
for (unsigned i = 0; i < num_threads; ++i) unit_lim.push_back(0);
std::function<void(void)> collect_units = [&,this]() {
for (unsigned i = 0; i < num_threads; ++i) {
context& pctx = *pctxs[i];
pctx.pop_to_base_lvl();
ast_translation tr(pctx.m, ctx.m);
unsigned sz = pctx.assigned_literals().size();
for (unsigned j = unit_lim[i]; j < sz; ++j) {
literal lit = pctx.assigned_literals()[j];
expr_ref e(pctx.bool_var2expr(lit.var()), pctx.m);
if (lit.sign()) e = pctx.m.mk_not(e);
expr_ref ce(tr(e.get()), ctx.m);
if (!unit_set.contains(ce)) {
unit_set.insert(ce);
unit_trail.push_back(ce);
}
}
}
unsigned sz = unit_trail.size();
for (unsigned i = 0; i < num_threads; ++i) {
context& pctx = *pctxs[i];
ast_translation tr(ctx.m, pctx.m);
for (unsigned j = unit_lim[i]; j < sz; ++j) {
expr_ref src(ctx.m), dst(pctx.m);
dst = tr(unit_trail.get(j));
pctx.assert_expr(dst);
}
unit_lim[i] = pctx.assigned_literals().size();
}
IF_VERBOSE(1, verbose_stream() << "(smt.thread :units " << sz << ")\n");
};
std::mutex mux;
auto worker_thread = [&](int i) {
try {
context& pctx = *pctxs[i];
ast_manager& pm = *pms[i];
expr_ref_vector lasms(pasms[i]);
expr_ref c(pm);
pctx.get_fparams().m_max_conflicts = std::min(thread_max_conflicts, max_conflicts);
if (num_rounds > 0 && (num_rounds % pctx.get_fparams().m_threads_cube_frequency) == 0)
cube(pctx, lasms, c);
IF_VERBOSE(1, verbose_stream() << "(smt.thread " << i;
if (num_rounds > 0) verbose_stream() << " :round " << num_rounds;
if (c) verbose_stream() << " :cube " << mk_bounded_pp(c, pm, 3);
verbose_stream() << ")\n";);
lbool r = pctx.check(lasms.size(), lasms.data());
if (r == l_undef && pctx.m_num_conflicts >= max_conflicts)
; // no-op
else if (r == l_undef && pctx.m_num_conflicts >= thread_max_conflicts)
return;
else if (r == l_false && pctx.unsat_core().contains(c)) {
IF_VERBOSE(1, verbose_stream() << "(smt.thread " << i << " :learn " << mk_bounded_pp(c, pm, 3) << ")");
pctx.assert_expr(mk_not(mk_and(pctx.unsat_core())));
return;
}
bool first = false;
{
std::lock_guard<std::mutex> lock(mux);
if (finished_id == UINT_MAX) {
finished_id = i;
first = true;
result = r;
done = true;
}
if (!first && r != l_undef && result == l_undef) {
finished_id = i;
result = r;
}
else if (!first) return;
}
for (ast_manager* m : pms) {
if (m != &pm) m->limit().cancel();
}
}
catch (z3_error & err) {
if (finished_id == UINT_MAX) {
error_code = err.error_code();
ex_kind = ERROR_EX;
done = true;
}
}
catch (z3_exception & ex) {
if (finished_id == UINT_MAX) {
ex_msg = ex.what();
ex_kind = DEFAULT_EX;
done = true;
}
}
catch (...) {
if (finished_id == UINT_MAX) {
ex_msg = "unknown exception";
ex_kind = ERROR_EX;
done = true;
}
}
};
// for debugging: num_threads = 1;
while (true) {
vector<std::thread> threads(num_threads);
for (unsigned i = 0; i < num_threads; ++i) {
// [&, i] is the lambda's capture clause: capture all variables by reference (&) except i, which is captured by value.
threads[i] = std::thread([&, i]() { worker_thread(i); });
}
for (auto & th : threads) {
th.join();
}
if (done) break;
collect_units();
++num_rounds;
max_conflicts = (max_conflicts < thread_max_conflicts) ? 0 : (max_conflicts - thread_max_conflicts);
thread_max_conflicts *= 2;
}
for (context* c : pctxs) {
c->collect_statistics(ctx.m_aux_stats);
}
if (finished_id == UINT_MAX) {
switch (ex_kind) {
case ERROR_EX: throw z3_error(error_code);
default: throw default_exception(std::move(ex_msg));
}
}
model_ref mdl;
context& pctx = *pctxs[finished_id];
ast_translation tr(*pms[finished_id], m);
switch (result) {
case l_true:
pctx.get_model(mdl);
if (mdl)
ctx.set_model(mdl->translate(tr));
break;
case l_false:
ctx.m_unsat_core.reset();
for (expr* e : pctx.unsat_core())
ctx.m_unsat_core.push_back(tr(e));
break;
default:
break;
}
return result;
}
}
#endif