diff --git a/src/ast/sls/sat_ddfw.cpp b/src/ast/sls/sat_ddfw.cpp index bc89ca28b..5f6d9fa95 100644 --- a/src/ast/sls/sat_ddfw.cpp +++ b/src/ast/sls/sat_ddfw.cpp @@ -255,12 +255,14 @@ namespace sat { m_use_list_clauses = m_clauses.size(); m_use_list_index.reset(); m_flat_use_list.reset(); + m_use_list.reserve(2 * num_vars()); for (auto const& ul : m_use_list) { m_use_list_index.push_back(m_flat_use_list.size()); m_flat_use_list.append(ul); } m_use_list_index.push_back(m_flat_use_list.size()); init_clause_data(); + SASSERT(2 * num_vars() + 1 == m_use_list_index.size()); return true; } diff --git a/src/ast/sls/sls_arith_base.cpp b/src/ast/sls/sls_arith_base.cpp index fafd336a0..f1e4eb856 100644 --- a/src/ast/sls/sls_arith_base.cpp +++ b/src/ast/sls/sls_arith_base.cpp @@ -1145,6 +1145,8 @@ namespace sls { else { SASSERT(!a.is_arith_expr(e)); } + initialize_of_bool_var(bv); + add_new_terms(); } @@ -1527,52 +1529,72 @@ namespace sls { throw default_exception("repair is not supported for " + mk_pp(e, m)); } } - for (unsigned v = 0; v < m_vars.size(); ++v) - initialize_bool_vars_of(v); } template - void arith_base::initialize_bool_vars_of(var_t v) { - if (!m_vars[v].m_bool_vars_of.empty()) + void arith_base::initialize_of_bool_var(sat::bool_var bv) { + auto* ineq = get_ineq(bv); + if (!ineq) return; buffer todo; - todo.push_back(v); - auto& vi = m_vars[v]; + for (auto const& [coeff, v] : ineq->m_args) + todo.push_back(v); m_tmp_set.reset(); for (unsigned i = 0; i < todo.size(); ++i) { var_t u = todo[i]; auto& ui = m_vars[u]; - for (auto const& idx : ui.m_muls) { - auto& [x, monomial] = m_muls[idx]; - bool found = false; - for (auto u : todo) found |= u == x; - if (!found) - todo.push_back(x); + if (m_tmp_set.contains(u)) + continue; + m_tmp_set.insert(u); + ui.m_bool_vars_of.push_back(bv); + if (is_add(u)) { + auto const& ad = get_add(u); + for (auto const& [c, w] : ad.m_args) + todo.push_back(w); } - for (auto const& idx : ui.m_adds) { - auto x = m_adds[idx].m_var; - bool found = false; - for (auto u : todo) found |= u == x; - if (!found) - todo.push_back(x); + if (is_mul(u)) { + auto const& [w, monomial] = get_mul(u); + for (auto [w, p] : monomial) + todo.push_back(w); + } + if (is_op(u)) { + auto const& op = m_ops[ui.m_def_idx]; + todo.push_back(op.m_arg1); + todo.push_back(op.m_arg2); } - for (auto const& [coeff, bv] : ui.m_linear_occurs) - m_tmp_set.insert(bv); } - for (auto bv : m_tmp_set) - vi.m_bool_vars_of.push_back(bv); + } - m_tmp_nat_set.reset(); - m_tmp_nat_set.assure_domain(ctx.clauses().size() + 1); - - for (auto bv : vi.m_bool_vars_of) { - for (auto lit : { sat::literal(bv, false), sat::literal(bv, true) }) { - for (auto ci : ctx.get_use_list(lit)) { - if (m_tmp_nat_set.contains(ci)) - continue; - m_tmp_nat_set.insert(ci); - vi.m_clauses_of.push_back(ci); - } + template + void arith_base::initialize_clauses_of(sat::bool_var bv, unsigned ci) { + auto* ineq = get_ineq(bv); + if (!ineq) + return; + buffer todo; + for (auto const& [coeff, v] : ineq->m_args) + todo.push_back(v); + m_tmp_set.reset(); + for (unsigned i = 0; i < todo.size(); ++i) { + var_t u = todo[i]; + auto& ui = m_vars[u]; + if (m_tmp_set.contains(u)) + continue; + m_tmp_set.insert(u); + ui.m_clauses_of.push_back(ci); + if (is_add(u)) { + auto const& ad = get_add(u); + for (auto const& [c, w] : ad.m_args) + todo.push_back(w); + } + if (is_mul(u)) { + auto const& [w, monomial] = get_mul(u); + for (auto [w, p] : monomial) + todo.push_back(w); + } + if (is_op(u)) { + auto const& op = m_ops[ui.m_def_idx]; + todo.push_back(op.m_arg1); + todo.push_back(op.m_arg2); } } } diff --git a/src/ast/sls/sls_arith_base.h b/src/ast/sls/sls_arith_base.h index 86d45afd0..b7c1ab872 100644 --- a/src/ast/sls/sls_arith_base.h +++ b/src/ast/sls/sls_arith_base.h @@ -295,7 +295,8 @@ namespace sls { double compute_score(var_t x, num_t const& delta); void save_best_values(); - void initialize_bool_vars_of(var_t v); + void initialize_of_bool_var(sat::bool_var v); + void initialize_clauses_of(sat::bool_var v, unsigned cl); var_t mk_var(expr* e); var_t mk_term(expr* e); var_t mk_op(arith_op_kind k, expr* e, expr* x, expr* y); diff --git a/src/ast/sls/sls_arith_clausal.cpp b/src/ast/sls/sls_arith_clausal.cpp index 1f6157787..3de698afd 100644 --- a/src/ast/sls/sls_arith_clausal.cpp +++ b/src/ast/sls/sls_arith_clausal.cpp @@ -275,8 +275,8 @@ namespace sls { DEBUG_CODE( for (sat::bool_var bv = 0; bv < ctx.num_bool_vars(); ++bv) { if (a.get_ineq(bv) && a.get_ineq(bv)->is_true() != ctx.is_true(bv)) { - TRACE("arith", tout << bv << " " << *a.get_ineq(bv) << "\n"; - tout << a.m_vars[v].m_bool_vars_of << "\n"); + TRACE("arith", tout << "bv:" << bv << " " << *a.get_ineq(bv) << ctx.is_true(bv) << "\n"; + tout << "bool vars: " << a.m_vars[v].m_bool_vars_of << "\n"); } VERIFY(!a.get_ineq(bv) || a.get_ineq(bv)->is_true() == ctx.is_true(bv)); }); @@ -349,7 +349,6 @@ namespace sls { vi.set_value(vi.m_hi->value); else vi.set_value(num_t(0)); - vi.m_bool_vars_of.reset(); } initialize(); } @@ -365,6 +364,14 @@ namespace sls { m_no_improve = 0; m_no_improve_bool = 0; m_no_improve_arith = 0; + for (; m_num_clauses < ctx.clauses().size(); ++m_num_clauses) { + auto const& c = ctx.get_clause(m_num_clauses); + for (auto lit : c) { + auto bv = lit.var(); + if (a.get_ineq(bv)) + a.initialize_clauses_of(bv, m_num_clauses); + } + } } diff --git a/src/ast/sls/sls_arith_clausal.h b/src/ast/sls/sls_arith_clausal.h index b8d1c4294..06b70d5d6 100644 --- a/src/ast/sls/sls_arith_clausal.h +++ b/src/ast/sls/sls_arith_clausal.h @@ -81,6 +81,7 @@ namespace sls { unsigned m_best_last_step = 0; unsigned m_num_lookaheads = 0; double m_best_score = 0; + unsigned m_num_clauses = 0; // avoid checking the same updates twice var_t m_last_var = UINT_MAX; diff --git a/src/ast/sls/sls_context.cpp b/src/ast/sls/sls_context.cpp index 33ff74f22..09657500d 100644 --- a/src/ast/sls/sls_context.cpp +++ b/src/ast/sls/sls_context.cpp @@ -450,7 +450,7 @@ namespace sls { sat::literal context::mk_literal(expr* e) { expr_ref _e(e, m); - SASSERT(!m_input_assertions.contains(e)); + sat::literal lit; bool neg = false; expr* a, * b, * c; @@ -459,6 +459,7 @@ namespace sls { auto v = m_atom2bool_var.get(e->get_id(), sat::null_bool_var); if (v != sat::null_bool_var) return sat::literal(v, neg); + SASSERT(!m_input_assertions.contains(e)); sat::literal_vector clause; lit = mk_literal(); register_atom(lit.var(), e);