3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-15 21:38:44 +00:00

optimize theory pb

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2014-02-25 18:06:54 -08:00
parent e180cfe256
commit 478b3160ac
4 changed files with 89 additions and 49 deletions

View file

@ -247,7 +247,7 @@ lbool pb_rewriter_util<PBU>::normalize(typename PBU::args_t& args, typename PBU:
// example: k = 5, min = 3, max = 4: 5/3 -> 2 5/4 -> 1, n = 2 // example: k = 5, min = 3, max = 4: 5/3 -> 2 5/4 -> 1, n = 2
// replace all coefficients by 1, and k by 2. // replace all coefficients by 1, and k by 2.
// //
if (!k.is_one()) { if (false && !k.is_one()) {
PBU::numeral min = args[0].second, max = args[0].second; PBU::numeral min = args[0].second, max = args[0].second;
for (unsigned i = 1; i < args.size(); ++i) { for (unsigned i = 1; i < args.size(); ++i) {
if (args[i].second < min) min = args[i].second; if (args[i].second < min) min = args[i].second;

View file

@ -424,7 +424,7 @@ namespace smt {
bool_var abv = ctx.mk_bool_var(atom); bool_var abv = ctx.mk_bool_var(atom);
ctx.set_var_theory(abv, get_id()); ctx.set_var_theory(abv, get_id());
ineq* c = alloc(ineq, literal(abv), m_util.is_eq(atom)); ineq* c = alloc(ineq, m_mpz_mgr, literal(abv), m_util.is_eq(atom));
c->m_args[0].m_k = m_util.get_k(atom); c->m_args[0].m_k = m_util.get_k(atom);
numeral& k = c->m_args[0].m_k; numeral& k = c->m_args[0].m_k;
arg_t& args = c->m_args[0]; arg_t& args = c->m_args[0];
@ -446,6 +446,7 @@ namespace smt {
else { else {
SASSERT(m_util.is_at_least_k(atom) || m_util.is_ge(atom) || m_util.is_eq(atom)); SASSERT(m_util.is_at_least_k(atom) || m_util.is_ge(atom) || m_util.is_eq(atom));
} }
TRACE("pb", display(tout, *c););
c->unique(); c->unique();
lbool is_true = c->normalize(); lbool is_true = c->normalize();
c->prune(); c->prune();
@ -479,10 +480,13 @@ namespace smt {
} }
// maximal coefficient: // maximal coefficient:
numeral& max_watch = c->m_max_watch; scoped_mpz& max_watch = c->m_max_watch;
max_watch = numeral::zero(); max_watch.reset();
for (unsigned i = 0; i < args.size(); ++i) { for (unsigned i = 0; i < args.size(); ++i) {
max_watch = std::max(max_watch, args[i].second); mpz const& num = args[i].second.to_mpq().numerator();
if (m_mpz_mgr.lt(max_watch, num)) {
max_watch = num;
}
} }
// pre-compile threshold for cardinality // pre-compile threshold for cardinality
@ -633,16 +637,17 @@ namespace smt {
watch.pop_back(); watch.pop_back();
SASSERT(ineq_index < c.watch_size()); SASSERT(ineq_index < c.watch_size());
numeral coeff = c.coeff(ineq_index); scoped_mpz coeff(m_mpz_mgr);
coeff = c.ncoeff(ineq_index);
if (ineq_index + 1 < c.watch_size()) { if (ineq_index + 1 < c.watch_size()) {
std::swap(c.args()[ineq_index], c.args()[c.watch_size()-1]); std::swap(c.args()[ineq_index], c.args()[c.watch_size()-1]);
} }
--c.m_watch_sz; --c.m_watch_sz;
c.m_watch_sum -= coeff; c.m_watch_sum -= coeff;
if (c.max_watch() == coeff) { if (coeff == c.max_watch()) {
coeff = c.coeff(0); coeff = c.ncoeff(0);
for (unsigned i = 1; coeff != c.max_watch() && i < c.watch_size(); ++i) { for (unsigned i = 1; coeff != c.max_watch() && i < c.watch_size(); ++i) {
if (coeff < c.coeff(i)) coeff = c.coeff(i); if (coeff < c.ncoeff(i)) coeff = c.ncoeff(i);
} }
c.set_max_watch(coeff); c.set_max_watch(coeff);
} }
@ -653,7 +658,8 @@ namespace smt {
void theory_pb::add_watch(ineq& c, unsigned i) { void theory_pb::add_watch(ineq& c, unsigned i) {
SASSERT(c.is_ge()); SASSERT(c.is_ge());
literal lit = c.lit(i); literal lit = c.lit(i);
numeral coeff = c.coeff(i); scoped_mpz coeff(m_mpz_mgr);
coeff = c.ncoeff(i);
c.m_watch_sum += coeff; c.m_watch_sum += coeff;
SASSERT(i >= c.watch_size()); SASSERT(i >= c.watch_size());
@ -834,18 +840,19 @@ namespace smt {
} }
literal_vector& theory_pb::get_helpful_literals(ineq& c, bool negate) { literal_vector& theory_pb::get_helpful_literals(ineq& c, bool negate) {
numeral sum = numeral::zero(); scoped_mpz sum(m_mpz_mgr);
mpz const& k = c.mpz_k();
context& ctx = get_context(); context& ctx = get_context();
literal_vector& lits = get_lits(); literal_vector& lits = get_lits();
for (unsigned i = 0; sum < c.k() && i < c.size(); ++i) { for (unsigned i = 0; sum < k && i < c.size(); ++i) {
literal l = c.lit(i); literal l = c.lit(i);
if (ctx.get_assignment(l) == l_true) { if (ctx.get_assignment(l) == l_true) {
sum += c.coeff(i); sum += c.ncoeff(i);
if (negate) l = ~l; if (negate) l = ~l;
lits.push_back(l); lits.push_back(l);
} }
} }
SASSERT(sum >= c.k()); SASSERT(sum >= k);
return lits; return lits;
} }
@ -891,8 +898,8 @@ namespace smt {
*/ */
void theory_pb::assign_ineq(ineq& c, bool is_true) { void theory_pb::assign_ineq(ineq& c, bool is_true) {
context& ctx = get_context(); context& ctx = get_context();
ctx.push_trail(value_trail<context, numeral>(c.m_max_sum)); ctx.push_trail(value_trail<context, scoped_mpz>(c.m_max_sum));
ctx.push_trail(value_trail<context, numeral>(c.m_min_sum)); ctx.push_trail(value_trail<context, scoped_mpz>(c.m_min_sum));
ctx.push_trail(value_trail<context, unsigned>(c.m_nfixed)); ctx.push_trail(value_trail<context, unsigned>(c.m_nfixed));
ctx.push_trail(rewatch_vars(*this, c)); ctx.push_trail(rewatch_vars(*this, c));
@ -904,15 +911,14 @@ namespace smt {
ctx.push_trail(negate_ineq(c)); ctx.push_trail(negate_ineq(c));
} }
numeral maxsum = numeral::zero(); scoped_mpz maxsum(m_mpz_mgr), mininc(m_mpz_mgr);
numeral mininc = numeral::zero();
for (unsigned i = 0; i < sz; ++i) { for (unsigned i = 0; i < sz; ++i) {
lbool asgn = ctx.get_assignment(c.lit(i)); lbool asgn = ctx.get_assignment(c.lit(i));
if (asgn != l_false) { if (asgn != l_false) {
maxsum += c.coeff(i); maxsum += c.ncoeff(i);
} }
if (asgn == l_undef && (mininc.is_zero() || mininc > c.coeff(i))) { if (asgn == l_undef && (mininc.is_zero() || mininc > c.ncoeff(i))) {
mininc = c.coeff(i); mininc = c.ncoeff(i);
} }
} }
@ -920,19 +926,19 @@ namespace smt {
tout << "assign: " << c.lit() << "\n"; tout << "assign: " << c.lit() << "\n";
display(tout, c); ); display(tout, c); );
if (maxsum < c.k()) { if (maxsum < c.mpz_k()) {
literal_vector& lits = get_unhelpful_literals(c, false); literal_vector& lits = get_unhelpful_literals(c, false);
lits.push_back(~c.lit()); lits.push_back(~c.lit());
add_clause(c, lits); add_clause(c, lits);
} }
else { else {
init_watch_literal(c); init_watch_literal(c);
SASSERT(c.watch_sum() >= c.k()); SASSERT(c.m_watch_sum >= c.mpz_k());
DEBUG_CODE(validate_watch(c);); DEBUG_CODE(validate_watch(c););
} }
// perform unit propagation // perform unit propagation
if (maxsum >= c.k() && maxsum - mininc < c.k()) { if (maxsum >= c.mpz_k() && maxsum - mininc < c.mpz_k()) {
literal_vector& lits = get_unhelpful_literals(c, true); literal_vector& lits = get_unhelpful_literals(c, true);
lits.push_back(c.lit()); lits.push_back(c.lit());
for (unsigned i = 0; i < sz; ++i) { for (unsigned i = 0; i < sz; ++i) {
@ -989,12 +995,12 @@ namespace smt {
SASSERT(i < c.size()); SASSERT(i < c.size());
if (c.lit(i).sign() == is_true) { if (c.lit(i).sign() == is_true) {
ctx.push_trail(value_trail<context, numeral>(c.m_max_sum)); ctx.push_trail(value_trail<context, scoped_mpz>(c.m_max_sum));
c.m_max_sum -= c.coeff(i); c.m_max_sum -= c.ncoeff(i);
} }
else { else {
ctx.push_trail(value_trail<context, numeral>(c.m_min_sum)); ctx.push_trail(value_trail<context, scoped_mpz>(c.m_min_sum));
c.m_min_sum += c.coeff(i); c.m_min_sum += c.ncoeff(i);
} }
DEBUG_CODE( DEBUG_CODE(
numeral sum(0); numeral sum(0);
@ -1336,14 +1342,17 @@ namespace smt {
void theory_pb::init_watch_literal(ineq& c) { void theory_pb::init_watch_literal(ineq& c) {
context& ctx = get_context(); context& ctx = get_context();
c.m_watch_sum = numeral::zero(); scoped_mpz max_k(m_mpz_mgr);
c.m_watch_sum.reset();
c.m_watch_sz = 0; c.m_watch_sz = 0;
c.m_max_watch = numeral::zero(); c.m_max_watch.reset();
bool watch_more = c.watch_sum() < c.k() + c.max_watch(); bool watch_more = true;
for (unsigned i = 0; watch_more && i < c.size(); ++i) { for (unsigned i = 0; watch_more && i < c.size(); ++i) {
if (ctx.get_assignment(c.lit(i)) != l_false) { if (ctx.get_assignment(c.lit(i)) != l_false) {
add_watch(c, i); add_watch(c, i);
watch_more = c.watch_sum() < c.k() + c.max_watch(); max_k = c.mpz_k();
max_k += c.max_watch();
watch_more = c.m_watch_sum < max_k;
} }
} }
ctx.push_trail(unwatch_ge(*this, c)); ctx.push_trail(unwatch_ge(*this, c));
@ -1358,7 +1367,7 @@ namespace smt {
c.m_watch_sz = 0; c.m_watch_sz = 0;
for (unsigned i = 0; i < c.size(); ++i) { for (unsigned i = 0; i < c.size(); ++i) {
watch_var(c.lit(i).var(), &c); watch_var(c.lit(i).var(), &c);
c.m_max_sum += c.coeff(i); c.m_max_sum += c.ncoeff(i);
} }
} }
@ -1561,6 +1570,7 @@ namespace smt {
unset_marks(); unset_marks();
m_num_marks = 0; m_num_marks = 0;
m_lemma.reset(); m_lemma.reset();
m_lemma.m_k.reset();
m_ineq_literals.reset(); m_ineq_literals.reset();
process_ineq(c, null_literal, numeral::one()); // add consequent to lemma. process_ineq(c, null_literal, numeral::one()); // add consequent to lemma.
@ -1918,12 +1928,12 @@ namespace smt {
} }
out << (c.is_ge()?" >= ":" = ") << c.k() << "\n"; out << (c.is_ge()?" >= ":" = ") << c.k() << "\n";
if (c.m_num_propagations) out << "propagations: " << c.m_num_propagations << " "; if (c.m_num_propagations) out << "propagations: " << c.m_num_propagations << " ";
if (c.max_watch().is_pos()) out << "max_watch: " << c.max_watch() << " "; if (c.m_max_watch.is_pos()) out << "max_watch: " << c.max_watch() << " ";
if (c.watch_size()) out << "watch size: " << c.watch_size() << " "; if (c.watch_size()) out << "watch size: " << c.watch_size() << " ";
if (c.watch_sum().is_pos()) out << "watch-sum: " << c.watch_sum() << " "; if (c.m_watch_sum.is_pos()) out << "watch-sum: " << c.watch_sum() << " ";
if (!c.max_sum().is_zero()) out << "sum: [" << c.min_sum() << ":" << c.max_sum() << "] "; if (!c.m_max_sum.is_zero()) out << "sum: [" << c.min_sum() << ":" << c.max_sum() << "] ";
if (c.m_num_propagations || c.max_watch().is_pos() || c.watch_size() || if (c.m_num_propagations || c.m_max_watch.is_pos() || c.watch_size() ||
c.watch_sum().is_pos() || !c.max_sum().is_zero()) out << "\n"; c.m_watch_sum.is_pos() || !c.m_max_sum.is_zero()) out << "\n";
return out; return out;
} }

View file

@ -105,23 +105,27 @@ namespace smt {
struct ineq { struct ineq {
unsynch_mpz_manager& m_mpz; // mpz manager.
literal m_lit; // literal repesenting predicate literal m_lit; // literal repesenting predicate
bool m_is_eq; // is this an = or >=. bool m_is_eq; // is this an = or >=.
arg_t m_args[2]; // encode args[0]*coeffs[0]+...+args[n-1]*coeffs[n-1] >= k(); arg_t m_args[2]; // encode args[0]*coeffs[0]+...+args[n-1]*coeffs[n-1] >= k();
// Watch the first few positions until the sum satisfies: // Watch the first few positions until the sum satisfies:
// sum coeffs[i] >= m_lower + max_watch // sum coeffs[i] >= m_lower + max_watch
numeral m_max_watch; // maximal coefficient. scoped_mpz m_max_watch; // maximal coefficient.
unsigned m_watch_sz; // number of literals being watched. unsigned m_watch_sz; // number of literals being watched.
numeral m_watch_sum; // maximal sum of watch literals. scoped_mpz m_watch_sum; // maximal sum of watch literals.
// Watch infrastructure for = and unassigned >=: // Watch infrastructure for = and unassigned >=:
unsigned m_nfixed; // number of variables that are fixed. unsigned m_nfixed; // number of variables that are fixed.
numeral m_max_sum; // maximal possible sum. scoped_mpz m_max_sum; // maximal possible sum.
numeral m_min_sum; // minimal possible sum. scoped_mpz m_min_sum; // minimal possible sum.
unsigned m_num_propagations; unsigned m_num_propagations;
unsigned m_compilation_threshold; unsigned m_compilation_threshold;
lbool m_compiled; lbool m_compiled;
ineq(literal l, bool is_eq) : m_lit(l), m_is_eq(is_eq) { ineq(unsynch_mpz_manager& m, literal l, bool is_eq) :
m_mpz(m), m_lit(l), m_is_eq(is_eq),
m_max_watch(m), m_watch_sum(m),
m_max_sum(m), m_min_sum(m) {
reset(); reset();
} }
@ -130,22 +134,24 @@ namespace smt {
literal lit() const { return m_lit; } literal lit() const { return m_lit; }
numeral const & k() const { return args().m_k; } numeral const & k() const { return args().m_k; }
mpz const & mpz_k() const { return k().to_mpq().numerator(); }
literal lit(unsigned i) const { return args()[i].first; } literal lit(unsigned i) const { return args()[i].first; }
numeral const & coeff(unsigned i) const { return args()[i].second; } numeral const & coeff(unsigned i) const { return args()[i].second; }
class mpz const& ncoeff(unsigned i) const { return coeff(i).to_mpq().numerator(); }
unsigned size() const { return args().size(); } unsigned size() const { return args().size(); }
numeral const& watch_sum() const { return m_watch_sum; } class mpz const& watch_sum() const { return m_watch_sum; }
numeral const& max_watch() const { return m_max_watch; } class mpz const& max_watch() const { return m_max_watch.get(); }
void set_max_watch(numeral const& n) { m_max_watch = n; } void set_max_watch(mpz const& n) { m_max_watch = n; }
unsigned watch_size() const { return m_watch_sz; } unsigned watch_size() const { return m_watch_sz; }
// variable watch infrastructure // variable watch infrastructure
numeral const& min_sum() const { return m_min_sum; } class mpz const& min_sum() const { return m_min_sum; }
numeral const& max_sum() const { return m_max_sum; } class mpz const& max_sum() const { return m_max_sum; }
unsigned nfixed() const { return m_nfixed; } unsigned nfixed() const { return m_nfixed; }
bool vwatch_initialized() const { return !max_sum().is_zero(); } bool vwatch_initialized() const { return !m_mpz.is_zero(max_sum()); }
void vwatch_reset() { m_min_sum.reset(); m_max_sum.reset(); m_nfixed = 0; } void vwatch_reset() { m_min_sum.reset(); m_max_sum.reset(); m_nfixed = 0; }
unsigned find_lit(bool_var v, unsigned begin, unsigned end) { unsigned find_lit(bool_var v, unsigned begin, unsigned end) {

View file

@ -97,6 +97,10 @@ public:
return a.m().eq(a, b); return a.m().eq(a, b);
} }
friend bool operator!=(_scoped_numeral const & a, numeral const & b) {
return !a.m().eq(a, b);
}
friend bool operator<(_scoped_numeral const & a, numeral const & b) { friend bool operator<(_scoped_numeral const & a, numeral const & b) {
return a.m().lt(a, b); return a.m().lt(a, b);
} }
@ -113,6 +117,26 @@ public:
return a.m().ge(a, b); return a.m().ge(a, b);
} }
bool is_zero() const {
return m().is_zero(*this);
}
bool is_pos() const {
return m().is_pos(*this);
}
bool is_neg() const {
return m().is_neg(*this);
}
bool is_nonpos() const {
return m().is_nonpos(*this);
}
bool is_nonneg() const {
return m().is_nonneg(*this);
}
friend bool is_zero(_scoped_numeral const & a) { friend bool is_zero(_scoped_numeral const & a) {
return a.m().is_zero(a); return a.m().is_zero(a);
} }