3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2026-03-06 21:34:53 +00:00

second round of copilot prompting

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2026-02-23 13:34:06 -08:00
parent c7bf96325c
commit c4c4d18da3
4 changed files with 429 additions and 50 deletions

View file

@ -461,7 +461,13 @@ public:
\brief Return true if x is a definition.
*/
bool is_definition(var x) const { return m_defs[x] != 0; }
/**
\brief Return the number of clauses/definitions that watch variable x.
Used as an occurrence-count approximation by AriParti-style variable selectors.
*/
unsigned num_watches(var x) const { return m_wlist[x].size(); }
typedef svector<watched> watch_list;
typedef _scoped_numeral_vector<numeral_manager> scoped_numeral_vector;

View file

@ -184,6 +184,216 @@ public:
}
};
/**
\brief Variable selector implementing AriParti's multi-key heuristic (Section 4.2).
Uses four ranking keys applied in a dynamically-rotated order:
Key 0 (split_cnt): fewer previous splits preferred (avoids re-splitting the same variable)
Key 1 (cz): interval contains zero preferred (zero-split heuristic)
Key 2 (occ): more watch-list entries preferred (more constrained variable)
Key 3 (width): wider interval preferred (more room to bisect)
Width is encoded with AriParti-style penalties so that bounded, half-bounded, and
fully-unbounded variables can all be compared on a single scale:
- Fully unbounded: width = PENALTY^2 (= 1 048 576)
- Half-bounded [lo, +), lo>0: width = PENALTY / max(1,lo)
- Half-bounded [lo, +), lo0: width = PENALTY + (-lo)
- Half-bounded (-, up], up0: width = PENALTY + up
- Half-bounded (-, up], up<0: width = PENALTY / max(1,-up)
- Fully bounded [lo, up]: width = up - lo
Key ordering for a child node is derived from its parent's ordering by the same
rotation rule as AriParti: find the first consecutive pair (i-1,i) where
key_rank[i-1] < key_rank[i] and swap them; if no such pair exists reset to [0,1,2,3].
*/
template<typename C>
class ariparti_var_selector : public context_t<C>::var_selector {
typedef typename context_t<C>::numeral_manager numeral_manager;
typedef typename numeral_manager::numeral numeral;
typedef typename context_t<C>::bound bound;
typedef typename context_t<C>::node node;
static const unsigned NUM_KEYS = 4;
static const unsigned PENALTY = 1024;
// Per-variable split counts (incremented each time operator() selects a variable)
unsigned_vector m_split_cnt;
// Per-node key ordering (indexed by node id)
vector<svector<unsigned>> m_key_rank;
svector<unsigned> default_key_rank() const {
svector<unsigned> r;
for (unsigned i = 0; i < NUM_KEYS; ++i) r.push_back(i);
return r;
}
svector<unsigned> child_key_rank(svector<unsigned> const & parent) const {
svector<unsigned> r = parent;
unsigned pos = 0;
for (unsigned i = 1; i < NUM_KEYS; ++i) {
if (r[i - 1] < r[i]) { pos = i; break; }
}
if (pos == 0) {
// no out-of-order consecutive pair: reset to identity
for (unsigned i = 0; i < NUM_KEYS; ++i) r[i] = i;
} else {
std::swap(r[pos - 1], r[pos]);
}
return r;
}
svector<unsigned> const & key_rank_for(node * n) {
unsigned id = n->id();
if (id >= m_key_rank.size()) {
m_key_rank.resize(id + 1);
m_key_rank[id] = default_key_rank();
}
if (m_key_rank[id].empty())
m_key_rank[id] = default_key_rank();
return m_key_rank[id];
}
public:
ariparti_var_selector(context_t<C> * ctx) : context_t<C>::var_selector(ctx) {}
void new_var_eh(var x) override {
if (x >= m_split_cnt.size())
m_split_cnt.resize(x + 1, 0);
}
void new_node_eh(node * n) override {
unsigned id = n->id();
if (id >= m_key_rank.size())
m_key_rank.resize(id + 1);
node * parent = n->parent();
if (parent == nullptr) {
m_key_rank[id] = default_key_rank();
} else {
m_key_rank[id] = child_key_rank(key_rank_for(parent));
}
}
void del_node_eh(node * n) override {
unsigned id = n->id();
if (id < m_key_rank.size())
m_key_rank[id].reset();
}
var operator()(node * n) override {
numeral_manager & nm = this->ctx()->nm();
unsigned num = this->ctx()->num_vars();
if (num == 0) return null_var;
svector<unsigned> const & rank = key_rank_for(n);
// Ensure split_cnt vector is large enough
if (m_split_cnt.size() < num) m_split_cnt.resize(num, 0);
var best = null_var;
unsigned best_split = 0, best_occ = 0;
bool best_cz = false;
_scoped_numeral<numeral_manager> best_width(nm), curr_width(nm);
auto key_lt = [&](var a, var b) {
// Returns true if a is strictly better than b for key k.
// For use inside the selection loop where we compare 'a' against 'best'.
(void)a; (void)b; return false; // placeholder inlined below
};
(void)key_lt;
for (var x = 0; x < num; ++x) {
if (this->ctx()->is_definition(x)) continue;
bound * lo = n->lower(x);
bound * up = n->upper(x);
if (lo != nullptr && up != nullptr && nm.eq(lo->value(), up->value()))
continue; // already fixed
unsigned split = (x < m_split_cnt.size()) ? m_split_cnt[x] : 0;
bool cz = ((lo == nullptr || nm.is_neg(lo->value())) &&
(up == nullptr || nm.is_pos(up->value())));
unsigned occ = this->ctx()->num_watches(x);
if (occ == 0) continue; // variable not in any constraint
// Compute width with penalty encoding
if (lo == nullptr && up == nullptr) {
nm.set(curr_width, (int)(PENALTY * PENALTY));
} else if (lo == nullptr) {
// (-∞, up]
if (nm.is_neg(up->value())) {
// up < 0: width = PENALTY / max(1, -up)
_scoped_numeral<numeral_manager> neg_up(nm);
nm.set(neg_up, up->value());
nm.neg(neg_up);
if (nm.is_zero(neg_up) || nm.lt(neg_up, 1)) nm.set(neg_up, 1);
nm.set(curr_width, (int)PENALTY);
nm.div(curr_width, neg_up, curr_width);
} else {
// up >= 0: width = PENALTY + up
nm.set(curr_width, (int)PENALTY);
nm.add(curr_width, up->value(), curr_width);
}
} else if (up == nullptr) {
// [lo, +∞)
if (nm.is_pos(lo->value())) {
// lo > 0: width = PENALTY / max(1, lo)
_scoped_numeral<numeral_manager> pos_lo(nm);
nm.set(pos_lo, lo->value());
if (nm.lt(pos_lo, 1)) nm.set(pos_lo, 1);
nm.set(curr_width, (int)PENALTY);
nm.div(curr_width, pos_lo, curr_width);
} else {
// lo <= 0: width = PENALTY + (-lo)
_scoped_numeral<numeral_manager> neg_lo(nm);
nm.set(neg_lo, lo->value());
nm.neg(neg_lo);
nm.set(curr_width, (int)PENALTY);
nm.add(curr_width, neg_lo, curr_width);
}
} else {
// Fully bounded: width = up - lo
C::round_to_plus_inf(nm);
nm.sub(up->value(), lo->value(), curr_width);
}
if (best == null_var) {
best = x;
best_split = split; best_cz = cz;
best_occ = occ; nm.set(best_width, curr_width);
continue;
}
// Multi-key comparison using the node's key ordering
bool prefer = false;
for (unsigned ki = 0; ki < NUM_KEYS; ++ki) {
unsigned key = rank[ki];
bool lt = false, eq = false;
switch (key) {
case 0: lt = split < best_split; eq = split == best_split; break;
case 1: lt = cz && !best_cz; eq = cz == best_cz; break;
case 2: lt = occ > best_occ; eq = occ == best_occ; break;
case 3: lt = nm.gt(curr_width, best_width);
eq = nm.eq(curr_width, best_width); break;
default: break;
}
if (lt) { prefer = true; break; }
if (!eq) break; // x is strictly worse on this key
}
if (prefer) {
best = x;
best_split = split; best_cz = cz;
best_occ = occ; nm.set(best_width, curr_width);
}
}
if (best != null_var) {
if (best >= m_split_cnt.size()) m_split_cnt.resize(best + 1, 0);
++m_split_cnt[best];
}
return best;
}
};
template<typename C>
class midpoint_node_splitter : public context_t<C>::node_splitter {
typedef typename context_t<C>::numeral_manager numeral_manager;

View file

@ -138,7 +138,7 @@ namespace smt {
if (m_config.m_max_cube_depth <= cube.size())
goto check_cube_start;
auto atom = get_split_atom();
auto atom = get_split_atom(node);
if (!atom)
goto check_cube_start;
b.split(m_l2g, id, node, atom);
@ -339,6 +339,51 @@ namespace smt {
}
}
svector<unsigned> parallel::batch_manager::compute_child_key_rank(svector<unsigned> const & parent) {
static const unsigned NUM_KEYS = 5;
svector<unsigned> r = parent;
// Find the first consecutive pair (i-1, i) where r[i-1] < r[i] and swap them.
// If no such pair exists (the rank is non-increasing), reset to identity [0..4].
unsigned pos = 0;
for (unsigned i = 1; i < NUM_KEYS; ++i) {
if (r[i - 1] < r[i]) { pos = i; break; }
}
if (pos == 0) {
for (unsigned i = 0; i < NUM_KEYS; ++i) r[i] = i;
} else {
std::swap(r[pos - 1], r[pos]);
}
return r;
}
svector<unsigned> parallel::batch_manager::get_node_key_rank(node * n) {
// Called from workers (lock must NOT be held by caller).
// Returns by value to avoid holding a dangling reference after releasing mux.
std::scoped_lock lock(mux);
if (n == nullptr) {
svector<unsigned> r;
for (unsigned i = 0; i < 5; ++i) r.push_back(i);
return r;
}
auto it = m_node_key_rank.find(n);
if (it == m_node_key_rank.end()) {
svector<unsigned> r;
for (unsigned i = 0; i < 5; ++i) r.push_back(i);
m_node_key_rank[n] = r;
return r;
}
return it->second;
}
unsigned parallel::batch_manager::get_var_split_cnt(expr * var) {
// Called from workers (no mutex held by caller).
std::scoped_lock lock(mux);
unsigned cnt = 0;
m_var_split_cnt.find(var, cnt);
return cnt;
}
void parallel::batch_manager::split(ast_translation &l2g, unsigned source_worker_id,
search_tree::node<cube_config> *node, expr *atom) {
std::scoped_lock lock(mux);
@ -354,9 +399,35 @@ namespace smt {
// then ignore split, and instead set the status of node to open.
++m_stats.m_num_cubes;
m_stats.m_max_cube_depth = std::max(m_stats.m_max_cube_depth, node->depth() + 1);
// Update AriParti-style split tracking before modifying the tree.
// Extract the split variable from the atom (expected form: var <= mid or var >= mid).
if (is_app(lit.get()) && to_app(lit.get())->get_num_args() == 2) {
expr * lhs = to_app(lit.get())->get_arg(0);
if (is_app(lhs) && to_app(lhs)->get_num_args() == 0) {
// lhs is a leaf variable — increment its global split count.
unsigned cnt = 0;
m_var_split_cnt.find(lhs, cnt);
m_var_split_cnt.insert(lhs, cnt + 1);
}
}
m_search_tree.split(node, lit, nlit);
// Initialise key_rank for the two children using AriParti's rotation rule.
svector<unsigned> parent_rank;
auto pit = m_node_key_rank.find(node);
if (pit == m_node_key_rank.end()) {
for (unsigned i = 0; i < 5; ++i) parent_rank.push_back(i);
} else {
parent_rank = pit->second;
}
svector<unsigned> child_rank = compute_child_key_rank(parent_rank);
if (node->left()) m_node_key_rank[node->left()] = child_rank;
if (node->right()) m_node_key_rank[node->right()] = child_rank;
}
void parallel::batch_manager::collect_clause(ast_translation &l2g, unsigned source_worker_id, expr *clause) {
std::scoped_lock lock(mux);
expr *g_clause = l2g(clause);
@ -411,27 +482,43 @@ namespace smt {
return r;
}
expr_ref parallel::worker::get_arith_split_atom() {
expr_ref parallel::worker::get_arith_split_atom(node * cur_node) {
arith_util a(m);
arith_value av(m);
av.init(ctx.get());
// For each arithmetic constant (arity-0 variable of Int/Real sort) collect:
// - occurrence count: number of arithmetic-comparison parent enodes
// - current theory bounds from arith_value (may be absent for unbounded vars)
// AriParti-style 5-key variable selection (Section 4.2 of AriParti paper):
//
// Variables without theory-propagated bounds are still valid split candidates:
// they will be split at 0 (AriParti's zero-split heuristic, paper Section 4.2).
// Key 0 (split_cnt): fewer previous global splits → preferred
// Key 1 (cz): interval contains zero → preferred (zero-split heuristic)
// Key 2 (deg): higher polynomial degree → preferred (prioritise nonlinear vars)
// Key 3 (occ): more occurrences in arithmetic comparisons → preferred
// Key 4 (width): wider interval → preferred
//
// Width uses AriParti penalty encoding so that unbounded/half-bounded/bounded
// variables can all be ranked on a single numeric scale:
// fully unbounded: PENALTY^2 (= 1 048 576)
// half-bounded [lo,+∞) lo>0: PENALTY/max(1,lo)
// half-bounded [lo,+∞) lo≤0: PENALTY + (-lo)
// half-bounded (-∞,up] up≥0: PENALTY + up
// half-bounded (-∞,up] up<0: PENALTY/max(1,-up)
// fully bounded [lo,up]: up - lo
//
// Delta for half-bounded midpoints follows AriParti: 128 (vs old value of 1).
static const rational PENALTY(1024);
static const rational PENALTY_SQ = PENALTY * PENALTY; // 1 048 576
struct VarInfo {
unsigned occ = 0;
unsigned deg = 1; // 1=linear, 2+=nonlinear
bool has_lo = false, has_up = false;
rational lo, up;
bool lo_strict = false, up_strict = false;
};
obj_map<expr, VarInfo> vars;
for (enode *n : ctx->enodes()) {
expr *e = n->get_expr();
for (enode *en : ctx->enodes()) {
expr *e = en->get_expr();
if (!is_app(e))
continue;
if (to_app(e)->get_num_args() != 0)
@ -439,11 +526,33 @@ namespace smt {
if (!a.is_int_real(e))
continue;
VarInfo &info = vars.insert_if_not_there(e, VarInfo{});
for (enode *p : n->get_parents()) {
for (enode *p : en->get_parents()) {
expr *pe = p->get_expr();
if (a.is_le(pe) || a.is_ge(pe) || a.is_lt(pe) || a.is_gt(pe) ||
m.is_eq(pe))
++info.occ;
// Detect nonlinear context: variable appears inside a multiplication
// or power expression where at least one other argument is non-constant.
if (a.is_mul(pe) || a.is_power(pe)) {
app *ppe = to_app(pe);
for (unsigned i = 0; i < ppe->get_num_args(); ++i) {
expr *arg = ppe->get_arg(i);
if (arg != e && !a.is_numeral(arg)) {
info.deg = 2; // at least quadratic / bilinear
break;
}
}
// power: check second argument for the actual exponent
if (a.is_power(pe) && ppe->get_num_args() == 2) {
rational exp_val;
if (a.is_numeral(ppe->get_arg(1), exp_val) &&
exp_val > rational(1)) {
info.deg = std::max(info.deg, (unsigned)exp_val.get_unsigned());
} else {
info.deg = std::max(info.deg, 2u);
}
}
}
}
if (info.occ == 0)
info.occ = 1;
@ -453,55 +562,83 @@ namespace smt {
if (vars.empty())
return expr_ref(m);
// Select the best variable using AriParti's heuristic (Section 4.2):
// 1. More occurrences in arithmetic atoms is better.
// 2. Among ties, wider interval is better.
// 3. Among ties, interval containing zero is better.
// Fully-bounded variables are preferred over half-bounded or unbounded
// because the interval width is more meaningful.
// Get the per-node key ordering (initialised from AriParti's rotation logic).
svector<unsigned> key_rank = b.get_node_key_rank(cur_node);
// Encode width using AriParti penalty scheme.
auto encode_width = [&](VarInfo const &info) -> rational {
if (!info.has_lo && !info.has_up)
return PENALTY_SQ;
if (!info.has_lo) {
// (-∞, up]
if (info.up < rational::zero()) {
rational neg_up = -info.up;
if (neg_up < rational::one()) neg_up = rational::one();
return PENALTY / neg_up;
} else {
return PENALTY + info.up;
}
}
if (!info.has_up) {
// [lo, +∞)
if (info.lo > rational::zero()) {
rational pos_lo = info.lo;
if (pos_lo < rational::one()) pos_lo = rational::one();
return PENALTY / pos_lo;
} else {
return PENALTY + (-info.lo);
}
}
return info.up - info.lo;
};
expr *best_term = nullptr;
rational best_width = rational(-1);
unsigned best_occ = 0;
unsigned best_split = 0;
bool best_cz = false;
bool best_bounded = false;
unsigned best_deg = 0;
unsigned best_occ = 0;
rational best_width;
for (auto const &[term, info] : vars) {
if (info.has_lo && info.has_up && info.lo >= info.up)
continue; // already fixed — nothing useful to split
bool fully_bounded = info.has_lo && info.has_up;
rational width = fully_bounded ? (info.up - info.lo) : rational(-1);
unsigned split = b.get_var_split_cnt(m_l2g(term));
bool cz;
if (fully_bounded)
if (info.has_lo && info.has_up)
cz = (info.lo <= rational::zero() && rational::zero() <= info.up);
else if (info.has_lo)
cz = (info.lo <= rational::zero());
else if (info.has_up)
cz = (rational::zero() <= info.up);
else
cz = true; // unbounded: split at 0 is always valid
cz = true;
rational width = encode_width(info);
bool prefer = !best_term;
if (!prefer) {
if (info.occ > best_occ)
prefer = true;
else if (info.occ == best_occ) {
if (fully_bounded && !best_bounded)
prefer = true;
else if (fully_bounded == best_bounded) {
if (width > best_width)
prefer = true;
else if (width == best_width && cz && !best_cz)
prefer = true;
for (unsigned ki = 0; ki < 5; ++ki) {
unsigned key = key_rank[ki];
bool lt = false, eq = false;
switch (key) {
case 0: lt = split < best_split; eq = split == best_split; break;
case 1: lt = cz && !best_cz; eq = cz == best_cz; break;
case 2: lt = info.deg > best_deg; eq = info.deg == best_deg; break;
case 3: lt = info.occ > best_occ; eq = info.occ == best_occ; break;
case 4: lt = width > best_width; eq = width == best_width; break;
default: break;
}
if (lt) { prefer = true; break; }
if (!eq) break; // current term is strictly worse on this key
}
}
if (prefer) {
best_term = term;
best_width = width;
best_occ = info.occ;
best_cz = cz;
best_bounded = fully_bounded;
best_term = term;
best_split = split;
best_cz = cz;
best_deg = info.deg;
best_occ = info.occ;
best_width = width;
}
}
if (!best_term)
@ -511,6 +648,8 @@ namespace smt {
vars.find(best_term, bi);
// Compute split midpoint following AriParti Section 4.2.
// Delta of 128 (vs old value of 1) for half-bounded intervals.
static const rational DELTA(128);
rational mid;
if (best_cz) {
mid = rational::zero();
@ -519,22 +658,28 @@ namespace smt {
if (a.is_int(best_term))
mid = floor(mid);
} else if (bi.has_lo) {
mid = bi.lo; // split at known lower bound
// [lo, +∞): split at lo + delta
mid = bi.lo + DELTA;
if (a.is_int(best_term))
mid = floor(mid);
} else if (bi.has_up) {
mid = bi.up - rational(1);
if (!a.is_int(best_term))
mid = bi.up - rational(1, 2);
// (-∞, up]: split at floor(up) - delta
mid = floor(bi.up) - DELTA;
} else {
mid = rational::zero();
}
sort *srt = best_term->get_sort();
LOG_WORKER(2, " arith split on " << mk_bounded_pp(best_term, m, 2)
<< " at " << mid << "\n");
<< " at " << mid
<< " (split_cnt=" << best_split
<< " cz=" << best_cz
<< " deg=" << best_deg
<< " occ=" << best_occ << ")\n");
return expr_ref(a.mk_le(best_term, a.mk_numeral(mid, srt)), m);
}
expr_ref parallel::worker::get_split_atom() {
expr_ref parallel::worker::get_split_atom(node * cur_node) {
expr_ref result(m);
double score = 0;
unsigned n = 0;
@ -544,7 +689,7 @@ namespace smt {
// This is particularly effective for arithmetic theories (QF_LRA, QF_LIA,
// QF_NRA, QF_NIA) where splitting at the midpoint of a variable's current
// interval is more informative than a Boolean variable split.
expr_ref arith_atom = get_arith_split_atom();
expr_ref arith_atom = get_arith_split_atom(cur_node);
if (arith_atom)
return arith_atom;
@ -569,6 +714,7 @@ namespace smt {
return result;
}
void parallel::batch_manager::set_sat(ast_translation &l2g, model &m) {
std::scoped_lock lock(mux);
IF_VERBOSE(1, verbose_stream() << "Batch manager setting SAT.\n");

View file

@ -23,6 +23,7 @@ Revision History:
#include "ast/sls/sls_smt_solver.h"
#include <thread>
#include <mutex>
#include <unordered_map>
namespace smt {
@ -43,7 +44,8 @@ namespace smt {
expr_ref clause;
};
class batch_manager {
class batch_manager {
friend class worker; // worker accesses AriParti tracking helpers
enum state {
is_running,
@ -66,7 +68,14 @@ namespace smt {
stats m_stats;
using node = search_tree::node<cube_config>;
search_tree::tree<cube_config> m_search_tree;
// AriParti-style split tracking (protected by mux):
// How many times each arithmetic expression has been chosen as a split variable.
obj_map<expr, unsigned> m_var_split_cnt;
// Per-node key ordering for 5-key AriParti variable selection.
// Key ordering for node n's children is derived from n's ordering via rotation.
std::unordered_map<node*, svector<unsigned>> m_node_key_rank;
unsigned m_exception_code = 0;
std::string m_exception_msg;
vector<shared_clause> shared_clause_trail; // store all shared clauses with worker IDs
@ -108,6 +117,14 @@ namespace smt {
void backtrack(ast_translation& l2g, expr_ref_vector const& core, node* n);
void split(ast_translation& l2g, unsigned id, node* n, expr* atom);
// AriParti-style split-tracking helpers (thread-safe; called by workers).
// Returns the 5-element key ordering for node n (initialises to [0..4] if absent).
svector<unsigned> get_node_key_rank(node * n);
// Returns global split count for the given arithmetic expression.
unsigned get_var_split_cnt(expr * var);
// Computes the child key ordering from parent using AriParti's rotation rule.
static svector<unsigned> compute_child_key_rank(svector<unsigned> const & parent);
void collect_clause(ast_translation& l2g, unsigned source_worker_id, expr* clause);
expr_ref_vector return_shared_clauses(ast_translation& g2l, unsigned& worker_limit, unsigned worker_id);
@ -146,8 +163,8 @@ namespace smt {
unsigned m_num_initial_atoms = 0;
unsigned m_shared_clause_limit = 0; // remembers the index into shared_clause_trail marking the boundary between "old" and "new" clauses to share
expr_ref get_split_atom();
expr_ref get_arith_split_atom();
expr_ref get_split_atom(node * n);
expr_ref get_arith_split_atom(node * n);
lbool check_cube(expr_ref_vector const& cube);
void share_units();