diff --git a/src/sat/smt/arith_axioms.cpp b/src/sat/smt/arith_axioms.cpp index 046470000..09db74f75 100644 --- a/src/sat/smt/arith_axioms.cpp +++ b/src/sat/smt/arith_axioms.cpp @@ -208,6 +208,8 @@ namespace arith { bool solver::check_band_term(app* n) { unsigned sz; expr* x, * y; + if (!ctx.is_relevant(expr2enode(n))) + return true; VERIFY(a.is_band(n, sz, x, y)); if (use_nra_model()) { found_unsupported(n); @@ -217,6 +219,11 @@ namespace arith { theory_var vy = expr2enode(y)->get_th_var(get_id()); theory_var vn = expr2enode(n)->get_th_var(get_id()); rational N = rational::power_of_two(sz); + if (!get_value(vx).is_int() || !get_value(vy).is_int()) { + + s().display(verbose_stream()); + verbose_stream() << vx << " " << vy << " " << mk_pp(n, m) << "\n"; + } SASSERT(get_value(vx).is_int()); SASSERT(get_value(vy).is_int()); SASSERT(get_value(vn).is_int()); diff --git a/src/sat/smt/euf_solver.cpp b/src/sat/smt/euf_solver.cpp index b95e44d74..f1ccb6879 100644 --- a/src/sat/smt/euf_solver.cpp +++ b/src/sat/smt/euf_solver.cpp @@ -22,6 +22,7 @@ Author: #include "sat/smt/pb_solver.h" #include "sat/smt/bv_solver.h" #include "sat/smt/polysat_solver.h" +#include "sat/smt/intblast_solver.h" #include "sat/smt/euf_solver.h" #include "sat/smt/array_solver.h" #include "sat/smt/arith_solver.h" @@ -135,8 +136,16 @@ namespace euf { special_relations_util sp(m); if (pb.get_family_id() == fid) ext = alloc(pb::solver, *this, fid); - else if (bvu.get_family_id() == fid) - ext = alloc(polysat::solver, *this, fid); + else if (bvu.get_family_id() == fid) { + if (get_config().m_bv_solver == 0) + ext = alloc(bv::solver, *this, fid); + else if (get_config().m_bv_solver == 1) + ext = alloc(polysat::solver, *this, fid); + else if (get_config().m_bv_solver == 2) + ext = alloc(intblast::solver, *this); + else + throw default_exception("unknown bit-vector solver. Accepted values 0 (bit blast), 1 (polysat), 2 (int blast)"); + } else if (au.get_family_id() == fid) ext = alloc(array::solver, *this, fid); else if (fpa.get_family_id() == fid) diff --git a/src/sat/smt/intblast_solver.cpp b/src/sat/smt/intblast_solver.cpp index e87944a91..32bf52f79 100644 --- a/src/sat/smt/intblast_solver.cpp +++ b/src/sat/smt/intblast_solver.cpp @@ -88,22 +88,25 @@ namespace intblast { expr* e = n->get_expr(); expr* x, * y; VERIFY(m.is_eq(n->get_expr(), x, y)); + SASSERT(bv.is_bv(x)); + ensure_translated(x); + ensure_translated(y); m_args.reset(); - m_args.push_back(translated(x)); - m_args.push_back(translated(y)); - add_equiv(expr2literal(e), eq_internalize(umod(x, 0), umod(x, 1))); + m_args.push_back(a.mk_sub(translated(x), translated(y))); + expr_ref lhs(umod(x, 0), m); + ctx.get_rewriter()(lhs); + add_equiv(expr2literal(e), eq_internalize(lhs, a.mk_int(0))); } void solver::internalize_bv(app* e) { - ensure_args(e); - m_args.reset(); - for (auto arg : *e) - m_args.push_back(translated(arg)); - translate_bv(e); + ensure_translated(e); // possibly wait until propagation? - if (m.is_bool(e)) - add_equiv(expr2literal(e), mk_literal(translated(e))); + if (m.is_bool(e)) { + expr_ref r(translated(e), m); + ctx.get_rewriter()(r); + add_equiv(expr2literal(e), mk_literal(r)); + } add_bound_axioms(); } @@ -120,32 +123,28 @@ namespace intblast { } } - void solver::ensure_args(app* e) { + void solver::ensure_translated(expr* e) { + if (m_translate.get(e->get_id(), nullptr)) + return; ptr_vector todo; ast_fast_mark1 visited; - for (auto arg : *e) { - if (!m_translate.get(arg->get_id(), nullptr)) - todo.push_back(arg); - } - if (todo.empty()) - return; + todo.push_back(e); + visited.mark(e); for (unsigned i = 0; i < todo.size(); ++i) { expr* e = todo[i]; - if (m.is_bool(e)) + if (!is_app(e)) continue; - else if (is_app(e)) { - for (auto arg : *to_app(e)) - if (!visited.is_marked(arg)) { - visited.mark(arg); - todo.push_back(arg); - } - } - else if (is_lambda(e)) - throw default_exception("lambdas are not supported in intblaster"); + app* a = to_app(e); + if (m.is_bool(e) && a->get_family_id() != bv.get_family_id()) + continue; + for (auto arg : *a) + if (!visited.is_marked(arg) && !m_translate.get(arg->get_id(), nullptr)) { + visited.mark(arg); + todo.push_back(arg); + } } - std::stable_sort(todo.begin(), todo.end(), [&](expr* a, expr* b) { return get_depth(a) < get_depth(b); }); - for (expr* e : todo) + for (expr* e : todo) translate_expr(e); } @@ -369,7 +368,7 @@ namespace intblast { } if (any_of(m_vars, [&](expr* v) { return translated(v) == x; })) return x; - return to_expr(a.mk_mod(x, a.mk_int(N))); + return a.mk_mod(x, a.mk_int(N)); } expr* solver::smod(expr* bv_expr, unsigned i) { @@ -393,6 +392,10 @@ namespace intblast { translate_var(to_var(e)); else { app* ap = to_app(e); + if (m_is_plugin && ap->get_family_id() == basic_family_id && m.is_bool(ap)) { + set_translated(e, e); + return; + } m_args.reset(); for (auto arg : *ap) m_args.push_back(translated(arg)); @@ -543,7 +546,6 @@ namespace intblast { r = a.mk_uminus(arg(0)); break; case OP_CONCAT: { - r = a.mk_int(0); unsigned sz = 0; for (unsigned i = 0; i < args.size(); ++i) { expr* old_arg = e->get_arg(i); @@ -595,7 +597,7 @@ namespace intblast { expr* x = arg(0), * y = arg(1); r = a.mk_int(0); for (unsigned i = 0; i < bv.get_bv_size(e); ++i) - r = m.mk_ite(a.mk_eq(y, a.mk_int(i)), a.mk_mul(x, a.mk_int(rational::power_of_two(i))), r); + r = m.mk_ite(m.mk_eq(y, a.mk_int(i)), a.mk_mul(x, a.mk_int(rational::power_of_two(i))), r); break; } case OP_BNOT: @@ -605,7 +607,7 @@ namespace intblast { expr* x = arg(0), * y = arg(1); r = a.mk_int(0); for (unsigned i = 0; i < bv.get_bv_size(e); ++i) - r = m.mk_ite(a.mk_eq(y, a.mk_int(i)), a.mk_idiv(x, a.mk_int(rational::power_of_two(i))), r); + r = m.mk_ite(m.mk_eq(y, a.mk_int(i)), a.mk_idiv(x, a.mk_int(rational::power_of_two(i))), r); break; } // Or use (p + q) - band(p, q)? @@ -649,7 +651,7 @@ namespace intblast { r = m.mk_ite(signbit, a.mk_int(N - 1), a.mk_int(0)); for (unsigned i = 0; i < bv.get_bv_size(e); ++i) { expr* d = a.mk_idiv(x, a.mk_int(rational::power_of_two(i))); - r = m.mk_ite(a.mk_eq(y, a.mk_int(i)), + r = m.mk_ite(m.mk_eq(y, a.mk_int(i)), m.mk_ite(signbit, a.mk_uminus(d), d), r); } @@ -686,6 +688,7 @@ namespace intblast { bv_expr = e->get_arg(0); r = m.mk_ite(m.mk_eq(umod(bv_expr, 0), umod(bv_expr, 1)), a.mk_int(1), a.mk_int(0)); break; + case OP_BSMOD_I: case OP_BSMOD: { bv_expr = e; expr* x = umod(bv_expr, 0), *y = umod(bv_expr, 0); @@ -693,12 +696,12 @@ namespace intblast { expr* signx = a.mk_ge(x, a.mk_int(N/2)); expr* signy = a.mk_ge(y, a.mk_int(N/2)); expr* u = a.mk_mod(x, y); - // x < 0, y < 0 -> r = -u - // x < 0, y >= 0 -> r = y - u - // x >= 0, y < 0 -> r = y + u - // x >= 0, y >= 0 -> r = u - // u = 0 -> r = 0 - // y = 0 -> r = x + // u = 0 -> 0 + // y = 0 -> x + // x < 0, y < 0 -> -u + // x < 0, y >= 0 -> y - u + // x >= 0, y < 0 -> y + u + // x >= 0, y >= 0 -> u r = a.mk_uminus(u); r = m.mk_ite(m.mk_and(m.mk_not(signx), signy), a.mk_add(u, y), r); r = m.mk_ite(m.mk_and(signx, m.mk_not(signy)), a.mk_sub(y, u), r); @@ -707,15 +710,41 @@ namespace intblast { r = m.mk_ite(m.mk_eq(y, a.mk_int(0)), x, r); break; } + case OP_BSDIV_I: case OP_BSDIV: { + // d = udiv(x mod N, y mod N) // y = 0, x > 0 -> 1 // y = 0, x <= 0 -> -1 - // y != 0 -> machine_div(x, y) -#if 0 - -#endif + // x = 0, y != 0 -> 0 + // x < 0, y < 0 -> -d + // x < 0, y > 0 -> -d + // x > 0, y > 0 -> d + // x < 0, y < 0 -> d + bv_expr = e; + expr* x = umod(bv_expr, 0), * y = umod(bv_expr, 0); + rational N = rational::power_of_two(bv.get_bv_size(bv_expr)); + expr* signx = a.mk_ge(x, a.mk_int(N / 2)); + expr* signy = a.mk_ge(y, a.mk_int(N / 2)); + expr* d = a.mk_idiv(x, y); + r = m.mk_ite(m.mk_iff(signx, signy), d, a.mk_uminus(d)); + r = m.mk_ite(m.mk_eq(y, a.mk_int(0)), m.mk_ite(signx, a.mk_int(-1), a.mk_int(1)), r); + break; + } + case OP_BSREM_I: + case OP_BSREM: { + // y = 0 -> x + // else x - sdiv(x, y) * y + bv_expr = e; + expr* x = umod(bv_expr, 0), * y = umod(bv_expr, 0); + rational N = rational::power_of_two(bv.get_bv_size(bv_expr)); + expr* signx = a.mk_ge(x, a.mk_int(N / 2)); + expr* signy = a.mk_ge(y, a.mk_int(N / 2)); + expr* d = a.mk_idiv(x, y); + d = m.mk_ite(m.mk_iff(signx, signy), d, a.mk_uminus(d)); + r = a.mk_sub(x, a.mk_mul(d, y)); + r = m.mk_ite(m.mk_eq(y, a.mk_int(0)), x, r); + break; } - case OP_ROTATE_LEFT: case OP_ROTATE_RIGHT: case OP_EXT_ROTATE_LEFT: @@ -723,7 +752,7 @@ namespace intblast { case OP_REPEAT: case OP_BREDOR: case OP_BREDAND: - case OP_BSREM: + verbose_stream() << mk_pp(e, m) << "\n"; NOT_IMPLEMENTED_YET(); break; @@ -739,11 +768,16 @@ namespace intblast { bool has_bv_arg = any_of(*e, [&](expr* arg) { return bv.is_bv(arg); }); if (has_bv_arg) { expr* bv_expr = e->get_arg(0); - set_translated(e, m.mk_eq(umod(bv_expr, 0), umod(bv_expr, 1))); + m_args[0] = a.mk_sub(arg(0), arg(1)); + set_translated(e, m.mk_eq(umod(bv_expr, 0), a.mk_int(0))); } else set_translated(e, m.mk_eq(arg(0), arg(1))); } + else if (m.is_ite(e)) + set_translated(e, m.mk_ite(arg(0), arg(1), arg(2))); + else if (m_is_plugin) + set_translated(e, e); else set_translated(e, m.mk_app(e->get_decl(), m_args)); } diff --git a/src/sat/smt/intblast_solver.h b/src/sat/smt/intblast_solver.h index ee85ed6e2..7dd37d5a7 100644 --- a/src/sat/smt/intblast_solver.h +++ b/src/sat/smt/intblast_solver.h @@ -57,7 +57,7 @@ namespace intblast { sat::literal_vector m_core; ptr_vector m_bv2int, m_int2bv; statistics m_stats; - bool m_is_plugin = true; // when the solver is used as a plugin, then do not translate below quantifiers. + bool m_is_plugin = true; // when the solver is used as a plugin, then do not translate below quantifiers. bool is_bv(sat::literal lit); void translate(expr_ref_vector& es); @@ -80,7 +80,7 @@ namespace intblast { void translate_quantifier(quantifier* q); void translate_var(var* v); - void ensure_args(app* e); + void ensure_translated(expr* e); void internalize_bv(app* e); unsigned m_vars_qhead = 0; diff --git a/src/sat/smt/polysat/constraints.h b/src/sat/smt/polysat/constraints.h index a9ec63165..47c9beb49 100644 --- a/src/sat/smt/polysat/constraints.h +++ b/src/sat/smt/polysat/constraints.h @@ -31,12 +31,15 @@ namespace polysat { class constraint { unsigned_vector m_vars; + unsigned m_num_watch = 0; public: virtual ~constraint() {} unsigned_vector& vars() { return m_vars; } unsigned_vector const& vars() const { return m_vars; } unsigned var(unsigned idx) const { return m_vars[idx]; } bool contains_var(pvar v) const { return m_vars.contains(v); } + unsigned num_watch() const { return m_num_watch; } + void set_num_watch(unsigned n) { SASSERT(n <= 2); m_num_watch = n; } virtual std::ostream& display(std::ostream& out, lbool status) const = 0; virtual std::ostream& display(std::ostream& out) const = 0; virtual lbool eval() const = 0; @@ -63,6 +66,8 @@ namespace polysat { unsigned_vector const& vars() const { return m_constraint->vars(); } unsigned var(unsigned idx) const { return m_constraint->var(idx); } bool contains_var(pvar v) const { return m_constraint->contains_var(v); } + unsigned num_watch() const { return m_constraint->num_watch(); } + void set_num_watch(unsigned n) { m_constraint->set_num_watch(n); } void activate(core& c, dependency const& d) { m_constraint->activate(c, m_sign, d); } void propagate(core& c, lbool value, dependency const& d) { m_constraint->propagate(c, value, d); } bool is_always_true() const { return eval() == l_true; } diff --git a/src/sat/smt/polysat/core.cpp b/src/sat/smt/polysat/core.cpp index de2fedb5a..c0b56a3d8 100644 --- a/src/sat/smt/polysat/core.cpp +++ b/src/sat/smt/polysat/core.cpp @@ -74,11 +74,13 @@ namespace polysat { if (vars.size() > 0) verbose_stream() << "(" << c.m_constraint_index.size() -1 << ": " << c.m_watch[vars[0]] << ") "; if (vars.size() > 1) verbose_stream() << "(" << c.m_constraint_index.size() -1 << ": " << c.m_watch[vars[1]] << ") "; verbose_stream() << "\n"); - SASSERT(vars.size() <= 0 || c.m_watch[vars[0]].back() == c.m_constraint_index.size() - 1); - SASSERT(vars.size() <= 1 || c.m_watch[vars[1]].back() == c.m_constraint_index.size() - 1); - if (vars.size() > 0) + unsigned n = sc.num_watch(); + SASSERT(n <= vars.size()); + SASSERT(n <= 0 || c.m_watch[vars[0]].back() == c.m_constraint_index.size() - 1); + SASSERT(n <= 1 || c.m_watch[vars[1]].back() == c.m_constraint_index.size() - 1); + if (n > 0) c.m_watch[vars[0]].pop_back(); - if (vars.size() > 1) + if (n > 1) c.m_watch[vars[1]].pop_back(); c.m_constraint_index.pop_back(); } @@ -138,9 +140,10 @@ namespace polysat { for (; i < sz && j < 2; ++i) if (!is_assigned(vars[i])) std::swap(vars[i], vars[j++]); - if (vars.size() > 0) + sc.set_num_watch(i); + if (i > 0) add_watch(idx, vars[0]); - if (vars.size() > 1) + if (i > 1) add_watch(idx, vars[1]); IF_VERBOSE(10, verbose_stream() << "add watch " << sc << " " << vars << " "; if (vars.size() > 0) verbose_stream() << "( " << idx << " : " << m_watch[vars[0]] << ") "; @@ -225,19 +228,12 @@ namespace polysat { s.trail().push(mk_assign_var(v, *this)); return; - // to debug: - unsigned sz = m_watch[v].size(); - for (unsigned i = 0; i < sz; ++i) { - auto idx = m_watch[v][i]; - auto [sc, dep, value] = m_constraint_index[idx]; - sc.propagate(*this, value, dep); - } - // update the watch lists for pvars // remove constraints from m_watch[v] that have more than 2 free variables. // for entries where there is only one free variable left add to viable set - unsigned j = 0; - for (auto idx : m_watch[v]) { + unsigned j = 0, sz = m_watch[v].size(); + for (unsigned k = 0; k < sz; ++k) { + auto idx = m_watch[v][k]; auto [sc, dep, value] = m_constraint_index[idx]; auto& vars = sc.vars(); if (vars[0] != v) @@ -253,6 +249,11 @@ namespace polysat { break; } } + + // this can create fresh literals and update m_watch, but + // will not update m_watch[v] (other than copy constructor for m_watch) + // because v has been assigned a value. + sc.propagate(*this, value, dep); SASSERT(!swapped || vars.size() <= 1 || (!is_assigned(vars[0]) && !is_assigned(vars[1]))); if (swapped) @@ -267,6 +268,7 @@ namespace polysat { // detect unitary, add to viable, detect conflict? m_viable.add_unitary(v1, idx); } + SASSERT(m_watch[v].size() == sz && "size of watch list was not changed"); m_watch[v].shrink(j); verbose_stream() << "new watch " << v << ": " << m_watch[v] << "\n"; } diff --git a/src/smt/params/smt_params_helper.pyg b/src/smt/params/smt_params_helper.pyg index 300bef1fb..b882c1abf 100644 --- a/src/smt/params/smt_params_helper.pyg +++ b/src/smt/params/smt_params_helper.pyg @@ -54,6 +54,7 @@ def_module_params(module_name='smt', ('bv.watch_diseq', BOOL, False, 'use watch lists instead of eager axioms for bit-vectors'), ('bv.delay', BOOL, False, 'delay internalize expensive bit-vector operations'), ('bv.size_reduce', BOOL, False, 'pre-processing; turn assertions that set the upper bits of a bit-vector to constants into a substitution that replaces the bit-vector with constant bits. Useful for minimizing circuits as many input bits to circuits are constant'), + ('bv.solver', UINT, 1, 'bit-vector solver engine: 0 - bit-blasting, 1 - polysat, 2 - intblast, requires sat.smt=true'), ('arith.random_initial_value', BOOL, False, 'use random initial values in the simplex-based procedure for linear arithmetic'), ('arith.solver', UINT, 6, 'arithmetic solver: 0 - no solver, 1 - bellman-ford based solver (diff. logic only), 2 - simplex based solver, 3 - floyd-warshall based solver (diff. logic only) and no theory combination 4 - utvpi, 5 - infinitary lra, 6 - lra solver'), ('arith.nl', BOOL, True, '(incomplete) nonlinear arithmetic support based on Groebner basis and interval propagation, relevant only if smt.arith.solver=2'), diff --git a/src/smt/params/theory_bv_params.cpp b/src/smt/params/theory_bv_params.cpp index 734a983fb..8a3ddcf37 100644 --- a/src/smt/params/theory_bv_params.cpp +++ b/src/smt/params/theory_bv_params.cpp @@ -28,6 +28,7 @@ void theory_bv_params::updt_params(params_ref const & _p) { m_bv_enable_int2bv2int = p.bv_enable_int2bv(); m_bv_delay = p.bv_delay(); m_bv_size_reduce = p.bv_size_reduce(); + m_bv_solver = p.bv_solver(); } #define DISPLAY_PARAM(X) out << #X"=" << X << '\n'; @@ -42,4 +43,5 @@ void theory_bv_params::display(std::ostream & out) const { DISPLAY_PARAM(m_bv_enable_int2bv2int); DISPLAY_PARAM(m_bv_delay); DISPLAY_PARAM(m_bv_size_reduce); + DISPLAY_PARAM(m_bv_solver); } diff --git a/src/smt/params/theory_bv_params.h b/src/smt/params/theory_bv_params.h index 523459f09..97428c8ba 100644 --- a/src/smt/params/theory_bv_params.h +++ b/src/smt/params/theory_bv_params.h @@ -36,6 +36,7 @@ struct theory_bv_params { bool m_bv_watch_diseq = false; bool m_bv_delay = true; bool m_bv_size_reduce = false; + unsigned m_bv_solver = 0; theory_bv_params(params_ref const & p = params_ref()) { updt_params(p); }