3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-08 02:15:19 +00:00

User-functions fix (#5868)

This commit is contained in:
Clemens Eisenhofer 2022-02-26 18:21:01 +01:00 committed by GitHub
parent 689e2d41de
commit 412b05076c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 107 additions and 54 deletions

View file

@ -25,6 +25,7 @@ Notes:
#include<string>
#include<sstream>
#include<memory>
#include<vector>
#include<z3.h>
#include<limits.h>
#include<functional>
@ -542,7 +543,7 @@ namespace z3 {
~ast_vector_tpl() { Z3_ast_vector_dec_ref(ctx(), m_vector); }
operator Z3_ast_vector() const { return m_vector; }
unsigned size() const { return Z3_ast_vector_size(ctx(), m_vector); }
T operator[](int i) const { assert(0 <= i); Z3_ast r = Z3_ast_vector_get(ctx(), m_vector, i); check_error(); return cast_ast<T>()(ctx(), r); }
T operator[](unsigned i) const { Z3_ast r = Z3_ast_vector_get(ctx(), m_vector, i); check_error(); return cast_ast<T>()(ctx(), r); }
void push_back(T const & e) { Z3_ast_vector_push(ctx(), m_vector, e); check_error(); }
void resize(unsigned sz) { Z3_ast_vector_resize(ctx(), m_vector, sz); check_error(); }
T back() const { return operator[](size() - 1); }
@ -1149,6 +1150,19 @@ namespace z3 {
\pre i < num_args()
*/
expr arg(unsigned i) const { Z3_ast r = Z3_get_app_arg(ctx(), *this, i); check_error(); return expr(ctx(), r); }
/**
\brief Return a vector of all the arguments of this application.
This method assumes the expression is an application.
\pre is_app()
*/
expr_vector args() const {
expr_vector vec(ctx());
unsigned argCnt = num_args();
for (unsigned i = 0; i < argCnt; i++)
vec.push_back(arg(i));
return vec;
}
/**
\brief Return the 'body' of this quantifier.
@ -3936,7 +3950,8 @@ namespace z3 {
created_eh_t m_created_eh;
solver* s;
context* c;
std::vector<z3::context*> subcontexts;
Z3_solver_callback cb { nullptr };
struct scoped_cb {
@ -3944,8 +3959,8 @@ namespace z3 {
scoped_cb(void* _p, Z3_solver_callback cb):p(*static_cast<user_propagator_base*>(_p)) {
p.cb = cb;
}
~scoped_cb() {
p.cb = nullptr;
~scoped_cb() {
p.cb = nullptr;
}
};
@ -3958,7 +3973,9 @@ namespace z3 {
}
static void* fresh_eh(void* p, Z3_context ctx) {
return static_cast<user_propagator_base*>(p)->fresh(ctx);
context* c = new context(ctx);
static_cast<user_propagator_base*>(p)->subcontexts.push_back(c);
return static_cast<user_propagator_base*>(p)->fresh(*c);
}
static void fixed_eh(void* _p, Z3_solver_callback cb, Z3_ast _var, Z3_ast _value) {
@ -3993,60 +4010,69 @@ namespace z3 {
user_propagator_base(context& c) : s(nullptr), c(&c) {}
user_propagator_base(solver* s): s(s), c(nullptr) {
Z3_solver_propagate_init(ctx(), *s, this, push_eh, pop_eh, fresh_eh);
Z3_solver_propagate_init(ctx(), *s, this, push_eh, pop_eh, fresh_eh);
}
virtual void push() = 0;
virtual void pop(unsigned num_scopes) = 0;
virtual ~user_propagator_base() = default;
virtual ~user_propagator_base() {
for (auto& subcontext : subcontexts) {
subcontext->detach(); // detach first; the subcontexts will be freed internally!
delete subcontext;
}
}
context& ctx() {
return c ? *c : s->ctx();
return c ? *c : s->ctx();
}
/**
\brief user_propagators created using \c fresh() are created during
\brief user_propagators created using \c fresh() are created during
search and their lifetimes are restricted to search time. They should
be garbage collected by the propagator used to invoke \c fresh().
The life-time of the Z3_context object can only be assumed valid during
callbacks, such as \c fixed(), which contains expressions based on the
context.
*/
virtual user_propagator_base* fresh(Z3_context ctx) = 0;
virtual user_propagator_base* fresh(context& ctx) = 0;
/**
\brief register callbacks.
Callbacks can only be registered with user_propagators
that were created using a solver.
that were created using a solver.
*/
void register_fixed(fixed_eh_t& f) {
assert(s);
m_fixed_eh = f;
Z3_solver_propagate_fixed(ctx(), *s, fixed_eh);
void register_fixed(fixed_eh_t& f) {
m_fixed_eh = f;
if (s) {
Z3_solver_propagate_fixed(ctx(), *s, fixed_eh);
}
}
void register_fixed() {
assert(s);
m_fixed_eh = [this](expr const& id, expr const& e) {
m_fixed_eh = [this](expr const &id, expr const &e) {
fixed(id, e);
};
Z3_solver_propagate_fixed(ctx(), *s, fixed_eh);
if (s) {
Z3_solver_propagate_fixed(ctx(), *s, fixed_eh);
}
}
void register_eq(eq_eh_t& f) {
assert(s);
m_eq_eh = f;
Z3_solver_propagate_eq(ctx(), *s, eq_eh);
void register_eq(eq_eh_t& f) {
m_eq_eh = f;
if (s) {
Z3_solver_propagate_eq(ctx(), *s, eq_eh);
}
}
void register_eq() {
assert(s);
m_eq_eh = [this](expr const& x, expr const& y) {
eq(x, y);
};
Z3_solver_propagate_eq(ctx(), *s, eq_eh);
if (s) {
Z3_solver_propagate_eq(ctx(), *s, eq_eh);
}
}
/**
@ -4054,34 +4080,39 @@ namespace z3 {
During the final check stage, all propagations have been processed.
This is an opportunity for the user-propagator to delay some analysis
that could be expensive to perform incrementally. It is also an opportunity
for the propagator to implement branch and bound optimization.
for the propagator to implement branch and bound optimization.
*/
void register_final(final_eh_t& f) {
assert(s);
m_final_eh = f;
Z3_solver_propagate_final(ctx(), *s, final_eh);
void register_final(final_eh_t& f) {
m_final_eh = f;
if (s) {
Z3_solver_propagate_final(ctx(), *s, final_eh);
}
}
void register_final() {
assert(s);
void register_final() {
m_final_eh = [this]() {
final();
};
Z3_solver_propagate_final(ctx(), *s, final_eh);
if (s) {
Z3_solver_propagate_final(ctx(), *s, final_eh);
}
}
void register_created(created_eh_t& c) {
assert(s);
m_created_eh = c;
Z3_solver_propagate_created(ctx(), *s, created_eh);
if (s) {
Z3_solver_propagate_created(ctx(), *s, created_eh);
}
}
void register_created() {
m_created_eh = [this](expr const& e) {
created(e);
};
Z3_solver_propagate_created(ctx(), *s, created_eh);
if (s) {
Z3_solver_propagate_created(ctx(), *s, created_eh);
}
}
virtual void fixed(expr const& /*id*/, expr const& /*e*/) { }
@ -4095,10 +4126,10 @@ namespace z3 {
/**
\brief tracks \c e by a unique identifier that is returned by the call.
If the \c fixed() callback is registered and if \c e is a Boolean or Bit-vector,
If the \c fixed() callback is registered and if \c e is a Boolean or Bit-vector,
the \c fixed() callback gets invoked when \c e is bound to a value.
If the \c eq() callback is registered, then equalities between registered expressions
are reported.
are reported.
A consumer can use the \c propagate or \c conflict functions to invoke propagations
or conflicts as a consequence of these callbacks. These functions take a list of identifiers
for registered expressions that have been fixed. The list of identifiers must correspond to
@ -4143,9 +4174,6 @@ namespace z3 {
}
};
}
/**@}*/

View file

@ -171,7 +171,7 @@ namespace smt {
dst_ctx.setup_context(dst_ctx.m_fparams.m_auto_config);
dst_ctx.internalize_assertions();
dst_ctx.copy_user_propagator(src_ctx);
dst_ctx.copy_user_propagator(src_ctx, true);
TRACE("smt_context",
src_ctx.display(tout);
@ -193,13 +193,16 @@ namespace smt {
}
}
void context::copy_user_propagator(context& src_ctx) {
void context::copy_user_propagator(context& src_ctx, bool copy_registered) {
if (!src_ctx.m_user_propagator)
return;
ast_translation tr(src_ctx.m, m, false);
auto* p = get_theory(m.mk_family_id("user_propagator"));
m_user_propagator = reinterpret_cast<theory_user_propagator*>(p);
SASSERT(m_user_propagator);
if (!copy_registered) {
return;
}
ast_translation tr(src_ctx.m, m, false);
for (unsigned i = 0; i < src_ctx.m_user_propagator->get_num_vars(); ++i) {
app* e = src_ctx.m_user_propagator->get_expr(i);
m_user_propagator->add_expr(tr(e));
@ -211,7 +214,7 @@ namespace smt {
new_ctx->m_is_auxiliary = true;
new_ctx->set_logic(l == nullptr ? m_setup.get_logic() : *l);
copy_plugins(*this, *new_ctx);
new_ctx->copy_user_propagator(*this);
new_ctx->copy_user_propagator(*this, false);
return new_ctx;
}

View file

@ -1576,7 +1576,7 @@ namespace smt {
void log_stats();
void copy_user_propagator(context& src);
void copy_user_propagator(context& src, bool copy_registered);
public:
context(ast_manager & m, smt_params & fp, params_ref const & p = params_ref());

View file

@ -94,8 +94,14 @@ void theory_user_propagator::register_cb(expr* e) {
}
theory * theory_user_propagator::mk_fresh(context * new_ctx) {
auto* th = alloc(theory_user_propagator, *new_ctx);
void* ctx = m_fresh_eh(m_user_context, new_ctx->get_manager(), th->m_api_context);
auto* th = alloc(theory_user_propagator, *new_ctx);
void* ctx;
try {
ctx = m_fresh_eh(m_user_context, new_ctx->get_manager(), th->m_api_context);
}
catch (...) {
throw default_exception("Exception thrown in \"fresh\"-callback");
}
th->add(ctx, m_push_eh, m_pop_eh, m_fresh_eh);
if ((bool)m_fixed_eh) th->register_fixed(m_fixed_eh);
if ((bool)m_final_eh) th->register_final(m_final_eh);
@ -110,7 +116,12 @@ final_check_status theory_user_propagator::final_check_eh() {
return FC_DONE;
force_push();
unsigned sz = m_prop.size();
m_final_eh(m_user_context, this);
try {
m_final_eh(m_user_context, this);
}
catch (...) {
throw default_exception("Exception thrown in \"final\"-callback");
}
propagate();
bool done = (sz == m_prop.size()) && !ctx.inconsistent();
return done ? FC_DONE : FC_CONTINUE;
@ -125,7 +136,12 @@ void theory_user_propagator::new_fixed_eh(theory_var v, expr* value, unsigned nu
m_fixed.insert(v);
ctx.push_trail(insert_map<uint_set, unsigned>(m_fixed, v));
m_id2justification.setx(v, literal_vector(num_lits, jlits), literal_vector());
m_fixed_eh(m_user_context, this, var2expr(v), value);
try {
m_fixed_eh(m_user_context, this, var2expr(v), value);
}
catch (...) {
throw default_exception("Exception thrown in \"fixed\"-callback");
}
}
void theory_user_propagator::push_scope_eh() {
@ -228,11 +244,17 @@ bool theory_user_propagator::internalize_term(app* term) {
ctx.mk_enode(term, true, false, true);
add_expr(term);
if (!m_created_eh)
throw default_exception("You have to register a created event handler for new terms if you track them");
if (!m_created_eh && (m_fixed_eh || m_eq_eh || m_diseq_eh))
return true;
if (m_created_eh)
try {
m_created_eh(m_user_context, this, term);
}
catch (...) {
throw default_exception("Exception thrown in \"created\"-callback");
}
return true;
}

View file

@ -142,7 +142,7 @@ namespace smt {
void collect_statistics(::statistics & st) const override;
model_value_proc * mk_value(enode * n, model_generator & mg) override { return nullptr; }
void init_model(model_generator & m) override {}
bool include_func_interp(func_decl* f) override { return true; }
bool include_func_interp(func_decl* f) override { return false; }
bool can_propagate() override;
void propagate() override;
void display(std::ostream& out) const override {}