diff --git a/src/math/polysat/fixed_bits.cpp b/src/math/polysat/fixed_bits.cpp index 8b94c1a59..63643dfc4 100644 --- a/src/math/polysat/fixed_bits.cpp +++ b/src/math/polysat/fixed_bits.cpp @@ -17,7 +17,7 @@ Abstract: namespace polysat { bit_justication* bit_justication::get_other_justification(const fixed_bits& fixed, const pdd& p, unsigned idx) { - return fixed.m_tbv_to_justification[{ p, idx }]; + return fixed.m_bvpos_to_justification[{ p, idx }].m_justification; } const tbv_ref* bit_justication::get_tbv(fixed_bits& fixed, const pdd& p) { @@ -25,23 +25,23 @@ namespace polysat { } // returns: Is it consistent - bool bit_justication::fix_value_core(solver& s, fixed_bits& fixed, const pdd& p, tbv_ref& tbv, unsigned idx, tbit val, bit_justication** j) { + bool bit_justication::fix_bit(solver& s, fixed_bits& fixed, const pdd& p, tbv_ref& tbv, unsigned idx, tbit val, bit_justication** j) { SASSERT(j && *j); - if (!fixed.fix_value(s, p, tbv, idx, val, *j) && (*j)->can_dealloc()) { - // TODO: Potential double deallocation + if (!fixed.fix_bit(s, p, tbv, idx, val, *j) && (*j)->can_dealloc()) { + // TODO: Potential double deallocation; Check! dealloc(*j); *j = nullptr; } return fixed.m_consistent; } - bool bit_justication::fix_value_core(solver& s, fixed_bits& fixed, const pdd& p, tbv_ref& tbv, unsigned idx, tbit val, bit_justication* j) { - return fix_value_core(s, fixed, p, tbv, idx, val, &j); + bool bit_justication::fix_bit(solver& s, fixed_bits& fixed, const pdd& p, tbv_ref& tbv, unsigned idx, tbit val, bit_justication* j) { + return fix_bit(s, fixed, p, tbv, idx, val, &j); } void bit_justication_constraint::get_dependencies(fixed_bits& fixed, bit_dependencies& to_process) { for (const auto& dep : this->m_dependencies) { - LOG("Dependency: pdd: " << dep.pdd() << " idx: " << dep.idx()); + LOG("Dependency: pdd: " << dep.m_pdd << " idx: " << dep.m_bit_idx); to_process.push_back(dep); } } @@ -158,7 +158,7 @@ namespace polysat { // We know the value of this bit // forward propagation // this might add a conflict if the value is already set to another value - if (!fix_value_core(s, fixed, r, r_tbv, i, min_val & 1 ? BIT_1 : BIT_0, alloc(bit_justication_mul, i, p, q))) + if (!fix_bit(s, fixed, r, r_tbv, i, min_val & 1 ? BIT_1 : BIT_0, alloc(bit_justication_mul, i, p, q))) return; } else if (r_tbv[i] != BIT_z && min_val == max_val - 1) { @@ -188,13 +188,13 @@ namespace polysat { // Does not set: (1, 1), (0, 0), (0, z), (z, 0) // Also does not set: (z, z) [because we don't know which one. We only know that it has to be 0 => we can still set max_val = min_val] if (bit1 == BIT_1) { - if (!fix_value_core(s, fixed, q, q_tbv, y, BIT_0, SHARED_JUSTIFICATION)) { + if (!fix_bit(s, fixed, q, q_tbv, y, BIT_0, SHARED_JUSTIFICATION)) { VERIFY(false); } set_bits.push_back(y << 1 | 1); } else if (bit2 == BIT_1) { - if (!fix_value_core(s, fixed, p, p_tbv, x, BIT_0, SHARED_JUSTIFICATION)) { + if (!fix_bit(s, fixed, p, p_tbv, x, BIT_0, SHARED_JUSTIFICATION)) { VERIFY(false); } set_bits.push_back(x << 1 | 0); @@ -204,20 +204,20 @@ namespace polysat { // Sets: (1, z), (z, 1), (1, 1), (z, z) // Does not set: (0, 0), (0, z), (z, 0), (0, 1), (1, 0) if (bit1 == BIT_1) { - if (!fix_value_core(s, fixed, q, q_tbv, y, BIT_1, SHARED_JUSTIFICATION)) { + if (!fix_bit(s, fixed, q, q_tbv, y, BIT_1, SHARED_JUSTIFICATION)) { VERIFY(false); } set_bits.push_back(y << 1 | 1); } if (bit2 == BIT_1) { - if (!fix_value_core(s, fixed, p, p_tbv, x, BIT_1, SHARED_JUSTIFICATION)) { + if (!fix_bit(s, fixed, p, p_tbv, x, BIT_1, SHARED_JUSTIFICATION)) { VERIFY(false); } set_bits.push_back(x << 1 | 0); } if (bit1 == BIT_z && bit2 == BIT_z) { - if (!fix_value_core(s, fixed, p, p_tbv, i, BIT_1, SHARED_JUSTIFICATION) || - !fix_value_core(s, fixed, q, q_tbv, i, BIT_1, SHARED_JUSTIFICATION)) { + if (!fix_bit(s, fixed, p, p_tbv, i, BIT_1, SHARED_JUSTIFICATION) || + !fix_bit(s, fixed, q, q_tbv, i, BIT_1, SHARED_JUSTIFICATION)) { VERIFY(false); } set_bits.push_back(y << 1 | 1); @@ -387,7 +387,7 @@ namespace polysat { max_bit_value++; if (min_bit_value == max_bit_value) - if (!fix_value_core(s, fixed, r, r_tbv, i, min_bit_value & 1 ? BIT_1 : BIT_0, alloc(bit_justication_add))) + if (!fix_bit(s, fixed, r, r_tbv, i, min_bit_value & 1 ? BIT_1 : BIT_0, alloc(bit_justication_add))) return; min_bit_value >>= 1; @@ -484,14 +484,14 @@ namespace polysat { // In principle, the dependencies should be acyclic so this should terminate. If there are cycles it is for sure a bug while (!to_process.empty()) { - bit_dependency& curr = to_process.back(); - if (curr.pdd().is_val()) { + bvpos& curr = to_process.back(); + if (curr.m_pdd->is_val()) { to_process.pop_back(); continue; // We don't need an explanation for bits of constants } - SASSERT(m_tbv_to_justification.contains(curr)); + SASSERT(m_bvpos_to_justification.contains(curr)); - bit_justication* j = m_tbv_to_justification[curr]; + bit_justication* j = m_bvpos_to_justification[curr].m_justification; to_process.pop_back(); insert_constraint(j); GET_DEPENDENCY(j); @@ -505,7 +505,7 @@ namespace polysat { return (*get_tbv(p))[idx]; } - // True iff the justification changed? Alternatively: true if the justification was not used (can be deallocated). + // True iff the justification was stored (must not be deallocated!) bool fixed_bits::fix_value_core(const pdd& p, tbv_ref& tbv, unsigned idx, tbit val, bit_justication* j) { LOG("Fixing bit " << idx << " in " << p << " (" << tbv << ")"); constraint* c; @@ -514,36 +514,52 @@ namespace polysat { } SASSERT(val != BIT_x); // We don't use don't-cares - if (val == BIT_z) + if (val == BIT_z) // just ignore this case return false; tbit curr_val = tbv[idx]; + bvpos pos(p, idx); - if (val == curr_val) - return false; // TODO: Take the new justification if it has a lower decision level + if (val == curr_val) { // we already have the "correct" value there + if (p.is_val()) + return false; // no justification because it is trivial + SASSERT(m_bvpos_to_justification.contains(pos)); + justified_bvpos& old_j = m_bvpos_to_justification[pos]; + if (old_j.m_justification->m_decision_level > j->m_decision_level) + return false; + replace_justification(old_j, j); // the new justification is better. Let's take it + return true; + } auto& m = tbv.manager(); if (curr_val == BIT_z) { m.set(*tbv, idx, val); - auto jstfc = m_tbv_to_justification.get({ p, idx }, nullptr); - if (jstfc && jstfc->can_dealloc()) - dealloc(jstfc); - m_tbv_to_justification.insert({ p, idx }, j); + justified_bvpos jpos(pos, j, m_trail.size()); + + auto jstfc = m_bvpos_to_justification.get(pos, {}); + if (jstfc.m_justification && jstfc.m_justification->can_dealloc()) + dealloc(jstfc.m_justification); + + m_bvpos_to_justification.insert(pos, jpos); + m_trail.push_back(jpos); return true; } SASSERT((curr_val == BIT_1 && val == BIT_0) || (curr_val == BIT_0 && val == BIT_1)); - SASSERT(m_tbv_to_justification.contains({ p, idx })); + SASSERT(p.is_val() || m_bvpos_to_justification.contains(pos)); m_consistent = false; return false; } - bool fixed_bits::fix_value(solver& s, const pdd& p, tbv_ref& tbv, unsigned idx, tbit val, bit_justication* j) { + bool fixed_bits::fix_bit(solver& s, const pdd& p, tbv_ref& tbv, unsigned idx, tbit val, bit_justication* j) { + SASSERT(m_trail.size() == s.m_level); + bool changed = fix_value_core(p, tbv, idx, val, j); if (changed) return true; if (!m_consistent) { - clause_ref explanation = get_explanation(s, j, m_tbv_to_justification[{ p, idx }]); + bvpos pos(p, idx); + clause_ref explanation = get_explanation(s, j, m_bvpos_to_justification[pos].m_justification); s.set_conflict(*explanation); } return false; @@ -560,7 +576,7 @@ namespace polysat { // TODO: Propagate equality if everything is set if (!m_consistent) { LOG("Adding conflict on bit " << idx << " on pdd " << p); - clause_ref explanation = get_explanation(s, j, m_tbv_to_justification[{ p, idx }]); + clause_ref explanation = get_explanation(s, j, m_bvpos_to_justification[{ p, idx }].m_justification); s.set_conflict(*explanation); return false; // get_explanation will dealloc the justification } @@ -570,17 +586,62 @@ namespace polysat { } void fixed_bits::clear_value(const pdd& p, unsigned idx) { - // TODO: Use during backtracking - SASSERT(p.is_var()); tbv_ref& tbv = *get_tbv(p); auto& m = tbv.manager(); m.set(*tbv, idx, BIT_z); - - SASSERT(m_tbv_to_justification.contains({ p, idx })); - auto& jstfc = m_tbv_to_justification[{ p, idx }]; - if (jstfc->can_dealloc()) - dealloc(jstfc); - jstfc = nullptr; + bvpos pos(p, idx); + SASSERT(m_bvpos_to_justification.contains(pos)); + const auto& jstfc = m_bvpos_to_justification[pos]; + if (jstfc.m_justification->can_dealloc()) + dealloc(jstfc.m_justification); + m_bvpos_to_justification.remove(pos); + } + + void fixed_bits::replace_justification(const justified_bvpos& old_j, bit_justication* new_j) { + SASSERT(old_j.m_justification->m_decision_level > new_j->m_decision_level); + SASSERT(m_trail[old_j.m_trail_pos] == old_j); + + if (old_j.m_justification->can_dealloc()) + dealloc(old_j.m_justification); + m_trail[old_j.m_trail_pos].m_justification = new_j; // We only overwrite with smaller decision-levels. This way we preserve some kind of "order" + } + + void fixed_bits::push() { + m_trail_size.push_back(m_trail.size()); + } + + void fixed_bits::pop(unsigned pop_cnt) { + SASSERT(!m_consistent); // Why do we backtrack if this is true? We might remove this for (random) restarts + SASSERT(pop_cnt > 0); + + unsigned old_lvl = m_trail_size.size(); + unsigned new_lvl = old_lvl - pop_cnt; + SASSERT(pop_cnt <= old_lvl); + + unsigned prev_cnt = m_trail_size[new_lvl]; + m_trail_size.resize(new_lvl); + + unsigned last_to_keep = -1; + + for (unsigned i = m_trail.size(); i > prev_cnt; i--) { + // all elements m_trail[j] with (j > i) have higher decision levels than new_lvl + justified_bvpos& curr = m_trail[i - 1]; + SASSERT(curr.m_justification->m_decision_level <= old_lvl); + + if (curr.m_justification->m_decision_level > new_lvl) { + clear_value(curr.get_pdd(), curr.get_idx()); + if (last_to_keep != -1) + std::swap(curr, m_trail[--last_to_keep]); + } + else { + if (last_to_keep == -1) + last_to_keep = i; + } + } + if (last_to_keep == -1) + m_trail.resize(prev_cnt); + else + m_trail.resize(last_to_keep); } #define COUNT(DOWN, TO_COUNT) \ diff --git a/src/math/polysat/fixed_bits.h b/src/math/polysat/fixed_bits.h index 122eb2cc2..940f2a961 100644 --- a/src/math/polysat/fixed_bits.h +++ b/src/math/polysat/fixed_bits.h @@ -24,53 +24,57 @@ namespace polysat { class constraint; class fixed_bits; - class bit_dependency { + struct bvpos { optional m_pdd; unsigned m_bit_idx; public: - bit_dependency() : m_pdd(optional::undef()), m_bit_idx(0) {} - bit_dependency(const bit_dependency& v) = default; - bit_dependency(bit_dependency&& v) = default; + bvpos() : m_pdd(optional::undef()), m_bit_idx(0) {} + bvpos(const bvpos& v) = default; + bvpos(bvpos&& v) = default; - bit_dependency(const pdd& pdd, unsigned bit_idx) : m_pdd(pdd), m_bit_idx(bit_idx) {} + bvpos(const pdd& pdd, unsigned bit_idx) : m_pdd(pdd), m_bit_idx(bit_idx) {} - bool operator==(const bit_dependency& other) const { + bool operator==(const bvpos& other) const { return m_pdd == other.m_pdd && m_bit_idx == other.m_bit_idx; } - bit_dependency& operator=(bit_dependency&& other) { + bvpos& operator=(bvpos&& other) { m_pdd = other.m_pdd; m_bit_idx = other.m_bit_idx; return *this; } - bit_dependency& operator=(bit_dependency& other) { + bvpos& operator=(bvpos& other) { m_pdd = other.m_pdd; m_bit_idx = other.m_bit_idx; return *this; } - unsigned idx() const { return m_bit_idx; } - const pdd& pdd() const { return *m_pdd; } + unsigned get_idx() const { return m_bit_idx; } + const pdd& get_pdd() const { return *m_pdd; } }; - using bit_dependencies = vector; + using bit_dependencies = vector; class bit_justication { protected: static bit_justication* get_other_justification(const fixed_bits& fixed, const pdd& p, unsigned idx); static const tbv_ref* get_tbv(fixed_bits& fixed, const pdd& p); - static bool fix_value_core(solver& s, fixed_bits& fixed, const pdd& p, tbv_ref& tbv, unsigned idx, tbit val, bit_justication** j); - static bool fix_value_core(solver& s, fixed_bits& fixed, const pdd& p, tbv_ref& tbv, unsigned idx, tbit val, bit_justication* j); + static bool fix_bit(solver& s, fixed_bits& fixed, const pdd& p, tbv_ref& tbv, unsigned idx, tbit val, bit_justication** j); + static bool fix_bit(solver& s, fixed_bits& fixed, const pdd& p, tbv_ref& tbv, unsigned idx, tbit val, bit_justication* j); public: - virtual bool can_dealloc() { return true; } + + unsigned m_decision_level; + + virtual bool can_dealloc() { return true; } // we may not dealloc if the justification is used for multiple bits virtual bool has_constraint(constraint*& constr) { return false; } virtual void get_dependencies(fixed_bits& fixed, bit_dependencies& to_process) = 0; // returns if element may be deallocated after call (usually true) }; // if multiple bits are justified by the same justification + // All elements have to be in the same decision-level class bit_justication_shared : public bit_justication { bit_justication* m_justification; unsigned m_references = 0; @@ -121,12 +125,12 @@ namespace polysat { void get_dependencies(fixed_bits& fixed, bit_dependencies& to_process) override; static bit_justication_constraint* mk_assignment(constraint* c) { return alloc(bit_justication_constraint, c ); } - static bit_justication_constraint* mk_unary(constraint* c, bit_dependency v) { + static bit_justication_constraint* mk_unary(constraint* c, bvpos v) { bit_dependencies dep; dep.push_back(std::move(v)); return alloc(bit_justication_constraint, c, std::move(dep)); } - static bit_justication_constraint* mk_binary(constraint* c, bit_dependency v1, bit_dependency v2) { + static bit_justication_constraint* mk_binary(constraint* c, bvpos v1, bvpos v2) { bit_dependencies dep; dep.push_back(std::move(v1)); dep.push_back(std::move(v2)); @@ -177,6 +181,18 @@ namespace polysat { void get_dependencies(fixed_bits& fixed, bit_dependencies& to_process) override; }; + struct justified_bvpos : public bvpos { + bit_justication* m_justification; + unsigned m_trail_pos; + + justified_bvpos() = default; + + justified_bvpos(const pdd & pdd, unsigned idx, bit_justication* jstfc, unsigned int trail_pos) : + bvpos(pdd, idx), m_justification(jstfc), m_trail_pos(trail_pos) {} + + justified_bvpos(const bvpos& pos, bit_justication* jstfc, unsigned int trail_pos) : + bvpos(pos), m_justification(jstfc), m_trail_pos(trail_pos) {} + }; class fixed_bits final { @@ -194,18 +210,20 @@ namespace polysat { }; using pdd_to_tbv_map = map; - using tbv_to_justification_key = bit_dependency; - using tbv_to_justification_eq = default_eq; - struct tbv_to_justification_hash { - unsigned operator()(tbv_to_justification_key const& args) const { - return combine_hash(args.pdd().hash(), args.idx()); + using bvpos_to_justification_eq = default_eq; + struct bvpos_to_justification_hash { + unsigned operator()(bvpos const& args) const { + return combine_hash(args.get_pdd().hash(), args.get_idx()); } }; - using tbv_to_justification_map = map; + using bvpos_to_justification_map = map; //vector> m_var_to_tbv; - pdd_to_tbv_map m_var_to_tbv; // TODO: free tbv_ref pointers - tbv_to_justification_map m_tbv_to_justification; // the elements are pointers. Deallocate them before replacing them + pdd_to_tbv_map m_var_to_tbv; + bvpos_to_justification_map m_bvpos_to_justification; + + svector m_trail; + unsigned_vector m_trail_size; bool m_consistent = true; // in case evaluating results in a bit-conflict @@ -214,7 +232,9 @@ namespace polysat { clause_ref get_explanation(solver& s, bit_justication* j1, bit_justication* j2); bool fix_value_core(const pdd& p, tbv_ref& tbv, unsigned idx, tbit val, bit_justication* j); - bool fix_value(solver& s, const pdd& p, tbv_ref& tbv, unsigned idx, tbit val, bit_justication* j); + bool fix_bit(solver& s, const pdd& p, tbv_ref& tbv, unsigned idx, tbit val, bit_justication* j); + void clear_value(const pdd& p, unsigned idx); + void replace_justification(const justified_bvpos& old_j, bit_justication* new_j); void propagate_to_subterm(solver& s, const pdd& p); @@ -223,8 +243,11 @@ namespace polysat { fixed_bits(solver& s) : m_solver(s) {} ~fixed_bits() { - for (auto& tbv : m_var_to_tbv) { + for (auto& tbv : m_var_to_tbv) dealloc(tbv.m_value); + for (justified_bvpos& just : m_trail) { + if (just.m_justification->can_dealloc()) + dealloc(just.m_justification); } } @@ -240,8 +263,9 @@ namespace polysat { tbit get_value(const pdd& p, unsigned idx); // More efficient than calling "eval" and accessing the returned tbv elements // call this function also if we already know that the correct value is written there. We might decrease the decision level (for "replay") bool fix_value(solver& s, const pdd& p, unsigned idx, tbit val, bit_justication* j); - void clear_value(const pdd& p, unsigned idx); - + void push(); + void pop(unsigned pop_cnt = 1); + tbv_ref* eval(solver& s, const pdd& p); }; diff --git a/src/math/polysat/solver.cpp b/src/math/polysat/solver.cpp index 4d2fa1ff6..9d67e94c5 100644 --- a/src/math/polysat/solver.cpp +++ b/src/math/polysat/solver.cpp @@ -600,6 +600,7 @@ namespace polysat { } } #endif + m_fixed_bits.push(); if (can_bdecide()) bdecide(); else @@ -822,6 +823,7 @@ namespace polysat { continue; } if (j.is_decision()) { + m_fixed_bits.pop(); m_conflict.revert_pvar(v); revert_decision(v); return; @@ -850,6 +852,7 @@ namespace polysat { } SASSERT(!m_bvars.is_assumption(var)); // TODO: "assumption" is basically "propagated by unit clause" (or "at base level"); except we do not explicitly store the unit clause. if (m_bvars.is_decision(var)) { + m_fixed_bits.pop(); revert_bool_decision(lit); return; } diff --git a/src/util/tbv.h b/src/util/tbv.h index b4725a1a0..b010dc388 100644 --- a/src/util/tbv.h +++ b/src/util/tbv.h @@ -27,10 +27,10 @@ Revision History: class tbv; enum tbit { - BIT_z = 0x0, - BIT_0 = 0x1, - BIT_1 = 0x2, - BIT_x = 0x3 + BIT_z = 0x0, // unknown + BIT_0 = 0x1, // for sure 0 + BIT_1 = 0x2, // for sure 1 + BIT_x = 0x3 // don't care }; inline tbit neg(tbit t) {