From 2e00f2f32db7a05dbecbdb75ca7a1e1918a6a0a8 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Thu, 17 Feb 2022 09:21:41 +0200 Subject: [PATCH] Propagator (#5845) * user propagator without ids Signed-off-by: Nikolaj Bjorner * user propagator without ids Signed-off-by: Nikolaj Bjorner * fix signature Signed-off-by: Nikolaj Bjorner * references #5818 Signed-off-by: Nikolaj Bjorner * fix c++ build Signed-off-by: Nikolaj Bjorner * switch to vs 2022 Signed-off-by: Nikolaj Bjorner * switch 2022 Signed-off-by: Nikolaj Bjorner * Update propagator example (I) (#5835) * fix #5829 * na Signed-off-by: Nikolaj Bjorner * switch to vs 2022 Signed-off-by: Nikolaj Bjorner * Adapted the example to the changes in the propagator Co-authored-by: Nikolaj Bjorner * context goes out of scope in stack allocation, so can't used scoped context when passing objects around * parameter check Signed-off-by: Nikolaj Bjorner * add rewriter Signed-off-by: Nikolaj Bjorner * Fixed bug in user-propagator "created" (#5843) Co-authored-by: Clemens Eisenhofer <56730610+CEisenhofer@users.noreply.github.com> --- examples/userPropagator/example.cpp | 84 +++++++++----- src/api/api_solver.cpp | 27 +++-- src/api/c++/z3++.h | 108 +++++++++--------- src/api/python/z3/z3.py | 22 ++-- src/api/z3_api.h | 18 +-- src/sat/sat_solver/inc_sat_solver.cpp | 4 +- src/sat/smt/euf_solver.h | 4 +- src/sat/smt/user_solver.cpp | 36 +++--- src/sat/smt/user_solver.h | 17 +-- src/smt/smt_context.h | 4 +- src/smt/smt_kernel.cpp | 4 +- src/smt/smt_kernel.h | 2 +- src/smt/smt_solver.cpp | 4 +- src/smt/tactic/smt_tactic_core.cpp | 141 ++---------------------- src/smt/theory_user_propagator.cpp | 49 ++++---- src/smt/theory_user_propagator.h | 33 ++++-- src/solver/tactic2solver.cpp | 4 +- src/tactic/core/elim_uncnstr_tactic.cpp | 3 +- src/tactic/core/reduce_args_tactic.cpp | 5 +- src/tactic/tactic.h | 2 +- src/tactic/tactical.cpp | 6 +- src/tactic/user_propagator_base.h | 12 +- 22 files changed, 261 insertions(+), 328 deletions(-) diff --git a/examples/userPropagator/example.cpp b/examples/userPropagator/example.cpp index b66e3bb0e..1b6888798 100644 --- a/examples/userPropagator/example.cpp +++ b/examples/userPropagator/example.cpp @@ -50,15 +50,36 @@ struct model_hash_function { } }; +namespace std { + + template<> + struct hash { + std::size_t operator()(const z3::expr &k) const { + return k.hash(); + } + }; +} + +// Do not use Z3's == operator in the hash table +namespace std { + + template<> + struct equal_to { + bool operator()(const z3::expr &lhs, const z3::expr &rhs) const { + return z3::eq(lhs, rhs); + } + }; +} + class user_propagator : public z3::user_propagator_base { protected: unsigned board; - std::unordered_map& id_mapping; + std::unordered_map& id_mapping; model currentModel; std::unordered_set modelSet; - std::vector fixedValues; + std::vector fixedValues; std::stack fixedCnt; int solutionId = 1; @@ -70,7 +91,10 @@ public: } void final() final { - this->conflict((unsigned) fixedValues.size(), fixedValues.data()); + z3::expr_vector conflicting(fixedValues[0].ctx()); + for (auto&& v : fixedValues) + conflicting.push_back(v); + this->conflict(conflicting); if (modelSet.find(currentModel) != modelSet.end()) { WriteLine("Got already computed model"); return; @@ -91,20 +115,20 @@ public: return (unsigned)e.get_numeral_int(); } - void fixed(unsigned id, z3::expr const &e) override { - fixedValues.push_back(id); - unsigned value = bvToInt(e); - currentModel[id_mapping[id]] = value; + void fixed(z3::expr const &ast, z3::expr const &value) override { + fixedValues.push_back(ast); + unsigned valueBv = bvToInt(value); + currentModel[id_mapping[ast]] = valueBv; } - user_propagator(z3::solver *s, std::unordered_map& idMapping, unsigned board) + user_propagator(z3::solver *s, std::unordered_map& idMapping, unsigned board) : user_propagator_base(s), board(board), id_mapping(idMapping), currentModel(board, (unsigned)-1) { this->register_fixed(); this->register_final(); } - virtual ~user_propagator() = default; + ~user_propagator() = default; void push() override { fixedCnt.push((unsigned) fixedValues.size()); @@ -117,50 +141,58 @@ public: for (auto j = fixedValues.size(); j > lastCnt; j--) { currentModel[fixedValues[j - 1]] = (unsigned)-1; } - fixedValues.resize(lastCnt); + fixedValues.erase(fixedValues.cbegin() + lastCnt, fixedValues.cend()); } } - user_propagator_base *fresh(Z3_context) override { return this; } + user_propagator_base *fresh(Z3_context) override { + return this; + } }; class user_propagator_with_theory : public user_propagator { public: - void fixed(unsigned id, z3::expr const &e) override { - unsigned queenId = id_mapping[id]; - unsigned queenPos = bvToInt(e); + void fixed(z3::expr const &ast, z3::expr const &value) override { + unsigned queenId = id_mapping[ast]; + unsigned queenPos = bvToInt(value); if (queenPos >= board) { - this->conflict(1, &id); + z3::expr_vector conflicting(ast.ctx()); + conflicting.push_back(ast); + this->conflict(conflicting); return; } - for (unsigned fixed : fixedValues) { + for (z3::expr fixed : fixedValues) { unsigned otherId = id_mapping[fixed]; unsigned otherPos = currentModel[fixed]; if (queenPos == otherPos) { - const unsigned conflicting[] = {id, fixed}; - this->conflict(2, conflicting); + z3::expr_vector conflicting(ast.ctx()); + conflicting.push_back(ast); + conflicting.push_back(fixed); + this->conflict(conflicting); continue; } #ifdef QUEEN int diffY = abs((int)queenId - (int)otherId); int diffX = abs((int)queenPos - (int)otherPos); if (diffX == diffY) { - const unsigned conflicting[] = {id, fixed}; - this->conflict(2, conflicting); + z3::expr_vector conflicting(ast.ctx()); + conflicting.push_back(ast); + conflicting.push_back(fixed); + this->conflict(conflicting); } #endif } - fixedValues.push_back(id); - currentModel[id_mapping[id]] = queenPos; + fixedValues.push_back(ast); + currentModel[id_mapping[ast]] = queenPos; } - user_propagator_with_theory(z3::solver *s, std::unordered_map& idMapping, unsigned board) + user_propagator_with_theory(z3::solver *s, std::unordered_map& idMapping, unsigned board) : user_propagator(s, idMapping, board) {} }; @@ -261,7 +293,7 @@ inline int test1(unsigned num) { int test23(unsigned num, bool withTheory) { z3::context context; z3::solver solver(context, Z3_mk_simple_solver(context)); - std::unordered_map idMapping; + std::unordered_map idMapping; user_propagator *propagator; if (!withTheory) { @@ -274,8 +306,8 @@ int test23(unsigned num, bool withTheory) { std::vector queens = createQueens(context, num); for (unsigned i = 0; i < queens.size(); i++) { - unsigned id = propagator->add(queens[i]); - idMapping[id] = i; + propagator->add(queens[i]); + idMapping[queens[i]] = i; } if (!withTheory) { diff --git a/src/api/api_solver.cpp b/src/api/api_solver.cpp index 43211b3b8..99d332724 100644 --- a/src/api/api_solver.cpp +++ b/src/api/api_solver.cpp @@ -902,7 +902,7 @@ extern "C" { Z3_fixed_eh fixed_eh) { Z3_TRY; RESET_ERROR_CODE(); - user_propagator::fixed_eh_t _fixed = (void(*)(void*,user_propagator::callback*,unsigned,expr*))fixed_eh; + user_propagator::fixed_eh_t _fixed = (void(*)(void*,user_propagator::callback*,expr*,expr*))fixed_eh; to_solver_ref(s)->user_propagate_register_fixed(_fixed); Z3_CATCH; } @@ -924,7 +924,7 @@ extern "C" { Z3_eq_eh eq_eh) { Z3_TRY; RESET_ERROR_CODE(); - user_propagator::eq_eh_t _eq = (void(*)(void*,user_propagator::callback*,unsigned,unsigned))eq_eh; + user_propagator::eq_eh_t _eq = (void(*)(void*,user_propagator::callback*,expr*,expr*))eq_eh; to_solver_ref(s)->user_propagate_register_eq(_eq); Z3_CATCH; } @@ -935,39 +935,42 @@ extern "C" { Z3_eq_eh diseq_eh) { Z3_TRY; RESET_ERROR_CODE(); - user_propagator::eq_eh_t _diseq = (void(*)(void*,user_propagator::callback*,unsigned,unsigned))diseq_eh; + user_propagator::eq_eh_t _diseq = (void(*)(void*,user_propagator::callback*,expr*,expr*))diseq_eh; to_solver_ref(s)->user_propagate_register_diseq(_diseq); Z3_CATCH; } - unsigned Z3_API Z3_solver_propagate_register(Z3_context c, Z3_solver s, Z3_ast e) { + void Z3_API Z3_solver_propagate_register(Z3_context c, Z3_solver s, Z3_ast e) { Z3_TRY; LOG_Z3_solver_propagate_register(c, s, e); RESET_ERROR_CODE(); - return to_solver_ref(s)->user_propagate_register_expr(to_expr(e)); - Z3_CATCH_RETURN(0); + to_solver_ref(s)->user_propagate_register_expr(to_expr(e)); + Z3_CATCH; } - unsigned Z3_API Z3_solver_propagate_register_cb(Z3_context c, Z3_solver_callback s, Z3_ast e) { + void Z3_API Z3_solver_propagate_register_cb(Z3_context c, Z3_solver_callback s, Z3_ast e) { Z3_TRY; LOG_Z3_solver_propagate_register_cb(c, s, e); RESET_ERROR_CODE(); - return reinterpret_cast(s)->register_cb(to_expr(e)); - Z3_CATCH_RETURN(0); + reinterpret_cast(s)->register_cb(to_expr(e)); + Z3_CATCH; } - void Z3_API Z3_solver_propagate_consequence(Z3_context c, Z3_solver_callback s, unsigned num_fixed, unsigned const* fixed_ids, unsigned num_eqs, unsigned const* eq_lhs, unsigned const* eq_rhs, Z3_ast conseq) { + void Z3_API Z3_solver_propagate_consequence(Z3_context c, Z3_solver_callback s, unsigned num_fixed, Z3_ast const* fixed_ids, unsigned num_eqs, Z3_ast const* eq_lhs, Z3_ast const* eq_rhs, Z3_ast conseq) { Z3_TRY; LOG_Z3_solver_propagate_consequence(c, s, num_fixed, fixed_ids, num_eqs, eq_lhs, eq_rhs, conseq); RESET_ERROR_CODE(); - reinterpret_cast(s)->propagate_cb(num_fixed, fixed_ids, num_eqs, eq_lhs, eq_rhs, to_expr(conseq)); + expr* const * _fixed_ids = (expr* const*) fixed_ids; + expr* const * _eq_lhs = (expr*const*) eq_lhs; + expr* const * _eq_rhs = (expr*const*) eq_rhs; + reinterpret_cast(s)->propagate_cb(num_fixed, _fixed_ids, num_eqs, _eq_lhs, _eq_rhs, to_expr(conseq)); Z3_CATCH; } void Z3_API Z3_solver_propagate_created(Z3_context c, Z3_solver s, Z3_created_eh created_eh) { Z3_TRY; RESET_ERROR_CODE(); - user_propagator::created_eh_t c = (void(*)(void*, user_propagator::callback*, expr*, unsigned))created_eh; + user_propagator::created_eh_t c = (void(*)(void*, user_propagator::callback*, expr*))created_eh; to_solver_ref(s)->user_propagate_register_created(c); Z3_CATCH; } diff --git a/src/api/c++/z3++.h b/src/api/c++/z3++.h index b7547b990..0bbab4185 100644 --- a/src/api/c++/z3++.h +++ b/src/api/c++/z3++.h @@ -155,9 +155,10 @@ namespace z3 { class context { private: + friend class user_propagator_base; bool m_enable_exceptions; rounding_mode m_rounding_mode; - Z3_context m_ctx; + Z3_context m_ctx = nullptr; void init(config & c) { set_context(Z3_mk_context_rc(c)); } @@ -173,7 +174,6 @@ namespace z3 { context(context const &) = delete; context & operator=(context const &) = delete; - friend class scoped_context; context(Z3_context c) { set_context(c); } void detach() { m_ctx = nullptr; } public: @@ -394,14 +394,6 @@ namespace z3 { expr_vector parse_file(char const* s, sort_vector const& sorts, func_decl_vector const& decls); }; - class scoped_context final { - context m_ctx; - public: - scoped_context(Z3_context c): m_ctx(c) {} - ~scoped_context() { m_ctx.detach(); } - context& operator()() { return m_ctx; } - }; - template class array { @@ -509,7 +501,7 @@ namespace z3 { ast(context & c):object(c), m_ast(0) {} ast(context & c, Z3_ast n):object(c), m_ast(n) { Z3_inc_ref(ctx(), m_ast); } ast(ast const & s) :object(s), m_ast(s.m_ast) { Z3_inc_ref(ctx(), m_ast); } - ~ast() { if (m_ast) Z3_dec_ref(*m_ctx, m_ast); } + ~ast() { if (m_ast) { Z3_dec_ref(*m_ctx, m_ast); } } operator Z3_ast() const { return m_ast; } operator bool() const { return m_ast != 0; } ast & operator=(ast const & s) { @@ -3933,23 +3925,20 @@ namespace z3 { class user_propagator_base { - typedef std::function fixed_eh_t; + typedef std::function fixed_eh_t; typedef std::function final_eh_t; - typedef std::function eq_eh_t; - typedef std::function created_eh_t; + typedef std::function eq_eh_t; + typedef std::function created_eh_t; final_eh_t m_final_eh; eq_eh_t m_eq_eh; fixed_eh_t m_fixed_eh; created_eh_t m_created_eh; solver* s; - Z3_context c; + context* c; + Z3_solver_callback cb { nullptr }; - Z3_context ctx() { - return c ? c : (Z3_context)s->ctx(); - } - struct scoped_cb { user_propagator_base& p; scoped_cb(void* _p, Z3_solver_callback cb):p(*static_cast(_p)) { @@ -3972,17 +3961,19 @@ namespace z3 { return static_cast(p)->fresh(ctx); } - static void fixed_eh(void* _p, Z3_solver_callback cb, unsigned id, Z3_ast _value) { + static void fixed_eh(void* _p, Z3_solver_callback cb, Z3_ast _var, Z3_ast _value) { user_propagator_base* p = static_cast(_p); scoped_cb _cb(p, cb); - scoped_context ctx(p->ctx()); - expr value(ctx(), _value); - static_cast(p)->m_fixed_eh(id, value); + expr value(p->ctx(), _value); + expr var(p->ctx(), _var); + p->m_fixed_eh(var, value); } - static void eq_eh(void* p, Z3_solver_callback cb, unsigned x, unsigned y) { + static void eq_eh(void* _p, Z3_solver_callback cb, Z3_ast _x, Z3_ast _y) { + user_propagator_base* p = static_cast(_p); scoped_cb _cb(p, cb); - static_cast(p)->m_eq_eh(x, y); + expr x(p->ctx(), _x), y(p->ctx(), _y); + p->m_eq_eh(x, y); } static void final_eh(void* p, Z3_solver_callback cb) { @@ -3990,17 +3981,16 @@ namespace z3 { static_cast(p)->m_final_eh(); } - static void created_eh(void* _p, Z3_solver_callback cb, Z3_ast _e, unsigned id) { + static void created_eh(void* _p, Z3_solver_callback cb, Z3_ast _e) { user_propagator_base* p = static_cast(_p); scoped_cb _cb(p, cb); - scoped_context ctx(p->ctx()); - expr e(ctx(), _e); - static_cast(p)->m_created_eh(id, e); + expr e(p->ctx(), _e); + p->m_created_eh(e); } public: - user_propagator_base(Z3_context c) : s(nullptr), c(c) {} + user_propagator_base(context& c) : s(nullptr), c(&c) {} user_propagator_base(solver* s): s(s), c(nullptr) { Z3_solver_propagate_init(ctx(), *s, this, push_eh, pop_eh, fresh_eh); @@ -4011,6 +4001,10 @@ namespace z3 { virtual ~user_propagator_base() = default; + context& ctx() { + return c ? *c : s->ctx(); + } + /** \brief user_propagators created using \c fresh() are created during search and their lifetimes are restricted to search time. They should @@ -4035,7 +4029,7 @@ namespace z3 { void register_fixed() { assert(s); - m_fixed_eh = [this](unsigned id, expr const& e) { + m_fixed_eh = [this](expr const& id, expr const& e) { fixed(id, e); }; Z3_solver_propagate_fixed(ctx(), *s, fixed_eh); @@ -4049,7 +4043,7 @@ namespace z3 { void register_eq() { assert(s); - m_eq_eh = [this](unsigned x, unsigned y) { + m_eq_eh = [this](expr const& x, expr const& y) { eq(x, y); }; Z3_solver_propagate_eq(ctx(), *s, eq_eh); @@ -4084,19 +4078,19 @@ namespace z3 { } void register_created() { - m_created_eh = [this](unsigned id, expr const& e) { - created(id, e); + m_created_eh = [this](expr const& e) { + created(e); }; Z3_solver_propagate_created(ctx(), *s, created_eh); } - virtual void fixed(unsigned /*id*/, expr const& /*e*/) { } + virtual void fixed(expr const& /*id*/, expr const& /*e*/) { } - virtual void eq(unsigned /*x*/, unsigned /*y*/) { } + virtual void eq(expr const& /*x*/, expr const& /*y*/) { } virtual void final() { } - virtual void created(unsigned /*id*/, expr const& /*e*/) {} + virtual void created(expr const& /*e*/) {} /** \brief tracks \c e by a unique identifier that is returned by the call. @@ -4112,34 +4106,40 @@ namespace z3 { correspond to equalities that have been registered during a callback. */ - unsigned add(expr const& e) { + void add(expr const& e) { if (cb) - return Z3_solver_propagate_register_cb(ctx(), cb, e); - if (s) - return Z3_solver_propagate_register(ctx(), *s, e); - assert(false); - return 0; + Z3_solver_propagate_register_cb(ctx(), cb, e); + else if (s) + Z3_solver_propagate_register(ctx(), *s, e); + else + assert(false); } - void conflict(unsigned num_fixed, unsigned const* fixed) { + void conflict(expr_vector const& fixed) { assert(cb); - scoped_context _ctx(ctx()); - expr conseq = _ctx().bool_val(false); - Z3_solver_propagate_consequence(ctx(), cb, num_fixed, fixed, 0, nullptr, nullptr, conseq); + expr conseq = ctx().bool_val(false); + array _fixed(fixed); + Z3_solver_propagate_consequence(ctx(), cb, fixed.size(), _fixed.ptr(), 0, nullptr, nullptr, conseq); } - void propagate(unsigned num_fixed, unsigned const* fixed, expr const& conseq) { + void propagate(expr_vector const& fixed, expr const& conseq) { assert(cb); - assert(conseq.ctx() == ctx()); - Z3_solver_propagate_consequence(ctx(), cb, num_fixed, fixed, 0, nullptr, nullptr, conseq); + assert((Z3_context)conseq.ctx() == (Z3_context)ctx()); + array _fixed(fixed); + Z3_solver_propagate_consequence(ctx(), cb, _fixed.size(), _fixed.ptr(), 0, nullptr, nullptr, conseq); } - void propagate(unsigned num_fixed, unsigned const* fixed, - unsigned num_eqs, unsigned const* lhs, unsigned const * rhs, + void propagate(expr_vector const& fixed, + expr_vector const& lhs, expr_vector const& rhs, expr const& conseq) { assert(cb); - assert(conseq.ctx() == ctx()); - Z3_solver_propagate_consequence(ctx(), cb, num_fixed, fixed, num_eqs, lhs, rhs, conseq); + assert((Z3_context)conseq.ctx() == (Z3_context)ctx()); + assert(lhs.size() == rhs.size()); + array _fixed(fixed); + array _lhs(lhs); + array _rhs(rhs); + + Z3_solver_propagate_consequence(ctx(), cb, _fixed.size(), _fixed.ptr(), lhs.size(), _lhs.ptr(), _rhs.ptr(), conseq); } }; diff --git a/src/api/python/z3/z3.py b/src/api/python/z3/z3.py index 727f289a1..1a35865c9 100644 --- a/src/api/python/z3/z3.py +++ b/src/api/python/z3/z3.py @@ -11261,7 +11261,7 @@ def user_prop_fresh(id, ctx): def user_prop_fixed(ctx, cb, id, value): prop = _prop_closures.get(ctx) prop.cb = cb - prop.fixed(id, _to_expr_ref(ctypes.c_void_p(value), prop.ctx())) + prop.fixed(_to_expr_ref(ctypes.c_void_p(id), prop.ctx()), _to_expr_ref(ctypes.c_void_p(value), prop.ctx())) prop.cb = None @@ -11275,6 +11275,8 @@ def user_prop_final(ctx, cb): def user_prop_eq(ctx, cb, x, y): prop = _prop_closures.get(ctx) prop.cb = cb + x = _to_expr_ref(ctypes.c_void_p(x), prop.ctx()) + y = _to_expr_ref(ctypes.c_void_p(y), prop.ctx()) prop.eq(x, y) prop.cb = None @@ -11282,6 +11284,8 @@ def user_prop_eq(ctx, cb, x, y): def user_prop_diseq(ctx, cb, x, y): prop = _prop_closures.get(ctx) prop.cb = cb + x = _to_expr_ref(ctypes.c_void_p(x), prop.ctx()) + y = _to_expr_ref(ctypes.c_void_p(y), prop.ctx()) prop.diseq(x, y) prop.cb = None @@ -11385,18 +11389,12 @@ class UserPropagateBase: # Propagation can only be invoked as during a fixed or final callback. # def propagate(self, e, ids, eqs=[]): - num_fixed = len(ids) - _ids = (ctypes.c_uint * num_fixed)() - for i in range(num_fixed): - _ids[i] = ids[i] + _ids, num_fixed = _to_ast_array(ids) num_eqs = len(eqs) - _lhs = (ctypes.c_uint * num_eqs)() - _rhs = (ctypes.c_uint * num_eqs)() - for i in range(num_eqs): - _lhs[i] = eqs[i][0] - _rhs[i] = eqs[i][1] + _lhs, _num_lhs = _to_ast_array([x for x, y in eqs]) + _rhs, _num_lhs = _to_ast_array([y for x, y in eqs]) Z3_solver_propagate_consequence(e.ctx.ref(), ctypes.c_void_p( self.cb), num_fixed, _ids, num_eqs, _lhs, _rhs, e.ast) - def conflict(self, ids): - self.propagate(BoolVal(False, self.ctx()), ids, eqs=[]) + def conflict(self, deps): + self.propagate(BoolVal(False, self.ctx()), deps, eqs=[]) diff --git a/src/api/z3_api.h b/src/api/z3_api.h index 16d3a292c..95c04ea17 100644 --- a/src/api/z3_api.h +++ b/src/api/z3_api.h @@ -1433,10 +1433,10 @@ Z3_DECLARE_CLOSURE(Z3_error_handler, void, (Z3_context c, Z3_error_code e)); Z3_DECLARE_CLOSURE(Z3_push_eh, void, (void* ctx)); Z3_DECLARE_CLOSURE(Z3_pop_eh, void, (void* ctx, unsigned num_scopes)); Z3_DECLARE_CLOSURE(Z3_fresh_eh, void*, (void* ctx, Z3_context new_context)); -Z3_DECLARE_CLOSURE(Z3_fixed_eh, void, (void* ctx, Z3_solver_callback cb, unsigned id, Z3_ast value)); -Z3_DECLARE_CLOSURE(Z3_eq_eh, void, (void* ctx, Z3_solver_callback cb, unsigned x, unsigned y)); +Z3_DECLARE_CLOSURE(Z3_fixed_eh, void, (void* ctx, Z3_solver_callback cb, Z3_ast t, Z3_ast value)); +Z3_DECLARE_CLOSURE(Z3_eq_eh, void, (void* ctx, Z3_solver_callback cb, Z3_ast s, Z3_ast t)); Z3_DECLARE_CLOSURE(Z3_final_eh, void, (void* ctx, Z3_solver_callback cb)); -Z3_DECLARE_CLOSURE(Z3_created_eh, void, (void* ctx, Z3_solver_callback cb, Z3_ast e, unsigned id)); +Z3_DECLARE_CLOSURE(Z3_created_eh, void, (void* ctx, Z3_solver_callback cb, Z3_ast t)); /** @@ -6761,10 +6761,10 @@ extern "C" { \brief register an expression to propagate on with the solver. Only expressions of type Bool and type Bit-Vector can be registered for propagation. - def_API('Z3_solver_propagate_register', UINT, (_in(CONTEXT), _in(SOLVER), _in(AST))) + def_API('Z3_solver_propagate_register', VOID, (_in(CONTEXT), _in(SOLVER), _in(AST))) */ - unsigned Z3_API Z3_solver_propagate_register(Z3_context c, Z3_solver s, Z3_ast e); + void Z3_API Z3_solver_propagate_register(Z3_context c, Z3_solver s, Z3_ast e); /** \brief register an expression to propagate on with the solver. @@ -6772,9 +6772,9 @@ extern "C" { Unlike \ref Z3_solver_propagate_register, this function takes a solver callback context as argument. It can be invoked during a callback to register new expressions. - def_API('Z3_solver_propagate_register_cb', UINT, (_in(CONTEXT), _in(SOLVER_CALLBACK), _in(AST))) + def_API('Z3_solver_propagate_register_cb', VOID, (_in(CONTEXT), _in(SOLVER_CALLBACK), _in(AST))) */ - unsigned Z3_API Z3_solver_propagate_register_cb(Z3_context c, Z3_solver_callback cb, Z3_ast e); + void Z3_API Z3_solver_propagate_register_cb(Z3_context c, Z3_solver_callback cb, Z3_ast e); /** \brief propagate a consequence based on fixed values. @@ -6782,10 +6782,10 @@ extern "C" { The callback adds a propagation consequence based on the fixed values of the \c ids. - def_API('Z3_solver_propagate_consequence', VOID, (_in(CONTEXT), _in(SOLVER_CALLBACK), _in(UINT), _in_array(2, UINT), _in(UINT), _in_array(4, UINT), _in_array(4, UINT), _in(AST))) + def_API('Z3_solver_propagate_consequence', VOID, (_in(CONTEXT), _in(SOLVER_CALLBACK), _in(UINT), _in_array(2, AST), _in(UINT), _in_array(4, AST), _in_array(4, AST), _in(AST))) */ - void Z3_API Z3_solver_propagate_consequence(Z3_context c, Z3_solver_callback, unsigned num_fixed, unsigned const* fixed_ids, unsigned num_eqs, unsigned const* eq_lhs, unsigned const* eq_rhs, Z3_ast conseq); + void Z3_API Z3_solver_propagate_consequence(Z3_context c, Z3_solver_callback, unsigned num_fixed, Z3_ast const* fixed, unsigned num_eqs, Z3_ast const* eq_lhs, Z3_ast const* eq_rhs, Z3_ast conseq); /** \brief Check whether the assertions in a given solver are consistent or not. diff --git a/src/sat/sat_solver/inc_sat_solver.cpp b/src/sat/sat_solver/inc_sat_solver.cpp index db1f28743..61dc53cd8 100644 --- a/src/sat/sat_solver/inc_sat_solver.cpp +++ b/src/sat/sat_solver/inc_sat_solver.cpp @@ -683,8 +683,8 @@ public: ensure_euf()->user_propagate_register_diseq(diseq_eh); } - unsigned user_propagate_register_expr(expr* e) override { - return ensure_euf()->user_propagate_register_expr(e); + void user_propagate_register_expr(expr* e) override { + ensure_euf()->user_propagate_register_expr(e); } void user_propagate_register_created(user_propagator::created_eh_t& r) override { diff --git a/src/sat/smt/euf_solver.h b/src/sat/smt/euf_solver.h index d299c92d0..669eb1616 100644 --- a/src/sat/smt/euf_solver.h +++ b/src/sat/smt/euf_solver.h @@ -434,9 +434,9 @@ namespace euf { check_for_user_propagator(); m_user_propagator->register_created(ceh); } - unsigned user_propagate_register_expr(expr* e) { + void user_propagate_register_expr(expr* e) { check_for_user_propagator(); - return m_user_propagator->add_expr(e); + m_user_propagator->add_expr(e); } // solver factory diff --git a/src/sat/smt/user_solver.cpp b/src/sat/smt/user_solver.cpp index febbe9383..6b3eb6718 100644 --- a/src/sat/smt/user_solver.cpp +++ b/src/sat/smt/user_solver.cpp @@ -28,31 +28,33 @@ namespace user_solver { dealloc(m_api_context); } - unsigned solver::add_expr(expr* e) { + void solver::add_expr(expr* e) { force_push(); ctx.internalize(e, false); euf::enode* n = expr2enode(e); if (is_attached_to_var(n)) - return n->get_th_var(get_id()); + return; euf::theory_var v = mk_var(n); ctx.attach_th_var(n, this, v); expr_ref r(m); sat::literal_vector explain; if (ctx.is_fixed(n, r, explain)) - m_prop.push_back(prop_info(explain, v, r)); - return v; + m_prop.push_back(prop_info(explain, v, r)); } void solver::propagate_cb( - unsigned num_fixed, unsigned const* fixed_ids, - unsigned num_eqs, unsigned const* eq_lhs, unsigned const* eq_rhs, + unsigned num_fixed, expr* const* fixed_ids, + unsigned num_eqs, expr* const* eq_lhs, expr* const* eq_rhs, expr* conseq) { - m_prop.push_back(prop_info(num_fixed, fixed_ids, num_eqs, eq_lhs, eq_rhs, expr_ref(conseq, m))); + m_fixed_ids.reset(); + for (unsigned i = 0; i < num_fixed; ++i) + m_fixed_ids.push_back(get_th_var(fixed_ids[i])); + m_prop.push_back(prop_info(num_fixed, m_fixed_ids.data(), num_eqs, eq_lhs, eq_rhs, expr_ref(conseq, m))); DEBUG_CODE(validate_propagation();); } - unsigned solver::register_cb(expr* e) { - return add_expr(e); + void solver::register_cb(expr* e) { + add_expr(e); } sat::check_result solver::check() { @@ -68,7 +70,7 @@ namespace user_solver { return; force_push(); m_id2justification.setx(v, sat::literal_vector(num_lits, jlits), sat::literal_vector()); - m_fixed_eh(m_user_context, this, v, value); + m_fixed_eh(m_user_context, this, var2expr(v), value); } void solver::asserted(sat::literal lit) { @@ -80,7 +82,7 @@ namespace user_solver { sat::literal_vector lits; lits.push_back(lit); m_id2justification.setx(v, lits, sat::literal_vector()); - m_fixed_eh(m_user_context, this, v, lit.sign() ? m.mk_false() : m.mk_true()); + m_fixed_eh(m_user_context, this, var2expr(v), lit.sign() ? m.mk_false() : m.mk_true()); } void solver::push_core() { @@ -141,9 +143,9 @@ namespace user_solver { auto& j = justification::from_index(idx); auto const& prop = m_prop[j.m_propagation_index]; for (unsigned id : prop.m_ids) - r.append(m_id2justification[id]); + r.append(m_id2justification[id]); for (auto const& p : prop.m_eqs) - ctx.add_antecedent(var2enode(p.first), var2enode(p.second)); + ctx.add_antecedent(expr2enode(p.first), expr2enode(p.second)); } /* @@ -156,7 +158,7 @@ namespace user_solver { for (auto lit: m_id2justification[id]) VERIFY(s().value(lit) == l_true); for (auto const& p : prop.m_eqs) - VERIFY(var2enode(p.first)->get_root() == var2enode(p.second)->get_root()); + VERIFY(expr2enode(p.first)->get_root() == expr2enode(p.second)->get_root()); } std::ostream& solver::display(std::ostream& out) const { @@ -171,7 +173,7 @@ namespace user_solver { for (unsigned id : prop.m_ids) out << id << ": " << m_id2justification[id]; for (auto const& p : prop.m_eqs) - out << "v" << p.first << " == v" << p.second << " "; + out << "v" << mk_pp(p.first, m) << " == v" << mk_pp(p.second, m) << " "; return out; } @@ -224,9 +226,9 @@ namespace user_solver { SASSERT(!n || !n->is_attached_to(get_id())); if (!n) n = mk_enode(e, false); - auto v = add_expr(e); + add_expr(e); if (m_created_eh) - m_created_eh(m_user_context, this, e, v); + m_created_eh(m_user_context, this, e); return true; } diff --git a/src/sat/smt/user_solver.h b/src/sat/smt/user_solver.h index a30bc6a6d..13948db81 100644 --- a/src/sat/smt/user_solver.h +++ b/src/sat/smt/user_solver.h @@ -29,13 +29,13 @@ namespace user_solver { class solver : public euf::th_euf_solver, public user_propagator::callback { struct prop_info { - unsigned_vector m_ids; - expr_ref m_conseq; - svector> m_eqs; + unsigned_vector m_ids; + expr_ref m_conseq; + svector> m_eqs; sat::literal_vector m_lits; - euf::theory_var m_var = euf::null_theory_var; + euf::theory_var m_var = euf::null_theory_var; - prop_info(unsigned num_fixed, unsigned const* fixed_ids, unsigned num_eqs, unsigned const* eq_lhs, unsigned const* eq_rhs, expr_ref const& c): + prop_info(unsigned num_fixed, unsigned const* fixed_ids, unsigned num_eqs, expr* const* eq_lhs, expr* const* eq_rhs, expr_ref const& c): m_ids(num_fixed, fixed_ids), m_conseq(c) { @@ -72,6 +72,7 @@ namespace user_solver { vector m_id2justification; sat::literal_vector m_lits; euf::enode_pair_vector m_eqs; + unsigned_vector m_fixed_ids; stats m_stats; struct justification { @@ -118,7 +119,7 @@ namespace user_solver { m_fresh_eh = fresh_eh; } - unsigned add_expr(expr* e); + void add_expr(expr* e); void register_final(user_propagator::final_eh_t& final_eh) { m_final_eh = final_eh; } void register_fixed(user_propagator::fixed_eh_t& fixed_eh) { m_fixed_eh = fixed_eh; } @@ -128,8 +129,8 @@ namespace user_solver { bool has_fixed() const { return (bool)m_fixed_eh; } - void propagate_cb(unsigned num_fixed, unsigned const* fixed_ids, unsigned num_eqs, unsigned const* lhs, unsigned const* rhs, expr* conseq) override; - unsigned register_cb(expr* e) override; + void propagate_cb(unsigned num_fixed, expr* const* fixed_ids, unsigned num_eqs, expr* const* lhs, expr* const* rhs, expr* conseq) override; + void register_cb(expr* e) override; void new_fixed_eh(euf::theory_var v, expr* value, unsigned num_lits, sat::literal const* jlits); diff --git a/src/smt/smt_context.h b/src/smt/smt_context.h index b8c8a7ee9..12f0cc2ad 100644 --- a/src/smt/smt_context.h +++ b/src/smt/smt_context.h @@ -1726,10 +1726,10 @@ namespace smt { m_user_propagator->register_diseq(diseq_eh); } - unsigned user_propagate_register_expr(expr* e) { + void user_propagate_register_expr(expr* e) { if (!m_user_propagator) throw default_exception("user propagator must be initialized"); - return m_user_propagator->add_expr(e); + m_user_propagator->add_expr(e); } void user_propagate_register_created(user_propagator::created_eh_t& r) { diff --git a/src/smt/smt_kernel.cpp b/src/smt/smt_kernel.cpp index 347260c84..87e5fd36d 100644 --- a/src/smt/smt_kernel.cpp +++ b/src/smt/smt_kernel.cpp @@ -276,8 +276,8 @@ namespace smt { m_imp->m_kernel.user_propagate_register_diseq(diseq_eh); } - unsigned kernel::user_propagate_register_expr(expr* e) { - return m_imp->m_kernel.user_propagate_register_expr(e); + void kernel::user_propagate_register_expr(expr* e) { + m_imp->m_kernel.user_propagate_register_expr(e); } void kernel::user_propagate_register_created(user_propagator::created_eh_t& r) { diff --git a/src/smt/smt_kernel.h b/src/smt/smt_kernel.h index 28680e7a6..77ad2559c 100644 --- a/src/smt/smt_kernel.h +++ b/src/smt/smt_kernel.h @@ -307,7 +307,7 @@ namespace smt { void user_propagate_register_diseq(user_propagator::eq_eh_t& diseq_eh); - unsigned user_propagate_register_expr(expr* e); + void user_propagate_register_expr(expr* e); void user_propagate_register_created(user_propagator::created_eh_t& r); diff --git a/src/smt/smt_solver.cpp b/src/smt/smt_solver.cpp index 1cb9b26e1..ad67c19d1 100644 --- a/src/smt/smt_solver.cpp +++ b/src/smt/smt_solver.cpp @@ -236,8 +236,8 @@ namespace { m_context.user_propagate_register_diseq(diseq_eh); } - unsigned user_propagate_register_expr(expr* e) override { - return m_context.user_propagate_register_expr(e); + void user_propagate_register_expr(expr* e) override { + m_context.user_propagate_register_expr(e); } void user_propagate_register_created(user_propagator::created_eh_t& c) override { diff --git a/src/smt/tactic/smt_tactic_core.cpp b/src/smt/tactic/smt_tactic_core.cpp index bf2ea9bd6..072e1ed24 100644 --- a/src/smt/tactic/smt_tactic_core.cpp +++ b/src/smt/tactic/smt_tactic_core.cpp @@ -40,6 +40,7 @@ class smt_tactic : public tactic { ast_manager& m; smt_params m_params; params_ref m_params_ref; + expr_ref_vector m_vars; statistics m_stats; smt::kernel* m_ctx = nullptr; symbol m_logic; @@ -321,141 +322,20 @@ public: user_propagator::eq_eh_t m_eq_eh; user_propagator::eq_eh_t m_diseq_eh; user_propagator::created_eh_t m_created_eh; - - expr_ref_vector m_vars; - unsigned_vector m_var2internal; - unsigned_vector m_internal2var; - unsigned_vector m_limit; - user_propagator::push_eh_t i_push_eh; - user_propagator::pop_eh_t i_pop_eh; - user_propagator::fixed_eh_t i_fixed_eh; - user_propagator::final_eh_t i_final_eh; - user_propagator::eq_eh_t i_eq_eh; - user_propagator::eq_eh_t i_diseq_eh; - user_propagator::created_eh_t i_created_eh; - - - struct callback : public user_propagator::callback { - smt_tactic* t = nullptr; - user_propagator::callback* cb = nullptr; - unsigned_vector fixed, lhs, rhs; - void propagate_cb(unsigned num_fixed, unsigned const* fixed_ids, unsigned num_eqs, unsigned const* eq_lhs, unsigned const* eq_rhs, expr* conseq) override { - fixed.reset(); - lhs.reset(); - rhs.reset(); - for (unsigned i = 0; i < num_fixed; ++i) - fixed.push_back(t->m_var2internal[fixed_ids[i]]); - for (unsigned i = 0; i < num_eqs; ++i) { - lhs.push_back(t->m_var2internal[eq_lhs[i]]); - rhs.push_back(t->m_var2internal[eq_rhs[i]]); - } - cb->propagate_cb(num_fixed, fixed.data(), num_eqs, lhs.data(), rhs.data(), conseq); - } - - unsigned register_cb(expr* e) override { - unsigned j = t->m_vars.size(); - t->m_vars.push_back(e); - unsigned i = cb->register_cb(e); - t->m_var2internal.setx(j, i, 0); - t->m_internal2var.setx(i, j, 0); - return j; - } - }; - - callback i_cb; - - void init_i_fixed_eh() { - if (!m_fixed_eh) - return; - i_fixed_eh = [this](void* ctx, user_propagator::callback* cb, unsigned id, expr* value) { - i_cb.t = this; - i_cb.cb = cb; - m_fixed_eh(ctx, &i_cb, m_internal2var[id], value); - }; - m_ctx->user_propagate_register_fixed(i_fixed_eh); - } - - void init_i_final_eh() { - if (!m_final_eh) - return; - i_final_eh = [this](void* ctx, user_propagator::callback* cb) { - i_cb.t = this; - i_cb.cb = cb; - m_final_eh(ctx, &i_cb); - }; - m_ctx->user_propagate_register_final(i_final_eh); - } - - void init_i_eq_eh() { - if (!m_eq_eh) - return; - i_eq_eh = [this](void* ctx, user_propagator::callback* cb, unsigned u, unsigned v) { - i_cb.t = this; - i_cb.cb = cb; - m_eq_eh(ctx, &i_cb, m_internal2var[u], m_internal2var[v]); - }; - m_ctx->user_propagate_register_eq(i_eq_eh); - } - - void init_i_diseq_eh() { - if (!m_diseq_eh) - return; - i_diseq_eh = [this](void* ctx, user_propagator::callback* cb, unsigned u, unsigned v) { - i_cb.t = this; - i_cb.cb = cb; - m_diseq_eh(ctx, &i_cb, m_internal2var[u], m_internal2var[v]); - }; - m_ctx->user_propagate_register_diseq(i_diseq_eh); - } - - void init_i_created_eh() { - if (!m_created_eh) - return; - i_created_eh = [this](void* ctx, user_propagator::callback* cb, expr* e, unsigned i) { - unsigned j = m_vars.size(); - m_vars.push_back(e); - m_internal2var.setx(i, j, 0); - m_var2internal.setx(j, i, 0); - m_created_eh(ctx, cb, e, j); - }; - m_ctx->user_propagate_register_created(i_created_eh); - } - - void init_i_push_pop() { - i_push_eh = [this](void* ctx) { - m_limit.push_back(m_vars.size()); - m_push_eh(ctx); - }; - i_pop_eh = [this](void* ctx, unsigned n) { - unsigned old_sz = m_limit.size() - n; - unsigned num_vars = m_limit[old_sz]; - m_vars.shrink(num_vars); - m_limit.shrink(old_sz); - m_pop_eh(ctx, n); - }; - } - - void user_propagate_delay_init() { if (!m_user_ctx) return; - init_i_push_pop(); - m_ctx->user_propagate_init(m_user_ctx, i_push_eh, i_pop_eh, m_fresh_eh); - init_i_fixed_eh(); - init_i_final_eh(); - init_i_eq_eh(); - init_i_diseq_eh(); - init_i_created_eh(); + m_ctx->user_propagate_init(m_user_ctx, m_push_eh, m_pop_eh, m_fresh_eh); + if (m_fixed_eh) m_ctx->user_propagate_register_fixed(m_fixed_eh); + if (m_final_eh) m_ctx->user_propagate_register_final(m_final_eh); + if (m_eq_eh) m_ctx->user_propagate_register_eq(m_eq_eh); + if (m_diseq_eh) m_ctx->user_propagate_register_diseq(m_diseq_eh); + if (m_created_eh) m_ctx->user_propagate_register_created(m_created_eh); - unsigned i = 0; - for (expr* v : m_vars) { - unsigned j = m_ctx->user_propagate_register_expr(v); - m_var2internal.setx(i, j, 0); - m_internal2var.setx(j, i, 0); - ++i; - } + for (expr* v : m_vars) + m_ctx->user_propagate_register_expr(v); } void user_propagate_clear() override { @@ -496,9 +376,8 @@ public: m_diseq_eh = diseq_eh; } - unsigned user_propagate_register_expr(expr* e) override { + void user_propagate_register_expr(expr* e) override { m_vars.push_back(e); - return m_vars.size() - 1; } void user_propagate_register_created(user_propagator::created_eh_t& created_eh) override { diff --git a/src/smt/theory_user_propagator.cpp b/src/smt/theory_user_propagator.cpp index 5e5bd1e1d..85a02b154 100644 --- a/src/smt/theory_user_propagator.cpp +++ b/src/smt/theory_user_propagator.cpp @@ -23,7 +23,8 @@ Author: using namespace smt; theory_user_propagator::theory_user_propagator(context& ctx): - theory(ctx, ctx.get_manager().mk_family_id(user_propagator::plugin::name())) + theory(ctx, ctx.get_manager().mk_family_id(user_propagator::plugin::name())), + m_var2expr(ctx.get_manager()) {} theory_user_propagator::~theory_user_propagator() { @@ -38,9 +39,10 @@ void theory_user_propagator::force_push() { } } -unsigned theory_user_propagator::add_expr(expr* e) { +void theory_user_propagator::add_expr(expr* term) { force_push(); expr_ref r(m); + expr* e = term; ctx.get_rewriter()(e, r); if (r != e) { r = m.mk_fresh_const("aux-expr", e->get_sort()); @@ -52,8 +54,14 @@ unsigned theory_user_propagator::add_expr(expr* e) { } enode* n = ensure_enode(e); if (is_attached_to_var(n)) - return n->get_th_var(get_id()); + return; + + theory_var v = mk_var(n); + m_var2expr.reserve(v + 1); + m_var2expr[v] = term; + m_expr2var.setx(term->get_id(), v, null_theory_var); + if (m.is_bool(e) && !ctx.b_internalized(e)) { bool_var bv = ctx.mk_bool_var(e); ctx.set_var_theory(bv, get_id()); @@ -65,22 +73,24 @@ unsigned theory_user_propagator::add_expr(expr* e) { literal_vector explain; if (ctx.is_fixed(n, r, explain)) m_prop.push_back(prop_info(explain, v, r)); - return v; + } void theory_user_propagator::propagate_cb( - unsigned num_fixed, unsigned const* fixed_ids, - unsigned num_eqs, unsigned const* eq_lhs, unsigned const* eq_rhs, + unsigned num_fixed, expr* const* fixed_ids, + unsigned num_eqs, expr* const* eq_lhs, expr* const* eq_rhs, expr* conseq) { CTRACE("user_propagate", ctx.lit_internalized(conseq) && ctx.get_assignment(ctx.get_literal(conseq)) == l_true, ctx.display(tout << "redundant consequence: " << mk_pp(conseq, m) << "\n")); - if (ctx.lit_internalized(conseq) && ctx.get_assignment(ctx.get_literal(conseq)) == l_true) + expr_ref _conseq(conseq, m); + ctx.get_rewriter()(conseq, _conseq); + if (ctx.lit_internalized(_conseq) && ctx.get_assignment(ctx.get_literal(_conseq)) == l_true) return; - m_prop.push_back(prop_info(num_fixed, fixed_ids, num_eqs, eq_lhs, eq_rhs, expr_ref(conseq, m))); + m_prop.push_back(prop_info(num_fixed, fixed_ids, num_eqs, eq_lhs, eq_rhs, _conseq)); } -unsigned theory_user_propagator::register_cb(expr* e) { - return add_expr(e); +void theory_user_propagator::register_cb(expr* e) { + add_expr(e); } theory * theory_user_propagator::mk_fresh(context * new_ctx) { @@ -91,6 +101,7 @@ theory * theory_user_propagator::mk_fresh(context * new_ctx) { if ((bool)m_final_eh) th->register_final(m_final_eh); if ((bool)m_eq_eh) th->register_eq(m_eq_eh); if ((bool)m_diseq_eh) th->register_diseq(m_diseq_eh); + if ((bool)m_created_eh) th->register_created(m_created_eh); return th; } @@ -114,7 +125,7 @@ void theory_user_propagator::new_fixed_eh(theory_var v, expr* value, unsigned nu m_fixed.insert(v); ctx.push_trail(insert_map(m_fixed, v)); m_id2justification.setx(v, literal_vector(num_lits, jlits), literal_vector()); - m_fixed_eh(m_user_context, this, v, value); + m_fixed_eh(m_user_context, this, var2expr(v), value); } void theory_user_propagator::push_scope_eh() { @@ -142,12 +153,12 @@ void theory_user_propagator::propagate_consequence(prop_info const& prop) { justification* js; m_lits.reset(); m_eqs.reset(); - for (unsigned id : prop.m_ids) - m_lits.append(m_id2justification[id]); + for (expr* id : prop.m_ids) + m_lits.append(m_id2justification[expr2var(id)]); for (auto const& p : prop.m_eqs) - m_eqs.push_back(enode_pair(get_enode(p.first), get_enode(p.second))); + m_eqs.push_back(enode_pair(get_enode(expr2var(p.first)), get_enode(expr2var(p.second)))); DEBUG_CODE(for (auto const& p : m_eqs) VERIFY(p.first->get_root() == p.second->get_root());); - DEBUG_CODE(for (unsigned id : prop.m_ids) VERIFY(m_fixed.contains(id));); + DEBUG_CODE(for (expr* e : prop.m_ids) VERIFY(m_fixed.contains(expr2var(e)));); DEBUG_CODE(for (literal lit : m_lits) VERIFY(ctx.get_assignment(lit) == l_true);); TRACE("user_propagate", tout << "propagating #" << prop.m_conseq->get_id() << ": " << prop.m_conseq << "\n"); @@ -216,12 +227,12 @@ bool theory_user_propagator::internalize_term(app* term) { if (term->get_family_id() == get_id() && !ctx.e_internalized(term)) ctx.mk_enode(term, true, false, true); - unsigned v = add_expr(term); + add_expr(term); - if (!m_created_eh && (m_fixed_eh || m_eq_eh || m_diseq_eh)) - throw default_exception("You have to register a created event handler for new terms if you track them"); + if (!m_created_eh && (m_fixed_eh || m_eq_eh || m_diseq_eh)) + return true; if (m_created_eh) - m_created_eh(m_user_context, this, term, v); + m_created_eh(m_user_context, this, term); return true; } diff --git a/src/smt/theory_user_propagator.h b/src/smt/theory_user_propagator.h index f1e558256..e1aa33b8e 100644 --- a/src/smt/theory_user_propagator.h +++ b/src/smt/theory_user_propagator.h @@ -30,13 +30,13 @@ namespace smt { class theory_user_propagator : public theory, public user_propagator::callback { struct prop_info { - unsigned_vector m_ids; + ptr_vector m_ids; expr_ref m_conseq; - svector> m_eqs; + svector> m_eqs; literal_vector m_lits; - theory_var m_var = null_theory_var; - prop_info(unsigned num_fixed, unsigned const* fixed_ids, - unsigned num_eqs, unsigned const* eq_lhs, unsigned const* eq_rhs, expr_ref const& c): + theory_var m_var = null_theory_var; + prop_info(unsigned num_fixed, expr* const* fixed_ids, + unsigned num_eqs, expr* const* eq_lhs, expr* const* eq_rhs, expr_ref const& c): m_ids(num_fixed, fixed_ids), m_conseq(c) { for (unsigned i = 0; i < num_eqs; ++i) @@ -64,7 +64,7 @@ namespace smt { user_propagator::fixed_eh_t m_fixed_eh; user_propagator::eq_eh_t m_eq_eh; user_propagator::eq_eh_t m_diseq_eh; - user_propagator::created_eh_t m_created_eh; + user_propagator::created_eh_t m_created_eh; user_propagator::context_obj* m_api_context = nullptr; unsigned m_qhead = 0; @@ -76,6 +76,15 @@ namespace smt { literal_vector m_lits; enode_pair_vector m_eqs; stats m_stats; + expr_ref_vector m_var2expr; + unsigned_vector m_expr2var; + + expr* var2expr(theory_var v) { return m_var2expr.get(v); } + theory_var expr2var(expr* e) { check_defined(e); return m_expr2var[e->get_id()]; } + void check_defined(expr* e) { + if (e->get_id() >= m_expr2var.size() || get_num_vars() <= m_expr2var[e->get_id()]) + throw default_exception("expression is not registered"); + } void force_push(); @@ -101,7 +110,7 @@ namespace smt { m_fresh_eh = fresh_eh; } - unsigned add_expr(expr* e); + void add_expr(expr* e); void register_final(user_propagator::final_eh_t& final_eh) { m_final_eh = final_eh; } void register_fixed(user_propagator::fixed_eh_t& fixed_eh) { m_fixed_eh = fixed_eh; } @@ -110,17 +119,17 @@ namespace smt { void register_created(user_propagator::created_eh_t& created_eh) { m_created_eh = created_eh; } bool has_fixed() const { return (bool)m_fixed_eh; } - - void propagate_cb(unsigned num_fixed, unsigned const* fixed_ids, unsigned num_eqs, unsigned const* lhs, unsigned const* rhs, expr* conseq) override; - unsigned register_cb(expr* e) override; + + void propagate_cb(unsigned num_fixed, expr* const* fixed_ids, unsigned num_eqs, expr* const* lhs, expr* const* rhs, expr* conseq) override; + void register_cb(expr* e) override; void new_fixed_eh(theory_var v, expr* value, unsigned num_lits, literal const* jlits); theory * mk_fresh(context * new_ctx) override; bool internalize_atom(app* atom, bool gate_ctx) override; bool internalize_term(app* term) override; - void new_eq_eh(theory_var v1, theory_var v2) override { if (m_eq_eh) m_eq_eh(m_user_context, this, v1, v2); } - void new_diseq_eh(theory_var v1, theory_var v2) override { if (m_diseq_eh) m_diseq_eh(m_user_context, this, v1, v2); } + void new_eq_eh(theory_var v1, theory_var v2) override { if (m_eq_eh) m_eq_eh(m_user_context, this, var2expr(v1), var2expr(v2)); } + void new_diseq_eh(theory_var v1, theory_var v2) override { if (m_diseq_eh) m_diseq_eh(m_user_context, this, var2expr(v1), var2expr(v2)); } bool use_diseqs() const override { return ((bool)m_diseq_eh); } bool build_models() const override { return false; } final_check_status final_check_eh() override; diff --git a/src/solver/tactic2solver.cpp b/src/solver/tactic2solver.cpp index 6ed570297..a2909fd7b 100644 --- a/src/solver/tactic2solver.cpp +++ b/src/solver/tactic2solver.cpp @@ -108,8 +108,8 @@ public: m_tactic->user_propagate_register_diseq(diseq_eh); } - unsigned user_propagate_register_expr(expr* e) override { - return m_tactic->user_propagate_register_expr(e); + void user_propagate_register_expr(expr* e) override { + m_tactic->user_propagate_register_expr(e); } void user_propagate_register_created(user_propagator::created_eh_t& created_eh) override { diff --git a/src/tactic/core/elim_uncnstr_tactic.cpp b/src/tactic/core/elim_uncnstr_tactic.cpp index 26a69fd4a..c97fa670e 100644 --- a/src/tactic/core/elim_uncnstr_tactic.cpp +++ b/src/tactic/core/elim_uncnstr_tactic.cpp @@ -892,9 +892,8 @@ public: m_num_elim_apps = 0; } - unsigned user_propagate_register_expr(expr* e) override { + void user_propagate_register_expr(expr* e) override { m_nonvars.insert(e); - return 0; } void user_propagate_clear() override { diff --git a/src/tactic/core/reduce_args_tactic.cpp b/src/tactic/core/reduce_args_tactic.cpp index 607928f64..7f0d82f2e 100644 --- a/src/tactic/core/reduce_args_tactic.cpp +++ b/src/tactic/core/reduce_args_tactic.cpp @@ -78,7 +78,7 @@ public: void operator()(goal_ref const & g, goal_ref_buffer & result) override; void cleanup() override; - unsigned user_propagate_register_expr(expr* e) override; + void user_propagate_register_expr(expr* e) override; void user_propagate_clear() override; }; @@ -502,9 +502,8 @@ void reduce_args_tactic::cleanup() { m_imp->m_vars.append(vars); } -unsigned reduce_args_tactic::user_propagate_register_expr(expr* e) { +void reduce_args_tactic::user_propagate_register_expr(expr* e) { m_imp->m_vars.push_back(e); - return 0; } void reduce_args_tactic::user_propagate_clear() { diff --git a/src/tactic/tactic.h b/src/tactic/tactic.h index 437c7f804..af8b24b2f 100644 --- a/src/tactic/tactic.h +++ b/src/tactic/tactic.h @@ -85,7 +85,7 @@ public: throw default_exception("tactic does not support user propagation"); } - unsigned user_propagate_register_expr(expr* e) override { return 0; } + void user_propagate_register_expr(expr* e) override { } virtual char const* name() const = 0; protected: diff --git a/src/tactic/tactical.cpp b/src/tactic/tactical.cpp index c5a9e6984..2d14a5eaa 100644 --- a/src/tactic/tactical.cpp +++ b/src/tactic/tactical.cpp @@ -190,9 +190,9 @@ public: m_t2->user_propagate_register_diseq(diseq_eh); } - unsigned user_propagate_register_expr(expr* e) override { + void user_propagate_register_expr(expr* e) override { m_t1->user_propagate_register_expr(e); - return m_t2->user_propagate_register_expr(e); + m_t2->user_propagate_register_expr(e); } void user_propagate_clear() override { @@ -848,7 +848,7 @@ public: void reset() override { m_t->reset(); } void set_logic(symbol const& l) override { m_t->set_logic(l); } void set_progress_callback(progress_callback * callback) override { m_t->set_progress_callback(callback); } - unsigned user_propagate_register_expr(expr* e) override { return m_t->user_propagate_register_expr(e); } + void user_propagate_register_expr(expr* e) override { m_t->user_propagate_register_expr(e); } void user_propagate_clear() override { m_t->user_propagate_clear(); } protected: diff --git a/src/tactic/user_propagator_base.h b/src/tactic/user_propagator_base.h index 403df8af5..07270ffe6 100644 --- a/src/tactic/user_propagator_base.h +++ b/src/tactic/user_propagator_base.h @@ -8,8 +8,8 @@ namespace user_propagator { class callback { public: virtual ~callback() = default; - virtual void propagate_cb(unsigned num_fixed, unsigned const* fixed_ids, unsigned num_eqs, unsigned const* eq_lhs, unsigned const* eq_rhs, expr* conseq) = 0; - virtual unsigned register_cb(expr* e) = 0; + virtual void propagate_cb(unsigned num_fixed, expr* const* fixed_ids, unsigned num_eqs, expr* const* eq_lhs, expr* const* eq_rhs, expr* conseq) = 0; + virtual void register_cb(expr* e) = 0; }; class context_obj { @@ -18,12 +18,12 @@ namespace user_propagator { }; typedef std::function final_eh_t; - typedef std::function fixed_eh_t; - typedef std::function eq_eh_t; + typedef std::function fixed_eh_t; + typedef std::function eq_eh_t; typedef std::function fresh_eh_t; typedef std::function push_eh_t; typedef std::function pop_eh_t; - typedef std::function created_eh_t; + typedef std::function created_eh_t; class plugin : public decl_plugin { @@ -77,7 +77,7 @@ namespace user_propagator { throw default_exception("user-propagators are only supported on the SMT solver"); } - virtual unsigned user_propagate_register_expr(expr* e) { + virtual void user_propagate_register_expr(expr* e) { throw default_exception("user-propagators are only supported on the SMT solver"); }