3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-08 18:31:49 +00:00

pb theory

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2013-11-17 10:39:33 -08:00
parent f6c5088cc9
commit f3721e5a15
2 changed files with 144 additions and 167 deletions

View file

@ -22,8 +22,38 @@ Notes:
#include "smt_context.h"
#include "ast_pp.h"
#include "sorting_network.h"
#include "uint_set.h"
namespace smt {
void theory_pb::ineq::negate() {
m_lit.neg();
numeral sum = 0;
for (unsigned i = 0; i < size(); ++i) {
m_args[i].first.neg();
sum += coeff(i);
}
m_k = sum - m_k + 1;
SASSERT(well_formed());
}
bool theory_pb::ineq::well_formed() const {
SASSERT(k() > 0);
uint_set vars;
numeral sum = 0;
for (unsigned i = 0; i < size(); ++i) {
SASSERT(coeff(i) <= k());
SASSERT(1 <= coeff(i));
SASSERT(lit(i) != true_literal);
SASSERT(lit(i) != false_literal);
SASSERT(lit(i) != null_literal);
SASSERT(!vars.contains(lit(i).var()));
vars.insert(lit(i).var());
sum += coeff(i);
}
SASSERT(sum >= k());
return true;
}
theory_pb::theory_pb(ast_manager& m):
theory(m.mk_family_id("card")),
@ -54,7 +84,7 @@ namespace smt {
bool_var abv = ctx.mk_bool_var(atom);
ctx.set_var_theory(abv, get_id());
ineq* c = alloc(ineq, atom, literal(abv));
ineq* c = alloc(ineq, literal(abv));
c->m_k = m_util.get_k(atom);
numeral& k = c->m_k;
arg_t& args = c->m_args;
@ -96,17 +126,6 @@ namespace smt {
max_coeff = std::max(max_coeff, args[i].second);
}
// compute watch literals:
numeral sum = 0;
unsigned wsz = 0;
while (sum < k + max_coeff && wsz < args.size()) {
sum += args[wsz].second;
wsz++;
}
for (unsigned i = 0; i < wsz; ++i) {
add_watch(*c, i);
}
// pre-compile threshold for cardinality
bool is_cardinality = true;
@ -359,6 +378,8 @@ namespace smt {
m_ineqs.reset();
m_ineqs_trail.reset();
m_ineqs_lim.reset();
m_assign_ineqs_trail.reset();
m_assign_ineqs_lim.reset();
m_stats.reset();
m_to_compile.reset();
}
@ -420,36 +441,38 @@ namespace smt {
*/
void theory_pb::assign_ineq(ineq& c, bool is_true) {
if (c.lit().sign() == is_true) {
c.negate();
}
context& ctx = get_context();
numeral sum = 0, maxsum = 0;
numeral maxsum = 0;
for (unsigned i = 0; i < c.size(); ++i) {
switch (ctx.get_assignment(c.lit(i))) {
case l_true:
sum += c.coeff(i);
// falll through
case l_undef:
if (ctx.get_assignment(c.lit(i)) != l_false) {
maxsum += c.coeff(i);
break;
default:
break;
}
}
lbool lit_assignment = ctx.get_assignment(c.lit());
TRACE("card",
tout << "assign: " << c.lit() << " <- " << is_true << "\n";
display(tout, c); );
if (sum >= c.k() && !is_true) {
literal_vector& lits = get_helpful_literals(c, true);
lits.push_back(c.lit());
add_clause(c, lits);
}
else if (maxsum < c.k() && is_true) {
if (maxsum < c.k()) {
literal_vector& lits = get_unhelpful_literals(c, true);
lits.push_back(~c.lit());
add_clause(c, lits);
return;
}
c.m_max_sum = 0;
c.m_watch_sz = 0;
for (unsigned i = 0; c.max_sum() < c.k() + c.max_coeff() && i < c.size(); ++i) {
if (ctx.get_assignment(c.lit(i)) != l_false) {
add_watch(c, i);
}
}
SASSERT(c.max_sum() >= c.k());
m_assign_ineqs_trail.push_back(&c);
}
/**
@ -463,136 +486,70 @@ namespace smt {
bool removed = false;
context& ctx = get_context();
ineq& c = *watch[watch_index];
numeral k = c.m_k;
unsigned w = c.find_lit(v, 0, c.watch_size());
numeral coeff = c.coeff(w);
SASSERT(ctx.get_assignment(c.lit()) == l_true);
if (is_true == c.lit(w).sign()) {
//
// max_sum is decreased.
// Adjust set of watched literals.
//
numeral k = c.k();
del_watch(watch, watch_index, c, w);
removed = true;
for (unsigned i = c.watch_size(); c.max_sum() < k + c.max_coeff() && i < c.size(); ++i) {
if (ctx.get_assignment(c.lit(i)) != l_false) {
add_watch(c, i);
}
}
if (c.max_sum() < k) {
//
// L: 3*x1 + 2*x2 + x4 >= 3, but x1 <- 0, x2 <- 0
// create clause x1 or x2 or ~L
//
literal_vector& lits = get_unhelpful_literals(c, false);
lits.push_back(~c.lit());
add_clause(c, lits);
}
else if (c.max_sum() < k + c.max_coeff()) {
//
// opportunities for unit propagation for unassigned
// literals whose coefficients satisfy
// c.max_sum() - coeff < k
//
// L: 3*x1 + 2*x2 + x4 >= 3, but x1 <- 0
// Create clauses x1 or ~L or x2
// x1 or ~L or x4
//
literal_vector& lits = get_unhelpful_literals(c, true);
lits.push_back(c.lit());
for (unsigned i = 0; i < c.size(); ++i) {
if (c.max_sum() - c.coeff(i) < k && ctx.get_assignment(c.lit(i)) == l_undef) {
add_assign(c, lits, c.lit(i));
}
}
}
//
// else: c.max_sum() >= k + c.max_coeff()
// we might miss opportunities for unit propagation if
// max_coeff is not the maximal coefficient
// of the current unassigned literal, but we can
// rely on eventually learning this from propagations.
//
}
//
// else: the current set of watch remain a potentially feasible assignment.
//
TRACE("card",
tout << "assign: " << literal(v) << " <- " << is_true << "\n";
display(tout, c); );
if (is_true == c.lit(w).sign()) {
//
// sum is not increased.
// Adjust set of watched literals.
//
numeral tmp_sum = c.sum();
for (unsigned i = c.watch_size(); c.max_sum() < k + c.max_coeff() + coeff && i < c.size(); ++i) {
lbool lit_assignment = ctx.get_assignment(c.lit(i));
switch(lit_assignment) {
case l_true:
tmp_sum += c.coeff(i);
// fall-through
case l_undef:
add_watch(c, i);
break;
case l_false:
break;
}
}
if (c.max_sum() >= k + coeff) {
del_watch(watch, watch_index, c, w);
SASSERT(c.max_sum() >= k);
removed = true;
}
SASSERT(tmp_sum <= c.max_sum());
TRACE("card", tout << "tmp_sum: " << tmp_sum << "\n"; display(tout, c); );
if (c.max_sum() < k) {
//
// c.lit() <- false
//
switch(ctx.get_assignment(c.lit())) {
case l_false:
break;
case l_true: {
literal_vector& lits = get_unhelpful_literals(c, true);
lits.push_back(~c.lit());
add_clause(c, lits);
break;
}
case l_undef: {
add_assign(c, get_unhelpful_literals(c, false), ~c.lit());
break;
}
}
}
else if (tmp_sum >= k) {
//
// c.lit() <- true
//
switch(ctx.get_assignment(c.lit())) {
case l_true:
break;
case l_false: {
literal_vector& lits = get_helpful_literals(c, true);
lits.push_back(c.lit());
add_clause(c, lits);
break;
}
case l_undef: {
add_assign(c, get_helpful_literals(c, false), c.lit());
break;
}
}
}
else if (c.max_sum() < k + c.max_coeff()) {
// tmp_sum < k <= c.max_sum()
// opportunities for unit propagation for unassigned
// literals whose coefficients satisfy
// c.max_sum() - coeff < k
if (l_true == ctx.get_assignment(c.lit())) {
literal_vector& lits = get_unhelpful_literals(c, true);
lits.push_back(c.lit());
numeral max_sum = c.max_sum() - coeff;
for (unsigned i = 0; i < c.size(); ++i) {
if (max_sum - c.coeff(i) < k && ctx.get_assignment(c.lit(i)) == l_undef) {
add_assign(c, lits, c.lit(i));
}
}
}
}
else {
// c.max_sum() >= k + c.max_coeff()
// tmp_sum < k <= c.max_sum() - c.max_coeff()
// we might miss opportunities for unit propagation if
// max_coeff is not the maximal coefficient
// of the current unassigned literal, but we can
// rely on eventually learning this from propagations.
}
}
else {
// sum is increased the current set of watch
// literals represent a potentially feasible assignment.
//
ctx.push_trail(value_trail<context, numeral>(c.m_sum));
c.m_sum += coeff;
}
if (c.sum() >= k) {
lbool ineq_assignment = ctx.get_assignment(c.lit());
switch(ineq_assignment) {
case l_true:
break;
case l_undef: {
add_assign(c, get_helpful_literals(c, false), c.lit());
break;
}
case l_false: {
literal_vector& lits = get_helpful_literals(c, true);
lits.push_back(c.lit());
add_clause(c, lits);
break;
}
}
}
// else if c.sum() < k and lit(w) was assigned to true:
// Progress was made.
// The watch list contains at least enough
// literals to force the assignment.
return removed;
}
@ -771,7 +728,7 @@ namespace smt {
literal thl = c.lit();
se.add_clause(~thl, at_least_k);
se.add_clause(thl, ~at_least_k);
TRACE("card", tout << mk_pp(c.m_app, m) << "\n";);
TRACE("card", tout << c.lit() << "\n";);
// auxiliary clauses get removed when popping scopes.
// we have to recompile the circuit after back-tracking.
c.m_compiled = l_false;
@ -781,21 +738,21 @@ namespace smt {
void theory_pb::init_search_eh() {
m_to_compile.reset();
}
void theory_pb::push_scope_eh() {
m_ineqs_lim.push_back(m_ineqs_trail.size());
m_assign_ineqs_lim.push_back(m_assign_ineqs_trail.size());
}
void theory_pb::pop_scope_eh(unsigned num_scopes) {
unsigned sz = m_ineqs_lim[m_ineqs_lim.size()-num_scopes];
while (m_ineqs_trail.size() > sz) {
bool_var v = m_ineqs_trail.back();
ineq* c = 0;
VERIFY(m_ineqs.find(v, c));
m_ineqs.remove(v);
m_ineqs_trail.pop_back();
// remove watched literals.
unsigned new_lim = m_assign_ineqs_lim.size()-num_scopes;
unsigned sz = m_assign_ineqs_lim[new_lim];
while (m_assign_ineqs_trail.size() > sz) {
ineq* c = m_assign_ineqs_trail.back();
for (unsigned i = 0; i < c->watch_size(); ++i) {
bool_var w = c->lit(i).var();
ptr_vector<ineq>* ineqs = 0;
@ -808,9 +765,22 @@ namespace smt {
}
}
}
m_assign_ineqs_trail.pop_back();
}
m_assign_ineqs_lim.resize(new_lim);
// remove inequalities.
new_lim = m_ineqs_lim.size()-num_scopes;
sz = m_ineqs_lim[new_lim];
while (m_ineqs_trail.size() > sz) {
bool_var v = m_ineqs_trail.back();
ineq* c = 0;
VERIFY(m_ineqs.find(v, c));
m_ineqs.remove(v);
m_ineqs_trail.pop_back();
dealloc(c);
}
m_ineqs_lim.resize(m_ineqs_lim.size()-num_scopes);
m_ineqs_lim.resize(new_lim);
}
void theory_pb::display(std::ostream& out) const {
@ -832,7 +802,10 @@ namespace smt {
std::ostream& theory_pb::display(std::ostream& out, ineq& c) const {
ast_manager& m = get_manager();
out << mk_pp(c.m_app, m) << "\n";
context& ctx = get_context();
expr_ref tmp(m);
ctx.literal2expr(c.lit(), tmp);
out << tmp << "\n";
for (unsigned i = 0; i < c.size(); ++i) {
out << c.coeff(i) << "*" << c.lit(i);
if (i + 1 < c.size()) {

View file

@ -42,7 +42,6 @@ namespace smt {
struct ineq {
app* m_app;
literal m_lit; // literal repesenting predicate
arg_t m_args; // encode args[0]*coeffs[0]+...+args[n-1]*coeffs[n-1] >= m_k;
numeral m_k; // invariants: m_k > 0, coeffs[i] > 0
@ -58,8 +57,7 @@ namespace smt {
unsigned m_compilation_threshold;
lbool m_compiled;
ineq(app* a, literal l):
m_app(a),
ineq(literal l):
m_lit(l),
m_max_coeff(0),
m_watch_sz(0),
@ -91,6 +89,10 @@ namespace smt {
}
return begin;
}
void negate();
bool well_formed() const;
};
typedef ptr_vector<ineq> watch_list;
@ -99,6 +101,8 @@ namespace smt {
u_map<ineq*> m_ineqs; // per inequality.
unsigned_vector m_ineqs_trail;
unsigned_vector m_ineqs_lim;
ptr_vector<ineq> m_assign_ineqs_trail;
unsigned_vector m_assign_ineqs_lim;
literal_vector m_literals; // temporary vector
card_util m_util;
stats m_stats;