3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-23 09:05:31 +00:00

optimize theory pb

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2014-02-25 18:06:54 -08:00
parent e180cfe256
commit 478b3160ac
4 changed files with 89 additions and 49 deletions

View file

@ -424,7 +424,7 @@ namespace smt {
bool_var abv = ctx.mk_bool_var(atom);
ctx.set_var_theory(abv, get_id());
ineq* c = alloc(ineq, literal(abv), m_util.is_eq(atom));
ineq* c = alloc(ineq, m_mpz_mgr, literal(abv), m_util.is_eq(atom));
c->m_args[0].m_k = m_util.get_k(atom);
numeral& k = c->m_args[0].m_k;
arg_t& args = c->m_args[0];
@ -446,6 +446,7 @@ namespace smt {
else {
SASSERT(m_util.is_at_least_k(atom) || m_util.is_ge(atom) || m_util.is_eq(atom));
}
TRACE("pb", display(tout, *c););
c->unique();
lbool is_true = c->normalize();
c->prune();
@ -479,10 +480,13 @@ namespace smt {
}
// maximal coefficient:
numeral& max_watch = c->m_max_watch;
max_watch = numeral::zero();
scoped_mpz& max_watch = c->m_max_watch;
max_watch.reset();
for (unsigned i = 0; i < args.size(); ++i) {
max_watch = std::max(max_watch, args[i].second);
mpz const& num = args[i].second.to_mpq().numerator();
if (m_mpz_mgr.lt(max_watch, num)) {
max_watch = num;
}
}
// pre-compile threshold for cardinality
@ -633,16 +637,17 @@ namespace smt {
watch.pop_back();
SASSERT(ineq_index < c.watch_size());
numeral coeff = c.coeff(ineq_index);
scoped_mpz coeff(m_mpz_mgr);
coeff = c.ncoeff(ineq_index);
if (ineq_index + 1 < c.watch_size()) {
std::swap(c.args()[ineq_index], c.args()[c.watch_size()-1]);
}
--c.m_watch_sz;
c.m_watch_sum -= coeff;
if (c.max_watch() == coeff) {
coeff = c.coeff(0);
if (coeff == c.max_watch()) {
coeff = c.ncoeff(0);
for (unsigned i = 1; coeff != c.max_watch() && i < c.watch_size(); ++i) {
if (coeff < c.coeff(i)) coeff = c.coeff(i);
if (coeff < c.ncoeff(i)) coeff = c.ncoeff(i);
}
c.set_max_watch(coeff);
}
@ -653,7 +658,8 @@ namespace smt {
void theory_pb::add_watch(ineq& c, unsigned i) {
SASSERT(c.is_ge());
literal lit = c.lit(i);
numeral coeff = c.coeff(i);
scoped_mpz coeff(m_mpz_mgr);
coeff = c.ncoeff(i);
c.m_watch_sum += coeff;
SASSERT(i >= c.watch_size());
@ -834,18 +840,19 @@ namespace smt {
}
literal_vector& theory_pb::get_helpful_literals(ineq& c, bool negate) {
numeral sum = numeral::zero();
scoped_mpz sum(m_mpz_mgr);
mpz const& k = c.mpz_k();
context& ctx = get_context();
literal_vector& lits = get_lits();
for (unsigned i = 0; sum < c.k() && i < c.size(); ++i) {
for (unsigned i = 0; sum < k && i < c.size(); ++i) {
literal l = c.lit(i);
if (ctx.get_assignment(l) == l_true) {
sum += c.coeff(i);
sum += c.ncoeff(i);
if (negate) l = ~l;
lits.push_back(l);
}
}
SASSERT(sum >= c.k());
SASSERT(sum >= k);
return lits;
}
@ -891,8 +898,8 @@ namespace smt {
*/
void theory_pb::assign_ineq(ineq& c, bool is_true) {
context& ctx = get_context();
ctx.push_trail(value_trail<context, numeral>(c.m_max_sum));
ctx.push_trail(value_trail<context, numeral>(c.m_min_sum));
ctx.push_trail(value_trail<context, scoped_mpz>(c.m_max_sum));
ctx.push_trail(value_trail<context, scoped_mpz>(c.m_min_sum));
ctx.push_trail(value_trail<context, unsigned>(c.m_nfixed));
ctx.push_trail(rewatch_vars(*this, c));
@ -904,15 +911,14 @@ namespace smt {
ctx.push_trail(negate_ineq(c));
}
numeral maxsum = numeral::zero();
numeral mininc = numeral::zero();
scoped_mpz maxsum(m_mpz_mgr), mininc(m_mpz_mgr);
for (unsigned i = 0; i < sz; ++i) {
lbool asgn = ctx.get_assignment(c.lit(i));
if (asgn != l_false) {
maxsum += c.coeff(i);
maxsum += c.ncoeff(i);
}
if (asgn == l_undef && (mininc.is_zero() || mininc > c.coeff(i))) {
mininc = c.coeff(i);
if (asgn == l_undef && (mininc.is_zero() || mininc > c.ncoeff(i))) {
mininc = c.ncoeff(i);
}
}
@ -920,19 +926,19 @@ namespace smt {
tout << "assign: " << c.lit() << "\n";
display(tout, c); );
if (maxsum < c.k()) {
if (maxsum < c.mpz_k()) {
literal_vector& lits = get_unhelpful_literals(c, false);
lits.push_back(~c.lit());
add_clause(c, lits);
}
else {
init_watch_literal(c);
SASSERT(c.watch_sum() >= c.k());
SASSERT(c.m_watch_sum >= c.mpz_k());
DEBUG_CODE(validate_watch(c););
}
// perform unit propagation
if (maxsum >= c.k() && maxsum - mininc < c.k()) {
if (maxsum >= c.mpz_k() && maxsum - mininc < c.mpz_k()) {
literal_vector& lits = get_unhelpful_literals(c, true);
lits.push_back(c.lit());
for (unsigned i = 0; i < sz; ++i) {
@ -989,12 +995,12 @@ namespace smt {
SASSERT(i < c.size());
if (c.lit(i).sign() == is_true) {
ctx.push_trail(value_trail<context, numeral>(c.m_max_sum));
c.m_max_sum -= c.coeff(i);
ctx.push_trail(value_trail<context, scoped_mpz>(c.m_max_sum));
c.m_max_sum -= c.ncoeff(i);
}
else {
ctx.push_trail(value_trail<context, numeral>(c.m_min_sum));
c.m_min_sum += c.coeff(i);
ctx.push_trail(value_trail<context, scoped_mpz>(c.m_min_sum));
c.m_min_sum += c.ncoeff(i);
}
DEBUG_CODE(
numeral sum(0);
@ -1336,14 +1342,17 @@ namespace smt {
void theory_pb::init_watch_literal(ineq& c) {
context& ctx = get_context();
c.m_watch_sum = numeral::zero();
scoped_mpz max_k(m_mpz_mgr);
c.m_watch_sum.reset();
c.m_watch_sz = 0;
c.m_max_watch = numeral::zero();
bool watch_more = c.watch_sum() < c.k() + c.max_watch();
c.m_max_watch.reset();
bool watch_more = true;
for (unsigned i = 0; watch_more && i < c.size(); ++i) {
if (ctx.get_assignment(c.lit(i)) != l_false) {
add_watch(c, i);
watch_more = c.watch_sum() < c.k() + c.max_watch();
max_k = c.mpz_k();
max_k += c.max_watch();
watch_more = c.m_watch_sum < max_k;
}
}
ctx.push_trail(unwatch_ge(*this, c));
@ -1358,7 +1367,7 @@ namespace smt {
c.m_watch_sz = 0;
for (unsigned i = 0; i < c.size(); ++i) {
watch_var(c.lit(i).var(), &c);
c.m_max_sum += c.coeff(i);
c.m_max_sum += c.ncoeff(i);
}
}
@ -1561,6 +1570,7 @@ namespace smt {
unset_marks();
m_num_marks = 0;
m_lemma.reset();
m_lemma.m_k.reset();
m_ineq_literals.reset();
process_ineq(c, null_literal, numeral::one()); // add consequent to lemma.
@ -1918,12 +1928,12 @@ namespace smt {
}
out << (c.is_ge()?" >= ":" = ") << c.k() << "\n";
if (c.m_num_propagations) out << "propagations: " << c.m_num_propagations << " ";
if (c.max_watch().is_pos()) out << "max_watch: " << c.max_watch() << " ";
if (c.m_max_watch.is_pos()) out << "max_watch: " << c.max_watch() << " ";
if (c.watch_size()) out << "watch size: " << c.watch_size() << " ";
if (c.watch_sum().is_pos()) out << "watch-sum: " << c.watch_sum() << " ";
if (!c.max_sum().is_zero()) out << "sum: [" << c.min_sum() << ":" << c.max_sum() << "] ";
if (c.m_num_propagations || c.max_watch().is_pos() || c.watch_size() ||
c.watch_sum().is_pos() || !c.max_sum().is_zero()) out << "\n";
if (c.m_watch_sum.is_pos()) out << "watch-sum: " << c.watch_sum() << " ";
if (!c.m_max_sum.is_zero()) out << "sum: [" << c.min_sum() << ":" << c.max_sum() << "] ";
if (c.m_num_propagations || c.m_max_watch.is_pos() || c.watch_size() ||
c.m_watch_sum.is_pos() || !c.m_max_sum.is_zero()) out << "\n";
return out;
}

View file

@ -105,23 +105,27 @@ namespace smt {
struct ineq {
unsynch_mpz_manager& m_mpz; // mpz manager.
literal m_lit; // literal repesenting predicate
bool m_is_eq; // is this an = or >=.
arg_t m_args[2]; // encode args[0]*coeffs[0]+...+args[n-1]*coeffs[n-1] >= k();
// Watch the first few positions until the sum satisfies:
// sum coeffs[i] >= m_lower + max_watch
numeral m_max_watch; // maximal coefficient.
scoped_mpz m_max_watch; // maximal coefficient.
unsigned m_watch_sz; // number of literals being watched.
numeral m_watch_sum; // maximal sum of watch literals.
scoped_mpz m_watch_sum; // maximal sum of watch literals.
// Watch infrastructure for = and unassigned >=:
unsigned m_nfixed; // number of variables that are fixed.
numeral m_max_sum; // maximal possible sum.
numeral m_min_sum; // minimal possible sum.
scoped_mpz m_max_sum; // maximal possible sum.
scoped_mpz m_min_sum; // minimal possible sum.
unsigned m_num_propagations;
unsigned m_compilation_threshold;
lbool m_compiled;
ineq(literal l, bool is_eq) : m_lit(l), m_is_eq(is_eq) {
ineq(unsynch_mpz_manager& m, literal l, bool is_eq) :
m_mpz(m), m_lit(l), m_is_eq(is_eq),
m_max_watch(m), m_watch_sum(m),
m_max_sum(m), m_min_sum(m) {
reset();
}
@ -130,22 +134,24 @@ namespace smt {
literal lit() const { return m_lit; }
numeral const & k() const { return args().m_k; }
mpz const & mpz_k() const { return k().to_mpq().numerator(); }
literal lit(unsigned i) const { return args()[i].first; }
numeral const & coeff(unsigned i) const { return args()[i].second; }
class mpz const& ncoeff(unsigned i) const { return coeff(i).to_mpq().numerator(); }
unsigned size() const { return args().size(); }
numeral const& watch_sum() const { return m_watch_sum; }
numeral const& max_watch() const { return m_max_watch; }
void set_max_watch(numeral const& n) { m_max_watch = n; }
class mpz const& watch_sum() const { return m_watch_sum; }
class mpz const& max_watch() const { return m_max_watch.get(); }
void set_max_watch(mpz const& n) { m_max_watch = n; }
unsigned watch_size() const { return m_watch_sz; }
// variable watch infrastructure
numeral const& min_sum() const { return m_min_sum; }
numeral const& max_sum() const { return m_max_sum; }
class mpz const& min_sum() const { return m_min_sum; }
class mpz const& max_sum() const { return m_max_sum; }
unsigned nfixed() const { return m_nfixed; }
bool vwatch_initialized() const { return !max_sum().is_zero(); }
bool vwatch_initialized() const { return !m_mpz.is_zero(max_sum()); }
void vwatch_reset() { m_min_sum.reset(); m_max_sum.reset(); m_nfixed = 0; }
unsigned find_lit(bool_var v, unsigned begin, unsigned end) {