diff --git a/src/sat/smt/euf_solver.cpp b/src/sat/smt/euf_solver.cpp index b6606d4f6..b95e44d74 100644 --- a/src/sat/smt/euf_solver.cpp +++ b/src/sat/smt/euf_solver.cpp @@ -291,6 +291,26 @@ namespace euf { } } + void solver::get_eq_antecedents(enode* a, enode* b, literal_vector& r) { + m_egraph.begin_explain(); + m_explain.reset(); + m_egraph.explain_eq(m_explain, nullptr, a, b); + for (unsigned qhead = 0; qhead < m_explain.size(); ++qhead) { + size_t* e = m_explain[qhead]; + if (is_literal(e)) + r.push_back(get_literal(e)); + else { + size_t idx = get_justification(e); + auto* ext = sat::constraint_base::to_extension(idx); + SASSERT(ext != this); + sat::literal lit = sat::null_literal; + ext->get_antecedents(lit, idx, r, true); + } + } + m_egraph.end_explain(); + } + + void solver::get_th_antecedents(literal l, th_explain& jst, literal_vector& r, bool probing) { for (auto lit : euf::th_explain::lits(jst)) r.push_back(lit); diff --git a/src/sat/smt/euf_solver.h b/src/sat/smt/euf_solver.h index 9cac6e02a..7d2d01473 100644 --- a/src/sat/smt/euf_solver.h +++ b/src/sat/smt/euf_solver.h @@ -369,6 +369,7 @@ namespace euf { void flush_roots() override; void get_antecedents(literal l, ext_justification_idx idx, literal_vector& r, bool probing) override; + void get_eq_antecedents(enode* a, enode* b, literal_vector& r); void get_th_antecedents(literal l, th_explain& jst, literal_vector& r, bool probing); void add_eq_antecedent(bool probing, enode* a, enode* b); void explain_diseq(ptr_vector& ex, cc_justification* cc, enode* a, enode* b); diff --git a/src/sat/smt/intblast_solver.cpp b/src/sat/smt/intblast_solver.cpp index 46a42b0d9..0b9f08ec8 100644 --- a/src/sat/smt/intblast_solver.cpp +++ b/src/sat/smt/intblast_solver.cpp @@ -647,7 +647,7 @@ namespace intblast { bv_rewriter_params p(ctx.s().params()); expr* x = arg(0), * y = umod(e, 1); if (p.hi_div0()) - r = m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(0), y)); + r = m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(0), y); else r = a.mk_mod(x, y); break; diff --git a/src/sat/smt/polysat/core.cpp b/src/sat/smt/polysat/core.cpp index dd30e8227..793db7392 100644 --- a/src/sat/smt/polysat/core.cpp +++ b/src/sat/smt/polysat/core.cpp @@ -149,13 +149,14 @@ namespace polysat { m_var = m_var_queue.next_var(); s.trail().push(mk_dqueue_var(m_var, *this)); switch (m_viable.find_viable(m_var, m_value)) { - case find_t::empty: - s.set_lemma(m_viable.get_core(), 0, m_viable.explain()); - // propagate_unsat_core(); + case find_t::empty: + s.set_lemma(m_viable.get_core(), m_viable.explain()); + // propagate_unsat_core(); return sat::check_result::CR_CONTINUE; - case find_t::singleton: + case find_t::singleton: { s.propagate(m_constraints.eq(var2pdd(m_var), m_value), m_viable.explain()); return sat::check_result::CR_CONTINUE; + } case find_t::multiple: s.add_eq_literal(m_var, m_value); return sat::check_result::CR_CONTINUE; diff --git a/src/sat/smt/polysat/types.h b/src/sat/smt/polysat/types.h index 6f855f98b..e7beb3eb1 100644 --- a/src/sat/smt/polysat/types.h +++ b/src/sat/smt/polysat/types.h @@ -107,7 +107,7 @@ namespace polysat { virtual ~solver_interface() {} virtual void add_eq_literal(pvar v, rational const& val) = 0; virtual void set_conflict(dependency_vector const& core) = 0; - virtual void set_lemma(core_vector const& aux_core, unsigned level, dependency_vector const& core) = 0; + virtual void set_lemma(core_vector const& aux_core, dependency_vector const& core) = 0; virtual void add_polysat_clause(char const* name, core_vector cs, bool redundant) = 0; virtual dependency propagate(signed_constraint sc, dependency_vector const& deps) = 0; virtual void propagate(dependency const& d, bool sign, dependency_vector const& deps) = 0; diff --git a/src/sat/smt/polysat_solver.cpp b/src/sat/smt/polysat_solver.cpp index c14ca1d14..a713cc4a4 100644 --- a/src/sat/smt/polysat_solver.cpp +++ b/src/sat/smt/polysat_solver.cpp @@ -130,8 +130,13 @@ namespace polysat { return { core, eqs }; } - void solver::set_lemma(core_vector const& aux_core, unsigned level, dependency_vector const& core) { + void solver::set_lemma(core_vector const& aux_core, dependency_vector const& core) { auto [lits, eqs] = explain_deps(core); + unsigned level = 0; + for (auto const& [n1, n2] : eqs) + ctx.get_eq_antecedents(n1, n2, lits); + for (auto lit : lits) + level = std::max(level, s().lvl(lit)); auto ex = euf::th_explain::conflict(*this, lits, eqs, nullptr); ctx.push(value_trail(m_has_lemma)); m_has_lemma = true; diff --git a/src/sat/smt/polysat_solver.h b/src/sat/smt/polysat_solver.h index a04c76618..e88eafdd2 100644 --- a/src/sat/smt/polysat_solver.h +++ b/src/sat/smt/polysat_solver.h @@ -144,7 +144,7 @@ namespace polysat { // callbacks from core void add_eq_literal(pvar v, rational const& val) override; void set_conflict(dependency_vector const& core) override; - void set_lemma(core_vector const& aux_core, unsigned level, dependency_vector const& core) override; + void set_lemma(core_vector const& aux_core, dependency_vector const& core) override; dependency propagate(signed_constraint sc, dependency_vector const& deps) override; void propagate(dependency const& d, bool sign, dependency_vector const& deps) override; trail_stack& trail() override;