diff --git a/src/math/lp/lp_settings.h b/src/math/lp/lp_settings.h index 727bc3531..7ffffe5c5 100644 --- a/src/math/lp/lp_settings.h +++ b/src/math/lp/lp_settings.h @@ -115,6 +115,7 @@ struct statistics { unsigned m_hnf_cutter_calls; unsigned m_hnf_cuts; unsigned m_nla_calls; + unsigned m_nla_bounds; unsigned m_horner_calls; unsigned m_horner_conflicts; unsigned m_cross_nested_forms; @@ -144,7 +145,7 @@ struct statistics { st.update("arith-grobner-conflicts", m_grobner_conflicts); st.update("arith-offset-eqs", m_offset_eqs); st.update("arith-fixed-eqs", m_fixed_eqs); - + st.update("arith-nla-bounds", m_nla_bounds); } }; diff --git a/src/math/lp/nla_core.cpp b/src/math/lp/nla_core.cpp index 54f955760..d06815410 100644 --- a/src/math/lp/nla_core.cpp +++ b/src/math/lp/nla_core.cpp @@ -811,6 +811,7 @@ void core::print_stats(std::ostream& out) { void core::clear() { m_lemma_vec->clear(); + m_literal_vec->clear(); } void core::init_search() { @@ -1504,11 +1505,62 @@ void core::check_bounded_divisions(vector& l_vec) { m_divisions.check_bounded_divisions(); } -lbool core::check(vector& l_vec) { +bool core::can_add_bound(unsigned j, u_map& bounds) { + unsigned count = 1; + if (bounds.find(j, count)) { + if (count >= 2) + return false; + ++count; + } + bounds.insert(j, count); + struct decrement : public trail { + u_map& bounds; + unsigned j; + decrement(u_map& bounds, unsigned j): + bounds(bounds), + j(j) + {} + void undo() override { + --bounds[j]; + } + }; + trail().push(decrement(bounds, j)); + return true; +} + +void core::add_bounds() { + unsigned r = random(), sz = m_to_refine.size(); + for (unsigned k = 0; k < sz; k++) { + lpvar i = m_to_refine[(k + r) % sz]; + auto const& m = m_emons[i]; + for (lpvar j : m.vars()) { + //m_lar_solver.print_column_info(j, verbose_stream() << "check variable " << j << " ") << "\n"; + if (var_is_free(j)) + m_literal_vec->push_back(ineq(j, lp::lconstraint_kind::EQ, rational::zero())); +#if 0 + else if (has_lower_bound(j) && can_add_bound(j, m_lower_bounds_added)) { + m_literal_vec->push_back(ineq(j, lp::lconstraint_kind::LE, get_lower_bound(j))); + std::cout << "called lower\n"; + } + else if (has_upper_bound(j) && can_add_bound(j, m_upper_bounds_added)) { + m_literal_vec->push_back(ineq(j, lp::lconstraint_kind::GE, get_upper_bound(j))); + std::cout << "called upper\n"; + } +#endif + else + continue; + ++lp_settings().stats().m_nla_bounds; + return; + } + } +} + +lbool core::check(vector& lits, vector& l_vec) { lp_settings().stats().m_nla_calls++; TRACE("nla_solver", tout << "calls = " << lp_settings().stats().m_nla_calls << "\n";); m_lar_solver.get_rid_of_inf_eps(); m_lemma_vec = &l_vec; + m_literal_vec = &lits; if (!(m_lar_solver.get_status() == lp::lp_status::OPTIMAL || m_lar_solver.get_status() == lp::lp_status::FEASIBLE)) { TRACE("nla_solver", tout << "unknown because of the m_lar_solver.m_status = " << m_lar_solver.get_status() << "\n";); @@ -1518,40 +1570,44 @@ lbool core::check(vector& l_vec) { init_to_refine(); patch_monomials(); set_use_nra_model(false); - if (m_to_refine.empty()) { return l_true; } + if (m_to_refine.empty()) + return l_true; init_search(); lbool ret = l_undef; bool run_grobner = need_run_grobner(); bool run_horner = need_run_horner(); bool run_bounded_nlsat = should_run_bounded_nlsat(); + bool run_bounds = params().arith_nl_branching(); - if (l_vec.empty() && !done()) + auto no_effect = [&]() { return !done() && l_vec.empty() && lits.empty(); }; + + if (no_effect()) m_monomial_bounds(); - if (l_vec.empty() && !done() && run_horner) - m_horner.horner_lemmas(); + { + std::function check1 = [&]() { if (no_effect() && run_horner) m_horner.horner_lemmas(); }; + std::function check2 = [&]() { if (no_effect() && run_grobner) m_grobner(); }; + std::function check3 = [&]() { if (no_effect() && run_bounds) add_bounds(); }; - if (l_vec.empty() && !done() && run_grobner) - m_grobner(); - - if (l_vec.empty() && !done()) + std::pair> checks[] = + { {1, check1}, + {1, check2}, + {1, check3} }; + check_weighted(3, checks); + if (!l_vec.empty() || !lits.empty()) + return l_false; + } + + if (no_effect()) m_basics.basic_lemma(true); - if (l_vec.empty() && !done()) + if (no_effect()) m_basics.basic_lemma(false); - if (l_vec.empty() && !done()) + if (no_effect()) m_divisions.check(); -#if 0 - if (l_vec.empty() && !done() && !run_horner) - m_horner.horner_lemmas(); - - if (l_vec.empty() && !done() && !run_grobner) - m_grobner(); -#endif - if (!conflict_found() && !done() && run_bounded_nlsat) ret = bounded_nlsat(); @@ -1635,8 +1691,9 @@ bool core::no_lemmas_hold() const { } lbool core::test_check(vector& l) { + vector lits; m_lar_solver.set_status(lp::lp_status::OPTIMAL); - return check(l); + return check(lits, l); } std::ostream& core::print_terms(std::ostream& out) const { diff --git a/src/math/lp/nla_core.h b/src/math/lp/nla_core.h index 78f46bb41..530e08c8d 100644 --- a/src/math/lp/nla_core.h +++ b/src/math/lp/nla_core.h @@ -84,6 +84,7 @@ class core { smt_params_helper m_params; std::function m_relevant; vector * m_lemma_vec; + vector * m_literal_vec = nullptr; indexed_uint_set m_to_refine; tangents m_tangents; basics m_basics; @@ -110,6 +111,10 @@ class core { void check_weighted(unsigned sz, std::pair>* checks); + u_map m_lower_bounds_added, m_upper_bounds_added; + bool can_add_bound(unsigned j, u_map& bounds); + void add_bounds(); + public: // constructor core(lp::lar_solver& s, params_ref const& p, reslimit&); @@ -380,7 +385,7 @@ public: bool conflict_found() const; - lbool check(vector& l_vec); + lbool check(vector& ineqs, vector& l_vec); lbool check_power(lpvar r, lpvar x, lpvar y, vector& l_vec); void check_bounded_divisions(vector&); diff --git a/src/math/lp/nla_solver.cpp b/src/math/lp/nla_solver.cpp index 0e6efd526..ccc7b6073 100644 --- a/src/math/lp/nla_solver.cpp +++ b/src/math/lp/nla_solver.cpp @@ -42,8 +42,8 @@ namespace nla { bool solver::need_check() { return m_core->has_relevant_monomial(); } - lbool solver::check(vector& l) { - return m_core->check(l); + lbool solver::check(vector& lits, vector& lemmas) { + return m_core->check(lits, lemmas); } void solver::push(){ diff --git a/src/math/lp/nla_solver.h b/src/math/lp/nla_solver.h index d61b0593b..c1ad5f32a 100644 --- a/src/math/lp/nla_solver.h +++ b/src/math/lp/nla_solver.h @@ -36,7 +36,7 @@ namespace nla { void push(); void pop(unsigned scopes); bool need_check(); - lbool check(vector&); + lbool check(vector& lits, vector&); lbool check_power(lpvar r, lpvar x, lpvar y, vector&); bool is_monic_var(lpvar) const; bool influences_nl_var(lpvar) const; diff --git a/src/sat/smt/arith_solver.cpp b/src/sat/smt/arith_solver.cpp index 0e97c3503..77f10c000 100644 --- a/src/sat/smt/arith_solver.cpp +++ b/src/sat/smt/arith_solver.cpp @@ -1406,30 +1406,43 @@ namespace arith { m_lemma = l; //todo avoid the copy m_explanation = l.expl(); literal_vector core; - for (auto const& ineq : m_lemma.ineqs()) { - bool is_lower = true, pos = true, is_eq = false; - switch (ineq.cmp()) { - case lp::LE: is_lower = false; pos = false; break; - case lp::LT: is_lower = true; pos = true; break; - case lp::GE: is_lower = true; pos = false; break; - case lp::GT: is_lower = false; pos = true; break; - case lp::EQ: is_eq = true; pos = false; break; - case lp::NE: is_eq = true; pos = true; break; - default: UNREACHABLE(); - } - TRACE("arith", tout << "is_lower: " << is_lower << " pos " << pos << "\n";); - // TBD utility: lp::lar_term term = mk_term(ineq.m_poly); - // then term is used instead of ineq.m_term - sat::literal lit; - if (is_eq) - lit = mk_eq(ineq.term(), ineq.rs()); - else - lit = ctx.expr2literal(mk_bound(ineq.term(), ineq.rs(), is_lower)); - core.push_back(pos ? lit : ~lit); - } + for (auto const& ineq : m_lemma.ineqs()) + core.push_back(mk_ineq_literal(ineq)); set_conflict_or_lemma(hint_type::nla_h, core, false); } + void solver::assume_literals() { + for (auto const& ineq : m_nla_literals) { + auto lit = mk_ineq_literal(ineq); + ctx.mark_relevant(lit); + s().set_phase(lit); + } + } + + sat::literal solver::mk_ineq_literal(nla::ineq const& ineq) { + bool is_lower = true, pos = true, is_eq = false; + switch (ineq.cmp()) { + case lp::LE: is_lower = false; pos = false; break; + case lp::LT: is_lower = true; pos = true; break; + case lp::GE: is_lower = true; pos = false; break; + case lp::GT: is_lower = false; pos = true; break; + case lp::EQ: is_eq = true; pos = false; break; + case lp::NE: is_eq = true; pos = true; break; + default: UNREACHABLE(); + } + TRACE("arith", tout << "is_lower: " << is_lower << " pos " << pos << "\n";); + // TBD utility: lp::lar_term term = mk_term(ineq.m_poly); + // then term is used instead of ineq.m_term + sat::literal lit; + if (is_eq) + lit = mk_eq(ineq.term(), ineq.rs()); + else + lit = ctx.expr2literal(mk_bound(ineq.term(), ineq.rs(), is_lower)); + + return pos ? lit : ~lit; + } + + lbool solver::check_nla() { if (!m.inc()) { TRACE("arith", tout << "canceled\n";); @@ -1442,9 +1455,10 @@ namespace arith { return l_true; m_a1 = nullptr; m_a2 = nullptr; - lbool r = m_nla->check(m_nla_lemma_vector); + lbool r = m_nla->check(m_nla_literals, m_nla_lemma_vector); switch (r) { - case l_false: + case l_false: + assume_literals(); for (const nla::lemma& l : m_nla_lemma_vector) false_case_of_check_nla(l); break; diff --git a/src/sat/smt/arith_solver.h b/src/sat/smt/arith_solver.h index 20252ede9..87dfb1e57 100644 --- a/src/sat/smt/arith_solver.h +++ b/src/sat/smt/arith_solver.h @@ -249,6 +249,7 @@ namespace arith { // lemmas lp::explanation m_explanation; vector m_nla_lemma_vector; + vector m_nla_literals; literal_vector m_core, m_core2; vector m_coeffs; svector m_eqs; @@ -463,6 +464,8 @@ namespace arith { void set_evidence(lp::constraint_index idx); void assign(literal lit, literal_vector const& core, svector const& eqs, euf::th_proof_hint const* pma); + void assume_literals(); + sat::literal mk_ineq_literal(nla::ineq const& ineq); void false_case_of_check_nla(const nla::lemma& l); void dbg_finalize_model(model& mdl); diff --git a/src/smt/theory_lra.cpp b/src/smt/theory_lra.cpp index 6dc14fb13..cf4262422 100644 --- a/src/smt/theory_lra.cpp +++ b/src/smt/theory_lra.cpp @@ -1979,44 +1979,55 @@ public: } nla::lemma m_lemma; - + + literal mk_literal(nla::ineq const& ineq) { + bool is_lower = true, pos = true, is_eq = false; + switch (ineq.cmp()) { + case lp::LE: is_lower = false; pos = false; break; + case lp::LT: is_lower = true; pos = true; break; + case lp::GE: is_lower = true; pos = false; break; + case lp::GT: is_lower = false; pos = true; break; + case lp::EQ: is_eq = true; pos = false; break; + case lp::NE: is_eq = true; pos = true; break; + default: UNREACHABLE(); + } + TRACE("arith", tout << "is_lower: " << is_lower << " pos " << pos << "\n";); + app_ref atom(m); + // TBD utility: lp::lar_term term = mk_term(ineq.m_poly); + // then term is used instead of ineq.m_term + if (is_eq) + atom = mk_eq(ineq.term(), ineq.rs()); + else + // create term >= 0 (or term <= 0) + atom = mk_bound(ineq.term(), ineq.rs(), is_lower); + return literal(ctx().get_bool_var(atom), pos); + } + void false_case_of_check_nla(const nla::lemma & l) { m_lemma = l; //todo avoid the copy m_explanation = l.expl(); literal_vector core; for (auto const& ineq : m_lemma.ineqs()) { - bool is_lower = true, pos = true, is_eq = false; - switch (ineq.cmp()) { - case lp::LE: is_lower = false; pos = false; break; - case lp::LT: is_lower = true; pos = true; break; - case lp::GE: is_lower = true; pos = false; break; - case lp::GT: is_lower = false; pos = true; break; - case lp::EQ: is_eq = true; pos = false; break; - case lp::NE: is_eq = true; pos = true; break; - default: UNREACHABLE(); - } - TRACE("arith", tout << "is_lower: " << is_lower << " pos " << pos << "\n";); - app_ref atom(m); - // TBD utility: lp::lar_term term = mk_term(ineq.m_poly); - // then term is used instead of ineq.m_term - if (is_eq) { - atom = mk_eq(ineq.term(), ineq.rs()); - } - else { - // create term >= 0 (or term <= 0) - atom = mk_bound(ineq.term(), ineq.rs(), is_lower); - } - literal lit(ctx().get_bool_var(atom), pos); + auto lit = mk_literal(ineq); core.push_back(~lit); } set_conflict_or_lemma(core, false); } + + void assume_literal(nla::ineq const& i) { + auto lit = mk_literal(i); + ctx().mark_as_relevant(lit); + ctx().set_true_first_flag(lit.var()); + } final_check_status check_nla_continue() { m_a1 = nullptr; m_a2 = nullptr; - lbool r = m_nla->check(m_nla_lemma_vector); + lbool r = m_nla->check(m_nla_literals, m_nla_lemma_vector); + switch (r) { - case l_false: + case l_false: + for (const nla::ineq& i : m_nla_literals) + assume_literal(i); for (const nla::lemma & l : m_nla_lemma_vector) false_case_of_check_nla(l); return FC_CONTINUE; @@ -3170,6 +3181,7 @@ public: lp::explanation m_explanation; vector m_nla_lemma_vector; + vector m_nla_literals; literal_vector m_core; svector m_eqs; vector m_params;