From 06ebf9a02af5720ba6d55cd65f9c6462dce5522c Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Tue, 12 Dec 2023 14:41:31 -0800 Subject: [PATCH] n/a --- src/ast/bv_decl_plugin.cpp | 5 + src/ast/bv_decl_plugin.h | 1 + src/sat/smt/intblast_solver.cpp | 314 +++++++++++++------ src/sat/smt/intblast_solver.h | 31 +- src/sat/smt/polysat/umul_ovfl_constraint.cpp | 2 + src/sat/smt/polysat_solver.cpp | 26 ++ 6 files changed, 272 insertions(+), 107 deletions(-) diff --git a/src/ast/bv_decl_plugin.cpp b/src/ast/bv_decl_plugin.cpp index f725fefc5..30cfe4cdb 100644 --- a/src/ast/bv_decl_plugin.cpp +++ b/src/ast/bv_decl_plugin.cpp @@ -942,3 +942,8 @@ app * bv_util::mk_bv2int(expr* e) { parameter p(s); return m_manager.mk_app(get_fid(), OP_BV2INT, 1, &p, 1, &e); } + +app* bv_util::mk_int2bv(unsigned sz, expr* e) { + parameter p(sz); + return m_manager.mk_app(get_fid(), OP_INT2BV, 1, &p, 1, &e); +} diff --git a/src/ast/bv_decl_plugin.h b/src/ast/bv_decl_plugin.h index 4eeac49ee..cb1f63881 100644 --- a/src/ast/bv_decl_plugin.h +++ b/src/ast/bv_decl_plugin.h @@ -522,6 +522,7 @@ public: app * mk_bv_lshr(expr* arg1, expr* arg2) { return m_manager.mk_app(get_fid(), OP_BLSHR, arg1, arg2); } app * mk_bv2int(expr* e); + app * mk_int2bv(unsigned sz, expr* e); // TODO: all these binary ops commute (right?) but it'd be more logical to swap `n` & `m` in the `return` app * mk_bvsmul_no_ovfl(expr* m, expr* n) { return m_manager.mk_app(get_fid(), OP_BSMUL_NO_OVFL, n, m); } diff --git a/src/sat/smt/intblast_solver.cpp b/src/sat/smt/intblast_solver.cpp index 0505eaa92..250e279cd 100644 --- a/src/sat/smt/intblast_solver.cpp +++ b/src/sat/smt/intblast_solver.cpp @@ -29,8 +29,7 @@ namespace intblast { bv(m), a(m), m_args(m), - m_translate(m), - m_pinned(m) + m_translate(m) {} euf::theory_var solver::mk_var(euf::enode* n) { @@ -85,14 +84,38 @@ namespace intblast { return true; } + void solver::eq_internalized(euf::enode* n) { + expr* e = n->get_expr(); + expr* x, * y; + VERIFY(m.is_eq(n->get_expr(), x, y)); + m_args.reset(); + m_args.push_back(translated(x)); + m_args.push_back(translated(y)); + add_equiv(expr2literal(e), eq_internalize(umod(x, 0), umod(x, 1))); + } + void solver::internalize_bv(app* e) { ensure_args(e); m_args.reset(); for (auto arg : *e) m_args.push_back(translated(arg)); translate_bv(e); - if (m.is_bool(e)) - add_equiv(expr2literal(e), mk_literal(translated(e))); + if (m.is_bool(e)) + add_equiv(expr2literal(e), mk_literal(translated(e))); + add_bound_axioms(); + } + + void solver::add_bound_axioms() { + if (m_vars_qhead == m_vars.size()) + return; + ctx.push(value_trail(m_vars_qhead)); + for (; m_vars_qhead < m_vars.size(); ++m_vars_qhead) { + auto v = m_vars[m_vars_qhead]; + auto w = translated(v); + auto sz = rational::power_of_two(bv.get_bv_size(v->get_sort())); + add_unit(ctx.mk_literal(a.mk_ge(w, a.mk_int(0)))); + add_unit(ctx.mk_literal(a.mk_le(w, a.mk_int(sz - 1)))); + } } void solver::ensure_args(app* e) { @@ -106,17 +129,17 @@ namespace intblast { return; for (unsigned i = 0; i < todo.size(); ++i) { expr* e = todo[i]; - if (is_app(e)) { + if (m.is_bool(e)) + continue; + else if (is_app(e)) { for (auto arg : *to_app(e)) if (!visited.is_marked(arg)) { visited.mark(arg); todo.push_back(arg); } } - else if (is_quantifier(e) && !visited.is_marked(to_quantifier(e)->get_expr())) { - visited.mark(to_quantifier(e)->get_expr()); - todo.push_back(to_quantifier(e)->get_expr()); - } + else if (is_lambda(e)) + throw default_exception("lambdas are not supported in intblaster"); } std::stable_sort(todo.begin(), todo.end(), [&](expr* a, expr* b) { return get_depth(a) < get_depth(b); }); @@ -176,8 +199,8 @@ namespace intblast { } m_core.reset(); - m_vars.reset(); m_translate.reset(); + m_is_plugin = false; m_solver = mk_smt2_solver(m, s.params(), symbol::null); expr_ref_vector es(m); @@ -186,8 +209,9 @@ namespace intblast { translate(es); - for (auto const& [src, vi] : m_vars) { - auto const& [v, b] = vi; + for (auto e : m_vars) { + auto v = translated(e); + auto b = rational::power_of_two(bv.get_bv_size(e)); m_solver->assert_expr(a.mk_le(a.mk_int(0), v)); m_solver->assert_expr(a.mk_lt(v, a.mk_int(b))); } @@ -296,19 +320,68 @@ namespace intblast { es[i] = translated(es.get(i)); } - expr* solver::mk_mod(expr* x) { - if (m_vars.contains(x)) + sat::check_result solver::check() { + // ensure that bv2int is injective + for (auto e : m_bv2int) { + euf::enode* n = expr2enode(e); + euf::enode* r1 = n->get_arg(0)->get_root(); + for (auto sib : euf::enode_class(n)) { + if (sib == n) + continue; + if (!bv.is_bv2int(sib->get_expr())) + continue; + if (sib->get_arg(0)->get_root() == r1) + continue; + add_clause(~eq_internalize(n, sib), eq_internalize(sib->get_arg(0), n->get_arg(0)), nullptr); + return sat::check_result::CR_CONTINUE; + } + } + // ensure that int2bv respects values + // bv2int(int2bv(x)) = x mod N + for (auto e : m_int2bv) { + auto n = expr2enode(e); + auto x = n->get_arg(0)->get_expr(); + auto bv2int = bv.mk_bv2int(e); + ctx.internalize(bv2int); + auto N = rational::power_of_two(bv.get_bv_size(e)); + auto xModN = a.mk_mod(x, a.mk_int(N)); + ctx.internalize(xModN); + auto nBv2int = ctx.get_enode(bv2int); + auto nxModN = ctx.get_enode(xModN); + if (nBv2int->get_root() != nxModN->get_root()) { + add_unit(eq_internalize(nBv2int, nxModN)); + return sat::check_result::CR_CONTINUE; + } + } + return sat::check_result::CR_DONE; + } + + expr* solver::umod(expr* bv_expr, unsigned i) { + expr* x = arg(i); + rational r; + rational N = bv_size(bv_expr); + if (a.is_numeral(x, r)) { + if (0 <= r && r < N) + return x; + return a.mk_int(mod(r, N)); + } + if (any_of(m_vars, [&](expr* v) { return translated(v) == x; })) return x; - return to_expr(a.mk_mod(x, a.mk_int(bv_size()))); + return to_expr(a.mk_mod(x, a.mk_int(N))); } - expr* solver::mk_smod(expr* x) { - auto shift = bv_size() / 2; - return a.mk_mod(a.mk_add(x, a.mk_int(shift)), a.mk_int(bv_size())); + expr* solver::smod(expr* bv_expr, unsigned i) { + expr* x = arg(i); + auto N = bv_size(bv_expr); + auto shift = N / 2; + rational r; + if (a.is_numeral(x, r)) + return a.mk_int(mod(r + shift, N)); + return a.mk_mod(a.mk_add(x, a.mk_int(shift)), a.mk_int(N)); } - rational solver::bv_size() { - return rational::power_of_two(bv.get_bv_size(bv_expr->get_sort())); + rational solver::bv_size(expr* bv_expr) { + return rational::power_of_two(bv.get_bv_size(bv_expr->get_sort())); } void solver::translate_expr(expr* e) { @@ -318,7 +391,6 @@ namespace intblast { translate_var(to_var(e)); else { app* ap = to_app(e); - bv_expr = e; m_args.reset(); for (auto arg : *ap) m_args.push_back(translated(arg)); @@ -333,6 +405,12 @@ namespace intblast { } void solver::translate_quantifier(quantifier* q) { + if (is_lambda(q)) + throw default_exception("lambdas are not supported in intblaster"); + if (m_is_plugin) { + set_translated(q, q); + return; + } expr* b = q->get_expr(); unsigned nd = q->get_num_decls(); ptr_vector sorts; @@ -357,37 +435,47 @@ namespace intblast { set_translated(v, v); } + // Translate functions that are not built-in or bit-vectors. + // Base method uses fresh functions. + // Other method could use bv2int, int2bv axioms and coercions. + // f(args) = bv2int(f(int2bv(args')) + // + void solver::translate_app(app* e) { - bool has_bv_arg = any_of(*e, [&](expr* arg) { return bv.is_bv(arg); }); - bool has_bv_sort = bv.is_bv(e); - func_decl* f = e->get_decl(); - if (has_bv_arg) { - verbose_stream() << mk_pp(e, m) << "\n"; - // need to update args with mod where they are bit-vectors. - NOT_IMPLEMENTED_YET(); + + if (m_is_plugin && m.is_bool(e)) { + set_translated(e, e); + return; } - if (has_bv_arg || has_bv_sort) { - ptr_vector domain; - for (auto* arg : *e) { - sort* s = arg->get_sort(); - domain.push_back(bv.is_bv_sort(s) ? a.mk_int() : s); + bool has_bv_sort = bv.is_bv(e); + func_decl* f = e->get_decl(); + + for (unsigned i = 0; i < m_args.size(); ++i) + if (bv.is_bv(e->get_arg(i))) + m_args[i] = bv.mk_int2bv(bv.get_bv_size(e->get_arg(i)), m_args.get(i)); + + if (has_bv_sort) + m_vars.push_back(e); + + if (m_is_plugin) { + expr* r = m.mk_app(f, m_args); + if (has_bv_sort) { + ctx.push(push_back_vector(m_vars)); + r = bv.mk_bv2int(r); } - sort* range = bv.is_bv(e) ? a.mk_int() : e->get_sort(); + set_translated(e, r); + return; + } + else if (has_bv_sort) { func_decl* g = nullptr; if (!m_new_funs.find(f, g)) { - g = m.mk_fresh_func_decl(e->get_decl()->get_name(), symbol("bv"), domain.size(), domain.data(), range); + g = m.mk_fresh_func_decl(e->get_decl()->get_name(), symbol("bv"), f->get_arity(), f->get_domain(), a.mk_int()); m_new_funs.insert(f, g); - m_pinned.push_back(f); - m_pinned.push_back(g); } f = g; } - - set_translated(e, m.mk_app(f, m_args)); - - if (has_bv_sort) - m_vars.insert(e, { translated(e), bv_size()}); + set_translated(e, m.mk_app(f, m_args)); } void solver::translate_bv(app* e) { @@ -403,61 +491,59 @@ namespace intblast { return r; }; - bv_expr = e; + expr* bv_expr = e; expr* r = nullptr; auto const& args = m_args; switch (e->get_decl_kind()) { case OP_BADD: - r = (a.mk_add(args)); + r = a.mk_add(args); break; case OP_BSUB: - r = (a.mk_sub(args.size(), args.data())); + r = a.mk_sub(args.size(), args.data()); break; case OP_BMUL: - r = (a.mk_mul(args)); + r = a.mk_mul(args); break; case OP_ULEQ: bv_expr = e->get_arg(0); - r = (a.mk_le(mk_mod(arg(0)), mk_mod(arg(1)))); + r = a.mk_le(umod(bv_expr, 0), umod(bv_expr, 1)); break; case OP_UGEQ: bv_expr = e->get_arg(0); - r = (a.mk_ge(mk_mod(arg(0)), mk_mod(arg(1)))); + r = a.mk_ge(umod(bv_expr, 0), umod(bv_expr, 1)); break; case OP_ULT: bv_expr = e->get_arg(0); - r = (a.mk_lt(mk_mod(arg(0)), mk_mod(arg(1)))); + r = a.mk_lt(umod(bv_expr, 0), umod(bv_expr, 1)); break; case OP_UGT: bv_expr = e->get_arg(0); - r = (a.mk_gt(mk_mod(arg(0)), mk_mod(arg(1)))); + r = a.mk_gt(umod(bv_expr, 0), umod(bv_expr, 1)); break; case OP_SLEQ: bv_expr = e->get_arg(0); - r = (a.mk_le(mk_smod(arg(0)), mk_smod(arg(1)))); + r = a.mk_le(smod(bv_expr, 0), smod(bv_expr, 1)); break; case OP_SGEQ: - r = (a.mk_ge(mk_smod(arg(0)), mk_smod(arg(1)))); + r = a.mk_ge(smod(bv_expr, 0), smod(bv_expr, 1)); break; case OP_SLT: bv_expr = e->get_arg(0); - r = (a.mk_lt(mk_smod(arg(0)), mk_smod(arg(1)))); + r = a.mk_lt(smod(bv_expr, 0), smod(bv_expr, 1)); break; case OP_SGT: bv_expr = e->get_arg(0); - r = (a.mk_gt(mk_smod(arg(0)), mk_smod(arg(1)))); + r = a.mk_gt(smod(bv_expr, 0), smod(bv_expr, 1)); break; case OP_BNEG: - r = (a.mk_uminus(arg(0))); + r = a.mk_uminus(arg(0)); break; case OP_CONCAT: { r = a.mk_int(0); unsigned sz = 0; for (unsigned i = 0; i < args.size(); ++i) { expr* old_arg = e->get_arg(i); - expr* new_arg = arg(i); - bv_expr = old_arg; - new_arg = mk_mod(new_arg); + expr* new_arg = umod(old_arg, i); if (sz > 0) { new_arg = a.mk_mul(new_arg, a.mk_int(rational::power_of_two(sz))); r = a.mk_add(r, new_arg); @@ -482,23 +568,22 @@ namespace intblast { rational val; unsigned sz; VERIFY(bv.is_numeral(e, val, sz)); - r = (a.mk_int(val)); + r = a.mk_int(val); break; } case OP_BUREM_I: { expr* x = arg(0), * y = arg(1); - r = (m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(0), a.mk_mod(x, y))); + r = m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(0), a.mk_mod(x, y)); break; } case OP_BUDIV_I: { expr* x = arg(0), * y = arg(1); - r = (m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(0), a.mk_idiv(x, y))); + r = m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(0), a.mk_idiv(x, umod(bv_expr, 1))); break; } case OP_BUMUL_NO_OVFL: { - expr* x = arg(0), * y = arg(1); bv_expr = e->get_arg(0); - r = (a.mk_lt(a.mk_mul(mk_mod(x), mk_mod(y)), a.mk_int(bv_size()))); + r = a.mk_lt(a.mk_mul(umod(bv_expr, 0), umod(bv_expr, 1)), a.mk_int(bv_size(bv_expr))); break; } case OP_BSHL: { @@ -509,7 +594,7 @@ namespace intblast { break; } case OP_BNOT: - r = (bnot(arg(0))); + r = bnot(arg(0)); break; case OP_BLSHR: { expr* x = arg(0), * y = arg(1); @@ -518,7 +603,7 @@ namespace intblast { r = m.mk_ite(a.mk_eq(y, a.mk_int(i)), a.mk_idiv(x, a.mk_int(rational::power_of_two(i))), r); break; } - // Or use (p + q) - band(p, q)? + // Or use (p + q) - band(p, q)? case OP_BOR: { r = arg(0); for (unsigned i = 1; i < args.size(); ++i) @@ -537,46 +622,43 @@ namespace intblast { case OP_BXNOR: case OP_BXOR: { unsigned sz = bv.get_bv_size(e); - expr* p = arg(0); + r = arg(0); for (unsigned i = 1; i < args.size(); ++i) { expr* q = arg(i); - p = a.mk_sub(a.mk_add(p, q), a.mk_mul(a.mk_int(2), a.mk_band(sz, p, q))); + r = a.mk_sub(a.mk_add(r, q), a.mk_mul(a.mk_int(2), a.mk_band(sz, r, q))); } if (e->get_decl_kind() == OP_BXNOR) - p = bnot(p); - r = (p); + r = bnot(r); break; } case OP_BUDIV: { bv_rewriter_params p(ctx.s().params()); expr* x = arg(0), * y = arg(1); if (p.hi_div0()) - r = (m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(0), a.mk_idiv(x, y))); + r = m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(0), a.mk_idiv(x, y)); else - r = (a.mk_idiv(x, y)); + r = a.mk_idiv(x, y); break; } case OP_BUREM: { bv_rewriter_params p(ctx.s().params()); expr* x = arg(0), * y = arg(1); if (p.hi_div0()) - r = (m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(0), a.mk_mod(x, y))); + r = m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(0), a.mk_mod(x, y)); else - r = (a.mk_mod(x, y)); + r = a.mk_mod(x, y); break; } + // // ashr(x, y) // if y = k & x >= 0 -> x / 2^k // if y = k & x < 0 -> - (x / 2^k) // - case OP_BASHR: { - expr* x = arg(0), * y = arg(1); rational N = rational::power_of_two(bv.get_bv_size(e)); - bv_expr = e; - x = mk_mod(x); - y = mk_mod(y); + expr* x = umod(e, 0); + expr* y = umod(e, 1); expr* signbit = a.mk_ge(x, a.mk_int(N / 2)); r = m.mk_ite(signbit, a.mk_int(N - 1), a.mk_int(0)); for (unsigned i = 0; i < bv.get_bv_size(e); ++i) { @@ -587,15 +669,39 @@ namespace intblast { } break; } + case OP_ZERO_EXT: + bv_expr = e->get_arg(0); + r = umod(bv_expr, 0); + SASSERT(bv.get_bv_size(e) >= bv.get_bv_size(bv_expr)); + break; + case OP_SIGN_EXT: { + bv_expr = e->get_arg(0); + r = umod(bv_expr, 0); + SASSERT(bv.get_bv_size(e) >= bv.get_bv_size(bv_expr)); + unsigned arg_sz = bv.get_bv_size(bv_expr); + unsigned sz = bv.get_bv_size(e); + rational N = rational::power_of_two(sz); + rational M = rational::power_of_two(arg_sz); + expr* signbit = a.mk_ge(r, a.mk_int(M / 2)); + r = m.mk_ite(signbit, a.mk_uminus(r), r); + break; + } + case OP_INT2BV: + m_int2bv.push_back(e); + ctx.push(push_back_vector(m_int2bv)); + r = arg(0); + break; + case OP_BV2INT: + m_bv2int.push_back(e); + ctx.push(push_back_vector(m_bv2int)); + r = arg(0); + break; case OP_BCOMP: - case OP_ROTATE_LEFT: case OP_ROTATE_RIGHT: case OP_EXT_ROTATE_LEFT: case OP_EXT_ROTATE_RIGHT: case OP_REPEAT: - case OP_ZERO_EXT: - case OP_SIGN_EXT: case OP_BREDOR: case OP_BREDAND: case OP_BSDIV: @@ -610,19 +716,19 @@ namespace intblast { } set_translated(e, r); } - + void solver::translate_basic(app* e) { if (m.is_eq(e)) { bool has_bv_arg = any_of(*e, [&](expr* arg) { return bv.is_bv(arg); }); if (has_bv_arg) { - bv_expr = e->get_arg(0); - set_translated(e, m.mk_eq(mk_mod(arg(0)), mk_mod(arg(1)))); + expr* bv_expr = e->get_arg(0); + set_translated(e, m.mk_eq(umod(bv_expr, 0), umod(bv_expr, 1))); } - else - set_translated(e, m.mk_eq(arg(0), arg(1))); + else + set_translated(e, m.mk_eq(arg(0), arg(1))); } - else - set_translated(e, m.mk_app(e->get_decl(), m_args)); + else + set_translated(e, m.mk_app(e->get_decl(), m_args)); } rational solver::get_value(expr* e) const { @@ -630,11 +736,9 @@ namespace intblast { model_ref mdl; m_solver->get_model(mdl); expr_ref r(m); - var_info vi; + r = translated(e); rational val; - if (!m_vars.find(e, vi)) - return rational::zero(); - if (!mdl->eval_expr(vi.dst, r, true)) + if (!mdl->eval_expr(r, r, true)) return rational::zero(); if (!a.is_numeral(r, val)) return rational::zero(); @@ -642,6 +746,32 @@ namespace intblast { } void solver::add_value(euf::enode* n, model& mdl, expr_ref_vector& values) { + if (m_is_plugin) + add_value_plugin(n, mdl, values); + else + add_value_solver(n, mdl, values); + } + + bool solver::add_dep(euf::enode* n, top_sort& dep) { + // bv2int + auto e = ctx.get_enode(translated(n->get_expr())); + if (!e) + return false; + dep.add(n, e); + } + + // TODO: handle dependencies properly by using arithmetical model to retrieve values of translated + // bit-vectors directly. + void solver::add_value_plugin(euf::enode* n, model& mdl, expr_ref_vector& values) { + SASSERT(bv.is_bv(n->get_expr())); + rational N = rational::power_of_two(bv.get_bv_size(n->get_expr())); + auto e = ctx.get_enode(translated(n->get_expr())); + expr_ref value(m); + value = values.get(e->get_root_id()); + values.setx(n->get_root_id(), value); + } + + void solver::add_value_solver(euf::enode* n, model& mdl, expr_ref_vector& values) { expr_ref value(m); if (n->interpreted()) value = n->get_expr(); diff --git a/src/sat/smt/intblast_solver.h b/src/sat/smt/intblast_solver.h index 037b009a3..707f53832 100644 --- a/src/sat/smt/intblast_solver.h +++ b/src/sat/smt/intblast_solver.h @@ -46,23 +46,18 @@ namespace euf { namespace intblast { class solver : public euf::th_euf_solver { - struct var_info { - expr* dst; - rational sz; - }; - euf::solver& ctx; sat::solver& s; ast_manager& m; bv_util bv; arith_util a; scoped_ptr<::solver> m_solver; - obj_map m_vars; obj_map m_new_funs; expr_ref_vector m_translate, m_args; - ast_ref_vector m_pinned; sat::literal_vector m_core; + ptr_vector m_bv2int, m_int2bv; statistics m_stats; + bool m_is_plugin = true; // when the solver is used as a plugin, then do not translate below quantifiers. bool is_bv(sat::literal lit); void translate(expr_ref_vector& es); @@ -70,14 +65,13 @@ namespace intblast { rational get_value(expr* e) const; - expr* translated(expr* e) { expr* r = m_translate.get(e->get_id(), nullptr); SASSERT(r); return r; } + expr* translated(expr* e) const { expr* r = m_translate.get(e->get_id(), nullptr); SASSERT(r); return r; } void set_translated(expr* e, expr* r) { m_translate.setx(e->get_id(), r); } expr* arg(unsigned i) { return m_args.get(i); } - expr* mk_mod(expr* x); - expr* mk_smod(expr* x); - expr* bv_expr = nullptr; - rational bv_size(); + expr* umod(expr* bv_expr, unsigned i); + expr* smod(expr* bv_expr, unsigned i); + rational bv_size(expr* bv_expr); void translate_expr(expr* e); void translate_bv(app* e); @@ -89,8 +83,15 @@ namespace intblast { void ensure_args(app* e); void internalize_bv(app* e); + unsigned m_vars_qhead = 0; + ptr_vector m_vars; + void add_bound_axioms(); + euf::theory_var mk_var(euf::enode* n) override; + void add_value_plugin(euf::enode* n, model& mdl, expr_ref_vector& values); + void add_value_solver(euf::enode* n, model& mdl, expr_ref_vector& values); + public: solver(euf::solver& ctx); @@ -102,12 +103,12 @@ namespace intblast { void add_value(euf::enode* n, model& mdl, expr_ref_vector& values) override; + bool add_dep(euf::enode* n, top_sort& dep) override; + std::ostream& display(std::ostream& out) const override; void collect_statistics(statistics& st) const override; - - bool unit_propagate() override { return false; } void get_antecedents(sat::literal l, sat::ext_justification_idx idx, sat::literal_vector& r, bool probing) override {} @@ -130,7 +131,7 @@ namespace intblast { sat::literal internalize(expr* e, bool, bool) override; - void eq_internalized(euf::enode* n) override {} + void eq_internalized(euf::enode* n) override; }; diff --git a/src/sat/smt/polysat/umul_ovfl_constraint.cpp b/src/sat/smt/polysat/umul_ovfl_constraint.cpp index 5d185e7ee..445169c2f 100644 --- a/src/sat/smt/polysat/umul_ovfl_constraint.cpp +++ b/src/sat/smt/polysat/umul_ovfl_constraint.cpp @@ -77,6 +77,8 @@ namespace polysat { } void umul_ovfl_constraint::propagate(core& c, lbool value, dependency const& dep) { + if (value == l_undef) + return; auto& C = c.cs(); auto p1 = c.subst(p()); auto q1 = c.subst(q()); diff --git a/src/sat/smt/polysat_solver.cpp b/src/sat/smt/polysat_solver.cpp index 690548aaa..c14ca1d14 100644 --- a/src/sat/smt/polysat_solver.cpp +++ b/src/sat/smt/polysat_solver.cpp @@ -256,6 +256,32 @@ namespace polysat { void solver::add_polysat_clause(char const* name, core_vector cs, bool is_redundant) { sat::literal_vector lits; + signed_constraint sc; + unsigned constraint_count = 0; + for (auto e : cs) { + if (std::holds_alternative(e)) { + sc = *std::get_if(&e); + constraint_count++; + } + } + if (constraint_count == 1) { + auto lit = ctx.mk_literal(constraint2expr(sc)); + svector eqs; + for (auto e : cs) { + if (std::holds_alternative(e)) { + auto d = *std::get_if(&e); + if (d.is_literal()) + lits.push_back(d.literal()); + else if (d.is_eq()) { + auto [v1, v2] = d.eq(); + eqs.push_back({ var2enode(v1), var2enode(v2) }); + } + } + } + ctx.propagate(lit, euf::th_explain::propagate(*this, lits, eqs, lit, nullptr)); + return; + } + for (auto e : cs) { if (std::holds_alternative(e)) { auto d = *std::get_if(&e);