From c4963f4381baaaed29fd991542e702ec108def21 Mon Sep 17 00:00:00 2001 From: Jakob Rath Date: Fri, 18 Jun 2021 17:48:50 +0200 Subject: [PATCH] Polysat: add two more prototype rules (#5355) * Add try_div to PDDs * x>y is false when x==y * First version of the other two prototype rules * More band-aid fixes... --- src/math/dd/dd_pdd.cpp | 34 ++++-- src/math/dd/dd_pdd.h | 2 + src/math/polysat/explain.cpp | 179 +++++++++++++++++++++++++++- src/math/polysat/solver.cpp | 19 +++ src/math/polysat/solver.h | 1 + src/math/polysat/ule_constraint.cpp | 8 +- src/test/pdd.cpp | 13 ++ 7 files changed, 239 insertions(+), 17 deletions(-) diff --git a/src/math/dd/dd_pdd.cpp b/src/math/dd/dd_pdd.cpp index d70c090b4..24b93817a 100644 --- a/src/math/dd/dd_pdd.cpp +++ b/src/math/dd/dd_pdd.cpp @@ -461,10 +461,17 @@ namespace dd { * But such a multiplication would create nodes with non-integral coefficients. */ pdd pdd_manager::div(pdd const& a, rational const& c) { + pdd res(zero_pdd, this); + VERIFY(try_div(a, c, res)); + return res; + } + + bool pdd_manager::try_div(pdd const& a, rational const& c, pdd& out_result) { if (m_semantics == free_e) { // Don't cache separately for the free semantics; // use 'mul' so we can share results for a/c and a*(1/c). - return mul(inv(c), a); + out_result = mul(inv(c), a); + return true; } SASSERT(c.is_int()); bool first = true; @@ -472,7 +479,11 @@ namespace dd { scoped_push _sp(*this); while (true) { try { - return pdd(div_rec(a.root, c, null_pdd), this); + PDD res = div_rec(a.root, c, null_pdd); + if (res != null_pdd) + out_result = pdd(res, this); + SASSERT(well_formed()); + return res != null_pdd; } catch (const mem_out &) { try_gc(); @@ -480,10 +491,9 @@ namespace dd { first = false; } } - SASSERT(well_formed()); - return pdd(zero_pdd, this); } + /// Returns null_pdd if one of the coefficients is not divisible by c. pdd_manager::PDD pdd_manager::div_rec(PDD a, rational const& c, PDD c_pdd) { SASSERT(m_semantics != free_e); SASSERT(c.is_int()); @@ -491,8 +501,10 @@ namespace dd { return zero_pdd; if (is_val(a)) { rational r = val(a) / c; - SASSERT(r.is_int()); - return imk_val(r); + if (r.is_int()) + return imk_val(r); + else + return null_pdd; } if (c_pdd == null_pdd) c_pdd = imk_val(c); @@ -502,10 +514,14 @@ namespace dd { return e2->m_result; push(div_rec(lo(a), c, c_pdd)); push(div_rec(hi(a), c, c_pdd)); - PDD r = make_node(level(a), read(2), read(1)); + PDD l = read(2); + PDD h = read(1); + PDD res = null_pdd; + if (l != null_pdd && h != null_pdd) + res = make_node(level(a), l, h); pop(2); - e1->m_result = r; - return r; + e1->m_result = res; + return res; } pdd pdd_manager::pow(pdd const &p, unsigned j) { diff --git a/src/math/dd/dd_pdd.h b/src/math/dd/dd_pdd.h index e5024e0c2..553ad004b 100644 --- a/src/math/dd/dd_pdd.h +++ b/src/math/dd/dd_pdd.h @@ -322,6 +322,7 @@ namespace dd { pdd mul(pdd const& a, pdd const& b); pdd mul(rational const& c, pdd const& b); pdd div(pdd const& a, rational const& c); + bool try_div(pdd const& a, rational const& c, pdd& out_result); pdd mk_or(pdd const& p, pdd const& q); pdd mk_xor(pdd const& p, pdd const& q); pdd mk_xor(pdd const& p, unsigned q); @@ -408,6 +409,7 @@ namespace dd { pdd operator~() const { return m.mk_not(*this); } pdd rev_sub(rational const& r) const { return m.sub(m.mk_val(r), *this); } pdd div(rational const& other) const { return m.div(*this, other); } + bool try_div(rational const& other, pdd& out_result) const { return m.try_div(*this, other, out_result); } pdd pow(unsigned j) const { return m.pow(*this, j); } pdd reduce(pdd const& other) const { return m.reduce(*this, other); } bool different_leading_term(pdd const& other) const { return m.different_leading_term(*this, other); } diff --git a/src/math/polysat/explain.cpp b/src/math/polysat/explain.cpp index 01866fd5a..6280f5708 100644 --- a/src/math/polysat/explain.cpp +++ b/src/math/polysat/explain.cpp @@ -40,6 +40,9 @@ namespace polysat { LOG("New constraint: " << show_deref(c)); } } + else { + LOG("No lemma"); + } m_var = null_var; m_cjust_v.reset(); @@ -72,7 +75,7 @@ namespace polysat { /// [x] zx > yx ==> ... clause_ref conflict_explainer::by_ugt_x() { - LOG_H3("Try zx > yx"); + LOG_H3("Try zx > yx where x := v" << m_var); for (auto* c : m_conflict.units()) LOG("Constraint: " << show_deref(c)); for (auto* c : m_conflict.clauses()) @@ -82,8 +85,8 @@ namespace polysat { for (auto* c : m_conflict.units()) { if (!c->is_ule()) continue; - pdd lhs = c->to_ule().lhs(); - pdd rhs = c->to_ule().rhs(); + pdd const& lhs = c->to_ule().lhs(); + pdd const& rhs = c->to_ule().rhs(); if (lhs.degree(m_var) != 1) continue; if (rhs.degree(m_var) != 1) @@ -127,13 +130,179 @@ namespace polysat { return nullptr; } - /// [y] y >= z' /\ zx > yx ==> ... + /// [y] z' <= y /\ zx > yx ==> ... clause_ref conflict_explainer::by_ugt_y() { + LOG_H3("Try z' <= y && zx > yx where y := v" << m_var); + for (auto* c : m_conflict.units()) + LOG("Constraint: " << show_deref(c)); + for (auto* c : m_conflict.clauses()) + LOG("Clause: " << show_deref(c)); + + pdd const y = m_solver.var(m_var); + + // Collect constraints of shape "_ <= y" + ptr_vector ds; + for (auto* d : m_conflict.units()) { + if (!d->is_ule()) + continue; + if (!d->is_positive()) + continue; + pdd const& rhs = d->to_ule().rhs(); + // TODO: a*y where 'a' divides 'x' should also be easy to handle (assuming for now they're numbers) + // TODO: also z' < y should follow the same pattern. + if (rhs != y) + continue; + LOG("z' <= y candidate: " << show_deref(d)); + ds.push_back(d); + } + if (ds.empty()) + return nullptr; + + // Find constraint of shape: zx > yx + for (auto* c : m_conflict.units()) { + if (!c->is_ule()) + continue; + pdd const& lhs = c->to_ule().lhs(); + pdd const& rhs = c->to_ule().rhs(); + if (rhs.degree(m_var) != 1) + continue; + pdd x = lhs; + pdd rest = lhs; + rhs.factor(m_var, 1, x, rest); + if (!rest.is_zero()) + continue; + // TODO: in principle, 'x' could be any polynomial. However, we need to divide the lhs by x, and we don't have general polynomial division yet. + // so for now we just allow the form 'value*variable'. + // (extension to arbitrary monomials for 'x' should be fairly easy too) + if (!x.is_unary()) + continue; + unsigned x_var = x.var(); + rational x_coeff = x.hi().val(); + pdd xz = lhs; + if (!lhs.try_div(x_coeff, xz)) + continue; + pdd z = lhs; + xz.factor(x_var, 1, z, rest); + if (!rest.is_zero()) + continue; + + unsigned const lvl = c->level(); + if (c->is_positive()) { + // zx <= yx + NOT_IMPLEMENTED_YET(); + } + else { + SASSERT(c->is_negative()); + // zx > yx + + LOG("zx > yx: " << show_deref(c)); + + // TODO: for now, we just choose the first of the other constraints + constraint* d = ds[0]; + SASSERT(d->is_ule() && d->is_positive()); + pdd const& z_prime = d->to_ule().lhs(); + + unsigned const p = m_solver.size(m_var); + + clause_builder clause(m_solver); + // Omega^*(x, y) + push_omega_mul(clause, lvl, p, x, y); + // zx > z'x + constraint_ref zx_gt_zpx = m_solver.m_constraints.ult(lvl, pos_t, z*x, z_prime*x, null_dep()); + LOG("zx>z'x: " << show_deref(zx_gt_zpx)); + clause.push_new_constraint(std::move(zx_gt_zpx)); + + return clause.build(lvl, {c->dep(), m_solver.m_dm}); + } + } return nullptr; } - /// [z] y' >= z /\ zx > yx ==> ... + /// [z] z <= y' /\ zx > yx ==> ... clause_ref conflict_explainer::by_ugt_z() { + LOG_H3("Try z <= y' && zx > yx where z := v" << m_var); + for (auto* c : m_conflict.units()) + LOG("Constraint: " << show_deref(c)); + for (auto* c : m_conflict.clauses()) + LOG("Clause: " << show_deref(c)); + + pdd const z = m_solver.var(m_var); + + // Collect constraints of shape "z <= _" + ptr_vector ds; + for (auto* d : m_conflict.units()) { + if (!d->is_ule()) + continue; + if (!d->is_positive()) + continue; + pdd const& lhs = d->to_ule().lhs(); + // TODO: a*y where 'a' divides 'x' should also be easy to handle (assuming for now they're numbers) + // TODO: also z < y' should follow the same pattern. + if (lhs != z) + continue; + LOG("z <= y' candidate: " << show_deref(d)); + ds.push_back(d); + } + if (ds.empty()) + return nullptr; + + // Find constraint of shape: zx > yx + for (auto* c : m_conflict.units()) { + if (!c->is_ule()) + continue; + pdd const& lhs = c->to_ule().lhs(); + pdd const& rhs = c->to_ule().rhs(); + if (lhs.degree(m_var) != 1) + continue; + pdd x = lhs; + pdd rest = lhs; + lhs.factor(m_var, 1, x, rest); + if (!rest.is_zero()) + continue; + // TODO: in principle, 'x' could be any polynomial. However, we need to divide the lhs by x, and we don't have general polynomial division yet. + // so for now we just allow the form 'value*variable'. + // (extension to arbitrary monomials for 'x' should be fairly easy too) + if (!x.is_unary()) + continue; + unsigned x_var = x.var(); + rational x_coeff = x.hi().val(); + pdd xy = lhs; + if (!rhs.try_div(x_coeff, xy)) + continue; + pdd y = lhs; + xy.factor(x_var, 1, y, rest); + if (!rest.is_zero()) + continue; + + unsigned const lvl = c->level(); + if (c->is_positive()) { + // zx <= yx + NOT_IMPLEMENTED_YET(); + } + else { + SASSERT(c->is_negative()); + // zx > yx + + LOG("zx > yx: " << show_deref(c)); + + // TODO: for now, we just choose the first of the other constraints + constraint* d = ds[0]; + SASSERT(d->is_ule() && d->is_positive()); + pdd const& y_prime = d->to_ule().rhs(); + + unsigned const p = m_solver.size(m_var); + + clause_builder clause(m_solver); + // Omega^*(x, y') + push_omega_mul(clause, lvl, p, x, y_prime); + // y'x > yx + constraint_ref ypx_gt_yx = m_solver.m_constraints.ult(lvl, pos_t, y_prime*x, y*x, null_dep()); + LOG("y'x>yx: " << show_deref(ypx_gt_yx)); + clause.push_new_constraint(std::move(ypx_gt_yx)); + + return clause.build(lvl, {c->dep(), m_solver.m_dm}); + } + } return nullptr; } diff --git a/src/math/polysat/solver.cpp b/src/math/polysat/solver.cpp index 8203a0737..0711a32d6 100644 --- a/src/math/polysat/solver.cpp +++ b/src/math/polysat/solver.cpp @@ -121,6 +121,7 @@ namespace polysat { LOG("Starting"); m_disjunctive_lemma.reset(); while (m_lim.inc()) { + m_stats.m_num_iterations++; LOG_H1("Next solving loop iteration"); LOG("Free variables: " << m_free_vars); LOG("Assignments: " << assignment()); @@ -622,6 +623,10 @@ namespace polysat { } SASSERT(m_bvars.is_propagation(var)); clause_ref new_lemma = resolve_bool(lit); + if (!new_lemma) { + backtrack(i, lemma); + return; + } SASSERT(new_lemma); LOG("new_lemma: " << show_deref(new_lemma)); LOG("new_lemma is always false: " << new_lemma->is_always_false(*this)); @@ -905,6 +910,7 @@ namespace polysat { // - We have a conflict but we don't know. It will be discovered when y and z are assigned, // and then may lead to an assertion failure through this call to narrow. // TODO: what to do with "unassigned" constraints at this point? (we probably should have resolved those away, even in the 'backtrack' case.) + // NOTE: they are constraints from clauses that were added to cjust… how to deal with that? should we add the whole clause to cjust? if (!c->is_undef()) // TODO: this check to be removed once this is fixed properly. c->narrow(*this); if (is_conflict()) { @@ -963,6 +969,8 @@ namespace polysat { for (auto c : m_conflict.units()) { if (c->bvar() == var) continue; + if (c->is_undef()) // TODO: see revert_decision for a note on this. + continue; reason_lvl = std::max(reason_lvl, c->level()); reason_dep = m_dm.mk_join(reason_dep, c->dep()); reason_lits.push_back(c->blit()); @@ -981,13 +989,23 @@ namespace polysat { continue; // NOTE: in general, narrow may change the conflict. // But since we just backjumped, narrowing should not result in an additional conflict. + if (c->is_undef()) // TODO: see revert_decision for a note on this. + continue; c->narrow(*this); + if (is_conflict()) { + LOG_H1("Conflict during revert_bool_decision/narrow!"); + return; + } } m_conflict.reset(); clause* reason_cl = reason.get(); add_lemma_clause(std::move(reason)); propagate_bool(~lit, reason_cl); + if (is_conflict()) { + LOG_H1("Conflict during revert_bool_decision/propagate_bool!"); + return; + } decide_bool(*lemma); } @@ -1202,6 +1220,7 @@ namespace polysat { } void solver::collect_statistics(statistics& st) const { + st.update("polysat iterations", m_stats.m_num_iterations); st.update("polysat decisions", m_stats.m_num_decisions); st.update("polysat conflicts", m_stats.m_num_conflicts); st.update("polysat propagations", m_stats.m_num_propagations); diff --git a/src/math/polysat/solver.h b/src/math/polysat/solver.h index 11042f764..730fd3e27 100644 --- a/src/math/polysat/solver.h +++ b/src/math/polysat/solver.h @@ -35,6 +35,7 @@ namespace polysat { class solver { struct stats { + unsigned m_num_iterations; unsigned m_num_decisions; unsigned m_num_propagations; unsigned m_num_conflicts; diff --git a/src/math/polysat/ule_constraint.cpp b/src/math/polysat/ule_constraint.cpp index d73d73c57..d798b5287 100644 --- a/src/math/polysat/ule_constraint.cpp +++ b/src/math/polysat/ule_constraint.cpp @@ -28,13 +28,15 @@ namespace polysat { } void ule_constraint::narrow(solver& s) { - SASSERT(!is_undef()); + LOG_H3("Narrowing " << *this); LOG("Assignment: " << s.assignment()); auto p = lhs().subst_val(s.assignment()); LOG("Substituted LHS: " << lhs() << " := " << p); auto q = rhs().subst_val(s.assignment()); LOG("Substituted RHS: " << rhs() << " := " << q); + SASSERT(!is_undef()); + if (is_always_false(p, q)) { s.set_conflict(*this); return; @@ -95,8 +97,8 @@ namespace polysat { VERIFY(!is_undef()); if (is_positive()) return lhs.is_val() && rhs.is_val() && lhs.val() > rhs.val(); - else - return lhs.is_val() && rhs.is_val() && lhs.val() <= rhs.val(); + else + return (lhs.is_val() && rhs.is_val() && lhs.val() <= rhs.val()) || (lhs == rhs); } bool ule_constraint::is_always_false() { diff --git a/src/test/pdd.cpp b/src/test/pdd.cpp index 11ddd4f54..bc5ecbd5f 100644 --- a/src/test/pdd.cpp +++ b/src/test/pdd.cpp @@ -448,6 +448,18 @@ public : SASSERT((p + b*b*b).max_pow2_divisor() == 0); } + static void try_div() { + std::cout << "try_div\n"; + pdd_manager m(4, pdd_manager::mod2N_e, 256); + pdd const a = m.mk_var(0); + pdd const b = m.mk_var(1); + + pdd const p = 5*a + 15*a*b; + SASSERT_EQ(p.div(rational(5)), a + 3*a*b); + pdd res = a; + SASSERT(!p.try_div(rational(3), res)); + } + static void binary_resolve() { std::cout << "binary resolve\n"; pdd_manager m(4, pdd_manager::mod2N_e, 4); @@ -575,6 +587,7 @@ void tst_pdd() { dd::test::degree_of_variables(); dd::test::factor(); dd::test::max_pow2_divisor(); + dd::test::try_div(); dd::test::binary_resolve(); dd::test::pow(); dd::test::subst_val();