diff --git a/src/api/api_solver.cpp b/src/api/api_solver.cpp index 08f864226..ae77cb4ea 100644 --- a/src/api/api_solver.cpp +++ b/src/api/api_solver.cpp @@ -1092,15 +1092,15 @@ extern "C" { Z3_CATCH; } - void Z3_API Z3_solver_propagate_consequence(Z3_context c, Z3_solver_callback s, unsigned num_fixed, Z3_ast const* fixed_ids, unsigned num_eqs, Z3_ast const* eq_lhs, Z3_ast const* eq_rhs, Z3_ast conseq) { + bool Z3_API Z3_solver_propagate_consequence(Z3_context c, Z3_solver_callback s, unsigned num_fixed, Z3_ast const* fixed_ids, unsigned num_eqs, Z3_ast const* eq_lhs, Z3_ast const* eq_rhs, Z3_ast conseq) { Z3_TRY; LOG_Z3_solver_propagate_consequence(c, s, num_fixed, fixed_ids, num_eqs, eq_lhs, eq_rhs, conseq); RESET_ERROR_CODE(); expr* const * _fixed_ids = (expr* const*) fixed_ids; expr* const * _eq_lhs = (expr*const*) eq_lhs; expr* const * _eq_rhs = (expr*const*) eq_rhs; - reinterpret_cast(s)->propagate_cb(num_fixed, _fixed_ids, num_eqs, _eq_lhs, _eq_rhs, to_expr(conseq)); - Z3_CATCH; + return reinterpret_cast(s)->propagate_cb(num_fixed, _fixed_ids, num_eqs, _eq_lhs, _eq_rhs, to_expr(conseq)); + Z3_CATCH_RETURN(false); } void Z3_API Z3_solver_propagate_created(Z3_context c, Z3_solver s, Z3_created_eh created_eh) { diff --git a/src/api/c++/z3++.h b/src/api/c++/z3++.h index 88bbd2dcc..799644970 100644 --- a/src/api/c++/z3++.h +++ b/src/api/c++/z3++.h @@ -4496,14 +4496,14 @@ namespace z3 { Z3_solver_propagate_consequence(ctx(), cb, fixed.size(), _fixed.ptr(), lhs.size(), _lhs.ptr(), _rhs.ptr(), conseq); } - void propagate(expr_vector const& fixed, expr const& conseq) { + bool propagate(expr_vector const& fixed, expr const& conseq) { assert(cb); assert((Z3_context)conseq.ctx() == (Z3_context)ctx()); array _fixed(fixed); - Z3_solver_propagate_consequence(ctx(), cb, _fixed.size(), _fixed.ptr(), 0, nullptr, nullptr, conseq); + return Z3_solver_propagate_consequence(ctx(), cb, _fixed.size(), _fixed.ptr(), 0, nullptr, nullptr, conseq); } - void propagate(expr_vector const& fixed, + bool propagate(expr_vector const& fixed, expr_vector const& lhs, expr_vector const& rhs, expr const& conseq) { assert(cb); @@ -4513,7 +4513,7 @@ namespace z3 { array _lhs(lhs); array _rhs(rhs); - Z3_solver_propagate_consequence(ctx(), cb, _fixed.size(), _fixed.ptr(), lhs.size(), _lhs.ptr(), _rhs.ptr(), conseq); + return Z3_solver_propagate_consequence(ctx(), cb, _fixed.size(), _fixed.ptr(), lhs.size(), _lhs.ptr(), _rhs.ptr(), conseq); } }; diff --git a/src/api/dotnet/UserPropagator.cs b/src/api/dotnet/UserPropagator.cs index 68f2b0127..e591c3354 100644 --- a/src/api/dotnet/UserPropagator.cs +++ b/src/api/dotnet/UserPropagator.cs @@ -252,21 +252,29 @@ namespace Microsoft.Z3 /// /// Propagate consequence + /// + /// if the propagated expression is new for the solver; + /// if the propagation was ignored + /// /// - public void Propagate(IEnumerable terms, Expr conseq) + public bool Propagate(IEnumerable terms, Expr conseq) { - Propagate(terms, new EqualityPairs(), conseq); + return Propagate(terms, new EqualityPairs(), conseq); } /// /// Propagate consequence + /// + /// if the propagated expression is new for the solver; + /// if the propagation was ignored + /// /// - public void Propagate(IEnumerable terms, EqualityPairs equalities, Expr conseq) + public bool Propagate(IEnumerable terms, EqualityPairs equalities, Expr conseq) { var nTerms = Z3Object.ArrayToNative(terms.ToArray()); var nLHS = Z3Object.ArrayToNative(equalities.LHS.ToArray()); var nRHS = Z3Object.ArrayToNative(equalities.RHS.ToArray()); - Native.Z3_solver_propagate_consequence(ctx.nCtx, this.callback, (uint)nTerms.Length, nTerms, (uint)equalities.Count, nLHS, nRHS, conseq.NativeObject); + return Native.Z3_solver_propagate_consequence(ctx.nCtx, this.callback, (uint)nTerms.Length, nTerms, (uint)equalities.Count, nLHS, nRHS, conseq.NativeObject) != 0; } diff --git a/src/api/python/z3/z3.py b/src/api/python/z3/z3.py index 20871d36e..5c067f4d3 100644 --- a/src/api/python/z3/z3.py +++ b/src/api/python/z3/z3.py @@ -11704,7 +11704,7 @@ class UserPropagateBase: num_eqs = len(eqs) _lhs, _num_lhs = _to_ast_array([x for x, y in eqs]) _rhs, _num_rhs = _to_ast_array([y for x, y in eqs]) - Z3_solver_propagate_consequence(e.ctx.ref(), ctypes.c_void_p( + return Z3_solver_propagate_consequence(e.ctx.ref(), ctypes.c_void_p( self.cb), num_fixed, _ids, num_eqs, _lhs, _rhs, e.ast) def conflict(self, deps = [], eqs = []): diff --git a/src/api/z3_api.h b/src/api/z3_api.h index 54974d57b..b73c71912 100644 --- a/src/api/z3_api.h +++ b/src/api/z3_api.h @@ -7147,14 +7147,18 @@ extern "C" { /** \brief propagate a consequence based on fixed values. - This is a callback a client may invoke during the fixed_eh callback. + This is a callback a client may invoke during the fixed_eh callback. The callback adds a propagation consequence based on the fixed values of the - \c ids. - - def_API('Z3_solver_propagate_consequence', VOID, (_in(CONTEXT), _in(SOLVER_CALLBACK), _in(UINT), _in_array(2, AST), _in(UINT), _in_array(4, AST), _in_array(4, AST), _in(AST))) + \c ids. + The solver might discard the propagation in case it is true in the current state. + The function returns false in this case; otw. the function returns true. + At least one propagation in the final callback has to return true in order to + prevent the solver from finishing. + + def_API('Z3_solver_propagate_consequence', BOOL, (_in(CONTEXT), _in(SOLVER_CALLBACK), _in(UINT), _in_array(2, AST), _in(UINT), _in_array(4, AST), _in_array(4, AST), _in(AST))) */ - - void Z3_API Z3_solver_propagate_consequence(Z3_context c, Z3_solver_callback cb, unsigned num_fixed, Z3_ast const* fixed, unsigned num_eqs, Z3_ast const* eq_lhs, Z3_ast const* eq_rhs, Z3_ast conseq); + + bool Z3_API Z3_solver_propagate_consequence(Z3_context c, Z3_solver_callback cb, unsigned num_fixed, Z3_ast const* fixed, unsigned num_eqs, Z3_ast const* eq_lhs, Z3_ast const* eq_rhs, Z3_ast conseq); /** \brief Check whether the assertions in a given solver are consistent or not. diff --git a/src/math/lp/lar_solver.cpp b/src/math/lp/lar_solver.cpp index b6e9b63fe..07e3472c5 100644 --- a/src/math/lp/lar_solver.cpp +++ b/src/math/lp/lar_solver.cpp @@ -213,7 +213,7 @@ namespace lp { void lar_solver::fill_explanation_from_crossed_bounds_column(explanation& evidence) const { lp_assert(static_cast(get_column_type(m_crossed_bounds_column)) >= static_cast(column_type::boxed)); - lp_assert(!m_mpq_lar_core_solver.m_r_solver.column_is_feasible(m_crossed_bounds_column)); + lp_assert(!column_is_feasible(m_crossed_bounds_column)); // this is the case when the lower bound is in conflict with the upper one const ul_pair& ul = m_columns_to_ul_pairs[m_crossed_bounds_column]; @@ -673,7 +673,7 @@ namespace lp { m_mpq_lar_core_solver.m_r_solver.add_delta_to_x_and_track_feasibility(bj, -A_r().get_val(c) * delta); TRACE("change_x_del", tout << "changed basis column " << bj << ", it is " << - (m_mpq_lar_core_solver.m_r_solver.column_is_feasible(bj) ? "feas" : "inf") << std::endl;); + (column_is_feasible(bj) ? "feas" : "inf") << std::endl;); } } @@ -1327,7 +1327,7 @@ namespace lp { became_feas.clear(); for (unsigned j : m_mpq_lar_core_solver.m_r_solver.inf_heap()) { lp_assert(m_mpq_lar_core_solver.m_r_heading[j] >= 0); - if (m_mpq_lar_core_solver.m_r_solver.column_is_feasible(j)) + if (column_is_feasible(j)) became_feas.push_back(j); } for (unsigned j : became_feas) @@ -1738,16 +1738,18 @@ namespace lp { lconstraint_kind kind, const mpq& right_side, constraint_index constr_index) { + TRACE("lar_solver_feas", tout << "j = " << j << " was " << (this->column_is_feasible(j)?"feas":"non-feas") << std::endl;); m_constraints.activate(constr_index); if (column_has_upper_bound(j)) update_column_type_and_bound_with_ub(j, kind, right_side, constr_index); else update_column_type_and_bound_with_no_ub(j, kind, right_side, constr_index); + TRACE("lar_solver_feas", tout << "j = " << j << " became " << (this->column_is_feasible(j)?"feas":"non-feas") << ", and " << (this->column_is_bounded(j)? "bounded":"non-bounded") << std::endl;); } // clang-format on void lar_solver::insert_to_columns_with_changed_bounds(unsigned j) { m_columns_with_changed_bounds.insert(j); - TRACE("lar_solver", tout << "column " << j << (m_mpq_lar_core_solver.m_r_solver.column_is_feasible(j) ? " feas" : " non-feas") << "\n";); + TRACE("lar_solver", tout << "column " << j << (column_is_feasible(j) ? " feas" : " non-feas") << "\n";); } // clang-format off void lar_solver::update_column_type_and_bound_check_on_equal(unsigned j, diff --git a/src/math/lp/lar_solver.h b/src/math/lp/lar_solver.h index fd6fef8fd..b130c198e 100644 --- a/src/math/lp/lar_solver.h +++ b/src/math/lp/lar_solver.h @@ -481,6 +481,7 @@ class lar_solver : public column_namer { unsigned map_term_index_to_column_index(unsigned j) const; bool column_is_fixed(unsigned j) const; bool column_is_free(unsigned j) const; + bool column_is_feasible(unsigned j) const { return m_mpq_lar_core_solver.m_r_solver.column_is_feasible(j);} unsigned column_to_reported_index(unsigned j) const; lp_settings& settings(); lp_settings const& settings() const; diff --git a/src/math/lp/lp_core_solver_base.h b/src/math/lp/lp_core_solver_base.h index e058100ab..232119b77 100644 --- a/src/math/lp/lp_core_solver_base.h +++ b/src/math/lp/lp_core_solver_base.h @@ -539,31 +539,23 @@ public: return m_basis_heading[j] >= 0; } - - void update_x_with_feasibility_tracking(unsigned j, const X & v) { - TRACE("lar_solver", tout << "j = " << j << ", v = " << v << "\n";); - m_x[j] = v; - track_column_feasibility(j); - } - void add_delta_to_x_and_track_feasibility(unsigned j, const X & del) { - TRACE("lar_solver", tout << "del = " << del << ", was x[" << j << "] = " << m_x[j] << "\n";); + TRACE("lar_solver_feas_bug", tout << "del = " << del << ", was x[" << j << "] = " << m_x[j] << "\n";); m_x[j] += del; - TRACE("lar_solver", tout << "became x[" << j << "] = " << m_x[j] << "\n";); + TRACE("lar_solver_feas_bug", tout << "became x[" << j << "] = " << m_x[j] << "\n";); track_column_feasibility(j); } void update_x(unsigned j, const X & v) { m_x[j] = v; - TRACE("lar_solver", tout << "j = " << j << ", v = " << v << (column_is_feasible(j)? " feas":" non-feas") << "\n";); + TRACE("lar_solver_feas", tout << "not tracking feas j = " << j << ", v = " << v << (column_is_feasible(j)? " feas":" non-feas") << "\n";); } - // clang-format on - void add_delta_to_x(unsigned j, const X& delta) { - m_x[j] += delta; - TRACE("lar_solver", tout << "j = " << j << " v = " << m_x[j] << " delta = " << delta << (column_is_feasible(j) ? " feas" : " non-feas") << "\n";); - } - // clang-format off - + + void add_delta_to_x(unsigned j, const X& delta) { + m_x[j] += delta; + TRACE("lar_solver_feas", tout << "not tracking feas j = " << j << " v = " << m_x[j] << " delta = " << delta << (column_is_feasible(j) ? " feas" : " non-feas") << "\n";); + } + void track_column_feasibility(unsigned j) { if (column_is_feasible(j)) remove_column_from_inf_heap(j); @@ -573,7 +565,7 @@ public: void insert_column_into_inf_heap(unsigned j) { if (!m_inf_heap.contains(j)) { m_inf_heap.insert(j); - TRACE("lar_solver_inf_heap", tout << "insert into heap j = " << j << "\n";); + TRACE("lar_solver_inf_heap", tout << "insert into inf_heap j = " << j << "\n";); } lp_assert(!column_is_feasible(j)); } @@ -586,7 +578,7 @@ public: } void clear_inf_heap() { - TRACE("lar_solver",); + TRACE("lar_solver_feas",); m_inf_heap.clear(); } diff --git a/src/sat/smt/user_solver.cpp b/src/sat/smt/user_solver.cpp index 1e8897b8c..2823c81f8 100644 --- a/src/sat/smt/user_solver.cpp +++ b/src/sat/smt/user_solver.cpp @@ -43,15 +43,19 @@ namespace user_solver { m_prop.push_back(prop_info(explain, v, r)); } - void solver::propagate_cb( - unsigned num_fixed, expr* const* fixed_ids, - unsigned num_eqs, expr* const* eq_lhs, expr* const* eq_rhs, - expr* conseq) { + bool solver::propagate_cb( + unsigned num_fixed, expr* const* fixed_ids, + unsigned num_eqs, expr* const* eq_lhs, expr* const* eq_rhs, + expr* conseq) { + auto* n = ctx.get_enode(conseq); + if (n && s().value(ctx.enode2literal(n)) == l_true) + return false; m_fixed_ids.reset(); for (unsigned i = 0; i < num_fixed; ++i) m_fixed_ids.push_back(get_th_var(fixed_ids[i])); m_prop.push_back(prop_info(num_fixed, m_fixed_ids.data(), num_eqs, eq_lhs, eq_rhs, expr_ref(conseq, m))); DEBUG_CODE(validate_propagation();); + return true; } void solver::register_cb(expr* e) { @@ -76,7 +80,7 @@ namespace user_solver { sat::check_result solver::check() { if (!(bool)m_final_eh) - return sat::check_result::CR_DONE; + return sat::check_result::CR_DONE; unsigned sz = m_prop.size(); m_final_eh(m_user_context, this); return sz == m_prop.size() ? sat::check_result::CR_DONE : sat::check_result::CR_CONTINUE; diff --git a/src/sat/smt/user_solver.h b/src/sat/smt/user_solver.h index bd1b703e0..cd94441ea 100644 --- a/src/sat/smt/user_solver.h +++ b/src/sat/smt/user_solver.h @@ -135,7 +135,7 @@ namespace user_solver { bool has_fixed() const { return (bool)m_fixed_eh; } - void propagate_cb(unsigned num_fixed, expr* const* fixed_ids, unsigned num_eqs, expr* const* lhs, expr* const* rhs, expr* conseq) override; + bool propagate_cb(unsigned num_fixed, expr* const* fixed_ids, unsigned num_eqs, expr* const* lhs, expr* const* rhs, expr* conseq) override; void register_cb(expr* e) override; bool next_split_cb(expr* e, unsigned idx, lbool phase) override; diff --git a/src/smt/theory_user_propagator.cpp b/src/smt/theory_user_propagator.cpp index 2d5b4917d..7c72419c1 100644 --- a/src/smt/theory_user_propagator.cpp +++ b/src/smt/theory_user_propagator.cpp @@ -83,7 +83,7 @@ void theory_user_propagator::add_expr(expr* term, bool ensure_enode) { } -void theory_user_propagator::propagate_cb( +bool theory_user_propagator::propagate_cb( unsigned num_fixed, expr* const* fixed_ids, unsigned num_eqs, expr* const* eq_lhs, expr* const* eq_rhs, expr* conseq) { @@ -95,9 +95,10 @@ void theory_user_propagator::propagate_cb( if (!ctx.get_manager().is_true(_conseq) && !ctx.get_manager().is_false(_conseq)) ctx.mark_as_relevant((expr*)_conseq); - if (ctx.lit_internalized(_conseq) && ctx.get_assignment(ctx.get_literal(_conseq)) == l_true) - return; - m_prop.push_back(prop_info(num_fixed, fixed_ids, num_eqs, eq_lhs, eq_rhs, _conseq)); + if (ctx.lit_internalized(_conseq) && ctx.get_assignment(ctx.get_literal(_conseq)) == l_true) + return false; + m_prop.push_back(prop_info(num_fixed, fixed_ids, num_eqs, eq_lhs, eq_rhs, _conseq)); + return true; } void theory_user_propagator::register_cb(expr* e) { @@ -386,7 +387,7 @@ bool theory_user_propagator::internalize_atom(app* atom, bool gate_ctx) { return internalize_term(atom); } -bool theory_user_propagator::internalize_term(app* term) { +bool theory_user_propagator::internalize_term(app* term) { for (auto arg : *term) ensure_enode(arg); if (term->get_family_id() == get_id() && !ctx.e_internalized(term)) diff --git a/src/smt/theory_user_propagator.h b/src/smt/theory_user_propagator.h index 5a6eafc0a..2f045b0ba 100644 --- a/src/smt/theory_user_propagator.h +++ b/src/smt/theory_user_propagator.h @@ -130,7 +130,7 @@ namespace smt { bool has_fixed() const { return (bool)m_fixed_eh; } - void propagate_cb(unsigned num_fixed, expr* const* fixed_ids, unsigned num_eqs, expr* const* lhs, expr* const* rhs, expr* conseq) override; + bool propagate_cb(unsigned num_fixed, expr* const* fixed_ids, unsigned num_eqs, expr* const* lhs, expr* const* rhs, expr* conseq) override; void register_cb(expr* e) override; bool next_split_cb(expr* e, unsigned idx, lbool phase) override; diff --git a/src/tactic/user_propagator_base.h b/src/tactic/user_propagator_base.h index d4dae5166..40c0fa8fc 100644 --- a/src/tactic/user_propagator_base.h +++ b/src/tactic/user_propagator_base.h @@ -9,7 +9,7 @@ namespace user_propagator { class callback { public: virtual ~callback() = default; - virtual void propagate_cb(unsigned num_fixed, expr* const* fixed_ids, unsigned num_eqs, expr* const* eq_lhs, expr* const* eq_rhs, expr* conseq) = 0; + virtual bool propagate_cb(unsigned num_fixed, expr* const* fixed_ids, unsigned num_eqs, expr* const* eq_lhs, expr* const* eq_rhs, expr* conseq) = 0; virtual void register_cb(expr* e) = 0; virtual bool next_split_cb(expr* e, unsigned idx, lbool phase) = 0; };