3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-08-26 21:16:02 +00:00

add top-k fixed-sized min-heap priority queue for top scoring literals

This commit is contained in:
Ilana Shapiro 2025-07-27 18:12:07 -07:00
parent 65504953f7
commit 36fbee3a2d
5 changed files with 335 additions and 85 deletions

View file

@ -21,11 +21,12 @@ for rel_path in "${REL_TEST_FILES[@]}"; do
test_name="$rel_path" test_name="$rel_path"
echo "Running: $test_name" echo "Running: $test_name"
echo "===== $test_name =====" >> "$OUT_FILE" echo "===== $test_name =====" | tee -a "$OUT_FILE"
$Z3 "$full_path" $OPTIONS >> "$OUT_FILE" 2>&1 # Run Z3 and pipe output to both screen and file
$Z3 "$full_path" $OPTIONS 2>&1 | tee -a "$OUT_FILE"
echo "" >> "$OUT_FILE" echo "" | tee -a "$OUT_FILE"
done done
echo "Results written to $OUT_FILE" echo "Results written to $OUT_FILE"

191
src/smt/priority_queue.h Normal file
View file

@ -0,0 +1,191 @@
// SOURCE: https://github.com/Ten0/updatable_priority_queue/blob/master/updatable_priority_queue.h
#include <utility>
#include <vector>
namespace updatable_priority_queue {
template <typename Key, typename Priority>
struct priority_queue_node {
Priority priority;
Key key;
priority_queue_node(const Key& key, const Priority& priority) : priority(priority), key(key) {}
friend bool operator<(const priority_queue_node& pqn1, const priority_queue_node& pqn2) {
return pqn1.priority > pqn2.priority;
}
friend bool operator>(const priority_queue_node& pqn1, const priority_queue_node& pqn2) {
return pqn1.priority < pqn2.priority;
}
};
/** Key has to be an uint value (convertible to size_t)
* This is a max heap (max is on top), to match stl's pQ */
template <typename Key, typename Priority>
class priority_queue {
protected:
std::vector<size_t> id_to_heappos;
std::vector<priority_queue_node<Key,Priority>> heap;
std::size_t max_size = 4; // std::numeric_limits<std::size_t>::max(); // Create a variable max_size that defaults to the largest size_t value possible
public:
// priority_queue() {}
priority_queue(std::size_t max_size = std::numeric_limits<std::size_t>::max()): max_size(max_size) {}
// Returns a const reference to the internal heap storage
const std::vector<priority_queue_node<Key, Priority>>& get_heap() const {
return heap;
}
bool empty() const { return heap.empty(); }
std::size_t size() const { return heap.size(); }
/** first is priority, second is key */
const priority_queue_node<Key,Priority>& top() const { return heap.front(); }
void pop(bool remember_key=false) {
if(size() == 0) return;
id_to_heappos[heap.front().key] = -1-remember_key;
if(size() > 1) {
*heap.begin() = std::move(*(heap.end()-1));
id_to_heappos[heap.front().key] = 0;
}
heap.pop_back();
sift_down(0);
}
priority_queue_node<Key,Priority> pop_value(bool remember_key=true) {
if(size() == 0) return priority_queue_node<Key, Priority>(-1, Priority());
priority_queue_node<Key,Priority> ret = std::move(*heap.begin());
id_to_heappos[ret.key] = -1-remember_key;
if(size() > 1) {
*heap.begin() = std::move(*(heap.end()-1));
id_to_heappos[heap.front().key] = 0;
}
heap.pop_back();
sift_down(0);
return ret;
}
/** Sets the priority for the given key. If not present, it will be added, otherwise it will be updated
* Returns true if the priority was changed.
* */
bool set(const Key& key, const Priority& priority, bool only_if_higher=false) {
if(key < id_to_heappos.size() && id_to_heappos[key] < ((size_t)-2)) // This key is already in the pQ
return update(key, priority, only_if_higher);
else
return push(key, priority, only_if_higher);
}
std::pair<bool,Priority> get_priority(const Key& key) {
if(key < id_to_heappos.size()) {
size_t pos = id_to_heappos[key];
if(pos < ((size_t)-2)) {
return {true, heap[pos].priority};
}
}
return {false, 0};
}
/** Returns true if the key was not inside and was added, otherwise does nothing and returns false
* If the key was remembered and only_if_unknown is true, does nothing and returns false
* */
bool push(const Key& key, const Priority& priority, bool only_if_unknown = false) {
extend_ids(key);
if (id_to_heappos[key] < ((size_t)-2)) return false; // already inside
if (only_if_unknown && id_to_heappos[key] == ((size_t)-2)) return false; // was evicted and only_if_unknown prevents re-adding
if (heap.size() < max_size) {
// We have room: just add new element
size_t n = heap.size();
id_to_heappos[key] = n;
heap.emplace_back(key, priority);
sift_up(n);
return true;
} else {
// Heap full: heap[0] is the smallest priority in the top-k (min-heap)
if (priority <= heap[0].priority) {
// New element priority too small or equal, discard it
return false;
}
// Evict smallest element at heap[0]
Key evicted_key = heap[0].key;
id_to_heappos[evicted_key] = -2; // Mark evicted
heap[0] = priority_queue_node<Key, Priority>(key, priority);
id_to_heappos[key] = 0;
sift_down(0); // restore min-heap property
return true;
}
}
/** Returns true if the key was already inside and was updated, otherwise does nothing and returns false */
bool update(const Key& key, const Priority& new_priority, bool only_if_higher=false) {
if(key >= id_to_heappos.size()) return false;
size_t heappos = id_to_heappos[key];
if(heappos >= ((size_t)-2)) return false;
Priority& priority = heap[heappos].priority;
if(new_priority > priority) {
priority = new_priority;
sift_up(heappos);
return true;
}
else if(!only_if_higher && new_priority < priority) {
priority = new_priority;
sift_down(heappos);
return true;
}
return false;
}
void clear() {
heap.clear();
id_to_heappos.clear();
}
private:
void extend_ids(Key k) {
size_t new_size = k+1;
if(id_to_heappos.size() < new_size)
id_to_heappos.resize(new_size, -1);
}
void sift_down(size_t heappos) {
size_t len = heap.size();
size_t child = heappos*2+1;
if(len < 2 || child >= len) return;
if(child+1 < len && heap[child+1] > heap[child]) ++child; // Check whether second child is higher
if(!(heap[child] > heap[heappos])) return; // Already in heap order
priority_queue_node<Key,Priority> val = std::move(heap[heappos]);
do {
heap[heappos] = std::move(heap[child]);
id_to_heappos[heap[heappos].key] = heappos;
heappos = child;
child = 2*child+1;
if(child >= len) break;
if(child+1 < len && heap[child+1] > heap[child]) ++child;
} while(heap[child] > val);
heap[heappos] = std::move(val);
id_to_heappos[heap[heappos].key] = heappos;
}
void sift_up(size_t heappos) {
size_t len = heap.size();
if(len < 2 || heappos <= 0) return;
size_t parent = (heappos-1)/2;
if(!(heap[heappos] > heap[parent])) return;
priority_queue_node<Key, Priority> val = std::move(heap[heappos]);
do {
heap[heappos] = std::move(heap[parent]);
id_to_heappos[heap[heappos].key] = heappos;
heappos = parent;
if(heappos <= 0) break;
parent = (parent-1)/2;
} while(val > heap[parent]);
heap[heappos] = std::move(val);
id_to_heappos[heap[heappos].key] = heappos;
}
};
}

View file

@ -50,6 +50,7 @@ Revision History:
#include "model/model.h" #include "model/model.h"
#include "solver/progress_callback.h" #include "solver/progress_callback.h"
#include "solver/assertions/asserted_formulas.h" #include "solver/assertions/asserted_formulas.h"
#include "smt/priority_queue.h"
#include <tuple> #include <tuple>
// there is a significant space overhead with allocating 1000+ contexts in // there is a significant space overhead with allocating 1000+ contexts in
@ -189,7 +190,8 @@ namespace smt {
unsigned_vector m_lit_occs; //!< occurrence count of literals unsigned_vector m_lit_occs; //!< occurrence count of literals
svector<bool_var_data> m_bdata; //!< mapping bool_var -> data svector<bool_var_data> m_bdata; //!< mapping bool_var -> data
svector<double> m_activity; svector<double> m_activity;
svector<std::array<double, 2>> m_scores; updatable_priority_queue::priority_queue<bool_var, double> m_pq_scores;
svector<std::array<double, 2>> m_lit_scores;
clause_vector m_aux_clauses; clause_vector m_aux_clauses;
clause_vector m_lemmas; clause_vector m_lemmas;
vector<clause_vector> m_clauses_to_reinit; vector<clause_vector> m_clauses_to_reinit;
@ -932,10 +934,11 @@ namespace smt {
void dump_axiom(unsigned n, literal const* lits); void dump_axiom(unsigned n, literal const* lits);
void add_scores(unsigned n, literal const* lits); void add_scores(unsigned n, literal const* lits);
void reset_scores() { void reset_scores() {
for (auto& s : m_scores) s[0] = s[1] = 0.0; for (auto& s : m_lit_scores) s[0] = s[1] = 0.0;
m_pq_scores.clear(); // Clear the priority queue heap as well
} }
double get_score(literal l) const { double get_score(literal l) const {
return m_scores[l.var()][l.sign()]; return m_lit_scores[l.var()][l.sign()];
} }
public: public:

View file

@ -928,8 +928,8 @@ namespace smt {
set_bool_var(id, v); set_bool_var(id, v);
m_bdata.reserve(v+1); m_bdata.reserve(v+1);
m_activity.reserve(v+1); m_activity.reserve(v+1);
m_scores.reserve(v + 1); m_lit_scores.reserve(v + 1);
m_scores[v][0] = m_scores[v][1] = 0.0; m_lit_scores[v][0] = m_lit_scores[v][1] = 0.0;
m_bool_var2expr.reserve(v+1); m_bool_var2expr.reserve(v+1);
m_bool_var2expr[v] = n; m_bool_var2expr[v] = n;
literal l(v, false); literal l(v, false);
@ -1527,11 +1527,24 @@ namespace smt {
}} }}
} }
// void context::add_scores(unsigned n, literal const* lits) {
// for (unsigned i = 0; i < n; ++i) {
// auto lit = lits[i];
// unsigned v = lit.var();
// m_lit_scores[v][lit.sign()] += 1.0 / n;
// }
// }
void context::add_scores(unsigned n, literal const* lits) { void context::add_scores(unsigned n, literal const* lits) {
for (unsigned i = 0; i < n; ++i) { for (unsigned i = 0; i < n; ++i) {
auto lit = lits[i]; auto lit = lits[i];
unsigned v = lit.var(); unsigned v = lit.var(); // unique key per literal
m_scores[v][lit.sign()] += 1.0 / n;
auto curr_score = m_lit_scores[v][0] * m_lit_scores[v][1];
m_lit_scores[v][lit.sign()] += 1.0 / n;
auto new_score = m_lit_scores[v][0] * m_lit_scores[v][1];
m_pq_scores.set(v, new_score);
} }
} }

View file

@ -92,63 +92,84 @@ namespace smt {
sl.push_child(&(new_m->limit())); sl.push_child(&(new_m->limit()));
} }
// auto cube = [](context& ctx, expr_ref_vector& lasms, expr_ref& c) { auto cube = [](context& ctx, expr_ref_vector& lasms, expr_ref& c) {
// lookahead lh(ctx); lookahead lh(ctx);
// c = lh.choose(); c = lh.choose();
// if (c) { if (c) {
// if ((ctx.get_random_value() % 2) == 0)
// c = c.get_manager().mk_not(c);
// lasms.push_back(c);
// }
// };
auto cube = [&](context& ctx, expr_ref_vector& lasms, expr_ref& c) {
lookahead lh(ctx); // Create lookahead object to use get_score for evaluation
std::vector<std::pair<expr_ref, double>> candidates; // List of candidate literals and their scores
unsigned budget = 10; // Maximum number of variables to sample for building the cubes
// Loop through all Boolean variables in the context
for (bool_var v = 0; v < ctx.m_bool_var2expr.size(); ++v) {
if (ctx.get_assignment(v) != l_undef) continue; // Skip already assigned variables
expr* e = ctx.bool_var2expr(v); // Get expression associated with variable
if (!e) continue; // Skip if not a valid variable
literal lit(v, false); // Create literal for v = true
ctx.push_scope(); // Save solver state
ctx.assign(lit, b_justification::mk_axiom(), true); // Assign v = true with axiom justification
ctx.propagate(); // Propagate consequences of assignment
if (!ctx.inconsistent()) { // Only keep variable if assignment didnt lead to conflict
double score = lh.get_score(); // Evaluate current state using lookahead scoring
candidates.emplace_back(expr_ref(e, ctx.get_manager()), score); // Store (expr, score) pair
}
ctx.pop_scope(1); // Restore solver state
if (candidates.size() >= budget) break; // Stop early if sample budget is exhausted
}
// Sort candidates in descending order by score (higher score = better)
std::sort(candidates.begin(), candidates.end(),
[](auto& a, auto& b) { return a.second > b.second; });
unsigned cube_size = 2; // compute_cube_size_from_feedback(); // NEED TO IMPLEMENT: Decide how many literals to include (adaptive)
// Select top-scoring literals to form the cube
for (unsigned i = 0; i < std::min(cube_size, (unsigned)candidates.size()); ++i) {
expr_ref lit = candidates[i].first;
// Randomly flip polarity with 50% chance (introduces polarity diversity)
if ((ctx.get_random_value() % 2) == 0) if ((ctx.get_random_value() % 2) == 0)
lit = ctx.get_manager().mk_not(lit); c = c.get_manager().mk_not(c);
lasms.push_back(c);
lasms.push_back(lit); // Add literal as thread-local assumption
} }
}; };
auto cube_batch_pq = [&](context& ctx, expr_ref_vector& lasms, expr_ref& c) {
unsigned k = 4; // Number of top literals you want
ast_manager& m = ctx.get_manager();
// Get the entire fixed-size priority queue (it's not that big)
auto candidates = ctx.m_pq_scores.get_heap(); // returns vector<node<key, priority>>
// Sort descending by priority (higher priority first)
std::sort(candidates.begin(), candidates.end(),
[](const auto& a, const auto& b) { return a.priority > b.priority; });
expr_ref_vector conjuncts(m);
unsigned count = 0;
for (const auto& node : candidates) {
if (ctx.get_assignment(node.key) != l_undef) continue;
expr* e = ctx.bool_var2expr(node.key);
if (!e) continue;
expr_ref lit(e, m);
conjuncts.push_back(lit);
if (++count >= k) break;
}
c = mk_and(conjuncts);
lasms.push_back(c);
};
auto cube_batch = [&](context& ctx, expr_ref_vector& lasms, expr_ref& c) {
std::vector<std::pair<expr_ref, double>> candidates;
unsigned k = 4; // Get top-k scoring literals
ast_manager& m = ctx.get_manager();
// std::cout << ctx.m_bool_var2expr.size() << std::endl; // Prints the size of m_bool_var2expr
// Loop over first 100 Boolean vars
for (bool_var v = 0; v < 100; ++v) {
if (ctx.get_assignment(v) != l_undef) continue;
expr* e = ctx.bool_var2expr(v);
if (!e) continue;
literal lit(v, false);
double score = ctx.get_score(lit);
if (score == 0.0) continue;
candidates.emplace_back(expr_ref(e, m), score);
}
// Sort all candidate literals descending by score
std::sort(candidates.begin(), candidates.end(),
[](auto& a, auto& b) { return a.second > b.second; });
// Clear c and build it as conjunction of top-k
expr_ref_vector conjuncts(m);
for (unsigned i = 0; i < std::min(k, (unsigned)candidates.size()); ++i) {
expr_ref lit = candidates[i].first;
conjuncts.push_back(lit);
}
// Build conjunction and store in c
c = mk_and(conjuncts);
// Add the single cube formula to lasms (not each literal separately)
lasms.push_back(c);
};
obj_hashtable<expr> unit_set; obj_hashtable<expr> unit_set;
expr_ref_vector unit_trail(ctx.m); expr_ref_vector unit_trail(ctx.m);
@ -189,33 +210,47 @@ namespace smt {
std::mutex mux; std::mutex mux;
// Lambda defining the work each SMT thread performs
auto worker_thread = [&](int i) { auto worker_thread = [&](int i) {
try { try {
// Get thread-specific context and AST manager
context& pctx = *pctxs[i]; context& pctx = *pctxs[i];
ast_manager& pm = *pms[i]; ast_manager& pm = *pms[i];
// Initialize local assumptions and cube
expr_ref_vector lasms(pasms[i]); expr_ref_vector lasms(pasms[i]);
expr_ref c(pm); expr_ref c(pm);
// Set the max conflict limit for this thread
pctx.get_fparams().m_max_conflicts = std::min(thread_max_conflicts, max_conflicts); pctx.get_fparams().m_max_conflicts = std::min(thread_max_conflicts, max_conflicts);
// Periodically generate cubes based on frequency
if (num_rounds > 0 && (num_rounds % pctx.get_fparams().m_threads_cube_frequency) == 0) if (num_rounds > 0 && (num_rounds % pctx.get_fparams().m_threads_cube_frequency) == 0)
cube(pctx, lasms, c); cube_batch(pctx, lasms, c);
// Optional verbose logging
IF_VERBOSE(1, verbose_stream() << "(smt.thread " << i; IF_VERBOSE(1, verbose_stream() << "(smt.thread " << i;
if (num_rounds > 0) verbose_stream() << " :round " << num_rounds; if (num_rounds > 0) verbose_stream() << " :round " << num_rounds;
if (c) verbose_stream() << " :cube " << mk_bounded_pp(c, pm, 3); if (c) verbose_stream() << " :cube " << mk_bounded_pp(c, pm, 3);
verbose_stream() << ")\n";); verbose_stream() << ")\n";);
// Check satisfiability of assumptions
lbool r = pctx.check(lasms.size(), lasms.data()); lbool r = pctx.check(lasms.size(), lasms.data());
if (r == l_undef && pctx.m_num_conflicts >= max_conflicts) // Handle results based on outcome and conflict count
; // no-op if (r == l_undef && pctx.m_num_conflicts >= max_conflicts)
else if (r == l_undef && pctx.m_num_conflicts >= thread_max_conflicts) ; // no-op, allow loop to continue
return; else if (r == l_undef && pctx.m_num_conflicts >= thread_max_conflicts)
return; // quit thread early
// If cube was unsat and it's in the core, learn from it
else if (r == l_false && pctx.unsat_core().contains(c)) { else if (r == l_false && pctx.unsat_core().contains(c)) {
IF_VERBOSE(1, verbose_stream() << "(smt.thread " << i << " :learn " << mk_bounded_pp(c, pm, 3) << ")"); IF_VERBOSE(1, verbose_stream() << "(smt.thread " << i << " :learn " << mk_bounded_pp(c, pm, 3) << ")");
pctx.assert_expr(mk_not(mk_and(pctx.unsat_core()))); pctx.assert_expr(mk_not(mk_and(pctx.unsat_core())));
return; return;
} }
// Begin thread-safe update of shared result state
bool first = false; bool first = false;
{ {
std::lock_guard<std::mutex> lock(mux); std::lock_guard<std::mutex> lock(mux);
@ -229,29 +264,27 @@ namespace smt {
finished_id = i; finished_id = i;
result = r; result = r;
} }
else if (!first) return; else if (!first) return; // nothing new to contribute
} }
// Cancel limits on other threads now that a result is known
for (ast_manager* m : pms) { for (ast_manager* m : pms) {
if (m != &pm) m->limit().cancel(); if (m != &pm) m->limit().cancel();
} }
} } catch (z3_error & err) {
catch (z3_error & err) {
if (finished_id == UINT_MAX) { if (finished_id == UINT_MAX) {
error_code = err.error_code(); error_code = err.error_code();
ex_kind = ERROR_EX; ex_kind = ERROR_EX;
done = true; done = true;
} }
} } catch (z3_exception & ex) {
catch (z3_exception & ex) {
if (finished_id == UINT_MAX) { if (finished_id == UINT_MAX) {
ex_msg = ex.what(); ex_msg = ex.what();
ex_kind = DEFAULT_EX; ex_kind = DEFAULT_EX;
done = true; done = true;
} }
} } catch (...) {
catch (...) {
if (finished_id == UINT_MAX) { if (finished_id == UINT_MAX) {
ex_msg = "unknown exception"; ex_msg = "unknown exception";
ex_kind = ERROR_EX; ex_kind = ERROR_EX;
@ -260,36 +293,45 @@ namespace smt {
} }
}; };
// for debugging: num_threads = 1; // Thread scheduling loop
while (true) { while (true) {
vector<std::thread> threads(num_threads); vector<std::thread> threads(num_threads);
// Launch threads
for (unsigned i = 0; i < num_threads; ++i) { for (unsigned i = 0; i < num_threads; ++i) {
// [&, i] is the lambda's capture clause: capture all variables by reference (&) except i, which is captured by value. // [&, i] is the lambda's capture clause: capture all variables by reference (&) except i, which is captured by value.
threads[i] = std::thread([&, i]() { worker_thread(i); }); threads[i] = std::thread([&, i]() { worker_thread(i); });
} }
// Wait for all threads to finish
for (auto & th : threads) { for (auto & th : threads) {
th.join(); th.join();
} }
// Stop if one finished with a result
if (done) break; if (done) break;
// Otherwise update shared state and retry
collect_units(); collect_units();
++num_rounds; ++num_rounds;
max_conflicts = (max_conflicts < thread_max_conflicts) ? 0 : (max_conflicts - thread_max_conflicts); max_conflicts = (max_conflicts < thread_max_conflicts) ? 0 : (max_conflicts - thread_max_conflicts);
thread_max_conflicts *= 2; thread_max_conflicts *= 2;
} }
// Gather statistics from all solver contexts
for (context* c : pctxs) { for (context* c : pctxs) {
c->collect_statistics(ctx.m_aux_stats); c->collect_statistics(ctx.m_aux_stats);
} }
// If no thread finished successfully, throw recorded error
if (finished_id == UINT_MAX) { if (finished_id == UINT_MAX) {
switch (ex_kind) { switch (ex_kind) {
case ERROR_EX: throw z3_error(error_code); case ERROR_EX: throw z3_error(error_code);
default: throw default_exception(std::move(ex_msg)); default: throw default_exception(std::move(ex_msg));
} }
} }
// Handle result: translate model/unsat core back to main context
model_ref mdl; model_ref mdl;
context& pctx = *pctxs[finished_id]; context& pctx = *pctxs[finished_id];
ast_translation tr(*pms[finished_id], m); ast_translation tr(*pms[finished_id], m);
@ -306,7 +348,7 @@ namespace smt {
break; break;
default: default:
break; break;
} }
return result; return result;
} }