3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-23 09:05:31 +00:00

adding roundingSat strategy

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2018-08-31 20:25:49 -05:00
parent 7230461671
commit 43807a7edc
3 changed files with 401 additions and 132 deletions

View file

@ -1009,21 +1009,6 @@ namespace sat {
// ---------------------------
// conflict resolution
void ba_solver::normalize_active_coeffs() {
reset_active_var_set();
unsigned i = 0, j = 0, sz = m_active_vars.size();
for (; i < sz; ++i) {
bool_var v = m_active_vars[i];
if (!m_active_var_set.contains(v) && get_coeff(v) != 0) {
m_active_var_set.insert(v);
if (j != i) {
m_active_vars[j] = m_active_vars[i];
}
++j;
}
}
m_active_vars.shrink(j);
}
void ba_solver::inc_coeff(literal l, unsigned offset) {
SASSERT(offset > 0);
@ -1066,44 +1051,46 @@ namespace sat {
return m_coeffs.get(v, 0);
}
uint64_t ba_solver::get_coeff(literal lit) const {
int64_t c1 = get_coeff(lit.var());
SASSERT(c1 < 0 == lit.sign());
uint64_t c = static_cast<uint64_t>(std::abs(c1));
m_overflow |= c != c1;
return c;
}
void ba_solver::get_coeff(bool_var v, literal& l, unsigned& c) {
int64_t c1 = get_coeff(v);
l = literal(v, c1 < 0);
c1 = std::abs(c1);
c = static_cast<unsigned>(c1);
m_overflow |= c != c1;
}
unsigned ba_solver::get_abs_coeff(bool_var v) const {
int64_t c = get_coeff(v);
if (c < INT_MIN+1 || c > UINT_MAX) {
m_overflow = true;
return UINT_MAX;
}
return static_cast<unsigned>(std::abs(c));
int64_t c1 = std::abs(get_coeff(v));
unsigned c = static_cast<unsigned>(c1);
m_overflow |= c != c1;
return c;
}
int ba_solver::get_int_coeff(bool_var v) const {
int64_t c = m_coeffs.get(v, 0);
if (c < INT_MIN || c > INT_MAX) {
m_overflow = true;
return 0;
}
return static_cast<int>(c);
int64_t c1 = m_coeffs.get(v, 0);
int c = static_cast<int>(c1);
m_overflow |= c != c1;
return c;
}
void ba_solver::inc_bound(int64_t i) {
if (i < INT_MIN || i > INT_MAX) {
m_overflow = true;
return;
}
int64_t new_bound = m_bound;
new_bound += i;
if (new_bound < 0) {
m_overflow = true;
}
else if (new_bound > UINT_MAX) {
m_overflow = true;
}
else {
m_bound = static_cast<unsigned>(new_bound);
}
unsigned nb = static_cast<unsigned>(new_bound);
m_overflow |= new_bound < 0 || nb != new_bound;
m_bound = nb;
}
void ba_solver::reset_coeffs() {
for (unsigned i = 0; i < m_active_vars.size(); ++i) {
for (unsigned i = m_active_vars.size(); i-- > 0; ) {
m_coeffs[m_active_vars[i]] = 0;
}
m_active_vars.reset();
@ -1115,7 +1102,47 @@ namespace sat {
// #define DEBUG_CODE(_x_) _x_
lbool ba_solver::resolve_conflict() {
void ba_solver::bail_resolve_conflict(unsigned idx) {
m_overflow = false;
literal_vector const& lits = s().m_trail;
while (m_num_marks > 0) {
bool_var v = lits[idx].var();
if (s().is_marked(v)) {
s().reset_mark(v);
--m_num_marks;
}
if (idx == 0 && !_debug_conflict) {
_debug_conflict = true;
_debug_var2position.reserve(s().num_vars());
for (unsigned i = 0; i < lits.size(); ++i) {
_debug_var2position[lits[i].var()] = i;
}
IF_VERBOSE(0,
active2pb(m_A);
uint64_t c = 0;
for (uint64_t c1 : m_A.m_coeffs) c += c1;
verbose_stream() << "sum of coefficients: " << c << "\n";
display(verbose_stream(), m_A, true);
verbose_stream() << "conflicting literal: " << s().m_not_l << "\n";);
for (literal l : lits) {
if (s().is_marked(l.var())) {
IF_VERBOSE(0, verbose_stream() << "missing mark: " << l << "\n";);
s().reset_mark(l.var());
}
}
m_num_marks = 0;
resolve_conflict();
}
--idx;
}
}
lbool ba_solver::resolve_conflict() {
#if 1
return resolve_conflict_rs();
#endif
if (0 == m_num_propagations_since_pop) {
return l_undef;
}
@ -1256,7 +1283,6 @@ namespace sat {
process_next_resolvent:
// find the next marked variable in the assignment stack
//
bool_var v;
while (true) {
consequent = lits[idx];
@ -1290,7 +1316,6 @@ namespace sat {
DEBUG_CODE(for (bool_var i = 0; i < static_cast<bool_var>(s().num_vars()); ++i) SASSERT(!s().is_marked(i)););
SASSERT(validate_lemma());
normalize_active_coeffs();
if (!create_asserting_lemma()) {
goto bail_out;
@ -1315,43 +1340,243 @@ namespace sat {
return l_true;
bail_out:
bail_resolve_conflict(idx);
return l_undef;
}
m_overflow = false;
uint64_t ba_solver::ineq::coeff(literal l) const {
bool_var v = l.var();
for (unsigned i = size(); i-- > 0; ) {
if (lit(i).var() == v) return coeff(i);
}
UNREACHABLE();
return 0;
}
void ba_solver::ineq::divide(uint64_t c) {
if (c == 1) return;
for (unsigned i = size(); i-- > 0; ) {
m_coeffs[i] = (m_coeffs[i] + c - 1) / c;
}
m_k = (m_k + c - 1) / c;
}
/**
* Remove literal at position i, subtract coefficient from bound.
*/
void ba_solver::ineq::weaken(unsigned i) {
uint64_t ci = coeff(i);
SASSERT(m_k >= ci);
m_k -= ci;
m_lits[i] = m_lits.back();
m_coeffs[i] = m_coeffs.back();
m_lits.pop_back();
m_coeffs.pop_back();
}
/**
* Round coefficient of inequality to 1.
*/
void ba_solver::round_to_one(ineq& ineq, literal lit) {
uint64_t c = ineq.coeff(lit);
if (c == 1) return;
unsigned sz = ineq.size();
for (unsigned i = 0; i < sz; ++i) {
uint64_t ci = ineq.coeff(i);
if (ci % c != 0 && !is_false(ineq.lit(i))) {
ineq.weaken(i);
--i;
--sz;
}
}
ineq.divide(c);
}
void ba_solver::round_to_one(literal lit) {
uint64_t c = get_coeff(lit);
if (c == 1) return;
for (bool_var v : m_active_vars) {
literal l;
unsigned ci;
get_coeff(v, l, ci);
if (ci > 0 && ci % c != 0 && !is_false(l)) {
m_coeffs[v] = 0;
}
}
divide(c);
}
void ba_solver::divide(uint64_t c) {
SASSERT(c != 0);
if (c == 1) return;
reset_active_var_set();
unsigned j = 0, sz = m_active_vars.size();
for (unsigned i = 0; i < sz; ++i) {
bool_var v = m_active_vars[i];
int ci = get_int_coeff(v);
if (m_active_var_set.contains(v) || ci == 0) continue;
m_active_var_set.insert(v);
if (ci > 0) {
m_coeffs[v] = (ci + c - 1) / c;
}
else {
m_coeffs[v] = -static_cast<int64_t>((-ci + c - 1) / c);
}
m_active_vars[j++] = v;
}
m_active_vars.shrink(j);
if (m_bound % c != 0) {
++m_stats.m_num_cut;
m_bound = static_cast<unsigned>((m_bound + c - 1) / c);
}
}
void ba_solver::resolve_on(literal consequent) {
round_to_one(consequent);
m_coeffs[consequent.var()] = 0;
}
void ba_solver::resolve_with(ineq const& ineq) {
TRACE("ba", display(tout, ineq, true););
inc_bound(1 + ineq.m_k);
for (unsigned i = ineq.size(); i-- > 0; ) {
literal l = ineq.lit(i);
inc_coeff(l, static_cast<unsigned>(ineq.coeff(i)));
}
}
void ba_solver::reset_marks(unsigned idx) {
while (m_num_marks > 0) {
bool_var v = lits[idx].var();
SASSERT(idx > 0);
bool_var v = s().m_trail[idx].var();
if (s().is_marked(v)) {
s().reset_mark(v);
--m_num_marks;
}
if (idx == 0 && !_debug_conflict) {
_debug_conflict = true;
_debug_var2position.reserve(s().num_vars());
for (unsigned i = 0; i < lits.size(); ++i) {
_debug_var2position[lits[i].var()] = i;
}
IF_VERBOSE(0,
active2pb(m_A);
uint64_t c = 0;
for (uint64_t c1 : m_A.m_coeffs) c += c1;
verbose_stream() << "sum of coefficients: " << c << "\n";
display(verbose_stream(), m_A, true);
verbose_stream() << "conflicting literal: " << s().m_not_l << "\n";);
for (literal l : lits) {
if (s().is_marked(l.var())) {
IF_VERBOSE(0, verbose_stream() << "missing mark: " << l << "\n";);
s().reset_mark(l.var());
}
}
m_num_marks = 0;
resolve_conflict();
}
--idx;
}
}
lbool ba_solver::resolve_conflict_rs() {
if (0 == m_num_propagations_since_pop) {
return l_undef;
}
m_overflow = false;
reset_coeffs();
m_num_marks = 0;
m_bound = 0;
literal consequent = s().m_not_l;
justification js = s().m_conflict;
TRACE("ba", tout << consequent << " " << js << "\n";);
m_conflict_lvl = s().get_max_lvl(consequent, js);
if (consequent != null_literal) {
consequent.neg();
process_antecedent(consequent, 1);
}
unsigned idx = s().m_trail.size() - 1;
do {
// TBD: termination condition
// if UIP is below m_conflict level
TRACE("ba", s().display_justification(tout << "process consequent: " << consequent << " : ", js) << "\n";
active2pb(m_A); display(tout, m_A, true);
);
switch (js.get_kind()) {
case justification::NONE:
SASSERT(consequent != null_literal);
resolve_on(consequent);
break;
case justification::BINARY:
SASSERT(consequent != null_literal);
resolve_on(consequent);
process_antecedent(js.get_literal());
break;
case justification::TERNARY:
SASSERT(consequent != null_literal);
resolve_on(consequent);
process_antecedent(js.get_literal1());
process_antecedent(js.get_literal2());
break;
case justification::CLAUSE: {
clause & c = s().get_clause(js);
unsigned i = 0;
if (consequent == null_literal) {
m_bound = 1;
}
else {
resolve_on(consequent);
if (c[0] == consequent) {
i = 1;
}
else {
SASSERT(c[1] == consequent);
process_antecedent(c[0]);
i = 2;
}
}
unsigned sz = c.size();
for (; i < sz; i++)
process_antecedent(c[i]);
break;
}
case justification::EXT_JUSTIFICATION: {
++m_stats.m_num_resolves;
ext_justification_idx index = js.get_ext_justification_idx();
constraint& cnstr = index2constraint(index);
constraint2pb(cnstr, consequent, 1, m_A);
if (consequent == null_literal) {
m_bound = static_cast<unsigned>(m_A.m_k);
for (unsigned i = m_A.size(); i-- > 0; ) {
inc_coeff(m_A.lit(i), static_cast<unsigned>(m_A.coeff(i)));
}
}
else {
round_to_one(consequent);
round_to_one(m_A, consequent);
resolve_with(m_A);
}
break;
}
default:
UNREACHABLE();
break;
}
cut();
// find the next marked variable in the assignment stack
bool_var v;
while (true) {
consequent = s().m_trail[idx];
v = consequent.var();
if (s().is_marked(v)) break;
if (idx == 0) {
goto bail_out;
}
--idx;
}
SASSERT(lvl(v) == m_conflict_lvl);
s().reset_mark(v);
--idx;
--m_num_marks;
js = s().m_justification[v];
}
while (m_num_marks > 0 && !m_overflow);
TRACE("ba", active2pb(m_A); display(tout, m_A, true););
active2constraint();
if (!m_overflow) {
return l_true;
}
bail_out:
m_overflow = false;
return l_undef;
}
bool ba_solver::create_asserting_lemma() {
bool adjusted = false;
@ -1461,10 +1686,17 @@ namespace sat {
}
if (g >= 2) {
normalize_active_coeffs();
for (bool_var v : m_active_vars) {
reset_active_var_set();
unsigned j = 0, sz = m_active_vars.size();
for (unsigned i = 0; i < sz; ++i) {
bool_var v = m_active_vars[i];
int64_t c = m_coeffs[v];
if (m_active_var_set.contains(v) || c == 0) continue;
m_active_var_set.insert(v);
m_coeffs[v] /= static_cast<int>(g);
m_active_vars[j++] = v;
}
m_active_vars.shrink(j);
m_bound = (m_bound + g - 1) / g;
++m_stats.m_num_cut;
}
@ -1502,7 +1734,6 @@ namespace sat {
if (level > 0 && !s().is_marked(v) && level == m_conflict_lvl) {
s().mark(v);
TRACE("sat_verbose", tout << "Mark: v" << v << "\n";);
++m_num_marks;
if (_debug_conflict && _debug_consequent != null_literal && _debug_var2position[_debug_consequent.var()] < _debug_var2position[l.var()]) {
IF_VERBOSE(0, verbose_stream() << "antecedent " << l << " is above consequent in stack\n";);
@ -1551,10 +1782,10 @@ namespace sat {
if (k == 1 && lit == null_literal) {
literal_vector _lits(lits);
s().mk_clause(_lits.size(), _lits.c_ptr(), learned);
return 0;
return nullptr;
}
if (!learned && clausify(lit, lits.size(), lits.c_ptr(), k)) {
return 0;
return nullptr;
}
void * mem = m_allocator.allocate(card::get_obj_size(lits.size()));
card* c = new (mem) card(next_id(), lit, lits, k);
@ -1615,7 +1846,7 @@ namespace sat {
bool units = true;
for (wliteral wl : wlits) units &= wl.first == 1;
if (k == 0 && lit == null_literal) {
return 0;
return nullptr;
}
if (units || k == 1) {
literal_vector lits;
@ -3612,7 +3843,7 @@ namespace sat {
}
}
void ba_solver::display(std::ostream& out, ineq& ineq, bool values) const {
void ba_solver::display(std::ostream& out, ineq const& ineq, bool values) const {
for (unsigned i = 0; i < ineq.m_lits.size(); ++i) {
out << ineq.m_coeffs[i] << "*" << ineq.m_lits[i] << " ";
if (values) out << value(ineq.m_lits[i]) << " ";
@ -3824,37 +4055,38 @@ namespace sat {
p.reset(m_bound);
for (bool_var v : m_active_vars) {
if (m_active_var_set.contains(v)) continue;
int64_t coeff = get_coeff(v);
unsigned coeff;
literal lit;
get_coeff(v, lit, coeff);
if (coeff == 0) continue;
m_active_var_set.insert(v);
literal lit(v, coeff < 0);
p.m_lits.push_back(lit);
p.m_coeffs.push_back(std::abs(coeff));
p.m_coeffs.push_back(coeff);
}
}
ba_solver::constraint* ba_solver::active2constraint() {
void ba_solver::active2wlits() {
reset_active_var_set();
m_wlits.reset();
uint64_t sum = 0;
if (m_bound == 1) return 0;
if (m_overflow) return 0;
uint64_t sum = 0;
for (bool_var v : m_active_vars) {
int coeff = get_int_coeff(v);
unsigned coeff;
literal lit;
get_coeff(v, lit, coeff);
if (m_active_var_set.contains(v) || coeff == 0) continue;
m_active_var_set.insert(v);
literal lit(v, coeff < 0);
m_wlits.push_back(wliteral(get_abs_coeff(v), lit));
sum += get_abs_coeff(v);
m_wlits.push_back(wliteral(static_cast<unsigned>(coeff), lit));
sum += coeff;
}
m_overflow |= sum >= UINT_MAX/2;
}
if (m_overflow || sum >= UINT_MAX/2) {
return 0;
ba_solver::constraint* ba_solver::active2constraint() {
active2wlits();
if (m_overflow) {
return nullptr;
}
else {
return add_pb_ge(null_literal, m_wlits, m_bound, true);
}
return add_pb_ge(null_literal, m_wlits, m_bound, true);
}
/*
@ -3889,11 +4121,9 @@ namespace sat {
ba_solver::constraint* ba_solver::active2card() {
normalize_active_coeffs();
m_wlits.reset();
for (bool_var v : m_active_vars) {
int coeff = get_int_coeff(v);
m_wlits.push_back(std::make_pair(get_abs_coeff(v), literal(v, coeff < 0)));
active2wlits();
if (m_overflow) {
return nullptr;
}
std::sort(m_wlits.begin(), m_wlits.end(), compare_wlit());
unsigned k = 0;
@ -3905,7 +4135,7 @@ namespace sat {
++k;
}
if (k == 1) {
return 0;
return nullptr;
}
while (!m_wlits.empty()) {
wliteral wl = m_wlits.back();
@ -3928,7 +4158,9 @@ namespace sat {
++num_max_level;
}
}
if (m_overflow) return 0;
if (m_overflow) {
return nullptr;
}
if (slack >= k) {
#if 0
@ -3963,15 +4195,18 @@ namespace sat {
void ba_solver::justification2pb(justification const& js, literal lit, unsigned offset, ineq& ineq) {
switch (js.get_kind()) {
case justification::NONE:
SASSERT(lit != null_literal);
ineq.reset(offset);
ineq.push(lit, offset);
break;
case justification::BINARY:
SASSERT(lit != null_literal);
ineq.reset(offset);
ineq.push(lit, offset);
ineq.push(js.get_literal(), offset);
break;
case justification::TERNARY:
SASSERT(lit != null_literal);
ineq.reset(offset);
ineq.push(lit, offset);
ineq.push(js.get_literal1(), offset);
@ -3986,35 +4221,7 @@ namespace sat {
case justification::EXT_JUSTIFICATION: {
ext_justification_idx index = js.get_ext_justification_idx();
constraint& cnstr = index2constraint(index);
switch (cnstr.tag()) {
case card_t: {
card& c = cnstr.to_card();
ineq.reset(offset*c.k());
for (literal l : c) ineq.push(l, offset);
if (c.lit() != null_literal) ineq.push(~c.lit(), offset*c.k());
break;
}
case pb_t: {
pb& p = cnstr.to_pb();
ineq.reset(p.k());
for (wliteral wl : p) ineq.push(wl.second, wl.first);
if (p.lit() != null_literal) ineq.push(~p.lit(), p.k());
break;
}
case xr_t: {
xr& x = cnstr.to_xr();
literal_vector ls;
get_antecedents(lit, x, ls);
ineq.reset(offset);
for (literal l : ls) ineq.push(~l, offset);
literal lxr = x.lit();
if (lxr != null_literal) ineq.push(~lxr, offset);
break;
}
default:
UNREACHABLE();
break;
}
constraint2pb(cnstr, lit, offset, ineq);
break;
}
default:
@ -4023,6 +4230,38 @@ namespace sat {
}
}
void ba_solver::constraint2pb(constraint& cnstr, literal lit, unsigned offset, ineq& ineq) {
switch (cnstr.tag()) {
case card_t: {
card& c = cnstr.to_card();
ineq.reset(offset*c.k());
for (literal l : c) ineq.push(l, offset);
if (c.lit() != null_literal) ineq.push(~c.lit(), offset*c.k());
break;
}
case pb_t: {
pb& p = cnstr.to_pb();
ineq.reset(offset * p.k());
for (wliteral wl : p) ineq.push(wl.second, offset * wl.first);
if (p.lit() != null_literal) ineq.push(~p.lit(), offset * p.k());
break;
}
case xr_t: {
xr& x = cnstr.to_xr();
literal_vector ls;
SASSERT(lit != null_literal);
get_antecedents(lit, x, ls);
ineq.reset(offset);
for (literal l : ls) ineq.push(~l, offset);
literal lxr = x.lit();
if (lxr != null_literal) ineq.push(~lxr, offset);
break;
}
default:
UNREACHABLE();
break;
}
}
// validate that m_A & m_B implies m_C

View file

@ -208,8 +208,14 @@ namespace sat {
svector<uint64_t> m_coeffs;
uint64_t m_k;
ineq(): m_k(0) {}
unsigned size() const { return m_lits.size(); }
literal lit(unsigned i) const { return m_lits[i]; }
uint64_t coeff(unsigned i) const { return m_coeffs[i]; }
void reset(uint64_t k) { m_lits.reset(); m_coeffs.reset(); m_k = k; }
void push(literal l, uint64_t c) { m_lits.push_back(l); m_coeffs.push_back(c); }
uint64_t coeff(literal lit) const;
void divide(uint64_t c);
void weaken(unsigned i);
};
solver* m_solver;
@ -396,10 +402,22 @@ namespace sat {
lbool eval(model const& m, pb const& p) const;
double get_reward(pb const& p, literal_occs_fun& occs) const;
// RoundingPb conflict resolution
lbool resolve_conflict_rs();
void round_to_one(ineq& ineq, literal lit);
void round_to_one(literal lit);
void divide(uint64_t c);
void resolve_on(literal lit);
void resolve_with(ineq const& ineq);
void reset_marks(unsigned idx);
void bail_resolve_conflict(unsigned idx);
// access solver
inline lbool value(bool_var v) const { return value(literal(v, false)); }
inline lbool value(literal lit) const { return m_lookahead ? m_lookahead->value(lit) : m_solver->value(lit); }
inline lbool value(model const& m, literal l) const { return l.sign() ? ~m[l.var()] : m[l.var()]; }
inline bool is_false(literal lit) const { return l_false == value(lit); }
inline unsigned lvl(literal lit) const { return m_lookahead || m_unit_walk ? 0 : m_solver->lvl(lit); }
inline unsigned lvl(bool_var v) const { return m_lookahead || m_unit_walk ? 0 : m_solver->lvl(v); }
@ -426,9 +444,10 @@ namespace sat {
mutable bool m_overflow;
void reset_active_var_set();
void normalize_active_coeffs();
void inc_coeff(literal l, unsigned offset);
int64_t get_coeff(bool_var v) const;
uint64_t get_coeff(literal lit) const;
void get_coeff(bool_var v, literal& l, unsigned& c);
unsigned get_abs_coeff(bool_var v) const;
int get_int_coeff(bool_var v) const;
unsigned get_bound() const;
@ -436,6 +455,7 @@ namespace sat {
literal get_asserting_literal(literal conseq);
void process_antecedent(literal l, unsigned offset);
void process_antecedent(literal l) { process_antecedent(l, 1); }
void process_card(card& c, unsigned offset);
void cut();
bool create_asserting_lemma();
@ -466,10 +486,13 @@ namespace sat {
void active2pb(ineq& p);
constraint* active2constraint();
constraint* active2card();
void active2wlits();
void justification2pb(justification const& j, literal lit, unsigned offset, ineq& p);
void constraint2pb(constraint& cnstr, literal lit, unsigned offset, ineq& p);
bool validate_resolvent();
unsigned get_coeff(ineq const& pb, literal lit);
void display(std::ostream& out, ineq& p, bool values = false) const;
void display(std::ostream& out, ineq const& p, bool values = false) const;
void display(std::ostream& out, card const& c, bool values) const;
void display(std::ostream& out, pb const& p, bool values) const;
void display(std::ostream& out, xr const& c, bool values) const;