diff --git a/src/sat/smt/polysat/saturation.cpp b/src/sat/smt/polysat/saturation.cpp index 4b9445e1c..8f26ca732 100644 --- a/src/sat/smt/polysat/saturation.cpp +++ b/src/sat/smt/polysat/saturation.cpp @@ -175,12 +175,12 @@ namespace polysat { } // Ovfl(x, y) & ~Ovfl(y, z) ==> x > z - void saturation::try_umul_ovfl(pvar v, umul_ovfl const& sc) { + void saturation::try_umul_monotonicity(umul_ovfl const& sc) { auto p = sc.p(), q = sc.q(); - auto match_mul_arg = [&](auto const& sc2) { - auto p2 = sc2.to_umul_ovfl().p(), q2 = sc2.to_umul_ovfl().q(); + auto match_mul_arg = [&](auto const& sc2) { + auto p2 = sc2.to_umul_ovfl().p(), q2 = sc2.to_umul_ovfl().q(); return p == p2 || p == q2 || q == p2 || q == q2; - }; + }; auto extract_mul_args = [&](auto const& sc2) -> std::pair { auto p2 = sc2.to_umul_ovfl().p(), q2 = sc2.to_umul_ovfl().q(); if (p == p2) @@ -193,7 +193,7 @@ namespace polysat { SASSERT(q == q2); return { p, p2 }; } - }; + }; for (auto id : constraint_iterator(c, [&](auto const& sc2) { return sc2.is_umul_ovfl() && sc.sign() != sc2.sign() && match_mul_arg(sc2); })) { auto sc2 = c.get_constraint(id); auto d1 = c.get_dependency(sc.id()); @@ -201,9 +201,106 @@ namespace polysat { auto [q1, q2] = extract_mul_args(sc2); if (sc.sign()) add_clause("[y] ~ovfl(p, q1) & ovfl(p, q2) ==> q1 < q2", { d1, d2, C.ult(q1, q2) }, true); - else - add_clause("[y] ovfl(p, q1) & ~ovfl(p, q2) ==> q1 > q2", { d1, d2, C.ult(q2, q1)}, true); - } + else + add_clause("[y] ovfl(p, q1) & ~ovfl(p, q2) ==> q1 > q2", { d1, d2, C.ult(q2, q1) }, true); + } + } + + /** + * Expand the following axioms: + * Ovfl(x, y) & msb(x) <= k => msb(y) >= N - k - 1 + * Ovfl(x, y) & msb(x) <= k & msb(y) <= N - k - 1 => 0x * 0y >= 2^{N-1} + * + * ~Ovfl(x,y) & msb(x) >= k => msb(y) <= N - k - 1 + * ~Ovfl(x,y) & msb(x) >= k & msb(y) >= N - k - 1 => 0x * 0y < 2^{N-1} + */ + void saturation::try_umul_blast(umul_ovfl const& sc) { + auto x = sc.p(); + auto y = sc.q(); + if (!x.is_val()) + return; + if (!y.is_val()) + return; + auto N = x.manager().power_of_2(); + auto d = c.get_dependency(sc.id()); + + auto vx = x.val(); + auto vy = y.val(); + auto bx = vx.get_num_bits(); + auto by = vy.get_num_bits(); + if (bx > by) { + std::swap(bx, by); + std::swap(x, y); + } + + // Keep in mind that + // num-bits(0) = 1 + // num-bits(1) = 1 + // num-bits(2) = 2 + // num-bits(4) = 3 + // msb(0) = 0 + // msb(1) = 0 + // msb(2) = 1 + // msb(3) = 1 + // msb(2^N - 1) = N-1 + // msb(x) >= k <=> x >= 2^k, for k > 0 + // msb(x) <= k <=> x < 2^{k+1}, for k + 1 < N + + auto msb_ge = [&](pdd const& x, unsigned k) { + SASSERT(k > 0); + return C.uge(x, rational::power_of_two(k)); + }; + + auto msb_le = [&](pdd const& x, unsigned k) { + SASSERT(k + 1 < N); + return C.ult(x, rational::power_of_two(k + 1)); + }; + + if (sc.sign()) { + // Case ~Ovfl(x,y) is asserted by current assignment x * y is overflow + SASSERT(bx > 1 && by > 1); + SASSERT(bx + by >= N + 1); + auto k = bx - 1; + if (bx + by > N + 1) + add_clause("~Ovfl(x, y) & msb(x) >= k => msb(y) <= N - k - 1", + {d, ~msb_ge(x, k), msb_le(y, N - k - 1)}, + true); + else { + auto x1 = c.mk_zero_extend(1, x); + auto y1 = c.mk_zero_extend(1, y); + add_clause("~Ovfl(x, y) & msb(x) >= k & msb(y) >= N - k - 1 => 0x * 0y < 2^{N-1}", + { d, ~msb_ge(x, k), + ~msb_ge(y, N - k - 1), + C.ult(x1 * y1, rational::power_of_two(N - 1)) + }, true); + } + } + else { + // Case Ovfl(x,y) + if (bx == 0) { + add_clause("Ovfl(x, y) => x > 1", { d, C.ugt(x, 1) }, true); + } + else if (bx + by < N + 1) { + SASSERT(bx <= by); + auto k = bx - 1; + add_clause("Ovfl(x, y) & msb(x) <= k => msb(y) >= N - k - 1", + { d, ~msb_le(x, k), msb_ge(y, N - k - 1) }, true); + } + else { + auto k = bx - 1; + auto x1 = c.mk_zero_extend(1, x); + auto y1 = c.mk_zero_extend(1, y); + add_clause("Ovfl(x, y) & msb(x) <= k & msb(y) <= N - k - 1 = > 0x * 0y >= 2 ^ {N - 1}", + { d, ~msb_le(x, k), ~msb_le(y, N - k - 1), C.uge(x1 * y1, rational::power_of_two(N - 1)) }, + true); + } + } + } + + + void saturation::try_umul_ovfl(pvar v, umul_ovfl const& sc) { + try_umul_monotonicity(sc); + try_umul_blast(sc); } void saturation::try_eq_resolve(pvar v, inequality const& i) { diff --git a/src/sat/smt/polysat/saturation.h b/src/sat/smt/polysat/saturation.h index 74527c7c6..49af28275 100644 --- a/src/sat/smt/polysat/saturation.h +++ b/src/sat/smt/polysat/saturation.h @@ -65,6 +65,8 @@ namespace polysat { void try_ugt_y(pvar v, inequality const& i); void try_ugt_z(pvar z, inequality const& i); void try_umul_ovfl(pvar v, umul_ovfl const& sc); + void try_umul_monotonicity(umul_ovfl const& sc); + void try_umul_blast(umul_ovfl const& sc); void try_op(pvar v, signed_constraint& sc, dependency const& d); signed_constraint ineq(bool is_strict, pdd const& x, pdd const& y); diff --git a/src/sat/smt/polysat/types.h b/src/sat/smt/polysat/types.h index 02881c20d..4ccdb8077 100644 --- a/src/sat/smt/polysat/types.h +++ b/src/sat/smt/polysat/types.h @@ -149,6 +149,7 @@ namespace polysat { virtual void get_bitvector_super_slices(pvar v, offset_slices& out) = 0; virtual void get_fixed_bits(pvar v, fixed_bits_vector& fixed_slice) = 0; virtual pdd mk_ite(signed_constraint const& sc, pdd const& p, pdd const& q) = 0; + virtual pdd mk_zero_extend(unsigned sz, pdd const& p) = 0; virtual unsigned level(dependency const& d) = 0; }; diff --git a/src/sat/smt/polysat_internalize.cpp b/src/sat/smt/polysat_internalize.cpp index 530ee3340..535cae8f9 100644 --- a/src/sat/smt/polysat_internalize.cpp +++ b/src/sat/smt/polysat_internalize.cpp @@ -858,6 +858,13 @@ namespace polysat { return expr2pdd(ite); } + pdd solver::mk_zero_extend(unsigned n, pdd const& p) { + expr_ref pe = pdd2expr(p); + auto ze = bv.mk_zero_extend(n, pe); + ctx.internalize(ze); + return expr2pdd(ze); + } + dd::pdd solver::expr2pdd(expr* e) { return var2pdd(get_th_var(e)); } diff --git a/src/sat/smt/polysat_solver.h b/src/sat/smt/polysat_solver.h index c94918ff9..3f42381c6 100644 --- a/src/sat/smt/polysat_solver.h +++ b/src/sat/smt/polysat_solver.h @@ -225,6 +225,7 @@ namespace polysat { void get_bitvector_suffixes(pvar v, offset_slices& out) override; void get_fixed_bits(pvar v, fixed_bits_vector& fixed_bits) override; pdd mk_ite(signed_constraint const& sc, pdd const& p, pdd const& q) override; + pdd mk_zero_extend(unsigned sz, pdd const& p) override; unsigned level(dependency const& d) override; dependency explain_slice(pvar v, pvar w, unsigned offset);