diff --git a/src/math/polysat/univariate/univariate_solver.cpp b/src/math/polysat/univariate/univariate_solver.cpp index 6bf1d88c1..dd54e262c 100644 --- a/src/math/polysat/univariate/univariate_solver.cpp +++ b/src/math/polysat/univariate/univariate_solver.cpp @@ -32,13 +32,70 @@ namespace polysat { return deps; } + bool univariate_solver::find_min(rational& val) { + val = model(); + push(); + // try reducing val by setting bits to 0, starting at the msb. + for (unsigned k = bit_width; k-- > 0; ) { + if (!val.get_bit(k)) { + add_bit0(k, null_dep); + continue; + } + // try decreasing k-th bit + push(); + add_bit0(k, 0); + lbool result = check(); + if (result == l_true) { + SASSERT(model() < val); + val = model(); + } + pop(1); + if (result == l_true) + add_bit0(k, null_dep); + else if (result == l_false) + add_bit1(k, null_dep); + else + return false; + } + pop(1); + return true; + } + + bool univariate_solver::find_max(rational& val) { + val = model(); + push(); + // try increasing val by setting bits to 1, starting at the msb. + for (unsigned k = bit_width; k-- > 0; ) { + if (val.get_bit(k)) { + add_bit1(k, 0); + continue; + } + // try increasing k-th bit + push(); + add_bit1(k, 0); + lbool result = check(); + if (result == l_true) { + SASSERT(model() > val); + val = model(); + } + pop(1); + if (result == l_true) + add_bit1(k, null_dep); + else if (result == l_false) + add_bit0(k, null_dep); + else + return false; + } + pop(1); + return true; + } + class univariate_bitblast_solver : public univariate_solver { // TODO: does it make sense to share m and bv between different solver instances? // TODO: consider pooling solvers to save setup overhead, see if solver/solver_pool.h can be used ast_manager m; scoped_ptr bv; scoped_ptr s; - unsigned bit_width; unsigned m_scope_level = 0; func_decl_ref x_decl; expr_ref x; @@ -46,7 +103,7 @@ namespace polysat { public: univariate_bitblast_solver(solver_factory& mk_solver, unsigned bit_width) : - bit_width(bit_width), + univariate_solver(bit_width), x_decl(m), x(m) { reg_decl_plugins(m); @@ -192,7 +249,7 @@ namespace polysat { } template - void add_ule_impl(lhs_t const& lhs, rhs_t const& rhs, bool sign, dep_t dep) { + void add_ule_impl(lhs_t const& lhs, rhs_t const& rhs, bool sign, dep_t dep) { if (is_zero(rhs)) add(m.mk_eq(mk_poly(lhs), mk_poly(rhs)), sign, dep); else @@ -329,71 +386,289 @@ namespace polysat { s->get_model(model); SASSERT(model); app* val = to_app(model->get_const_interp(x_decl)); - SASSERT(val->get_decl_kind() == OP_BV_NUM); - SASSERT(val->get_num_parameters() == 2); - auto const& p = val->get_parameter(0); - SASSERT(p.is_rational()); - cached_model = p.get_rational(); + unsigned sz; + VERIFY(bv->is_numeral(val, cached_model, sz)); } return cached_model; } - bool find_min(rational& val) override { - val = model(); + + + bool find_two(rational& out1, rational& out2) override { + out1 = model(); + bool ok = true; push(); - // try reducing val by setting bits to 0, starting at the msb. - for (unsigned k = bit_width; k-- > 0; ) { - if (!val.get_bit(k)) { - add_bit0(k, null_dep); - continue; - } - // try decreasing k-th bit - push(); - add_bit0(k, 0); - lbool result = check(); - if (result == l_true) { - SASSERT(model() < val); - val = model(); - } - pop(1); - if (result == l_true) - add_bit0(k, null_dep); - else if (result == l_false) - add_bit1(k, null_dep); - else - return false; + add(m.mk_eq(mk_numeral(out1), x), true, null_dep); + switch (check()) { + case l_true: + out2 = model(); + break; + case l_false: + out2 = out1; + break; + default: + ok = false; + break; } pop(1); + IF_VERBOSE(10, verbose_stream() << "viable " << out1 << " " << out2 << "\n"); + return ok; + } + + + std::ostream& display(std::ostream& out) const override { + return out << *s; + } + }; + + // stub for alternative int-blast solver. + class univariate_intblast_solver : public univariate_solver { + ast_manager m; + scoped_ptr a; + scoped_ptr s; + rational m_mod; + unsigned m_scope_level = 0; + func_decl_ref x_decl; + expr_ref x; + vector model_cache; + + void add(expr* e, bool sign, dep_t dep) { + reset_cache(); + if (sign) + e = m.mk_not(e); + if (dep == null_dep) { + s->assert_expr(e); + IF_VERBOSE(10, verbose_stream() << "(assert " << expr_ref(e, m) << ")\n"); + } + else { + expr* a = m.mk_const(m.mk_const_decl(symbol(dep), m.mk_bool_sort())); + s->assert_expr(e, a); + IF_VERBOSE(10, verbose_stream() << "(assert (! " << expr_ref(e, m) << " :named " << expr_ref(a, m) << "))\n"); + } + } + + bool is_zero(univariate const& p) const { + for (auto n : p) + if (n != 0) + return false; return true; } - bool find_max(rational& val) override { - val = model(); - push(); - // try increasing val by setting bits to 1, starting at the msb. - for (unsigned k = bit_width; k-- > 0; ) { - if (val.get_bit(k)) { - add_bit1(k, 0); - continue; + bool is_zero(rational const& p) const { + return p.is_zero(); + } + + public: + univariate_intblast_solver(solver_factory& mk_solver, unsigned bit_width) : + univariate_solver(bit_width), + m_mod(rational::power_of_two(bit_width)), + x_decl(m), + x(m) { + reg_decl_plugins(m); + a = alloc(arith_util, m); + params_ref p; + p.set_bool("bv.polysat", false); + // p.set_bool("smt", true); + s = mk_solver(m, p, false, true, true, symbol::null); + x_decl = m.mk_const_decl("x", a->mk_int()); + x = m.mk_const(x_decl); + model_cache.push_back(rational(-1)); + s->assert_expr(a->mk_le(mk_numeral(0), x)); + s->assert_expr(a->mk_lt(x, mk_numeral(m_mod))); + } + + + ~univariate_intblast_solver() override = default; + + void reset_cache() { + model_cache.back() = -1; + } + + void push_cache() { + rational v = model_cache.back(); + model_cache.push_back(v); + } + + void pop_cache() { + model_cache.pop_back(); + } + + void push() override { + m_scope_level++; + push_cache(); + s->push(); + } + + void pop(unsigned n) override { + SASSERT(scope_level() >= n); + m_scope_level -= n; + pop_cache(); + s->pop(n); + } + + unsigned scope_level() const override { + return m_scope_level; + } + + expr* mk_numeral(rational const& r) const { + // assert 0 <= r < 2^bit-width + return a->mk_int(r); + } + + expr* mk_numeral(uint64_t u) const { + // assert u < 2^bit-width + return a->mk_int(rational(u, rational::ui64())); + } + + lbool check() override { + return s->check_sat(); + } + + expr_ref mk_poly(rational const& p) { + return {mk_numeral(p), m}; + } + + + // 2^k*x --> x << k + // n*x --> n * x + expr* mk_poly_term(rational const& coeff, expr* xpow) const { + SASSERT(!coeff.is_zero()); + if (coeff.is_one()) + return xpow; + return a->mk_mul(mk_numeral(coeff), xpow); + } + + // [d,c,b,a] --> d + c*x + b*(x*x) + a*(x*x*x) + expr_ref mk_poly(univariate const& p) { + expr_ref e(m); + if (p.empty()) + e = mk_numeral(rational::zero()); + else { + if (!p[0].is_zero()) + e = mk_numeral(p[0]); + expr_ref xpow = x; + for (unsigned i = 1; i < p.size(); ++i) { + if (!p[i].is_zero()) { + expr* t = mk_poly_term(p[i], xpow); + e = e ? a->mk_add(e, t) : t; + } + if (i + 1 < p.size()) + xpow = a->mk_mul(xpow, x); } - // try increasing k-th bit - push(); - add_bit1(k, 0); - lbool result = check(); - if (result == l_true) { - SASSERT(model() > val); - val = model(); - } - pop(1); - if (result == l_true) - add_bit1(k, null_dep); - else if (result == l_false) - add_bit0(k, null_dep); - else - return false; + if (!e) + e = mk_numeral(p[0]); } - pop(1); - return true; + if (!a->is_numeral(e) && e != x) + e = a->mk_mod(e, mk_numeral(m_mod)); + return e; + } + + + template + void add_ule_impl(lhs_t const& lhs, rhs_t const& rhs, bool sign, dep_t dep) { + // todo: simplify x - k == 0 into x = k + // or ensure that bounds simplification tactic is enabled. + // without bounds simplification, int-blasting doesnt work. + if (is_zero(rhs)) + add(m.mk_eq(mk_poly(lhs), mk_poly(rhs)), sign, dep); + else + add(a->mk_le(mk_poly(lhs), mk_poly(rhs)), sign, dep); + } + + void add_ule(univariate const& lhs, univariate const& rhs, bool sign, dep_t dep) override { add_ule_impl(lhs, rhs, sign, dep); } + void add_ule(univariate const& lhs, rational const& rhs, bool sign, dep_t dep) override { add_ule_impl(lhs, rhs, sign, dep); } + void add_ule(rational const& lhs, univariate const& rhs, bool sign, dep_t dep) override { add_ule_impl(lhs, rhs, sign, dep); } + + void add_umul_ovfl(univariate const& lhs, univariate const& rhs, bool sign, dep_t dep) override { + auto c = a->mk_mul(mk_poly(lhs), mk_poly(rhs)); + if (sign) // or the other way around? + add(a->mk_lt(c, mk_numeral(m_mod)), sign, dep); + else + add(a->mk_ge(c, mk_numeral(m_mod)), sign, dep); + } + + void add_smul_ovfl(univariate const& lhs, univariate const& rhs, bool sign, dep_t dep) override { + NOT_IMPLEMENTED_YET(); + } + + void add_smul_udfl(univariate const& lhs, univariate const& rhs, bool sign, dep_t dep) override { + NOT_IMPLEMENTED_YET(); + } + + void add_lshr(univariate const& in1, univariate const& in2, univariate const& out, bool sign, dep_t dep) override { + NOT_IMPLEMENTED_YET(); + } + + void add_ashr(univariate const& in1, univariate const& in2, univariate const& out, bool sign, dep_t dep) override { + NOT_IMPLEMENTED_YET(); + } + + void add_shl(univariate const& in1, univariate const& in2, univariate const& out, bool sign, dep_t dep) override { + NOT_IMPLEMENTED_YET(); + } + + void add_and(univariate const& in1, univariate const& in2, univariate const& out, bool sign, dep_t dep) override { + NOT_IMPLEMENTED_YET(); + } + + void add_or(univariate const& in1, univariate const& in2, univariate const& out, bool sign, dep_t dep) override { + NOT_IMPLEMENTED_YET(); + } + + void add_xor(univariate const& in1, univariate const& in2, univariate const& out, bool sign, dep_t dep) override { + NOT_IMPLEMENTED_YET(); + } + + void add_not(univariate const& in, univariate const& out, bool sign, dep_t dep) override { + NOT_IMPLEMENTED_YET(); + } + + void add_inv(univariate const& in, univariate const& out, bool sign, dep_t dep) override { + NOT_IMPLEMENTED_YET(); + } + + void add_udiv(univariate const& in1, univariate const& in2, univariate const& out, bool sign, dep_t dep) override { + NOT_IMPLEMENTED_YET(); + } + + void add_urem(univariate const& in1, univariate const& in2, univariate const& out, bool sign, dep_t dep) override { + NOT_IMPLEMENTED_YET(); + } + + void add_ule_const(rational const& val, bool sign, dep_t dep) override { + NOT_IMPLEMENTED_YET(); + } + + void add_uge_const(rational const& val, bool sign, dep_t dep) override { + NOT_IMPLEMENTED_YET(); + } + + void add_bit(unsigned idx, bool sign, dep_t dep) override { + add(m.mk_eq(mk_numeral(m_mod/2), a->mk_mul(mk_numeral(rational::power_of_two(bit_width - idx - 1)), x)), sign, dep); + } + + void unsat_core(dep_vector& deps) override { + deps.reset(); + expr_ref_vector core(m); + s->get_unsat_core(core); + for (expr* a : core) { + unsigned dep = to_app(a)->get_decl()->get_name().get_num(); + deps.push_back(dep); + } + IF_VERBOSE(10, verbose_stream() << "core " << deps << "\n"); + SASSERT(deps.size() > 0); + } + + rational model() override { + rational& cached_model = model_cache.back(); + if (cached_model.is_neg()) { + model_ref model; + s->get_model(model); + SASSERT(model); + app* val = to_app(model->get_const_interp(x_decl)); + VERIFY(a->is_numeral(val, cached_model)); + } + return cached_model; } bool find_two(rational& out1, rational& out2) override { @@ -419,7 +694,7 @@ namespace polysat { std::ostream& display(std::ostream& out) const override { return out << *s; - } + } }; class univariate_bitblast_factory : public univariate_solver_factory { diff --git a/src/math/polysat/univariate/univariate_solver.h b/src/math/polysat/univariate/univariate_solver.h index 2baaca623..59c82b15e 100644 --- a/src/math/polysat/univariate/univariate_solver.h +++ b/src/math/polysat/univariate/univariate_solver.h @@ -24,6 +24,8 @@ Author: namespace polysat { class univariate_solver { + protected: + unsigned bit_width; public: using dep_t = unsigned; using dep_vector = svector; @@ -34,6 +36,8 @@ namespace polysat { const dep_t null_dep = UINT_MAX; + univariate_solver(unsigned bit_width) : bit_width(bit_width) {} + virtual ~univariate_solver() = default; virtual void push() = 0; @@ -59,7 +63,7 @@ namespace polysat { * Precondition: check() returned l_true * Returns: true on success, false on resource out. */ - virtual bool find_min(rational& out_min) = 0; + bool find_min(rational& out_min); /** * Find maximal model. @@ -67,7 +71,7 @@ namespace polysat { * Precondition: check() returned l_true * Returns: true on success, false on resource out. */ - virtual bool find_max(rational& out_max) = 0; + bool find_max(rational& out_max); /** * Find up to two viable values.