3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-24 01:25:31 +00:00

updates to sorting networks

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2017-09-23 22:36:19 -05:00
parent 3c4ac9aee5
commit edb3569599
18 changed files with 2070 additions and 170 deletions

View file

@ -87,6 +87,7 @@ namespace sat {
m_local_search_threads = p.local_search_threads();
m_lookahead_simplify = p.lookahead_simplify();
m_lookahead_search = p.lookahead_search();
m_lookahead_reward = p.lookahead_reward();
m_ccc = p.ccc();
// These parameters are not exposed
@ -163,7 +164,7 @@ namespace sat {
m_pb_solver = PB_SOLVER;
}
else {
throw sat_param_exception("invalid PB solver");
throw sat_param_exception("invalid PB solver: solver, totalizer, circuit, sorting");
}
}

View file

@ -75,6 +75,7 @@ namespace sat {
bool m_local_search;
bool m_lookahead_search;
bool m_lookahead_simplify;
symbol m_lookahead_reward;
bool m_ccc;
unsigned m_simplify_mult1;

View file

@ -41,7 +41,6 @@ namespace sat {
}
}
void lookahead::flip_prefix() {
if (m_trail_lim.size() < 64) {
uint64 mask = (1ull << m_trail_lim.size());
@ -228,11 +227,6 @@ namespace sat {
if (is_sat()) {
return false;
}
if (newbies) {
enable_trace("sat");
TRACE("sat", display(tout););
TRACE("sat", tout << sum << "\n";);
}
}
SASSERT(!m_candidates.empty());
// cut number of candidates down to max_num_cand.
@ -292,9 +286,8 @@ namespace sat {
double lookahead::init_candidates(unsigned level, bool newbies) {
m_candidates.reset();
double sum = 0;
for (bool_var const* it = m_freevars.begin(), * end = m_freevars.end(); it != end; ++it) {
SASSERT(is_undef(*it));
bool_var x = *it;
for (bool_var x : m_freevars) {
SASSERT(is_undef(x));
if (!m_select_lookahead_vars.empty()) {
if (m_select_lookahead_vars.contains(x)) {
m_candidates.push_back(candidate(x, m_rating[x]));
@ -333,20 +326,20 @@ namespace sat {
}
bool lookahead::is_sat() const {
for (bool_var const* it = m_freevars.begin(), * end = m_freevars.end(); it != end; ++it) {
literal l(*it, false);
for (bool_var x : m_freevars) {
literal l(x, false);
literal_vector const& lits1 = m_binary[l.index()];
for (unsigned i = 0; i < lits1.size(); ++i) {
if (!is_true(lits1[i])) {
TRACE("sat", tout << l << " " << lits1[i] << "\n";);
for (literal lit1 : lits1) {
if (!is_true(lit1)) {
TRACE("sat", tout << l << " " << lit1 << "\n";);
return false;
}
}
l.neg();
literal_vector const& lits2 = m_binary[l.index()];
for (unsigned i = 0; i < lits2.size(); ++i) {
if (!is_true(lits2[i])) {
TRACE("sat", tout << l << " " << lits2[i] << "\n";);
for (literal lit2 : lits2) {
if (!is_true(lit2)) {
TRACE("sat", tout << l << " " << lit2 << "\n";);
return false;
}
}
@ -389,12 +382,114 @@ namespace sat {
break;
}
case heule_schur_reward:
heule_schur_scores();
break;
case heule_unit_reward:
heule_unit_scores();
break;
case unit_literal_reward:
heule_schur_scores();
break;
}
}
static unsigned counter = 0;
void lookahead::heule_schur_scores() {
if (counter % 10 != 0) return;
++counter;
for (bool_var x : m_freevars) {
literal l(x, false);
m_rating[l.var()] = heule_schur_score(l) * heule_schur_score(~l);
}
}
double lookahead::heule_schur_score(literal l) {
double sum = 0;
for (literal lit : m_binary[l.index()]) {
if (is_undef(lit)) sum += literal_occs(lit) / 4.0;
}
watch_list& wlist = m_watches[l.index()];
for (auto & w : wlist) {
switch (w.get_kind()) {
case watched::BINARY:
UNREACHABLE();
break;
case watched::TERNARY: {
literal l1 = w.get_literal1();
literal l2 = w.get_literal2();
if (is_undef(l1) && is_undef(l2)) {
sum += (literal_occs(l1) + literal_occs(l2)) / 8.0;
}
break;
}
case watched::CLAUSE: {
clause_offset cls_off = w.get_clause_offset();
clause & c = *(m_cls_allocator.get_clause(cls_off));
unsigned sz = 0;
double to_add = 0;
for (literal lit : c) {
if (is_undef(lit) && lit != ~l) {
to_add += literal_occs(lit);
++sz;
}
}
sum += pow(0.5, sz) * to_add / sz;
break;
}
default:
break;
}
}
return sum;
}
void lookahead::heule_unit_scores() {
if (counter % 10 != 0) return;
++counter;
for (bool_var x : m_freevars) {
literal l(x, false);
m_rating[l.var()] = heule_unit_score(l) * heule_unit_score(~l);
}
}
double lookahead::heule_unit_score(literal l) {
double sum = 0;
for (literal lit : m_binary[l.index()]) {
if (is_undef(lit)) sum += 0.25;
}
watch_list& wlist = m_watches[l.index()];
for (auto & w : wlist) {
switch (w.get_kind()) {
case watched::BINARY:
UNREACHABLE();
break;
case watched::TERNARY: {
literal l1 = w.get_literal1();
literal l2 = w.get_literal2();
if (is_undef(l1) && is_undef(l2)) {
sum += 0.25;
}
break;
}
case watched::CLAUSE: {
clause_offset cls_off = w.get_clause_offset();
clause & c = *(m_cls_allocator.get_clause(cls_off));
unsigned sz = 0;
for (literal lit : c) {
if (is_undef(lit) && lit != ~l) {
++sz;
}
}
sum += pow(0.5, sz);
break;
}
default:
break;
}
}
return sum;
}
void lookahead::ensure_H(unsigned level) {
while (m_H.size() <= level) {
m_H.push_back(svector<double>());
@ -404,16 +499,16 @@ namespace sat {
void lookahead::h_scores(svector<double>& h, svector<double>& hp) {
double sum = 0;
for (bool_var const* it = m_freevars.begin(), * end = m_freevars.end(); it != end; ++it) {
literal l(*it, false);
for (bool_var x : m_freevars) {
literal l(x, false);
sum += h[l.index()] + h[(~l).index()];
}
if (sum == 0) sum = 0.0001;
double factor = 2 * m_freevars.size() / sum;
double sqfactor = factor * factor;
double afactor = factor * m_config.m_alpha;
for (bool_var const* it = m_freevars.begin(), * end = m_freevars.end(); it != end; ++it) {
literal l(*it, false);
for (bool_var x : m_freevars) {
literal l(x, false);
double pos = l_score(l, h, factor, sqfactor, afactor);
double neg = l_score(~l, h, factor, sqfactor, afactor);
hp[l.index()] = pos;
@ -425,28 +520,25 @@ namespace sat {
double lookahead::l_score(literal l, svector<double> const& h, double factor, double sqfactor, double afactor) {
double sum = 0, tsum = 0;
literal_vector::iterator it = m_binary[l.index()].begin(), end = m_binary[l.index()].end();
for (; it != end; ++it) {
bool_var v = it->var();
if (is_undef(*it)) sum += h[it->index()];
// if (m_freevars.contains(it->var())) sum += h[it->index()];
for (literal lit : m_binary[l.index()]) {
if (is_undef(lit)) sum += h[lit.index()];
// if (m_freevars.contains(lit.var())) sum += h[lit.index()];
}
watch_list& wlist = m_watches[l.index()];
watch_list::iterator wit = wlist.begin(), wend = wlist.end();
for (; wit != wend; ++wit) {
switch (wit->get_kind()) {
for (auto & w : wlist) {
switch (w.get_kind()) {
case watched::BINARY:
UNREACHABLE();
break;
case watched::TERNARY: {
literal l1 = wit->get_literal1();
literal l2 = wit->get_literal2();
literal l1 = w.get_literal1();
literal l2 = w.get_literal2();
// if (is_undef(l1) && is_undef(l2))
tsum += h[l1.index()] * h[l2.index()];
break;
}
case watched::CLAUSE: {
clause_offset cls_off = wit->get_clause_offset();
clause_offset cls_off = w.get_clause_offset();
clause & c = *(m_cls_allocator.get_clause(cls_off));
// approximation compared to ternary clause case:
// we pick two other literals from the clause.
@ -865,8 +957,6 @@ namespace sat {
copy_clauses(m_s.m_clauses);
copy_clauses(m_s.m_learned);
// m_config.m_use_ternary_reward &= !m_s.m_ext;
// copy units
unsigned trail_sz = m_s.init_trail_size();
for (unsigned i = 0; i < trail_sz; ++i) {
@ -995,15 +1085,17 @@ namespace sat {
return unsat;
}
void lookahead::push_lookahead1(literal lit, unsigned level) {
unsigned lookahead::push_lookahead1(literal lit, unsigned level) {
SASSERT(m_search_mode == lookahead_mode::searching);
m_search_mode = lookahead_mode::lookahead1;
scoped_level _sl(*this, level);
unsigned old_sz = m_trail.size();
assign(lit);
propagate();
return m_trail.size() - old_sz;
}
void lookahead::pop_lookahead1(literal lit) {
void lookahead::pop_lookahead1(literal lit, unsigned num_units) {
bool unsat = inconsistent();
SASSERT(m_search_mode == lookahead_mode::lookahead1);
m_inconsistent = false;
@ -1025,8 +1117,15 @@ namespace sat {
}
m_stats.m_windfall_binaries += m_wstack.size();
}
if (m_config.m_reward_type == unit_literal_reward) {
m_lookahead_reward += m_wstack.size();
switch (m_config.m_reward_type) {
case unit_literal_reward:
m_lookahead_reward += num_units;
break;
case heule_unit_reward:
case heule_schur_reward:
break;
default:
break;
}
m_wstack.reset();
}
@ -1226,6 +1325,9 @@ namespace sat {
case heule_schur_reward:
m_lookahead_reward += (literal_occs(l1) + literal_occs(l2)) / 8.0;
break;
case heule_unit_reward:
m_lookahead_reward += 0.25;
break;
case unit_literal_reward:
break;
}
@ -1253,6 +1355,9 @@ namespace sat {
m_lookahead_reward += pow(0.5, sz) * to_add / sz;
break;
}
case heule_unit_reward:
m_lookahead_reward += pow(0.5, sz);
break;
case ternary_reward:
m_lookahead_reward = (double)0.001;
break;
@ -1326,11 +1431,13 @@ namespace sat {
IF_VERBOSE(30, verbose_stream() << scope_lvl() << " " << lit << " binary: " << m_binary_trail.size() << " trail: " << m_trail_lim.back() << "\n";);
}
TRACE("sat", tout << "lookahead: " << lit << " @ " << m_lookahead[i].m_offset << "\n";);
unsigned old_trail_sz = m_trail.size();
reset_lookahead_reward(lit);
push_lookahead1(lit, level);
if (!first) do_double(lit, base);
bool unsat = inconsistent();
pop_lookahead1(lit);
bool unsat = inconsistent();
unsigned num_units = m_trail.size() - old_trail_sz;
pop_lookahead1(lit, num_units);
if (unsat) {
TRACE("sat", tout << "backtracking and settting " << ~lit << "\n";);
reset_lookahead_reward();
@ -1407,6 +1514,7 @@ namespace sat {
switch (m_config.m_reward_type) {
case ternary_reward: return l + r + (1 << 10) * l * r;
case heule_schur_reward: return l * r;
case heule_unit_reward: return l * r;
case unit_literal_reward: return l * r;
default: UNREACHABLE(); return l * r;
}
@ -1489,7 +1597,7 @@ namespace sat {
}
}
}
else {
else {
inc_lookahead_reward(l, m_lookahead_reward);
}
}
@ -1625,7 +1733,13 @@ namespace sat {
}
TRACE("sat", tout << "choose: " << l << " " << trail << "\n";);
++m_stats.m_decisions;
IF_VERBOSE(1, verbose_stream() << "select " << pp_prefix(m_prefix, m_trail_lim.size()) << ": " << l << " " << m_trail.size() << "\n";);
IF_VERBOSE(1, printf("\r");
std::stringstream strm;
strm << pp_prefix(m_prefix, m_trail_lim.size());
for (unsigned i = 0; i < 50; ++i) strm << " ";
printf(strm.str().c_str());
fflush(stdout);
);
push(l, c_fixed_truth);
trail.push_back(l);
SASSERT(inconsistent() || !is_unsat());
@ -1851,16 +1965,20 @@ namespace sat {
m_lookahead.reset();
}
std::ostream& lookahead::display(std::ostream& out) const {
std::ostream& lookahead::display_summary(std::ostream& out) const {
out << "Prefix: " << pp_prefix(m_prefix, m_trail_lim.size()) << "\n";
out << "Level: " << m_level << "\n";
out << "Free vars: " << m_freevars.size() << "\n";
return out;
}
std::ostream& lookahead::display(std::ostream& out) const {
display_summary(out);
display_values(out);
display_binary(out);
display_clauses(out);
out << "free vars: ";
for (bool_var const* it = m_freevars.begin(), * end = m_freevars.end(); it != end; ++it) {
out << *it << " ";
}
for (bool_var b : m_freevars) out << b << " ";
out << "\n";
for (unsigned i = 0; i < m_watches.size(); ++i) {
watch_list const& wl = m_watches[i];
@ -1879,6 +1997,24 @@ namespace sat {
return m_model;
}
void lookahead::init_config() {
if (m_s.m_config.m_lookahead_reward == symbol("hs")) {
m_config.m_reward_type = heule_schur_reward;
}
else if (m_s.m_config.m_lookahead_reward == symbol("heuleu")) {
m_config.m_reward_type = heule_unit_reward;
}
else if (m_s.m_config.m_lookahead_reward == symbol("ternary")) {
m_config.m_reward_type = ternary_reward;
}
else if (m_s.m_config.m_lookahead_reward == symbol("unit")) {
m_config.m_reward_type = unit_literal_reward;
}
else {
warning_msg("Reward type not recognized");
}
}
void lookahead::collect_statistics(statistics& st) const {
st.update("lh bool var", m_vprefix.size());
st.update("lh clauses", m_clauses.size());

View file

@ -69,7 +69,8 @@ namespace sat {
enum reward_t {
ternary_reward,
unit_literal_reward,
heule_schur_reward
heule_schur_reward,
heule_unit_reward
};
struct config {
@ -277,6 +278,10 @@ namespace sat {
void init_pre_selection(unsigned level);
void ensure_H(unsigned level);
void h_scores(svector<double>& h, svector<double>& hp);
void heule_schur_scores();
double heule_schur_score(literal l);
void heule_unit_scores();
double heule_unit_score(literal l);
double l_score(literal l, svector<double> const& h, double factor, double sqfactor, double afactor);
// ------------------------------------
@ -393,8 +398,8 @@ namespace sat {
void push(literal lit, unsigned level);
void pop();
bool push_lookahead2(literal lit, unsigned level);
void push_lookahead1(literal lit, unsigned level);
void pop_lookahead1(literal lit);
unsigned push_lookahead1(literal lit, unsigned level);
void pop_lookahead1(literal lit, unsigned num_units);
double mix_diff(double l, double r) const;
clause const& get_clause(watch_list::iterator it) const;
bool is_nary_propagation(clause const& c, literal l) const;
@ -444,6 +449,8 @@ namespace sat {
void init_search();
void checkpoint();
void init_config();
public:
lookahead(solver& s) :
m_s(s),
@ -453,6 +460,7 @@ namespace sat {
m_level(2),
m_prefix(0) {
m_s.rlimit().push_child(&m_rlimit);
init_config();
}
~lookahead() {
@ -488,6 +496,7 @@ namespace sat {
void scc();
std::ostream& display(std::ostream& out) const;
std::ostream& display_summary(std::ostream& out) const;
model const& get_model();
void collect_statistics(statistics& st) const;

View file

@ -31,9 +31,11 @@ def_module_params('sat',
('cardinality.solver', BOOL, False, 'use cardinality solver'),
('pb.solver', SYMBOL, 'circuit', 'method for handling Pseudo-Boolean constraints: circuit (arithmetical circuit), sorting (sorting circuit), totalizer (use totalizer encoding), solver (use SMT solver)'),
('xor.solver', BOOL, False, 'use xor solver'),
('atmost1_encoding', SYMBOL, 'grouped', 'encoding used for at-most-1 constraints grouped, bimander, ordered'),
('local_search_threads', UINT, 0, 'number of local search threads to find satisfiable solution'),
('local_search', BOOL, False, 'use local search instead of CDCL'),
('lookahead_search', BOOL, False, 'use lookahead solver'),
('lookahead_simplify', BOOL, False, 'use lookahead solver during simplification'),
('lookahead.reward', SYMBOL, 'heuleu', 'select lookahead heuristic: ternary, hs (Heule Schur), heuleu (Heule Unit), or unit'),
('ccc', BOOL, False, 'use Concurrent Cube and Conquer solver')
))