diff --git a/src/api/api_solver.cpp b/src/api/api_solver.cpp index 6f38d0246..1d3643edf 100644 --- a/src/api/api_solver.cpp +++ b/src/api/api_solver.cpp @@ -948,6 +948,14 @@ extern "C" { Z3_CATCH_RETURN(0); } + unsigned Z3_API Z3_solver_propagate_register_cb(Z3_context c, Z3_solver_callback s, Z3_ast e) { + Z3_TRY; + Z3_solver_propagate_register_cb(c, s, e); + RESET_ERROR_CODE(); + return reinterpret_cast(s)->register_cb(to_expr(e)); + Z3_CATCH_RETURN(0); + } + void Z3_API Z3_solver_propagate_consequence(Z3_context c, Z3_solver_callback s, unsigned num_fixed, unsigned const* fixed_ids, unsigned num_eqs, unsigned const* eq_lhs, unsigned const* eq_rhs, Z3_ast conseq) { Z3_TRY; LOG_Z3_solver_propagate_consequence(c, s, num_fixed, fixed_ids, num_eqs, eq_lhs, eq_rhs, conseq); diff --git a/src/api/c++/z3++.h b/src/api/c++/z3++.h index 3e86a3d80..95c09e47a 100644 --- a/src/api/c++/z3++.h +++ b/src/api/c++/z3++.h @@ -4030,8 +4030,12 @@ namespace z3 { */ unsigned add(expr const& e) { - assert(s); - return Z3_solver_propagate_register(ctx(), *s, e); + if (cb) + return Z3_solver_propagate_register_cb(ctx(), cb, e); + if (s) + return Z3_solver_propagate_register(ctx(), *s, e); + assert(false); + return 0; } void conflict(unsigned num_fixed, unsigned const* fixed) { diff --git a/src/api/z3_api.h b/src/api/z3_api.h index a63008051..767eebc43 100644 --- a/src/api/z3_api.h +++ b/src/api/z3_api.h @@ -6685,6 +6685,16 @@ extern "C" { unsigned Z3_API Z3_solver_propagate_register(Z3_context c, Z3_solver s, Z3_ast e); + /** + \brief register an expression to propagate on with the solver. + Only expressions of type Bool and type Bit-Vector can be registered for propagation. + Unlike \ref Z3_solver_propagate_register, this function takes a solver callback context + as argument. It can be invoked during a callback to register new expressions. + + def_API('Z3_solver_propagate_register_cb', UINT, (_in(CONTEXT), _in(SOLVER_CALLBACK), _in(AST))) + */ + unsigned Z3_API Z3_solver_propagate_register_cb(Z3_context c, Z3_solver_callback cb, Z3_ast e); + /** \brief propagate a consequence based on fixed values. This is a callback a client may invoke during the fixed_eh callback. diff --git a/src/sat/smt/user_solver.cpp b/src/sat/smt/user_solver.cpp index 5a1a07f11..edbaf6d8d 100644 --- a/src/sat/smt/user_solver.cpp +++ b/src/sat/smt/user_solver.cpp @@ -47,6 +47,10 @@ namespace user_solver { DEBUG_CODE(validate_propagation();); } + unsigned solver::register_cb(expr* e) { + return add_expr(e); + } + sat::check_result solver::check() { if (!(bool)m_final_eh) return sat::check_result::CR_DONE; diff --git a/src/sat/smt/user_solver.h b/src/sat/smt/user_solver.h index 275e33bbc..b11742608 100644 --- a/src/sat/smt/user_solver.h +++ b/src/sat/smt/user_solver.h @@ -111,6 +111,7 @@ namespace user_solver { bool has_fixed() const { return (bool)m_fixed_eh; } void propagate_cb(unsigned num_fixed, unsigned const* fixed_ids, unsigned num_eqs, unsigned const* lhs, unsigned const* rhs, expr* conseq) override; + unsigned register_cb(expr* e) override; void new_fixed_eh(euf::theory_var v, expr* value, unsigned num_lits, sat::literal const* jlits); diff --git a/src/smt/tactic/smt_tactic_core.cpp b/src/smt/tactic/smt_tactic_core.cpp index 90c580b67..9e9d3a9de 100644 --- a/src/smt/tactic/smt_tactic_core.cpp +++ b/src/smt/tactic/smt_tactic_core.cpp @@ -346,6 +346,15 @@ public: } cb->propagate_cb(num_fixed, fixed.data(), num_eqs, lhs.data(), rhs.data(), conseq); } + + unsigned register_cb(expr* e) override { + unsigned j = t->m_vars.size(); + t->m_vars.push_back(e); + unsigned i = cb->register_cb(e); + t->m_var2internal.setx(j, i, 0); + t->m_internal2var.setx(i, j, 0); + return j; + } }; callback i_cb; diff --git a/src/smt/theory_user_propagator.cpp b/src/smt/theory_user_propagator.cpp index e70e75c8c..2b50e07ab 100644 --- a/src/smt/theory_user_propagator.cpp +++ b/src/smt/theory_user_propagator.cpp @@ -59,6 +59,10 @@ void theory_user_propagator::propagate_cb( m_prop.push_back(prop_info(num_fixed, fixed_ids, num_eqs, eq_lhs, eq_rhs, expr_ref(conseq, m))); } +unsigned theory_user_propagator::register_cb(expr* e) { + return add_expr(e); +} + theory * theory_user_propagator::mk_fresh(context * new_ctx) { auto* th = alloc(theory_user_propagator, *new_ctx); void* ctx = m_fresh_eh(m_user_context, new_ctx->get_manager(), th->m_api_context); diff --git a/src/smt/theory_user_propagator.h b/src/smt/theory_user_propagator.h index 3e9db8cdd..d007de6a0 100644 --- a/src/smt/theory_user_propagator.h +++ b/src/smt/theory_user_propagator.h @@ -98,6 +98,7 @@ namespace smt { bool has_fixed() const { return (bool)m_fixed_eh; } void propagate_cb(unsigned num_fixed, unsigned const* fixed_ids, unsigned num_eqs, unsigned const* lhs, unsigned const* rhs, expr* conseq) override; + unsigned register_cb(expr* e) override; void new_fixed_eh(theory_var v, expr* value, unsigned num_lits, literal const* jlits); diff --git a/src/tactic/user_propagator_base.h b/src/tactic/user_propagator_base.h index d9645069d..899722c2a 100644 --- a/src/tactic/user_propagator_base.h +++ b/src/tactic/user_propagator_base.h @@ -9,6 +9,7 @@ namespace user_propagator { public: virtual ~callback() = default; virtual void propagate_cb(unsigned num_fixed, unsigned const* fixed_ids, unsigned num_eqs, unsigned const* eq_lhs, unsigned const* eq_rhs, expr* conseq) = 0; + virtual unsigned register_cb(expr* e) = 0; }; class context_obj {