diff --git a/src/sat/smt/polysat/core.cpp b/src/sat/smt/polysat/core.cpp index 619b1fac7..835130380 100644 --- a/src/sat/smt/polysat/core.cpp +++ b/src/sat/smt/polysat/core.cpp @@ -369,6 +369,9 @@ namespace polysat { s.set_conflict(deps, "non-viable assignment"); return; } + s.propagate_eq(v, value, dep); + if (s.inconsistent()) + return; m_values[v] = value; m_justification[v] = dep; m_assignment.push(v , value); diff --git a/src/sat/smt/polysat/types.h b/src/sat/smt/polysat/types.h index 5258fd31b..02881c20d 100644 --- a/src/sat/smt/polysat/types.h +++ b/src/sat/smt/polysat/types.h @@ -141,6 +141,7 @@ namespace polysat { virtual void set_conflict(dependency_vector const& core, char const* hint) = 0; virtual dependency propagate(signed_constraint sc, dependency_vector const& deps, char const* hint) = 0; virtual void propagate(dependency const& d, bool sign, dependency_vector const& deps, char const* hint) = 0; + virtual void propagate_eq(pvar v, rational const& val, dependency const& d) = 0; virtual trail_stack& trail() = 0; virtual bool inconsistent() const = 0; virtual void get_bitvector_suffixes(pvar v, offset_slices& out) = 0; diff --git a/src/sat/smt/polysat_solver.cpp b/src/sat/smt/polysat_solver.cpp index 73f6d1bc7..5a167d166 100644 --- a/src/sat/smt/polysat_solver.cpp +++ b/src/sat/smt/polysat_solver.cpp @@ -268,6 +268,24 @@ namespace polysat { return dependency(lit.var()); } + void solver::propagate_eq(pvar pv, rational const& val, dependency const& d) { + auto v = m_pddvar2var[pv]; + auto a = var2enode(v); + auto bval = bv.mk_numeral(val, get_bv_size(v)); + ctx.internalize(bval); + auto b = ctx.get_enode(bval); + if (a->get_root() == b->get_root()) + return; + proof_hint* hint = nullptr; + sat::literal_vector core; + euf::enode_pair_vector eqs; + explain_dep(d, eqs, core); + if (ctx.use_drat()) + hint = mk_proof_hint("propagate-eq", core, eqs); + auto exp = euf::th_explain::propagate(*this, core, eqs, a, b, hint); + ctx.propagate(a, b, exp); + } + unsigned solver::level(dependency const& d) { if (d.is_bool_var()) return s().lvl(d.bool_var()); diff --git a/src/sat/smt/polysat_solver.h b/src/sat/smt/polysat_solver.h index 5b45533cb..b64e8fb7c 100644 --- a/src/sat/smt/polysat_solver.h +++ b/src/sat/smt/polysat_solver.h @@ -215,6 +215,7 @@ namespace polysat { bool add_axiom(char const* name, constraint_or_dependency const* begin, constraint_or_dependency const* end, bool redundant) override; dependency propagate(signed_constraint sc, dependency_vector const& deps, char const* hint) override; void propagate(dependency const& d, bool sign, dependency_vector const& deps, char const* hint) override; + void propagate_eq(pvar v, rational const& val, dependency const& d) override; trail_stack& trail() override; bool inconsistent() const override; void get_bitvector_sub_slices(pvar v, offset_slices& out) override;