diff --git a/src/sat/smt/intblast_solver.cpp b/src/sat/smt/intblast_solver.cpp index db5798236..0505eaa92 100644 --- a/src/sat/smt/intblast_solver.cpp +++ b/src/sat/smt/intblast_solver.cpp @@ -22,16 +22,109 @@ Author: namespace intblast { solver::solver(euf::solver& ctx) : + th_euf_solver(ctx, symbol("intblast"), ctx.get_manager().get_family_id("bv")), ctx(ctx), s(ctx.s()), m(ctx.get_manager()), bv(m), a(m), - m_trail(m), + m_args(m), + m_translate(m), m_pinned(m) {} - lbool solver::check() { + euf::theory_var solver::mk_var(euf::enode* n) { + auto r = euf::th_euf_solver::mk_var(n); + ctx.attach_th_var(n, this, r); + TRACE("bv", tout << "mk-var: v" << r << " " << ctx.bpp(n) << "\n";); + return r; + } + + sat::literal solver::internalize(expr* e, bool sign, bool root) { + force_push(); + SASSERT(m.is_bool(e)); + if (!visit_rec(m, e, sign, root)) + return sat::null_literal; + sat::literal lit = expr2literal(e); + if (sign) + lit.neg(); + return lit; + } + + void solver::internalize(expr* e) { + force_push(); + visit_rec(m, e, false, false); + } + + bool solver::visit(expr* e) { + if (!is_app(e) || to_app(e)->get_family_id() != get_id()) { + ctx.internalize(e); + return true; + } + m_stack.push_back(sat::eframe(e)); + return false; + } + + bool solver::visited(expr* e) { + euf::enode* n = expr2enode(e); + return n && n->is_attached_to(get_id()); + } + + bool solver::post_visit(expr* e, bool sign, bool root) { + euf::enode* n = expr2enode(e); + app* a = to_app(e); + if (visited(e)) + return true; + SASSERT(!n || !n->is_attached_to(get_id())); + if (!n) + n = mk_enode(e, false); + SASSERT(!n->is_attached_to(get_id())); + mk_var(n); + SASSERT(n->is_attached_to(get_id())); + internalize_bv(a); + return true; + } + + void solver::internalize_bv(app* e) { + ensure_args(e); + m_args.reset(); + for (auto arg : *e) + m_args.push_back(translated(arg)); + translate_bv(e); + if (m.is_bool(e)) + add_equiv(expr2literal(e), mk_literal(translated(e))); + } + + void solver::ensure_args(app* e) { + ptr_vector todo; + ast_fast_mark1 visited; + for (auto arg : *e) { + if (!m_translate.get(arg->get_id(), nullptr)) + todo.push_back(arg); + } + if (todo.empty()) + return; + for (unsigned i = 0; i < todo.size(); ++i) { + expr* e = todo[i]; + if (is_app(e)) { + for (auto arg : *to_app(e)) + if (!visited.is_marked(arg)) { + visited.mark(arg); + todo.push_back(arg); + } + } + else if (is_quantifier(e) && !visited.is_marked(to_quantifier(e)->get_expr())) { + visited.mark(to_quantifier(e)->get_expr()); + todo.push_back(to_quantifier(e)->get_expr()); + } + } + + std::stable_sort(todo.begin(), todo.end(), [&](expr* a, expr* b) { return get_depth(a) < get_depth(b); }); + for (expr* e : todo) + translate_expr(e); + } + + lbool solver::check_solver_state() { sat::literal_vector literals; uint_set selected; for (auto const& clause : s.clauses()) { @@ -84,7 +177,7 @@ namespace intblast { m_core.reset(); m_vars.reset(); - m_trail.reset(); + m_translate.reset(); m_solver = mk_smt2_solver(m, s.params(), symbol::null); expr_ref_vector es(m); @@ -123,7 +216,6 @@ namespace intblast { m_core.push_back(ctx.mk_literal(e)); } } - return r; }; @@ -189,349 +281,348 @@ namespace intblast { void solver::translate(expr_ref_vector& es) { ptr_vector todo; - obj_map translated; - expr_ref_vector args(m); sorted_subterms(es, todo); - for (expr* e : todo) { - if (is_quantifier(e)) { - quantifier* q = to_quantifier(e); - expr* b = q->get_expr(); - - unsigned nd = q->get_num_decls(); - ptr_vector sorts; - for (unsigned i = 0; i < nd; ++i) { - auto s = q->get_decl_sort(i); - if (bv.is_bv_sort(s)) { - NOT_IMPLEMENTED_YET(); - sorts.push_back(a.mk_int()); - } - else - sorts.push_back(s); - } - b = translated[b]; - // TODO if sorts contain integer, then created bounds variables. - m_trail.push_back(m.update_quantifier(q, b)); - translated.insert(e, m_trail.back()); - continue; - } - if (is_var(e)) { - if (bv.is_bv_sort(e->get_sort())) { - expr* v = m.mk_var(to_var(e)->get_idx(), a.mk_int()); - m_trail.push_back(v); - translated.insert(e, m_trail.back()); - } - else { - m_trail.push_back(e); - translated.insert(e, m_trail.back()); - } - continue; - } - app* ap = to_app(e); - expr* bv_expr = e; - args.reset(); - for (auto arg : *ap) - args.push_back(translated[arg]); - - auto bv_size = [&]() { return rational::power_of_two(bv.get_bv_size(bv_expr->get_sort())); }; - - auto mk_mod = [&](expr* x) { - if (m_vars.contains(x)) - return x; - return to_expr(a.mk_mod(x, a.mk_int(bv_size()))); - }; - - auto mk_smod = [&](expr* x) { - auto shift = bv_size() / 2; - return a.mk_mod(a.mk_add(x, a.mk_int(shift)), a.mk_int(bv_size())); - }; - - if (m.is_eq(e)) { - bool has_bv_arg = any_of(*ap, [&](expr* arg) { return bv.is_bv(arg); }); - if (has_bv_arg) { - bv_expr = ap->get_arg(0); - m_trail.push_back(m.mk_eq(mk_mod(args.get(0)), mk_mod(args.get(1)))); - translated.insert(e, m_trail.back()); - } - else { - m_trail.push_back(m.mk_eq(args.get(0), args.get(1))); - translated.insert(e, m_trail.back()); - } - continue; - } - - if (m.is_ite(e)) { - m_trail.push_back(m.mk_ite(args.get(0), args.get(1), args.get(2))); - translated.insert(e, m_trail.back()); - continue; - } - - if (ap->get_family_id() != bv.get_family_id()) { - bool has_bv_arg = any_of(*ap, [&](expr* arg) { return bv.is_bv(arg); }); - bool has_bv_sort = bv.is_bv(e); - func_decl* f = ap->get_decl(); - if (has_bv_arg) { - verbose_stream() << mk_pp(ap, m) << "\n"; - // need to update args with mod where they are bit-vectors. - NOT_IMPLEMENTED_YET(); - } - - if (has_bv_arg || has_bv_sort) { - ptr_vector domain; - for (auto* arg : *ap) { - sort* s = arg->get_sort(); - domain.push_back(bv.is_bv_sort(s) ? a.mk_int() : s); - } - sort* range = bv.is_bv(e) ? a.mk_int() : e->get_sort(); - func_decl* g = nullptr; - if (!m_new_funs.find(f, g)) { - g = m.mk_fresh_func_decl(ap->get_decl()->get_name(), symbol("bv"), domain.size(), domain.data(), range); - m_new_funs.insert(f, g); - m_pinned.push_back(f); - m_pinned.push_back(g); - } - f = g; - } - - m_trail.push_back(m.mk_app(f, args)); - translated.insert(e, m_trail.back()); - - if (has_bv_sort) - m_vars.insert(e, { m_trail.back(), bv_size() }); - - continue; - } - - auto bnot = [&](expr* e) { - return a.mk_sub(a.mk_int(-1), e); - }; - - auto band = [&](expr_ref_vector const& args) { - expr * r = args.get(0); - for (unsigned i = 1; i < args.size(); ++i) - r = a.mk_band(bv.get_bv_size(e), r, args.get(i)); - return r; - }; - - switch (ap->get_decl_kind()) { - case OP_BADD: - m_trail.push_back(a.mk_add(args)); - break; - case OP_BSUB: - m_trail.push_back(a.mk_sub(args.size(), args.data())); - break; - case OP_BMUL: - m_trail.push_back(a.mk_mul(args)); - break; - case OP_ULEQ: - bv_expr = ap->get_arg(0); - m_trail.push_back(a.mk_le(mk_mod(args.get(0)), mk_mod(args.get(1)))); - break; - case OP_UGEQ: - bv_expr = ap->get_arg(0); - m_trail.push_back(a.mk_ge(mk_mod(args.get(0)), mk_mod(args.get(1)))); - break; - case OP_ULT: - bv_expr = ap->get_arg(0); - m_trail.push_back(a.mk_lt(mk_mod(args.get(0)), mk_mod(args.get(1)))); - break; - case OP_UGT: - bv_expr = ap->get_arg(0); - m_trail.push_back(a.mk_gt(mk_mod(args.get(0)), mk_mod(args.get(1)))); - break; - case OP_SLEQ: - bv_expr = ap->get_arg(0); - m_trail.push_back(a.mk_le(mk_smod(args.get(0)), mk_smod(args.get(1)))); - break; - case OP_SGEQ: - m_trail.push_back(a.mk_ge(mk_smod(args.get(0)), mk_smod(args.get(1)))); - break; - case OP_SLT: - bv_expr = ap->get_arg(0); - m_trail.push_back(a.mk_lt(mk_smod(args.get(0)), mk_smod(args.get(1)))); - break; - case OP_SGT: - bv_expr = ap->get_arg(0); - m_trail.push_back(a.mk_gt(mk_smod(args.get(0)), mk_smod(args.get(1)))); - break; - case OP_BNEG: - m_trail.push_back(a.mk_uminus(args.get(0))); - break; - case OP_CONCAT: { - expr_ref r(a.mk_int(0), m); - unsigned sz = 0; - for (unsigned i = 0; i < args.size(); ++i) { - expr* old_arg = ap->get_arg(i); - expr* new_arg = args.get(i); - bv_expr = old_arg; - new_arg = mk_mod(new_arg); - if (sz > 0) { - new_arg = a.mk_mul(new_arg, a.mk_int(rational::power_of_two(sz))); - r = a.mk_add(r, new_arg); - } - else - r = new_arg; - sz += bv.get_bv_size(old_arg->get_sort()); - } - m_trail.push_back(r); - break; - } - case OP_EXTRACT: { - unsigned lo, hi; - expr* old_arg; - VERIFY(bv.is_extract(e, lo, hi, old_arg)); - unsigned sz = hi - lo + 1; - expr* new_arg = args.get(0); - if (lo > 0) - new_arg = a.mk_idiv(new_arg, a.mk_int(rational::power_of_two(lo))); - m_trail.push_back(new_arg); - break; - } - case OP_BV_NUM: { - rational val; - unsigned sz; - VERIFY(bv.is_numeral(e, val, sz)); - m_trail.push_back(a.mk_int(val)); - break; - } - case OP_BUREM_I: { - expr* x = args.get(0), * y = args.get(1); - m_trail.push_back(m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(0), a.mk_mod(x, y))); - break; - } - case OP_BUDIV_I: { - expr* x = args.get(0), * y = args.get(1); - m_trail.push_back(m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(0), a.mk_idiv(x, y))); - break; - } - case OP_BUMUL_NO_OVFL: { - expr* x = args.get(0), * y = args.get(1); - bv_expr = ap->get_arg(0); - m_trail.push_back(a.mk_lt(a.mk_mul(mk_mod(x), mk_mod(y)), a.mk_int(bv_size()))); - break; - } - case OP_BSHL: { - expr* x = args.get(0), * y = args.get(1); - expr* r = a.mk_int(0); - for (unsigned i = 0; i < bv.get_bv_size(e); ++i) - r = m.mk_ite(a.mk_eq(y, a.mk_int(i)), a.mk_mul(x, a.mk_int(rational::power_of_two(i))), r); - m_trail.push_back(r); - break; - } - case OP_BNOT: - m_trail.push_back(bnot(args.get(0))); - break; - case OP_BLSHR: { - expr* x = args.get(0), * y = args.get(1); - expr* r = a.mk_int(0); - for (unsigned i = 0; i < bv.get_bv_size(e); ++i) - r = m.mk_ite(a.mk_eq(y, a.mk_int(i)), a.mk_idiv(x, a.mk_int(rational::power_of_two(i))), r); - m_trail.push_back(r); - break; - } - // Or use (p + q) - band(p, q)? - case OP_BOR: - for (unsigned i = 0; i < args.size(); ++i) - args[i] = bnot(args.get(i)); - m_trail.push_back(bnot(band(args))); - break; - case OP_BNAND: - m_trail.push_back(bnot(band(args))); - break; - case OP_BAND: - m_trail.push_back(band(args)); - break; - // From "Hacker's Delight", section 2-2. Addition Combined with Logical Operations; - // found via Int-Blasting paper; see https://doi.org/10.1007/978-3-030-94583-1_24 - // (p + q) - 2*band(p, q); - case OP_BXNOR: - case OP_BXOR: { - unsigned sz = bv.get_bv_size(e); - expr* p = args.get(0); - for (unsigned i = 1; i < args.size(); ++i) { - expr* q = args.get(i); - p = a.mk_sub(a.mk_add(p, q), a.mk_mul(a.mk_int(2), a.mk_band(sz, p, q))); - } - if (ap->get_decl_kind() == OP_BXNOR) - p = bnot(p); - m_trail.push_back(p); - break; - } - case OP_BUDIV: { - bv_rewriter_params p(ctx.s().params()); - expr* x = args.get(0), * y = args.get(1); - if (p.hi_div0()) - m_trail.push_back(m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(0), a.mk_idiv(x, y))); - else - m_trail.push_back(a.mk_idiv(x, y)); - break; - } - case OP_BUREM: { - bv_rewriter_params p(ctx.s().params()); - expr* x = args.get(0), * y = args.get(1); - if (p.hi_div0()) - m_trail.push_back(m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(0), a.mk_mod(x, y))); - else - m_trail.push_back(a.mk_mod(x, y)); - break; - } - // - // ashr(x, y) - // if y = k & x >= 0 -> x / 2^k - // if y = k & x < 0 -> - (x / 2^k) - // - - case OP_BASHR: { - expr* x = args.get(0), * y = args.get(1); - rational N = rational::power_of_two(bv.get_bv_size(e)); - bv_expr = ap; - x = mk_mod(x); - y = mk_mod(y); - expr* signbit = a.mk_ge(x, a.mk_int(N/2)); - expr* r = m.mk_ite(signbit, a.mk_int(N - 1), a.mk_int(0)); - for (unsigned i = 0; i < bv.get_bv_size(e); ++i) { - expr* d = a.mk_idiv(x, a.mk_int(rational::power_of_two(i))); - r = m.mk_ite(a.mk_eq(y, a.mk_int(i)), - m.mk_ite(signbit, a.mk_uminus(d), d), - r); - } - m_trail.push_back(r); - break; - } - case OP_BCOMP: - - case OP_ROTATE_LEFT: - case OP_ROTATE_RIGHT: - case OP_EXT_ROTATE_LEFT: - case OP_EXT_ROTATE_RIGHT: - case OP_REPEAT: - case OP_ZERO_EXT: - case OP_SIGN_EXT: - case OP_BREDOR: - case OP_BREDAND: - case OP_BSDIV: - case OP_BSREM: - case OP_BSMOD: - verbose_stream() << mk_pp(e, m) << "\n"; - NOT_IMPLEMENTED_YET(); - break; - default: - verbose_stream() << mk_pp(e, m) << "\n"; - NOT_IMPLEMENTED_YET(); - } - translated.insert(e, m_trail.back()); - } + for (expr* e : todo) + translate_expr(e); TRACE("bv", for (expr* e : es) - tout << mk_pp(e, m) << "\n->\n" << mk_pp(translated[e], m) << "\n"; + tout << mk_pp(e, m) << "\n->\n" << mk_pp(translated(e), m) << "\n"; ); for (unsigned i = 0; i < es.size(); ++i) - es[i] = translated[es.get(i)]; + es[i] = translated(es.get(i)); + } + expr* solver::mk_mod(expr* x) { + if (m_vars.contains(x)) + return x; + return to_expr(a.mk_mod(x, a.mk_int(bv_size()))); + } + expr* solver::mk_smod(expr* x) { + auto shift = bv_size() / 2; + return a.mk_mod(a.mk_add(x, a.mk_int(shift)), a.mk_int(bv_size())); + } + + rational solver::bv_size() { + return rational::power_of_two(bv.get_bv_size(bv_expr->get_sort())); + } + + void solver::translate_expr(expr* e) { + if (is_quantifier(e)) + translate_quantifier(to_quantifier(e)); + else if (is_var(e)) + translate_var(to_var(e)); + else { + app* ap = to_app(e); + bv_expr = e; + m_args.reset(); + for (auto arg : *ap) + m_args.push_back(translated(arg)); + + if (ap->get_family_id() == basic_family_id) + translate_basic(ap); + else if (ap->get_family_id() == bv.get_family_id()) + translate_bv(ap); + else + translate_app(ap); + } + } + + void solver::translate_quantifier(quantifier* q) { + expr* b = q->get_expr(); + unsigned nd = q->get_num_decls(); + ptr_vector sorts; + for (unsigned i = 0; i < nd; ++i) { + auto s = q->get_decl_sort(i); + if (bv.is_bv_sort(s)) { + NOT_IMPLEMENTED_YET(); + sorts.push_back(a.mk_int()); + } + else + sorts.push_back(s); + } + b = translated(b); + // TODO if sorts contain integer, then created bounds variables. + set_translated(q, m.update_quantifier(q, b)); + } + + void solver::translate_var(var* v) { + if (bv.is_bv_sort(v->get_sort())) + set_translated(v, m.mk_var(v->get_idx(), a.mk_int())); + else + set_translated(v, v); + } + + void solver::translate_app(app* e) { + bool has_bv_arg = any_of(*e, [&](expr* arg) { return bv.is_bv(arg); }); + bool has_bv_sort = bv.is_bv(e); + func_decl* f = e->get_decl(); + if (has_bv_arg) { + verbose_stream() << mk_pp(e, m) << "\n"; + // need to update args with mod where they are bit-vectors. + NOT_IMPLEMENTED_YET(); + } + + if (has_bv_arg || has_bv_sort) { + ptr_vector domain; + for (auto* arg : *e) { + sort* s = arg->get_sort(); + domain.push_back(bv.is_bv_sort(s) ? a.mk_int() : s); + } + sort* range = bv.is_bv(e) ? a.mk_int() : e->get_sort(); + func_decl* g = nullptr; + if (!m_new_funs.find(f, g)) { + g = m.mk_fresh_func_decl(e->get_decl()->get_name(), symbol("bv"), domain.size(), domain.data(), range); + m_new_funs.insert(f, g); + m_pinned.push_back(f); + m_pinned.push_back(g); + } + f = g; + } + + set_translated(e, m.mk_app(f, m_args)); + + if (has_bv_sort) + m_vars.insert(e, { translated(e), bv_size()}); + } + + void solver::translate_bv(app* e) { + + auto bnot = [&](expr* e) { + return a.mk_sub(a.mk_int(-1), e); + }; + + auto band = [&](expr_ref_vector const& args) { + expr* r = arg(0); + for (unsigned i = 1; i < args.size(); ++i) + r = a.mk_band(bv.get_bv_size(e), r, arg(i)); + return r; + }; + + bv_expr = e; + expr* r = nullptr; + auto const& args = m_args; + switch (e->get_decl_kind()) { + case OP_BADD: + r = (a.mk_add(args)); + break; + case OP_BSUB: + r = (a.mk_sub(args.size(), args.data())); + break; + case OP_BMUL: + r = (a.mk_mul(args)); + break; + case OP_ULEQ: + bv_expr = e->get_arg(0); + r = (a.mk_le(mk_mod(arg(0)), mk_mod(arg(1)))); + break; + case OP_UGEQ: + bv_expr = e->get_arg(0); + r = (a.mk_ge(mk_mod(arg(0)), mk_mod(arg(1)))); + break; + case OP_ULT: + bv_expr = e->get_arg(0); + r = (a.mk_lt(mk_mod(arg(0)), mk_mod(arg(1)))); + break; + case OP_UGT: + bv_expr = e->get_arg(0); + r = (a.mk_gt(mk_mod(arg(0)), mk_mod(arg(1)))); + break; + case OP_SLEQ: + bv_expr = e->get_arg(0); + r = (a.mk_le(mk_smod(arg(0)), mk_smod(arg(1)))); + break; + case OP_SGEQ: + r = (a.mk_ge(mk_smod(arg(0)), mk_smod(arg(1)))); + break; + case OP_SLT: + bv_expr = e->get_arg(0); + r = (a.mk_lt(mk_smod(arg(0)), mk_smod(arg(1)))); + break; + case OP_SGT: + bv_expr = e->get_arg(0); + r = (a.mk_gt(mk_smod(arg(0)), mk_smod(arg(1)))); + break; + case OP_BNEG: + r = (a.mk_uminus(arg(0))); + break; + case OP_CONCAT: { + r = a.mk_int(0); + unsigned sz = 0; + for (unsigned i = 0; i < args.size(); ++i) { + expr* old_arg = e->get_arg(i); + expr* new_arg = arg(i); + bv_expr = old_arg; + new_arg = mk_mod(new_arg); + if (sz > 0) { + new_arg = a.mk_mul(new_arg, a.mk_int(rational::power_of_two(sz))); + r = a.mk_add(r, new_arg); + } + else + r = new_arg; + sz += bv.get_bv_size(old_arg->get_sort()); + } + break; + } + case OP_EXTRACT: { + unsigned lo, hi; + expr* old_arg; + VERIFY(bv.is_extract(e, lo, hi, old_arg)); + unsigned sz = hi - lo + 1; + expr* r = arg(0); + if (lo > 0) + r = a.mk_idiv(r, a.mk_int(rational::power_of_two(lo))); + break; + } + case OP_BV_NUM: { + rational val; + unsigned sz; + VERIFY(bv.is_numeral(e, val, sz)); + r = (a.mk_int(val)); + break; + } + case OP_BUREM_I: { + expr* x = arg(0), * y = arg(1); + r = (m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(0), a.mk_mod(x, y))); + break; + } + case OP_BUDIV_I: { + expr* x = arg(0), * y = arg(1); + r = (m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(0), a.mk_idiv(x, y))); + break; + } + case OP_BUMUL_NO_OVFL: { + expr* x = arg(0), * y = arg(1); + bv_expr = e->get_arg(0); + r = (a.mk_lt(a.mk_mul(mk_mod(x), mk_mod(y)), a.mk_int(bv_size()))); + break; + } + case OP_BSHL: { + expr* x = arg(0), * y = arg(1); + r = a.mk_int(0); + for (unsigned i = 0; i < bv.get_bv_size(e); ++i) + r = m.mk_ite(a.mk_eq(y, a.mk_int(i)), a.mk_mul(x, a.mk_int(rational::power_of_two(i))), r); + break; + } + case OP_BNOT: + r = (bnot(arg(0))); + break; + case OP_BLSHR: { + expr* x = arg(0), * y = arg(1); + r = a.mk_int(0); + for (unsigned i = 0; i < bv.get_bv_size(e); ++i) + r = m.mk_ite(a.mk_eq(y, a.mk_int(i)), a.mk_idiv(x, a.mk_int(rational::power_of_two(i))), r); + break; + } + // Or use (p + q) - band(p, q)? + case OP_BOR: { + r = arg(0); + for (unsigned i = 1; i < args.size(); ++i) + r = a.mk_sub(a.mk_add(r, arg(i)), a.mk_band(bv.get_bv_size(e), r, arg(i))); + break; + } + case OP_BNAND: + r = (bnot(band(args))); + break; + case OP_BAND: + r = (band(args)); + break; + // From "Hacker's Delight", section 2-2. Addition Combined with Logical Operations; + // found via Int-Blasting paper; see https://doi.org/10.1007/978-3-030-94583-1_24 + // (p + q) - 2*band(p, q); + case OP_BXNOR: + case OP_BXOR: { + unsigned sz = bv.get_bv_size(e); + expr* p = arg(0); + for (unsigned i = 1; i < args.size(); ++i) { + expr* q = arg(i); + p = a.mk_sub(a.mk_add(p, q), a.mk_mul(a.mk_int(2), a.mk_band(sz, p, q))); + } + if (e->get_decl_kind() == OP_BXNOR) + p = bnot(p); + r = (p); + break; + } + case OP_BUDIV: { + bv_rewriter_params p(ctx.s().params()); + expr* x = arg(0), * y = arg(1); + if (p.hi_div0()) + r = (m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(0), a.mk_idiv(x, y))); + else + r = (a.mk_idiv(x, y)); + break; + } + case OP_BUREM: { + bv_rewriter_params p(ctx.s().params()); + expr* x = arg(0), * y = arg(1); + if (p.hi_div0()) + r = (m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(0), a.mk_mod(x, y))); + else + r = (a.mk_mod(x, y)); + break; + } + // + // ashr(x, y) + // if y = k & x >= 0 -> x / 2^k + // if y = k & x < 0 -> - (x / 2^k) + // + + case OP_BASHR: { + expr* x = arg(0), * y = arg(1); + rational N = rational::power_of_two(bv.get_bv_size(e)); + bv_expr = e; + x = mk_mod(x); + y = mk_mod(y); + expr* signbit = a.mk_ge(x, a.mk_int(N / 2)); + r = m.mk_ite(signbit, a.mk_int(N - 1), a.mk_int(0)); + for (unsigned i = 0; i < bv.get_bv_size(e); ++i) { + expr* d = a.mk_idiv(x, a.mk_int(rational::power_of_two(i))); + r = m.mk_ite(a.mk_eq(y, a.mk_int(i)), + m.mk_ite(signbit, a.mk_uminus(d), d), + r); + } + break; + } + case OP_BCOMP: + + case OP_ROTATE_LEFT: + case OP_ROTATE_RIGHT: + case OP_EXT_ROTATE_LEFT: + case OP_EXT_ROTATE_RIGHT: + case OP_REPEAT: + case OP_ZERO_EXT: + case OP_SIGN_EXT: + case OP_BREDOR: + case OP_BREDAND: + case OP_BSDIV: + case OP_BSREM: + case OP_BSMOD: + verbose_stream() << mk_pp(e, m) << "\n"; + NOT_IMPLEMENTED_YET(); + break; + default: + verbose_stream() << mk_pp(e, m) << "\n"; + NOT_IMPLEMENTED_YET(); + } + set_translated(e, r); + } + + void solver::translate_basic(app* e) { + if (m.is_eq(e)) { + bool has_bv_arg = any_of(*e, [&](expr* arg) { return bv.is_bv(arg); }); + if (has_bv_arg) { + bv_expr = e->get_arg(0); + set_translated(e, m.mk_eq(mk_mod(arg(0)), mk_mod(arg(1)))); + } + else + set_translated(e, m.mk_eq(arg(0), arg(1))); + } + else + set_translated(e, m.mk_app(e->get_decl(), m_args)); } rational solver::get_value(expr* e) const { diff --git a/src/sat/smt/intblast_solver.h b/src/sat/smt/intblast_solver.h index b87724cc8..037b009a3 100644 --- a/src/sat/smt/intblast_solver.h +++ b/src/sat/smt/intblast_solver.h @@ -8,12 +8,24 @@ Module Name: Abstract: Int-blast solver. - It assumes a full assignemnt to literals in + + check_solver_state assumes a full assignment to literals in irredundant clauses. It picks a satisfying Boolean assignment and checks if it is feasible for bit-vectors using an arithmetic solver. + The solver plugin is self-contained. + + Internalize: + - internalize bit-vector terms bottom-up by updating m_translate. + - add axioms of the form: + - ule(b,a) <=> translate(ule(b, a)) + - let arithmetic solver handle bit-vector constraints. + - For shared b + - Ensure: int2bv(translate(b)) = b + - but avoid bit-blasting by ensuring int2bv is injective (mod N) during final check + Author: Nikolaj Bjorner (nbjorner) 2023-12-10 @@ -33,7 +45,7 @@ namespace euf { namespace intblast { - class solver { + class solver : public euf::th_euf_solver { struct var_info { expr* dst; rational sz; @@ -47,7 +59,7 @@ namespace intblast { scoped_ptr<::solver> m_solver; obj_map m_vars; obj_map m_new_funs; - expr_ref_vector m_trail; + expr_ref_vector m_translate, m_args; ast_ref_vector m_pinned; sat::literal_vector m_core; statistics m_stats; @@ -58,18 +70,68 @@ namespace intblast { rational get_value(expr* e) const; + expr* translated(expr* e) { expr* r = m_translate.get(e->get_id(), nullptr); SASSERT(r); return r; } + void set_translated(expr* e, expr* r) { m_translate.setx(e->get_id(), r); } + expr* arg(unsigned i) { return m_args.get(i); } + + expr* mk_mod(expr* x); + expr* mk_smod(expr* x); + expr* bv_expr = nullptr; + rational bv_size(); + + void translate_expr(expr* e); + void translate_bv(app* e); + void translate_basic(app* e); + void translate_app(app* e); + void translate_quantifier(quantifier* q); + void translate_var(var* v); + + void ensure_args(app* e); + void internalize_bv(app* e); + + euf::theory_var mk_var(euf::enode* n) override; + public: solver(euf::solver& ctx); - lbool check(); + ~solver() override {} + + lbool check_solver_state(); sat::literal_vector const& unsat_core(); - void add_value(euf::enode* n, model& mdl, expr_ref_vector& values); + void add_value(euf::enode* n, model& mdl, expr_ref_vector& values) override; - std::ostream& display(std::ostream& out) const; + std::ostream& display(std::ostream& out) const override; + + void collect_statistics(statistics& st) const override; + + + + bool unit_propagate() override { return false; } + + void get_antecedents(sat::literal l, sat::ext_justification_idx idx, sat::literal_vector& r, bool probing) override {} + + sat::check_result check() override { return sat::check_result::CR_DONE; } + + std::ostream& display_justification(std::ostream& out, sat::ext_justification_idx idx) const override { return out; } + + std::ostream& display_constraint(std::ostream& out, sat::ext_constraint_idx idx) const override { return out; } + + euf::th_solver* clone(euf::solver& ctx) override { return alloc(solver, ctx); } + + void internalize(expr* e) override; + + bool visited(expr* e) override; + + bool post_visit(expr* e, bool sign, bool root) override; + + bool visit(expr* e) override; + + sat::literal internalize(expr* e, bool, bool) override; + + void eq_internalized(euf::enode* n) override {} - void collect_statistics(statistics& st) const; }; } diff --git a/src/sat/smt/polysat/constraints.h b/src/sat/smt/polysat/constraints.h index 15d8dfa09..a9ec63165 100644 --- a/src/sat/smt/polysat/constraints.h +++ b/src/sat/smt/polysat/constraints.h @@ -41,6 +41,8 @@ namespace polysat { virtual std::ostream& display(std::ostream& out) const = 0; virtual lbool eval() const = 0; virtual lbool eval(assignment const& a) const = 0; + virtual void activate(core& c, bool sign, dependency const& d) = 0; + virtual void propagate(core& c, lbool value, dependency const& d) = 0; }; inline std::ostream& operator<<(std::ostream& out, constraint const& c) { return c.display(out); } @@ -61,6 +63,8 @@ namespace polysat { unsigned_vector const& vars() const { return m_constraint->vars(); } unsigned var(unsigned idx) const { return m_constraint->var(idx); } bool contains_var(pvar v) const { return m_constraint->contains_var(v); } + void activate(core& c, dependency const& d) { m_constraint->activate(c, m_sign, d); } + void propagate(core& c, lbool value, dependency const& d) { m_constraint->propagate(c, value, d); } bool is_always_true() const { return eval() == l_true; } bool is_always_false() const { return eval() == l_false; } lbool eval(assignment& a) const; @@ -84,6 +88,8 @@ namespace polysat { signed_constraint eq(pdd const& p) { return ule(p, p.manager().mk_val(0)); } signed_constraint eq(pdd const& p, rational const& v) { return eq(p - p.manager().mk_val(v)); } + signed_constraint eq(pdd const& p, unsigned v) { return eq(p - p.manager().mk_val(v)); } + signed_constraint eq(pdd const& p, pdd const& q) { return eq(p - q); } signed_constraint ule(pdd const& p, pdd const& q); signed_constraint sle(pdd const& p, pdd const& q) { auto sh = rational::power_of_two(p.power_of_2() - 1); return ule(p + sh, q + sh); } signed_constraint ult(pdd const& p, pdd const& q) { return ~ule(q, p); } diff --git a/src/sat/smt/polysat/core.cpp b/src/sat/smt/polysat/core.cpp index a552bb9ab..dd30e8227 100644 --- a/src/sat/smt/polysat/core.cpp +++ b/src/sat/smt/polysat/core.cpp @@ -187,12 +187,13 @@ namespace polysat { return sc; } - void core::propagate_assignment(prop_item& dc) { auto [idx, sign, dep] = dc; auto sc = get_constraint(idx, sign); if (sc.is_eq(m_var, m_value)) propagate_assignment(m_var, m_value, dep); + else + sc.activate(*this, dep); } void core::add_watch(unsigned idx, unsigned var) { @@ -216,7 +217,7 @@ namespace polysat { // for entries where there is only one free variable left add to viable set unsigned j = 0; for (auto idx : m_watch[v]) { - auto [sc, as, value] = m_constraint_index[idx]; + auto [sc, dep, value] = m_constraint_index[idx]; auto& vars = sc.vars(); if (vars[0] != v) std::swap(vars[0], vars[1]); @@ -231,6 +232,8 @@ namespace polysat { } } + sc.propagate(*this, value, dep); + SASSERT(!swapped || vars.size() <= 1 || (!is_assigned(vars[0]) && !is_assigned(vars[1]))); if (swapped) continue; @@ -262,6 +265,11 @@ namespace polysat { default: break; } + // propagate current assignment for sc + sc.propagate(*this, to_lbool(!sign), dep); + if (s.inconsistent()) + return; + // if sc is v == value, then check the watch list for v to propagate truth assignments if (sc.is_eq(m_var, m_value)) { for (auto idx1 : m_watch[m_var]) { @@ -360,4 +368,13 @@ namespace polysat { } + void core::add_axiom(signed_constraint sc) { + auto idx = register_constraint(sc, dependency::axiom()); + assign_eh(idx, false, dependency::axiom()); + } + + void core::add_clause(char const* name, core_vector const& cs, bool is_redundant) { + s.add_polysat_clause(name, cs, is_redundant); + } + } diff --git a/src/sat/smt/polysat/core.h b/src/sat/smt/polysat/core.h index c3dddfece..6297e567e 100644 --- a/src/sat/smt/polysat/core.h +++ b/src/sat/smt/polysat/core.h @@ -84,7 +84,7 @@ namespace polysat { void get_bitvector_prefixes(pvar v, pvar_vector& out); void get_fixed_bits(pvar v, svector& fixed_bits); bool inconsistent() const; - void add_clause(char const* name, std::initializer_list cs, bool is_redundant); + void add_watch(unsigned idx, unsigned var); @@ -94,6 +94,8 @@ namespace polysat { lbool eval(signed_constraint const& sc); dependency_vector explain_eval(signed_constraint const& sc); + void add_axiom(signed_constraint sc); + public: core(solver_interface& s); @@ -118,13 +120,20 @@ namespace polysat { signed_constraint bit(pdd const& p, unsigned i) { return m_constraints.bit(p, i); } - signed_constraint lshr(pdd const& a, pdd const& b, pdd const& r) { return m_constraints.lshr(a, b, r); } - signed_constraint ashr(pdd const& a, pdd const& b, pdd const& r) { return m_constraints.ashr(a, b, r); } - signed_constraint shl(pdd const& a, pdd const& b, pdd const& r) { return m_constraints.shl(a, b, r); } - signed_constraint band(pdd const& a, pdd const& b, pdd const& r) { return m_constraints.band(a, b, r); } + void lshr(pdd const& a, pdd const& b, pdd const& r) { add_axiom(m_constraints.lshr(a, b, r)); } + void ashr(pdd const& a, pdd const& b, pdd const& r) { add_axiom(m_constraints.ashr(a, b, r)); } + void shl(pdd const& a, pdd const& b, pdd const& r) { add_axiom(m_constraints.shl(a, b, r)); } + void band(pdd const& a, pdd const& b, pdd const& r) { add_axiom(m_constraints.band(a, b, r)); } pdd bnot(pdd p) { return -p - 1; } + + /* + * Add a named clause. Dependencies are assumed, signed constraints are guaranteeed. + * In other words, the clause represents the formula /\ d_i -> \/ sc_j + * Where d_i are logical interpretations of dependencies and sc_j are signed constraints. + */ + void add_clause(char const* name, core_vector const& cs, bool is_redundant); pvar add_var(unsigned sz); pdd var(pvar p) { return m_vars[p]; } diff --git a/src/sat/smt/polysat/op_constraint.cpp b/src/sat/smt/polysat/op_constraint.cpp index 401b1ca52..234ddea04 100644 --- a/src/sat/smt/polysat/op_constraint.cpp +++ b/src/sat/smt/polysat/op_constraint.cpp @@ -13,13 +13,13 @@ Notes: Additional possible functionality on constraints: -- activate - when operation is first activated. It may be created and only activated later. - bit-wise assignments - narrow based on bit assignment, not entire word assignment. - integration with congruence tables - integration with conflict resolution --*/ +#include "util/log.h" #include "sat/smt/polysat/op_constraint.h" #include "sat/smt/polysat/core.h" @@ -157,7 +157,6 @@ namespace polysat { return out << "&"; case op_constraint::code::inv_op: return out << "inv"; - default: UNREACHABLE(); return out; @@ -176,96 +175,95 @@ namespace polysat { return out << r() << " " << eq << " " << p() << " " << m_op << " " << q(); } -#if 0 - /** - * Produce lemmas that contradict the given assignment. - * - * We can assume that op_constraint is only asserted positive. - */ - clause_ref op_constraint::produce_lemma(solver& s, assignment const& a, bool is_positive) { - SASSERT(is_positive); - - if (is_currently_true(a, is_positive)) - return {}; - - return produce_lemma(s, a); - } - - clause_ref op_constraint::produce_lemma(solver& s, assignment const& a) { + void op_constraint::activate(core& c, bool sign, dependency const& dep) { + SASSERT(!sign); switch (m_op) { - case code::lshr_op: - return lemma_lshr(s, a); - case code::shl_op: - return lemma_shl(s, a); case code::and_op: - return lemma_and(s, a); - case code::inv_op: - return lemma_inv(s, a); + activate_and(c, dep); + break; default: - NOT_IMPLEMENTED_YET(); - return {}; + break; } } + void op_constraint::propagate(core& c, lbool value, dependency const& dep) { + SASSERT(value == l_true); + switch (m_op) { + case code::lshr_op: + propagate_lshr(c, dep); + break; + case code::shl_op: + propagate_shl(c, dep); + break; + case code::and_op: + propagate_and(c, dep); + break; + case code::inv_op: + propagate_inv(c, dep); + break; + default: + NOT_IMPLEMENTED_YET(); + break; + } + } + + void op_constraint::propagate_inv(core& s, dependency const& dep) { + + } + /** - * Enforce basic axioms for r == p >> q: - * - * q >= N -> r = 0 - * q >= k -> r[i] = 0 for N - k <= i < N (bit indices range from 0 to N-1, inclusive) - * q >= k -> r <= 2^{N-k} - 1 - * q = k -> r[i] = p[i+k] for 0 <= i < N - k - * r <= p - * q != 0 -> r <= p (subsumed by previous axiom) - * q != 0 /\ p > 0 -> r < p - * q = 0 -> r = p - * p = q -> r = 0 - * - * when q is a constant, several axioms can be enforced at activation time. - * - * Enforce also inferences and bounds - * - * TODO: use also - * s.m_viable.min_viable(); - * s.m_viable.max_viable() - * when r, q are variables. - */ - clause_ref op_constraint::lemma_lshr(solver& s, assignment const& a) { + * Enforce basic axioms for r == p >> q: + * + * q >= N -> r = 0 + * q >= k -> r[i] = 0 for N - k <= i < N (bit indices range from 0 to N-1, inclusive) + * q >= k -> r <= 2^{N-k} - 1 + * q = k -> r[i] = p[i+k] for 0 <= i < N - k + * r <= p + * q != 0 -> r <= p (subsumed by previous axiom) + * q != 0 /\ p > 0 -> r < p + * q = 0 -> r = p + * p = q -> r = 0 + * + * when q is a constant, several axioms can be enforced at activation time. + * + * Enforce also inferences and bounds + * + * TODO: use also + * s.m_viable.min_viable(); + * s.m_viable.max_viable() + * when r, q are variables. + */ + void op_constraint::propagate_lshr(core& c, dependency const& d) { auto& m = p().manager(); - auto const pv = a.apply_to(p()); - auto const qv = a.apply_to(q()); - auto const rv = a.apply_to(r()); + auto const pv = c.subst(p()); + auto const qv = c.subst(q()); + auto const rv = c.subst(r()); unsigned const N = m.power_of_2(); - signed_constraint const lshr(this, true); + + signed_constraint const lshr(polysat::ckind_t::op_t, this); + auto& C = c.cs(); if (pv.is_val() && rv.is_val() && rv.val() > pv.val()) - // r <= p - return s.mk_clause(~lshr, s.ule(r(), p()), true); + c.add_clause("lshr 1", { d, C.ule(r(), p()) }, false); + else if (qv.is_val() && qv.val() >= N && rv.is_val() && !rv.is_zero()) // TODO: instead of rv.is_val() && !rv.is_zero(), we should use !is_forced_zero(r) which checks whether eval(r) = 0 or bvalue(r=0) = true; see saturation.cpp - // q >= N -> r = 0 - return s.mk_clause(~lshr, ~s.ule(N, q()), s.eq(r()), true); + c.add_clause("q >= N -> r = 0", { d, ~C.ule(N, q()), C.eq(r()) }, true); else if (qv.is_zero() && pv.is_val() && rv.is_val() && pv != rv) - // q = 0 -> p = r - return s.mk_clause(~lshr, ~s.eq(q()), s.eq(p(), r()), true); + c.add_clause("q = 0 -> p = r", { d, ~C.eq(q()), C.eq(p(), r()) } , true); else if (qv.is_val() && !qv.is_zero() && pv.is_val() && rv.is_val() && !pv.is_zero() && rv.val() >= pv.val()) - // q != 0 & p > 0 -> r < p - return s.mk_clause(~lshr, s.eq(q()), s.ule(p(), 0), s.ult(r(), p()), true); + c.add_clause("q != 0 & p > 0 -> r < p", { d, C.eq(q()), C.ule(p(), 0), C.ult(r(), p()) }, true); else if (qv.is_val() && !qv.is_zero() && qv.val() < N && rv.is_val() && rv.val() > rational::power_of_two(N - qv.val().get_unsigned()) - 1) - // q >= k -> r <= 2^{N-k} - 1 - return s.mk_clause(~lshr, ~s.ule(qv.val(), q()), s.ule(r(), rational::power_of_two(N - qv.val().get_unsigned()) - 1), true); - // else if (pv == qv && !rv.is_zero()) - // return s.mk_clause(~lshr, ~s.eq(p(), q()), s.eq(r()), true); + c.add_clause("q >= k -> r <= 2^{N-k} - 1", { d, ~C.ule(qv.val(), q()), C.ule(r(), rational::power_of_two(N - qv.val().get_unsigned()) - 1)}, true); else if (pv.is_val() && rv.is_val() && qv.is_val() && !qv.is_zero()) { unsigned k = qv.val().get_unsigned(); - // q = k -> r[i] = p[i+k] for 0 <= i < N - k for (unsigned i = 0; i < N - k; ++i) { - if (rv.val().get_bit(i) && !pv.val().get_bit(i + k)) { - return s.mk_clause(~lshr, ~s.eq(q(), k), ~s.bit(r(), i), s.bit(p(), i + k), true); - } - if (!rv.val().get_bit(i) && pv.val().get_bit(i + k)) { - return s.mk_clause(~lshr, ~s.eq(q(), k), s.bit(r(), i), ~s.bit(p(), i + k), true); - } + if (rv.val().get_bit(i) && !pv.val().get_bit(i + k)) + c.add_clause("q = k -> r[i] = p[i+k] for 0 <= i < N - k", { d, ~C.eq(q(), k), ~C.bit(r(), i), C.bit(p(), i + k) }, true); + + if (!rv.val().get_bit(i) && pv.val().get_bit(i + k)) + c.add_clause("q = k -> r[i] = p[i+k] for 0 <= i < N - k", { d, ~C.eq(q(), k), C.bit(r(), i), ~C.bit(p(), i + k) }, true); } } else { @@ -276,19 +274,44 @@ namespace polysat { rational const& q_val = qv.val(); if (q_val >= N) // q >= N ==> r = 0 - return s.mk_clause(~lshr, ~s.ule(N, q()), s.eq(r()), true); - if (pv.is_val()) { + c.add_clause("q >= N ==> r = 0", { d, ~C.ule(N, q()), C.eq(r()) }, true); + else if (pv.is_val()) { SASSERT(q_val.is_unsigned()); - // p = p_val & q = q_val ==> r = p_val / 2^q_val + // rational const r_val = machine_div2k(pv.val(), q_val.get_unsigned()); - return s.mk_clause(~lshr, ~s.eq(p(), pv), ~s.eq(q(), qv), s.eq(r(), r_val), true); + c.add_clause("p = p_val & q = q_val ==> r = p_val / 2^q_val", { d, ~C.eq(p(), pv), ~C.eq(q(), qv), C.eq(r(), r_val) }, true); } } } - return {}; } + void op_constraint::activate_and(core& c, dependency const& d) { + auto x = p(), y = q(); + auto& C = c.cs(); + if (x.is_val()) + std::swap(x, y); + if (!y.is_val()) + return; + auto& m = x.manager(); + auto yv = y.val(); + if (!(yv + 1).is_power_of_two()) + return; + if (yv == m.max_value()) + c.add_clause("band-mask-true", { d, C.eq(x, r()) }, false); + else if (yv == 0) + c.add_clause("band-mask-false", { d, C.eq(r()) }, false); + else { + unsigned N = m.power_of_2(); + unsigned k = yv.get_num_bits(); + SASSERT(k < N); + rational exp = rational::power_of_two(N - k); + c.add_clause("band-mask 1", { d, C.eq(x * exp, r() * exp) }, false); + c.add_clause("band-mask 2", { d, C.ule(r(), y) }, false); // maybe always activate these constraints regardless? + } + } + + /** * Enforce axioms for constraint: r == p << q * @@ -298,35 +321,33 @@ namespace polysat { * q = k -> r[i+k] = p[i] for 0 <= i < N - k * q = 0 -> r = p */ - clause_ref op_constraint::lemma_shl(solver& s, assignment const& a) { + void op_constraint::propagate_shl(core& c, dependency const& d) { auto& m = p().manager(); - auto const pv = a.apply_to(p()); - auto const qv = a.apply_to(q()); - auto const rv = a.apply_to(r()); + auto const pv = c.subst(p()); + auto const qv = c.subst(q()); + auto const rv = c.subst(r()); unsigned const N = m.power_of_2(); + auto& C = c.cs(); - signed_constraint const shl(this, true); - - if (qv.is_val() && qv.val() >= N && rv.is_val() && !rv.is_zero()) - // q >= N -> r = 0 - return s.mk_clause(~shl, ~s.ule(N, q()), s.eq(r()), true); + if (qv.is_val() && qv.val() >= N && rv.is_val() && !rv.is_zero()) + c.add_clause("q >= N -> r = 0", { d, ~C.ule(N, q()), C.eq(r()) }, true); else if (qv.is_zero() && pv.is_val() && rv.is_val() && rv != pv) - // q = 0 -> r = p - return s.mk_clause(~shl, ~s.eq(q()), s.eq(r(), p()), true); + // + c.add_clause("q = 0 -> r = p", { d, ~C.eq(q()), C.eq(r(), p()) }, true); else if (qv.is_val() && !qv.is_zero() && qv.val() < N && rv.is_val() && !rv.is_zero() && rv.val() < rational::power_of_two(qv.val().get_unsigned())) // q >= k -> r = 0 \/ r >= 2^k (intuitive version) // q >= k -> r - 1 >= 2^k - 1 (equivalent unit constraint to better support narrowing) - return s.mk_clause(~shl, ~s.ule(qv.val(), q()), s.ule(rational::power_of_two(qv.val().get_unsigned()) - 1, r() - 1), true); + c.add_clause("q >= k -> r - 1 >= 2^k - 1", { d, ~C.ule(qv.val(), q()), C.ule(rational::power_of_two(qv.val().get_unsigned()) - 1, r() - 1) }, true); else if (pv.is_val() && rv.is_val() && qv.is_val() && !qv.is_zero()) { unsigned k = qv.val().get_unsigned(); // q = k -> r[i+k] = p[i] for 0 <= i < N - k for (unsigned i = 0; i < N - k; ++i) { if (rv.val().get_bit(i + k) && !pv.val().get_bit(i)) { - return s.mk_clause(~shl, ~s.eq(q(), k), ~s.bit(r(), i + k), s.bit(p(), i), true); + c.add_clause("q = k -> r[i+k] = p[i] for 0 <= i < N - k", { d, ~C.eq(q(), k), ~C.bit(r(), i + k), C.bit(p(), i) }, true); } if (!rv.val().get_bit(i + k) && pv.val().get_bit(i)) { - return s.mk_clause(~shl, ~s.eq(q(), k), s.bit(r(), i + k), ~s.bit(p(), i), true); + c.add_clause("q = k -> r[i+k] = p[i] for 0 <= i < N - k", { d, ~C.eq(q(), k), C.bit(r(), i + k), ~C.bit(p(), i) }, true); } } } @@ -338,43 +359,15 @@ namespace polysat { rational const& q_val = qv.val(); if (q_val >= N) // q >= N ==> r = 0 - return s.mk_clause("shl forward 1", {~shl, ~s.ule(N, q()), s.eq(r())}, true); + c.add_clause("shl forward 1", {d, ~C.ule(N, q()), C.eq(r())}, true); if (pv.is_val()) { SASSERT(q_val.is_unsigned()); // p = p_val & q = q_val ==> r = p_val * 2^q_val rational const r_val = pv.val() * rational::power_of_two(q_val.get_unsigned()); - return s.mk_clause("shl forward 2", {~shl, ~s.eq(p(), pv), ~s.eq(q(), qv), s.eq(r(), r_val)}, true); + c.add_clause("shl forward 2", {d, ~C.eq(p(), pv), ~C.eq(q(), qv), C.eq(r(), r_val)}, true); } } } - return {}; - } - - - - void op_constraint::activate_and(solver& s) { - auto x = p(), y = q(); - if (x.is_val()) - std::swap(x, y); - if (!y.is_val()) - return; - auto& m = x.manager(); - auto yv = y.val(); - if (!(yv + 1).is_power_of_two()) - return; - signed_constraint const andc(this, true); - if (yv == m.max_value()) - s.add_clause(~andc, s.eq(x, r()), false); - else if (yv == 0) - s.add_clause(~andc, s.eq(r()), false); - else { - unsigned N = m.power_of_2(); - unsigned k = yv.get_num_bits(); - SASSERT(k < N); - rational exp = rational::power_of_two(N - k); - s.add_clause(~andc, s.eq(x * exp, r() * exp), false); - s.add_clause(~andc, s.ule(r(), y), false); // maybe always activate these constraints regardless? - } } /** @@ -390,48 +383,39 @@ namespace polysat { * p = 2^k - 1 && r = 0 && q != 0 => q >= 2^k * q = 2^k - 1 && r = 0 && p != 0 => p >= 2^k */ - clause_ref op_constraint::lemma_and(solver& s, assignment const& a) { + void op_constraint::propagate_and(core& c, dependency const& d) { auto& m = p().manager(); - auto pv = a.apply_to(p()); - auto qv = a.apply_to(q()); - auto rv = a.apply_to(r()); + auto pv = c.subst(p()); + auto qv = c.subst(q()); + auto rv = c.subst(r()); + auto& C = c.cs(); - signed_constraint const andc(this, true); // op_constraints are always true - - // r <= p if (pv.is_val() && rv.is_val() && rv.val() > pv.val()) - return s.mk_clause(~andc, s.ule(r(), p()), true); - // r <= q - if (qv.is_val() && rv.is_val() && rv.val() > qv.val()) - return s.mk_clause(~andc, s.ule(r(), q()), true); - // p = q => r = p - if (pv.is_val() && qv.is_val() && rv.is_val() && pv == qv && rv != pv) - return s.mk_clause(~andc, ~s.eq(p(), q()), s.eq(r(), p()), true); - if (pv.is_val() && qv.is_val() && rv.is_val()) { - // p = -1 => r = q + c.add_clause("p&q <= p", { d, C.ule(r(), p()) }, true); + else if (qv.is_val() && rv.is_val() && rv.val() > qv.val()) + c.add_clause("p&q <= q", { d, C.ule(r(), q()) }, true); + else if (pv.is_val() && qv.is_val() && rv.is_val() && pv == qv && rv != pv) + c.add_clause("p = q => r = p", { d, ~C.eq(p(), q()), C.eq(r(), p()) }, true); + else if (pv.is_val() && qv.is_val() && rv.is_val()) { if (pv.is_max() && qv != rv) - return s.mk_clause(~andc, ~s.eq(p(), m.max_value()), s.eq(q(), r()), true); - // q = -1 => r = p + c.add_clause("p = -1 => r = q", { d, ~C.eq(p(), m.max_value()), C.eq(q(), r()) }, true); if (qv.is_max() && pv != rv) - return s.mk_clause(~andc, ~s.eq(q(), m.max_value()), s.eq(p(), r()), true); + c.add_clause("q = -1 => r = p", { d, ~C.eq(q(), m.max_value()), C.eq(p(), r()) }, true); unsigned const N = m.power_of_2(); unsigned pow; if ((pv.val() + 1).is_power_of_two(pow)) { - // p = 2^k - 1 && r = 0 && q != 0 => q >= 2^k if (rv.is_zero() && !qv.is_zero() && qv.val() <= pv.val()) - return s.mk_clause(~andc, ~s.eq(p(), pv), ~s.eq(r()), s.eq(q()), s.ule(pv + 1, q()), true); - // p = 2^k - 1 ==> r*2^{N - k} = q*2^{N - k} + c.add_clause("p = 2^k - 1 && r = 0 && q != 0 => q >= 2^k", { d, ~C.eq(p(), pv), ~C.eq(r()), C.eq(q()), C.ule(pv + 1, q()) }, true); if (rv != qv) - return s.mk_clause(~andc, ~s.eq(p(), pv), s.eq(r() * rational::power_of_two(N - pow), q() * rational::power_of_two(N - pow)), true); + c.add_clause("p = 2^k - 1 ==> r*2^{N - k} = q*2^{N - k}", { d, ~C.eq(p(), pv), C.eq(r() * rational::power_of_two(N - pow), q() * rational::power_of_two(N - pow)) }, true); } if ((qv.val() + 1).is_power_of_two(pow)) { - // q = 2^k - 1 && r = 0 && p != 0 ==> p >= 2^k if (rv.is_zero() && !pv.is_zero() && pv.val() <= qv.val()) - return s.mk_clause(~andc, ~s.eq(q(), qv), ~s.eq(r()), s.eq(p()), s.ule(qv + 1, p()), true); - // q = 2^k - 1 ==> r*2^{N - k} = p*2^{N - k} + c.add_clause("q = 2^k - 1 && r = 0 && p != 0 ==> p >= 2^k", { d, ~C.eq(q(), qv), ~C.eq(r()), C.eq(p()), C.ule(qv + 1, p()) }, true); + // if (rv != pv) - return s.mk_clause(~andc, ~s.eq(q(), qv), s.eq(r() * rational::power_of_two(N - pow), p() * rational::power_of_two(N - pow)), true); + c.add_clause("q = 2^k - 1 ==> r*2^{N - k} = p*2^{N - k}", { d, ~C.eq(q(), qv), C.eq(r() * rational::power_of_two(N - pow), p() * rational::power_of_two(N - pow)) }, true); } for (unsigned i = 0; i < N; ++i) { @@ -441,33 +425,31 @@ namespace polysat { if (rb == (pb && qb)) continue; if (pb && qb && !rb) - return s.mk_clause(~andc, ~s.bit(p(), i), ~s.bit(q(), i), s.bit(r(), i), true); + c.add_clause("p&q[i] = p[i]&q[i]", { d, ~C.bit(p(), i), ~C.bit(q(), i), C.bit(r(), i) }, true); else if (!pb && rb) - return s.mk_clause(~andc, s.bit(p(), i), ~s.bit(r(), i), true); + c.add_clause("p&q[i] = p[i]&q[i]", { d, C.bit(p(), i), ~C.bit(r(), i) }, true); else if (!qb && rb) - return s.mk_clause(~andc, s.bit(q(), i), ~s.bit(r(), i), true); + c.add_clause("p&q[i] = p[i]&q[i]", { d, C.bit(q(), i), ~C.bit(r(), i) }, true); else UNREACHABLE(); } - return {}; + return; } // Propagate r if p or q are 0 - if (pv.is_zero() && !rv.is_zero()) // rv not necessarily fully evaluated - return s.mk_clause(~andc, s.ule(r(), p()), true); - if (qv.is_zero() && !rv.is_zero()) // rv not necessarily fully evaluated - return s.mk_clause(~andc, s.ule(r(), q()), true); + else if (pv.is_zero() && !rv.is_zero()) // rv not necessarily fully evaluated + c.add_clause("p = 0 -> p&q = 0", { d, C.ule(r(), p()) }, true); + else if (qv.is_zero() && !rv.is_zero()) // rv not necessarily fully evaluated + c.add_clause("q = 0 -> p&q = 0", { d, C.ule(r(), q()) }, true); // p = a && q = b ==> r = a & b - if (pv.is_val() && qv.is_val() && !rv.is_val()) { + else if (pv.is_val() && qv.is_val() && !rv.is_val()) { // Just assign by this very weak justification. It will be strengthened in saturation in case of a conflict LOG(p() << " = " << pv << " and " << q() << " = " << qv << " yields [band] " << r() << " = " << bitwise_and(pv.val(), qv.val())); - return s.mk_clause(~andc, ~s.eq(p(), pv), ~s.eq(q(), qv), s.eq(r(), bitwise_and(pv.val(), qv.val())), true); + c.add_clause("p = a & q = b => r = a&b", { d, ~C.eq(p(), pv), ~C.eq(q(), qv), C.eq(r(), bitwise_and(pv.val(), qv.val())) }, true); } - - return {}; } - +#if 0 /** * Produce lemmas for constraint: r == inv p @@ -490,15 +472,15 @@ namespace polysat { // p = 0 ==> r = 0 if (pv.is_zero()) - return s.mk_clause(~invc, ~s.eq(p()), s.eq(r()), true); + c.add_clause(~invc, ~C.eq(p()), C.eq(r()), true); // r = 0 ==> p = 0 if (rv.is_zero()) - return s.mk_clause(~invc, ~s.eq(r()), s.eq(p()), true); + c.add_clause(~invc, ~C.eq(r()), C.eq(p()), true); // forward propagation: p assigned ==> r = pseudo_inverse(eval(p)) // TODO: (later) this should be propagated instead of adding a clause /*if (pv.is_val() && !rv.is_val()) - return s.mk_clause(~invc, ~s.eq(p(), pv), s.eq(r(), pv.val().pseudo_inverse(m.power_of_2())), true);*/ + c.add_clause(~invc, ~C.eq(p(), pv), C.eq(r(), pv.val().pseudo_inverse(m.power_of_2())), true);*/ if (!pv.is_val() || !rv.is_val()) return {}; @@ -511,7 +493,7 @@ namespace polysat { // p != 0 ==> odd(r) if (parity_rv != 0) - return s.mk_clause("r = inv p & p != 0 ==> odd(r)", {~invc, s.eq(p()), s.odd(r())}, true); + c.add_clause("r = inv p & p != 0 ==> odd(r)", {~invc, C.eq(p()), s.odd(r())}, true); pdd prod = p() * r(); rational prodv = (pv * rv).val(); @@ -527,13 +509,13 @@ namespace polysat { LOG("Its in [" << lower << "; " << upper << ")"); // parity(p) >= k ==> p * r >= 2^k if (prodv < rational::power_of_two(middle)) - return s.mk_clause("r = inv p & parity(p) >= k ==> p*r >= 2^k", + c.add_clause("r = inv p & parity(p) >= k ==> p*r >= 2^k", {~invc, ~s.parity_at_least(p(), middle), s.uge(prod, rational::power_of_two(middle))}, false); // parity(p) >= k ==> r <= 2^(N - k) - 1 (because r is the smallest pseudo-inverse) rational const max_rv = rational::power_of_two(m.power_of_2() - middle) - 1; if (rv.val() > max_rv) - return s.mk_clause("r = inv p & parity(p) >= k ==> r <= 2^(N - k) - 1", - {~invc, ~s.parity_at_least(p(), middle), s.ule(r(), max_rv)}, false); + c.add_clause("r = inv p & parity(p) >= k ==> r <= 2^(N - k) - 1", + {~invc, ~s.parity_at_least(p(), middle), C.ule(r(), max_rv)}, false); } else { // parity less than middle SASSERT(parity_pv < middle); @@ -541,8 +523,8 @@ namespace polysat { LOG("Its in [" << lower << "; " << upper << ")"); // parity(p) < k ==> p * r <= 2^k - 1 if (prodv > rational::power_of_two(middle)) - return s.mk_clause("r = inv p & parity(p) < k ==> p*r <= 2^k - 1", - {~invc, s.parity_at_least(p(), middle), s.ule(prod, rational::power_of_two(middle) - 1)}, false); + c.add_clause("r = inv p & parity(p) < k ==> p*r <= 2^k - 1", + {~invc, s.parity_at_least(p(), middle), C.ule(prod, rational::power_of_two(middle) - 1)}, false); } } // Why did it evaluate to false in this case? @@ -550,114 +532,5 @@ namespace polysat { return {}; } - - - void op_constraint::activate_udiv(solver& s) { - // signed_constraint const udivc(this, true); Do we really need this premiss? We anyway assert these constraints as unit clauses - - pdd const& quot = r(); - pdd const& rem = m_linked->r(); - - // Axioms for quotient/remainder: - // a = b*q + r - // multiplication does not overflow in b*q - // addition does not overflow in (b*q) + r; for now expressed as: r <= bq+r - // b ≠ 0 ==> r < b - // b = 0 ==> q = -1 - // TODO: when a,b become evaluable, can we actually propagate q,r? doesn't seem like it. - // Maybe we need something like an op_constraint for better propagation. - s.add_clause(s.eq(q() * quot + rem - p()), false); - s.add_clause(~s.umul_ovfl(q(), quot), false); - // r <= b*q+r - // { apply equivalence: p <= q <=> q-p <= -p-1 } - // b*q <= -r-1 - s.add_clause(s.ule(q() * quot, -rem - 1), false); - - auto c_eq = s.eq(q()); - s.add_clause(c_eq, s.ult(rem, q()), false); - s.add_clause(~c_eq, s.eq(quot + 1), false); - } - - /** - * Produce lemmas for constraint: r == p / q - * q = 0 ==> r = max_value - * p = 0 ==> r = 0 || r = max_value - * q = 1 ==> r = p - */ - clause_ref op_constraint::lemma_udiv(solver& s, assignment const& a) { - auto pv = a.apply_to(p()); - auto qv = a.apply_to(q()); - auto rv = a.apply_to(r()); - - if (eval_udiv(pv, qv, rv) == l_true) - return {}; - - signed_constraint const udivc(this, true); - - if (qv.is_zero() && !rv.is_val()) - return s.mk_clause(~udivc, ~s.eq(q()), s.eq(r(), r().manager().max_value()), true); - if (pv.is_zero() && !rv.is_val()) - return s.mk_clause(~udivc, ~s.eq(p()), s.eq(r()), s.eq(r(), r().manager().max_value()), true); - if (qv.is_one()) - return s.mk_clause(~udivc, ~s.eq(q(), 1), s.eq(r(), p()), true); - - if (pv.is_val() && qv.is_val() && !rv.is_val()) { - SASSERT(!qv.is_zero()); - // TODO: We could actually propagate an interval. Instead of p = 9 & q = 4 => r = 2 we could do p >= 8 && p < 12 && q = 4 => r = 2 - return s.mk_clause(~udivc, ~s.eq(p(), pv.val()), ~s.eq(q(), qv.val()), s.eq(r(), div(pv.val(), qv.val())), true); - } - - return {}; - } - - - /** - * Produce lemmas for constraint: r == p % q - * p = 0 ==> r = 0 - * q = 1 ==> r = 0 - * q = 0 ==> r = p - */ - clause_ref op_constraint::lemma_urem(solver& s, assignment const& a) { - auto pv = a.apply_to(p()); - auto qv = a.apply_to(q()); - auto rv = a.apply_to(r()); - - if (eval_urem(pv, qv, rv) == l_true) - return {}; - - signed_constraint const urem(this, true); - - if (pv.is_zero() && !rv.is_val()) - return s.mk_clause(~urem, ~s.eq(p()), s.eq(r()), true); - if (qv.is_one() && !rv.is_val()) - return s.mk_clause(~urem, ~s.eq(q(), 1), s.eq(r()), true); - if (qv.is_zero()) - return s.mk_clause(~urem, ~s.eq(q()), s.eq(r(), p()), true); - - if (pv.is_val() && qv.is_val() && !rv.is_val()) { - SASSERT(!qv.is_zero()); - return s.mk_clause(~urem, ~s.eq(p(), pv.val()), ~s.eq(q(), qv.val()), s.eq(r(), mod(pv.val(), qv.val())), true); - } - - return {}; - } - - /** Evaluate constraint: r == p % q */ - lbool op_constraint::eval_urem(pdd const& p, pdd const& q, pdd const& r) { - - if (q.is_one() && r.is_val()) { - return r.val().is_zero() ? l_true : l_false; - } - if (q.is_zero()) { - if (r == p) - return l_true; - } - - if (!p.is_val() || !q.is_val() || !r.is_val()) - return l_undef; - - return r.val() == mod(p.val(), q.val()) ? l_true : l_false; // mod == rem as we know hat q > 0 - } - #endif } diff --git a/src/sat/smt/polysat/op_constraint.h b/src/sat/smt/polysat/op_constraint.h index a33f1b705..d7b6be392 100644 --- a/src/sat/smt/polysat/op_constraint.h +++ b/src/sat/smt/polysat/op_constraint.h @@ -50,32 +50,22 @@ namespace polysat { op_constraint(code c, pdd const& r, pdd const& p, pdd const& q); lbool eval(pdd const& r, pdd const& p, pdd const& q) const; -// clause_ref produce_lemma(core& s, assignment const& a); - // clause_ref lemma_lshr(core& s, assignment const& a); static lbool eval_lshr(pdd const& p, pdd const& q, pdd const& r); - - // clause_ref lemma_shl(core& s, assignment const& a); static lbool eval_shl(pdd const& p, pdd const& q, pdd const& r); - - // clause_ref lemma_and(core& s, assignment const& a); static lbool eval_and(pdd const& p, pdd const& q, pdd const& r); - - // clause_ref lemma_inv(core& s, assignment const& a); static lbool eval_inv(pdd const& p, pdd const& r); + + void propagate_lshr(core& s, dependency const& dep); + void propagate_shl(core& s, dependency const& dep); + void propagate_and(core& s, dependency const& dep); + void propagate_inv(core& s, dependency const& dep); + - // clause_ref lemma_udiv(core& s, assignment const& a); - static lbool eval_udiv(pdd const& p, pdd const& q, pdd const& r); - - // clause_ref lemma_urem(core& s, assignment const& a); - static lbool eval_urem(pdd const& p, pdd const& q, pdd const& r); std::ostream& display(std::ostream& out, char const* eq) const; - void activate(core& s); - - void activate_and(core& s); - void activate_udiv(core& s); + void activate_and(core& s, dependency const& d); public: ~op_constraint() override {} @@ -89,6 +79,8 @@ namespace polysat { lbool eval(assignment const& a) const override; bool is_always_true() const { return false; } bool is_always_false() const { return false; } + void activate(core& c, bool sign, dependency const& dep) override; + void propagate(core& c, lbool value, dependency const& dep) override; }; } diff --git a/src/sat/smt/polysat/types.h b/src/sat/smt/polysat/types.h index 207ea091e..6f855f98b 100644 --- a/src/sat/smt/polysat/types.h +++ b/src/sat/smt/polysat/types.h @@ -29,12 +29,17 @@ namespace polysat { class signed_constraint; class dependency { - std::variant> m_data; + struct axiom_t {}; + std::variant> m_data; unsigned m_level; + dependency(): m_data(axiom_t()), m_level(0) {} public: dependency(sat::literal lit, unsigned level) : m_data(lit), m_level(level) {} dependency(theory_var v1, theory_var v2, unsigned level) : m_data(std::make_pair(v1, v2)), m_level(level) {} + static dependency axiom() { return dependency(); } bool is_null() const { return is_literal() && *std::get_if(&m_data) == sat::null_literal; } + bool is_axiom() const { return std::holds_alternative(m_data); } + bool is_eq() const { return std::holds_alternative>(m_data); } bool is_literal() const { return std::holds_alternative(m_data); } sat::literal literal() const { SASSERT(is_literal()); return *std::get_if(&m_data); } std::pair eq() const { SASSERT(!is_literal()); return *std::get_if>(&m_data); } @@ -46,6 +51,8 @@ namespace polysat { inline std::ostream& operator<<(std::ostream& out, dependency d) { if (d.is_null()) return out << "null"; + else if (d.is_axiom()) + return out << "axiom@" << d.level(); else if (d.is_literal()) return out << d.literal() << "@" << d.level(); else @@ -87,7 +94,7 @@ namespace polysat { using dependency_vector = vector; - using core_vector = vector>; + using core_vector = std::initializer_list>; @@ -101,6 +108,7 @@ namespace polysat { virtual void add_eq_literal(pvar v, rational const& val) = 0; virtual void set_conflict(dependency_vector const& core) = 0; virtual void set_lemma(core_vector const& aux_core, unsigned level, dependency_vector const& core) = 0; + virtual void add_polysat_clause(char const* name, core_vector cs, bool redundant) = 0; virtual dependency propagate(signed_constraint sc, dependency_vector const& deps) = 0; virtual void propagate(dependency const& d, bool sign, dependency_vector const& deps) = 0; virtual trail_stack& trail() = 0; diff --git a/src/sat/smt/polysat/ule_constraint.cpp b/src/sat/smt/polysat/ule_constraint.cpp index 185dad0ee..bdfcb7c5f 100644 --- a/src/sat/smt/polysat/ule_constraint.cpp +++ b/src/sat/smt/polysat/ule_constraint.cpp @@ -70,6 +70,8 @@ Useful lemmas: --*/ +#include "util/log.h" +#include "sat/smt/polysat/core.h" #include "sat/smt/polysat/constraints.h" #include "sat/smt/polysat/ule_constraint.h" @@ -314,8 +316,6 @@ namespace polysat { return display(out, l_true, m_lhs, m_rhs); } - - // Evaluate lhs <= rhs lbool ule_constraint::eval(pdd const& lhs, pdd const& rhs) { // NOTE: don't assume simplifications here because we also call this on partially substituted constraints @@ -343,4 +343,15 @@ namespace polysat { return eval(a.apply_to(lhs()), a.apply_to(rhs())); } + void ule_constraint::activate(core& c, bool sign, dependency const& d) { + auto p = c.subst(lhs()); + auto q = c.subst(rhs()); + auto& C = c.cs(); + if (sign && !lhs().is_val() && !rhs().is_val()) { + c.add_clause("lhs > rhs ==> -1 > rhs", { d, C.ult(rhs(), -1) }, false); + c.add_clause("lhs > rhs ==> lhs > 0", { d, C.ult(0, lhs()) }, false); + } + } + + } diff --git a/src/sat/smt/polysat/ule_constraint.h b/src/sat/smt/polysat/ule_constraint.h index aa53e6a4f..81a0b64c5 100644 --- a/src/sat/smt/polysat/ule_constraint.h +++ b/src/sat/smt/polysat/ule_constraint.h @@ -35,6 +35,8 @@ namespace polysat { std::ostream& display(std::ostream& out) const override; lbool eval() const override; lbool eval(assignment const& a) const override; + void activate(core& c, bool sign, dependency const& dep); + void propagate(core& c, lbool value, dependency const& dep) {} bool is_eq() const { return m_rhs.is_zero(); } unsigned power_of_2() const { return m_lhs.power_of_2(); } diff --git a/src/sat/smt/polysat/umul_ovfl_constraint.cpp b/src/sat/smt/polysat/umul_ovfl_constraint.cpp index e7dc5801c..5d185e7ee 100644 --- a/src/sat/smt/polysat/umul_ovfl_constraint.cpp +++ b/src/sat/smt/polysat/umul_ovfl_constraint.cpp @@ -10,6 +10,8 @@ Author: Jakob Rath, Nikolaj Bjorner (nbjorner) 2021-12-09 --*/ +#include "util/log.h" +#include "sat/smt/polysat/core.h" #include "sat/smt/polysat/constraints.h" #include "sat/smt/polysat/assignment.h" #include "sat/smt/polysat/umul_ovfl_constraint.h" @@ -70,4 +72,84 @@ namespace polysat { return eval(a.apply_to(p()), a.apply_to(q())); } + void umul_ovfl_constraint::activate(core& c, bool sign, dependency const& dep) { + + } + + void umul_ovfl_constraint::propagate(core& c, lbool value, dependency const& dep) { + auto& C = c.cs(); + auto p1 = c.subst(p()); + auto q1 = c.subst(q()); + if (narrow_bound(c, value == l_true, p(), q(), p1, q1, dep)) + return; + if (narrow_bound(c, value == l_true, q(), p(), q1, p1, dep)) + return; + } + + /** + * if p constant, q, propagate inequality + */ + bool umul_ovfl_constraint::narrow_bound(core& c, bool is_positive, pdd const& p0, pdd const& q0, pdd const& p, pdd const& q, dependency const& d) { + LOG("p: " << p0 << " := " << p); + LOG("q: " << q0 << " := " << q); + + if (!p.is_val()) + return false; + VERIFY(!p.is_zero() && !p.is_one()); // evaluation should catch this case + + rational const& M = p.manager().two_to_N(); + auto& C = c.cs(); + + // q_bound + // = min q . Ovfl(p_val, q) + // = min q . p_val * q >= M + // = min q . q >= M / p_val + // = ceil(M / p_val) + rational const q_bound = ceil(M / p.val()); + SASSERT(2 <= q_bound && q_bound <= M / 2); + SASSERT(p.val() * q_bound >= M); + SASSERT(p.val() * (q_bound - 1) < M); + // LOG("q_bound: " << q.manager().mk_val(q_bound)); + + // We need the following properties for the bounds: + // + // p_bound * (q_bound - 1) < M + // p_bound * q_bound >= M + // + // With these properties we get: + // + // p <= p_bound & q < q_bound ==> ~Ovfl(p, q) + // p >= p_bound & q >= q_bound ==> Ovfl(p, q) + // + // Written as lemmas: + // + // Ovfl(p, q) & p <= p_bound ==> q >= q_bound + // ~Ovfl(p, q) & p >= p_bound ==> q < q_bound + // + if (is_positive) { + // Find largest bound for p such that q_bound is still correct. + // p_bound = max p . (q_bound - 1)*p < M + // = max p . p < M / (q_bound - 1) + // = ceil(M / (q_bound - 1)) - 1 + rational const p_bound = ceil(M / (q_bound - 1)) - 1; + SASSERT(p.val() <= p_bound); + SASSERT(p_bound * q_bound >= M); + SASSERT(p_bound * (q_bound - 1) < M); + // LOG("p_bound: " << p.manager().mk_val(p_bound)); + c.add_clause("~Ovfl(p, q) & p <= p_bound ==> q < q_bound", { d, ~C.ule(p0, p_bound), C.ule(q_bound, q0) }, false); + } + else { + // Find lowest bound for p such that q_bound is still correct. + // p_bound = min p . Ovfl(p, q_bound) = ceil(M / q_bound) + rational const p_bound = ceil(M / q_bound); + SASSERT(p_bound <= p.val()); + SASSERT(p_bound * q_bound >= M); + SASSERT(p_bound * (q_bound - 1) < M); + // LOG("p_bound: " << p.manager().mk_val(p_bound)); + c.add_clause("~Ovfl(p, q) & p >= p_bound ==> q < q_bound", { d, ~C.ule(p_bound, p0), C.ult(q0, q_bound) }, false); + } + return true; + } + + } diff --git a/src/sat/smt/polysat/umul_ovfl_constraint.h b/src/sat/smt/polysat/umul_ovfl_constraint.h index c9d03fb01..374d346f7 100644 --- a/src/sat/smt/polysat/umul_ovfl_constraint.h +++ b/src/sat/smt/polysat/umul_ovfl_constraint.h @@ -25,6 +25,8 @@ namespace polysat { static bool is_always_false(bool is_positive, pdd const& p, pdd const& q) { return is_always_true(!is_positive, p, q); } static lbool eval(pdd const& p, pdd const& q); + bool narrow_bound(core& c, bool is_positive, pdd const& p0, pdd const& q0, pdd const& p, pdd const& q, dependency const& d); + public: umul_ovfl_constraint(pdd const& p, pdd const& q); ~umul_ovfl_constraint() override {} @@ -34,6 +36,8 @@ namespace polysat { std::ostream& display(std::ostream& out) const override; lbool eval() const override; lbool eval(assignment const& a) const override; + void activate(core& c, bool sign, dependency const& dep) override; + void propagate(core& c, lbool value, dependency const& dep) override; }; } diff --git a/src/sat/smt/polysat_internalize.cpp b/src/sat/smt/polysat_internalize.cpp index 5e5647bd3..ef469fe6f 100644 --- a/src/sat/smt/polysat_internalize.cpp +++ b/src/sat/smt/polysat_internalize.cpp @@ -235,9 +235,7 @@ namespace polysat { if (n->get_num_args() == 2) { expr* x, * y; VERIFY(bv.is_bv_and(n, x, y)); - auto sc = m_core.band(expr2pdd(x), expr2pdd(y), expr2pdd(n)); - // auto index = m_core.register_constraint(sc, dependency::axiom()); - // + m_core.band(expr2pdd(x), expr2pdd(y), expr2pdd(n)); } else { expr_ref z(n->get_arg(0), m); @@ -252,19 +250,19 @@ namespace polysat { void solver::internalize_lshr(app* n) { expr* x, * y; VERIFY(bv.is_bv_lshr(n, x, y)); - auto sc = m_core.lshr(expr2pdd(x), expr2pdd(y), expr2pdd(n)); + m_core.lshr(expr2pdd(x), expr2pdd(y), expr2pdd(n)); } void solver::internalize_ashr(app* n) { expr* x, * y; VERIFY(bv.is_bv_ashr(n, x, y)); - auto sc = m_core.ashr(expr2pdd(x), expr2pdd(y), expr2pdd(n)); + m_core.ashr(expr2pdd(x), expr2pdd(y), expr2pdd(n)); } void solver::internalize_shl(app* n) { expr* x, * y; VERIFY(bv.is_bv_shl(n, x, y)); - auto sc = m_core.shl(expr2pdd(x), expr2pdd(y), expr2pdd(n)); + m_core.shl(expr2pdd(x), expr2pdd(y), expr2pdd(n)); } void solver::internalize_urem_i(app* rem) { diff --git a/src/sat/smt/polysat_solver.cpp b/src/sat/smt/polysat_solver.cpp index 9f185b22d..690548aaa 100644 --- a/src/sat/smt/polysat_solver.cpp +++ b/src/sat/smt/polysat_solver.cpp @@ -64,7 +64,7 @@ namespace polysat { case sat::check_result::CR_GIVEUP: { if (!m.inc()) return sat::check_result::CR_GIVEUP; - switch (m_intblast.check()) { + switch (m_intblast.check_solver_state()) { case l_true: trail().push(value_trail(m_use_intblast_model)); m_use_intblast_model = true; @@ -254,10 +254,25 @@ namespace polysat { return ctx.get_trail_stack(); } - void solver::add_polysat_clause(char const* name, std::initializer_list cs, bool is_redundant) { + void solver::add_polysat_clause(char const* name, core_vector cs, bool is_redundant) { sat::literal_vector lits; - for (auto sc : cs) - lits.push_back(ctx.mk_literal(constraint2expr(sc))); + for (auto e : cs) { + if (std::holds_alternative(e)) { + auto d = *std::get_if(&e); + SASSERT(!d.is_null()); + if (d.is_literal()) + lits.push_back(~d.literal()); + else if (d.is_eq()) { + auto [v1, v2] = d.eq(); + lits.push_back(~eq_internalize(var2enode(v1), var2enode(v2))); + } + else { + SASSERT(d.is_axiom()); + } + } + else if (std::holds_alternative(e)) + lits.push_back(ctx.mk_literal(constraint2expr(*std::get_if(&e)))); + } s().add_clause(lits.size(), lits.data(), sat::status::th(is_redundant, get_id(), nullptr)); } diff --git a/src/sat/smt/polysat_solver.h b/src/sat/smt/polysat_solver.h index f54bafb1c..a04c76618 100644 --- a/src/sat/smt/polysat_solver.h +++ b/src/sat/smt/polysat_solver.h @@ -151,7 +151,7 @@ namespace polysat { bool inconsistent() const override; void get_bitvector_prefixes(pvar v, pvar_vector& out) override; void get_fixed_bits(pvar v, svector& fixed_bits) override; - void add_polysat_clause(char const* name, std::initializer_list cs, bool is_redundant); + void add_polysat_clause(char const* name, core_vector cs, bool redundant) override; std::pair explain_deps(dependency_vector const& deps);