diff --git a/src/ast/ast.h b/src/ast/ast.h index 1c25ff9ec..baefc4685 100644 --- a/src/ast/ast.h +++ b/src/ast/ast.h @@ -2122,6 +2122,7 @@ public: app * mk_or(expr * arg1, expr * arg2) { return mk_app(m_basic_family_id, OP_OR, arg1, arg2); } app * mk_and(expr * arg1, expr * arg2) { return mk_app(m_basic_family_id, OP_AND, arg1, arg2); } app * mk_or(expr * arg1, expr * arg2, expr * arg3) { return mk_app(m_basic_family_id, OP_OR, arg1, arg2, arg3); } + app * mk_or(expr* a, expr* b, expr* c, expr* d) { expr* args[4] = { a, b, c, d }; return mk_app(m_basic_family_id, OP_OR, 4, args); } app * mk_and(expr * arg1, expr * arg2, expr * arg3) { return mk_app(m_basic_family_id, OP_AND, arg1, arg2, arg3); } app * mk_implies(expr * arg1, expr * arg2) { return mk_app(m_basic_family_id, OP_IMPLIES, arg1, arg2); } app * mk_not(expr * n) { return mk_app(m_basic_family_id, OP_NOT, n); } diff --git a/src/ast/pattern/expr_pattern_match.cpp b/src/ast/pattern/expr_pattern_match.cpp index 5cd3542df..5a1f20c3b 100644 --- a/src/ast/pattern/expr_pattern_match.cpp +++ b/src/ast/pattern/expr_pattern_match.cpp @@ -52,28 +52,50 @@ expr_pattern_match::match_quantifier(quantifier* qf, app_ref_vector& patterns, u } m_regs[0] = qf->get_expr(); for (unsigned i = 0; i < m_precompiled.size(); ++i) { - quantifier* qf2 = m_precompiled[i].get(); - if (qf2->get_kind() != qf->get_kind() || is_lambda(qf)) { - continue; - } - if (qf2->get_num_decls() != qf->get_num_decls()) { - continue; - } - subst s; - if (match(qf->get_expr(), m_first_instrs[i], s)) { - for (unsigned j = 0; j < qf2->get_num_patterns(); ++j) { - app* p = static_cast(qf2->get_pattern(j)); - expr_ref p_result(m_manager); - instantiate(p, qf->get_num_decls(), s, p_result); - patterns.push_back(to_app(p_result.get())); - } - weight = qf2->get_weight(); - return true; + if (match_quantifier(i, qf, patterns, weight)) + return true; + } + return false; +} + +bool +expr_pattern_match::match_quantifier(unsigned i, quantifier* qf, app_ref_vector& patterns, unsigned& weight) { + quantifier* qf2 = m_precompiled[i].get(); + if (qf2->get_kind() != qf->get_kind() || is_lambda(qf)) { + return false; + } + if (qf2->get_num_decls() != qf->get_num_decls()) { + return false; + } + subst s; + if (match(qf->get_expr(), m_first_instrs[i], s)) { + for (unsigned j = 0; j < qf2->get_num_patterns(); ++j) { + app* p = static_cast(qf2->get_pattern(j)); + expr_ref p_result(m_manager); + instantiate(p, qf->get_num_decls(), s, p_result); + patterns.push_back(to_app(p_result.get())); } + weight = qf2->get_weight(); + return true; } return false; } +bool expr_pattern_match::match_quantifier_index(quantifier* qf, app_ref_vector& patterns, unsigned& index) { + if (m_regs.empty()) return false; + m_regs[0] = qf->get_expr(); + + for (unsigned i = 0; i < m_precompiled.size(); ++i) { + unsigned weight = 0; + if (match_quantifier(i, qf, patterns, weight)) { + index = i; + return true; + } + } + return false; +} + + void expr_pattern_match::instantiate(expr* a, unsigned num_bound, subst& s, expr_ref& result) { bound b; @@ -399,8 +421,16 @@ expr_pattern_match::initialize(char const * spec_string) { TRACE("expr_pattern_match", display(tout); ); } -void -expr_pattern_match::display(std::ostream& out) const { +unsigned expr_pattern_match::initialize(quantifier* q) { + if (m_instrs.empty()) { + m_instrs.push_back(instr(BACKTRACK)); + } + compile(q); + return m_precompiled.size() - 1; +} + + +void expr_pattern_match::display(std::ostream& out) const { for (unsigned i = 0; i < m_instrs.size(); ++i) { display(out, m_instrs[i]); } diff --git a/src/ast/pattern/expr_pattern_match.h b/src/ast/pattern/expr_pattern_match.h index d1388b43f..679d69fbb 100644 --- a/src/ast/pattern/expr_pattern_match.h +++ b/src/ast/pattern/expr_pattern_match.h @@ -131,11 +131,14 @@ class expr_pattern_match { public: expr_pattern_match(ast_manager & manager); ~expr_pattern_match(); - virtual bool match_quantifier(quantifier * qf, app_ref_vector & patterns, unsigned & weight); - virtual void initialize(char const * database); + bool match_quantifier(quantifier * qf, app_ref_vector & patterns, unsigned & weight); + bool match_quantifier_index(quantifier* qf, app_ref_vector & patterns, unsigned& index); + unsigned initialize(quantifier* qf); + void initialize(char const * database); void display(std::ostream& out) const; private: + bool match_quantifier(unsigned i, quantifier * qf, app_ref_vector & patterns, unsigned & weight); void instantiate(expr* a, unsigned num_bound, subst& s, expr_ref& result); void compile(expr* q); bool match(expr* a, unsigned init, subst& s); diff --git a/src/ast/rewriter/func_decl_replace.cpp b/src/ast/rewriter/func_decl_replace.cpp index dcde9044f..97451cb58 100644 --- a/src/ast/rewriter/func_decl_replace.cpp +++ b/src/ast/rewriter/func_decl_replace.cpp @@ -93,4 +93,5 @@ void func_decl_replace::reset() { m_cache.reset(); m_subst.reset(); m_refs.reset(); + m_funs.reset(); } diff --git a/src/ast/rewriter/func_decl_replace.h b/src/ast/rewriter/func_decl_replace.h index a553ed999..7f37e8753 100644 --- a/src/ast/rewriter/func_decl_replace.h +++ b/src/ast/rewriter/func_decl_replace.h @@ -27,13 +27,14 @@ class func_decl_replace { ast_manager& m; obj_map m_subst; obj_map m_cache; - ptr_vector m_todo, m_args; - expr_ref_vector m_refs; + ptr_vector m_todo, m_args; + expr_ref_vector m_refs; + func_decl_ref_vector m_funs; public: - func_decl_replace(ast_manager& m): m(m), m_refs(m) {} + func_decl_replace(ast_manager& m): m(m), m_refs(m), m_funs(m) {} - void insert(func_decl* src, func_decl* dst) { m_subst.insert(src, dst); } + void insert(func_decl* src, func_decl* dst) { m_subst.insert(src, dst); m_funs.push_back(src), m_funs.push_back(dst); } expr_ref operator()(expr* e); diff --git a/src/tactic/core/special_relations_tactic.cpp b/src/tactic/core/special_relations_tactic.cpp index f256bdbef..be0706c5c 100644 --- a/src/tactic/core/special_relations_tactic.cpp +++ b/src/tactic/core/special_relations_tactic.cpp @@ -24,23 +24,12 @@ void special_relations_tactic::collect_feature(goal const& g, unsigned idx, obj_map& goal_features) { expr* f = g.form(idx); func_decl_ref p(m); - if (is_transitivity(f, p)) { - insert(goal_features, p, idx, sr_transitive); - } - else if (is_anti_symmetry(f, p)) { - insert(goal_features, p, idx, sr_antisymmetric); - } - else if (is_left_tree(f, p)) { - insert(goal_features, p, idx, sr_lefttree); - } - else if (is_right_tree(f, p)) { - insert(goal_features, p, idx, sr_righttree); - } - else if (is_reflexive(f, p)) { - insert(goal_features, p, idx, sr_reflexive); - } - else if (is_total(f, p)) { - insert(goal_features, p, idx, sr_total); + if (!is_quantifier(f)) return; + unsigned index = 0; + app_ref_vector patterns(m); + if (m_pm.match_quantifier_index(to_quantifier(f), patterns, index)) { + p = to_app(patterns.get(0)->get_arg(0))->get_decl(); + insert(goal_features, p, idx, m_properties[index]); } } @@ -53,32 +42,68 @@ void special_relations_tactic::insert(obj_map& goal_featur } -bool special_relations_tactic::is_transitivity(expr* fml, func_decl_ref& p) { - // match Forall x, y, z . p(x,y) & p(y,z) -> p(x,z) - return false; +void special_relations_tactic::initialize() { + if (!m_properties.empty()) return; + sort_ref A(m); + func_decl_ref R(m.mk_func_decl(symbol("R"), A, A, m.mk_bool_sort()), m); + var_ref x(m.mk_var(0, A), m); + var_ref y(m.mk_var(1, A), m); + var_ref z(m.mk_var(2, A), m); + expr* _x = x, *_y = y, *_z = z; + + expr_ref Rxy(m.mk_app(R, _x, y), m); + expr_ref Ryz(m.mk_app(R, _y, z), m); + expr_ref Rxz(m.mk_app(R, _x, z), m); + expr_ref Rxx(m.mk_app(R, _x, x), m); + expr_ref Ryx(m.mk_app(R, _y, x), m); + expr_ref Rzy(m.mk_app(R, _z, y), m); + expr_ref Rzx(m.mk_app(R, _z, x), m); + expr_ref nRxy(m.mk_not(Rxy), m); + expr_ref nRyx(m.mk_not(Ryx), m); + expr_ref nRzx(m.mk_not(Rzx), m); + expr_ref nRxz(m.mk_not(Rxz), m); + + sort* As[3] = { A, A, A}; + symbol xyz[3] = { symbol("x"), symbol("y"), symbol("z") }; + expr_ref fml(m); + quantifier_ref q(m); + expr_ref pat(m.mk_pattern(to_app(Rxy)), m); + expr* pats[1] = { pat }; + fml = m.mk_or(m.mk_not(Rxy), m.mk_not(Ryz), Rxz); + q = m.mk_forall(3, As, xyz, fml, 0, symbol::null, symbol::null, 1, pats); + register_pattern(m_pm.initialize(q), sr_transitive); + + fml = Rxx; + q = m.mk_forall(1, As, xyz, fml, 0, symbol::null, symbol::null, 1, pats); + register_pattern(m_pm.initialize(q), sr_reflexive); + + fml = m.mk_or(nRxy, nRyx, m.mk_eq(x, y)); + q = m.mk_forall(2, As, xyz, fml, 0, symbol::null, symbol::null, 1, pats); + register_pattern(m_pm.initialize(q), sr_antisymmetric); + + fml = m.mk_or(nRyx, nRzx, Ryz, Rzy); + q = m.mk_forall(3, As, xyz, fml, 0, symbol::null, symbol::null, 1, pats); + register_pattern(m_pm.initialize(q), sr_lefttree); + + fml = m.mk_or(nRxy, nRxz, Ryx, Rzy); + q = m.mk_forall(3, As, xyz, fml, 0, symbol::null, symbol::null, 1, pats); + register_pattern(m_pm.initialize(q), sr_righttree); + + fml = m.mk_or(Rxy, Ryx); + q = m.mk_forall(2, As, xyz, fml, 0, symbol::null, symbol::null, 1, pats); + register_pattern(m_pm.initialize(q), sr_total); } -bool special_relations_tactic::is_anti_symmetry(expr* fml, func_decl_ref& p) { - return false; -} -bool special_relations_tactic::is_left_tree(expr* fml, func_decl_ref& p) { - return false; -} -bool special_relations_tactic::is_right_tree(expr* fml, func_decl_ref& p) { - return false; -} -bool special_relations_tactic::is_reflexive(expr* fml, func_decl_ref& p) { - return false; -} -bool special_relations_tactic::is_total(expr* fml, func_decl_ref& p) { - return false; -} -bool special_relations_tactic::is_symmetric(expr* fml, func_decl_ref& p) { - return false; + +void special_relations_tactic::register_pattern(unsigned index, sr_property p) { + SASSERT(index == m_properties.size() + 1); + m_properties.push_back(p); } + void special_relations_tactic::operator()(goal_ref const & g, goal_ref_buffer & result) { tactic_report report("special_relations", *g); + initialize(); obj_map goal_features; unsigned size = g->size(); for (unsigned idx = 0; idx < size; idx++) { diff --git a/src/tactic/core/special_relations_tactic.h b/src/tactic/core/special_relations_tactic.h index 9d7be3717..58da2efbc 100644 --- a/src/tactic/core/special_relations_tactic.h +++ b/src/tactic/core/special_relations_tactic.h @@ -23,10 +23,13 @@ Notes: #include "tactic/tactic.h" #include "tactic/tactical.h" #include "ast/special_relations_decl_plugin.h" +#include "ast/pattern/expr_pattern_match.h" class special_relations_tactic : public tactic { ast_manager& m; params_ref m_params; + expr_pattern_match m_pm; + svector m_properties; struct sp_axioms { unsigned_vector m_goal_indices; @@ -37,6 +40,9 @@ class special_relations_tactic : public tactic { void collect_feature(goal const& g, unsigned idx, obj_map& goal_features); void insert(obj_map& goal_features, func_decl* f, unsigned idx, sr_property p); + void initialize(); + void register_pattern(unsigned index, sr_property); + bool is_transitivity(expr* fml, func_decl_ref& p); bool is_anti_symmetry(expr* fml, func_decl_ref& p); bool is_left_tree(expr* fml, func_decl_ref& p); @@ -47,7 +53,7 @@ class special_relations_tactic : public tactic { public: - special_relations_tactic(ast_manager & m, params_ref const & ref = params_ref()): m(m), m_params(ref) {} + special_relations_tactic(ast_manager & m, params_ref const & ref = params_ref()): m(m), m_params(ref), m_pm(m) {} ~special_relations_tactic() override {}