From 49a071988cec7aae927acd2084f7da79b89d70a5 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Wed, 1 Nov 2023 03:52:20 -0700 Subject: [PATCH] remove temporary algebraic numbers from upper layers, move to owner module --- src/math/lp/nla_solver.cpp | 11 ++++++ src/math/lp/nla_solver.h | 2 ++ src/math/lp/nra_solver.cpp | 68 ++++++++++++++++++++++-------------- src/math/lp/nra_solver.h | 5 +++ src/sat/smt/arith_solver.cpp | 17 +++------ src/sat/smt/arith_solver.h | 1 - src/smt/theory_lra.cpp | 23 ++++-------- 7 files changed, 70 insertions(+), 57 deletions(-) diff --git a/src/math/lp/nla_solver.cpp b/src/math/lp/nla_solver.cpp index 4b501a39e..5ed9b4538 100644 --- a/src/math/lp/nla_solver.cpp +++ b/src/math/lp/nla_solver.cpp @@ -88,6 +88,17 @@ namespace nla { return m_core->m_nra.value(v); } + scoped_anum& solver::tmp1() { + SASSERT(use_nra_model()); + return m_core->m_nra.tmp1(); + } + + scoped_anum& solver::tmp2() { + SASSERT(use_nra_model()); + return m_core->m_nra.tmp1(); + } + + // ensure r = x^y, add abstraction/refinement lemmas lbool solver::check_power(lpvar r, lpvar x, lpvar y) { return m_core->check_power(r, x, y); diff --git a/src/math/lp/nla_solver.h b/src/math/lp/nla_solver.h index 7da05c3e4..1fbafdf6b 100644 --- a/src/math/lp/nla_solver.h +++ b/src/math/lp/nla_solver.h @@ -47,6 +47,8 @@ namespace nla { core& get_core(); nlsat::anum_manager& am(); nlsat::anum const& am_value(lp::var_index v) const; + scoped_anum& tmp1(); + scoped_anum& tmp2(); vector const& lemmas() const; vector const& literals() const; vector const& fixed_equalities() const; diff --git a/src/math/lp/nra_solver.cpp b/src/math/lp/nra_solver.cpp index 74d6f7187..5661f2e89 100644 --- a/src/math/lp/nra_solver.cpp +++ b/src/math/lp/nra_solver.cpp @@ -27,6 +27,7 @@ struct solver::imp { indexed_uint_set m_term_set; scoped_ptr m_nlsat; scoped_ptr m_values; // values provided by LRA solver + scoped_ptr m_tmp1, m_tmp2; nla::core& m_nla_core; imp(lp::lar_solver& s, reslimit& lim, params_ref const& p, nla::core& nla_core): @@ -102,6 +103,15 @@ struct solver::imp { } } + void reset() { + m_values = nullptr; + m_tmp1 = nullptr; m_tmp2 = nullptr; + m_nlsat = alloc(nlsat::solver, m_limit, m_params, false); + m_values = alloc(scoped_anum_vector, am()); + m_term_set.reset(); + m_lp2nl.reset(); + } + /** \brief one-shot nlsat check. A one shot checker is the least functionality that can @@ -115,11 +125,7 @@ struct solver::imp { */ lbool check() { SASSERT(need_check()); - m_values = nullptr; - m_nlsat = alloc(nlsat::solver, m_limit, m_params, false); - m_values = alloc(scoped_anum_vector, am()); - m_term_set.reset(); - m_lp2nl.reset(); + reset(); vector core; init_cone_of_influence(); @@ -316,28 +322,24 @@ struct solver::imp { } lbool check(dd::solver::equation_vector const& eqs) { - m_values = nullptr; - m_nlsat = alloc(nlsat::solver, m_limit, m_params, false); - m_values = alloc(scoped_anum_vector, am()); - m_lp2nl.reset(); - m_term_set.reset(); + reset(); for (auto const& eq : eqs) - add_eq(*eq); + add_eq(*eq); for (auto const& m : m_nla_core.emons()) - if (any_of(m.vars(), [&](lp::lpvar v) { return m_lp2nl.contains(v); })) - add_monic_eq_bound(m); + if (any_of(m.vars(), [&](lp::lpvar v) { return m_lp2nl.contains(v); })) + add_monic_eq_bound(m); for (unsigned i : m_term_set) - add_term(i); + add_term(i); for (auto const& [v, w] : m_lp2nl) { - if (lra.column_has_lower_bound(v)) - add_lb(lra.get_lower_bound(v), w, lra.get_column_lower_bound_witness(v)); - if (lra.column_has_upper_bound(v)) - add_ub(lra.get_upper_bound(v), w, lra.get_column_upper_bound_witness(v)); + if (lra.column_has_lower_bound(v)) + add_lb(lra.get_lower_bound(v), w, lra.get_column_lower_bound_witness(v)); + if (lra.column_has_upper_bound(v)) + add_ub(lra.get_upper_bound(v), w, lra.get_column_upper_bound_witness(v)); } - + lbool r = l_undef; try { - r = m_nlsat->check(); + r = m_nlsat->check(); } catch (z3_exception&) { if (m_limit.is_canceled()) { @@ -347,7 +349,7 @@ struct solver::imp { throw; } } - + switch (r) { case l_true: m_nla_core.set_use_nra_model(true); @@ -380,11 +382,7 @@ struct solver::imp { } lbool check(vector const& eqs) { - m_values = nullptr; - m_nlsat = alloc(nlsat::solver, m_limit, m_params, false); - m_values = alloc(scoped_anum_vector, am()); - m_lp2nl.reset(); - m_term_set.reset(); + reset(); for (auto const& eq : eqs) add_eq(eq); for (auto const& m : m_nla_core.emons()) @@ -562,6 +560,19 @@ struct solver::imp { return m_nlsat->am(); } + scoped_anum& tmp1() { + if (!m_tmp1) + m_tmp1 = alloc(scoped_anum, am()); + return *m_tmp1; + } + + scoped_anum& tmp2() { + if (!m_tmp2) + m_tmp2 = alloc(scoped_anum, am()); + return *m_tmp2; + } + + void updt_params(params_ref& p) { m_params.append(p); } @@ -616,6 +627,11 @@ nlsat::anum_manager& solver::am() { return m_imp->am(); } +scoped_anum& solver::tmp1() { return m_imp->tmp1(); } + +scoped_anum& solver::tmp2() { return m_imp->tmp2(); } + + void solver::updt_params(params_ref& p) { m_imp->updt_params(p); } diff --git a/src/math/lp/nra_solver.h b/src/math/lp/nra_solver.h index db1311dad..747b4cee3 100644 --- a/src/math/lp/nra_solver.h +++ b/src/math/lp/nra_solver.h @@ -6,6 +6,7 @@ #pragma once #include "util/vector.h" #include "math/lp/lp_settings.h" +#include "math/polynomial/algebraic_numbers.h" #include "util/rlimit.h" #include "util/params.h" #include "nlsat/nlsat_solver.h" @@ -58,6 +59,10 @@ namespace nra { nlsat::anum_manager& am(); + scoped_anum& tmp1(); + + scoped_anum& tmp2(); + void updt_params(params_ref& p); /* diff --git a/src/sat/smt/arith_solver.cpp b/src/sat/smt/arith_solver.cpp index 74cce593d..cb2ac53ba 100644 --- a/src/sat/smt/arith_solver.cpp +++ b/src/sat/smt/arith_solver.cpp @@ -619,11 +619,11 @@ namespace arith { value = n->get_root()->get_expr(); } else if (use_nra_model() && lp().external_to_local(v) != lp::null_lpvar) { - anum const& an = nl_value(v, *m_a1); + anum const& an = nl_value(v, m_nla->tmp1()); if (a.is_int(o) && !m_nla->am().is_int(an)) value = a.mk_numeral(rational::zero(), a.is_int(o)); else - value = a.mk_numeral(m_nla->am(), nl_value(v, *m_a1), a.is_int(o)); + value = a.mk_numeral(m_nla->am(), nl_value(v, m_nla->tmp1()), a.is_int(o)); } else if (v != euf::null_theory_var) { rational r = get_value(v); @@ -961,19 +961,12 @@ namespace arith { } bool solver::use_nra_model() { - if (m_nla && m_nla->use_nra_model()) { - if (!m_a1) { - m_a1 = alloc(scoped_anum, m_nla->am()); - m_a2 = alloc(scoped_anum, m_nla->am()); - } - return true; - } - return false; + return m_nla && m_nla->use_nra_model(); } bool solver::is_eq(theory_var v1, theory_var v2) { if (use_nra_model()) { - return m_nla->am().eq(nl_value(v1, *m_a1), nl_value(v2, *m_a2)); + return m_nla->am().eq(nl_value(v1, m_nla->tmp1()), nl_value(v2, m_nla->tmp2())); } else { return get_ivalue(v1) == get_ivalue(v2); @@ -1471,7 +1464,6 @@ namespace arith { if (!m_nla->need_check()) return l_true; - m_a1 = nullptr; m_a2 = nullptr; lbool r = m_nla->check(); switch (r) { case l_false: @@ -1518,7 +1510,6 @@ namespace arith { void solver::propagate_nla() { if (m_nla) { - m_a1 = nullptr; m_a2 = nullptr; m_nla->propagate(); add_lemmas(); lp().collect_more_rows_for_lp_propagation(); diff --git a/src/sat/smt/arith_solver.h b/src/sat/smt/arith_solver.h index f3e8d1407..ddaaa6164 100644 --- a/src/sat/smt/arith_solver.h +++ b/src/sat/smt/arith_solver.h @@ -233,7 +233,6 @@ namespace arith { // non-linear arithmetic scoped_ptr m_nla; - scoped_ptr m_a1, m_a2; // integer arithmetic scoped_ptr m_lia; diff --git a/src/smt/theory_lra.cpp b/src/smt/theory_lra.cpp index d6c14311e..936efc459 100644 --- a/src/smt/theory_lra.cpp +++ b/src/smt/theory_lra.cpp @@ -174,7 +174,6 @@ class theory_lra::imp { // non-linear arithmetic scoped_ptr m_nla; - mutable scoped_ptr m_a1, m_a2; // integer arithmetic scoped_ptr m_lia; @@ -192,14 +191,7 @@ class theory_lra::imp { }; bool use_nra_model() const { - if (m_nla && m_nla->use_nra_model()) { - if (!m_a1) { - m_a1 = alloc(scoped_anum, m_nla->am()); - m_a2 = alloc(scoped_anum, m_nla->am()); - } - return true; - } - return false; + return m_nla && m_nla->use_nra_model(); } struct var_value_hash { @@ -1604,7 +1596,7 @@ public: bool is_eq(theory_var v1, theory_var v2) { if (use_nra_model()) - return m_nla->am().eq(nl_value(v1, *m_a1), nl_value(v2, *m_a2)); + return m_nla->am().eq(nl_value(v1, m_nla->tmp1()), nl_value(v2, m_nla->tmp2())); else return get_ivalue(v1) == get_ivalue(v2); } @@ -2038,7 +2030,6 @@ public: } final_check_status check_nla_continue() { - m_a1 = nullptr; m_a2 = nullptr; lbool r = m_nla->check(); switch (r) { case l_false: @@ -2178,8 +2169,6 @@ public: void propagate_nla() { if (m_nla) { - m_a1 = nullptr; - m_a2 = nullptr; m_nla->propagate(); add_lemmas(); lp().collect_more_rows_for_lp_propagation(); @@ -3387,7 +3376,7 @@ public: } nlsat::anum const& nl_value(theory_var v, scoped_anum& r) const { - SASSERT(m_nla && m_nla->use_nra_model()); + SASSERT(use_nra_model()); auto t = get_tv(v); if (t.is_term()) { @@ -3431,11 +3420,11 @@ public: theory_var v = n->get_th_var(get_id()); expr* o = n->get_expr(); if (use_nra_model() && lp().external_to_local(v) != lp::null_lpvar) { - anum const& an = nl_value(v, *m_a1); + anum const& an = nl_value(v, m_nla->tmp1()); if (a.is_int(o) && !m_nla->am().is_int(an)) { return alloc(expr_wrapper_proc, a.mk_numeral(rational::zero(), a.is_int(o))); } - return alloc(expr_wrapper_proc, a.mk_numeral(m_nla->am(), nl_value(v, *m_a1), a.is_int(o))); + return alloc(expr_wrapper_proc, a.mk_numeral(m_nla->am(), nl_value(v, m_nla->tmp1()), a.is_int(o))); } else { rational r = get_value(v); @@ -3818,7 +3807,7 @@ public: if (!ctx().is_relevant(get_enode(v))) out << "irr: "; out << "v" << v << " "; if (t.is_null()) out << "null"; else out << (t.is_term() ? "t":"j") << vi; - if (use_nra_model() && is_registered_var(v)) m_nla->am().display(out << " = ", nl_value(v, *m_a1)); + if (use_nra_model() && is_registered_var(v)) m_nla->am().display(out << " = ", nl_value(v, m_nla->tmp1())); else if (can_get_value(v)) out << " = " << get_value(v); if (is_int(v)) out << ", int"; if (ctx().is_shared(get_enode(v))) out << ", shared";