From 082936bca6fbd6d363cf09c20f60d113691c2ed8 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Tue, 8 Aug 2017 09:21:06 +0200 Subject: [PATCH] enable overloading resolution on define-fun declarations, fix #1199 Signed-off-by: Nikolaj Bjorner --- .../simplifier/basic_simplifier_plugin.cpp | 7 +- src/cmd_context/basic_cmds.cpp | 40 ++-- src/cmd_context/cmd_context.cpp | 178 +++++++++++++----- src/cmd_context/cmd_context.h | 44 ++++- src/parsers/smt2/smt2parser.cpp | 6 +- src/smt/smt_context.h | 10 +- 6 files changed, 205 insertions(+), 80 deletions(-) diff --git a/src/ast/simplifier/basic_simplifier_plugin.cpp b/src/ast/simplifier/basic_simplifier_plugin.cpp index 25998d832..be51bc291 100644 --- a/src/ast/simplifier/basic_simplifier_plugin.cpp +++ b/src/ast/simplifier/basic_simplifier_plugin.cpp @@ -59,7 +59,12 @@ bool basic_simplifier_plugin::reduce(func_decl * f, unsigned num_args, expr * co mk_iff(args[0], args[1], result); return true; case OP_XOR: - mk_xor(args[0], args[1], result); + switch (num_args) { + case 0: result = m_manager.mk_true(); break; + case 1: result = args[0]; break; + case 2: mk_xor(args[0], args[1], result); break; + default: UNREACHABLE(); break; + } return true; case OP_NOT: SASSERT(num_args == 1); diff --git a/src/cmd_context/basic_cmds.cpp b/src/cmd_context/basic_cmds.cpp index 8830358af..db522b82b 100644 --- a/src/cmd_context/basic_cmds.cpp +++ b/src/cmd_context/basic_cmds.cpp @@ -135,28 +135,30 @@ ATOMIC_CMD(get_assignment_cmd, "get-assignment", "retrieve assignment", { model_ref m; ctx.get_check_sat_result()->get_model(m); ctx.regular_stream() << "("; - dictionary const & macros = ctx.get_macros(); - dictionary::iterator it = macros.begin(); - dictionary::iterator end = macros.end(); + dictionary const & macros = ctx.get_macros(); + dictionary::iterator it = macros.begin(); + dictionary::iterator end = macros.end(); for (bool first = true; it != end; ++it) { symbol const & name = (*it).m_key; - cmd_context::macro const & _m = (*it).m_value; - if (_m.first == 0 && ctx.m().is_bool(_m.second)) { - expr_ref val(ctx.m()); - m->eval(_m.second, val, true); - if (ctx.m().is_true(val) || ctx.m().is_false(val)) { - if (first) - first = false; - else - ctx.regular_stream() << " "; - ctx.regular_stream() << "("; - if (is_smt2_quoted_symbol(name)) { - ctx.regular_stream() << mk_smt2_quoted_symbol(name); + macro_decls const & _m = (*it).m_value; + for (auto md : _m) { + if (md.m_domain.size() == 0 && ctx.m().is_bool(md.m_body)) { + expr_ref val(ctx.m()); + m->eval(md.m_body, val, true); + if (ctx.m().is_true(val) || ctx.m().is_false(val)) { + if (first) + first = false; + else + ctx.regular_stream() << " "; + ctx.regular_stream() << "("; + if (is_smt2_quoted_symbol(name)) { + ctx.regular_stream() << mk_smt2_quoted_symbol(name); + } + else { + ctx.regular_stream() << name; + } + ctx.regular_stream() << " " << (ctx.m().is_true(val) ? "true" : "false") << ")"; } - else { - ctx.regular_stream() << name; - } - ctx.regular_stream() << " " << (ctx.m().is_true(val) ? "true" : "false") << ")"; } } } diff --git a/src/cmd_context/cmd_context.cpp b/src/cmd_context/cmd_context.cpp index a893a1637..f172e5e93 100644 --- a/src/cmd_context/cmd_context.cpp +++ b/src/cmd_context/cmd_context.cpp @@ -76,6 +76,15 @@ bool func_decls::signatures_collide(func_decl* f, func_decl* g) const { return f == g; } +bool func_decls::signatures_collide(unsigned n, sort* const* domain, sort* range, func_decl* g) const { + if (g->get_range() != range) return false; + if (n != g->get_arity()) return false; + for (unsigned i = 0; i < n; ++i) { + if (domain[i] != g->get_domain(i)) return false; + } + return true; +} + bool func_decls::contains(func_decl * f) const { if (GET_TAG(m_decls) == 0) { func_decl* g = UNTAG(func_decl*, m_decls); @@ -90,6 +99,21 @@ bool func_decls::contains(func_decl * f) const { return false; } + +bool func_decls::contains(unsigned n, sort* const* domain, sort* range) const { + if (GET_TAG(m_decls) == 0) { + func_decl* g = UNTAG(func_decl*, m_decls); + return g && signatures_collide(n, domain, range, g); + } + else { + func_decl_set * fs = UNTAG(func_decl_set *, m_decls); + for (func_decl* g : *fs) { + if (signatures_collide(n, domain, range, g)) return true; + } + } + return false; +} + bool func_decls::insert(ast_manager & m, func_decl * f) { if (contains(f)) return false; @@ -205,6 +229,94 @@ func_decl * func_decls::find(ast_manager & m, unsigned num_args, expr * const * return find(num_args, sorts.c_ptr(), range); } +void macro_decls::finalize(ast_manager& m) { + for (auto v : *m_decls) m.dec_ref(v.m_body); + dealloc(m_decls); +} + +bool macro_decls::insert(ast_manager& m, unsigned arity, sort *const* domain, expr* body) { + if (find(arity, domain)) return false; + m.inc_ref(body); + if (!m_decls) m_decls = alloc(vector); + m_decls->push_back(macro_decl(arity, domain, body)); + return true; +} + +expr* macro_decls::find(unsigned arity, sort *const* domain) const { + if (!m_decls) return 0; + for (auto v : *m_decls) { + if (v.m_domain.size() != arity) continue; + bool eq = true; + for (unsigned i = 0; eq && i < arity; ++i) { + eq = domain[i] == v.m_domain[i]; + } + if (eq) return v.m_body; + } + return 0; +} + +void macro_decls::erase_last(ast_manager& m) { + SASSERT(m_decls); + SASSERT(!m_decls->empty()); + m.dec_ref(m_decls->back().m_body); + m_decls->pop_back(); +} + +bool cmd_context::contains_func_decl(symbol const& s, unsigned n, sort* const* domain, sort* range) const { + func_decls fs; + return m_func_decls.find(s, fs) && fs.contains(n, domain, range); +} + +bool cmd_context::contains_macro(symbol const& s) const { + return m_macros.contains(s); +} + +bool cmd_context::contains_macro(symbol const& s, func_decl* f) const { + return contains_macro(s, f->get_arity(), f->get_domain()); +} + +bool cmd_context::contains_macro(symbol const& s, unsigned arity, sort *const* domain) const { + macro_decls decls; + return m_macros.find(s, decls) && 0 != decls.find(arity, domain); +} + +void cmd_context::insert_macro(symbol const& s, unsigned arity, sort*const* domain, expr* t) { + macro_decls decls; + if (!m_macros.find(s, decls)) { + VERIFY(decls.insert(m(), arity, domain, t)); + m_macros.insert(s, decls); + } + else { + VERIFY(decls.insert(m(), arity, domain, t)); + } +} + +void cmd_context::erase_macro(symbol const& s) { + macro_decls decls; + VERIFY(m_macros.find(s, decls)); + decls.erase_last(m()); +} + +bool cmd_context::macros_find(symbol const& s, unsigned n, expr*const* args, expr*& t) const { + macro_decls decls; + if (!m_macros.find(s, decls)) { + return false; + } + for (macro_decl const& d : decls) { + if (d.m_domain.size() != n) continue; + bool eq = true; + for (unsigned i = 0; eq && i < n; ++i) { + eq = d.m_domain[i] == m().get_sort(args[i]); + } + if (eq) { + t = d.m_body; + return true; + } + } + return false; +} + + ast_object_ref::ast_object_ref(cmd_context & ctx, ast * a):m_ast(a) { ctx.m().inc_ref(a); } @@ -658,7 +770,7 @@ void cmd_context::insert(symbol const & s, func_decl * f) { if (!m_check_logic(f)) { throw cmd_exception(m_check_logic.get_last_error()); } - if (m_macros.contains(s)) { + if (contains_macro(s, f)) { throw cmd_exception("invalid declaration, named expression already defined with this name ", s); } if (m_builtin_decls.contains(s)) { @@ -697,20 +809,20 @@ void cmd_context::insert(symbol const & s, psort_decl * p) { TRACE("cmd_context", tout << "new sort decl\n"; p->display(tout); tout << "\n";); } -void cmd_context::insert(symbol const & s, unsigned arity, expr * t) { +void cmd_context::insert(symbol const & s, unsigned arity, sort *const* domain, expr * t) { + expr_ref _t(t, m()); m_check_sat_result = 0; if (m_builtin_decls.contains(s)) { throw cmd_exception("invalid macro/named expression, builtin symbol ", s); } - if (m_macros.contains(s)) { + if (contains_macro(s, arity, domain)) { throw cmd_exception("named expression already defined"); } - if (m_func_decls.contains(s)) { + if (contains_func_decl(s, arity, domain, m().get_sort(t))) { throw cmd_exception("invalid named expression, declaration already defined with this name ", s); } - m().inc_ref(t); TRACE("insert_macro", tout << "new macro " << arity << "\n" << mk_pp(t, m()) << "\n";); - m_macros.insert(s, macro(arity, t)); + insert_macro(s, arity, domain, t); if (!m_global_decls) { m_macros_stack.push_back(s); } @@ -783,8 +895,9 @@ func_decl * cmd_context::find_func_decl(symbol const & s) const { } throw cmd_exception("invalid function declaration reference, must provide signature for builtin symbol ", s); } - if (m_macros.contains(s)) + if (contains_macro(s)) { throw cmd_exception("invalid function declaration reference, named expressions (aka macros) cannot be referenced ", s); + } func_decls fs; if (m_func_decls.find(s, fs)) { if (fs.more_than_one()) @@ -840,7 +953,7 @@ func_decl * cmd_context::find_func_decl(symbol const & s, unsigned num_indices, return f; } - if (m_macros.contains(s)) + if (contains_macro(s, arity, domain)) throw cmd_exception("invalid function declaration reference, named expressions (aka macros) cannot be referenced ", s); if (num_indices > 0) @@ -862,11 +975,6 @@ psort_decl * cmd_context::find_psort_decl(symbol const & s) const { return p; } -cmd_context::macro cmd_context::find_macro(symbol const & s) const { - macro m; - m_macros.find(s, m); - return m; -} cmd * cmd_context::find_cmd(symbol const & s) const { cmd * c = 0; @@ -918,21 +1026,14 @@ void cmd_context::mk_app(symbol const & s, unsigned num_args, expr * const * arg } if (num_indices > 0) throw cmd_exception("invalid use of indexed indentifier, unknown builtin function ", s); - macro _m; - if (m_macros.find(s, _m)) { - if (num_args != _m.first) - throw cmd_exception("invalid defined function application, incorrect number of arguments ", s); - if (num_args == 0) { - result = _m.second; - return; - } - SASSERT(num_args > 0); + expr* _t; + if (macros_find(s, num_args, args, _t)) { TRACE("macro_bug", tout << "well_sorted_check_enabled(): " << well_sorted_check_enabled() << "\n"; tout << "s: " << s << "\n"; - tout << "body:\n" << mk_ismt2_pp(_m.second, m()) << "\n"; + tout << "body:\n" << mk_ismt2_pp(_t, m()) << "\n"; tout << "args:\n"; for (unsigned i = 0; i < num_args; i++) tout << mk_ismt2_pp(args[i], m()) << "\n" << mk_pp(m().get_sort(args[i]), m()) << "\n";); var_subst subst(m()); - subst(_m.second, num_args, args, result); + subst(_t, num_args, args, result); if (well_sorted_check_enabled() && !is_well_sorted(m(), result)) throw cmd_exception("invalid macro application, sort mismatch ", s); return; @@ -956,7 +1057,6 @@ void cmd_context::mk_app(symbol const & s, unsigned num_args, expr * const * arg if (f->get_arity() != 0) throw cmd_exception("invalid function application, missing arguments ", s); result = m().mk_const(f); - return; } else { func_decl * f = fs.find(m(), num_args, args, range); @@ -965,7 +1065,6 @@ void cmd_context::mk_app(symbol const & s, unsigned num_args, expr * const * arg if (well_sorted_check_enabled()) m().check_sort(f, num_args, args); result = m().mk_app(f, num_args, args); - return; } } @@ -1023,21 +1122,6 @@ void cmd_context::erase_psort_decl(symbol const & s) { erase_psort_decl_core(s); } -void cmd_context::erase_macro_core(symbol const & s) { - macro _m; - if (m_macros.find(s, _m)) { - m().dec_ref(_m.second); - m_macros.erase(s); - } -} - -void cmd_context::erase_macro(symbol const & s) { - if (!global_decls()) { - throw cmd_exception("macros (aka named expressions) can only be erased when global (instead of scoped) declarations are used"); - } - erase_macro_core(s); -} - void cmd_context::erase_cmd(symbol const & s) { cmd * c; if (m_cmds.find(s, c)) { @@ -1087,11 +1171,8 @@ void cmd_context::reset_psort_decls() { } void cmd_context::reset_macros() { - dictionary::iterator it = m_macros.begin(); - dictionary::iterator end = m_macros.end(); - for (; it != end; ++it) { - expr * t = (*it).m_value.second; - m().dec_ref(t); + for (auto & kv : m_macros) { + kv.m_value.finalize(m()); } m_macros.reset(); m_macros_stack.reset(); @@ -1274,10 +1355,7 @@ void cmd_context::restore_macros(unsigned old_sz) { svector::iterator end = m_macros_stack.end(); for (; it != end; ++it) { symbol const & s = *it; - macro _m; - VERIFY (m_macros.find(s, _m)); - m().dec_ref(_m.second); - m_macros.erase(s); + erase_macro(s); } m_macros_stack.shrink(old_sz); } diff --git a/src/cmd_context/cmd_context.h b/src/cmd_context/cmd_context.h index ca4883e74..189863e58 100644 --- a/src/cmd_context/cmd_context.h +++ b/src/cmd_context/cmd_context.h @@ -43,11 +43,13 @@ Notes: class func_decls { func_decl * m_decls; bool signatures_collide(func_decl* f, func_decl* g) const; + bool signatures_collide(unsigned n, sort*const* domain, sort* range, func_decl* g) const; public: func_decls():m_decls(0) {} func_decls(ast_manager & m, func_decl * f); void finalize(ast_manager & m); bool contains(func_decl * f) const; + bool contains(unsigned n, sort* const* domain, sort* range) const; bool insert(ast_manager & m, func_decl * f); void erase(ast_manager & m, func_decl * f); bool more_than_one() const; @@ -58,6 +60,29 @@ public: func_decl * find(ast_manager & m, unsigned num_args, expr * const * args, sort * range) const; }; +struct macro_decl { + ptr_vector m_domain; + expr* m_body; + + macro_decl(unsigned arity, sort *const* domain, expr* body): + m_domain(arity, domain), m_body(body) {} + + void dec_ref(ast_manager& m) { m.dec_ref(m_body); } + +}; + +class macro_decls { + vector* m_decls; +public: + macro_decls() { m_decls = 0; } + void finalize(ast_manager& m); + bool insert(ast_manager& m, unsigned arity, sort *const* domain, expr* body); + expr* find(unsigned arity, sort *const* domain) const; + void erase_last(ast_manager& m); + vector::iterator begin() const { return m_decls->begin(); } + vector::iterator end() const { return m_decls->end(); } +}; + /** \brief Generic wrapper. */ @@ -184,7 +209,7 @@ protected: dictionary m_func_decls; obj_map m_func_decl2alias; dictionary m_psort_decls; - dictionary m_macros; + dictionary m_macros; // the following fields m_func_decls_stack, m_psort_decls_stack and m_exprs_stack are used when m_global_decls == false typedef std::pair sf_pair; svector m_func_decls_stack; @@ -253,7 +278,6 @@ protected: void erase_func_decl_core(symbol const & s, func_decl * f); void erase_psort_decl_core(symbol const & s); - void erase_macro_core(symbol const & s); bool logic_has_arith() const; bool logic_has_bv() const; @@ -268,6 +292,16 @@ protected: void mk_solver(); + bool contains_func_decl(symbol const& s, unsigned n, sort* const* domain, sort* range) const; + + bool contains_macro(symbol const& s) const; + bool contains_macro(symbol const& s, func_decl* f) const; + bool contains_macro(symbol const& s, unsigned arity, sort *const* domain) const; + void insert_macro(symbol const& s, unsigned arity, sort*const* domain, expr* t); + void erase_macro(symbol const& s); + bool macros_find(symbol const& s, unsigned n, expr*const* args, expr*& t) const; + + public: cmd_context(bool main_ctx = true, ast_manager * m = 0, symbol const & l = symbol::null); ~cmd_context(); @@ -337,7 +371,7 @@ public: void insert(func_decl * f) { insert(f->get_name(), f); } void insert(symbol const & s, psort_decl * p); void insert(psort_decl * p) { insert(p->get_name(), p); } - void insert(symbol const & s, unsigned arity, expr * t); + void insert(symbol const & s, unsigned arity, sort *const* domain, expr * t); void insert(symbol const & s, object_ref *); void insert(tactic_cmd * c) { tactic_manager::insert(c); } void insert(probe_info * p) { tactic_manager::insert(p); } @@ -348,7 +382,6 @@ public: func_decl * find_func_decl(symbol const & s, unsigned num_indices, unsigned const * indices, unsigned arity, sort * const * domain, sort * range) const; psort_decl * find_psort_decl(symbol const & s) const; - macro find_macro(symbol const & s) const; cmd * find_cmd(symbol const & s) const; sexpr * find_user_tactic(symbol const & s) const; object_ref * find_object_ref(symbol const & s) const; @@ -360,7 +393,6 @@ public: void erase_func_decl(symbol const & s, func_decl * f); void erase_func_decl(func_decl * f) { erase_func_decl(f->get_name(), f); } void erase_psort_decl(symbol const & s); - void erase_macro(symbol const & s); void erase_object_ref(symbol const & s); void erase_user_tactic(symbol const & s); void reset_func_decls(); @@ -400,7 +432,7 @@ public: void validate_check_sat_result(lbool r); unsigned num_scopes() const { return m_scopes.size(); } - dictionary const & get_macros() const { return m_macros; } + dictionary const & get_macros() const { return m_macros; } bool is_model_available() const; diff --git a/src/parsers/smt2/smt2parser.cpp b/src/parsers/smt2/smt2parser.cpp index b86a663a7..3d895668b 100644 --- a/src/parsers/smt2/smt2parser.cpp +++ b/src/parsers/smt2/smt2parser.cpp @@ -992,7 +992,7 @@ namespace smt2 { TRACE("name_expr", tout << "naming: " << s << " ->\n" << mk_pp(n, m()) << "\n";); if (!is_ground(n) && has_free_vars(n)) throw parser_exception("invalid named expression, expression contains free variables"); - m_ctx.insert(s, 0, n); + m_ctx.insert(s, 0, 0, n); m_last_named_expr.first = s; m_last_named_expr.second = n; } @@ -1984,7 +1984,7 @@ namespace smt2 { parse_expr(); if (m().get_sort(expr_stack().back()) != sort_stack().back()) throw parser_exception("invalid function/constant definition, sort mismatch"); - m_ctx.insert(id, num_vars, expr_stack().back()); + m_ctx.insert(id, num_vars, sort_stack().c_ptr() + sort_spos, expr_stack().back()); check_rparen("invalid function/constant definition, ')' expected"); // restore stacks & env symbol_stack().shrink(sym_spos); @@ -2135,7 +2135,7 @@ namespace smt2 { parse_expr(); if (m().get_sort(expr_stack().back()) != sort_stack().back()) throw parser_exception("invalid constant definition, sort mismatch"); - m_ctx.insert(id, 0, expr_stack().back()); + m_ctx.insert(id, 0, 0, expr_stack().back()); check_rparen("invalid constant definition, ')' expected"); expr_stack().pop_back(); sort_stack().pop_back(); diff --git a/src/smt/smt_context.h b/src/smt/smt_context.h index 84bf9f62a..1aa3b385b 100644 --- a/src/smt/smt_context.h +++ b/src/smt/smt_context.h @@ -257,7 +257,15 @@ namespace smt { return m_params; } - bool get_cancel_flag() { return !m_manager.limit().inc(); } + bool get_cancel_flag() { + if (m_manager.limit().inc()) { + // get_simplifier().reset(); + return false; + } + else { + return true; + } + } region & get_region() { return m_region;