diff --git a/src/ast/rewriter/pb_rewriter_def.h b/src/ast/rewriter/pb_rewriter_def.h index 3c60babce..7e2a13779 100644 --- a/src/ast/rewriter/pb_rewriter_def.h +++ b/src/ast/rewriter/pb_rewriter_def.h @@ -247,7 +247,7 @@ lbool pb_rewriter_util::normalize(typename PBU::args_t& args, typename PBU: // example: k = 5, min = 3, max = 4: 5/3 -> 2 5/4 -> 1, n = 2 // replace all coefficients by 1, and k by 2. // - if (!k.is_one()) { + if (false && !k.is_one()) { PBU::numeral min = args[0].second, max = args[0].second; for (unsigned i = 1; i < args.size(); ++i) { if (args[i].second < min) min = args[i].second; diff --git a/src/smt/theory_pb.cpp b/src/smt/theory_pb.cpp index fc2f26986..cd2891474 100644 --- a/src/smt/theory_pb.cpp +++ b/src/smt/theory_pb.cpp @@ -424,7 +424,7 @@ namespace smt { bool_var abv = ctx.mk_bool_var(atom); ctx.set_var_theory(abv, get_id()); - ineq* c = alloc(ineq, literal(abv), m_util.is_eq(atom)); + ineq* c = alloc(ineq, m_mpz_mgr, literal(abv), m_util.is_eq(atom)); c->m_args[0].m_k = m_util.get_k(atom); numeral& k = c->m_args[0].m_k; arg_t& args = c->m_args[0]; @@ -446,6 +446,7 @@ namespace smt { else { SASSERT(m_util.is_at_least_k(atom) || m_util.is_ge(atom) || m_util.is_eq(atom)); } + TRACE("pb", display(tout, *c);); c->unique(); lbool is_true = c->normalize(); c->prune(); @@ -479,10 +480,13 @@ namespace smt { } // maximal coefficient: - numeral& max_watch = c->m_max_watch; - max_watch = numeral::zero(); + scoped_mpz& max_watch = c->m_max_watch; + max_watch.reset(); for (unsigned i = 0; i < args.size(); ++i) { - max_watch = std::max(max_watch, args[i].second); + mpz const& num = args[i].second.to_mpq().numerator(); + if (m_mpz_mgr.lt(max_watch, num)) { + max_watch = num; + } } // pre-compile threshold for cardinality @@ -633,16 +637,17 @@ namespace smt { watch.pop_back(); SASSERT(ineq_index < c.watch_size()); - numeral coeff = c.coeff(ineq_index); + scoped_mpz coeff(m_mpz_mgr); + coeff = c.ncoeff(ineq_index); if (ineq_index + 1 < c.watch_size()) { std::swap(c.args()[ineq_index], c.args()[c.watch_size()-1]); } --c.m_watch_sz; c.m_watch_sum -= coeff; - if (c.max_watch() == coeff) { - coeff = c.coeff(0); + if (coeff == c.max_watch()) { + coeff = c.ncoeff(0); for (unsigned i = 1; coeff != c.max_watch() && i < c.watch_size(); ++i) { - if (coeff < c.coeff(i)) coeff = c.coeff(i); + if (coeff < c.ncoeff(i)) coeff = c.ncoeff(i); } c.set_max_watch(coeff); } @@ -653,7 +658,8 @@ namespace smt { void theory_pb::add_watch(ineq& c, unsigned i) { SASSERT(c.is_ge()); literal lit = c.lit(i); - numeral coeff = c.coeff(i); + scoped_mpz coeff(m_mpz_mgr); + coeff = c.ncoeff(i); c.m_watch_sum += coeff; SASSERT(i >= c.watch_size()); @@ -834,18 +840,19 @@ namespace smt { } literal_vector& theory_pb::get_helpful_literals(ineq& c, bool negate) { - numeral sum = numeral::zero(); + scoped_mpz sum(m_mpz_mgr); + mpz const& k = c.mpz_k(); context& ctx = get_context(); literal_vector& lits = get_lits(); - for (unsigned i = 0; sum < c.k() && i < c.size(); ++i) { + for (unsigned i = 0; sum < k && i < c.size(); ++i) { literal l = c.lit(i); if (ctx.get_assignment(l) == l_true) { - sum += c.coeff(i); + sum += c.ncoeff(i); if (negate) l = ~l; lits.push_back(l); } } - SASSERT(sum >= c.k()); + SASSERT(sum >= k); return lits; } @@ -891,8 +898,8 @@ namespace smt { */ void theory_pb::assign_ineq(ineq& c, bool is_true) { context& ctx = get_context(); - ctx.push_trail(value_trail(c.m_max_sum)); - ctx.push_trail(value_trail(c.m_min_sum)); + ctx.push_trail(value_trail(c.m_max_sum)); + ctx.push_trail(value_trail(c.m_min_sum)); ctx.push_trail(value_trail(c.m_nfixed)); ctx.push_trail(rewatch_vars(*this, c)); @@ -904,15 +911,14 @@ namespace smt { ctx.push_trail(negate_ineq(c)); } - numeral maxsum = numeral::zero(); - numeral mininc = numeral::zero(); + scoped_mpz maxsum(m_mpz_mgr), mininc(m_mpz_mgr); for (unsigned i = 0; i < sz; ++i) { lbool asgn = ctx.get_assignment(c.lit(i)); if (asgn != l_false) { - maxsum += c.coeff(i); + maxsum += c.ncoeff(i); } - if (asgn == l_undef && (mininc.is_zero() || mininc > c.coeff(i))) { - mininc = c.coeff(i); + if (asgn == l_undef && (mininc.is_zero() || mininc > c.ncoeff(i))) { + mininc = c.ncoeff(i); } } @@ -920,19 +926,19 @@ namespace smt { tout << "assign: " << c.lit() << "\n"; display(tout, c); ); - if (maxsum < c.k()) { + if (maxsum < c.mpz_k()) { literal_vector& lits = get_unhelpful_literals(c, false); lits.push_back(~c.lit()); add_clause(c, lits); } else { init_watch_literal(c); - SASSERT(c.watch_sum() >= c.k()); + SASSERT(c.m_watch_sum >= c.mpz_k()); DEBUG_CODE(validate_watch(c);); } // perform unit propagation - if (maxsum >= c.k() && maxsum - mininc < c.k()) { + if (maxsum >= c.mpz_k() && maxsum - mininc < c.mpz_k()) { literal_vector& lits = get_unhelpful_literals(c, true); lits.push_back(c.lit()); for (unsigned i = 0; i < sz; ++i) { @@ -989,12 +995,12 @@ namespace smt { SASSERT(i < c.size()); if (c.lit(i).sign() == is_true) { - ctx.push_trail(value_trail(c.m_max_sum)); - c.m_max_sum -= c.coeff(i); + ctx.push_trail(value_trail(c.m_max_sum)); + c.m_max_sum -= c.ncoeff(i); } else { - ctx.push_trail(value_trail(c.m_min_sum)); - c.m_min_sum += c.coeff(i); + ctx.push_trail(value_trail(c.m_min_sum)); + c.m_min_sum += c.ncoeff(i); } DEBUG_CODE( numeral sum(0); @@ -1336,14 +1342,17 @@ namespace smt { void theory_pb::init_watch_literal(ineq& c) { context& ctx = get_context(); - c.m_watch_sum = numeral::zero(); + scoped_mpz max_k(m_mpz_mgr); + c.m_watch_sum.reset(); c.m_watch_sz = 0; - c.m_max_watch = numeral::zero(); - bool watch_more = c.watch_sum() < c.k() + c.max_watch(); + c.m_max_watch.reset(); + bool watch_more = true; for (unsigned i = 0; watch_more && i < c.size(); ++i) { if (ctx.get_assignment(c.lit(i)) != l_false) { add_watch(c, i); - watch_more = c.watch_sum() < c.k() + c.max_watch(); + max_k = c.mpz_k(); + max_k += c.max_watch(); + watch_more = c.m_watch_sum < max_k; } } ctx.push_trail(unwatch_ge(*this, c)); @@ -1358,7 +1367,7 @@ namespace smt { c.m_watch_sz = 0; for (unsigned i = 0; i < c.size(); ++i) { watch_var(c.lit(i).var(), &c); - c.m_max_sum += c.coeff(i); + c.m_max_sum += c.ncoeff(i); } } @@ -1561,6 +1570,7 @@ namespace smt { unset_marks(); m_num_marks = 0; m_lemma.reset(); + m_lemma.m_k.reset(); m_ineq_literals.reset(); process_ineq(c, null_literal, numeral::one()); // add consequent to lemma. @@ -1918,12 +1928,12 @@ namespace smt { } out << (c.is_ge()?" >= ":" = ") << c.k() << "\n"; if (c.m_num_propagations) out << "propagations: " << c.m_num_propagations << " "; - if (c.max_watch().is_pos()) out << "max_watch: " << c.max_watch() << " "; + if (c.m_max_watch.is_pos()) out << "max_watch: " << c.max_watch() << " "; if (c.watch_size()) out << "watch size: " << c.watch_size() << " "; - if (c.watch_sum().is_pos()) out << "watch-sum: " << c.watch_sum() << " "; - if (!c.max_sum().is_zero()) out << "sum: [" << c.min_sum() << ":" << c.max_sum() << "] "; - if (c.m_num_propagations || c.max_watch().is_pos() || c.watch_size() || - c.watch_sum().is_pos() || !c.max_sum().is_zero()) out << "\n"; + if (c.m_watch_sum.is_pos()) out << "watch-sum: " << c.watch_sum() << " "; + if (!c.m_max_sum.is_zero()) out << "sum: [" << c.min_sum() << ":" << c.max_sum() << "] "; + if (c.m_num_propagations || c.m_max_watch.is_pos() || c.watch_size() || + c.m_watch_sum.is_pos() || !c.m_max_sum.is_zero()) out << "\n"; return out; } diff --git a/src/smt/theory_pb.h b/src/smt/theory_pb.h index be2158970..28257713e 100644 --- a/src/smt/theory_pb.h +++ b/src/smt/theory_pb.h @@ -105,23 +105,27 @@ namespace smt { struct ineq { + unsynch_mpz_manager& m_mpz; // mpz manager. literal m_lit; // literal repesenting predicate bool m_is_eq; // is this an = or >=. arg_t m_args[2]; // encode args[0]*coeffs[0]+...+args[n-1]*coeffs[n-1] >= k(); // Watch the first few positions until the sum satisfies: // sum coeffs[i] >= m_lower + max_watch - numeral m_max_watch; // maximal coefficient. + scoped_mpz m_max_watch; // maximal coefficient. unsigned m_watch_sz; // number of literals being watched. - numeral m_watch_sum; // maximal sum of watch literals. + scoped_mpz m_watch_sum; // maximal sum of watch literals. // Watch infrastructure for = and unassigned >=: unsigned m_nfixed; // number of variables that are fixed. - numeral m_max_sum; // maximal possible sum. - numeral m_min_sum; // minimal possible sum. + scoped_mpz m_max_sum; // maximal possible sum. + scoped_mpz m_min_sum; // minimal possible sum. unsigned m_num_propagations; unsigned m_compilation_threshold; lbool m_compiled; - ineq(literal l, bool is_eq) : m_lit(l), m_is_eq(is_eq) { + ineq(unsynch_mpz_manager& m, literal l, bool is_eq) : + m_mpz(m), m_lit(l), m_is_eq(is_eq), + m_max_watch(m), m_watch_sum(m), + m_max_sum(m), m_min_sum(m) { reset(); } @@ -130,22 +134,24 @@ namespace smt { literal lit() const { return m_lit; } numeral const & k() const { return args().m_k; } + mpz const & mpz_k() const { return k().to_mpq().numerator(); } literal lit(unsigned i) const { return args()[i].first; } numeral const & coeff(unsigned i) const { return args()[i].second; } + class mpz const& ncoeff(unsigned i) const { return coeff(i).to_mpq().numerator(); } unsigned size() const { return args().size(); } - numeral const& watch_sum() const { return m_watch_sum; } - numeral const& max_watch() const { return m_max_watch; } - void set_max_watch(numeral const& n) { m_max_watch = n; } + class mpz const& watch_sum() const { return m_watch_sum; } + class mpz const& max_watch() const { return m_max_watch.get(); } + void set_max_watch(mpz const& n) { m_max_watch = n; } unsigned watch_size() const { return m_watch_sz; } // variable watch infrastructure - numeral const& min_sum() const { return m_min_sum; } - numeral const& max_sum() const { return m_max_sum; } + class mpz const& min_sum() const { return m_min_sum; } + class mpz const& max_sum() const { return m_max_sum; } unsigned nfixed() const { return m_nfixed; } - bool vwatch_initialized() const { return !max_sum().is_zero(); } + bool vwatch_initialized() const { return !m_mpz.is_zero(max_sum()); } void vwatch_reset() { m_min_sum.reset(); m_max_sum.reset(); m_nfixed = 0; } unsigned find_lit(bool_var v, unsigned begin, unsigned end) { diff --git a/src/util/scoped_numeral.h b/src/util/scoped_numeral.h index 54d55f827..0023be7e2 100644 --- a/src/util/scoped_numeral.h +++ b/src/util/scoped_numeral.h @@ -97,6 +97,10 @@ public: return a.m().eq(a, b); } + friend bool operator!=(_scoped_numeral const & a, numeral const & b) { + return !a.m().eq(a, b); + } + friend bool operator<(_scoped_numeral const & a, numeral const & b) { return a.m().lt(a, b); } @@ -113,6 +117,26 @@ public: return a.m().ge(a, b); } + bool is_zero() const { + return m().is_zero(*this); + } + + bool is_pos() const { + return m().is_pos(*this); + } + + bool is_neg() const { + return m().is_neg(*this); + } + + bool is_nonpos() const { + return m().is_nonpos(*this); + } + + bool is_nonneg() const { + return m().is_nonneg(*this); + } + friend bool is_zero(_scoped_numeral const & a) { return a.m().is_zero(a); }