diff --git a/src/api/api_solver.cpp b/src/api/api_solver.cpp index b93fb42af..2ca54a599 100644 --- a/src/api/api_solver.cpp +++ b/src/api/api_solver.cpp @@ -903,6 +903,26 @@ extern "C" { Z3_CATCH_RETURN(nullptr); } + Z3_ast Z3_API Z3_solver_congruence_root(Z3_context c, Z3_solver s, Z3_ast a) { + Z3_TRY; + LOG_Z3_solver_congruence_root(c, s, a); + RESET_ERROR_CODE(); + init_solver(c, s); + expr* r = to_solver_ref(s)->congruence_root(to_expr(a)); + RETURN_Z3(of_expr(r)); + Z3_CATCH_RETURN(nullptr); + } + + Z3_ast Z3_API Z3_solver_congruence_next(Z3_context c, Z3_solver s, Z3_ast a) { + Z3_TRY; + LOG_Z3_solver_congruence_next(c, s, a); + RESET_ERROR_CODE(); + init_solver(c, s); + expr* sib = to_solver_ref(s)->congruence_next(to_expr(a)); + RETURN_Z3(of_expr(sib)); + Z3_CATCH_RETURN(nullptr); + } + class api_context_obj : public user_propagator::context_obj { api::context* c; public: diff --git a/src/api/python/z3/z3.py b/src/api/python/z3/z3.py index 1f24222ad..695bb939f 100644 --- a/src/api/python/z3/z3.py +++ b/src/api/python/z3/z3.py @@ -7241,6 +7241,22 @@ class Solver(Z3PPObject): cube are likely more useful to cube on.""" return self.cube_vs + def root(self, t): + t = _py2expr(t, self.ctx) + """Retrieve congruence closure root of the term t relative to the current search state + The function primarily works for SimpleSolver. Terms and variables that are + eliminated during pre-processing are not visible to the congruence closure. + """ + return _to_expr_ref(Z3_solver_congruence_root(self.ctx.ref(), self.solver, t.ast), self.ctx) + + def next(self, t): + t = _py2expr(t, self.ctx) + """Retrieve congruence closure sibling of the term t relative to the current search state + The function primarily works for SimpleSolver. Terms and variables that are + eliminated during pre-processing are not visible to the congruence closure. + """ + return _to_expr_ref(Z3_solver_congruence_next(self.ctx.ref(), self.solver, t.ast), self.ctx) + def proof(self): """Return a proof for the last `check()`. Proof construction must be enabled.""" return _to_expr_ref(Z3_solver_get_proof(self.ctx.ref(), self.solver), self.ctx) diff --git a/src/api/z3_api.h b/src/api/z3_api.h index 7a0b47da0..ffa0d8665 100644 --- a/src/api/z3_api.h +++ b/src/api/z3_api.h @@ -6882,6 +6882,26 @@ extern "C" { */ void Z3_API Z3_solver_get_levels(Z3_context c, Z3_solver s, Z3_ast_vector literals, unsigned sz, unsigned levels[]); + /** + \brief retrieve the congruence closure root of an expression. + The root is retrieved relative to the state where the solver was in when it completed. + If it completed during a set of case splits, the congruence roots are relative to these case splits. + That is, the congruences are not consequences but they are true under the current state. + + def_API('Z3_solver_congruence_root', AST, (_in(CONTEXT), _in(SOLVER), _in(AST))) + */ + Z3_ast Z3_API Z3_solver_congruence_root(Z3_context c, Z3_solver s, Z3_ast a); + + + /** + \brief retrieve the next expression in the congruence class. The set of congruent siblings form a cyclic list. + Repeated calls on the siblings will result in returning to the original expression. + + def_API('Z3_solver_congruence_next', AST, (_in(CONTEXT), _in(SOLVER), _in(AST))) + */ + Z3_ast Z3_API Z3_solver_congruence_next(Z3_context c, Z3_solver s, Z3_ast a); + + /** \brief register a callback to that retrieves assumed, inferred and deleted clauses during search. diff --git a/src/ast/converters/expr_inverter.cpp b/src/ast/converters/expr_inverter.cpp index 5553420ad..41a60ccd5 100644 --- a/src/ast/converters/expr_inverter.cpp +++ b/src/ast/converters/expr_inverter.cpp @@ -81,7 +81,7 @@ public: * */ - bool operator()(func_decl* f, unsigned num, expr* const* args, expr_ref& r, expr_ref& side_cond) override { + bool operator()(func_decl* f, unsigned num, expr* const* args, expr_ref& r, proof_ref& pr) override { SASSERT(f->get_family_id() == m.get_basic_family_id()); switch (f->get_decl_kind()) { case OP_ITE: @@ -233,7 +233,7 @@ public: } - bool operator()(func_decl* f, unsigned num, expr* const* args, expr_ref& r, expr_ref& side_cond) override { + bool operator()(func_decl* f, unsigned num, expr* const* args, expr_ref& r, proof_ref& pr) override { SASSERT(f->get_family_id() == a.get_family_id()); switch (f->get_decl_kind()) { case OP_ADD: @@ -531,7 +531,7 @@ class bv_expr_inverter : public iexpr_inverter { * y := 0 * */ - bool operator()(func_decl* f, unsigned num, expr* const* args, expr_ref& r, expr_ref& side_cond) override { + bool operator()(func_decl* f, unsigned num, expr* const* args, expr_ref& r, proof_ref& pr) override { SASSERT(f->get_family_id() == bv.get_family_id()); switch (f->get_decl_kind()) { case OP_BADD: @@ -611,7 +611,7 @@ public: family_id get_fid() const override { return a.get_family_id(); } - bool operator()(func_decl* f, unsigned num, expr* const* args, expr_ref& r, expr_ref& side_cond) override { + bool operator()(func_decl* f, unsigned num, expr* const* args, expr_ref& r, proof_ref& pr) override { SASSERT(f->get_family_id() == a.get_family_id()); switch (f->get_decl_kind()) { case OP_SELECT: @@ -679,7 +679,7 @@ public: * head(x) -> fresh * x := cons(fresh, arb) */ - bool operator()(func_decl* f, unsigned num, expr* const* args, expr_ref& r, expr_ref& side_cond) override { + bool operator()(func_decl* f, unsigned num, expr* const* args, expr_ref& r, proof_ref& pr) override { if (dt.is_accessor(f)) { SASSERT(num == 1); if (uncnstr(args[0])) { @@ -799,7 +799,7 @@ expr_inverter::expr_inverter(ast_manager& m): iexpr_inverter(m) { } -bool expr_inverter::operator()(func_decl* f, unsigned num, expr* const* args, expr_ref& new_expr, expr_ref& side_cond) { +bool expr_inverter::operator()(func_decl* f, unsigned num, expr* const* args, expr_ref& new_expr, proof_ref& pr) { if (num == 0) return false; @@ -812,7 +812,7 @@ bool expr_inverter::operator()(func_decl* f, unsigned num, expr* const* args, ex return false; auto* p = m_inverters.get(fid, nullptr); - return p && (*p)(f, num, args, new_expr, side_cond); + return p && (*p)(f, num, args, new_expr, pr); } bool expr_inverter::mk_diff(expr* t, expr_ref& r) { @@ -849,3 +849,10 @@ void expr_inverter::set_model_converter(generic_model_converter* mc) { if (p) p->set_model_converter(mc); } + +void expr_inverter::set_produce_proofs(bool pr) { + m_produce_proofs = pr; + for (auto* p : m_inverters) + if (p) + p->set_produce_proofs(pr); +} diff --git a/src/ast/converters/expr_inverter.h b/src/ast/converters/expr_inverter.h index 5b7965478..60540aff3 100644 --- a/src/ast/converters/expr_inverter.h +++ b/src/ast/converters/expr_inverter.h @@ -24,6 +24,7 @@ protected: ast_manager& m; std::function m_is_var; generic_model_converter_ref m_mc; + bool m_produce_proofs = false; bool uncnstr(expr* e) const { return m_is_var(e); } bool uncnstr(unsigned num, expr * const * args) const; @@ -37,8 +38,9 @@ public: virtual ~iexpr_inverter() {} virtual void set_is_var(std::function& is_var) { m_is_var = is_var; } virtual void set_model_converter(generic_model_converter* mc) { m_mc = mc; } + virtual void set_produce_proofs(bool p) { m_produce_proofs = true; } - virtual bool operator()(func_decl* f, unsigned n, expr* const* args, expr_ref& new_expr, expr_ref& side_cond) = 0; + virtual bool operator()(func_decl* f, unsigned n, expr* const* args, expr_ref& new_expr, proof_ref& pr) = 0; virtual bool mk_diff(expr* t, expr_ref& r) = 0; virtual family_id get_fid() const = 0; }; @@ -49,9 +51,10 @@ class expr_inverter : public iexpr_inverter { public: expr_inverter(ast_manager& m); ~expr_inverter() override; - bool operator()(func_decl* f, unsigned n, expr* const* args, expr_ref& new_expr, expr_ref& side_cond) override; + bool operator()(func_decl* f, unsigned n, expr* const* args, expr_ref& new_expr, proof_ref& pr) override; bool mk_diff(expr* t, expr_ref& r) override; void set_is_var(std::function& is_var) override; void set_model_converter(generic_model_converter* mc) override; + void set_produce_proofs(bool p) override; family_id get_fid() const override { return null_family_id; } }; diff --git a/src/ast/simplifiers/elim_unconstrained.cpp b/src/ast/simplifiers/elim_unconstrained.cpp index df6f92251..41a905dea 100644 --- a/src/ast/simplifiers/elim_unconstrained.cpp +++ b/src/ast/simplifiers/elim_unconstrained.cpp @@ -62,7 +62,8 @@ bool elim_unconstrained::is_var_lt(int v1, int v2) const { void elim_unconstrained::eliminate() { while (!m_heap.empty()) { - expr_ref r(m), side_cond(m); + expr_ref r(m); + proof_ref pr(m); int v = m_heap.erase_min(); node& n = get_node(v); if (n.m_refcount == 0) @@ -84,7 +85,7 @@ void elim_unconstrained::eliminate() { unsigned sz = m_args.size(); for (expr* arg : *to_app(t)) m_args.push_back(reconstruct_term(get_node(arg))); - bool inverted = m_inverter(t->get_decl(), to_app(t)->get_num_args(), m_args.data() + sz, r, side_cond); + bool inverted = m_inverter(t->get_decl(), to_app(t)->get_num_args(), m_args.data() + sz, r, pr); n.m_refcount = 0; m_args.shrink(sz); if (!inverted) { @@ -113,7 +114,7 @@ void elim_unconstrained::eliminate() { IF_VERBOSE(11, verbose_stream() << mk_bounded_pp(get_node(v).m_orig, m) << " " << mk_bounded_pp(t, m) << " -> " << r << " " << get_node(e).m_refcount << "\n";); - SASSERT(!side_cond && "not implemented to add side conditions\n"); + SASSERT(!pr && "not implemented to add proofs\n"); } } diff --git a/src/muz/spacer/spacer_iuc_solver.h b/src/muz/spacer/spacer_iuc_solver.h index 0d4712215..e201a1fe1 100644 --- a/src/muz/spacer/spacer_iuc_solver.h +++ b/src/muz/spacer/spacer_iuc_solver.h @@ -122,6 +122,8 @@ public: void set_phase(phase* p) override { m_solver.set_phase(p); } void move_to_front(expr* e) override { m_solver.move_to_front(e); } expr_ref_vector cube(expr_ref_vector&, unsigned) override { return expr_ref_vector(m); } + expr* congruence_root(expr* e) override { return e; } + expr* congruence_next(expr* e) override { return e; } void get_levels(ptr_vector const& vars, unsigned_vector& depth) override { m_solver.get_levels(vars, depth); } expr_ref_vector get_trail(unsigned max_level) override { return m_solver.get_trail(max_level); } diff --git a/src/opt/opt_solver.h b/src/opt/opt_solver.h index 84d31ed0f..2682fca09 100644 --- a/src/opt/opt_solver.h +++ b/src/opt/opt_solver.h @@ -110,6 +110,8 @@ namespace opt { void get_levels(ptr_vector const& vars, unsigned_vector& depth) override; expr_ref_vector get_trail(unsigned max_level) override { return m_context.get_trail(max_level); } expr_ref_vector cube(expr_ref_vector&, unsigned) override { return expr_ref_vector(m); } + expr* congruence_root(expr* e) override { return e; } + expr* congruence_next(expr* e) override { return e; } void set_phase(expr* e) override { m_context.set_phase(e); } phase* get_phase() override { return m_context.get_phase(); } void set_phase(phase* p) override { m_context.set_phase(p); } diff --git a/src/sat/sat_solver/inc_sat_solver.cpp b/src/sat/sat_solver/inc_sat_solver.cpp index 41b8e609c..0362d8d3e 100644 --- a/src/sat/sat_solver/inc_sat_solver.cpp +++ b/src/sat/sat_solver/inc_sat_solver.cpp @@ -463,6 +463,10 @@ public: } return fmls; } + + expr* congruence_next(expr* e) override { return e; } + expr* congruence_root(expr* e) override { return e; } + lbool get_consequences_core(expr_ref_vector const& assumptions, expr_ref_vector const& vars, expr_ref_vector& conseq) override { init_preprocess(); diff --git a/src/sat/sat_solver/sat_smt_solver.cpp b/src/sat/sat_solver/sat_smt_solver.cpp index e37d513a0..f5872b05a 100644 --- a/src/sat/sat_solver/sat_smt_solver.cpp +++ b/src/sat/sat_solver/sat_smt_solver.cpp @@ -476,6 +476,9 @@ public: set_reason_unknown(m_solver.get_reason_unknown()); return fmls; } + + expr* congruence_next(expr* e) override { return e; } + expr* congruence_root(expr* e) override { return e; } lbool find_mutexes(expr_ref_vector const& vars, vector& mutexes) override { diff --git a/src/smt/smt_kernel.cpp b/src/smt/smt_kernel.cpp index ae3338f52..c4ecf6787 100644 --- a/src/smt/smt_kernel.cpp +++ b/src/smt/smt_kernel.cpp @@ -213,6 +213,20 @@ namespace smt { return out; } + expr* kernel::congruence_root(expr * e) { + smt::enode* n = m_imp->m_kernel.find_enode(e); + if (!n) + return e; + return n->get_root()->get_expr(); + } + + expr* kernel::congruence_next(expr * e) { + smt::enode* n = m_imp->m_kernel.find_enode(e); + if (!n) + return e; + return n->get_next()->get_expr(); + } + void kernel::collect_statistics(::statistics & st) const { m_imp->m_kernel.collect_statistics(st); } diff --git a/src/smt/smt_kernel.h b/src/smt/smt_kernel.h index fa4a48406..ccea5caf8 100644 --- a/src/smt/smt_kernel.h +++ b/src/smt/smt_kernel.h @@ -239,6 +239,13 @@ namespace smt { */ expr_ref_vector cubes(unsigned depth); + /** + \brief access congruence closure + */ + expr* congruence_next(expr* e); + + expr* congruence_root(expr* e); + /** \brief retrieve depth of variables from decision stack. diff --git a/src/smt/smt_solver.cpp b/src/smt/smt_solver.cpp index 61c7fdda7..4be78b20a 100644 --- a/src/smt/smt_solver.cpp +++ b/src/smt/smt_solver.cpp @@ -330,6 +330,10 @@ namespace { m_context.get_units(units); } + expr* congruence_next(expr* e) override { return m_context.congruence_next(e); } + expr* congruence_root(expr* e) override { return m_context.congruence_root(e); } + + expr_ref_vector cube(expr_ref_vector& vars, unsigned cutoff) override { ast_manager& m = get_manager(); if (!m_cuber) { diff --git a/src/solver/combined_solver.cpp b/src/solver/combined_solver.cpp index 7b1449637..53aa56753 100644 --- a/src/solver/combined_solver.cpp +++ b/src/solver/combined_solver.cpp @@ -275,6 +275,10 @@ public: return m_solver2->cube(vars, backtrack_level); } + expr* congruence_next(expr* e) override { switch_inc_mode(); return m_solver2->congruence_next(e); } + expr* congruence_root(expr* e) override { switch_inc_mode(); return m_solver2->congruence_root(e); } + + expr * get_assumption(unsigned idx) const override { unsigned c1 = m_solver1->get_num_assumptions(); if (idx < c1) return m_solver1->get_assumption(idx); diff --git a/src/solver/solver.h b/src/solver/solver.h index 957cb7c8e..7d7a3eec2 100644 --- a/src/solver/solver.h +++ b/src/solver/solver.h @@ -238,6 +238,15 @@ public: virtual expr_ref_vector cube(expr_ref_vector& vars, unsigned backtrack_level) = 0; + /** + \brief retrieve congruence closure root. + */ + virtual expr* congruence_root(expr* e) = 0; + + /** + \brief retrieve congruence closure sibling + */ + virtual expr* congruence_next(expr* e) = 0; /** \brief Display the content of this solver. diff --git a/src/solver/solver_pool.cpp b/src/solver/solver_pool.cpp index f5760bde3..411634162 100644 --- a/src/solver/solver_pool.cpp +++ b/src/solver/solver_pool.cpp @@ -262,6 +262,9 @@ public: expr_ref_vector cube(expr_ref_vector& vars, unsigned ) override { return expr_ref_vector(m); } + expr* congruence_next(expr* e) override { return e; } + expr* congruence_root(expr* e) override { return e; } + ast_manager& get_manager() const override { return m_base->get_manager(); } void refresh(solver* new_base) { diff --git a/src/solver/tactic2solver.cpp b/src/solver/tactic2solver.cpp index b65ffde57..cc3ac9336 100644 --- a/src/solver/tactic2solver.cpp +++ b/src/solver/tactic2solver.cpp @@ -136,6 +136,9 @@ public: return expr_ref_vector(get_manager()); } + expr* congruence_next(expr* e) override { return e; } + expr* congruence_root(expr* e) override { return e; } + model_converter_ref get_model_converter() const override { return m_mc; } void get_levels(ptr_vector const& vars, unsigned_vector& depth) override { diff --git a/src/tactic/fd_solver/bounded_int2bv_solver.cpp b/src/tactic/fd_solver/bounded_int2bv_solver.cpp index 7b7ca630e..4ac82c0c2 100644 --- a/src/tactic/fd_solver/bounded_int2bv_solver.cpp +++ b/src/tactic/fd_solver/bounded_int2bv_solver.cpp @@ -210,6 +210,8 @@ public: void set_reason_unknown(char const* msg) override { m_solver->set_reason_unknown(msg); } void get_labels(svector & r) override { m_solver->get_labels(r); } ast_manager& get_manager() const override { return m; } + expr* congruence_next(expr* e) override { return m_solver->congruence_next(e); } + expr* congruence_root(expr* e) override { return m_solver->congruence_root(e); } expr_ref_vector cube(expr_ref_vector& vars, unsigned backtrack_level) override { flush_assertions(); return m_solver->cube(vars, backtrack_level); } lbool find_mutexes(expr_ref_vector const& vars, vector& mutexes) override { return m_solver->find_mutexes(vars, mutexes); } lbool get_consequences_core(expr_ref_vector const& asms, expr_ref_vector const& vars, expr_ref_vector& consequences) override { diff --git a/src/tactic/fd_solver/enum2bv_solver.cpp b/src/tactic/fd_solver/enum2bv_solver.cpp index 7ec5243e7..2690e7033 100644 --- a/src/tactic/fd_solver/enum2bv_solver.cpp +++ b/src/tactic/fd_solver/enum2bv_solver.cpp @@ -131,6 +131,9 @@ public: expr_ref_vector cube(expr_ref_vector& vars, unsigned backtrack_level) override { return m_solver->cube(vars, backtrack_level); } + expr* congruence_next(expr* e) override { return m_solver->congruence_next(e); } + expr* congruence_root(expr* e) override { return m_solver->congruence_root(e); } + lbool get_consequences_core(expr_ref_vector const& asms, expr_ref_vector const& vars, expr_ref_vector& consequences) override { datatype_util dt(m); diff --git a/src/tactic/fd_solver/pb2bv_solver.cpp b/src/tactic/fd_solver/pb2bv_solver.cpp index 1a5f7d16a..19f2630f2 100644 --- a/src/tactic/fd_solver/pb2bv_solver.cpp +++ b/src/tactic/fd_solver/pb2bv_solver.cpp @@ -122,6 +122,8 @@ public: void get_labels(svector & r) override { m_solver->get_labels(r); } ast_manager& get_manager() const override { return m; } expr_ref_vector cube(expr_ref_vector& vars, unsigned backtrack_level) override { flush_assertions(); return m_solver->cube(vars, backtrack_level); } + expr* congruence_next(expr* e) override { return m_solver->congruence_next(e); } + expr* congruence_root(expr* e) override { return m_solver->congruence_root(e); } lbool find_mutexes(expr_ref_vector const& vars, vector& mutexes) override { return m_solver->find_mutexes(vars, mutexes); } lbool get_consequences_core(expr_ref_vector const& asms, expr_ref_vector const& vars, expr_ref_vector& consequences) override { flush_assertions(); diff --git a/src/tactic/fd_solver/smtfd_solver.cpp b/src/tactic/fd_solver/smtfd_solver.cpp index 3729a2ad1..5676c6ee8 100644 --- a/src/tactic/fd_solver/smtfd_solver.cpp +++ b/src/tactic/fd_solver/smtfd_solver.cpp @@ -2086,6 +2086,10 @@ namespace smtfd { expr_ref_vector cube(expr_ref_vector& vars, unsigned backtrack_level) override { return expr_ref_vector(m); } + + expr* congruence_root(expr* e) override { return e; } + + expr* congruence_next(expr* e) override { return e; } lbool get_consequences_core(expr_ref_vector const& asms, expr_ref_vector const& vars, expr_ref_vector& consequences) override { return l_undef;