diff --git a/src/api/python/z3/z3.py b/src/api/python/z3/z3.py index 9ce7a0531..724980339 100644 --- a/src/api/python/z3/z3.py +++ b/src/api/python/z3/z3.py @@ -11364,12 +11364,12 @@ def to_ContextObj(ptr,): return ctx -def user_prop_fresh(ctx, new_ctx): +def user_prop_fresh(ctx, _new_ctx): _prop_closures.set_threaded() prop = _prop_closures.get(ctx) nctx = Context() Z3_del_context(nctx.ctx) - new_ctx = to_ContextObj(new_ctx) + new_ctx = to_ContextObj(_new_ctx) nctx.ctx = new_ctx nctx.eh = Z3_set_error_handler(new_ctx, z3_error_handler) nctx.owner = False @@ -11390,6 +11390,13 @@ def user_prop_fixed(ctx, cb, id, value): prop.fixed(id, value) prop.cb = None +def user_prop_created(ctx, cb, id): + prop = _prop_closures.get(ctx) + prop.cb = cb + id = _to_expr_ref(to_Ast(id), prop.ctx()) + prop.created(id) + prop.cb = None + def user_prop_final(ctx, cb): prop = _prop_closures.get(ctx) prop.cb = cb @@ -11417,10 +11424,32 @@ _user_prop_push = Z3_push_eh(user_prop_push) _user_prop_pop = Z3_pop_eh(user_prop_pop) _user_prop_fresh = Z3_fresh_eh(user_prop_fresh) _user_prop_fixed = Z3_fixed_eh(user_prop_fixed) +_user_prop_created = Z3_created_eh(user_prop_created) _user_prop_final = Z3_final_eh(user_prop_final) _user_prop_eq = Z3_eq_eh(user_prop_eq) _user_prop_diseq = Z3_eq_eh(user_prop_diseq) +def PropagateFunction(name, *sig): + """Create a function that gets tracked by user propagator. + Every term headed by this function symbol is tracked. + If a term is fixed and the fixed callback is registered a + callback is invoked that the term headed by this function is fixed. + """ + sig = _get_args(sig) + if z3_debug(): + _z3_assert(len(sig) > 0, "At least two arguments expected") + arity = len(sig) - 1 + rng = sig[arity] + if z3_debug(): + _z3_assert(is_sort(rng), "Z3 sort expected") + dom = (Sort * arity)() + for i in range(arity): + if z3_debug(): + _z3_assert(is_sort(sig[i]), "Z3 sort expected") + dom[i] = sig[i].ast + ctx = rng.ctx + return FuncDeclRef(Z3_solver_propagate_declare(ctx.ref(), to_symbol(name, ctx), arity, dom, rng.ast), ctx) + class UserPropagateBase: @@ -11443,6 +11472,7 @@ class UserPropagateBase: self.final = None self.eq = None self.diseq = None + self.created = None if ctx: self.fresh_ctx = ctx if s: @@ -11473,6 +11503,13 @@ class UserPropagateBase: Z3_solver_propagate_fixed(self.ctx_ref(), self.solver.solver, _user_prop_fixed) self.fixed = fixed + def add_created(self, created): + assert not self.created + assert not self._ctx + if self.solver: + Z3_solver_propagate_created(self.ctx_ref(), self.solver.solver, _user_prop_created) + self.created = created + def add_final(self, final): assert not self.final assert not self._ctx @@ -11504,9 +11541,12 @@ class UserPropagateBase: raise Z3Exception("fresh needs to be overwritten") def add(self, e): - assert self.solver assert not self._ctx - Z3_solver_propagate_register(self.ctx_ref(), self.solver.solver, e.ast) + if self.solver: + Z3_solver_propagate_register(self.ctx_ref(), self.solver.solver, e.ast) + else: + Z3_solver_propagate_register_cb(self.ctx_ref(), ctypes.c_void_p(self.cb), e.ast) + # # Propagation can only be invoked as during a fixed or final callback. @@ -11519,5 +11559,5 @@ class UserPropagateBase: Z3_solver_propagate_consequence(e.ctx.ref(), ctypes.c_void_p( self.cb), num_fixed, _ids, num_eqs, _lhs, _rhs, e.ast) - def conflict(self, deps): - self.propagate(BoolVal(False, self.ctx()), deps, eqs=[]) + def conflict(self, deps = [], eqs = []): + self.propagate(BoolVal(False, self.ctx()), deps, eqs) diff --git a/src/ast/ast.cpp b/src/ast/ast.cpp index 473bd82b5..c51ea4e32 100644 --- a/src/ast/ast.cpp +++ b/src/ast/ast.cpp @@ -1429,7 +1429,7 @@ ast_manager::~ast_manager() { } m_plugins.reset(); while (!m_ast_table.empty()) { - DEBUG_CODE(IF_VERBOSE(0, verbose_stream() << "ast_manager LEAKED: " << m_ast_table.size() << std::endl);); + DEBUG_CODE(IF_VERBOSE(1, verbose_stream() << "ast_manager LEAKED: " << m_ast_table.size() << std::endl);); ptr_vector roots; ast_mark mark; for (ast * n : m_ast_table) { @@ -1465,22 +1465,21 @@ ast_manager::~ast_manager() { break; } } - for (ast * n : m_ast_table) { - if (!mark.is_marked(n)) { + for (ast * n : m_ast_table) + if (!mark.is_marked(n)) roots.push_back(n); - } - } + SASSERT(!roots.empty()); for (unsigned i = 0; i < roots.size(); ++i) { ast* a = roots[i]; DEBUG_CODE( - std::cout << "Leaked: "; - if (is_sort(a)) { - std::cout << to_sort(a)->get_name() << "\n"; - } - else { - std::cout << mk_ll_pp(a, *this, false) << "id: " << a->get_id() << "\n"; - }); + IF_VERBOSE(1, + verbose_stream() << "Leaked: "; + if (is_sort(a)) + verbose_stream() << to_sort(a)->get_name() << "\n"; + else + verbose_stream() << mk_ll_pp(a, *this, false) << "id: " << a->get_id() << "\n"; + );); a->m_ref_count = 0; delete_node(a); } diff --git a/src/sat/smt/user_solver.cpp b/src/sat/smt/user_solver.cpp index 9e2ea3eab..49d760177 100644 --- a/src/sat/smt/user_solver.cpp +++ b/src/sat/smt/user_solver.cpp @@ -131,6 +131,21 @@ namespace user_solver { m_id2justification.setx(v, lits, sat::literal_vector()); m_fixed_eh(m_user_context, this, var2expr(v), lit.sign() ? m.mk_false() : m.mk_true()); } + + void solver::new_eq_eh(euf::th_eq const& eq) { + if (!m_eq_eh) + return; + force_push(); + m_eq_eh(m_user_context, this, var2expr(eq.v1()), var2expr(eq.v2())); + } + + void solver::new_diseq_eh(euf::th_eq const& de) { + if (!m_diseq_eh) + return; + force_push(); + m_diseq_eh(m_user_context, this, var2expr(eq.v1()), var2expr(eq.v2())); + } + void solver::push_core() { th_euf_solver::push_core(); diff --git a/src/sat/smt/user_solver.h b/src/sat/smt/user_solver.h index 951b97fb6..28528b9a1 100644 --- a/src/sat/smt/user_solver.h +++ b/src/sat/smt/user_solver.h @@ -144,6 +144,10 @@ namespace user_solver { bool get_case_split(sat::bool_var& var, lbool &phase) override; void asserted(sat::literal lit) override; + bool use_diseqs() const override { return (bool)m_diseq_eh; } + void new_eq_eh(euf::th_eq const& eq) override; + void new_diseq_eh(euf::th_eq const& de) override; + sat::check_result check() override; void push_core() override; void pop_core(unsigned n) override; diff --git a/src/smt/smt_context.cpp b/src/smt/smt_context.cpp index 6652804ea..420147256 100644 --- a/src/smt/smt_context.cpp +++ b/src/smt/smt_context.cpp @@ -821,6 +821,8 @@ namespace smt { SASSERT(t2 != null_theory_id); theory_var v1 = m_fparams.m_new_core2th_eq ? get_closest_var(n1, t2) : r1->get_th_var(t2); + TRACE("merge_theory_vars", tout << get_theory(t2)->get_name() << ": " << v2 << " == " << v1 << "\n"); + if (v1 != null_theory_var) { // only send the equality to the theory, if the equality was not propagated by it. if (t2 != from_th) @@ -839,6 +841,7 @@ namespace smt { SASSERT(v1 != null_theory_var); SASSERT(t1 != null_theory_id); theory_var v2 = r2->get_th_var(t1); + TRACE("merge_theory_vars", tout << get_theory(t1)->get_name() << ": " << v2 << " == " << v1 << "\n"); if (v2 == null_theory_var) { r2->add_th_var(v1, t1, m_region); push_new_th_diseqs(r2, v1, get_theory(t1)); diff --git a/src/smt/theory_user_propagator.cpp b/src/smt/theory_user_propagator.cpp index 911d1715e..ec0acd903 100644 --- a/src/smt/theory_user_propagator.cpp +++ b/src/smt/theory_user_propagator.cpp @@ -146,7 +146,9 @@ final_check_status theory_user_propagator::final_check_eh() { catch (...) { throw default_exception("Exception thrown in \"final\"-callback"); } + CTRACE("user_propagate", can_propagate(), tout << "can propagate\n"); propagate(); + CTRACE("user_propagate", ctx.inconsistent(), tout << "inconsistent\n"); // check if it became inconsistent or something new was propagated/registered bool done = (sz1 == m_prop.size()) && (sz2 == m_expr2var.size()) && !ctx.inconsistent(); return done ? FC_DONE : FC_CONTINUE; @@ -298,13 +300,17 @@ void theory_user_propagator::propagate_consequence(prop_info const& prop) { m_eqs.reset(); for (expr* id : prop.m_ids) m_lits.append(m_id2justification[expr2var(id)]); - for (auto const& p : prop.m_eqs) - m_eqs.push_back(enode_pair(get_enode(expr2var(p.first)), get_enode(expr2var(p.second)))); - DEBUG_CODE(for (auto const& p : m_eqs) VERIFY(p.first->get_root() == p.second->get_root());); + for (auto const& [a,b] : prop.m_eqs) + if (a != b) + m_eqs.push_back(enode_pair(get_enode(expr2var(a)), get_enode(expr2var(b)))); + DEBUG_CODE(for (auto const& [a, b] : m_eqs) VERIFY(a->get_root() == b->get_root());); DEBUG_CODE(for (expr* e : prop.m_ids) VERIFY(m_fixed.contains(expr2var(e)));); DEBUG_CODE(for (literal lit : m_lits) VERIFY(ctx.get_assignment(lit) == l_true);); - TRACE("user_propagate", tout << "propagating #" << prop.m_conseq->get_id() << ": " << prop.m_conseq << "\n"); + TRACE("user_propagate", tout << "propagating #" << prop.m_conseq->get_id() << ": " << prop.m_conseq << "\n"; + for (auto const& [a,b] : m_eqs) tout << enode_pp(a, ctx) << " == " << enode_pp(b, ctx) << "\n"; + for (expr* e : prop.m_ids) tout << mk_pp(e, m) << "\n"; + for (literal lit : m_lits) tout << lit << "\n"); if (m.is_false(prop.m_conseq)) { js = ctx.mk_justification( @@ -341,9 +347,9 @@ void theory_user_propagator::propagate_new_fixed(prop_info const& prop) { void theory_user_propagator::propagate() { - TRACE("user_propagate", tout << "propagating queue head: " << m_qhead << " prop queue: " << m_prop.size() << "\n"); if (m_qhead == m_prop.size() && m_to_add_qhead == m_to_add.size()) return; + TRACE("user_propagate", tout << "propagating queue head: " << m_qhead << " prop queue: " << m_prop.size() << "\n"); force_push(); unsigned qhead = m_to_add_qhead; diff --git a/src/smt/theory_user_propagator.h b/src/smt/theory_user_propagator.h index ba9900848..73fc5bb45 100644 --- a/src/smt/theory_user_propagator.h +++ b/src/smt/theory_user_propagator.h @@ -140,10 +140,11 @@ namespace smt { bool get_case_split(bool_var& var, bool& is_pos); theory * mk_fresh(context * new_ctx) override; + char const* get_name() const override { return "user_propagate"; } bool internalize_atom(app* atom, bool gate_ctx) override; bool internalize_term(app* term) override; - void new_eq_eh(theory_var v1, theory_var v2) override { if (m_eq_eh) m_eq_eh(m_user_context, this, var2expr(v1), var2expr(v2)); } - void new_diseq_eh(theory_var v1, theory_var v2) override { if (m_diseq_eh) m_diseq_eh(m_user_context, this, var2expr(v1), var2expr(v2)); } + void new_eq_eh(theory_var v1, theory_var v2) override { force_push(); if (m_eq_eh) m_eq_eh(m_user_context, this, var2expr(v1), var2expr(v2)); } + void new_diseq_eh(theory_var v1, theory_var v2) override { force_push(); if (m_diseq_eh) m_diseq_eh(m_user_context, this, var2expr(v1), var2expr(v2)); } bool use_diseqs() const override { return ((bool)m_diseq_eh); } bool build_models() const override { return false; } final_check_status final_check_eh() override;