3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-08 18:31:49 +00:00

adding uhle/uhte for faster asymmetric branching

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2017-11-29 21:21:56 -08:00
parent 26bd784b1f
commit da0aa71082
4 changed files with 221 additions and 57 deletions

View file

@ -19,6 +19,7 @@ Revision History:
#include "sat/sat_asymm_branch.h"
#include "sat/sat_asymm_branch_params.hpp"
#include "sat/sat_solver.h"
#include "sat/sat_scc.h"
#include "util/stopwatch.h"
#include "util/trace.h"
@ -26,6 +27,7 @@ namespace sat {
asymm_branch::asymm_branch(solver & _s, params_ref const & p):
s(_s),
m_params(p),
m_counter(0) {
updt_params(p);
reset_statistics();
@ -59,12 +61,12 @@ namespace sat {
void asymm_branch::process(clause_vector& clauses) {
int64 limit = -m_asymm_branch_limit;
std::stable_sort(s.m_clauses.begin(), s.m_clauses.end(), clause_size_lt());
m_counter -= s.m_clauses.size();
std::stable_sort(clauses.begin(), clauses.end(), clause_size_lt());
m_counter -= clauses.size();
SASSERT(s.m_qhead == s.m_trail.size());
clause_vector::iterator it = s.m_clauses.begin();
clause_vector::iterator it = clauses.begin();
clause_vector::iterator it2 = it;
clause_vector::iterator end = s.m_clauses.end();
clause_vector::iterator end = clauses.end();
try {
for (; it != end; ++it) {
if (s.inconsistent()) {
@ -86,14 +88,14 @@ namespace sat {
*it2 = *it;
++it2;
}
s.m_clauses.set_end(it2);
clauses.set_end(it2);
}
catch (solver_exception & ex) {
// put m_clauses in a consistent state...
for (; it != end; ++it, ++it2) {
*it2 = *it;
}
s.m_clauses.set_end(it2);
clauses.set_end(it2);
m_counter = -m_counter;
throw ex;
}
@ -143,6 +145,99 @@ namespace sat {
return true;
}
void asymm_branch::setup_big() {
scc scc(s, m_params);
vector<literal_vector> const& big = scc.get_big(true); // include learned binary clauses
}
struct asymm_branch::compare_left {
scc& s;
compare_left(scc& s): s(s) {}
bool operator()(literal u, literal v) const {
return s.get_left(u) < s.get_left(v);
}
};
void asymm_branch::sort(scc& scc, clause const& c) {
m_pos.reset(); m_neg.reset();
for (literal l : c) {
m_pos.push_back(l);
m_neg.push_back(~l);
}
compare_left cmp(scc);
std::sort(m_pos.begin(), m_pos.end(), cmp);
std::sort(m_neg.begin(), m_neg.end(), cmp);
}
bool asymm_branch::uhte(scc& scc, clause & c) {
unsigned pindex = 0, nindex = 0;
literal lpos = m_pos[pindex++];
literal lneg = m_neg[nindex++];
while (true) {
if (scc.get_left(lneg) > scc.get_left(lpos)) {
if (pindex == m_pos.size()) return false;
lpos = m_pos[pindex++];
}
else if (scc.get_right(lneg) < scc.get_right(lpos) ||
(m_pos.size() == 2 && (lpos == ~lneg || scc.get_parent(lpos) == lneg))) {
if (nindex == m_neg.size()) return false;
lneg = m_neg[nindex++];
}
else {
return true;
}
}
return false;
}
bool asymm_branch::uhle(scoped_detach& scoped_d, scc& scc, clause & c) {
int right = scc.get_right(m_pos.back());
m_to_delete.reset();
for (unsigned i = m_pos.size() - 1; i-- > 0; ) {
literal lit = m_pos[i];
SASSERT(scc.get_left(lit) < scc.get_left(last));
int right2 = scc.get_right(lit);
if (right2 > right) {
// lit => last, so lit can be deleted
m_to_delete.push_back(lit);
}
else {
right = right2;
}
}
right = scc.get_right(m_neg[0]);
for (unsigned i = 1; i < m_neg.size(); ++i) {
literal lit = m_neg[i];
int right2 = scc.get_right(lit);
if (right > right2) {
// ~first => ~lit
m_to_delete.push_back(~lit);
}
else {
right = right2;
}
}
if (!m_to_delete.empty()) {
unsigned j = 0;
for (unsigned i = 0; i < c.size(); ++i) {
if (!m_to_delete.contains(c[i])) {
c[j] = c[i];
++j;
}
else {
m_pos.erase(c[i]);
m_neg.erase(~c[i]);
}
}
return re_attach(scoped_d, c, j);
}
else {
return true;
}
}
bool asymm_branch::propagate_literal(clause const& c, literal l) {
SASSERT(!s.inconsistent());
TRACE("asymm_branch_detail", tout << "assigning: " << l << "\n";);
@ -190,8 +285,12 @@ namespace sat {
new_sz = j;
m_elim_literals += c.size() - new_sz;
// std::cout << "cleanup: " << c.id() << ": " << literal_vector(new_sz, c.begin()) << " delta: " << (c.size() - new_sz) << " " << skip_idx << " " << new_sz << "\n";
switch(new_sz) {
case 0:
return re_attach(scoped_d, c, new_sz);
}
bool asymm_branch::re_attach(scoped_detach& scoped_d, clause& c, unsigned new_sz) {
switch(new_sz) {
case 0:
s.set_conflict(justification());
return false;
case 1:
@ -216,6 +315,15 @@ namespace sat {
}
}
bool asymm_branch::process2(scc& scc, clause & c) {
scoped_detach scoped_d(s, c);
if (uhte(scc, c)) {
scoped_d.del_clause();
return false;
}
return uhle(scoped_d, scc, c);
}
bool asymm_branch::process(clause & c) {
if (c.is_blocked()) return true;
TRACE("asymm_branch_detail", tout << "processing: " << c << "\n";);

View file

@ -20,6 +20,7 @@ Revision History:
#define SAT_ASYMM_BRANCH_H_
#include "sat/sat_types.h"
#include "sat/sat_scc.h"
#include "util/statistics.h"
#include "util/params.h"
@ -30,21 +31,37 @@ namespace sat {
class asymm_branch {
struct report;
solver & s;
solver & s;
params_ref m_params;
int64 m_counter;
random_gen m_rand;
unsigned m_calls;
// config
bool m_asymm_branch;
bool m_asymm_branch_all;
int64 m_asymm_branch_limit;
bool m_asymm_branch;
bool m_asymm_branch_all;
int64 m_asymm_branch_limit;
// stats
unsigned m_elim_literals;
unsigned m_elim_literals;
literal_vector m_pos, m_neg; // literals (complements of literals) in clauses sorted by discovery time (m_left in scc).
literal_vector m_to_delete;
struct compare_left;
void sort(scc & scc, clause const& c);
bool uhle(scoped_detach& scoped_d, scc & scc, clause & c);
bool uhte(scc & scc, clause & c);
bool re_attach(scoped_detach& scoped_d, clause& c, unsigned new_sz);
bool process(clause & c);
bool process2(scc& scc, clause & c);
void process(clause_vector & c);
bool process_all(clause & c);
@ -55,6 +72,8 @@ namespace sat {
bool propagate_literal(clause const& c, literal l);
void setup_big();
public:
asymm_branch(solver & s, params_ref const & p);

View file

@ -234,58 +234,70 @@ namespace sat {
return to_elim.size();
}
void scc::get_dfs_num(svector<int>& dfs, bool learned) {
unsigned num_lits = m_solver.num_vars() * 2;
vector<literal_vector> dag(num_lits);
svector<bool> roots(num_lits, true);
literal_vector todo;
SASSERT(dfs.size() == num_lits);
unsigned num_edges = 0;
// shuffle vertices to obtain different DAG traversal each time
void scc::shuffle(literal_vector& lits) {
unsigned sz = lits.size();
if (sz > 1) {
for (unsigned i = sz; i-- > 0; ) {
std::swap(lits[i], lits[m_rand(i+1)]);
}
}
}
// retrieve DAG
vector<literal_vector> const& scc::get_big(bool learned) {
unsigned num_lits = m_solver.num_vars() * 2;
m_dag.reset();
m_roots.reset();
m_dag.resize(num_lits, 0);
m_roots.resize(num_lits, true);
SASSERT(num_lits == m_dag.size() && num_lits == m_roots.size());
for (unsigned l_idx = 0; l_idx < num_lits; l_idx++) {
literal u(to_literal(l_idx));
literal u = to_literal(l_idx);
if (m_solver.was_eliminated(u.var()))
continue;
auto& edges = dag[u.index()];
auto& edges = m_dag[l_idx];
for (watched const& w : m_solver.m_watches[l_idx]) {
if (learned ? w.is_binary_clause() : w.is_binary_unblocked_clause()) {
literal v = w.get_literal();
roots[v.index()] = false;
m_roots[v.index()] = false;
edges.push_back(v);
++num_edges;
}
}
unsigned sz = edges.size();
// shuffle vertices to obtain different DAG traversal each time
if (sz > 1) {
for (unsigned i = sz; i-- > 0; ) {
std::swap(edges[i], edges[m_rand(i+1)]);
}
}
shuffle(edges);
}
// std::cout << "dag num edges: " << num_edges << "\n";
return m_dag;
}
void scc::get_dfs_num(bool learned) {
unsigned num_lits = m_solver.num_vars() * 2;
SASSERT(m_left.size() == num_lits);
SASSERT(m_right.size() == num_lits);
literal_vector todo;
// retrieve literals that have no predecessors
for (unsigned l_idx = 0; l_idx < num_lits; l_idx++) {
literal u(to_literal(l_idx));
if (roots[u.index()]) todo.push_back(u);
if (m_roots[u.index()]) todo.push_back(u);
}
// std::cout << "num roots: " << roots.size() << "\n";
// traverse DAG, annotate nodes with DFS number
shuffle(todo);
int dfs_num = 0;
while (!todo.empty()) {
literal u = todo.back();
int& d = dfs[u.index()];
int& d = m_left[u.index()];
// already visited
if (d > 0) {
if (m_right[u.index()] < 0) {
m_right[u.index()] = dfs_num;
}
todo.pop_back();
}
// visited as child:
else if (d < 0) {
d = -d;
for (literal v : dag[u.index()]) {
if (dfs[v.index()] == 0) {
dfs[v.index()] = - d - 1;
for (literal v : m_dag[u.index()]) {
if (m_left[v.index()] == 0) {
m_left[v.index()] = - d - 1;
m_root[v.index()] = m_root[u.index()];
m_parent[v.index()] = u;
todo.push_back(v);
}
}
@ -297,9 +309,21 @@ namespace sat {
}
}
bool scc::reduce_tr(svector<int> const& dfs, bool learned) {
unsigned scc::reduce_tr(bool learned) {
unsigned num_lits = m_solver.num_vars() * 2;
m_left.reset();
m_right.reset();
m_root.reset();
m_parent.reset();
m_left.resize(num_lits, 0);
m_right.resize(num_lits, -1);
for (unsigned i = 0; i < num_lits; ++i) {
m_root[i] = to_literal(i);
m_parent[i] = to_literal(i);
}
get_dfs_num(learned);
unsigned idx = 0;
bool reduced = false;
unsigned elim = m_num_elim_bin;
for (watch_list & wlist : m_solver.m_watches) {
literal u = to_literal(idx++);
watch_list::iterator it = wlist.begin();
@ -309,9 +333,8 @@ namespace sat {
watched& w = *it;
if (learned ? w.is_binary_learned_clause() : w.is_binary_unblocked_clause()) {
literal v = w.get_literal();
if (dfs[u.index()] + 1 < dfs[v.index()]) {
if (m_left[u.index()] + 1 < m_left[v.index()]) {
++m_num_elim_bin;
reduced = true;
}
else {
*itprev = *it;
@ -325,19 +348,13 @@ namespace sat {
}
wlist.set_end(itprev);
}
return reduced;
}
bool scc::reduce_tr(bool learned) {
unsigned num_lits = m_solver.num_vars() * 2;
svector<int> dfs(num_lits);
get_dfs_num(dfs, learned);
return reduce_tr(dfs, learned);
return m_num_elim_bin - elim;
}
void scc::reduce_tr() {
while (reduce_tr(false)) {}
while (reduce_tr(true)) {}
unsigned quota = 0, num_reduced = 0;
while ((num_reduced = reduce_tr(false)) > quota) { quota = std::max(100u, num_reduced / 2); }
while ((num_reduced = reduce_tr(true)) > quota) { quota = std::max(100u, num_reduced / 2); }
}
void scc::collect_statistics(statistics & st) const {

View file

@ -37,12 +37,19 @@ namespace sat {
unsigned m_num_elim_bin;
random_gen m_rand;
void get_dfs_num(svector<int>& dfs, bool learned);
// BIG state:
vector<literal_vector> m_dag;
svector<bool> m_roots;
svector<int> m_left, m_right;
literal_vector m_root, m_parent;
void shuffle(literal_vector& lits);
void reduce_tr();
bool reduce_tr(bool learned);
bool reduce_tr(svector<int> const& dfs, bool learned);
unsigned reduce_tr(bool learned);
public:
scc(solver & s, params_ref const & p);
unsigned operator()();
@ -51,6 +58,19 @@ namespace sat {
void collect_statistics(statistics & st) const;
void reset_statistics();
/*
\brief retrieve binary implication graph
*/
vector<literal_vector> const& get_big(bool learned);
int get_left(literal l) const { return m_left[l.index()]; }
int get_right(literal l) const { return m_right[l.index()]; }
literal get_parent(literal l) const { return m_parent[l.index()]; }
literal get_root(literal l) const { return m_root[l.index()]; }
void get_dfs_num(bool learned);
};
};