From 6f63f8761c9272e1374224707a7790a341e57c86 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sun, 20 Sep 2020 06:47:27 -0700 Subject: [PATCH] optimizations to bv-solver and euf-egraph (#4698) * additional bit-vector propagators Signed-off-by: Nikolaj Bjorner * rename restrict (not a keyword, but well) #4694, tune euf Signed-off-by: Nikolaj Bjorner * merge Signed-off-by: Nikolaj Bjorner * add pb rewriting to pb2bv #4697 Signed-off-by: Nikolaj Bjorner --- src/ast/ast.cpp | 2 +- src/ast/euf/euf_egraph.cpp | 98 +++++++++++++++++----------- src/ast/euf/euf_egraph.h | 16 +++-- src/ast/euf/euf_enode.h | 11 ++++ src/ast/rewriter/pb2bv_rewriter.h | 2 +- src/muz/base/dl_rule_set.cpp | 8 +-- src/muz/base/dl_rule_set.h | 2 +- src/muz/rel/dl_compiler.cpp | 2 +- src/sat/sat_asymm_branch.cpp | 1 + src/sat/sat_probing.cpp | 2 +- src/sat/sat_probing.h | 2 - src/sat/sat_solver.cpp | 7 +- src/sat/sat_solver.h | 4 +- src/sat/sat_solver_core.h | 1 + src/sat/smt/bv_internalize.cpp | 2 +- src/sat/smt/bv_solver.cpp | 30 ++++++--- src/sat/smt/bv_solver.h | 7 +- src/sat/smt/euf_internalize.cpp | 26 ++++++-- src/sat/smt/euf_solver.cpp | 40 ++++++------ src/sat/smt/sat_th.cpp | 4 +- src/sat/tactic/goal2sat.cpp | 11 +++- src/smt/smt_context.cpp | 2 + src/tactic/arith/pb2bv_tactic.cpp | 39 +++++++---- src/tactic/bv/bv1_blaster_tactic.cpp | 3 +- 24 files changed, 206 insertions(+), 116 deletions(-) diff --git a/src/ast/ast.cpp b/src/ast/ast.cpp index b27d8166e..cec8fea8f 100644 --- a/src/ast/ast.cpp +++ b/src/ast/ast.cpp @@ -1825,7 +1825,7 @@ ast * ast_manager::register_node_core(ast * n) { n->m_id = is_decl(n) ? m_decl_id_gen.mk() : m_expr_id_gen.mk(); // track_id(*this, n, 3); - + TRACE("ast", tout << (s_count++) << " Object " << n->m_id << " was created.\n";); TRACE("mk_var_bug", tout << "mk_ast: " << n->m_id << "\n";); // increment reference counters diff --git a/src/ast/euf/euf_egraph.cpp b/src/ast/euf/euf_egraph.cpp index c2ff7939d..0045d7efa 100644 --- a/src/ast/euf/euf_egraph.cpp +++ b/src/ast/euf/euf_egraph.cpp @@ -27,12 +27,16 @@ namespace euf { r2->dec_class_size(r1->class_size()); std::swap(r1->m_next, r2->m_next); auto begin = r2->begin_parents() + r2_num_parents, end = r2->end_parents(); + // DEBUG_CODE(for (auto it = begin; it != end; ++it) VERIFY(((*it)->merge_enabled()) == m_table.contains(*it));); for (auto it = begin; it != end; ++it) - m_table.erase(*it); + if ((*it)->merge_enabled()) + m_table.erase(*it); for (enode* c : enode_class(r1)) c->m_root = r1; for (auto it = begin; it != end; ++it) - m_table.insert(*it); + if ((*it)->merge_enabled()) + m_table.insert(*it); + r2->m_parents.shrink(r2_num_parents); unmerge_justification(n1); } @@ -48,33 +52,22 @@ namespace euf { return n; } - void egraph::reinsert(enode* n) { - unsigned num_parents = n->m_parents.size(); - for (unsigned i = 0; i < num_parents; ++i) { - enode* p = n->m_parents[i]; - if (is_equality(p)) { - reinsert_equality(p); - } - else { - auto rc = m_table.insert(p); - merge(rc.first, p, justification::congruence(rc.second)); - if (inconsistent()) - break; - } + void egraph::reinsert(enode* p) { + if (p->merge_enabled()) { + auto rc = m_table.insert(p); + merge(rc.first, p, justification::congruence(rc.second)); } + else if (p->is_equality()) + reinsert_equality(p); } void egraph::reinsert_equality(enode* p) { - SASSERT(is_equality(p)); - if (p->get_arg(0)->get_root() == p->get_arg(1)->get_root() && m_value(p) != l_true) { + SASSERT(p->is_equality()); + if (p->value() != l_true && p->get_arg(0)->get_root() == p->get_arg(1)->get_root()) { add_literal(p, true); } } - bool egraph::is_equality(enode* p) const { - return m.is_eq(p->get_expr()); - } - void egraph::force_push() { if (m_num_scopes == 0) return; @@ -103,7 +96,8 @@ namespace euf { n->mark_interpreted(); if (num_args == 0) return n; - if (is_equality(n)) { + if (m.is_eq(f)) { + n->set_is_equality(); update_children(n); reinsert_equality(n); return n; @@ -150,13 +144,15 @@ namespace euf { } void egraph::new_diseq(enode* n1) { - SASSERT(m.is_eq(n1->get_expr())); + SASSERT(n1->is_equality()); enode* arg1 = n1->get_arg(0), * arg2 = n1->get_arg(1); enode* r1 = arg1->get_root(); enode* r2 = arg2->get_root(); TRACE("euf", tout << "new-diseq: " << mk_pp(r1->get_expr(), m) << " " << mk_pp(r2->get_expr(), m) << ": " << r1->has_th_vars() << " " << r2->has_th_vars() << "\n";); - if (r1 == r2) + if (r1 == r2) { + add_literal(n1, true); return; + } if (!r1->has_th_vars()) return; if (!r2->has_th_vars()) @@ -189,7 +185,7 @@ namespace euf { if (!th_propagates_diseqs(id)) return; for (enode* p : enode_parents(r)) { - if (m.is_eq(p->get_expr()) && m.is_false(p->get_root()->get_expr())) { + if (p->is_equality() && p->value() == l_false) { enode* n = nullptr; n = (r == p->get_arg(0)->get_root()) ? p->get_arg(1) : p->get_arg(0); n = n->get_root(); @@ -254,6 +250,13 @@ namespace euf { } } + void egraph::set_value(enode* n, lbool value) { + force_push(); + VERIFY(n->value() == l_undef); + n->set_value(value); + m_updates.push_back(update_record(n, update_record::value_assignment())); + } + void egraph::pop(unsigned num_scopes) { if (num_scopes <= m_num_scopes) { m_num_scopes -= num_scopes; @@ -309,6 +312,10 @@ namespace euf { case update_record::tag_t::is_inconsistent: m_inconsistent = p.m_inconsistent; break; + case update_record::tag_t::is_value_assignment: + VERIFY(p.r1->value() != l_undef); + p.r1->set_value(l_undef); + break; default: UNREACHABLE(); break; @@ -324,12 +331,16 @@ namespace euf { } void egraph::merge(enode* n1, enode* n2, justification j) { + if (!n1->merge_enabled() && !n2->merge_enabled()) { + return; + } SASSERT(m.get_sort(n1->get_expr()) == m.get_sort(n2->get_expr())); enode* r1 = n1->get_root(); enode* r2 = n2->get_root(); if (r1 == r2) return; TRACE("euf", j.display(tout << "merge: " << mk_bounded_pp(n1->get_expr(), m) << " == " << mk_bounded_pp(n2->get_expr(), m) << " ", m_display_justification) << "\n";); + IF_VERBOSE(20, j.display(verbose_stream() << "merge: " << mk_bounded_pp(n1->get_expr(), m) << " == " << mk_bounded_pp(n2->get_expr(), m) << " ", m_display_justification) << "\n";); force_push(); SASSERT(m_num_scopes == 0); ++m_stats.m_num_merge; @@ -337,18 +348,29 @@ namespace euf { set_conflict(n1, n2, j); return; } - if ((r1->class_size() > r2->class_size() && !r2->interpreted()) || r1->interpreted()) { + if ((r1->class_size() > r2->class_size() && !r2->interpreted()) || r1->interpreted() || r1->value() != l_undef) { std::swap(r1, r2); std::swap(n1, n2); } - if ((m.is_true(r2->get_expr()) || m.is_false(r2->get_expr())) && j.is_congruence()) - add_literal(n1, false); - if (m.is_false(r2->get_expr()) && m.is_eq(n1->get_expr())) - new_diseq(n1); - for (enode* p : enode_parents(n1)) - m_table.erase(p); - for (enode* p : enode_parents(n2)) - m_table.erase(p); + if (r1->value() != l_undef) + return; + if (j.is_congruence() && (m.is_false(r2->get_expr()) || m.is_true(r2->get_expr()))) { + add_literal(n1, false); + } + if (n1->is_equality() && r2->value() == l_false) + new_diseq(n1); + unsigned num_merge = 0, num_eqs = 0; + for (enode* p : enode_parents(n1)) { + if (p->merge_enabled()) { + m_table.erase(p); + m_worklist.push_back(p); + ++num_merge; + } + else if (p->is_equality()) { + m_worklist.push_back(p); + ++num_eqs; + } + } push_eq(r1, n1, r2->num_parents()); merge_justification(n1, n2, j); for (enode* c : enode_class(n1)) @@ -357,7 +379,6 @@ namespace euf { r2->inc_class_size(r1->class_size()); r2->m_parents.append(r1->m_parents); merge_th_eq(r1, r2); - m_worklist.push_back(r2); } void egraph::merge_th_eq(enode* n, enode* root) { @@ -383,14 +404,13 @@ namespace euf { unsigned head = 0, tail = m_worklist.size(); while (head < tail && m.limit().inc() && !inconsistent()) { for (unsigned i = head; i < tail && !inconsistent(); ++i) { - enode* n = m_worklist[i]->get_root(); + enode* n = m_worklist[i]; if (!n->is_marked1()) { n->mark1(); - m_worklist[i] = n; reinsert(n); } } - for (unsigned i = head; i < tail; ++i) + for (unsigned i = head; i < tail; ++i) m_worklist[i]->unmark1(); head = tail; tail = m_worklist.size(); @@ -460,7 +480,7 @@ namespace euf { m_tmp_eq->m_expr = eq; SASSERT(m_tmp_eq->num_args() == 2); enode* r = m_table.find(m_tmp_eq); - if (r && m_value(r->get_root()) == l_false) + if (r && r->get_root()->value() == l_false) return true; return false; } diff --git a/src/ast/euf/euf_egraph.h b/src/ast/euf/euf_egraph.h index 52421f43d..164bc4597 100644 --- a/src/ast/euf/euf_egraph.h +++ b/src/ast/euf/euf_egraph.h @@ -90,9 +90,10 @@ namespace euf { struct new_th_eq_qhead {}; struct new_lits_qhead {}; struct inconsistent {}; + struct value_assignment {}; enum class tag_t { is_set_parent, is_add_node, is_toggle_merge, is_add_th_var, is_replace_th_var, is_new_lit, is_new_th_eq, - is_new_th_eq_qhead, is_new_lits_qhead, is_inconsistent }; + is_new_th_eq_qhead, is_new_lits_qhead, is_inconsistent, is_value_assignment }; tag_t tag; enode* r1; enode* n1; @@ -124,7 +125,9 @@ namespace euf { update_record(unsigned qh, new_lits_qhead): tag(tag_t::is_new_lits_qhead), r1(nullptr), n1(nullptr), qhead(qh) {} update_record(bool inc, inconsistent) : - tag(tag_t::is_inconsistent), m_inconsistent(inc) {} + tag(tag_t::is_inconsistent), r1(nullptr), n1(nullptr), m_inconsistent(inc) {} + update_record(enode* n, value_assignment) : + tag(tag_t::is_value_assignment), r1(n), n1(nullptr), qhead(0) {} }; ast_manager& m; enode_vector m_worklist; @@ -151,7 +154,6 @@ namespace euf { std::function m_used_eq; std::function m_used_cc; std::function m_display_justification; - std::function m_value; void push_eq(enode* r1, enode* n1, unsigned r2_num_parents) { m_updates.push_back(update_record(r1, n1, r2_num_parents)); @@ -160,7 +162,6 @@ namespace euf { void add_th_eq(theory_id id, theory_var v1, theory_var v2, enode* c, enode* r); - void new_diseq(enode* n1); void add_th_diseqs(theory_id id, theory_var v1, enode* r); bool th_propagates_diseqs(theory_id id) const; void add_literal(enode* n, bool is_eq); @@ -202,12 +203,12 @@ namespace euf { void push() { ++m_num_scopes; } void pop(unsigned num_scopes); - bool is_equality(enode* n) const; - /** \brief merge nodes, all effects are deferred to the propagation step. */ void merge(enode* n1, enode* n2, void* reason) { merge(n1, n2, justification::external(reason)); } + void new_diseq(enode* n1); + /** \brief propagate set of merges. @@ -243,11 +244,12 @@ namespace euf { void add_th_var(enode* n, theory_var v, theory_id id); void set_th_propagates_diseqs(theory_id id); void set_merge_enabled(enode* n, bool enable_merge); + void set_value(enode* n, lbool value); + void set_bool_var(enode* n, unsigned v) { n->set_bool_var(v); } void set_used_eq(std::function& used_eq) { m_used_eq = used_eq; } void set_used_cc(std::function& used_cc) { m_used_cc = used_cc; } void set_display_justification(std::function & d) { m_display_justification = d; } - void set_eval(std::function& eval) { m_value = eval; } void begin_explain(); void end_explain(); diff --git a/src/ast/euf/euf_enode.h b/src/ast/euf/euf_enode.h index f5a709d0f..b1beacf4e 100644 --- a/src/ast/euf/euf_enode.h +++ b/src/ast/euf/euf_enode.h @@ -17,6 +17,7 @@ Author: #include "util/vector.h" #include "util/id_var_list.h" +#include "util/lbool.h" #include "ast/ast.h" #include "ast/euf/euf_justification.h" @@ -43,6 +44,9 @@ namespace euf { bool m_update_children{ false }; bool m_interpreted{ false }; bool m_merge_enabled{ true }; + bool m_is_equality{ false }; + lbool m_value; + unsigned m_bool_var { UINT_MAX }; unsigned m_class_size{ 1 }; unsigned m_table_id{ UINT_MAX }; enode_vector m_parents; @@ -104,6 +108,9 @@ namespace euf { void replace_th_var(theory_var v, theory_id id) { m_th_vars.replace(v, id); } void del_th_var(theory_id id) { m_th_vars.del_var(id); } void set_merge_enabled(bool m) { m_merge_enabled = m; } + void set_value(lbool v) { m_value = v; } + void set_is_equality() { m_is_equality = true; } + void set_bool_var(unsigned v) { m_bool_var = v; } public: ~enode() { @@ -121,6 +128,10 @@ namespace euf { unsigned num_args() const { return m_num_args; } unsigned num_parents() const { return m_parents.size(); } bool interpreted() const { return m_interpreted; } + bool is_equality() const { return m_is_equality; } + lbool value() const { return m_value; } + unsigned bool_var() const { return m_bool_var; } + bool commutative() const { return m_commutative; } void mark_interpreted() { SASSERT(num_args() == 0); m_interpreted = true; } bool merge_enabled() { return m_merge_enabled; } diff --git a/src/ast/rewriter/pb2bv_rewriter.h b/src/ast/rewriter/pb2bv_rewriter.h index 2c637fa5e..342439402 100644 --- a/src/ast/rewriter/pb2bv_rewriter.h +++ b/src/ast/rewriter/pb2bv_rewriter.h @@ -40,6 +40,6 @@ public: void pop(unsigned num_scopes); void flush_side_constraints(expr_ref_vector& side_constraints); unsigned num_translated() const; - void collect_statistics(statistics & st) const; + void collect_statistics(::statistics & st) const; }; diff --git a/src/muz/base/dl_rule_set.cpp b/src/muz/base/dl_rule_set.cpp index 3a3326903..e8234120a 100644 --- a/src/muz/base/dl_rule_set.cpp +++ b/src/muz/base/dl_rule_set.cpp @@ -137,7 +137,7 @@ namespace datalog { return *e->get_data().get_value(); } - void rule_dependencies::restrict(const item_set & allowed) { + void rule_dependencies::restrict_dependencies(const item_set & allowed) { ptr_vector to_remove; for (auto const& kv : *this) { func_decl * pred = kv.m_key; @@ -154,10 +154,8 @@ namespace datalog { void rule_dependencies::remove(func_decl * itm) { remove_m_data_entry(itm); - for (auto const& kv : *this) { - item_set & itms = *kv.get_value(); - itms.remove(itm); - } + for (auto const& kv : *this) + kv.get_value()->remove(itm); } void rule_dependencies::remove(const item_set & to_remove) { diff --git a/src/muz/base/dl_rule_set.h b/src/muz/base/dl_rule_set.h index e8fdafe0f..7c71ff946 100644 --- a/src/muz/base/dl_rule_set.h +++ b/src/muz/base/dl_rule_set.h @@ -62,7 +62,7 @@ namespace datalog { void populate(const rule_set & rules); void populate(unsigned n, rule * const * rules); - void restrict(const item_set & allowed); + void restrict_dependencies(const item_set & allowed); void remove(func_decl * itm); void remove(const item_set & to_remove); diff --git a/src/muz/rel/dl_compiler.cpp b/src/muz/rel/dl_compiler.cpp index 8604f345d..387331353 100644 --- a/src/muz/rel/dl_compiler.cpp +++ b/src/muz/rel/dl_compiler.cpp @@ -1006,7 +1006,7 @@ namespace datalog { SASSERT(global_deltas.empty()); rule_dependencies deps(m_rule_set.get_dependencies()); - deps.restrict(preds); + deps.restrict_dependencies(preds); cycle_breaker(deps, global_deltas)(); VERIFY( deps.sort_deps(ordered_preds) ); diff --git a/src/sat/sat_asymm_branch.cpp b/src/sat/sat_asymm_branch.cpp index 91bfde2a5..21e6a008d 100644 --- a/src/sat/sat_asymm_branch.cpp +++ b/src/sat/sat_asymm_branch.cpp @@ -171,6 +171,7 @@ namespace sat { TRACE("asymm_branch_detail", s.display(tout);); report rpt(*this); bool_vector saved_phase(s.m_phase); + flet _is_probing(s.m_is_probing, true); bool change = true; unsigned counter = 0; diff --git a/src/sat/sat_probing.cpp b/src/sat/sat_probing.cpp index 510593bbf..f6fac6db7 100644 --- a/src/sat/sat_probing.cpp +++ b/src/sat/sat_probing.cpp @@ -237,7 +237,7 @@ namespace sat { if (m_probing_cache && memory::get_allocation_size() > m_probing_cache_limit) m_cached_bins.finalize(); - flet _probing(m_active, true); + flet _is_probing(s.m_is_probing, true); report rpt(*this); bool r = true; m_counter = 0; diff --git a/src/sat/sat_probing.h b/src/sat/sat_probing.h index dd75d21bf..4b13d6afd 100644 --- a/src/sat/sat_probing.h +++ b/src/sat/sat_probing.h @@ -31,7 +31,6 @@ namespace sat { unsigned m_stopped_at; // where did it stop literal_set m_assigned; // literals assigned in the first branch literal_vector m_to_assert; - bool m_active { false }; // counters int m_counter; // track cost @@ -78,7 +77,6 @@ namespace sat { void collect_statistics(statistics & st) const; void reset_statistics(); - bool active() const { return m_active; } // return the literals implied by l. // return 0, if the cache is not available diff --git a/src/sat/sat_solver.cpp b/src/sat/sat_solver.cpp index 3219d7d60..16c1471c4 100644 --- a/src/sat/sat_solver.cpp +++ b/src/sat/sat_solver.cpp @@ -955,9 +955,11 @@ namespace sat { m_assigned_since_gc[v] = true; m_trail.push_back(l); - if (m_ext && m_external[v]) + if (m_ext && m_external[v] && (!is_probing() || at_base_lvl())) m_ext->asserted(l); - +// else +// std::cout << "assert " << l << "\n"; + switch (m_config.m_branching_heuristic) { case BH_VSIDS: break; @@ -1339,6 +1341,7 @@ namespace sat { m_conflicts_since_restart = 0; m_restart_threshold = m_config.m_restart_initial; } + log_stats(); lbool is_sat = l_undef; while (is_sat == l_undef && !should_cancel()) { if (inconsistent()) is_sat = resolve_conflict_core(); diff --git a/src/sat/sat_solver.h b/src/sat/sat_solver.h index abd4a66d2..cc400201a 100644 --- a/src/sat/sat_solver.h +++ b/src/sat/sat_solver.h @@ -103,6 +103,7 @@ namespace sat { scc m_scc; asymm_branch m_asymm_branch; probing m_probing; + bool m_is_probing { false }; mus m_mus; // MUS for minimal core extraction binspr m_binspr; bool m_inconsistent; @@ -350,6 +351,7 @@ namespace sat { bool was_eliminated(bool_var v) const { return m_eliminated[v]; } void set_eliminated(bool_var v, bool f) override; bool was_eliminated(literal l) const { return was_eliminated(l.var()); } + void set_phase(literal l) override { m_phase[l.var()] = !l.sign(); } unsigned scope_lvl() const { return m_scope_lvl; } unsigned search_lvl() const { return m_search_lvl; } bool at_search_lvl() const { return m_scope_lvl == m_search_lvl; } @@ -662,7 +664,7 @@ namespace sat { public: void set_should_simplify() { m_next_simplify = m_conflicts_since_init; } bool_var_vector const& get_vars_to_reinit() const { return m_vars_to_reinit; } - bool is_probing() const { return m_probing.active(); } + bool is_probing() const { return m_is_probing; } public: void user_push() override; diff --git a/src/sat/sat_solver_core.h b/src/sat/sat_solver_core.h index f6beb72ec..40ac4cbc4 100644 --- a/src/sat/sat_solver_core.h +++ b/src/sat/sat_solver_core.h @@ -81,6 +81,7 @@ namespace sat { virtual void set_external(bool_var v) {} virtual void set_non_external(bool_var v) {} virtual void set_eliminated(bool_var v, bool f) {} + virtual void set_phase(literal l) { } // optional support for user-scopes. Not relevant for sat_tactic integration. // it is only relevant for incremental mode SAT, which isn't wrapped (yet) diff --git a/src/sat/smt/bv_internalize.cpp b/src/sat/smt/bv_internalize.cpp index 376d4c69e..b1d7a5467 100644 --- a/src/sat/smt/bv_internalize.cpp +++ b/src/sat/smt/bv_internalize.cpp @@ -597,7 +597,7 @@ namespace bv { if (a) { if (!a->is_fresh()) ctx.push(add_eq_occurs_trail(a)); - a->m_eqs = new (get_region()) eq_occurs(idx, v1, v2, n, a->m_eqs); + a->m_eqs = new (get_region()) eq_occurs(idx, v1, v2, expr2literal(n->get_expr()), n, a->m_eqs); } } diff --git a/src/sat/smt/bv_solver.cpp b/src/sat/smt/bv_solver.cpp index e022fcb91..aa5295d30 100644 --- a/src/sat/smt/bv_solver.cpp +++ b/src/sat/smt/bv_solver.cpp @@ -152,7 +152,7 @@ namespace bv { SASSERT(m_bits[v1][idx] == ~m_bits[v2][idx]); TRACE("bv", tout << "found new diseq axiom\n" << pp(v1) << pp(v2);); m_stats.m_num_diseq_static++; - expr_ref eq(m.mk_eq(var2expr(v1), var2expr(v2)), m); + expr_ref eq = mk_var_eq(v1, v2); add_unit(~ctx.internalize(eq, false, false, m_is_redundant)); } @@ -252,7 +252,7 @@ namespace bv { force_push(); assert_ackerman(v1, v2); } - else + else m_ackerman.used_diseq_eh(v1, v2); } @@ -410,7 +410,7 @@ namespace bv { } void solver::propagate_eq_occurs(eq_occurs const& occ) { - auto lit = expr2literal(occ.m_node->get_expr()); + auto lit = occ.m_literal; if (s().value(lit) != l_undef) return; lbool val1 = s().value(m_bits[occ.m_v1][occ.m_idx]); @@ -438,16 +438,24 @@ namespace bv { if (val == l_false) bit1.neg(); - for (theory_var v2 = m_find.next(v1); v2 != v1 && !s().inconsistent(); v2 = m_find.next(v2)) { + unsigned num_bits = 0, num_assigned = 0; + for (theory_var v2 = m_find.next(v1); v2 != v1; v2 = m_find.next(v2)) { literal bit2 = m_bits[v2][idx]; SASSERT(m_bits[v1][idx] != ~m_bits[v2][idx]); TRACE("bv", tout << "propagating #" << var2enode(v2)->get_expr_id() << "[" << idx << "] = " << s().value(bit2) << "\n";); if (val == l_false) bit2.neg(); - if (l_true != s().value(bit2)) - assign_bit(bit2, v1, v2, idx, bit1, false); + ++num_bits; + if (num_bits > 4 && num_assigned == 0) + break; + if (s().value(bit2) == l_true) + continue; + ++num_assigned; + if (!assign_bit(bit2, v1, v2, idx, bit1, false)) + break; } + // std::cout << num_bits << " " << num_assigned << "\n"; } sat::check_result solver::check() { @@ -718,7 +726,9 @@ namespace bv { return sat::justification::mk_ext_justification(s().scope_lvl(), constraint->to_index()); } - void solver::assign_bit(literal consequent, theory_var v1, theory_var v2, unsigned idx, literal antecedent, bool propagate_eqc) { + + bool solver::assign_bit(literal consequent, theory_var v1, theory_var v2, unsigned idx, literal antecedent, bool propagate_eqc) { + m_stats.m_num_bit2core++; SASSERT(ctx.s().value(antecedent) == l_true); SASSERT(m_bits[v2][idx].var() == consequent.var()); @@ -727,10 +737,11 @@ namespace bv { if (s().value(consequent) == l_false) { m_stats.m_num_conflicts++; SASSERT(s().inconsistent()); + return false; } else { if (false && get_config().m_bv_eq_axioms) { - expr_ref eq(m.mk_eq(var2expr(v1), var2expr(v2)), m); + expr_ref eq = mk_var_eq(v1, v2); flet _is_redundant(m_is_redundant, true); literal eq_lit = ctx.internalize(eq, false, false, m_is_redundant); add_clause(~antecedent, ~eq_lit, consequent); @@ -744,7 +755,8 @@ namespace bv { if (a && a->is_bit()) for (auto curr : a->to_bit()) if (propagate_eqc || find(curr.first) != find(v2) || curr.second != idx) - m_prop_queue.push_back(propagation_item(curr)); + m_prop_queue.push_back(propagation_item(curr)); + return true; } } diff --git a/src/sat/smt/bv_solver.h b/src/sat/smt/bv_solver.h index 76fe52075..8b4153783 100644 --- a/src/sat/smt/bv_solver.h +++ b/src/sat/smt/bv_solver.h @@ -137,10 +137,11 @@ namespace bv { unsigned m_idx; theory_var m_v1; theory_var m_v2; + sat::literal m_literal; euf::enode* m_node; eq_occurs* m_next; - eq_occurs(unsigned idx, theory_var v1, theory_var v2, euf::enode* n, eq_occurs* next = nullptr): - m_idx(idx), m_v1(v1), m_v2(v2), m_node(n), m_next(next) {} + eq_occurs(unsigned idx, theory_var v1, theory_var v2, sat::literal lit, euf::enode* n, eq_occurs* next = nullptr): + m_idx(idx), m_v1(v1), m_v2(v2), m_literal(lit), m_node(n), m_next(next) {} }; class eq_occurs_it { @@ -278,7 +279,7 @@ namespace bv { void add_fixed_eq(theory_var v1, theory_var v2); svector m_merge_aux[2]; //!< auxiliary vector used in merge_zero_one_bits bool merge_zero_one_bits(theory_var r1, theory_var r2); - void assign_bit(literal consequent, theory_var v1, theory_var v2, unsigned idx, literal antecedent, bool propagate_eqc); + bool assign_bit(literal consequent, theory_var v1, theory_var v2, unsigned idx, literal antecedent, bool propagate_eqc); void propagate_bits(var_pos entry); void propagate_eq_occurs(eq_occurs const& occ); numeral const& power2(unsigned i) const; diff --git a/src/sat/smt/euf_internalize.cpp b/src/sat/smt/euf_internalize.cpp index fd1ba52f4..b91e5c1a3 100644 --- a/src/sat/smt/euf_internalize.cpp +++ b/src/sat/smt/euf_internalize.cpp @@ -82,10 +82,11 @@ namespace euf { void solver::attach_node(euf::enode* n) { expr* e = n->get_expr(); + sat::literal lit; if (!m.is_bool(e)) drat_log_node(e); else - attach_lit(literal(si.add_bool_var(e), false), e); + lit = attach_lit(literal(si.add_bool_var(e), false), e); if (!m.is_bool(e) && m.get_sort(e)->get_family_id() != null_family_id) { auto* e_ext = expr2solver(e); @@ -93,7 +94,7 @@ namespace euf { if (s_ext && s_ext != e_ext) s_ext->apply_sort_cnstr(n, m.get_sort(e)); } - expr* a = nullptr, * b = nullptr; + expr* a = nullptr, * b = nullptr; if (m.is_eq(e, a, b) && m.get_sort(a)->get_family_id() != null_family_id) { auto* s_ext = sort2solver(m.get_sort(a)); if (s_ext) @@ -121,10 +122,12 @@ namespace euf { return lit; m_var2expr[v] = e; m_var_trail.push_back(v); - if (!m_egraph.find(e)) { - enode* n = m_egraph.mk(e, 0, nullptr); + enode* n = m_egraph.find(e); + if (!n) + n = m_egraph.mk(e, 0, nullptr); + m_egraph.set_bool_var(n, v); + if (!m.is_true(e) && !m.is_false(e)) m_egraph.set_merge_enabled(n, false); - } return lit; } @@ -262,6 +265,19 @@ namespace euf { s().add_clause(2, lits1, st); s().add_clause(2, lits2, st); } + else if (m.is_eq(e, th, el) && !m.is_iff(e)) { + sat::literal lit1 = expr2literal(e); + s().set_phase(lit1); + expr_ref e2(m.mk_eq(el, th), m); + enode* n2 = m_egraph.find(e2); + if (n2) { + sat::literal lit2 = expr2literal(e2); + sat::literal lits1[2] = { ~lit1, lit2 }; + sat::literal lits2[2] = { lit1, ~lit2 }; + s().add_clause(2, lits1, st); + s().add_clause(2, lits2, st); + } + } } diff --git a/src/sat/smt/euf_solver.cpp b/src/sat/smt/euf_solver.cpp index 23dff1f5a..60840591a 100644 --- a/src/sat/smt/euf_solver.cpp +++ b/src/sat/smt/euf_solver.cpp @@ -45,12 +45,7 @@ namespace euf { [&](std::ostream& out, void* j) { display_justification_ptr(out, reinterpret_cast(j)); }; - std::function eval = [&](enode* n) { - sat::literal lit = expr2literal(n->get_expr()); - return (lit == sat::null_literal) ? l_undef : s().value(lit); - }; m_egraph.set_display_justification(disp); - m_egraph.set_eval(eval); } void solver::updt_params(params_ref const& p) { @@ -197,7 +192,7 @@ namespace euf { e = m_var2expr[l.var()]; n = m_egraph.find(e); SASSERT(n); - SASSERT(m_egraph.is_equality(n)); + SASSERT(n->is_equality()); SASSERT(!l.sign()); m_egraph.explain_eq(m_explain, n->get_arg(0), n->get_arg(1)); break; @@ -219,28 +214,32 @@ namespace euf { if (!e) return; - bool sign = l.sign(); - + TRACE("euf", tout << "asserted: " << mk_bounded_pp(e, m) << " := " << l << "@" << s().scope_lvl() << "\n";); euf::enode* n = m_egraph.find(e); - TRACE("euf", tout << "asserted: " << l << "@" << s().scope_lvl() << "\n";); if (!n) return; - for (auto th : enode_th_vars(n)) + bool sign = l.sign(); + m_egraph.set_value(n, sign ? l_false : l_true); + auto const & j = s().get_justification(l); + for (auto th : enode_th_vars(n)) m_id2solver[th.get_id()]->asserted(l); - if (!n->merge_enabled()) - return; + size_t* c = to_ptr(l); SASSERT(is_literal(c)); SASSERT(l == get_literal(c)); if (m.is_eq(e) && n->num_args() == 2 && !sign) { + SASSERT(!m.is_iff(e)); euf::enode* na = n->get_arg(0); euf::enode* nb = n->get_arg(1); m_egraph.merge(na, nb, c); } - else { + else if (n->merge_enabled()) { euf::enode* nb = sign ? mk_false() : mk_true(); m_egraph.merge(n, nb, c); } + else if (m.is_eq(e) && n->num_args() == 2 && sign) { + m_egraph.new_diseq(n); + } } @@ -278,7 +277,7 @@ namespace euf { bool is_eq = p.second; expr* e = n->get_expr(); expr* a = nullptr, *b = nullptr; - bool_var v = si.to_bool_var(e); + bool_var v = n->bool_var(); SASSERT(m.is_bool(e)); size_t cnstr; literal lit; @@ -288,10 +287,12 @@ namespace euf { lit = literal(v, false); } else { - a = e, b = n->get_root()->get_expr(); - SASSERT(m.is_true(b) || m.is_false(b)); + lbool val = n->get_root()->value(); + a = e; + b = (val == l_true) ? m.mk_true() : m.mk_false(); + SASSERT(val != l_undef); cnstr = lit_constraint().to_index(); - lit = literal(v, m.is_false(b)); + lit = literal(v, val == l_false); } unsigned lvl = s().scope_lvl(); @@ -457,10 +458,7 @@ namespace euf { auto* ext = bool_var2solver(v); if (ext) return ext->get_phase(v); - expr* e = bool_var2expr(v); - if (e && m.is_eq(e)) - return l_true; - return l_undef; + return l_undef; } bool solver::set_root(literal l, literal r) { diff --git a/src/sat/smt/sat_th.cpp b/src/sat/smt/sat_th.cpp index e0450cd89..ddb65a5b4 100644 --- a/src/sat/smt/sat_th.cpp +++ b/src/sat/smt/sat_th.cpp @@ -135,7 +135,9 @@ namespace euf { return !is_true(a, b, c, d) && (ctx.s().add_clause(4, lits, sat::status::th(m_is_redundant, get_id())), true); } - bool th_euf_solver::is_true(sat::literal lit) { return ctx.s().value(lit) == l_true; } + bool th_euf_solver::is_true(sat::literal lit) { + return ctx.s().value(lit) == l_true; + } euf::enode* th_euf_solver::mk_enode(expr* e, bool suppress_args) { m_args.reset(); diff --git a/src/sat/tactic/goal2sat.cpp b/src/sat/tactic/goal2sat.cpp index 59fba823c..3c001dbbe 100644 --- a/src/sat/tactic/goal2sat.cpp +++ b/src/sat/tactic/goal2sat.cpp @@ -136,7 +136,7 @@ struct goal2sat::imp : public sat::sat_internalizer { void mk_root_clause(sat::literal l) { TRACE("goal2sat", tout << "mk_clause: " << l << "\n";); - m_solver.add_clause(1, &l, m_is_redundant ? mk_status() : sat::status::input()); + m_solver.add_clause(1, &l, m_is_redundant ? mk_status() : sat::status::input()); } void mk_root_clause(sat::literal l1, sat::literal l2) { @@ -191,9 +191,12 @@ struct goal2sat::imp : public sat::sat_internalizer { sat::bool_var to_bool_var(expr* e) override { sat::literal l; + sat::bool_var v = m_map.to_bool_var(e); + if (v != sat::null_bool_var) + return v; if (is_app(e) && m_cache.find(to_app(e), l) && !l.sign()) return l.var(); - return m_map.to_bool_var(e); + return sat::null_bool_var; } @@ -399,6 +402,7 @@ struct goal2sat::imp : public sat::sat_internalizer { if (m_aig) m_aig->add_or(l, num, aig_lits.c_ptr()); + m_solver.set_phase(~l); m_result_stack.shrink(old_sz); if (sign) l.neg(); @@ -450,7 +454,8 @@ struct goal2sat::imp : public sat::sat_internalizer { mk_clause(num+1, lits); if (m_aig) { m_aig->add_and(l, num, aig_lits.c_ptr()); - } + } + m_solver.set_phase(l); if (sign) l.neg(); diff --git a/src/smt/smt_context.cpp b/src/smt/smt_context.cpp index a5d8e6caf..cfe9f01d3 100644 --- a/src/smt/smt_context.cpp +++ b/src/smt/smt_context.cpp @@ -485,6 +485,7 @@ namespace smt { TRACE("add_eq", tout << "redundant constraint.\n";); return; } + IF_VERBOSE(20, verbose_stream() << "merge " << mk_bounded_pp(n1->get_owner(), m) << " " << mk_bounded_pp(n2->get_owner(), m) << "\n"); if (r1->is_interpreted() && r2->is_interpreted()) { TRACE("add_eq", tout << "interpreted roots conflict.\n";); @@ -1408,6 +1409,7 @@ namespace smt { TRACE("propagate_bool_var_enode_bug", tout << "var: " << v << " #" << bool_var2expr(v)->get_id() << "\n";); SASSERT(v < static_cast(m_b_internalized_stack.size())); enode * n = bool_var2enode(v); + CTRACE("mk_bool_var", !n, tout << "No enode for " << v << "\n";); bool sign = val == l_false; if (n->merge_tf()) diff --git a/src/tactic/arith/pb2bv_tactic.cpp b/src/tactic/arith/pb2bv_tactic.cpp index 3cdf4ebf8..d0d9df586 100644 --- a/src/tactic/arith/pb2bv_tactic.cpp +++ b/src/tactic/arith/pb2bv_tactic.cpp @@ -16,19 +16,22 @@ Author: Notes: --*/ -#include "tactic/tactical.h" -#include "tactic/arith/bound_manager.h" -#include "ast/rewriter/bool_rewriter.h" -#include "ast/rewriter/rewriter_def.h" + #include "util/ref_util.h" -#include "ast/arith_decl_plugin.h" #include "util/trace.h" +#include "util/statistics.h" +#include "ast/arith_decl_plugin.h" #include "ast/ast_smt2_pp.h" #include "ast/expr_substitution.h" +#include "ast/ast_pp.h" +#include "ast/rewriter/bool_rewriter.h" +#include "ast/rewriter/rewriter_def.h" +#include "ast/rewriter/pb2bv_rewriter.h" +#include "tactic/tactical.h" +#include "tactic/arith/bound_manager.h" #include "tactic/generic_model_converter.h" #include "tactic/arith/pb2bv_model_converter.h" #include "tactic/arith/pb2bv_tactic.h" -#include "ast/ast_pp.h" class pb2bv_tactic : public tactic { public: @@ -38,11 +41,13 @@ public: typedef rational numeral; ast_manager & m; arith_util & m_util; - bound_manager & m_bm; + pb_util & m_pb; + bound_manager & m_bm; - only_01_visitor(arith_util & u, bound_manager & bm): + only_01_visitor(arith_util & u, pb_util& pb, bound_manager & bm): m(u.get_manager()), m_util(u), + m_pb(pb), m_bm(bm) { } @@ -80,7 +85,10 @@ public: throw_non_pb(n); } } - + + if (fid == m_pb.get_family_id()) + return; + if (is_uninterp_const(n)) { if (m.is_bool(n)) return; // boolean variables are ok @@ -109,8 +117,10 @@ private: ast_manager & m; bound_manager m_bm; bool_rewriter m_b_rw; + pb2bv_rewriter m_pb_rw; arith_util m_arith_util; bv_util m_bv_util; + pb_util m_pb; expr_dependency_ref_vector m_new_deps; bool m_produce_models; @@ -187,7 +197,7 @@ private: void quick_pb_check(goal_ref const & g) { expr_fast_mark1 visited; - only_01_visitor proc(m_arith_util, m_bm); + only_01_visitor proc(m_arith_util, m_pb, m_bm); unsigned sz = g->size(); for (unsigned i = 0; i < sz; i++) { expr * f = g->form(i); @@ -846,8 +856,10 @@ private: m(_m), m_bm(m), m_b_rw(m, p), + m_pb_rw(m, p), m_arith_util(m), m_bv_util(m), + m_pb(m), m_new_deps(m), m_temporary_ints(m), m_used_dependencies(m), @@ -870,6 +882,7 @@ private: m_all_clauses_limit = p.get_uint("pb2bv_all_clauses_limit", 8); m_cardinality_limit = p.get_uint("pb2bv_cardinality_limit", UINT_MAX); m_b_rw.updt_params(p); + m_pb_rw.updt_params(p); } void collect_param_descrs(param_descrs & r) { @@ -878,6 +891,7 @@ private: r.insert("pb2bv_cardinality_limit", CPK_UINT, "(default: inf) limit for using arc-consistent cardinality constraint encoding."); m_b_rw.get_param_descrs(r); + m_pb_rw.collect_param_descrs(r); r.erase("flat"); r.erase("elim_and"); } @@ -925,7 +939,9 @@ private: TRACE("pb2bv_convert", tout << "pos: " << pos << "\n" << mk_ismt2_pp(atom, m) << "\n--->\n" << mk_ismt2_pp(new_f, m) << "\n";); } else { + proof_ref pr(m); m_rw(curr, new_f); + m_pb_rw(true, new_f, new_f, pr); } if (m_produce_unsat_cores) { new_deps.push_back(m.mk_join(m_used_dependencies, g->dep(idx))); @@ -1017,8 +1033,9 @@ struct is_pb_probe : public probe { bound_manager bm(m); bm(g); arith_util a_util(m); + pb_util pb(m); expr_fast_mark1 visited; - pb2bv_tactic::only_01_visitor proc(a_util, bm); + pb2bv_tactic::only_01_visitor proc(a_util, pb, bm); unsigned sz = g.size(); for (unsigned i = 0; i < sz; i++) { diff --git a/src/tactic/bv/bv1_blaster_tactic.cpp b/src/tactic/bv/bv1_blaster_tactic.cpp index 8613b8a9b..bc33b9f76 100644 --- a/src/tactic/bv/bv1_blaster_tactic.cpp +++ b/src/tactic/bv/bv1_blaster_tactic.cpp @@ -27,6 +27,7 @@ Notes: #include "ast/bv_decl_plugin.h" #include "ast/rewriter/rewriter_def.h" #include "ast/for_each_expr.h" +#include "ast/ast_util.h" #include "ast/rewriter/bv_rewriter.h" class bv1_blaster_tactic : public tactic { @@ -141,7 +142,7 @@ class bv1_blaster_tactic : public tactic { --i; new_eqs.push_back(m().mk_eq(bits1[i], bits2[i])); } - result = m().mk_and(new_eqs.size(), new_eqs.c_ptr()); + result = mk_and(m(), new_eqs.size(), new_eqs.c_ptr()); } void reduce_ite(expr * c, expr * t, expr * e, expr_ref & result) {