3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-24 01:25:31 +00:00

add ite-finder, profile

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2020-01-05 13:35:14 -08:00
parent a6c3c18e74
commit e1fb74edc5
17 changed files with 321 additions and 168 deletions

View file

@ -3802,7 +3802,8 @@ namespace sat {
xor_finder xf(s());
std::function<void (literal_vector const&)> f = [this](literal_vector const& l) { add_xr(l, false); };
xf.set(f);
xf.extract_xors(s().m_clauses);
clause_vector clauses(s().clauses());
xf(clauses);
for (clause* cp : xf.removed_clauses()) {
cp->set_removed(true);
m_clause_removed = true;

View file

@ -21,15 +21,10 @@
namespace sat {
void aig_finder::operator()(clause_vector const& clauses) {
void aig_finder::operator()(clause_vector& clauses) {
m_big.init(s, true);
for (clause* cp : clauses) {
clause& c = *cp;
if (c.size() <= 2) continue;
if (find_aig(c)) continue;
if (find_if(c)) continue;
}
find_aigs(clauses);
find_ifs(clauses);
}
bool aig_finder::implies(literal a, literal b) {
@ -42,11 +37,29 @@ namespace sat {
return false;
}
void aig_finder::find_aigs(clause_vector& clauses) {
if (!m_on_aig) {
return;
}
unsigned j = 0;
for (clause* cp : clauses) {
clause& c = *cp;
if (!find_aig(c)) {
clauses[j++] = cp;
}
}
clauses.shrink(j);
}
// a = ~b & ~c
// if (~a | ~b) (~a | ~c), (b | c | a)
bool aig_finder::find_aig(clause& c) {
bool is_aig = false;
if (c.size() <= 2) {
return false;
}
for (literal head : c) {
is_aig = true;
for (literal tail : c) {
@ -62,7 +75,7 @@ namespace sat {
for (literal tail : c)
if (tail != head)
m_ands.push_back(~tail);
m_aig_def(head, m_ands, c);
m_on_aig(head, m_ands);
break;
}
}
@ -76,78 +89,139 @@ namespace sat {
// y, z -> x
// ~y, u -> x
//
// So there are clauses
// y -> (x = z)
// u -> (x = ~y)
//
// from clause x, y, z
// then ~x, ~y -> z
// look for ~y, z -> ~x - contains ternary(y, ~z, ~x)
// look for ~x, y -> u - u is used in a ternary claues (~y, x)
// look for y, u -> ~x - contains ternary(~u, ~x, ~y)
// then ~x = if ~y then z else u
bool aig_finder::find_if(clause& c) {
return false;
#if 0
if (c.size() != 3) return false;
void aig_finder::find_ifs(clause_vector& clauses) {
literal x = c[0], y = c[1], z = c[2];
if (find_if(~x, ~y, z, c)) return true;
if (find_if(~x, ~z, y, c)) return true;
if (find_if(~y, ~x, z, c)) return true;
if (find_if(~y, ~z, x, c)) return true;
if (find_if(~z, ~x, y, c)) return true;
if (find_if(~z, ~y, x, c)) return true;
return false;
#endif
}
#if 0
// x, y -> z
// x, ~y -> u
// y, z -> x
// ~y, u -> x
// x + yz + (1 + y)u = 0
bool aig_finder::check_if(literal x, literal y, literal z, clause& c) {
clause* c2 = find_clause(~y, ~z, x);
if (!c2) {
return false;
if (!m_on_if) {
return;
}
for (clause* c3 : ternay_clauses_with(~x, y)) {
literal u = third_literal(~x, y, *c3);
clause* c4 = find_clause(y, ~u, x);
if (c4) {
m_if_def(x, y, z, u, c, *c2, *c3, *c4);
for (clause* cp : clauses) cp->unmark_used();
typedef svector<std::pair<literal, clause*>> use_list_t;
struct binary {
literal x, y;
use_list_t* use_list;
binary(literal x, literal y, use_list_t* u): x(x), y(y), use_list(u) {
if (x.index() > y.index()) std::swap(x, y);
}
binary():x(null_literal), y(null_literal), use_list(nullptr) {}
struct hash {
unsigned operator()(binary const& t) const { return t.x.hash() + 2* t.y.hash(); }
};
struct eq {
bool operator()(binary const& a, binary const& b) const {
return a.x == b.x && a.y == b.y;
}
};
};
hashtable<binary, binary::hash, binary::eq> binaries;
scoped_ptr_vector<use_list_t> use_lists;
auto insert_binary = [&](literal x, literal y, literal z, clause* c) {
binary b(x, y, nullptr);
auto* e = binaries.insert_if_not_there2(b);
if (e->get_data().use_list == nullptr) {
use_list_t* use_list = alloc(use_list_t);
use_lists.push_back(use_list);
e->get_data().use_list = use_list;
}
e->get_data().use_list->push_back(std::make_pair(z, c));
};
struct ternary {
literal x, y, z;
clause* orig;
ternary(literal x, literal y, literal z, clause* c):
x(x), y(y), z(z), orig(c) {
if (x.index() > y.index()) std::swap(x, y);
if (y.index() > z.index()) std::swap(y, z);
if (x.index() > y.index()) std::swap(x, y);
}
ternary():x(null_literal), y(null_literal), z(null_literal), orig(nullptr) {}
struct hash {
unsigned operator()(ternary const& t) const { return mk_mix(t.x.hash(), t.y.hash(), t.z.hash()); }
};
struct eq {
bool operator()(ternary const& a, ternary const& b) const {
return a.x == b.x && a.y == b.y && a.z == b.z;
}
};
};
hashtable<ternary, ternary::hash, ternary::eq> ternaries;
auto has_ternary = [&](literal x, literal y, literal z, clause*& c) {
ternary t(x, y, z, nullptr);
if (ternaries.find(t, t)) {
c = t.orig;
return true;
}
if (implies(~y, z) || implies(~x, y) || implies(~x, z)) {
c = nullptr;
return true;
}
return false;
};
auto insert_ternary = [&](clause& c) {
if (c.size() == 3) {
ternaries.insert(ternary(c[0], c[1], c[2], &c));
insert_binary(c[0], c[1], c[2], &c);
insert_binary(c[0], c[2], c[1], &c);
insert_binary(c[2], c[1], c[0], &c);
}
};
for (clause* cp : s.learned()) {
insert_ternary(*cp);
}
}
literal aig_finder::third_literal(literal a, literal b, clause const& c) {
for (literal lit : c)
if (lit != a && lit != b)
return lit;
return null_literal;
}
clause* aig_finder::find_clause(literal a, literal b, literal c) {
for (auto const& w : s.get_wlist(~a)) {
if (w.is_ternary() &&
(b == w.get_literal1() && c == w.get_literal2()) ||
(c == w.get_literal1() && b == w.get_literal2())) {
for (clause* cp : s.clauses()) {
clause& cl = *cp;
#define pair_eq(a, b, x, y) ((a == x && b == y) || (a == y && b == x))
#define tern_eq(a, b, c, cl) \
cl.size() == 3 && \
((cl[0] == a && pair_eq(b, c, c1[1], c1[2])) || \
(cl[0] == b && pair_eq(a, c, cl[1], cl[2])) || \
(cl[0] == c && pair_eq(a, b, cl[1], cl[2]))))
if (tern_eq(a, b, c, *cp)) return cp;
for (clause* cp : s.clauses()) {
insert_ternary(*cp);
}
auto try_ite = [&,this](literal x, literal y, literal z, clause& c) {
clause* c1, *c3;
if (has_ternary(y, ~z, ~x, c1)) {
binary b(~y, x, nullptr);
if (!binaries.find(b, b)) {
return false;
}
for (auto p : *b.use_list) {
literal u = p.first;
clause* c2 = p.second;
if (has_ternary(~u, ~x, ~y, c3)) {
c.mark_used();
if (c1) c1->mark_used();
if (c2) c2->mark_used();
if (c3) c3->mark_used();
m_on_if(~x, ~y, z, u);
return true;
}
}
}
if (w.is_clause() && tern_eq(a, b, c, s.get_clause(w)))
return &s.get_clause(w);
return false;
};
for (clause* cp : clauses) {
clause& c = *cp;
if (c.size() != 3 || c.was_used()) continue;
literal x = c[0], y = c[1], z = c[2];
if (try_ite(x, y, z, c)) continue;
if (try_ite(y, x, z, c)) continue;
if (try_ite(z, y, x, c)) continue;
}
return nullptr;
std::function<bool(clause*)> not_used = [](clause* cp) { return !cp->was_used(); };
clauses.filter_update(not_used);
}
#endif
}

View file

@ -38,15 +38,18 @@ namespace sat {
solver& s;
big m_big;
literal_vector m_ands;
std::function<void (literal head, literal_vector const& ands, clause& orig)> m_aig_def;
std::function<void (literal head, literal_vector const& ands)> m_on_aig;
std::function<void (literal head, literal cond, literal th, literal el)> m_on_if;
bool implies(literal a, literal b);
bool find_aig(clause& c);
bool find_if(clause& c);
void find_ifs(clause_vector& clauses);
void find_aigs(clause_vector& clauses);
public:
aig_finder(solver& s) : s(s), m_big(s.rand()) {}
~aig_finder() {}
void set(std::function<void (literal head, literal_vector const& ands, clause& orig)>& f) { m_aig_def = f; }
void operator()(clause_vector const& clauses);
void set(std::function<void (literal head, literal_vector const& ands)>& f) { m_on_aig = f; }
void set(std::function<void (literal head, literal cond, literal th, literal el)>& f) { m_on_if = f; }
void operator()(clause_vector& clauses);
};
}

View file

@ -30,10 +30,11 @@ namespace sat {
~report() {
m_watch.stop();
IF_VERBOSE(2,
verbose_stream() << " (sat.anf.simplifier "
verbose_stream() << " (sat.anf.simplifier"
<< " :num-units " << s.m_stats.m_num_units
<< " :num-eqs " << s.m_stats.m_num_eq
<< mem_stat() << m_watch << ")\n");
<< " :num-eqs " << s.m_stats.m_num_eq
<< " :mb " << mem_stat()
<< m_watch << ")\n");
}
};
@ -289,11 +290,7 @@ namespace sat {
};
xor_finder xf(s);
xf.set(f);
xf.extract_xors(clauses);
for (clause* cp : clauses) cp->unmark_used();
for (clause* cp : xf.removed_clauses()) cp->mark_used();
std::function<bool(clause*)> not_used = [](clause* cp) { return !cp->was_used(); };
clauses.filter_update(not_used);
xf(clauses);
}
static solver::bin_clause normalize(solver::bin_clause const& b) {
@ -314,25 +311,28 @@ namespace sat {
if (!m_config.m_compile_aig) {
return;
}
for (clause* cp : clauses) cp->unmark_used();
hashtable<solver::bin_clause, solver::bin_clause_hash, default_eq<solver::bin_clause>> seen_bin;
std::function<void(literal head, literal_vector const& tail, clause& c)> f =
[&,this](literal head, literal_vector const& tail, clause& c) {
c.mark_used();
std::function<void(literal head, literal_vector const& tail)> on_aig =
[&,this](literal head, literal_vector const& tail) {
add_aig(head, tail, ps);
for (literal l : tail) {
seen_bin.insert(normalize(solver::bin_clause(~l, head)));
}
m_stats.m_num_aigs++;
};
std::function<void(literal head, literal c, literal th, literal el)> on_if =
[&,this](literal head, literal c, literal th, literal el) {
add_if(head, c, th, el, ps);
m_stats.m_num_ifs++;
};
aig_finder af(s);
af.set(f);
af.set(on_aig);
af.set(on_if);
af(clauses);
std::function<bool(clause*)> not_used = [](clause* cp) { return !cp->was_used(); };
std::function<bool(solver::bin_clause b)> not_seen = [&](solver::bin_clause b) { return !seen_bin.contains(normalize(b)); };
clauses.filter_update(not_used);
std::function<bool(solver::bin_clause b)> not_seen =
[&](solver::bin_clause b) { return !seen_bin.contains(normalize(b)); };
bins.filter_update(not_seen);
}
@ -361,14 +361,14 @@ namespace sat {
cfg.m_expr_size_limit = 1000;
cfg.m_max_steps = 1000;
cfg.m_random_seed = s.rand()();
cfg.m_enable_exlin = true;
cfg.m_enable_exlin = m_config.m_enable_exlin;
unsigned max_num_nodes = 1 << 18;
ps.get_manager().set_max_num_nodes(max_num_nodes);
ps.set(cfg);
}
#define lit2pdd(_l_) _l_.sign() ? ~m.mk_var(_l_.var()) : m.mk_var(_l_.var())
#define lit2pdd(_l_) (_l_.sign() ? ~m.mk_var(_l_.var()) : m.mk_var(_l_.var()))
void anf_simplifier::add_bin(solver::bin_clause const& b, pdd_solver& ps) {
auto& m = ps.get_manager();
@ -380,6 +380,7 @@ namespace sat {
}
void anf_simplifier::add_clause(clause const& c, pdd_solver& ps) {
if (c.size() > m_config.m_max_clause_size) return;
auto& m = ps.get_manager();
dd::pdd p = m.zero();
for (literal l : c) p |= lit2pdd(l);
@ -401,15 +402,23 @@ namespace sat {
for (literal l : ands) q &= lit2pdd(l);
dd::pdd p = lit2pdd(head) ^ q;
ps.add(p);
TRACE("anf_simplifier", tout << "aig: " << head << " == " << ands << " : " << p << "\n";);
TRACE("anf_simplifier", tout << "aig: " << head << " == " << ands << " poly : " << p << "\n";);
}
void anf_simplifier::add_if(literal head, literal c, literal th, literal el, pdd_solver& ps) {
auto& m = ps.get_manager();
dd::pdd p = lit2pdd(head) ^ (lit2pdd(c) & lit2pdd(th)) ^ (~lit2pdd(c) & lit2pdd(el));
ps.add(p);
TRACE("anf_simplifier", tout << "ite: " << head << " == " << c << "?" << th << ":" << el << " poly : " << p << "\n";);
}
void anf_simplifier::save_statistics(pdd_solver& solver) {
solver.collect_statistics(m_st);
m_st.update("anf.num-units", m_stats.m_num_units);
m_st.update("anf.num-eqs", m_stats.m_num_eq);
m_st.update("anf.num-aigs", m_stats.m_num_aigs);
m_st.update("anf.num-xors", m_stats.m_num_xors);
m_st.update("sat-anf.units", m_stats.m_num_units);
m_st.update("sat-anf.eqs", m_stats.m_num_eq);
m_st.update("sat-anf.ands", m_stats.m_num_aigs);
m_st.update("sat-anf.ites", m_stats.m_num_ifs);
m_st.update("sat-anf.xors", m_stats.m_num_xors);
}
}

View file

@ -42,12 +42,14 @@ namespace sat {
bool m_compile_xor;
bool m_compile_aig;
bool m_anf2phase;
bool m_enable_exlin;
config():
m_max_clause_size(10),
m_max_clause_size(3),
m_max_clauses(10000),
m_compile_xor(true),
m_compile_aig(true),
m_anf2phase(false)
m_anf2phase(false),
m_enable_exlin(false)
{}
};
@ -56,7 +58,7 @@ namespace sat {
struct stats {
unsigned m_num_units, m_num_eq;
unsigned m_num_aigs, m_num_xors;
unsigned m_num_aigs, m_num_xors, m_num_ifs;
stats() { reset(); }
void reset() { memset(this, 0, sizeof(*this)); }
};
@ -84,6 +86,7 @@ namespace sat {
void add_clause(clause const& c, pdd_solver& ps);
void add_bin(solver::bin_clause const& b, pdd_solver& ps);
void add_xor(literal_vector const& x, pdd_solver& ps);
void add_if(literal head, literal c, literal t, literal e, pdd_solver& ps);
void add_aig(literal head, literal_vector const& ands, pdd_solver& ps);
void save_statistics(pdd_solver& ps);

View file

@ -101,7 +101,10 @@ namespace sat {
m_unit_walk_threads = p.unit_walk_threads();
m_binspr = p.binspr();
m_anf_simplify = p.anf();
m_anf_delay = p.anf_delay();
m_anf_exlin = p.anf_exlin();
m_aig_simplify = p.aig();
m_aig_delay = p.aig_delay();
m_lookahead_simplify = p.lookahead_simplify();
m_lookahead_double = p.lookahead_double();
m_lookahead_simplify_bca = p.lookahead_simplify_bca();

View file

@ -121,7 +121,10 @@ namespace sat {
bool m_unit_walk;
bool m_binspr;
bool m_aig_simplify;
unsigned m_aig_delay;
bool m_anf_simplify;
unsigned m_anf_delay;
bool m_anf_exlin;
bool m_lookahead_simplify;
bool m_lookahead_simplify_bca;
cutoff_t m_lookahead_cube_cutoff;

View file

@ -71,7 +71,10 @@ def_module_params('sat',
('unit_walk_threads', UINT, 0, 'number of unit-walk search threads to find satisfiable solution'),
('binspr', BOOL, False, 'enable SPR inferences of binary propagation redundant clauses. This inprocessing step eliminates models'),
('anf', BOOL, False, 'enable ANF based simplification in-processing'),
('anf.delay', UINT, 2, 'delay ANF simplification by in-processing round'),
('anf.exlin', BOOL, False, 'enable extended linear simplification'),
('aig', BOOL, False, 'enable AIG based simplification in-processing'),
('aig.delay', UINT, 2, 'delay AIG simplification by in-processing round'),
('lookahead.cube.cutoff', SYMBOL, 'depth', 'cutoff type used to create lookahead cubes: depth, freevars, psat, adaptive_freevars, adaptive_psat'),
# - depth: the maximal cutoff is fixed to the value of lookahead.cube.depth.
# So if the value is 10, at most 1024 cubes will be generated of length 10.

View file

@ -1907,18 +1907,6 @@ namespace sat {
lh.collect_statistics(m_aux_stats);
}
if (m_config.m_anf_simplify) {
anf_simplifier anf(*this);
anf();
anf.collect_statistics(m_aux_stats);
}
if (m_config.m_aig_simplify) {
aig_simplifier aig(*this);
aig();
aig.collect_statistics(m_aux_stats);
}
reinit_assumptions();
if (inconsistent()) return;
@ -1942,18 +1930,21 @@ namespace sat {
m_binspr();
}
#if 0
static unsigned file_no = 0;
#pragma omp critical (print_sat)
{
++file_no;
std::ostringstream ostrm;
ostrm << "s" << file_no << ".txt";
std::ofstream ous(ostrm.str());
display(ous);
if (m_config.m_anf_simplify && m_simplifications > m_config.m_anf_delay && !inconsistent()) {
anf_simplifier anf(*this);
anf_simplifier::config cfg;
cfg.m_enable_exlin = m_config.m_anf_exlin;
anf();
anf.collect_statistics(m_aux_stats);
// TBD: throttle anf_delay based on yield
}
#endif
if (m_config.m_aig_simplify && m_simplifications > m_config.m_aig_delay && !inconsistent()) {
aig_simplifier aig(*this);
aig();
aig.collect_statistics(m_aux_stats);
// TBD: throttle aig_delay based on yield
}
}
bool solver::set_root(literal l, literal r) {

View file

@ -225,10 +225,6 @@ namespace sat {
return out << std::fixed << std::setprecision(2) << mem;
}
inline std::ostream& operator<<(std::ostream& out, stopwatch const& sw) {
return out << " :time " << std::fixed << std::setprecision(2) << sw.get_seconds();
}
struct dimacs_lit {
literal m_lit;
dimacs_lit(literal l):m_lit(l) {}

View file

@ -24,7 +24,7 @@
namespace sat {
void xor_finder::extract_xors(clause_vector& clauses) {
void xor_finder::operator()(clause_vector& clauses) {
m_removed_clauses.reset();
if (!s.get_config().m_xor_solver) {
return;
@ -49,6 +49,11 @@ namespace sat {
}
}
m_clause_filters.clear();
for (clause* cp : clauses) cp->unmark_used();
for (clause* cp : m_removed_clauses) cp->mark_used();
std::function<bool(clause*)> not_used = [](clause* cp) { return !cp->was_used(); };
clauses.filter_update(not_used);
}
void xor_finder::extract_xor(clause& c) {
@ -108,7 +113,7 @@ namespace sat {
s.set_external(l.var());
}
if (parity) lits[0].neg();
m_add_xr(lits);
m_on_xor(lits);
}
bool xor_finder::extract_xor(bool parity, clause& c, literal l1, literal l2) {

View file

@ -47,7 +47,7 @@ namespace sat {
literal_vector m_clause; // reference clause with literals sorted according to main clause
unsigned_vector m_missing; // set of indices not occurring in clause.
clause_vector m_removed_clauses;
std::function<void (literal_vector const& lits)> m_add_xr;
std::function<void (literal_vector const& lits)> m_on_xor;
inline void set_combination(unsigned mask) { m_combination |= (1 << mask); }
inline bool get_combination(unsigned mask) const { return (m_combination & (1 << mask)) != 0; }
@ -65,12 +65,12 @@ namespace sat {
xor_finder(solver& s) : s(s), m_max_xor_size(5) { init_parity(); }
~xor_finder() {}
void set(std::function<void (literal_vector const& lits)>& f) { m_add_xr = f; }
void set(std::function<void (literal_vector const& lits)>& f) { m_on_xor = f; }
bool parity(unsigned i, unsigned j) const { return m_parity[i][j]; }
unsigned max_xor_size() const { return m_max_xor_size; }
void extract_xors(clause_vector& clauses);
clause_vector& removed_clauses() { return m_removed_clauses; }
void operator()(clause_vector& clauses);
clause_vector const& removed_clauses() const { return m_removed_clauses; }
};
}