diff --git a/src/math/lp/nla_core.h b/src/math/lp/nla_core.h index bf9252450..af8adcbc6 100644 --- a/src/math/lp/nla_core.h +++ b/src/math/lp/nla_core.h @@ -218,6 +218,7 @@ public: void add_idivision(lpvar q, lpvar x, lpvar y, lpvar r) { m_divisions.add_idivision(q, x, y, r); } void add_rdivision(lpvar q, lpvar x, lpvar y, lpvar r) { m_divisions.add_rdivision(q, x, y, r); } void add_bounded_division(lpvar q, lpvar x, lpvar y, lpvar r) { m_divisions.add_bounded_division(q, x, y, r); } + void add_mod_division(lpvar x, lpvar y, lpvar r) { m_divisions.add_mod_division(x, y, r); } void set_add_mul_def_hook(std::function const& f) { m_add_mul_def_hook = f; } lpvar add_mul_def(unsigned sz, lpvar const* vs) { SASSERT(m_add_mul_def_hook); lpvar v = m_add_mul_def_hook(sz, vs); add_monic(v, sz, vs); return v; } diff --git a/src/math/lp/nla_divisions.cpp b/src/math/lp/nla_divisions.cpp index 5b4501e4e..20e0a9b4c 100644 --- a/src/math/lp/nla_divisions.cpp +++ b/src/math/lp/nla_divisions.cpp @@ -41,6 +41,13 @@ namespace nla { m_core.trail().push(push_back_vector(m_bounded_divisions)); } + void divisions::add_mod_division(lpvar x, lpvar y, lpvar r) { + if (x == null_lpvar || y == null_lpvar || r == null_lpvar) + return; + m_mod_divisions.push_back({ x, y, r }); + m_core.trail().push(push_back_vector(m_mod_divisions)); + } + typedef lp::lar_term term; // y1 >= y2 > 0 & x1 <= x2 => x1/y1 <= x2/y2 @@ -205,16 +212,16 @@ namespace nla { } // mod(factor, p) = 0 => mod(factor * k, p) = 0 - // For each division (q, x, y, r) where x is a monic m = f1 * f2 * ... * fk, + // For each mod division (x, y, r) where x is a monic m = f1 * f2 * ... * fk, // if some factor fi has mod(fi, p) = 0 (fixed), then mod(x, p) = 0. void divisions::check_mod_mult() { core& c = m_core; - unsigned offset = c.random(), sz = m_bounded_divisions.size(); + unsigned offset = c.random(), sz = m_mod_divisions.size(); for (unsigned j = 0; j < sz; ++j) { unsigned i = (offset + j) % sz; - auto [q, x, y, r] = m_bounded_divisions[i]; - if (!c.is_relevant(q)) + auto [x, y, r] = m_mod_divisions[i]; + if (!c.is_relevant(r)) continue; if (c.var_is_fixed_to_zero(r)) continue; @@ -227,7 +234,7 @@ namespace nla { continue; auto const& m = c.emons()[x]; for (lpvar f : m.vars()) { - for (auto const& [q2, x2, y2, r2] : m_bounded_divisions) { + for (auto const& [x2, y2, r2] : m_mod_divisions) { if (x2 != f) continue; if (c.val(y2) != yv) diff --git a/src/math/lp/nla_divisions.h b/src/math/lp/nla_divisions.h index 96a50c05a..386ae7673 100644 --- a/src/math/lp/nla_divisions.h +++ b/src/math/lp/nla_divisions.h @@ -25,12 +25,15 @@ namespace nla { vector> m_idivisions; vector> m_rdivisions; vector> m_bounded_divisions; + // mod divisions: (x, y, r) where r = mod(x, y), used by check_mod_mult + vector> m_mod_divisions; public: divisions(core& c):m_core(c) {} void add_idivision(lpvar q, lpvar x, lpvar y, lpvar r); void add_rdivision(lpvar q, lpvar x, lpvar y, lpvar r); void add_bounded_division(lpvar q, lpvar x, lpvar y, lpvar r); + void add_mod_division(lpvar x, lpvar y, lpvar r); void check(); void check_bounded_divisions(); void check_mod_mult(); diff --git a/src/math/lp/nla_solver.cpp b/src/math/lp/nla_solver.cpp index 562143459..688196fed 100644 --- a/src/math/lp/nla_solver.cpp +++ b/src/math/lp/nla_solver.cpp @@ -32,6 +32,10 @@ namespace nla { m_core->add_bounded_division(q, x, y, r); } + void solver::add_mod_division(lpvar x, lpvar y, lpvar r) { + m_core->add_mod_division(x, y, r); + } + void solver::set_relevant(std::function& is_relevant) { m_core->set_relevant(is_relevant); } diff --git a/src/math/lp/nla_solver.h b/src/math/lp/nla_solver.h index 36d136d38..d53acd0c2 100644 --- a/src/math/lp/nla_solver.h +++ b/src/math/lp/nla_solver.h @@ -31,6 +31,7 @@ namespace nla { void add_idivision(lpvar q, lpvar x, lpvar y, lpvar r); void add_rdivision(lpvar q, lpvar x, lpvar y, lpvar r); void add_bounded_division(lpvar q, lpvar x, lpvar y, lpvar r); + void add_mod_division(lpvar x, lpvar y, lpvar r); void check_bounded_divisions(); void set_relevant(std::function& is_relevant); void updt_params(params_ref const& p); diff --git a/src/smt/theory_lra.cpp b/src/smt/theory_lra.cpp index 1fc06b079..06f9fd247 100644 --- a/src/smt/theory_lra.cpp +++ b/src/smt/theory_lra.cpp @@ -175,6 +175,8 @@ class theory_lra::imp { // non-linear arithmetic scoped_ptr m_nla; + // pending mod divisions to register when NLA is created + svector> m_pending_mod_divisions; // integer arithmetic scoped_ptr m_lia; @@ -269,6 +271,9 @@ class theory_lra::imp { m_nla->set_relevant(is_relevant); m_nla->updt_params(ctx().get_params()); m_nla->get_core().set_add_mul_def_hook([&](unsigned sz, lpvar const* vs) { return add_mul_def(sz, vs); }); + for (auto const& [x, y, rv] : m_pending_mod_divisions) + m_nla->add_mod_division(register_theory_var_in_lar_solver(x), register_theory_var_in_lar_solver(y), register_theory_var_in_lar_solver(rv)); + m_pending_mod_divisions.reset(); } } @@ -473,18 +478,16 @@ class theory_lra::imp { if (!a.is_numeral(n2, r) || r.is_zero()) found_underspecified(n); if (!ctx().relevancy()) mk_idiv_mod_axioms(n1, n2); if (a.is_numeral(n2) && !r.is_zero()) { - ensure_nla(); - app_ref div(a.mk_idiv(n1, n2), m); - ctx().internalize(div, false); - internalize_term(to_app(div)); internalize_term(to_app(n1)); internalize_term(to_app(n2)); internalize_term(t); - theory_var q = mk_var(div); theory_var x = mk_var(n1); theory_var y = mk_var(n2); theory_var rv = mk_var(n); - m_nla->add_bounded_division(register_theory_var_in_lar_solver(q), register_theory_var_in_lar_solver(x), register_theory_var_in_lar_solver(y), register_theory_var_in_lar_solver(rv)); + if (m_nla) + m_nla->add_mod_division(register_theory_var_in_lar_solver(x), register_theory_var_in_lar_solver(y), register_theory_var_in_lar_solver(rv)); + else + m_pending_mod_divisions.push_back({x, y, rv}); } } else if (a.is_rem(n, n1, n2)) { diff --git a/src/test/api.cpp b/src/test/api.cpp index 9413a5ada..a7e7f329b 100644 --- a/src/test/api.cpp +++ b/src/test/api.cpp @@ -274,8 +274,8 @@ void test_max_reg() { } #endif - std::cout << "BNH: " << num_sat << "/6 optimizations returned sat" << std::endl; - ENSURE(num_sat == 6); + std::cout << "BNH: " << num_sat << "/2 optimizations returned sat" << std::endl; + ENSURE(num_sat == 2); Z3_del_context(ctx); std::cout << "BNH optimization test done" << std::endl; }