3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-24 01:25:31 +00:00

Propagator (#5845)

* user propagator without ids

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* user propagator without ids

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* fix signature

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* references #5818

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* fix c++ build

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* switch to vs 2022

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* switch 2022

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* Update propagator example (I) (#5835)

* fix #5829

* na

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* switch to vs 2022

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* Adapted the example to the changes in the propagator

Co-authored-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* context goes out of scope in stack allocation, so can't used scoped context when passing objects around

* parameter check

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* add rewriter

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* Fixed bug in user-propagator "created" (#5843)

Co-authored-by: Clemens Eisenhofer <56730610+CEisenhofer@users.noreply.github.com>
This commit is contained in:
Nikolaj Bjorner 2022-02-17 09:21:41 +02:00 committed by GitHub
parent 2e15e2aa4d
commit 2e00f2f32d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 261 additions and 328 deletions

View file

@ -902,7 +902,7 @@ extern "C" {
Z3_fixed_eh fixed_eh) {
Z3_TRY;
RESET_ERROR_CODE();
user_propagator::fixed_eh_t _fixed = (void(*)(void*,user_propagator::callback*,unsigned,expr*))fixed_eh;
user_propagator::fixed_eh_t _fixed = (void(*)(void*,user_propagator::callback*,expr*,expr*))fixed_eh;
to_solver_ref(s)->user_propagate_register_fixed(_fixed);
Z3_CATCH;
}
@ -924,7 +924,7 @@ extern "C" {
Z3_eq_eh eq_eh) {
Z3_TRY;
RESET_ERROR_CODE();
user_propagator::eq_eh_t _eq = (void(*)(void*,user_propagator::callback*,unsigned,unsigned))eq_eh;
user_propagator::eq_eh_t _eq = (void(*)(void*,user_propagator::callback*,expr*,expr*))eq_eh;
to_solver_ref(s)->user_propagate_register_eq(_eq);
Z3_CATCH;
}
@ -935,39 +935,42 @@ extern "C" {
Z3_eq_eh diseq_eh) {
Z3_TRY;
RESET_ERROR_CODE();
user_propagator::eq_eh_t _diseq = (void(*)(void*,user_propagator::callback*,unsigned,unsigned))diseq_eh;
user_propagator::eq_eh_t _diseq = (void(*)(void*,user_propagator::callback*,expr*,expr*))diseq_eh;
to_solver_ref(s)->user_propagate_register_diseq(_diseq);
Z3_CATCH;
}
unsigned Z3_API Z3_solver_propagate_register(Z3_context c, Z3_solver s, Z3_ast e) {
void Z3_API Z3_solver_propagate_register(Z3_context c, Z3_solver s, Z3_ast e) {
Z3_TRY;
LOG_Z3_solver_propagate_register(c, s, e);
RESET_ERROR_CODE();
return to_solver_ref(s)->user_propagate_register_expr(to_expr(e));
Z3_CATCH_RETURN(0);
to_solver_ref(s)->user_propagate_register_expr(to_expr(e));
Z3_CATCH;
}
unsigned Z3_API Z3_solver_propagate_register_cb(Z3_context c, Z3_solver_callback s, Z3_ast e) {
void Z3_API Z3_solver_propagate_register_cb(Z3_context c, Z3_solver_callback s, Z3_ast e) {
Z3_TRY;
LOG_Z3_solver_propagate_register_cb(c, s, e);
RESET_ERROR_CODE();
return reinterpret_cast<user_propagator::callback*>(s)->register_cb(to_expr(e));
Z3_CATCH_RETURN(0);
reinterpret_cast<user_propagator::callback*>(s)->register_cb(to_expr(e));
Z3_CATCH;
}
void Z3_API Z3_solver_propagate_consequence(Z3_context c, Z3_solver_callback s, unsigned num_fixed, unsigned const* fixed_ids, unsigned num_eqs, unsigned const* eq_lhs, unsigned const* eq_rhs, Z3_ast conseq) {
void Z3_API Z3_solver_propagate_consequence(Z3_context c, Z3_solver_callback s, unsigned num_fixed, Z3_ast const* fixed_ids, unsigned num_eqs, Z3_ast const* eq_lhs, Z3_ast const* eq_rhs, Z3_ast conseq) {
Z3_TRY;
LOG_Z3_solver_propagate_consequence(c, s, num_fixed, fixed_ids, num_eqs, eq_lhs, eq_rhs, conseq);
RESET_ERROR_CODE();
reinterpret_cast<user_propagator::callback*>(s)->propagate_cb(num_fixed, fixed_ids, num_eqs, eq_lhs, eq_rhs, to_expr(conseq));
expr* const * _fixed_ids = (expr* const*) fixed_ids;
expr* const * _eq_lhs = (expr*const*) eq_lhs;
expr* const * _eq_rhs = (expr*const*) eq_rhs;
reinterpret_cast<user_propagator::callback*>(s)->propagate_cb(num_fixed, _fixed_ids, num_eqs, _eq_lhs, _eq_rhs, to_expr(conseq));
Z3_CATCH;
}
void Z3_API Z3_solver_propagate_created(Z3_context c, Z3_solver s, Z3_created_eh created_eh) {
Z3_TRY;
RESET_ERROR_CODE();
user_propagator::created_eh_t c = (void(*)(void*, user_propagator::callback*, expr*, unsigned))created_eh;
user_propagator::created_eh_t c = (void(*)(void*, user_propagator::callback*, expr*))created_eh;
to_solver_ref(s)->user_propagate_register_created(c);
Z3_CATCH;
}

View file

@ -155,9 +155,10 @@ namespace z3 {
class context {
private:
friend class user_propagator_base;
bool m_enable_exceptions;
rounding_mode m_rounding_mode;
Z3_context m_ctx;
Z3_context m_ctx = nullptr;
void init(config & c) {
set_context(Z3_mk_context_rc(c));
}
@ -173,7 +174,6 @@ namespace z3 {
context(context const &) = delete;
context & operator=(context const &) = delete;
friend class scoped_context;
context(Z3_context c) { set_context(c); }
void detach() { m_ctx = nullptr; }
public:
@ -394,14 +394,6 @@ namespace z3 {
expr_vector parse_file(char const* s, sort_vector const& sorts, func_decl_vector const& decls);
};
class scoped_context final {
context m_ctx;
public:
scoped_context(Z3_context c): m_ctx(c) {}
~scoped_context() { m_ctx.detach(); }
context& operator()() { return m_ctx; }
};
template<typename T>
class array {
@ -509,7 +501,7 @@ namespace z3 {
ast(context & c):object(c), m_ast(0) {}
ast(context & c, Z3_ast n):object(c), m_ast(n) { Z3_inc_ref(ctx(), m_ast); }
ast(ast const & s) :object(s), m_ast(s.m_ast) { Z3_inc_ref(ctx(), m_ast); }
~ast() { if (m_ast) Z3_dec_ref(*m_ctx, m_ast); }
~ast() { if (m_ast) { Z3_dec_ref(*m_ctx, m_ast); } }
operator Z3_ast() const { return m_ast; }
operator bool() const { return m_ast != 0; }
ast & operator=(ast const & s) {
@ -3933,23 +3925,20 @@ namespace z3 {
class user_propagator_base {
typedef std::function<void(unsigned, expr const&)> fixed_eh_t;
typedef std::function<void(expr const&, expr const&)> fixed_eh_t;
typedef std::function<void(void)> final_eh_t;
typedef std::function<void(unsigned, unsigned)> eq_eh_t;
typedef std::function<void(unsigned, expr const&)> created_eh_t;
typedef std::function<void(expr const&, expr const&)> eq_eh_t;
typedef std::function<void(expr const&)> created_eh_t;
final_eh_t m_final_eh;
eq_eh_t m_eq_eh;
fixed_eh_t m_fixed_eh;
created_eh_t m_created_eh;
solver* s;
Z3_context c;
context* c;
Z3_solver_callback cb { nullptr };
Z3_context ctx() {
return c ? c : (Z3_context)s->ctx();
}
struct scoped_cb {
user_propagator_base& p;
scoped_cb(void* _p, Z3_solver_callback cb):p(*static_cast<user_propagator_base*>(_p)) {
@ -3972,17 +3961,19 @@ namespace z3 {
return static_cast<user_propagator_base*>(p)->fresh(ctx);
}
static void fixed_eh(void* _p, Z3_solver_callback cb, unsigned id, Z3_ast _value) {
static void fixed_eh(void* _p, Z3_solver_callback cb, Z3_ast _var, Z3_ast _value) {
user_propagator_base* p = static_cast<user_propagator_base*>(_p);
scoped_cb _cb(p, cb);
scoped_context ctx(p->ctx());
expr value(ctx(), _value);
static_cast<user_propagator_base*>(p)->m_fixed_eh(id, value);
expr value(p->ctx(), _value);
expr var(p->ctx(), _var);
p->m_fixed_eh(var, value);
}
static void eq_eh(void* p, Z3_solver_callback cb, unsigned x, unsigned y) {
static void eq_eh(void* _p, Z3_solver_callback cb, Z3_ast _x, Z3_ast _y) {
user_propagator_base* p = static_cast<user_propagator_base*>(_p);
scoped_cb _cb(p, cb);
static_cast<user_propagator_base*>(p)->m_eq_eh(x, y);
expr x(p->ctx(), _x), y(p->ctx(), _y);
p->m_eq_eh(x, y);
}
static void final_eh(void* p, Z3_solver_callback cb) {
@ -3990,17 +3981,16 @@ namespace z3 {
static_cast<user_propagator_base*>(p)->m_final_eh();
}
static void created_eh(void* _p, Z3_solver_callback cb, Z3_ast _e, unsigned id) {
static void created_eh(void* _p, Z3_solver_callback cb, Z3_ast _e) {
user_propagator_base* p = static_cast<user_propagator_base*>(_p);
scoped_cb _cb(p, cb);
scoped_context ctx(p->ctx());
expr e(ctx(), _e);
static_cast<user_propagator_base*>(p)->m_created_eh(id, e);
expr e(p->ctx(), _e);
p->m_created_eh(e);
}
public:
user_propagator_base(Z3_context c) : s(nullptr), c(c) {}
user_propagator_base(context& c) : s(nullptr), c(&c) {}
user_propagator_base(solver* s): s(s), c(nullptr) {
Z3_solver_propagate_init(ctx(), *s, this, push_eh, pop_eh, fresh_eh);
@ -4011,6 +4001,10 @@ namespace z3 {
virtual ~user_propagator_base() = default;
context& ctx() {
return c ? *c : s->ctx();
}
/**
\brief user_propagators created using \c fresh() are created during
search and their lifetimes are restricted to search time. They should
@ -4035,7 +4029,7 @@ namespace z3 {
void register_fixed() {
assert(s);
m_fixed_eh = [this](unsigned id, expr const& e) {
m_fixed_eh = [this](expr const& id, expr const& e) {
fixed(id, e);
};
Z3_solver_propagate_fixed(ctx(), *s, fixed_eh);
@ -4049,7 +4043,7 @@ namespace z3 {
void register_eq() {
assert(s);
m_eq_eh = [this](unsigned x, unsigned y) {
m_eq_eh = [this](expr const& x, expr const& y) {
eq(x, y);
};
Z3_solver_propagate_eq(ctx(), *s, eq_eh);
@ -4084,19 +4078,19 @@ namespace z3 {
}
void register_created() {
m_created_eh = [this](unsigned id, expr const& e) {
created(id, e);
m_created_eh = [this](expr const& e) {
created(e);
};
Z3_solver_propagate_created(ctx(), *s, created_eh);
}
virtual void fixed(unsigned /*id*/, expr const& /*e*/) { }
virtual void fixed(expr const& /*id*/, expr const& /*e*/) { }
virtual void eq(unsigned /*x*/, unsigned /*y*/) { }
virtual void eq(expr const& /*x*/, expr const& /*y*/) { }
virtual void final() { }
virtual void created(unsigned /*id*/, expr const& /*e*/) {}
virtual void created(expr const& /*e*/) {}
/**
\brief tracks \c e by a unique identifier that is returned by the call.
@ -4112,34 +4106,40 @@ namespace z3 {
correspond to equalities that have been registered during a callback.
*/
unsigned add(expr const& e) {
void add(expr const& e) {
if (cb)
return Z3_solver_propagate_register_cb(ctx(), cb, e);
if (s)
return Z3_solver_propagate_register(ctx(), *s, e);
assert(false);
return 0;
Z3_solver_propagate_register_cb(ctx(), cb, e);
else if (s)
Z3_solver_propagate_register(ctx(), *s, e);
else
assert(false);
}
void conflict(unsigned num_fixed, unsigned const* fixed) {
void conflict(expr_vector const& fixed) {
assert(cb);
scoped_context _ctx(ctx());
expr conseq = _ctx().bool_val(false);
Z3_solver_propagate_consequence(ctx(), cb, num_fixed, fixed, 0, nullptr, nullptr, conseq);
expr conseq = ctx().bool_val(false);
array<Z3_ast> _fixed(fixed);
Z3_solver_propagate_consequence(ctx(), cb, fixed.size(), _fixed.ptr(), 0, nullptr, nullptr, conseq);
}
void propagate(unsigned num_fixed, unsigned const* fixed, expr const& conseq) {
void propagate(expr_vector const& fixed, expr const& conseq) {
assert(cb);
assert(conseq.ctx() == ctx());
Z3_solver_propagate_consequence(ctx(), cb, num_fixed, fixed, 0, nullptr, nullptr, conseq);
assert((Z3_context)conseq.ctx() == (Z3_context)ctx());
array<Z3_ast> _fixed(fixed);
Z3_solver_propagate_consequence(ctx(), cb, _fixed.size(), _fixed.ptr(), 0, nullptr, nullptr, conseq);
}
void propagate(unsigned num_fixed, unsigned const* fixed,
unsigned num_eqs, unsigned const* lhs, unsigned const * rhs,
void propagate(expr_vector const& fixed,
expr_vector const& lhs, expr_vector const& rhs,
expr const& conseq) {
assert(cb);
assert(conseq.ctx() == ctx());
Z3_solver_propagate_consequence(ctx(), cb, num_fixed, fixed, num_eqs, lhs, rhs, conseq);
assert((Z3_context)conseq.ctx() == (Z3_context)ctx());
assert(lhs.size() == rhs.size());
array<Z3_ast> _fixed(fixed);
array<Z3_ast> _lhs(lhs);
array<Z3_ast> _rhs(rhs);
Z3_solver_propagate_consequence(ctx(), cb, _fixed.size(), _fixed.ptr(), lhs.size(), _lhs.ptr(), _rhs.ptr(), conseq);
}
};

View file

@ -11261,7 +11261,7 @@ def user_prop_fresh(id, ctx):
def user_prop_fixed(ctx, cb, id, value):
prop = _prop_closures.get(ctx)
prop.cb = cb
prop.fixed(id, _to_expr_ref(ctypes.c_void_p(value), prop.ctx()))
prop.fixed(_to_expr_ref(ctypes.c_void_p(id), prop.ctx()), _to_expr_ref(ctypes.c_void_p(value), prop.ctx()))
prop.cb = None
@ -11275,6 +11275,8 @@ def user_prop_final(ctx, cb):
def user_prop_eq(ctx, cb, x, y):
prop = _prop_closures.get(ctx)
prop.cb = cb
x = _to_expr_ref(ctypes.c_void_p(x), prop.ctx())
y = _to_expr_ref(ctypes.c_void_p(y), prop.ctx())
prop.eq(x, y)
prop.cb = None
@ -11282,6 +11284,8 @@ def user_prop_eq(ctx, cb, x, y):
def user_prop_diseq(ctx, cb, x, y):
prop = _prop_closures.get(ctx)
prop.cb = cb
x = _to_expr_ref(ctypes.c_void_p(x), prop.ctx())
y = _to_expr_ref(ctypes.c_void_p(y), prop.ctx())
prop.diseq(x, y)
prop.cb = None
@ -11385,18 +11389,12 @@ class UserPropagateBase:
# Propagation can only be invoked as during a fixed or final callback.
#
def propagate(self, e, ids, eqs=[]):
num_fixed = len(ids)
_ids = (ctypes.c_uint * num_fixed)()
for i in range(num_fixed):
_ids[i] = ids[i]
_ids, num_fixed = _to_ast_array(ids)
num_eqs = len(eqs)
_lhs = (ctypes.c_uint * num_eqs)()
_rhs = (ctypes.c_uint * num_eqs)()
for i in range(num_eqs):
_lhs[i] = eqs[i][0]
_rhs[i] = eqs[i][1]
_lhs, _num_lhs = _to_ast_array([x for x, y in eqs])
_rhs, _num_lhs = _to_ast_array([y for x, y in eqs])
Z3_solver_propagate_consequence(e.ctx.ref(), ctypes.c_void_p(
self.cb), num_fixed, _ids, num_eqs, _lhs, _rhs, e.ast)
def conflict(self, ids):
self.propagate(BoolVal(False, self.ctx()), ids, eqs=[])
def conflict(self, deps):
self.propagate(BoolVal(False, self.ctx()), deps, eqs=[])

View file

@ -1433,10 +1433,10 @@ Z3_DECLARE_CLOSURE(Z3_error_handler, void, (Z3_context c, Z3_error_code e));
Z3_DECLARE_CLOSURE(Z3_push_eh, void, (void* ctx));
Z3_DECLARE_CLOSURE(Z3_pop_eh, void, (void* ctx, unsigned num_scopes));
Z3_DECLARE_CLOSURE(Z3_fresh_eh, void*, (void* ctx, Z3_context new_context));
Z3_DECLARE_CLOSURE(Z3_fixed_eh, void, (void* ctx, Z3_solver_callback cb, unsigned id, Z3_ast value));
Z3_DECLARE_CLOSURE(Z3_eq_eh, void, (void* ctx, Z3_solver_callback cb, unsigned x, unsigned y));
Z3_DECLARE_CLOSURE(Z3_fixed_eh, void, (void* ctx, Z3_solver_callback cb, Z3_ast t, Z3_ast value));
Z3_DECLARE_CLOSURE(Z3_eq_eh, void, (void* ctx, Z3_solver_callback cb, Z3_ast s, Z3_ast t));
Z3_DECLARE_CLOSURE(Z3_final_eh, void, (void* ctx, Z3_solver_callback cb));
Z3_DECLARE_CLOSURE(Z3_created_eh, void, (void* ctx, Z3_solver_callback cb, Z3_ast e, unsigned id));
Z3_DECLARE_CLOSURE(Z3_created_eh, void, (void* ctx, Z3_solver_callback cb, Z3_ast t));
/**
@ -6761,10 +6761,10 @@ extern "C" {
\brief register an expression to propagate on with the solver.
Only expressions of type Bool and type Bit-Vector can be registered for propagation.
def_API('Z3_solver_propagate_register', UINT, (_in(CONTEXT), _in(SOLVER), _in(AST)))
def_API('Z3_solver_propagate_register', VOID, (_in(CONTEXT), _in(SOLVER), _in(AST)))
*/
unsigned Z3_API Z3_solver_propagate_register(Z3_context c, Z3_solver s, Z3_ast e);
void Z3_API Z3_solver_propagate_register(Z3_context c, Z3_solver s, Z3_ast e);
/**
\brief register an expression to propagate on with the solver.
@ -6772,9 +6772,9 @@ extern "C" {
Unlike \ref Z3_solver_propagate_register, this function takes a solver callback context
as argument. It can be invoked during a callback to register new expressions.
def_API('Z3_solver_propagate_register_cb', UINT, (_in(CONTEXT), _in(SOLVER_CALLBACK), _in(AST)))
def_API('Z3_solver_propagate_register_cb', VOID, (_in(CONTEXT), _in(SOLVER_CALLBACK), _in(AST)))
*/
unsigned Z3_API Z3_solver_propagate_register_cb(Z3_context c, Z3_solver_callback cb, Z3_ast e);
void Z3_API Z3_solver_propagate_register_cb(Z3_context c, Z3_solver_callback cb, Z3_ast e);
/**
\brief propagate a consequence based on fixed values.
@ -6782,10 +6782,10 @@ extern "C" {
The callback adds a propagation consequence based on the fixed values of the
\c ids.
def_API('Z3_solver_propagate_consequence', VOID, (_in(CONTEXT), _in(SOLVER_CALLBACK), _in(UINT), _in_array(2, UINT), _in(UINT), _in_array(4, UINT), _in_array(4, UINT), _in(AST)))
def_API('Z3_solver_propagate_consequence', VOID, (_in(CONTEXT), _in(SOLVER_CALLBACK), _in(UINT), _in_array(2, AST), _in(UINT), _in_array(4, AST), _in_array(4, AST), _in(AST)))
*/
void Z3_API Z3_solver_propagate_consequence(Z3_context c, Z3_solver_callback, unsigned num_fixed, unsigned const* fixed_ids, unsigned num_eqs, unsigned const* eq_lhs, unsigned const* eq_rhs, Z3_ast conseq);
void Z3_API Z3_solver_propagate_consequence(Z3_context c, Z3_solver_callback, unsigned num_fixed, Z3_ast const* fixed, unsigned num_eqs, Z3_ast const* eq_lhs, Z3_ast const* eq_rhs, Z3_ast conseq);
/**
\brief Check whether the assertions in a given solver are consistent or not.