From 367e5fdd5258eacfc18e1089b305d9004382578d Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Mon, 28 Sep 2020 19:24:16 -0700 Subject: [PATCH] delay internalize (#4714) * adding array solver Signed-off-by: Nikolaj Bjorner * use default in model construction Signed-off-by: Nikolaj Bjorner * debug delay internalization Signed-off-by: Nikolaj Bjorner * bv Signed-off-by: Nikolaj Bjorner * arrays Signed-off-by: Nikolaj Bjorner * get rid of implied values and bounds Signed-off-by: Nikolaj Bjorner * redo egraph * remove out Signed-off-by: Nikolaj Bjorner * remove files Signed-off-by: Nikolaj Bjorner --- src/api/api_solver.cpp | 33 -- src/api/c++/z3++.h | 10 - src/api/python/z3/z3.py | 15 - src/api/z3_api.h | 29 -- src/ast/bv_decl_plugin.cpp | 8 + src/ast/bv_decl_plugin.h | 24 +- src/ast/euf/euf_egraph.cpp | 232 +++++++--- src/ast/euf/euf_egraph.h | 31 +- src/ast/euf/euf_enode.cpp | 18 +- src/ast/euf/euf_enode.h | 18 +- src/ast/euf/euf_etable.cpp | 87 ++-- src/ast/euf/euf_etable.h | 124 +++--- src/ast/rewriter/bv_rewriter.cpp | 38 +- src/math/lp/core_solver_pretty_printer_def.h | 12 +- src/math/lp/lp_core_solver_base.h | 6 +- src/muz/spacer/spacer_iuc_solver.h | 3 - src/opt/opt_solver.h | 3 - src/sat/sat_solver.cpp | 11 +- src/sat/sat_solver.h | 6 +- src/sat/sat_solver/inc_sat_solver.cpp | 28 +- src/sat/smt/array_axioms.cpp | 140 ++++-- src/sat/smt/array_internalize.cpp | 44 +- src/sat/smt/array_model.cpp | 76 +++- src/sat/smt/array_solver.cpp | 11 +- src/sat/smt/array_solver.h | 10 +- src/sat/smt/ba_solver.cpp | 8 +- src/sat/smt/ba_solver.h | 4 +- src/sat/smt/bv_delay_internalize.cpp | 404 +++++++++++++++--- src/sat/smt/bv_internalize.cpp | 162 ++++--- src/sat/smt/bv_invariant.cpp | 4 +- src/sat/smt/bv_solver.cpp | 107 ++--- src/sat/smt/bv_solver.h | 94 ++-- src/sat/smt/euf_internalize.cpp | 31 +- src/sat/smt/euf_model.cpp | 66 ++- src/sat/smt/euf_relevancy.cpp | 24 +- src/sat/smt/euf_solver.cpp | 37 +- src/sat/smt/euf_solver.h | 5 +- src/sat/smt/sat_dual_solver.cpp | 9 +- src/sat/smt/sat_dual_solver.h | 2 +- src/sat/smt/sat_th.cpp | 28 +- src/sat/smt/sat_th.h | 4 +- src/sat/smt/user_solver.cpp | 2 +- src/sat/smt/user_solver.h | 2 +- src/sat/tactic/goal2sat.cpp | 2 +- src/smt/smt_context.cpp | 41 -- src/smt/smt_context.h | 7 - src/smt/smt_context_pp.cpp | 23 +- src/smt/smt_kernel.cpp | 24 -- src/smt/smt_kernel.h | 10 - src/smt/smt_solver.cpp | 12 - src/smt/theory_fpa.cpp | 6 +- src/solver/combined_solver.cpp | 22 - src/solver/solver.h | 11 - src/solver/solver_pool.cpp | 12 - src/solver/tactic2solver.cpp | 13 - .../fd_solver/bounded_int2bv_solver.cpp | 13 - src/tactic/fd_solver/enum2bv_solver.cpp | 12 - src/tactic/fd_solver/pb2bv_solver.cpp | 12 - src/tactic/fd_solver/smtfd_solver.cpp | 12 - src/util/statistics.cpp | 25 +- 60 files changed, 1343 insertions(+), 924 deletions(-) diff --git a/src/api/api_solver.cpp b/src/api/api_solver.cpp index 049ed9f59..e50ee3aad 100644 --- a/src/api/api_solver.cpp +++ b/src/api/api_solver.cpp @@ -559,39 +559,6 @@ extern "C" { Z3_CATCH_RETURN(nullptr); } - Z3_ast Z3_API Z3_solver_get_implied_value(Z3_context c, Z3_solver s, Z3_ast e) { - Z3_TRY; - LOG_Z3_solver_get_implied_value(c, s, e); - RESET_ERROR_CODE(); - init_solver(c, s); - expr_ref v = to_solver_ref(s)->get_implied_value(to_expr(e)); - mk_c(c)->save_ast_trail(v); - RETURN_Z3(of_ast(v)); - Z3_CATCH_RETURN(nullptr); - } - - Z3_ast Z3_API Z3_solver_get_implied_lower(Z3_context c, Z3_solver s, Z3_ast e) { - Z3_TRY; - LOG_Z3_solver_get_implied_lower(c, s, e); - RESET_ERROR_CODE(); - init_solver(c, s); - expr_ref v = to_solver_ref(s)->get_implied_lower_bound(to_expr(e)); - mk_c(c)->save_ast_trail(v); - RETURN_Z3(of_ast(v)); - Z3_CATCH_RETURN(nullptr); - } - - Z3_ast Z3_API Z3_solver_get_implied_upper(Z3_context c, Z3_solver s, Z3_ast e) { - Z3_TRY; - LOG_Z3_solver_get_implied_upper(c, s, e); - RESET_ERROR_CODE(); - init_solver(c, s); - expr_ref v = to_solver_ref(s)->get_implied_upper_bound(to_expr(e)); - mk_c(c)->save_ast_trail(v); - RETURN_Z3(of_ast(v)); - Z3_CATCH_RETURN(nullptr); - } - static Z3_lbool _solver_check(Z3_context c, Z3_solver s, unsigned num_assumptions, Z3_ast const assumptions[]) { for (unsigned i = 0; i < num_assumptions; i++) { if (!is_expr(to_ast(assumptions[i]))) { diff --git a/src/api/c++/z3++.h b/src/api/c++/z3++.h index e3291df51..da94ee195 100644 --- a/src/api/c++/z3++.h +++ b/src/api/c++/z3++.h @@ -2402,16 +2402,6 @@ namespace z3 { void from_file(char const* file) { Z3_solver_from_file(ctx(), m_solver, file); ctx().check_parser_error(); } void from_string(char const* s) { Z3_solver_from_string(ctx(), m_solver, s); ctx().check_parser_error(); } - expr lower(expr const& e) { - Z3_ast r = Z3_solver_get_implied_lower(ctx(), m_solver, e); check_error(); return expr(ctx(), r); - } - expr upper(expr const& e) { - Z3_ast r = Z3_solver_get_implied_upper(ctx(), m_solver, e); check_error(); return expr(ctx(), r); - } - expr value(expr const& e) { - Z3_ast r = Z3_solver_get_implied_value(ctx(), m_solver, e); check_error(); return expr(ctx(), r); - } - check_result check() { Z3_lbool r = Z3_solver_check(ctx(), m_solver); check_error(); return to_check_result(r); } check_result check(unsigned n, expr * const assumptions) { array _assumptions(n); diff --git a/src/api/python/z3/z3.py b/src/api/python/z3/z3.py index 2e2de4e23..fdeb424bf 100644 --- a/src/api/python/z3/z3.py +++ b/src/api/python/z3/z3.py @@ -6857,21 +6857,6 @@ class Solver(Z3PPObject): """ return AstVector(Z3_solver_get_trail(self.ctx.ref(), self.solver), self.ctx) - def value(self, e): - """Return value of term in solver, if any is given. - """ - return _to_expr_ref(Z3_solver_get_implied_value(self.ctx.ref(), self.solver, e.as_ast()), self.ctx) - - def lower(self, e): - """Return lower bound known to solver based on the last call. - """ - return _to_expr_ref(Z3_solver_get_implied_lower(self.ctx.ref(), self.solver, e.as_ast()), self.ctx) - - def upper(self, e): - """Return upper bound known to solver based on the last call. - """ - return _to_expr_ref(Z3_solver_get_implied_upper(self.ctx.ref(), self.solver, e.as_ast()), self.ctx) - def statistics(self): """Return statistics for the last `check()`. diff --git a/src/api/z3_api.h b/src/api/z3_api.h index 4311a0bdb..05c9dfb0c 100644 --- a/src/api/z3_api.h +++ b/src/api/z3_api.h @@ -6499,35 +6499,6 @@ extern "C" { */ void Z3_API Z3_solver_get_levels(Z3_context c, Z3_solver s, Z3_ast_vector literals, unsigned sz, unsigned levels[]); - /** - \brief retrieve implied value for expression, if any is implied by solver at search level. - The method works for expressions that are known to the solver state, such as Boolean and - arithmetical variables. - - def_API('Z3_solver_get_implied_value', AST, (_in(CONTEXT), _in(SOLVER), _in(AST))) - */ - Z3_ast Z3_API Z3_solver_get_implied_value(Z3_context c, Z3_solver s, Z3_ast e); - - /** - \brief retrieve implied lower bound value for arithmetic expression. - If a lower bound is implied at search level, the arithmetic expression returned - is a constant representing the bound. - - def_API('Z3_solver_get_implied_lower', AST, (_in(CONTEXT), _in(SOLVER), _in(AST))) - */ - Z3_ast Z3_API Z3_solver_get_implied_lower(Z3_context c, Z3_solver s, Z3_ast e); - - /** - \brief retrieve implied upper bound value for arithmetic expression. - If an upper bound is implied at search level, the arithmetic expression returned - is a constant representing the bound. - - def_API('Z3_solver_get_implied_upper', AST, (_in(CONTEXT), _in(SOLVER), _in(AST))) - */ - - Z3_ast Z3_API Z3_solver_get_implied_upper(Z3_context c, Z3_solver s, Z3_ast e); - - /** \brief register a user-properator with the solver. diff --git a/src/ast/bv_decl_plugin.cpp b/src/ast/bv_decl_plugin.cpp index cf53d59bf..672fa8b9c 100644 --- a/src/ast/bv_decl_plugin.cpp +++ b/src/ast/bv_decl_plugin.cpp @@ -811,6 +811,14 @@ bool bv_recognizers::is_zero(expr const * n) const { return decl->get_parameter(0).get_rational().is_zero(); } +bool bv_recognizers::is_one(expr const* n) const { + if (!is_app_of(n, get_fid(), OP_BV_NUM)) { + return false; + } + func_decl* decl = to_app(n)->get_decl(); + return decl->get_parameter(0).get_rational().is_one(); +} + bool bv_recognizers::is_extract(expr const* e, unsigned& low, unsigned& high, expr*& b) const { if (!is_extract(e)) return false; low = get_extract_low(e); diff --git a/src/ast/bv_decl_plugin.h b/src/ast/bv_decl_plugin.h index 2f3242dff..252a2d458 100644 --- a/src/ast/bv_decl_plugin.h +++ b/src/ast/bv_decl_plugin.h @@ -298,6 +298,7 @@ public: bool is_numeral(expr const * n) const { return is_app_of(n, get_fid(), OP_BV_NUM); } bool is_allone(expr const * e) const; bool is_zero(expr const * e) const; + bool is_one(expr const* e) const; bool is_bv_sort(sort const * s) const; bool is_bv(expr const* e) const { return is_bv_sort(get_sort(e)); } @@ -349,6 +350,7 @@ public: bool is_bv_lshr(expr const * e) const { return is_app_of(e, get_fid(), OP_BLSHR); } bool is_bv_shl(expr const * e) const { return is_app_of(e, get_fid(), OP_BSHL); } bool is_sign_ext(expr const * e) const { return is_app_of(e, get_fid(), OP_SIGN_EXT); } + bool is_bv_umul_no_ovfl(expr const* e) const { return is_app_of(e, get_fid(), OP_BUMUL_NO_OVFL); } MATCH_BINARY(is_bv_add); MATCH_BINARY(is_bv_mul); @@ -407,10 +409,16 @@ public: return m_manager.mk_app(get_fid(), OP_EXTRACT, 2, params, 1, &n); } app * mk_concat(unsigned num, expr * const * args) { return m_manager.mk_app(get_fid(), OP_CONCAT, num, args); } - app * mk_concat(expr * arg1, expr * arg2) { expr * args[2] = { arg1, arg2 }; return mk_concat(2, args); } app * mk_bv_or(unsigned num, expr * const * args) { return m_manager.mk_app(get_fid(), OP_BOR, num, args); } - app * mk_bv_not(expr * arg) { return m_manager.mk_app(get_fid(), OP_BNOT, arg); } + app * mk_bv_and(unsigned num, expr * const * args) { return m_manager.mk_app(get_fid(), OP_BAND, num, args); } app * mk_bv_xor(unsigned num, expr * const * args) { return m_manager.mk_app(get_fid(), OP_BXOR, num, args); } + + app * mk_concat(expr * arg1, expr * arg2) { expr * args[2] = { arg1, arg2 }; return mk_concat(2, args); } + app * mk_bv_and(expr* x, expr* y) { expr* args[2] = { x, y }; return mk_bv_and(2, args); } + app * mk_bv_or(expr* x, expr* y) { expr* args[2] = { x, y }; return mk_bv_or(2, args); } + app * mk_bv_xor(expr* x, expr* y) { expr* args[2] = { x, y }; return mk_bv_xor(2, args); } + + app * mk_bv_not(expr * arg) { return m_manager.mk_app(get_fid(), OP_BNOT, arg); } app * mk_bv_neg(expr * arg) { return m_manager.mk_app(get_fid(), OP_BNEG, arg); } app * mk_bv_urem(expr * arg1, expr * arg2) const { return m_manager.mk_app(get_fid(), OP_BUREM, arg1, arg2); } app * mk_bv_srem(expr * arg1, expr * arg2) const { return m_manager.mk_app(get_fid(), OP_BSREM, arg1, arg2); } @@ -418,6 +426,18 @@ public: app * mk_bv_add(expr * arg1, expr * arg2) const { return m_manager.mk_app(get_fid(), OP_BADD, arg1, arg2); } app * mk_bv_sub(expr * arg1, expr * arg2) const { return m_manager.mk_app(get_fid(), OP_BSUB, arg1, arg2); } app * mk_bv_mul(expr * arg1, expr * arg2) const { return m_manager.mk_app(get_fid(), OP_BMUL, arg1, arg2); } + app * mk_bv_udiv(expr * arg1, expr * arg2) const { return m_manager.mk_app(get_fid(), OP_BUDIV, arg1, arg2); } + app * mk_bv_udiv_i(expr * arg1, expr * arg2) const { return m_manager.mk_app(get_fid(), OP_BUDIV_I, arg1, arg2); } + app * mk_bv_udiv0(expr * arg) const { return m_manager.mk_app(get_fid(), OP_BUDIV0, arg); } + app * mk_bv_sdiv(expr * arg1, expr * arg2) const { return m_manager.mk_app(get_fid(), OP_BSDIV, arg1, arg2); } + app * mk_bv_sdiv_i(expr * arg1, expr * arg2) const { return m_manager.mk_app(get_fid(), OP_BSDIV_I, arg1, arg2); } + app * mk_bv_sdiv0(expr * arg) const { return m_manager.mk_app(get_fid(), OP_BSDIV0, arg); } + app * mk_bv_srem_i(expr * arg1, expr * arg2) const { return m_manager.mk_app(get_fid(), OP_BSREM_I, arg1, arg2); } + app * mk_bv_srem0(expr * arg) const { return m_manager.mk_app(get_fid(), OP_BSREM0, arg); } + app * mk_bv_urem_i(expr * arg1, expr * arg2) const { return m_manager.mk_app(get_fid(), OP_BUREM_I, arg1, arg2); } + app * mk_bv_urem0(expr * arg) const { return m_manager.mk_app(get_fid(), OP_BUREM0, arg); } + app * mk_bv_smod_i(expr * arg1, expr * arg2) const { return m_manager.mk_app(get_fid(), OP_BSMOD_I, arg1, arg2); } + app * mk_bv_smod0(expr * arg) const { return m_manager.mk_app(get_fid(), OP_BSMOD0, arg); } app * mk_zero_extend(unsigned n, expr* e) { parameter p(n); return m_manager.mk_app(get_fid(), OP_ZERO_EXT, 1, &p, 1, &e); diff --git a/src/ast/euf/euf_egraph.cpp b/src/ast/euf/euf_egraph.cpp index 595a0f644..6963079f5 100644 --- a/src/ast/euf/euf_egraph.cpp +++ b/src/ast/euf/euf_egraph.cpp @@ -13,35 +13,69 @@ Author: Nikolaj Bjorner (nbjorner) 2020-08-23 +Notes: + +Each node has a congruence closure root, cg. +cg is set to the representative in the cc table +(first insertion of congruent node). +Each node n has a set of parents, denoted n.P. + +set r2 to the root of r1: + +Merge: Erase: + for each p r1.P such that p.cg == p: + erase from table + Update root: + r1.root := r2 + Insert: + for each p in r1.P: + p.cg = insert p in table + if p.cg == p: + append p to r2.P + else + add p.cg, p to worklist + +Unmerge: Erase: + for each p in added nodes: + erase p from table + Revert root: + r1.root := r1 + Insert: + for each p in r1.P: + insert p if n was cc root before merge + +condition for being cc root before merge: + p->cg == p + congruent(p, p->cg) + +congruent(p,q) := roots of p.children = roots of q.children + +Example: + +Initially: + n1 := f(a,b) has root n1 + n2 := f(a',b) has root n2 + table = [f(a,b) |-> n1, f(a',b) |-> n2] + +merge(a,a') (a' becomes root) + table = [f(a',b) |-> n2] + n1.cg = n2 + a'.P = [n2] + n1 is not added as parent because it is not a cc root after the assignment a.root := a' + +unmerge(a,a') +- nothing is erased +- n1 is reinserted. It used to be a root. + + --*/ #include "ast/euf/euf_egraph.h" #include "ast/ast_pp.h" -#include "ast/ast_ll_pp.h" #include "ast/ast_translation.h" namespace euf { - - void egraph::undo_eq(enode* r1, enode* n1, unsigned r2_num_parents) { - enode* r2 = r1->get_root(); - r2->dec_class_size(r1->class_size()); - std::swap(r1->m_next, r2->m_next); - auto begin = r2->begin_parents() + r2_num_parents, end = r2->end_parents(); - // DEBUG_CODE(for (auto it = begin; it != end; ++it) VERIFY(((*it)->merge_enabled()) == m_table.contains(*it));); - for (auto it = begin; it != end; ++it) - if ((*it)->merge_enabled()) - m_table.erase(*it); - for (enode* c : enode_class(r1)) - c->m_root = r1; - for (auto it = begin; it != end; ++it) - if ((*it)->merge_enabled()) - m_table.insert(*it); - - r2->m_parents.shrink(r2_num_parents); - unmerge_justification(n1); - } - enode* egraph::mk_enode(expr* f, unsigned num_args, enode * const* args) { enode* n = enode::mk(m_region, f, num_args, args); m_nodes.push_back(n); @@ -53,13 +87,11 @@ namespace euf { return n; } - void egraph::reinsert(enode* p) { - if (p->merge_enabled()) { - auto rc = m_table.insert(p); - merge(rc.first, p, justification::congruence(rc.second)); - } - else if (p->is_equality()) - reinsert_equality(p); + enode_bool_pair egraph::insert_table(enode* p) { + auto rc = m_table.insert(p); + enode* p_other = rc.first; + p->m_cg = rc.first; + return rc; } void egraph::reinsert_equality(enode* p) { @@ -72,6 +104,7 @@ namespace euf { void egraph::force_push() { if (m_num_scopes == 0) return; + // DEBUG_CODE(invariant();); for (; m_num_scopes > 0; --m_num_scopes) { m_scopes.push_back(m_updates.size()); m_region.push_scope(); @@ -103,7 +136,7 @@ namespace euf { reinsert_equality(n); return n; } - enode_bool_pair p = m_table.insert(n); + enode_bool_pair p = insert_table(n); enode* n2 = p.first; if (n2 == n) update_children(n); @@ -151,7 +184,7 @@ namespace euf { enode* arg1 = n->get_arg(0), * arg2 = n->get_arg(1); enode* r1 = arg1->get_root(); enode* r2 = arg2->get_root(); - TRACE("euf", tout << "new-diseq: " << mk_pp(r1->get_expr(), m) << " " << mk_pp(r2->get_expr(), m) << ": " << r1->has_th_vars() << " " << r2->has_th_vars() << "\n";); + TRACE("euf", tout << "new-diseq: " << bpp(r1) << " " << bpp(r2) << ": " << r1->has_th_vars() << " " << r2->has_th_vars() << "\n";); if (r1 == r2) { add_literal(n, true); return; @@ -208,8 +241,6 @@ namespace euf { return m_th_propagates_diseqs.get(id, false); } - - void egraph::add_th_var(enode* n, theory_var v, theory_id id) { force_push(); theory_var w = n->get_th_var(id); @@ -267,6 +298,7 @@ namespace euf { } num_scopes -= m_num_scopes; m_num_scopes = 0; + SASSERT(m_new_lits_qhead <= m_new_lits.size()); unsigned old_lim = m_scopes.size() - num_scopes; @@ -274,8 +306,9 @@ namespace euf { auto undo_node = [&]() { enode* n = m_nodes.back(); expr* e = m_exprs.back(); - if (n->num_args() > 0) + if (n->num_args() > 0 && n->is_cgr()) m_table.erase(n); + m_expr2enode[e->get_id()] = nullptr; n->~enode(); m_nodes.pop_back(); @@ -328,21 +361,24 @@ namespace euf { m_updates.shrink(num_updates); m_scopes.shrink(old_lim); m_region.pop_scope(num_scopes); - m_worklist.reset(); + m_to_merge.reset(); SASSERT(m_new_lits_qhead <= m_new_lits.size()); SASSERT(m_new_th_eqs_qhead <= m_new_th_eqs.size()); + // DEBUG_CODE(invariant();); } void egraph::merge(enode* n1, enode* n2, justification j) { - if (!n1->merge_enabled() && !n2->merge_enabled()) - return; + + if (!n1->merge_enabled() && !n2->merge_enabled()) + return; SASSERT(m.get_sort(n1->get_expr()) == m.get_sort(n2->get_expr())); enode* r1 = n1->get_root(); enode* r2 = n2->get_root(); if (r1 == r2) return; - TRACE("euf", j.display(tout << "merge: " << mk_bounded_pp(n1->get_expr(), m) << " == " << mk_bounded_pp(n2->get_expr(), m) << " ", m_display_justification) << "\n";); - IF_VERBOSE(20, j.display(verbose_stream() << "merge: " << mk_bounded_pp(n1->get_expr(), m) << " == " << mk_bounded_pp(n2->get_expr(), m) << " ", m_display_justification) << "\n";); + + TRACE("euf", j.display(tout << "merge: " << bpp(n1) << " == " << bpp(n2) << " ", m_display_justification) << "\n";); + IF_VERBOSE(20, j.display(verbose_stream() << "merge: " << bpp(n1) << " == " << bpp(n2) << " ", m_display_justification) << "\n";); force_push(); SASSERT(m_num_scopes == 0); ++m_stats.m_num_merge; @@ -356,31 +392,57 @@ namespace euf { } if (r1->value() != l_undef) return; - if (j.is_congruence() && (m.is_false(r2->get_expr()) || m.is_true(r2->get_expr()))) { + if (j.is_congruence() && (m.is_false(r2->get_expr()) || m.is_true(r2->get_expr()))) add_literal(n1, false); - } if (n1->is_equality() && n1->value() == l_false) - new_diseq(n1); - unsigned num_merge = 0, num_eqs = 0; - for (enode* p : enode_parents(n1)) { - if (p->merge_enabled()) { - m_table.erase(p); - m_worklist.push_back(p); - ++num_merge; - } - else if (p->is_equality()) { - m_worklist.push_back(p); - ++num_eqs; - } - } + new_diseq(n1); + remove_parents(r1, r2); push_eq(r1, n1, r2->num_parents()); merge_justification(n1, n2, j); for (enode* c : enode_class(n1)) c->m_root = r2; std::swap(r1->m_next, r2->m_next); r2->inc_class_size(r1->class_size()); - r2->m_parents.append(r1->m_parents); merge_th_eq(r1, r2); + reinsert_parents(r1, r2); + } + + void egraph::remove_parents(enode* r1, enode* r2) { + for (enode* p : enode_parents(r1)) { + if (p->is_marked1()) + continue; + if (p->merge_enabled()) { + if (!p->is_cgr()) + continue; + SASSERT(m_table.contains_ptr(p)); + p->mark1(); + m_table.erase(p); + SASSERT(!m_table.contains_ptr(p)); + } + else if (p->is_equality()) + p->mark1(); + } + } + + void egraph::reinsert_parents(enode* r1, enode* r2) { + for (enode* p : enode_parents(r1)) { + if (!p->is_marked1()) + continue; + p->unmark1(); + if (p->merge_enabled()) { + auto rc = insert_table(p); + enode* p_other = rc.first; + SASSERT(m_table.contains_ptr(p) == (p_other == p)); + if (p_other != p) + m_to_merge.push_back(to_merge(p_other, p, rc.second)); + else + r2->m_parents.push_back(p); + } + else if (p->is_equality()) { + r2->m_parents.push_back(p); + reinsert_equality(p); + } + } } void egraph::merge_th_eq(enode* n, enode* root) { @@ -400,24 +462,45 @@ namespace euf { } } + void egraph::undo_eq(enode* r1, enode* n1, unsigned r2_num_parents) { + enode* r2 = r1->get_root(); + TRACE("euf", tout << "undo-eq old-root: " << bpp(r1) << " current-root " << bpp(r2) << " node: " << bpp(n1) << "\n";); + r2->dec_class_size(r1->class_size()); + std::swap(r1->m_next, r2->m_next); + auto begin = r2->begin_parents() + r2_num_parents, end = r2->end_parents(); + for (auto it = begin; it != end; ++it) { + enode* p = *it; + TRACE("euf", tout << "erase " << bpp(p) << "\n";); + SASSERT(!p->merge_enabled() || m_table.contains_ptr(p)); + SASSERT(!p->merge_enabled() || p->is_cgr()); + if (p->merge_enabled()) + m_table.erase(p); + } + + for (enode* c : enode_class(r1)) + c->m_root = r1; + + for (enode* p : enode_parents(r1)) + if (p->merge_enabled() && (p->is_cgr() || !p->congruent(p->m_cg))) + insert_table(p); + r2->m_parents.shrink(r2_num_parents); + unmerge_justification(n1); + } + + bool egraph::propagate() { SASSERT(m_new_lits_qhead <= m_new_lits.size()); - SASSERT(m_num_scopes == 0 || m_worklist.empty()); - unsigned head = 0, tail = m_worklist.size(); + SASSERT(m_num_scopes == 0 || m_to_merge.empty()); + unsigned head = 0, tail = m_to_merge.size(); while (head < tail && m.limit().inc() && !inconsistent()) { for (unsigned i = head; i < tail && !inconsistent(); ++i) { - enode* n = m_worklist[i]; - if (!n->is_marked1()) { - n->mark1(); - reinsert(n); - } + auto const& w = m_to_merge[i]; + merge(w.a, w.b, justification::congruence(w.commutativity)); } - for (unsigned i = head; i < tail; ++i) - m_worklist[i]->unmark1(); head = tail; - tail = m_worklist.size(); + tail = m_to_merge.size(); } - m_worklist.reset(); + m_to_merge.reset(); force_push(); return (m_new_lits_qhead < m_new_lits.size()) || @@ -565,10 +648,7 @@ namespace euf { SASSERT(a->get_root() == b->get_root()); enode* lca = find_lca(a, b); - TRACE("euf_verbose", tout << "explain-eq: " << a->get_expr_id() << " = " << b->get_expr_id() - << ": " << mk_bounded_pp(a->get_expr(), m) - << " == " << mk_bounded_pp(b->get_expr(), m) - << " lca: " << mk_bounded_pp(lca->get_expr(), m) << "\n";); + TRACE("euf_verbose", tout << "explain-eq: " << bpp(a) << " == " << bpp(b) << " lca: " << bpp(lca) << "\n";); push_to_lca(a, lca); push_to_lca(b, lca); if (m_used_eq) @@ -590,7 +670,16 @@ namespace euf { void egraph::invariant() { for (enode* n : m_nodes) - n->invariant(); + n->invariant(*this); + for (enode* n : m_nodes) + if (n->merge_enabled() && n->num_args() > 0 && (!m_table.find(n) || n->get_root() != m_table.find(n)->get_root())) { + CTRACE("euf", !m_table.find(n), tout << "node is not in table\n";); + CTRACE("euf", m_table.find(n), tout << "root " << bpp(n->get_root()) << " table root " << bpp(m_table.find(n)->get_root()) << "\n";); + TRACE("euf", display(tout << bpp(n) << " is not closed under congruence\n");); + UNREACHABLE(); + } + + } std::ostream& egraph::display(std::ostream& out, unsigned max_args, enode* n) const { @@ -653,13 +742,14 @@ namespace euf { for (unsigned i = 0; i < src.m_nodes.size(); ++i) { enode* n1 = src.m_nodes[i]; expr* e1 = src.m_exprs[i]; - SASSERT(!n1->has_th_vars()); args.reset(); for (unsigned j = 0; j < n1->num_args(); ++j) args.push_back(old_expr2new_enode[n1->get_arg(j)->get_expr_id()]); expr* e2 = tr(e1); enode* n2 = mk(e2, args.size(), args.c_ptr()); old_expr2new_enode.setx(e1->get_id(), n2, nullptr); + n2->set_value(n2->value()); + n2->m_bool_var = n1->m_bool_var; } for (unsigned i = 0; i < src.m_nodes.size(); ++i) { enode* n1 = src.m_nodes[i]; diff --git a/src/ast/euf/euf_egraph.h b/src/ast/euf/euf_egraph.h index 44b77ded3..8e1f79864 100644 --- a/src/ast/euf/euf_egraph.h +++ b/src/ast/euf/euf_egraph.h @@ -18,9 +18,11 @@ Notes: It relies on - data structures form the (legacy) SMT solver. - it still uses eager path compression. - - delayed congruence table reconstruction from egg. - - it does not deduplicate parents. + NB. The worklist is in reality inheritied from the legacy SMT solver. + It is claimed to have the same effect as delayed congruence table reconstruction from egg. + Similar to the legacy solver, parents are partially deduplicated. + --*/ #pragma once @@ -29,6 +31,7 @@ Notes: #include "util/lbool.h" #include "ast/euf/euf_enode.h" #include "ast/euf/euf_etable.h" +#include "ast/ast_ll_pp.h" namespace euf { @@ -70,7 +73,15 @@ namespace euf { }; class egraph { + typedef ptr_vector > trail_stack; + + struct to_merge { + enode* a, * b; + bool commutativity; + to_merge(enode* a, enode* b, bool c) : a(a), b(b), commutativity(c) {} + }; + struct stats { unsigned m_num_merge; unsigned m_num_th_eqs; @@ -130,7 +141,7 @@ namespace euf { tag(tag_t::is_value_assignment), r1(n), n1(nullptr), qhead(0) {} }; ast_manager& m; - enode_vector m_worklist; + svector m_to_merge; etable m_table; region m_region; svector m_updates; @@ -168,12 +179,13 @@ namespace euf { void undo_eq(enode* r1, enode* n1, unsigned r2_num_parents); void undo_add_th_var(enode* n, theory_id id); enode* mk_enode(expr* f, unsigned num_args, enode * const* args); - void reinsert(enode* n); void force_push(); void set_conflict(enode* n1, enode* n2, justification j); void merge(enode* n1, enode* n2, justification j); void merge_th_eq(enode* n, enode* root); void merge_justification(enode* n1, enode* n2, justification j); + void reinsert_parents(enode* r1, enode* r2); + void remove_parents(enode* r1, enode* r2); void unmerge_justification(enode* n1); void reinsert_equality(enode* p); void update_children(enode* n); @@ -183,6 +195,8 @@ namespace euf { void push_congruence(enode* n1, enode* n2, bool commutative); void push_todo(enode* n); + enode_bool_pair insert_table(enode* p); + template void explain_eq(ptr_vector& justifications, enode* a, enode* b, justification const& j) { if (j.is_external()) @@ -267,7 +281,15 @@ namespace euf { std::ostream& display(std::ostream& out) const { return g.display(out, 0, n); } }; e_pp pp(enode* n) const { return e_pp(*this, n); } + struct b_pp { + egraph const& g; + enode* n; + b_pp(egraph const& g, enode* n) : g(g), n(n) {} + std::ostream& display(std::ostream& out) const { return out << n->get_expr_id() << ": " << mk_bounded_pp(n->get_expr(), g.m); } + }; + b_pp bpp(enode* n) const { return b_pp(*this, n); } std::ostream& display(std::ostream& out) const; + void collect_statistics(statistics& st) const; unsigned num_scopes() const { return m_scopes.size() + m_num_scopes; } @@ -275,4 +297,5 @@ namespace euf { inline std::ostream& operator<<(std::ostream& out, egraph const& g) { return g.display(out); } inline std::ostream& operator<<(std::ostream& out, egraph::e_pp const& p) { return p.display(out); } + inline std::ostream& operator<<(std::ostream& out, egraph::b_pp const& p) { return p.display(out); } } diff --git a/src/ast/euf/euf_enode.cpp b/src/ast/euf/euf_enode.cpp index c88d34eef..0a495cb20 100644 --- a/src/ast/euf/euf_enode.cpp +++ b/src/ast/euf/euf_enode.cpp @@ -16,10 +16,11 @@ Author: --*/ #include "ast/euf/euf_enode.h" +#include "ast/euf/euf_egraph.h" namespace euf { - void enode::invariant() { + void enode::invariant(egraph& g) { unsigned class_size = 0; bool found_root = false; bool found_this = false; @@ -27,6 +28,7 @@ namespace euf { VERIFY(c->m_root == m_root); found_root |= c == m_root; found_this |= c == this; + ++class_size; } VERIFY(found_root); VERIFY(found_this); @@ -34,20 +36,26 @@ namespace euf { if (is_root()) { VERIFY(!m_target); for (enode* p : enode_parents(this)) { + if (!p->merge_enabled()) + continue; bool found = false; for (enode* arg : enode_args(p)) { found |= arg->get_root() == this; } + CTRACE("euf", !found, tout << g.bpp(p) << " does not have a child with root: " << g.bpp(this) << "\n";); VERIFY(found); } for (enode* c : enode_class(this)) { if (c == this) continue; for (enode* p : enode_parents(c)) { + if (!p->merge_enabled()) + continue; bool found = false; for (enode* q : enode_parents(this)) { found |= p->congruent(q); } + CTRACE("euf", !found, tout << "parent " << g.bpp(p) << " of " << g.bpp(c) << " is not congruent to a parent of " << g.bpp(this) << "\n";); VERIFY(found); } } @@ -118,7 +126,6 @@ namespace euf { prev->m_target = nullptr; prev->m_justification = justification::axiom(); while (curr != nullptr) { - enode* new_curr = curr->m_target; justification new_js = curr->m_justification; curr->m_target = prev; @@ -128,4 +135,11 @@ namespace euf { curr = new_curr; } } + + bool enode::children_are_roots() const { + for (auto* child : enode_args(this)) + if (!child->is_root()) + return false; + return true; + } } diff --git a/src/ast/euf/euf_enode.h b/src/ast/euf/euf_enode.h index b1beacf4e..39e56611f 100644 --- a/src/ast/euf/euf_enode.h +++ b/src/ast/euf/euf_enode.h @@ -30,6 +30,8 @@ namespace euf { typedef ptr_vector enode_vector; typedef std::pair enode_pair; typedef svector enode_pair_vector; + typedef std::pair enode_bool_pair; + typedef svector enode_bool_pair_vector; typedef id_var_list<> th_var_list; typedef int theory_var; typedef int theory_id; @@ -48,11 +50,12 @@ namespace euf { lbool m_value; unsigned m_bool_var { UINT_MAX }; unsigned m_class_size{ 1 }; - unsigned m_table_id{ UINT_MAX }; + unsigned m_table_id{ UINT_MAX }; enode_vector m_parents; enode* m_next{ nullptr }; enode* m_root{ nullptr }; enode* m_target{ nullptr }; + enode* m_cg { nullptr }; th_var_list m_th_vars; justification m_justification; unsigned m_num_args{ 0 }; @@ -102,6 +105,7 @@ namespace euf { void set_update_children() { m_update_children = true; } + friend class add_th_var_trail; friend class replace_th_var_trail; void add_th_var(theory_var v, theory_id id, region & r) { m_th_vars.add_var(v, id, r); } @@ -131,7 +135,7 @@ namespace euf { bool is_equality() const { return m_is_equality; } lbool value() const { return m_value; } unsigned bool_var() const { return m_bool_var; } - + bool is_cgr() const { return this == m_cg; } bool commutative() const { return m_commutative; } void mark_interpreted() { SASSERT(num_args() == 0); m_interpreted = true; } bool merge_enabled() { return m_merge_enabled; } @@ -172,6 +176,8 @@ namespace euf { func_decl* get_decl() const { return is_app(m_expr) ? to_app(m_expr)->get_decl() : nullptr; } unsigned get_expr_id() const { return m_expr->get_id(); } unsigned get_root_id() const { return m_root->m_expr->get_id(); } + bool children_are_roots() const; + theory_var get_th_var(theory_id id) const { return m_th_vars.find(id); } theory_var get_closest_th_var(theory_id id) const; bool is_attached_to(theory_id id) const { return get_th_var(id) != null_theory_var; } @@ -190,15 +196,15 @@ namespace euf { enode* const* begin_parents() const { return m_parents.begin(); } enode* const* end_parents() const { return m_parents.end(); } - void invariant(); + void invariant(class egraph& g); bool congruent(enode* n) const; }; class enode_args { - enode& n; + enode const& n; public: - enode_args(enode& _n):n(_n) {} - enode_args(enode* _n):n(*_n) {} + enode_args(enode const& _n):n(_n) {} + enode_args(enode const* _n):n(*_n) {} enode* const* begin() const { return n.m_args; } enode* const* end() const { return n.m_args + n.num_args(); } }; diff --git a/src/ast/euf/euf_etable.cpp b/src/ast/euf/euf_etable.cpp index 83446c024..5ddba7dd5 100644 --- a/src/ast/euf/euf_etable.cpp +++ b/src/ast/euf/euf_etable.cpp @@ -22,7 +22,7 @@ namespace euf { // one table per func_decl implementation unsigned etable::cg_hash::operator()(enode * n) const { - SASSERT(n->get_decl()->is_flat_associative() || n->num_args() >= 3); + SASSERT(decl(n)->is_flat_associative() || num_args(n) >= 3); unsigned a, b, c; a = b = 0x9e3779b9; c = 11; @@ -30,33 +30,33 @@ namespace euf { unsigned i = n->num_args(); while (i >= 3) { i--; - a += n->get_arg(i)->get_root()->hash(); + a += get_root(n, i)->hash(); i--; - b += n->get_arg(i)->get_root()->hash(); + b += get_root(n, i)->hash(); i--; - c += n->get_arg(i)->get_root()->hash(); + c += get_root(n, i)->hash(); mix(a, b, c); } switch (i) { case 2: - b += n->get_arg(1)->get_root()->hash(); + b += get_root(n, 1)->hash(); Z3_fallthrough; case 1: - c += n->get_arg(0)->get_root()->hash(); + c += get_root(n, 0)->hash(); } mix(a, b, c); return c; } bool etable::cg_eq::operator()(enode * n1, enode * n2) const { - SASSERT(n1->get_decl() == n2->get_decl()); - unsigned num = n1->num_args(); - if (num != n2->num_args()) { + SASSERT(decl(n1) == decl(n2)); + unsigned num = num_args(n1); + if (num != num_args(n2)) { return false; } for (unsigned i = 0; i < num; i++) - if (n1->get_arg(i)->get_root() != n2->get_arg(i)->get_root()) + if (get_root(n1, i) != get_root(n2, i)) return false; return true; } @@ -69,31 +69,25 @@ namespace euf { reset(); } - void * etable::mk_table_for(func_decl * d) { + void * etable::mk_table_for(unsigned arity, func_decl * d) { void * r; SASSERT(d->get_arity() >= 1); - switch (d->get_arity()) { + SASSERT(arity >= d->get_arity()); + switch (arity) { case 1: r = TAG(void*, alloc(unary_table), UNARY); SASSERT(GET_TAG(r) == UNARY); return r; case 2: - if (d->is_flat_associative()) { - // applications of declarations that are flat-assoc (e.g., +) may have many arguments. - r = TAG(void*, alloc(table), NARY); - SASSERT(GET_TAG(r) == NARY); - return r; - } - else if (d->is_commutative()) { + if (d->is_commutative()) { r = TAG(void*, alloc(comm_table, cg_comm_hash(), cg_comm_eq(m_commutativity)), BINARY_COMM); SASSERT(GET_TAG(r) == BINARY_COMM); - return r; } else { r = TAG(void*, alloc(binary_table), BINARY); SASSERT(GET_TAG(r) == BINARY); - return r; } + return r; default: r = TAG(void*, alloc(table), NARY); SASSERT(GET_TAG(r) == NARY); @@ -104,18 +98,20 @@ namespace euf { unsigned etable::set_table_id(enode * n) { func_decl * f = n->get_decl(); unsigned tid; - if (!m_func_decl2id.find(f, tid)) { + decl_info d(f, n->num_args()); + if (!m_func_decl2id.find(d, tid)) { tid = m_tables.size(); - m_func_decl2id.insert(f, tid); + m_func_decl2id.insert(d, tid); m_manager.inc_ref(f); SASSERT(tid <= m_tables.size()); - m_tables.push_back(mk_table_for(f)); + m_tables.push_back(mk_table_for(n->num_args(), f)); } SASSERT(tid < m_tables.size()); n->set_table_id(tid); DEBUG_CODE({ - unsigned tid_prime; - SASSERT(m_func_decl2id.find(n->get_decl(), tid_prime) && tid == tid_prime); + decl_info d(n->get_decl(), n->num_args()); + SASSERT(m_func_decl2id.contains(d)); + SASSERT(m_func_decl2id[d] == tid); }); return tid; } @@ -139,7 +135,7 @@ namespace euf { } m_tables.reset(); for (auto const& kv : m_func_decl2id) { - m_manager.dec_ref(kv.m_key); + m_manager.dec_ref(kv.m_key.first); } m_func_decl2id.reset(); } @@ -147,7 +143,7 @@ namespace euf { void etable::display(std::ostream & out) const { for (auto const& kv : m_func_decl2id) { void * t = m_tables[kv.m_value]; - out << mk_pp(kv.m_key, m_manager) << ": "; + out << mk_pp(kv.m_key.first, m_manager) << ": "; switch (GET_TAG(t)) { case UNARY: display_unary(out, t); @@ -245,5 +241,40 @@ namespace euf { } } + bool etable::contains(enode* n) const { + SASSERT(n->num_args() > 0); + void* t = const_cast(this)->get_table(n); + switch (static_cast(GET_TAG(t))) { + case UNARY: + return UNTAG(unary_table*, t)->contains(n); + case BINARY: + return UNTAG(binary_table*, t)->contains(n); + case BINARY_COMM: + return UNTAG(comm_table*, t)->contains(n); + default: + return UNTAG(table*, t)->contains(n); + } + } + + enode* etable::find(enode* n) const { + SASSERT(n->num_args() > 0); + enode* r = nullptr; + void* t = const_cast(this)->get_table(n); + switch (static_cast(GET_TAG(t))) { + case UNARY: + return UNTAG(unary_table*, t)->find(n, r) ? r : nullptr; + case BINARY: + return UNTAG(binary_table*, t)->find(n, r) ? r : nullptr; + case BINARY_COMM: + return UNTAG(comm_table*, t)->find(n, r) ? r : nullptr; + default: + return UNTAG(table*, t)->find(n, r) ? r : nullptr; + } + } + + bool etable::contains_ptr(enode* n) const { + return find(n) == n; + } + }; diff --git a/src/ast/euf/euf_etable.h b/src/ast/euf/euf_etable.h index 5981f7aec..8ec27a814 100644 --- a/src/ast/euf/euf_etable.h +++ b/src/ast/euf/euf_etable.h @@ -1,5 +1,4 @@ -/*++ -Copyright (c) 2006 Microsoft Corporation +/*++ Copyright (c) 2006 Microsoft Corporation Module Name: @@ -22,28 +21,33 @@ Revision History: #include "util/chashtable.h" namespace euf { - - typedef std::pair enode_bool_pair; // one table per function symbol + static unsigned num_args(enode* n) { return n->num_args(); } + static func_decl* decl(enode* n) { return n->get_decl(); } + + /** \brief Congruence table. */ class etable { + static enode* get_root(enode* n, unsigned idx) { return n->get_arg(idx)->get_root(); } + struct cg_unary_hash { unsigned operator()(enode * n) const { - SASSERT(n->num_args() == 1); - return n->get_arg(0)->get_root()->hash(); + SASSERT(num_args(n) == 1); + return get_root(n, 0)->hash(); } }; struct cg_unary_eq { + bool operator()(enode * n1, enode * n2) const { - SASSERT(n1->num_args() == 1); - SASSERT(n2->num_args() == 1); - SASSERT(n1->get_decl() == n2->get_decl()); - return n1->get_arg(0)->get_root() == n2->get_arg(0)->get_root(); + SASSERT(num_args(n1) == 1); + SASSERT(num_args(n2) == 1); + SASSERT(decl(n1) == decl(n2)); + return get_root(n1, 0) == get_root(n2, 0); } }; @@ -51,19 +55,19 @@ namespace euf { struct cg_binary_hash { unsigned operator()(enode * n) const { - SASSERT(n->num_args() == 2); - return combine_hash(n->get_arg(0)->get_root()->hash(), n->get_arg(1)->get_root()->hash()); + SASSERT(num_args(n) == 2); + return combine_hash(get_root(n, 0)->hash(), get_root(n, 1)->hash()); } }; struct cg_binary_eq { bool operator()(enode * n1, enode * n2) const { - SASSERT(n1->num_args() == 2); - SASSERT(n2->num_args() == 2); - SASSERT(n1->get_decl() == n2->get_decl()); - return - n1->get_arg(0)->get_root() == n2->get_arg(0)->get_root() && - n1->get_arg(1)->get_root() == n2->get_arg(1)->get_root(); + SASSERT(num_args(n1) == 2); + SASSERT(num_args(n2) == 2); + SASSERT(decl(n1) == decl(n2)); + return + get_root(n1, 0) == get_root(n2, 0) && + get_root(n1, 1) == get_root(n2, 1); } }; @@ -71,9 +75,9 @@ namespace euf { struct cg_comm_hash { unsigned operator()(enode * n) const { - SASSERT(n->num_args() == 2); - unsigned h1 = n->get_arg(0)->get_root()->hash(); - unsigned h2 = n->get_arg(1)->get_root()->hash(); + SASSERT(num_args(n) == 2); + unsigned h1 = get_root(n, 0)->hash(); + unsigned h2 = get_root(n, 1)->hash(); if (h1 > h2) std::swap(h1, h2); return hash_u((h1 << 16) | (h2 & 0xFFFF)); @@ -82,15 +86,16 @@ namespace euf { struct cg_comm_eq { bool & m_commutativity; - cg_comm_eq(bool & c):m_commutativity(c) {} + cg_comm_eq( bool & c): m_commutativity(c) {} bool operator()(enode * n1, enode * n2) const { - SASSERT(n1->num_args() == 2); - SASSERT(n2->num_args() == 2); - SASSERT(n1->get_decl() == n2->get_decl()); - enode * c1_1 = n1->get_arg(0)->get_root(); - enode * c1_2 = n1->get_arg(1)->get_root(); - enode * c2_1 = n2->get_arg(0)->get_root(); - enode * c2_2 = n2->get_arg(1)->get_root(); + SASSERT(num_args(n1) == 2); + SASSERT(num_args(n2) == 2); + + SASSERT(decl(n1) == decl(n2)); + enode* c1_1 = get_root(n1, 0); + enode* c1_2 = get_root(n1, 1); + enode* c2_1 = get_root(n2, 0); + enode* c2_2 = get_root(n2, 1); if (c1_1 == c2_1 && c1_2 == c2_2) { return true; } @@ -113,11 +118,19 @@ namespace euf { }; typedef chashtable table; + typedef std::pair decl_info; + struct decl_hash { + unsigned operator()(decl_info const& d) const { return d.first->hash(); } + }; + struct decl_eq { + bool operator()(decl_info const& a, decl_info const& b) const { return a == b; } + }; + ast_manager & m_manager; - bool m_commutativity; //!< true if the last found congruence used commutativity + bool m_commutativity{ false }; //!< true if the last found congruence used commutativity ptr_vector m_tables; - obj_map m_func_decl2id; + map m_func_decl2id; enum table_kind { UNARY, @@ -126,7 +139,7 @@ namespace euf { NARY }; - void * mk_table_for(func_decl * d); + void * mk_table_for(unsigned n, func_decl * d); unsigned set_table_id(enode * n); void * get_table(enode * n) { @@ -157,52 +170,11 @@ namespace euf { void erase(enode * n); - bool contains(enode * n) const { - SASSERT(n->num_args() > 0); - void * t = const_cast(this)->get_table(n); - switch (static_cast(GET_TAG(t))) { - case UNARY: - return UNTAG(unary_table*, t)->contains(n); - case BINARY: - return UNTAG(binary_table*, t)->contains(n); - case BINARY_COMM: - return UNTAG(comm_table*, t)->contains(n); - default: - return UNTAG(table*, t)->contains(n); - } - } + bool contains(enode* n) const; - enode * find(enode * n) const { - SASSERT(n->num_args() > 0); - enode * r = nullptr; - void * t = const_cast(this)->get_table(n); - switch (static_cast(GET_TAG(t))) { - case UNARY: - return UNTAG(unary_table*, t)->find(n, r) ? r : nullptr; - case BINARY: - return UNTAG(binary_table*, t)->find(n, r) ? r : nullptr; - case BINARY_COMM: - return UNTAG(comm_table*, t)->find(n, r) ? r : nullptr; - default: - return UNTAG(table*, t)->find(n, r) ? r : nullptr; - } - } + enode* find(enode* n) const; - bool contains_ptr(enode * n) const { - enode * r; - SASSERT(n->num_args() > 0); - void * t = const_cast(this)->get_table(n); - switch (static_cast(GET_TAG(t))) { - case UNARY: - return UNTAG(unary_table*, t)->find(n, r) && n == r; - case BINARY: - return UNTAG(binary_table*, t)->find(n, r) && n == r; - case BINARY_COMM: - return UNTAG(comm_table*, t)->find(n, r) && n == r; - default: - return UNTAG(table*, t)->find(n, r) && n == r; - } - } + bool contains_ptr(enode* n) const; void reset(); diff --git a/src/ast/rewriter/bv_rewriter.cpp b/src/ast/rewriter/bv_rewriter.cpp index 3e2b2c614..a0115cc9b 100644 --- a/src/ast/rewriter/bv_rewriter.cpp +++ b/src/ast/rewriter/bv_rewriter.cpp @@ -1012,7 +1012,7 @@ br_status bv_rewriter::mk_bv_sdiv_core(expr * arg1, expr * arg2, bool hi_div0, e r2 = m_util.norm(r2, bv_size, true); if (r2.is_zero()) { if (!hi_div0) { - result = m().mk_app(get_fid(), OP_BSDIV0, arg1); + result = m_util.mk_bv_sdiv0(arg1); return BR_REWRITE1; } else { @@ -1035,19 +1035,19 @@ br_status bv_rewriter::mk_bv_sdiv_core(expr * arg1, expr * arg2, bool hi_div0, e return BR_DONE; } - result = m().mk_app(get_fid(), OP_BSDIV_I, arg1, arg2); + result = m_util.mk_bv_sdiv_i(arg1, arg2); return BR_DONE; } if (hi_div0) { - result = m().mk_app(get_fid(), OP_BSDIV_I, arg1, arg2); + result = m_util.mk_bv_sdiv_i(arg1, arg2); return BR_DONE; } bv_size = get_bv_size(arg2); result = m().mk_ite(m().mk_eq(arg2, mk_numeral(0, bv_size)), - m().mk_app(get_fid(), OP_BSDIV0, arg1), - m().mk_app(get_fid(), OP_BSDIV_I, arg1, arg2)); + m_util.mk_bv_sdiv0(arg1), + m_util.mk_bv_sdiv_i(arg1, arg2)); return BR_REWRITE2; } @@ -1061,7 +1061,7 @@ br_status bv_rewriter::mk_bv_udiv_core(expr * arg1, expr * arg2, bool hi_div0, e r2 = m_util.norm(r2, bv_size); if (r2.is_zero()) { if (!hi_div0) { - result = m().mk_app(get_fid(), OP_BUDIV0, arg1); + result = m_util.mk_bv_udiv0(arg1); return BR_REWRITE1; } else { @@ -1090,19 +1090,19 @@ br_status bv_rewriter::mk_bv_udiv_core(expr * arg1, expr * arg2, bool hi_div0, e } - result = m().mk_app(get_fid(), OP_BUDIV_I, arg1, arg2); + result = m_util.mk_bv_udiv_i(arg1, arg2); return BR_DONE; } if (hi_div0) { - result = m().mk_app(get_fid(), OP_BUDIV_I, arg1, arg2); + result = m_util.mk_bv_udiv_i(arg1, arg2); return BR_DONE; } bv_size = get_bv_size(arg2); result = m().mk_ite(m().mk_eq(arg2, mk_numeral(0, bv_size)), - m().mk_app(get_fid(), OP_BUDIV0, arg1), - m().mk_app(get_fid(), OP_BUDIV_I, arg1, arg2)); + m_util.mk_bv_udiv0(arg1), + m_util.mk_bv_udiv_i(arg1, arg2)); TRACE("bv_udiv", tout << mk_ismt2_pp(arg1, m()) << "\n" << mk_ismt2_pp(arg2, m()) << "\n---->\n" << mk_ismt2_pp(result, m()) << "\n";); return BR_REWRITE2; @@ -1201,7 +1201,7 @@ br_status bv_rewriter::mk_bv_urem_core(expr * arg1, expr * arg2, bool hi_div0, e r2 = m_util.norm(r2, bv_size); if (r2.is_zero()) { if (!hi_div0) { - result = m().mk_app(get_fid(), OP_BUREM0, arg1); + result = m_util.mk_bv_urem0(arg1); return BR_REWRITE1; } else { @@ -1233,7 +1233,7 @@ br_status bv_rewriter::mk_bv_urem_core(expr * arg1, expr * arg2, bool hi_div0, e return BR_REWRITE2; } - result = m().mk_app(get_fid(), OP_BUREM_I, arg1, arg2); + result = m_util.mk_bv_urem_i(arg1, arg2); return BR_DONE; } @@ -1242,7 +1242,7 @@ br_status bv_rewriter::mk_bv_urem_core(expr * arg1, expr * arg2, bool hi_div0, e if (is_num1 && r1.is_zero()) { expr * zero = arg1; result = m().mk_ite(m().mk_eq(arg2, zero), - m().mk_app(get_fid(), OP_BUREM0, zero), + m_util.mk_bv_urem0(zero), zero); return BR_REWRITE2; } @@ -1254,7 +1254,7 @@ br_status bv_rewriter::mk_bv_urem_core(expr * arg1, expr * arg2, bool hi_div0, e expr * x_minus_1 = arg1; expr * minus_one = mk_numeral(rational::power_of_two(bv_size) - numeral(1), bv_size); result = m().mk_ite(m().mk_eq(x, mk_numeral(0, bv_size)), - m().mk_app(get_fid(), OP_BUREM0, minus_one), + m_util.mk_bv_urem0(minus_one), x_minus_1); return BR_REWRITE2; } @@ -1278,14 +1278,14 @@ br_status bv_rewriter::mk_bv_urem_core(expr * arg1, expr * arg2, bool hi_div0, e } if (hi_div0) { - result = m().mk_app(get_fid(), OP_BUREM_I, arg1, arg2); + result = m_util.mk_bv_urem_i(arg1, arg2); return BR_DONE; } bv_size = get_bv_size(arg2); result = m().mk_ite(m().mk_eq(arg2, mk_numeral(0, bv_size)), - m().mk_app(get_fid(), OP_BUREM0, arg1), - m().mk_app(get_fid(), OP_BUREM_I, arg1, arg2)); + m_util.mk_bv_urem0(arg1), + m_util.mk_bv_urem_i(arg1, arg2)); return BR_REWRITE2; } @@ -1297,7 +1297,7 @@ br_status bv_rewriter::mk_bv_smod_core(expr * arg1, expr * arg2, bool hi_div0, e if (is_num1) { r1 = m_util.norm(r1, bv_size, true); if (r1.is_zero()) { - result = m().mk_app(get_fid(), OP_BUREM, arg1, arg2); + result = m_util.mk_bv_urem(arg1, arg2); return BR_REWRITE1; } } @@ -1306,7 +1306,7 @@ br_status bv_rewriter::mk_bv_smod_core(expr * arg1, expr * arg2, bool hi_div0, e r2 = m_util.norm(r2, bv_size, true); if (r2.is_zero()) { if (!hi_div0) - result = m().mk_app(get_fid(), OP_BSMOD0, arg1); + result = m_util.mk_bv_smod0(arg1); else result = arg1; return BR_DONE; diff --git a/src/math/lp/core_solver_pretty_printer_def.h b/src/math/lp/core_solver_pretty_printer_def.h index 63520f014..8980c1c28 100644 --- a/src/math/lp/core_solver_pretty_printer_def.h +++ b/src/math/lp/core_solver_pretty_printer_def.h @@ -101,8 +101,16 @@ template void core_solver_pretty_printer::init_m_ for (const auto & c : m_core_solver.m_A.m_columns[column]){ t[c.var()] = m_core_solver.m_A.get_val(c); } - - string name = m_core_solver.column_name(column); + + auto const& value = m_core_solver.get_var_value(column); + + if (m_core_solver.column_is_fixed(column) && is_zero(value)) + continue; + string name; + if (m_core_solver.column_is_fixed(column)) + name = "*" + T_to_string(value); + else + name = m_core_solver.column_name(column); for (unsigned row = 0; row < nrows(); row ++) { m_A[row].resize(ncols(), ""); m_signs[row].resize(ncols(),""); diff --git a/src/math/lp/lp_core_solver_base.h b/src/math/lp/lp_core_solver_base.h index d911daa36..1bd8fe2df 100644 --- a/src/math/lp/lp_core_solver_base.h +++ b/src/math/lp/lp_core_solver_base.h @@ -191,7 +191,7 @@ public: void add_delta_to_entering(unsigned entering, const X & delta); - const T & get_var_value(unsigned j) const { + const X & get_var_value(unsigned j) const { return m_x[j]; } @@ -618,9 +618,9 @@ public: return out; } - bool column_is_free(unsigned j) const { return this->m_column_type[j] == column_type::free_column; } + bool column_is_free(unsigned j) const { return this->m_column_types[j] == column_type::free_column; } - bool column_is_fixed(unsigned j) const { return this->m_column_type[j] == column_type::fixed; } + bool column_is_fixed(unsigned j) const { return this->m_column_types[j] == column_type::fixed; } bool column_has_upper_bound(unsigned j) const { diff --git a/src/muz/spacer/spacer_iuc_solver.h b/src/muz/spacer/spacer_iuc_solver.h index 409ca3049..a153e6af6 100644 --- a/src/muz/spacer/spacer_iuc_solver.h +++ b/src/muz/spacer/spacer_iuc_solver.h @@ -123,9 +123,6 @@ public: expr_ref_vector cube(expr_ref_vector&, unsigned) override { return expr_ref_vector(m); } void get_levels(ptr_vector const& vars, unsigned_vector& depth) override { m_solver.get_levels(vars, depth); } expr_ref_vector get_trail() override { return m_solver.get_trail(); } - expr_ref get_implied_value(expr* e) override { return m_solver.get_implied_value(e); } - expr_ref get_implied_lower_bound(expr* e) override { return m_solver.get_implied_lower_bound(e); } - expr_ref get_implied_upper_bound(expr* e) override { return m_solver.get_implied_upper_bound(e); } void push() override; void pop(unsigned n) override; diff --git a/src/opt/opt_solver.h b/src/opt/opt_solver.h index a7951a50b..bdef80765 100644 --- a/src/opt/opt_solver.h +++ b/src/opt/opt_solver.h @@ -110,9 +110,6 @@ namespace opt { void get_levels(ptr_vector const& vars, unsigned_vector& depth) override; expr_ref_vector get_trail() override { return m_context.get_trail(); } expr_ref_vector cube(expr_ref_vector&, unsigned) override { return expr_ref_vector(m); } - expr_ref get_implied_value(expr* e) override { return m_context.get_implied_value(e); } - expr_ref get_implied_lower_bound(expr* e) override { return m_context.get_implied_lower_bound(e); } - expr_ref get_implied_upper_bound(expr* e) override { return m_context.get_implied_upper_bound(e); } void set_logic(symbol const& logic); diff --git a/src/sat/sat_solver.cpp b/src/sat/sat_solver.cpp index 9ce23799e..b22debbc6 100644 --- a/src/sat/sat_solver.cpp +++ b/src/sat/sat_solver.cpp @@ -940,11 +940,6 @@ namespace sat { m_phase[v] = !l.sign(); m_assigned_since_gc[v] = true; m_trail.push_back(l); - - if (m_ext && m_external[v] && (!is_probing() || at_base_lvl())) - m_ext->asserted(l); -// else -// std::cout << "assert " << l << "\n"; switch (m_config.m_branching_heuristic) { case BH_VSIDS: @@ -1042,7 +1037,7 @@ namespace sat { lbool val1, val2; bool keep; unsigned curr_level = lvl(l); - TRACE("sat_propagate", tout << "propagating: " << l << " " << m_justification[l.var()] << "\n"; ); + TRACE("sat_propagate", tout << "propagating: " << l << "@" << curr_level << " " << m_justification[l.var()] << "\n"; ); literal not_l = ~l; SASSERT(value(l) == l_true); @@ -1204,6 +1199,9 @@ namespace sat { } } wlist.set_end(it2); + if (m_ext && m_external[l.var()] && (!is_probing() || at_base_lvl())) + m_ext->asserted(l); + return true; } @@ -3575,6 +3573,7 @@ namespace sat { m_trail.shrink(old_sz); m_qhead = m_trail.size(); if (!m_replay_assign.empty()) IF_VERBOSE(20, verbose_stream() << "replay assign: " << m_replay_assign.size() << "\n"); + CTRACE("sat", !m_replay_assign.empty(), tout << "replay-assign: " << m_replay_assign << "\n";); for (unsigned i = m_replay_assign.size(); i-- > 0; ) { literal lit = m_replay_assign[i]; m_trail.push_back(lit); diff --git a/src/sat/sat_solver.h b/src/sat/sat_solver.h index 2fb281554..c8b7e11f2 100644 --- a/src/sat/sat_solver.h +++ b/src/sat/sat_solver.h @@ -371,9 +371,13 @@ namespace sat { switch (value(l)) { case l_false: set_conflict(j, ~l); break; case l_undef: assign_core(l, j); break; - case l_true: return; + case l_true: update_assign(l, j); break; } } + void update_assign(literal l, justification j) { + if (lvl(l) > j.level()) + m_justification[l.var()] = j; + } void assign_unit(literal l) { assign(l, justification(0)); } void assign_scoped(literal l) { assign(l, justification(scope_lvl())); } void assign_core(literal l, justification jst); diff --git a/src/sat/sat_solver/inc_sat_solver.cpp b/src/sat/sat_solver/inc_sat_solver.cpp index e507c9a5f..57e968b45 100644 --- a/src/sat/sat_solver/inc_sat_solver.cpp +++ b/src/sat/sat_solver/inc_sat_solver.cpp @@ -123,7 +123,7 @@ public: ast_translation tr(m, dst_m); m_solver.pop_to_base_level(); inc_sat_solver* result = alloc(inc_sat_solver, dst_m, p, is_incremental()); - auto* ext = dynamic_cast(m_solver.get_extension()); + auto* ext = get_euf(); if (ext) { auto& si = result->m_goal2sat.si(dst_m, m_params, result->m_solver, result->m_map, result->m_dep2asm, is_incremental()); euf::solver::scoped_set_translate st(*ext, dst_m, si); @@ -258,6 +258,8 @@ public: void push_internal() { m_solver.user_push(); + if (get_euf()) + get_euf()->user_push(); ++m_num_scopes; m_mcs.push_back(m_mcs.back()); m_fmls_lim.push_back(m_fmls.size()); @@ -280,6 +282,8 @@ public: m_num_scopes -= n; // ? m_internalized_converted = false; m_has_uninterpreted.pop(n); + if (get_euf()) + get_euf()->user_pop(n); while (n > 0) { m_mcs.pop_back(); m_fmls_head = m_fmls_head_lim.back(); @@ -337,6 +341,11 @@ public: m_params.set_sym("pb.solver", p1.pb_solver()); m_solver.updt_params(m_params); m_solver.set_incremental(is_incremental() && !override_incremental()); + if (p1.euf() && !get_euf()) { + ensure_euf(); + for (unsigned i = 0; i < m_num_scopes; ++i) + get_euf()->user_push(); + } } void collect_statistics(statistics & st) const override { @@ -374,19 +383,6 @@ public: return nullptr; } - // TODO - expr_ref get_implied_value(expr* e) override { - return expr_ref(e, m); - } - - expr_ref get_implied_lower_bound(expr* e) override { - return expr_ref(e, m); - } - - expr_ref get_implied_upper_bound(expr* e) override { - return expr_ref(e, m); - } - expr_ref_vector last_cube(bool is_sat) { expr_ref_vector result(m); result.push_back(is_sat ? m.mk_true() : m.mk_false()); @@ -624,6 +620,10 @@ public: m_preprocess->reset(); } + euf::solver* get_euf() { + return dynamic_cast(m_solver.get_extension()); + } + euf::solver* ensure_euf() { auto* ext = dynamic_cast(m_solver.get_extension()); return ext; diff --git a/src/sat/smt/array_axioms.cpp b/src/sat/smt/array_axioms.cpp index 5c48b22cf..dce3b5249 100644 --- a/src/sat/smt/array_axioms.cpp +++ b/src/sat/smt/array_axioms.cpp @@ -31,57 +31,27 @@ namespace array { ctx.push(push_back_vector>(m_axiom_trail)); } - bool solver::assert_axiom(unsigned idx) { - axiom_record const& r = m_axiom_trail[idx]; + bool solver::propagate_axiom(unsigned idx) { if (m_axioms.contains(idx)) return false; m_axioms.insert(idx); ctx.push(insert_map(m_axioms, idx)); - expr* child = r.n->get_expr(); - app* select; + return assert_axiom(idx); + } + + bool solver::assert_axiom(unsigned idx) { + axiom_record& r = m_axiom_trail[idx]; switch (r.m_kind) { case axiom_record::kind_t::is_store: - TRACE("array", tout << "store-axiom: " << mk_bounded_pp(child, m, 2) << "\n";); - return assert_store_axiom(to_app(child)); + return assert_store_axiom(to_app(r.n->get_expr())); case axiom_record::kind_t::is_select: - select = r.select->get_app(); - SASSERT(a.is_select(select)); - SASSERT(can_beta_reduce(r.n)); - TRACE("array", tout << "select-axiom: " << mk_bounded_pp(select, m, 2) << " " << mk_bounded_pp(child, m, 2) << "\n";); - if (r.select->get_arg(0)->get_root() != r.n->get_root()) { - IF_VERBOSE(0, verbose_stream() << "could delay " << mk_pp(select, m) << " " << mk_pp(child, m) << "\n"); - } - if (a.is_const(child)) - return assert_select_const_axiom(select, to_app(child)); - else if (a.is_as_array(child)) - return assert_select_as_array_axiom(select, to_app(child)); - else if (a.is_store(child)) - return assert_select_store_axiom(select, to_app(child)); - else if (a.is_map(child)) - return assert_select_map_axiom(select, to_app(child)); - else if (is_lambda(child)) - return assert_select_lambda_axiom(select, child); - else - UNREACHABLE(); - break; + return assert_select(idx, r); case axiom_record::kind_t::is_default: - SASSERT(can_beta_reduce(r.n)); - TRACE("array", tout << "default-axiom: " << mk_bounded_pp(child, m, 2) << "\n";); - if (a.is_const(child)) - return assert_default_const_axiom(to_app(child)); - else if (a.is_store(child)) - return assert_default_store_axiom(to_app(child)); - else if (a.is_map(child)) - return assert_default_map_axiom(to_app(child)); - else - return true; - break; + return assert_default(r); case axiom_record::kind_t::is_extensionality: - TRACE("array", tout << "extensionality-axiom: " << mk_bounded_pp(child, m, 2) << "\n";); return assert_extensionality(r.n->get_arg(0)->get_expr(), r.n->get_arg(1)->get_expr()); case axiom_record::kind_t::is_congruence: - TRACE("array", tout << "congruence-axiom: " << mk_bounded_pp(child, m, 2) << " " << mk_bounded_pp(r.select->get_expr(), m, 2) << "\n";); - return assert_congruent_axiom(child, r.select->get_expr()); + return assert_congruent_axiom(r.n->get_expr(), r.select->get_expr()); default: UNREACHABLE(); break; @@ -89,6 +59,68 @@ namespace array { return false; } + bool solver::assert_default(axiom_record& r) { + expr* child = r.n->get_expr(); + SASSERT(can_beta_reduce(r.n)); + if (!ctx.is_relevant(child)) + return false; + TRACE("array", tout << "default-axiom: " << mk_bounded_pp(child, m, 2) << "\n";); + if (a.is_const(child)) + return assert_default_const_axiom(to_app(child)); + else if (a.is_store(child)) + return assert_default_store_axiom(to_app(child)); + else if (a.is_map(child)) + return assert_default_map_axiom(to_app(child)); + else + return false; + } + + struct solver::set_delay_bit : trail { + solver& s; + unsigned m_idx; + set_delay_bit(solver& s, unsigned idx) : s(s), m_idx(idx) {} + void undo(euf::solver& euf) override { + s.m_axiom_trail[m_idx].m_delayed = false; + } + }; + + bool solver::assert_select(unsigned idx, axiom_record& r) { + expr* child = r.n->get_expr(); + app* select = r.select->get_app(); + SASSERT(a.is_select(select)); + SASSERT(can_beta_reduce(r.n)); + //std::cout << mk_bounded_pp(child, m) << " " << ctx.is_relevant(child) << " " << mk_bounded_pp(select, m) << "\n"; + if (!ctx.is_relevant(child)) + return false; + for (unsigned i = 1; i < select->get_num_args(); ++i) + if (!ctx.is_relevant(select->get_arg(i))) + return false; + TRACE("array", tout << "select-axiom: " << mk_bounded_pp(select, m, 2) << " " << mk_bounded_pp(child, m, 2) << "\n";); +// if (r.select->get_arg(0)->get_root() != r.n->get_root()) +// std::cout << "delayed: " << r.m_delayed << "\n"; + if (get_config().m_array_delay_exp_axiom && r.select->get_arg(0)->get_root() != r.n->get_root() && !r.m_delayed) { + IF_VERBOSE(11, verbose_stream() << "delay: " << mk_bounded_pp(child, m) << " " << mk_bounded_pp(select, m) << "\n"); + ctx.push(set_delay_bit(*this, idx)); + r.m_delayed = true; + return false; + } + if (r.select->get_arg(0)->get_root() != r.n->get_root() && r.m_delayed) + return false; + if (a.is_const(child)) + return assert_select_const_axiom(select, to_app(child)); + else if (a.is_as_array(child)) + return assert_select_as_array_axiom(select, to_app(child)); + else if (a.is_store(child)) + return assert_select_store_axiom(select, to_app(child)); + else if (a.is_map(child)) + return assert_select_map_axiom(select, to_app(child)); + else if (is_lambda(child)) + return assert_select_lambda_axiom(select, child); + else + UNREACHABLE(); + return false; + } + /** * Assert * select(n, i) = v @@ -96,6 +128,7 @@ namespace array { * n := store(a, i, v) */ bool solver::assert_store_axiom(app* e) { + TRACE("array", tout << "store-axiom: " << mk_bounded_pp(e, m) << "\n";); ++m_stats.m_num_store_axiom; SASSERT(a.is_store(e)); unsigned num_args = e->get_num_args(); @@ -182,6 +215,7 @@ namespace array { * e1 = e2 or select(e1, diff(e1,e2)) != select(e2, diff(e1, e2)) */ bool solver::assert_extensionality(expr* e1, expr* e2) { + TRACE("array", tout << "extensionality-axiom: " << mk_bounded_pp(e1, m) << " == " << mk_bounded_pp(e2, m) << "\n";); ++m_stats.m_num_extensionality_axiom; func_decl_ref_vector* funcs = nullptr; VERIFY(m_sort2diff.find(m.get_sort(e1), funcs)); @@ -289,7 +323,6 @@ namespace array { return ctx.propagate(expr2enode(val), e_internalize(def), array_axiom()); } - /** * let n := store(a, i, v) * Assert: @@ -368,6 +401,7 @@ namespace array { \brief assert n1 = n2 => forall vars . (n1 vars) = (n2 vars) */ bool solver::assert_congruent_axiom(expr* e1, expr* e2) { + TRACE("array", tout << "congruence-axiom: " << mk_bounded_pp(e1, m) << " " << mk_bounded_pp(e2, m) << "\n";); ++m_stats.m_num_congruence_axiom; sort* srt = m.get_sort(e1); unsigned dimension = get_array_arity(srt); @@ -446,10 +480,24 @@ namespace array { for (unsigned v = 0; v < num_vars; v++) { propagate_parent_select_axioms(v); auto& d = get_var_data(v); - if (d.m_prop_upward) + if (!d.m_prop_upward) + continue; + euf::enode* n = var2enode(v); + bool has_default = false; + for (euf::enode* p : euf::enode_parents(n)) + has_default |= a.is_default(p->get_expr()); + if (has_default) propagate_parent_default(v); } - return unit_propagate(); + bool change = false; + unsigned sz = m_axiom_trail.size(); + m_delay_qhead = 0; + for (; m_delay_qhead < sz; ++m_delay_qhead) + if (m_axiom_trail[m_delay_qhead].m_delayed && assert_axiom(m_delay_qhead)) + change = true; + if (unit_propagate()) + change = true; + return change; } bool solver::add_interface_equalities() { @@ -466,9 +514,11 @@ namespace array { continue; if (have_different_model_values(v1, v2)) continue; - expr_ref eq(m.mk_eq(e1, e2), m); + if (ctx.get_egraph().are_diseq(var2enode(v1), var2enode(v2))) + continue; + expr_ref eq(m.mk_eq(e1, e2), m); sat::literal lit = b_internalize(eq); - if (s().value(lit) == l_undef) + if (s().value(lit) == l_undef) prop = true; } } diff --git a/src/sat/smt/array_internalize.cpp b/src/sat/smt/array_internalize.cpp index c0d40727d..f612be3c5 100644 --- a/src/sat/smt/array_internalize.cpp +++ b/src/sat/smt/array_internalize.cpp @@ -71,7 +71,8 @@ namespace array { void solver::internalize_lambda(euf::enode* n) { set_prop_upward(n); - push_axiom(default_axiom(n)); + if (!a.is_store(n->get_expr())) + push_axiom(default_axiom(n)); add_lambda(n->get_th_var(get_id()), n); } @@ -150,5 +151,46 @@ namespace array { return true; } + /** + \brief Return true if v is shared between two different "instances" of the array theory. + It is shared if it is used in more than one role. The possible roles are: array, index, and value. + Example: + (store v i j) <--- v is used as an array + (select A v) <--- v is used as an index + (store A i v) <--- v is used as an value + */ + bool solver::is_shared(theory_var v) const { + euf::enode* n = var2enode(v); + euf::enode* r = n->get_root(); + bool is_array = false; + bool is_index = false; + bool is_value = false; + auto set_array = [&](euf::enode* arg) { if (arg->get_root() == r) is_array = true; }; + auto set_index = [&](euf::enode* arg) { if (arg->get_root() == r) is_index = true; }; + auto set_value = [&](euf::enode* arg) { if (arg->get_root() == r) is_value = true; }; + + for (euf::enode* parent : euf::enode_parents(r)) { + app* p = parent->get_app(); + unsigned num_args = parent->num_args(); + if (a.is_store(p)) { + set_array(parent->get_arg(0)); + for (unsigned i = 1; i < num_args - 1; i++) + set_index(parent->get_arg(i)); + set_value(parent->get_arg(num_args - 1)); + } + else if (a.is_select(p)) { + set_array(parent->get_arg(0)); + for (unsigned i = 1; i < num_args - 1; i++) + set_index(parent->get_arg(i)); + } + else if (a.is_const(p)) { + set_value(parent->get_arg(0)); + } + if (is_array + is_index + is_value > 1) + return true; + } + return false; + } + } diff --git a/src/sat/smt/array_model.cpp b/src/sat/smt/array_model.cpp index 5b56a0476..a722c459e 100644 --- a/src/sat/smt/array_model.cpp +++ b/src/sat/smt/array_model.cpp @@ -15,7 +15,6 @@ Author: --*/ -#include "ast/ast_ll_pp.h" #include "model/array_factory.h" #include "sat/smt/array_solver.h" #include "sat/smt/euf_solver.h" @@ -29,6 +28,10 @@ namespace array { return; } for (euf::enode* p : euf::enode_parents(n)) { + if (a.is_default(p->get_expr())) { + dep.add(n, p); + continue; + } if (!a.is_select(p->get_expr())) continue; dep.add(n, p); @@ -37,9 +40,7 @@ namespace array { } for (euf::enode* k : euf::enode_class(n)) if (a.is_const(k->get_expr())) - dep.add(n, k); - else if (a.is_default(k->get_expr())) - dep.add(n, k); + dep.add(n, k->get_arg(0)); } @@ -52,23 +53,58 @@ namespace array { func_interp * fi = alloc(func_interp, m, arity); mdl.register_decl(f, fi); - for (euf::enode* p : euf::enode_parents(n)) { - if (!a.is_select(p->get_expr()) || p->get_arg(0)->get_root() != n->get_root()) - continue; - args.reset(); - for (unsigned i = 1; i < p->num_args(); ++i) - args.push_back(values.get(p->get_arg(i)->get_root_id())); - expr* value = values.get(p->get_root_id()); - fi->insert_entry(args.c_ptr(), value); - } - if (!fi->get_else()) - for (euf::enode* k : euf::enode_class(n)) - if (a.is_const(k->get_expr())) - fi->set_else(k->get_arg(0)->get_root()->get_expr()); if (!fi->get_else()) - for (euf::enode* k : euf::enode_parents(n)) - if (a.is_default(k->get_expr())) - fi->set_else(k->get_root()->get_expr()); + for (euf::enode* k : euf::enode_class(n)) + if (a.is_const(k->get_expr())) + fi->set_else(values.get(k->get_arg(0)->get_root_id())); + + if (!fi->get_else()) + for (euf::enode* p : euf::enode_parents(n)) + if (a.is_default(p->get_expr())) + fi->set_else(values.get(p->get_root_id())); + + if (!fi->get_else()) { + expr* else_value = nullptr; + unsigned max_occ_num = 0; + obj_map num_occ; + for (euf::enode* p : euf::enode_parents(n)) { + if (a.is_select(p->get_expr()) && p->get_arg(0)->get_root() == n->get_root()) { + expr* v = values.get(p->get_root_id()); + if (!v) + continue; + unsigned no = 0; + num_occ.find(v, no); + ++no; + num_occ.insert(v, no); + if (no > max_occ_num) { + else_value = v; + max_occ_num = no; + } + } + } + if (else_value) + fi->set_else(else_value); + } + + for (euf::enode* p : euf::enode_parents(n)) { + if (a.is_select(p->get_expr()) && p->get_arg(0)->get_root() == n->get_root()) { +// std::cout << "parent " << mk_bounded_pp(p->get_expr(), m) << "\n"; + expr* value = values.get(p->get_root_id()); + if (!value || value == fi->get_else()) + continue; + args.reset(); + bool relevant = true; + for (unsigned i = 1; relevant && i < p->num_args(); ++i) + relevant = ctx.is_relevant(p->get_arg(i)->get_root()); + if (!relevant) + continue; + for (unsigned i = 1; i < p->num_args(); ++i) + args.push_back(values.get(p->get_arg(i)->get_root_id())); +// for (expr* arg : args) +// std::cout << "arg " << mk_bounded_pp(arg, m) << "\n"; + fi->insert_entry(args.c_ptr(), value); + } + } parameter p(f); values.set(n->get_root_id(), m.mk_app(get_id(), OP_AS_ARRAY, 1, &p)); diff --git a/src/sat/smt/array_solver.cpp b/src/sat/smt/array_solver.cpp index bf83c4061..3d339080d 100644 --- a/src/sat/smt/array_solver.cpp +++ b/src/sat/smt/array_solver.cpp @@ -107,8 +107,8 @@ namespace array { auto& d = get_var_data(i); out << var2enode(i)->get_expr_id() << " " << mk_bounded_pp(var2expr(i), m, 2) << "\n"; display_info(out, "parent lambdas", d.m_parent_lambdas); - display_info(out, "parent select", d.m_parent_selects); - display_info(out, "b ", d.m_lambdas); + display_info(out, "parent select", d.m_parent_selects); + display_info(out, "lambdas", d.m_lambdas); } return out; } @@ -141,7 +141,7 @@ namespace array { st.update("array splits", m_stats.m_num_eq_splits); } - euf::th_solver* solver::fresh(sat::solver* s, euf::solver& ctx) { + euf::th_solver* solver::clone(sat::solver* s, euf::solver& ctx) { auto* result = alloc(solver, ctx, get_id()); ast_translation tr(m, ctx.get_manager()); for (unsigned i = 0; i < get_num_vars(); ++i) { @@ -165,7 +165,7 @@ namespace array { bool prop = false; ctx.push(value_trail(m_qhead)); for (; m_qhead < m_axiom_trail.size() && !s().inconsistent(); ++m_qhead) - if (assert_axiom(m_qhead)) + if (propagate_axiom(m_qhead)) prop = true; return prop; } @@ -174,7 +174,6 @@ namespace array { euf::enode* n1 = var2enode(v1); euf::enode* n2 = var2enode(v2); SASSERT(n1->get_root() == n2->get_root()); - SASSERT(n1->is_root() || n2->is_root()); SASSERT(v1 == find(v1)); expr* e1 = n1->get_expr(); expr* e2 = n2->get_expr(); @@ -204,7 +203,7 @@ namespace array { v_child = find(v_child); tracked_push(get_var_data(v_child).m_parent_selects, select); euf::enode* child = var2enode(v_child); - if (can_beta_reduce(child)) + if (can_beta_reduce(child) && child != select->get_arg(0)) push_axiom(select_axiom(select, child)); } diff --git a/src/sat/smt/array_solver.h b/src/sat/smt/array_solver.h index afff61219..d9c93e03f 100644 --- a/src/sat/smt/array_solver.h +++ b/src/sat/smt/array_solver.h @@ -92,6 +92,7 @@ namespace array { kind_t m_kind; euf::enode* n; euf::enode* select; + bool m_delayed { false }; axiom_record(kind_t k, euf::enode* n, euf::enode* select = nullptr) : m_kind(k), n(n), select(select) {} struct hash { @@ -119,8 +120,13 @@ namespace array { axiom_table_t m_axioms; svector m_axiom_trail; unsigned m_qhead { 0 }; + unsigned m_delay_qhead { 0 }; + struct set_delay_bit; void push_axiom(axiom_record const& r); + bool propagate_axiom(unsigned idx); bool assert_axiom(unsigned idx); + bool assert_select(unsigned idx, axiom_record & r); + bool assert_default(axiom_record & r); axiom_record select_axiom(euf::enode* s, euf::enode* n) { return axiom_record(axiom_record::kind_t::is_select, n, s); } axiom_record default_axiom(euf::enode* n) { return axiom_record(axiom_record::kind_t::is_default, n); } @@ -180,7 +186,6 @@ namespace array { bool have_different_model_values(theory_var v1, theory_var v2); // diagnostics - std::ostream& display_info(std::ostream& out, char const* id, euf::enode_vector const& v) const; public: solver(euf::solver& ctx, theory_id id); @@ -195,7 +200,7 @@ namespace array { std::ostream& display_justification(std::ostream& out, sat::ext_justification_idx idx) const override; std::ostream& display_constraint(std::ostream& out, sat::ext_constraint_idx idx) const override; void collect_statistics(statistics& st) const override; - euf::th_solver* fresh(sat::solver* s, euf::solver& ctx) override; + euf::th_solver* clone(sat::solver* s, euf::solver& ctx) override; void new_eq_eh(euf::th_eq const& eq) override; bool unit_propagate() override; void add_value(euf::enode* n, model& mdl, expr_ref_vector& values) override; @@ -204,6 +209,7 @@ namespace array { void internalize(expr* e, bool redundant) override; euf::theory_var mk_var(euf::enode* n) override; void apply_sort_cnstr(euf::enode* n, sort* s) override; + bool is_shared(theory_var v) const override; void tracked_push(euf::enode_vector& v, euf::enode* n); diff --git a/src/sat/smt/ba_solver.cpp b/src/sat/smt/ba_solver.cpp index 2155a7c39..ba0553aa0 100644 --- a/src/sat/smt/ba_solver.cpp +++ b/src/sat/smt/ba_solver.cpp @@ -3137,14 +3137,14 @@ namespace sat { } extension* ba_solver::copy(solver* s) { - return fresh(s, m, si, m_id); + return clone_aux(s, m, si, m_id); } - euf::th_solver* ba_solver::fresh(solver* new_s, euf::solver& new_ctx) { - return fresh(new_s, new_ctx.get_manager(), new_ctx.get_si(), get_id()); + euf::th_solver* ba_solver::clone(solver* new_s, euf::solver& new_ctx) { + return clone_aux(new_s, new_ctx.get_manager(), new_ctx.get_si(), get_id()); } - euf::th_solver* ba_solver::fresh(solver* new_s, ast_manager& m, sat::sat_internalizer& si, euf::theory_id id) { + euf::th_solver* ba_solver::clone_aux(solver* new_s, ast_manager& m, sat::sat_internalizer& si, euf::theory_id id) { ba_solver* result = alloc(ba_solver, m, si, id); result->set_solver(new_s); copy_constraints(result, m_constraints); diff --git a/src/sat/smt/ba_solver.h b/src/sat/smt/ba_solver.h index 81132ecec..703aa9438 100644 --- a/src/sat/smt/ba_solver.h +++ b/src/sat/smt/ba_solver.h @@ -151,7 +151,7 @@ namespace sat { unsigned_vector m_weights; svector m_wlits; - euf::th_solver* fresh(sat::solver* new_s, ast_manager& m, sat::sat_internalizer& si, euf::theory_id id); + euf::th_solver* clone_aux(sat::solver* new_s, ast_manager& m, sat::sat_internalizer& si, euf::theory_id id); bool subsumes(card& c1, card& c2, literal_vector& comp); bool subsumes(card& c1, clause& c2, bool& self); @@ -433,7 +433,7 @@ namespace sat { literal internalize(expr* e, bool sign, bool root, bool redundant) override; void internalize(expr* e, bool redundant) override; bool to_formulas(std::function& l2e, expr_ref_vector& fmls) override; - euf::th_solver* fresh(solver* s, euf::solver& ctx) override; + euf::th_solver* clone(solver* s, euf::solver& ctx) override; ptr_vector const & constraints() const { return m_constraints; } std::ostream& display(std::ostream& out, constraint const& c, bool values) const; diff --git a/src/sat/smt/bv_delay_internalize.cpp b/src/sat/smt/bv_delay_internalize.cpp index 8b0c0ba69..6af9ffcf9 100644 --- a/src/sat/smt/bv_delay_internalize.cpp +++ b/src/sat/smt/bv_delay_internalize.cpp @@ -20,53 +20,77 @@ Author: namespace bv { - bool solver::check_delay_internalized(euf::enode* n) { - expr* e = n->get_expr(); + bool solver::check_delay_internalized(expr* e) { + if (!ctx.is_relevant(e)) + return true; + if (get_internalize_mode(e) != internalize_mode::delay_i) + return true; SASSERT(bv.is_bv(e)); - SASSERT(get_internalize_mode(e) != internalize_mode::no_delay_i); switch (to_app(e)->get_decl_kind()) { case OP_BMUL: - return check_mul(n); + return check_mul(to_app(e)); + case OP_BSMUL_NO_OVFL: + case OP_BSMUL_NO_UDFL: + case OP_BUMUL_NO_OVFL: + return check_bool_eval(expr2enode(e)); default: - return check_eval(n); + return check_bv_eval(expr2enode(e)); } return true; } bool solver::should_bit_blast(expr* e) { - return bv.get_bv_size(e) <= 10; + return bv.get_bv_size(e) <= 12; } - void solver::eval_args(euf::enode* n, vector& args) { - rational val; - for (euf::enode* arg : euf::enode_args(n)) { - theory_var v = arg->get_th_var(get_id()); - VERIFY(get_fixed_value(v, val)); - args.push_back(val); - } + expr_ref solver::eval_args(euf::enode* n, expr_ref_vector& args) { + for (euf::enode* arg : euf::enode_args(n)) + args.push_back(eval_bv(arg)); + expr_ref r(m.mk_app(n->get_decl(), args), m); + ctx.get_rewriter()(r); + return r; } - bool solver::check_mul(euf::enode* n) { - SASSERT(n->num_args() >= 2); - app* e = to_app(n->get_expr()); - rational val, val_mul(1); - vector args; - eval_args(n, args); - for (rational const& val_arg : args) - val_mul *= val_arg; + expr_ref solver::eval_bv(euf::enode* n) { + rational val; theory_var v = n->get_th_var(get_id()); + SASSERT(get_fixed_value(v, val)); VERIFY(get_fixed_value(v, val)); - val_mul = mod(val_mul, power2(get_bv_size(v))); - IF_VERBOSE(12, verbose_stream() << "check_mul " << mk_bounded_pp(n->get_expr(), m) << " " << args << " = " << val_mul << " =? " << val << "\n"); - if (val_mul == val) + return expr_ref(bv.mk_numeral(val, get_bv_size(v)), m); + } + + bool solver::check_mul(app* e) { + SASSERT(e->get_num_args() >= 2); + expr_ref_vector args(m); + euf::enode* n = expr2enode(e); + auto r1 = eval_bv(n); + auto r2 = eval_args(n, args); + if (r1 == r2) return true; - // Some possible approaches: + TRACE("bv", tout << mk_bounded_pp(e, m) << " evaluates to " << r1 << " arguments: " << args << "\n";); + // check x*0 = 0 + if (!check_mul_zero(e, args, r1, r2)) + return false; - // check base cases: val_mul = 0 or val = 0, some values in product are 1, + // check x*1 = x + if (!check_mul_one(e, args, r1, r2)) + return false; + + // Add propagation axiom for arguments + if (!check_mul_invertibility(e, args, r1)) + return false; // check discrepancies in low-order bits - // Add axioms for multiplication when fixing high-order bits to 0 + // Add axioms for multiplication when fixing high-order bits + if (!check_mul_low_bits(e, args, r1, r2)) + return false; + + + // Some other possible approaches: + // algebraic rules: + // x*(y+z), and there are nodes for x*y or x*z -> x*(y+z) = x*y + x*z + // compute S-polys for a set of constraints. // Hensel lifting: // The idea is dual to fixing high-order bits. Fix the low order bits where multiplication @@ -78,40 +102,316 @@ namespace bv { // check tangets hi >= y >= y0 and hi' >= x => x*y >= x*y0 - // compute S-polys for a set of constraints. + + if (m_cheap_axioms) + return true; set_delay_internalize(e, internalize_mode::no_delay_i); - internalize_circuit(e, v); + internalize_circuit(e); return false; } - bool solver::check_eval(euf::enode* n) { - expr_ref_vector args(m); - expr_ref r1(m), r2(m); - rational val; - app* a = to_app(n->get_expr()); - theory_var v = n->get_th_var(get_id()); - VERIFY(get_fixed_value(v, val)); - r1 = bv.mk_numeral(val, get_bv_size(v)); - SASSERT(bv.is_bv(a)); - for (euf::enode* arg : euf::enode_args(n)) { - SASSERT(bv.is_bv(arg->get_expr())); - theory_var v_arg = arg->get_th_var(get_id()); - VERIFY(get_fixed_value(v_arg, val)); - args.push_back(bv.mk_numeral(val, get_bv_size(v_arg))); + /** + * Add invertibility condition for multiplication + * + * x * y = z => (y | -y) & z = z + * + * This propagator relates to Niemetz and Preiner's consistency and invertibility conditions. + * The idea is that the side-conditions for ensuring invertibility are valid + * and in some cases are cheap to bit-blast. For multiplication, we include only + * the _consistency_ condition because the side-constraints for invertibility + * appear expensive (to paraphrase FMCAD 2020 paper): + * x * s = t => (s = 0 or mcb(x << c, y << c)) + * + * for c = ctz(s) and y = (t >> c) / (s >> c) + * + * mcb(x,t/s) just mean that the bit-vectors are compatible as ternary bit-vectors, + * which for propagation means that they are the same. + */ + + bool solver::check_mul_invertibility(app* n, expr_ref_vector const& arg_values, expr* value) { + + expr_ref inv(m), eq(m); + + auto invert = [&](expr* s, expr* t) { + return bv.mk_bv_and(bv.mk_bv_or(s, bv.mk_bv_neg(s)), t); + }; + auto check_invert = [&](expr* s) { + inv = invert(s, value); + ctx.get_rewriter()(inv); + return inv == value; + }; + auto add_inv = [&](expr* s) { + inv = invert(s, n); + expr_ref eq(m.mk_eq(inv, n), m); + TRACE("bv", tout << "enforce " << eq << "\n";); + add_unit(b_internalize(eq)); + }; + bool ok = true; + for (unsigned i = 0; i < arg_values.size(); ++i) { + if (!check_invert(arg_values[i])) { + add_inv(n->get_arg(i)); + ok = false; + } } - r2 = m.mk_app(a->get_decl(), args); - ctx.get_rewriter()(r2); + return ok; + } + + /* + * Check that multiplication with 0 is correctly propagated. + * If not, create algebraic axioms enforcing 0*x = 0 and x*0 = 0 + * + * z = 0, then lsb(x) + 1 + lsb(y) + 1 >= sz + */ + bool solver::check_mul_zero(app* n, expr_ref_vector const& arg_values, expr* mul_value, expr* arg_value) { + SASSERT(mul_value != arg_value); + SASSERT(!(bv.is_zero(mul_value) && bv.is_zero(arg_value))); + if (bv.is_zero(arg_value)) { + unsigned sz = n->get_num_args(); + expr_ref_vector args(m, sz, n->get_args()); + for (unsigned i = 0; i < sz && !s().inconsistent(); ++i) { + args[i] = arg_value; + expr_ref r(m.mk_app(n->get_decl(), args), m); + set_delay_internalize(r, internalize_mode::init_bits_only_i); // do not bit-blast this multiplier. + expr_ref eq(m.mk_eq(r, arg_value), m); + args[i] = n->get_arg(i); + std::cout << eq << "@" << s().scope_lvl() << "\n"; + add_unit(b_internalize(eq)); + } + return false; + } + if (bv.is_zero(mul_value)) { + return true; +#if 0 + vector lsb_bits; + for (expr* arg : *n) { + expr_ref_vector bits(m); + encode_lsb_tail(arg, bits); + lsb_bits.push_back(bits); + } + expr_ref_vector zs(m); + literal_vector lits; + expr_ref eq(m.mk_eq(n, mul_value), m); + lits.push_back(~b_internalize(eq)); + + for (unsigned i = 0; i < lsb_bits.size(); ++i) { + } + expr_ref z(m.mk_or(zs), m); + add_clause(lits); + // sum of lsb should be at least sz + return true; +#endif + } + return true; + } + + /*** + * check that 1*y = y, x*1 = x + */ + bool solver::check_mul_one(app* n, expr_ref_vector const& arg_values, expr* mul_value, expr* arg_value) { + if (arg_values.size() != 2) + return true; + if (bv.is_one(arg_values[0])) { + expr_ref mul1(m.mk_app(n->get_decl(), arg_values[0], n->get_arg(1)), m); + set_delay_internalize(mul1, internalize_mode::init_bits_only_i); + expr_ref eq(m.mk_eq(mul1, n->get_arg(1)), m); + add_unit(b_internalize(eq)); + TRACE("bv", tout << eq << "\n";); + return false; + } + if (bv.is_one(arg_values[1])) { + expr_ref mul1(m.mk_app(n->get_decl(), n->get_arg(0), arg_values[1]), m); + set_delay_internalize(mul1, internalize_mode::init_bits_only_i); + expr_ref eq(m.mk_eq(mul1, n->get_arg(0)), m); + add_unit(b_internalize(eq)); + TRACE("bv", tout << eq << "\n";); + return false; + } + return true; + } + + /** + * Check for discrepancies in low-order bits. + * Add bit-blasting axioms if there are discrepancies within low order bits. + */ + bool solver::check_mul_low_bits(app* n, expr_ref_vector const& arg_values, expr* value1, expr* value2) { + rational v0, v1, two(2); + unsigned sz; + VERIFY(bv.is_numeral(value1, v0, sz)); + VERIFY(bv.is_numeral(value2, v1)); + unsigned num_bits = 10; + if (sz <= num_bits) + return true; + bool diff = false; + for (unsigned i = 0; !diff && i < num_bits; ++i) { + rational b0 = mod(v0, two); + rational b1 = mod(v1, two); + diff = b0 != b1; + div(v0, two, v0); + div(v1, two, v1); + } + if (!diff) + return true; + + auto safe_for_fixing_bits = [&](expr* e) { + euf::enode* n = expr2enode(e); + theory_var v = n->get_th_var(get_id()); + for (unsigned i = num_bits; i < sz; ++i) { + sat::literal lit = m_bits[v][i]; + if (s().value(lit) == l_true && s().lvl(lit) > s().search_lvl()) + return false; + } + return true; + }; + for (expr* arg : *n) + if (!safe_for_fixing_bits(arg)) + return true; + if (!safe_for_fixing_bits(n)) + return true; + + auto value_for_bv = [&](expr* e) { + euf::enode* n = expr2enode(e); + theory_var v = n->get_th_var(get_id()); + rational val(0); + for (unsigned i = num_bits; i < sz; ++i) { + sat::literal lit = m_bits[v][i]; + if (s().value(lit) == l_true && s().lvl(lit) <= s().search_lvl()) + val += power2(i - num_bits); + } + return val; + }; + auto extract_low_bits = [&](expr* e) { + rational val = value_for_bv(e); + expr_ref lo(bv.mk_extract(num_bits - 1, 0, e), m); + expr_ref hi(bv.mk_numeral(val, sz - num_bits), m); + return expr_ref(bv.mk_concat(lo, hi), m); + }; + expr_ref_vector args(m); + for (expr* arg : *n) + args.push_back(extract_low_bits(arg)); + expr_ref lhs(extract_low_bits(n), m); + expr_ref rhs(m.mk_app(n->get_decl(), args), m); + set_delay_internalize(rhs, internalize_mode::no_delay_i); + expr_ref eq(m.mk_eq(lhs, rhs), m); + add_unit(b_internalize(eq)); + TRACE("bv", tout << "low-bits: " << eq << "\n";); + std::cout << "low bits\n"; + return false; + } + + /** + * The i'th bit in xs is 1 if the most significant bit of x is i or higher. + */ + void solver::encode_msb_tail(expr* x, expr_ref_vector& xs) { + theory_var v = expr2enode(x)->get_th_var(get_id()); + sat::literal_vector const& bits = m_bits[v]; + if (bits.empty()) + return; + expr_ref tmp = literal2expr(bits.back()); + for (unsigned i = bits.size() - 1; i-- > 0; ) { + auto b = bits[i]; + tmp = m.mk_or(literal2expr(b), tmp); + xs.push_back(tmp); + } + }; + + /** + * The i'th bit in xs is 1 if the least significant bit of x is i or lower. + */ + void solver::encode_lsb_tail(expr* x, expr_ref_vector& xs) { + theory_var v = expr2enode(x)->get_th_var(get_id()); + sat::literal_vector const& bits = m_bits[v]; + if (bits.empty()) + return; + expr_ref tmp = literal2expr(bits[0]); + for (unsigned i = 1; i < bits.size(); ++i) { + auto b = bits[i]; + tmp = m.mk_or(literal2expr(b), tmp); + xs.push_back(tmp); + } + }; + + /** + * Check non-overflow of unsigned multiplication. + * + * no_overflow(x, y) = > msb(x) + msb(y) <= sz; + * msb(x) + msb(y) < sz => no_overflow(x,y) + */ + bool solver::check_umul_no_overflow(app* n, expr_ref_vector const& arg_values, expr* value) { + SASSERT(arg_values.size() == 2); + SASSERT(m.is_true(value) || m.is_false(value)); + rational v0, v1; + unsigned sz; + VERIFY(bv.is_numeral(arg_values[0], v0, sz)); + VERIFY(bv.is_numeral(arg_values[1], v1)); + unsigned msb0 = v0.get_num_bits(); + unsigned msb1 = v1.get_num_bits(); + expr_ref_vector xs(m), ys(m), zs(m); + + if (m.is_true(value) && msb0 + msb1 > sz && !v0.is_zero() && !v1.is_zero()) { + sat::literal no_overflow = expr2literal(n); + encode_msb_tail(n->get_arg(0), xs); + encode_msb_tail(n->get_arg(1), ys); + for (unsigned i = 1; i <= sz; ++i) { + sat::literal bit0 = b_internalize(xs.get(i - 1)); + sat::literal bit1 = b_internalize(ys.get(sz - i)); + add_clause(~no_overflow, ~bit0, ~bit1); + } + return false; + } + else if (m.is_false(value) && msb0 + msb1 < sz) { + encode_msb_tail(n->get_arg(0), xs); + encode_msb_tail(n->get_arg(1), ys); + sat::literal_vector lits; + lits.push_back(expr2literal(n)); + for (unsigned i = 1; i < sz; ++i) { + expr_ref msb_ge_sz(m.mk_and(xs.get(i - 1), ys.get(sz - i - 1)), m); + lits.push_back(b_internalize(msb_ge_sz)); + } + add_clause(lits); + return false; + } + return true; + } + + bool solver::check_bv_eval(euf::enode* n) { + expr_ref_vector args(m); + app* a = n->get_app(); + SASSERT(bv.is_bv(a)); + auto r1 = eval_bv(n); + auto r2 = eval_args(n, args); if (r1 == r2) return true; + if (m_cheap_axioms) + return true; set_delay_internalize(a, internalize_mode::no_delay_i); - internalize_circuit(a, v); + internalize_circuit(a); + return false; + } + + bool solver::check_bool_eval(euf::enode* n) { + expr_ref_vector args(m); + SASSERT(m.is_bool(n->get_expr())); + sat::literal lit = expr2literal(n->get_expr()); + expr* r1 = m.mk_bool_val(s().value(lit) == l_true); + auto r2 = eval_args(n, args); + if (r1 == r2) + return true; + app* a = n->get_app(); + if (bv.is_bv_umul_no_ovfl(a) && !check_umul_no_overflow(a, args, r1)) + return false; + if (m_cheap_axioms) + return true; + set_delay_internalize(a, internalize_mode::no_delay_i); + internalize_circuit(a); return false; } void solver::set_delay_internalize(expr* e, internalize_mode mode) { if (!m_delay_internalize.contains(e)) ctx.push(insert_obj_map(m_delay_internalize, e)); + else + ctx.push(remove_obj_map(m_delay_internalize, e, m_delay_internalize[e])); m_delay_internalize.insert(e, mode); } @@ -120,6 +420,7 @@ namespace bv { return internalize_mode::no_delay_i; if (!get_config().m_bv_delay) return internalize_mode::no_delay_i; + internalize_mode mode; switch (to_app(e)->get_decl_kind()) { case OP_BMUL: case OP_BSMUL_NO_OVFL: @@ -129,18 +430,15 @@ namespace bv { case OP_BUREM_I: case OP_BSREM_I: case OP_BUDIV_I: - case OP_BSDIV_I: { + case OP_BSDIV_I: if (should_bit_blast(e)) return internalize_mode::no_delay_i; - internalize_mode mode = internalize_mode::init_bits_i; + mode = internalize_mode::delay_i; if (!m_delay_internalize.find(e, mode)) - set_delay_internalize(e, mode); - return mode; - } + m_delay_internalize.insert(e, mode); + return mode; default: return internalize_mode::no_delay_i; } } - - } diff --git a/src/sat/smt/bv_internalize.cpp b/src/sat/smt/bv_internalize.cpp index 0af6cc1cd..5b3fd0e4a 100644 --- a/src/sat/smt/bv_internalize.cpp +++ b/src/sat/smt/bv_internalize.cpp @@ -15,6 +15,7 @@ Author: --*/ +#include "params/bv_rewriter_params.hpp" #include "sat/smt/bv_solver.h" #include "sat/smt/euf_solver.h" #include "sat/smt/sat_th.h" @@ -22,26 +23,10 @@ Author: namespace bv { - solver::bit_atom& solver::atom::to_bit() { - SASSERT(is_bit()); - return dynamic_cast(*this); - } - - solver::def_atom& solver::atom::to_def() { - SASSERT(!is_bit()); - SASSERT(!is_eq()); - return dynamic_cast(*this); - } - - solver::eq_atom& solver::atom::to_eq() { - SASSERT(is_eq()); - return dynamic_cast(*this); - } - class solver::add_var_pos_trail : public trail { - solver::bit_atom* m_atom; + solver::atom* m_atom; public: - add_var_pos_trail(solver::bit_atom* a) :m_atom(a) {} + add_var_pos_trail(solver::atom* a) :m_atom(a) {} void undo(euf::solver& euf) override { SASSERT(m_atom->m_occs); m_atom->m_occs = m_atom->m_occs->m_next; @@ -49,9 +34,9 @@ namespace bv { }; class solver::add_eq_occurs_trail : public trail { - bit_atom* m_atom; + atom* m_atom; public: - add_eq_occurs_trail(bit_atom* a) :m_atom(a) {} + add_eq_occurs_trail(atom* a) :m_atom(a) {} void undo(euf::solver& euf) override { SASSERT(m_atom->m_eqs); m_atom->m_eqs = m_atom->m_eqs->m_next; @@ -61,10 +46,10 @@ namespace bv { }; class solver::del_eq_occurs_trail : public trail { - bit_atom* m_atom; + atom* m_atom; eq_occurs* m_node; public: - del_eq_occurs_trail(bit_atom* a, eq_occurs* n) : m_atom(a), m_node(n) {} + del_eq_occurs_trail(atom* a, eq_occurs* n) : m_atom(a), m_node(n) {} void undo(euf::solver& euf) override { if (m_node->m_next) m_node->m_next->m_prev = m_node; @@ -75,7 +60,7 @@ namespace bv { } }; - void solver::del_eq_occurs(bit_atom* a, eq_occurs* occ) { + void solver::del_eq_occurs(atom* a, eq_occurs* occ) { eq_occurs* prev = occ->m_prev; if (prev) prev->m_next = occ->m_next; @@ -104,8 +89,7 @@ namespace bv { m_bits.push_back(sat::literal_vector()); m_wpos.push_back(0); m_zero_one_bits.push_back(zero_one_bits()); - ctx.attach_th_var(n, this, r); - + ctx.attach_th_var(n, this, r); TRACE("bv", tout << "mk-var: " << r << " " << n->get_expr_id() << " " << mk_bounded_pp(n->get_expr(), m) << "\n";); return r; } @@ -157,14 +141,14 @@ namespace bv { SASSERT(!n->is_attached_to(get_id())); theory_var v = mk_var(n); SASSERT(n->is_attached_to(get_id())); - if (internalize_mode::init_bits_i == get_internalize_mode(a)) + if (internalize_mode::no_delay_i != get_internalize_mode(a)) mk_bits(n->get_th_var(get_id())); else - internalize_circuit(a, v); + internalize_circuit(a); return true; } - bool solver::internalize_circuit(app* a, theory_var v) { + bool solver::internalize_circuit(app* a) { std::function bin; std::function ebin; @@ -177,8 +161,9 @@ namespace bv { #define internalize_nfl(F) ebin = [&](unsigned sz, expr* const* xs, expr* const* ys, expr_ref& out) { m_bb.F(sz, xs, ys, out);}; internalize_novfl(a, ebin); switch (a->get_decl_kind()) { - case OP_BV_NUM: internalize_num(a, v); break; + case OP_BV_NUM: internalize_num(a); break; case OP_BNOT: internalize_un(mk_not); break; + case OP_BNEG: internalize_un(mk_neg); break; case OP_BREDAND: internalize_un(mk_redand); break; case OP_BREDOR: internalize_un(mk_redor); break; case OP_BSDIV_I: internalize_bin(mk_sdiv); break; @@ -199,7 +184,7 @@ namespace bv { case OP_BNAND: internalize_bin(mk_nand); break; case OP_BNOR: internalize_bin(mk_nor); break; case OP_BXNOR: internalize_bin(mk_xnor); break; - case OP_BCOMP: internalize_bin(mk_comp); break; + case OP_BCOMP: internalize_bin(mk_comp); break; case OP_SIGN_EXT: internalize_pun(mk_sign_extend); break; case OP_ZERO_EXT: internalize_pun(mk_zero_extend); break; case OP_ROTATE_LEFT: internalize_pun(mk_rotate_left); break; @@ -208,8 +193,14 @@ namespace bv { case OP_BSMUL_NO_OVFL: internalize_nfl(mk_smul_no_overflow); break; case OP_BSMUL_NO_UDFL: internalize_nfl(mk_smul_no_underflow); break; case OP_BIT2BOOL: internalize_bit2bool(a); break; - case OP_ULEQ: internalize_le(a); break; - case OP_SLEQ: internalize_le(a); break; + case OP_ULEQ: internalize_le(a); break; + case OP_SLEQ: internalize_le(a); break; + case OP_UGEQ: internalize_le(a); break; + case OP_SGEQ: internalize_le(a); break; + case OP_ULT: internalize_le(a); break; + case OP_SLT: internalize_le(a); break; + case OP_UGT: internalize_le(a); break; + case OP_SGT: internalize_le(a); break; case OP_XOR3: internalize_xor3(a); break; case OP_CARRY: internalize_carry(a); break; case OP_BSUB: internalize_sub(a); break; @@ -218,12 +209,14 @@ namespace bv { case OP_MKBV: internalize_mkbv(a); break; case OP_INT2BV: internalize_int2bv(a); break; case OP_BV2INT: internalize_bv2int(a); break; + case OP_BUDIV: internalize_udiv(a); break; case OP_BSDIV0: break; case OP_BUDIV0: break; case OP_BSREM0: break; case OP_BUREM0: break; case OP_BSMOD0: break; default: + IF_VERBOSE(0, verbose_stream() << mk_bounded_pp(a, m) << "\n"); UNREACHABLE(); break; } @@ -231,7 +224,7 @@ namespace bv { } void solver::mk_bits(theory_var v) { - TRACE("bv", tout << "v" << v << "\n";); + TRACE("bv", tout << "v" << v << "@" << s().scope_lvl() << "\n";); expr* e = var2expr(v); unsigned bv_size = get_bv_size(v); m_bits[v].reset(); @@ -239,6 +232,7 @@ namespace bv { expr_ref b2b(bv.mk_bit2bool(e, i), m); m_bits[v].push_back(sat::null_literal); sat::literal lit = ctx.internalize(b2b, false, false, m_is_redundant); + TRACE("bv", tout << "add-bit: " << lit << " " << literal2expr(lit) << "\n";); SASSERT(m_bits[v].back() == lit); } } @@ -263,6 +257,12 @@ namespace bv { } void solver::get_bits(theory_var v, expr_ref_vector& r) { + for (literal lit : m_bits[v]) { + if (!literal2expr(lit)) + ctx.display(std::cout << "Missing expression: " << lit << "\n"); + SASSERT(literal2expr(lit)); + } + for (literal lit : m_bits[v]) r.push_back(literal2expr(lit)); } @@ -289,22 +289,23 @@ namespace bv { void solver::add_bit(theory_var v, literal l) { unsigned idx = m_bits[v].size(); m_bits[v].push_back(l); + TRACE("bv", tout << "add-bit: v" << v << "[" << idx << "] " << l << " " << literal2expr(l) << "@" << s().scope_lvl() << "\n";); + SASSERT(m_num_scopes == 0); s().set_external(l.var()); set_bit_eh(v, l, idx); } - solver::bit_atom* solver::mk_bit_atom(sat::bool_var bv) { + solver::atom* solver::mk_atom(sat::bool_var bv) { atom* a = get_bv2a(bv); - if (a) - return a->is_bit() ? &a->to_bit() : nullptr; - else { - bit_atom* b = new (get_region()) bit_atom(); - insert_bv2a(bv, b); - ctx.push(mk_atom_trail(bv, *this)); - return b; - } + if (a) + return a; + a = new (get_region()) atom(); + insert_bv2a(bv, a); + ctx.push(mk_atom_trail(bv, *this)); + return a; } +#if 0 solver::eq_atom* solver::mk_eq_atom(sat::bool_var bv) { atom* a = get_bv2a(bv); if (a) @@ -316,6 +317,7 @@ namespace bv { return b; } } +#endif void solver::set_bit_eh(theory_var v, literal l, unsigned idx) { @@ -323,14 +325,12 @@ namespace bv { if (s().value(l) != l_undef && s().lvl(l) == 0) register_true_false_bit(v, idx); else if (m_bits[v].size() > 1) { - bit_atom* b = mk_bit_atom(l.var()); - if (b) { - if (b->m_occs) - find_new_diseq_axioms(*b, v, idx); - if (!b->is_fresh()) - ctx.push(add_var_pos_trail(b)); - b->m_occs = new (get_region()) var_pos_occ(v, idx, b->m_occs); - } + atom* b = mk_atom(l.var()); + if (b->m_occs) + find_new_diseq_axioms(*b, v, idx); + if (!b->is_fresh()) + ctx.push(add_var_pos_trail(b)); + b->m_occs = new (get_region()) var_pos_occ(v, idx, b->m_occs); } } @@ -339,7 +339,19 @@ namespace bv { SASSERT(get_bv_size(n) == bits.size()); SASSERT(euf::null_theory_var != n->get_th_var(get_id())); theory_var v = n->get_th_var(get_id()); - m_bits[v].reset(); + + if (!m_bits[v].empty()) { + SASSERT(bits.size() == m_bits[v].size()); + unsigned i = 0; + for (expr* bit : bits) { + sat::literal lit = ctx.internalize(bit, false, false, m_is_redundant); + TRACE("bv", tout << "set " << m_bits[v][i] << " == " << lit << "\n";); + add_clause(~lit, m_bits[v][i]); + add_clause(lit, ~m_bits[v][i]); + ++i; + } + return; + } for (expr* bit : bits) add_bit(v, ctx.internalize(bit, false, false, m_is_redundant)); for (expr* bit : bits) @@ -356,11 +368,13 @@ namespace bv { return get_bv_size(var2enode(v)); } - void solver::internalize_num(app* n, theory_var v) { + void solver::internalize_num(app* a) { numeral val; unsigned sz = 0; - SASSERT(expr2enode(n)->interpreted()); - VERIFY(bv.is_numeral(n, val, sz)); + euf::enode* n = expr2enode(a); + theory_var v = n->get_th_var(get_id()); + SASSERT(n->interpreted()); + VERIFY(bv.is_numeral(a, val, sz)); expr_ref_vector bits(m); m_bb.num2bits(val, sz, bits); SASSERT(bits.size() == sz); @@ -453,18 +467,20 @@ namespace bv { } } - template + template void solver::internalize_le(app* n) { SASSERT(n->get_num_args() == 2); expr_ref_vector arg1_bits(m), arg2_bits(m); - get_arg_bits(n, 0, arg1_bits); - get_arg_bits(n, 1, arg2_bits); + get_arg_bits(n, Rev ? 1 : 0, arg1_bits); + get_arg_bits(n, Rev ? 0 : 1, arg2_bits); expr_ref le(m); if (Signed) m_bb.mk_sle(arg1_bits.size(), arg1_bits.c_ptr(), arg2_bits.c_ptr(), le); else m_bb.mk_ule(arg1_bits.size(), arg1_bits.c_ptr(), arg2_bits.c_ptr(), le); literal def = ctx.internalize(le, false, false, m_is_redundant); + if (Negated) + def.neg(); add_def(def, expr2literal(n)); } @@ -498,6 +514,29 @@ namespace bv { add_clause(r, ~l1, ~l2, ~l3); } + void solver::internalize_udiv_i(app* a) { + std::function bin; + bin = [&](unsigned sz, expr* const* xs, expr* const* ys, expr_ref_vector& bits) { m_bb.mk_udiv(sz, xs, ys, bits); }; + internalize_binary(a, bin); + } + + void solver::internalize_udiv(app* n) { + bv_rewriter_params p(s().params()); + expr* arg1 = n->get_arg(0); + expr* arg2 = n->get_arg(1); + if (p.hi_div0()) { + expr_ref eq(m.mk_eq(n, bv.mk_bv_udiv_i(arg1, arg2)), m); + add_unit(b_internalize(eq)); + return; + } + unsigned sz = bv.get_bv_size(n); + expr_ref zero(bv.mk_numeral(0, sz), m); + expr_ref eq(m.mk_eq(arg2, zero), m); + expr_ref udiv(m.mk_ite(eq, bv.mk_bv_udiv0(arg1), bv.mk_bv_udiv_i(arg1, arg2)), m); + expr_ref eq2(m.mk_eq(n, udiv), m); + add_unit(b_internalize(eq2)); + } + void solver::internalize_unary(app* n, std::function& fn) { SASSERT(n->get_num_args() == 1); expr_ref_vector arg1_bits(m), bits(m); @@ -554,7 +593,9 @@ namespace bv { } void solver::add_def(sat::literal def, sat::literal l) { - def_atom* a = new (get_region()) def_atom(l, def); + atom* a = new (get_region()) atom(); + a->m_var = l; + a->m_def = def; insert_bv2a(l.var(), a); ctx.push(mk_atom_trail(l.var(), *this)); add_clause(l, ~def); @@ -612,8 +653,9 @@ namespace bv { sat::literal lit0 = m_bits[v_arg][idx]; if (lit0 == sat::null_literal) { m_bits[v_arg][idx] = lit; + TRACE("bv", tout << "add-bit: " << lit << " " << literal2expr(lit) << "\n";); if (arg_sz > 1) { - bit_atom* a = new (get_region()) bit_atom(); + atom* a = new (get_region()) atom(); a->m_occs = new (get_region()) var_pos_occ(v_arg, idx); insert_bv2a(lit.var(), a); ctx.push(mk_atom_trail(lit.var(), *this)); @@ -677,7 +719,7 @@ namespace bv { } void solver::eq_internalized(sat::bool_var b1, sat::bool_var b2, unsigned idx, theory_var v1, theory_var v2, literal lit, euf::enode* n) { - bit_atom* a = mk_bit_atom(b1); + atom* a = mk_atom(b1); // eq_atom* b = mk_eq_atom(lit.var()); if (a) { if (!a->is_fresh()) diff --git a/src/sat/smt/bv_invariant.cpp b/src/sat/smt/bv_invariant.cpp index 82fde885e..4af785460 100644 --- a/src/sat/smt/bv_invariant.cpp +++ b/src/sat/smt/bv_invariant.cpp @@ -23,8 +23,8 @@ namespace bv { void solver::validate_atoms() const { sat::bool_var v = 0; for (auto* a : m_bool_var2atom) { - if (a && a->is_bit()) { - for (auto vp : a->to_bit()) { + if (a) { + for (auto vp : *a) { SASSERT(m_bits[vp.first][vp.second].var() == v); VERIFY(m_bits[vp.first][vp.second].var() == v); } diff --git a/src/sat/smt/bv_solver.cpp b/src/sat/smt/bv_solver.cpp index e292e52f3..d09affb91 100644 --- a/src/sat/smt/bv_solver.cpp +++ b/src/sat/smt/bv_solver.cpp @@ -37,11 +37,11 @@ namespace bv { }; class solver::bit_occs_trail : public trail { - bit_atom& a; + atom& a; var_pos_occ* m_occs; public: - bit_occs_trail(solver& s, bit_atom& a): a(a), m_occs(a.m_occs) {} + bit_occs_trail(solver& s, atom& a): a(a), m_occs(a.m_occs) {} virtual void undo(euf::solver& euf) { IF_VERBOSE(1, verbose_stream() << "add back occurrences " << & a << "\n"); @@ -134,7 +134,7 @@ namespace bv { /** *\brief v[idx] = ~v'[idx], then v /= v' is a theory axiom. */ - void solver::find_new_diseq_axioms(bit_atom& a, theory_var v, unsigned idx) { + void solver::find_new_diseq_axioms(atom& a, theory_var v, unsigned idx) { if (!get_config().m_bv_eq_axioms) return; literal l = m_bits[v][idx]; @@ -180,12 +180,8 @@ namespace bv { } } else if (m.is_bool(e) && (a = m_bool_var2atom.get(expr2literal(e).var(), nullptr))) { - if (a->is_bit()) { - for (var_pos vp : a->to_bit()) - out << " " << var2enode(vp.first)->get_expr_id() << "[" << vp.second << "]"; - } - else - out << "def-atom"; + for (var_pos vp : *a) + out << " " << var2enode(vp.first)->get_expr_id() << "[" << vp.second << "]"; } else out << " " << mk_bounded_pp(e, m, 1); @@ -269,7 +265,7 @@ namespace bv { void solver::get_antecedents(literal l, sat::ext_justification_idx idx, literal_vector& r, bool probing) { auto& c = bv_justification::from_index(idx); - TRACE("bv", display_constraint(tout, idx);); + TRACE("bv", display_constraint(tout, idx) << "\n";); switch (c.m_kind) { case bv_justification::kind_t::eq2bit: SASSERT(s().value(c.m_antecedent) == l_true); @@ -389,12 +385,10 @@ namespace bv { void solver::asserted(literal l) { atom* a = get_bv2a(l.var()); TRACE("bv", tout << "asserted: " << l << "\n";); - if (a && a->is_bit()) { + if (a) { force_push(); - m_prop_queue.push_back(propagation_item(&a->to_bit())); - } - else if (a && a->is_eq()) { - for (auto p : a->to_eq().m_eqs) { + m_prop_queue.push_back(propagation_item(a)); + for (auto p : a->m_bit2occ) { del_eq_occurs(p.first, p.second); } } @@ -422,7 +416,6 @@ namespace bv { ++num_eq_assigned; } } - IF_VERBOSE(20, verbose_stream() << "atoms: " << num_atoms << " eqs: " << num_eqs << " atoms-assigned:" << num_assigned << " eqs-assigned: " << num_eq_assigned << " lits: " << num_lit_assigned << "\n"); } else propagate_bits(p.m_vp); @@ -495,20 +488,39 @@ namespace bv { return num_assigned > 0; } + /** + * Check each delay internalized bit-vector operation for compliance. + * + * TBD: add model-repair attempt after cheap propagation axioms have been added + */ sat::check_result solver::check() { force_push(); SASSERT(m_prop_queue.size() == m_prop_queue_head); bool ok = true; - for (auto kv : m_delay_internalize) { - if (ctx.is_relevant(kv.m_key) && - kv.m_value == internalize_mode::init_bits_i && - !check_delay_internalized(expr2enode(kv.m_key))) + svector> delay; + for (auto kv : m_delay_internalize) + delay.push_back(std::make_pair(kv.m_key, kv.m_value)); + flet _cheap1(m_cheap_axioms, true); + for (auto kv : delay) + if (!check_delay_internalized(kv.first)) ok = false; - } - return ok ? sat::check_result::CR_DONE : sat::check_result::CR_CONTINUE; + if (!ok) + return sat::check_result::CR_CONTINUE; + + // if (repair_model()) return sat::check_result::DONE; + + flet _cheap2(m_cheap_axioms, false); + for (auto kv : delay) + if (!check_delay_internalized(kv.first)) + ok = false; + + if (!ok) + return sat::check_result::CR_CONTINUE; + return sat::check_result::CR_DONE; } void solver::push_core() { + TRACE("bv", tout << "push: " << get_num_vars() << "@" << m_prop_queue_lim.size() << "\n";); th_euf_solver::push_core(); m_prop_queue_lim.push_back(m_prop_queue.size()); } @@ -523,22 +535,18 @@ namespace bv { m_bits.shrink(old_sz); m_wpos.shrink(old_sz); m_zero_one_bits.shrink(old_sz); + TRACE("bv", tout << "num vars " << old_sz << "@" << m_prop_queue_lim.size() << "\n";); } - void solver::pre_simplify() {} - void solver::simplify() { m_ackerman.propagate(); } bool solver::set_root(literal l, literal r) { atom* a = get_bv2a(l.var()); - atom* b = get_bv2a(r.var()); - if (!a || !a->is_bit()) + if (!a) return true; - if (b && !b->is_bit()) - return false; - for (auto vp : a->to_bit()) { + for (auto vp : *a) { sat::literal l2 = m_bits[vp.first][vp.second]; if (l2.var() == r.var()) continue; @@ -549,8 +557,8 @@ namespace bv { m_bits[vp.first][vp.second] = r2; set_bit_eh(vp.first, r2, vp.second); } - ctx.push(bit_occs_trail(*this, a->to_bit())); - a->to_bit().m_occs = nullptr; + ctx.push(bit_occs_trail(*this, *a)); + a->m_occs = nullptr; // validate_atoms(); return true; } @@ -621,8 +629,7 @@ namespace bv { return out << "bv <- v" << v1 << "[" << cidx << "] != v" << v2 << "[" << cidx << "] " << m_bits[v1][cidx] << " != " << m_bits[v2][cidx]; } case bv_justification::kind_t::ne2bit: - return out << "bv <- " << m_bits[v1] << " != " << m_bits[v2] << " @" << cidx; - break; + return out << "bv <- " << m_bits[v1] << " != " << m_bits[v2] << " @" << cidx; default: UNREACHABLE(); break; @@ -643,7 +650,7 @@ namespace bv { sat::extension* solver::copy(sat::solver* s) { UNREACHABLE(); return nullptr; } - euf::th_solver* solver::fresh(sat::solver* s, euf::solver& ctx) { + euf::th_solver* solver::clone(sat::solver* s, euf::solver& ctx) { bv::solver* result = alloc(bv::solver, ctx, get_id()); ast_translation tr(m, ctx.get_manager()); for (unsigned i = 0; i < get_num_vars(); ++i) { @@ -664,22 +671,18 @@ namespace bv { if (!a) continue; - if (a->is_bit()) { - bit_atom* new_a = new (result->get_region()) bit_atom(); - m_bool_var2atom.setx(i, new_a, nullptr); - for (auto vp : a->to_bit()) - new_a->m_occs = new (result->get_region()) var_pos_occ(vp.first, vp.second, new_a->m_occs); - for (auto const& occ : a->to_bit().eqs()) { - expr* e = occ.m_node->get_expr(); - expr_ref e2(tr(e), tr.to()); - euf::enode* n = ctx.get_enode(e2); - new_a->m_eqs = new (result->get_region()) eq_occurs(occ.m_bv1, occ.m_bv2, occ.m_idx, occ.m_v1, occ.m_v2, occ.m_literal, n, new_a->m_eqs); - } - } - else { - def_atom* new_a = new (result->get_region()) def_atom(a->to_def().m_var, a->to_def().m_def); - m_bool_var2atom.setx(i, new_a, nullptr); + atom* new_a = new (result->get_region()) atom(); + m_bool_var2atom.setx(i, new_a, nullptr); + for (auto vp : *a) + new_a->m_occs = new (result->get_region()) var_pos_occ(vp.first, vp.second, new_a->m_occs); + for (auto const& occ : a->eqs()) { + expr* e = occ.m_node->get_expr(); + expr_ref e2(tr(e), tr.to()); + euf::enode* n = ctx.get_enode(e2); + new_a->m_eqs = new (result->get_region()) eq_occurs(occ.m_bv1, occ.m_bv2, occ.m_idx, occ.m_v1, occ.m_v2, occ.m_literal, n, new_a->m_eqs); } + new_a->m_def = a->m_def; + new_a->m_var = a->m_var; validate_atoms(); } return result; @@ -783,9 +786,7 @@ namespace bv { return sat::justification::mk_ext_justification(s().scope_lvl(), constraint->to_index()); } - bool solver::assign_bit(literal consequent, theory_var v1, theory_var v2, unsigned idx, literal antecedent, bool propagate_eqc) { - m_stats.m_num_eq2bit++; SASSERT(ctx.s().value(antecedent) == l_true); SASSERT(m_bits[v2][idx].var() == consequent.var()); @@ -801,8 +802,8 @@ namespace bv { find_wpos(v2); bool_var cv = consequent.var(); atom* a = get_bv2a(cv); - if (a && a->is_bit()) - for (auto curr : a->to_bit()) + if (a) + for (auto curr : *a) if (propagate_eqc || find(curr.first) != find(v2) || curr.second != idx) m_prop_queue.push_back(propagation_item(curr)); return true; diff --git a/src/sat/smt/bv_solver.h b/src/sat/smt/bv_solver.h index 82c72ca1d..d3deea85d 100644 --- a/src/sat/smt/bv_solver.h +++ b/src/sat/smt/bv_solver.h @@ -139,26 +139,10 @@ namespace bv { eq_occurs_it end() const { return eq_occurs_it(nullptr); } }; - struct bit_atom; - struct def_atom; - struct eq_atom; - class atom { - public: - - atom() {} - virtual ~atom() {} - virtual bool is_bit() const { return false; } - virtual bool is_eq() const { return false; } - bit_atom& to_bit(); - def_atom& to_def(); - eq_atom& to_eq(); - - }; - struct var_pos_occ { var_pos m_vp; - var_pos_occ * m_next; - var_pos_occ(theory_var v = euf::null_theory_var, unsigned idx = 0, var_pos_occ * next = nullptr):m_vp(v, idx), m_next(next) {} + var_pos_occ* m_next; + var_pos_occ(theory_var v = euf::null_theory_var, unsigned idx = 0, var_pos_occ* next = nullptr) :m_vp(v, idx), m_next(next) {} }; class var_pos_it { @@ -172,37 +156,24 @@ namespace bv { bool operator!=(var_pos_it const& other) const { return !(*this == other); } }; - struct bit_atom : public atom { + struct atom { eq_occurs* m_eqs; var_pos_occ * m_occs; - bit_atom() :m_eqs(nullptr), m_occs(nullptr) {} - ~bit_atom() override {} - bool is_bit() const override { return true; } + svector> m_bit2occ; + literal m_var { sat::null_literal }; + literal m_def { sat::null_literal }; + atom() :m_eqs(nullptr), m_occs(nullptr) {} + ~atom() { m_bit2occ.clear(); } var_pos_it begin() const { return var_pos_it(m_occs); } var_pos_it end() const { return var_pos_it(nullptr); } bool is_fresh() const { return !m_occs && !m_eqs; } - eqs_iterator eqs() const { return eqs_iterator(m_eqs); } - }; - - struct eq_atom : public atom { - eq_atom(){} - ~eq_atom() override { m_eqs.clear(); } - bool is_bit() const override { return false; } - bool is_eq() const override { return true; } - svector> m_eqs; - }; - - struct def_atom : public atom { - literal m_var; - literal m_def; - def_atom(literal v, literal d):m_var(v), m_def(d) {} - ~def_atom() override {} + eqs_iterator eqs() const { return eqs_iterator(m_eqs); } }; struct propagation_item { var_pos m_vp { var_pos(0, 0) }; - bit_atom* m_atom{ nullptr }; - explicit propagation_item(bit_atom* a) : m_atom(a) {} + atom* m_atom{ nullptr }; + explicit propagation_item(atom* a) : m_atom(a) {} explicit propagation_item(var_pos const& vp) : m_vp(vp) {} }; @@ -253,22 +224,21 @@ namespace bv { sat::status status() const { return sat::status::th(m_is_redundant, get_id()); } void register_true_false_bit(theory_var v, unsigned i); void add_bit(theory_var v, sat::literal lit); - bit_atom* mk_bit_atom(sat::bool_var b); - eq_atom* mk_eq_atom(sat::bool_var b); + atom* mk_atom(sat::bool_var b); void eq_internalized(sat::bool_var b1, sat::bool_var b2, unsigned idx, theory_var v1, theory_var v2, sat::literal eq, euf::enode* n); - void del_eq_occurs(bit_atom* a, eq_occurs* occ); + void del_eq_occurs(atom* a, eq_occurs* occ); void set_bit_eh(theory_var v, literal l, unsigned idx); void init_bits(expr* e, expr_ref_vector const & bits); void mk_bits(theory_var v); void add_def(sat::literal def, sat::literal l); - bool internalize_circuit(app* a, theory_var v); + bool internalize_circuit(app* a); void internalize_unary(app* n, std::function& fn); void internalize_binary(app* n, std::function& fn); void internalize_ac_binary(app* n, std::function& fn); void internalize_par_unary(app* n, std::function& fn); void internalize_novfl(app* n, std::function& fn); - void internalize_num(app * n, theory_var v); + void internalize_num(app * n); void internalize_concat(app * n); void internalize_bv2int(app* n); void internalize_int2bv(app* n); @@ -278,7 +248,9 @@ namespace bv { void internalize_sub(app* n); void internalize_extract(app* n); void internalize_bit2bool(app* n); - template + void internalize_udiv(app* n); + void internalize_udiv_i(app* n); + template void internalize_le(app* n); void assert_bv2int_axiom(app * n); void assert_int2bv_axiom(app* n); @@ -286,23 +258,34 @@ namespace bv { // delay internalize enum class internalize_mode { + delay_i, no_delay_i, - init_bits_i + init_bits_only_i }; obj_map m_delay_internalize; + bool m_cheap_axioms{ true }; bool should_bit_blast(expr * n); - bool check_delay_internalized(euf::enode* n); - bool check_mul(euf::enode* n); - bool check_eval(euf::enode* n); + bool check_delay_internalized(expr* e); + bool check_mul(app* e); + bool check_mul_invertibility(app* n, expr_ref_vector const& arg_values, expr* value); + bool check_mul_zero(app* n, expr_ref_vector const& arg_values, expr* value1, expr* value2); + bool check_mul_one(app* n, expr_ref_vector const& arg_values, expr* value1, expr* value2); + bool check_mul_low_bits(app* n, expr_ref_vector const& arg_values, expr* value1, expr* value2); + bool check_umul_no_overflow(app* n, expr_ref_vector const& arg_values, expr* value); + bool check_bv_eval(euf::enode* n); + bool check_bool_eval(euf::enode* n); + void encode_msb_tail(expr* x, expr_ref_vector& xs); + void encode_lsb_tail(expr* x, expr_ref_vector& xs); internalize_mode get_internalize_mode(expr* e); void set_delay_internalize(expr* e, internalize_mode mode); - void eval_args(euf::enode* n, vector& args); - + expr_ref eval_args(euf::enode* n, expr_ref_vector& eargs); + expr_ref eval_bv(euf::enode* n); + // solving theory_var find(theory_var v) const { return m_find.find(v); } void find_wpos(theory_var v); - void find_new_diseq_axioms(bit_atom& a, theory_var v, unsigned idx); + void find_new_diseq_axioms(atom& a, theory_var v, unsigned idx); void mk_new_diseq_axiom(theory_var v1, theory_var v2, unsigned idx); bool get_fixed_value(theory_var v, numeral& result) const; void add_fixed_eq(theory_var v1, theory_var v2); @@ -332,8 +315,7 @@ namespace bv { void asserted(literal l) override; sat::check_result check() override; void push_core() override; - void pop_core(unsigned n) override; - void pre_simplify() override; + void pop_core(unsigned n) override; void simplify() override; bool set_root(literal l, literal r) override; void flush_roots() override; @@ -343,7 +325,7 @@ namespace bv { std::ostream& display_justification(std::ostream& out, sat::ext_justification_idx idx) const override; std::ostream& display_constraint(std::ostream& out, sat::ext_constraint_idx idx) const override; void collect_statistics(statistics& st) const override; - euf::th_solver* fresh(sat::solver* s, euf::solver& ctx) override; + euf::th_solver* clone(sat::solver* s, euf::solver& ctx) override; extension* copy(sat::solver* s) override; void find_mutexes(literal_vector& lits, vector & mutexes) override {} void gc() override {} diff --git a/src/sat/smt/euf_internalize.cpp b/src/sat/smt/euf_internalize.cpp index 144494c33..263910669 100644 --- a/src/sat/smt/euf_internalize.cpp +++ b/src/sat/smt/euf_internalize.cpp @@ -132,6 +132,7 @@ namespace euf { SASSERT(m_egraph.find(e)->bool_var() == v); return lit; } + TRACE("euf", tout << "attach " << v << " " << mk_bounded_pp(e, m) << "\n";); m_var2expr[v] = e; m_var_trail.push_back(v); enode* n = m_egraph.find(e); @@ -299,20 +300,18 @@ namespace euf { if (m.is_ite(n->get_expr())) return true; - theory_id th_id = null_theory_id; - for (auto p : euf::enode_th_vars(n)) { - if (th_id == null_theory_id) - th_id = p.get_id(); - else - return true; - } - if (th_id == null_theory_id) - return false; - // the variable is shared if the equivalence class of n // contains a parent application. - for (euf::enode* parent : euf::enode_parents(n)) { + family_id th_id = m.get_basic_family_id(); + for (auto p : euf::enode_th_vars(n)) { + if (m.get_basic_family_id() != p.get_id()) { + th_id = p.get_id(); + break; + } + } + + for (enode* parent : euf::enode_parents(n)) { app* p = to_app(parent->get_expr()); family_id fid = p->get_family_id(); if (fid != th_id && fid != m.get_basic_family_id()) @@ -345,9 +344,13 @@ namespace euf { // the theories of (array int int) and (array (array int int) int). // Remark: The inconsistency is not going to be detected if they are // not marked as shared. - return true; - // TODO - // return get_theory(th_id)->is_shared(l->get_var()); + + for (auto p : euf::enode_th_vars(n)) + if (fid2solver(p.get_id())->is_shared(p.get_var())) + return true; + + return false; + } expr_ref solver::mk_eq(expr* e1, expr* e2) { diff --git a/src/sat/smt/euf_model.cpp b/src/sat/smt/euf_model.cpp index faab2145f..84aa98e3d 100644 --- a/src/sat/smt/euf_model.cpp +++ b/src/sat/smt/euf_model.cpp @@ -44,20 +44,55 @@ namespace euf { void solver::collect_dependencies(deps_t& deps) { for (enode* n : m_egraph.nodes()) { - if (n->num_args() == 0) { - deps.insert(n, nullptr); - continue; - } - auto* mb = expr2solver(n->get_expr()); + auto* mb = sort2solver(m.get_sort(n->get_expr())); if (mb) mb->add_dep(n, deps); else deps.insert(n, nullptr); } + + TRACE("euf", + for (auto const& d : deps.deps()) + if (d.m_value) { + tout << mk_bounded_pp(d.m_key->get_expr(), m) << ":\n"; + for (auto* n : *d.m_value) + tout << " " << mk_bounded_pp(n->get_expr(), m) << "\n"; + } + ); } + class solver::user_sort { + solver& s; + ast_manager& m; + model_ref& mdl; + expr_ref_vector& values; + user_sort_factory factory; + scoped_ptr_vector sort_values; + obj_map sort2values; + public: + user_sort(solver& s, expr_ref_vector& values, model_ref& mdl): + s(s), m(s.m), mdl(mdl), values(values), factory(m) {} + + ~user_sort() { + for (auto kv : sort2values) + mdl->register_usort(kv.m_key, kv.m_value->size(), kv.m_value->c_ptr()); + } + + void add(unsigned id, sort* srt) { + expr_ref value(factory.get_fresh_value(srt), m); + values.set(id, value); + expr_ref_vector* vals = nullptr; + if (!sort2values.find(srt, vals)) { + vals = alloc(expr_ref_vector, m); + sort2values.insert(srt, vals); + sort_values.push_back(vals); + } + vals->push_back(value); + } + }; + void solver::dependencies2values(deps_t& deps, expr_ref_vector& values, model_ref& mdl) { - user_sort_factory user_sort(m); + user_sort user_sort(*this, values, mdl); for (enode* n : deps.top_sorted()) { unsigned id = n->get_root_id(); if (values.get(id, nullptr)) @@ -94,16 +129,15 @@ namespace euf { } continue; } - auto* mb = expr2solver(e); - if (mb) - mb->add_value(n, *mdl, values); - else if (m.is_uninterp(m.get_sort(e))) { - expr* v = user_sort.get_fresh_value(m.get_sort(e)); - values.set(id, v); - } - else if ((mb = sort2solver(m.get_sort(e)))) - mb->add_value(n, *mdl, values); - else { + TRACE("euf", tout << "value for " << mk_bounded_pp(e, m) << "\n";); + sort* srt = m.get_sort(e); + if (m.is_uninterp(srt)) + user_sort.add(id, srt); + else if (auto* mbS = sort2solver(srt)) + mbS->add_value(n, *mdl, values); + else if (auto* mbE = expr2solver(e)) + mbE->add_value(n, *mdl, values); + else { IF_VERBOSE(1, verbose_stream() << "no model values created for " << mk_pp(e, m) << "\n"); } } diff --git a/src/sat/smt/euf_relevancy.cpp b/src/sat/smt/euf_relevancy.cpp index 65726a150..5f0d9010c 100644 --- a/src/sat/smt/euf_relevancy.cpp +++ b/src/sat/smt/euf_relevancy.cpp @@ -28,9 +28,8 @@ namespace euf { } void solver::add_root(unsigned n, sat::literal const* lits) { - bool_var v = s().add_var(false); ensure_dual_solver(); - m_dual_solver->add_root(sat::literal(v, false), n, lits); + m_dual_solver->add_root(n, lits); } void solver::add_aux(unsigned n, sat::literal const* lits) { @@ -70,9 +69,26 @@ namespace euf { m_relevant_expr_ids.setx(e->get_id(), true, false); if (!is_app(e)) continue; + expr* c = nullptr, *th = nullptr, *el = nullptr; + if (m.is_ite(e, c, th, el)) { + sat::literal lit = expr2literal(c); + todo.push_back(c); + switch (s().value(lit)) { + case l_true: + todo.push_back(th); + break; + case l_false: + todo.push_back(el); + break; + default: + todo.push_back(th); + todo.push_back(el); + break; + } + continue; + } for (expr* arg : *to_app(e)) - if (!visited.get(arg->get_id(), false)) - todo.push_back(arg); + todo.push_back(arg); } TRACE("euf", diff --git a/src/sat/smt/euf_solver.cpp b/src/sat/smt/euf_solver.cpp index d7c08a14d..0f6d69a8c 100644 --- a/src/sat/smt/euf_solver.cpp +++ b/src/sat/smt/euf_solver.cpp @@ -22,6 +22,7 @@ Author: #include "sat/smt/ba_solver.h" #include "sat/smt/bv_solver.h" #include "sat/smt/euf_solver.h" +#include "sat/smt/array_solver.h" namespace euf { @@ -79,6 +80,7 @@ namespace euf { return nullptr; pb_util pb(m); bv_util bvu(m); + array_util au(m); if (pb.get_family_id() == fid) { ext = alloc(sat::ba_solver, *this, fid); if (use_drat()) @@ -89,6 +91,11 @@ namespace euf { if (use_drat()) s().get_drat().add_theory(fid, symbol("bv")); } + else if (au.get_family_id() == fid) { + ext = alloc(array::solver, *this, fid); + if (use_drat()) + s().get_drat().add_theory(fid, symbol("array")); + } if (ext) { ext->set_solver(m_solver); ext->push_scopes(s().num_scopes()); @@ -210,10 +217,11 @@ namespace euf { void solver::asserted(literal l) { expr* e = m_var2expr.get(l.var(), nullptr); - if (!e) + if (!e) { + TRACE("euf", tout << "asserted: " << l << "@" << s().scope_lvl() << "\n";); return; - - TRACE("euf", tout << "asserted: " << mk_bounded_pp(e, m) << " := " << l << "@" << s().scope_lvl() << "\n";); + } + TRACE("euf", tout << "asserted: " << l << "@" << s().scope_lvl() << " := " << mk_bounded_pp(e, m) << "\n";); euf::enode* n = m_egraph.find(e); if (!n) return; @@ -370,18 +378,30 @@ namespace euf { m_egraph.push(); } + void solver::user_push() { + push(); + if (m_dual_solver) + m_dual_solver->push(); + } + + void solver::user_pop(unsigned n) { + pop(n); + if (m_dual_solver) + m_dual_solver->pop(n); + } + void solver::pop(unsigned n) { start_reinit(n); - m_egraph.pop(n); + m_trail.pop_scope(n); for (auto* e : m_solvers) e->pop(n); + si.pop(n); + m_egraph.pop(n); scope const & s = m_scopes[m_scopes.size() - n]; for (unsigned i = m_var_trail.size(); i-- > s.m_var_lim; ) m_var2expr[m_var_trail[i]] = nullptr; m_var_trail.shrink(s.m_var_lim); - m_trail.pop_scope(n); m_scopes.shrink(m_scopes.size() - n); - si.pop(n); SASSERT(m_egraph.num_scopes() == m_scopes.size()); TRACE("euf", tout << "pop to: " << m_scopes.size() << "\n";); } @@ -424,8 +444,9 @@ namespace euf { if (replay.m.empty()) return; - TRACE("euf", for (auto const& kv : replay.m) tout << "replay: " << kv.m_value << " " << mk_bounded_pp(kv.m_key, m) << "\n";); + TRACE("euf", for (auto const& kv : replay.m) tout << kv.m_value << "\n";); for (auto const& kv : replay.m) { + TRACE("euf", tout << "replay: " << kv.m_value << " " << mk_bounded_pp(kv.m_key, m) << "\n";); sat::literal lit; expr* e = kv.m_key; if (si.is_bool_op(e)) @@ -557,7 +578,7 @@ namespace euf { for (unsigned i = 0; i < m_id2solver.size(); ++i) { auto* e = m_id2solver[i]; if (e) - r->add_solver(i, e->fresh(s, *r)); + r->add_solver(i, e->clone(s, *r)); } return r; } diff --git a/src/sat/smt/euf_solver.h b/src/sat/smt/euf_solver.h index ace5b383e..5f33f743c 100644 --- a/src/sat/smt/euf_solver.h +++ b/src/sat/smt/euf_solver.h @@ -54,6 +54,7 @@ namespace euf { class solver : public sat::extension, public th_internalizer, public th_decompile { typedef top_sort deps_t; friend class ackerman; + class user_sort; // friend class sat::ba_solver; struct stats { unsigned m_ackerman; @@ -129,7 +130,7 @@ namespace euf { th_solver* func_decl2solver(func_decl* f) { return get_solver(f->get_family_id(), f); } th_solver* expr2solver(expr* e); th_solver* bool_var2solver(sat::bool_var v); - th_solver* fid2solver(family_id fid) { return m_id2solver.get(fid, nullptr); } + th_solver* fid2solver(family_id fid) const { return m_id2solver.get(fid, nullptr); } void add_solver(family_id fid, th_solver* th); void init_ackerman(); @@ -234,6 +235,8 @@ namespace euf { sat::check_result check() override; void push() override; void pop(unsigned n) override; + void user_push(); + void user_pop(unsigned n); void pre_simplify() override; void simplify() override; // have a way to replace l by r in all constraints diff --git a/src/sat/smt/sat_dual_solver.cpp b/src/sat/smt/sat_dual_solver.cpp index 680cfd865..deb1fe4db 100644 --- a/src/sat/smt/sat_dual_solver.cpp +++ b/src/sat/smt/sat_dual_solver.cpp @@ -65,10 +65,11 @@ namespace sat { return literal(m_var2ext[lit.var()], lit.sign()); } - void dual_solver::add_root(literal lit, unsigned sz, literal const* clause) { + void dual_solver::add_root(unsigned sz, literal const* clause) { + literal root(m_solver.mk_var(), false); for (unsigned i = 0; i < sz; ++i) - m_solver.mk_clause(ext2lit(lit), ~ext2lit(clause[i]), status::input()); - m_roots.push_back(~ext2lit(lit)); + m_solver.mk_clause(root, ~ext2lit(clause[i]), status::input()); + m_roots.push_back(~root); } void dual_solver::add_aux(unsigned sz, literal const* clause) { @@ -89,7 +90,7 @@ namespace sat { if (is_sat == l_false) for (literal lit : m_solver.get_core()) m_core.push_back(lit2ext(lit)); - TRACE("euf", m_solver.display(tout << m_core << "\n");); + TRACE("euf", m_solver.display(tout << "is-sat: " << is_sat << " core: " << m_core << "\n");); m_solver.user_pop(1); return is_sat == l_false; } diff --git a/src/sat/smt/sat_dual_solver.h b/src/sat/smt/sat_dual_solver.h index 374ecbf19..76197c234 100644 --- a/src/sat/smt/sat_dual_solver.h +++ b/src/sat/smt/sat_dual_solver.h @@ -47,7 +47,7 @@ namespace sat { * Add a root clause from the input problem. * At least one literal has to be satisfied in every root. */ - void add_root(literal lit, unsigned sz, literal const* clause); + void add_root(unsigned sz, literal const* clause); /* * Add auxiliary clauses that originate from compiling definitions. diff --git a/src/sat/smt/sat_th.cpp b/src/sat/smt/sat_th.cpp index ddb65a5b4..61fa091ff 100644 --- a/src/sat/smt/sat_th.cpp +++ b/src/sat/smt/sat_th.cpp @@ -116,23 +116,43 @@ namespace euf { pop_core(n); } + sat::status th_euf_solver::mk_status() { + return sat::status::th(m_is_redundant, get_id()); + } + bool th_euf_solver::add_unit(sat::literal lit) { - return !is_true(lit) && (ctx.s().add_clause(1, &lit, sat::status::th(m_is_redundant, get_id())), true); + bool was_true = is_true(lit); + ctx.s().add_clause(1, &lit, mk_status()); + return !was_true; } bool th_euf_solver::add_clause(sat::literal a, sat::literal b) { + bool was_true = is_true(a, b); sat::literal lits[2] = { a, b }; - return !is_true(a, b) && (ctx.s().add_clause(2, lits, sat::status::th(m_is_redundant, get_id())), true); + ctx.s().add_clause(2, lits, mk_status()); + return !was_true; } bool th_euf_solver::add_clause(sat::literal a, sat::literal b, sat::literal c) { + bool was_true = is_true(a, b, c); sat::literal lits[3] = { a, b, c }; - return !is_true(a, b, c) && (ctx.s().add_clause(3, lits, sat::status::th(m_is_redundant, get_id())), true); + ctx.s().add_clause(3, lits, mk_status()); + return !was_true; } bool th_euf_solver::add_clause(sat::literal a, sat::literal b, sat::literal c, sat::literal d) { + bool was_true = is_true(a, b, c, d); sat::literal lits[4] = { a, b, c, d }; - return !is_true(a, b, c, d) && (ctx.s().add_clause(4, lits, sat::status::th(m_is_redundant, get_id())), true); + ctx.s().add_clause(4, lits, mk_status()); + return !was_true; + } + + bool th_euf_solver::add_clause(sat::literal_vector const& lits) { + bool was_true = false; + for (auto lit : lits) + was_true |= is_true(lit); + s().add_clause(lits.size(), lits.c_ptr(), mk_status()); + return !was_true; } bool th_euf_solver::is_true(sat::literal lit) { diff --git a/src/sat/smt/sat_th.h b/src/sat/smt/sat_th.h index 8f2bad6fd..b384700d2 100644 --- a/src/sat/smt/sat_th.h +++ b/src/sat/smt/sat_th.h @@ -94,7 +94,7 @@ namespace euf { public: th_solver(ast_manager& m, euf::theory_id id): extension(id), m(m) {} - virtual th_solver* fresh(sat::solver* s, euf::solver& ctx) = 0; + virtual th_solver* clone(sat::solver* s, euf::solver& ctx) = 0; virtual void new_eq_eh(euf::th_eq const& eq) {} @@ -121,11 +121,13 @@ namespace euf { region& get_region(); + sat::status mk_status(); bool add_unit(sat::literal lit); bool add_clause(sat::literal lit) { return add_unit(lit); } bool add_clause(sat::literal a, sat::literal b); bool add_clause(sat::literal a, sat::literal b, sat::literal c); bool add_clause(sat::literal a, sat::literal b, sat::literal c, sat::literal d); + bool add_clause(sat::literal_vector const& lits); bool is_true(sat::literal lit); bool is_true(sat::literal a, sat::literal b) { return is_true(a) || is_true(b); } diff --git a/src/sat/smt/user_solver.cpp b/src/sat/smt/user_solver.cpp index e12e57a6d..f7fdb9364 100644 --- a/src/sat/smt/user_solver.cpp +++ b/src/sat/smt/user_solver.cpp @@ -146,7 +146,7 @@ namespace user { return display_justification(out, idx); } - euf::th_solver* solver::fresh(sat::solver* dst_s, euf::solver& dst_ctx) { + euf::th_solver* solver::clone(sat::solver* dst_s, euf::solver& dst_ctx) { auto* result = alloc(solver, dst_ctx); result->set_solver(dst_s); ast_translation tr(m, dst_ctx.get_manager(), false); diff --git a/src/sat/smt/user_solver.h b/src/sat/smt/user_solver.h index c7df05ad0..2d8b530e1 100644 --- a/src/sat/smt/user_solver.h +++ b/src/sat/smt/user_solver.h @@ -124,7 +124,7 @@ namespace user { std::ostream& display(std::ostream& out) const override; std::ostream& display_justification(std::ostream& out, sat::ext_justification_idx idx) const override; std::ostream& display_constraint(std::ostream& out, sat::ext_constraint_idx idx) const override; - euf::th_solver* fresh(sat::solver* s, euf::solver& ctx) override; + euf::th_solver* clone(sat::solver* s, euf::solver& ctx) override; }; }; diff --git a/src/sat/tactic/goal2sat.cpp b/src/sat/tactic/goal2sat.cpp index 174dc06fc..ad0f3adad 100644 --- a/src/sat/tactic/goal2sat.cpp +++ b/src/sat/tactic/goal2sat.cpp @@ -791,7 +791,7 @@ struct goal2sat::imp : public sat::sat_internalizer { sat::literal result = m_result_stack.back(); m_result_stack.pop_back(); if (!result.sign() && m_map.to_bool_var(n) == sat::null_bool_var) - m_map.insert(n, result.var()); + m_map.insert(n, result.var()); return result; } diff --git a/src/smt/smt_context.cpp b/src/smt/smt_context.cpp index cfe9f01d3..abfa8bcd6 100644 --- a/src/smt/smt_context.cpp +++ b/src/smt/smt_context.cpp @@ -4618,47 +4618,6 @@ namespace smt { TRACE("model", tout << *m_model << "\n";); } - expr_ref context::get_implied_value(expr* e) { - pop_to_search_lvl(); - if (m.is_bool(e)) { - if (b_internalized(e)) { - switch (get_assignment(get_bool_var(e))) { - case l_true: e = m.mk_true(); break; - case l_false: e = m.mk_false(); break; - default: break; - } - } - return expr_ref(e, m); - } - - if (e_internalized(e)) { - enode* n = get_enode(e); - for (enode* r : *n) { - if (m.is_value(r->get_owner())) { - return expr_ref(r->get_owner(), m); - } - } - } - - arith_value av(m); - av.init(this); - return av.get_fixed(e); - } - - expr_ref context::get_implied_lower_bound(expr* e) { - pop_to_search_lvl(); - arith_value av(m); - av.init(this); - return av.get_lo(e); - } - - expr_ref context::get_implied_upper_bound(expr* e) { - pop_to_search_lvl(); - arith_value av(m); - av.init(this); - return av.get_up(e); - } - }; diff --git a/src/smt/smt_context.h b/src/smt/smt_context.h index c499ae7ee..1f1bfb381 100644 --- a/src/smt/smt_context.h +++ b/src/smt/smt_context.h @@ -585,13 +585,6 @@ namespace smt { return get_bdata(v).get_theory(); } - expr_ref get_implied_value(expr* e); - - expr_ref get_implied_lower_bound(expr* e); - - expr_ref get_implied_upper_bound(expr* e); - - friend class set_var_theory_trail; void set_var_theory(bool_var v, theory_id tid); diff --git a/src/smt/smt_context_pp.cpp b/src/smt/smt_context_pp.cpp index 0d84ede00..492346c90 100644 --- a/src/smt/smt_context_pp.cpp +++ b/src/smt/smt_context_pp.cpp @@ -179,8 +179,12 @@ namespace smt { std::ostream& context::display_clauses(std::ostream & out, ptr_vector const & v) const { for (clause* cp : v) { out << "("; - for (auto lit : *cp) - out << lit << " "; + bool first = true; + for (auto lit : *cp) { + if (!first) out << " "; + first = false; + out << lit; + } out << ")\n"; } return out; @@ -385,20 +389,7 @@ namespace smt { st.update("max generation", m_stats.m_max_generation); st.update("minimized lits", m_stats.m_num_minimized_lits); st.update("num checks", m_stats.m_num_checks); - st.update("mk bool var", m_stats.m_num_mk_bool_var); - -#if 0 - // missing? - st.update("mk lit", m_stats.m_num_mk_lits); - st.update("sat conflicts", m_stats.m_num_sat_conflicts); - st.update("del bool var", m_stats.m_num_del_bool_var); - st.update("mk enode", m_stats.m_num_mk_enode); - st.update("del enode", m_stats.m_num_del_enode); - st.update("mk bin clause", m_stats.m_num_mk_bin_clause); - st.update("backwd subs", m_stats.m_num_bs); - st.update("backwd subs res", m_stats.m_num_bsr); - st.update("frwrd subs res", m_stats.m_num_fsr); -#endif + st.update("mk bool var", m_stats.m_num_mk_bool_var ? m_stats.m_num_mk_bool_var - 1 : 0); m_qmanager->collect_statistics(st); m_asserted_formulas.collect_statistics(st); for (theory* th : m_theory_set) { diff --git a/src/smt/smt_kernel.cpp b/src/smt/smt_kernel.cpp index 4d4cf8000..ecd443a55 100644 --- a/src/smt/smt_kernel.cpp +++ b/src/smt/smt_kernel.cpp @@ -131,18 +131,6 @@ namespace smt { lbool find_mutexes(expr_ref_vector const& vars, vector& mutexes) { return m_kernel.find_mutexes(vars, mutexes); } - - expr_ref get_implied_value(expr* e) { - return m_kernel.get_implied_value(e); - } - - expr_ref get_implied_lower_bound(expr* e) { - return m_kernel.get_implied_lower_bound(e); - } - - expr_ref get_implied_upper_bound(expr* e) { - return m_kernel.get_implied_upper_bound(e); - } void get_model(model_ref & m) { m_kernel.get_model(m); @@ -461,18 +449,6 @@ namespace smt { return m_imp->get_trail(); } - expr_ref kernel::get_implied_value(expr* e) { - return m_imp->get_implied_value(e); - } - - expr_ref kernel::get_implied_lower_bound(expr* e) { - return m_imp->get_implied_lower_bound(e); - } - - expr_ref kernel::get_implied_upper_bound(expr* e) { - return m_imp->get_implied_upper_bound(e); - } - void kernel::user_propagate_init( void* ctx, solver::push_eh_t& push_eh, diff --git a/src/smt/smt_kernel.h b/src/smt/smt_kernel.h index 2590b488e..cb64f7d13 100644 --- a/src/smt/smt_kernel.h +++ b/src/smt/smt_kernel.h @@ -224,16 +224,6 @@ namespace smt { */ expr_ref_vector cubes(unsigned depth); - /** - \brief retrieve upper/lower bound for arithmetic term, if it is implied. - retrieve implied values if terms are fixed to a value. - */ - - expr_ref get_implied_value(expr* e); - - expr_ref get_implied_lower_bound(expr* e); - - expr_ref get_implied_upper_bound(expr* e); /** \brief retrieve depth of variables from decision stack. diff --git a/src/smt/smt_solver.cpp b/src/smt/smt_solver.cpp index 208b51f0a..b78dd5b6c 100644 --- a/src/smt/smt_solver.cpp +++ b/src/smt/smt_solver.cpp @@ -390,18 +390,6 @@ namespace { } } - expr_ref get_implied_value(expr* e) override { - return m_context.get_implied_value(e); - } - - expr_ref get_implied_lower_bound(expr* e) override { - return m_context.get_implied_lower_bound(e); - } - - expr_ref get_implied_upper_bound(expr* e) override { - return m_context.get_implied_upper_bound(e); - } - bool fds_intersect(func_decl_set & pattern_fds, func_decl_set & assrtn_fds) { for (func_decl * fd : pattern_fds) { if (assrtn_fds.contains(fd)) diff --git a/src/smt/theory_fpa.cpp b/src/smt/theory_fpa.cpp index 11c92b265..aa458b672 100644 --- a/src/smt/theory_fpa.cpp +++ b/src/smt/theory_fpa.cpp @@ -809,10 +809,8 @@ namespace smt { bv2fp.convert_min_max_specials(&mdl, &new_model, seen); bv2fp.convert_uf2bvuf(&mdl, &new_model, seen); - for (obj_hashtable::iterator it = seen.begin(); - it != seen.end(); - it++) - mdl.unregister_decl(*it); + for (func_decl* f : seen) + mdl.unregister_decl(f); for (unsigned i = 0; i < new_model.get_num_constants(); i++) { func_decl * f = new_model.get_constant(i); diff --git a/src/solver/combined_solver.cpp b/src/solver/combined_solver.cpp index 19b1c9351..267a22de0 100644 --- a/src/solver/combined_solver.cpp +++ b/src/solver/combined_solver.cpp @@ -184,28 +184,6 @@ public: return m_solver1->get_scope_level(); } - expr_ref get_implied_value(expr* e) override { - if (m_use_solver1_results) - return m_solver1->get_implied_value(e); - else - return m_solver2->get_implied_value(e); - } - - expr_ref get_implied_lower_bound(expr* e) override { - if (m_use_solver1_results) - return m_solver1->get_implied_lower_bound(e); - else - return m_solver2->get_implied_lower_bound(e); - } - - expr_ref get_implied_upper_bound(expr* e) override { - if (m_use_solver1_results) - return m_solver1->get_implied_upper_bound(e); - else - return m_solver2->get_implied_upper_bound(e); - } - - lbool get_consequences(expr_ref_vector const& asms, expr_ref_vector const& vars, expr_ref_vector& consequences) override { switch_inc_mode(); m_use_solver1_results = false; diff --git a/src/solver/solver.h b/src/solver/solver.h index 513572ab2..9d7ecd690 100644 --- a/src/solver/solver.h +++ b/src/solver/solver.h @@ -227,17 +227,6 @@ public: virtual expr_ref_vector cube(expr_ref_vector& vars, unsigned backtrack_level) = 0; - /** - \brief retrieve fixed value assignment in current solver state, if it is implied. - */ - virtual expr_ref get_implied_value(expr* e) = 0; - - /** - \brief retrieve upper/lower bound for arithmetic term, if it is implied. - */ - virtual expr_ref get_implied_lower_bound(expr* e) = 0; - - virtual expr_ref get_implied_upper_bound(expr* e) = 0; class propagate_callback { public: diff --git a/src/solver/solver_pool.cpp b/src/solver/solver_pool.cpp index 4cd6a5258..16764ba56 100644 --- a/src/solver/solver_pool.cpp +++ b/src/solver/solver_pool.cpp @@ -127,18 +127,6 @@ public: return m_base->get_trail(); } - expr_ref get_implied_value(expr* e) override { - return expr_ref(e, m); - } - - expr_ref get_implied_lower_bound(expr* e) override { - return expr_ref(e, m); - } - - expr_ref get_implied_upper_bound(expr* e) override { - return expr_ref(e, m); - } - lbool check_sat_core2(unsigned num_assumptions, expr * const * assumptions) override { SASSERT(!m_pushed || get_scope_level() > 0); m_proof.reset(); diff --git a/src/solver/tactic2solver.cpp b/src/solver/tactic2solver.cpp index c4bd6b419..f68131bab 100644 --- a/src/solver/tactic2solver.cpp +++ b/src/solver/tactic2solver.cpp @@ -96,19 +96,6 @@ public: expr_ref_vector get_trail() override { throw default_exception("cannot retrieve trail from solvers created using tactics"); } - - expr_ref get_implied_value(expr* e) override { - return expr_ref(e, m); - } - - expr_ref get_implied_lower_bound(expr* e) override { - return expr_ref(e, m); - } - - expr_ref get_implied_upper_bound(expr* e) override { - return expr_ref(e, m); - } - }; ast_manager& tactic2solver::get_manager() const { return m_assertions.get_manager(); } diff --git a/src/tactic/fd_solver/bounded_int2bv_solver.cpp b/src/tactic/fd_solver/bounded_int2bv_solver.cpp index f872d1f5c..6c749f694 100644 --- a/src/tactic/fd_solver/bounded_int2bv_solver.cpp +++ b/src/tactic/fd_solver/bounded_int2bv_solver.cpp @@ -360,19 +360,6 @@ private: return m_assertions.get(idx); } } - - expr_ref get_implied_value(expr* e) override { - return expr_ref(e, m); - } - - expr_ref get_implied_lower_bound(expr* e) override { - return expr_ref(e, m); - } - - expr_ref get_implied_upper_bound(expr* e) override { - return expr_ref(e, m); - } - }; solver * mk_bounded_int2bv_solver(ast_manager & m, params_ref const & p, solver* s) { diff --git a/src/tactic/fd_solver/enum2bv_solver.cpp b/src/tactic/fd_solver/enum2bv_solver.cpp index aa7888b5c..aae6ea820 100644 --- a/src/tactic/fd_solver/enum2bv_solver.cpp +++ b/src/tactic/fd_solver/enum2bv_solver.cpp @@ -197,18 +197,6 @@ public: return m_solver->get_assertion(idx); } - expr_ref get_implied_value(expr* e) override { - return expr_ref(e, m); - } - - expr_ref get_implied_lower_bound(expr* e) override { - return expr_ref(e, m); - } - - expr_ref get_implied_upper_bound(expr* e) override { - return expr_ref(e, m); - } - }; solver * mk_enum2bv_solver(ast_manager & m, params_ref const & p, solver* s) { diff --git a/src/tactic/fd_solver/pb2bv_solver.cpp b/src/tactic/fd_solver/pb2bv_solver.cpp index 20e0bc545..9e17dd711 100644 --- a/src/tactic/fd_solver/pb2bv_solver.cpp +++ b/src/tactic/fd_solver/pb2bv_solver.cpp @@ -147,18 +147,6 @@ public: return m_solver->get_assertion(idx); } - expr_ref get_implied_value(expr* e) override { - return expr_ref(e, m); - } - - expr_ref get_implied_lower_bound(expr* e) override { - return expr_ref(e, m); - } - - expr_ref get_implied_upper_bound(expr* e) override { - return expr_ref(e, m); - } - private: void flush_assertions() const { diff --git a/src/tactic/fd_solver/smtfd_solver.cpp b/src/tactic/fd_solver/smtfd_solver.cpp index 84c10c28b..89ec15b05 100644 --- a/src/tactic/fd_solver/smtfd_solver.cpp +++ b/src/tactic/fd_solver/smtfd_solver.cpp @@ -2104,18 +2104,6 @@ namespace smtfd { return m_assertions.get(idx); } - expr_ref get_implied_value(expr* e) override { - return expr_ref(e, m); - } - - expr_ref get_implied_lower_bound(expr* e) override { - return expr_ref(e, m); - } - - expr_ref get_implied_upper_bound(expr* e) override { - return expr_ref(e, m); - } - }; } diff --git a/src/util/statistics.cpp b/src/util/statistics.cpp index 148dec78b..af4ba5d61 100644 --- a/src/util/statistics.cpp +++ b/src/util/statistics.cpp @@ -45,24 +45,19 @@ void statistics::reset() { template static void mk_map(V const & v, M & m) { - typename V::const_iterator it = v.begin(); - typename V::const_iterator end = v.end(); - for (; it != end; ++it) { + for (auto const& p : v) { typename V::data::second_type val; - if (m.find(it->first, val)) - m.insert(it->first, it->second + val); + if (m.find(p.first, val)) + m.insert(p.first, p.second + val); else - m.insert(it->first, it->second); + m.insert(p.first, p.second); } } template static void get_keys(M const & m, ptr_buffer & keys) { - typename M::iterator it = m.begin(); - typename M::iterator end = m.end(); - for (; it != end; ++it) { - keys.push_back(const_cast(it->m_key)); - } + for (auto const& kv : m) + keys.push_back(const_cast(kv.m_key)); } static void display_smt2_key(std::ostream & out, char const * key) { @@ -175,10 +170,8 @@ std::ostream& statistics::display(std::ostream & out) const { template static void display_internal(std::ostream & out, M const & m) { - typename M::iterator it = m.begin(); - typename M::iterator end = m.end(); - for (; it != end; it++) { - char const * key = it->m_key; + for (auto const& kv : m) { + char const * key = kv.m_key; if (*key == ':') key++; while (*key) { if ('a' <= *key && *key <= 'z') @@ -188,7 +181,7 @@ static void display_internal(std::ostream & out, M const & m) { else out << *key; } - out << " " << it->m_value << "\n"; + out << " " << kv.m_value << "\n"; } }