diff --git a/src/math/lp/lp_settings.h b/src/math/lp/lp_settings.h index 727bc3531..7ffffe5c5 100644 --- a/src/math/lp/lp_settings.h +++ b/src/math/lp/lp_settings.h @@ -115,6 +115,7 @@ struct statistics { unsigned m_hnf_cutter_calls; unsigned m_hnf_cuts; unsigned m_nla_calls; + unsigned m_nla_bounds; unsigned m_horner_calls; unsigned m_horner_conflicts; unsigned m_cross_nested_forms; @@ -144,7 +145,7 @@ struct statistics { st.update("arith-grobner-conflicts", m_grobner_conflicts); st.update("arith-offset-eqs", m_offset_eqs); st.update("arith-fixed-eqs", m_fixed_eqs); - + st.update("arith-nla-bounds", m_nla_bounds); } }; diff --git a/src/math/lp/nla_core.cpp b/src/math/lp/nla_core.cpp index 49691d6b8..1b34272df 100644 --- a/src/math/lp/nla_core.cpp +++ b/src/math/lp/nla_core.cpp @@ -809,6 +809,7 @@ void core::print_stats(std::ostream& out) { void core::clear() { m_lemma_vec->clear(); + m_literal_vec->clear(); } void core::init_search() { @@ -1501,12 +1502,28 @@ void core::check_bounded_divisions(vector& l_vec) { m_lemma_vec = &l_vec; m_divisions.check_bounded_divisions(); } +// looking for a free variable inside of a monic to split +void core::add_bounds() { + unsigned r = random(), sz = m_to_refine.size(); + for (unsigned k = 0; k < sz; k++) { + lpvar i = m_to_refine[(k + r) % sz]; + auto const& m = m_emons[i]; + for (lpvar j : m.vars()) { + if (!var_is_free(j)) continue; + // split the free variable (j <= 0, or j > 0), and return + m_literal_vec->push_back(ineq(j, lp::lconstraint_kind::EQ, rational::zero())); + ++lp_settings().stats().m_nla_bounds; + return; + } + } +} -lbool core::check(vector& l_vec) { +lbool core::check(vector& lits, vector& l_vec) { lp_settings().stats().m_nla_calls++; TRACE("nla_solver", tout << "calls = " << lp_settings().stats().m_nla_calls << "\n";); lra.get_rid_of_inf_eps(); m_lemma_vec = &l_vec; + m_literal_vec = &lits; if (!(lra.get_status() == lp::lp_status::OPTIMAL || lra.get_status() == lp::lp_status::FEASIBLE)) { TRACE("nla_solver", tout << "unknown because of the lra.m_status = " << lra.get_status() << "\n";); @@ -1516,43 +1533,47 @@ lbool core::check(vector& l_vec) { init_to_refine(); patch_monomials(); set_use_nra_model(false); - if (m_to_refine.empty()) { return l_true; } + if (m_to_refine.empty()) + return l_true; init_search(); lbool ret = l_undef; bool run_grobner = need_run_grobner(); bool run_horner = need_run_horner(); bool run_bounded_nlsat = should_run_bounded_nlsat(); + bool run_bounds = params().arith_nl_branching(); - if (l_vec.empty() && !done()) + auto no_effect = [&]() { return !done() && l_vec.empty() && lits.empty(); }; + + if (no_effect()) m_monomial_bounds(); if (l_vec.empty() && !done() && improve_bounds()) return l_false; - if (l_vec.empty() && !done() && run_horner) - m_horner.horner_lemmas(); + { + std::function check1 = [&]() { if (no_effect() && run_horner) m_horner.horner_lemmas(); }; + std::function check2 = [&]() { if (no_effect() && run_grobner) m_grobner(); }; + std::function check3 = [&]() { if (no_effect() && run_bounds) add_bounds(); }; - if (l_vec.empty() && !done() && run_grobner) - m_grobner(); - - if (l_vec.empty() && !done()) + std::pair> checks[] = + { {1, check1}, + {1, check2}, + {1, check3} }; + check_weighted(3, checks); + if (!l_vec.empty() || !lits.empty()) + return l_false; + } + + if (no_effect()) m_basics.basic_lemma(true); - if (l_vec.empty() && !done()) + if (no_effect()) m_basics.basic_lemma(false); - if (l_vec.empty() && !done()) + if (no_effect()) m_divisions.check(); -#if 0 - if (l_vec.empty() && !done() && !run_horner) - m_horner.horner_lemmas(); - - if (l_vec.empty() && !done() && !run_grobner) - m_grobner(); -#endif - if (!conflict_found() && !done() && run_bounded_nlsat) ret = bounded_nlsat(); @@ -1636,8 +1657,9 @@ bool core::no_lemmas_hold() const { } lbool core::test_check(vector& l) { + vector lits; lra.set_status(lp::lp_status::OPTIMAL); - return check(l); + return check(lits, l); } std::ostream& core::print_terms(std::ostream& out) const { diff --git a/src/math/lp/nla_core.h b/src/math/lp/nla_core.h index 0bd4a5814..e2a66c321 100644 --- a/src/math/lp/nla_core.h +++ b/src/math/lp/nla_core.h @@ -85,6 +85,7 @@ class core { smt_params_helper m_params; std::function m_relevant; vector * m_lemma_vec; + vector * m_literal_vec = nullptr; indexed_uint_set m_to_refine; tangents m_tangents; basics m_basics; @@ -110,7 +111,7 @@ class core { monic const* m_patched_monic = nullptr; void check_weighted(unsigned sz, std::pair>* checks); - + void add_bounds(); // try to improve bounds for variables in monomials. bool improve_bounds(); @@ -384,7 +385,7 @@ public: bool conflict_found() const; - lbool check(vector& l_vec); + lbool check(vector& ineqs, vector& l_vec); lbool check_power(lpvar r, lpvar x, lpvar y, vector& l_vec); void check_bounded_divisions(vector&); diff --git a/src/math/lp/nla_solver.cpp b/src/math/lp/nla_solver.cpp index 0e6efd526..ccc7b6073 100644 --- a/src/math/lp/nla_solver.cpp +++ b/src/math/lp/nla_solver.cpp @@ -42,8 +42,8 @@ namespace nla { bool solver::need_check() { return m_core->has_relevant_monomial(); } - lbool solver::check(vector& l) { - return m_core->check(l); + lbool solver::check(vector& lits, vector& lemmas) { + return m_core->check(lits, lemmas); } void solver::push(){ diff --git a/src/math/lp/nla_solver.h b/src/math/lp/nla_solver.h index d61b0593b..c1ad5f32a 100644 --- a/src/math/lp/nla_solver.h +++ b/src/math/lp/nla_solver.h @@ -36,7 +36,7 @@ namespace nla { void push(); void pop(unsigned scopes); bool need_check(); - lbool check(vector&); + lbool check(vector& lits, vector&); lbool check_power(lpvar r, lpvar x, lpvar y, vector&); bool is_monic_var(lpvar) const; bool influences_nl_var(lpvar) const; diff --git a/src/sat/smt/arith_solver.cpp b/src/sat/smt/arith_solver.cpp index bda3bebcd..6d3768176 100644 --- a/src/sat/smt/arith_solver.cpp +++ b/src/sat/smt/arith_solver.cpp @@ -1410,30 +1410,43 @@ namespace arith { m_lemma = l; //todo avoid the copy m_explanation = l.expl(); literal_vector core; - for (auto const& ineq : m_lemma.ineqs()) { - bool is_lower = true, pos = true, is_eq = false; - switch (ineq.cmp()) { - case lp::LE: is_lower = false; pos = false; break; - case lp::LT: is_lower = true; pos = true; break; - case lp::GE: is_lower = true; pos = false; break; - case lp::GT: is_lower = false; pos = true; break; - case lp::EQ: is_eq = true; pos = false; break; - case lp::NE: is_eq = true; pos = true; break; - default: UNREACHABLE(); - } - TRACE("arith", tout << "is_lower: " << is_lower << " pos " << pos << "\n";); - // TBD utility: lp::lar_term term = mk_term(ineq.m_poly); - // then term is used instead of ineq.m_term - sat::literal lit; - if (is_eq) - lit = mk_eq(ineq.term(), ineq.rs()); - else - lit = ctx.expr2literal(mk_bound(ineq.term(), ineq.rs(), is_lower)); - core.push_back(pos ? lit : ~lit); - } + for (auto const& ineq : m_lemma.ineqs()) + core.push_back(mk_ineq_literal(ineq)); set_conflict_or_lemma(hint_type::nla_h, core, false); } + void solver::assume_literals() { + for (auto const& ineq : m_nla_literals) { + auto lit = mk_ineq_literal(ineq); + ctx.mark_relevant(lit); + s().set_phase(lit); + } + } + + sat::literal solver::mk_ineq_literal(nla::ineq const& ineq) { + bool is_lower = true, pos = true, is_eq = false; + switch (ineq.cmp()) { + case lp::LE: is_lower = false; pos = false; break; + case lp::LT: is_lower = true; pos = true; break; + case lp::GE: is_lower = true; pos = false; break; + case lp::GT: is_lower = false; pos = true; break; + case lp::EQ: is_eq = true; pos = false; break; + case lp::NE: is_eq = true; pos = true; break; + default: UNREACHABLE(); + } + TRACE("arith", tout << "is_lower: " << is_lower << " pos " << pos << "\n";); + // TBD utility: lp::lar_term term = mk_term(ineq.m_poly); + // then term is used instead of ineq.m_term + sat::literal lit; + if (is_eq) + lit = mk_eq(ineq.term(), ineq.rs()); + else + lit = ctx.expr2literal(mk_bound(ineq.term(), ineq.rs(), is_lower)); + + return pos ? lit : ~lit; + } + + lbool solver::check_nla() { if (!m.inc()) { TRACE("arith", tout << "canceled\n";); @@ -1446,9 +1459,10 @@ namespace arith { return l_true; m_a1 = nullptr; m_a2 = nullptr; - lbool r = m_nla->check(m_nla_lemma_vector); + lbool r = m_nla->check(m_nla_literals, m_nla_lemma_vector); switch (r) { - case l_false: + case l_false: + assume_literals(); for (const nla::lemma& l : m_nla_lemma_vector) false_case_of_check_nla(l); break; diff --git a/src/sat/smt/arith_solver.h b/src/sat/smt/arith_solver.h index 1b6f58782..e23162393 100644 --- a/src/sat/smt/arith_solver.h +++ b/src/sat/smt/arith_solver.h @@ -249,6 +249,7 @@ namespace arith { // lemmas lp::explanation m_explanation; vector m_nla_lemma_vector; + vector m_nla_literals; literal_vector m_core, m_core2; vector m_coeffs; svector m_eqs; @@ -463,6 +464,8 @@ namespace arith { void set_evidence(lp::constraint_index idx); void assign(literal lit, literal_vector const& core, svector const& eqs, euf::th_proof_hint const* pma); + void assume_literals(); + sat::literal mk_ineq_literal(nla::ineq const& ineq); void false_case_of_check_nla(const nla::lemma& l); void dbg_finalize_model(model& mdl); diff --git a/src/smt/theory_lra.cpp b/src/smt/theory_lra.cpp index 08fa39864..9e9e2d730 100644 --- a/src/smt/theory_lra.cpp +++ b/src/smt/theory_lra.cpp @@ -1979,44 +1979,55 @@ public: } nla::lemma m_lemma; - + + literal mk_literal(nla::ineq const& ineq) { + bool is_lower = true, pos = true, is_eq = false; + switch (ineq.cmp()) { + case lp::LE: is_lower = false; pos = false; break; + case lp::LT: is_lower = true; pos = true; break; + case lp::GE: is_lower = true; pos = false; break; + case lp::GT: is_lower = false; pos = true; break; + case lp::EQ: is_eq = true; pos = false; break; + case lp::NE: is_eq = true; pos = true; break; + default: UNREACHABLE(); + } + TRACE("arith", tout << "is_lower: " << is_lower << " pos " << pos << "\n";); + app_ref atom(m); + // TBD utility: lp::lar_term term = mk_term(ineq.m_poly); + // then term is used instead of ineq.m_term + if (is_eq) + atom = mk_eq(ineq.term(), ineq.rs()); + else + // create term >= 0 (or term <= 0) + atom = mk_bound(ineq.term(), ineq.rs(), is_lower); + return literal(ctx().get_bool_var(atom), pos); + } + void false_case_of_check_nla(const nla::lemma & l) { m_lemma = l; //todo avoid the copy m_explanation = l.expl(); literal_vector core; for (auto const& ineq : m_lemma.ineqs()) { - bool is_lower = true, pos = true, is_eq = false; - switch (ineq.cmp()) { - case lp::LE: is_lower = false; pos = false; break; - case lp::LT: is_lower = true; pos = true; break; - case lp::GE: is_lower = true; pos = false; break; - case lp::GT: is_lower = false; pos = true; break; - case lp::EQ: is_eq = true; pos = false; break; - case lp::NE: is_eq = true; pos = true; break; - default: UNREACHABLE(); - } - TRACE("arith", tout << "is_lower: " << is_lower << " pos " << pos << "\n";); - app_ref atom(m); - // TBD utility: lp::lar_term term = mk_term(ineq.m_poly); - // then term is used instead of ineq.m_term - if (is_eq) { - atom = mk_eq(ineq.term(), ineq.rs()); - } - else { - // create term >= 0 (or term <= 0) - atom = mk_bound(ineq.term(), ineq.rs(), is_lower); - } - literal lit(ctx().get_bool_var(atom), pos); + auto lit = mk_literal(ineq); core.push_back(~lit); } set_conflict_or_lemma(core, false); } + + void assume_literal(nla::ineq const& i) { + auto lit = mk_literal(i); + ctx().mark_as_relevant(lit); + ctx().set_true_first_flag(lit.var()); + } final_check_status check_nla_continue() { m_a1 = nullptr; m_a2 = nullptr; - lbool r = m_nla->check(m_nla_lemma_vector); + lbool r = m_nla->check(m_nla_literals, m_nla_lemma_vector); + switch (r) { - case l_false: + case l_false: + for (const nla::ineq& i : m_nla_literals) + assume_literal(i); for (const nla::lemma & l : m_nla_lemma_vector) false_case_of_check_nla(l); return FC_CONTINUE; @@ -3173,6 +3184,7 @@ public: lp::explanation m_explanation; vector m_nla_lemma_vector; + vector m_nla_literals; literal_vector m_core; svector m_eqs; vector m_params;