diff --git a/src/sat/smt/xor_gaussian.cpp b/src/sat/smt/xor_gaussian.cpp index 312074c84..1af5a0f20 100644 --- a/src/sat/smt/xor_gaussian.cpp +++ b/src/sat/smt/xor_gaussian.cpp @@ -30,9 +30,9 @@ static const unsigned unassigned_col = UINT32_MAX; ///returns popcnt unsigned PackedRow::find_watchVar( - sat::literal_vector& tmp_clause, + literal_vector& tmp_clause, const unsigned_vector& col_to_var, - char_vector &var_has_resp_row, + bool_vector &var_has_resp_row, unsigned& non_resp_var) { unsigned popcnt = 0; non_resp_var = UINT_MAX; @@ -98,7 +98,7 @@ void PackedRow::get_reason( gret PackedRow::propGause( const unsigned_vector& col_to_var, - char_vector &var_has_resp_row, + bool_vector &var_has_resp_row, unsigned& new_resp_var, PackedRow& tmp_col, PackedRow& tmp_col2, @@ -273,16 +273,13 @@ void EGaussian::fill_matrix() { } mat.resize(num_rows, num_cols); // initial gaussian matrix - bdd_matrix.clear(); for (unsigned row = 0; row < num_rows; row++) { const xor_clause& c = m_xorclauses[row]; mat[row].set(c, var_to_col, num_cols); char_vector line; line.resize(num_rows, 0); line[row] = 1; - bdd_matrix.push_back(line); } - SASSERT(bdd_matrix.size() == num_rows); // reset var_has_resp_row.clear(); @@ -294,7 +291,7 @@ void EGaussian::fill_matrix() { //reset satisfied_xor state SASSERT(m_solver.m_num_scopes == 0); satisfied_xors.clear(); - satisfied_xors.resize(num_rows, 0); + satisfied_xors.resize(num_rows, false); } void EGaussian::delete_gauss_watch_this_matrix() { @@ -346,7 +343,6 @@ bool EGaussian::full_init(bool& created) { case gret::prop: SASSERT(m_solver.m_num_scopes == 0); m_solver.s().propagate(false); // TODO: Can we really do this here? - // m_solver.ok = m_solver.propagate().isNull(); if (inconsistent()) { TRACE("xor", tout << "eliminate & adjust matrix during init lead to UNSAT\n";); return false; @@ -394,42 +390,38 @@ bool EGaussian::full_init(bool& created) { } void EGaussian::eliminate() { - PackedMatrix::iterator end_row_it = mat.begin() + num_rows; - PackedMatrix::iterator rowI = mat.begin(); + // TODO: Why twice? gauss_jordan_elim + const unsigned end_row = num_rows; + unsigned rowI = 0; unsigned row_i = 0; unsigned col = 0; // Gauss-Jordan Elimination while (row_i != num_rows && col != num_cols) { - PackedMatrix::iterator row_with_1_in_col = rowI; + unsigned row_with_1_in_col = rowI; unsigned row_with_1_in_col_n = row_i; - //Find first "1" in column. - for (; row_with_1_in_col != end_row_it; ++row_with_1_in_col, row_with_1_in_col_n++) { - if ((*row_with_1_in_col)[col]) + // Find first "1" in column. + for (; row_with_1_in_col < end_row; ++row_with_1_in_col, row_with_1_in_col_n++) { + if (mat[row_with_1_in_col][col]) break; } - //We have found a "1" in this column - if (row_with_1_in_col != end_row_it) { - var_has_resp_row[col_to_var[col]] = 1; + // We have found a "1" in this column + if (row_with_1_in_col < end_row) { + var_has_resp_row[col_to_var[col]] = true; // swap row row_with_1_in_col and rowIt - if (row_with_1_in_col != rowI) { - (*rowI).swapBoth(*row_with_1_in_col); - std::swap(bdd_matrix[row_i], bdd_matrix[row_with_1_in_col_n]); - } + if (row_with_1_in_col != rowI) + mat[rowI].swapBoth(mat[row_with_1_in_col]); // XOR into *all* rows that have a "1" in column COL // Since we XOR into *all*, this is Gauss-Jordan (and not just Gauss) unsigned k = 0; - for (PackedMatrix::iterator k_row = mat.begin() - ; k_row != end_row_it - ; ++k_row, k++ - ) { + for (unsigned k_row = 0; k_row < end_row; ++k_row, k++) { // xor rows K and I - if (k_row != rowI && (*k_row)[col]) { - (*k_row).xor_in(*rowI); + if (k_row != rowI && mat[k_row][col]) { + mat[k_row].xor_in(mat[rowI]); } } row_i++; @@ -439,7 +431,7 @@ void EGaussian::eliminate() { } } -sat::literal_vector* EGaussian::get_reason(const unsigned row, int& out_ID) { +literal_vector* EGaussian::get_reason(const unsigned row, int& out_ID) { if (!xor_reasons[row].m_must_recalc) { out_ID = xor_reasons[row].m_ID; return &(xor_reasons[row].m_reason); @@ -447,7 +439,7 @@ sat::literal_vector* EGaussian::get_reason(const unsigned row, int& out_ID) { // Clean up previous one - svector& to_fill = xor_reasons[row].m_reason; + literal_vector& to_fill = xor_reasons[row].m_reason; to_fill.clear(); mat[row].get_reason( @@ -493,7 +485,7 @@ gret EGaussian::init_adjust_matrix() { } TRACE("xor", tout << "-> empty on row: " << row_i;); TRACE("xor", tout << "-> Satisfied XORs set for row: " << row_i;); - satisfied_xors[row_i] = 1; + satisfied_xors[row_i] = true; break; //Unit (i.e. toplevel unit) @@ -508,13 +500,13 @@ gret EGaussian::init_adjust_matrix() { TRACE("xor", tout << "-> UNIT during adjust: " << tmp_clause[0];); TRACE("xor", tout << "-> Satisfied XORs set for row: " << row_i;); - satisfied_xors[row_i] = 1; + satisfied_xors[row_i] = true; SASSERT(check_row_satisfied(row_i)); //adjusting row.setZero(); // reset this row all zero row_to_var_non_resp.push_back(UINT32_MAX); - var_has_resp_row[tmp_clause[0].var()] = 0; + var_has_resp_row[tmp_clause[0].var()] = false; return gret::prop; } @@ -535,7 +527,7 @@ gret EGaussian::init_adjust_matrix() { row.setZero(); row_to_var_non_resp.push_back(UINT32_MAX); // delete non-basic value in this row - var_has_resp_row[tmp_clause[0].var()] = 0; // delete basic value in this row + var_has_resp_row[tmp_clause[0].var()] = false; // delete basic value in this row break; } @@ -579,7 +571,7 @@ void EGaussian::delete_gausswatch(const unsigned row_n) { unsigned EGaussian::get_max_level(const gauss_data& gqd, const unsigned row_n) { int ID; - auto cl = get_reason(row_n, ID); + literal_vector* cl = get_reason(row_n, ID); unsigned nMaxLevel = gqd.currLevel; unsigned nMaxInd = 1; @@ -600,8 +592,9 @@ unsigned EGaussian::get_max_level(const gauss_data& gqd, const unsigned row_n) { } bool EGaussian::find_truths( - gauss_watched*& i, - gauss_watched*& j, + svector& ws, + unsigned& i, + unsigned& j, const unsigned var, const unsigned row_n, gauss_data& gqd) { @@ -622,7 +615,7 @@ bool EGaussian::find_truths( if (satisfied_xors[row_n]) { TRACE("xor", tout << "-> xor satisfied as per satisfied_xors[row_n]";); SASSERT(check_row_satisfied(row_n)); - *j++ = *i; + ws[j++] = ws[i]; find_truth_ret_satisfied_precheck++; return true; } @@ -633,8 +626,8 @@ bool EGaussian::find_truths( //var has a responsible row, so THIS row must be it! //since if a var has a responsible row, only ONE row can have a 1 there was_resp_var = true; - var_has_resp_row[row_to_var_non_resp[row_n]] = 1; - var_has_resp_row[var] = 0; + var_has_resp_row[row_to_var_non_resp[row_n]] = true; + var_has_resp_row[var] = false; } unsigned new_resp_var; @@ -653,7 +646,7 @@ bool EGaussian::find_truths( switch (ret) { case gret::confl: { find_truth_ret_confl++; - *j++ = *i; + ws[j++] = ws[i]; xor_reasons[row_n].m_must_recalc = true; xor_reasons[row_n].m_propagated = sat::null_literal; @@ -662,8 +655,8 @@ bool EGaussian::find_truths( TRACE("xor", tout << "--> conflict";); if (was_resp_var) { // recover - var_has_resp_row[row_to_var_non_resp[row_n]] = 0; - var_has_resp_row[var] = 1; + var_has_resp_row[row_to_var_non_resp[row_n]] = false; + var_has_resp_row[var] = true; } return false; @@ -672,7 +665,7 @@ bool EGaussian::find_truths( case gret::prop: { find_truth_ret_prop++; TRACE("xor", tout << "--> propagation";); - *j++ = *i; + ws[j++] = ws[i]; xor_reasons[row_n].m_must_recalc = true; xor_reasons[row_n].m_propagated = ret_lit_prop; @@ -683,12 +676,12 @@ bool EGaussian::find_truths( gqd.status = gauss_res::prop; if (was_resp_var) { // recover - var_has_resp_row[row_to_var_non_resp[row_n]] = 0; - var_has_resp_row[var] = 1; + var_has_resp_row[row_to_var_non_resp[row_n]] = false; + var_has_resp_row[var] = true; } TRACE("xor", tout << "--> Satisfied XORs set for row: " << row_n;); - satisfied_xors[row_n] = 1; + satisfied_xors[row_n] = true; SASSERT(check_row_satisfied(row_n)); return true; } @@ -718,8 +711,8 @@ bool EGaussian::find_truths( //so elimination will be needed //clear old one, add new resp - var_has_resp_row[row_to_var_non_resp[row_n]] = 0; - var_has_resp_row[new_resp_var] = 1; + var_has_resp_row[row_to_var_non_resp[row_n]] = false; + var_has_resp_row[new_resp_var] = true; // store the eliminate variable & row gqd.new_resp_var = new_resp_var; @@ -739,15 +732,14 @@ bool EGaussian::find_truths( TRACE("xor", tout << "--> satisfied";); find_truth_ret_satisfied++; - // printf("%d:This row is nothing( maybe already true) n",row_n); - *j++ = *i; + ws[j++] = ws[i]; if (was_resp_var) { // recover - var_has_resp_row[row_to_var_non_resp[row_n]] = 0; - var_has_resp_row[var] = 1; + var_has_resp_row[row_to_var_non_resp[row_n]] = false; + var_has_resp_row[var] = true; } TRACE("xor", tout << "--> Satisfied XORs set for row: " << row_n;); - satisfied_xors[row_n] = 1; + satisfied_xors[row_n] = true; SASSERT(check_row_satisfied(row_n)); return true; @@ -843,7 +835,7 @@ void EGaussian::eliminate_col(unsigned p, gauss_data& gqd) { << " is being watched on var: " << orig_non_resp_var + 1 << " i.e. it must contain '1' for this var's column";); - SASSERT(satisfied_xors[row_i] == 0); + SASSERT(!satisfied_xors[row_i]); (*rowI).xor_in(*(mat.begin() + new_resp_row_n)); elim_xored_rows++; @@ -919,7 +911,7 @@ void EGaussian::eliminate_col(unsigned p, gauss_data& gqd) { gqd.status = gauss_res::prop; TRACE("xor", tout << "---> Satisfied XORs set for row: " << row_i;); - satisfied_xors[row_i] = 1; + satisfied_xors[row_i] = true; SASSERT(check_row_satisfied(row_i)); break; } @@ -946,7 +938,7 @@ void EGaussian::eliminate_col(unsigned p, gauss_data& gqd) { row_to_var_non_resp[row_i] = p; TRACE("xor", tout << "---> Satisfied XORs set for row: " << row_i;); - satisfied_xors[row_i] = 1; + satisfied_xors[row_i] = true; SASSERT(check_row_satisfied(row_i)); break; default: @@ -1092,4 +1084,4 @@ bool EGaussian::must_disable(gauss_data& gqd) { void EGaussian::move_back_xor_clauses() { for (const auto& x: m_xorclauses) m_solver.m_xorclauses.push_back(std::move(x)); -} \ No newline at end of file +} diff --git a/src/sat/smt/xor_gaussian.h b/src/sat/smt/xor_gaussian.h index ea92951b1..2598a7eb6 100644 --- a/src/sat/smt/xor_gaussian.h +++ b/src/sat/smt/xor_gaussian.h @@ -231,8 +231,8 @@ namespace xr { } // add all elements in other.m_clash_vars that are not yet in m_clash_vars: - void merge_clash(const xor_clause& other, visit_helper& visited) { - visited.init_visited(m_clash_vars.size()); + void merge_clash(const xor_clause& other, visit_helper& visited, unsigned num_vars) { + visited.init_visited(num_vars); for (const bool_var& v: m_clash_vars) visited.mark_visited(v); @@ -356,8 +356,8 @@ namespace xr { int64_t* __restrict mp1 = mp - 1; int64_t* __restrict mp2 = b.mp - 1; - unsigned i = size+1; - while(i != 0) { + unsigned i = size + 1; + while (i != 0) { std::swap(*mp1, *mp2); mp1++; mp2++; @@ -391,13 +391,13 @@ namespace xr { unsigned find_watchVar( sat::literal_vector& tmp_clause, const unsigned_vector& col_to_var, - char_vector &var_has_resp_row, + bool_vector &var_has_resp_row, unsigned& non_resp_var); // using find nonbasic value after watch list is enter gret propGause( const unsigned_vector& col_to_var, - char_vector &var_has_resp_row, + bool_vector &var_has_resp_row, unsigned& new_resp_var, PackedRow& tmp_col, PackedRow& tmp_col2, @@ -551,14 +551,15 @@ namespace xr { ///returns FALSE in case of conflict bool find_truths( - gauss_watched*& i, - gauss_watched*& j, + svector& ws, + unsigned& i, + unsigned& j, const unsigned var, const unsigned row_n, gauss_data& gqd ); - sat::literal_vector* get_reason(const unsigned row, int& out_ID); + literal_vector* get_reason(const unsigned row, int& out_ID); // when basic variable is touched , eliminate one col void eliminate_col( @@ -634,23 +635,22 @@ namespace xr { bool cancelled_since_val_update = true; unsigned last_val_update = 0; - //Is the clause at this ROW satisfied already? - //satisfied_xors[row] tells me that - // TODO: Are characters enough? - char_vector satisfied_xors; + // Is the clause at this ROW satisfied already? + // satisfied_xors[row] tells me that + // TODO: Maybe compress further + bool_vector satisfied_xors; // Someone is responsible for this column if TRUE - ///we always WATCH this variable - char_vector var_has_resp_row; + // we always WATCH this variable + bool_vector var_has_resp_row; - ///row_to_var_non_resp[ROW] gives VAR it's NOT responsible for - ///we always WATCH this variable + // row_to_var_non_resp[ROW] gives VAR it's NOT responsible for + // we always WATCH this variable unsigned_vector row_to_var_non_resp; PackedMatrix mat; - svector bdd_matrix; // TODO: we will probably not need it - unsigned_vector var_to_col; ///var->col mapping. Index with VAR + unsigned_vector var_to_col; ///var->col mapping. Index with VAR unsigned_vector col_to_var; ///col->var mapping. Index with COL unsigned num_rows = 0; unsigned num_cols = 0; @@ -669,7 +669,7 @@ namespace xr { inline void EGaussian::canceling() { cancelled_since_val_update = true; - memset(satisfied_xors.data(), 0, satisfied_xors.size()); + memset(satisfied_xors.data(), false, satisfied_xors.size()); } inline double EGaussian::get_density() { @@ -689,4 +689,4 @@ namespace xr { inline bool EGaussian::is_initialized() const { return initialized; } -} \ No newline at end of file +} diff --git a/src/sat/smt/xor_matrix_finder.h b/src/sat/smt/xor_matrix_finder.h index ddaade4df..6e1095dab 100644 --- a/src/sat/smt/xor_matrix_finder.h +++ b/src/sat/smt/xor_matrix_finder.h @@ -28,7 +28,7 @@ namespace xr { class xor_matrix_finder { struct matrix_shape { - matrix_shape(uint32_t matrix_num) : m_num(matrix_num) {} + matrix_shape(unsigned matrix_num) : m_num(matrix_num) {} matrix_shape() {} @@ -39,12 +39,12 @@ namespace xr { double m_density = 0; uint64_t tot_size() const { - return (uint64_t)m_rows*(uint64_t)m_cols; + return (uint64_t)m_rows * (uint64_t)m_cols; } }; struct sorter { - bool operator () (const matrix_shape& left, const matrix_shape& right) { + bool operator()(const matrix_shape& left, const matrix_shape& right) { return left.m_sum_xor_sizes < right.m_sum_xor_sizes; } }; diff --git a/src/sat/smt/xor_solver.cpp b/src/sat/smt/xor_solver.cpp index 9131ca4d2..08bc872f4 100644 --- a/src/sat/smt/xor_solver.cpp +++ b/src/sat/smt/xor_solver.cpp @@ -151,31 +151,32 @@ namespace xr { bool confl_in_gauss = false; SASSERT(m_gwatches.size() > p.var()); svector& ws = m_gwatches[p.var()]; - gauss_watched* i = ws.begin(); - gauss_watched* j = i; - const gauss_watched* end = ws.end(); + unsigned i = 0, j = 0; + const unsigned end = ws.size(); - for (; i != end; i++) { - if (m_gqueuedata[i->matrix_num].disabled || !m_gmatrices[i->matrix_num]->is_initialized()) + for (; i < end; i++) { + const unsigned matrix_num = ws[i].matrix_num; + const unsigned row_n = ws[i].row_n; + if (m_gqueuedata[matrix_num].disabled || !m_gmatrices[matrix_num]->is_initialized()) continue; //remove watch and continue - m_gqueuedata[i->matrix_num].new_resp_var = UINT_MAX; - m_gqueuedata[i->matrix_num].new_resp_row = UINT_MAX; - m_gqueuedata[i->matrix_num].do_eliminate = false; - m_gqueuedata[i->matrix_num].currLevel = currLevel; + m_gqueuedata[matrix_num].new_resp_var = UINT_MAX; + m_gqueuedata[matrix_num].new_resp_row = UINT_MAX; + m_gqueuedata[matrix_num].do_eliminate = false; + m_gqueuedata[matrix_num].currLevel = currLevel; - if (m_gmatrices[i->matrix_num]->find_truths(i, j, p.var(), i->row_n, m_gqueuedata[i->matrix_num])) { + if (m_gmatrices[matrix_num]->find_truths(ws, i, j, p.var(), row_n, m_gqueuedata[matrix_num])) { continue; } else { confl_in_gauss = true; - i++; + i++; // TODO: That's strange, but this is really written this was in CMS break; } } - for (; i != end; i++) - *j++ = *i; + for (; i < end; i++) + ws[j++] = ws[i]; ws.shrink((unsigned)(i - j)); for (unsigned g = 0; g < m_gqueuedata.size(); g++) { @@ -299,7 +300,7 @@ namespace xr { xors[j++] = x; } else { - for (const auto& v : x.m_clash_vars) + for (const bool_var& v : x.m_clash_vars) m_removed_xorclauses_clash_vars.insert(v); } } @@ -316,14 +317,14 @@ namespace xr { bool solver::clean_one_xor(xor_clause& x) { unsigned j = 0; - for (auto const& v : x.m_clash_vars) + for (const bool_var & v : x.m_clash_vars) if (s().value(v) == l_undef) x.m_clash_vars[j++] = v; x.m_clash_vars.shrink(j); j = 0; - for (auto const& v : x) { + for (const bool_var& v : x) { if (s().value(v) != l_undef) x.m_rhs ^= s().value(v) == l_true; else @@ -335,11 +336,6 @@ namespace xr { case 0: if (x.m_rhs) s().set_conflict(); - /*TODO: Implement - if (inconsistent()) { - SASSERT(m_solver.unsat_cl_ID == 0); - m_solver.unsat_cl_ID = solver->clauseID; - }*/ return false; case 1: { s().assign_scoped(sat::literal(x[0], !x.m_rhs)); @@ -347,9 +343,9 @@ namespace xr { return false; } case 2: { - sat::literal_vector vec(x.size()); + literal_vector vec(x.size()); for (const auto& v : x.m_vars) - vec.push_back(sat::literal(v)); + vec.push_back(literal(v)); add_xor_clause(vec, x.m_rhs, true); return false; } @@ -490,23 +486,24 @@ namespace xr { void solver::clean_equivalent_xors(vector& txors){ if (!txors.empty()) { - size_t orig_size = txors.size(); for (xor_clause& x: txors) std::sort(x.begin(), x.end()); std::sort(txors.begin(), txors.end()); + m_visited.init_visited(s().num_vars()); + unsigned sz = 1; unsigned j = 0; for (unsigned i = 1; i < txors.size(); i++) { auto& jd = txors[j]; auto& id = txors[i]; if (jd.m_vars == id.m_vars && jd.m_rhs == id.m_rhs) { - jd.merge_clash(id, m_visited); + jd.merge_clash(id, m_visited, s().num_vars()); jd.m_detached |= id.m_detached; } else { j++; - j = i; + txors[j] = txors[i]; sz++; } } @@ -566,8 +563,8 @@ namespace xr { unsigned xored = 0; SASSERT(m_occurrences.empty()); - #if 0 - //Link in xors into watchlist + + // Link in xors into watchlist for (unsigned i = 0; i < xors.size(); i++) { const xor_clause& x = xors[i]; for (bool_var v: x) { @@ -577,22 +574,18 @@ namespace xr { m_occ_cnt[v]++; sat::literal l(v, false); - SASSERT(s()->watches.size() > l.toInt()); - m_watches[l].push(Watched(i, WatchType::watch_idx_t)); - m_watches.smudge(l); + watch_neg_literal(l, i); } } - //Don't XOR together over variables that are in regular clauses + // Don't XOR together over variables that are in regular clauses s().init_visited(); for (unsigned i = 0; i < 2 * s().num_vars(); i++) { const auto& ws = s().get_wlist(i); for (const auto& w: ws) { - if (w.is_binary_clause()/* TODO: Does redundancy information exist in Z3? Can we use learned instead of "!w.red()"?*/ && !w.is_learned()) { - sat::bool_var v = w.get_literal().var(); - s().mark_visited(v); - } + if (w.is_binary_clause()/* TODO: Does redundancy information exist in Z3? Can we use learned instead of "!w.red()"?*/ && !w.is_learned()) + s().mark_visited(w.get_literal().var()); } } @@ -601,19 +594,19 @@ namespace xr { if (cl->red() || cl->used_in_xor()) { continue; }*/ - // TODO: maybe again instead + // TODO: maybe again this instead if (cl->is_learned()) continue; for (literal l: *cl) s().mark_visited(l.var()); } - //until fixedpoint + // until fixedpoint bool changed = true; while (changed) { changed = false; m_interesting.clear(); - for (const unsigned l : m_occurrences) { + for (const bool_var l : m_occurrences) { if (m_occ_cnt[l] == 2 && !s().is_visited(l)) { m_interesting.push_back(l); } @@ -621,7 +614,7 @@ namespace xr { while (!m_interesting.empty()) { - //Pop and check if it can be XOR-ed together + // Pop and check if it can be XOR-ed together const unsigned v = m_interesting.back(); m_interesting.resize(m_interesting.size()-1); if (m_occ_cnt[v] != 2) @@ -630,48 +623,49 @@ namespace xr { unsigned indexes[2]; unsigned at = 0; size_t i2 = 0; - //SASSERT(watches.size() > literal(v, false).index()); - vector ws = s().get_wlist(literal(v, false)); + sat::watch_list& ws = s().get_wlist(literal(v, false)); //Remove the 2 indexes from the watchlist for (unsigned i = 0; i < ws.size(); i++) { const sat::watched& w = ws[i]; - if (!w.isIdx()) { + if (!w.is_ext_constraint()) { + // TODO: Check!!! Is this fine? ws[i2++] = ws[i]; - } else if (!xors[w.get_idx()].empty()) { + } + else if (!xors[w.get_ext_constraint_idx()].empty()) { SASSERT(at < 2); - indexes[at] = w.get_idx(); + indexes[at] = w.get_ext_constraint_idx(); at++; } } SASSERT(at == 2); - ws.resize(i2); + ws.shrink(i2); xor_clause& x0 = xors[indexes[0]]; xor_clause& x1 = xors[indexes[1]]; unsigned clash_var; unsigned clash_num = xor_two(&x0, &x1, clash_var); - //If they are equivalent + // If they are equivalent if (x0.size() == x1.size() && x0.m_rhs == x1.m_rhs - && clash_num == x0.size() - ) { + && clash_num == x0.size()) { + TRACE("xor", tout - << "x1: " << x0 << " -- at idx: " << indexes[0] - << "and x2: " << x1 << " -- at idx: " << indexes[1] - << "are equivalent.\n"); + << "x1: " << x0 << " -- at idx: " << indexes[0] + << "and x2: " << x1 << " -- at idx: " << indexes[1] + << "are equivalent.\n"); - //Update clash values & detached values - x1.merge_clash(x0, m_visited); + // Update clash values & detached values + x1.merge_clash(x0, m_visited, s().num_vars()); x1.m_detached |= x0.m_detached; TRACE("xor", tout << "after merge: " << x1 << " -- at idx: " << indexes[1] << "\n";); x0 = xor_clause(); - //Re-attach the other, remove the occur of the one we deleted - s().m_watches[Lit(v, false)].push(Watched(indexes[1], WatchType::watch_idx_t)); + // Re-attach the other, remove the occurrence of the one we deleted + watch_neg_literal(ws, indexes[1]); for (unsigned v2: x1) { sat::literal l(v2, false); @@ -682,29 +676,29 @@ namespace xr { } } } else if (clash_num > 1 || x0.m_detached || x1.m_detached) { - //add back to ws, can't do much - ws.push(Watched(indexes[0], WatchType::watch_idx_t)); - ws.push(Watched(indexes[1], WatchType::watch_idx_t)); + // add back to watch-list, can't do much + watch_neg_literal(ws, indexes[0]); + watch_neg_literal(ws, indexes[1]); continue; } else { m_occ_cnt[v] -= 2; SASSERT(m_occ_cnt[v] == 0); xor_clause x_new(m_tmp_vars_xor_two, x0.m_rhs ^ x1.m_rhs, clash_var); - x_new.merge_clash(x0, m_visited); - x_new.merge_clash(x1, m_visited); + x_new.merge_clash(x0, m_visited, s().num_vars()); + x_new.merge_clash(x1, m_visited, s().num_vars()); TRACE("xor", tout - << "x1: " << x0 << " -- at idx: " << indexes[0] << "\n" - << "x2: " << x1 << " -- at idx: " << indexes[1] << "\n" - << "clashed on var: " << clash_var+1 << "\n" - << "final: " << x_new << " -- at idx: " << xors.size() << "\n";); + << "x1: " << x0 << " -- at idx: " << indexes[0] << "\n" + << "x2: " << x1 << " -- at idx: " << indexes[1] << "\n" + << "clashed on var: " << clash_var+1 << "\n" + << "final: " << x_new << " -- at idx: " << xors.size() << "\n";); changed = true; xors.push_back(x_new); - for(uint32_t v2: x_new) { - sat::literal l(v2, false); - s().watches[l].push(Watched(xors.size()-1, WatchType::watch_idx_t)); + for (bool_var v2 : x_new) { + literal l(v2, false); + watch_neg_literal(l, xors.size() - 1); SASSERT(m_occ_cnt[l.var()] >= 1); if (m_occ_cnt[l.var()] == 2 && !s().is_visited(l.var())) { m_interesting.push_back(l.var()); @@ -717,19 +711,103 @@ namespace xr { } } - //Clear + // Clear for (const bool_var l : m_occurrences) { m_occ_cnt[l] = 0; + // Caution: Merged smudge- (from watched literals) and occurrences-list + clean_occur_from_idx(literal(l, false)); } m_occurrences.clear(); - clean_occur_from_idx_types_only_smudged(); clean_xors_from_empty(xors); - #endif return !s().inconsistent(); } + + // Remove all watches coming from xor solver + // TODO: Differentiate if the watch came from another theory (not xor)!! + void solver::clean_occur_from_idx(const literal l) { + vector& ws = s().get_wlist(~l); // the same polarity that was added + unsigned i = 0, j = 0; + const unsigned end = ws.size(); + for (; i < end; i++) { + if (!ws[i].is_ext_constraint()) { + ws[j++] = ws[i]; + } + } + ws.shrink(i - j); + } + + // Removes all xor clauses that do not contain any variables + // (and have rhs = false; i.e., are trivially satisfied) and move them to unused + void solver::clean_xors_from_empty(vector& thisxors) { + unsigned j = 0; + for (unsigned i = 0; i < thisxors.size(); i++) { + xor_clause& x = thisxors[i]; + if (x.empty() && !x.m_rhs) { + if (!x.m_clash_vars.empty()) { + m_xorclauses_unused.push_back(x); + } + } else { + thisxors[j++] = thisxors[i]; + } + } + thisxors.shrink(j); + } + + // Merge two xor clauses; the resulting clause is in m_tmp_vars_xor_two and the variable where it was glued is in clash_var + // returns 0 if no common variable was found, 1 if there was exactly one and 2 if there are more + // only 1 is successful + unsigned solver::xor_two(xor_clause const* x1_p, xor_clause const* x2_p, bool_var& clash_var) { + m_tmp_vars_xor_two.clear(); + if (x1_p->size() > x2_p->size()) + std::swap(x1_p, x2_p); + + const xor_clause& x1 = *x1_p; + const xor_clause& x2 = *x2_p; + + m_visited.init_visited(s().num_vars(), 2); + + unsigned clash_num = 0; + for (bool_var v : x1) { + SASSERT(!m_visited.is_visited(v)); + m_visited.inc_visited(v); + } + + bool_var i_x2; + bool early_abort = false; + for (i_x2 = 0; i_x2 < x2.size(); i_x2++) { + bool_var v = x2[i_x2]; + SASSERT(m_visited.num_visited(v) < 2); + if (!m_visited.is_visited(v)) { + m_tmp_vars_xor_two.push_back(v); + } + else { + clash_var = v; + if (clash_num > 0 && clash_num != i_x2) { + //early abort, it's never gonna be good + clash_num++; + early_abort = true; + break; + } + clash_num++; + } + + m_visited.inc_visited(v, 2); + } + + if (!early_abort) { + for (bool_var v: x1) { + if (m_visited.num_visited(v) < 2) { + m_tmp_vars_xor_two.push_back(v); + } + } + } + + return clash_num; + } + std::ostream& solver::display_justification(std::ostream& out, sat::ext_justification_idx idx) const { return out; } diff --git a/src/sat/smt/xor_solver.h b/src/sat/smt/xor_solver.h index 845446d58..ddec48df6 100644 --- a/src/sat/smt/xor_solver.h +++ b/src/sat/smt/xor_solver.h @@ -58,6 +58,7 @@ namespace xr { // and we need the list of occurrences unsigned_vector m_occ_cnt; bool_var_vector m_interesting; + bool_var_vector m_tmp_vars_xor_two; void force_push(); void push_core(); @@ -70,8 +71,36 @@ namespace xr { void add_xor_clause(const sat::literal_vector& lits, bool rhs, const bool attach); + void clean_occur_from_idx(const literal l); + void clean_xors_from_empty(vector& thisxors); + unsigned xor_two(xor_clause const* x1_p, xor_clause const* x2_p, bool_var& clash_var); + bool inconsistent() const { return s().inconsistent(); } + // TODO: CMS watches the literals directly; Z3 their negation. "_neg_" just for now to avoid confusion + bool is_neg_watched(sat::watch_list& l, size_t idx) const { + return l.contains(sat::watched((sat::ext_constraint_idx)idx)); + } + + bool is_neg_watched(literal lit, size_t idx) const { + return s().get_wlist(~lit).contains(sat::watched((sat::ext_constraint_idx)idx)); + } + + void unwatch_neg_literal(literal lit, size_t idx) { + s().get_wlist(~lit).erase(sat::watched(idx)); + SASSERT(!is_neg_watched(lit, idx)); + } + + void watch_neg_literal(sat::watch_list& l, size_t idx) { + SASSERT(!is_neg_watched(l, idx)); + l.push_back(sat::watched(idx)); + } + + void watch_neg_literal(literal lit, size_t idx) { + watch_neg_literal(s().get_wlist(~lit), idx); + } + + public: solver(euf::solver& ctx); solver(ast_manager& m, euf::theory_id id); diff --git a/src/util/visit_helper.h b/src/util/visit_helper.h index 5f4591828..6f77fe09e 100644 --- a/src/util/visit_helper.h +++ b/src/util/visit_helper.h @@ -41,8 +41,9 @@ public: } void mark_visited(unsigned v) { m_visited[v] = m_visited_begin + 1; } - void inc_visited(unsigned v) { - m_visited[v] = std::min(m_visited_end, std::max(m_visited_begin, m_visited[v]) + 1); + void inc_visited(unsigned v) { inc_visited(v, 1); } + void inc_visited(unsigned v, unsigned by) { + m_visited[v] = std::min(m_visited_end, std::max(m_visited_begin, m_visited[v]) + by); } bool is_visited(unsigned v) const { return m_visited[v] > m_visited_begin; } unsigned num_visited(unsigned v) const { return std::max(m_visited_begin, m_visited[v]) - m_visited_begin; }