From a6da207b652384d3a0328846140aa169b54c2252 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sat, 11 Nov 2017 11:25:43 -0800 Subject: [PATCH] fix crash bugs in sat solver Signed-off-by: Nikolaj Bjorner --- src/api/c++/z3++.h | 87 +++++++++++++++++++++++++++++++++++- src/sat/sat_asymm_branch.cpp | 1 - src/sat/sat_clause.cpp | 16 +++++-- src/sat/sat_clause.h | 1 + src/sat/sat_justification.h | 2 +- src/sat/sat_simplifier.cpp | 9 +--- src/sat/sat_solver.cpp | 36 ++++++++------- src/sat/sat_types.h | 2 +- src/sat/sat_watched.cpp | 20 ++++----- src/sat/sat_watched.h | 2 + 10 files changed, 133 insertions(+), 43 deletions(-) diff --git a/src/api/c++/z3++.h b/src/api/c++/z3++.h index c397271e3..52671a073 100644 --- a/src/api/c++/z3++.h +++ b/src/api/c++/z3++.h @@ -421,6 +421,7 @@ namespace z3 { void set(char const * k, unsigned n) { Z3_params_set_uint(ctx(), m_params, ctx().str_symbol(k), n); } void set(char const * k, double n) { Z3_params_set_double(ctx(), m_params, ctx().str_symbol(k), n); } void set(char const * k, symbol const & s) { Z3_params_set_symbol(ctx(), m_params, ctx().str_symbol(k), s); } + void set(char const * k, char const* s) { Z3_params_set_symbol(ctx(), m_params, ctx().str_symbol(k), ctx().str_symbol(s)); } friend std::ostream & operator<<(std::ostream & out, params const & p); }; @@ -1508,6 +1509,11 @@ namespace z3 { m_vector = s.m_vector; return *this; } + bool contains(T const& x) const { + for (auto y : *this) if (x == y) return true; + return false; + } + class iterator { ast_vector_tpl const* m_vector; unsigned m_index; @@ -1907,6 +1913,11 @@ namespace z3 { return *this; } void set(params const & p) { Z3_solver_set_params(ctx(), m_solver, p); check_error(); } + void set(char const * k, bool v) { params p(ctx()); p.set(k, v); set(p); } + void set(char const * k, unsigned v) { params p(ctx()); p.set(k, v); set(p); } + void set(char const * k, double v) { params p(ctx()); p.set(k, v); set(p); } + void set(char const * k, symbol const & v) { params p(ctx()); p.set(k, v); set(p); } + void set(char const * k, char const* v) { params p(ctx()); p.set(k, v); set(p); } void push() { Z3_solver_push(ctx(), m_solver); check_error(); } void pop(unsigned n = 1) { Z3_solver_pop(ctx(), m_solver, n); check_error(); } void reset() { Z3_solver_reset(ctx(), m_solver); check_error(); } @@ -1919,6 +1930,8 @@ namespace z3 { void add(expr const & e, char const * p) { add(e, ctx().bool_const(p)); } + void add(expr_vector const& v) { check_context(*this, v); for (expr e : v) add(e); } + void from_file(char const* file) { Z3_solver_from_file(ctx(), m_solver, file); check_error(); } check_result check() { Z3_lbool r = Z3_solver_check(ctx(), m_solver); check_error(); return to_check_result(r); } check_result check(unsigned n, expr * const assumptions) { array _assumptions(n); @@ -1976,6 +1989,78 @@ namespace z3 { param_descrs get_param_descrs() { return param_descrs(ctx(), Z3_solver_get_param_descrs(ctx(), m_solver)); } + + expr_vector cube(unsigned cutoff) { Z3_ast_vector r = Z3_solver_cube(ctx(), m_solver, cutoff); check_error(); return expr_vector(ctx(), r); } + + class cube_iterator { + solver& m_solver; + unsigned& m_cutoff; + expr_vector m_cube; + bool m_end; + + bool is_false() const { return m_cube.size() == 1 && Z3_OP_FALSE == m_cube[0].decl().decl_kind(); } + + void check_end() { + if (is_false()) { + m_cube = z3::expr_vector(m_solver.ctx()); + m_end = true; + } + else if (m_cube.empty()) { + m_end = true; + } + } + + void inc() { + assert(!m_end); + m_cube = m_solver.cube(m_cutoff); + m_cutoff = 0xFFFFFFFF; + check_end(); + } + public: + cube_iterator(solver& s, unsigned& cutoff, bool end): + m_solver(s), + m_cutoff(cutoff), + m_cube(s.ctx()), + m_end(end) { + if (!m_end) { + inc(); + } + } + + cube_iterator& operator++() { + assert(!m_end); + inc(); + return *this; + } + cube_iterator operator++(int) { assert(false); return *this; } + expr_vector const * operator->() const { return &(operator*()); } + expr_vector const& operator*() const { return m_cube; } + + bool operator==(cube_iterator const& other) { + return other.m_end == m_end; + }; + bool operator!=(cube_iterator const& other) { + return other.m_end != m_end; + }; + + }; + + class cube_generator { + solver& m_solver; + unsigned m_cutoff; + public: + cube_generator(solver& s): + m_solver(s), + m_cutoff(0xFFFFFFFF) + {} + + cube_iterator begin() { return cube_iterator(m_solver, m_cutoff, false); } + cube_iterator end() { return cube_iterator(m_solver, m_cutoff, true); } + void set_cutoff(unsigned c) { m_cutoff = c; } + }; + + cube_generator cubes() { return cube_generator(*this); } + }; inline std::ostream & operator<<(std::ostream & out, solver const & s) { out << Z3_solver_to_string(s.ctx(), s); return out; } @@ -1999,7 +2084,7 @@ namespace z3 { return *this; } void add(expr const & f) { check_context(*this, f); Z3_goal_assert(ctx(), m_goal, f); check_error(); } - void add(expr_vector const& v) { check_context(*this, v); for (expr e : v) add(v); } + void add(expr_vector const& v) { check_context(*this, v); for (expr e : v) add(e); } unsigned size() const { return Z3_goal_size(ctx(), m_goal); } expr operator[](int i) const { assert(0 <= i); Z3_ast r = Z3_goal_formula(ctx(), m_goal, i); check_error(); return expr(ctx(), r); } Z3_goal_prec precision() const { return Z3_goal_precision(ctx(), m_goal); } diff --git a/src/sat/sat_asymm_branch.cpp b/src/sat/sat_asymm_branch.cpp index 029274c97..83574749e 100644 --- a/src/sat/sat_asymm_branch.cpp +++ b/src/sat/sat_asymm_branch.cpp @@ -201,7 +201,6 @@ namespace sat { return false; default: c.shrink(new_sz); - s.attach_clause(c); if (s.m_config.m_drat) s.m_drat.add(c, true); // if (s.m_config.m_drat) s.m_drat.del(c0); // TBD SASSERT(s.m_qhead == s.m_trail.size()); diff --git a/src/sat/sat_clause.cpp b/src/sat/sat_clause.cpp index 43d614e38..9bfd1d38d 100644 --- a/src/sat/sat_clause.cpp +++ b/src/sat/sat_clause.cpp @@ -129,20 +129,23 @@ namespace sat { } clause * clause_allocator::get_clause(clause_offset cls_off) const { -#if defined(_AMD64_) +#if 0 +// defined(_AMD64_) if (((cls_off & c_alignment_mask) == c_last_segment)) { unsigned id = cls_off >> c_cls_alignment; return const_cast(m_last_seg_id2cls[id]); } return reinterpret_cast(m_segments[cls_off & c_alignment_mask] + (static_cast(cls_off) & ~c_alignment_mask)); #else + VERIFY(cls_off == reinterpret_cast(reinterpret_cast(cls_off))); return reinterpret_cast(cls_off); #endif } clause_offset clause_allocator::get_offset(clause const * cls) const { -#if defined(_AMD64_) +#if 0 +// defined(_AMD64_) size_t ptr = reinterpret_cast(cls); SASSERT((ptr & c_alignment_mask) == 0); @@ -163,6 +166,7 @@ namespace sat { return static_cast(reinterpret_cast(cls)) + i; } #else + VERIFY(cls == reinterpret_cast(reinterpret_cast(cls))); return reinterpret_cast(cls); #endif } @@ -178,9 +182,13 @@ namespace sat { void clause_allocator::del_clause(clause * cls) { TRACE("sat_clause", tout << "delete: " << cls->id() << " " << *cls << "\n";); + if (cls->id() == 62805 && cls->capacity() == 29) { + std::cout << "delete 62805\n"; + for (literal l : *cls) { + std::cout << l << "\n"; + } + } m_id_gen.recycle(cls->id()); -#if defined(_AMD64_) -#endif size_t size = clause::get_obj_size(cls->m_capacity); cls->~clause(); m_allocator.deallocate(size, cls); diff --git a/src/sat/sat_clause.h b/src/sat/sat_clause.h index dd7af64eb..08fff7adb 100644 --- a/src/sat/sat_clause.h +++ b/src/sat/sat_clause.h @@ -62,6 +62,7 @@ namespace sat { public: unsigned id() const { return m_id; } unsigned size() const { return m_size; } + unsigned capacity() const { return m_capacity; } literal & operator[](unsigned idx) { SASSERT(idx < m_size); return m_lits[idx]; } literal const & operator[](unsigned idx) const { SASSERT(idx < m_size); return m_lits[idx]; } bool is_learned() const { return m_learned; } diff --git a/src/sat/sat_justification.h b/src/sat/sat_justification.h index b8b3dcbdc..497d636c8 100644 --- a/src/sat/sat_justification.h +++ b/src/sat/sat_justification.h @@ -48,7 +48,7 @@ namespace sat { literal get_literal2() const { SASSERT(is_ternary_clause()); return to_literal(m_val2 >> 3); } bool is_clause() const { return m_val2 == CLAUSE; } - clause_offset get_clause_offset() const { return val1(); } + clause_offset get_clause_offset() const { return m_val1; } bool is_ext_justification() const { return m_val2 == EXT_JUSTIFICATION; } ext_justification_idx get_ext_justification_idx() const { return m_val1; } diff --git a/src/sat/sat_simplifier.cpp b/src/sat/sat_simplifier.cpp index 2cc5758d3..89b07ddd4 100644 --- a/src/sat/sat_simplifier.cpp +++ b/src/sat/sat_simplifier.cpp @@ -268,19 +268,12 @@ namespace sat { bool vars_eliminated = m_num_elim_vars > m_old_num_elim_vars; - if (m_need_cleanup) { + if (m_need_cleanup || vars_eliminated) { TRACE("after_simplifier", tout << "cleanning watches...\n";); cleanup_watches(); cleanup_clauses(s.m_learned, true, vars_eliminated, m_learned_in_use_lists); cleanup_clauses(s.m_clauses, false, vars_eliminated, true); } - else { - TRACE("after_simplifier", tout << "skipping cleanup...\n";); - if (vars_eliminated) { - // must remove learned clauses with eliminated variables - cleanup_clauses(s.m_learned, true, true, m_learned_in_use_lists); - } - } CASSERT("sat_solver", s.check_invariant()); TRACE("after_simplifier", s.display(tout); tout << "model_converter:\n"; s.m_mc.display(tout);); diff --git a/src/sat/sat_solver.cpp b/src/sat/sat_solver.cpp index 61dc1f6c2..d2a1b8b16 100644 --- a/src/sat/sat_solver.cpp +++ b/src/sat/sat_solver.cpp @@ -417,6 +417,9 @@ namespace sat { } unsigned some_idx = c.size() >> 1; literal block_lit = c[some_idx]; + DEBUG_CODE(for (auto const& w : m_watches[(~c[0]).index()]) VERIFY(!w.is_clause() || w.get_clause_offset() != cls_off);); + DEBUG_CODE(for (auto const& w : m_watches[(~c[1]).index()]) VERIFY(!w.is_clause() || w.get_clause_offset() != cls_off);); + VERIFY(c[0] != c[1]); m_watches[(~c[0]).index()].push_back(watched(block_lit, cls_off)); m_watches[(~c[1]).index()].push_back(watched(block_lit, cls_off)); return reinit; @@ -563,6 +566,9 @@ namespace sat { void solver::detach_nary_clause(clause & c) { clause_offset cls_off = get_offset(c); + if (c.id() == 62805 && c.capacity() == 29) { + std::cout << "detach: " << c[0] << " " << c[1] << " size: " << c.size() << " cap: " << c.capacity() << " id: " << c.id() << "\n"; + } erase_clause_watch(get_wlist(~c[0]), cls_off); erase_clause_watch(get_wlist(~c[1]), cls_off); } @@ -753,18 +759,19 @@ namespace sat { it2++; break; } - SASSERT(c[1] == not_l); if (value(c[0]) == l_true) { it2->set_clause(c[0], cls_off); it2++; break; } + VERIFY(c[1] == not_l); literal * l_it = c.begin() + 2; literal * l_end = c.end(); for (; l_it != l_end; ++l_it) { if (value(*l_it) != l_false) { c[1] = *l_it; *l_it = not_l; + DEBUG_CODE(for (auto const& w : m_watches[(~c[1]).index()]) VERIFY(!w.is_clause() || w.get_clause_offset() != cls_off);); m_watches[(~c[1]).index()].push_back(watched(c[0], cls_off)); goto end_clause_case; } @@ -1555,7 +1562,7 @@ namespace sat { if (!check_clauses(m_model)) { - std::cout << "failure checking clauses on transformed model\n"; + IF_VERBOSE(0, verbose_stream() << "failure checking clauses on transformed model\n";); UNREACHABLE(); throw solver_exception("check model failed"); } @@ -3267,29 +3274,24 @@ namespace sat { for (unsigned i = 0; i < m_trail.size(); i++) { out << max_weight << " " << dimacs_lit(m_trail[i]) << " 0\n"; } - vector::const_iterator it = m_watches.begin(); - vector::const_iterator end = m_watches.end(); - for (unsigned l_idx = 0; it != end; ++it, ++l_idx) { + unsigned l_idx = 0; + for (watch_list const& wlist : m_watches) { literal l = ~to_literal(l_idx); - watch_list const & wlist = *it; - watch_list::const_iterator it2 = wlist.begin(); - watch_list::const_iterator end2 = wlist.end(); - for (; it2 != end2; ++it2) { - if (it2->is_binary_clause() && l.index() < it2->get_literal().index()) - out << max_weight << " " << dimacs_lit(l) << " " << dimacs_lit(it2->get_literal()) << " 0\n"; + for (watched const& w : wlist) { + if (w.is_binary_clause() && l.index() < w.get_literal().index()) + out << max_weight << " " << dimacs_lit(l) << " " << dimacs_lit(w.get_literal()) << " 0\n"; } + ++l_idx; } clause_vector const * vs[2] = { &m_clauses, &m_learned }; for (unsigned i = 0; i < 2; i++) { clause_vector const & cs = *(vs[i]); - clause_vector::const_iterator it = cs.begin(); - clause_vector::const_iterator end = cs.end(); - for (; it != end; ++it) { - clause const & c = *(*it); + for (clause const* cp : cs) { + clause const & c = *cp; unsigned clsz = c.size(); out << max_weight << " "; - for (unsigned j = 0; j < clsz; j++) - out << dimacs_lit(c[j]) << " "; + for (literal l : c) + out << dimacs_lit(l) << " "; out << "0\n"; } } diff --git a/src/sat/sat_types.h b/src/sat/sat_types.h index b2d29495c..5eded92ec 100644 --- a/src/sat/sat_types.h +++ b/src/sat/sat_types.h @@ -109,7 +109,7 @@ namespace sat { typedef svector literal_vector; typedef std::pair literal_pair; - typedef unsigned clause_offset; + typedef size_t clause_offset; typedef size_t ext_constraint_idx; typedef size_t ext_justification_idx; diff --git a/src/sat/sat_watched.cpp b/src/sat/sat_watched.cpp index d890a8ff7..af4fd598e 100644 --- a/src/sat/sat_watched.cpp +++ b/src/sat/sat_watched.cpp @@ -27,8 +27,9 @@ namespace sat { for (; it != end; ++it) { if (it->is_clause() && it->get_clause_offset() == c) { watch_list::iterator it2 = it; - ++it; + ++it; for (; it != end; ++it) { + SASSERT(!((it->is_clause() && it->get_clause_offset() == c))); *it2 = *it; ++it2; } @@ -40,27 +41,26 @@ namespace sat { } std::ostream& display_watch_list(std::ostream & out, clause_allocator const & ca, watch_list const & wlist) { - watch_list::const_iterator it = wlist.begin(); - watch_list::const_iterator end = wlist.end(); - for (bool first = true; it != end; ++it) { + bool first = true; + for (watched const& w : wlist) { if (first) first = false; else out << " "; - switch (it->get_kind()) { + switch (w.get_kind()) { case watched::BINARY: - out << it->get_literal(); - if (it->is_learned()) + out << w.get_literal(); + if (w.is_learned()) out << "*"; break; case watched::TERNARY: - out << "(" << it->get_literal1() << " " << it->get_literal2() << ")"; + out << "(" << w.get_literal1() << " " << w.get_literal2() << ")"; break; case watched::CLAUSE: - out << "(" << it->get_blocked_literal() << " " << *(ca.get_clause(it->get_clause_offset())) << ")"; + out << "(" << w.get_blocked_literal() << " " << *(ca.get_clause(w.get_clause_offset())) << ")"; break; case watched::EXT_CONSTRAINT: - out << "ext: " << it->get_ext_constraint_idx(); + out << "ext: " << w.get_ext_constraint_idx(); break; default: UNREACHABLE(); diff --git a/src/sat/sat_watched.h b/src/sat/sat_watched.h index 8d7924c6f..e2d814f5b 100644 --- a/src/sat/sat_watched.h +++ b/src/sat/sat_watched.h @@ -64,6 +64,8 @@ namespace sat { SASSERT(get_literal2() == l2); } + unsigned val2() const { return m_val2; } + watched(literal blocked_lit, clause_offset cls_off): m_val1(cls_off), m_val2(static_cast(CLAUSE) + (blocked_lit.to_uint() << 2)) {