diff --git a/src/math/interval/mod_interval.h b/src/math/interval/mod_interval.h index 88dfff485..50b6efe94 100644 --- a/src/math/interval/mod_interval.h +++ b/src/math/interval/mod_interval.h @@ -59,6 +59,7 @@ public: bool is_empty() const { return emp; } bool is_singleton() const { return !is_empty() && (lo + 1 == hi || (hi == 0 && is_max(lo))); } bool contains(Numeral const& n) const; + bool contains(mod_interval const& other) const; virtual bool is_max(Numeral const& n) const { return (Numeral)(n + 1) == 0; } void set_free() { lo = hi = 0; emp = false; } diff --git a/src/math/interval/mod_interval_def.h b/src/math/interval/mod_interval_def.h index 06b12b56c..cc5a78d1b 100644 --- a/src/math/interval/mod_interval_def.h +++ b/src/math/interval/mod_interval_def.h @@ -32,6 +32,27 @@ bool mod_interval::contains(Numeral const& n) const { return lo <= n || n < hi; } +template +bool mod_interval::contains(mod_interval const& other) const { + if (is_empty()) + return other.is_empty(); + if (is_free()) + return true; + if (hi == 0) + return lo <= other.lo && (other.lo < other.hi || other.hi == 0); + if (lo < hi) + return lo <= other.lo && other.hi <= hi; + if (other.lo < other.hi && other.hi <= hi) + return true; + if (other.lo < other.hi && lo <= other.lo) + return true; + if (other.hi == 0) + return lo <= other.lo; + SASSERT(other.hi < other.lo && other.hi != 0); + SASSERT(hi < lo && hi != 0); + return lo <= other.lo && other.hi <= hi; +} + template mod_interval mod_interval::operator+(mod_interval const& other) const { if (is_empty()) diff --git a/src/math/polysat/fixplex.h b/src/math/polysat/fixplex.h index 34571b897..dd5dfe57a 100644 --- a/src/math/polysat/fixplex.h +++ b/src/math/polysat/fixplex.h @@ -29,6 +29,7 @@ Author: #include "util/dependency.h" #include "util/ref.h" #include "util/params.h" +#include "util/union_find.h" inline rational to_rational(uint64_t n) { return rational(n, rational::ui64()); } inline unsigned trailing_zeros(unsigned short n) { return trailing_zeros((uint32_t)n); } @@ -183,8 +184,7 @@ namespace polysat { var_heap m_to_patch; vector m_vars; vector m_rows; - vector m_var_eqs; - vector m_fixed_vals; + bool m_bland = false ; unsigned m_blands_rule_threshold = 1000; unsigned m_num_repeated = 0; @@ -198,7 +198,14 @@ namespace polysat { u_dependency_manager m_deps; svector m_trail; svector m_row_trail; + + // euqality propagation + union_find_default_ctx m_union_find_ctx; + union_find<> m_union_find; + vector m_var_eqs; + vector m_fixed_vals; map m_value2fixed_var; + uint_set m_eq_rows; // inequalities svector m_ineqs; @@ -206,11 +213,15 @@ namespace polysat { uint_set m_touched_vars; vector m_var2ineqs; + // bound propagation + uint_set m_bound_rows; + public: fixplex(params_ref const& p, reslimit& lim): m_limit(lim), M(m), - m_to_patch(1024) { + m_to_patch(1024), + m_union_find(m_union_find_ctx) { updt_params(p); } @@ -249,7 +260,6 @@ namespace polysat { svector> stack; uint_set on_stack; lbool propagate_ineqs(unsigned idx); - void propagate_eqs(); vector const& var_eqs() const { return m_var_eqs; } void add_row(var_t base, unsigned num_vars, var_t const* vars, numeral const* coeffs); @@ -266,6 +276,9 @@ namespace polysat { bool patch(); bool propagate(); + bool propagate_ineqs(); + bool propagate_row_eqs(); + bool propagate_row_bounds(); bool is_satisfied(); var_t select_smallest_var() { return m_to_patch.empty()?null_var:m_to_patch.erase_min(); } @@ -276,6 +289,8 @@ namespace polysat { void lookahead_eq(row const& r1, numeral const& cx, var_t x, numeral const& cy, var_t y); void get_offset_eqs(row const& r); void fixed_var_eh(u_dependency* dep, var_t x); + var_t find(var_t x) { return m_union_find.find(x); } + void merge(var_t x, var_t y) { m_union_find.merge(x, y); } void eq_eh(var_t x, var_t y, u_dependency* dep); bool propagate_row(row const& r); bool propagate_ineq(ineq const& i); diff --git a/src/math/polysat/fixplex_def.h b/src/math/polysat/fixplex_def.h index 6eaf17b03..1223fe783 100644 --- a/src/math/polysat/fixplex_def.h +++ b/src/math/polysat/fixplex_def.h @@ -91,11 +91,13 @@ namespace polysat { void fixplex::push() { m_trail.push_back(trail_i::inc_level_i); m_deps.push_scope(); + m_union_find_ctx.get_trail_stack().push_scope(); } template void fixplex::pop(unsigned n) { m_deps.pop_scope(n); + m_union_find_ctx.get_trail_stack().pop_scope(n); while (n > 0) { switch (m_trail.back()) { case trail_i::inc_level_i: @@ -135,6 +137,7 @@ namespace polysat { while (v >= m_vars.size()) { M.ensure_var(m_vars.size()); m_vars.push_back(var_info()); + m_union_find.mk_var(); } if (m_to_patch.get_bounds() <= v) m_to_patch.set_bounds(2 * v + 1); @@ -242,6 +245,8 @@ namespace polysat { SASSERT(well_formed()); m_trail.push_back(trail_i::add_row_i); m_row_trail.push_back(base_var); + m_eq_rows.insert(r.id()); + m_bound_rows.insert(r.id()); } template @@ -286,6 +291,8 @@ namespace polysat { m_rows[r.id()].m_base = null_var; m_non_integral.remove(r.id()); M.del(r); + m_eq_rows.remove(r.id()); + m_bound_rows.remove(r.id()); SASSERT(M.col_begin(var) == M.col_end(var)); SASSERT(well_formed()); } @@ -671,11 +678,21 @@ namespace polysat { */ template bool fixplex::propagate() { - lbool r; - while (!m_ineqs_to_propagate.empty()) { - unsigned idx = *m_ineqs_to_propagate.begin(); - if (idx < m_ineqs.size() && (r = propagate_ineqs(idx), r == l_false)) - return false; + return propagate_ineqs() && propagate_row_bounds() && propagate_row_eqs(); + } + + template + bool fixplex::propagate_ineqs() { + lbool r = l_true; + while (!m_ineqs_to_propagate.empty() && r == l_true) { + unsigned idx = *m_ineqs_to_propagate.begin(); + if (idx >= m_ineqs.size()) { + m_ineqs_to_propagate.remove(idx); + continue; + } + r = propagate_ineqs(idx); + if (r == l_undef) + return true; m_ineqs_to_propagate.remove(idx); } return true; @@ -845,6 +862,9 @@ namespace polysat { * c - coefficient of y in r_z * * returns true if elimination preserves equivalence (is lossless). + * + * TBD: add r_z.id() to m_eq_rows, m_bound_rows with some frequency? + * */ template bool fixplex::eliminate_var( @@ -871,6 +891,7 @@ namespace polysat { return tz_b <= tz_c; } +#if 0 template bool fixplex::is_feasible() const { for (unsigned i = m_vars.size(); i-- > 0; ) @@ -878,6 +899,7 @@ namespace polysat { return false; return true; } +#endif /*** * Record an infeasible row. @@ -1055,9 +1077,11 @@ namespace polysat { */ template - void fixplex::propagate_eqs() { - for (unsigned i = 0; i < m_rows.size(); ++i) - get_offset_eqs(row(i)); + bool fixplex::propagate_row_eqs() { + for (unsigned i : m_eq_rows) + get_offset_eqs(row(i)); + m_eq_rows.reset(); + return !inconsistent(); } @@ -1113,11 +1137,10 @@ namespace polysat { std::swap(z, u); std::swap(cz, cu); } - if (z == x && u != y && cx == cz && cu == cy && value(u) == value(y)) + if (z == x && find(u) != find(y) && cx == cz && cu == cy && value(u) == value(y)) eq_eh(u, y, m_deps.mk_join(row2dep(r1), row2dep(r2))); - if (z == x && u != y && cx + cz == 0 && cu + cy == 0 && value(u) == value(y)) + if (z == x && find(u) != find(y) && cx + cz == 0 && cu + cy == 0 && value(u) == value(y)) eq_eh(u, y, m_deps.mk_join(row2dep(r1), row2dep(r2))); - } } @@ -1129,8 +1152,8 @@ namespace polysat { numeral val = value(x); fix_entry e; if (m_value2fixed_var.find(val, e)) { - SASSERT(x != e.x); - eq_eh(x, e.x, m_deps.mk_join(e.dep, dep)); + if (find(x) != find(e.x)) + eq_eh(x, e.x, m_deps.mk_join(e.dep, dep)); } else { m_value2fixed_var.insert(val, fix_entry(x, dep)); @@ -1141,23 +1164,20 @@ namespace polysat { template void fixplex::eq_eh(var_t x, var_t y, u_dependency* dep) { + SASSERT(find(x) != find(y)); + merge(x, y); m_var_eqs.push_back(var_eq(x, y, dep)); m_trail.push_back(trail_i::add_eq_i); - } + } -#if 0 template - lbool fixplex::propagate_bounds() { - lbool r = l_true; - for (unsigned i = 0; i < m_rows.size(); ++i) + bool fixplex::propagate_row_bounds() { + for (unsigned i : m_bound_rows) if (!propagate_row(row(i))) - return l_false; - for (auto ineq : m_ineqs) - if (r = propagate_ineqs(ineq), r != l_true) - return r; - return l_true; + return false; + m_bound_rows.reset(); + return true; } -#endif // // DFS search propagating inequalities diff --git a/src/test/fixplex.cpp b/src/test/fixplex.cpp index 79048553b..e28cec4b8 100644 --- a/src/test/fixplex.cpp +++ b/src/test/fixplex.cpp @@ -114,7 +114,6 @@ namespace polysat { fp.add_row(z, 3, ys2, coeffs2); fp.set_bounds(u, 1, 2, 1); fp.run(); - fp.propagate_eqs(); for (auto e : fp.var_eqs()) std::cout << e.x << " == " << e.y << "\n"; diff --git a/src/util/union_find.h b/src/util/union_find.h index c82d25857..7e42e1bba 100644 --- a/src/util/union_find.h +++ b/src/util/union_find.h @@ -93,6 +93,11 @@ public: return r; } + void reserve(unsigned v) { + while (get_num_vars() <= v) + mk_var(); + } + unsigned get_num_vars() const { return m_find.size(); }