diff --git a/src/math/polysat/constraint.h b/src/math/polysat/constraint.h index d6a38304c..ac42f8cdd 100644 --- a/src/math/polysat/constraint.h +++ b/src/math/polysat/constraint.h @@ -255,6 +255,9 @@ namespace polysat { void add_to_univariate_solver(solver& s, univariate_solver& us, unsigned dep) const { get()->add_to_univariate_solver(s, us, dep, is_positive()); } + unsigned var(unsigned idx) const { return m_constraint->var(idx); } + bool contains_var(pvar v) const { return m_constraint->contains_var(v); } + sat::bool_var bvar() const { return m_constraint->bvar(); } sat::literal blit() const { return sat::literal(bvar(), is_negative()); } constraint* get() const { return m_constraint; } diff --git a/src/math/polysat/interval.h b/src/math/polysat/interval.h index 952d53a1b..9bea56d41 100644 --- a/src/math/polysat/interval.h +++ b/src/math/polysat/interval.h @@ -115,6 +115,27 @@ namespace polysat { return true; return false; } + + eval_interval intersect(eval_interval const& other) const { + if (is_full()) return other; + if (other.is_full()) return *this; + + pdd i_lo = lo(); + rational i_lo_val = lo_val(); + if (lo_val() < other.lo_val()) { + i_lo = other.lo(); + i_lo_val = other.lo_val(); + } + + pdd i_hi = hi(); + rational i_hi_val = hi_val(); + if (hi_val() > other.hi_val()) { + i_hi = other.hi(); + i_hi_val = other.hi_val(); + } + + return eval_interval::proper(i_lo, i_lo_val, i_hi, i_hi_val); + } }; inline std::ostream& operator<<(std::ostream& os, eval_interval const& i) { diff --git a/src/math/polysat/solver.cpp b/src/math/polysat/solver.cpp index 256f2c1f3..96deda234 100644 --- a/src/math/polysat/solver.cpp +++ b/src/math/polysat/solver.cpp @@ -1193,6 +1193,69 @@ namespace polysat { if (all_ok) LOG("All good!"); return all_ok; } + + + + // All variables of clause 'cl' except 'z' are assigned. + // Goal: a possibly weaker clause that implies a restriction on z around z_val + clause_ref solver::make_asserting(clause& cl, pvar z, rational z_val) { + signed_constraints cz; // constraints of 'cl' that contain 'z' + sat::literal_vector lits; // literals of the new clause + for (sat::literal lit : cl) { + signed_constraint c = lit2cnstr(lit); + if (c.contains_var(z)) + cz.push_back(c); + else + lits.push_back(lit); + } + SASSERT(!cz.empty()); + if (cz.size() == 1) { + // TODO: even in this case, if the constraint is non-linear in z, we might want to extract a maximal forbidden interval around z_val. + return nullptr; + } + else { + // multiple constraints that contain z + find_implied_constraint(cz, z, z_val, lits); + return clause::from_literals(std::move(lits)); + } + } + + // Each constraint in 'cz' is univariate in 'z' under the current assignment. + // Goal: a literal that is implied by the disjunction of cz and ensures z != z_val in viable. + // (plus side conditions that do not depend on z) + void solver::find_implied_constraint(signed_constraints const& cz, pvar z, rational z_val, sat::literal_vector& out_lits) + { + unsigned const out_lits_original_size = out_lits.size(); + + forbidden_intervals fi(*this); + fi_record entry({ eval_interval::full(), {}, {}, rational::one()}); + + auto intersection = eval_interval::full(); + bool all_unit = true; + + for (signed_constraint const& c : cz) { + if (fi.get_interval(c, z, entry) && entry.coeff == 1) { + intersection = intersection.intersect(entry.interval); + for (auto const& sc : entry.side_cond) + out_lits.push_back(sc.blit()); + } else { + all_unit = false; + break; + } + } + + if (all_unit) { + // Unit intervals from all constraints + // => build constraint from intersection of forbidden intervals + // z \not\in [l;u[ <=> z - l >= u - l + auto c_new = ule(intersection.hi() - intersection.lo(), z - intersection.lo()); + out_lits.push_back(c_new.blit()); + return; + } else { + out_lits.shrink(out_lits_original_size); + // TODO: SAT-based approach + } + } } diff --git a/src/math/polysat/solver.h b/src/math/polysat/solver.h index 5cec5a360..05f939067 100644 --- a/src/math/polysat/solver.h +++ b/src/math/polysat/solver.h @@ -234,6 +234,9 @@ namespace polysat { bool can_propagate(); void propagate(); + clause_ref make_asserting(clause& cl, pvar z, rational z_val); + void find_implied_constraint(signed_constraints const& cz, pvar z, rational z_val, sat::literal_vector& out_lits); + public: /**