diff --git a/src/ast/rewriter/rewriter_def.h b/src/ast/rewriter/rewriter_def.h index 61d177809..dddb02dfd 100644 --- a/src/ast/rewriter/rewriter_def.h +++ b/src/ast/rewriter/rewriter_def.h @@ -184,9 +184,20 @@ void rewriter_tpl::process_app(app * t, frame & fr) { unsigned num_args = t->get_num_args(); while (fr.m_i < num_args) { expr * arg = t->get_arg(fr.m_i); + if (fr.m_i >= 1 && m().is_ite(t) && !ProofGen) { + expr * cond = result_stack()[fr.m_spos].get(); + if (m().is_true(cond)) { + arg = t->get_arg(1); + } + else if (m().is_false(cond)) { + arg = t->get_arg(2); + } + } fr.m_i++; if (!visit(arg, fr.m_max_depth)) return; + + } func_decl * f = t->get_decl(); diff --git a/src/smt/smt_context.cpp b/src/smt/smt_context.cpp index 3c88a75a2..d6a837d37 100644 --- a/src/smt/smt_context.cpp +++ b/src/smt/smt_context.cpp @@ -4123,6 +4123,7 @@ namespace smt { if (fcs == FC_DONE) { mk_proto_model(l_true); m_model = m_proto_model->mk_model(); + add_rec_funs_to_model(); } return fcs == FC_DONE; @@ -4175,8 +4176,52 @@ namespace smt { return m_last_search_failure; } + void context::add_rec_funs_to_model() { + ast_manager& m = m_manager; + SASSERT(m_model); + for (unsigned i = 0; i < m_asserted_formulas.get_num_formulas(); ++i) { + expr* e = m_asserted_formulas.get_formula(i); + if (is_quantifier(e)) { + quantifier* q = to_quantifier(e); + std::cout << mk_pp(q, m) << "\n"; + if (!m.is_rec_fun_def(q)) continue; + SASSERT(q->get_num_patterns() == 1); + expr* fn = to_app(q->get_pattern(0))->get_arg(0); + SASSERT(is_app(fn)); + func_decl* f = to_app(fn)->get_decl(); + expr* eq = q->get_expr(); + expr_ref body(m); + if (is_fun_def(fn, q->get_expr(), body)) { + func_interp* fi = alloc(func_interp, m, f->get_arity()); + fi->set_else(body); + m_model->register_decl(f, fi); + } + } + } + } + + bool context::is_fun_def(expr* f, expr* body, expr_ref& result) { + expr* t1, *t2, *t3; + if (m_manager.is_eq(body, t1, t2) || m_manager.is_iff(body, t1, t2)) { + if (t1 == f) return result = t2, true; + if (t2 == f) return result = t1, true; + return false; + } + if (m_manager.is_ite(body, t1, t2, t3)) { + expr_ref body1(m_manager), body2(m_manager); + if (is_fun_def(f, t2, body1) && is_fun_def(f, t3, body2)) { + // f is not free in t1 + result = m_manager.mk_ite(t1, body1, body2); + return true; + } + } + return false; + } + + }; + #ifdef Z3DEBUG void pp(smt::context & c) { c.display(std::cout); diff --git a/src/smt/smt_context.h b/src/smt/smt_context.h index b9b068442..c1c684fec 100644 --- a/src/smt/smt_context.h +++ b/src/smt/smt_context.h @@ -209,6 +209,7 @@ namespace smt { ~scoped_mk_model() { if (m_ctx.m_proto_model.get() != 0) { m_ctx.m_model = m_ctx.m_proto_model->mk_model(); + m_ctx.add_rec_funs_to_model(); m_ctx.m_proto_model = 0; // proto_model is not needed anymore. } } @@ -1156,6 +1157,10 @@ namespace smt { bool propagate(); + void add_rec_funs_to_model(); + + bool is_fun_def(expr* f, expr* q, expr_ref& body); + public: bool can_propagate() const; diff --git a/src/smt/smt_quantifier.h b/src/smt/smt_quantifier.h index bc731bf9c..bc249ed1a 100644 --- a/src/smt/smt_quantifier.h +++ b/src/smt/smt_quantifier.h @@ -165,6 +165,8 @@ namespace smt { virtual void push() = 0; virtual void pop(unsigned num_scopes) = 0; + + }; };