3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-08 10:25:18 +00:00

make cutset maintainance incremental, expose option for goal2sat to populate aig

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2020-01-08 16:39:49 -08:00
parent 301f9598a4
commit ca243428f8
7 changed files with 188 additions and 156 deletions

View file

@ -16,9 +16,12 @@
--*/
#include "sat/sat_aig_finder.h"
#include "sat/sat_solver.h"
namespace sat {
aig_finder::aig_finder(solver& s): s(s), m_big(s.rand()) {}
void aig_finder::operator()(clause_vector& clauses) {
m_big.init(s, true);
find_aigs(clauses);

View file

@ -29,11 +29,12 @@
#include "util/statistics.h"
#include "sat/sat_clause.h"
#include "sat/sat_types.h"
#include "sat/sat_solver.h"
#include "sat/sat_big.h"
namespace sat {
class solver;
class aig_finder {
solver& s;
big m_big;
@ -50,7 +51,7 @@ namespace sat {
void validate_clause(literal_vector const& clause, vector<literal_vector> const& clauses);
public:
aig_finder(solver& s) : s(s), m_big(s.rand()) {}
aig_finder(solver& s);
~aig_finder() {}
void set(std::function<void (literal head, literal_vector const& ands)>& f) { m_on_aig = f; }
void set(std::function<void (literal head, literal cond, literal th, literal el)>& f) { m_on_if = f; }

View file

@ -25,9 +25,8 @@ namespace sat {
struct aig_simplifier::report {
aig_simplifier& s;
aig_cuts& c;
stopwatch m_watch;
report(aig_simplifier& s, aig_cuts& c): s(s), c(c) { m_watch.start(); }
report(aig_simplifier& s): s(s) { m_watch.start(); }
~report() {
IF_VERBOSE(2,
verbose_stream() << "(sat.aig-simplifier"
@ -39,39 +38,53 @@ namespace sat {
}
};
aig_simplifier::aig_simplifier(solver& s):s(s), m_aig_cuts(m_config.m_max_cut_size, m_config.m_max_cutset_size) {
}
void aig_simplifier::add_and(literal head, unsigned sz, literal const* lits) {
m_aig_cuts.add_node(head, and_op, sz, lits);
}
void aig_simplifier::add_or(literal head, unsigned sz, literal const* lits) {
m_aig_cuts.add_node(head, and_op, sz, lits);
}
void aig_simplifier::add_xor(literal head, unsigned sz, literal const* lits) {
m_aig_cuts.add_node(head, xor_op, sz, lits);
}
void aig_simplifier::add_ite(literal head, literal c, literal t, literal e) {
literal lits[3] = { c, t, e };
m_aig_cuts.add_node(head, ite_op, 3, lits);
}
void aig_simplifier::add_iff(literal head, literal l1, literal l2) {
literal lits[2] = { l1, ~l2 };
m_aig_cuts.add_node(head, xor_op, 2, lits);
}
void aig_simplifier::operator()() {
aig_cuts aigc;
report _report(*this, aigc);
report _report(*this);
TRACE("aig_simplifier", s.display(tout););
clauses2aig(aigc);
aig2clauses(aigc);
clauses2aig();
aig2clauses();
}
/**
\brief extract AIG definitions from clauses
Ensure that they are sorted and variables have unique definitions.
*/
void aig_simplifier::clauses2aig(aig_cuts& aigc) {
struct aig_def {
literal head;
bool_op op;
unsigned sz;
unsigned offset;
aig_def(literal h, bool_op op, unsigned sz, unsigned o): head(h), op(op), sz(sz), offset(o) {}
};
svector<aig_def> aig_defs;
void aig_simplifier::clauses2aig() {
literal_vector literals;
std::function<void (literal head, literal_vector const& ands)> on_and =
[&,this](literal head, literal_vector const& ands) {
aig_defs.push_back(aig_def(head, and_op, ands.size(), literals.size()));
literals.append(ands);
[&,this](literal head, literal_vector const& ands) {
m_aig_cuts.add_node(head, and_op, ands.size(), ands.c_ptr());
m_stats.m_num_ands++;
};
std::function<void (literal head, literal c, literal t, literal e)> on_ite =
[&,this](literal head, literal c, literal t, literal e) {
aig_defs.push_back(aig_def(head, ite_op, 3, literals.size()));
literal args[3] = { c, t, e };
literals.append(3, args);
literal args[3] = { c, t, e };
m_aig_cuts.add_node(head, ite_op, 3, args);
m_stats.m_num_ites++;
};
aig_finder af(s);
@ -97,88 +110,25 @@ namespace sat {
// ~head = t1 + t2 + ..
literal head = ~xors[index];
unsigned sz = xors.size() - 1;
aig_defs.push_back(aig_def(head, xor_op, sz, literals.size()));
for (unsigned i = xors.size(); i-- > 0; ) {
if (i != index)
literals.push_back(xors[i]);
}
m_aig_cuts.add_node(head, xor_op, sz, literals.c_ptr());
literals.reset();
m_stats.m_num_xors++;
};
xor_finder xf(s);
xf.set(on_xor);
xf(clauses);
svector<bool> outs(s.num_vars(), false);
svector<bool> ins(s.num_vars(), false);
for (auto a : aig_defs) {
outs[a.head.var()] = true;
}
for (auto a : aig_defs) {
for (unsigned i = 0; i < a.sz; ++i) {
unsigned v = literals[a.offset+i].var();
if (!outs[v]) ins[v] = true;
}
}
std::function<void(aig_def)> force_var = [&] (aig_def a) {
for (unsigned i = 0; i < a.sz; ++i) {
unsigned v = literals[a.offset + i].var();
if (!ins[v]) {
aigc.add_var(v);
ins[v] = true;
}
}
};
std::function<void(unsigned)> add_var = [&] (unsigned v) {
if (!outs[v] && ins[v]) {
aigc.add_var(v);
outs[v] = true;
}
};
for (auto a : aig_defs) {
for (unsigned i = 0; i < a.sz; ++i) {
add_var(literals[a.offset+i].var());
}
}
while (true) {
unsigned j = 0;
for (auto a : aig_defs) {
bool visited = true;
for (unsigned i = 0; visited && i < a.sz; ++i) {
visited &= ins[literals[a.offset + i].var()];
}
unsigned h = a.head.var();
if (!ins[h] && visited) {
ins[h] = true;
aigc.add_node(a.head, a.op, a.sz, literals.c_ptr() + a.offset);
}
else if (!ins[h]) {
aig_defs[j++] = a;
}
else {
TRACE("aig_simplifier", tout << "skip " << a.head << " == .. \n";);
force_var(a);
}
}
if (j == 0) {
break;
}
if (j == aig_defs.size()) {
IF_VERBOSE(2, verbose_stream() << "break cycle " << j << "\n");
force_var(aig_defs.back());
}
aig_defs.shrink(j);
}
xf(clauses);
}
void aig_simplifier::aig2clauses(aig_cuts& aigc) {
vector<cut_set> cuts = aigc.get_cuts(m_config.m_max_cut_size, m_config.m_max_cutset_size);
void aig_simplifier::aig2clauses() {
vector<cut_set> const& cuts = m_aig_cuts.get_cuts();
map<cut const*, unsigned, cut::hash_proc, cut::eq_proc> cut2id;
union_find_default_ctx ctx;
union_find<> uf(ctx);
union_find<> uf(ctx), uf2(ctx);
for (unsigned i = 2*s.num_vars(); i--> 0; ) uf.mk_var();
auto add_eq = [&](literal l1, literal l2) {
uf.merge(l1.index(), l2.index());
@ -212,8 +162,31 @@ namespace sat {
}
}
if (old_num_eqs < m_stats.m_num_eqs) {
elim_eqs elim(s);
elim(uf);
// extract equivalences over non-eliminated literals.
bool new_eq = false;
for (unsigned idx = 0; idx < uf.get_num_vars(); ++idx) {
if (!uf.is_root(idx) || 1 == uf.size(idx)) continue;
literal root = null_literal;
unsigned first = idx;
do {
literal lit = to_literal(idx);
if (!s.was_eliminated(lit)) {
if (root == null_literal) {
root = lit;
}
else {
uf2.merge(lit.index(), root.index());
new_eq = true;
}
}
idx = uf.next(idx);
}
while (first != idx);
}
if (new_eq) {
elim_eqs elim(s);
elim(uf2);
}
}
}
@ -224,28 +197,18 @@ namespace sat {
st.update("sat-aig.ites", m_stats.m_num_ites);
st.update("sat-aig.xors", m_stats.m_num_xors);
}
vector<cut_set> aig_cuts::get_cuts(unsigned max_cut_size, unsigned max_cutset_size) {
unsigned_vector sorted = top_sort();
vector<cut_set> cuts(m_aig.size());
aig_cuts::aig_cuts(unsigned max_cut_size, unsigned max_cutset_size) {
m_max_cut_size = std::min(cut().max_cut_size, max_cut_size);
m_max_cutset_size = max_cutset_size;
}
vector<cut_set> const& aig_cuts::get_cuts() {
unsigned_vector node_ids = filter_valid_nodes();
m_cut_set1.init(m_region, m_max_cutset_size + 1);
m_cut_set2.init(m_region, m_max_cutset_size + 1);
unsigned j = 0;
for (unsigned id : sorted) {
node const& n = m_aig[id];
if (n.is_valid()) {
auto& cut_set = cuts[id];
cut_set.init(m_region, m_max_cutset_size + 1);
cut_set.push_back(cut(id));
sorted[j++] = id;
}
}
sorted.shrink(j);
augment(sorted, cuts);
return cuts;
augment(node_ids, m_cuts);
return m_cuts;
}
void aig_cuts::augment(unsigned_vector const& ids, vector<cut_set>& cuts) {
@ -259,6 +222,12 @@ namespace sat {
else if (n.is_ite()) {
augment_ite(n, cut_set, cuts);
}
else if (n.num_children() == 0) {
augment_aig0(n, cut_set, cuts);
}
else if (n.num_children() == 1) {
augment_aig1(n, cut_set, cuts);
}
else if (n.num_children() == 2) {
augment_aig2(n, cut_set, cuts);
}
@ -299,6 +268,28 @@ namespace sat {
}
}
void aig_cuts::augment_aig0(node const& n, cut_set& cs, vector<cut_set>& cuts) {
SASSERT(n.is_and());
cut c;
cs.reset();
if (!n.sign()) {
c.m_table = 3;
}
cs.insert(c);
}
void aig_cuts::augment_aig1(node const& n, cut_set& cs, vector<cut_set>& cuts) {
SASSERT(n.is_and());
literal lit = child(n, 0);
for (auto const& a : cuts[lit.var()]) {
if (cs.size() >= m_max_cutset_size) break;
cut c;
c.set_table(a.m_table);
if (n.sign()) c.negate();
cs.insert(c);
}
}
void aig_cuts::augment_aig2(node const& n, cut_set& cs, vector<cut_set>& cuts) {
SASSERT(n.is_and() || n.is_xor());
literal l1 = child(n, 0);
@ -363,7 +354,11 @@ namespace sat {
void aig_cuts::add_var(unsigned v) {
m_aig.reserve(v + 1);
m_aig[v] = node(v);
m_cuts.reserve(v + 1);
if (!m_aig[v].is_valid()) {
m_aig[v] = node(v);
init_cut_set(v);
}
SASSERT(m_aig[v].is_valid());
}
@ -371,49 +366,45 @@ namespace sat {
TRACE("aig_simplifier", tout << head << " == " << op << " " << literal_vector(sz, args) << "\n";);
unsigned v = head.var();
m_aig.reserve(v + 1);
m_aig[v] = node(head.sign(), op, sz, m_literals.size());
unsigned offset = m_literals.size();
node n(head.sign(), op, sz, offset);
m_literals.append(sz, args);
DEBUG_CODE(
for (unsigned i = 0; i < sz; ++i) {
SASSERT(m_aig[args[i].var()].is_valid());
});
for (unsigned i = 0; i < sz; ++i) {
if (!m_aig[args[i].var()].is_valid()) {
add_var(args[i].var());
}
}
if (!m_aig[v].is_valid() || m_aig[v].is_var()) {
m_aig[v] = n;
init_cut_set(v);
}
else {
insert_aux(v, n);
}
SASSERT(m_aig[v].is_valid());
}
unsigned_vector aig_cuts::top_sort() {
unsigned_vector result;
svector<bool> visit;
visit.reserve(m_aig.size(), false);
unsigned_vector todo;
void aig_cuts::init_cut_set(unsigned id) {
node const& n = m_aig[id];
SASSERT(n.is_valid());
auto& cut_set = m_cuts[id];
cut_set.init(m_region, m_max_cutset_size + 1);
cut_set.push_back(cut(id));
}
void aig_cuts::insert_aux(unsigned v, node const& n) {
// TBD: throttle and replacement strategy
m_aux_aig.reserve(v + 1);
m_aux_aig[v].push_back(n);
}
unsigned_vector aig_cuts::filter_valid_nodes() {
unsigned id = 0;
unsigned_vector result;
for (node const& n : m_aig) {
if (n.is_valid()) todo.push_back(id);
if (n.is_valid()) result.push_back(id);
++id;
}
while (!todo.empty()) {
unsigned id = todo.back();
if (visit[id]) {
todo.pop_back();
continue;
}
bool all_visit = true;
node const& n = m_aig[id];
SASSERT(n.is_valid());
if (!n.is_var()) {
for (unsigned i = 0; i < n.num_children(); ++i) {
bool_var v = child(n, i).var();
if (!visit[v]) {
todo.push_back(v);
all_visit = false;
}
}
}
if (all_visit) {
visit[id] = true;
result.push_back(id);
todo.pop_back();
}
}
return result;
}
}

View file

@ -61,18 +61,26 @@ namespace sat {
unsigned m_max_cut_size;
unsigned m_max_cutset_size;
cut_set m_cut_set1, m_cut_set2;
vector<cut_set> m_cuts;
unsigned_vector top_sort();
void insert_aux(unsigned v, node const& n);
void init_cut_set(unsigned id);
unsigned_vector filter_valid_nodes();
void augment(unsigned_vector const& ids, vector<cut_set>& cuts);
void augment_ite(node const& n, cut_set& cs, vector<cut_set>& cuts);
void augment_aig0(node const& n, cut_set& cs, vector<cut_set>& cuts);
void augment_aig1(node const& n, cut_set& cs, vector<cut_set>& cuts);
void augment_aig2(node const& n, cut_set& cs, vector<cut_set>& cuts);
void augment_aigN(node const& n, cut_set& cs, vector<cut_set>& cuts);
public:
aig_cuts(unsigned max_cut_size, unsigned max_cutset_size);
void add_var(unsigned v);
void add_node(literal head, bool_op op, unsigned sz, literal const* args);
literal child(node const& n, unsigned idx) const { SASSERT(!n.is_var()); SASSERT(idx < n.num_children()); return m_literals[n.offset() + idx]; }
vector<cut_set> get_cuts(unsigned max_cut_size, unsigned max_cutset_size);
vector<cut_set> const & get_cuts();
};
class aig_simplifier {
@ -91,14 +99,22 @@ namespace sat {
solver& s;
stats m_stats;
config m_config;
aig_cuts m_aig_cuts;
struct report;
void clauses2aig(aig_cuts& aigc);
void aig2clauses(aig_cuts& aigc);
void clauses2aig();
void aig2clauses();
public:
aig_simplifier(solver& s) : s(s) {}
aig_simplifier(solver& s);
~aig_simplifier() {}
void operator()();
void collect_statistics(statistics& st) const;
void add_and(literal head, unsigned sz, literal const* args);
void add_or(literal head, unsigned sz, literal const* args);
void add_xor(literal head, unsigned sz, literal const* args);
void add_ite(literal head, literal c, literal t, literal e);
void add_iff(literal head, literal l1, literal l2);
};
}

View file

@ -31,6 +31,7 @@ Revision History:
#include "sat/sat_simplifier.h"
#include "sat/sat_scc.h"
#include "sat/sat_asymm_branch.h"
#include "sat/sat_aig_simplifier.h"
#include "sat/sat_iff3_finder.h"
#include "sat/sat_probing.h"
#include "sat/sat_mus.h"
@ -89,6 +90,7 @@ namespace sat {
config m_config;
stats m_stats;
scoped_ptr<extension> m_ext;
scoped_ptr<aig_simplifier> m_aig_simplifier;
parallel* m_par;
drat m_drat; // DRAT for generating proofs
clause_allocator m_cls_allocator[2];
@ -398,6 +400,7 @@ namespace sat {
bool is_incremental() const { return m_config.m_incremental; }
extension* get_extension() const override { return m_ext.get(); }
void set_extension(extension* e) override;
aig_simplifier* get_aig_simplifier() override { return m_aig_simplifier.get(); }
bool set_root(literal l, literal r);
void flush_roots();
typedef std::pair<literal, literal> bin_clause;

View file

@ -23,7 +23,10 @@ Revision History:
#include "sat/sat_types.h"
namespace sat {
class aig_simplifier;
class extension;
class solver_core {
protected:
reslimit& m_rlimit;
@ -89,6 +92,8 @@ namespace sat {
virtual extension* get_extension() const { return nullptr; }
virtual void set_extension(extension* e) { if (e) throw default_exception("optional API not supported"); }
virtual aig_simplifier* get_aig_simplifier() { return nullptr; }
// The following methods are used when converting the state from the SAT solver back
// to a set of assertions.

View file

@ -35,6 +35,7 @@ Notes:
#include "ast/for_each_expr.h"
#include "sat/tactic/goal2sat.h"
#include "sat/ba_solver.h"
#include "sat/sat_aig_simplifier.h"
#include "model/model_evaluator.h"
#include "model/model_v2_pp.h"
#include "tactic/tactic.h"
@ -53,6 +54,7 @@ struct goal2sat::imp {
ast_manager & m;
pb_util pb;
sat::ba_solver* m_ext;
sat::aig_simplifier* m_aig;
svector<frame> m_frame_stack;
svector<sat::literal> m_result_stack;
obj_map<app, sat::literal> m_cache;
@ -73,6 +75,7 @@ struct goal2sat::imp {
m(_m),
pb(m),
m_ext(nullptr),
m_aig(nullptr),
m_solver(s),
m_map(map),
m_dep2asm(dep2asm),
@ -82,6 +85,7 @@ struct goal2sat::imp {
m_is_lemma(false) {
updt_params(p);
m_true = sat::null_literal;
m_aig = s.get_aig_simplifier();
}
void updt_params(params_ref const & p) {
@ -252,6 +256,9 @@ struct goal2sat::imp {
sat::literal l(k, false);
m_cache.insert(t, l);
sat::literal * lits = m_result_stack.end() - num;
if (m_aig) m_aig->add_or(l, num, lits);
for (unsigned i = 0; i < num; i++) {
mk_clause(~lits[i], l);
}
@ -290,8 +297,11 @@ struct goal2sat::imp {
sat::bool_var k = m_solver.add_var(false);
sat::literal l(k, false);
m_cache.insert(t, l);
// l => /\ lits
sat::literal * lits = m_result_stack.end() - num;
if (m_aig) m_aig->add_and(l, num, lits);
// l => /\ lits
for (unsigned i = 0; i < num; i++) {
mk_clause(~l, lits[i]);
}
@ -341,6 +351,7 @@ struct goal2sat::imp {
mk_clause(~t, ~e, l, false);
mk_clause(t, e, ~l, false);
}
if (m_aig) m_aig->add_ite(l, c, t, e);
m_result_stack.shrink(sz-3);
if (sign)
l.neg();
@ -374,6 +385,7 @@ struct goal2sat::imp {
mk_clause(~l, ~l1, l2);
mk_clause(l, l1, l2);
mk_clause(l, ~l1, ~l2);
if (m_aig) m_aig->add_iff(l, l1, l2);
m_result_stack.shrink(sz-2);
if (sign)
l.neg();
@ -400,6 +412,7 @@ struct goal2sat::imp {
}
ensure_extension();
m_ext->add_xr(lits);
if (m_aig) m_aig->add_xor(~lits.back(), lits.size() - 1, lits.c_ptr() + 1);
sat::literal lit(v, sign);
if (root) {
m_result_stack.reset();
@ -634,7 +647,7 @@ struct goal2sat::imp {
m_ext = alloc(sat::ba_solver);
m_solver.set_extension(m_ext);
}
}
}
}
void convert(app * t, bool root, bool sign) {