diff --git a/src/math/subpaving/subpaving_t.h b/src/math/subpaving/subpaving_t.h index 7300e3da3..2b3806bd6 100644 --- a/src/math/subpaving/subpaving_t.h +++ b/src/math/subpaving/subpaving_t.h @@ -461,7 +461,13 @@ public: \brief Return true if x is a definition. */ bool is_definition(var x) const { return m_defs[x] != 0; } - + + /** + \brief Return the number of clauses/definitions that watch variable x. + Used as an occurrence-count approximation by AriParti-style variable selectors. + */ + unsigned num_watches(var x) const { return m_wlist[x].size(); } + typedef svector watch_list; typedef _scoped_numeral_vector scoped_numeral_vector; diff --git a/src/math/subpaving/subpaving_t_def.h b/src/math/subpaving/subpaving_t_def.h index b71b10fae..239db99ef 100644 --- a/src/math/subpaving/subpaving_t_def.h +++ b/src/math/subpaving/subpaving_t_def.h @@ -184,6 +184,216 @@ public: } }; +/** + \brief Variable selector implementing AriParti's multi-key heuristic (Section 4.2). + + Uses four ranking keys applied in a dynamically-rotated order: + Key 0 (split_cnt): fewer previous splits → preferred (avoids re-splitting the same variable) + Key 1 (cz): interval contains zero → preferred (zero-split heuristic) + Key 2 (occ): more watch-list entries → preferred (more constrained variable) + Key 3 (width): wider interval → preferred (more room to bisect) + + Width is encoded with AriParti-style penalties so that bounded, half-bounded, and + fully-unbounded variables can all be compared on a single scale: + - Fully unbounded: width = PENALTY^2 (= 1 048 576) + - Half-bounded [lo, +∞), lo>0: width = PENALTY / max(1,lo) + - Half-bounded [lo, +∞), lo≤0: width = PENALTY + (-lo) + - Half-bounded (-∞, up], up≥0: width = PENALTY + up + - Half-bounded (-∞, up], up<0: width = PENALTY / max(1,-up) + - Fully bounded [lo, up]: width = up - lo + + Key ordering for a child node is derived from its parent's ordering by the same + rotation rule as AriParti: find the first consecutive pair (i-1,i) where + key_rank[i-1] < key_rank[i] and swap them; if no such pair exists reset to [0,1,2,3]. +*/ +template +class ariparti_var_selector : public context_t::var_selector { + typedef typename context_t::numeral_manager numeral_manager; + typedef typename numeral_manager::numeral numeral; + typedef typename context_t::bound bound; + typedef typename context_t::node node; + + static const unsigned NUM_KEYS = 4; + static const unsigned PENALTY = 1024; + + // Per-variable split counts (incremented each time operator() selects a variable) + unsigned_vector m_split_cnt; + + // Per-node key ordering (indexed by node id) + vector> m_key_rank; + + svector default_key_rank() const { + svector r; + for (unsigned i = 0; i < NUM_KEYS; ++i) r.push_back(i); + return r; + } + + svector child_key_rank(svector const & parent) const { + svector r = parent; + unsigned pos = 0; + for (unsigned i = 1; i < NUM_KEYS; ++i) { + if (r[i - 1] < r[i]) { pos = i; break; } + } + if (pos == 0) { + // no out-of-order consecutive pair: reset to identity + for (unsigned i = 0; i < NUM_KEYS; ++i) r[i] = i; + } else { + std::swap(r[pos - 1], r[pos]); + } + return r; + } + + svector const & key_rank_for(node * n) { + unsigned id = n->id(); + if (id >= m_key_rank.size()) { + m_key_rank.resize(id + 1); + m_key_rank[id] = default_key_rank(); + } + if (m_key_rank[id].empty()) + m_key_rank[id] = default_key_rank(); + return m_key_rank[id]; + } + +public: + ariparti_var_selector(context_t * ctx) : context_t::var_selector(ctx) {} + + void new_var_eh(var x) override { + if (x >= m_split_cnt.size()) + m_split_cnt.resize(x + 1, 0); + } + + void new_node_eh(node * n) override { + unsigned id = n->id(); + if (id >= m_key_rank.size()) + m_key_rank.resize(id + 1); + node * parent = n->parent(); + if (parent == nullptr) { + m_key_rank[id] = default_key_rank(); + } else { + m_key_rank[id] = child_key_rank(key_rank_for(parent)); + } + } + + void del_node_eh(node * n) override { + unsigned id = n->id(); + if (id < m_key_rank.size()) + m_key_rank[id].reset(); + } + + var operator()(node * n) override { + numeral_manager & nm = this->ctx()->nm(); + unsigned num = this->ctx()->num_vars(); + if (num == 0) return null_var; + + svector const & rank = key_rank_for(n); + + // Ensure split_cnt vector is large enough + if (m_split_cnt.size() < num) m_split_cnt.resize(num, 0); + + var best = null_var; + unsigned best_split = 0, best_occ = 0; + bool best_cz = false; + _scoped_numeral best_width(nm), curr_width(nm); + + auto key_lt = [&](var a, var b) { + // Returns true if a is strictly better than b for key k. + // For use inside the selection loop where we compare 'a' against 'best'. + (void)a; (void)b; return false; // placeholder – inlined below + }; + (void)key_lt; + + for (var x = 0; x < num; ++x) { + if (this->ctx()->is_definition(x)) continue; + bound * lo = n->lower(x); + bound * up = n->upper(x); + if (lo != nullptr && up != nullptr && nm.eq(lo->value(), up->value())) + continue; // already fixed + + unsigned split = (x < m_split_cnt.size()) ? m_split_cnt[x] : 0; + bool cz = ((lo == nullptr || nm.is_neg(lo->value())) && + (up == nullptr || nm.is_pos(up->value()))); + unsigned occ = this->ctx()->num_watches(x); + if (occ == 0) continue; // variable not in any constraint + + // Compute width with penalty encoding + if (lo == nullptr && up == nullptr) { + nm.set(curr_width, (int)(PENALTY * PENALTY)); + } else if (lo == nullptr) { + // (-∞, up] + if (nm.is_neg(up->value())) { + // up < 0: width = PENALTY / max(1, -up) + _scoped_numeral neg_up(nm); + nm.set(neg_up, up->value()); + nm.neg(neg_up); + if (nm.is_zero(neg_up) || nm.lt(neg_up, 1)) nm.set(neg_up, 1); + nm.set(curr_width, (int)PENALTY); + nm.div(curr_width, neg_up, curr_width); + } else { + // up >= 0: width = PENALTY + up + nm.set(curr_width, (int)PENALTY); + nm.add(curr_width, up->value(), curr_width); + } + } else if (up == nullptr) { + // [lo, +∞) + if (nm.is_pos(lo->value())) { + // lo > 0: width = PENALTY / max(1, lo) + _scoped_numeral pos_lo(nm); + nm.set(pos_lo, lo->value()); + if (nm.lt(pos_lo, 1)) nm.set(pos_lo, 1); + nm.set(curr_width, (int)PENALTY); + nm.div(curr_width, pos_lo, curr_width); + } else { + // lo <= 0: width = PENALTY + (-lo) + _scoped_numeral neg_lo(nm); + nm.set(neg_lo, lo->value()); + nm.neg(neg_lo); + nm.set(curr_width, (int)PENALTY); + nm.add(curr_width, neg_lo, curr_width); + } + } else { + // Fully bounded: width = up - lo + C::round_to_plus_inf(nm); + nm.sub(up->value(), lo->value(), curr_width); + } + + if (best == null_var) { + best = x; + best_split = split; best_cz = cz; + best_occ = occ; nm.set(best_width, curr_width); + continue; + } + + // Multi-key comparison using the node's key ordering + bool prefer = false; + for (unsigned ki = 0; ki < NUM_KEYS; ++ki) { + unsigned key = rank[ki]; + bool lt = false, eq = false; + switch (key) { + case 0: lt = split < best_split; eq = split == best_split; break; + case 1: lt = cz && !best_cz; eq = cz == best_cz; break; + case 2: lt = occ > best_occ; eq = occ == best_occ; break; + case 3: lt = nm.gt(curr_width, best_width); + eq = nm.eq(curr_width, best_width); break; + default: break; + } + if (lt) { prefer = true; break; } + if (!eq) break; // x is strictly worse on this key + } + if (prefer) { + best = x; + best_split = split; best_cz = cz; + best_occ = occ; nm.set(best_width, curr_width); + } + } + + if (best != null_var) { + if (best >= m_split_cnt.size()) m_split_cnt.resize(best + 1, 0); + ++m_split_cnt[best]; + } + return best; + } +}; + template class midpoint_node_splitter : public context_t::node_splitter { typedef typename context_t::numeral_manager numeral_manager; diff --git a/src/smt/smt_parallel.cpp b/src/smt/smt_parallel.cpp index 93931ef54..823611fd8 100644 --- a/src/smt/smt_parallel.cpp +++ b/src/smt/smt_parallel.cpp @@ -138,7 +138,7 @@ namespace smt { if (m_config.m_max_cube_depth <= cube.size()) goto check_cube_start; - auto atom = get_split_atom(); + auto atom = get_split_atom(node); if (!atom) goto check_cube_start; b.split(m_l2g, id, node, atom); @@ -339,6 +339,51 @@ namespace smt { } } + svector parallel::batch_manager::compute_child_key_rank(svector const & parent) { + static const unsigned NUM_KEYS = 5; + svector r = parent; + // Find the first consecutive pair (i-1, i) where r[i-1] < r[i] and swap them. + // If no such pair exists (the rank is non-increasing), reset to identity [0..4]. + unsigned pos = 0; + for (unsigned i = 1; i < NUM_KEYS; ++i) { + if (r[i - 1] < r[i]) { pos = i; break; } + } + if (pos == 0) { + for (unsigned i = 0; i < NUM_KEYS; ++i) r[i] = i; + } else { + std::swap(r[pos - 1], r[pos]); + } + return r; + } + + + svector parallel::batch_manager::get_node_key_rank(node * n) { + // Called from workers (lock must NOT be held by caller). + // Returns by value to avoid holding a dangling reference after releasing mux. + std::scoped_lock lock(mux); + if (n == nullptr) { + svector r; + for (unsigned i = 0; i < 5; ++i) r.push_back(i); + return r; + } + auto it = m_node_key_rank.find(n); + if (it == m_node_key_rank.end()) { + svector r; + for (unsigned i = 0; i < 5; ++i) r.push_back(i); + m_node_key_rank[n] = r; + return r; + } + return it->second; + } + + unsigned parallel::batch_manager::get_var_split_cnt(expr * var) { + // Called from workers (no mutex held by caller). + std::scoped_lock lock(mux); + unsigned cnt = 0; + m_var_split_cnt.find(var, cnt); + return cnt; + } + void parallel::batch_manager::split(ast_translation &l2g, unsigned source_worker_id, search_tree::node *node, expr *atom) { std::scoped_lock lock(mux); @@ -354,9 +399,35 @@ namespace smt { // then ignore split, and instead set the status of node to open. ++m_stats.m_num_cubes; m_stats.m_max_cube_depth = std::max(m_stats.m_max_cube_depth, node->depth() + 1); + + // Update AriParti-style split tracking before modifying the tree. + // Extract the split variable from the atom (expected form: var <= mid or var >= mid). + if (is_app(lit.get()) && to_app(lit.get())->get_num_args() == 2) { + expr * lhs = to_app(lit.get())->get_arg(0); + if (is_app(lhs) && to_app(lhs)->get_num_args() == 0) { + // lhs is a leaf variable — increment its global split count. + unsigned cnt = 0; + m_var_split_cnt.find(lhs, cnt); + m_var_split_cnt.insert(lhs, cnt + 1); + } + } + m_search_tree.split(node, lit, nlit); + + // Initialise key_rank for the two children using AriParti's rotation rule. + svector parent_rank; + auto pit = m_node_key_rank.find(node); + if (pit == m_node_key_rank.end()) { + for (unsigned i = 0; i < 5; ++i) parent_rank.push_back(i); + } else { + parent_rank = pit->second; + } + svector child_rank = compute_child_key_rank(parent_rank); + if (node->left()) m_node_key_rank[node->left()] = child_rank; + if (node->right()) m_node_key_rank[node->right()] = child_rank; } + void parallel::batch_manager::collect_clause(ast_translation &l2g, unsigned source_worker_id, expr *clause) { std::scoped_lock lock(mux); expr *g_clause = l2g(clause); @@ -411,27 +482,43 @@ namespace smt { return r; } - expr_ref parallel::worker::get_arith_split_atom() { + expr_ref parallel::worker::get_arith_split_atom(node * cur_node) { arith_util a(m); arith_value av(m); av.init(ctx.get()); - // For each arithmetic constant (arity-0 variable of Int/Real sort) collect: - // - occurrence count: number of arithmetic-comparison parent enodes - // - current theory bounds from arith_value (may be absent for unbounded vars) + // AriParti-style 5-key variable selection (Section 4.2 of AriParti paper): // - // Variables without theory-propagated bounds are still valid split candidates: - // they will be split at 0 (AriParti's zero-split heuristic, paper Section 4.2). + // Key 0 (split_cnt): fewer previous global splits → preferred + // Key 1 (cz): interval contains zero → preferred (zero-split heuristic) + // Key 2 (deg): higher polynomial degree → preferred (prioritise nonlinear vars) + // Key 3 (occ): more occurrences in arithmetic comparisons → preferred + // Key 4 (width): wider interval → preferred + // + // Width uses AriParti penalty encoding so that unbounded/half-bounded/bounded + // variables can all be ranked on a single numeric scale: + // fully unbounded: PENALTY^2 (= 1 048 576) + // half-bounded [lo,+∞) lo>0: PENALTY/max(1,lo) + // half-bounded [lo,+∞) lo≤0: PENALTY + (-lo) + // half-bounded (-∞,up] up≥0: PENALTY + up + // half-bounded (-∞,up] up<0: PENALTY/max(1,-up) + // fully bounded [lo,up]: up - lo + // + // Delta for half-bounded midpoints follows AriParti: 128 (vs old value of 1). + static const rational PENALTY(1024); + static const rational PENALTY_SQ = PENALTY * PENALTY; // 1 048 576 + struct VarInfo { unsigned occ = 0; + unsigned deg = 1; // 1=linear, 2+=nonlinear bool has_lo = false, has_up = false; rational lo, up; bool lo_strict = false, up_strict = false; }; obj_map vars; - for (enode *n : ctx->enodes()) { - expr *e = n->get_expr(); + for (enode *en : ctx->enodes()) { + expr *e = en->get_expr(); if (!is_app(e)) continue; if (to_app(e)->get_num_args() != 0) @@ -439,11 +526,33 @@ namespace smt { if (!a.is_int_real(e)) continue; VarInfo &info = vars.insert_if_not_there(e, VarInfo{}); - for (enode *p : n->get_parents()) { + for (enode *p : en->get_parents()) { expr *pe = p->get_expr(); if (a.is_le(pe) || a.is_ge(pe) || a.is_lt(pe) || a.is_gt(pe) || m.is_eq(pe)) ++info.occ; + // Detect nonlinear context: variable appears inside a multiplication + // or power expression where at least one other argument is non-constant. + if (a.is_mul(pe) || a.is_power(pe)) { + app *ppe = to_app(pe); + for (unsigned i = 0; i < ppe->get_num_args(); ++i) { + expr *arg = ppe->get_arg(i); + if (arg != e && !a.is_numeral(arg)) { + info.deg = 2; // at least quadratic / bilinear + break; + } + } + // power: check second argument for the actual exponent + if (a.is_power(pe) && ppe->get_num_args() == 2) { + rational exp_val; + if (a.is_numeral(ppe->get_arg(1), exp_val) && + exp_val > rational(1)) { + info.deg = std::max(info.deg, (unsigned)exp_val.get_unsigned()); + } else { + info.deg = std::max(info.deg, 2u); + } + } + } } if (info.occ == 0) info.occ = 1; @@ -453,55 +562,83 @@ namespace smt { if (vars.empty()) return expr_ref(m); - // Select the best variable using AriParti's heuristic (Section 4.2): - // 1. More occurrences in arithmetic atoms is better. - // 2. Among ties, wider interval is better. - // 3. Among ties, interval containing zero is better. - // Fully-bounded variables are preferred over half-bounded or unbounded - // because the interval width is more meaningful. + // Get the per-node key ordering (initialised from AriParti's rotation logic). + svector key_rank = b.get_node_key_rank(cur_node); + + // Encode width using AriParti penalty scheme. + auto encode_width = [&](VarInfo const &info) -> rational { + if (!info.has_lo && !info.has_up) + return PENALTY_SQ; + if (!info.has_lo) { + // (-∞, up] + if (info.up < rational::zero()) { + rational neg_up = -info.up; + if (neg_up < rational::one()) neg_up = rational::one(); + return PENALTY / neg_up; + } else { + return PENALTY + info.up; + } + } + if (!info.has_up) { + // [lo, +∞) + if (info.lo > rational::zero()) { + rational pos_lo = info.lo; + if (pos_lo < rational::one()) pos_lo = rational::one(); + return PENALTY / pos_lo; + } else { + return PENALTY + (-info.lo); + } + } + return info.up - info.lo; + }; + expr *best_term = nullptr; - rational best_width = rational(-1); - unsigned best_occ = 0; + unsigned best_split = 0; bool best_cz = false; - bool best_bounded = false; + unsigned best_deg = 0; + unsigned best_occ = 0; + rational best_width; for (auto const &[term, info] : vars) { if (info.has_lo && info.has_up && info.lo >= info.up) continue; // already fixed — nothing useful to split - bool fully_bounded = info.has_lo && info.has_up; - rational width = fully_bounded ? (info.up - info.lo) : rational(-1); + unsigned split = b.get_var_split_cnt(m_l2g(term)); bool cz; - if (fully_bounded) + if (info.has_lo && info.has_up) cz = (info.lo <= rational::zero() && rational::zero() <= info.up); else if (info.has_lo) cz = (info.lo <= rational::zero()); else if (info.has_up) cz = (rational::zero() <= info.up); else - cz = true; // unbounded: split at 0 is always valid + cz = true; + rational width = encode_width(info); bool prefer = !best_term; if (!prefer) { - if (info.occ > best_occ) - prefer = true; - else if (info.occ == best_occ) { - if (fully_bounded && !best_bounded) - prefer = true; - else if (fully_bounded == best_bounded) { - if (width > best_width) - prefer = true; - else if (width == best_width && cz && !best_cz) - prefer = true; + for (unsigned ki = 0; ki < 5; ++ki) { + unsigned key = key_rank[ki]; + bool lt = false, eq = false; + switch (key) { + case 0: lt = split < best_split; eq = split == best_split; break; + case 1: lt = cz && !best_cz; eq = cz == best_cz; break; + case 2: lt = info.deg > best_deg; eq = info.deg == best_deg; break; + case 3: lt = info.occ > best_occ; eq = info.occ == best_occ; break; + case 4: lt = width > best_width; eq = width == best_width; break; + default: break; } + if (lt) { prefer = true; break; } + if (!eq) break; // current term is strictly worse on this key } } if (prefer) { - best_term = term; - best_width = width; - best_occ = info.occ; - best_cz = cz; - best_bounded = fully_bounded; + best_term = term; + best_split = split; + best_cz = cz; + best_deg = info.deg; + best_occ = info.occ; + best_width = width; } } if (!best_term) @@ -511,6 +648,8 @@ namespace smt { vars.find(best_term, bi); // Compute split midpoint following AriParti Section 4.2. + // Delta of 128 (vs old value of 1) for half-bounded intervals. + static const rational DELTA(128); rational mid; if (best_cz) { mid = rational::zero(); @@ -519,22 +658,28 @@ namespace smt { if (a.is_int(best_term)) mid = floor(mid); } else if (bi.has_lo) { - mid = bi.lo; // split at known lower bound + // [lo, +∞): split at lo + delta + mid = bi.lo + DELTA; + if (a.is_int(best_term)) + mid = floor(mid); } else if (bi.has_up) { - mid = bi.up - rational(1); - if (!a.is_int(best_term)) - mid = bi.up - rational(1, 2); + // (-∞, up]: split at floor(up) - delta + mid = floor(bi.up) - DELTA; } else { mid = rational::zero(); } sort *srt = best_term->get_sort(); LOG_WORKER(2, " arith split on " << mk_bounded_pp(best_term, m, 2) - << " at " << mid << "\n"); + << " at " << mid + << " (split_cnt=" << best_split + << " cz=" << best_cz + << " deg=" << best_deg + << " occ=" << best_occ << ")\n"); return expr_ref(a.mk_le(best_term, a.mk_numeral(mid, srt)), m); } - expr_ref parallel::worker::get_split_atom() { + expr_ref parallel::worker::get_split_atom(node * cur_node) { expr_ref result(m); double score = 0; unsigned n = 0; @@ -544,7 +689,7 @@ namespace smt { // This is particularly effective for arithmetic theories (QF_LRA, QF_LIA, // QF_NRA, QF_NIA) where splitting at the midpoint of a variable's current // interval is more informative than a Boolean variable split. - expr_ref arith_atom = get_arith_split_atom(); + expr_ref arith_atom = get_arith_split_atom(cur_node); if (arith_atom) return arith_atom; @@ -569,6 +714,7 @@ namespace smt { return result; } + void parallel::batch_manager::set_sat(ast_translation &l2g, model &m) { std::scoped_lock lock(mux); IF_VERBOSE(1, verbose_stream() << "Batch manager setting SAT.\n"); diff --git a/src/smt/smt_parallel.h b/src/smt/smt_parallel.h index 91de8deb5..ebfe7e5e3 100644 --- a/src/smt/smt_parallel.h +++ b/src/smt/smt_parallel.h @@ -23,6 +23,7 @@ Revision History: #include "ast/sls/sls_smt_solver.h" #include #include +#include namespace smt { @@ -43,7 +44,8 @@ namespace smt { expr_ref clause; }; - class batch_manager { + class batch_manager { + friend class worker; // worker accesses AriParti tracking helpers enum state { is_running, @@ -66,7 +68,14 @@ namespace smt { stats m_stats; using node = search_tree::node; search_tree::tree m_search_tree; - + + // AriParti-style split tracking (protected by mux): + // How many times each arithmetic expression has been chosen as a split variable. + obj_map m_var_split_cnt; + // Per-node key ordering for 5-key AriParti variable selection. + // Key ordering for node n's children is derived from n's ordering via rotation. + std::unordered_map> m_node_key_rank; + unsigned m_exception_code = 0; std::string m_exception_msg; vector shared_clause_trail; // store all shared clauses with worker IDs @@ -108,6 +117,14 @@ namespace smt { void backtrack(ast_translation& l2g, expr_ref_vector const& core, node* n); void split(ast_translation& l2g, unsigned id, node* n, expr* atom); + // AriParti-style split-tracking helpers (thread-safe; called by workers). + // Returns the 5-element key ordering for node n (initialises to [0..4] if absent). + svector get_node_key_rank(node * n); + // Returns global split count for the given arithmetic expression. + unsigned get_var_split_cnt(expr * var); + // Computes the child key ordering from parent using AriParti's rotation rule. + static svector compute_child_key_rank(svector const & parent); + void collect_clause(ast_translation& l2g, unsigned source_worker_id, expr* clause); expr_ref_vector return_shared_clauses(ast_translation& g2l, unsigned& worker_limit, unsigned worker_id); @@ -146,8 +163,8 @@ namespace smt { unsigned m_num_initial_atoms = 0; unsigned m_shared_clause_limit = 0; // remembers the index into shared_clause_trail marking the boundary between "old" and "new" clauses to share - expr_ref get_split_atom(); - expr_ref get_arith_split_atom(); + expr_ref get_split_atom(node * n); + expr_ref get_arith_split_atom(node * n); lbool check_cube(expr_ref_vector const& cube); void share_units();