diff --git a/src/math/dd/dd_pdd.cpp b/src/math/dd/dd_pdd.cpp index d0acd82c1..0c310322a 100644 --- a/src/math/dd/dd_pdd.cpp +++ b/src/math/dd/dd_pdd.cpp @@ -171,32 +171,24 @@ namespace dd { } break; case pdd_sub_op: - if (is_val(p)) { + if (is_val(p) || (!is_val(q) && level_p < level_q)) { + // p - (ax + b) = -ax + (p - b) push(apply_rec(p, lo(q), op)); - r = make_node(level_q, read(1), hi(q)); - npop = 1; + push(minus_rec(hi(q))); + r = make_node(level_q, read(2), read(1)); } - else if (is_val(q)) { - push(apply_rec(lo(p), q, op)); - r = make_node(level_p, read(1), hi(p)); - npop = 1; - } - else if (level_p == level_q) { - push(apply_rec(lo(p), lo(q), op)); - push(apply_rec(hi(p), hi(q), op)); - r = make_node(level_p, read(2), read(1)); - } - else if (level_p > level_q) { - // x*hi(p) + (lo(p) - q) + else if (is_val(q) || (level_p > level_q)) { + // (ax + b) - k = ax + (b - k) push(apply_rec(lo(p), q, op)); r = make_node(level_p, read(1), hi(p)); npop = 1; } else { - // x*hi(q) + (p - lo(q)) - push(apply_rec(p, lo(q), op)); - r = make_node(level_q, read(1), hi(q)); - npop = 1; + SASSERT(level_p == level_q); + // (ax + b) - (cx + d) = (a - c)x + (b - d) + push(apply_rec(lo(p), lo(q), op)); + push(apply_rec(hi(p), hi(q), op)); + r = make_node(level_p, read(2), read(1)); } break; case pdd_mul_op: diff --git a/src/math/grobner/pdd_grobner.cpp b/src/math/grobner/pdd_grobner.cpp index 3a102c044..4cb4c1779 100644 --- a/src/math/grobner/pdd_grobner.cpp +++ b/src/math/grobner/pdd_grobner.cpp @@ -12,6 +12,7 @@ --*/ #include "math/grobner/pdd_grobner.h" +#include "util/uint_set.h" namespace dd { @@ -125,6 +126,7 @@ namespace dd { } void grobner::saturate() { + simplify(); if (is_tuned()) tuned_init(); try { while (!done() && step()) { @@ -214,7 +216,10 @@ namespace dd { equation_vector linear; for (equation* e : m_to_simplify) { pdd p = e->poly(); - if (p.is_linear() && (!binary || p.is_binary())) { + if (binary) { + if (p.is_binary()) linear.push_back(e); + } + else if (p.is_linear()) { linear.push_back(e); } } @@ -798,10 +803,17 @@ namespace dd { } i = 0; + uint_set head_vars; for (auto* e : m_solved) { VERIFY(e->state() == solved); VERIFY(e->idx() == i); ++i; + pdd p = e->poly(); + if (!p.is_val() && p.hi().is_val()) { + unsigned v = p.var(); + SASSERT(!head_vars.contains(v)); + head_vars.insert(v); + } } // equations in to_simplify have correct indices @@ -811,11 +823,9 @@ namespace dd { for (auto* e : m_to_simplify) { VERIFY(e->idx() == i); VERIFY(e->state() == to_simplify); - VERIFY(!e->poly().is_val()); - if (is_tuned()) { - pdd const& p = e->poly(); - VERIFY(p.is_val() || m_watch[p.var()].contains(e)); - } + pdd const& p = e->poly(); + VERIFY(!p.is_val()); + VERIFY(!is_tuned() || m_watch[p.var()].contains(e)); ++i; } // the watch list consists of equations in to_simplify diff --git a/src/test/pdd_grobner.cpp b/src/test/pdd_grobner.cpp index 7a60426c1..0e3687e01 100644 --- a/src/test/pdd_grobner.cpp +++ b/src/test/pdd_grobner.cpp @@ -28,21 +28,23 @@ namespace dd { grobner gb(lim, m); gb.add(v1*v2 + v1*v3); gb.add(v1 - 1); - print_eqs(gb.equations()); + gb.display(std::cout); gb.saturate(); - print_eqs(gb.equations()); + gb.display(std::cout); gb.reset(); gb.add(v1*v1*v2 + v2*v3); gb.add(v1*v1*v2 + v2*v1); + gb.display(std::cout); gb.saturate(); - print_eqs(gb.equations()); + gb.display(std::cout); gb.reset(); gb.add(v1*v1*v2 + v2*v1); gb.add(v1*v1*v2 + v2*v1); + gb.display(std::cout); gb.saturate(); - print_eqs(gb.equations()); + gb.display(std::cout); gb.reset(); // stop early on contradiction @@ -52,7 +54,7 @@ namespace dd { gb.add(v3*v1 + v1*v2 + v2*v3 + v1); gb.add(v3*v1 + v1*v2 + v2*v3 + v2); gb.saturate(); - print_eqs(gb.equations()); + gb.display(std::cout << "early contradiction\n"); gb.reset(); // result is v1 = 0, v2 = 0. @@ -61,15 +63,15 @@ namespace dd { gb.add(v3*v1 + v1*v2 + v2*v3 + v1); gb.add(v3*v1 + v1*v2 + v2*v3 + v2); gb.saturate(); - print_eqs(gb.equations()); + gb.display(std::cout << "v1 = v2 = 0\n"); gb.reset(); // everything rewrites to a multiple of v0 gb.add(v3 - 2*v2); gb.add(v2 - 2*v1); gb.add(v1 - 2*v0); - gb.saturate(); - print_eqs(gb.equations()); + gb.saturate(); + gb.display(std::cout << "multiple of v0\n"); gb.reset(); //