diff --git a/src/cmd_context/tptp_frontend.cpp b/src/cmd_context/tptp_frontend.cpp index e787fc29f..f55fd53f5 100644 --- a/src/cmd_context/tptp_frontend.cpp +++ b/src/cmd_context/tptp_frontend.cpp @@ -259,8 +259,10 @@ class tptp_parser { arith_util m_arith; sort* m_univ; bool m_has_conjecture = false; + bool m_last_name_quoted = false; unsigned m_lambda_counter = 0; std::unordered_map m_sorts; + std::unordered_map, sort*>> m_ho_sort_info; // ho_sort → (domain, range) std::unordered_map m_decls; func_decl_ref_vector m_pinned_decls; // prevents cached func_decls from being freed std::unordered_map, sort*>> m_typed_decls; @@ -315,6 +317,7 @@ class tptp_parser { std::string parse_name() { if (is(token_kind::id) || is(token_kind::str)) { + m_last_name_quoted = is(token_kind::str); std::string r = m_curr.text; next(); return r; @@ -351,6 +354,7 @@ class tptp_parser { if (it != m_sorts.end()) return it->second; sort* s = m.mk_uninterpreted_sort(symbol(key)); m_sorts.emplace(key, s); + m_ho_sort_info.emplace(s, std::make_pair(domain, range)); return s; } @@ -459,6 +463,16 @@ class tptp_parser { return false; } + bool is_bound_var(app* a) const { + std::string name = a->get_decl()->get_name().str(); + for (auto it = m_bound.rbegin(); it != m_bound.rend(); ++it) { + auto jt = it->find(name); + if (jt != it->end() && jt->second == a) + return true; + } + return false; + } + bool should_create_implicit_var(std::string const& n) const { return is_var_name(n) && m_implicit_scope; } @@ -643,9 +657,9 @@ class tptp_parser { } expr_ref b(m); - if (is_var_name(n) && find_bound(n, b)) + if (!m_last_name_quoted && is_var_name(n) && find_bound(n, b)) return b; - if (should_create_implicit_var(n)) + if (!m_last_name_quoted && should_create_implicit_var(n)) return expr_ref(get_or_create_implicit_var(n), m); expr_ref_vector args(m); @@ -665,6 +679,73 @@ class tptp_parser { return expr_ref(m_arith.mk_uminus(a), m); } + if (n == "$sum" || n == "$difference" || n == "$product") { + if (args.size() != 2) + throw parse_error("arithmetic function expects arity 2"); + expr_ref a(args.get(0), m), b(args.get(1), m); + if (!m_arith.is_int_real(a) || !m_arith.is_int_real(b)) + throw parse_error("arithmetic function expects arithmetic arguments"); + bool use_real = m_arith.is_real(a) || m_arith.is_real(b); + if (use_real) { + if (m_arith.is_int(a)) a = expr_ref(m_arith.mk_to_real(a), m); + if (m_arith.is_int(b)) b = expr_ref(m_arith.mk_to_real(b), m); + } + if (n == "$sum") return expr_ref(m_arith.mk_add(a, b), m); + if (n == "$difference") return expr_ref(m_arith.mk_sub(a, b), m); + /* $product */ return expr_ref(m_arith.mk_mul(a, b), m); + } + + if (n == "$quotient_e" || n == "$quotient_t" || n == "$quotient_f" || n == "$quotient") { + if (args.size() != 2) + throw parse_error("arithmetic function expects arity 2"); + expr_ref a(args.get(0), m), b(args.get(1), m); + if (!m_arith.is_int_real(a) || !m_arith.is_int_real(b)) + throw parse_error("arithmetic function expects arithmetic arguments"); + if (m_arith.is_int(a) && m_arith.is_int(b)) + return expr_ref(m_arith.mk_idiv(a, b), m); + if (m_arith.is_int(a)) a = expr_ref(m_arith.mk_to_real(a), m); + if (m_arith.is_int(b)) b = expr_ref(m_arith.mk_to_real(b), m); + return expr_ref(m_arith.mk_div(a, b), m); + } + + if (n == "$remainder_e" || n == "$remainder_t" || n == "$remainder_f") { + if (args.size() != 2) + throw parse_error("arithmetic function expects arity 2"); + expr_ref a(args.get(0), m), b(args.get(1), m); + if (!m_arith.is_int_real(a) || !m_arith.is_int_real(b)) + throw parse_error("arithmetic function expects arithmetic arguments"); + return expr_ref(m_arith.mk_mod(a, b), m); + } + + if (n == "$floor" || n == "$ceiling" || n == "$truncate" || n == "$round" || n == "$to_int") { + if (args.size() != 1) + throw parse_error("arithmetic function expects arity 1"); + expr_ref a(args.get(0), m); + if (!m_arith.is_int_real(a)) + throw parse_error("arithmetic function expects arithmetic argument"); + if (m_arith.is_int(a)) return a; + return expr_ref(m_arith.mk_to_int(a), m); + } + + if (n == "$to_rat" || n == "$to_real") { + if (args.size() != 1) + throw parse_error("arithmetic function expects arity 1"); + expr_ref a(args.get(0), m); + if (!m_arith.is_int_real(a)) + throw parse_error("arithmetic function expects arithmetic argument"); + if (m_arith.is_real(a)) return a; + return expr_ref(m_arith.mk_to_real(a), m); + } + + if (n == "$is_int") { + if (args.size() != 1) + throw parse_error("arithmetic predicate '$is_int' expects arity 1"); + expr_ref a(args.get(0), m); + if (!m_arith.is_int_real(a)) + throw parse_error("arithmetic predicate '$is_int' expects arithmetic argument"); + return expr_ref(m_arith.mk_is_int(a), m); + } + func_decl* f = mk_decl(n, args.size(), false); return expr_ref(args.empty() ? m.mk_const(f) : m.mk_app(f, args.size(), args.data()), m); } @@ -692,6 +773,53 @@ class tptp_parser { unsigned new_arity = args.size(); std::string name = orig_d->get_name().str(); + + // Check if the LHS has a ho_sort (function type as value). + // If so, use its type info to determine domain/range for the application. + sort* lhs_sort = orig_d->get_range(); // sort of the 0-arity constant + auto ho_it = m_ho_sort_info.find(lhs_sort); + if (ho_it != m_ho_sort_info.end() && a->get_num_args() == 0) { + auto& [ho_domain, ho_range] = ho_it->second; + sort* range = ho_range; + // If we're partially applying (fewer args than domain), result is a ho_sort + if (new_arity < ho_domain.size()) { + std::vector remaining(ho_domain.begin() + new_arity, ho_domain.end()); + range = get_ho_sort(remaining, ho_range); + } + // For bound variables, use $apply so that the variable remains in the + // expression body and mk_forall can properly abstract it. + if (is_bound_var(a)) { + std::vector apply_dom; + apply_dom.push_back(lhs_sort); // first arg is the function variable + for (unsigned i = 0; i < new_arity; ++i) { + if (i < ho_domain.size()) + apply_dom.push_back(ho_domain[i]); + else + apply_dom.push_back(args.get(i)->get_sort()); + } + // Create unique apply function per signature + std::ostringstream apn; + apn << "$apply_" << lhs_sort->get_name() << "_" << new_arity; + func_decl* apply_d = m.mk_func_decl(symbol(apn.str()), apply_dom.size(), apply_dom.data(), range); + m_pinned_decls.push_back(apply_d); + expr_ref_vector all_args(m); + all_args.push_back(e); // the bound variable itself + for (unsigned i = 0; i < new_arity; ++i) all_args.push_back(args.get(i)); + return expr_ref(m.mk_app(apply_d, all_args.size(), all_args.data()), m); + } + // For non-bound ho-typed constants, use the name directly + std::vector dom; + for (unsigned i = 0; i < new_arity; ++i) { + if (i < ho_domain.size()) + dom.push_back(ho_domain[i]); + else + dom.push_back(args.get(i)->get_sort()); + } + func_decl* new_d = m.mk_func_decl(symbol(name), new_arity, dom.data(), range); + m_pinned_decls.push_back(new_d); + return expr_ref(m.mk_app(new_d, new_arity, args.data()), m); + } + func_decl* new_d = mk_decl(name, new_arity, false); // Verify argument sorts match; if not, create untyped func_decl with actual arg sorts bool sorts_ok = (new_d->get_arity() == new_arity); @@ -759,6 +887,54 @@ class tptp_parser { return parse_term_primary(); } + // Build an arithmetic expression from a TPTP function name and arguments + expr_ref mk_arith_expr(std::string const& n, expr_ref_vector const& args) { + if (n == "$uminus") { + if (args.size() != 1) throw parse_error("$uminus expects arity 1"); + expr_ref a(args.get(0), m); + return expr_ref(m_arith.mk_uminus(a), m); + } + if (n == "$sum" || n == "$difference" || n == "$product") { + if (args.size() != 2) throw parse_error("arithmetic binary function expects arity 2"); + expr_ref a(args.get(0), m), b(args.get(1), m); + bool use_real = m_arith.is_real(a) || m_arith.is_real(b); + if (use_real) { + if (m_arith.is_int(a)) a = expr_ref(m_arith.mk_to_real(a), m); + if (m_arith.is_int(b)) b = expr_ref(m_arith.mk_to_real(b), m); + } + if (n == "$sum") return expr_ref(m_arith.mk_add(a, b), m); + if (n == "$difference") return expr_ref(m_arith.mk_sub(a, b), m); + return expr_ref(m_arith.mk_mul(a, b), m); + } + if (n == "$quotient" || n == "$quotient_e" || n == "$quotient_t" || n == "$quotient_f") { + if (args.size() != 2) throw parse_error("quotient expects arity 2"); + expr_ref a(args.get(0), m), b(args.get(1), m); + if (m_arith.is_int(a) && m_arith.is_int(b)) + return expr_ref(m_arith.mk_idiv(a, b), m); + if (m_arith.is_int(a)) a = expr_ref(m_arith.mk_to_real(a), m); + if (m_arith.is_int(b)) b = expr_ref(m_arith.mk_to_real(b), m); + return expr_ref(m_arith.mk_div(a, b), m); + } + if (n == "$remainder_e" || n == "$remainder_t" || n == "$remainder_f") { + if (args.size() != 2) throw parse_error("remainder expects arity 2"); + expr_ref a(args.get(0), m), b(args.get(1), m); + return expr_ref(m_arith.mk_mod(a, b), m); + } + if (n == "$floor" || n == "$ceiling" || n == "$truncate" || n == "$round" || n == "$to_int") { + if (args.size() != 1) throw parse_error("$to_int expects arity 1"); + expr_ref a(args.get(0), m); + if (m_arith.is_int(a)) return a; + return expr_ref(m_arith.mk_to_int(a), m); + } + if (n == "$to_rat" || n == "$to_real") { + if (args.size() != 1) throw parse_error("$to_real expects arity 1"); + expr_ref a(args.get(0), m); + if (m_arith.is_real(a)) return a; + return expr_ref(m_arith.mk_to_real(a), m); + } + throw parse_error("unknown arithmetic function: " + n); + } + // Coerce two expressions to have the same sort for equality. // If sorts already match, returns lhs unchanged. Otherwise coerces the // 0-arity constant side to match the other's sort. If that's not possible, @@ -795,11 +971,11 @@ class tptp_parser { if (accept(token_kind::lparen)) { expr_ref e = parse_formula(); // Handle equality/inequality inside parenthesized expressions - // e.g., ( (f @ Y) = (g @ X @ Z) ) where f(Y) is non-Bool - if (!m.is_bool(e) && (is(token_kind::equal_tok) || is(token_kind::neq_tok))) { + // In THF, (A = B) is used even for Bool-sorted expressions (meaning iff) + if (is(token_kind::equal_tok) || is(token_kind::neq_tok)) { bool neq = accept(token_kind::neq_tok); if (!neq) expect(token_kind::equal_tok, "'='"); - expr_ref rhs = parse_term(); + expr_ref rhs = m.is_bool(e) ? parse_formula() : parse_term(); e = coerce_eq(e, rhs); expr_ref eq(m.mk_eq(e, rhs), m); e = neq ? expr_ref(m.mk_not(eq), m) : eq; @@ -855,15 +1031,34 @@ class tptp_parser { /* n == "$greatereq"*/ return expr_ref(m_arith.mk_ge(lhs, rhs), m); } + // Arithmetic terms that may appear before = or != in formula position + if (n == "$sum" || n == "$difference" || n == "$product" || + n == "$quotient" || n == "$quotient_e" || n == "$quotient_t" || n == "$quotient_f" || + n == "$remainder_e" || n == "$remainder_t" || n == "$remainder_f" || + n == "$uminus" || n == "$floor" || n == "$ceiling" || n == "$truncate" || + n == "$round" || n == "$to_int" || n == "$to_rat" || n == "$to_real") { + // Re-parse via the term path by constructing the arithmetic expression + expr_ref arith_expr = mk_arith_expr(n, args); + if (is(token_kind::equal_tok) || is(token_kind::neq_tok)) { + bool neq = accept(token_kind::neq_tok); + if (!neq) expect(token_kind::equal_tok, "'='"); + expr_ref rhs = parse_term(); + arith_expr = coerce_eq(arith_expr, rhs); + expr_ref eq(m.mk_eq(arith_expr, rhs), m); + return neq ? expr_ref(m.mk_not(eq), m) : eq; + } + return arith_expr; + } + expr_ref lhs(m); bool has_lhs = false; if (args.empty()) { expr_ref b(m); - if (is_var_name(n) && find_bound(n, b)) { + if (!m_last_name_quoted && is_var_name(n) && find_bound(n, b)) { lhs = b; has_lhs = true; } - else if (should_create_implicit_var(n)) { + else if (!m_last_name_quoted && should_create_implicit_var(n)) { lhs = expr_ref(get_or_create_implicit_var(n), m); has_lhs = true; } @@ -953,9 +1148,25 @@ class tptp_parser { oss << "$lambda_" << m_lambda_counter++; std::string lname = oss.str(); // Lambda type: param sorts → body sort + // If the body sort is itself a ho_sort, flatten it into the lambda's type + // e.g., ^ [X: mu] : (body with sort (U > Bool)) → sort (mu * U > Bool) std::vector param_sorts; for (auto* v : vars) param_sorts.push_back(v->get_sort()); - sort* lambda_sort = param_sorts.empty() ? body->get_sort() : get_ho_sort(param_sorts, body->get_sort()); + sort* body_sort = body->get_sort(); + sort* lambda_sort; + if (param_sorts.empty()) { + lambda_sort = body_sort; + } else { + auto body_ho = m_ho_sort_info.find(body_sort); + if (body_ho != m_ho_sort_info.end()) { + // Flatten: params + body's domain → body's range + std::vector full_domain = param_sorts; + full_domain.insert(full_domain.end(), body_ho->second.first.begin(), body_ho->second.first.end()); + lambda_sort = get_ho_sort(full_domain, body_ho->second.second); + } else { + lambda_sort = get_ho_sort(param_sorts, body_sort); + } + } func_decl* f = m.mk_func_decl(symbol(lname), 0, static_cast(nullptr), lambda_sort); m_pinned_decls.push_back(f); return expr_ref(m.mk_const(f), m); @@ -1128,9 +1339,13 @@ class tptp_parser { } if (role == "conjecture") { m_has_conjecture = true; - f = m.mk_not(f); + if (m.is_bool(f)) + f = m.mk_not(f); } - m_cmd.assert_expr(f); + // Only assert Bool-sorted formulas; non-Bool results from + // incomplete higher-order approximation are silently skipped. + if (m.is_bool(f)) + m_cmd.assert_expr(f); } if (accept(token_kind::comma)) {