From e7cf78996982f4b6cbcbdc7c2a45ebd2e8011361 Mon Sep 17 00:00:00 2001 From: Clemens Eisenhofer Date: Fri, 11 Nov 2022 10:28:41 +0100 Subject: [PATCH] Use Z3's watch-list --- src/sat/smt/xor_gaussian.h | 4 +- src/sat/smt/xor_solver.cpp | 134 +++++++++++++++++++++++++------------ src/sat/smt/xor_solver.h | 27 ++++++++ src/util/visit_helper.h | 5 +- 4 files changed, 125 insertions(+), 45 deletions(-) diff --git a/src/sat/smt/xor_gaussian.h b/src/sat/smt/xor_gaussian.h index ea92951b1..52d7e5783 100644 --- a/src/sat/smt/xor_gaussian.h +++ b/src/sat/smt/xor_gaussian.h @@ -231,8 +231,8 @@ namespace xr { } // add all elements in other.m_clash_vars that are not yet in m_clash_vars: - void merge_clash(const xor_clause& other, visit_helper& visited) { - visited.init_visited(m_clash_vars.size()); + void merge_clash(const xor_clause& other, visit_helper& visited, unsigned num_vars) { + visited.init_visited(num_vars); for (const bool_var& v: m_clash_vars) visited.mark_visited(v); diff --git a/src/sat/smt/xor_solver.cpp b/src/sat/smt/xor_solver.cpp index 9131ca4d2..6d596966e 100644 --- a/src/sat/smt/xor_solver.cpp +++ b/src/sat/smt/xor_solver.cpp @@ -151,7 +151,7 @@ namespace xr { bool confl_in_gauss = false; SASSERT(m_gwatches.size() > p.var()); svector& ws = m_gwatches[p.var()]; - gauss_watched* i = ws.begin(); + gauss_watched* i = ws.begin(); // TODO: Convert to index or iterator-for gauss_watched* j = i; const gauss_watched* end = ws.end(); @@ -169,7 +169,7 @@ namespace xr { } else { confl_in_gauss = true; - i++; + i++; // TODO: That's strange, but this is really written this was in CMS break; } } @@ -495,13 +495,15 @@ namespace xr { std::sort(x.begin(), x.end()); std::sort(txors.begin(), txors.end()); + m_visited.init_visited(s().num_vars()); + unsigned sz = 1; unsigned j = 0; for (unsigned i = 1; i < txors.size(); i++) { auto& jd = txors[j]; auto& id = txors[i]; if (jd.m_vars == id.m_vars && jd.m_rhs == id.m_rhs) { - jd.merge_clash(id, m_visited); + jd.merge_clash(id, m_visited, s().num_vars()); jd.m_detached |= id.m_detached; } else { @@ -566,8 +568,8 @@ namespace xr { unsigned xored = 0; SASSERT(m_occurrences.empty()); - #if 0 - //Link in xors into watchlist + + // Link in xors into watchlist for (unsigned i = 0; i < xors.size(); i++) { const xor_clause& x = xors[i]; for (bool_var v: x) { @@ -577,13 +579,13 @@ namespace xr { m_occ_cnt[v]++; sat::literal l(v, false); - SASSERT(s()->watches.size() > l.toInt()); - m_watches[l].push(Watched(i, WatchType::watch_idx_t)); - m_watches.smudge(l); + watch_neg_literal(l, i); + // TODO: What's that for? + // m_watches.smudge(l); } } - //Don't XOR together over variables that are in regular clauses + // Don't XOR together over variables that are in regular clauses s().init_visited(); for (unsigned i = 0; i < 2 * s().num_vars(); i++) { @@ -608,7 +610,7 @@ namespace xr { s().mark_visited(l.var()); } - //until fixedpoint + // until fixedpoint bool changed = true; while (changed) { changed = false; @@ -621,7 +623,7 @@ namespace xr { while (!m_interesting.empty()) { - //Pop and check if it can be XOR-ed together + // Pop and check if it can be XOR-ed together const unsigned v = m_interesting.back(); m_interesting.resize(m_interesting.size()-1); if (m_occ_cnt[v] != 2) @@ -630,17 +632,18 @@ namespace xr { unsigned indexes[2]; unsigned at = 0; size_t i2 = 0; - //SASSERT(watches.size() > literal(v, false).index()); - vector ws = s().get_wlist(literal(v, false)); + sat::watch_list& ws = s().get_wlist(literal(v, false)); //Remove the 2 indexes from the watchlist for (unsigned i = 0; i < ws.size(); i++) { const sat::watched& w = ws[i]; - if (!w.isIdx()) { + if (!w.is_ext_constraint()) { + // TODO: Check!!! Is this fine? ws[i2++] = ws[i]; - } else if (!xors[w.get_idx()].empty()) { + } + else if (!xors[w.get_ext_constraint_idx()].empty()) { SASSERT(at < 2); - indexes[at] = w.get_idx(); + indexes[at] = w.get_ext_constraint_idx(); at++; } } @@ -652,26 +655,26 @@ namespace xr { unsigned clash_var; unsigned clash_num = xor_two(&x0, &x1, clash_var); - //If they are equivalent + // If they are equivalent if (x0.size() == x1.size() && x0.m_rhs == x1.m_rhs - && clash_num == x0.size() - ) { + && clash_num == x0.size()) { + TRACE("xor", tout - << "x1: " << x0 << " -- at idx: " << indexes[0] - << "and x2: " << x1 << " -- at idx: " << indexes[1] - << "are equivalent.\n"); + << "x1: " << x0 << " -- at idx: " << indexes[0] + << "and x2: " << x1 << " -- at idx: " << indexes[1] + << "are equivalent.\n"); - //Update clash values & detached values - x1.merge_clash(x0, m_visited); + // Update clash values & detached values + x1.merge_clash(x0, m_visited, s().num_vars()); x1.m_detached |= x0.m_detached; TRACE("xor", tout << "after merge: " << x1 << " -- at idx: " << indexes[1] << "\n";); x0 = xor_clause(); - //Re-attach the other, remove the occur of the one we deleted - s().m_watches[Lit(v, false)].push(Watched(indexes[1], WatchType::watch_idx_t)); + // Re-attach the other, remove the occurrence of the one we deleted + watch_neg_literal(ws, indexes[1]); for (unsigned v2: x1) { sat::literal l(v2, false); @@ -682,29 +685,29 @@ namespace xr { } } } else if (clash_num > 1 || x0.m_detached || x1.m_detached) { - //add back to ws, can't do much - ws.push(Watched(indexes[0], WatchType::watch_idx_t)); - ws.push(Watched(indexes[1], WatchType::watch_idx_t)); + // add back to watch-list, can't do much + watch_neg_literal(ws, indexes[0]); + watch_neg_literal(ws, indexes[1]); continue; } else { m_occ_cnt[v] -= 2; SASSERT(m_occ_cnt[v] == 0); xor_clause x_new(m_tmp_vars_xor_two, x0.m_rhs ^ x1.m_rhs, clash_var); - x_new.merge_clash(x0, m_visited); - x_new.merge_clash(x1, m_visited); + x_new.merge_clash(x0, m_visited, s().num_vars()); + x_new.merge_clash(x1, m_visited, s().num_vars()); TRACE("xor", tout - << "x1: " << x0 << " -- at idx: " << indexes[0] << "\n" - << "x2: " << x1 << " -- at idx: " << indexes[1] << "\n" - << "clashed on var: " << clash_var+1 << "\n" - << "final: " << x_new << " -- at idx: " << xors.size() << "\n";); + << "x1: " << x0 << " -- at idx: " << indexes[0] << "\n" + << "x2: " << x1 << " -- at idx: " << indexes[1] << "\n" + << "clashed on var: " << clash_var+1 << "\n" + << "final: " << x_new << " -- at idx: " << xors.size() << "\n";); changed = true; xors.push_back(x_new); - for(uint32_t v2: x_new) { + for (bool_var v2 : x_new) { sat::literal l(v2, false); - s().watches[l].push(Watched(xors.size()-1, WatchType::watch_idx_t)); + watch_neg_literal(l, xors.size() - 1); SASSERT(m_occ_cnt[l.var()] >= 1); if (m_occ_cnt[l.var()] == 2 && !s().is_visited(l.var())) { m_interesting.push_back(l.var()); @@ -717,19 +720,68 @@ namespace xr { } } - //Clear + // Clear for (const bool_var l : m_occurrences) { m_occ_cnt[l] = 0; } m_occurrences.clear(); - clean_occur_from_idx_types_only_smudged(); - clean_xors_from_empty(xors); - #endif + // TODO: Implement + //clean_occur_from_idx_types_only_smudged(); + //clean_xors_from_empty(xors); return !s().inconsistent(); } + unsigned solver::xor_two(xor_clause const* x1_p, xor_clause const* x2_p, bool_var& clash_var) { + m_tmp_vars_xor_two.clear(); + if (x1_p->size() > x2_p->size()) + std::swap(x1_p, x2_p); + + const xor_clause& x1 = *x1_p; + const xor_clause& x2 = *x2_p; + + m_visited.init_visited(s().num_vars(), 2); + + unsigned clash_num = 0; + for (bool_var v : x1) { + SASSERT(!m_visited.is_visited(v)); + m_visited.inc_visited(v); + } + + bool_var i_x2; + bool early_abort = false; + for (i_x2 = 0; i_x2 < x2.size(); i_x2++) { + bool_var v = x2[i_x2]; + SASSERT(m_visited.num_visited(v) < 2); + if (!m_visited.is_visited(v)) { + m_tmp_vars_xor_two.push_back(v); + } + else { + clash_var = v; + if (clash_num > 0 && clash_num != i_x2) { + //early abort, it's never gonna be good + clash_num++; + early_abort = true; + break; + } + clash_num++; + } + + m_visited.inc_visited(v, 2); + } + + if (!early_abort) { + for (bool_var v: x1) { + if (m_visited.num_visited(v) < 2) { + m_tmp_vars_xor_two.push_back(v); + } + } + } + + return clash_num; + } + std::ostream& solver::display_justification(std::ostream& out, sat::ext_justification_idx idx) const { return out; } diff --git a/src/sat/smt/xor_solver.h b/src/sat/smt/xor_solver.h index 845446d58..eeec78274 100644 --- a/src/sat/smt/xor_solver.h +++ b/src/sat/smt/xor_solver.h @@ -58,6 +58,7 @@ namespace xr { // and we need the list of occurrences unsigned_vector m_occ_cnt; bool_var_vector m_interesting; + bool_var_vector m_tmp_vars_xor_two; void force_push(); void push_core(); @@ -70,8 +71,34 @@ namespace xr { void add_xor_clause(const sat::literal_vector& lits, bool rhs, const bool attach); + unsigned xor_two(xor_clause const* x1_p, xor_clause const* x2_p, bool_var& clash_var); + bool inconsistent() const { return s().inconsistent(); } + // TODO: CMS watches the literals directly; Z3 their negation. "_neg_" just for now to avoid confusion + bool is_neg_watched(sat::watch_list& l, size_t idx) const { + return l.contains(sat::watched((sat::ext_constraint_idx)idx)); + } + + bool is_neg_watched(literal lit, size_t idx) const { + return s().get_wlist(lit).contains(sat::watched((sat::ext_constraint_idx)idx)); + } + + void unwatch_neg_literal(literal lit, size_t idx) { + s().get_wlist(lit).erase(sat::watched(idx)); + SASSERT(!is_neg_watched(lit, idx)); + } + + void watch_neg_literal(sat::watch_list& l, size_t idx) { + SASSERT(!is_neg_watched(l, idx)); + l.push_back(sat::watched(idx)); + } + + void watch_neg_literal(literal lit, size_t idx) { + watch_neg_literal(s().get_wlist(lit), idx); + } + + public: solver(euf::solver& ctx); solver(ast_manager& m, euf::theory_id id); diff --git a/src/util/visit_helper.h b/src/util/visit_helper.h index 5f4591828..6f77fe09e 100644 --- a/src/util/visit_helper.h +++ b/src/util/visit_helper.h @@ -41,8 +41,9 @@ public: } void mark_visited(unsigned v) { m_visited[v] = m_visited_begin + 1; } - void inc_visited(unsigned v) { - m_visited[v] = std::min(m_visited_end, std::max(m_visited_begin, m_visited[v]) + 1); + void inc_visited(unsigned v) { inc_visited(v, 1); } + void inc_visited(unsigned v, unsigned by) { + m_visited[v] = std::min(m_visited_end, std::max(m_visited_begin, m_visited[v]) + by); } bool is_visited(unsigned v) const { return m_visited[v] > m_visited_begin; } unsigned num_visited(unsigned v) const { return std::max(m_visited_begin, m_visited[v]) - m_visited_begin; }