From b8d18c6c6d150ad73747e90ccaa272b66a1f7b89 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Fri, 11 Jan 2019 20:52:19 -0800 Subject: [PATCH] speed-up handling of cnf input to inc_sat_solver Signed-off-by: Nikolaj Bjorner --- src/ast/expr2var.cpp | 47 +++++++++++++++++------- src/ast/expr2var.h | 10 +++-- src/sat/sat_cleaner.cpp | 53 ++++++++++++++------------- src/sat/sat_simplifier.cpp | 16 +++++--- src/sat/sat_solver/inc_sat_solver.cpp | 46 +++++++++++++++++++++-- src/sat/tactic/atom2bool_var.cpp | 10 +++-- src/smt/theory_pb.cpp | 8 ++-- 7 files changed, 133 insertions(+), 57 deletions(-) diff --git a/src/ast/expr2var.cpp b/src/ast/expr2var.cpp index b1fdba2b5..61adcfc3a 100644 --- a/src/ast/expr2var.cpp +++ b/src/ast/expr2var.cpp @@ -29,8 +29,17 @@ void expr2var::insert(expr * n, var v) { TRACE("expr2var", tout << "interpreted:\n" << mk_ismt2_pp(n, m()) << "\n";); m_interpreted_vars = true; } - m().inc_ref(n); - m_mapping.insert(n, v); + unsigned idx = m_id2map.get(n->get_id(), UINT_MAX); + if (idx == UINT_MAX) { + m().inc_ref(n); + idx = m_mapping.size(); + m_mapping.push_back(key_value(n, v)); + m_id2map.setx(n->get_id(), idx, UINT_MAX); + } + else { + m_mapping[idx] = key_value(n, v); + } + m_recent_exprs.push_back(n); } @@ -40,20 +49,22 @@ expr2var::expr2var(ast_manager & m): } expr2var::~expr2var() { - dec_ref_map_keys(m(), m_mapping); + for (auto & kv : m_mapping) { + m().dec_ref(kv.m_key); + } } expr2var::var expr2var::to_var(expr * n) const { - var v = UINT_MAX; - m_mapping.find(n, v); + var v = m_id2map.get(n->get_id(), UINT_MAX); + if (v != UINT_MAX) { + v = m_mapping[v].m_value; + } return v; } void expr2var::display(std::ostream & out) const { - obj_map::iterator it = m_mapping.begin(); - obj_map::iterator end = m_mapping.end(); - for (; it != end; ++it) { - out << mk_ismt2_pp(it->m_key, m()) << " -> " << it->m_value << "\n"; + for (auto const& kv : m_mapping) { + out << mk_ismt2_pp(kv.m_key, m()) << " -> " << kv.m_value << "\n"; } } @@ -68,8 +79,11 @@ void expr2var::mk_inv(expr_ref_vector & var2expr) const { } void expr2var::reset() { - dec_ref_map_keys(m(), m_mapping); - SASSERT(m_mapping.empty()); + for (auto & kv : m_mapping) { + m().dec_ref(kv.m_key); + } + m_mapping.reset(); + m_id2map.reset(); m_recent_exprs.reset(); m_recent_lim.reset(); m_interpreted_vars = false; @@ -83,8 +97,15 @@ void expr2var::pop(unsigned num_scopes) { if (num_scopes > 0) { unsigned sz = m_recent_lim[m_recent_lim.size() - num_scopes]; for (unsigned i = sz; i < m_recent_exprs.size(); ++i) { - m_mapping.erase(m_recent_exprs[i]); - m().dec_ref(m_recent_exprs[i]); + expr* n = m_recent_exprs[i]; + unsigned idx = m_id2map[n->get_id()]; + if (idx + 1 != m_mapping.size()) { + m_id2map[m_mapping.back().m_key->get_id()] = idx; + m_mapping[idx] = m_mapping.back(); + } + m_id2map[n->get_id()] = UINT_MAX; + m_mapping.pop_back(); + m().dec_ref(n); } m_recent_exprs.shrink(sz); m_recent_lim.shrink(m_recent_lim.size() - num_scopes); diff --git a/src/ast/expr2var.h b/src/ast/expr2var.h index 2b4d8c3fe..2bf2fe160 100644 --- a/src/ast/expr2var.h +++ b/src/ast/expr2var.h @@ -32,12 +32,14 @@ Notes: class expr2var { public: typedef unsigned var; - typedef obj_map expr2var_mapping; - typedef expr2var_mapping::iterator iterator; + typedef obj_map::key_data key_value; + typedef key_value const* iterator; typedef ptr_vector::const_iterator recent_iterator; protected: ast_manager & m_manager; - expr2var_mapping m_mapping; + + unsigned_vector m_id2map; + svector m_mapping; ptr_vector m_recent_exprs; unsigned_vector m_recent_lim; bool m_interpreted_vars; @@ -51,7 +53,7 @@ public: var to_var(expr * n) const; - bool is_var(expr * n) const { return m_mapping.contains(n); } + bool is_var(expr * n) const { return m_id2map.get(n->get_id(), UINT_MAX) != UINT_MAX; } void display(std::ostream & out) const; diff --git a/src/sat/sat_cleaner.cpp b/src/sat/sat_cleaner.cpp index e13a117fd..4a3fb82b4 100644 --- a/src/sat/sat_cleaner.cpp +++ b/src/sat/sat_cleaner.cpp @@ -78,6 +78,7 @@ namespace sat { } void cleaner::cleanup_clauses(clause_vector & cs) { + tmp_clause tmp; clause_vector::iterator it = cs.begin(); clause_vector::iterator it2 = it; clause_vector::iterator end = cs.end(); @@ -88,12 +89,10 @@ namespace sat { CTRACE("sat_cleaner_frozen", c.frozen(), tout << c << "\n";); unsigned sz = c.size(); unsigned i = 0, j = 0; - bool sat = false; m_cleanup_counter += sz; for (; i < sz; i++) { switch (s.value(c[i])) { case l_true: - sat = true; goto end_loop; case l_false: m_elim_literals++; @@ -108,9 +107,9 @@ namespace sat { } end_loop: CTRACE("sat_cleaner_frozen", c.frozen(), - tout << "sat: " << sat << ", new_size: " << j << "\n"; + tout << "sat: " << (i < sz) << ", new_size: " << j << "\n"; tout << mk_lits_pp(j, c.begin()) << "\n";); - if (sat) { + if (i < sz) { m_elim_clauses++; s.del_clause(c); } @@ -119,33 +118,37 @@ namespace sat { CTRACE("sat_cleaner_bug", new_sz < 2, tout << "new_sz: " << new_sz << "\n"; if (c.size() > 0) tout << "unit: " << c[0] << "\n"; s.display_watches(tout);); - if (new_sz == 0) { + switch (new_sz) { + case 0: s.set_conflict(justification()); s.del_clause(c); - } - else if (new_sz == 1) { + break; + case 1: s.assign(c[0], justification()); s.del_clause(c); - } - else { + break; + case 2: SASSERT(s.value(c[0]) == l_undef && s.value(c[1]) == l_undef); - if (new_sz == 2) { - TRACE("cleanup_bug", tout << "clause became binary: " << c[0] << " " << c[1] << "\n";); - s.mk_bin_clause(c[0], c[1], c.is_learned()); - s.del_clause(c); + TRACE("cleanup_bug", tout << "clause became binary: " << c[0] << " " << c[1] << "\n";); + s.mk_bin_clause(c[0], c[1], c.is_learned()); + s.del_clause(c); + break; + default: + SASSERT(s.value(c[0]) == l_undef && s.value(c[1]) == l_undef); + if (s.m_config.m_drat && new_sz < i) { + tmp.set(c.size(), c.begin(), c.is_learned()); } - else { - c.shrink(new_sz); - *it2 = *it; - it2++; - if (!c.frozen()) { - s.attach_clause(c); - } - if (s.m_config.m_drat) { - // for optimization, could also report deletion - // of previous version of clause. - s.m_drat.add(c, true); - } + c.shrink(new_sz); + *it2 = *it; + it2++; + if (!c.frozen()) { + s.attach_clause(c); + } + if (s.m_config.m_drat && new_sz < i) { + // for optimization, could also report deletion + // of previous version of clause. + s.m_drat.add(c, true); + s.m_drat.del(*tmp.get()); } } } diff --git a/src/sat/sat_simplifier.cpp b/src/sat/sat_simplifier.cpp index ca9b71a2d..4feeb4b05 100644 --- a/src/sat/sat_simplifier.cpp +++ b/src/sat/sat_simplifier.cpp @@ -349,7 +349,7 @@ namespace sat { } if (sz == 2) { s.mk_bin_clause(c[0], c[1], c.is_learned()); - s.del_clause(c); + s.del_clause(c, false); continue; } *it2 = *it; @@ -611,10 +611,15 @@ namespace sat { break; } } - if (j < sz) { - if (s.m_config.m_drat) s.m_drat.del(c); + if (j < sz && !r) { + if (s.m_config.m_drat) { + m_dummy.set(c.size(), c.begin(), c.is_learned()); + } c.shrink(j); - if (s.m_config.m_drat) s.m_drat.add(c, true); + if (s.m_config.m_drat) { + s.m_drat.add(c, true); + s.m_drat.del(*m_dummy.get()); + } } return r; } @@ -2020,8 +2025,7 @@ namespace sat { for (auto & c2 : m_neg_cls) { m_new_cls.reset(); if (!resolve(c1, c2, pos_l, m_new_cls)) - continue; - if (false && v == 767) IF_VERBOSE(0, verbose_stream() << "elim: " << c1 << " + " << c2 << " -> " << m_new_cls << "\n"); + continue; TRACE("resolution_new_cls", tout << c1 << "\n" << c2 << "\n-->\n" << m_new_cls << "\n";); if (cleanup_clause(m_new_cls)) continue; // clause is already satisfied. diff --git a/src/sat/sat_solver/inc_sat_solver.cpp b/src/sat/sat_solver/inc_sat_solver.cpp index 13780529a..5eb613ac2 100644 --- a/src/sat/sat_solver/inc_sat_solver.cpp +++ b/src/sat/sat_solver/inc_sat_solver.cpp @@ -60,6 +60,7 @@ class inc_sat_solver : public solver { atom2bool_var m_map; scoped_ptr m_bb_rewriter; tactic_ref m_preprocess; + bool m_is_cnf; unsigned m_num_scopes; sat::literal_vector m_asms; goal_ref_buffer m_subgoals; @@ -88,6 +89,7 @@ public: m_fmls_head(0), m_core(m), m_map(m), + m_is_cnf(true), m_num_scopes(0), m_unknown("no reason given"), m_internalized_converted(false), @@ -262,9 +264,22 @@ public: void assert_expr_core2(expr * t, expr * a) override { if (a) { m_asmsf.push_back(a); - assert_expr_core(m.mk_implies(a, t)); + if (m_is_cnf && is_literal(t) && is_literal(a)) { + assert_expr_core(m.mk_or(::mk_not(m, a), t)); + } + else if (m_is_cnf && m.is_or(t) && is_clause(t) && is_literal(a)) { + expr_ref_vector args(m); + args.push_back(::mk_not(m, a)); + args.append(to_app(t)->get_num_args(), to_app(t)->get_args()); + assert_expr_core(m.mk_or(args.size(), args.c_ptr())); + } + else { + m_is_cnf = false; + assert_expr_core(m.mk_implies(a, t)); + } } else { + m_is_cnf &= is_clause(t); assert_expr_core(t); } } @@ -545,7 +560,12 @@ private: SASSERT(!g->proofs_enabled()); TRACE("sat", m_solver.display(tout); g->display(tout);); try { - (*m_preprocess)(g, m_subgoals); + if (m_is_cnf) { + m_subgoals.push_back(g.get()); + } + else { + (*m_preprocess)(g, m_subgoals); + } } catch (tactic_exception & ex) { IF_VERBOSE(0, verbose_stream() << "exception in tactic " << ex.msg() << "\n";); @@ -705,6 +725,25 @@ private: } } + bool is_literal(expr* n) { + return is_uninterp_const(n) || (m.is_not(n, n) && is_uninterp_const(n)); + } + + bool is_clause(expr* fml) { + if (is_literal(fml)) { + return true; + } + if (!m.is_or(fml)) { + return false; + } + for (expr* n : *to_app(fml)) { + if (!is_literal(n)) { + return false; + } + } + return true; + } + lbool internalize_formulas() { if (m_fmls_head == m_fmls.size()) { return l_true; @@ -712,7 +751,8 @@ private: dep2asm_t dep2asm; goal_ref g = alloc(goal, m, true, false); // models, maybe cores are enabled for (unsigned i = m_fmls_head ; i < m_fmls.size(); ++i) { - g->assert_expr(m_fmls[i].get()); + expr* fml = m_fmls.get(i); + g->assert_expr(fml); } lbool res = internalize_goal(g, dep2asm, false); if (res != l_undef) { diff --git a/src/sat/tactic/atom2bool_var.cpp b/src/sat/tactic/atom2bool_var.cpp index b79eaa251..cd6ab776d 100644 --- a/src/sat/tactic/atom2bool_var.cpp +++ b/src/sat/tactic/atom2bool_var.cpp @@ -38,9 +38,13 @@ void atom2bool_var::mk_var_inv(app_ref_vector & var2expr) const { } sat::bool_var atom2bool_var::to_bool_var(expr * n) const { - sat::bool_var v = sat::null_bool_var; - m_mapping.find(n, v); - return v; + unsigned idx = m_id2map.get(n->get_id(), UINT_MAX); + if (idx == UINT_MAX) { + return sat::null_bool_var; + } + else { + return m_mapping[idx].m_value; + } } struct collect_boolean_interface_proc { diff --git a/src/smt/theory_pb.cpp b/src/smt/theory_pb.cpp index 182324832..dea698f3f 100644 --- a/src/smt/theory_pb.cpp +++ b/src/smt/theory_pb.cpp @@ -966,9 +966,9 @@ namespace smt { justification* js = nullptr; c.inc_propagations(*this); if (!resolve_conflict(c, lits)) { - if (proofs_enabled()) { - js = alloc(theory_lemma_justification, get_id(), ctx, lits.size(), lits.c_ptr()); - } + if (proofs_enabled()) { + js = alloc(theory_lemma_justification, get_id(), ctx, lits.size(), lits.c_ptr()); + } ctx.mk_clause(lits.size(), lits.c_ptr(), js, CLS_AUX_LEMMA, nullptr); } SASSERT(ctx.inconsistent()); @@ -1195,7 +1195,9 @@ namespace smt { // perform unit propagation if (maxsum >= c.mpz_k() && maxsum - mininc < c.mpz_k()) { literal_vector& lits = get_unhelpful_literals(c, true); + // for (literal lit : lits) SASSERT(ctx.get_assignment(lit) == l_true); lits.push_back(c.lit()); + // SASSERT(ctx.get_assignment(c.lit()) == l_true); for (unsigned i = 0; i < sz; ++i) { literal lit = c.lit(i); if (ctx.get_assignment(lit) == l_undef) {