diff --git a/scripts/update_api.py b/scripts/update_api.py index 7a0f857c3..59811bf4c 100755 --- a/scripts/update_api.py +++ b/scripts/update_api.py @@ -337,6 +337,10 @@ def Z3_set_error_handler(ctx, hndlr, _elems=Elementaries(_lib.Z3_set_error_handl _elems.Check(ctx) return ceh +def Z3_solver_propagate_init(ctx, s, user_ctx, push_eh, pop_eh, fixed_eh, fresh_eh, _elems = Elementaries(_lib.Z3_solver_propagate_init)): + _elems.f(ctx, s, user_ctx, push_eh, pop_eh, fixed_eh, fresh_eh) + _elems.Check(ctx) + """) for sig in _API2PY: @@ -967,6 +971,9 @@ def def_API(name, result, params): elif ty == BOOL: log_c.write(" I(a%s);\n" % i) exe_c.write("in.get_bool(%s)" % i) + elif ty == VOID_PTR: + log_c.write(" P(0);\n") + exe_c.write("in.get_obj_addr(%s)" % i) elif ty == PRINT_MODE or ty == ERROR_CODE: log_c.write(" U(static_cast(a%s));\n" % i) exe_c.write("static_cast<%s>(in.get_uint(%s))" % (type2str(ty), i)) @@ -1817,6 +1824,15 @@ _error_handler_type = ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_uint) _lib.Z3_set_error_handler.restype = None _lib.Z3_set_error_handler.argtypes = [ContextObj, _error_handler_type] +push_eh_type = ctypes.CFUNCTYPE(None, ctypes.c_void_p) +pop_eh_type = ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_uint) +fixed_eh_type = ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_uint, ctypes.c_void_p) +fresh_eh_type = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.c_void_p) + +_lib.Z3_solver_propagate_init.restype = None +_lib.Z3_solver_propagate_init.argtypes = [ContextObj, SolverObj, ctypes.c_void_p, push_eh_type, pop_eh_type, fixed_eh_type, fresh_eh_type] + + """ ) diff --git a/src/api/api_solver.cpp b/src/api/api_solver.cpp index 4cd4462b9..afcaa29db 100644 --- a/src/api/api_solver.cpp +++ b/src/api/api_solver.cpp @@ -892,14 +892,16 @@ extern "C" { void* user_context, Z3_push_eh push_eh, Z3_pop_eh pop_eh, - Z3_fixed_eh fixed_eh) { + Z3_fixed_eh fixed_eh, + Z3_fresh_eh fresh_eh) { Z3_TRY; RESET_ERROR_CODE(); init_solver(c, s); std::function _push = push_eh; std::function _pop = pop_eh; - std::function _fixed = [&](void* ctx, unsigned id, expr* e) { fixed_eh(ctx, id, of_ast(e)); }; - to_solver_ref(s)->user_propagate_init(user_context, _fixed, _push, _pop); + std::function _fixed = (void(*)(void*,unsigned,expr*))fixed_eh; + std::function _fresh = fresh_eh; + to_solver_ref(s)->user_propagate_init(user_context, _fixed, _push, _pop, _fresh); Z3_CATCH; } diff --git a/src/api/python/z3/z3.py b/src/api/python/z3/z3.py index 04c954a41..e7f719fe9 100644 --- a/src/api/python/z3/z3.py +++ b/src/api/python/z3/z3.py @@ -10505,17 +10505,44 @@ def TransitiveClosure(f): return FuncDeclRef(Z3_mk_transitive_closure(f.ctx_ref(), f.ast), f.ctx) -""" +_user_propagate_bases = {} + +def user_prop_push(ctx): + _user_propagate_bases[ctx].push(); + +def user_prop_pop(ctx, num_scopes): + _user_propagate_bases[ctx].pop(num_scopes) + +def user_prop_fixed(ctx, id, value): + prop = _user_propagate_bases[ctx] + prop.fixed(id, _to_expr_ref(ctypes.c_void_p(value), prop.ctx)) + +def user_prop_fresh(ctx): + prop = _user_propagate_bases[ctx] + new_prop = prop.copy() + return ctypes.c_void_p(new_prop.id) + + +_user_prop_push = push_eh_type(user_prop_push) +_user_prop_pop = pop_eh_type(user_prop_pop) +_user_prop_fixed = fixed_eh_type(user_prop_fixed) +_user_prop_fresh = fresh_eh_type(user_prop_fresh) + class UserPropagateBase: def __init__(self, s): + self.id = len(_user_propagate_bases) + 3 self.solver = s self.ctx = s.ctx - Z3_user_propagate_init(self, - ctypes.CFUNCTYPE(None, ctypes.c_void_p)(_user_prop_push), - ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_uint())(_user_prop_pop), - ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_uint(), ctypes.c_void_p)(_user_prop_fixed)) - + _user_propagate_bases[self.id] = self + Z3_solver_propagate_init(s.ctx.ref(), + s.solver, + ctypes.c_void_p(self.id), + _user_prop_push, + _user_prop_pop, + _user_prop_fixed, + _user_prop_fresh) + def push(self): raise Z3Exception("push has not been overwritten") @@ -10526,23 +10553,8 @@ class UserPropagateBase: raise Z3Exception("fixed has not been overwritten") def add(self, e): - return Z3_user_propagate_register(self.ctx.ref(), s.solver, e.ast) + return Z3_solver_propagate_register(self.ctx.ref(), self.solver.solver, e.ast) def propagate(self, ids, e): - Z3_user_propagate_consequence(self.ctx.ref(), s.solver, ids, e.ast) - -def _user_prop_push(ctx): - user_prop = ctx # need to access as python object. - user_prop.push() - -def _user_prop_pop(ctx, num_scopes): - user_prop = ctx # need to access as python object - user_prop.pop(num_scopes) - -def _user_prop_fixed(ctx, id, value): - user_prop = ctx # need to access as python object - user_prop.fixed(id, _to_expr_ref(value, user_prop.ctx)) - - -""" + Z3_solver_propagate_consequence(self.ctx.ref(), self.solver.solver, ids, e.ast) diff --git a/src/api/python/z3/z3types.py b/src/api/python/z3/z3types.py index 7cf61f49e..d52a7914e 100644 --- a/src/api/python/z3/z3types.py +++ b/src/api/python/z3/z3types.py @@ -121,3 +121,6 @@ class FuncEntryObj(ctypes.c_void_p): class RCFNumObj(ctypes.c_void_p): def __init__(self, e): self._as_parameter_ = e def from_param(obj): return obj + + + diff --git a/src/api/z3_api.h b/src/api/z3_api.h index b012ef66d..127bb90b1 100644 --- a/src/api/z3_api.h +++ b/src/api/z3_api.h @@ -1412,6 +1412,15 @@ typedef enum */ typedef void Z3_error_handler(Z3_context c, Z3_error_code e); + +/** + \brief callback functions for user propagator. +*/ +typedef void Z3_push_eh(void* ctx); +typedef void Z3_pop_eh(void* ctx, unsigned num_scopes); +typedef void Z3_fixed_eh(void* ctx, unsigned id, Z3_ast value); +typedef void* Z3_fresh_eh(void* ctx); + /** \brief A Goal is essentially a set of formulas. Z3 provide APIs for building strategies/tactics for solving and transforming Goals. @@ -6515,13 +6524,10 @@ extern "C" { Z3_ast Z3_API Z3_solver_get_implied_upper(Z3_context c, Z3_solver s, Z3_ast e); + /** \brief register a user-properator with the solver. - */ - - typedef void Z3_push_eh(void* ctx); - typedef void Z3_pop_eh(void* ctx, unsigned num_scopes); - typedef void Z3_fixed_eh(void* ctx, unsigned id, Z3_ast value); + */ void Z3_API Z3_solver_propagate_init( Z3_context c, @@ -6529,7 +6535,8 @@ extern "C" { void* user_context, Z3_push_eh push_eh, Z3_pop_eh pop_eh, - Z3_fixed_eh fixed_eh); + Z3_fixed_eh fixed_eh, + Z3_fresh_eh fresh_eh); /** \brief register an expression to propagate on with the solver. diff --git a/src/smt/smt_context.cpp b/src/smt/smt_context.cpp index 49c645028..bb935a61a 100644 --- a/src/smt/smt_context.cpp +++ b/src/smt/smt_context.cpp @@ -54,6 +54,7 @@ namespace smt { m_qmanager(alloc(quantifier_manager, *this, p, _p)), m_model_generator(alloc(model_generator, m)), m_relevancy_propagator(mk_relevancy_propagator(*this)), + m_user_propagator(nullptr), m_random(p.m_random_seed), m_flushing(false), m_lemma_id(0), @@ -514,6 +515,7 @@ namespace smt { m_qmanager->add_eq_eh(r1, r2); + merge_theory_vars(n2, n1, js); // 'Proof' tree @@ -528,7 +530,6 @@ namespace smt { // --------------- // r1 -> .. -> n1 -> n2 -> ... -> r2 - remove_parents_from_cg_table(r1); enode * curr = r1; @@ -1304,7 +1305,7 @@ namespace smt { bool_var v = l.var(); bool_var_data & d = get_bdata(v); lbool val = get_assignment(v); - CTRACE("propagate_atoms", v == 13, tout << "propagating atom, #" << bool_var2expr(v)->get_id() << ", is_enode(): " << d.is_enode() + TRACE("propagate_atoms", tout << "propagating atom, #" << bool_var2expr(v)->get_id() << ", is_enode(): " << d.is_enode() << " tag: " << (d.is_eq()?"eq":"") << (d.is_theory_atom()?"th":"") << (d.is_quantifier()?"q":"") << " " << l << "\n";); SASSERT(val != l_undef); if (d.is_enode()) @@ -1926,8 +1927,6 @@ namespace smt { for (theory* t : m_theory_set) t->push_scope_eh(); - if (m_user_propagator) - m_user_propagator->push_scope_eh(); CASSERT("context", check_invariant()); } @@ -2425,9 +2424,6 @@ namespace smt { for (theory* th : m_theory_set) th->pop_scope_eh(num_scopes); - if (m_user_propagator) - m_user_propagator->pop_scope_eh(num_scopes); - del_justifications(m_justifications, s.m_justifications_lim); m_asserted_formulas.pop_scope(num_scopes); @@ -2882,11 +2878,13 @@ namespace smt { void* ctx, std::function& fixed_eh, std::function& push_eh, - std::function& pop_eh) { + std::function& pop_eh, + std::function& fresh_eh) { m_user_propagator = alloc(user_propagator, *this); - m_user_propagator->add(ctx, fixed_eh, push_eh, pop_eh); + m_user_propagator->add(ctx, fixed_eh, push_eh, pop_eh, fresh_eh); for (unsigned i = m_scopes.size(); i-- > 0; ) m_user_propagator->push_scope_eh(); + register_plugin(m_user_propagator); } bool context::watches_fixed(enode* n) const { diff --git a/src/smt/smt_context.h b/src/smt/smt_context.h index c62341de3..6addf1479 100644 --- a/src/smt/smt_context.h +++ b/src/smt/smt_context.h @@ -93,7 +93,7 @@ namespace smt { scoped_ptr m_qmanager; scoped_ptr m_model_generator; scoped_ptr m_relevancy_propagator; - scoped_ptr m_user_propagator; + user_propagator* m_user_propagator; random_gen m_random; bool m_flushing; // (debug support) true when flushing mutable unsigned m_lemma_id; @@ -1686,7 +1686,8 @@ namespace smt { void* ctx, std::function& fixed_eh, std::function& push_eh, - std::function& pop_eh); + std::function& pop_eh, + std::function& fresh_eh); unsigned user_propagate_register(expr* e) { if (!m_user_propagator) diff --git a/src/smt/smt_kernel.cpp b/src/smt/smt_kernel.cpp index 40fe2cfd3..ca8ea5a37 100644 --- a/src/smt/smt_kernel.cpp +++ b/src/smt/smt_kernel.cpp @@ -237,8 +237,9 @@ namespace smt { void* ctx, std::function& fixed_eh, std::function& push_eh, - std::function& pop_eh) { - m_kernel.user_propagate_init(ctx, fixed_eh, push_eh, pop_eh); + std::function& pop_eh, + std::function& fresh_eh) { + m_kernel.user_propagate_init(ctx, fixed_eh, push_eh, pop_eh, fresh_eh); } unsigned user_propagate_register(expr* e) { @@ -464,8 +465,9 @@ namespace smt { void* ctx, std::function& fixed_eh, std::function& push_eh, - std::function& pop_eh) { - m_imp->user_propagate_init(ctx, fixed_eh, push_eh, pop_eh); + std::function& pop_eh, + std::function& fresh_eh) { + m_imp->user_propagate_init(ctx, fixed_eh, push_eh, pop_eh, fresh_eh); } unsigned kernel::user_propagate_register(expr* e) { diff --git a/src/smt/smt_kernel.h b/src/smt/smt_kernel.h index 7fdd14fc1..126e4c20f 100644 --- a/src/smt/smt_kernel.h +++ b/src/smt/smt_kernel.h @@ -291,7 +291,8 @@ namespace smt { void* ctx, std::function& fixed_eh, std::function& push_eh, - std::function& pop_eh); + std::function& pop_eh, + std::function& fresh_eh); /** \brief register an expression to be tracked fro user propagation. diff --git a/src/smt/smt_solver.cpp b/src/smt/smt_solver.cpp index 9868840e6..52cce7437 100644 --- a/src/smt/smt_solver.cpp +++ b/src/smt/smt_solver.cpp @@ -212,8 +212,9 @@ namespace { void* ctx, std::function& fixed_eh, std::function& push_eh, - std::function& pop_eh) override { - m_context.user_propagate_init(ctx, fixed_eh, push_eh, pop_eh); + std::function& pop_eh, + std::function& fresh_eh) override { + m_context.user_propagate_init(ctx, fixed_eh, push_eh, pop_eh, fresh_eh); } unsigned user_propagate_register(expr* e) override { diff --git a/src/smt/user_propagator.cpp b/src/smt/user_propagator.cpp index 7bcc551a3..59691db63 100644 --- a/src/smt/user_propagator.cpp +++ b/src/smt/user_propagator.cpp @@ -16,6 +16,7 @@ Author: --*/ +#include "ast/ast_pp.h" #include "smt/user_propagator.h" #include "smt/smt_context.h" @@ -23,26 +24,47 @@ using namespace smt; user_propagator::user_propagator(context& ctx): theory(ctx, ctx.get_manager().mk_family_id("user_propagator")), - m_qhead(0) + m_qhead(0), + m_num_scopes(0) {} +void user_propagator::force_push() { + for (; m_num_scopes > 0; --m_num_scopes) { + theory::push_scope_eh(); + m_push_eh(m_user_context); + m_prop_lim.push_back(m_prop.size()); + } +} + +// TODO: check type of 'e', either Bool or Bit-vector. +// + unsigned user_propagator::add_expr(expr* e) { - // TODO: check type of 'e', either Bool or Bit-vector. - return mk_var(ensure_enode(e)); + force_push(); + enode* n = ensure_enode(e); + if (is_attached_to_var(n)) + return n->get_th_var(get_id()); + theory_var v = mk_var(n); + ctx.attach_th_var(n, this, v); + return v; } void user_propagator::new_fixed_eh(theory_var v, expr* value, unsigned num_lits, literal const* jlits) { + force_push(); m_id2justification.setx(v, literal_vector(num_lits, jlits), literal_vector()); m_fixed_eh(m_user_context, v, value); } void user_propagator::push_scope_eh() { - theory::push_scope_eh(); - m_push_eh(m_user_context); - m_prop_lim.push_back(m_prop.size()); + ++m_num_scopes; } void user_propagator::pop_scope_eh(unsigned num_scopes) { + unsigned n = std::min(num_scopes, m_num_scopes); + m_num_scopes -= n; + num_scopes -= n; + if (num_scopes == 0) + return; m_pop_eh(m_user_context, num_scopes); theory::pop_scope_eh(num_scopes); unsigned old_sz = m_prop_lim.size() - num_scopes; @@ -55,6 +77,7 @@ bool user_propagator::can_propagate() { } void user_propagator::propagate() { + force_push(); unsigned qhead = m_qhead; literal_vector lits; enode_pair_vector eqs; diff --git a/src/smt/user_propagator.h b/src/smt/user_propagator.h index 23f8350bc..1320cca7d 100644 --- a/src/smt/user_propagator.h +++ b/src/smt/user_propagator.h @@ -32,6 +32,7 @@ namespace smt { std::function m_fixed_eh; std::function m_push_eh; std::function m_pop_eh; + std::function m_fresh_eh; struct prop_info { unsigned_vector m_ids; expr_ref m_conseq; @@ -44,6 +45,9 @@ namespace smt { vector m_prop; unsigned_vector m_prop_lim; vector m_id2justification; + unsigned m_num_scopes; + + void force_push(); public: user_propagator(context& ctx); @@ -57,11 +61,13 @@ namespace smt { void* ctx, std::function& fixed_eh, std::function& push_eh, - std::function& pop_eh) { + std::function& pop_eh, + std::function& fresh_eh) { m_user_context = ctx; m_fixed_eh = fixed_eh; m_push_eh = push_eh; m_pop_eh = pop_eh; + m_fresh_eh = fresh_eh; } unsigned add_expr(expr* e); @@ -72,16 +78,21 @@ namespace smt { void new_fixed_eh(theory_var v, expr* value, unsigned num_lits, literal const* jlits); - theory * mk_fresh(context * new_ctx) override { UNREACHABLE(); return alloc(user_propagator, *new_ctx); } + theory * mk_fresh(context * new_ctx) override { + auto* th = alloc(user_propagator, *new_ctx); + void* ctx = m_fresh_eh(m_user_context); + th->add(ctx, m_fixed_eh, m_push_eh, m_pop_eh, m_fresh_eh); + return th; + } bool internalize_atom(app * atom, bool gate_ctx) override { UNREACHABLE(); return false; } bool internalize_term(app * term) override { UNREACHABLE(); return false; } - void new_eq_eh(theory_var v1, theory_var v2) override { UNREACHABLE(); } - void new_diseq_eh(theory_var v1, theory_var v2) override { UNREACHABLE(); } + void new_eq_eh(theory_var v1, theory_var v2) override { } + void new_diseq_eh(theory_var v1, theory_var v2) override { } bool use_diseqs() const override { return false; } bool build_models() const override { return false; } - final_check_status final_check_eh() override { UNREACHABLE(); return FC_DONE; } + final_check_status final_check_eh() override { return FC_DONE; } void reset_eh() override {} - void assign_eh(bool_var v, bool is_true) override { UNREACHABLE(); } + void assign_eh(bool_var v, bool is_true) override { } void init_search_eh() override {} void push_scope_eh() override; void pop_scope_eh(unsigned num_scopes) override; diff --git a/src/solver/solver.h b/src/solver/solver.h index 659bf249b..14379e2f6 100644 --- a/src/solver/solver.h +++ b/src/solver/solver.h @@ -242,7 +242,8 @@ public: void* ctx, std::function& fixed_eh, std::function& push_eh, - std::function& pop_eh) { + std::function& pop_eh, + std::function& fresh_eh) { throw default_exception("user-propagators are only supported on the SMT solver"); }