mirror of
https://github.com/Z3Prover/z3
synced 2025-04-24 01:25:31 +00:00
updates to sorting networks
Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
parent
3c4ac9aee5
commit
edb3569599
18 changed files with 2070 additions and 170 deletions
|
@ -87,6 +87,7 @@ namespace sat {
|
|||
m_local_search_threads = p.local_search_threads();
|
||||
m_lookahead_simplify = p.lookahead_simplify();
|
||||
m_lookahead_search = p.lookahead_search();
|
||||
m_lookahead_reward = p.lookahead_reward();
|
||||
m_ccc = p.ccc();
|
||||
|
||||
// These parameters are not exposed
|
||||
|
@ -163,7 +164,7 @@ namespace sat {
|
|||
m_pb_solver = PB_SOLVER;
|
||||
}
|
||||
else {
|
||||
throw sat_param_exception("invalid PB solver");
|
||||
throw sat_param_exception("invalid PB solver: solver, totalizer, circuit, sorting");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -75,6 +75,7 @@ namespace sat {
|
|||
bool m_local_search;
|
||||
bool m_lookahead_search;
|
||||
bool m_lookahead_simplify;
|
||||
symbol m_lookahead_reward;
|
||||
bool m_ccc;
|
||||
|
||||
unsigned m_simplify_mult1;
|
||||
|
|
|
@ -41,7 +41,6 @@ namespace sat {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
void lookahead::flip_prefix() {
|
||||
if (m_trail_lim.size() < 64) {
|
||||
uint64 mask = (1ull << m_trail_lim.size());
|
||||
|
@ -228,11 +227,6 @@ namespace sat {
|
|||
if (is_sat()) {
|
||||
return false;
|
||||
}
|
||||
if (newbies) {
|
||||
enable_trace("sat");
|
||||
TRACE("sat", display(tout););
|
||||
TRACE("sat", tout << sum << "\n";);
|
||||
}
|
||||
}
|
||||
SASSERT(!m_candidates.empty());
|
||||
// cut number of candidates down to max_num_cand.
|
||||
|
@ -292,9 +286,8 @@ namespace sat {
|
|||
double lookahead::init_candidates(unsigned level, bool newbies) {
|
||||
m_candidates.reset();
|
||||
double sum = 0;
|
||||
for (bool_var const* it = m_freevars.begin(), * end = m_freevars.end(); it != end; ++it) {
|
||||
SASSERT(is_undef(*it));
|
||||
bool_var x = *it;
|
||||
for (bool_var x : m_freevars) {
|
||||
SASSERT(is_undef(x));
|
||||
if (!m_select_lookahead_vars.empty()) {
|
||||
if (m_select_lookahead_vars.contains(x)) {
|
||||
m_candidates.push_back(candidate(x, m_rating[x]));
|
||||
|
@ -333,20 +326,20 @@ namespace sat {
|
|||
}
|
||||
|
||||
bool lookahead::is_sat() const {
|
||||
for (bool_var const* it = m_freevars.begin(), * end = m_freevars.end(); it != end; ++it) {
|
||||
literal l(*it, false);
|
||||
for (bool_var x : m_freevars) {
|
||||
literal l(x, false);
|
||||
literal_vector const& lits1 = m_binary[l.index()];
|
||||
for (unsigned i = 0; i < lits1.size(); ++i) {
|
||||
if (!is_true(lits1[i])) {
|
||||
TRACE("sat", tout << l << " " << lits1[i] << "\n";);
|
||||
for (literal lit1 : lits1) {
|
||||
if (!is_true(lit1)) {
|
||||
TRACE("sat", tout << l << " " << lit1 << "\n";);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
l.neg();
|
||||
literal_vector const& lits2 = m_binary[l.index()];
|
||||
for (unsigned i = 0; i < lits2.size(); ++i) {
|
||||
if (!is_true(lits2[i])) {
|
||||
TRACE("sat", tout << l << " " << lits2[i] << "\n";);
|
||||
for (literal lit2 : lits2) {
|
||||
if (!is_true(lit2)) {
|
||||
TRACE("sat", tout << l << " " << lit2 << "\n";);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -389,12 +382,114 @@ namespace sat {
|
|||
break;
|
||||
}
|
||||
case heule_schur_reward:
|
||||
heule_schur_scores();
|
||||
break;
|
||||
case heule_unit_reward:
|
||||
heule_unit_scores();
|
||||
break;
|
||||
case unit_literal_reward:
|
||||
heule_schur_scores();
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
static unsigned counter = 0;
|
||||
void lookahead::heule_schur_scores() {
|
||||
if (counter % 10 != 0) return;
|
||||
++counter;
|
||||
for (bool_var x : m_freevars) {
|
||||
literal l(x, false);
|
||||
m_rating[l.var()] = heule_schur_score(l) * heule_schur_score(~l);
|
||||
}
|
||||
}
|
||||
|
||||
double lookahead::heule_schur_score(literal l) {
|
||||
double sum = 0;
|
||||
for (literal lit : m_binary[l.index()]) {
|
||||
if (is_undef(lit)) sum += literal_occs(lit) / 4.0;
|
||||
}
|
||||
watch_list& wlist = m_watches[l.index()];
|
||||
for (auto & w : wlist) {
|
||||
switch (w.get_kind()) {
|
||||
case watched::BINARY:
|
||||
UNREACHABLE();
|
||||
break;
|
||||
case watched::TERNARY: {
|
||||
literal l1 = w.get_literal1();
|
||||
literal l2 = w.get_literal2();
|
||||
if (is_undef(l1) && is_undef(l2)) {
|
||||
sum += (literal_occs(l1) + literal_occs(l2)) / 8.0;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case watched::CLAUSE: {
|
||||
clause_offset cls_off = w.get_clause_offset();
|
||||
clause & c = *(m_cls_allocator.get_clause(cls_off));
|
||||
unsigned sz = 0;
|
||||
double to_add = 0;
|
||||
for (literal lit : c) {
|
||||
if (is_undef(lit) && lit != ~l) {
|
||||
to_add += literal_occs(lit);
|
||||
++sz;
|
||||
}
|
||||
}
|
||||
sum += pow(0.5, sz) * to_add / sz;
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
void lookahead::heule_unit_scores() {
|
||||
if (counter % 10 != 0) return;
|
||||
++counter;
|
||||
for (bool_var x : m_freevars) {
|
||||
literal l(x, false);
|
||||
m_rating[l.var()] = heule_unit_score(l) * heule_unit_score(~l);
|
||||
}
|
||||
}
|
||||
|
||||
double lookahead::heule_unit_score(literal l) {
|
||||
double sum = 0;
|
||||
for (literal lit : m_binary[l.index()]) {
|
||||
if (is_undef(lit)) sum += 0.25;
|
||||
}
|
||||
watch_list& wlist = m_watches[l.index()];
|
||||
for (auto & w : wlist) {
|
||||
switch (w.get_kind()) {
|
||||
case watched::BINARY:
|
||||
UNREACHABLE();
|
||||
break;
|
||||
case watched::TERNARY: {
|
||||
literal l1 = w.get_literal1();
|
||||
literal l2 = w.get_literal2();
|
||||
if (is_undef(l1) && is_undef(l2)) {
|
||||
sum += 0.25;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case watched::CLAUSE: {
|
||||
clause_offset cls_off = w.get_clause_offset();
|
||||
clause & c = *(m_cls_allocator.get_clause(cls_off));
|
||||
unsigned sz = 0;
|
||||
for (literal lit : c) {
|
||||
if (is_undef(lit) && lit != ~l) {
|
||||
++sz;
|
||||
}
|
||||
}
|
||||
sum += pow(0.5, sz);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
void lookahead::ensure_H(unsigned level) {
|
||||
while (m_H.size() <= level) {
|
||||
m_H.push_back(svector<double>());
|
||||
|
@ -404,16 +499,16 @@ namespace sat {
|
|||
|
||||
void lookahead::h_scores(svector<double>& h, svector<double>& hp) {
|
||||
double sum = 0;
|
||||
for (bool_var const* it = m_freevars.begin(), * end = m_freevars.end(); it != end; ++it) {
|
||||
literal l(*it, false);
|
||||
for (bool_var x : m_freevars) {
|
||||
literal l(x, false);
|
||||
sum += h[l.index()] + h[(~l).index()];
|
||||
}
|
||||
if (sum == 0) sum = 0.0001;
|
||||
double factor = 2 * m_freevars.size() / sum;
|
||||
double sqfactor = factor * factor;
|
||||
double afactor = factor * m_config.m_alpha;
|
||||
for (bool_var const* it = m_freevars.begin(), * end = m_freevars.end(); it != end; ++it) {
|
||||
literal l(*it, false);
|
||||
for (bool_var x : m_freevars) {
|
||||
literal l(x, false);
|
||||
double pos = l_score(l, h, factor, sqfactor, afactor);
|
||||
double neg = l_score(~l, h, factor, sqfactor, afactor);
|
||||
hp[l.index()] = pos;
|
||||
|
@ -425,28 +520,25 @@ namespace sat {
|
|||
|
||||
double lookahead::l_score(literal l, svector<double> const& h, double factor, double sqfactor, double afactor) {
|
||||
double sum = 0, tsum = 0;
|
||||
literal_vector::iterator it = m_binary[l.index()].begin(), end = m_binary[l.index()].end();
|
||||
for (; it != end; ++it) {
|
||||
bool_var v = it->var();
|
||||
if (is_undef(*it)) sum += h[it->index()];
|
||||
// if (m_freevars.contains(it->var())) sum += h[it->index()];
|
||||
for (literal lit : m_binary[l.index()]) {
|
||||
if (is_undef(lit)) sum += h[lit.index()];
|
||||
// if (m_freevars.contains(lit.var())) sum += h[lit.index()];
|
||||
}
|
||||
watch_list& wlist = m_watches[l.index()];
|
||||
watch_list::iterator wit = wlist.begin(), wend = wlist.end();
|
||||
for (; wit != wend; ++wit) {
|
||||
switch (wit->get_kind()) {
|
||||
for (auto & w : wlist) {
|
||||
switch (w.get_kind()) {
|
||||
case watched::BINARY:
|
||||
UNREACHABLE();
|
||||
break;
|
||||
case watched::TERNARY: {
|
||||
literal l1 = wit->get_literal1();
|
||||
literal l2 = wit->get_literal2();
|
||||
literal l1 = w.get_literal1();
|
||||
literal l2 = w.get_literal2();
|
||||
// if (is_undef(l1) && is_undef(l2))
|
||||
tsum += h[l1.index()] * h[l2.index()];
|
||||
break;
|
||||
}
|
||||
case watched::CLAUSE: {
|
||||
clause_offset cls_off = wit->get_clause_offset();
|
||||
clause_offset cls_off = w.get_clause_offset();
|
||||
clause & c = *(m_cls_allocator.get_clause(cls_off));
|
||||
// approximation compared to ternary clause case:
|
||||
// we pick two other literals from the clause.
|
||||
|
@ -865,8 +957,6 @@ namespace sat {
|
|||
copy_clauses(m_s.m_clauses);
|
||||
copy_clauses(m_s.m_learned);
|
||||
|
||||
// m_config.m_use_ternary_reward &= !m_s.m_ext;
|
||||
|
||||
// copy units
|
||||
unsigned trail_sz = m_s.init_trail_size();
|
||||
for (unsigned i = 0; i < trail_sz; ++i) {
|
||||
|
@ -995,15 +1085,17 @@ namespace sat {
|
|||
return unsat;
|
||||
}
|
||||
|
||||
void lookahead::push_lookahead1(literal lit, unsigned level) {
|
||||
unsigned lookahead::push_lookahead1(literal lit, unsigned level) {
|
||||
SASSERT(m_search_mode == lookahead_mode::searching);
|
||||
m_search_mode = lookahead_mode::lookahead1;
|
||||
scoped_level _sl(*this, level);
|
||||
unsigned old_sz = m_trail.size();
|
||||
assign(lit);
|
||||
propagate();
|
||||
return m_trail.size() - old_sz;
|
||||
}
|
||||
|
||||
void lookahead::pop_lookahead1(literal lit) {
|
||||
void lookahead::pop_lookahead1(literal lit, unsigned num_units) {
|
||||
bool unsat = inconsistent();
|
||||
SASSERT(m_search_mode == lookahead_mode::lookahead1);
|
||||
m_inconsistent = false;
|
||||
|
@ -1025,8 +1117,15 @@ namespace sat {
|
|||
}
|
||||
m_stats.m_windfall_binaries += m_wstack.size();
|
||||
}
|
||||
if (m_config.m_reward_type == unit_literal_reward) {
|
||||
m_lookahead_reward += m_wstack.size();
|
||||
switch (m_config.m_reward_type) {
|
||||
case unit_literal_reward:
|
||||
m_lookahead_reward += num_units;
|
||||
break;
|
||||
case heule_unit_reward:
|
||||
case heule_schur_reward:
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
m_wstack.reset();
|
||||
}
|
||||
|
@ -1226,6 +1325,9 @@ namespace sat {
|
|||
case heule_schur_reward:
|
||||
m_lookahead_reward += (literal_occs(l1) + literal_occs(l2)) / 8.0;
|
||||
break;
|
||||
case heule_unit_reward:
|
||||
m_lookahead_reward += 0.25;
|
||||
break;
|
||||
case unit_literal_reward:
|
||||
break;
|
||||
}
|
||||
|
@ -1253,6 +1355,9 @@ namespace sat {
|
|||
m_lookahead_reward += pow(0.5, sz) * to_add / sz;
|
||||
break;
|
||||
}
|
||||
case heule_unit_reward:
|
||||
m_lookahead_reward += pow(0.5, sz);
|
||||
break;
|
||||
case ternary_reward:
|
||||
m_lookahead_reward = (double)0.001;
|
||||
break;
|
||||
|
@ -1326,11 +1431,13 @@ namespace sat {
|
|||
IF_VERBOSE(30, verbose_stream() << scope_lvl() << " " << lit << " binary: " << m_binary_trail.size() << " trail: " << m_trail_lim.back() << "\n";);
|
||||
}
|
||||
TRACE("sat", tout << "lookahead: " << lit << " @ " << m_lookahead[i].m_offset << "\n";);
|
||||
unsigned old_trail_sz = m_trail.size();
|
||||
reset_lookahead_reward(lit);
|
||||
push_lookahead1(lit, level);
|
||||
if (!first) do_double(lit, base);
|
||||
bool unsat = inconsistent();
|
||||
pop_lookahead1(lit);
|
||||
bool unsat = inconsistent();
|
||||
unsigned num_units = m_trail.size() - old_trail_sz;
|
||||
pop_lookahead1(lit, num_units);
|
||||
if (unsat) {
|
||||
TRACE("sat", tout << "backtracking and settting " << ~lit << "\n";);
|
||||
reset_lookahead_reward();
|
||||
|
@ -1407,6 +1514,7 @@ namespace sat {
|
|||
switch (m_config.m_reward_type) {
|
||||
case ternary_reward: return l + r + (1 << 10) * l * r;
|
||||
case heule_schur_reward: return l * r;
|
||||
case heule_unit_reward: return l * r;
|
||||
case unit_literal_reward: return l * r;
|
||||
default: UNREACHABLE(); return l * r;
|
||||
}
|
||||
|
@ -1489,7 +1597,7 @@ namespace sat {
|
|||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
else {
|
||||
inc_lookahead_reward(l, m_lookahead_reward);
|
||||
}
|
||||
}
|
||||
|
@ -1625,7 +1733,13 @@ namespace sat {
|
|||
}
|
||||
TRACE("sat", tout << "choose: " << l << " " << trail << "\n";);
|
||||
++m_stats.m_decisions;
|
||||
IF_VERBOSE(1, verbose_stream() << "select " << pp_prefix(m_prefix, m_trail_lim.size()) << ": " << l << " " << m_trail.size() << "\n";);
|
||||
IF_VERBOSE(1, printf("\r");
|
||||
std::stringstream strm;
|
||||
strm << pp_prefix(m_prefix, m_trail_lim.size());
|
||||
for (unsigned i = 0; i < 50; ++i) strm << " ";
|
||||
printf(strm.str().c_str());
|
||||
fflush(stdout);
|
||||
);
|
||||
push(l, c_fixed_truth);
|
||||
trail.push_back(l);
|
||||
SASSERT(inconsistent() || !is_unsat());
|
||||
|
@ -1851,16 +1965,20 @@ namespace sat {
|
|||
m_lookahead.reset();
|
||||
}
|
||||
|
||||
std::ostream& lookahead::display(std::ostream& out) const {
|
||||
std::ostream& lookahead::display_summary(std::ostream& out) const {
|
||||
out << "Prefix: " << pp_prefix(m_prefix, m_trail_lim.size()) << "\n";
|
||||
out << "Level: " << m_level << "\n";
|
||||
out << "Free vars: " << m_freevars.size() << "\n";
|
||||
return out;
|
||||
}
|
||||
|
||||
std::ostream& lookahead::display(std::ostream& out) const {
|
||||
display_summary(out);
|
||||
display_values(out);
|
||||
display_binary(out);
|
||||
display_clauses(out);
|
||||
out << "free vars: ";
|
||||
for (bool_var const* it = m_freevars.begin(), * end = m_freevars.end(); it != end; ++it) {
|
||||
out << *it << " ";
|
||||
}
|
||||
for (bool_var b : m_freevars) out << b << " ";
|
||||
out << "\n";
|
||||
for (unsigned i = 0; i < m_watches.size(); ++i) {
|
||||
watch_list const& wl = m_watches[i];
|
||||
|
@ -1879,6 +1997,24 @@ namespace sat {
|
|||
return m_model;
|
||||
}
|
||||
|
||||
void lookahead::init_config() {
|
||||
if (m_s.m_config.m_lookahead_reward == symbol("hs")) {
|
||||
m_config.m_reward_type = heule_schur_reward;
|
||||
}
|
||||
else if (m_s.m_config.m_lookahead_reward == symbol("heuleu")) {
|
||||
m_config.m_reward_type = heule_unit_reward;
|
||||
}
|
||||
else if (m_s.m_config.m_lookahead_reward == symbol("ternary")) {
|
||||
m_config.m_reward_type = ternary_reward;
|
||||
}
|
||||
else if (m_s.m_config.m_lookahead_reward == symbol("unit")) {
|
||||
m_config.m_reward_type = unit_literal_reward;
|
||||
}
|
||||
else {
|
||||
warning_msg("Reward type not recognized");
|
||||
}
|
||||
}
|
||||
|
||||
void lookahead::collect_statistics(statistics& st) const {
|
||||
st.update("lh bool var", m_vprefix.size());
|
||||
st.update("lh clauses", m_clauses.size());
|
||||
|
|
|
@ -69,7 +69,8 @@ namespace sat {
|
|||
enum reward_t {
|
||||
ternary_reward,
|
||||
unit_literal_reward,
|
||||
heule_schur_reward
|
||||
heule_schur_reward,
|
||||
heule_unit_reward
|
||||
};
|
||||
|
||||
struct config {
|
||||
|
@ -277,6 +278,10 @@ namespace sat {
|
|||
void init_pre_selection(unsigned level);
|
||||
void ensure_H(unsigned level);
|
||||
void h_scores(svector<double>& h, svector<double>& hp);
|
||||
void heule_schur_scores();
|
||||
double heule_schur_score(literal l);
|
||||
void heule_unit_scores();
|
||||
double heule_unit_score(literal l);
|
||||
double l_score(literal l, svector<double> const& h, double factor, double sqfactor, double afactor);
|
||||
|
||||
// ------------------------------------
|
||||
|
@ -393,8 +398,8 @@ namespace sat {
|
|||
void push(literal lit, unsigned level);
|
||||
void pop();
|
||||
bool push_lookahead2(literal lit, unsigned level);
|
||||
void push_lookahead1(literal lit, unsigned level);
|
||||
void pop_lookahead1(literal lit);
|
||||
unsigned push_lookahead1(literal lit, unsigned level);
|
||||
void pop_lookahead1(literal lit, unsigned num_units);
|
||||
double mix_diff(double l, double r) const;
|
||||
clause const& get_clause(watch_list::iterator it) const;
|
||||
bool is_nary_propagation(clause const& c, literal l) const;
|
||||
|
@ -444,6 +449,8 @@ namespace sat {
|
|||
void init_search();
|
||||
void checkpoint();
|
||||
|
||||
void init_config();
|
||||
|
||||
public:
|
||||
lookahead(solver& s) :
|
||||
m_s(s),
|
||||
|
@ -453,6 +460,7 @@ namespace sat {
|
|||
m_level(2),
|
||||
m_prefix(0) {
|
||||
m_s.rlimit().push_child(&m_rlimit);
|
||||
init_config();
|
||||
}
|
||||
|
||||
~lookahead() {
|
||||
|
@ -488,6 +496,7 @@ namespace sat {
|
|||
void scc();
|
||||
|
||||
std::ostream& display(std::ostream& out) const;
|
||||
std::ostream& display_summary(std::ostream& out) const;
|
||||
model const& get_model();
|
||||
|
||||
void collect_statistics(statistics& st) const;
|
||||
|
|
|
@ -31,9 +31,11 @@ def_module_params('sat',
|
|||
('cardinality.solver', BOOL, False, 'use cardinality solver'),
|
||||
('pb.solver', SYMBOL, 'circuit', 'method for handling Pseudo-Boolean constraints: circuit (arithmetical circuit), sorting (sorting circuit), totalizer (use totalizer encoding), solver (use SMT solver)'),
|
||||
('xor.solver', BOOL, False, 'use xor solver'),
|
||||
('atmost1_encoding', SYMBOL, 'grouped', 'encoding used for at-most-1 constraints grouped, bimander, ordered'),
|
||||
('local_search_threads', UINT, 0, 'number of local search threads to find satisfiable solution'),
|
||||
('local_search', BOOL, False, 'use local search instead of CDCL'),
|
||||
('lookahead_search', BOOL, False, 'use lookahead solver'),
|
||||
('lookahead_simplify', BOOL, False, 'use lookahead solver during simplification'),
|
||||
('lookahead.reward', SYMBOL, 'heuleu', 'select lookahead heuristic: ternary, hs (Heule Schur), heuleu (Heule Unit), or unit'),
|
||||
('ccc', BOOL, False, 'use Concurrent Cube and Conquer solver')
|
||||
))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue