diff --git a/src/sat/sat_extension.h b/src/sat/sat_extension.h index d6a956a32..653ea4039 100644 --- a/src/sat/sat_extension.h +++ b/src/sat/sat_extension.h @@ -127,6 +127,8 @@ namespace sat { virtual bool tracking_assumptions() { return false; } virtual bool enable_self_propagate() const { return false; } + virtual void add_xor(literal_vector const& lits) { throw default_exception("solver does not support adding xor clauses"); } + virtual bool extract_pb(std::function& card, std::function& pb) { return false; diff --git a/src/sat/smt/euf_internalize.cpp b/src/sat/smt/euf_internalize.cpp index 9f747090a..6f49fb35d 100644 --- a/src/sat/smt/euf_internalize.cpp +++ b/src/sat/smt/euf_internalize.cpp @@ -31,6 +31,7 @@ Notes: #include "ast/pb_decl_plugin.h" #include "sat/smt/euf_solver.h" +#include "sat/smt/xor_solver.h" namespace euf { @@ -498,4 +499,17 @@ namespace euf { } } + void solver::add_xor(sat::literal_vector const& lits) { + family_id fid = m.mk_family_id("xor"); + auto* ext = m_id2solver.get(fid, nullptr); + th_solver* xr; + if (!ext) { + xr = alloc(xr::solver, *this); + add_solver(xr); + ext = xr; + } + xr = dynamic_cast(ext); + xr->add_xor(lits); + } + } diff --git a/src/sat/smt/euf_solver.h b/src/sat/smt/euf_solver.h index beb0809fb..64577430c 100644 --- a/src/sat/smt/euf_solver.h +++ b/src/sat/smt/euf_solver.h @@ -451,6 +451,7 @@ namespace euf { void set_bool_var2expr(sat::bool_var v, expr* e) { m_var_trail.push_back(v); m_bool_var2expr.setx(v, e, nullptr); } expr* bool_var2expr(sat::bool_var v) const { return m_bool_var2expr.get(v, nullptr); } expr_ref literal2expr(sat::literal lit) const { expr* e = bool_var2expr(lit.var()); return (e && lit.sign()) ? expr_ref(mk_not(m, e), m) : expr_ref(e, m); } + void add_xor(sat::literal_vector const& lits) override; unsigned generation() const { return m_generation; } sat::literal attach_lit(sat::literal lit, expr* e); diff --git a/src/sat/smt/xor_solver.h b/src/sat/smt/xor_solver.h index 76c8f1ee8..337b679b9 100644 --- a/src/sat/smt/xor_solver.h +++ b/src/sat/smt/xor_solver.h @@ -177,6 +177,10 @@ namespace xr { solver(ast_manager& m, sat::sat_internalizer& si, euf::theory_id id); th_solver* clone(euf::solver& ctx) override; + + void add_xor(sat::literal_vector const& lits) override { NOT_IMPLEMENTED_YET(); } + + sat::literal internalize(expr* e, bool sign, bool root) override { UNREACHABLE(); return sat::null_literal; } void internalize(expr* e) override { UNREACHABLE(); } diff --git a/src/sat/tactic/goal2sat.cpp b/src/sat/tactic/goal2sat.cpp index 0ca46c2dd..09b720383 100644 --- a/src/sat/tactic/goal2sat.cpp +++ b/src/sat/tactic/goal2sat.cpp @@ -612,7 +612,93 @@ struct goal2sat::imp : public sat::sat_internalizer { } } + unsigned get_num_args(app* t) { + if (!m_xor || (!m.is_iff(t) && !m.is_xor(t))) + return t->get_num_args(); + + unsigned n = 2; + while (m.is_iff(t->get_arg(1)) || m.is_xor(t->get_arg(1))) { + ++n; + t = to_app(t->get_arg(1)); + } + return n; + } + + expr* get_arg(app* t, unsigned idx) { + if (!m_xor || (!m.is_iff(t) && !m.is_xor(t))) + return t->get_arg(idx); + + while (idx >= 1) { + SASSERT(m.is_iff(t) || m.is_xor(t)); + t = to_app(t->get_arg(1)); + --idx; + } + if (m.is_iff(t) || m.is_xor(t)) + return t->get_arg(idx); + else + return t; + } + + bool is_iff(app* t) { + bool r = true; + expr* e = t, *x = nullptr; + do { + if (m.is_iff(e, x, e)) + continue; + if (m.is_xor(e, x, e)) { + r = !r; + continue; + } + break; + } + while (true); + return r; + } + + /** + * Convert xor expressions to native xor solver directly. + */ + + void convert_iff_native(app * t, bool root, bool sign) { + unsigned sz = m_result_stack.size(); + unsigned num_args = get_num_args(t); + ptr_buffer args; + SASSERT(sz >= num_args && num_args >= 2); + sat::literal_vector lits; + sat::bool_var v = add_var(true, t); + lits.push_back(sat::literal(v, is_iff(t))); + for (unsigned i = 0; i < num_args; ++i) { + sat::literal lit(m_result_stack[sz - num_args + i]); + m_solver.set_external(lit.var()); + lits.push_back(lit); + } + // ensure that = is converted to xor + for (unsigned i = 1; i + 1 < lits.size(); ++i) + lits[i].neg(); + TRACE("goal2sat", tout << "convert-xor " << mk_bounded_pp(t, m, 8) << " " << lits << "\n"); + + ensure_xor(); + m_solver.get_extension()->add_xor(lits); + if (aig()) + aig()->add_xor(~lits.back(), lits.size() - 1, lits.data() + 1); + sat::literal lit(v, sign); + if (root) { + m_result_stack.reset(); + mk_root_clause(lit); + } + else { + m_result_stack.shrink(sz - num_args); + m_result_stack.push_back(lit); + } + } + + void convert_iff(app * t, bool root, bool sign) { + if (m_xor) { + convert_iff_native(t, root, sign); + return; + } + if (t->get_num_args() != 2) throw default_exception("unexpected number of arguments to " + mk_pp(t, m)); SASSERT(t->get_num_args() == 2); @@ -824,9 +910,9 @@ struct goal2sat::imp : public sat::sat_internalizer { visit(t->get_arg(0), root, !sign); continue; } - unsigned num = t->get_num_args(); + unsigned num = get_num_args(t); while (m_frame_stack[fsz-1].m_idx < num) { - expr * arg = t->get_arg(m_frame_stack[fsz-1].m_idx); + expr * arg = get_arg(t, m_frame_stack[fsz-1].m_idx); m_frame_stack[fsz - 1].m_idx++; if (!visit(arg, false, false)) goto loop;