From 412b05076ce3ebc28a1ef567c7d5841f51ad934a Mon Sep 17 00:00:00 2001 From: Clemens Eisenhofer <56730610+CEisenhofer@users.noreply.github.com> Date: Sat, 26 Feb 2022 18:21:01 +0100 Subject: [PATCH] User-functions fix (#5868) --- src/api/c++/z3++.h | 110 ++++++++++++++++++----------- src/smt/smt_context.cpp | 11 +-- src/smt/smt_context.h | 2 +- src/smt/theory_user_propagator.cpp | 36 ++++++++-- src/smt/theory_user_propagator.h | 2 +- 5 files changed, 107 insertions(+), 54 deletions(-) diff --git a/src/api/c++/z3++.h b/src/api/c++/z3++.h index 0bbab4185..970a96c72 100644 --- a/src/api/c++/z3++.h +++ b/src/api/c++/z3++.h @@ -25,6 +25,7 @@ Notes: #include #include #include +#include #include #include #include @@ -542,7 +543,7 @@ namespace z3 { ~ast_vector_tpl() { Z3_ast_vector_dec_ref(ctx(), m_vector); } operator Z3_ast_vector() const { return m_vector; } unsigned size() const { return Z3_ast_vector_size(ctx(), m_vector); } - T operator[](int i) const { assert(0 <= i); Z3_ast r = Z3_ast_vector_get(ctx(), m_vector, i); check_error(); return cast_ast()(ctx(), r); } + T operator[](unsigned i) const { Z3_ast r = Z3_ast_vector_get(ctx(), m_vector, i); check_error(); return cast_ast()(ctx(), r); } void push_back(T const & e) { Z3_ast_vector_push(ctx(), m_vector, e); check_error(); } void resize(unsigned sz) { Z3_ast_vector_resize(ctx(), m_vector, sz); check_error(); } T back() const { return operator[](size() - 1); } @@ -1149,6 +1150,19 @@ namespace z3 { \pre i < num_args() */ expr arg(unsigned i) const { Z3_ast r = Z3_get_app_arg(ctx(), *this, i); check_error(); return expr(ctx(), r); } + /** + \brief Return a vector of all the arguments of this application. + This method assumes the expression is an application. + + \pre is_app() + */ + expr_vector args() const { + expr_vector vec(ctx()); + unsigned argCnt = num_args(); + for (unsigned i = 0; i < argCnt; i++) + vec.push_back(arg(i)); + return vec; + } /** \brief Return the 'body' of this quantifier. @@ -3936,7 +3950,8 @@ namespace z3 { created_eh_t m_created_eh; solver* s; context* c; - + std::vector subcontexts; + Z3_solver_callback cb { nullptr }; struct scoped_cb { @@ -3944,8 +3959,8 @@ namespace z3 { scoped_cb(void* _p, Z3_solver_callback cb):p(*static_cast(_p)) { p.cb = cb; } - ~scoped_cb() { - p.cb = nullptr; + ~scoped_cb() { + p.cb = nullptr; } }; @@ -3958,7 +3973,9 @@ namespace z3 { } static void* fresh_eh(void* p, Z3_context ctx) { - return static_cast(p)->fresh(ctx); + context* c = new context(ctx); + static_cast(p)->subcontexts.push_back(c); + return static_cast(p)->fresh(*c); } static void fixed_eh(void* _p, Z3_solver_callback cb, Z3_ast _var, Z3_ast _value) { @@ -3993,60 +4010,69 @@ namespace z3 { user_propagator_base(context& c) : s(nullptr), c(&c) {} user_propagator_base(solver* s): s(s), c(nullptr) { - Z3_solver_propagate_init(ctx(), *s, this, push_eh, pop_eh, fresh_eh); + Z3_solver_propagate_init(ctx(), *s, this, push_eh, pop_eh, fresh_eh); } virtual void push() = 0; virtual void pop(unsigned num_scopes) = 0; - virtual ~user_propagator_base() = default; + virtual ~user_propagator_base() { + for (auto& subcontext : subcontexts) { + subcontext->detach(); // detach first; the subcontexts will be freed internally! + delete subcontext; + } + } context& ctx() { - return c ? *c : s->ctx(); + return c ? *c : s->ctx(); } /** - \brief user_propagators created using \c fresh() are created during + \brief user_propagators created using \c fresh() are created during search and their lifetimes are restricted to search time. They should be garbage collected by the propagator used to invoke \c fresh(). The life-time of the Z3_context object can only be assumed valid during callbacks, such as \c fixed(), which contains expressions based on the context. */ - virtual user_propagator_base* fresh(Z3_context ctx) = 0; + virtual user_propagator_base* fresh(context& ctx) = 0; /** \brief register callbacks. Callbacks can only be registered with user_propagators - that were created using a solver. + that were created using a solver. */ - void register_fixed(fixed_eh_t& f) { - assert(s); - m_fixed_eh = f; - Z3_solver_propagate_fixed(ctx(), *s, fixed_eh); + void register_fixed(fixed_eh_t& f) { + m_fixed_eh = f; + if (s) { + Z3_solver_propagate_fixed(ctx(), *s, fixed_eh); + } } void register_fixed() { - assert(s); - m_fixed_eh = [this](expr const& id, expr const& e) { + m_fixed_eh = [this](expr const &id, expr const &e) { fixed(id, e); }; - Z3_solver_propagate_fixed(ctx(), *s, fixed_eh); + if (s) { + Z3_solver_propagate_fixed(ctx(), *s, fixed_eh); + } } - void register_eq(eq_eh_t& f) { - assert(s); - m_eq_eh = f; - Z3_solver_propagate_eq(ctx(), *s, eq_eh); + void register_eq(eq_eh_t& f) { + m_eq_eh = f; + if (s) { + Z3_solver_propagate_eq(ctx(), *s, eq_eh); + } } void register_eq() { - assert(s); m_eq_eh = [this](expr const& x, expr const& y) { eq(x, y); }; - Z3_solver_propagate_eq(ctx(), *s, eq_eh); + if (s) { + Z3_solver_propagate_eq(ctx(), *s, eq_eh); + } } /** @@ -4054,34 +4080,39 @@ namespace z3 { During the final check stage, all propagations have been processed. This is an opportunity for the user-propagator to delay some analysis that could be expensive to perform incrementally. It is also an opportunity - for the propagator to implement branch and bound optimization. + for the propagator to implement branch and bound optimization. */ - void register_final(final_eh_t& f) { - assert(s); - m_final_eh = f; - Z3_solver_propagate_final(ctx(), *s, final_eh); + void register_final(final_eh_t& f) { + m_final_eh = f; + if (s) { + Z3_solver_propagate_final(ctx(), *s, final_eh); + } } - - void register_final() { - assert(s); + + void register_final() { m_final_eh = [this]() { final(); }; - Z3_solver_propagate_final(ctx(), *s, final_eh); + if (s) { + Z3_solver_propagate_final(ctx(), *s, final_eh); + } } void register_created(created_eh_t& c) { - assert(s); m_created_eh = c; - Z3_solver_propagate_created(ctx(), *s, created_eh); + if (s) { + Z3_solver_propagate_created(ctx(), *s, created_eh); + } } void register_created() { m_created_eh = [this](expr const& e) { created(e); }; - Z3_solver_propagate_created(ctx(), *s, created_eh); + if (s) { + Z3_solver_propagate_created(ctx(), *s, created_eh); + } } virtual void fixed(expr const& /*id*/, expr const& /*e*/) { } @@ -4095,10 +4126,10 @@ namespace z3 { /** \brief tracks \c e by a unique identifier that is returned by the call. - If the \c fixed() callback is registered and if \c e is a Boolean or Bit-vector, + If the \c fixed() callback is registered and if \c e is a Boolean or Bit-vector, the \c fixed() callback gets invoked when \c e is bound to a value. If the \c eq() callback is registered, then equalities between registered expressions - are reported. + are reported. A consumer can use the \c propagate or \c conflict functions to invoke propagations or conflicts as a consequence of these callbacks. These functions take a list of identifiers for registered expressions that have been fixed. The list of identifiers must correspond to @@ -4143,9 +4174,6 @@ namespace z3 { } }; - - - } /**@}*/ diff --git a/src/smt/smt_context.cpp b/src/smt/smt_context.cpp index 186d03dc3..2eba2c2a2 100644 --- a/src/smt/smt_context.cpp +++ b/src/smt/smt_context.cpp @@ -171,7 +171,7 @@ namespace smt { dst_ctx.setup_context(dst_ctx.m_fparams.m_auto_config); dst_ctx.internalize_assertions(); - dst_ctx.copy_user_propagator(src_ctx); + dst_ctx.copy_user_propagator(src_ctx, true); TRACE("smt_context", src_ctx.display(tout); @@ -193,13 +193,16 @@ namespace smt { } } - void context::copy_user_propagator(context& src_ctx) { + void context::copy_user_propagator(context& src_ctx, bool copy_registered) { if (!src_ctx.m_user_propagator) return; - ast_translation tr(src_ctx.m, m, false); auto* p = get_theory(m.mk_family_id("user_propagator")); m_user_propagator = reinterpret_cast(p); SASSERT(m_user_propagator); + if (!copy_registered) { + return; + } + ast_translation tr(src_ctx.m, m, false); for (unsigned i = 0; i < src_ctx.m_user_propagator->get_num_vars(); ++i) { app* e = src_ctx.m_user_propagator->get_expr(i); m_user_propagator->add_expr(tr(e)); @@ -211,7 +214,7 @@ namespace smt { new_ctx->m_is_auxiliary = true; new_ctx->set_logic(l == nullptr ? m_setup.get_logic() : *l); copy_plugins(*this, *new_ctx); - new_ctx->copy_user_propagator(*this); + new_ctx->copy_user_propagator(*this, false); return new_ctx; } diff --git a/src/smt/smt_context.h b/src/smt/smt_context.h index 12f0cc2ad..f1b2514b1 100644 --- a/src/smt/smt_context.h +++ b/src/smt/smt_context.h @@ -1576,7 +1576,7 @@ namespace smt { void log_stats(); - void copy_user_propagator(context& src); + void copy_user_propagator(context& src, bool copy_registered); public: context(ast_manager & m, smt_params & fp, params_ref const & p = params_ref()); diff --git a/src/smt/theory_user_propagator.cpp b/src/smt/theory_user_propagator.cpp index 85a02b154..751386ef3 100644 --- a/src/smt/theory_user_propagator.cpp +++ b/src/smt/theory_user_propagator.cpp @@ -94,8 +94,14 @@ void theory_user_propagator::register_cb(expr* e) { } theory * theory_user_propagator::mk_fresh(context * new_ctx) { - auto* th = alloc(theory_user_propagator, *new_ctx); - void* ctx = m_fresh_eh(m_user_context, new_ctx->get_manager(), th->m_api_context); + auto* th = alloc(theory_user_propagator, *new_ctx); + void* ctx; + try { + ctx = m_fresh_eh(m_user_context, new_ctx->get_manager(), th->m_api_context); + } + catch (...) { + throw default_exception("Exception thrown in \"fresh\"-callback"); + } th->add(ctx, m_push_eh, m_pop_eh, m_fresh_eh); if ((bool)m_fixed_eh) th->register_fixed(m_fixed_eh); if ((bool)m_final_eh) th->register_final(m_final_eh); @@ -110,7 +116,12 @@ final_check_status theory_user_propagator::final_check_eh() { return FC_DONE; force_push(); unsigned sz = m_prop.size(); - m_final_eh(m_user_context, this); + try { + m_final_eh(m_user_context, this); + } + catch (...) { + throw default_exception("Exception thrown in \"final\"-callback"); + } propagate(); bool done = (sz == m_prop.size()) && !ctx.inconsistent(); return done ? FC_DONE : FC_CONTINUE; @@ -125,7 +136,12 @@ void theory_user_propagator::new_fixed_eh(theory_var v, expr* value, unsigned nu m_fixed.insert(v); ctx.push_trail(insert_map(m_fixed, v)); m_id2justification.setx(v, literal_vector(num_lits, jlits), literal_vector()); - m_fixed_eh(m_user_context, this, var2expr(v), value); + try { + m_fixed_eh(m_user_context, this, var2expr(v), value); + } + catch (...) { + throw default_exception("Exception thrown in \"fixed\"-callback"); + } } void theory_user_propagator::push_scope_eh() { @@ -228,11 +244,17 @@ bool theory_user_propagator::internalize_term(app* term) { ctx.mk_enode(term, true, false, true); add_expr(term); + + if (!m_created_eh) + throw default_exception("You have to register a created event handler for new terms if you track them"); - if (!m_created_eh && (m_fixed_eh || m_eq_eh || m_diseq_eh)) - return true; - if (m_created_eh) + try { m_created_eh(m_user_context, this, term); + } + catch (...) { + throw default_exception("Exception thrown in \"created\"-callback"); + } + return true; } diff --git a/src/smt/theory_user_propagator.h b/src/smt/theory_user_propagator.h index e1aa33b8e..1045feb0a 100644 --- a/src/smt/theory_user_propagator.h +++ b/src/smt/theory_user_propagator.h @@ -142,7 +142,7 @@ namespace smt { void collect_statistics(::statistics & st) const override; model_value_proc * mk_value(enode * n, model_generator & mg) override { return nullptr; } void init_model(model_generator & m) override {} - bool include_func_interp(func_decl* f) override { return true; } + bool include_func_interp(func_decl* f) override { return false; } bool can_propagate() override; void propagate() override; void display(std::ostream& out) const override {}