diff --git a/src/ast/sls/sls_arith_clausal.cpp b/src/ast/sls/sls_arith_clausal.cpp index a0dcad290..bd3f0df99 100644 --- a/src/ast/sls/sls_arith_clausal.cpp +++ b/src/ast/sls/sls_arith_clausal.cpp @@ -119,7 +119,6 @@ namespace sls { tout << "\n";); for (auto v : ctx.unsat_vars()) { - auto* ineq = a.get_ineq(v); if (!ineq) continue; @@ -140,6 +139,45 @@ namespace sls { a.m_updates.reset(); a.m_fixed_atoms.reset(); + unsigned sz = a.m_bool_var_atoms.size(); + bool is_big = sz > 45u; + sat::bool_var bv; + + auto occurs_negative = [&](sat::bool_var bv) { + if (ctx.unsat_vars().contains(bv)) + return false; + auto* ineq = a.get_ineq(bv); + if (!ineq) + return false; + sat::literal lit(bv, !ineq->is_true()); + auto const& ul = ctx.get_use_list(~lit); + return ul.begin() != ul.end(); + }; + + unsigned idx = 0; + //unsigned num_sampled = 0; + for (unsigned i = std::min(sz, 45u); i-- > 0;) { + if (is_big) { + idx = ctx.rand(sz); + bv = a.m_bool_var_atoms[idx]; + } + else + bv = a.m_bool_var_atoms[i]; + + if (occurs_negative(bv)) { + auto e = ctx.atom(bv); + auto& i = a.get_bool_info(e); + a.add_lookahead(i, bv); + //++num_sampled; + } + + if (is_big) { + --sz; + a.m_bool_var_atoms.swap_elems(idx, sz); + } + } + +#if 0 for (auto bv : a.m_bool_var_atoms) { if (ctx.unsat_vars().contains(bv)) continue; @@ -150,13 +188,14 @@ namespace sls { auto const& ul = ctx.get_use_list(~lit); if (ul.begin() == ul.end()) continue; - auto v = lit.var(); // literal is false in some clause but none of the clauses where it occurs false are unsat. - auto e = ctx.atom(v); + auto e = ctx.atom(bv); auto& i = a.get_bool_info(e); - a.add_lookahead(i, v); + + a.add_lookahead(i, bv); } +#endif } template @@ -244,7 +283,8 @@ namespace sls { template double arith_clausal::get_score(var_t v, num_t const& delta) { auto& vi = a.m_vars[v]; - VERIFY(a.update_num(v, delta)); + if (!a.update_num(v, delta)) + return -1; double score = 0; for (auto ci : vi.m_clauses_of) { auto const& c = ctx.get_clause(ci); @@ -273,8 +313,10 @@ namespace sls { else if (c.m_num_trues == 0 && num_true > 0) score += c.m_weight; } + // revert the update - VERIFY(a.update_num(v, -delta)); + a.update_args_value(v, vi.value() - delta); + return score; } diff --git a/src/test/sls_test.cpp b/src/test/sls_test.cpp index 27fbfee12..f08491c60 100644 --- a/src/test/sls_test.cpp +++ b/src/test/sls_test.cpp @@ -19,13 +19,17 @@ namespace bv { sat::clause_info const& get_clause(unsigned idx) const override { return m_clauses[idx]; } ptr_iterator get_use_list(sat::literal lit) override { return ptr_iterator(nullptr, nullptr); } void flip(sat::bool_var v) override { } + sat::bool_var bool_flip() override { return sat::null_bool_var; } double reward(sat::bool_var v) override { return 0; } double get_weigth(unsigned clause_idx) override { return 0; } bool is_true(sat::literal lit) override { return true; } bool try_rotate(sat::bool_var v, sat::bool_var_set& rotated, unsigned& bound) override { return false; } unsigned num_vars() const override { return 0; } indexed_uint_set const& unsat() const override { return s; } + indexed_uint_set const& unsat_vars() const override { return s; } + void shift_weights() override {} void on_model(model_ref& mdl) override {} + unsigned num_external_in_unsat_vars() const override { return 0; } sat::bool_var add_var() override { return sat::null_bool_var;} void add_clause(unsigned n, sat::literal const* lits) override {} // void collect_statistics(statistics& st) const override {} diff --git a/src/util/uint_set.h b/src/util/uint_set.h index 196930d17..fc1f508d7 100644 --- a/src/util/uint_set.h +++ b/src/util/uint_set.h @@ -354,6 +354,15 @@ public: return m_elems[index]; } + void swap_elems(unsigned i, unsigned j) { + if (i == j) + return; + SASSERT(i < m_size && j < m_size); + unsigned x = m_elems[i], y = m_elems[j]; + m_elems[i] = y; m_elems[j] = x; + m_index[x] = j; m_index[y] = i; + } + bool contains(unsigned x) const { return x < m_index.size() && m_index[x] < m_size && m_elems[m_index[x]] == x; } void reset() { m_size = 0; } bool empty() const { return m_size == 0; }