From 796e2fd9eb8891a04a70fc0879a231a508479e13 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sun, 13 Sep 2020 19:29:59 -0700 Subject: [PATCH] arrays (#4684) * arrays Signed-off-by: Nikolaj Bjorner * arrays Signed-off-by: Nikolaj Bjorner * na Signed-off-by: Nikolaj Bjorner * arrays Signed-off-by: Nikolaj Bjorner * na Signed-off-by: Nikolaj Bjorner * fill Signed-off-by: Nikolaj Bjorner * update drat and fix euf bugs Signed-off-by: Nikolaj Bjorner * na Signed-off-by: Nikolaj Bjorner * na Signed-off-by: Nikolaj Bjorner * na Signed-off-by: Nikolaj Bjorner * const qualifiers Signed-off-by: Nikolaj Bjorner * na Signed-off-by: Nikolaj Bjorner * reorg ba Signed-off-by: Nikolaj Bjorner * reorg Signed-off-by: Nikolaj Bjorner * build warnings Signed-off-by: Nikolaj Bjorner --- src/ast/ast.cpp | 9 +- src/ast/euf/euf_egraph.cpp | 73 +- src/ast/euf/euf_egraph.h | 13 +- src/ast/euf/euf_enode.cpp | 19 +- src/ast/euf/euf_enode.h | 1 + src/ast/euf/euf_etable.cpp | 3 - src/ast/euf/euf_justification.h | 37 +- src/muz/base/dl_rule.h | 10 +- src/muz/spacer/spacer_context.cpp | 2 +- src/opt/opt_context.cpp | 2 +- src/sat/dimacs.cpp | 17 +- src/sat/dimacs.h | 20 +- src/sat/sat_binspr.h | 5 +- src/sat/sat_drat.cpp | 96 ++- src/sat/sat_drat.h | 9 +- src/sat/sat_elim_eqs.cpp | 15 +- src/sat/sat_extension.h | 14 +- src/sat/sat_lookahead.cpp | 2 +- src/sat/sat_lookahead.h | 6 +- src/sat/sat_lut_finder.h | 4 +- src/sat/sat_parallel.h | 4 +- src/sat/sat_probing.cpp | 7 + src/sat/sat_simplifier.h | 4 +- src/sat/sat_solver.cpp | 75 +- src/sat/sat_solver.h | 5 +- src/sat/sat_types.h | 12 + src/sat/smt/CMakeLists.txt | 4 + src/sat/smt/array_axioms.cpp | 129 ++-- src/sat/smt/array_internalize.cpp | 71 +- src/sat/smt/array_model.cpp | 71 +- src/sat/smt/array_solver.cpp | 265 ++++--- src/sat/smt/array_solver.h | 45 +- src/sat/smt/ba_card.cpp | 290 ++++++++ src/sat/smt/ba_card.h | 70 ++ src/sat/smt/ba_constraint.cpp | 58 ++ src/sat/smt/ba_constraint.h | 143 ++++ src/sat/smt/ba_internalize.cpp | 12 +- src/sat/smt/ba_pb.cpp | 308 ++++++++ src/sat/smt/ba_pb.h | 67 ++ src/sat/smt/ba_solver.cpp | 954 ++++--------------------- src/sat/smt/ba_solver.h | 273 ++----- src/sat/smt/ba_solver_interface.h | 52 ++ src/sat/smt/ba_xor.cpp | 192 +++++ src/sat/smt/ba_xor.h | 53 ++ src/sat/smt/bv_solver.cpp | 4 +- src/sat/smt/euf_ackerman.cpp | 17 +- src/sat/smt/euf_internalize.cpp | 22 +- src/sat/smt/euf_invariant.cpp | 11 + src/sat/smt/euf_proof.cpp | 28 +- src/sat/smt/euf_solver.cpp | 69 +- src/sat/smt/euf_solver.h | 27 +- src/sat/smt/sat_th.cpp | 34 +- src/sat/smt/sat_th.h | 20 +- src/sat/smt/xor_solver.cpp | 157 +--- src/sat/tactic/goal2sat.cpp | 160 +++-- src/shell/drat_frontend.cpp | 163 +++-- src/smt/dyn_ack.cpp | 6 +- src/smt/dyn_ack.h | 4 +- src/smt/fingerprints.h | 10 +- src/smt/params/dyn_ack_params.cpp | 2 +- src/smt/params/dyn_ack_params.h | 4 +- src/smt/params/smt_params.h | 10 +- src/smt/params/theory_arith_params.cpp | 6 +- src/smt/params/theory_arith_params.h | 12 +- src/smt/params/theory_array_params.h | 2 +- src/smt/params/theory_bv_params.h | 2 +- src/smt/seq_regex.h | 2 +- src/smt/smt_setup.cpp | 30 +- src/smt/theory_arith.h | 2 +- src/smt/theory_arith_aux.h | 4 +- src/smt/theory_arith_core.h | 8 +- src/smt/theory_diff_logic_def.h | 4 +- src/smt/theory_lra.cpp | 20 +- src/smt/theory_seq.cpp | 2 +- src/test/egraph.cpp | 2 + src/test/tbv.cpp | 4 +- src/util/cmd_context_types.h | 2 +- src/util/sexpr.cpp | 36 +- src/util/sexpr.h | 14 +- 79 files changed, 2571 insertions(+), 1850 deletions(-) create mode 100644 src/sat/smt/ba_card.cpp create mode 100644 src/sat/smt/ba_card.h create mode 100644 src/sat/smt/ba_constraint.cpp create mode 100644 src/sat/smt/ba_constraint.h create mode 100644 src/sat/smt/ba_pb.cpp create mode 100644 src/sat/smt/ba_pb.h create mode 100644 src/sat/smt/ba_solver_interface.h create mode 100644 src/sat/smt/ba_xor.cpp create mode 100644 src/sat/smt/ba_xor.h diff --git a/src/ast/ast.cpp b/src/ast/ast.cpp index 3e6646299..6d210c92d 100644 --- a/src/ast/ast.cpp +++ b/src/ast/ast.cpp @@ -1785,13 +1785,12 @@ bool ast_manager::slow_not_contains(ast const * n) { } #endif -#if 0 static unsigned s_count = 0; +#if 0 static void track_id(ast_manager& m, ast* n, unsigned id) { if (n->get_id() != id) return; ++s_count; - std::cout << &m << " " << s_count << "\n"; - SASSERT(s_count != 240); + TRACE("ast", tout << s_count << "\n";); } #endif @@ -1825,9 +1824,9 @@ ast * ast_manager::register_node_core(ast * n) { n->m_id = is_decl(n) ? m_decl_id_gen.mk() : m_expr_id_gen.mk(); - // track_id(*this, n, 254); +// track_id(*this, n, 3); - TRACE("ast", tout << "Object " << n->m_id << " was created.\n";); + TRACE("ast", tout << (s_count++) << " Object " << n->m_id << " was created.\n";); TRACE("mk_var_bug", tout << "mk_ast: " << n->m_id << "\n";); // increment reference counters switch (n->get_kind()) { diff --git a/src/ast/euf/euf_egraph.cpp b/src/ast/euf/euf_egraph.cpp index 39e97d7c2..ab3dd0b65 100644 --- a/src/ast/euf/euf_egraph.cpp +++ b/src/ast/euf/euf_egraph.cpp @@ -81,8 +81,10 @@ namespace euf { m_scopes.push_back(m_updates.size()); m_region.push_scope(); } - m_updates.push_back(update_record(m_new_lits_qhead, update_record::new_lits_qhead())); m_updates.push_back(update_record(m_new_th_eqs_qhead, update_record::new_th_eq_qhead())); + m_updates.push_back(update_record(m_new_lits_qhead, update_record::new_lits_qhead())); + SASSERT(m_new_lits_qhead <= m_new_lits.size()); + SASSERT(m_new_th_eqs_qhead <= m_new_th_eqs.size()); } void egraph::update_children(enode* n) { @@ -107,17 +109,23 @@ namespace euf { return n; } enode_bool_pair p = m_table.insert(n); - enode* r = p.first; - if (r == n) { + enode* n2 = p.first; + if (n2 == n) { update_children(n); } else { - SASSERT(r->get_expr() != n->get_expr()); - merge_justification(n, r, justification::congruence(p.second)); - std::swap(n->m_next, r->m_next); - n->m_root = r; - r->inc_class_size(n->class_size()); - push_eq(n, n, r->num_parents()); + merge(n, n2, justification::congruence(p.second)); +#if 0 + SASSERT(n2->get_expr() != n->get_expr()); + SASSERT(n->class_size() == 1); + SASSERT(n->is_root()); + merge_justification(n, n2, justification::congruence(p.second)); + enode* r2 = n2->get_root(); + std::swap(n->m_next, r2->m_next); + n->m_root = r2; + r2->inc_class_size(n->class_size()); + push_eq(n, n, r2->num_parents()); +#endif } return n; } @@ -128,12 +136,14 @@ namespace euf { } void egraph::add_th_eq(theory_id id, theory_var v1, theory_var v2, enode* c, enode* r) { + TRACE("euf_verbose", tout << "eq: " << v1 << " == " << v2 << "\n";); m_new_th_eqs.push_back(th_eq(id, v1, v2, c, r)); m_updates.push_back(update_record(update_record::new_th_eq())); ++m_stats.m_num_th_eqs; } void egraph::add_literal(enode* n, bool is_eq) { + TRACE("euf_verbose", tout << "lit: " << n->get_expr_id() << "\n";); m_new_lits.push_back(enode_bool_pair(n, is_eq)); m_updates.push_back(update_record(update_record::new_lit())); if (is_eq) ++m_stats.m_num_eqs; else ++m_stats.m_num_lits; @@ -186,6 +196,9 @@ namespace euf { return; } num_scopes -= m_num_scopes; + m_num_scopes = 0; + + SASSERT(m_new_lits_qhead <= m_new_lits.size()); unsigned old_lim = m_scopes.size() - num_scopes; unsigned num_updates = m_scopes[old_lim]; auto undo_node = [&](enode* n) { @@ -240,16 +253,19 @@ namespace euf { m_scopes.shrink(old_lim); m_region.pop_scope(num_scopes); m_worklist.reset(); + SASSERT(m_new_lits_qhead <= m_new_lits.size()); + SASSERT(m_new_th_eqs_qhead <= m_new_th_eqs.size()); } - void egraph::merge(enode* n1, enode* n2, justification j) { + void egraph::merge(enode* n1, enode* n2, justification j) { SASSERT(m.get_sort(n1->get_expr()) == m.get_sort(n2->get_expr())); - TRACE("euf", tout << n1->get_expr_id() << " == " << n2->get_expr_id() << "\n";); - force_push(); enode* r1 = n1->get_root(); enode* r2 = n2->get_root(); if (r1 == r2) return; + TRACE("euf", j.display(tout << n1->get_expr_id() << " == " << n2->get_expr_id() << " ", m_display_justification) << "\n";); + force_push(); + SASSERT(m_num_scopes == 0); ++m_stats.m_num_merge; if (r1->interpreted() && r2->interpreted()) { set_conflict(n1, n2, j); @@ -294,6 +310,7 @@ namespace euf { } bool egraph::propagate() { + SASSERT(m_new_lits_qhead <= m_new_lits.size()); SASSERT(m_num_scopes == 0 || m_worklist.empty()); unsigned head = 0, tail = m_worklist.size(); while (head < tail && m.limit().inc() && !inconsistent()) { @@ -311,6 +328,7 @@ namespace euf { tail = m_worklist.size(); } m_worklist.reset(); + force_push(); return (m_new_lits_qhead < m_new_lits.size()) || (m_new_th_eqs_qhead < m_new_th_eqs.size()) || @@ -329,17 +347,27 @@ namespace euf { } void egraph::merge_justification(enode* n1, enode* n2, justification j) { + SASSERT(!n1->get_root()->m_target); + SASSERT(!n2->get_root()->m_target); SASSERT(n1->reaches(n1->get_root())); + SASSERT(!n2->reaches(n1->get_root())); + SASSERT(!n2->reaches(n1)); n1->reverse_justification(); n1->m_target = n2; n1->m_justification = j; + SASSERT(n1->acyclic()); + SASSERT(n2->acyclic()); SASSERT(n1->get_root()->reaches(n1)); + SASSERT(!n2->get_root()->m_target); + TRACE("euf_verbose", tout << "merge " << n1->get_expr_id() << " " << n2->get_expr_id() << " updates: " << m_updates.size() << "\n";); } void egraph::unmerge_justification(enode* n1) { + TRACE("euf_verbose", tout << "unmerge " << n1->get_expr_id() << " " << n1->m_target->get_expr_id() << "\n";); // r1 -> .. -> n1 -> n2 -> ... -> r2 // where n2 = n1->m_target SASSERT(n1->get_root()->reaches(n1)); + SASSERT(n1->m_target); n1->m_target = nullptr; n1->m_justification = justification::axiom(); n1->get_root()->reverse_justification(); @@ -347,7 +375,7 @@ namespace euf { // n1 -> ... -> r1 // n2 -> ... -> r2 SASSERT(n1->reaches(n1->get_root())); - SASSERT(n1->get_root()->m_target == nullptr); + SASSERT(!n1->get_root()->m_target); } /** @@ -359,8 +387,9 @@ namespace euf { void egraph::push_congruence(enode* n1, enode* n2, bool comm) { SASSERT(is_app(n1->get_expr())); SASSERT(n1->get_decl() == n2->get_decl()); - if (m_used_cc) + if (m_used_cc && !comm) { m_used_cc(to_app(n1->get_expr()), to_app(n2->get_expr())); + } if (comm && n1->get_arg(0)->get_root() == n2->get_arg(1)->get_root() && n1->get_arg(1)->get_root() == n2->get_arg(0)->get_root()) { @@ -409,6 +438,7 @@ namespace euf { void egraph::end_explain() { for (enode* n : m_todo) n->unmark1(); + DEBUG_CODE(for (enode* n : m_nodes) SASSERT(!n->is_marked1());); m_todo.reset(); } @@ -424,7 +454,12 @@ namespace euf { template void egraph::explain_eq(ptr_vector& justifications, enode* a, enode* b) { SASSERT(a->get_root() == b->get_root()); + enode* lca = find_lca(a, b); + TRACE("euf_verbose", tout << "explain-eq: " << a->get_expr_id() << " = " << b->get_expr_id() + << ": " << mk_bounded_pp(a->get_expr(), m) + << " == " << mk_bounded_pp(b->get_expr(), m) + << " lca: " << mk_bounded_pp(lca->get_expr(), m) << "\n";); push_to_lca(a, lca); push_to_lca(b, lca); if (m_used_eq) @@ -438,6 +473,7 @@ namespace euf { enode* n = m_todo[i]; if (n->m_target && !n->is_marked1()) { n->mark1(); + CTRACE("euf", m_display_justification, n->m_justification.display(tout << n->get_expr_id() << " = " << n->m_target->get_expr_id() << " ", m_display_justification) << "\n";); explain_eq(justifications, n, n->m_target, n->m_justification); } } @@ -461,21 +497,26 @@ namespace euf { out << "v:" << f->get_id(); out << "\n"; if (!n->m_parents.empty()) { - out << " "; + out << " parents "; for (enode* p : enode_parents(n)) out << p->get_expr_id() << " "; out << "\n"; } if (n->has_th_vars()) { - out << " "; + out << " theories "; for (auto v : enode_th_vars(n)) out << v.get_id() << ":" << v.get_var() << " "; out << "\n"; } + if (n->m_target && m_display_justification) + n->m_justification.display(out << " = " << n->m_target->get_expr_id() << " j: ", m_display_justification) << "\n"; return out; } std::ostream& egraph::display(std::ostream& out) const { + out << "updates " << m_updates.size() << "\n"; + out << "newlits " << m_new_lits.size() << " qhead: " << m_new_lits_qhead << "\n"; + out << "neweqs " << m_new_th_eqs.size() << " qhead: " << m_new_th_eqs_qhead << "\n"; m_table.display(out); unsigned max_args = 0; for (enode* n : m_nodes) diff --git a/src/ast/euf/euf_egraph.h b/src/ast/euf/euf_egraph.h index 581ddba94..e47f3d304 100644 --- a/src/ast/euf/euf_egraph.h +++ b/src/ast/euf/euf_egraph.h @@ -125,7 +125,8 @@ namespace euf { enode_vector m_todo; stats m_stats; std::function m_used_eq; - std::function m_used_cc; + std::function m_used_cc; + std::function m_display_justification; void push_eq(enode* r1, enode* n1, unsigned r2_num_parents) { m_updates.push_back(update_record(r1, n1, r2_num_parents)); @@ -151,6 +152,7 @@ namespace euf { void push_to_lca(enode* a, enode* lca); void push_congruence(enode* n1, enode* n2, bool commutative); void push_todo(enode* n); + template void explain_eq(ptr_vector& justifications, enode* a, enode* b, justification const& j) { if (j.is_external()) @@ -200,8 +202,8 @@ namespace euf { bool has_th_eq() const { return m_new_th_eqs_qhead < m_new_th_eqs.size(); } enode_bool_pair get_literal() const { return m_new_lits[m_new_lits_qhead]; } th_eq get_th_eq() const { return m_new_th_eqs[m_new_th_eqs_qhead]; } - void next_literal() { SASSERT(m_new_lits_qhead < m_new_lits.size()); m_new_lits_qhead++; } - void next_th_eq() { SASSERT(m_new_th_eqs_qhead < m_new_th_eqs.size()); m_new_th_eqs_qhead++; } + void next_literal() { force_push(); SASSERT(m_new_lits_qhead < m_new_lits.size()); m_new_lits_qhead++; } + void next_th_eq() { force_push(); SASSERT(m_new_th_eqs_qhead < m_new_th_eqs.size()); m_new_th_eqs_qhead++; } void add_th_var(enode* n, theory_var v, theory_id id); @@ -209,7 +211,8 @@ namespace euf { void set_used_eq(std::function& used_eq) { m_used_eq = used_eq; } void set_used_cc(std::function& used_cc) { m_used_cc = used_cc; } - + void set_display_justification(std::function & d) { m_display_justification = d; } + void begin_explain(); void end_explain(); template @@ -228,6 +231,8 @@ namespace euf { e_pp pp(enode* n) const { return e_pp(*this, n); } std::ostream& display(std::ostream& out) const; void collect_statistics(statistics& st) const; + + unsigned num_scopes() const { return m_scopes.size() + m_num_scopes; } }; inline std::ostream& operator<<(std::ostream& out, egraph const& g) { return g.display(out); } diff --git a/src/ast/euf/euf_enode.cpp b/src/ast/euf/euf_enode.cpp index b481749fe..538f2174c 100644 --- a/src/ast/euf/euf_enode.cpp +++ b/src/ast/euf/euf_enode.cpp @@ -31,7 +31,8 @@ namespace euf { VERIFY(found_root); VERIFY(found_this); VERIFY(this != m_root || class_size == m_class_size); - if (this == m_root) { + if (is_root()) { + VERIFY(!m_target); for (enode* p : enode_parents(this)) { bool found = false; for (enode* arg : enode_args(p)) { @@ -69,8 +70,24 @@ namespace euf { return true; } + bool enode::acyclic() const { + enode const* n = this; + enode const* p = this; + while (n) { + n = n->m_target; + if (n) { + p = p->m_target; + n = n->m_target; + } + if (n == p) + return false; + } + return true; + } + bool enode::reaches(enode* n) const { enode const* r = this; + SASSERT(acyclic()); while (r) { if (r == n) return true; diff --git a/src/ast/euf/euf_enode.h b/src/ast/euf/euf_enode.h index 68a2048dc..5bc4fff6e 100644 --- a/src/ast/euf/euf_enode.h +++ b/src/ast/euf/euf_enode.h @@ -156,6 +156,7 @@ namespace euf { void reverse_justification(); bool reaches(enode* n) const; + bool acyclic() const; enode* const* begin_parents() const { return m_parents.begin(); } enode* const* end_parents() const { return m_parents.end(); } diff --git a/src/ast/euf/euf_etable.cpp b/src/ast/euf/euf_etable.cpp index 704c0c398..83446c024 100644 --- a/src/ast/euf/euf_etable.cpp +++ b/src/ast/euf/euf_etable.cpp @@ -215,8 +215,6 @@ namespace euf { return enode_bool_pair(n_prime, false); case BINARY: n_prime = UNTAG(binary_table*, t)->insert_if_not_there(n); - TRACE("euf", tout << "insert: " << n->get_expr_id() << " " << cg_binary_hash()(n) << " inserted: " << (n == n_prime) << " " << n_prime->get_expr_id() << "\n"; - display_binary(tout, t); tout << "contains_ptr: " << contains_ptr(n) << "\n";); return enode_bool_pair(n_prime, false); case BINARY_COMM: m_commutativity = false; @@ -236,7 +234,6 @@ namespace euf { UNTAG(unary_table*, t)->erase(n); break; case BINARY: - TRACE("euf", tout << "erase: " << n->get_expr_id() << " " << cg_binary_hash()(n) << " contains: " << contains_ptr(n) << "\n";); UNTAG(binary_table*, t)->erase(n); break; case BINARY_COMM: diff --git a/src/ast/euf/euf_justification.h b/src/ast/euf/euf_justification.h index 20e733de0..2241ff0b6 100644 --- a/src/ast/euf/euf_justification.h +++ b/src/ast/euf/euf_justification.h @@ -20,7 +20,7 @@ Author: namespace euf { class justification { - enum kind_t { + enum class kind_t { axiom_t, congruence_t, external_t @@ -29,20 +29,20 @@ namespace euf { bool m_comm; void* m_external; justification(bool comm): - m_kind(congruence_t), + m_kind(kind_t::congruence_t), m_comm(comm), m_external(nullptr) {} justification(void* ext): - m_kind(external_t), + m_kind(kind_t::external_t), m_comm(false), m_external(ext) {} public: justification(): - m_kind(axiom_t), + m_kind(kind_t::axiom_t), m_comm(false), m_external(nullptr) {} @@ -51,24 +51,43 @@ namespace euf { static justification congruence(bool c) { return justification(c); } static justification external(void* ext) { return justification(ext); } - bool is_external() const { return m_kind == external_t; } - bool is_congruence() const { return m_kind == congruence_t; } + bool is_external() const { return m_kind == kind_t::external_t; } + bool is_congruence() const { return m_kind == kind_t::congruence_t; } bool is_commutative() const { return m_comm; } template T* ext() const { SASSERT(is_external()); return static_cast(m_external); } justification copy(std::function& copy_justification) const { switch (m_kind) { - case external_t: + case kind_t::external_t: return external(copy_justification(m_external)); - case axiom_t: + case kind_t::axiom_t: return axiom(); - case congruence_t: + case kind_t::congruence_t: return congruence(m_comm); default: UNREACHABLE(); return axiom(); } } + + std::ostream& display(std::ostream& out, std::function const& ext) const { + switch (m_kind) { + case kind_t::external_t: + if (ext) + ext(out, m_external); + else + out << "external"; + return out; + case kind_t::axiom_t: + return out << "axiom"; + case kind_t::congruence_t: + return out << "congruence"; + default: + UNREACHABLE(); + return out; + } + return out; + } }; } diff --git a/src/muz/base/dl_rule.h b/src/muz/base/dl_rule.h index c3768899a..7d29b1362 100644 --- a/src/muz/base/dl_rule.h +++ b/src/muz/base/dl_rule.h @@ -287,13 +287,13 @@ namespace datalog { class rule : public accounted_object { friend class rule_manager; - app * m_head; - proof* m_proof; + app* m_head{ nullptr }; + proof* m_proof{ nullptr }; unsigned m_tail_size:20; // unsigned m_reserve:12; - unsigned m_ref_cnt; - unsigned m_positive_cnt; - unsigned m_uninterp_cnt; + unsigned m_ref_cnt{ 0 }; + unsigned m_positive_cnt{ 0 }; + unsigned m_uninterp_cnt{ 0 }; symbol m_name; /** The following field is an array of tagged pointers. diff --git a/src/muz/spacer/spacer_context.cpp b/src/muz/spacer/spacer_context.cpp index 423f0a71c..9e348519b 100644 --- a/src/muz/spacer/spacer_context.cpp +++ b/src/muz/spacer/spacer_context.cpp @@ -2613,7 +2613,7 @@ void context::init_global_smt_params() { m.toggle_proof_mode(PGM_ENABLED); params_ref p; if (!m_use_eq_prop) { - p.set_uint("arith.propagation_mode", BP_NONE); + p.set_uint("arith.propagation_mode", (unsigned)bound_prop_mode::BP_NONE); p.set_bool("arith.auto_config_simplex", true); p.set_bool("arith.propagate_eqs", false); p.set_bool("arith.eager_eq_axioms", false); diff --git a/src/opt/opt_context.cpp b/src/opt/opt_context.cpp index 0e6a4cf55..f354cbb03 100644 --- a/src/opt/opt_context.cpp +++ b/src/opt/opt_context.cpp @@ -666,7 +666,7 @@ namespace opt { opt_params p(m_params); if (p.optsmt_engine() == symbol("symba") || p.optsmt_engine() == symbol("farkas")) { - auto str = std::to_string(AS_OPTINF); + auto str = std::to_string((unsigned)(arith_solver_id::AS_OPTINF)); gparams::set("smt.arith.solver", str.c_str()); } } diff --git a/src/sat/dimacs.cpp b/src/sat/dimacs.cpp index 027186fb7..eb276b9e0 100644 --- a/src/sat/dimacs.cpp +++ b/src/sat/dimacs.cpp @@ -20,7 +20,6 @@ Revision History: #undef max #undef min #include "sat/sat_solver.h" -#include "sat/sat_drat.h" template static bool is_whitespace(Buffer & in) { @@ -147,13 +146,21 @@ bool parse_dimacs(std::istream & in, std::ostream& err, sat::solver & solver) { namespace dimacs { std::ostream& operator<<(std::ostream& out, drat_record const& r) { + std::function fn = [&](int th) { return symbol(th); }; + drat_pp pp(r, fn); + return out << pp; + } + + std::ostream& operator<<(std::ostream& out, drat_pp const& p) { + auto const& r = p.r; + sat::status_pp pp(r.m_status, p.th); switch (r.m_tag) { case drat_record::tag_t::is_clause: - return out << r.m_status << " " << r.m_lits << "\n"; + return out << pp << " " << r.m_lits << " 0\n"; case drat_record::tag_t::is_node: - return out << "e " << r.m_node_id << " " << r.m_name << " " << r.m_args << "\n"; - case drat_record::is_bool_def: - return out << "b " << r.m_node_id << " " << r.m_args << "\n"; + return out << "e " << r.m_node_id << " " << r.m_name << " " << r.m_args << "0\n"; + case drat_record::tag_t::is_bool_def: + return out << "b " << r.m_node_id << " " << r.m_args << "0\n"; } return out; } diff --git a/src/sat/dimacs.h b/src/sat/dimacs.h index eed385556..681f65b1f 100644 --- a/src/sat/dimacs.h +++ b/src/sat/dimacs.h @@ -53,23 +53,27 @@ namespace dimacs { }; struct drat_record { - enum tag_t { is_clause, is_node, is_bool_def }; - tag_t m_tag; + enum class tag_t { is_clause, is_node, is_bool_def }; + tag_t m_tag{ tag_t::is_clause }; // a clause populates m_lits and m_status // a node populates m_node_id, m_name, m_args // a bool def populates m_node_id and one element in m_args sat::literal_vector m_lits; - sat::status m_status; - unsigned m_node_id; + sat::status m_status{ sat::status::redundant() }; + unsigned m_node_id{ 0 }; std::string m_name; unsigned_vector m_args; - drat_record(): - m_tag(is_clause), - m_status(sat::status::redundant()) - {} + drat_record() {} + }; + + struct drat_pp { + drat_record const& r; + std::function& th; + drat_pp(drat_record const& r, std::function& th) : r(r), th(th) {} }; std::ostream& operator<<(std::ostream& out, drat_record const& r); + std::ostream& operator<<(std::ostream& out, drat_pp const& r); class drat_parser { dimacs::stream_buffer in; diff --git a/src/sat/sat_binspr.h b/src/sat/sat_binspr.h index 874d8687f..82a5a128c 100644 --- a/src/sat/sat_binspr.h +++ b/src/sat/sat_binspr.h @@ -92,7 +92,10 @@ namespace sat { public: - binspr(solver& s): m_solver(s), m_stopped_at(0), m_limit1(1000), m_limit2(300) {} + binspr(solver& s): m_solver(s), m_stopped_at(0), m_limit1(1000), m_limit2(300) { + memset(m_true, 0, sizeof(unsigned) * max_lits); + memset(m_false, 0, sizeof(unsigned) * max_lits); + } ~binspr() {} diff --git a/src/sat/sat_drat.cpp b/src/sat/sat_drat.cpp index b924f378a..ede975b50 100644 --- a/src/sat/sat_drat.cpp +++ b/src/sat/sat_drat.cpp @@ -35,7 +35,6 @@ namespace sat { m_activity(false) { if (s.get_config().m_drat && s.get_config().m_drat_file.is_non_empty_string()) { - std::cout << "DRAT " << s.get_config().m_drat_file << "\n"; auto mode = s.get_config().m_drat_binary ? (std::ios_base::binary | std::ios_base::out | std::ios_base::trunc) : std::ios_base::out; m_out = alloc(std::ofstream, s.get_config().m_drat_file.str(), mode); if (s.get_config().m_drat_binary) { @@ -87,17 +86,14 @@ namespace sat { return; if (m_activity && ((m_stats.m_num_add % 1000) == 0)) dump_activity(); - + char buffer[10000]; char digits[20]; // enough for storing unsigned char* lastd = digits + sizeof(digits); unsigned len = 0; - if (st.is_asserted()) { - buffer[len++] = 'a'; - buffer[len++] = ' '; - } - else if (st.is_deleted()) { + + if (st.is_deleted()) { buffer[len++] = 'd'; buffer[len++] = ' '; } @@ -105,9 +101,15 @@ namespace sat { buffer[len++] = 'i'; buffer[len++] = ' '; } - else if (st.is_redundant() && !st.is_sat()) { - buffer[len++] = 'r'; - buffer[len++] = ' '; + else if (!st.is_sat()) { + if (st.is_redundant()) { + buffer[len++] = 'r'; + buffer[len++] = ' '; + } + else if (st.is_asserted()) { + buffer[len++] = 'a'; + buffer[len++] = ' '; + } } if (!st.is_sat()) { @@ -261,7 +263,7 @@ namespace sat { } void drat::def_begin(unsigned n, std::string const& name) { - if (m_out) + if (m_out) (*m_out) << "e " << n << " " << name; } @@ -373,8 +375,33 @@ namespace sat { } } + bool drat::is_drup(unsigned n, literal const* c, literal_vector& units) { + if (m_inconsistent) + return true; + if (n == 0) + return false; + + unsigned num_units = m_units.size(); + for (unsigned i = 0; !m_inconsistent && i < n; ++i) { + declare(c[i]); + assign_propagate(~c[i]); + } + + for (unsigned i = num_units; i < m_units.size(); ++i) { + m_assignment[m_units[i].var()] = l_undef; + } + units.append(m_units.size() - num_units, m_units.c_ptr() + num_units); + m_units.shrink(num_units); + bool ok = m_inconsistent; + m_inconsistent = false; + return ok; + } + bool drat::is_drup(unsigned n, literal const* c) { - if (m_inconsistent || n == 0) return true; + if (m_inconsistent) + return true; + if (n == 0) + return false; unsigned num_units = m_units.size(); for (unsigned i = 0; !m_inconsistent && i < n; ++i) { assign_propagate(~c[i]); @@ -448,6 +475,7 @@ namespace sat { } bool drat::is_drat(unsigned n, literal const* c) { + return false; if (m_inconsistent || n == 0) return true; for (unsigned i = 0; i < n; ++i) @@ -486,7 +514,7 @@ namespace sat { clause& c = *m_proof[i]; unsigned j = 0; for (; j < c.size() && c[j] != ~l; ++j) {} - if (st.is_sat() && j != c.size()) { + if (j != c.size()) { lits.append(j, c.begin()); lits.append(c.size() - j - 1, c.begin() + j + 1); if (!is_drup(lits.size(), lits.c_ptr())) @@ -520,6 +548,8 @@ namespace sat { // s.display(std::cout); std::string line; std::getline(std::cin, line); + exit(0); +#if 0 SASSERT(false); INVOKE_DEBUGGER(); exit(0); @@ -530,6 +560,7 @@ namespace sat { display(tout); s.display(tout);); UNREACHABLE(); +#endif } bool drat::contains(literal c, justification const& j) { @@ -723,6 +754,7 @@ namespace sat { if (m_out) (*m_out) << "0\n"; if (m_bout) bdump(0, nullptr, status::redundant()); if (m_check_unsat) { + verify(0, nullptr); SASSERT(m_inconsistent); } } @@ -756,24 +788,29 @@ namespace sat { } } void drat::add(literal_vector const& lits, status st) { + add(lits.size(), lits.c_ptr(), st); + } + + void drat::add(unsigned sz, literal const* lits, status st) { if (st.is_deleted()) ++m_stats.m_num_del; else ++m_stats.m_num_add; if (m_check) { - switch (lits.size()) { + switch (sz) { case 0: add(); break; case 1: append(lits[0], st); break; default: { - clause* c = m_alloc.mk_clause(lits.size(), lits.c_ptr(), st.is_redundant()); + clause* c = m_alloc.mk_clause(sz, lits, st.is_redundant()); append(*c, st); break; } } } if (m_out) - dump(lits.size(), lits.c_ptr(), st); + dump(sz, lits, st); } + void drat::add(literal_vector const& c) { ++m_stats.m_num_add; if (m_out) dump(c.size(), c.begin(), status::redundant()); @@ -842,7 +879,21 @@ namespace sat { void drat::check_model(model const& m) { } - std::ostream& operator<<(std::ostream& out, status const& st) { + void drat::collect_statistics(statistics& st) const { + st.update("num-drup", m_stats.m_num_drup); + st.update("num-drat", m_stats.m_num_drat); + st.update("num-add", m_stats.m_num_add); + st.update("num-del", m_stats.m_num_del); + } + + + std::ostream& operator<<(std::ostream& out, sat::status const& st) { + std::function th = [&](int id) { return symbol(id); }; + return out << sat::status_pp(st, th); + } + + std::ostream& operator<<(std::ostream& out, sat::status_pp const& p) { + auto st = p.st; if (st.is_deleted()) out << "d"; else if (st.is_input()) @@ -852,15 +903,8 @@ namespace sat { else if (st.is_redundant() && !st.is_sat()) out << "r"; if (!st.is_sat()) - out << " th" << st.m_orig; + out << " " << p.th(st.get_th()); return out; - } - - void drat::collect_statistics(statistics& st) const { - st.update("num-drup", m_stats.m_num_drup); - st.update("num-drat", m_stats.m_num_drat); - st.update("num-add", m_stats.m_num_add); - st.update("num-del", m_stats.m_num_del); - } + } } diff --git a/src/sat/sat_drat.h b/src/sat/sat_drat.h index 57ffab472..d7d6477f2 100644 --- a/src/sat/sat_drat.h +++ b/src/sat/sat_drat.h @@ -123,6 +123,7 @@ namespace sat { void add(clause& c, status st); void add(literal_vector const& c, status st); void add(literal_vector const& c); // add learned clause + void add(unsigned sz, literal const* lits, status st); // support for SMT - connect Boolean variables with AST nodes // associate AST node id with Boolean variable v @@ -156,9 +157,13 @@ namespace sat { void check_model(model const& m); void collect_statistics(statistics& st) const; + + bool inconsistent() const { return m_inconsistent; } + literal_vector const& units() { return m_units; } + bool is_drup(unsigned n, literal const* c, literal_vector& units); + solver& get_solver() { return s; } }; - std::ostream& operator<<(std::ostream& out, status const& st); +} -}; diff --git a/src/sat/sat_elim_eqs.cpp b/src/sat/sat_elim_eqs.cpp index 94aeea996..1c243867d 100644 --- a/src/sat/sat_elim_eqs.cpp +++ b/src/sat/sat_elim_eqs.cpp @@ -63,22 +63,14 @@ namespace sat { // consume tautology continue; } -#if 0 - if (l1 != r1) { - // add half r1 => r2, the other half ~r2 => ~r1 is added when traversing l2 - m_solver.m_watches[(~r1).index()].push_back(watched(r2, it->is_learned())); - continue; - } - it->set_literal(r2); // keep it. -#else if (l1 != r1 || l2 != r2) { if (r1.index() < r2.index()) { + TRACE("elim_eqs", tout << l1 << " " << l2 << " " << r1 << " " << r2 << "\n";); m_new_bin.push_back(bin(r1, r2, it->is_learned())); } continue; } // keep it -#endif } *itprev = *it; itprev++; @@ -233,9 +225,10 @@ namespace sat { if (m_solver.m_cut_simplifier) m_solver.m_cut_simplifier->set_root(v, r); bool set_root = m_solver.set_root(l, r); bool root_ok = !m_solver.is_external(v) || set_root; + TRACE("elim_eqs", tout << l << " " << r << "\n";); if (m_solver.is_assumption(v) || (m_solver.is_external(v) && (m_solver.is_incremental() || !root_ok))) { // cannot really eliminate v, since we have to notify extension of future assignments - if (m_solver.m_config.m_drat && m_solver.m_config.m_drat_file.is_null()) { + if (m_solver.m_config.m_drat) { m_solver.m_drat.add(~l, r, sat::status::redundant()); m_solver.m_drat.add(l, ~r, sat::status::redundant()); } @@ -291,6 +284,7 @@ namespace sat { } void elim_eqs::operator()(union_find<>& uf) { + TRACE("elim_eqs", tout << "before union-find bin\n";); literal_vector roots(m_solver.num_vars(), null_literal); bool_var_vector to_elim; for (unsigned i = m_solver.num_vars(); i-- > 0; ) { @@ -299,6 +293,7 @@ namespace sat { if (idx != l1.index()) { roots[i] = to_literal(idx); to_elim.push_back(i); + TRACE("elim_eqs", tout << "remove " << roots[i] << "\n";); } else { roots[i] = l1; diff --git a/src/sat/sat_extension.h b/src/sat/sat_extension.h index 6db3816f0..71214055d 100644 --- a/src/sat/sat_extension.h +++ b/src/sat/sat_extension.h @@ -25,7 +25,7 @@ Revision History: namespace sat { - enum check_result { + enum class check_result { CR_DONE, CR_CONTINUE, CR_GIVEUP }; @@ -54,11 +54,21 @@ namespace sat { }; class extension { + protected: + bool m_drating { false }; + int m_id { 0 }; public: + extension(int id): m_id(id) {} virtual ~extension() {} - virtual unsigned get_id() const { return 0; } + virtual int get_id() const { return m_id; } virtual void set_solver(solver* s) = 0; virtual void set_lookahead(lookahead* s) {}; + class scoped_drating { + extension& ext; + public: + scoped_drating(extension& e) :ext(e) { ext.m_drating = true; } + ~scoped_drating() { ext.m_drating = false; } + }; virtual void init_search() {} virtual bool propagate(literal l, ext_constraint_idx idx) = 0; virtual bool unit_propagate() = 0; diff --git a/src/sat/sat_lookahead.cpp b/src/sat/sat_lookahead.cpp index 4884f0797..e8fb74051 100644 --- a/src/sat/sat_lookahead.cpp +++ b/src/sat/sat_lookahead.cpp @@ -1308,7 +1308,7 @@ namespace sat { } } break; - case lookahead2: + case lookahead_mode::lookahead2: // this could create a conflict from propagation, but we complete the loop. for (binary const& b : m_ternary[(~l).index()]) { if (sz-- == 0) break; diff --git a/src/sat/sat_lookahead.h b/src/sat/sat_lookahead.h index 1be2e91d5..4b11b5917 100644 --- a/src/sat/sat_lookahead.h +++ b/src/sat/sat_lookahead.h @@ -42,7 +42,7 @@ namespace sat { return out; } - enum lookahead_mode { + enum class lookahead_mode { searching, // normal search lookahead1, // lookahead mode lookahead2 // double lookahead @@ -100,8 +100,8 @@ namespace sat { m_delta_rho = (double)0.7; m_dl_max_iterations = 2; m_tc1_limit = 10000000; - m_reward_type = ternary_reward; - m_cube_cutoff = adaptive_freevars_cutoff; + m_reward_type = reward_t::ternary_reward; + m_cube_cutoff = cutoff_t::adaptive_freevars_cutoff; m_cube_depth = 10; m_cube_fraction = 0.4; m_cube_freevars = 0.8; diff --git a/src/sat/sat_lut_finder.h b/src/sat/sat_lut_finder.h index 2e227e0b5..62f99f7f7 100644 --- a/src/sat/sat_lut_finder.h +++ b/src/sat/sat_lut_finder.h @@ -66,7 +66,9 @@ namespace sat { std::ostream& display_mask(std::ostream& out, uint64_t mask, unsigned sz) const; public: - lut_finder(solver& s) : s(s), m_max_lut_size(5) { } + lut_finder(solver& s) : s(s), m_max_lut_size(5) { + memset(m_masks, 0, sizeof(uint64_t)*7); + } ~lut_finder() {} void set(std::function& f) { m_on_lut = f; } diff --git a/src/sat/sat_parallel.h b/src/sat/sat_parallel.h index 78befb4c0..a162f8f7b 100644 --- a/src/sat/sat_parallel.h +++ b/src/sat/sat_parallel.h @@ -32,8 +32,8 @@ namespace sat { // shared pool of learned clauses. class vector_pool { unsigned_vector m_vectors; - unsigned m_size; - unsigned m_tail; + unsigned m_size{ 0 }; + unsigned m_tail{ 0 }; unsigned_vector m_heads; bool_vector m_at_end; void next(unsigned& index); diff --git a/src/sat/sat_probing.cpp b/src/sat/sat_probing.cpp index c56fee1e0..a485b2a1c 100644 --- a/src/sat/sat_probing.cpp +++ b/src/sat/sat_probing.cpp @@ -82,12 +82,15 @@ namespace sat { else { m_to_assert.reset(); s.push(); + TRACE("sat", tout << "probing " << l << "\n";); s.assign_scoped(l); m_counter--; unsigned old_tr_sz = s.m_trail.size(); s.propagate(false); if (s.inconsistent()) { + TRACE("sat", tout << "probe failed: " << ~l << "\n";); // ~l must be true + s.drat_explain_conflict(); s.pop(1); s.assign_scoped(~l); s.propagate(false); @@ -125,10 +128,14 @@ namespace sat { s.push(); literal l(v, false); s.assign_scoped(l); + TRACE("sat", tout << "probing " << l << "\n";); unsigned old_tr_sz = s.m_trail.size(); s.propagate(false); if (s.inconsistent()) { // ~l must be true + TRACE("sat", tout << "probe failed: " << ~l << "\n"; + s.display(tout);); + s.drat_explain_conflict(); s.pop(1); s.assign_scoped(~l); s.propagate(false); diff --git a/src/sat/sat_simplifier.h b/src/sat/sat_simplifier.h index a0001d74e..8c847a658 100644 --- a/src/sat/sat_simplifier.h +++ b/src/sat/sat_simplifier.h @@ -126,7 +126,7 @@ namespace sat { void init_visited(); void mark_visited(literal l) { m_visited[l.index()] = true; } void unmark_visited(literal l) { m_visited[l.index()] = false; } - bool is_marked(literal l) const { return m_visited[l.index()] != 0; } + void mark_all_but(clause const & c, literal l); void unmark_all(clause const & c); @@ -239,6 +239,8 @@ namespace sat { void propagate_unit(literal l); void subsume(); + bool is_marked(literal l) const { return m_visited[l.index()] != 0; } + }; }; diff --git a/src/sat/sat_solver.cpp b/src/sat/sat_solver.cpp index e67c34b11..967761e31 100644 --- a/src/sat/sat_solver.cpp +++ b/src/sat/sat_solver.cpp @@ -249,7 +249,7 @@ namespace sat { m_watches[2*v+1].reset(); m_assignment[2*v] = l_undef; m_assignment[2*v+1] = l_undef; - m_justification[2*v] = justification(UINT_MAX); + m_justification[v] = justification(UINT_MAX); m_decision[v] = dvar; m_eliminated[v] = false; m_external[v] = ext; @@ -337,6 +337,8 @@ namespace sat { } void solver::set_eliminated(bool_var v, bool f) { + if (m_eliminated[v] && !f) + reset_var(v, m_external[v], m_decision[v]); m_eliminated[v] = f; } @@ -386,9 +388,29 @@ namespace sat { m_stats.m_del_clause++; } + void solver::drat_explain_conflict() { + if (m_config.m_drat && m_ext) { + extension::scoped_drating _sd(*m_ext); + bool unique_max; + m_conflict_lvl = get_max_lvl(m_not_l, m_conflict, unique_max); + resolve_conflict_for_unsat_core(); + } + } + + void solver::drat_log_unit(literal lit, justification j) { + extension::scoped_drating _sd(*m_ext.get()); + if (j.get_kind() == justification::EXT_JUSTIFICATION) + fill_ext_antecedents(lit, j, false); + m_drat.add(lit, m_searching); + } + + void solver::drat_log_clause(unsigned num_lits, literal const* lits, sat::status st) { + m_drat.add(num_lits, lits, st); + } + clause * solver::mk_clause_core(unsigned num_lits, literal * lits, sat::status st) { bool redundant = st.is_redundant(); - TRACE("sat", tout << "mk_clause: " << mk_lits_pp(num_lits, lits) << (redundant?" learned":" aux") << "\n";); + TRACE("sat", tout << "mk_clause: " << mk_lits_pp(num_lits, lits) << (redundant?" learned":" aux") << "\n";); if (!redundant || !st.is_sat()) { unsigned old_sz = num_lits; bool keep = simplify_clause(num_lits, lits); @@ -397,11 +419,9 @@ namespace sat { return nullptr; // clause is equivalent to true. } // if an input clause is simplified, then log the simplified version as learned - if (old_sz > num_lits && m_config.m_drat) { - m_lemma.reset(); - m_lemma.append(num_lits, lits); - m_drat.add(m_lemma, st); - } + if (m_config.m_drat && old_sz > num_lits) + drat_log_clause(num_lits, lits, st); + ++m_stats.m_non_learned_generation; if (!m_searching) { m_mc.add_clause(num_lits, lits); @@ -413,6 +433,8 @@ namespace sat { set_conflict(); return nullptr; case 1: + if (m_config.m_drat && (!st.is_sat() || st.is_input())) + drat_log_clause(num_lits, lits, st); assign_unit(lits[0]); return nullptr; case 2: @@ -493,7 +515,7 @@ namespace sat { VERIFY(ENABLE_TERNARY); m_stats.m_mk_ter_clause++; clause * r = alloc_clause(3, lits, st.is_redundant()); - bool reinit = attach_ter_clause(*r); + bool reinit = attach_ter_clause(*r, st); if (reinit && !st.is_redundant()) push_reinit_stack(*r); if (st.is_redundant()) m_learned.push_back(r); @@ -505,10 +527,10 @@ namespace sat { return r; } - bool solver::attach_ter_clause(clause & c) { + bool solver::attach_ter_clause(clause & c, sat::status st) { VERIFY(ENABLE_TERNARY); bool reinit = false; - if (m_config.m_drat) m_drat.add(c, c.is_learned() ? status::redundant() : status::asserted()); + if (m_config.m_drat) m_drat.add(c, st); TRACE("sat_verbose", tout << c << "\n";); SASSERT(!c.was_removed()); m_watches[(~c[0]).index()].push_back(watched(c[1], c[2])); @@ -604,7 +626,7 @@ namespace sat { SASSERT(c.size() > 2); reinit = false; if (ENABLE_TERNARY && c.size() == 3) - reinit = attach_ter_clause(c); + reinit = attach_ter_clause(c, c.is_learned() ? sat::status::redundant() : sat::status::asserted()); else reinit = attach_nary_clause(c); } @@ -890,7 +912,9 @@ namespace sat { SASSERT(value(l) == l_undef); TRACE("sat_assign_core", tout << l << " " << j << "\n";); if (j.level() == 0) { - if (m_config.m_drat) m_drat.add(l, m_searching); + if (m_config.m_drat) + drat_log_unit(l, j); + j = justification(0); // erase justification for level 0 } else { @@ -1666,12 +1690,12 @@ namespace sat { lbool solver::final_check() { if (m_ext) { switch (m_ext->check()) { - case CR_DONE: + case check_result::CR_DONE: mk_model(); return l_true; - case CR_CONTINUE: + case check_result::CR_CONTINUE: break; - case CR_GIVEUP: + case check_result::CR_GIVEUP: throw abort_solver(); } return l_undef; @@ -2630,8 +2654,9 @@ namespace sat { } if (m_conflict_lvl == 0) { - if (m_config.m_drat && m_ext) - resolve_conflict_for_unsat_core(); + drat_explain_conflict(); + if (m_config.m_drat) + drat_log_clause(0, nullptr, sat::status::redundant()); TRACE("sat", tout << "conflict level is 0\n";); return l_false; } @@ -2883,7 +2908,7 @@ namespace sat { break; } case justification::EXT_JUSTIFICATION: { - fill_ext_antecedents(consequent, js, true); + fill_ext_antecedents(consequent, js, false); for (literal l : m_ext_antecedents) { process_antecedent_for_unsat_core(l); } @@ -2896,7 +2921,7 @@ namespace sat { } void solver::resolve_conflict_for_unsat_core() { - TRACE("sat", display(tout); + TRACE("sat_verbose", display(tout); unsigned level = 0; for (literal l : m_trail) { if (level != lvl(l)) { @@ -2914,7 +2939,7 @@ namespace sat { ); m_core.reset(); - if (m_conflict_lvl == 0) { + if (!m_config.m_drat && m_conflict_lvl == 0) { return; } SASSERT(m_unmark.empty()); @@ -3044,6 +3069,7 @@ namespace sat { bool_var var = antecedent.var(); unsigned var_lvl = lvl(var); SASSERT(var < num_vars()); + TRACE("sat", tout << "process " << var << "@" << var_lvl << " marked " << is_marked(var) << " conflict " << m_conflict_lvl << "\n";); if (!is_marked(var) && var_lvl > 0) { mark(var); switch (m_config.m_branching_heuristic) { @@ -3717,11 +3743,14 @@ namespace sat { } for (literal lit : m_lemma) mark_visited(lit); + auto is_active = [&](bool_var v) { + return value(v) != l_undef && lvl(v) <= new_lvl; + }; unsigned sz = m_active_vars.size(), j = old_num_vars; for (unsigned i = old_num_vars; i < sz; ++i) { bool_var v = m_active_vars[i]; - if (is_visited(v)) { + if (is_visited(v) || is_active(v)) { m_vars_to_reinit.push_back(v); m_active_vars[j++] = v; } @@ -3731,7 +3760,9 @@ namespace sat { } } m_active_vars.shrink(j); - IF_VERBOSE(0, verbose_stream() << "vars to reinit: " << m_vars_to_reinit << " free vars " << m_free_vars << "\n"); + IF_VERBOSE(11, verbose_stream() << "vars to reinit: " << m_vars_to_reinit << " free vars " << m_free_vars << "\n"; + display(verbose_stream());); + } void solver::shrink_vars(unsigned v) { diff --git a/src/sat/sat_solver.h b/src/sat/sat_solver.h index 8ebef6497..1cc88bb91 100644 --- a/src/sat/sat_solver.h +++ b/src/sat/sat_solver.h @@ -284,7 +284,7 @@ namespace sat { void mk_bin_clause(literal l1, literal l2, bool learned) { mk_bin_clause(l1, l2, learned ? sat::status::redundant() : sat::status::asserted()); } bool propagate_bin_clause(literal l1, literal l2); clause * mk_ter_clause(literal * lits, status st); - bool attach_ter_clause(clause & c); + bool attach_ter_clause(clause & c, status st); clause * mk_nary_clause(unsigned num_lits, literal * lits, status st); bool attach_nary_clause(clause & c); void attach_clause(clause & c, bool & reinit); @@ -296,6 +296,9 @@ namespace sat { void add_ate(clause& c) { m_mc.add_ate(c); } void add_ate(literal l1, literal l2) { m_mc.add_ate(l1, l2); } void add_ate(literal_vector const& lits) { m_mc.add_ate(lits); } + void drat_log_unit(literal lit, justification j); + void drat_log_clause(unsigned sz, literal const* lits, status st); + void drat_explain_conflict(); class scoped_disable_checkpoint { solver& s; diff --git a/src/sat/sat_types.h b/src/sat/sat_types.h index 67863c8f3..fef9d74f1 100644 --- a/src/sat/sat_types.h +++ b/src/sat/sat_types.h @@ -26,6 +26,7 @@ Revision History: #include "util/vector.h" #include "util/uint_set.h" #include "util/stopwatch.h" +#include "util/symbol.h" class params_ref; class reslimit; @@ -278,5 +279,16 @@ namespace sat { int get_th() const { return m_orig; } }; + struct status_pp { + status const& st; + std::function& th; + status_pp(status const& st, std::function& th) : st(st), th(th) {} + }; + + std::ostream& operator<<(std::ostream& out, sat::status const& st); + std::ostream& operator<<(std::ostream& out, sat::status_pp const& p); + }; + + diff --git a/src/sat/smt/CMakeLists.txt b/src/sat/smt/CMakeLists.txt index 775a35186..747e7752c 100644 --- a/src/sat/smt/CMakeLists.txt +++ b/src/sat/smt/CMakeLists.txt @@ -5,8 +5,12 @@ z3_add_component(sat_smt array_model.cpp array_solver.cpp atom2bool_var.cpp + ba_card.cpp + ba_constraint.cpp ba_internalize.cpp + ba_pb.cpp ba_solver.cpp + ba_xor.cpp bv_ackerman.cpp bv_internalize.cpp bv_solver.cpp diff --git a/src/sat/smt/array_axioms.cpp b/src/sat/smt/array_axioms.cpp index 21326764e..5c48b22cf 100644 --- a/src/sat/smt/array_axioms.cpp +++ b/src/sat/smt/array_axioms.cpp @@ -27,6 +27,8 @@ namespace array { m_axiom_trail.push_back(r); if (m_axioms.contains(idx)) m_axiom_trail.pop_back(); + else + ctx.push(push_back_vector>(m_axiom_trail)); } bool solver::assert_axiom(unsigned idx) { @@ -39,10 +41,16 @@ namespace array { app* select; switch (r.m_kind) { case axiom_record::kind_t::is_store: + TRACE("array", tout << "store-axiom: " << mk_bounded_pp(child, m, 2) << "\n";); return assert_store_axiom(to_app(child)); case axiom_record::kind_t::is_select: select = r.select->get_app(); SASSERT(a.is_select(select)); + SASSERT(can_beta_reduce(r.n)); + TRACE("array", tout << "select-axiom: " << mk_bounded_pp(select, m, 2) << " " << mk_bounded_pp(child, m, 2) << "\n";); + if (r.select->get_arg(0)->get_root() != r.n->get_root()) { + IF_VERBOSE(0, verbose_stream() << "could delay " << mk_pp(select, m) << " " << mk_pp(child, m) << "\n"); + } if (a.is_const(child)) return assert_select_const_axiom(select, to_app(child)); else if (a.is_as_array(child)) @@ -57,20 +65,22 @@ namespace array { UNREACHABLE(); break; case axiom_record::kind_t::is_default: + SASSERT(can_beta_reduce(r.n)); + TRACE("array", tout << "default-axiom: " << mk_bounded_pp(child, m, 2) << "\n";); if (a.is_const(child)) return assert_default_const_axiom(to_app(child)); else if (a.is_store(child)) return assert_default_store_axiom(to_app(child)); else if (a.is_map(child)) return assert_default_map_axiom(to_app(child)); - else if (a.is_as_array(child)) - return assert_default_as_array_axiom(to_app(child)); else - UNREACHABLE(); + return true; break; case axiom_record::kind_t::is_extensionality: + TRACE("array", tout << "extensionality-axiom: " << mk_bounded_pp(child, m, 2) << "\n";); return assert_extensionality(r.n->get_arg(0)->get_expr(), r.n->get_arg(1)->get_expr()); case axiom_record::kind_t::is_congruence: + TRACE("array", tout << "congruence-axiom: " << mk_bounded_pp(child, m, 2) << " " << mk_bounded_pp(r.select->get_expr(), m, 2) << "\n";); return assert_congruent_axiom(child, r.select->get_expr()); default: UNREACHABLE(); @@ -86,7 +96,7 @@ namespace array { * n := store(a, i, v) */ bool solver::assert_store_axiom(app* e) { - m_stats.m_num_store_axiom++; + ++m_stats.m_num_store_axiom; SASSERT(a.is_store(e)); unsigned num_args = e->get_num_args(); ptr_vector sel_args(num_args - 1, e->get_args()); @@ -104,7 +114,7 @@ namespace array { * where i = (i_1, ..., i_n), j = (j_1, .., j_n), k in 1..n */ bool solver::assert_select_store_axiom(app* select, app* store) { - m_stats.m_num_select_store_axiom++; + ++m_stats.m_num_select_store_axiom; SASSERT(a.is_store(store)); SASSERT(a.is_select(select)); SASSERT(store->get_num_args() == 1 + select->get_num_args()); @@ -126,7 +136,10 @@ namespace array { if (s1->get_root() == s2->get_root()) return false; sat::literal sel_eq = b_internalize(sel_eq_e); + if (s().value(sel_eq) == l_true) + return false; + bool new_prop = false; for (unsigned i = 1; i < num_args; i++) { expr* idx1 = store->get_arg(i); expr* idx2 = select->get_arg(i); @@ -135,13 +148,15 @@ namespace array { if (r1 == r2) continue; if (m.are_distinct(r1->get_expr(), r2->get_expr())) { + new_prop = true; add_clause(sel_eq); break; } sat::literal idx_eq = b_internalize(m.mk_eq(idx1, idx2)); - add_clause(idx_eq, sel_eq); + if (add_clause(idx_eq, sel_eq)) + new_prop = true; } - return true; + return new_prop; } /** @@ -149,7 +164,7 @@ namespace array { * select(const(v), i) = v */ bool solver::assert_select_const_axiom(app* select, app* cnst) { - m_stats.m_num_select_const_axiom++; + ++m_stats.m_num_select_const_axiom; expr* val = nullptr; VERIFY(a.is_const(cnst, val)); SASSERT(a.is_select(select)); @@ -167,7 +182,7 @@ namespace array { * e1 = e2 or select(e1, diff(e1,e2)) != select(e2, diff(e1, e2)) */ bool solver::assert_extensionality(expr* e1, expr* e2) { - m_stats.m_num_extensionality_axiom++; + ++m_stats.m_num_extensionality_axiom; func_decl_ref_vector* funcs = nullptr; VERIFY(m_sort2diff.find(m.get_sort(e1), funcs)); expr_ref_vector args1(m), args2(m); @@ -184,10 +199,7 @@ namespace array { expr_ref sel1_eq_sel2(m.mk_eq(sel1, sel2), m); literal lit1 = b_internalize(n1_eq_n2); literal lit2 = b_internalize(sel1_eq_sel2); - if (s().value(lit1) == l_true || s().value(lit2) == l_false) - return false; - add_clause(lit1, ~lit2); - return true; + return add_clause(lit1, ~lit2); } /** @@ -195,17 +207,12 @@ namespace array { * select(map[f](a, ... d), i) = f(select(a,i),...,select(d,i)) */ bool solver::assert_select_map_axiom(app* select, app* map) { - m_stats.m_num_select_map_axiom++; + ++m_stats.m_num_select_map_axiom; SASSERT(a.is_map(map)); SASSERT(a.is_select(select)); SASSERT(map->get_num_args() > 0); func_decl* f = a.get_map_func_decl(map); - - TRACE("array", - tout << mk_bounded_pp(map, m) << "\n"; - tout << mk_bounded_pp(select, m) << "\n";); unsigned num_args = select->get_num_args(); - unsigned num_arrays = map->get_num_args(); ptr_buffer args1, args2; vector > args2l; args1.push_back(map); @@ -238,7 +245,7 @@ namespace array { * select(as-array f, i_1, ..., i_n) = (f i_1 ... i_n) */ bool solver::assert_select_as_array_axiom(app* select, app* arr) { - m_stats.m_num_select_as_array_axiom++; + ++m_stats.m_num_select_as_array_axiom; SASSERT(a.is_as_array(arr)); SASSERT(a.is_select(select)); unsigned num_args = select->get_num_args(); @@ -257,39 +264,31 @@ namespace array { * default(map[f](a,..,d)) = f(default(a),..,default(d)) */ bool solver::assert_default_map_axiom(app* map) { - m_stats.m_num_default_map_axiom++; + ++m_stats.m_num_default_map_axiom; SASSERT(a.is_map(map)); func_decl* f = a.get_map_func_decl(map); SASSERT(map->get_num_args() == f->get_arity()); - ptr_buffer args2; + expr_ref_vector args2(m); for (expr* arg : *map) args2.push_back(a.mk_default(arg)); - expr_ref def1(a.mk_default(map), m); expr_ref def2(m.mk_app(f, args2), m); rewrite(def2); return ctx.propagate(e_internalize(def1), e_internalize(def2), array_axiom()); } - /** * Assert: * default(const(e)) = e */ bool solver::assert_default_const_axiom(app* cnst) { - m_stats.m_num_default_const_axiom++; + ++m_stats.m_num_default_const_axiom; expr* val = nullptr; VERIFY(a.is_const(cnst, val)); - TRACE("array", tout << mk_bounded_pp(cnst, m) << "\n";); expr_ref def(a.mk_default(cnst), m); return ctx.propagate(expr2enode(val), e_internalize(def), array_axiom()); } - bool solver::assert_default_as_array_axiom(app* as_array) { - // no-op - return false; - } - /** * let n := store(a, i, v) @@ -303,19 +302,15 @@ namespace array { * default(n) = default(a) */ bool solver::assert_default_store_axiom(app* store) { - m_stats.m_num_default_store_axiom++; + ++m_stats.m_num_default_store_axiom; SASSERT(a.is_store(store)); SASSERT(store->get_num_args() >= 3); expr_ref def1(m), def2(m); bool prop = false; - unsigned num_args = store->get_num_args(); - def1 = a.mk_default(store); def2 = a.mk_default(store->get_arg(0)); - bool is_new = false; - if (has_unitary_domain(store)) { def2 = store->get_arg(num_args - 1); } @@ -357,6 +352,7 @@ namespace array { * Assert select(lambda xs . M, N1,.., Nk) -> M[N1/x1, ..., Nk/xk] */ bool solver::assert_select_lambda_axiom(app* select, expr* lambda) { + ++m_stats.m_num_select_lambda_axiom; SASSERT(is_lambda(lambda)); SASSERT(a.is_select(select)); SASSERT(m.get_sort(lambda) == m.get_sort(select->get_arg(0))); @@ -373,8 +369,8 @@ namespace array { */ bool solver::assert_congruent_axiom(expr* e1, expr* e2) { ++m_stats.m_num_congruence_axiom; - sort* s = m.get_sort(e1); - unsigned dimension = get_array_arity(s); + sort* srt = m.get_sort(e1); + unsigned dimension = get_array_arity(srt); expr_ref n1_eq_n2(m.mk_eq(e1, e2), m); expr_ref_vector args1(m), args2(m); args1.push_back(e1); @@ -382,10 +378,10 @@ namespace array { svector names; sort_ref_vector sorts(m); for (unsigned i = 0; i < dimension; i++) { - sort * srt = get_array_domain(s, i); - sorts.push_back(srt); + sort * asrt = get_array_domain(srt, i); + sorts.push_back(asrt); names.push_back(symbol(i)); - expr * k = m.mk_var(dimension - i - 1, srt); + expr * k = m.mk_var(dimension - i - 1, asrt); args1.push_back(k); args2.push_back(k); } @@ -395,8 +391,8 @@ namespace array { expr_ref q(m.mk_forall(dimension, sorts.c_ptr(), names.c_ptr(), eq), m); rewrite(q); sat::literal fa_eq = b_internalize(q); - add_clause(~b_internalize(n1_eq_n2), fa_eq); - return true; + sat::literal neq = b_internalize(n1_eq_n2); + return add_clause(~neq, fa_eq); } bool solver::has_unitary_domain(app* array_term) { @@ -411,7 +407,7 @@ namespace array { return true; } - bool solver::has_large_domain(app* array_term) { + bool solver::has_large_domain(expr* array_term) { SASSERT(a.is_array(array_term)); sort* s = m.get_sort(array_term); unsigned dim = get_array_arity(s); @@ -429,7 +425,6 @@ namespace array { return false; } - std::pair solver::mk_epsilon(sort* s) { app* eps = nullptr; func_decl* diag = nullptr; @@ -444,24 +439,16 @@ namespace array { return std::make_pair(eps, diag); } - void solver::push_parent_select_store_axioms(theory_var v) { - expr* e = var2expr(v); - if (!a.is_array(e)) - return; - auto& d = get_var_data(v); - for (euf::enode* store : d.m_parents) - if (a.is_store(store->get_expr())) - for (euf::enode* sel : d.m_parents) - if (a.is_select(sel->get_expr())) - push_axiom(select_axiom(sel, store)); - } - bool solver::add_delayed_axioms() { if (!get_config().m_array_delay_exp_axiom) return false; unsigned num_vars = get_num_vars(); - for (unsigned v = 0; v < num_vars; v++) - push_parent_select_store_axioms(v); + for (unsigned v = 0; v < num_vars; v++) { + propagate_parent_select_axioms(v); + auto& d = get_var_data(v); + if (d.m_prop_upward) + propagate_parent_default(v); + } return unit_propagate(); } @@ -471,13 +458,15 @@ namespace array { bool prop = false; for (unsigned i = roots.size(); i-- > 0; ) { theory_var v1 = roots[i]; - euf::enode* n1 = var2enode(v1); + expr* e1 = var2expr(v1); for (unsigned j = i; j-- > 0; ) { theory_var v2 = roots[j]; - euf::enode* n2 = var2enode(v2); - if (m.get_sort(n1->get_expr()) != m.get_sort(n2->get_expr())) + expr* e2 = var2expr(v2); + if (m.get_sort(e1) != m.get_sort(e2)) continue; - expr_ref eq(m.mk_eq(n1->get_expr(), n2->get_expr()), m); + if (have_different_model_values(v1, v2)) + continue; + expr_ref eq(m.mk_eq(e1, e2), m); sat::literal lit = b_internalize(eq); if (s().value(lit) == l_undef) prop = true; @@ -498,15 +487,10 @@ namespace array { if (r->is_marked1()) { continue; } - // arrays used as indices in other arrays have to be treated as shared. - // issue #3532, #3529 - // - if (ctx.is_shared(r) || is_select_arg(r)) { - TRACE("array", tout << "new shared var: #" << r->get_expr_id() << "\n";); - theory_var r_th_var = r->get_th_var(get_id()); - SASSERT(r_th_var != euf::null_theory_var); - roots.push_back(r_th_var); - } + // arrays used as indices in other arrays have to be treated as shared issue #3532, #3529 + if (ctx.is_shared(r) || is_select_arg(r)) + roots.push_back(r->get_th_var(get_id())); + r->mark1(); to_unmark.push_back(r); } @@ -516,6 +500,7 @@ namespace array { } bool solver::is_select_arg(euf::enode* r) { + SASSERT(r->is_root()); for (euf::enode* n : euf::enode_parents(r)) if (a.is_select(n->get_expr())) for (unsigned i = 1; i < n->num_args(); ++i) diff --git a/src/sat/smt/array_internalize.cpp b/src/sat/smt/array_internalize.cpp index 5006537af..c0d40727d 100644 --- a/src/sat/smt/array_internalize.cpp +++ b/src/sat/smt/array_internalize.cpp @@ -20,13 +20,15 @@ Author: namespace array { - sat::literal solver::internalize(expr* e, bool sign, bool root, bool learned) { - // TODO - return sat::null_literal; + sat::literal solver::internalize(expr* e, bool sign, bool root, bool redundant) { + SASSERT(m.is_bool(e)); + if (!visit_rec(m, e, sign, root, redundant)) + return sat::null_literal; + return expr2literal(e); } void solver::internalize(expr* e, bool redundant) { - // TODO + visit_rec(m, e, false, false, redundant); } euf::theory_var solver::mk_var(euf::enode* n) { @@ -39,8 +41,11 @@ namespace array { void solver::ensure_var(euf::enode* n) { theory_var v = n->get_th_var(get_id()); - if (v == euf::null_theory_var) + if (v == euf::null_theory_var) { mk_var(n); + if (is_lambda(n->get_expr())) + internalize_lambda(n); + } } void solver::apply_sort_cnstr(euf::enode * n, sort * s) { @@ -48,19 +53,30 @@ namespace array { } void solver::internalize_store(euf::enode* n) { - if (get_config().m_array_laziness == 0) - add_parent(n->get_arg(0), n); + add_parent_lambda(n->get_arg(0)->get_th_var(get_id()), n); push_axiom(store_axiom(n)); + add_lambda(n->get_th_var(get_id()), n); + SASSERT(!get_var_data(n->get_th_var(get_id())).m_prop_upward); + } + + void solver::internalize_map(euf::enode* n) { + for (auto* arg : euf::enode_args(n)) { + add_parent_lambda(arg->get_th_var(get_id()), n); + set_prop_upward(arg); + } + push_axiom(default_axiom(n)); + add_lambda(n->get_th_var(get_id()), n); + SASSERT(!get_var_data(n->get_th_var(get_id())).m_prop_upward); + } + + void solver::internalize_lambda(euf::enode* n) { + set_prop_upward(n); + push_axiom(default_axiom(n)); + add_lambda(n->get_th_var(get_id()), n); } void solver::internalize_select(euf::enode* n) { - if (get_config().m_array_laziness == 0) - add_parent(n->get_arg(0), n); - } - - void solver::internalize_const(euf::enode* n) { - push_axiom(default_axiom(n)); - set_prop_upward(n); + add_parent_select(n->get_arg(0)->get_th_var(get_id()), n); } void solver::internalize_ext(euf::enode* n) { @@ -68,24 +84,10 @@ namespace array { } void solver::internalize_default(euf::enode* n) { - add_parent(n->get_arg(0), n); + add_parent_default(n->get_arg(0)->get_th_var(get_id()), n); set_prop_upward(n); } - void solver::internalize_map(euf::enode* n) { - for (auto* arg : euf::enode_args(n)) { - add_parent(arg, n); - set_prop_upward(arg); - } - push_axiom(default_axiom(n)); - } - - void solver::internalize_as_array(euf::enode* n) { - // TBD: delay verdict whether model is undetermined - ctx.unhandled_function(n->get_decl()); - push_axiom(default_axiom(n)); - } - bool solver::visited(expr* e) { euf::enode* n = expr2enode(e); return n && n->is_attached_to(get_id()); @@ -94,7 +96,8 @@ namespace array { bool solver::visit(expr* e) { if (!is_app(e) || to_app(e)->get_family_id() != get_id()) { ctx.internalize(e, m_is_redundant); - ensure_var(expr2enode(e)); + euf::enode* n = expr2enode(e); + ensure_var(n); return true; } m_stack.push_back(sat::eframe(e)); @@ -108,7 +111,7 @@ namespace array { if (!n) n = mk_enode(e, false); SASSERT(!n->is_attached_to(get_id())); - theory_var v = mk_var(n); + mk_var(n); for (auto* arg : euf::enode_args(n)) ensure_var(arg); switch (a->get_decl_kind()) { @@ -118,8 +121,9 @@ namespace array { case OP_SELECT: internalize_select(n); break; + case OP_AS_ARRAY: case OP_CONST_ARRAY: - internalize_const(n); + internalize_lambda(n); break; case OP_ARRAY_EXT: internalize_ext(n); @@ -130,9 +134,6 @@ namespace array { case OP_ARRAY_MAP: internalize_map(n); break; - case OP_AS_ARRAY: - internalize_as_array(n); - break; case OP_SET_UNION: case OP_SET_INTERSECT: case OP_SET_DIFFERENCE: diff --git a/src/sat/smt/array_model.cpp b/src/sat/smt/array_model.cpp index 7e18dfa7f..5b56a0476 100644 --- a/src/sat/smt/array_model.cpp +++ b/src/sat/smt/array_model.cpp @@ -53,7 +53,7 @@ namespace array { mdl.register_decl(f, fi); for (euf::enode* p : euf::enode_parents(n)) { - if (!a.is_select(p->get_expr())) + if (!a.is_select(p->get_expr()) || p->get_arg(0)->get_root() != n->get_root()) continue; args.reset(); for (unsigned i = 1; i < p->num_args(); ++i) @@ -74,4 +74,73 @@ namespace array { values.set(n->get_root_id(), m.mk_app(get_id(), OP_AS_ARRAY, 1, &p)); } + + bool solver::have_different_model_values(theory_var v1, theory_var v2) { + euf::enode* else1 = nullptr, * else2 = nullptr; + euf::enode* n1 = var2enode(v1), *n2 = var2enode(v2); + euf::enode* r1 = n1->get_root(), * r2 = n2->get_root(); + expr* e1 = n1->get_expr(); + expr* e; + if (!a.is_array(e1)) + return true; + auto find_else = [&](theory_var v, euf::enode* r) { + var_data& d = get_var_data(find(v)); + for (euf::enode* c : d.m_lambdas) + if (a.is_const(c->get_expr(), e)) + return expr2enode(e)->get_root(); + for (euf::enode* p : euf::enode_parents(r)) + for (euf::enode* pe : euf::enode_class(p)) + if (a.is_default(pe->get_expr())) + return pe->get_root(); + return (euf::enode*)nullptr; + }; + else1 = find_else(v1, r1); + else2 = find_else(v2, r2); + if (else1 && else2 && else1->get_root() != else2->get_root() && has_large_domain(e1)) + return true; + struct eq { + solver& s; + eq(solver& s) :s(s) {} + bool operator()(euf::enode* n1, euf::enode* n2) const { + SASSERT(s.a.is_select(n1->get_expr())); + SASSERT(s.a.is_select(n2->get_expr())); + for (unsigned i = n1->num_args(); i-- > 1; ) + if (n1->get_arg(i)->get_root() != n2->get_arg(i)->get_root()) + return false; + return true; + } + }; + struct hash { + solver& s; + hash(solver& s) :s(s) {} + unsigned operator()(euf::enode* n) const { + SASSERT(s.a.is_select(n->get_expr())); + unsigned h = 33; + for (unsigned i = n->num_args(); i-- > 1; ) + h = hash_u_u(h, n->get_arg(i)->get_root_id()); + return h; + } + }; + eq eq_proc(*this); + hash hash_proc(*this); + hashtable table(DEFAULT_HASHTABLE_INITIAL_CAPACITY, hash_proc, eq_proc); + euf::enode* p2 = nullptr; + auto maps_diff = [&](euf::enode* p, euf::enode* else_, euf::enode* r) { + return table.find(p, p2) ? p2->get_root() != r : (else_ && else_ != r); + }; + auto table_diff = [&](euf::enode* r1, euf::enode* r2, euf::enode* else1) { + table.reset(); + for (euf::enode* p : euf::enode_parents(r1)) + if (a.is_select(p->get_expr()) && r1 == p->get_arg(0)->get_root()) + table.insert(p); + for (euf::enode* p : euf::enode_parents(r2)) + if (a.is_select(p->get_expr()) && r2 == p->get_arg(0)->get_root()) + if (maps_diff(p, else1, p->get_root())) + return true; + return false; + }; + + return table_diff(r1, r2, else1) || table_diff(r2, r1, else2); + } + } diff --git a/src/sat/smt/array_solver.cpp b/src/sat/smt/array_solver.cpp index 23a7c95d5..f2bbb909c 100644 --- a/src/sat/smt/array_solver.cpp +++ b/src/sat/smt/array_solver.cpp @@ -13,6 +13,54 @@ Author: Nikolaj Bjorner (nbjorner) 2020-09-08 +Notes: + +A node n has attribtes: + + parent_selects: { A[i] | A ~ n } + parent_lambdas: { store(A,i,v) | A ~ n } u { map(f, .., A, ..) | A ~ n } + lambdas: { const(v) | const(v) ~ n } + u { map(f,..) | map(f,..) ~ n } + u { store(A,i,v) | store(A,i,v) ~ n } + u { as-array(f) | as-array(f) ~ n } + +The attributes are used for propagation. +When n1 is merged with n2, and n1 is the new root, the attributes from n2 are added to n1. +The merge also looks for new redexes. + +Let A[j] in parent_selects(n2) : + + lambda in parent_lambdas(n1) + ------------------------------- + lambda[j] = beta-reduce(lambda[j]) + + lambda in lambdas(n1) + ------------------------------- + lambda[j] = beta-reduce(lambda[j]) + +Beta reduction rules are: + beta-reduce(store(A,j,v)[i]) = if(i = j, v, A[j]) + beta-reduce(map(f,A,B)[i]) = f(A[i],B[i]) + beta-reduce(as-array(f)[i]) = f(i) + beta-reduce(const(v)[i]) = v + beta-reduce((lambda x M[x])[i]) = M[i] + +For enforcing + store(A,j,v)[i] = beta-reduce(store(A,j,v)[i]) + + only the following axiom is instantiated: + - i = j or store(A,j,v)[i] = A[i] + +The other required axiom, store(A,j,v)[j] = v +is added eagerly whenever store(A,j,v) is created. + +Current setup: to enforce extensionality on lambdas, +also currently, as a base-line it is eager: + + A ~ B, A = lambda x. M[x] + ------------------------------- + A = B => forall i . M[i] = B[i] + --*/ #include "ast/ast_ll_pp.h" @@ -21,7 +69,7 @@ Author: namespace array { - solver::solver(euf::solver& ctx, theory_id id): + solver::solver(euf::solver& ctx, theory_id id) : th_euf_solver(ctx, id), a(m), m_sort2epsilon(m), @@ -36,20 +84,16 @@ namespace array { } sat::check_result solver::check() { - flet _is_redundant(m_is_redundant, true); + // flet _is_redundant(m_is_redundant, true); bool turn[2] = { false, false }; turn[s().rand()(2)] = true; for (unsigned idx = 0; idx < 2; ++idx) { - if (turn[idx]) { - if (add_delayed_axioms()) - return sat::CR_CONTINUE; - } - else { - if (add_interface_equalities()) - return sat::CR_CONTINUE; - } + if (turn[idx] && add_delayed_axioms()) + return sat::check_result::CR_CONTINUE; + else if (!turn[idx] && add_interface_equalities()) + return sat::check_result::CR_CONTINUE; } - return sat::CR_DONE; + return sat::check_result::CR_DONE; } void solver::push() { @@ -57,38 +101,52 @@ namespace array { } void solver::pop(unsigned n) { - n = lazy_pop(n); + n = lazy_pop(n); if (n == 0) return; m_var_data.resize(get_num_vars()); } - std::ostream& solver::display(std::ostream& out) const { + std::ostream& solver::display(std::ostream& out) const { for (unsigned i = 0; i < get_num_vars(); ++i) { + auto& d = get_var_data(i); out << var2enode(i)->get_expr_id() << " " << mk_bounded_pp(var2expr(i), m, 2) << "\n"; + display_info(out, "parent beta", d.m_parent_lambdas); + display_info(out, "parent select", d.m_parent_selects); + display_info(out, "beta ", d.m_lambdas); } - return out; + return out; + } + std::ostream& solver::display_info(std::ostream& out, char const* id, euf::enode_vector const& v) const { + if (v.empty()) + return out; + out << id << ": "; + for (euf::enode* p : v) + out << mk_bounded_pp(p->get_expr(), m, 2) << " "; + out << "\n"; + return out; } std::ostream& solver::display_justification(std::ostream& out, sat::ext_justification_idx idx) const { return out; } std::ostream& solver::display_constraint(std::ostream& out, sat::ext_constraint_idx idx) const { return out; } void solver::collect_statistics(statistics& st) const { - st.update("array store", m_stats.m_num_store_axiom); - st.update("array sel/store", m_stats.m_num_select_store_axiom); - st.update("array sel/const", m_stats.m_num_select_const_axiom); - st.update("array sel/map", m_stats.m_num_select_map_axiom); + st.update("array store", m_stats.m_num_store_axiom); + st.update("array sel/store", m_stats.m_num_select_store_axiom); + st.update("array sel/const", m_stats.m_num_select_const_axiom); + st.update("array sel/map", m_stats.m_num_select_map_axiom); st.update("array sel/as array", m_stats.m_num_select_as_array_axiom); - st.update("array def/map", m_stats.m_num_default_map_axiom); - st.update("array def/const", m_stats.m_num_default_const_axiom); - st.update("array def/store", m_stats.m_num_default_store_axiom); - st.update("array ext ax", m_stats.m_num_extensionality_axiom); - st.update("array cong ax", m_stats.m_num_congruence_axiom); - st.update("array exp ax2", m_stats.m_num_select_store_axiom_delayed); - st.update("array splits", m_stats.m_num_eq_splits); + st.update("array sel/lambda", m_stats.m_num_select_lambda_axiom); + st.update("array def/map", m_stats.m_num_default_map_axiom); + st.update("array def/const", m_stats.m_num_default_const_axiom); + st.update("array def/store", m_stats.m_num_default_store_axiom); + st.update("array ext ax", m_stats.m_num_extensionality_axiom); + st.update("array cong ax", m_stats.m_num_congruence_axiom); + st.update("array exp ax2", m_stats.m_num_select_store_axiom_delayed); + st.update("array splits", m_stats.m_num_eq_splits); } - euf::th_solver* solver::fresh(sat::solver* s, euf::solver& ctx) { + euf::th_solver* solver::fresh(sat::solver* s, euf::solver& ctx) { auto* result = alloc(solver, ctx, get_id()); ast_translation tr(m, ctx.get_manager()); for (unsigned i = 0; i < get_num_vars(); ++i) { @@ -97,21 +155,21 @@ namespace array { euf::enode* n = ctx.get_enode(e2); result->mk_var(n); } - return result; + return result; } void solver::new_eq_eh(euf::th_eq const& eq) { m_find.merge(eq.m_v1, eq.m_v2); } - bool solver::unit_propagate() { + bool solver::unit_propagate() { if (m_qhead == m_axiom_trail.size()) return false; bool prop = false; ctx.push(value_trail(m_qhead)); - for (; m_qhead < m_axiom_trail.size() && !s().inconsistent(); ++m_qhead) + for (; m_qhead < m_axiom_trail.size() && !s().inconsistent(); ++m_qhead) if (assert_axiom(m_qhead)) - prop = true; + prop = true; return prop; } @@ -121,76 +179,97 @@ namespace array { SASSERT(n1->get_root() == n2->get_root()); SASSERT(n1->is_root() || n2->is_root()); SASSERT(v1 == find(v1)); - expr* e1 = n1->get_expr(); expr* e2 = n2->get_expr(); auto& d1 = get_var_data(v1); auto& d2 = get_var_data(v2); - if (d2.m_prop_upward && !d1.m_prop_upward) + if (d2.m_prop_upward && !d1.m_prop_upward) set_prop_upward(v1); - if (a.is_array(e1)) - for (euf::enode* parent : d2.m_parents) { - add_parent(v1, parent); - if (a.is_store(parent->get_expr())) - add_store(v1, parent); - } + for (euf::enode* lambda : d2.m_lambdas) + add_lambda(v1, lambda); + for (euf::enode* lambda : d2.m_parent_lambdas) + add_parent_lambda(v1, lambda); + for (euf::enode* select : d2.m_parent_selects) + add_parent_select(v1, select); if (is_lambda(e1) || is_lambda(e2)) push_axiom(congruence_axiom(n1, n2)); } - void solver::unmerge_eh(theory_var v1, theory_var v2) { - auto& p1 = get_var_data(v1).m_parents; - auto& p2 = get_var_data(v2).m_parents; - p1.shrink(p1.size() - p2.size()); + void solver::tracked_push(euf::enode_vector& v, euf::enode* n) { + v.push_back(n); + ctx.push(push_back_trail(v)); } - void solver::add_store(theory_var v, euf::enode* store) { - SASSERT(a.is_store(store->get_expr())); - auto& d = get_var_data(v); - unsigned lambda_equiv_class_size = get_lambda_equiv_size(d); - if (get_config().m_array_always_prop_upward || lambda_equiv_class_size >= 1) - set_prop_upward(d); - for (euf::enode* n : d.m_parents) - if (a.is_select(n->get_expr())) - push_axiom(select_axiom(n, store)); - if (get_config().m_array_always_prop_upward || lambda_equiv_class_size >= 1) - set_prop_upward(store); - } + void solver::add_parent_select(theory_var v_child, euf::enode* select) { + SASSERT(a.is_select(select->get_expr())); + SASSERT(m.get_sort(select->get_arg(0)->get_expr()) == m.get_sort(var2expr(v_child))); - void solver::add_parent(theory_var v_child, euf::enode* parent) { - SASSERT(parent->is_root()); - get_var_data(v_child).m_parents.push_back(parent); + v_child = find(v_child); + tracked_push(get_var_data(v_child).m_parent_selects, select); euf::enode* child = var2enode(v_child); - euf::enode* r = child->get_root(); - expr* p = parent->get_expr(); - expr* c = child->get_expr(); - if (a.is_select(p) && parent->get_arg(0)->get_root() == r) { - if (a.is_const(c) || a.is_as_array(c) || a.is_store(c) || is_lambda(c)) - push_axiom(select_axiom(parent, child)); -#if 0 - if (!get_config().m_array_delay_exp_axiom && d.m_prop_upward) { - auto& d = get_var_data(v_child); - for (euf::enode* p2 : d.m_parents) - if (a.is_store(p2->get_expr())) - push_axiom(select_axiom(parent, p2)); - } -#endif - } - else if (a.mk_default(p)) { - if (a.is_const(c) || a.is_store(c) || a.is_map(c) || a.is_as_array(c)) - push_axiom(default_axiom(child)); + if (can_beta_reduce(child)) + push_axiom(select_axiom(select, child)); + } + + void solver::add_lambda(theory_var v, euf::enode* lambda) { + SASSERT(can_beta_reduce(lambda)); + auto& d = get_var_data(find(v)); + if (should_set_prop_upward(d)) + set_prop_upward(d); + tracked_push(d.m_lambdas, lambda); + if (should_set_prop_upward(d)) { + set_prop_upward(lambda); + propagate_select_axioms(d, lambda); } } + void solver::add_parent_lambda(theory_var v_child, euf::enode* lambda) { + SASSERT(can_beta_reduce(lambda)); + auto& d = get_var_data(find(v_child)); + tracked_push(d.m_parent_lambdas, lambda); + if (should_set_prop_upward(d)) + propagate_select_axioms(d, lambda); + } + + void solver::add_parent_default(theory_var v, euf::enode* def) { + SASSERT(a.is_default(def->get_expr())); + auto& d = get_var_data(find(v)); + for (euf::enode* lambda : d.m_lambdas) + push_axiom(default_axiom(lambda)); + if (should_prop_upward(d)) + propagate_parent_default(v); + } + + void solver::propagate_select_axioms(var_data const& d, euf::enode* lambda) { + for (euf::enode* select : d.m_parent_selects) + push_axiom(select_axiom(select, lambda)); + } + + void solver::propagate_parent_default(theory_var v) { + auto& d = get_var_data(find(v)); + for (euf::enode* lambda : d.m_parent_lambdas) + push_axiom(default_axiom(lambda)); + } + + void solver::propagate_parent_select_axioms(theory_var v) { + v = find(v); + expr* e = var2expr(v); + if (!a.is_array(e)) + return; + auto& d = get_var_data(v); + for (euf::enode* lambda : d.m_parent_lambdas) + propagate_select_axioms(d, lambda); + } + void solver::set_prop_upward(theory_var v) { auto& d = get_var_data(find(v)); - if (!d.m_prop_upward) { - ctx.push(reset_flag_trail(d.m_prop_upward)); - d.m_prop_upward = true; - if (!get_config().m_array_delay_exp_axiom) - push_parent_select_store_axioms(v); - set_prop_upward(d); - } + if (d.m_prop_upward) + return; + ctx.push(reset_flag_trail(d.m_prop_upward)); + d.m_prop_upward = true; + if (should_prop_upward(d)) + propagate_parent_select_axioms(v); + set_prop_upward(d); } void solver::set_prop_upward(euf::enode* n) { @@ -199,22 +278,28 @@ namespace array { } void solver::set_prop_upward(var_data& d) { - for (auto* p : d.m_parents) + for (auto* p : d.m_lambdas) set_prop_upward(p); } /** - \brief Return the size of the equivalence class for array terms + \brief Return the size of the equivalence class for array terms that can be expressed as \lambda i : Index . [.. (select a i) ..] */ - unsigned solver::get_lambda_equiv_size(var_data const& d) { - unsigned sz = 0; - for (auto* p : d.m_parents) - if (a.is_store(p->get_expr())) - ++sz; - return sz; + unsigned solver::get_lambda_equiv_size(var_data const& d) const { + return d.m_parent_selects.size() + 2 * d.m_lambdas.size(); } + bool solver::should_set_prop_upward(var_data const& d) const { + return get_config().m_array_always_prop_upward || get_lambda_equiv_size(d) >= 1; + } + bool solver::should_prop_upward(var_data const& d) const { + return !get_config().m_array_delay_exp_axiom && d.m_prop_upward; + } + bool solver::can_beta_reduce(euf::enode* n) const { + expr* c = n->get_expr(); + return a.is_const(c) || a.is_as_array(c) || a.is_store(c) || is_lambda(c); + } } diff --git a/src/sat/smt/array_solver.h b/src/sat/smt/array_solver.h index aa923edd3..184c44991 100644 --- a/src/sat/smt/array_solver.h +++ b/src/sat/smt/array_solver.h @@ -42,6 +42,7 @@ namespace array { unsigned m_num_select_const_axiom, m_num_select_store_axiom_delayed; unsigned m_num_default_store_axiom, m_num_default_map_axiom; unsigned m_num_default_const_axiom, m_num_default_as_array_axiom; + unsigned m_num_select_lambda_axiom; void reset() { memset(this, 0, sizeof(*this)); } stats() { reset(); } }; @@ -49,10 +50,10 @@ namespace array { // void log_drat(array_justification const& c); struct var_data { - bool m_prop_upward{ false }; - bool m_is_array{ false }; - bool m_is_select{ false }; - ptr_vector m_parents; + bool m_prop_upward{ false }; + euf::enode_vector m_lambdas; // equivalent nodes that have beta reduction properties + euf::enode_vector m_parent_lambdas; // parents that have beta reduction properties + euf::enode_vector m_parent_selects; // parents that use array in select position var_data() {} }; @@ -76,11 +77,10 @@ namespace array { void ensure_var(euf::enode* n); void internalize_store(euf::enode* n); void internalize_select(euf::enode* n); - void internalize_const(euf::enode* n); + void internalize_lambda(euf::enode* n); void internalize_ext(euf::enode* n); void internalize_default(euf::enode* n); void internalize_map(euf::enode* n); - void internalize_as_array(euf::enode* n); // axioms struct axiom_record { @@ -144,33 +144,44 @@ namespace array { bool assert_default_map_axiom(app* map); bool assert_default_const_axiom(app* cnst); bool assert_default_store_axiom(app* store); - bool assert_default_as_array_axiom(app* as_array); bool assert_congruent_axiom(expr* e1, expr* e2); bool add_delayed_axioms(); bool has_unitary_domain(app* array_term); - bool has_large_domain(app* array_term); + bool has_large_domain(expr* array_term); std::pair mk_epsilon(sort* s); void collect_shared_vars(sbuffer& roots); bool add_interface_equalities(); bool is_select_arg(euf::enode* r); - // solving - void add_parent(theory_var v_child, euf::enode* parent); - void add_parent(euf::enode* child, euf::enode* parent) { add_parent(child->get_th_var(get_id()), parent); } - void add_store(theory_var v, euf::enode* store); + // solving + void add_parent_select(theory_var v_child, euf::enode* select); + void add_parent_default(theory_var v_child, euf::enode* def); + void add_lambda(theory_var v, euf::enode* lambda); + void add_parent_lambda(theory_var v_child, euf::enode* lambda); + + void propagate_select_axioms(var_data const& d, euf::enode* a); + void propagate_parent_select_axioms(theory_var v); + void propagate_parent_default(theory_var v); + void set_prop_upward(theory_var v); void set_prop_upward(var_data& d); void set_prop_upward(euf::enode* n); - void push_parent_select_store_axioms(theory_var v); - unsigned get_lambda_equiv_size(var_data const& d); + unsigned get_lambda_equiv_size(var_data const& d) const; + bool should_set_prop_upward(var_data const& d) const; + bool should_prop_upward(var_data const& d) const; + bool can_beta_reduce(euf::enode* n) const; var_data& get_var_data(euf::enode* n) { return get_var_data(n->get_th_var(get_id())); } var_data& get_var_data(theory_var v) { return *m_var_data[v]; } + var_data const& get_var_data(theory_var v) const { return *m_var_data[v]; } + // models + bool have_different_model_values(theory_var v1, theory_var v2); - // invariants + // diagnostics + std::ostream& display_info(std::ostream& out, char const* id, euf::enode_vector const& v) const; public: solver(euf::solver& ctx, theory_id id); ~solver() override {} @@ -196,8 +207,10 @@ namespace array { euf::theory_var mk_var(euf::enode* n) override; void apply_sort_cnstr(euf::enode* n, sort* s) override; + void tracked_push(euf::enode_vector& v, euf::enode* n); + void merge_eh(theory_var, theory_var, theory_var v1, theory_var v2); void after_merge_eh(theory_var r1, theory_var r2, theory_var v1, theory_var v2) {} - void unmerge_eh(theory_var v1, theory_var v2); + void unmerge_eh(theory_var v1, theory_var v2) {} }; } diff --git a/src/sat/smt/ba_card.cpp b/src/sat/smt/ba_card.cpp new file mode 100644 index 000000000..51e996ce2 --- /dev/null +++ b/src/sat/smt/ba_card.cpp @@ -0,0 +1,290 @@ +/*++ +Copyright (c) 2017 Microsoft Corporation + +Module Name: + + ba_card.cpp + +Abstract: + + Interface for Cardinality constraints. + +Author: + + Nikolaj Bjorner (nbjorner) 2017-01-30 + +--*/ + +#include "sat/smt/ba_card.h" +#include "sat/smt/ba_solver.h" +#include "sat/sat_simplifier.h" + +namespace ba { + + // ----------------------- + // pb_base + + bool pb_base::well_formed() const { + uint_set vars; + if (lit() != sat::null_literal) vars.insert(lit().var()); + for (unsigned i = 0; i < size(); ++i) { + bool_var v = get_lit(i).var(); + if (vars.contains(v)) return false; + if (get_coeff(i) > k()) return false; + vars.insert(v); + } + return true; + } + + + // ---------------------- + // card + + card::card(unsigned id, literal lit, literal_vector const& lits, unsigned k) : + pb_base(tag_t::card_t, id, lit, lits.size(), get_obj_size(lits.size()), k) { + for (unsigned i = 0; i < size(); ++i) { + m_lits[i] = lits[i]; + } + } + + void card::negate() { + m_lit.neg(); + for (unsigned i = 0; i < m_size; ++i) { + m_lits[i].neg(); + } + m_k = m_size - m_k + 1; + SASSERT(m_size >= m_k && m_k > 0); + } + + bool card::is_watching(literal l) const { + unsigned sz = std::min(k() + 1, size()); + for (unsigned i = 0; i < sz; ++i) { + if ((*this)[i] == l) return true; + } + return false; + } + + double card::get_reward(ba::solver_interface const& s, sat::literal_occs_fun& literal_occs) const { + unsigned k = this->k(), slack = 0; + bool do_add = s.get_config().m_lookahead_reward == sat::heule_schur_reward; + double to_add = do_add ? 0 : 1; + for (literal l : *this) { + switch (s.value(l)) { + case l_true: --k; if (k == 0) return 0; + case l_undef: + if (do_add) to_add += literal_occs(l); + ++slack; break; + case l_false: break; + } + } + if (k >= slack) return 1; + return pow(0.5, slack - k + 1) * to_add; + } + + std::ostream& card::display(std::ostream& out) const { + for (literal l : *this) + out << l << " "; + return out << " >= " << k(); + } + + void constraint::display_lit(std::ostream& out, solver_interface const& s, literal lit, unsigned sz, bool values) const { + if (lit != sat::null_literal) { + if (values) { + out << lit << "[" << sz << "]"; + out << "@(" << s.value(lit); + if (s.value(lit) != l_undef) { + out << ":" << s.lvl(lit); + } + out << "): "; + } + else { + out << lit << " == "; + } + } + } + + std::ostream& card::display(std::ostream& out, solver_interface const& s, bool values) const { + auto const& c = *this; + display_lit(out, s, c.lit(), c.size(), values); + for (unsigned i = 0; i < c.size(); ++i) { + literal l = c[i]; + out << l; + if (values) { + out << "@(" << s.value(l); + if (s.value(l) != l_undef) { + out << ":" << s.lvl(l); + } + out << ") "; + } + else { + out << " "; + } + } + return out << ">= " << c.k() << "\n"; + } + + void card::clear_watch(solver_interface& s) { + if (is_clear()) return; + reset_watch(); + unsigned sz = std::min(k() + 1, size()); + for (unsigned i = 0; i < sz; ++i) + unwatch_literal(s, (*this)[i]); + } + + bool card::init_watch(solver_interface& s) { + auto& c = *this; + literal root = c.lit(); + if (root != sat::null_literal && s.value(root) == l_false) { + clear_watch(s); + negate(); + root.neg(); + } + if (root != sat::null_literal) { + if (!is_watched(s, root)) watch_literal(s, root); + if (!is_pure() && !is_watched(s, ~root)) watch_literal(s, ~root); + } + TRACE("ba", display(tout << "init watch: ", s, true);); + SASSERT(root == sat::null_literal || s.value(root) == l_true); + unsigned j = 0, sz = c.size(), bound = c.k(); + // put the non-false literals into the head. + + if (bound == sz) { + for (literal l : c) s.assign(c, l); + return false; + } + + for (unsigned i = 0; i < sz; ++i) { + if (s.value(c[i]) != l_false) { + if (j != i) { + if (c.is_watched() && j <= bound && i > bound) { + c.unwatch_literal(s, c[j]); + c.watch_literal(s, c[i]); + } + c.swap(i, j); + } + ++j; + } + } + DEBUG_CODE( + bool is_false = false; + for (literal l : c) { + SASSERT(!is_false || s.value(l) == l_false); + is_false = s.value(l) == l_false; + }); + + // j is the number of non-false, sz - j the number of false. + + if (j < bound) { + if (is_watched()) clear_watch(s); + SASSERT(0 < bound && bound < sz); + literal alit = c[j]; + + // + // we need the assignment level of the asserting literal to be maximal. + // such that conflict resolution can use the asserting literal as a starting + // point. + // + + for (unsigned i = bound; i < sz; ++i) { + if (s.lvl(alit) < s.lvl(c[i])) { + c.swap(i, j); + alit = c[j]; + } + } + s.set_conflict(c, alit); + return false; + } + else if (j == bound) { + for (unsigned i = 0; i < bound; ++i) { + s.assign(c, c[i]); + } + return false; + } + else { + if (c.is_watched()) return true; + clear_watch(s); + for (unsigned i = 0; i <= bound; ++i) { + c.watch_literal(s, c[i]); + } + c.set_watch(); + return true; + } + } + + + card& constraint::to_card() { + SASSERT(is_card()); + return static_cast(*this); + } + + card const& constraint::to_card() const { + SASSERT(is_card()); + return static_cast(*this); + } + + + bool card::is_extended_binary(literal_vector& r) const { + auto const& ca = *this; + if (ca.size() == ca.k() + 1 && ca.lit() == sat::null_literal) { + r.reset(); + for (literal l : ca) r.push_back(l); + return true; + } + else { + return false; + } + } + + bool card::validate_unit_propagation(solver_interface const& s, literal alit) const { + (void) alit; + if (lit() != sat::null_literal && s.value(lit()) != l_true) + return false; + for (unsigned i = k(); i < size(); ++i) + if (s.value((*this)[i]) != l_false) + return false; + return true; + } + + lbool card::eval(solver_interface const& s) const { + unsigned trues = 0, undefs = 0; + for (literal l : *this) { + switch (s.value(l)) { + case l_true: trues++; break; + case l_undef: undefs++; break; + default: break; + } + } + if (trues + undefs < k()) return l_false; + if (trues >= k()) return l_true; + return l_undef; + } + + lbool card::eval(sat::model const& m) const { + unsigned trues = 0, undefs = 0; + for (literal l : *this) { + switch (ba::value(m, l)) { + case l_true: trues++; break; + case l_undef: undefs++; break; + default: break; + } + } + if (trues + undefs < k()) return l_false; + if (trues >= k()) return l_true; + return l_undef; + } + + void card::init_use_list(sat::ext_use_list& ul) const { + auto idx = cindex(); + for (auto l : *this) + ul.insert(l, idx); + } + + bool card::is_blocked(sat::simplifier& sim, literal lit) const { + unsigned weight = 0; + for (literal l2 : *this) + if (sim.is_marked(~l2)) ++weight; + + return weight >= k(); + } + +} diff --git a/src/sat/smt/ba_card.h b/src/sat/smt/ba_card.h new file mode 100644 index 000000000..98d0456f9 --- /dev/null +++ b/src/sat/smt/ba_card.h @@ -0,0 +1,70 @@ +/*++ +Copyright (c) 2017 Microsoft Corporation + +Module Name: + + ba_card.h + +Abstract: + + Interface for Cardinality constraints. + +Author: + + Nikolaj Bjorner (nbjorner) 2017-01-30 + +--*/ + +#pragma once + +#include "sat/sat_types.h" +#include "sat/smt/ba_constraint.h" + + +namespace ba { + + // base class for pb and cardinality constraints + class pb_base : public constraint { + protected: + unsigned m_k; + public: + pb_base(ba::tag_t t, unsigned id, literal l, unsigned sz, size_t osz, unsigned k) : + constraint(t, id, l, sz, osz), m_k(k) { + VERIFY(k < 4000000000); + } + virtual void set_k(unsigned k) { VERIFY(k < 4000000000); m_k = k; } + virtual unsigned get_coeff(unsigned i) const { UNREACHABLE(); return 0; } + unsigned k() const { return m_k; } + bool well_formed() const override; + }; + + class card : public pb_base { + literal m_lits[0]; + public: + static size_t get_obj_size(unsigned num_lits) { return sat::constraint_base::obj_size(sizeof(card) + num_lits * sizeof(literal)); } + card(unsigned id, literal lit, literal_vector const& lits, unsigned k); + literal operator[](unsigned i) const { return m_lits[i]; } + literal& operator[](unsigned i) { return m_lits[i]; } + literal const* begin() const { return m_lits; } + literal const* end() const { return static_cast(m_lits) + m_size; } + void negate() override; + void swap(unsigned i, unsigned j) override { std::swap(m_lits[i], m_lits[j]); } + literal_vector literals() const override { return literal_vector(m_size, m_lits); } + bool is_watching(literal l) const override; + literal get_lit(unsigned i) const override { return m_lits[i]; } + void set_lit(unsigned i, literal l) override { m_lits[i] = l; } + unsigned get_coeff(unsigned i) const override { return 1; } + double get_reward(ba::solver_interface const& s, sat::literal_occs_fun& occs) const override; + std::ostream& display(std::ostream& out) const override; + std::ostream& display(std::ostream& out, solver_interface const& s, bool values) const override; + void clear_watch(solver_interface& s) override; + bool init_watch(solver_interface& s) override; + bool is_extended_binary(literal_vector& r) const override; + bool validate_unit_propagation(solver_interface const& s, literal alit) const override; + lbool eval(sat::model const& m) const override; + lbool eval(solver_interface const& s) const override; + void init_use_list(sat::ext_use_list& ul) const override; + bool is_blocked(sat::simplifier& s, literal lit) const override; + + }; +} diff --git a/src/sat/smt/ba_constraint.cpp b/src/sat/smt/ba_constraint.cpp new file mode 100644 index 000000000..a15039e45 --- /dev/null +++ b/src/sat/smt/ba_constraint.cpp @@ -0,0 +1,58 @@ +/*++ +Copyright (c) 2017 Microsoft Corporation + +Module Name: + + ba_constraint.cpp + +Abstract: + + Interface for constraints. + +Author: + + Nikolaj Bjorner (nbjorner) 2017-01-30 + +--*/ + +#include "sat/smt/ba_constraint.h" + +namespace ba { + + unsigned constraint::fold_max_var(unsigned w) const { + if (lit() != sat::null_literal) w = std::max(w, lit().var()); + for (unsigned i = 0; i < size(); ++i) w = std::max(w, get_lit(i).var()); + return w; + } + + std::ostream& operator<<(std::ostream& out, constraint const& cnstr) { + if (cnstr.lit() != sat::null_literal) out << cnstr.lit() << " == "; + return cnstr.display(out); + } + + bool constraint::is_watched(solver_interface const& s, literal lit) const { + return s.get_wlist(~lit).contains(sat::watched(cindex())); + } + + void constraint::unwatch_literal(solver_interface& s, literal lit) { + sat::watched w(cindex()); + s.get_wlist(~lit).erase(w); + SASSERT(!is_watched(s, lit)); + } + + void constraint::watch_literal(solver_interface& s, literal lit) { + if (is_pure() && lit == ~this->lit()) return; + SASSERT(!is_watched(s, lit)); + sat::watched w(cindex()); + s.get_wlist(~lit).push_back(w); + } + + void constraint::nullify_tracking_literal(solver_interface& s) { + if (lit() != sat::null_literal) { + unwatch_literal(s, lit()); + unwatch_literal(s, ~lit()); + nullify_literal(); + } + } + +} diff --git a/src/sat/smt/ba_constraint.h b/src/sat/smt/ba_constraint.h new file mode 100644 index 000000000..365f9284f --- /dev/null +++ b/src/sat/smt/ba_constraint.h @@ -0,0 +1,143 @@ +/*++ +Copyright (c) 2017 Microsoft Corporation + +Module Name: + + ba_constraint.h + +Abstract: + + Interface for Boolean constraints. + +Author: + + Nikolaj Bjorner (nbjorner) 2017-01-30 + +Revision History: + +--*/ + +#pragma once +#include "sat/smt/ba_solver_interface.h" + +namespace ba { + + enum class tag_t { + card_t, + pb_t, + xr_t + }; + + class card; + class pb; + class xr; + class pb_base; + + inline lbool value(sat::model const& m, literal l) { return l.sign() ? ~m[l.var()] : m[l.var()]; } + + class constraint { + protected: + tag_t m_tag; + bool m_removed; + literal m_lit; + literal m_watch; + unsigned m_glue; + unsigned m_psm; + unsigned m_size; + size_t m_obj_size; + bool m_learned; + unsigned m_id; + bool m_pure; // is the constraint pure (only positive occurrences) + + void display_lit(std::ostream& out, solver_interface const& s, literal lit, unsigned sz, bool values) const; + public: + constraint(tag_t t, unsigned id, literal l, unsigned sz, size_t osz): + m_tag(t), m_removed(false), m_lit(l), m_watch(sat::null_literal), m_glue(0), m_psm(0), m_size(sz), m_obj_size(osz), m_learned(false), m_id(id), m_pure(false) { + } + sat::ext_constraint_idx cindex() const { return sat::constraint_base::mem2base(this); } + void deallocate(small_object_allocator& a) { a.deallocate(obj_size(), sat::constraint_base::mem2base_ptr(this)); } + unsigned id() const { return m_id; } + tag_t tag() const { return m_tag; } + literal lit() const { return m_lit; } + unsigned size() const { return m_size; } + void set_size(unsigned sz) { SASSERT(sz <= m_size); m_size = sz; } + void update_literal(literal l) { m_lit = l; } + bool was_removed() const { return m_removed; } + void set_removed() { m_removed = true; } + void nullify_literal() { m_lit = sat::null_literal; } + unsigned glue() const { return m_glue; } + void set_glue(unsigned g) { m_glue = g; } + unsigned psm() const { return m_psm; } + void set_psm(unsigned p) { m_psm = p; } + void set_learned(bool f) { m_learned = f; } + bool learned() const { return m_learned; } + bool is_watched() const { return m_watch == m_lit && m_lit != sat::null_literal; } + void set_watch() { m_watch = m_lit; } + void reset_watch() { m_watch = sat::null_literal; } + bool is_clear() const { return m_watch == sat::null_literal && m_lit != sat::null_literal; } + bool is_pure() const { return m_pure; } + void set_pure() { m_pure = true; } + unsigned fold_max_var(unsigned w) const; + + size_t obj_size() const { return m_obj_size; } + card& to_card(); + pb& to_pb(); + xr& to_xr(); + card const& to_card() const; + pb const& to_pb() const; + xr const& to_xr() const; + pb_base const& to_pb_base() const; + bool is_card() const { return m_tag == tag_t::card_t; } + bool is_pb() const { return m_tag == tag_t::pb_t; } + bool is_xr() const { return m_tag == tag_t::xr_t; } + + bool is_watched(solver_interface const& s, literal lit) const; + void unwatch_literal(solver_interface& s, literal lit); + void nullify_tracking_literal(solver_interface& s); + void watch_literal(solver_interface& s, literal lit); + virtual void clear_watch(solver_interface& s) = 0; + virtual bool init_watch(solver_interface& s) = 0; + virtual lbool eval(sat::model const& m) const = 0; + virtual lbool eval(solver_interface const& s) const = 0; + virtual bool is_blocked(sat::simplifier& s, literal lit) const = 0; + + virtual bool validate_unit_propagation(solver_interface const& s, literal alit) const = 0; + + virtual bool is_watching(literal l) const { UNREACHABLE(); return false; }; + virtual literal_vector literals() const { UNREACHABLE(); return literal_vector(); } + virtual void swap(unsigned i, unsigned j) { UNREACHABLE(); } + virtual literal get_lit(unsigned i) const { UNREACHABLE(); return sat::null_literal; } + virtual void set_lit(unsigned i, literal l) { UNREACHABLE(); } + virtual bool well_formed() const { return true; } + virtual void negate() { UNREACHABLE(); } + virtual bool is_extended_binary(literal_vector& r) const { return false; } + + virtual double get_reward(solver_interface const& s, sat::literal_occs_fun& occs) const { return 0; } + virtual std::ostream& display(std::ostream& out) const = 0; + virtual std::ostream& display(std::ostream& out, solver_interface const& s, bool values) const = 0; + virtual void init_use_list(sat::ext_use_list& ul) const = 0; + + class iterator { + constraint const& c; + unsigned idx; + public: + iterator(constraint const& c, unsigned idx) : c(c), idx(idx) {} + literal operator*() { return c.get_lit(idx); } + iterator& operator++() { ++idx; return *this; } + bool operator==(iterator const& other) const { SASSERT(&c == &other.c); return idx == other.idx; } + bool operator!=(iterator const& other) const { SASSERT(&c == &other.c); return idx != other.idx; } + }; + + class literal_iterator { + constraint const& c; + public: + literal_iterator(constraint const& c):c(c) {} + iterator begin() const { return iterator(c, 0); } + iterator end() const { return iterator(c, c.size()); } + }; + }; + + std::ostream& operator<<(std::ostream& out, constraint const& c); + + +} diff --git a/src/sat/smt/ba_internalize.cpp b/src/sat/smt/ba_internalize.cpp index 2aed982a1..7ccb4b4ae 100644 --- a/src/sat/smt/ba_internalize.cpp +++ b/src/sat/smt/ba_internalize.cpp @@ -284,7 +284,7 @@ namespace sat { } } - expr_ref ba_solver::get_card(std::function& lit2expr, ba_solver::card const& c) { + expr_ref ba_solver::get_card(std::function& lit2expr, ba::card const& c) { ptr_buffer lits; for (sat::literal l : c) { lits.push_back(lit2expr(l)); @@ -297,7 +297,7 @@ namespace sat { return fml; } - expr_ref ba_solver::get_pb(std::function& lit2expr, ba_solver::pb const& p) { + expr_ref ba_solver::get_pb(std::function& lit2expr, pb const& p) { ptr_buffer lits; vector coeffs; for (auto const& wl : p) { @@ -313,7 +313,7 @@ namespace sat { return fml; } - expr_ref ba_solver::get_xor(std::function& lit2expr, ba_solver::xr const& x) { + expr_ref ba_solver::get_xor(std::function& lit2expr, xr const& x) { ptr_buffer lits; for (sat::literal l : x) { lits.push_back(lit2expr(l)); @@ -329,13 +329,13 @@ namespace sat { bool ba_solver::to_formulas(std::function& l2e, expr_ref_vector& fmls) { for (auto* c : constraints()) { switch (c->tag()) { - case ba_solver::card_t: + case ba::tag_t::card_t: fmls.push_back(get_card(l2e, c->to_card())); break; - case ba_solver::pb_t: + case ba::tag_t::pb_t: fmls.push_back(get_pb(l2e, c->to_pb())); break; - case ba_solver::xr_t: + case ba::tag_t::xr_t: fmls.push_back(get_xor(l2e, c->to_xr())); break; } diff --git a/src/sat/smt/ba_pb.cpp b/src/sat/smt/ba_pb.cpp new file mode 100644 index 000000000..3c31c0138 --- /dev/null +++ b/src/sat/smt/ba_pb.cpp @@ -0,0 +1,308 @@ +/*++ +Copyright (c) 2017 Microsoft Corporation + +Module Name: + + ba_pb.cpp + +Abstract: + + Interface for PB constraints. + +Author: + + Nikolaj Bjorner (nbjorner) 2017-01-30 + +--*/ + +#include "sat/smt/ba_pb.h" + +namespace ba { + + pb& constraint::to_pb() { + SASSERT(is_pb()); + return static_cast(*this); + } + + pb const& constraint::to_pb() const { + SASSERT(is_pb()); + return static_cast(*this); + } + + pb_base const& constraint::to_pb_base() const { + SASSERT(is_pb() || is_card()); + return static_cast(*this); + } + + // ----------------------------------- + // pb + + pb::pb(unsigned id, literal lit, svector const& wlits, unsigned k) : + pb_base(tag_t::pb_t, id, lit, wlits.size(), get_obj_size(wlits.size()), k), + m_slack(0), + m_num_watch(0), + m_max_sum(0) { + for (unsigned i = 0; i < size(); ++i) { + m_wlits[i] = wlits[i]; + } + update_max_sum(); + } + + void pb::update_max_sum() { + m_max_sum = 0; + for (unsigned i = 0; i < size(); ++i) { + m_wlits[i].first = std::min(k(), m_wlits[i].first); + if (m_max_sum + m_wlits[i].first < m_max_sum) { + throw default_exception("addition of pb coefficients overflows"); + } + m_max_sum += m_wlits[i].first; + } + } + + void pb::negate() { + m_lit.neg(); + unsigned w = 0; + for (unsigned i = 0; i < m_size; ++i) { + m_wlits[i].second.neg(); + VERIFY(w + m_wlits[i].first >= w); + w += m_wlits[i].first; + } + m_k = w - m_k + 1; + VERIFY(w >= m_k && m_k > 0); + } + + bool pb::is_watching(literal l) const { + for (unsigned i = 0; i < m_num_watch; ++i) { + if ((*this)[i].second == l) return true; + } + return false; + } + + bool pb::is_cardinality() const { + if (size() == 0) return false; + unsigned w = (*this)[0].first; + for (wliteral wl : *this) if (w != wl.first) return false; + return true; + } + + double pb::get_reward(ba::solver_interface const& s, sat::literal_occs_fun& occs) const { + unsigned k = this->k(), slack = 0; + bool do_add = s.get_config().m_lookahead_reward == sat::heule_schur_reward; + double to_add = do_add ? 0 : 1; + double undefs = 0; + for (wliteral wl : *this) { + literal l = wl.second; + unsigned w = wl.first; + switch (s.value(l)) { + case l_true: if (k <= w) return 0; + case l_undef: + if (do_add) to_add += occs(l); + ++undefs; + slack += w; + break; // TBD multiplier factor on this + case l_false: break; + } + } + if (k >= slack || 0 == undefs) return 0; + double avg = slack / undefs; + return pow(0.5, (slack - k + 1) / avg) * to_add; + } + + + void pb::clear_watch(solver_interface& s) { + reset_watch(); + for (unsigned i = 0; i < num_watch(); ++i) { + unwatch_literal(s, (*this)[i].second); + } + set_num_watch(0); + DEBUG_CODE(for (wliteral wl : *this) VERIFY(!is_watched(s, wl.second));); + } + + + // watch a prefix of literals, such that the slack of these is >= k + bool pb::init_watch(solver_interface& s) { + auto& p = *this; + clear_watch(s); + if (lit() != sat::null_literal && s.value(p.lit()) == l_false) { + negate(); + } + + VERIFY(lit() == sat::null_literal || s.value(lit()) == l_true); + unsigned sz = size(), bound = k(); + + // put the non-false literals into the head. + unsigned slack = 0, slack1 = 0, num_watch = 0, j = 0; + for (unsigned i = 0; i < sz; ++i) { + if (s.value(p[i].second) != l_false) { + if (j != i) { + swap(i, j); + } + if (slack <= bound) { + slack += p[j].first; + ++num_watch; + } + else { + slack1 += p[j].first; + } + ++j; + } + } + + DEBUG_CODE( + bool is_false = false; + for (unsigned k = 0; k < sz; ++k) { + SASSERT(!is_false || s.value(p[k].second) == l_false); + SASSERT((k < j) == (s.value(p[k].second) != l_false)); + is_false = s.value(p[k].second) == l_false; + }); + + if (slack < bound) { + literal lit = p[j].second; + VERIFY(s.value(lit) == l_false); + for (unsigned i = j + 1; i < sz; ++i) { + if (s.lvl(lit) < s.lvl(p[i].second)) { + lit = p[i].second; + } + } + s.set_conflict(p, lit); + return false; + } + else { + for (unsigned i = 0; i < num_watch; ++i) { + p.watch_literal(s, p[i].second); + } + p.set_slack(slack); + p.set_num_watch(num_watch); + + // SASSERT(validate_watch(p, sat::null_literal)); + + TRACE("ba", display(tout << "init watch: ", s, true);); + + // slack is tight: + if (slack + slack1 == bound) { + SASSERT(slack1 == 0); + SASSERT(j == num_watch); + for (unsigned i = 0; i < j; ++i) { + s.assign(p, p[i].second); + } + } + return true; + } + } + + + std::ostream& pb::display(std::ostream& out) const { + bool first = true; + for (wliteral wl : *this) { + if (!first) out << "+ "; + if (wl.first != 1) out << wl.first << " * "; + out << wl.second << " "; + first = false; + } + return out << " >= " << k(); + } + + std::ostream& pb::display(std::ostream& out, solver_interface const& s, bool values) const { + auto const& p = *this; + if (p.lit() != sat::null_literal) out << p.lit() << " == "; + if (values) { + out << "[watch: " << p.num_watch() << ", slack: " << p.slack() << "]"; + } + if (p.lit() != sat::null_literal && values) { + out << "@(" << s.value(p.lit()); + if (s.value(p.lit()) != l_undef) { + out << ":" << s.lvl(p.lit()); + } + out << "): "; + } + unsigned i = 0; + for (wliteral wl : p) { + literal l = wl.second; + unsigned w = wl.first; + if (i > 0) out << "+ "; + if (i++ == p.num_watch()) out << " | "; + if (w > 1) out << w << " * "; + out << l; + if (values) { + out << "@(" << s.value(l); + if (s.value(l) != l_undef) { + out << ":" << s.lvl(l); + } + out << ") "; + } + else { + out << " "; + } + } + return out << ">= " << p.k() << "\n"; + } + + bool pb::validate_unit_propagation(solver_interface const& s, literal alit) const { + if (lit() != sat::null_literal && s.value(lit()) != l_true) + return false; + + unsigned sum = 0; + TRACE("ba", display(tout << "validate: " << alit << "\n", s, true);); + for (wliteral wl : *this) { + literal l = wl.second; + lbool val = s.value(l); + if (val != l_false && l != alit) { + sum += wl.first; + } + } + return sum < k(); + } + + lbool pb::eval(sat::model const& m) const { + auto const& p = *this; + unsigned trues = 0, undefs = 0; + for (wliteral wl : p) { + switch (ba::value(m, wl.second)) { + case l_true: trues += wl.first; break; + case l_undef: undefs += wl.first; break; + default: break; + } + } + if (trues + undefs < p.k()) return l_false; + if (trues >= p.k()) return l_true; + return l_undef; + } + + lbool pb::eval(solver_interface const& s) const { + auto const& p = *this; + unsigned trues = 0, undefs = 0; + for (wliteral wl : p) { + switch (s.value(wl.second)) { + case l_true: trues += wl.first; break; + case l_undef: undefs += wl.first; break; + default: break; + } + } + if (trues + undefs < p.k()) return l_false; + if (trues >= p.k()) return l_true; + return l_undef; + } + + void pb::init_use_list(sat::ext_use_list& ul) const { + auto idx = cindex(); + for (auto l : *this) + ul.insert(l.second, idx); + } + + bool pb::is_blocked(sat::simplifier& sim, literal lit) const { + unsigned weight = 0, offset = 0; + for (wliteral l2 : *this) { + if (~l2.second == lit) { + offset = l2.first; + break; + } + } + SASSERT(offset != 0); + for (wliteral l2 : *this) { + if (sim.is_marked(~l2.second)) { + weight += std::min(offset, l2.first); + } + } + return weight >= k(); + } +} diff --git a/src/sat/smt/ba_pb.h b/src/sat/smt/ba_pb.h new file mode 100644 index 000000000..97b8c40dc --- /dev/null +++ b/src/sat/smt/ba_pb.h @@ -0,0 +1,67 @@ +/*++ +Copyright (c) 2017 Microsoft Corporation + +Module Name: + + ba_pb.h + +Abstract: + + Interface for PB constraints. + +Author: + + Nikolaj Bjorner (nbjorner) 2017-01-30 + +--*/ + +#pragma once + +#include "sat/sat_types.h" +#include "sat/smt/ba_card.h" + + +namespace ba { + + class pb : public pb_base { + unsigned m_slack; + unsigned m_num_watch; + unsigned m_max_sum; + wliteral m_wlits[0]; + public: + static size_t get_obj_size(unsigned num_lits) { return sat::constraint_base::obj_size(sizeof(pb) + num_lits * sizeof(wliteral)); } + pb(unsigned id, literal lit, svector const& wlits, unsigned k); + literal lit() const { return m_lit; } + wliteral operator[](unsigned i) const { return m_wlits[i]; } + wliteral& operator[](unsigned i) { return m_wlits[i]; } + wliteral const* begin() const { return m_wlits; } + wliteral const* end() const { return begin() + m_size; } + + unsigned slack() const { return m_slack; } + void set_slack(unsigned s) { m_slack = s; } + unsigned num_watch() const { return m_num_watch; } + unsigned max_sum() const { return m_max_sum; } + void update_max_sum(); + void set_num_watch(unsigned s) { m_num_watch = s; } + bool is_cardinality() const; + void negate() override; + void set_k(unsigned k) override { m_k = k; VERIFY(k < 4000000000); update_max_sum(); } + void swap(unsigned i, unsigned j) override { std::swap(m_wlits[i], m_wlits[j]); } + literal_vector literals() const override { literal_vector lits; for (auto wl : *this) lits.push_back(wl.second); return lits; } + bool is_watching(literal l) const override; + literal get_lit(unsigned i) const override { return m_wlits[i].second; } + void set_lit(unsigned i, literal l) override { m_wlits[i].second = l; } + unsigned get_coeff(unsigned i) const override { return m_wlits[i].first; } + double get_reward(ba::solver_interface const& s, sat::literal_occs_fun& occs) const override; + void clear_watch(solver_interface& s) override; + std::ostream& display(std::ostream& out) const override; + std::ostream& display(std::ostream& out, solver_interface const& s, bool values) const override; + bool init_watch(solver_interface& s) override; + bool validate_unit_propagation(solver_interface const& s, literal alit) const override; + lbool eval(sat::model const& m) const override; + lbool eval(solver_interface const& s) const override; + void init_use_list(sat::ext_use_list& ul) const override; + bool is_blocked(sat::simplifier& s, literal lit) const override; + }; + +} diff --git a/src/sat/smt/ba_solver.cpp b/src/sat/smt/ba_solver.cpp index 482834b76..d27726f5f 100644 --- a/src/sat/smt/ba_solver.cpp +++ b/src/sat/smt/ba_solver.cpp @@ -23,270 +23,11 @@ Author: #include "sat/sat_simplifier_params.hpp" #include "sat/sat_xor_finder.h" - namespace sat { static unsigned _bad_id = 11111111; // 2759; // #define BADLOG(_cmd_) if (p.id() == _bad_id) { _cmd_; } - ba_solver::card& ba_solver::constraint::to_card() { - SASSERT(is_card()); - return static_cast(*this); - } - - ba_solver::card const& ba_solver::constraint::to_card() const{ - SASSERT(is_card()); - return static_cast(*this); - } - - ba_solver::pb& ba_solver::constraint::to_pb() { - SASSERT(is_pb()); - return static_cast(*this); - } - - ba_solver::pb const& ba_solver::constraint::to_pb() const{ - SASSERT(is_pb()); - return static_cast(*this); - } - - ba_solver::pb_base const& ba_solver::constraint::to_pb_base() const{ - SASSERT(is_pb() || is_card()); - return static_cast(*this); - } - - - unsigned ba_solver::constraint::fold_max_var(unsigned w) const { - if (lit() != null_literal) w = std::max(w, lit().var()); - for (unsigned i = 0; i < size(); ++i) w = std::max(w, get_lit(i).var()); - return w; - } - - - std::ostream& operator<<(std::ostream& out, ba_solver::constraint const& cnstr) { - if (cnstr.lit() != null_literal) out << cnstr.lit() << " == "; - switch (cnstr.tag()) { - case ba_solver::card_t: { - ba_solver::card const& c = cnstr.to_card(); - for (literal l : c) { - out << l << " "; - } - out << " >= " << c.k(); - break; - } - case ba_solver::pb_t: { - ba_solver::pb const& p = cnstr.to_pb(); - bool first = true; - for (ba_solver::wliteral wl : p) { - if (!first) out << "+ "; - if (wl.first != 1) out << wl.first << " * "; - out << wl.second << " "; - first = false; - } - out << " >= " << p.k(); - break; - } - case ba_solver::xr_t: { - ba_solver::xr const& x = cnstr.to_xr(); - for (unsigned i = 0; i < x.size(); ++i) { - out << x[i] << " "; - if (i + 1 < x.size()) out << "x "; - } - break; - } - default: - UNREACHABLE(); - } - return out; - } - - - // ----------------------- - // pb_base - - bool ba_solver::pb_base::well_formed() const { - uint_set vars; - if (lit() != null_literal) vars.insert(lit().var()); - for (unsigned i = 0; i < size(); ++i) { - bool_var v = get_lit(i).var(); - if (vars.contains(v)) return false; - if (get_coeff(i) > k()) return false; - vars.insert(v); - } - return true; - } - - // ---------------------- - // card - - ba_solver::card::card(unsigned id, literal lit, literal_vector const& lits, unsigned k): - pb_base(card_t, id, lit, lits.size(), get_obj_size(lits.size()), k) { - for (unsigned i = 0; i < size(); ++i) { - m_lits[i] = lits[i]; - } - } - - void ba_solver::card::negate() { - m_lit.neg(); - for (unsigned i = 0; i < m_size; ++i) { - m_lits[i].neg(); - } - m_k = m_size - m_k + 1; - SASSERT(m_size >= m_k && m_k > 0); - } - - bool ba_solver::card::is_watching(literal l) const { - unsigned sz = std::min(k() + 1, size()); - for (unsigned i = 0; i < sz; ++i) { - if ((*this)[i] == l) return true; - } - return false; - } - - // ----------------------------------- - // pb - - ba_solver::pb::pb(unsigned id, literal lit, svector const& wlits, unsigned k): - pb_base(pb_t, id, lit, wlits.size(), get_obj_size(wlits.size()), k), - m_slack(0), - m_num_watch(0), - m_max_sum(0) { - for (unsigned i = 0; i < size(); ++i) { - m_wlits[i] = wlits[i]; - } - update_max_sum(); - } - - void ba_solver::pb::update_max_sum() { - m_max_sum = 0; - for (unsigned i = 0; i < size(); ++i) { - m_wlits[i].first = std::min(k(), m_wlits[i].first); - if (m_max_sum + m_wlits[i].first < m_max_sum) { - throw default_exception("addition of pb coefficients overflows"); - } - m_max_sum += m_wlits[i].first; - } - } - - void ba_solver::pb::negate() { - m_lit.neg(); - unsigned w = 0; - for (unsigned i = 0; i < m_size; ++i) { - m_wlits[i].second.neg(); - VERIFY(w + m_wlits[i].first >= w); - w += m_wlits[i].first; - } - m_k = w - m_k + 1; - VERIFY(w >= m_k && m_k > 0); - } - - bool ba_solver::pb::is_watching(literal l) const { - for (unsigned i = 0; i < m_num_watch; ++i) { - if ((*this)[i].second == l) return true; - } - return false; - } - - - bool ba_solver::pb::is_cardinality() const { - if (size() == 0) return false; - unsigned w = (*this)[0].first; - for (wliteral wl : *this) if (w != wl.first) return false; - return true; - } - - - - - // ---------------------------- - // card - - bool ba_solver::init_watch(card& c) { - literal root = c.lit(); - if (root != null_literal && value(root) == l_false) { - clear_watch(c); - c.negate(); - root.neg(); - } - if (root != null_literal) { - if (!is_watched(root, c)) watch_literal(root, c); - if (!c.is_pure() && !is_watched(~root, c)) watch_literal(~root, c); - } - TRACE("ba", display(tout << "init watch: ", c, true);); - SASSERT(root == null_literal || value(root) == l_true); - unsigned j = 0, sz = c.size(), bound = c.k(); - // put the non-false literals into the head. - - if (bound == sz) { - for (literal l : c) assign(c, l); - return false; - } - - for (unsigned i = 0; i < sz; ++i) { - if (value(c[i]) != l_false) { - if (j != i) { - if (c.is_watched() && j <= bound && i > bound) { - unwatch_literal(c[j], c); - watch_literal(c[i], c); - } - c.swap(i, j); - } - ++j; - } - } - DEBUG_CODE( - bool is_false = false; - for (literal l : c) { - SASSERT(!is_false || value(l) == l_false); - is_false = value(l) == l_false; - }); - - // j is the number of non-false, sz - j the number of false. - - if (j < bound) { - if (c.is_watched()) clear_watch(c); - SASSERT(0 < bound && bound < sz); - literal alit = c[j]; - - // - // we need the assignment level of the asserting literal to be maximal. - // such that conflict resolution can use the asserting literal as a starting - // point. - // - - for (unsigned i = bound; i < sz; ++i) { - if (lvl(alit) < lvl(c[i])) { - c.swap(i, j); - alit = c[j]; - } - } - set_conflict(c, alit); - return false; - } - else if (j == bound) { - for (unsigned i = 0; i < bound; ++i) { - assign(c, c[i]); - } - return false; - } - else { - if (c.is_watched()) return true; - clear_watch(c); - for (unsigned i = 0; i <= bound; ++i) { - watch_literal(c[i], c); - } - c.set_watch(); - return true; - } - } - - void ba_solver::clear_watch(card& c) { - if (c.is_clear()) return; - c.clear_watch(); - unsigned sz = std::min(c.k() + 1, c.size()); - for (unsigned i = 0; i < sz; ++i) { - unwatch_literal(c[i], c); - } - } // ----------------------- // constraint @@ -337,7 +78,7 @@ namespace sat { if (nullify) { IF_VERBOSE(100, display(verbose_stream() << "nullify tracking literal\n", p, true);); SASSERT(lvl(p.lit()) == 0); - nullify_tracking_literal(p); + p.nullify_tracking_literal(*this); init_watch(p); } @@ -497,77 +238,6 @@ namespace sat { // pb - // watch a prefix of literals, such that the slack of these is >= k - bool ba_solver::init_watch(pb& p) { - clear_watch(p); - if (p.lit() != null_literal && value(p.lit()) == l_false) { - p.negate(); - } - - VERIFY(p.lit() == null_literal || value(p.lit()) == l_true); - unsigned sz = p.size(), bound = p.k(); - - // put the non-false literals into the head. - unsigned slack = 0, slack1 = 0, num_watch = 0, j = 0; - for (unsigned i = 0; i < sz; ++i) { - if (value(p[i].second) != l_false) { - if (j != i) { - p.swap(i, j); - } - if (slack <= bound) { - slack += p[j].first; - ++num_watch; - } - else { - slack1 += p[j].first; - } - ++j; - } - } - BADLOG(verbose_stream() << "watch " << num_watch << " out of " << sz << "\n"); - - DEBUG_CODE( - bool is_false = false; - for (unsigned k = 0; k < sz; ++k) { - SASSERT(!is_false || value(p[k].second) == l_false); - SASSERT((k < j) == (value(p[k].second) != l_false)); - is_false = value(p[k].second) == l_false; - }); - - if (slack < bound) { - literal lit = p[j].second; - VERIFY(value(lit) == l_false); - for (unsigned i = j + 1; i < sz; ++i) { - if (lvl(lit) < lvl(p[i].second)) { - lit = p[i].second; - } - } - set_conflict(p, lit); - return false; - } - else { - for (unsigned i = 0; i < num_watch; ++i) { - watch_literal(p[i], p); - } - p.set_slack(slack); - p.set_num_watch(num_watch); - - SASSERT(validate_watch(p, null_literal)); - - TRACE("ba", display(tout << "init watch: ", p, true);); - - // slack is tight: - if (slack + slack1 == bound) { - SASSERT(slack1 == 0); - SASSERT(j == num_watch); - for (unsigned i = 0; i < j; ++i) { - assign(p, p[i].second); - } - } - return true; - } - } - /* Chai Kuhlmann: Lw - set of watched literals @@ -660,11 +330,10 @@ namespace sat { literal lit = p[j].second; if (value(lit) != l_false) { slack += p[j].first; - SASSERT(!is_watched(p[j].second, p)); - watch_literal(p[j], p); + SASSERT(!p.is_watched(*this, p[j].second)); + p.watch_literal(*this, p[j].second); p.swap(num_watch, j); - add_index(p, num_watch, lit); - BADLOG(verbose_stream() << "add watch: " << lit << " num watch: " << num_watch << " max: " << m_a_max << " slack " << slack << "\n"); + add_index(p, num_watch, lit); ++num_watch; } } @@ -727,19 +396,6 @@ namespace sat { return l_undef; } - void ba_solver::watch_literal(wliteral l, pb& p) { - watch_literal(l.second, p); - } - - void ba_solver::clear_watch(pb& p) { - p.clear_watch(); - for (unsigned i = 0; i < p.num_watch(); ++i) { - unwatch_literal(p[i].second, p); - } - p.set_num_watch(0); - DEBUG_CODE(for (wliteral wl : p) VERIFY(!is_watched(wl.second, p));); - } - void ba_solver::recompile(pb& p) { // IF_VERBOSE(2, verbose_stream() << "re: " << p << "\n";); SASSERT(p.num_watch() == 0); @@ -831,40 +487,6 @@ namespace sat { } } - void ba_solver::display(std::ostream& out, pb const& p, bool values) const { - if (p.lit() != null_literal) out << p.lit() << " == "; - if (values) { - out << "[watch: " << p.num_watch() << ", slack: " << p.slack() << "]"; - } - if (p.lit() != null_literal && values) { - out << "@(" << value(p.lit()); - if (value(p.lit()) != l_undef) { - out << ":" << lvl(p.lit()); - } - out << "): "; - } - unsigned i = 0; - for (wliteral wl : p) { - literal l = wl.second; - unsigned w = wl.first; - if (i > 0) out << "+ "; - if (i++ == p.num_watch()) out << " | "; - if (w > 1) out << w << " * "; - out << l; - if (values) { - out << "@(" << value(l); - if (value(l) != l_undef) { - out << ":" << lvl(l); - } - out << ") "; - } - else { - out << " "; - } - } - out << ">= " << p.k() << "\n"; - } - // --------------------------- // conflict resolution @@ -920,7 +542,7 @@ namespace sat { return static_cast(c); } - ba_solver::wliteral ba_solver::get_wliteral(bool_var v) { + wliteral ba_solver::get_wliteral(bool_var v) { int64_t c1 = get_coeff(v); literal l = literal(v, c1 < 0); c1 = std::abs(c1); @@ -1096,16 +718,24 @@ namespace sat { break; } case justification::EXT_JUSTIFICATION: { - constraint& cnstr = index2constraint(js.get_ext_justification_idx()); + auto cindex = js.get_ext_justification_idx(); + auto* ext = sat::constraint_base::to_extension(cindex); + if (ext != this) { + m_lemma.reset(); + ext->get_antecedents(consequent, idx, m_lemma, false); + for (literal l : m_lemma) process_antecedent(~l, offset); + break; + } + constraint& cnstr = index2constraint(cindex); ++m_stats.m_num_resolves; switch (cnstr.tag()) { - case card_t: { + case ba::tag_t::card_t: { card& c = cnstr.to_card(); inc_bound(static_cast(offset) * c.k()); process_card(c, offset); break; } - case pb_t: { + case ba::tag_t::pb_t: { pb& p = cnstr.to_pb(); m_lemma.reset(); inc_bound(offset); @@ -1119,13 +749,14 @@ namespace sat { for (literal l : m_lemma) process_antecedent(~l, offset); break; } - case xr_t: { + case ba::tag_t::xr_t: { // jus.push_back(js); m_lemma.reset(); inc_bound(offset); inc_coeff(consequent, offset); get_xr_antecedents(consequent, idx, js, m_lemma); - for (literal l : m_lemma) process_antecedent(~l, offset); + for (literal l : m_lemma) + process_antecedent(~l, offset); break; } default: @@ -1428,11 +1059,19 @@ namespace sat { case justification::EXT_JUSTIFICATION: { ++m_stats.m_num_resolves; ext_justification_idx index = js.get_ext_justification_idx(); + auto* ext = sat::constraint_base::to_extension(index); + if (ext != this) { + m_lemma.reset(); + ext->get_antecedents(consequent, index, m_lemma, false); + for (literal l : m_lemma) + process_antecedent(~l, 1); + break; + } constraint& cnstr = index2constraint(index); SASSERT(!cnstr.was_removed()); switch (cnstr.tag()) { - case card_t: - case pb_t: { + case ba::tag_t::card_t: + case ba::tag_t::pb_t: { pb_base const& p = cnstr.to_pb_base(); unsigned k = p.k(), sz = p.size(); m_A.reset(0); @@ -1747,7 +1386,7 @@ namespace sat { add_at_least(lit, lits, k, m_is_redundant); } - ba_solver::constraint* ba_solver::add_at_least(literal lit, literal_vector const& lits, unsigned k, bool learned) { + constraint* ba_solver::add_at_least(literal lit, literal_vector const& lits, unsigned k, bool learned) { if (k == 1 && lit == null_literal) { literal_vector _lits(lits); s().mk_clause(_lits.size(), _lits.c_ptr(), status::th(learned, get_id())); @@ -1784,8 +1423,8 @@ namespace sat { } else { if (m_solver) m_solver->set_external(lit.var()); - watch_literal(lit, *c); - watch_literal(~lit, *c); + c->watch_literal(*this, lit); + c->watch_literal(*this, ~lit); } SASSERT(c->well_formed()); if (m_solver && m_solver->get_config().m_drat) { @@ -1798,27 +1437,20 @@ namespace sat { bool ba_solver::init_watch(constraint& c) { - if (inconsistent()) return false; - switch (c.tag()) { - case card_t: return init_watch(c.to_card()); - case pb_t: return init_watch(c.to_pb()); - case xr_t: return init_watch(c.to_xr()); - } - UNREACHABLE(); - return false; - } + return !inconsistent() && c.init_watch(*this); + } lbool ba_solver::add_assign(constraint& c, literal l) { switch (c.tag()) { - case card_t: return add_assign(c.to_card(), l); - case pb_t: return add_assign(c.to_pb(), l); - case xr_t: return add_assign(c.to_xr(), l); + case ba::tag_t::card_t: return add_assign(c.to_card(), l); + case ba::tag_t::pb_t: return add_assign(c.to_pb(), l); + case ba::tag_t::xr_t: return add_assign(c.to_xr(), l); } UNREACHABLE(); return l_undef; } - ba_solver::constraint* ba_solver::add_pb_ge(literal lit, svector const& wlits, unsigned k, bool learned) { + constraint* ba_solver::add_pb_ge(literal lit, svector const& wlits, unsigned k, bool learned) { bool units = true; for (wliteral wl : wlits) units &= wl.first == 1; if (k == 0 && lit == null_literal) { @@ -1868,54 +1500,10 @@ namespace sat { } } - double ba_solver::get_reward(card const& c, literal_occs_fun& literal_occs) const { - unsigned k = c.k(), slack = 0; - bool do_add = get_config().m_lookahead_reward == heule_schur_reward; - double to_add = do_add ? 0: 1; - for (literal l : c) { - switch (value(l)) { - case l_true: --k; if (k == 0) return 0; - case l_undef: - if (do_add) to_add += literal_occs(l); - ++slack; break; - case l_false: break; - } - } - if (k >= slack) return 1; - return pow(0.5, slack - k + 1) * to_add; - } - - double ba_solver::get_reward(pb const& c, literal_occs_fun& occs) const { - unsigned k = c.k(), slack = 0; - bool do_add = get_config().m_lookahead_reward == heule_schur_reward; - double to_add = do_add ? 0 : 1; - double undefs = 0; - for (wliteral wl : c) { - literal l = wl.second; - unsigned w = wl.first; - switch (value(l)) { - case l_true: if (k <= w) return 0; - case l_undef: - if (do_add) to_add += occs(l); - ++undefs; - slack += w; - break; // TBD multiplier factor on this - case l_false: break; - } - } - if (k >= slack || 0 == undefs) return 0; - double avg = slack / undefs; - return pow(0.5, (slack - k + 1)/avg) * to_add; - } double ba_solver::get_reward(literal l, ext_justification_idx idx, literal_occs_fun& occs) const { constraint const& c = index2constraint(idx); - switch (c.tag()) { - case card_t: return get_reward(c.to_card(), occs); - case pb_t: return get_reward(c.to_pb(), occs); - case xr_t: return 0; - default: UNREACHABLE(); return 0; - } + return c.get_reward(*this, occs); } @@ -1959,12 +1547,6 @@ namespace sat { if (l == 0) return false; unsigned start = s().m_scopes[l-1].m_trail_lim; literal_vector const& lits = s().m_trail; - -#if 0 - IF_VERBOSE(10, verbose_stream() << "level " << l << " scope level " << s().scope_lvl() << " tail lim start: " - << start << " size of lits: " << lits.size() << " num scopes " << s().m_scopes.size() << "\n";); -#endif - for (unsigned sz = lits.size(); sz-- > start; ) { if (lits[sz] == above) return true; if (lits[sz] == below) return false; @@ -2060,22 +1642,7 @@ namespace sat { } bool ba_solver::is_extended_binary(ext_justification_idx idx, literal_vector & r) { - constraint const& c = index2constraint(idx); - switch (c.tag()) { - case card_t: { - card const& ca = c.to_card(); - if (ca.size() == ca.k() + 1 && ca.lit() == null_literal) { - r.reset(); - for (literal l : ca) r.push_back(l); - return true; - } - else { - return false; - } - } - default: - return false; - } + return index2constraint(idx).is_extended_binary(r); } @@ -2111,28 +1678,12 @@ namespace sat { get_antecedents(l, index2constraint(idx), r, probing); } - bool ba_solver::is_watched(literal lit, constraint const& c) const { - return get_wlist(~lit).contains(watched(c.cindex())); - } - - void ba_solver::unwatch_literal(literal lit, constraint& c) { - watched w(c.cindex()); - get_wlist(~lit).erase(w); - SASSERT(!is_watched(lit, c)); - } - - void ba_solver::watch_literal(literal lit, constraint& c) { - if (c.is_pure() && lit == ~c.lit()) return; - SASSERT(!is_watched(lit, c)); - watched w(c.cindex()); - get_wlist(~lit).push_back(w); - } void ba_solver::get_antecedents(literal l, constraint const& c, literal_vector& r, bool probing) { switch (c.tag()) { - case card_t: get_antecedents(l, c.to_card(), r); break; - case pb_t: get_antecedents(l, c.to_pb(), r); break; - case xr_t: get_antecedents(l, c.to_xr(), r); break; + case ba::tag_t::card_t: get_antecedents(l, c.to_card(), r); break; + case ba::tag_t::pb_t: get_antecedents(l, c.to_pb(), r); break; + case ba::tag_t::xr_t: get_antecedents(l, c.to_xr(), r); break; default: UNREACHABLE(); break; } if (get_config().m_drat && m_solver && !probing) { @@ -2144,34 +1695,14 @@ namespace sat { } } - void ba_solver::nullify_tracking_literal(constraint& c) { - if (c.lit() != null_literal) { - unwatch_literal(c.lit(), c); - unwatch_literal(~c.lit(), c); - c.nullify_literal(); - } - } - void ba_solver::clear_watch(constraint& c) { - switch (c.tag()) { - case card_t: - clear_watch(c.to_card()); - break; - case pb_t: - clear_watch(c.to_pb()); - break; - case xr_t: - clear_watch(c.to_xr()); - break; - default: - UNREACHABLE(); - } + c.clear_watch(*this); } void ba_solver::remove_constraint(constraint& c, char const* reason) { TRACE("ba", display(tout << "remove ", c, true) << " " << reason << "\n";); IF_VERBOSE(21, display(verbose_stream() << "remove " << reason << " ", c, true);); - nullify_tracking_literal(c); + c.nullify_tracking_literal(*this); clear_watch(c); c.set_removed(); m_constraint_removed = true; @@ -2182,13 +1713,7 @@ namespace sat { bool ba_solver::validate_unit_propagation(constraint const& c, literal l) const { return true; - switch (c.tag()) { - case card_t: return validate_unit_propagation(c.to_card(), l); - case pb_t: return validate_unit_propagation(c.to_pb(), l); - case xr_t: return true; - default: UNREACHABLE(); break; - } - return false; + return c.validate_unit_propagation(*this, l); } bool ba_solver::validate_conflict(constraint const& c) const { @@ -2197,24 +1722,12 @@ namespace sat { lbool ba_solver::eval(constraint const& c) const { lbool v1 = c.lit() == null_literal ? l_true : value(c.lit()); - switch (c.tag()) { - case card_t: return eval(v1, eval(c.to_card())); - case pb_t: return eval(v1, eval(c.to_pb())); - case xr_t: return eval(v1, eval(c.to_xr())); - default: UNREACHABLE(); break; - } - return l_undef; + return eval(v1, c.eval(*this)); } lbool ba_solver::eval(model const& m, constraint const& c) const { - lbool v1 = c.lit() == null_literal ? l_true : value(m, c.lit()); - switch (c.tag()) { - case card_t: return eval(v1, eval(m, c.to_card())); - case pb_t: return eval(v1, eval(m, c.to_pb())); - case xr_t: return eval(v1, eval(m, c.to_xr())); - default: UNREACHABLE(); break; - } - return l_undef; + lbool v1 = c.lit() == null_literal ? l_true : ba::value(m, c.lit()); + return eval(v1, c.eval(m)); } lbool ba_solver::eval(lbool a, lbool b) const { @@ -2222,63 +1735,6 @@ namespace sat { return (a == b) ? l_true : l_false; } - lbool ba_solver::eval(card const& c) const { - unsigned trues = 0, undefs = 0; - for (literal l : c) { - switch (value(l)) { - case l_true: trues++; break; - case l_undef: undefs++; break; - default: break; - } - } - if (trues + undefs < c.k()) return l_false; - if (trues >= c.k()) return l_true; - return l_undef; - } - - lbool ba_solver::eval(model const& m, card const& c) const { - unsigned trues = 0, undefs = 0; - for (literal l : c) { - switch (value(m, l)) { - case l_true: trues++; break; - case l_undef: undefs++; break; - default: break; - } - } - if (trues + undefs < c.k()) return l_false; - if (trues >= c.k()) return l_true; - return l_undef; - } - - lbool ba_solver::eval(model const& m, pb const& p) const { - unsigned trues = 0, undefs = 0; - for (wliteral wl : p) { - switch (value(m, wl.second)) { - case l_true: trues += wl.first; break; - case l_undef: undefs += wl.first; break; - default: break; - } - } - if (trues + undefs < p.k()) return l_false; - if (trues >= p.k()) return l_true; - return l_undef; - } - - lbool ba_solver::eval(pb const& p) const { - unsigned trues = 0, undefs = 0; - for (wliteral wl : p) { - switch (value(wl.second)) { - case l_true: trues += wl.first; break; - case l_undef: undefs += wl.first; break; - default: break; - } - } - if (trues + undefs < p.k()) return l_false; - if (trues >= p.k()) return l_true; - return l_undef; - } - - bool ba_solver::validate() { if (!validate_watch_literals()) { return false; @@ -2324,14 +1780,14 @@ namespace sat { return false; } if (c.lit() != null_literal && value(c.lit()) != l_true) return true; - SASSERT(c.lit() == null_literal || lvl(c.lit()) == 0 || (is_watched(c.lit(), c) && is_watched(~c.lit(), c))); + SASSERT(c.lit() == null_literal || lvl(c.lit()) == 0 || (c.is_watched(*this, c.lit()) && c.is_watched(*this, ~c.lit()))); if (eval(c) == l_true) { return true; } literal_vector lits(c.literals()); for (literal l : lits) { if (lvl(l) == 0) continue; - bool found = is_watched(l, c); + bool found = c.is_watched(*this, l); if (found != c.is_watching(l)) { IF_VERBOSE(0, @@ -2355,9 +1811,9 @@ namespace sat { bool ba_solver::validate_watch(pb const& p, literal alit) const { for (unsigned i = 0; i < p.size(); ++i) { literal l = p[i].second; - if (l != alit && lvl(l) != 0 && is_watched(l, p) != (i < p.num_watch())) { + if (l != alit && lvl(l) != 0 && p.is_watched(*this, l) != (i < p.num_watch())) { IF_VERBOSE(0, display(verbose_stream(), p, true); - verbose_stream() << "literal " << l << " at position " << i << " " << is_watched(l, p) << "\n";); + verbose_stream() << "literal " << l << " at position " << i << " " << p.is_watched(*this, l) << "\n";); UNREACHABLE(); return false; } @@ -2379,7 +1835,7 @@ namespace sat { \brief Lex on (glue, size) */ struct constraint_glue_psm_lt { - bool operator()(ba_solver::constraint const * c1, ba_solver::constraint const * c2) const { + bool operator()(constraint const * c1, constraint const * c2) const { return (c1->glue() < c2->glue()) || (c1->glue() == c2->glue() && @@ -2391,12 +1847,12 @@ namespace sat { void ba_solver::update_psm(constraint& c) const { unsigned r = 0; switch (c.tag()) { - case card_t: + case ba::tag_t::card_t: for (literal l : c.to_card()) { if (s().m_phase[l.var()] == !l.sign()) ++r; } break; - case pb_t: + case ba::tag_t::pb_t: for (wliteral l : c.to_pb()) { if (s().m_phase[l.second.var()] == !l.second.sign()) ++r; } @@ -2484,7 +1940,7 @@ namespace sat { literal lit2 = c[i]; if (value(lit2) != l_false) { c.swap(index, i); - watch_literal(lit2, c); + c.watch_literal(*this, lit2); return l_undef; } } @@ -2533,7 +1989,7 @@ namespace sat { } - check_result ba_solver::check() { return CR_DONE; } + check_result ba_solver::check() { return check_result::CR_DONE; } void ba_solver::push() { m_constraint_to_reinit_lim.push_back(m_constraint_to_reinit.size()); @@ -2562,13 +2018,13 @@ namespace sat { void ba_solver::simplify(constraint& c) { SASSERT(s().at_base_lvl()); switch (c.tag()) { - case card_t: + case ba::tag_t::card_t: simplify(c.to_card()); break; - case pb_t: + case ba::tag_t::pb_t: simplify(c.to_pb()); break; - case xr_t: + case ba::tag_t::xr_t: simplify(c.to_xr()); break; default: @@ -2802,23 +2258,8 @@ namespace sat { void ba_solver::validate_eliminated(ptr_vector const& cs) { for (constraint const* c : cs) { if (c->learned()) continue; - switch (c->tag()) { - case tag_t::card_t: - for (literal l : c->to_card()) { - VERIFY(!s().was_eliminated(l.var())); - } - break; - case tag_t::pb_t: - for (wliteral wl : c->to_pb()) { - VERIFY(!s().was_eliminated(wl.second.var())); - } - break; - case tag_t::xr_t: - for (literal l : c->to_xr()) { - VERIFY(!s().was_eliminated(l.var())); - } - break; - } + for (auto l : constraint::literal_iterator(*c)) + VERIFY(!s().was_eliminated(l.var())); } } @@ -2827,13 +2268,13 @@ namespace sat { IF_VERBOSE(0, display(verbose_stream() << "recompile\n", c, true);); } switch (c.tag()) { - case card_t: + case ba::tag_t::card_t: recompile(c.to_card()); break; - case pb_t: + case ba::tag_t::pb_t: recompile(c.to_pb()); break; - case xr_t: + case ba::tag_t::xr_t: add_xr(c.to_xr().literals(), c.learned()); remove_constraint(c, "recompile xor"); break; @@ -2843,7 +2284,7 @@ namespace sat { } void ba_solver::recompile(card& c) { - SASSERT(c.lit() == null_literal || is_watched(c.lit(), c)); + SASSERT(c.lit() == null_literal || c.is_watched(*this, c.lit())); // pre-condition is that the literals, except c.lit(), in c are unwatched. if (c.id() == _bad_id) std::cout << "recompile: " << c << "\n"; @@ -2954,7 +2395,7 @@ namespace sat { if (c.lit() == null_literal || value(c.lit()) == l_true) { init_watch(c); } - SASSERT(c.lit() == null_literal || is_watched(c.lit(), c)); + SASSERT(c.lit() == null_literal || c.is_watched(*this, c.lit())); SASSERT(c.well_formed()); } } @@ -3036,18 +2477,18 @@ namespace sat { void ba_solver::split_root(constraint& c) { switch (c.tag()) { - case card_t: split_root(c.to_card()); break; - case pb_t: split_root(c.to_pb()); break; - case xr_t: NOT_IMPLEMENTED_YET(); break; + case ba::tag_t::card_t: split_root(c.to_card()); break; + case ba::tag_t::pb_t: split_root(c.to_pb()); break; + case ba::tag_t::xr_t: NOT_IMPLEMENTED_YET(); break; } } void ba_solver::flush_roots(constraint& c) { - if (c.lit() != null_literal && !is_watched(c.lit(), c)) { - watch_literal(c.lit(), c); - watch_literal(~c.lit(), c); + if (c.lit() != null_literal && !c.is_watched(*this, c.lit())) { + c.watch_literal(*this, c.lit()); + c.watch_literal(*this, ~c.lit()); } - SASSERT(c.lit() == null_literal || is_watched(c.lit(), c)); + SASSERT(c.lit() == null_literal || c.is_watched(*this, c.lit())); bool found = c.lit() != null_literal && m_root_vars[c.lit().var()]; for (unsigned i = 0; !found && i < c.size(); ++i) { found = m_root_vars[c.get_lit(i).var()]; @@ -3064,10 +2505,10 @@ namespace sat { literal root = c.lit(); if (root != null_literal && m_roots[root.index()] != root) { root = m_roots[root.index()]; - nullify_tracking_literal(c); + c.nullify_tracking_literal(*this); c.update_literal(root); - watch_literal(root, c); - watch_literal(~root, c); + c.watch_literal(*this, root); + c.watch_literal(*this, ~root); } bool found_dup = false; @@ -3131,31 +2572,19 @@ namespace sat { m_cnstr_use_list[(~lit).index()].push_back(cp); } switch (cp->tag()) { - case card_t: { - card& c = cp->to_card(); - for (literal l : c) { - m_cnstr_use_list[l.index()].push_back(&c); - if (lit != null_literal) m_cnstr_use_list[(~l).index()].push_back(&c); + case ba::tag_t::card_t: + case ba::tag_t::pb_t: + for (literal l : constraint::literal_iterator(*cp)) { + m_cnstr_use_list[l.index()].push_back(cp); + if (lit != null_literal) m_cnstr_use_list[(~l).index()].push_back(cp); } - break; - } - case pb_t: { - pb& p = cp->to_pb(); - for (wliteral wl : p) { - literal l = wl.second; - m_cnstr_use_list[l.index()].push_back(&p); - if (lit != null_literal) m_cnstr_use_list[(~l).index()].push_back(&p); - } - break; - } - case xr_t: { - xr& x = cp->to_xr(); - for (literal l : x) { - m_cnstr_use_list[l.index()].push_back(&x); - m_cnstr_use_list[(~l).index()].push_back(&x); + break; + case ba::tag_t::xr_t: + for (literal l : cp->to_xr()) { + m_cnstr_use_list[l.index()].push_back(cp); + m_cnstr_use_list[(~l).index()].push_back(cp); } - break; - } + break; } } } @@ -3167,8 +2596,8 @@ namespace sat { constraint& c = *cp; literal lit = c.lit(); switch (c.tag()) { - case card_t: - case pb_t: { + case ba::tag_t::card_t: + case ba::tag_t::pb_t: { if (lit != null_literal && value(lit) == l_undef && use_count(lit) == 1 && @@ -3276,10 +2705,10 @@ namespace sat { void ba_solver::unit_strengthen(big& big, constraint& c) { if (c.was_removed()) return; switch (c.tag()) { - case card_t: + case ba::tag_t::card_t: unit_strengthen(big, c.to_card()); break; - case pb_t: + case ba::tag_t::pb_t: unit_strengthen(big, c.to_pb()); break; default: @@ -3346,12 +2775,12 @@ namespace sat { void ba_solver::subsumption(constraint& cnstr) { if (cnstr.was_removed()) return; switch (cnstr.tag()) { - case card_t: { + case ba::tag_t::card_t: { card& c = cnstr.to_card(); if (c.k() > 1) subsumption(c); break; } - case pb_t: { + case ba::tag_t::pb_t: { pb& p = cnstr.to_pb(); if (p.k() > 1) subsumption(p); break; @@ -3413,7 +2842,7 @@ namespace sat { constraint& c = *(*it); if (c.was_removed()) { clear_watch(c); - nullify_tracking_literal(c); + c.nullify_tracking_literal(*this); c.deallocate(m_allocator); } else if (learned && !c.learned()) { @@ -3510,10 +2939,10 @@ namespace sat { if (c == &p1 || c->was_removed()) continue; bool s = false; switch (c->tag()) { - case card_t: + case ba::tag_t::card_t: s = subsumes(p1, c->to_card()); break; - case pb_t: + case ba::tag_t::pb_t: s = subsumes(p1, c->to_pb()); break; default: @@ -3604,25 +3033,9 @@ namespace sat { if (lit != null_literal) { s().set_external(lit.var()); } - switch (c.tag()) { - case card_t: - for (literal lit : c.to_card()) { - s().set_external(lit.var()); - SASSERT(!s().was_eliminated(lit.var())); - } - break; - case pb_t: - for (wliteral wl : c.to_pb()) { - s().set_external(wl.second.var()); - SASSERT(!s().was_eliminated(wl.second.var())); - } - break; - default: - for (literal lit : c.to_xr()) { - s().set_external(lit.var()); - SASSERT(!s().was_eliminated(lit.var())); - } - break; + for (literal lit : constraint::literal_iterator(c)) { + s().set_external(lit.var()); + SASSERT(!s().was_eliminated(lit.var())); } c.set_learned(false); } @@ -3743,14 +3156,14 @@ namespace sat { svector wlits; for (constraint* cp : constraints) { switch (cp->tag()) { - case card_t: { + case ba::tag_t::card_t: { card const& c = cp->to_card(); lits.reset(); for (literal l : c) lits.push_back(l); result->add_at_least(c.lit(), lits, c.k(), c.learned()); break; } - case pb_t: { + case ba::tag_t::pb_t: { pb const& p = cp->to_pb(); wlits.reset(); for (wliteral w : p) { @@ -3759,7 +3172,7 @@ namespace sat { result->add_pb_ge(p.lit(), wlits, p.k(), p.learned()); break; } - case xr_t: { + case ba::tag_t::xr_t: { xr const& x = cp->to_xr(); lits.reset(); for (literal l : x) lits.push_back(l); @@ -3772,7 +3185,7 @@ namespace sat { } } - void ba_solver::init_use_list(ext_use_list& ul) { + void ba_solver::init_use_list(sat::ext_use_list& ul) { ul.init(s().num_vars()); for (constraint const* cp : m_constraints) { ext_constraint_idx idx = cp->cindex(); @@ -3780,32 +3193,7 @@ namespace sat { ul.insert(cp->lit(), idx); ul.insert(~cp->lit(), idx); } - switch (cp->tag()) { - case card_t: { - card const& c = cp->to_card(); - for (literal l : c) { - ul.insert(l, idx); - } - break; - } - case pb_t: { - pb const& p = cp->to_pb(); - for (wliteral w : p) { - ul.insert(w.second, idx); - } - break; - } - case xr_t: { - xr const& x = cp->to_xr(); - for (literal l : x) { - ul.insert(l, idx); - ul.insert(~l, idx); - } - break; - } - default: - UNREACHABLE(); - } + cp->init_use_list(ul); } } @@ -3818,36 +3206,7 @@ namespace sat { constraint const& c = index2constraint(idx); simplifier& sim = s().m_simplifier; if (c.lit() != null_literal) return false; - switch (c.tag()) { - case card_t: { - card const& ca = c.to_card(); - unsigned weight = 0; - for (literal l2 : ca) { - if (sim.is_marked(~l2)) ++weight; - } - return weight >= ca.k(); - } - case pb_t: { - pb const& p = c.to_pb(); - unsigned weight = 0, offset = 0; - for (wliteral l2 : p) { - if (~l2.second == l) { - offset = l2.first; - break; - } - } - SASSERT(offset != 0); - for (wliteral l2 : p) { - if (sim.is_marked(~l2.second)) { - weight += std::min(offset, l2.first); - } - } - return weight >= p.k(); - } - default: - break; - } - return false; + return c.is_blocked(sim, l); } @@ -3892,41 +3251,6 @@ namespace sat { } - void ba_solver::display_lit(std::ostream& out, literal lit, unsigned sz, bool values) const { - if (lit != null_literal) { - if (values) { - out << lit << "[" << sz << "]"; - out << "@(" << value(lit); - if (value(lit) != l_undef) { - out << ":" << lvl(lit); - } - out << "): "; - } - else { - out << lit << " == "; - } - } - } - - void ba_solver::display(std::ostream& out, card const& c, bool values) const { - display_lit(out, c.lit(), c.size(), values); - for (unsigned i = 0; i < c.size(); ++i) { - literal l = c[i]; - out << l; - if (values) { - out << "@(" << value(l); - if (value(l) != l_undef) { - out << ":" << lvl(l); - } - out << ") "; - } - else { - out << " "; - } - } - out << ">= " << c.k() << "\n"; - } - std::ostream& ba_solver::display(std::ostream& out) const { for (constraint const* c : m_constraints) { out << (*c) << "\n"; @@ -3949,13 +3273,7 @@ namespace sat { } std::ostream& ba_solver::display(std::ostream& out, constraint const& c, bool values) const { - switch (c.tag()) { - case card_t: display(out, c.to_card(), values); break; - case pb_t: display(out, c.to_pb(), values); break; - case xr_t: display(out, c.to_xr(), values); break; - default: UNREACHABLE(); break; - } - return out; + return c.display(out, *this, values); } void ba_solver::collect_statistics(statistics& st) const { @@ -3970,31 +3288,6 @@ namespace sat { st.update("ba subsumes", m_stats.m_num_bin_subsumes + m_stats.m_num_clause_subsumes + m_stats.m_num_pb_subsumes); } - bool ba_solver::validate_unit_propagation(card const& c, literal alit) const { - (void) alit; - if (c.lit() != null_literal && value(c.lit()) != l_true) return false; - for (unsigned i = c.k(); i < c.size(); ++i) { - if (value(c[i]) != l_false) return false; - } - return true; - } - - bool ba_solver::validate_unit_propagation(pb const& p, literal alit) const { - if (p.lit() != null_literal && value(p.lit()) != l_true) { - return false; - } - - unsigned sum = 0; - TRACE("ba", display(tout << "validate: " << alit << "\n", p, true);); - for (wliteral wl : p) { - literal lit = wl.second; - lbool val = value(lit); - if (val != l_false && lit != alit) { - sum += wl.first; - } - } - return sum < p.k(); - } bool ba_solver::validate_unit_propagation(pb const& p, literal_vector const& r, literal alit) const { // all elements of r are true, @@ -4116,7 +3409,7 @@ namespace sat { m_overflow |= sum >= UINT_MAX/2; } - ba_solver::constraint* ba_solver::active2lemma() { + constraint* ba_solver::active2lemma() { switch (s().m_config.m_pb_lemma_format) { case PB_LEMMA_CARDINALITY: return active2card(); @@ -4128,7 +3421,7 @@ namespace sat { } } - ba_solver::constraint* ba_solver::active2constraint() { + constraint* ba_solver::active2constraint() { active2wlits(); if (m_overflow) { return nullptr; @@ -4164,13 +3457,13 @@ namespace sat { */ struct compare_wlit { - bool operator()(ba_solver::wliteral l1, ba_solver::wliteral l2) const { + bool operator()(wliteral l1, wliteral l2) const { return l1.first > l2.first; } }; - ba_solver::constraint* ba_solver::active2card() { + constraint* ba_solver::active2card() { active2wlits(); if (m_overflow) { return nullptr; @@ -4271,6 +3564,7 @@ namespace sat { } case justification::EXT_JUSTIFICATION: { ext_justification_idx index = js.get_ext_justification_idx(); + VERIFY(this == sat::constraint_base::to_extension(index)); constraint& cnstr = index2constraint(index); constraint2pb(cnstr, lit, offset, ineq); break; @@ -4283,21 +3577,21 @@ namespace sat { void ba_solver::constraint2pb(constraint& cnstr, literal lit, unsigned offset, ineq& ineq) { switch (cnstr.tag()) { - case card_t: { + case ba::tag_t::card_t: { card& c = cnstr.to_card(); ineq.reset(offset*c.k()); for (literal l : c) ineq.push(l, offset); if (c.lit() != null_literal) ineq.push(~c.lit(), offset*c.k()); break; } - case pb_t: { + case ba::tag_t::pb_t: { pb& p = cnstr.to_pb(); ineq.reset(offset * p.k()); for (wliteral wl : p) ineq.push(wl.second, offset * wl.first); if (p.lit() != null_literal) ineq.push(~p.lit(), offset * p.k()); break; } - case xr_t: { + case ba::tag_t::xr_t: { xr& x = cnstr.to_xr(); literal_vector ls; SASSERT(lit != null_literal); @@ -4546,7 +3840,7 @@ namespace sat { literal_vector lits; for (constraint* cp : m_constraints) { switch (cp->tag()) { - case card_t: { + case ba::tag_t::card_t: { card const& c = cp->to_card(); unsigned n = c.size(); unsigned k = c.k(); @@ -4584,18 +3878,18 @@ namespace sat { } break; } - case ba_solver::pb_t: { - ba_solver::pb const& p = cp->to_pb(); + case ba::tag_t::pb_t: { + pb const& p = cp->to_pb(); lits.reset(); coeffs.reset(); unsigned sum = 0; - for (ba_solver::wliteral wl : p) sum += wl.first; + for (wliteral wl : p) sum += wl.first; if (p.lit() == null_literal) { // w1 + .. + w_n >= k // <=> // ~wl + ... + ~w_n <= sum_of_weights - k - for (ba_solver::wliteral wl : p) lits.push_back(~(wl.second)), coeffs.push_back(wl.first); + for (wliteral wl : p) lits.push_back(~(wl.second)), coeffs.push_back(wl.first); add_pb(lits.size(), lits.c_ptr(), coeffs.c_ptr(), sum - p.k()); } else { @@ -4607,18 +3901,18 @@ namespace sat { // (sum - k + 1)*~lit + w1 + .. + w_n <= sum // k*lit + ~wl + ... + ~w_n <= sum lits.push_back(p.lit()), coeffs.push_back(p.k()); - for (ba_solver::wliteral wl : p) lits.push_back(~(wl.second)), coeffs.push_back(wl.first); + for (wliteral wl : p) lits.push_back(~(wl.second)), coeffs.push_back(wl.first); add_pb(lits.size(), lits.c_ptr(), coeffs.c_ptr(), sum); lits.reset(); coeffs.reset(); lits.push_back(~p.lit()), coeffs.push_back(sum + 1 - p.k()); - for (ba_solver::wliteral wl : p) lits.push_back(wl.second), coeffs.push_back(wl.first); + for (wliteral wl : p) lits.push_back(wl.second), coeffs.push_back(wl.first); add_pb(lits.size(), lits.c_ptr(), coeffs.c_ptr(), sum); } break; } - case ba_solver::xr_t: + case ba::tag_t::xr_t: return false; } } diff --git a/src/sat/smt/ba_solver.h b/src/sat/smt/ba_solver.h index 0ae2210b5..855b16410 100644 --- a/src/sat/smt/ba_solver.h +++ b/src/sat/smt/ba_solver.h @@ -26,6 +26,10 @@ Revision History: #include "sat/sat_big.h" #include "sat/smt/sat_smt.h" #include "sat/smt/sat_th.h" +#include "sat/smt/ba_constraint.h" +#include "sat/smt/ba_card.h" +#include "sat/smt/ba_pb.h" +#include "sat/smt/ba_xor.h" #include "util/small_object_allocator.h" #include "util/scoped_ptr_vector.h" #include "util/sorting_network.h" @@ -33,9 +37,16 @@ Revision History: namespace sat { + typedef ba::constraint constraint; + typedef ba::wliteral wliteral; + typedef ba::card card; + typedef ba::xr xr; + typedef ba::pb_base pb_base; + typedef ba::pb pb; + class xor_finder; - class ba_solver : public euf::th_solver { + class ba_solver : public euf::th_solver, public ba::solver_interface { friend class local_search; @@ -54,167 +65,7 @@ namespace sat { stats() { reset(); } void reset() { memset(this, 0, sizeof(*this)); } }; - - public: - enum tag_t { - card_t, - pb_t, - xr_t - }; - - class card; - class pb; - class xr; - class pb_base; - - class constraint { - protected: - tag_t m_tag; - bool m_removed; - literal m_lit; - literal m_watch; - unsigned m_glue; - unsigned m_psm; - unsigned m_size; - size_t m_obj_size; - bool m_learned; - unsigned m_id; - bool m_pure; // is the constraint pure (only positive occurrences) - public: - constraint(tag_t t, unsigned id, literal l, unsigned sz, size_t osz): - m_tag(t), m_removed(false), m_lit(l), m_watch(null_literal), m_glue(0), m_psm(0), m_size(sz), m_obj_size(osz), m_learned(false), m_id(id), m_pure(false) { - } - ext_constraint_idx cindex() const { return constraint_base::mem2base(this); } - void deallocate(small_object_allocator& a) { a.deallocate(obj_size(), constraint_base::mem2base_ptr(this)); } - unsigned id() const { return m_id; } - tag_t tag() const { return m_tag; } - literal lit() const { return m_lit; } - unsigned size() const { return m_size; } - void set_size(unsigned sz) { SASSERT(sz <= m_size); m_size = sz; } - void update_literal(literal l) { m_lit = l; } - bool was_removed() const { return m_removed; } - void set_removed() { m_removed = true; } - void nullify_literal() { m_lit = null_literal; } - unsigned glue() const { return m_glue; } - void set_glue(unsigned g) { m_glue = g; } - unsigned psm() const { return m_psm; } - void set_psm(unsigned p) { m_psm = p; } - void set_learned(bool f) { m_learned = f; } - bool learned() const { return m_learned; } - bool is_watched() const { return m_watch == m_lit && m_lit != null_literal; } - void set_watch() { m_watch = m_lit; } - void clear_watch() { m_watch = null_literal; } - bool is_clear() const { return m_watch == null_literal && m_lit != null_literal; } - bool is_pure() const { return m_pure; } - void set_pure() { m_pure = true; } - unsigned fold_max_var(unsigned w) const; - - size_t obj_size() const { return m_obj_size; } - card& to_card(); - pb& to_pb(); - xr& to_xr(); - card const& to_card() const; - pb const& to_pb() const; - xr const& to_xr() const; - pb_base const& to_pb_base() const; - bool is_card() const { return m_tag == card_t; } - bool is_pb() const { return m_tag == pb_t; } - bool is_xr() const { return m_tag == xr_t; } - - virtual bool is_watching(literal l) const { UNREACHABLE(); return false; }; - virtual literal_vector literals() const { UNREACHABLE(); return literal_vector(); } - virtual void swap(unsigned i, unsigned j) { UNREACHABLE(); } - virtual literal get_lit(unsigned i) const { UNREACHABLE(); return null_literal; } - virtual void set_lit(unsigned i, literal l) { UNREACHABLE(); } - virtual bool well_formed() const { return true; } - virtual void negate() { UNREACHABLE(); } - }; - - friend std::ostream& operator<<(std::ostream& out, constraint const& c); - - // base class for pb and cardinality constraints - class pb_base : public constraint { - protected: - unsigned m_k; - public: - pb_base(tag_t t, unsigned id, literal l, unsigned sz, size_t osz, unsigned k): - constraint(t, id, l, sz, osz), m_k(k) { VERIFY(k < 4000000000); } - virtual void set_k(unsigned k) { VERIFY(k < 4000000000); m_k = k; } - virtual unsigned get_coeff(unsigned i) const { UNREACHABLE(); return 0; } - unsigned k() const { return m_k; } - bool well_formed() const override; - }; - - class card : public pb_base { - literal m_lits[0]; - public: - static size_t get_obj_size(unsigned num_lits) { return constraint_base::obj_size(sizeof(card) + num_lits * sizeof(literal)); } - card(unsigned id, literal lit, literal_vector const& lits, unsigned k); - literal operator[](unsigned i) const { return m_lits[i]; } - literal& operator[](unsigned i) { return m_lits[i]; } - literal const* begin() const { return m_lits; } - literal const* end() const { return static_cast(m_lits) + m_size; } - void negate() override; - void swap(unsigned i, unsigned j) override { std::swap(m_lits[i], m_lits[j]); } - literal_vector literals() const override { return literal_vector(m_size, m_lits); } - bool is_watching(literal l) const override; - literal get_lit(unsigned i) const override { return m_lits[i]; } - void set_lit(unsigned i, literal l) override { m_lits[i] = l; } - unsigned get_coeff(unsigned i) const override { return 1; } - }; - - - typedef std::pair wliteral; - - class pb : public pb_base { - unsigned m_slack; - unsigned m_num_watch; - unsigned m_max_sum; - wliteral m_wlits[0]; - public: - static size_t get_obj_size(unsigned num_lits) { return constraint_base::obj_size(sizeof(pb) + num_lits * sizeof(wliteral)); } - pb(unsigned id, literal lit, svector const& wlits, unsigned k); - literal lit() const { return m_lit; } - wliteral operator[](unsigned i) const { return m_wlits[i]; } - wliteral& operator[](unsigned i) { return m_wlits[i]; } - wliteral const* begin() const { return m_wlits; } - wliteral const* end() const { return begin() + m_size; } - - unsigned slack() const { return m_slack; } - void set_slack(unsigned s) { m_slack = s; } - unsigned num_watch() const { return m_num_watch; } - unsigned max_sum() const { return m_max_sum; } - void update_max_sum(); - void set_num_watch(unsigned s) { m_num_watch = s; } - bool is_cardinality() const; - void negate() override; - void set_k(unsigned k) override { m_k = k; VERIFY(k < 4000000000); update_max_sum(); } - void swap(unsigned i, unsigned j) override { std::swap(m_wlits[i], m_wlits[j]); } - literal_vector literals() const override { literal_vector lits; for (auto wl : *this) lits.push_back(wl.second); return lits; } - bool is_watching(literal l) const override; - literal get_lit(unsigned i) const override { return m_wlits[i].second; } - void set_lit(unsigned i, literal l) override { m_wlits[i].second = l; } - unsigned get_coeff(unsigned i) const override { return m_wlits[i].first; } - }; - - class xr : public constraint { - literal m_lits[0]; - public: - static size_t get_obj_size(unsigned num_lits) { return constraint_base::obj_size(sizeof(xr) + num_lits * sizeof(literal)); } - xr(unsigned id, literal_vector const& lits); - literal operator[](unsigned i) const { return m_lits[i]; } - literal const* begin() const { return m_lits; } - literal const* end() const { return begin() + m_size; } - void negate() override { m_lits[0].neg(); } - void swap(unsigned i, unsigned j) override { std::swap(m_lits[i], m_lits[j]); } - bool is_watching(literal l) const override; - literal_vector literals() const override { return literal_vector(size(), begin()); } - literal get_lit(unsigned i) const override { return m_lits[i]; } - void set_lit(unsigned i, literal l) override { m_lits[i] = l; } - bool well_formed() const override; - }; - - + protected: struct ineq { @@ -235,29 +86,28 @@ namespace sat { sat_internalizer& si; pb_util m_pb; - solver* m_solver; - lookahead* m_lookahead; + solver* m_solver{ nullptr }; + lookahead* m_lookahead{ nullptr }; stats m_stats; small_object_allocator m_allocator; - - ptr_vector m_constraints; - ptr_vector m_learned; - ptr_vector m_constraint_to_reinit; + ptr_vector m_constraints; + ptr_vector m_learned; + ptr_vector m_constraint_to_reinit; unsigned_vector m_constraint_to_reinit_lim; - unsigned m_constraint_to_reinit_last_sz; - unsigned m_constraint_id; + unsigned m_constraint_to_reinit_last_sz{ 0 }; + unsigned m_constraint_id{ 0 }; // conflict resolution - unsigned m_num_marks; - unsigned m_conflict_lvl; + unsigned m_num_marks{ 0 }; + unsigned m_conflict_lvl{ 0 }; svector m_coeffs; svector m_active_vars; - unsigned m_bound; + unsigned m_bound{ 0 }; tracked_uint_set m_active_var_set; literal_vector m_lemma; literal_vector m_skipped; - unsigned m_num_propagations_since_pop; + unsigned m_num_propagations_since_pop{ 0 }; unsigned_vector m_parity_marks; literal_vector m_parity_trail; @@ -297,11 +147,11 @@ namespace sat { vector> m_cnstr_use_list; use_list m_clause_use_list; - bool m_simplify_change; - bool m_clause_removed; - bool m_constraint_removed; + bool m_simplify_change{ false }; + bool m_clause_removed{ false }; + bool m_constraint_removed{ false }; literal_vector m_roots; - bool_vector m_root_vars; + bool_vector m_root_vars; unsigned_vector m_weights; svector m_wlits; @@ -324,9 +174,9 @@ namespace sat { unsigned elim_pure(); bool elim_pure(literal lit); void unit_strengthen(); - void unit_strengthen(big& big, constraint& cs); + void unit_strengthen(big& big, ba::constraint& cs); void unit_strengthen(big& big, pb_base& p); - void subsumption(constraint& c1); + void subsumption(ba::constraint& c1); void subsumption(card& c1); void gc_half(char const* _method); void update_psm(constraint& c) const; @@ -345,10 +195,7 @@ namespace sat { // constraints constraint& index2constraint(size_t idx) const { return *reinterpret_cast(constraint_base::from_index(idx)->mem()); } void pop_constraint(); - void unwatch_literal(literal w, constraint& c); - void watch_literal(literal w, constraint& c); - void watch_literal(wliteral w, pb& p); - bool is_watched(literal l, constraint const& c) const; + // void watch_literal(wliteral w, pb& p); void add_constraint(constraint* c); bool init_watch(constraint& c); void init_watch(bool_var v); @@ -357,9 +204,8 @@ namespace sat { bool incremental_mode() const; void simplify(constraint& c); void pre_simplify(xor_finder& xu, constraint& c); - void nullify_tracking_literal(constraint& c); - void set_conflict(constraint& c, literal lit); - void assign(constraint& c, literal lit); + void set_conflict(constraint& c, literal lit) override; + void assign(constraint& c, literal lit) override; bool assigned_above(literal above, literal below); void get_antecedents(literal l, constraint const& c, literal_vector & r, bool probing); bool validate_conflict(constraint const& c) const; @@ -377,12 +223,10 @@ namespace sat { void split_root(constraint& c); unsigned next_id() { return m_constraint_id++; } void set_non_learned(constraint& c); - + double get_reward(literal l, ext_justification_idx idx, literal_occs_fun& occs) const override; // cardinality - bool init_watch(card& c); lbool add_assign(card& c, literal lit); - void clear_watch(card& c); void reset_coeffs(); void reset_marked_literals(); void get_antecedents(literal l, card const& c, literal_vector & r); @@ -392,13 +236,9 @@ namespace sat { bool clausify(literal lit, unsigned n, literal const* lits, unsigned k); lbool eval(card const& c) const; lbool eval(model const& m, card const& c) const; - double get_reward(card const& c, literal_occs_fun& occs) const; // xr specific functionality - void clear_watch(xr& x); - bool init_watch(xr& x); - bool parity(xr const& x, unsigned offset) const; lbool add_assign(xr& x, literal alit); void get_xr_antecedents(literal l, unsigned index, justification js, literal_vector& r); void get_antecedents(literal l, xr const& x, literal_vector & r); @@ -411,11 +251,9 @@ namespace sat { lbool eval(model const& m, xr const& x) const; // pb functionality - unsigned m_a_max; - bool init_watch(pb& p); + unsigned m_a_max{ 0 }; lbool add_assign(pb& p, literal alit); void add_index(pb& p, unsigned index, literal lit); - void clear_watch(pb& p); void get_antecedents(literal l, pb const& p, literal_vector & r); void split_root(pb_base& p); void simplify(pb_base& p); @@ -427,7 +265,6 @@ namespace sat { bool is_cardinality(pb const& p, literal_vector& lits); lbool eval(pb const& p) const; lbool eval(model const& m, pb const& p) const; - double get_reward(pb const& p, literal_occs_fun& occs) const; // RoundingPb conflict resolution lbool resolve_conflict_rs(); @@ -449,31 +286,32 @@ namespace sat { // access solver - inline lbool value(bool_var v) const { return value(literal(v, false)); } - inline lbool value(literal lit) const { return m_lookahead ? m_lookahead->value(lit) : m_solver->value(lit); } - inline lbool value(model const& m, literal l) const { return l.sign() ? ~m[l.var()] : m[l.var()]; } - inline bool is_false(literal lit) const { return l_false == value(lit); } + inline lbool value(bool_var v) const override { return value(literal(v, false)); } + inline lbool value(literal lit) const override { return m_lookahead ? m_lookahead->value(lit) : m_solver->value(lit); } + inline bool is_false(literal lit) const override { return l_false == value(lit); } - inline unsigned lvl(literal lit) const { return m_lookahead ? 0 : m_solver->lvl(lit); } - inline unsigned lvl(bool_var v) const { return m_lookahead ? 0 : m_solver->lvl(v); } - inline bool inconsistent() const { + inline unsigned lvl(literal lit) const override { return m_lookahead ? 0 : m_solver->lvl(lit); } + inline unsigned lvl(bool_var v) const override { return m_lookahead ? 0 : m_solver->lvl(v); } + inline bool inconsistent() const override { if (m_lookahead) return m_lookahead->inconsistent(); return m_solver->inconsistent(); } - inline watch_list& get_wlist(literal l) { return m_lookahead ? m_lookahead->get_wlist(l) : m_solver->get_wlist(l); } - inline watch_list const& get_wlist(literal l) const { return m_lookahead ? m_lookahead->get_wlist(l) : m_solver->get_wlist(l); } - inline void assign(literal l, justification j) { + inline watch_list& get_wlist(literal l) override { return m_lookahead ? m_lookahead->get_wlist(l) : m_solver->get_wlist(l); } + inline watch_list const& get_wlist(literal l) const override { return m_lookahead ? m_lookahead->get_wlist(l) : m_solver->get_wlist(l); } + inline void assign(literal l, justification j) override { if (m_lookahead) m_lookahead->assign(l); else m_solver->assign(l, j); } - inline void set_conflict(justification j, literal l) { + inline void set_conflict(justification j, literal l) override { if (m_lookahead) m_lookahead->set_conflict(); else m_solver->set_conflict(j, l); } - inline config const& get_config() const { return m_lookahead ? m_lookahead->get_config() : m_solver->get_config(); } + inline config const& get_config() const override { + return m_lookahead ? m_lookahead->get_config() : m_solver->get_config(); + } - mutable bool m_overflow; + mutable bool m_overflow{ false }; void reset_active_var_set(); bool test_and_set_active(bool_var v); void inc_coeff(literal l, unsigned offset); @@ -499,10 +337,7 @@ namespace sat { bool validate_assign(literal_vector const& lits, literal lit); bool validate_lemma(); bool validate_ineq(ineq const& ineq) const; - bool validate_unit_propagation(card const& c, literal alit) const; - bool validate_unit_propagation(pb const& p, literal alit) const; bool validate_unit_propagation(pb const& p, literal_vector const& r, literal alit) const; - bool validate_unit_propagation(xr const& x, literal alit) const; bool validate_conflict(literal_vector const& lits, ineq& p); bool validate_watch_literals() const; bool validate_watch_literal(literal lit) const; @@ -528,9 +363,6 @@ namespace sat { unsigned get_coeff(ineq const& pb, literal lit); void display(std::ostream& out, ineq const& p, bool values = false) const; - void display(std::ostream& out, card const& c, bool values) const; - void display(std::ostream& out, pb const& p, bool values) const; - void display(std::ostream& out, xr const& c, bool values) const; void display_lit(std::ostream& out, literal l, unsigned sz, bool values) const; constraint* add_at_least(literal l, literal_vector const& lits, unsigned k, bool learned); @@ -560,9 +392,9 @@ namespace sat { literal internalize_xor(expr* e, bool sign, bool root); // Decompile - expr_ref get_card(std::function& l2e, ba_solver::card const& c); - expr_ref get_pb(std::function& l2e, ba_solver::pb const& p); - expr_ref get_xor(std::function& l2e, ba_solver::xr const& x); + expr_ref get_card(std::function& l2e, card const& c); + expr_ref get_pb(std::function& l2e, pb const& p); + expr_ref get_xor(std::function& l2e, xr const& x); public: ba_solver(euf::solver& ctx, euf::theory_id id); @@ -598,7 +430,6 @@ namespace sat { void pop_reinit() override; void gc() override; unsigned max_var(unsigned w) const override; - double get_reward(literal l, ext_justification_idx idx, literal_occs_fun& occs) const override; bool is_extended_binary(ext_justification_idx idx, literal_vector & r) override; void init_use_list(ext_use_list& ul) override; bool is_blocked(literal l, ext_constraint_idx idx) override; diff --git a/src/sat/smt/ba_solver_interface.h b/src/sat/smt/ba_solver_interface.h new file mode 100644 index 000000000..5c6da6d31 --- /dev/null +++ b/src/sat/smt/ba_solver_interface.h @@ -0,0 +1,52 @@ +/*++ +Copyright (c) 2017 Microsoft Corporation + +Module Name: + + ba_solver_interface.h + +Abstract: + + Abstract interface for a solver, + covers functionality exposed by the sat and lookahead solvers. + +Author: + + Nikolaj Bjorner (nbjorner) 2017-01-30 + +Revision History: + +--*/ + +#pragma once + +#include "sat/sat_types.h" +#include "sat/sat_solver.h" +#include "sat/smt/sat_smt.h" + + +namespace ba { + + typedef sat::literal literal; + typedef sat::bool_var bool_var; + typedef sat::literal_vector literal_vector; + typedef std::pair wliteral; + class constraint; + + class solver_interface { + public: + virtual lbool value(bool_var v) const = 0; + virtual lbool value(literal lit) const = 0; + virtual bool is_false(literal lit) const = 0; + virtual unsigned lvl(literal lit) const = 0; + virtual unsigned lvl(bool_var v) const = 0; + virtual bool inconsistent() const = 0; + virtual sat::watch_list& get_wlist(literal l) = 0; + virtual sat::watch_list const& get_wlist(literal l) const = 0; + virtual void assign(literal l, sat::justification j) = 0; + virtual void set_conflict(sat::justification j, literal l) = 0; + virtual sat::config const& get_config() const = 0; + virtual void assign(constraint& c, literal lit) = 0; + virtual void set_conflict(constraint& c, literal lit) = 0; + }; +} diff --git a/src/sat/smt/ba_xor.cpp b/src/sat/smt/ba_xor.cpp new file mode 100644 index 000000000..3914739ae --- /dev/null +++ b/src/sat/smt/ba_xor.cpp @@ -0,0 +1,192 @@ +/*++ +Copyright (c) 2017 Microsoft Corporation + +Module Name: + + ba_xor.cpp + +Abstract: + + Interface for Xor constraints. + +Author: + + Nikolaj Bjorner (nbjorner) 2017-01-30 + +--*/ + +#include "sat/smt/ba_xor.h" +#include "sat/smt/ba_solver.h" + +namespace ba { + + xr& constraint::to_xr() { + SASSERT(is_xr()); + return static_cast(*this); + } + + xr const& constraint::to_xr() const { + SASSERT(is_xr()); + return static_cast(*this); + } + + xr::xr(unsigned id, literal_vector const& lits) : + constraint(ba::tag_t::xr_t, id, sat::null_literal, lits.size(), get_obj_size(lits.size())) { + for (unsigned i = 0; i < size(); ++i) { + m_lits[i] = lits[i]; + } + } + + + bool xr::is_watching(literal l) const { + return + l == (*this)[0] || l == (*this)[1] || + ~l == (*this)[0] || ~l == (*this)[1]; + } + + bool xr::well_formed() const { + uint_set vars; + if (lit() != sat::null_literal) vars.insert(lit().var()); + for (literal l : *this) { + bool_var v = l.var(); + if (vars.contains(v)) return false; + vars.insert(v); + } + return true; + } + + std::ostream& xr::display(std::ostream& out) const { + for (unsigned i = 0; i < size(); ++i) { + out << (*this)[i] << " "; + if (i + 1 < size()) out << "x "; + } + return out; + } + + void xr::clear_watch(solver_interface& s) { + auto& x = *this; + x.reset_watch(); + x.unwatch_literal(s, x[0]); + x.unwatch_literal(s, x[1]); + x.unwatch_literal(s, ~x[0]); + x.unwatch_literal(s, ~x[1]); + } + + + bool xr::init_watch(solver_interface& s) { + auto& x = *this; + x.clear_watch(s); + VERIFY(x.lit() == sat::null_literal); + TRACE("ba", x.display(tout);); + unsigned sz = x.size(); + unsigned j = 0; + for (unsigned i = 0; i < sz && j < 2; ++i) { + if (s.value(x[i]) == l_undef) { + x.swap(i, j); + ++j; + } + } + switch (j) { + case 0: + if (!parity(s, 0)) { + unsigned l = s.lvl(x[0]); + j = 1; + for (unsigned i = 1; i < sz; ++i) { + if (s.lvl(x[i]) > l) { + j = i; + l = s.lvl(x[i]); + } + } + s.set_conflict(x, x[j]); + } + return false; + case 1: + SASSERT(x.lit() == sat::null_literal || s.value(x.lit()) == l_true); + s.assign(x, parity(s, 1) ? ~x[0] : x[0]); + return false; + default: + SASSERT(j == 2); + x.watch_literal(s, x[0]); + x.watch_literal(s, x[1]); + x.watch_literal(s, ~x[0]); + x.watch_literal(s, ~x[1]); + return true; + } + } + + bool xr::parity(solver_interface const& s, unsigned offset) const { + auto const& x = *this; + bool odd = false; + unsigned sz = x.size(); + for (unsigned i = offset; i < sz; ++i) { + SASSERT(s.value(x[i]) != l_undef); + if (s.value(x[i]) == l_true) { + odd = !odd; + } + } + return odd; + } + + + std::ostream& xr::display(std::ostream& out, solver_interface const& s, bool values) const { + auto const& x = *this; + out << "xr: "; + for (literal l : x) { + out << l; + if (values) { + out << "@(" << s.value(l); + if (s.value(l) != l_undef) { + out << ":" << s.lvl(l); + } + out << ") "; + } + else { + out << " "; + } + } + return out << "\n"; + } + + bool xr::validate_unit_propagation(solver_interface const& s, literal alit) const { + if (s.value(lit()) != l_true) return false; + for (unsigned i = 1; i < size(); ++i) { + if (s.value((*this)[i]) == l_undef) return false; + } + return true; + } + + lbool xr::eval(solver_interface const& s) const { + auto const& x = *this; + bool odd = false; + for (auto l : x) { + switch (s.value(l)) { + case l_true: odd = !odd; break; + case l_false: break; + default: return l_undef; + } + } + return odd ? l_true : l_false; + } + + lbool xr::eval(sat::model const& m) const { + auto const& x = *this; + bool odd = false; + for (auto l : x) { + switch (ba::value(m, l)) { + case l_true: odd = !odd; break; + case l_false: break; + default: return l_undef; + } + } + return odd ? l_true : l_false; + } + + void xr::init_use_list(sat::ext_use_list& ul) const { + auto idx = cindex(); + for (auto l : *this) { + ul.insert(l, idx); + ul.insert(~l, idx); + } + } + +} diff --git a/src/sat/smt/ba_xor.h b/src/sat/smt/ba_xor.h new file mode 100644 index 000000000..009534951 --- /dev/null +++ b/src/sat/smt/ba_xor.h @@ -0,0 +1,53 @@ +/*++ +Copyright (c) 2017 Microsoft Corporation + +Module Name: + + ba_xor.h + +Abstract: + + Interface for Xor constraints. + +Author: + + Nikolaj Bjorner (nbjorner) 2017-01-30 + +--*/ + +#pragma once + +#include "sat/sat_types.h" +#include "sat/smt/ba_constraint.h" + + +namespace ba { + + class xr : public constraint { + literal m_lits[0]; + public: + static size_t get_obj_size(unsigned num_lits) { return sat::constraint_base::obj_size(sizeof(xr) + num_lits * sizeof(literal)); } + xr(unsigned id, literal_vector const& lits); + literal operator[](unsigned i) const { return m_lits[i]; } + literal const* begin() const { return m_lits; } + literal const* end() const { return begin() + m_size; } + void negate() override { m_lits[0].neg(); } + void swap(unsigned i, unsigned j) override { std::swap(m_lits[i], m_lits[j]); } + bool is_watching(literal l) const override; + literal_vector literals() const override { return literal_vector(size(), begin()); } + literal get_lit(unsigned i) const override { return m_lits[i]; } + void set_lit(unsigned i, literal l) override { m_lits[i] = l; } + bool well_formed() const override; + void clear_watch(solver_interface& s) override; + bool init_watch(solver_interface& s) override; + std::ostream& display(std::ostream& out) const override; + std::ostream& display(std::ostream& out, solver_interface const& s, bool values) const override; + + bool parity(solver_interface const& s, unsigned offset) const; + bool validate_unit_propagation(solver_interface const& s, literal alit) const override; + lbool eval(sat::model const& m) const override; + lbool eval(solver_interface const& s) const override; + void init_use_list(sat::ext_use_list& ul) const override; + bool is_blocked(sat::simplifier& s, literal lit) const override { return false; } + }; +} diff --git a/src/sat/smt/bv_solver.cpp b/src/sat/smt/bv_solver.cpp index b2cfb283e..1c241e1c9 100644 --- a/src/sat/smt/bv_solver.cpp +++ b/src/sat/smt/bv_solver.cpp @@ -362,7 +362,7 @@ namespace bv { eq eq_proc(*this); hash hash_proc(*this); map table(hash_proc, eq_proc); - for (unsigned v = 0; v < get_num_vars(); ++v) { + for (theory_var v = 0; v < static_cast(get_num_vars()); ++v) { if (!m_bits[v].empty()) { theory_var w = table.insert_if_not_there(v, v); if (v != w && m_find.find(v) != m_find.find(w)) @@ -424,7 +424,7 @@ namespace bv { result->m_bits[i].append(m_bits[i]); result->m_zero_one_bits[i].append(m_zero_one_bits[i]); } - for (unsigned i = 0; i < get_num_vars(); ++i) + for (theory_var i = 0; i < static_cast(get_num_vars()); ++i) if (find(i) != i) result->m_find.merge(i, find(i)); result->m_prop_queue.append(m_prop_queue); diff --git a/src/sat/smt/euf_ackerman.cpp b/src/sat/smt/euf_ackerman.cpp index 259d8dd30..d65bc7510 100644 --- a/src/sat/smt/euf_ackerman.cpp +++ b/src/sat/smt/euf_ackerman.cpp @@ -96,10 +96,12 @@ namespace euf { void ackerman::cg_conflict_eh(expr * n1, expr * n2) { if (!is_app(n1) || !is_app(n2)) return; + SASSERT(!s.m_drating); app* a = to_app(n1); app* b = to_app(n2); if (a->get_decl() != b->get_decl() || a->get_num_args() != b->get_num_args()) return; + TRACE("ack", tout << "conflict eh: " << mk_pp(a, m) << " == " << mk_pp(b, m) << "\n";); insert(a, b); gc(); } @@ -107,13 +109,19 @@ namespace euf { void ackerman::used_eq_eh(expr* a, expr* b, expr* c) { if (a == b || a == c || b == c) return; + if (s.m_drating) + return; + TRACE("ack", tout << mk_pp(a, m) << " " << mk_pp(b, m) << " " << mk_pp(c, m) << "\n";); insert(a, b, c); gc(); } void ackerman::used_cc_eh(app* a, app* b) { + if (s.m_drating) + return; + TRACE("ack", tout << "used cc: " << mk_pp(a, m) << " == " << mk_pp(b, m) << "\n";); SASSERT(a->get_decl() == b->get_decl()); - SASSERT(a->get_num_args() == b->get_num_args()); + SASSERT(a->get_num_args() == b->get_num_args()); insert(a, b); gc(); } @@ -153,15 +161,15 @@ namespace euf { } } - void ackerman::add_cc(expr* _a, expr* _b) { + void ackerman::add_cc(expr* _a, expr* _b) { app* a = to_app(_a); app* b = to_app(_b); + TRACE("ack", tout << mk_pp(a, m) << " " << mk_pp(b, m) << "\n";); sat::literal_vector lits; unsigned sz = a->get_num_args(); for (unsigned i = 0; i < sz; ++i) { expr_ref eq(m.mk_eq(a->get_arg(i), b->get_arg(i)), m); - sat::literal lit = s.internalize(eq, true, false, true); - lits.push_back(~lit); + lits.push_back(s.internalize(eq, true, false, true)); } expr_ref eq(m.mk_eq(a, b), m); lits.push_back(s.internalize(eq, false, false, true)); @@ -173,6 +181,7 @@ namespace euf { expr_ref eq1(m.mk_eq(a, c), m); expr_ref eq2(m.mk_eq(b, c), m); expr_ref eq3(m.mk_eq(a, b), m); + TRACE("ack", tout << mk_pp(a, m) << " " << mk_pp(b, m) << " " << mk_pp(c, m) << "\n";); lits[0] = s.internalize(eq1, true, false, true); lits[1] = s.internalize(eq2, true, false, true); lits[2] = s.internalize(eq3, false, false, true); diff --git a/src/sat/smt/euf_internalize.cpp b/src/sat/smt/euf_internalize.cpp index 2b773b060..fd6954d7e 100644 --- a/src/sat/smt/euf_internalize.cpp +++ b/src/sat/smt/euf_internalize.cpp @@ -84,7 +84,7 @@ namespace euf { void solver::attach_node(euf::enode* n) { expr* e = n->get_expr(); if (!m.is_bool(e)) - log_node(e); + drat_log_node(e); else attach_lit(literal(si.add_bool_var(e), false), e); @@ -98,20 +98,23 @@ namespace euf { } sat::literal solver::attach_lit(literal lit, expr* e) { + sat::bool_var v = lit.var(); + s().set_external(v); + s().set_eliminated(v, false); + if (lit.sign()) { - sat::bool_var v = si.add_bool_var(e); + v = si.add_bool_var(e); s().set_external(v); + s().set_eliminated(v, false); sat::literal lit2 = literal(v, false); - s().mk_clause(~lit, lit2, sat::status::asserted()); - s().mk_clause(lit, ~lit2, sat::status::asserted()); + s().mk_clause(~lit, lit2, sat::status::th(m_is_redundant, m.get_basic_family_id())); + s().mk_clause(lit, ~lit2, sat::status::th(m_is_redundant, m.get_basic_family_id())); lit = lit2; } - sat::bool_var v = lit.var(); m_var2expr.reserve(v + 1, nullptr); SASSERT(m_var2expr[v] == nullptr); m_var2expr[v] = e; m_var_trail.push_back(v); - s().set_external(v); if (!m_egraph.find(e)) { enode* n = m_egraph.mk(e, 0, nullptr); m_egraph.set_merge_enabled(n, false); @@ -215,14 +218,11 @@ namespace euf { void solver::axiomatize_basic(enode* n) { expr* e = n->get_expr(); sat::status st = sat::status::th(m_is_redundant, m.get_basic_family_id()); - if (m.is_ite(e)) { + expr* c = nullptr, * th = nullptr, * el = nullptr; + if (!m.is_bool(e) && m.is_ite(e, c, th, el)) { app* a = to_app(e); - expr* c = a->get_arg(0); - expr* th = a->get_arg(1); - expr* el = a->get_arg(2); sat::bool_var v = si.to_bool_var(c); SASSERT(v != sat::null_bool_var); - SASSERT(!m.is_bool(e)); expr_ref eq_th(m.mk_eq(a, th), m); expr_ref eq_el(m.mk_eq(a, el), m); sat::literal lit_th = internalize(eq_th, false, false, m_is_redundant); diff --git a/src/sat/smt/euf_invariant.cpp b/src/sat/smt/euf_invariant.cpp index 97f7acb51..e7555a016 100644 --- a/src/sat/smt/euf_invariant.cpp +++ b/src/sat/smt/euf_invariant.cpp @@ -44,5 +44,16 @@ namespace euf { } } + void solver::check_missing_eq_propagation() const { + if (s().inconsistent()) + return; + for (enode* n : m_egraph.nodes()) + if (m.is_false(n->get_root()->get_expr()) && m.is_eq(n->get_expr()) && + n->get_arg(0)->get_root() == n->get_arg(1)->get_root()) { + TRACE("euf", display(tout << n->get_expr_id() << ": " << mk_pp(n->get_expr(), m) << "\n");); + UNREACHABLE(); + } + } + } diff --git a/src/sat/smt/euf_proof.cpp b/src/sat/smt/euf_proof.cpp index f591161d7..ab95c99dc 100644 --- a/src/sat/smt/euf_proof.cpp +++ b/src/sat/smt/euf_proof.cpp @@ -15,24 +15,32 @@ Author: --*/ +#include "ast/ast_ll_pp.h" #include "sat/smt/euf_solver.h" namespace euf { void solver::init_drat() { - if (!m_drat_initialized) - get_drat().add_theory(m.get_basic_family_id(), symbol("euf")); + if (!m_drat_initialized) { + get_drat().add_theory(get_id(), symbol("euf")); + get_drat().add_theory(m.get_basic_family_id(), symbol("bool")); + } m_drat_initialized = true; } - void solver::log_node(expr* e) { + void solver::drat_log_node(expr* e) { if (!use_drat()) return; if (is_app(e)) { - std::stringstream strm; - strm << mk_ismt2_func(to_app(e)->get_decl(), m); - get_drat().def_begin(e->get_id(), strm.str()); - for (expr* arg : *to_app(e)) + app* a = to_app(e); + if (a->get_num_parameters() == 0) + get_drat().def_begin(e->get_id(), a->get_decl()->get_name().str()); + else { + std::stringstream strm; + strm << mk_ismt2_func(a->get_decl(), m); + get_drat().def_begin(e->get_id(), strm.str()); + } + for (expr* arg : *a) get_drat().def_add_arg(arg->get_id()); get_drat().def_end(); } @@ -57,7 +65,7 @@ namespace euf { for (literal lit : r) lits.push_back(~lit); if (l != sat::null_literal) lits.push_back(l); - get_drat().add(lits, sat::status::th(true, m.get_basic_family_id())); + get_drat().add(lits, sat::status::th(true, get_id())); } void solver::log_antecedents(std::ostream& out, literal l, literal_vector const& r) { @@ -65,14 +73,14 @@ namespace euf { expr* n = m_var2expr[l.var()]; out << ~l << ": "; if (!l.sign()) out << "! "; - out << mk_pp(n, m) << "\n"; + out << mk_bounded_pp(n, m) << "\n"; SASSERT(s().value(l) == l_true); } if (l != sat::null_literal) { out << l << ": "; if (l.sign()) out << "! "; expr* n = m_var2expr[l.var()]; - out << mk_pp(n, m) << "\n"; + out << mk_bounded_pp(n, m) << "\n"; } } diff --git a/src/sat/smt/euf_solver.cpp b/src/sat/smt/euf_solver.cpp index cf968daed..b906adf86 100644 --- a/src/sat/smt/euf_solver.cpp +++ b/src/sat/smt/euf_solver.cpp @@ -25,6 +25,27 @@ Author: namespace euf { + solver::solver(ast_manager& m, sat::sat_internalizer& si, params_ref const& p) : + extension(m.mk_family_id("euf")), + m(m), + si(si), + m_egraph(m), + m_trail(*this), + m_rewriter(m), + m_unhandled_functions(m), + m_solver(nullptr), + m_lookahead(nullptr), + m_to_m(&m), + m_to_si(&si), + m_reinit_exprs(m) + { + updt_params(p); + + std::function disp = + [&](std::ostream& out, void* j) { display_justification_ptr(out, reinterpret_cast(j)); }; + m_egraph.set_display_justification(disp); + } + void solver::updt_params(params_ref const& p) { m_config.updt_params(p); } @@ -129,7 +150,9 @@ namespace euf { ext->get_antecedents(lit, idx, r, probing); } } - m_egraph.end_explain(); + m_egraph.end_explain(); + TRACE("euf", tout << "eplain " << l << " <- " << r << " " << probing << "\n";); + DEBUG_CODE(for (auto lit : r) SASSERT(s().value(lit) == l_true);); if (!probing) log_antecedents(l, r); } @@ -150,7 +173,8 @@ namespace euf { expr* e = nullptr; euf::enode* n = nullptr; - init_ackerman(); + if (!probing && !m_drating) + init_ackerman(); switch (j.kind()) { case constraint::kind_t::conflict: @@ -185,7 +209,7 @@ namespace euf { } bool sign = l.sign(); - TRACE("euf", tout << "asserted: " << l << "@" << s().scope_lvl() << " " << (sign ? "not ": " ") << e->get_id() << "\n";); + TRACE("euf", tout << "asserted: " << l << "@" << s().scope_lvl() << "\n";); euf::enode* n = m_egraph.find(e); if (!n) return; @@ -230,6 +254,7 @@ namespace euf { break; propagated = true; } + DEBUG_CODE(if (!s().inconsistent()) check_missing_eq_propagation();); return propagated; } @@ -255,11 +280,19 @@ namespace euf { cnstr = lit_constraint().to_index(); lit = literal(v, m.is_false(b)); } + unsigned lvl = s().scope_lvl(); + + CTRACE("euf", s().value(lit) != l_true, tout << lit << " " << s().value(lit) << "@" << lvl << " " << is_eq << " " << mk_bounded_pp(a, m) << " = " << mk_bounded_pp(b, m) << "\n";); if (s().value(lit) == l_false && m_ackerman) m_ackerman->cg_conflict_eh(a, b); - unsigned lvl = s().scope_lvl(); - if (s().value(lit) != l_true) + switch (s().value(lit)) { + case l_true: + break; + case l_undef: + case l_false: s().assign(lit, sat::justification::mk_ext_justification(lvl, cnstr)); + break; + } } } @@ -295,15 +328,15 @@ namespace euf { bool cont = false; for (auto* e : m_solvers) switch (e->check()) { - case sat::CR_CONTINUE: cont = true; break; - case sat::CR_GIVEUP: give_up = true; break; + case sat::check_result::CR_CONTINUE: cont = true; break; + case sat::check_result::CR_GIVEUP: give_up = true; break; default: break; } if (cont) - return sat::CR_CONTINUE; + return sat::check_result::CR_CONTINUE; if (give_up) - return sat::CR_GIVEUP; - return sat::CR_DONE; + return sat::check_result::CR_GIVEUP; + return sat::check_result::CR_DONE; } void solver::push() { @@ -329,6 +362,7 @@ namespace euf { m_trail.pop_scope(n); m_scopes.shrink(m_scopes.size() - n); si.pop(n); + SASSERT(m_egraph.num_scopes() == m_scopes.size()); } void solver::start_reinit(unsigned n) { @@ -356,8 +390,8 @@ namespace euf { return; si.set_expr2var_replay(&expr2var_replay); for (auto const& kv : expr2var_replay) - si.internalize(kv.m_key, true); - si.set_expr2var_replay(nullptr); + attach_lit(si.internalize(kv.m_key, true), kv.m_key); + si.set_expr2var_replay(nullptr); } void solver::pre_simplify() { @@ -397,6 +431,7 @@ namespace euf { if (n && n->merge_enabled()) ok = false; } + TRACE("euf", tout << ok << " " << l << " -> " << r << "\n";); return ok; } @@ -417,6 +452,13 @@ namespace euf { return out; } + std::ostream& solver::display_justification_ptr(std::ostream& out, size_t* j) const { + if (is_literal(j)) + return out << get_literal(j) << " "; + else + return display_justification(out, get_justification(j)) << " "; + } + std::ostream& solver::display_justification(std::ostream& out, ext_justification_idx idx) const { auto* ext = sat::constraint_base::to_extension(idx); if (ext != this) @@ -480,6 +522,7 @@ namespace euf { return false; check_eqc_bool_assignment(); check_missing_bool_enode_propagation(); + check_missing_eq_propagation(); m_egraph.invariant(); return true; } @@ -531,7 +574,7 @@ namespace euf { void solver::init_ackerman() { if (m_ackerman) return; - if (m_config.m_dack == DACK_DISABLED) + if (m_config.m_dack == dyn_ack_strategy::DACK_DISABLED) return; m_ackerman = alloc(ackerman, *this, m); std::function used_eq = [&](expr* a, expr* b, expr* lca) { diff --git a/src/sat/smt/euf_solver.h b/src/sat/smt/euf_solver.h index 42395f04d..101af94ba 100644 --- a/src/sat/smt/euf_solver.h +++ b/src/sat/smt/euf_solver.h @@ -70,7 +70,7 @@ namespace euf { size_t* to_ptr(size_t jst) { return TAG(size_t*, reinterpret_cast(jst), 2); } bool is_literal(size_t* p) const { return GET_TAG(p) == 1; } bool is_justification(size_t* p) const { return GET_TAG(p) == 2; } - sat::literal get_literal(size_t* p) { + sat::literal get_literal(size_t* p) const { unsigned idx = static_cast(reinterpret_cast(UNTAG(size_t*, p))); return sat::to_literal(idx >> 4); } @@ -86,7 +86,6 @@ namespace euf { stats m_stats; th_rewriter m_rewriter; func_decl_ref_vector m_unhandled_functions; - sat::solver* m_solver { nullptr }; sat::lookahead* m_lookahead { nullptr }; @@ -148,13 +147,16 @@ namespace euf { // proofs void log_antecedents(std::ostream& out, literal l, literal_vector const& r); void log_antecedents(literal l, literal_vector const& r); - void log_node(expr* n); + bool m_drat_initialized{ false }; void init_drat(); // invariant void check_eqc_bool_assignment() const; - void check_missing_bool_enode_propagation() const; + void check_missing_bool_enode_propagation() const; + void check_missing_eq_propagation() const; + + std::ostream& display_justification_ptr(std::ostream& out, size_t* j) const; constraint& mk_constraint(constraint*& c, constraint::kind_t k); constraint& conflict_constraint() { return mk_constraint(m_conflict, constraint::kind_t::conflict); } @@ -162,21 +164,7 @@ namespace euf { constraint& lit_constraint() { return mk_constraint(m_lit, constraint::kind_t::lit); } public: - solver(ast_manager& m, sat::sat_internalizer& si, params_ref const& p = params_ref()): - m(m), - si(si), - m_egraph(m), - m_trail(*this), - m_rewriter(m), - m_unhandled_functions(m), - m_solver(nullptr), - m_lookahead(nullptr), - m_to_m(&m), - m_to_si(&si), - m_reinit_exprs(m) - { - updt_params(p); - } + solver(ast_manager& m, sat::sat_internalizer& si, params_ref const& p = params_ref()); ~solver() override { if (m_conflict) dealloc(sat::constraint_base::mem2base_ptr(m_conflict)); @@ -267,6 +255,7 @@ namespace euf { void unhandled_function(func_decl* f); th_rewriter& get_rewriter() { return m_rewriter; } bool is_shared(euf::enode* n) const; + void drat_log_node(expr* n); void update_model(model_ref& mdl); diff --git a/src/sat/smt/sat_th.cpp b/src/sat/smt/sat_th.cpp index 0fc9c7875..c804ff29c 100644 --- a/src/sat/smt/sat_th.cpp +++ b/src/sat/smt/sat_th.cpp @@ -31,17 +31,17 @@ namespace euf { loop: if (!m.inc()) throw tactic_exception(m.limit().get_cancel_msg()); - sat::eframe& fr = m_stack.back(); - expr* e = fr.m_e; + unsigned fsz = m_stack.size(); + expr* e = m_stack[fsz-1].m_e; if (visited(e)) { m_stack.pop_back(); continue; } unsigned num = is_app(e) ? to_app(e)->get_num_args() : 0; - while (fr.m_idx < num) { - expr* arg = to_app(e)->get_arg(fr.m_idx); - fr.m_idx++; + while (m_stack[fsz - 1].m_idx < num) { + expr* arg = to_app(e)->get_arg(m_stack[fsz - 1].m_idx); + m_stack[fsz - 1].m_idx++; if (!visit(arg)) goto loop; } @@ -120,25 +120,27 @@ namespace euf { } } - void th_euf_solver::add_unit(sat::literal lit) { - ctx.s().add_clause(1, &lit, sat::status::th(m_is_redundant, get_id())); + bool th_euf_solver::add_unit(sat::literal lit) { + return !is_true(lit) && (ctx.s().add_clause(1, &lit, sat::status::th(m_is_redundant, get_id())), true); } - void th_euf_solver::add_clause(sat::literal a, sat::literal b) { + bool th_euf_solver::add_clause(sat::literal a, sat::literal b) { sat::literal lits[2] = { a, b }; - ctx.s().add_clause(2, lits, sat::status::th(m_is_redundant, get_id())); + return !is_true(a, b) && (ctx.s().add_clause(2, lits, sat::status::th(m_is_redundant, get_id())), true); } - void th_euf_solver::add_clause(sat::literal a, sat::literal b, sat::literal c) { + bool th_euf_solver::add_clause(sat::literal a, sat::literal b, sat::literal c) { sat::literal lits[3] = { a, b, c }; - ctx.s().add_clause(3, lits, sat::status::th(m_is_redundant, get_id())); + return !is_true(a, b, c) && (ctx.s().add_clause(3, lits, sat::status::th(m_is_redundant, get_id())), true); } - void th_euf_solver::add_clause(sat::literal a, sat::literal b, sat::literal c, sat::literal d) { + bool th_euf_solver::add_clause(sat::literal a, sat::literal b, sat::literal c, sat::literal d) { sat::literal lits[4] = { a, b, c, d }; - ctx.s().add_clause(4, lits, sat::status::th(m_is_redundant, get_id())); + return !is_true(a, b, c, d) && (ctx.s().add_clause(4, lits, sat::status::th(m_is_redundant, get_id())), true); } + bool th_euf_solver::is_true(sat::literal lit) { return ctx.s().value(lit) == l_true; } + euf::enode* th_euf_solver::mk_enode(expr* e, bool suppress_args) { m_args.reset(); if (!suppress_args) @@ -146,15 +148,9 @@ namespace euf { m_args.push_back(expr2enode(arg)); euf::enode* n = ctx.mk_enode(e, m_args.size(), m_args.c_ptr()); ctx.attach_node(n); - if (m.is_bool(e)) { - sat::bool_var v = ctx.get_si().add_bool_var(e); - NOT_IMPLEMENTED_YET(); - // TODO: ctx.attach_lit(literal(v, false), e); - } return n; } - void th_euf_solver::rewrite(expr_ref& a) { ctx.get_rewriter()(a); } diff --git a/src/sat/smt/sat_th.h b/src/sat/smt/sat_th.h index 33d49467a..0e44bdd8a 100644 --- a/src/sat/smt/sat_th.h +++ b/src/sat/smt/sat_th.h @@ -86,11 +86,8 @@ namespace euf { class th_solver : public sat::extension, public th_model_builder, public th_decompile, public th_internalizer { protected: ast_manager & m; - theory_id m_id; public: - th_solver(ast_manager& m, euf::theory_id id): m(m), m_id(id) {} - - unsigned get_id() const override { return m_id; } + th_solver(ast_manager& m, euf::theory_id id): extension(id), m(m) {} virtual th_solver* fresh(sat::solver* s, euf::solver& ctx) = 0; @@ -115,11 +112,16 @@ namespace euf { region& get_region(); - void add_unit(sat::literal lit); - void add_clause(sat::literal lit) { add_unit(lit); } - void add_clause(sat::literal a, sat::literal b); - void add_clause(sat::literal a, sat::literal b, sat::literal c); - void add_clause(sat::literal a, sat::literal b, sat::literal c, sat::literal d); + bool add_unit(sat::literal lit); + bool add_clause(sat::literal lit) { return add_unit(lit); } + bool add_clause(sat::literal a, sat::literal b); + bool add_clause(sat::literal a, sat::literal b, sat::literal c); + bool add_clause(sat::literal a, sat::literal b, sat::literal c, sat::literal d); + + bool is_true(sat::literal lit); + bool is_true(sat::literal a, sat::literal b) { return is_true(a) || is_true(b); } + bool is_true(sat::literal a, sat::literal b, sat::literal c) { return is_true(a) || is_true(b, c); } + bool is_true(sat::literal a, sat::literal b, sat::literal c, sat::literal d) { return is_true(a) || is_true(b, c, c); } euf::enode* e_internalize(expr* e) { internalize(e, m_is_redundant); return expr2enode(e); } euf::enode* mk_enode(expr* e, bool suppress_args); diff --git a/src/sat/smt/xor_solver.cpp b/src/sat/smt/xor_solver.cpp index 51b784ff4..c9a11b0a9 100644 --- a/src/sat/smt/xor_solver.cpp +++ b/src/sat/smt/xor_solver.cpp @@ -22,105 +22,12 @@ Revision History: #include "sat/sat_simplifier_params.hpp" #include "sat/sat_xor_finder.h" + namespace sat { - ba_solver::xr& ba_solver::constraint::to_xr() { - SASSERT(is_xr()); - return static_cast(*this); - } - - ba_solver::xr const& ba_solver::constraint::to_xr() const{ - SASSERT(is_xr()); - return static_cast(*this); - } - - ba_solver::xr::xr(unsigned id, literal_vector const& lits): - constraint(xr_t, id, null_literal, lits.size(), get_obj_size(lits.size())) { - for (unsigned i = 0; i < size(); ++i) { - m_lits[i] = lits[i]; - } - } - - bool ba_solver::xr::is_watching(literal l) const { - return - l == (*this)[0] || l == (*this)[1] || - ~l == (*this)[0] || ~l == (*this)[1]; - } - - bool ba_solver::xr::well_formed() const { - uint_set vars; - if (lit() != null_literal) vars.insert(lit().var()); - for (literal l : *this) { - bool_var v = l.var(); - if (vars.contains(v)) return false; - vars.insert(v); - } - return true; - } // -------------------- // xr: - void ba_solver::clear_watch(xr& x) { - x.clear_watch(); - unwatch_literal(x[0], x); - unwatch_literal(x[1], x); - unwatch_literal(~x[0], x); - unwatch_literal(~x[1], x); - } - - bool ba_solver::parity(xr const& x, unsigned offset) const { - bool odd = false; - unsigned sz = x.size(); - for (unsigned i = offset; i < sz; ++i) { - SASSERT(value(x[i]) != l_undef); - if (value(x[i]) == l_true) { - odd = !odd; - } - } - return odd; - } - - bool ba_solver::init_watch(xr& x) { - clear_watch(x); - VERIFY(x.lit() == null_literal); - TRACE("ba", display(tout, x, true);); - unsigned sz = x.size(); - unsigned j = 0; - for (unsigned i = 0; i < sz && j < 2; ++i) { - if (value(x[i]) == l_undef) { - x.swap(i, j); - ++j; - } - } - switch (j) { - case 0: - if (!parity(x, 0)) { - unsigned l = lvl(x[0]); - j = 1; - for (unsigned i = 1; i < sz; ++i) { - if (lvl(x[i]) > l) { - j = i; - l = lvl(x[i]); - } - } - set_conflict(x, x[j]); - } - return false; - case 1: - SASSERT(x.lit() == null_literal || value(x.lit()) == l_true); - assign(x, parity(x, 1) ? ~x[0] : x[0]); - return false; - default: - SASSERT(j == 2); - watch_literal(x[0], x); - watch_literal(x[1], x); - watch_literal(~x[0], x); - watch_literal(~x[1], x); - return true; - } - } - - lbool ba_solver::add_assign(xr& x, literal alit) { // literal is assigned unsigned sz = x.size(); @@ -136,10 +43,10 @@ namespace sat { literal lit = x[i]; if (value(lit) == l_undef) { x.swap(index, i); - unwatch_literal(~alit, x); + x.unwatch_literal(*this, ~alit); // alit gets unwatched by propagate_core because we return l_undef - watch_literal(lit, x); - watch_literal(~lit, x); + x.watch_literal(*this, lit); + x.watch_literal(*this, ~lit); TRACE("ba", tout << "swap in: " << lit << " " << x << "\n";); return l_undef; } @@ -150,10 +57,10 @@ namespace sat { // alit resides at index 1. VERIFY(x[1].var() == alit.var()); if (value(x[0]) == l_undef) { - bool p = parity(x, 1); + bool p = x.parity(*this, 1); assign(x, p ? ~x[0] : x[0]); } - else if (!parity(x, 0)) { + else if (!x.parity(*this, 0)) { set_conflict(x, ~x[1]); } return inconsistent() ? l_false : l_true; @@ -232,7 +139,7 @@ namespace sat { } - ba_solver::constraint* ba_solver::add_xr(literal_vector const& _lits, bool learned) { + constraint* ba_solver::add_xr(literal_vector const& _lits, bool learned) { literal_vector lits; u_map var2sign; bool sign = false, odd = false; @@ -400,31 +307,6 @@ namespace sat { } } - lbool ba_solver::eval(xr const& x) const { - bool odd = false; - - for (auto l : x) { - switch (value(l)) { - case l_true: odd = !odd; break; - case l_false: break; - default: return l_undef; - } - } - return odd ? l_true : l_false; - } - - lbool ba_solver::eval(model const& m, xr const& x) const { - bool odd = false; - - for (auto l : x) { - switch (value(m, l)) { - case l_true: odd = !odd; break; - case l_false: break; - default: return l_undef; - } - } - return odd ? l_true : l_false; - } void ba_solver::pre_simplify(xor_finder& xf, constraint& c) { if (c.is_xr() && c.size() <= xf.max_xor_size()) { @@ -532,30 +414,5 @@ namespace sat { } } - void ba_solver::display(std::ostream& out, xr const& x, bool values) const { - out << "xr: "; - for (literal l : x) { - out << l; - if (values) { - out << "@(" << value(l); - if (value(l) != l_undef) { - out << ":" << lvl(l); - } - out << ") "; - } - else { - out << " "; - } - } - out << "\n"; - } - - bool ba_solver::validate_unit_propagation(xr const& x, literal alit) const { - if (value(x.lit()) != l_true) return false; - for (unsigned i = 1; i < x.size(); ++i) { - if (value(x[i]) == l_undef) return false; - } - return true; - } } diff --git a/src/sat/tactic/goal2sat.cpp b/src/sat/tactic/goal2sat.cpp index d7feab9dc..e31d9f6a9 100644 --- a/src/sat/tactic/goal2sat.cpp +++ b/src/sat/tactic/goal2sat.cpp @@ -102,80 +102,80 @@ struct goal2sat::imp : public sat::sat_internalizer { m_max_memory = megabytes_to_bytes(p.get_uint("max_memory", UINT_MAX)); m_xor_solver = p.get_bool("xor_solver", false); m_euf = sp.euf(); - m_drat = sp.drat_file() != symbol(); + m_drat = sp.drat_file().is_non_empty_string(); } void throw_op_not_handled(std::string const& s) { std::string s0 = "operator " + s + " not supported, apply simplifier before invoking translator"; throw tactic_exception(std::move(s0)); } + + sat::status mk_status() const { + return sat::status::th(m_is_redundant, m.get_basic_family_id()); + } void mk_clause(sat::literal l) { TRACE("goal2sat", tout << "mk_clause: " << l << "\n";); - m_solver.add_clause(1, &l, m_is_redundant ? sat::status::redundant() : sat::status::asserted()); + m_solver.add_clause(1, &l, mk_status()); } void mk_clause(sat::literal l1, sat::literal l2) { TRACE("goal2sat", tout << "mk_clause: " << l1 << " " << l2 << "\n";); - m_solver.add_clause(l1, l2, m_is_redundant ? sat::status::redundant() : sat::status::asserted()); + m_solver.add_clause(l1, l2, mk_status()); } void mk_clause(sat::literal l1, sat::literal l2, sat::literal l3) { TRACE("goal2sat", tout << "mk_clause: " << l1 << " " << l2 << " " << l3 << "\n";); - m_solver.add_clause(l1, l2, l3, m_is_redundant ? sat::status::redundant() : sat::status::asserted()); + m_solver.add_clause(l1, l2, l3, mk_status()); } void mk_clause(unsigned num, sat::literal * lits) { TRACE("goal2sat", tout << "mk_clause: "; for (unsigned i = 0; i < num; i++) tout << lits[i] << " "; tout << "\n";); - m_solver.add_clause(num, lits, m_is_redundant ? sat::status::redundant() : sat::status::asserted()); + m_solver.add_clause(num, lits, mk_status()); } void mk_root_clause(sat::literal l) { TRACE("goal2sat", tout << "mk_clause: " << l << "\n";); - m_solver.add_clause(1, &l, m_is_redundant ? sat::status::redundant() : sat::status::input()); + m_solver.add_clause(1, &l, m_is_redundant ? mk_status() : sat::status::input()); } void mk_root_clause(sat::literal l1, sat::literal l2) { TRACE("goal2sat", tout << "mk_clause: " << l1 << " " << l2 << "\n";); - m_solver.add_clause(l1, l2, m_is_redundant ? sat::status::redundant() : sat::status::input()); + m_solver.add_clause(l1, l2, m_is_redundant ? mk_status() : sat::status::input()); } void mk_root_clause(sat::literal l1, sat::literal l2, sat::literal l3) { TRACE("goal2sat", tout << "mk_clause: " << l1 << " " << l2 << " " << l3 << "\n";); - m_solver.add_clause(l1, l2, l3, m_is_redundant ? sat::status::redundant() : sat::status::input()); + m_solver.add_clause(l1, l2, l3, m_is_redundant ? mk_status() : sat::status::input()); } void mk_root_clause(unsigned num, sat::literal * lits) { TRACE("goal2sat", tout << "mk_clause: "; for (unsigned i = 0; i < num; i++) tout << lits[i] << " "; tout << "\n";); - m_solver.add_clause(num, lits, m_is_redundant ? sat::status::redundant() : sat::status::input()); + m_solver.add_clause(num, lits, m_is_redundant ? mk_status() : sat::status::input()); } sat::bool_var add_var(bool is_ext, expr* n) { auto v = m_solver.add_var(is_ext); - log_node(v, n); + log_node(n); + log_def(v, n); return v; } - void log_node(sat::bool_var v, expr* n) { - if (m_drat && m_solver.get_drat_ptr()) { - sat::drat* drat = m_solver.get_drat_ptr(); - if (is_app(n)) { - app* a = to_app(n); - std::stringstream strm; - strm << mk_ismt2_func(a->get_decl(), m); - drat->def_begin(n->get_id(), strm.str()); - for (expr* arg : *a) - drat->def_add_arg(arg->get_id()); - drat->def_end(); - } - else { - IF_VERBOSE(0, verbose_stream() << "skipping DRAT of non-app\n"); - } - drat->bool_def(v, n->get_id()); - } + void log_def(sat::bool_var v, expr* n) { + if (m_drat && m_solver.get_drat_ptr()) + m_solver.get_drat_ptr()->bool_def(v, n->get_id()); } - + void log_node(expr* n) { + if (m_drat && m_solver.get_drat_ptr()) { + if (is_app(n)) { + for (expr* arg : *to_app(n)) + if (m.is_not(arg)) + log_node(arg); + } + ensure_euf()->drat_log_node(n); + } + } sat::literal mk_true() { if (m_true == sat::null_literal) { @@ -187,6 +187,9 @@ struct goal2sat::imp : public sat::sat_internalizer { } sat::bool_var to_bool_var(expr* e) override { + sat::literal l; + if (is_app(e) && m_cache.find(to_app(e), l) && !l.sign()) + return l.var(); return m_map.to_bool_var(e); } @@ -209,7 +212,7 @@ struct goal2sat::imp : public sat::sat_internalizer { if (v == sat::null_bool_var) v = mk_bool_var(t); else - m_solver.set_external(v); + m_solver.set_external(v); return v; } @@ -264,9 +267,11 @@ struct goal2sat::imp : public sat::sat_internalizer { else m_unhandled_funs.push_back(to_app(t)->get_decl()); } - bool ext = m_default_external || !is_uninterp_const(t) || m_interface_vars.contains(t); v = mk_bool_var(t); l = sat::literal(v, sign); + bool ext = m_default_external || !is_uninterp_const(t) || m_interface_vars.contains(t); + if (ext) + m_solver.set_external(v); TRACE("sat", tout << "new_var: " << v << ": " << mk_bounded_pp(t, m, 2) << " " << is_uninterp_const(t) << "\n";); } } @@ -275,8 +280,6 @@ struct goal2sat::imp : public sat::sat_internalizer { l = sat::literal(v, sign); m_solver.set_eliminated(v, false); } - if (root) - m_result_stack.reset(); SASSERT(l != sat::null_literal); if (root) mk_root_clause(l); @@ -354,8 +357,10 @@ struct goal2sat::imp : public sat::sat_internalizer { } void convert_or(app * t, bool root, bool sign) { - TRACE("goal2sat", tout << "convert_or:\n" << mk_bounded_pp(t, m, 2) << "\n";); + TRACE("goal2sat", tout << "convert_or:\n" << mk_bounded_pp(t, m, 2) << " root " << root << " stack " << m_result_stack.size() << "\n";); unsigned num = t->get_num_args(); + SASSERT(num <= m_result_stack.size()); + unsigned old_sz = m_result_stack.size() - num; if (root) { SASSERT(num == m_result_stack.size()); if (sign) { @@ -369,7 +374,7 @@ struct goal2sat::imp : public sat::sat_internalizer { else { mk_root_clause(m_result_stack.size(), m_result_stack.c_ptr()); } - m_result_stack.reset(); + m_result_stack.shrink(old_sz); } else { SASSERT(num <= m_result_stack.size()); @@ -392,8 +397,7 @@ struct goal2sat::imp : public sat::sat_internalizer { mk_clause(num+1, lits); if (m_aig) { m_aig->add_or(l, num, aig_lits.c_ptr()); - } - unsigned old_sz = m_result_stack.size() - num - 1; + } m_result_stack.shrink(old_sz); if (sign) l.neg(); @@ -402,8 +406,11 @@ struct goal2sat::imp : public sat::sat_internalizer { } void convert_and(app * t, bool root, bool sign) { - TRACE("goal2sat", tout << "convert_and:\n" << mk_ismt2_pp(t, m) << "\n";); + TRACE("goal2sat", tout << "convert_and:\n" << mk_bounded_pp(t, m, 2) << " root: " << root << " result stack: " << m_result_stack.size() << "\n";); + unsigned num = t->get_num_args(); + unsigned old_sz = m_result_stack.size() - num; + SASSERT(num <= m_result_stack.size()); if (root) { if (sign) { for (unsigned i = 0; i < num; ++i) { @@ -416,7 +423,7 @@ struct goal2sat::imp : public sat::sat_internalizer { mk_root_clause(m_result_stack[i]); } } - m_result_stack.reset(); + m_result_stack.shrink(old_sz); } else { SASSERT(num <= m_result_stack.size()); @@ -445,7 +452,7 @@ struct goal2sat::imp : public sat::sat_internalizer { } if (sign) l.neg(); - unsigned old_sz = m_result_stack.size() - num - 1; + m_result_stack.shrink(old_sz); m_result_stack.push_back(l); TRACE("goal2sat", tout << m_result_stack << "\n";); @@ -458,6 +465,7 @@ struct goal2sat::imp : public sat::sat_internalizer { sat::literal c = m_result_stack[sz-3]; sat::literal t = m_result_stack[sz-2]; sat::literal e = m_result_stack[sz-1]; + m_result_stack.shrink(sz - 3); if (root) { SASSERT(sz == 3); if (sign) { @@ -468,7 +476,6 @@ struct goal2sat::imp : public sat::sat_internalizer { mk_root_clause(~c, t); mk_root_clause(c, e); } - m_result_stack.reset(); } else { sat::bool_var k = add_var(false, n); @@ -485,7 +492,7 @@ struct goal2sat::imp : public sat::sat_internalizer { if (m_aig) m_aig->add_ite(l, c, t, e); if (sign) l.neg(); - m_result_stack.shrink(sz-3); + m_result_stack.push_back(l); } } @@ -496,6 +503,7 @@ struct goal2sat::imp : public sat::sat_internalizer { SASSERT(sz >= 2); sat::literal l2 = m_result_stack[sz - 1]; sat::literal l1 = m_result_stack[sz - 2]; + m_result_stack.shrink(sz - 2); if (root) { SASSERT(sz == 2); if (sign) { @@ -504,8 +512,7 @@ struct goal2sat::imp : public sat::sat_internalizer { } else { mk_root_clause(~l1, l2); - } - m_result_stack.reset(); + } } else { sat::bool_var k = add_var(false, t); @@ -517,7 +524,6 @@ struct goal2sat::imp : public sat::sat_internalizer { mk_clause(~l2, l); if (sign) l.neg(); - m_result_stack.shrink(sz - 2); m_result_stack.push_back(l); } } @@ -528,6 +534,7 @@ struct goal2sat::imp : public sat::sat_internalizer { SASSERT(sz >= 2); sat::literal l1 = m_result_stack[sz-1]; sat::literal l2 = m_result_stack[sz-2]; + m_result_stack.shrink(sz - 2); if (root) { SASSERT(sz == 2); if (sign) { @@ -537,8 +544,7 @@ struct goal2sat::imp : public sat::sat_internalizer { else { mk_root_clause(l1, ~l2); mk_root_clause(~l1, l2); - } - m_result_stack.reset(); + } } else { sat::bool_var k = add_var(false, t); @@ -551,7 +557,6 @@ struct goal2sat::imp : public sat::sat_internalizer { if (m_aig) m_aig->add_iff(l, l1, l2); if (sign) l.neg(); - m_result_stack.shrink(sz - 2); m_result_stack.push_back(l); } } @@ -570,7 +575,8 @@ struct goal2sat::imp : public sat::sat_internalizer { return m_unhandled_funs; } - void convert_euf(expr* e, bool root, bool sign) { + euf::solver* ensure_euf() { + SASSERT(m_euf); sat::extension* ext = m_solver.get_extension(); euf::solver* euf = nullptr; if (!ext) { @@ -584,9 +590,14 @@ struct goal2sat::imp : public sat::sat_internalizer { } if (!euf) throw default_exception("cannot convert to euf"); + return euf; + } + + void convert_euf(expr* e, bool root, bool sign) { + SASSERT(m_euf); + TRACE("goal2sat", tout << "convert-euf " << mk_bounded_pp(e, m, 2) << " root " << root << "\n";); + euf::solver* euf = ensure_euf(); sat::literal lit = euf->internalize(e, sign, root, m_is_redundant); - if (root) - m_result_stack.reset(); if (lit == sat::null_literal) return; if (root) @@ -597,7 +608,7 @@ struct goal2sat::imp : public sat::sat_internalizer { void convert_ba(app* t, bool root, bool sign) { SASSERT(!m_euf); - sat::extension* ext = m_solver.get_extension(); + sat::extension* ext = dynamic_cast(m_solver.get_extension()); euf::th_solver* th = nullptr; if (!ext) { th = alloc(sat::ba_solver, m, *this, pb.get_family_id()); @@ -609,10 +620,7 @@ struct goal2sat::imp : public sat::sat_internalizer { SASSERT(th); } auto lit = th->internalize(t, sign, root, m_is_redundant); - if (root) - m_result_stack.reset(); - else - m_result_stack.shrink(m_result_stack.size() - t->get_num_args()); + m_result_stack.shrink(m_result_stack.size() - t->get_num_args()); if (lit == sat::null_literal) return; if (root) @@ -681,6 +689,10 @@ struct goal2sat::imp : public sat::sat_internalizer { }; void process(expr* n, bool is_root, bool redundant) { + TRACE("goal2sat", tout << "process-begin " << mk_bounded_pp(n, m, 3) + << " root: " << is_root + << " result-stack: " << m_result_stack.size() + << " frame-stack: " << m_frame_stack.size() << "\n";); flet _is_redundant(m_is_redundant, redundant); scoped_stack _sc(*this, is_root); unsigned sz = m_frame_stack.size(); @@ -693,14 +705,16 @@ struct goal2sat::imp : public sat::sat_internalizer { throw tactic_exception(m.limit().get_cancel_msg()); if (memory::get_allocation_size() > m_max_memory) throw tactic_exception(TACTIC_MAX_MEMORY_MSG); - frame & fr = m_frame_stack.back(); - app * t = fr.m_t; - bool root = fr.m_root; - bool sign = fr.m_sign; + unsigned fsz = m_frame_stack.size(); + frame const& _fr = m_frame_stack[fsz-1]; + app * t = _fr.m_t; + bool root = _fr.m_root; + bool sign = _fr.m_sign; TRACE("goal2sat_bug", tout << "result stack\n"; - tout << mk_ismt2_pp(t, m) << " root: " << root << " sign: " << sign << "\n"; + tout << "ref-count: " << t->get_ref_count() << "\n"; + tout << mk_bounded_pp(t, m, 3) << " root: " << root << " sign: " << sign << "\n"; tout << m_result_stack << "\n";); - if (fr.m_idx == 0 && process_cached(t, root, sign)) { + if (_fr.m_idx == 0 && process_cached(t, root, sign)) { m_frame_stack.pop_back(); continue; } @@ -715,27 +729,37 @@ struct goal2sat::imp : public sat::sat_internalizer { continue; } unsigned num = t->get_num_args(); - while (fr.m_idx < num) { - expr * arg = t->get_arg(fr.m_idx); - fr.m_idx++; + while (m_frame_stack[fsz-1].m_idx < num) { + expr * arg = t->get_arg(m_frame_stack[fsz-1].m_idx); + m_frame_stack[fsz - 1].m_idx++; if (!visit(arg, false, false)) goto loop; + TRACE("goal2sat_bug", tout << "visit " << mk_bounded_pp(t, m, 2) << " result stack: " << m_result_stack.size() << "\n";); } TRACE("goal2sat_bug", tout << "converting\n"; - tout << mk_ismt2_pp(t, m) << " root: " << root << " sign: " << sign << "\n"; + tout << mk_bounded_pp(t, m, 2) << " root: " << root << " sign: " << sign << "\n"; tout << m_result_stack << "\n";); + SASSERT(m_frame_stack.size() > sz); convert(t, root, sign); - m_frame_stack.pop_back(); + m_frame_stack.pop_back(); } + TRACE("goal2sat", tout + << "done process: " << mk_bounded_pp(n, m, 3) + << " frame-stack: " << m_frame_stack.size() + << " result-stack: " << m_result_stack.size() << "\n";); } sat::literal internalize(expr* n, bool redundant) override { unsigned sz = m_result_stack.size(); (void)sz; + SASSERT(n->get_ref_count() > 0); + TRACE("goal2sat", tout << "internalize " << mk_bounded_pp(n, m, 2) << "\n";); process(n, false, redundant); SASSERT(m_result_stack.size() == sz + 1); sat::literal result = m_result_stack.back(); m_result_stack.pop_back(); + if (!result.sign() && m_map.to_bool_var(n) == sat::null_bool_var) + m_map.insert(n, result.var()); return result; } @@ -766,8 +790,8 @@ struct goal2sat::imp : public sat::sat_internalizer { } void process(expr * n) { - m_result_stack.reset(); - TRACE("goal2sat", tout << "assert: "<< mk_pp(n, m) << "\n";); + VERIFY(m_result_stack.empty()); + TRACE("goal2sat", tout << "assert: " << mk_bounded_pp(n, m, 3) << "\n";); process(n, true, m_is_redundant); CTRACE("goal2sat", !m_result_stack.empty(), tout << m_result_stack << "\n";); SASSERT(m_result_stack.empty()); diff --git a/src/shell/drat_frontend.cpp b/src/shell/drat_frontend.cpp index 801f0a58b..3ebf592c1 100644 --- a/src/shell/drat_frontend.cpp +++ b/src/shell/drat_frontend.cpp @@ -17,11 +17,14 @@ Copyright (c) 2020 Microsoft Corporation class smt_checker { ast_manager& m; + sat::drat& m_drat; expr_ref_vector const& m_b2e; expr_ref_vector m_fresh_exprs; expr_ref_vector m_core; + expr_ref_vector m_inputs; params_ref m_params; - scoped_ptr m_solver; + scoped_ptr m_lemma_solver, m_input_solver; + sat::literal_vector m_units; expr* fresh(expr* e) { unsigned i = e->get_id(); @@ -59,29 +62,103 @@ class smt_checker { m_core.push_back(fml); } } -public: - smt_checker(expr_ref_vector const& b2e): - m(b2e.m()), m_b2e(b2e), m_fresh_exprs(m), m_core(m) { - m_solver = mk_smt_solver(m, m_params, symbol()); + + expr_ref lit2expr(sat::literal lit) { + return expr_ref(lit.sign() ? m.mk_not(m_b2e[lit.var()]) : m_b2e[lit.var()], m); } - - void check_shallow(sat::literal_vector const& lits) { - unfold1(lits); - m_solver->push(); - for (auto* c : m_core) - m_solver->assert_expr(c); - lbool is_sat = m_solver->check_sat(); - m_solver->pop(1); - if (is_sat == l_true) { - std::cout << "did not verify: " << lits << "\n" << m_core << "\n"; + + void add_units() { + auto const& units = m_drat.units(); + for (unsigned i = m_units.size(); i < units.size(); ++i) { + sat::literal lit = units[i]; + m_lemma_solver->assert_expr(lit2expr(lit)); + } + m_units.append(units.size() - m_units.size(), units.c_ptr() + m_units.size()); + } + + void check_assertion_redundant(sat::literal_vector const& input) { + expr_ref_vector args(m); + for (auto lit : input) + args.push_back(lit2expr(lit)); + m_inputs.push_back(args.size() == 1 ? args.back() : m.mk_or(args)); + + m_input_solver->push(); + for (auto lit : input) { + m_input_solver->assert_expr(lit2expr(~lit)); + } + lbool is_sat = m_input_solver->check_sat(); + if (is_sat != l_false) { + std::cout << "Failed to verify input\n"; + exit(0); + } + m_input_solver->pop(1); + } + + + /** + * Validate a lemma using the following attempts: + * 1. check if it is propositional DRUP + * 2. establish the negation of literals is unsat using a limited unfolding. + * 3. check that it is DRUP modulo theories by taking propositional implicants from DRUP validation + */ + sat::literal_vector drup_units; + + void check_clause(sat::literal_vector const& lits) { + + add_units(); + drup_units.reset(); + if (m_drat.is_drup(lits.size(), lits.c_ptr(), drup_units)) { + std::cout << "drup\n"; + return; + } + m_lemma_solver->push(); + for (auto lit : drup_units) + m_lemma_solver->assert_expr(lit2expr(lit)); + lbool is_sat = m_lemma_solver->check_sat(); + if (is_sat != l_false) { + std::cout << "did not verify: " << lits << "\n"; for (sat::literal lit : lits) { - expr_ref e(m_b2e[lit.var()], m); - if (lit.sign()) - e = m.mk_not(e); - std::cout << e << " "; + std::cout << lit2expr(lit) << "\n"; } std::cout << "\n"; + m_lemma_solver->display(std::cout); + exit(0); } + m_lemma_solver->pop(1); + std::cout << "smt\n"; + check_assertion_redundant(lits); + } + +public: + smt_checker(sat::drat& drat, expr_ref_vector const& b2e): + m(b2e.m()), m_drat(drat), m_b2e(b2e), m_fresh_exprs(m), m_core(m), m_inputs(m) { + m_lemma_solver = mk_smt_solver(m, m_params, symbol()); + m_input_solver = mk_smt_solver(m, m_params, symbol()); + } + + void add(sat::literal_vector const& lits, sat::status const& st) { + for (sat::literal lit : lits) + while (lit.var() >= m_drat.get_solver().num_vars()) + m_drat.get_solver().mk_var(true); + if (st.is_input()) + check_assertion_redundant(lits); + else if (!st.is_sat() && !st.is_deleted()) + check_clause(lits); + m_drat.add(lits, st); + } + + /** + * Add an assertion from the source file + */ + void add_assertion(expr* a) { + m_input_solver->assert_expr(a); + } + + void display_input() { + scoped_ptr s = mk_smt_solver(m, m_params, symbol()); + for (auto* e : m_inputs) + s->assert_expr(e); + s->display(std::cout); } }; @@ -100,10 +177,12 @@ static void verify_smt(char const* drat_file, char const* smt_file) { std::ifstream ins(drat_file); dimacs::drat_parser drat(ins, std::cerr); + ast_manager& m = ctx.m(); std::function read_theory = [&](char const* r) { - if (strcmp(r, "euf") == 0) - return ctx.m().get_basic_family_id(); - return ctx.m().mk_family_id(symbol(r)); + return m.mk_family_id(symbol(r)); + }; + std::function write_theory = [&](int th) { + return m.get_family_name(th); }; drat.set_read_theory(read_theory); params_ref p; @@ -113,43 +192,24 @@ static void verify_smt(char const* drat_file, char const* smt_file) { sat::drat drat_checker(solver); drat_checker.updt_config(); - expr_ref_vector bool_var2expr(ctx.m()); - expr_ref_vector exprs(ctx.m()), args(ctx.m()); + expr_ref_vector bool_var2expr(m); + expr_ref_vector exprs(m), args(m), inputs(m); func_decl* f = nullptr; ptr_vector sorts; - smt_checker checker(bool_var2expr); + smt_checker checker(drat_checker, bool_var2expr); - auto check_smt = [&](dimacs::drat_record const& r) { - auto const& st = r.m_status; - if (st.is_input()) - ; - else if (st.is_sat() && st.is_asserted()) { - std::cout << "Tseitin tautology " << r; - checker.check_shallow(r.m_lits); - } - else if (st.is_sat()) - ; - else if (st.is_deleted()) - ; - else { - std::cout << "check smt " << r; - checker.check_shallow(r.m_lits); - // TBD: shallow check may fail because it doesn't include - // all RUP units, whish are sometimes required. - } - }; + for (expr* a : ctx.assertions()) + checker.add_assertion(a); for (auto const& r : drat) { - std::cout << r; + std::cout << dimacs::drat_pp(r, write_theory); std::cout.flush(); switch (r.m_tag) { case dimacs::drat_record::tag_t::is_clause: - for (sat::literal lit : r.m_lits) - while (lit.var() >= solver.num_vars()) - solver.mk_var(true); - drat_checker.add(r.m_lits, r.m_status); - check_smt(r); + checker.add(r.m_lits, r.m_status); + if (drat_checker.inconsistent()) + std::cout << "inconsistent\n"; break; case dimacs::drat_record::tag_t::is_node: args.reset(); @@ -170,7 +230,7 @@ static void verify_smt(char const* drat_file, char const* smt_file) { exprs.reserve(r.m_node_id+1); exprs.set(r.m_node_id, ctx.m().mk_app(f, args.size(), args.c_ptr())); break; - case dimacs::drat_record::is_bool_def: + case dimacs::drat_record::tag_t::is_bool_def: bool_var2expr.reserve(r.m_node_id+1); bool_var2expr.set(r.m_node_id, exprs.get(r.m_args[0])); break; @@ -182,6 +242,7 @@ static void verify_smt(char const* drat_file, char const* smt_file) { statistics st; drat_checker.collect_statistics(st); std::cout << st << "\n"; + } diff --git a/src/smt/dyn_ack.cpp b/src/smt/dyn_ack.cpp index dc0ae2551..5fdd53d70 100644 --- a/src/smt/dyn_ack.cpp +++ b/src/smt/dyn_ack.cpp @@ -375,7 +375,7 @@ namespace smt { } void dyn_ack_manager::propagate_eh() { - if (m_params.m_dack == DACK_DISABLED) + if (m_params.m_dack == dyn_ack_strategy::DACK_DISABLED) return; m_num_propagations_since_last_gc++; if (m_num_propagations_since_last_gc > m_params.m_dack_gc) { @@ -407,7 +407,7 @@ namespace smt { } void dyn_ack_manager::instantiate(app * n1, app * n2) { - SASSERT(m_params.m_dack != DACK_DISABLED); + SASSERT(m_params.m_dack != dyn_ack_strategy::DACK_DISABLED); SASSERT(n1->get_decl() == n2->get_decl()); SASSERT(n1->get_num_args() == n2->get_num_args()); SASSERT(n1 != n2); @@ -461,7 +461,7 @@ namespace smt { void dyn_ack_manager::instantiate(app * n1, app * n2, app* r) { context& ctx = m_context; - SASSERT(m_params.m_dack != DACK_DISABLED); + SASSERT(m_params.m_dack != dyn_ack_strategy::DACK_DISABLED); SASSERT(n1 != n2 && n1 != r && n2 != r); ctx.m_stats.m_num_dyn_ack++; TRACE("dyn_ack_inst", tout << "dyn_ack: " << n1->get_id() << " " << n2->get_id() << " " << r->get_id() << "\n";); diff --git a/src/smt/dyn_ack.h b/src/smt/dyn_ack.h index 627c23191..7f33e57e5 100644 --- a/src/smt/dyn_ack.h +++ b/src/smt/dyn_ack.h @@ -97,7 +97,7 @@ namespace smt { \brief This method is invoked when the congruence rule was used during conflict resolution. */ void used_cg_eh(app * n1, app * n2) { - if (m_params.m_dack == DACK_CR) + if (m_params.m_dack == dyn_ack_strategy::DACK_CR) cg_eh(n1, n2); } @@ -105,7 +105,7 @@ namespace smt { \brief This method is invoked when the congruence rule is the root of a conflict. */ void cg_conflict_eh(app * n1, app * n2) { - if (m_params.m_dack == DACK_ROOT) + if (m_params.m_dack == dyn_ack_strategy::DACK_ROOT) cg_eh(n1, n2); } diff --git a/src/smt/fingerprints.h b/src/smt/fingerprints.h index f3f96e057..b1308e9b0 100644 --- a/src/smt/fingerprints.h +++ b/src/smt/fingerprints.h @@ -25,11 +25,11 @@ namespace smt { class fingerprint { protected: - void * m_data; - unsigned m_data_hash; - expr* m_def; - unsigned m_num_args; - enode * * m_args; + void* m_data{ nullptr }; + unsigned m_data_hash{ 0 }; + expr* m_def{ nullptr }; + unsigned m_num_args{ 0 }; + enode** m_args{ nullptr }; friend class fingerprint_set; fingerprint() {} diff --git a/src/smt/params/dyn_ack_params.cpp b/src/smt/params/dyn_ack_params.cpp index 4ba230a47..b1e99cf21 100644 --- a/src/smt/params/dyn_ack_params.cpp +++ b/src/smt/params/dyn_ack_params.cpp @@ -32,7 +32,7 @@ void dyn_ack_params::updt_params(params_ref const & _p) { #define DISPLAY_PARAM(X) out << #X"=" << X << std::endl; void dyn_ack_params::display(std::ostream & out) const { - DISPLAY_PARAM(m_dack); + DISPLAY_PARAM((unsigned)m_dack); DISPLAY_PARAM(m_dack_eq); DISPLAY_PARAM(m_dack_factor); DISPLAY_PARAM(m_dack_threshold); diff --git a/src/smt/params/dyn_ack_params.h b/src/smt/params/dyn_ack_params.h index 71091ca94..ce5a685bf 100644 --- a/src/smt/params/dyn_ack_params.h +++ b/src/smt/params/dyn_ack_params.h @@ -20,7 +20,7 @@ Revision History: #include "util/params.h" -enum dyn_ack_strategy { +enum class dyn_ack_strategy { DACK_DISABLED, DACK_ROOT, // congruence is the root of the conflict DACK_CR // congruence used during conflict resolution @@ -36,7 +36,7 @@ struct dyn_ack_params { public: dyn_ack_params(params_ref const & p = params_ref()) : - m_dack(DACK_ROOT), + m_dack(dyn_ack_strategy::DACK_ROOT), m_dack_eq(false), m_dack_factor(0.1), m_dack_threshold(10), diff --git a/src/smt/params/smt_params.h b/src/smt/params/smt_params.h index dc3b4c0d7..b2f898d4c 100644 --- a/src/smt/params/smt_params.h +++ b/src/smt/params/smt_params.h @@ -250,8 +250,8 @@ struct smt_params : public preprocessor_params, m_random_var_freq(0.01), m_inv_decay(1.052), m_clause_decay(1), - m_random_initial_activity(IA_RANDOM_WHEN_SEARCHING), - m_phase_selection(PS_CACHING_CONSERVATIVE), + m_random_initial_activity(initial_activity::IA_RANDOM_WHEN_SEARCHING), + m_phase_selection(phase_selection::PS_CACHING_CONSERVATIVE), m_phase_caching_on(700), m_phase_caching_off(100), m_minimize_lemmas(true), @@ -267,7 +267,7 @@ struct smt_params : public preprocessor_params, m_ematching(true), m_induction(false), m_clause_proof(false), - m_case_split_strategy(CS_ACTIVITY_DELAY_NEW), + m_case_split_strategy(case_split_strategy::CS_ACTIVITY_DELAY_NEW), m_rel_case_split_order(0), m_lookahead_diseq(false), m_theory_case_split(false), @@ -275,13 +275,13 @@ struct smt_params : public preprocessor_params, m_delay_units(false), m_delay_units_threshold(32), m_theory_resolve(false), - m_restart_strategy(RS_IN_OUT_GEOMETRIC), + m_restart_strategy(restart_strategy::RS_IN_OUT_GEOMETRIC), m_restart_initial(100), m_restart_factor(1.1), m_restart_adaptive(true), m_agility_factor(0.9999), m_restart_agility_threshold(0.18), - m_lemma_gc_strategy(LGC_FIXED), + m_lemma_gc_strategy(lemma_gc_strategy::LGC_FIXED), m_lemma_gc_half(false), m_recent_lemmas_size(100), m_lemma_gc_initial(5000), diff --git a/src/smt/params/theory_arith_params.cpp b/src/smt/params/theory_arith_params.cpp index a2a497731..ce9ff5fbc 100644 --- a/src/smt/params/theory_arith_params.cpp +++ b/src/smt/params/theory_arith_params.cpp @@ -50,11 +50,11 @@ void theory_arith_params::updt_params(params_ref const & _p) { void theory_arith_params::display(std::ostream & out) const { DISPLAY_PARAM(m_arith_eq2ineq); DISPLAY_PARAM(m_arith_process_all_eqs); - DISPLAY_PARAM(m_arith_mode); + DISPLAY_PARAM((unsigned)m_arith_mode); DISPLAY_PARAM(m_arith_auto_config_simplex); //!< force simplex solver in auto_config DISPLAY_PARAM(m_arith_blands_rule_threshold); DISPLAY_PARAM(m_arith_propagate_eqs); - DISPLAY_PARAM(m_arith_bound_prop); + DISPLAY_PARAM((unsigned)m_arith_bound_prop); DISPLAY_PARAM(m_arith_stronger_lemmas); DISPLAY_PARAM(m_arith_skip_rows_with_big_coeffs); DISPLAY_PARAM(m_arith_max_lemma_size); @@ -81,7 +81,7 @@ void theory_arith_params::display(std::ostream & out) const { DISPLAY_PARAM(m_arith_pivot_strategy); DISPLAY_PARAM(m_arith_bounded_expansion); DISPLAY_PARAM(m_arith_add_binary_bounds); - DISPLAY_PARAM(m_arith_propagation_strategy); + DISPLAY_PARAM((unsigned)m_arith_propagation_strategy); DISPLAY_PARAM(m_arith_eq_bounds); DISPLAY_PARAM(m_arith_lazy_adapter); DISPLAY_PARAM(m_arith_fixnum); diff --git a/src/smt/params/theory_arith_params.h b/src/smt/params/theory_arith_params.h index c5f7c5e3a..6abb75b10 100644 --- a/src/smt/params/theory_arith_params.h +++ b/src/smt/params/theory_arith_params.h @@ -21,7 +21,7 @@ Revision History: #include #include "util/params.h" -enum arith_solver_id { +enum class arith_solver_id { AS_NO_ARITH, // 0 AS_DIFF_LOGIC, // 1 AS_OLD_ARITH, // 2 @@ -31,13 +31,13 @@ enum arith_solver_id { AS_NEW_ARITH // 6 }; -enum bound_prop_mode { +enum class bound_prop_mode { BP_NONE, BP_SIMPLE, // only used for implying literals BP_REFINE // adds new literals, but only refines finite bounds }; -enum arith_prop_strategy { +enum class arith_prop_strategy { ARITH_PROP_AGILITY, ARITH_PROP_PROPORTIONAL }; @@ -114,11 +114,11 @@ struct theory_arith_params { theory_arith_params(params_ref const & p = params_ref()): m_arith_eq2ineq(false), m_arith_process_all_eqs(false), - m_arith_mode(AS_NEW_ARITH), + m_arith_mode(arith_solver_id::AS_NEW_ARITH), m_arith_auto_config_simplex(false), m_arith_blands_rule_threshold(1000), m_arith_propagate_eqs(true), - m_arith_bound_prop(BP_REFINE), + m_arith_bound_prop(bound_prop_mode::BP_REFINE), m_arith_stronger_lemmas(true), m_arith_skip_rows_with_big_coeffs(true), m_arith_max_lemma_size(128), @@ -145,7 +145,7 @@ struct theory_arith_params { m_arith_bounded_expansion(false), m_arith_pivot_strategy(arith_pivot_strategy::ARITH_PIVOT_SMALLEST), m_arith_add_binary_bounds(false), - m_arith_propagation_strategy(ARITH_PROP_PROPORTIONAL), + m_arith_propagation_strategy(arith_prop_strategy::ARITH_PROP_PROPORTIONAL), m_arith_eq_bounds(false), m_arith_lazy_adapter(false), m_arith_fixnum(false), diff --git a/src/smt/params/theory_array_params.h b/src/smt/params/theory_array_params.h index 8c1bf1a55..7bc64ff0c 100644 --- a/src/smt/params/theory_array_params.h +++ b/src/smt/params/theory_array_params.h @@ -44,7 +44,7 @@ struct theory_array_params { theory_array_params(): m_array_canonize_simplify(false), m_array_simplify(true), - m_array_mode(AR_FULL), + m_array_mode(array_solver_id::AR_FULL), m_array_weak(false), m_array_extensional(true), m_array_laziness(1), diff --git a/src/smt/params/theory_bv_params.h b/src/smt/params/theory_bv_params.h index c0bb135de..339799c81 100644 --- a/src/smt/params/theory_bv_params.h +++ b/src/smt/params/theory_bv_params.h @@ -36,7 +36,7 @@ struct theory_bv_params { bool m_bv_enable_int2bv2int; bool m_bv_watch_diseq; theory_bv_params(params_ref const & p = params_ref()): - m_bv_mode(BS_BLASTER), + m_bv_mode(bv_solver_id::BS_BLASTER), m_hi_div0(false), m_bv_reflect(true), m_bv_lazy_le(false), diff --git a/src/smt/seq_regex.h b/src/smt/seq_regex.h index 3bbca7260..acd316cb5 100644 --- a/src/smt/seq_regex.h +++ b/src/smt/seq_regex.h @@ -111,9 +111,9 @@ namespace smt { /* state_graph for dead state detection, and associated methods */ - state_graph m_state_graph; ptr_addr_map m_expr_to_state; expr_ref_vector m_state_to_expr; + state_graph m_state_graph; /* map from uninterpreted regex constants to assigned regex expressions by EQ */ // expr_map m_const_to_expr; unsigned m_max_state_graph_size { 10000 }; diff --git a/src/smt/smt_setup.cpp b/src/smt/smt_setup.cpp index 8cf2f244d..1d58c3f11 100644 --- a/src/smt/smt_setup.cpp +++ b/src/smt/smt_setup.cpp @@ -313,8 +313,8 @@ namespace smt { // } } else { - m_params.m_arith_bound_prop = BP_NONE; - m_params.m_arith_propagation_strategy = ARITH_PROP_AGILITY; + m_params.m_arith_bound_prop = bound_prop_mode::BP_NONE; + m_params.m_arith_propagation_strategy = arith_prop_strategy::ARITH_PROP_AGILITY; m_params.m_arith_add_binary_bounds = true; if (!st.m_has_rational && !m_params.m_model && st.arith_k_sum_is_small()) m_context.register_plugin(alloc(smt::theory_frdl, m_context)); @@ -524,7 +524,7 @@ namespace smt { m_params.m_restart_factor = 1.5; } if (st.m_num_bin_clauses + st.m_num_units == st.m_num_clauses && st.m_cnf && st.m_arith_k_sum > rational(100000)) { - m_params.m_arith_bound_prop = BP_NONE; + m_params.m_arith_bound_prop = bound_prop_mode::BP_NONE; m_params.m_arith_stronger_lemmas = false; } setup_lra_arith(); @@ -736,7 +736,7 @@ namespace smt { } void setup::setup_i_arith() { - if (AS_OLD_ARITH == m_params.m_arith_mode) { + if (arith_solver_id::AS_OLD_ARITH == m_params.m_arith_mode) { m_context.register_plugin(alloc(smt::theory_i_arith, m_context)); } else { @@ -745,7 +745,7 @@ namespace smt { } void setup::setup_lra_arith() { - if (m_params.m_arith_mode == AS_OLD_ARITH) + if (m_params.m_arith_mode == arith_solver_id::AS_OLD_ARITH) m_context.register_plugin(alloc(smt::theory_mi_arith, m_context)); else m_context.register_plugin(alloc(smt::theory_lra, m_context)); @@ -753,10 +753,10 @@ namespace smt { void setup::setup_mi_arith() { switch (m_params.m_arith_mode) { - case AS_OPTINF: + case arith_solver_id::AS_OPTINF: m_context.register_plugin(alloc(smt::theory_inf_arith, m_context)); break; - case AS_NEW_ARITH: + case arith_solver_id::AS_NEW_ARITH: setup_lra_arith(); break; default: @@ -778,13 +778,13 @@ namespace smt { bool int_only = !st.m_has_rational && !st.m_has_real && m_params.m_arith_int_only; auto mode = m_params.m_arith_mode; if (m_logic == "QF_LIA") { - mode = AS_NEW_ARITH; + mode = arith_solver_id::AS_NEW_ARITH; } switch(mode) { - case AS_NO_ARITH: + case arith_solver_id::AS_NO_ARITH: m_context.register_plugin(alloc(smt::theory_dummy, m_context, m_manager.mk_family_id("arith"), "no arithmetic")); break; - case AS_DIFF_LOGIC: + case arith_solver_id::AS_DIFF_LOGIC: m_params.m_arith_eq2ineq = true; if (fixnum) { if (int_only) @@ -799,7 +799,7 @@ namespace smt { m_context.register_plugin(alloc(smt::theory_rdl, m_context)); } break; - case AS_DENSE_DIFF_LOGIC: + case arith_solver_id::AS_DENSE_DIFF_LOGIC: m_params.m_arith_eq2ineq = true; if (fixnum) { if (int_only) @@ -814,23 +814,23 @@ namespace smt { m_context.register_plugin(alloc(smt::theory_dense_mi, m_context)); } break; - case AS_UTVPI: + case arith_solver_id::AS_UTVPI: m_params.m_arith_eq2ineq = true; if (int_only) m_context.register_plugin(alloc(smt::theory_iutvpi, m_context)); else m_context.register_plugin(alloc(smt::theory_rutvpi, m_context)); break; - case AS_OPTINF: + case arith_solver_id::AS_OPTINF: m_context.register_plugin(alloc(smt::theory_inf_arith, m_context)); break; - case AS_OLD_ARITH: + case arith_solver_id::AS_OLD_ARITH: if (m_params.m_arith_int_only && int_only) m_context.register_plugin(alloc(smt::theory_i_arith, m_context)); else m_context.register_plugin(alloc(smt::theory_mi_arith, m_context)); break; - case AS_NEW_ARITH: + case arith_solver_id::AS_NEW_ARITH: setup_lra_arith(); break; default: diff --git a/src/smt/theory_arith.h b/src/smt/theory_arith.h index 41bdb1872..1bdd2ff15 100644 --- a/src/smt/theory_arith.h +++ b/src/smt/theory_arith.h @@ -541,7 +541,7 @@ namespace smt { int random_lower() const { return m_params.m_arith_random_lower; } int random_upper() const { return m_params.m_arith_random_upper; } unsigned blands_rule_threshold() const { return m_params.m_arith_blands_rule_threshold; } - bound_prop_mode propagation_mode() const { return m_num_conflicts < m_params.m_arith_propagation_threshold ? m_params.m_arith_bound_prop : BP_NONE; } + bound_prop_mode propagation_mode() const { return m_num_conflicts < m_params.m_arith_propagation_threshold ? m_params.m_arith_bound_prop : bound_prop_mode::BP_NONE; } bool adaptive() const { return m_params.m_arith_adaptive; } double adaptive_assertion_threshold() const { return m_params.m_arith_adaptive_assertion_threshold; } unsigned max_lemma_size() const { return m_params.m_arith_max_lemma_size; } diff --git a/src/smt/theory_arith_aux.h b/src/smt/theory_arith_aux.h index a5426b2a6..5367f6e1e 100644 --- a/src/smt/theory_arith_aux.h +++ b/src/smt/theory_arith_aux.h @@ -1151,8 +1151,8 @@ namespace smt { */ template void theory_arith::enable_record_conflict(expr* bound) { - m_params.m_arith_bound_prop = BP_NONE; - SASSERT(propagation_mode() == BP_NONE); // bound propagation rules are not (yet) handled. + m_params.m_arith_bound_prop = bound_prop_mode::BP_NONE; + SASSERT(propagation_mode() == bound_prop_mode::BP_NONE); // bound propagation rules are not (yet) handled. if (bound) { m_bound_watch = ctx.get_bool_var(bound); } diff --git a/src/smt/theory_arith_core.h b/src/smt/theory_arith_core.h index 9a12a6d2c..ade915fb6 100644 --- a/src/smt/theory_arith_core.h +++ b/src/smt/theory_arith_core.h @@ -852,7 +852,7 @@ namespace smt { SASSERT(!has_var_kind(get_var_row(s), BASE)); } TRACE("init_row_bug", tout << "after:\n"; display_row_info(tout, r);); - if (propagation_mode() != BP_NONE) + if (propagation_mode() != bound_prop_mode::BP_NONE) mark_row_for_bound_prop(r_id); SASSERT(r.is_coeff_of(s, numeral::one())); SASSERT(wf_row(r_id)); @@ -1728,7 +1728,7 @@ namespace smt { template void theory_arith::add_row(unsigned rid1, const numeral & coeff, unsigned rid2, bool apply_gcd_test) { m_stats.m_add_rows++; - if (propagation_mode() != BP_NONE) + if (propagation_mode() != bound_prop_mode::BP_NONE) mark_row_for_bound_prop(rid1); row & r1 = m_rows[rid1]; row & r2 = m_rows[rid2]; @@ -2442,7 +2442,7 @@ namespace smt { push_bound_trail(v, l, false); set_bound(b, false); - if (propagation_mode() != BP_NONE) + if (propagation_mode() != bound_prop_mode::BP_NONE) mark_rows_for_bound_prop(v); return true; @@ -2490,7 +2490,7 @@ namespace smt { push_bound_trail(v, u, true); set_bound(b, true); - if (propagation_mode() != BP_NONE) + if (propagation_mode() != bound_prop_mode::BP_NONE) mark_rows_for_bound_prop(v); return true; diff --git a/src/smt/theory_diff_logic_def.h b/src/smt/theory_diff_logic_def.h index da6b7a6d2..62eaef9a1 100644 --- a/src/smt/theory_diff_logic_def.h +++ b/src/smt/theory_diff_logic_def.h @@ -512,7 +512,7 @@ void theory_diff_logic::propagate() { switch (m_params.m_arith_propagation_strategy) { - case ARITH_PROP_PROPORTIONAL: { + case arith_prop_strategy::ARITH_PROP_PROPORTIONAL: { ++m_num_propagation_calls; if (m_num_propagation_calls * (m_stats.m_num_conflicts + 1) > @@ -526,7 +526,7 @@ void theory_diff_logic::propagate() { } break; } - case ARITH_PROP_AGILITY: { + case arith_prop_strategy::ARITH_PROP_AGILITY: { // update agility with factor generated by other conflicts. double g = m_params.m_arith_adaptive_propagation_threshold; diff --git a/src/smt/theory_lra.cpp b/src/smt/theory_lra.cpp index 2ebcb9689..16e4cd4f0 100644 --- a/src/smt/theory_lra.cpp +++ b/src/smt/theory_lra.cpp @@ -992,7 +992,7 @@ public: smt_params_helper lpar(ctx().get_params()); lp().settings().set_resource_limit(m_resource_limit); lp().settings().simplex_strategy() = static_cast(lpar.arith_simplex_strategy()); - lp().settings().bound_propagation() = BP_NONE != propagation_mode(); + lp().settings().bound_propagation() = bound_prop_mode::BP_NONE != propagation_mode(); lp().settings().enable_hnf() = lpar.arith_enable_hnf(); lp().settings().print_external_var_name() = lpar.arith_print_ext_var_names(); lp().set_track_pivoted_rows(lpar.arith_bprop_on_pivoted_rows()); @@ -2320,11 +2320,11 @@ public: } bool should_propagate() const { - return BP_NONE != propagation_mode(); + return bound_prop_mode::BP_NONE != propagation_mode(); } bool should_refine_bounds() const { - return BP_REFINE == propagation_mode() && ctx().at_search_level(); + return bound_prop_mode::BP_REFINE == propagation_mode() && ctx().at_search_level(); } void consume(rational const& v, lp::constraint_index j) { @@ -2807,7 +2807,7 @@ public: // x <= hi -> ~(x >= hi') void propagate_bound(bool_var bv, bool is_true, lp_api::bound& b) { - if (BP_NONE == propagation_mode()) { + if (bound_prop_mode::BP_NONE == propagation_mode()) { return; } lp_api::bound_kind k = b.get_bound_kind(); @@ -3113,7 +3113,7 @@ public: bool propagate_eqs() const { return params().m_arith_propagate_eqs && m_num_conflicts < params().m_arith_propagation_threshold; } - bound_prop_mode propagation_mode() const { return m_num_conflicts < params().m_arith_propagation_threshold ? params().m_arith_bound_prop : BP_NONE; } + bound_prop_mode propagation_mode() const { return m_num_conflicts < params().m_arith_propagation_threshold ? params().m_arith_bound_prop : bound_prop_mode::BP_NONE; } unsigned small_lemma_size() const { return params().m_arith_small_lemma_size; } @@ -3568,10 +3568,10 @@ public: struct scoped_arith_mode { smt_params& p; scoped_arith_mode(smt_params& p) : p(p) { - p.m_arith_mode = AS_OLD_ARITH; + p.m_arith_mode = arith_solver_id::AS_OLD_ARITH; } ~scoped_arith_mode() { - p.m_arith_mode = AS_NEW_ARITH; + p.m_arith_mode = arith_solver_id::AS_NEW_ARITH; } }; @@ -3582,7 +3582,7 @@ public: } bool validate_conflict(literal_vector const& core, svector const& eqs) { - if (params().m_arith_mode != AS_NEW_ARITH) return true; + if (params().m_arith_mode != arith_solver_id::AS_NEW_ARITH) return true; scoped_arith_mode _sa(ctx().get_fparams()); context nctx(m, ctx().get_fparams(), ctx().get_params()); add_background(nctx); @@ -3601,7 +3601,7 @@ public: } bool validate_assign(literal lit, literal_vector const& core, svector const& eqs) { - if (params().m_arith_mode != AS_NEW_ARITH) return true; + if (params().m_arith_mode != arith_solver_id::AS_NEW_ARITH) return true; scoped_arith_mode _sa(ctx().get_fparams()); context nctx(m, ctx().get_fparams(), ctx().get_params()); m_core.push_back(~lit); @@ -3616,7 +3616,7 @@ public: } bool validate_eq(enode* x, enode* y) { - if (params().m_arith_mode == AS_NEW_ARITH) return true; + if (params().m_arith_mode == arith_solver_id::AS_NEW_ARITH) return true; context nctx(m, ctx().get_fparams(), ctx().get_params()); add_background(nctx); nctx.assert_expr(m.mk_not(m.mk_eq(x->get_owner(), y->get_owner()))); diff --git a/src/smt/theory_seq.cpp b/src/smt/theory_seq.cpp index 2528f24ed..17875f80a 100644 --- a/src/smt/theory_seq.cpp +++ b/src/smt/theory_seq.cpp @@ -1791,7 +1791,7 @@ void theory_seq::collect_statistics(::statistics & st) const { void theory_seq::init_search_eh() { auto as = get_fparams().m_arith_mode; - if (m_has_seq && as != AS_OLD_ARITH && as != AS_NEW_ARITH) { + if (m_has_seq && as != arith_solver_id::AS_OLD_ARITH && as != arith_solver_id::AS_NEW_ARITH) { throw default_exception("illegal arithmetic solver used with string solver"); } } diff --git a/src/test/egraph.cpp b/src/test/egraph.cpp index e03340bbb..04e950959 100644 --- a/src/test/egraph.cpp +++ b/src/test/egraph.cpp @@ -20,11 +20,13 @@ static expr_ref mk_app(char const* name, expr_ref const& arg, sort* s) { return expr_ref(m.mk_app(f, arg.get()), m); } +#if 0 static expr_ref mk_app(char const* name, expr_ref const& arg1, expr_ref const& arg2, sort* s) { ast_manager& m = arg1.m(); func_decl_ref f(m.mk_func_decl(symbol(name), m.get_sort(arg1), m.get_sort(arg2), s), m); return expr_ref(m.mk_app(f, arg1.get(), arg2.get()), m); } +#endif static void test1() { ast_manager m; diff --git a/src/test/tbv.cpp b/src/test/tbv.cpp index 3b5c5e782..c921551bd 100644 --- a/src/test/tbv.cpp +++ b/src/test/tbv.cpp @@ -74,10 +74,9 @@ static void tst2(unsigned num_bits) { } } +#if 0 // prints all don't care pareto fronts for 8-bit multiplier. static void test_dc() { - unsigned a = 0; - unsigned b = 0; unsigned num_bits = 8; unsigned num_vals = 1 << num_bits; tbv_manager m(num_bits*2); @@ -134,6 +133,7 @@ static void test_dc() { m.deallocate(t); } +#endif void tst_tbv() { // test_dc(); diff --git a/src/util/cmd_context_types.h b/src/util/cmd_context_types.h index 450579d8d..f2d0d2a0c 100644 --- a/src/util/cmd_context_types.h +++ b/src/util/cmd_context_types.h @@ -97,7 +97,7 @@ public: // command invocation void set_line_pos(int line, int pos) { m_line = line; m_pos = pos; } virtual void prepare(cmd_context & ctx) {} - virtual cmd_arg_kind next_arg_kind(cmd_context & ctx) const { UNREACHABLE(); return CPK_UINT; } + virtual cmd_arg_kind next_arg_kind(cmd_context & ctx) const { UNREACHABLE(); return cmd_arg_kind::CPK_UINT; } virtual void set_next_arg(cmd_context & ctx, unsigned val) { UNREACHABLE(); } virtual void set_next_arg(cmd_context & ctx, bool val) { UNREACHABLE(); } virtual void set_next_arg(cmd_context & ctx, rational const & val) { UNREACHABLE(); } diff --git a/src/util/sexpr.cpp b/src/util/sexpr.cpp index 4757330c2..93acf7a45 100644 --- a/src/util/sexpr.cpp +++ b/src/util/sexpr.cpp @@ -29,7 +29,7 @@ struct sexpr_composite : public sexpr { unsigned m_num_chilren; sexpr * m_children[0]; sexpr_composite(unsigned num_children, sexpr * const * children, unsigned line, unsigned pos): - sexpr(COMPOSITE, line, pos), + sexpr(kind_t::COMPOSITE, line, pos), m_num_chilren(num_children) { for (unsigned i = 0; i < num_children; i++) { m_children[i] = children[i]; @@ -45,7 +45,7 @@ struct sexpr_numeral : public sexpr { m_val(val) { } sexpr_numeral(rational const & val, unsigned line, unsigned pos): - sexpr(NUMERAL, line, pos), + sexpr(kind_t::NUMERAL, line, pos), m_val(val) { } }; @@ -53,7 +53,7 @@ struct sexpr_numeral : public sexpr { struct sexpr_bv : public sexpr_numeral { unsigned m_size; sexpr_bv(rational const & val, unsigned size, unsigned line, unsigned pos): - sexpr_numeral(BV_NUMERAL, val, line, pos), + sexpr_numeral(kind_t::BV_NUMERAL, val, line, pos), m_size(size) { } }; @@ -61,11 +61,11 @@ struct sexpr_bv : public sexpr_numeral { struct sexpr_string : public sexpr { std::string m_val; sexpr_string(std::string const & val, unsigned line, unsigned pos): - sexpr(STRING, line, pos), + sexpr(kind_t::STRING, line, pos), m_val(val) { } sexpr_string(char const * val, unsigned line, unsigned pos): - sexpr(STRING, line, pos), + sexpr(kind_t::STRING, line, pos), m_val(val) { } }; @@ -73,7 +73,7 @@ struct sexpr_string : public sexpr { struct sexpr_symbol : public sexpr { symbol m_val; sexpr_symbol(bool keyword, symbol const & val, unsigned line, unsigned pos): - sexpr(keyword ? KEYWORD : SYMBOL, line, pos), + sexpr(keyword ? kind_t::KEYWORD : kind_t::SYMBOL, line, pos), m_val(val) { } }; @@ -122,12 +122,12 @@ sexpr * const * sexpr::get_children() const { void sexpr::display_atom(std::ostream & out) const { switch (get_kind()) { - case sexpr::COMPOSITE: + case sexpr::kind_t::COMPOSITE: UNREACHABLE(); - case sexpr::NUMERAL: + case sexpr::kind_t::NUMERAL: out << static_cast(this)->m_val; break; - case sexpr::BV_NUMERAL: { + case sexpr::kind_t::BV_NUMERAL: { out << '#'; unsigned bv_size = static_cast(this)->m_size; rational val = static_cast(this)->m_val; @@ -172,11 +172,11 @@ void sexpr::display_atom(std::ostream & out) const { out << buf.c_ptr(); break; } - case sexpr::STRING: + case sexpr::kind_t::STRING: out << "\"" << escaped(static_cast(this)->m_val.c_str()) << "\""; break; - case sexpr::SYMBOL: - case sexpr::KEYWORD: + case sexpr::kind_t::SYMBOL: + case sexpr::kind_t::KEYWORD: out << static_cast(this)->m_val; break; default: @@ -220,7 +220,7 @@ void sexpr_manager::del(sexpr * n) { sexpr * n = m_to_delete.back(); m_to_delete.pop_back(); switch (n->get_kind()) { - case sexpr::COMPOSITE: { + case sexpr::kind_t::COMPOSITE: { unsigned num = n->get_num_children(); for (unsigned i = 0; i < num; i++) { sexpr * child = n->get_child(i); @@ -233,20 +233,20 @@ void sexpr_manager::del(sexpr * n) { m_allocator.deallocate(sizeof(sexpr_composite) + num * sizeof(sexpr*), n); break; } - case sexpr::NUMERAL: + case sexpr::kind_t::NUMERAL: static_cast(n)->~sexpr_numeral(); m_allocator.deallocate(sizeof(sexpr_numeral), n); break; - case sexpr::BV_NUMERAL: + case sexpr::kind_t::BV_NUMERAL: static_cast(n)->~sexpr_bv(); m_allocator.deallocate(sizeof(sexpr_bv), n); break; - case sexpr::STRING: + case sexpr::kind_t::STRING: static_cast(n)->~sexpr_string(); m_allocator.deallocate(sizeof(sexpr_string), n); break; - case sexpr::SYMBOL: - case sexpr::KEYWORD: + case sexpr::kind_t::SYMBOL: + case sexpr::kind_t::KEYWORD: static_cast(n)->~sexpr_symbol(); m_allocator.deallocate(sizeof(sexpr_symbol), n); break; diff --git a/src/util/sexpr.h b/src/util/sexpr.h index 8b736441d..1df57288e 100644 --- a/src/util/sexpr.h +++ b/src/util/sexpr.h @@ -27,7 +27,7 @@ class sexpr_manager; class sexpr { public: - enum kind_t { + enum class kind_t { COMPOSITE, NUMERAL, BV_NUMERAL, STRING, KEYWORD, SYMBOL }; protected: @@ -44,12 +44,12 @@ public: unsigned get_line() const { return m_line; } unsigned get_pos() const { return m_pos; } kind_t get_kind() const { return m_kind; } - bool is_composite() const { return get_kind() == COMPOSITE; } - bool is_numeral() const { return get_kind() == NUMERAL; } - bool is_bv_numeral() const { return get_kind() == BV_NUMERAL; } - bool is_string() const { return get_kind() == STRING; } - bool is_keyword() const { return get_kind() == KEYWORD; } - bool is_symbol() const { return get_kind() == SYMBOL; } + bool is_composite() const { return get_kind() == kind_t::COMPOSITE; } + bool is_numeral() const { return get_kind() == kind_t::NUMERAL; } + bool is_bv_numeral() const { return get_kind() == kind_t::BV_NUMERAL; } + bool is_string() const { return get_kind() == kind_t::STRING; } + bool is_keyword() const { return get_kind() == kind_t::KEYWORD; } + bool is_symbol() const { return get_kind() == kind_t::SYMBOL; } rational const & get_numeral() const; unsigned get_bv_size() const; symbol get_symbol() const;