diff --git a/src/sat/smt/arith_axioms.cpp b/src/sat/smt/arith_axioms.cpp index 997de9ae2..c037e56cc 100644 --- a/src/sat/smt/arith_axioms.cpp +++ b/src/sat/smt/arith_axioms.cpp @@ -399,7 +399,7 @@ namespace arith { theory_var v = mk_evar(n); theory_var v1 = mk_evar(p); - if (!can_get_ivalue(v1)) + if (!is_registered_var(v1)) continue; lp::impq r1 = get_ivalue(v1); rational r2; @@ -419,7 +419,7 @@ namespace arith { TRACE("arith", tout << "unbounded " << expr_ref(n, m) << "\n";); continue; } - if (!can_get_ivalue(v)) + if (!is_registered_var(v)) continue; lp::impq val_v = get_ivalue(v); if (val_v.y.is_zero() && val_v.x == div(r1.x, r2)) continue; diff --git a/src/sat/smt/arith_diagnostics.cpp b/src/sat/smt/arith_diagnostics.cpp index 026568e08..eda7bc744 100644 --- a/src/sat/smt/arith_diagnostics.cpp +++ b/src/sat/smt/arith_diagnostics.cpp @@ -52,13 +52,16 @@ namespace arith { out << "null"; else out << (t.is_term() ? "t" : "j") << vi; - if (m_nla && m_nla->use_nra_model() && can_get_ivalue(v)) { + if (m_nla && m_nla->use_nra_model() && is_registered_var(v)) { scoped_anum an(m_nla->am()); m_nla->am().display(out << " = ", nl_value(v, an)); } - else if (can_get_value(v)) out << " = " << get_value(v); - if (is_int(v)) out << ", int"; - if (ctx.is_shared(var2enode(v))) out << ", shared"; + else if (m_model_is_initialized && is_registered_var(v)) + out << " = " << get_value(v); + if (is_int(v)) + out << ", int"; + if (ctx.is_shared(var2enode(v))) + out << ", shared"; } out << " := " << mk_bounded_pp(var2expr(v), m) << "\n"; } diff --git a/src/sat/smt/arith_solver.cpp b/src/sat/smt/arith_solver.cpp index 712e7d719..36f547eb2 100644 --- a/src/sat/smt/arith_solver.cpp +++ b/src/sat/smt/arith_solver.cpp @@ -30,7 +30,6 @@ namespace arith { m_bound_terms(m), m_bound_predicate(m) { - reset_variable_values(); m_solver = alloc(lp::lar_solver); smt_params_helper lpar(ctx.s().params()); @@ -751,18 +750,15 @@ namespace arith { bound_prop_mode::BP_NONE; } - void solver::init_variable_values() { - reset_variable_values(); + void solver::init_model() { if (m.inc() && m_solver.get() && get_num_vars() > 0) { TRACE("arith", display(tout << "update variable values\n");); - lp().get_model(m_variable_values); + ctx.push(value_trail(m_model_is_initialized)); + m_model_is_initialized = true; + lp().init_model(); } } - void solver::reset_variable_values() { - m_variable_values.clear(); - } - lbool solver::get_phase(bool_var v) { api_bound* b; if (!m_bool_var2bound.find(v, b)) { @@ -786,18 +782,10 @@ namespace arith { return lp().compare_values(vi, k, b->get_value()) ? l_true : l_false; } - bool solver::can_get_value(theory_var v) const { - return can_get_bound(v) && !m_variable_values.empty(); - } - - bool solver::can_get_bound(theory_var v) const { + bool solver::is_registered_var(theory_var v) const { return v != euf::null_theory_var && lp().external_is_used(v); } - bool solver::can_get_ivalue(theory_var v) const { - return can_get_bound(v); - } - void solver::ensure_column(theory_var v) { SASSERT(!is_bool(v)); if (!lp().external_is_used(v)) @@ -805,68 +793,14 @@ namespace arith { } lp::impq solver::get_ivalue(theory_var v) const { - SASSERT(can_get_ivalue(v)); - auto t = get_tv(v); - if (!t.is_term()) - return lp().get_column_value(t.id()); - m_todo_terms.push_back(std::make_pair(t, rational::one())); - lp::impq result(0); - while (!m_todo_terms.empty()) { - t = m_todo_terms.back().first; - rational coeff = m_todo_terms.back().second; - m_todo_terms.pop_back(); - if (t.is_term()) { - const lp::lar_term& term = lp().get_term(t); - for (const auto& i : term) { - m_todo_terms.push_back(std::make_pair(lp().column2tv(i.column()), coeff * i.coeff())); - } - } - else { - result += lp().get_column_value(t.id()) * coeff; - } - } - return result; + SASSERT(is_registered_var(v)); + return m_solver->get_ivalue(get_tv(v)); } rational solver::get_value(theory_var v) const { - if (v == euf::null_theory_var || !lp().external_is_used(v)) { + if (v == euf::null_theory_var || !lp().external_is_used(v)) return rational::zero(); - } - - auto t = get_tv(v); - if (m_variable_values.count(t.index()) > 0) - return m_variable_values[t.index()]; - - if (!t.is_term() && lp().is_fixed(t.id())) - return lp().column_lower_bound(t.id()).x; - - if (!t.is_term()) - return rational::zero(); - - m_todo_terms.push_back(std::make_pair(t, rational::one())); - rational result(0); - while (!m_todo_terms.empty()) { - auto t2 = m_todo_terms.back().first; - rational coeff = m_todo_terms.back().second; - m_todo_terms.pop_back(); - if (t2.is_term()) { - const lp::lar_term& term = lp().get_term(t2); - for (const auto& i : term) { - auto tv = lp().column2tv(i.column()); - if (m_variable_values.count(tv.index()) > 0) { - result += m_variable_values[tv.index()] * coeff * i.coeff(); - } - else { - m_todo_terms.push_back(std::make_pair(tv, coeff * i.coeff())); - } - } - } - else { - result += m_variable_values[t2.index()] * coeff; - } - } - m_variable_values[t.index()] = result; - return result; + return m_solver->get_value(get_tv(v)); } void solver::random_update() { @@ -915,7 +849,7 @@ namespace arith { if (!ctx.is_shared(var2enode(v))) continue; ensure_column(v); - if (!can_get_ivalue(v)) + if (!is_registered_var(v)) continue; theory_var other = m_model_eqs.insert_if_not_there(v); TRACE("arith", tout << "insert: v" << v << " := " << get_value(v) << " found: v" << other << "\n";); @@ -977,8 +911,8 @@ namespace arith { sat::check_result solver::check() { force_push(); + m_model_is_initialized = false; flet _is_learned(m_is_redundant, true); - reset_variable_values(); IF_VERBOSE(12, verbose_stream() << "final-check " << lp().get_status() << "\n"); SASSERT(lp().ax_is_correct()); diff --git a/src/sat/smt/arith_solver.h b/src/sat/smt/arith_solver.h index 6aff1528a..044d76989 100644 --- a/src/sat/smt/arith_solver.h +++ b/src/sat/smt/arith_solver.h @@ -149,7 +149,6 @@ namespace arith { vector m_columns; var_coeffs m_left_side; // constraint left side - mutable std::unordered_map m_variable_values; // current model lpvar m_one_var { UINT_MAX }; lpvar m_zero_var { UINT_MAX }; lpvar m_rone_var { UINT_MAX }; @@ -332,12 +331,8 @@ namespace arith { bool all_zeros(vector const& v) const; bound_prop_mode propagation_mode() const; - void init_variable_values(); - void reset_variable_values(); - bool can_get_value(theory_var v) const; - bool can_get_bound(theory_var v) const; - bool can_get_ivalue(theory_var v) const; + bool is_registered_var(theory_var v) const; void ensure_column(theory_var v); lp::impq get_ivalue(theory_var v) const; rational get_value(theory_var v) const; @@ -378,6 +373,7 @@ namespace arith { obj_map m_predicate2term; obj_map m_term2bound_info; + bool m_model_is_initialized{ false }; bool use_bounded_expansion() const { return get_config().m_arith_bounded_expansion; } unsigned small_lemma_size() const { return get_config().m_arith_small_lemma_size; } @@ -428,7 +424,7 @@ namespace arith { void new_eq_eh(euf::th_eq const& eq) override { mk_eq_axiom(true, eq); } void new_diseq_eh(euf::th_eq const& de) override { mk_eq_axiom(false, de); } bool unit_propagate() override; - void init_model() override { init_variable_values(); } + void init_model() override; void finalize_model(model& mdl) override { DEBUG_CODE(dbg_finalize_model(mdl);); } void add_value(euf::enode* n, model& mdl, expr_ref_vector& values) override; sat::literal internalize(expr* e, bool sign, bool root, bool learned) override;