3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2026-01-20 17:14:43 +00:00

Replace std::vector with Z3's vector/svector types in theory_finite_set_lattice_refutation

Co-authored-by: NikolajBjorner <3085284+NikolajBjorner@users.noreply.github.com>
This commit is contained in:
copilot-swe-agent[bot] 2026-01-19 02:16:55 +00:00
parent ad50f62194
commit a051ba31f9
2 changed files with 43 additions and 61 deletions

View file

@ -12,33 +12,31 @@ Module Name:
#include "smt/smt_theory.h"
#include "smt/theory_finite_set.h"
#include "smt/smt_context.h"
#include "iostream"
#include "util/uint_set.h"
const int NUM_WORDS = 5;
// some example have shown, the introduction of large conflict clauses can severely slow down refutation
// some examples have shown, the introduction of large conflict clauses can severely slow down refutation
const int MAX_DECISION_LITERALS = 10;
const int MAX_VARS = 320;
namespace smt {
reachability_matrix::reachability_matrix(context &ctx, theory_finite_set_lattice_refutation &t_lattice)
: reachable(NUM_WORDS * NUM_WORDS * 64, 0), links(NUM_WORDS * NUM_WORDS * 64 * 64, {nullptr, nullptr}),
link_dls(NUM_WORDS * NUM_WORDS * 64 * 64, 0), non_links(NUM_WORDS * NUM_WORDS * 64),
non_link_justifications(NUM_WORDS * NUM_WORDS * 64 * 64, {nullptr, nullptr}), largest_var(0),
max_size(NUM_WORDS * 64), ctx(ctx), t_lattice_refutation(t_lattice) {}
: reachable(MAX_VARS), links(MAX_VARS * MAX_VARS, {nullptr, nullptr}),
link_dls(MAX_VARS * MAX_VARS, 0u), non_links(MAX_VARS),
non_link_justifications(MAX_VARS * MAX_VARS, {nullptr, nullptr}), largest_var(0),
max_size(MAX_VARS), ctx(ctx), t_lattice_refutation(t_lattice) {
// Initialize the uint_sets for each row
for (int i = 0; i < MAX_VARS; i++) {
reachable[i].reset();
non_links[i].reset();
}
}
int reachability_matrix::get_max_var() {
return largest_var;
}
inline int reachability_matrix::get_word_index(int row, int col) const {
return (row * NUM_WORDS) + (col / 64);
};
inline uint64_t reachability_matrix::get_bitmask(int col) const {
return 1ull << (col % 64);
};
bool reachability_matrix::is_reachability_forbidden(theory_var source, theory_var dest) {
return non_links[get_word_index(source, dest)] & get_bitmask(dest);
return non_links[source].contains(dest);
}
bool reachability_matrix::in_bounds(theory_var source, theory_var dest) {
@ -46,7 +44,7 @@ namespace smt {
}
bool reachability_matrix::is_reachable(theory_var source, theory_var dest) {
return reachable[get_word_index(source, dest)] & get_bitmask(dest);
return reachable[source].contains(dest);
}
bool reachability_matrix::is_linked(theory_var source, theory_var dest) {
@ -54,17 +52,22 @@ namespace smt {
}
bool reachability_matrix::bitwise_or_rows(int source_dest, int source) {
bool changes = false;
for (int i = 0; i < NUM_WORDS; i++) {
uint64_t old_value = reachable[source_dest * NUM_WORDS + i];
uint64_t new_value = reachable[source_dest * NUM_WORDS + i] | reachable[source * NUM_WORDS + i];
if (old_value == new_value) {
continue;
// Save old state for potential rollback
uint_set old_reachable = reachable[source_dest];
// Compute union: reachable[source_dest] |= reachable[source]
reachable[source_dest] |= reachable[source];
// Check if anything changed
bool changes = !(old_reachable == reachable[source_dest]);
if (changes) {
ctx.push_trail(value_trail(reachable[source_dest]));
// Check for conflicts with newly added reachabilities
for (unsigned dest : reachable[source]) {
if (!old_reachable.contains(dest)) {
check_reachability_conflict(source_dest, dest);
}
}
ctx.push_trail(value_trail(reachable[source_dest * NUM_WORDS + i]));
reachable[source_dest * NUM_WORDS + i] = new_value;
changes = true;
check_reachability_conflict_word(source_dest, i);
}
return changes;
}
@ -76,9 +79,8 @@ namespace smt {
ctx.push_trail(value_trail(largest_var));
largest_var = std::max({largest_var, source, dest});
int word_idx = get_word_index(source, dest);
ctx.push_trail(value_trail(reachable[word_idx]));
reachable[word_idx] |= get_bitmask(dest);
ctx.push_trail(value_trail(reachable[source]));
reachable[source].insert(dest);
ctx.push_trail(value_trail(links[source * max_size + dest]));
links[source * max_size + dest] = reachability_witness;
ctx.push_trail(value_trail(link_dls[source * max_size + dest]));
@ -95,13 +97,6 @@ namespace smt {
}
bitwise_or_rows(i, source);
}
if (conflict_word >= 0 && conflict_row >= 0) {
for (int i = conflict_word * 64; i < conflict_word * 64 + 64; i++) {
check_reachability_conflict(conflict_row, i);
}
conflict_word = -1;
conflict_row = -1;
}
return true;
}
@ -112,8 +107,8 @@ namespace smt {
}
ctx.push_trail(value_trail(largest_var));
largest_var = std::max({largest_var, source, dest});
ctx.push_trail(value_trail(non_links[get_word_index(source, dest)]));
non_links[get_word_index(source, dest)] |= get_bitmask(dest);
ctx.push_trail(value_trail(non_links[source]));
non_links[source].insert(dest);
ctx.push_trail(value_trail(non_link_justifications[source * max_size + dest]));
non_link_justifications[source * max_size + dest] = non_reachability_witness;
check_reachability_conflict(source, dest);
@ -162,7 +157,7 @@ namespace smt {
void reachability_matrix::get_path(theory_var source, theory_var dest, vector<enode_pair> &path,
int &num_decisions) {
SASSERT(is_reachable(source, dest));
vector<bool> visited(max_size, false);
bool_vector visited(max_size, false);
if (source != dest) {
visited[source] = true;
}
@ -207,22 +202,12 @@ namespace smt {
return false;
}
bool reachability_matrix::check_reachability_conflict_word(int row, int word) {
if (reachable[row * NUM_WORDS + word] & non_links[row * NUM_WORDS + word]) {
// somewhere in this word there is a conflict
conflict_row = row;
conflict_word = word;
return true;
}
return false;
}
void reachability_matrix::print_relations() {
TRACE(finite_set, tout << "largest_var: " << largest_var);
for (size_t i = 0; i < max_size; i++) {
for (size_t j = 0; j < max_size; j++) {
if ((reachable[get_word_index(i, j)] & get_bitmask(j)) || is_reachable(i, j)) {
TRACE(finite_set, tout << "reachable: " << i << "->" << j << " :" << is_reachable(i, j));
if (is_reachable(i, j)) {
TRACE(finite_set, tout << "reachable: " << i << "->" << j);
}
}
}

View file

@ -12,6 +12,7 @@ Module Name:
#include "ast/finite_set_decl_plugin.h"
#include "ast/rewriter/finite_set_axioms.h"
#include "smt/smt_theory.h"
#include "util/uint_set.h"
namespace smt {
class context;
@ -19,11 +20,11 @@ namespace smt {
class theory_finite_set_lattice_refutation;
class reachability_matrix {
std::vector<uint64_t> reachable;
std::vector<enode_pair> links;
std::vector<uint64_t> link_dls;
std::vector<uint64_t> non_links;
std::vector<enode_pair> non_link_justifications;
vector<uint_set> reachable;
vector<enode_pair> links;
vector<unsigned> link_dls;
vector<uint_set> non_links;
vector<enode_pair> non_link_justifications;
int largest_var;
@ -32,12 +33,9 @@ namespace smt {
context &ctx;
theory_finite_set_lattice_refutation &t_lattice_refutation;
int conflict_row = -1;
int conflict_word = -1;
// sets source_dest |= dest, and pushing the changed words to the trail
bool bitwise_or_rows(int source_dest, int source);
inline int get_word_index(int row, int col) const;
inline uint64_t get_bitmask(int col) const;
public:
void get_path(theory_var source, theory_var dest, vector<enode_pair> &path, int &num_decisions);
@ -48,7 +46,6 @@ namespace smt {
bool is_linked(theory_var source, theory_var dest);
bool check_reachability_conflict(theory_var source, theory_var dest);
bool check_reachability_conflict_word(int row, int word);
bool set_reachability(theory_var source, theory_var dest, enode_pair reachability_witness);
bool set_non_reachability(theory_var source, theory_var dest, enode_pair non_reachability_witness);