From b72575148ff59c10d9a584b7558f84754bdf61fe Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Mon, 11 Dec 2023 15:45:54 -0800 Subject: [PATCH] axioms for b-and Signed-off-by: Nikolaj Bjorner --- src/ast/arith_decl_plugin.h | 8 ++++ src/sat/smt/arith_axioms.cpp | 65 +++++++++++++++++++++++++++++++ src/sat/smt/arith_internalize.cpp | 4 +- src/sat/smt/arith_solver.cpp | 2 + src/sat/smt/arith_solver.h | 4 ++ src/sat/smt/intblast_solver.cpp | 27 +++++++++++-- src/sat/smt/intblast_solver.h | 4 +- 7 files changed, 106 insertions(+), 8 deletions(-) diff --git a/src/ast/arith_decl_plugin.h b/src/ast/arith_decl_plugin.h index a5ab60731..b073e205e 100644 --- a/src/ast/arith_decl_plugin.h +++ b/src/ast/arith_decl_plugin.h @@ -312,6 +312,14 @@ public: bool is_int_real(expr const * n) const { return is_int_real(n->get_sort()); } bool is_band(expr const* n) const { return is_app_of(n, arith_family_id, OP_ARITH_BAND); } + bool is_band(expr const* n, unsigned& sz, expr*& x, expr*& y) { + if (!is_band(n)) + return false; + x = to_app(n)->get_arg(0); + y = to_app(n)->get_arg(1); + sz = to_app(n)->get_parameter(0).get_int(); + return true; + } bool is_sin(expr const* n) const { return is_app_of(n, arith_family_id, OP_SIN); } bool is_cos(expr const* n) const { return is_app_of(n, arith_family_id, OP_COS); } diff --git a/src/sat/smt/arith_axioms.cpp b/src/sat/smt/arith_axioms.cpp index 173ae28c8..b8bffa5f2 100644 --- a/src/sat/smt/arith_axioms.cpp +++ b/src/sat/smt/arith_axioms.cpp @@ -205,6 +205,71 @@ namespace arith { add_clause(dgez, neg); } + bool solver::check_band_term(app* n) { + unsigned sz; + expr* x, * y; + VERIFY(a.is_band(n, sz, x, y)); + if (use_nra_model()) { + found_unsupported(n); + return true; + } + theory_var vx = expr2enode(x)->get_th_var(get_id()); + theory_var vy = expr2enode(y)->get_th_var(get_id()); + theory_var xn = expr2enode(n)->get_th_var(get_id()); + rational valx = get_value(vx); + rational valy = get_value(vy); + rational valn = get_value(xn); + + // x mod 2^{i + 1} >= 2^i means the i'th bit is 1. + auto bitof = [&](expr* x, unsigned i) { + expr_ref r(m); + r = a.mk_ge(a.mk_mod(x, a.mk_int(rational::power_of_two(i+1))), a.mk_int(rational::power_of_two(i))); + return mk_literal(r); + }; + for (unsigned i = 0; i < sz; ++i) { + bool xb = valx.get_bit(i); + bool yb = valy.get_bit(i); + bool nb = valn.get_bit(i); + if (xb && yb && !nb) { + add_clause(~bitof(x, i), ~bitof(y, i), bitof(n, i)); + return false; + } + if (nb && !xb) { + add_clause(~bitof(n, i), bitof(x, i)); + return false; + } + if (nb && !yb) { + add_clause(~bitof(n, i), bitof(y, i)); + return false; + } + } + return true; + } + + bool solver::check_band_terms() { + for (app* n : m_band_terms) { + if (!check_band_term(n)) + return false; + } + return true; + } + + /* + * 0 <= x&y < 2^sz + * x&y <= x + * x&y <= y + */ + void solver::mk_band_axiom(app* n) { + unsigned sz; + expr* x, * y; + VERIFY(a.is_band(n, sz, x, y)); + rational N = rational::power_of_two(sz); + add_clause(mk_literal(a.mk_ge(n, a.mk_int(0)))); + add_clause(mk_literal(a.mk_le(n, a.mk_int(N - 1)))); + add_clause(mk_literal(a.mk_le(n, a.mk_mod(x, a.mk_int(N))))); + add_clause(mk_literal(a.mk_le(n, a.mk_mod(y, a.mk_int(N))))); + } + void solver::mk_bound_axioms(api_bound& b) { theory_var v = b.get_var(); lp_api::bound_kind kind1 = b.get_bound_kind(); diff --git a/src/sat/smt/arith_internalize.cpp b/src/sat/smt/arith_internalize.cpp index 66b25cb34..4d0943d65 100644 --- a/src/sat/smt/arith_internalize.cpp +++ b/src/sat/smt/arith_internalize.cpp @@ -254,7 +254,9 @@ namespace arith { } else if (a.is_band(n)) { // unsupported for now. - found_unsupported(n); + m_band_terms.push_back(to_app(n)); + mk_band_axiom(to_app(n)); + ctx.push(push_back_vector(m_band_terms)); ensure_arg_vars(to_app(n)); } else if (!a.is_div0(n) && !a.is_mod0(n) && !a.is_idiv0(n) && !a.is_rem0(n) && !a.is_power0(n)) { diff --git a/src/sat/smt/arith_solver.cpp b/src/sat/smt/arith_solver.cpp index 2be9b6b60..9e03bbee4 100644 --- a/src/sat/smt/arith_solver.cpp +++ b/src/sat/smt/arith_solver.cpp @@ -1197,6 +1197,8 @@ namespace arith { default: UNREACHABLE(); } + if (lia_check == l_true && !check_band_terms()) + lia_check = l_false; return lia_check; } diff --git a/src/sat/smt/arith_solver.h b/src/sat/smt/arith_solver.h index 20ae599c2..50cdc63ef 100644 --- a/src/sat/smt/arith_solver.h +++ b/src/sat/smt/arith_solver.h @@ -214,6 +214,7 @@ namespace arith { expr* m_not_handled = nullptr; ptr_vector m_underspecified; ptr_vector m_idiv_terms; + ptr_vector m_band_terms; vector > m_use_list; // bounds where variables are used. // attributes for incremental version: @@ -317,6 +318,7 @@ namespace arith { void mk_bound_axioms(api_bound& b); void mk_bound_axiom(api_bound& b1, api_bound& b2); void mk_power0_axioms(app* t, app* n); + void mk_band_axiom(app* n); void flush_bound_axioms(); void add_farkas_clause(sat::literal l1, sat::literal l2); @@ -408,6 +410,8 @@ namespace arith { bool check_delayed_eqs(); lbool check_lia(); lbool check_nla(); + bool check_band_terms(); + bool check_band_term(app* n); void add_lemmas(); void propagate_nla(); void add_equality(lpvar v, rational const& k, lp::explanation const& exp); diff --git a/src/sat/smt/intblast_solver.cpp b/src/sat/smt/intblast_solver.cpp index 8a3738cec..65dc56e00 100644 --- a/src/sat/smt/intblast_solver.cpp +++ b/src/sat/smt/intblast_solver.cpp @@ -13,6 +13,7 @@ Author: #include "ast/ast_util.h" #include "ast/for_each_expr.h" +#include "params/bv_rewriter_params.hpp" #include "sat/smt/intblast_solver.h" #include "sat/smt/euf_solver.h" @@ -25,7 +26,8 @@ namespace intblast { m(ctx.get_manager()), bv(m), a(m), - m_trail(m) + m_trail(m), + m_pinned(m) {} lbool solver::check() { @@ -82,7 +84,6 @@ namespace intblast { m_core.reset(); m_vars.reset(); m_trail.reset(); - m_new_funs.reset(); m_solver = mk_smt2_solver(m, s.params(), symbol::null); expr_ref_vector es(m); @@ -284,6 +285,8 @@ namespace intblast { if (!m_new_funs.find(f, g)) { g = m.mk_fresh_func_decl(ap->get_decl()->get_name(), symbol("bv"), domain.size(), domain.data(), range); m_new_funs.insert(f, g); + m_pinned.push_back(f); + m_pinned.push_back(g); } f = g; } @@ -452,6 +455,24 @@ namespace intblast { m_trail.push_back(p); break; } + case OP_BUDIV: { + bv_rewriter_params p(ctx.s().params()); + expr* x = args.get(0), * y = args.get(1); + if (p.hi_div0()) + m_trail.push_back(m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(0), a.mk_idiv(x, y))); + else + m_trail.push_back(a.mk_idiv(x, y)); + break; + } + case OP_BUREM: { + bv_rewriter_params p(ctx.s().params()); + expr* x = args.get(0), * y = args.get(1); + if (p.hi_div0()) + m_trail.push_back(m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(0), a.mk_mod(x, y))); + else + m_trail.push_back(a.mk_mod(x, y)); + break; + } case OP_BCOMP: case OP_BASHR: case OP_ROTATE_LEFT: @@ -463,9 +484,7 @@ namespace intblast { case OP_SIGN_EXT: case OP_BREDOR: case OP_BREDAND: - case OP_BUDIV: case OP_BSDIV: - case OP_BUREM: case OP_BSREM: case OP_BSMOD: verbose_stream() << mk_pp(e, m) << "\n"; diff --git a/src/sat/smt/intblast_solver.h b/src/sat/smt/intblast_solver.h index c165e1562..a093713bb 100644 --- a/src/sat/smt/intblast_solver.h +++ b/src/sat/smt/intblast_solver.h @@ -47,13 +47,11 @@ namespace intblast { obj_map m_vars; obj_map m_new_funs; expr_ref_vector m_trail; + ast_ref_vector m_pinned; sat::literal_vector m_core; - - bool is_bv(sat::literal lit); void translate(expr_ref_vector& es); - void add_root_equations(expr_ref_vector& es, ptr_vector& sorted); void sorted_subterms(expr_ref_vector& es, ptr_vector& sorted); public: