diff --git a/src/sat/smt/polysat/core.h b/src/sat/smt/polysat/core.h index 32ad84fa3..b303ff8b8 100644 --- a/src/sat/smt/polysat/core.h +++ b/src/sat/smt/polysat/core.h @@ -125,10 +125,6 @@ namespace polysat { pdd bxnor(pdd a, pdd b) { NOT_IMPLEMENTED_YET(); throw default_exception("bxnor nyi"); } pdd bnor(pdd a, pdd b) { NOT_IMPLEMENTED_YET(); throw default_exception("bnotr nyi"); } pdd bnot(pdd a) { NOT_IMPLEMENTED_YET(); throw default_exception("bnot nyi"); } - pdd zero_ext(pdd a, unsigned sz) { NOT_IMPLEMENTED_YET(); throw default_exception("zero ext nyi"); } - pdd sign_ext(pdd a, unsigned sz) { NOT_IMPLEMENTED_YET(); throw default_exception("sign ext nyi"); } - pdd extract(pdd src, unsigned hi, unsigned lo) { NOT_IMPLEMENTED_YET(); throw default_exception("extract nyi"); } - pdd concat(unsigned n, pdd const* args) { NOT_IMPLEMENTED_YET(); throw default_exception("concat nyi"); } pvar add_var(unsigned sz); pdd var(pvar p) { return m_vars[p]; } unsigned size(pvar v) const { return m_vars[v].power_of_2(); } diff --git a/src/sat/smt/polysat_internalize.cpp b/src/sat/smt/polysat_internalize.cpp index 1b3cb1c04..387b2e5a7 100644 --- a/src/sat/smt/polysat_internalize.cpp +++ b/src/sat/smt/polysat_internalize.cpp @@ -135,8 +135,8 @@ namespace polysat { case OP_EXTRACT: internalize_extract(a); break; case OP_CONCAT: internalize_concat(a); break; - case OP_ZERO_EXT: internalize_par_unary(a, [&](pdd const& p, unsigned sz) { return m_core.zero_ext(p, sz); }); break; - case OP_SIGN_EXT: internalize_par_unary(a, [&](pdd const& p, unsigned sz) { return m_core.sign_ext(p, sz); }); break; + case OP_ZERO_EXT: internalize_zero_extend(a); break; + case OP_SIGN_EXT: internalize_sign_extend(a); break; // polysat::solver should also support at least: case OP_BREDAND: // x == 2^K - 1 unary, return single bit, 1 if all input bits are set. @@ -282,7 +282,38 @@ namespace polysat { add_polysat_clause("[axiom] quot_rem 4", { c_eq, ~m_core.ule(b, r) }, false); if (!c_eq.is_always_false()) add_polysat_clause("[axiom] quot_rem 5", { ~c_eq, m_core.eq(q + 1) }, false); + } + void solver::internalize_sign_extend(app* e) { + expr* arg = e->get_arg(0); + unsigned sz = bv.get_bv_size(e); + unsigned arg_sz = bv.get_bv_size(arg); + unsigned sz2 = sz - arg_sz; + + var2pdd(expr2enode(e)->get_th_var(get_id())); + + if (arg_sz == sz) + add_clause(eq_internalize(e, arg), false); + else { + sat::literal lt0 = ctx.mk_literal(bv.mk_slt(arg, bv.mk_numeral(0, arg_sz))); + // arg < 0 ==> e = concat(arg, 1...1) + // arg >= 0 ==> e = concat(arg, 0...0) + add_clause(lt0, eq_internalize(e, bv.mk_concat(arg, bv.mk_numeral(rational::power_of_two(sz2) - 1, sz2))), false); + add_clause(~lt0, eq_internalize(e, bv.mk_concat(arg, bv.mk_numeral(0, sz - arg_sz))), false); + } + } + + void solver::internalize_zero_extend(app* e) { + expr* arg = e->get_arg(0); + unsigned sz = bv.get_bv_size(e); + unsigned arg_sz = bv.get_bv_size(arg); + unsigned sz2 = sz - arg_sz; + var2pdd(expr2enode(e)->get_th_var(get_id())); + if (arg_sz == sz) + add_clause(eq_internalize(e, arg), false); + else + // e = concat(arg, 0...0) + add_clause(eq_internalize(e, bv.mk_concat(arg, bv.mk_numeral(0, sz - arg_sz))), false); } void solver::internalize_div_rem(app* e, bool is_div) { @@ -332,20 +363,13 @@ namespace polysat { } void solver::internalize_extract(app* e) { - unsigned const hi = bv.get_extract_high(e); - unsigned const lo = bv.get_extract_low(e); - auto const src = expr2pdd(e->get_arg(0)); - auto const p = m_core.extract(src, hi, lo); - SASSERT_EQ(p.power_of_2(), hi - lo + 1); + auto p = var2pdd(expr2enode(e)->get_th_var(get_id())); internalize_set(e, p); } void solver::internalize_concat(app* e) { SASSERT(bv.is_concat(e)); - vector args; - for (expr* arg : *e) - args.push_back(expr2pdd(arg)); - auto const p = m_core.concat(args.size(), args.data()); + auto p = var2pdd(expr2enode(e)->get_th_var(get_id())); internalize_set(e, p); } diff --git a/src/sat/smt/polysat_solver.cpp b/src/sat/smt/polysat_solver.cpp index 82a61486a..ad4beb561 100644 --- a/src/sat/smt/polysat_solver.cpp +++ b/src/sat/smt/polysat_solver.cpp @@ -30,9 +30,6 @@ The result of polysat::core::check is one of: #include "sat/smt/polysat/ule_constraint.h" #include "sat/smt/polysat/umul_ovfl_constraint.h" - - - namespace polysat { solver::solver(euf::solver& ctx, theory_id id): @@ -288,11 +285,13 @@ namespace polysat { expr* n = bv.mk_numeral(p.val(), p.power_of_2()); return expr_ref(n, m); } - auto lo = pdd2expr(p.lo()); - auto hi = pdd2expr(p.hi()); auto v = var2enode(m_pddvar2var[p.var()]); - hi = bv.mk_bv_mul(v->get_expr(), hi); - return expr_ref(bv.mk_bv_add(lo, hi), m); + expr* r = v->get_expr(); + if (!p.hi().is_one()) + r = bv.mk_bv_mul(r, pdd2expr(p.hi())); + if (!p.lo().is_zero()) + r = bv.mk_bv_add(r, pdd2expr(p.lo())); + return expr_ref(r, m); } // walk the egraph starting with pvar for overlaps. diff --git a/src/sat/smt/polysat_solver.h b/src/sat/smt/polysat_solver.h index 5a9c26cb8..3b5eb27ba 100644 --- a/src/sat/smt/polysat_solver.h +++ b/src/sat/smt/polysat_solver.h @@ -115,6 +115,8 @@ namespace polysat { void internalize_bit2bool(app* n); template void internalize_le(app* n); + void internalize_zero_extend(app* n); + void internalize_sign_extend(app* n); void internalize_udiv_i(app* e); void internalize_urem_i(app* e); void internalize_div_rem(app* e, bool is_div);