diff --git a/src/muz_qe/hilbert_basis.cpp b/src/muz_qe/hilbert_basis.cpp index 49bf20746..59b513d8d 100644 --- a/src/muz_qe/hilbert_basis.cpp +++ b/src/muz_qe/hilbert_basis.cpp @@ -192,19 +192,9 @@ class hilbert_basis::value_index2 { checker m_checker; vector m_keys; -#if 1 numeral const* get_keys(values const& vs) { return vs()-1; } -#else - numeral const* get_keys(values const& vs) { - unsigned sz = m_keys.size(); - for (unsigned i = 0; i < sz; ++i) { - m_keys[sz-i-1] = vs()[i-1]; - } - return m_keys.c_ptr(); - } -#endif public: value_index2(hilbert_basis& hb): hb(hb), m_init(false) { @@ -506,7 +496,11 @@ class hilbert_basis::passive2 { } }; hilbert_basis& hb; - svector const& m_sos; + svector m_pos_sos; + svector m_neg_sos; + vector m_pos_sos_sum; + vector m_neg_sos_sum; + vector m_sum_abs; unsigned_vector m_psos; svector m_pas; vector m_weight; @@ -527,40 +521,59 @@ class hilbert_basis::passive2 { public: passive2(hilbert_basis& hb): hb(hb), - m_sos(hb.m_sos), m_lt(&m_this), m_heap(10, m_lt) { m_this = this; } + void init(svector const& I) { + for (unsigned i = 0; i < I.size(); ++i) { + numeral const& w = hb.vec(I[i]).weight(); + if (w.is_pos()) { + m_pos_sos.push_back(I[i]); + m_pos_sos_sum.push_back(sum_abs(I[i])); + } + else { + m_neg_sos.push_back(I[i]); + m_neg_sos_sum.push_back(sum_abs(I[i])); + } + } + } + void reset() { m_heap.reset(); m_free_list.reset(); m_psos.reset(); m_pas.reset(); + m_sum_abs.reset(); + m_pos_sos.reset(); + m_neg_sos.reset(); + m_pos_sos_sum.reset(); + m_neg_sos_sum.reset(); m_weight.reset(); } void insert(offset_t idx, unsigned offset) { - SASSERT(!m_sos.empty()); + SASSERT(!m_pos_sos.empty()); unsigned v; - numeral w = sum_abs(idx) + sum_abs(m_sos[0]); if (m_free_list.empty()) { v = m_pas.size(); m_pas.push_back(idx); m_psos.push_back(offset); - m_weight.push_back(w); + m_weight.push_back(numeral(0)); m_heap.set_bounds(v+1); + m_sum_abs.push_back(sum_abs(idx)); } else { v = m_free_list.back(); m_free_list.pop_back(); m_pas[v] = idx; m_psos[v] = offset; - m_weight[v] = w; + m_weight[v] = numeral(0); + m_sum_abs[v] = sum_abs(idx); } - next_resolvable(v); + next_resolvable(hb.vec(idx).weight().is_pos(), v); } bool empty() const { @@ -570,12 +583,13 @@ public: unsigned pop(offset_t& sos, offset_t& pas) { SASSERT (!empty()); unsigned val = static_cast(m_heap.erase_min()); - unsigned psos = m_psos[val]; - sos = m_sos[psos]; pas = m_pas[val]; - m_psos[val]++; - next_resolvable(val); numeral old_weight = hb.vec(pas).weight(); + bool is_positive = old_weight.is_pos(); + unsigned psos = m_psos[val]; + sos = is_positive?m_neg_sos[psos]:m_pos_sos[psos]; + m_psos[val]++; + next_resolvable(is_positive, val); numeral new_weight = hb.vec(sos).weight() + old_weight; if (new_weight.is_pos() != old_weight.is_pos()) { psos = 0; @@ -599,7 +613,7 @@ public: public: iterator(passive2& p, unsigned i): p(p), m_idx(i) { fwd(); } offset_t pas() const { return p.m_pas[m_idx]; } - offset_t sos() const { return p.m_sos[p.m_psos[m_idx]]; } + offset_t sos() const { return (p.hb.vec(pas()).weight().is_pos()?p.m_neg_sos:p.m_pos_sos)[p.m_psos[m_idx]]; } iterator& operator++() { ++m_idx; fwd(); return *this; } iterator operator++(int) { iterator tmp = *this; ++*this; return tmp; } bool operator==(iterator const& it) const {return m_idx == it.m_idx; } @@ -614,12 +628,14 @@ public: return iterator(*this, m_pas.size()); } private: - void next_resolvable(unsigned v) { + void next_resolvable(bool is_positive, unsigned v) { offset_t pas = m_pas[v]; - while (m_psos[v] < m_sos.size()) { - offset_t sos = m_sos[m_psos[v]]; - if (hb.can_resolve(sos, pas)) { - m_weight[v] = sum_abs(pas) + sum_abs(sos); + svector const& soss = is_positive?m_neg_sos:m_pos_sos; + while (m_psos[v] < soss.size()) { + unsigned psos = m_psos[v]; + offset_t sos = soss[psos]; + if (hb.can_resolve(sos, pas, false)) { + m_weight[v] = m_sum_abs[v] + (is_positive?m_neg_sos_sum[psos]:m_pos_sos_sum[psos]); m_heap.insert(v); return; } @@ -745,7 +761,7 @@ unsigned hilbert_basis::get_num_vars() const { } hilbert_basis::values hilbert_basis::vec(offset_t offs) const { - return values(m_store.c_ptr() + (get_num_vars() + 1)*offs.m_offset); + return values(m_ineqs.size(), m_store.c_ptr() + offs.m_offset); } void hilbert_basis::init_basis() { @@ -804,6 +820,9 @@ lbool hilbert_basis::saturate_orig(num_vector const& ineq, bool is_eq) { offset_t idx = *it; values v = vec(idx); v.weight() = get_weight(v, ineq); + for (unsigned k = 0; k < m_current_ineq; ++k) { + v.weight(k) = get_weight(v, m_ineqs[k]); + } add_goal(idx); if (m_use_support) { support.insert(idx.m_offset); @@ -823,7 +842,7 @@ lbool hilbert_basis::saturate_orig(num_vector const& ineq, bool is_eq) { continue; } for (unsigned i = 0; !m_cancel && i < m_active.size(); ++i) { - if ((!m_use_support || support.contains(m_active[i].m_offset)) && can_resolve(idx, m_active[i])) { + if ((!m_use_support || support.contains(m_active[i].m_offset)) && can_resolve(idx, m_active[i], true)) { resolve(idx, m_active[i], j); if (add_goal(j)) { j = alloc_vector(); @@ -874,6 +893,9 @@ lbool hilbert_basis::saturate(num_vector const& ineq, bool is_eq) { offset_t idx = m_basis[i]; values v = vec(idx); v.weight() = get_weight(v, ineq); + for (unsigned k = 0; k < m_current_ineq; ++k) { + v.weight(k) = get_weight(v, m_ineqs[k]); + } m_index->insert(idx, v); if (v.weight().is_zero()) { m_zero.push_back(idx); @@ -886,6 +908,7 @@ lbool hilbert_basis::saturate(num_vector const& ineq, bool is_eq) { } } m_basis.resize(init_basis_size); + m_passive2->init(m_sos); // ASSERT basis is sorted by weight. // initialize passive @@ -902,7 +925,7 @@ lbool hilbert_basis::saturate(num_vector const& ineq, bool is_eq) { offset_t sos, pas; TRACE("hilbert_basis", display(tout); ); unsigned offset = m_passive2->pop(sos, pas); - SASSERT(can_resolve(sos, pas)); + SASSERT(can_resolve(sos, pas, true)); resolve(sos, pas, idx); if (is_subsumed(idx)) { continue; @@ -933,14 +956,6 @@ lbool hilbert_basis::saturate(num_vector const& ineq, bool is_eq) { m_free_list.push_back(m_basis.back()); m_basis.pop_back(); } - for (unsigned i = 0; i < init_basis_size; ++i) { - offset_t idx = m_basis[i]; - if (vec(idx).weight().is_neg()) { - m_basis[i] = m_basis.back(); - m_basis.pop_back(); - - } - } m_basis.append(m_zero); std::sort(m_basis.begin(), m_basis.end(), vector_lt_t(*this)); m_zero.reset(); @@ -1051,6 +1066,9 @@ void hilbert_basis::resolve(offset_t i, offset_t j, offset_t r) { u[k] = v[k] + w[k]; } u.weight() = v.weight() + w.weight(); + for (unsigned k = 0; k < m_current_ineq; ++k) { + u.weight(k) = v.weight(k) + w.weight(k); + } TRACE("hilbert_basis_verbose", display(tout, i); display(tout, j); @@ -1061,10 +1079,11 @@ void hilbert_basis::resolve(offset_t i, offset_t j, offset_t r) { hilbert_basis::offset_t hilbert_basis::alloc_vector() { if (m_free_list.empty()) { - unsigned num_vars = get_num_vars(); - unsigned idx = m_store.size(); - m_store.resize(idx + 1 + num_vars); - return offset_t(idx/(1+num_vars)); + unsigned sz = m_ineqs.size() + get_num_vars(); + unsigned idx = m_store.size(); + m_store.resize(idx + sz); + // std::cout << "alloc vector: " << idx << " " << sz << " " << m_store.c_ptr() + idx << " " << m_ineqs.size() << "\n"; + return offset_t(idx); } else { offset_t result = m_free_list.back(); @@ -1099,10 +1118,11 @@ bool hilbert_basis::is_subsumed(offset_t idx) { return false; } -bool hilbert_basis::can_resolve(offset_t i, offset_t j) const { - if (get_sign(i) == get_sign(j)) { +bool hilbert_basis::can_resolve(offset_t i, offset_t j, bool check_sign) const { + if (check_sign && get_sign(i) == get_sign(j)) { return false; } + SASSERT(get_sign(i) != get_sign(j)); values const& v1 = vec(i); values const& v2 = vec(j); if (v1[0].is_one() && v2[0].is_one()) { @@ -1121,7 +1141,7 @@ bool hilbert_basis::can_resolve(offset_t i, offset_t j) const { } hilbert_basis::sign_t hilbert_basis::get_sign(offset_t idx) const { - numeral val = vec(idx).weight(); + numeral const& val = vec(idx).weight(); if (val.is_pos()) { return pos; } @@ -1265,7 +1285,7 @@ bool hilbert_basis::is_subsumed(offset_t i, offset_t j) const { n >= m && (!m.is_neg() || n == m) && is_geq(v, w); for (unsigned k = 0; r && k < m_current_ineq; ++k) { - r = get_weight(vec(i), m_ineqs[k]) >= get_weight(vec(j), m_ineqs[k]); + r = v.weight(k) >= w.weight(k); } CTRACE("hilbert_basis", r, display(tout, i); diff --git a/src/muz_qe/hilbert_basis.h b/src/muz_qe/hilbert_basis.h index 78d4d7cec..bad4b1fbd 100644 --- a/src/muz_qe/hilbert_basis.h +++ b/src/muz_qe/hilbert_basis.h @@ -56,12 +56,14 @@ private: class values { numeral* m_values; public: - values(numeral* v):m_values(v) {} - numeral& weight() { return m_values[0]; } // value of a*x - numeral& operator[](unsigned i) { return m_values[i+1]; } // value of x_i - numeral const& weight() const { return m_values[0]; } // value of a*x - numeral const& operator[](unsigned i) const { return m_values[i+1]; } // value of x_i - numeral const* operator()() const { return m_values + 1; } + values(unsigned offset, numeral* v): m_values(v+offset) { } + numeral& weight() { return m_values[-1]; } // value of a*x + numeral const& weight() const { return m_values[-1]; } // value of a*x + numeral& weight(int i) { return m_values[-2-i]; } // value of b_i*x for 0 <= i < current inequality. + numeral const& weight(int i) const { return m_values[-2-i]; } // value of b_i*x + numeral& operator[](unsigned i) { return m_values[i]; } // value of x_i + numeral const& operator[](unsigned i) const { return m_values[i]; } // value of x_i + numeral const* operator()() const { return m_values; } }; vector m_ineqs; // set of asserted inequalities @@ -114,7 +116,7 @@ private: bool is_subsumed(offset_t idx); bool is_subsumed(offset_t i, offset_t j) const; void recycle(offset_t idx); - bool can_resolve(offset_t i, offset_t j) const; + bool can_resolve(offset_t i, offset_t j, bool check_sign) const; sign_t get_sign(offset_t idx) const; bool add_goal(offset_t idx); offset_t alloc_vector(); diff --git a/src/test/hilbert_basis.cpp b/src/test/hilbert_basis.cpp index e2d1d337e..69d733f68 100644 --- a/src/test/hilbert_basis.cpp +++ b/src/test/hilbert_basis.cpp @@ -521,6 +521,8 @@ void tst_hilbert_basis() { tst2(); tst3(); tst4(); + tst4(); + tst4(); tst5(); tst6(); tst7();