mirror of
https://github.com/Z3Prover/z3
synced 2025-08-26 13:06:05 +00:00
fixed-size min-heap for tracking top-k literals (#7752)
* very basic setup * ensure solve_eqs is fully disabled when smt.solve_eqs=false, #7743 Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com> * respect smt configuration parameter in elim_unconstrained simplifier Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com> * indentation * add bash files for test runs * add option to selectively disable variable solving for only ground expressions Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com> * remove verbose output Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com> * fix #7745 axioms for len(substr(...)) escaped due to nested rewriting * ensure atomic constraints are processed by arithmetic solver * #7739 optimization add simplification rule for at(x, offset) = "" Introducing j just postpones some rewrites that prevent useful simplifications. Z3 already uses common sub-expressions. The example highlights some opportunities for simplification, noteworthy at(..) = "". The example is solved in both versions after adding this simplification. * fix unsound len(substr) axiom Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com> * FreshConst is_sort (#7748) * #7750 add pre-processing simplification * Add parameter validation for selected API functions * updates to ac-plugin fix incrementality bugs by allowing destructive updates during saturation at the cost of redoing saturation after a pop. * enable passive, add check for bloom up-to-date * add top-k fixed-sized min-heap priority queue for top scoring literals --------- Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com> Co-authored-by: Nikolaj Bjorner <nbjorner@microsoft.com> Co-authored-by: humnrdble <83878671+humnrdble@users.noreply.github.com>
This commit is contained in:
parent
a9b4e35938
commit
435ea6ea99
24 changed files with 762 additions and 303 deletions
|
@ -92,66 +92,86 @@ namespace smt {
|
|||
sl.push_child(&(new_m->limit()));
|
||||
}
|
||||
|
||||
// Access socres as follows:
|
||||
// ctx.m_scores[lit.sign()][lit.var()]
|
||||
|
||||
// auto cube = [](context& ctx, expr_ref_vector& lasms, expr_ref& c) {
|
||||
// lookahead lh(ctx);
|
||||
// c = lh.choose();
|
||||
// if (c) {
|
||||
// if ((ctx.get_random_value() % 2) == 0)
|
||||
// c = c.get_manager().mk_not(c);
|
||||
// lasms.push_back(c);
|
||||
// }
|
||||
// };
|
||||
|
||||
auto cube = [&](context& ctx, expr_ref_vector& lasms, expr_ref& c) {
|
||||
lookahead lh(ctx); // Create lookahead object to use get_score for evaluation
|
||||
|
||||
std::vector<std::pair<expr_ref, double>> candidates; // List of candidate literals and their scores
|
||||
unsigned budget = 10; // Maximum number of variables to sample for building the cubes
|
||||
|
||||
// Loop through all Boolean variables in the context
|
||||
for (bool_var v = 0; v < ctx.m_bool_var2expr.size(); ++v) {
|
||||
if (ctx.get_assignment(v) != l_undef) continue; // Skip already assigned variables
|
||||
|
||||
expr* e = ctx.bool_var2expr(v); // Get expression associated with variable
|
||||
if (!e) continue; // Skip if not a valid variable
|
||||
|
||||
literal lit(v, false); // Create literal for v = true
|
||||
|
||||
ctx.push_scope(); // Save solver state
|
||||
ctx.assign(lit, b_justification::mk_axiom(), true); // Assign v = true with axiom justification
|
||||
ctx.propagate(); // Propagate consequences of assignment
|
||||
|
||||
if (!ctx.inconsistent()) { // Only keep variable if assignment didn’t lead to conflict
|
||||
double score = lh.get_score(); // Evaluate current state using lookahead scoring
|
||||
candidates.emplace_back(expr_ref(e, ctx.get_manager()), score); // Store (expr, score) pair
|
||||
}
|
||||
|
||||
ctx.pop_scope(1); // Restore solver state
|
||||
|
||||
if (candidates.size() >= budget) break; // Stop early if sample budget is exhausted
|
||||
}
|
||||
|
||||
// Sort candidates in descending order by score (higher score = better)
|
||||
std::sort(candidates.begin(), candidates.end(),
|
||||
[](auto& a, auto& b) { return a.second > b.second; });
|
||||
|
||||
unsigned cube_size = 2; // compute_cube_size_from_feedback(); // NEED TO IMPLEMENT: Decide how many literals to include (adaptive)
|
||||
|
||||
// Select top-scoring literals to form the cube
|
||||
for (unsigned i = 0; i < std::min(cube_size, (unsigned)candidates.size()); ++i) {
|
||||
expr_ref lit = candidates[i].first;
|
||||
|
||||
// Randomly flip polarity with 50% chance (introduces polarity diversity)
|
||||
auto cube = [](context& ctx, expr_ref_vector& lasms, expr_ref& c) {
|
||||
lookahead lh(ctx);
|
||||
c = lh.choose();
|
||||
if (c) {
|
||||
if ((ctx.get_random_value() % 2) == 0)
|
||||
lit = ctx.get_manager().mk_not(lit);
|
||||
|
||||
lasms.push_back(lit); // Add literal as thread-local assumption
|
||||
c = c.get_manager().mk_not(c);
|
||||
lasms.push_back(c);
|
||||
}
|
||||
};
|
||||
|
||||
auto cube_batch_pq = [&](context& ctx, expr_ref_vector& lasms, expr_ref& c) {
|
||||
unsigned k = 4; // Number of top literals you want
|
||||
ast_manager& m = ctx.get_manager();
|
||||
|
||||
// Get the entire fixed-size priority queue (it's not that big)
|
||||
auto candidates = ctx.m_pq_scores.get_heap(); // returns vector<node<key, priority>>
|
||||
|
||||
// Sort descending by priority (higher priority first)
|
||||
std::sort(candidates.begin(), candidates.end(),
|
||||
[](const auto& a, const auto& b) { return a.priority > b.priority; });
|
||||
|
||||
expr_ref_vector conjuncts(m);
|
||||
unsigned count = 0;
|
||||
|
||||
for (const auto& node : candidates) {
|
||||
if (ctx.get_assignment(node.key) != l_undef) continue;
|
||||
|
||||
expr* e = ctx.bool_var2expr(node.key);
|
||||
if (!e) continue;
|
||||
|
||||
|
||||
expr_ref lit(e, m);
|
||||
conjuncts.push_back(lit);
|
||||
|
||||
if (++count >= k) break;
|
||||
}
|
||||
|
||||
c = mk_and(conjuncts);
|
||||
lasms.push_back(c);
|
||||
};
|
||||
|
||||
auto cube_batch = [&](context& ctx, expr_ref_vector& lasms, expr_ref& c) {
|
||||
std::vector<std::pair<expr_ref, double>> candidates;
|
||||
unsigned k = 4; // Get top-k scoring literals
|
||||
ast_manager& m = ctx.get_manager();
|
||||
|
||||
// std::cout << ctx.m_bool_var2expr.size() << std::endl; // Prints the size of m_bool_var2expr
|
||||
// Loop over first 100 Boolean vars
|
||||
for (bool_var v = 0; v < 100; ++v) {
|
||||
if (ctx.get_assignment(v) != l_undef) continue;
|
||||
|
||||
expr* e = ctx.bool_var2expr(v);
|
||||
if (!e) continue;
|
||||
|
||||
literal lit(v, false);
|
||||
double score = ctx.get_score(lit);
|
||||
if (score == 0.0) continue;
|
||||
|
||||
candidates.emplace_back(expr_ref(e, m), score);
|
||||
}
|
||||
|
||||
// Sort all candidate literals descending by score
|
||||
std::sort(candidates.begin(), candidates.end(),
|
||||
[](auto& a, auto& b) { return a.second > b.second; });
|
||||
|
||||
// Clear c and build it as conjunction of top-k
|
||||
expr_ref_vector conjuncts(m);
|
||||
|
||||
for (unsigned i = 0; i < std::min(k, (unsigned)candidates.size()); ++i) {
|
||||
expr_ref lit = candidates[i].first;
|
||||
conjuncts.push_back(lit);
|
||||
}
|
||||
|
||||
// Build conjunction and store in c
|
||||
c = mk_and(conjuncts);
|
||||
|
||||
// Add the single cube formula to lasms (not each literal separately)
|
||||
lasms.push_back(c);
|
||||
};
|
||||
|
||||
obj_hashtable<expr> unit_set;
|
||||
expr_ref_vector unit_trail(ctx.m);
|
||||
|
@ -192,33 +212,47 @@ namespace smt {
|
|||
|
||||
std::mutex mux;
|
||||
|
||||
// Lambda defining the work each SMT thread performs
|
||||
auto worker_thread = [&](int i) {
|
||||
try {
|
||||
// Get thread-specific context and AST manager
|
||||
context& pctx = *pctxs[i];
|
||||
ast_manager& pm = *pms[i];
|
||||
|
||||
// Initialize local assumptions and cube
|
||||
expr_ref_vector lasms(pasms[i]);
|
||||
expr_ref c(pm);
|
||||
|
||||
// Set the max conflict limit for this thread
|
||||
pctx.get_fparams().m_max_conflicts = std::min(thread_max_conflicts, max_conflicts);
|
||||
|
||||
// Periodically generate cubes based on frequency
|
||||
if (num_rounds > 0 && (num_rounds % pctx.get_fparams().m_threads_cube_frequency) == 0)
|
||||
cube(pctx, lasms, c);
|
||||
cube_batch(pctx, lasms, c);
|
||||
|
||||
// Optional verbose logging
|
||||
IF_VERBOSE(1, verbose_stream() << "(smt.thread " << i;
|
||||
if (num_rounds > 0) verbose_stream() << " :round " << num_rounds;
|
||||
if (c) verbose_stream() << " :cube " << mk_bounded_pp(c, pm, 3);
|
||||
verbose_stream() << ")\n";);
|
||||
if (num_rounds > 0) verbose_stream() << " :round " << num_rounds;
|
||||
if (c) verbose_stream() << " :cube " << mk_bounded_pp(c, pm, 3);
|
||||
verbose_stream() << ")\n";);
|
||||
|
||||
// Check satisfiability of assumptions
|
||||
lbool r = pctx.check(lasms.size(), lasms.data());
|
||||
|
||||
if (r == l_undef && pctx.m_num_conflicts >= max_conflicts)
|
||||
; // no-op
|
||||
else if (r == l_undef && pctx.m_num_conflicts >= thread_max_conflicts)
|
||||
return;
|
||||
|
||||
// Handle results based on outcome and conflict count
|
||||
if (r == l_undef && pctx.m_num_conflicts >= max_conflicts)
|
||||
; // no-op, allow loop to continue
|
||||
else if (r == l_undef && pctx.m_num_conflicts >= thread_max_conflicts)
|
||||
return; // quit thread early
|
||||
|
||||
// If cube was unsat and it's in the core, learn from it
|
||||
else if (r == l_false && pctx.unsat_core().contains(c)) {
|
||||
IF_VERBOSE(1, verbose_stream() << "(smt.thread " << i << " :learn " << mk_bounded_pp(c, pm, 3) << ")");
|
||||
pctx.assert_expr(mk_not(mk_and(pctx.unsat_core())));
|
||||
return;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Begin thread-safe update of shared result state
|
||||
bool first = false;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mux);
|
||||
|
@ -232,29 +266,27 @@ namespace smt {
|
|||
finished_id = i;
|
||||
result = r;
|
||||
}
|
||||
else if (!first) return;
|
||||
else if (!first) return; // nothing new to contribute
|
||||
}
|
||||
|
||||
// Cancel limits on other threads now that a result is known
|
||||
for (ast_manager* m : pms) {
|
||||
if (m != &pm) m->limit().cancel();
|
||||
}
|
||||
|
||||
}
|
||||
catch (z3_error & err) {
|
||||
} catch (z3_error & err) {
|
||||
if (finished_id == UINT_MAX) {
|
||||
error_code = err.error_code();
|
||||
ex_kind = ERROR_EX;
|
||||
done = true;
|
||||
}
|
||||
}
|
||||
catch (z3_exception & ex) {
|
||||
} catch (z3_exception & ex) {
|
||||
if (finished_id == UINT_MAX) {
|
||||
ex_msg = ex.what();
|
||||
ex_kind = DEFAULT_EX;
|
||||
done = true;
|
||||
}
|
||||
}
|
||||
catch (...) {
|
||||
} catch (...) {
|
||||
if (finished_id == UINT_MAX) {
|
||||
ex_msg = "unknown exception";
|
||||
ex_kind = ERROR_EX;
|
||||
|
@ -263,36 +295,45 @@ namespace smt {
|
|||
}
|
||||
};
|
||||
|
||||
// for debugging: num_threads = 1;
|
||||
|
||||
// Thread scheduling loop
|
||||
while (true) {
|
||||
vector<std::thread> threads(num_threads);
|
||||
|
||||
// Launch threads
|
||||
for (unsigned i = 0; i < num_threads; ++i) {
|
||||
// [&, i] is the lambda's capture clause: capture all variables by reference (&) except i, which is captured by value.
|
||||
threads[i] = std::thread([&, i]() { worker_thread(i); });
|
||||
}
|
||||
|
||||
// Wait for all threads to finish
|
||||
for (auto & th : threads) {
|
||||
th.join();
|
||||
}
|
||||
|
||||
// Stop if one finished with a result
|
||||
if (done) break;
|
||||
|
||||
// Otherwise update shared state and retry
|
||||
collect_units();
|
||||
++num_rounds;
|
||||
max_conflicts = (max_conflicts < thread_max_conflicts) ? 0 : (max_conflicts - thread_max_conflicts);
|
||||
thread_max_conflicts *= 2;
|
||||
}
|
||||
|
||||
// Gather statistics from all solver contexts
|
||||
for (context* c : pctxs) {
|
||||
c->collect_statistics(ctx.m_aux_stats);
|
||||
}
|
||||
|
||||
// If no thread finished successfully, throw recorded error
|
||||
if (finished_id == UINT_MAX) {
|
||||
switch (ex_kind) {
|
||||
case ERROR_EX: throw z3_error(error_code);
|
||||
default: throw default_exception(std::move(ex_msg));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle result: translate model/unsat core back to main context
|
||||
model_ref mdl;
|
||||
context& pctx = *pctxs[finished_id];
|
||||
ast_translation tr(*pms[finished_id], m);
|
||||
|
@ -309,7 +350,7 @@ namespace smt {
|
|||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue