From 4cb158a79b2f3c8afdf96e12e2873de815b66c74 Mon Sep 17 00:00:00 2001 From: Clemens Eisenhofer <56730610+CEisenhofer@users.noreply.github.com> Date: Fri, 7 Jul 2023 18:58:41 +0200 Subject: [PATCH] User Propagator: Return if propagated lemma is redundant (#6791) * Give users ability to see if propagation failed * Skip propagations in the new core if they are already satisfied --- src/api/api_solver.cpp | 6 +++--- src/api/c++/z3++.h | 8 ++++---- src/api/dotnet/UserPropagator.cs | 16 ++++++++++++---- src/api/python/z3/z3.py | 2 +- src/api/z3_api.h | 16 ++++++++++------ src/sat/smt/user_solver.cpp | 14 +++++++++----- src/sat/smt/user_solver.h | 2 +- src/smt/theory_user_propagator.cpp | 11 ++++++----- src/smt/theory_user_propagator.h | 2 +- src/tactic/user_propagator_base.h | 2 +- 10 files changed, 48 insertions(+), 31 deletions(-) diff --git a/src/api/api_solver.cpp b/src/api/api_solver.cpp index 08f864226..ae77cb4ea 100644 --- a/src/api/api_solver.cpp +++ b/src/api/api_solver.cpp @@ -1092,15 +1092,15 @@ extern "C" { Z3_CATCH; } - void Z3_API Z3_solver_propagate_consequence(Z3_context c, Z3_solver_callback s, unsigned num_fixed, Z3_ast const* fixed_ids, unsigned num_eqs, Z3_ast const* eq_lhs, Z3_ast const* eq_rhs, Z3_ast conseq) { + bool Z3_API Z3_solver_propagate_consequence(Z3_context c, Z3_solver_callback s, unsigned num_fixed, Z3_ast const* fixed_ids, unsigned num_eqs, Z3_ast const* eq_lhs, Z3_ast const* eq_rhs, Z3_ast conseq) { Z3_TRY; LOG_Z3_solver_propagate_consequence(c, s, num_fixed, fixed_ids, num_eqs, eq_lhs, eq_rhs, conseq); RESET_ERROR_CODE(); expr* const * _fixed_ids = (expr* const*) fixed_ids; expr* const * _eq_lhs = (expr*const*) eq_lhs; expr* const * _eq_rhs = (expr*const*) eq_rhs; - reinterpret_cast(s)->propagate_cb(num_fixed, _fixed_ids, num_eqs, _eq_lhs, _eq_rhs, to_expr(conseq)); - Z3_CATCH; + return reinterpret_cast(s)->propagate_cb(num_fixed, _fixed_ids, num_eqs, _eq_lhs, _eq_rhs, to_expr(conseq)); + Z3_CATCH_RETURN(false); } void Z3_API Z3_solver_propagate_created(Z3_context c, Z3_solver s, Z3_created_eh created_eh) { diff --git a/src/api/c++/z3++.h b/src/api/c++/z3++.h index 88bbd2dcc..799644970 100644 --- a/src/api/c++/z3++.h +++ b/src/api/c++/z3++.h @@ -4496,14 +4496,14 @@ namespace z3 { Z3_solver_propagate_consequence(ctx(), cb, fixed.size(), _fixed.ptr(), lhs.size(), _lhs.ptr(), _rhs.ptr(), conseq); } - void propagate(expr_vector const& fixed, expr const& conseq) { + bool propagate(expr_vector const& fixed, expr const& conseq) { assert(cb); assert((Z3_context)conseq.ctx() == (Z3_context)ctx()); array _fixed(fixed); - Z3_solver_propagate_consequence(ctx(), cb, _fixed.size(), _fixed.ptr(), 0, nullptr, nullptr, conseq); + return Z3_solver_propagate_consequence(ctx(), cb, _fixed.size(), _fixed.ptr(), 0, nullptr, nullptr, conseq); } - void propagate(expr_vector const& fixed, + bool propagate(expr_vector const& fixed, expr_vector const& lhs, expr_vector const& rhs, expr const& conseq) { assert(cb); @@ -4513,7 +4513,7 @@ namespace z3 { array _lhs(lhs); array _rhs(rhs); - Z3_solver_propagate_consequence(ctx(), cb, _fixed.size(), _fixed.ptr(), lhs.size(), _lhs.ptr(), _rhs.ptr(), conseq); + return Z3_solver_propagate_consequence(ctx(), cb, _fixed.size(), _fixed.ptr(), lhs.size(), _lhs.ptr(), _rhs.ptr(), conseq); } }; diff --git a/src/api/dotnet/UserPropagator.cs b/src/api/dotnet/UserPropagator.cs index 68f2b0127..e591c3354 100644 --- a/src/api/dotnet/UserPropagator.cs +++ b/src/api/dotnet/UserPropagator.cs @@ -252,21 +252,29 @@ namespace Microsoft.Z3 /// /// Propagate consequence + /// + /// if the propagated expression is new for the solver; + /// if the propagation was ignored + /// /// - public void Propagate(IEnumerable terms, Expr conseq) + public bool Propagate(IEnumerable terms, Expr conseq) { - Propagate(terms, new EqualityPairs(), conseq); + return Propagate(terms, new EqualityPairs(), conseq); } /// /// Propagate consequence + /// + /// if the propagated expression is new for the solver; + /// if the propagation was ignored + /// /// - public void Propagate(IEnumerable terms, EqualityPairs equalities, Expr conseq) + public bool Propagate(IEnumerable terms, EqualityPairs equalities, Expr conseq) { var nTerms = Z3Object.ArrayToNative(terms.ToArray()); var nLHS = Z3Object.ArrayToNative(equalities.LHS.ToArray()); var nRHS = Z3Object.ArrayToNative(equalities.RHS.ToArray()); - Native.Z3_solver_propagate_consequence(ctx.nCtx, this.callback, (uint)nTerms.Length, nTerms, (uint)equalities.Count, nLHS, nRHS, conseq.NativeObject); + return Native.Z3_solver_propagate_consequence(ctx.nCtx, this.callback, (uint)nTerms.Length, nTerms, (uint)equalities.Count, nLHS, nRHS, conseq.NativeObject) != 0; } diff --git a/src/api/python/z3/z3.py b/src/api/python/z3/z3.py index 20871d36e..5c067f4d3 100644 --- a/src/api/python/z3/z3.py +++ b/src/api/python/z3/z3.py @@ -11704,7 +11704,7 @@ class UserPropagateBase: num_eqs = len(eqs) _lhs, _num_lhs = _to_ast_array([x for x, y in eqs]) _rhs, _num_rhs = _to_ast_array([y for x, y in eqs]) - Z3_solver_propagate_consequence(e.ctx.ref(), ctypes.c_void_p( + return Z3_solver_propagate_consequence(e.ctx.ref(), ctypes.c_void_p( self.cb), num_fixed, _ids, num_eqs, _lhs, _rhs, e.ast) def conflict(self, deps = [], eqs = []): diff --git a/src/api/z3_api.h b/src/api/z3_api.h index 54974d57b..b73c71912 100644 --- a/src/api/z3_api.h +++ b/src/api/z3_api.h @@ -7147,14 +7147,18 @@ extern "C" { /** \brief propagate a consequence based on fixed values. - This is a callback a client may invoke during the fixed_eh callback. + This is a callback a client may invoke during the fixed_eh callback. The callback adds a propagation consequence based on the fixed values of the - \c ids. - - def_API('Z3_solver_propagate_consequence', VOID, (_in(CONTEXT), _in(SOLVER_CALLBACK), _in(UINT), _in_array(2, AST), _in(UINT), _in_array(4, AST), _in_array(4, AST), _in(AST))) + \c ids. + The solver might discard the propagation in case it is true in the current state. + The function returns false in this case; otw. the function returns true. + At least one propagation in the final callback has to return true in order to + prevent the solver from finishing. + + def_API('Z3_solver_propagate_consequence', BOOL, (_in(CONTEXT), _in(SOLVER_CALLBACK), _in(UINT), _in_array(2, AST), _in(UINT), _in_array(4, AST), _in_array(4, AST), _in(AST))) */ - - void Z3_API Z3_solver_propagate_consequence(Z3_context c, Z3_solver_callback cb, unsigned num_fixed, Z3_ast const* fixed, unsigned num_eqs, Z3_ast const* eq_lhs, Z3_ast const* eq_rhs, Z3_ast conseq); + + bool Z3_API Z3_solver_propagate_consequence(Z3_context c, Z3_solver_callback cb, unsigned num_fixed, Z3_ast const* fixed, unsigned num_eqs, Z3_ast const* eq_lhs, Z3_ast const* eq_rhs, Z3_ast conseq); /** \brief Check whether the assertions in a given solver are consistent or not. diff --git a/src/sat/smt/user_solver.cpp b/src/sat/smt/user_solver.cpp index 1e8897b8c..2823c81f8 100644 --- a/src/sat/smt/user_solver.cpp +++ b/src/sat/smt/user_solver.cpp @@ -43,15 +43,19 @@ namespace user_solver { m_prop.push_back(prop_info(explain, v, r)); } - void solver::propagate_cb( - unsigned num_fixed, expr* const* fixed_ids, - unsigned num_eqs, expr* const* eq_lhs, expr* const* eq_rhs, - expr* conseq) { + bool solver::propagate_cb( + unsigned num_fixed, expr* const* fixed_ids, + unsigned num_eqs, expr* const* eq_lhs, expr* const* eq_rhs, + expr* conseq) { + auto* n = ctx.get_enode(conseq); + if (n && s().value(ctx.enode2literal(n)) == l_true) + return false; m_fixed_ids.reset(); for (unsigned i = 0; i < num_fixed; ++i) m_fixed_ids.push_back(get_th_var(fixed_ids[i])); m_prop.push_back(prop_info(num_fixed, m_fixed_ids.data(), num_eqs, eq_lhs, eq_rhs, expr_ref(conseq, m))); DEBUG_CODE(validate_propagation();); + return true; } void solver::register_cb(expr* e) { @@ -76,7 +80,7 @@ namespace user_solver { sat::check_result solver::check() { if (!(bool)m_final_eh) - return sat::check_result::CR_DONE; + return sat::check_result::CR_DONE; unsigned sz = m_prop.size(); m_final_eh(m_user_context, this); return sz == m_prop.size() ? sat::check_result::CR_DONE : sat::check_result::CR_CONTINUE; diff --git a/src/sat/smt/user_solver.h b/src/sat/smt/user_solver.h index bd1b703e0..cd94441ea 100644 --- a/src/sat/smt/user_solver.h +++ b/src/sat/smt/user_solver.h @@ -135,7 +135,7 @@ namespace user_solver { bool has_fixed() const { return (bool)m_fixed_eh; } - void propagate_cb(unsigned num_fixed, expr* const* fixed_ids, unsigned num_eqs, expr* const* lhs, expr* const* rhs, expr* conseq) override; + bool propagate_cb(unsigned num_fixed, expr* const* fixed_ids, unsigned num_eqs, expr* const* lhs, expr* const* rhs, expr* conseq) override; void register_cb(expr* e) override; bool next_split_cb(expr* e, unsigned idx, lbool phase) override; diff --git a/src/smt/theory_user_propagator.cpp b/src/smt/theory_user_propagator.cpp index 2d5b4917d..7c72419c1 100644 --- a/src/smt/theory_user_propagator.cpp +++ b/src/smt/theory_user_propagator.cpp @@ -83,7 +83,7 @@ void theory_user_propagator::add_expr(expr* term, bool ensure_enode) { } -void theory_user_propagator::propagate_cb( +bool theory_user_propagator::propagate_cb( unsigned num_fixed, expr* const* fixed_ids, unsigned num_eqs, expr* const* eq_lhs, expr* const* eq_rhs, expr* conseq) { @@ -95,9 +95,10 @@ void theory_user_propagator::propagate_cb( if (!ctx.get_manager().is_true(_conseq) && !ctx.get_manager().is_false(_conseq)) ctx.mark_as_relevant((expr*)_conseq); - if (ctx.lit_internalized(_conseq) && ctx.get_assignment(ctx.get_literal(_conseq)) == l_true) - return; - m_prop.push_back(prop_info(num_fixed, fixed_ids, num_eqs, eq_lhs, eq_rhs, _conseq)); + if (ctx.lit_internalized(_conseq) && ctx.get_assignment(ctx.get_literal(_conseq)) == l_true) + return false; + m_prop.push_back(prop_info(num_fixed, fixed_ids, num_eqs, eq_lhs, eq_rhs, _conseq)); + return true; } void theory_user_propagator::register_cb(expr* e) { @@ -386,7 +387,7 @@ bool theory_user_propagator::internalize_atom(app* atom, bool gate_ctx) { return internalize_term(atom); } -bool theory_user_propagator::internalize_term(app* term) { +bool theory_user_propagator::internalize_term(app* term) { for (auto arg : *term) ensure_enode(arg); if (term->get_family_id() == get_id() && !ctx.e_internalized(term)) diff --git a/src/smt/theory_user_propagator.h b/src/smt/theory_user_propagator.h index 5a6eafc0a..2f045b0ba 100644 --- a/src/smt/theory_user_propagator.h +++ b/src/smt/theory_user_propagator.h @@ -130,7 +130,7 @@ namespace smt { bool has_fixed() const { return (bool)m_fixed_eh; } - void propagate_cb(unsigned num_fixed, expr* const* fixed_ids, unsigned num_eqs, expr* const* lhs, expr* const* rhs, expr* conseq) override; + bool propagate_cb(unsigned num_fixed, expr* const* fixed_ids, unsigned num_eqs, expr* const* lhs, expr* const* rhs, expr* conseq) override; void register_cb(expr* e) override; bool next_split_cb(expr* e, unsigned idx, lbool phase) override; diff --git a/src/tactic/user_propagator_base.h b/src/tactic/user_propagator_base.h index d4dae5166..40c0fa8fc 100644 --- a/src/tactic/user_propagator_base.h +++ b/src/tactic/user_propagator_base.h @@ -9,7 +9,7 @@ namespace user_propagator { class callback { public: virtual ~callback() = default; - virtual void propagate_cb(unsigned num_fixed, expr* const* fixed_ids, unsigned num_eqs, expr* const* eq_lhs, expr* const* eq_rhs, expr* conseq) = 0; + virtual bool propagate_cb(unsigned num_fixed, expr* const* fixed_ids, unsigned num_eqs, expr* const* eq_lhs, expr* const* eq_rhs, expr* conseq) = 0; virtual void register_cb(expr* e) = 0; virtual bool next_split_cb(expr* e, unsigned idx, lbool phase) = 0; };