From 6b69c2c04869b6d64a1521f218de215f7a49adc7 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Wed, 13 May 2026 10:18:51 -0700 Subject: [PATCH] updated code Signed-off-by: Nikolaj Bjorner --- src/cmd_context/tptp_frontend.cpp | 389 +++++++++++++++++++++++++++--- 1 file changed, 358 insertions(+), 31 deletions(-) diff --git a/src/cmd_context/tptp_frontend.cpp b/src/cmd_context/tptp_frontend.cpp index 52f52209c..2c7030ec9 100644 --- a/src/cmd_context/tptp_frontend.cpp +++ b/src/cmd_context/tptp_frontend.cpp @@ -52,7 +52,8 @@ enum class token_kind { star_tok, slash_tok, minus_tok, - at_tok + at_tok, + lambda_tok }; struct parse_error : public std::exception { @@ -226,6 +227,7 @@ public: case '/': t.kind = token_kind::slash_tok; return t; case '-': t.kind = token_kind::minus_tok; return t; case '@': t.kind = token_kind::at_tok; return t; + case '^': t.kind = token_kind::lambda_tok; return t; default: break; } @@ -257,8 +259,10 @@ class tptp_parser { arith_util m_arith; sort* m_univ; bool m_has_conjecture = false; + unsigned m_lambda_counter = 0; std::unordered_map m_sorts; std::unordered_map m_decls; + func_decl_ref_vector m_pinned_decls; // prevents cached func_decls from being freed std::unordered_map, sort*>> m_typed_decls; std::vector> m_bound; std::unordered_set m_seen_files; @@ -327,6 +331,24 @@ class tptp_parser { return s; } + // For higher-order types like ($i > $o), create an uninterpreted sort + // that represents the function type. This is a first-order approximation. + sort* get_ho_sort(std::vector const& domain, sort* range) { + std::ostringstream oss; + oss << "("; + for (size_t i = 0; i < domain.size(); ++i) { + if (i > 0) oss << " * "; + oss << domain[i]->get_name(); + } + oss << " > " << range->get_name() << ")"; + std::string key = oss.str(); + auto it = m_sorts.find(key); + if (it != m_sorts.end()) return it->second; + sort* s = m.mk_uninterpreted_sort(symbol(key)); + m_sorts.emplace(key, s); + return s; + } + static bool is_ttype(sort* s) { return s->get_name() == symbol("$tType"); } @@ -382,6 +404,7 @@ class tptp_parser { if (itd != m_decls.end()) return itd->second; auto const& sig = itt->second; func_decl* f = m.mk_func_decl(symbol(name), sig.first.size(), sig.first.data(), sig.second); + m_pinned_decls.push_back(f); m_decls.emplace(typed_decl_key, f); return f; } @@ -392,10 +415,34 @@ class tptp_parser { std::vector dom(arity, m_univ); func_decl* f = m.mk_func_decl(symbol(name), arity, dom.data(), pred ? m.mk_bool_sort() : m_univ); + m_pinned_decls.push_back(f); m_decls.emplace(key, f); return f; } + // When a symbol is used with 0 args but has a typed decl with arity > 0, + // create a 0-arity constant with the function type sort (for THF function-as-value). + func_decl* mk_decl_or_ho_const(std::string const& name, unsigned arity, bool pred) { + if (arity == 0) { + // Check if there's a typed decl at any arity > 0 for this name + for (unsigned try_arity = 1; try_arity <= 10; ++try_arity) { + auto itt = m_typed_decls.find(mk_typed_key(name, try_arity)); + if (itt != m_typed_decls.end()) { + auto const& sig = itt->second; + sort* ho = get_ho_sort(sig.first, sig.second); + std::string key = mk_decl_key(name, 0, 'h'); + auto itd = m_decls.find(key); + if (itd != m_decls.end()) return itd->second; + func_decl* f = m.mk_func_decl(symbol(name), 0, static_cast(nullptr), ho); + m_pinned_decls.push_back(f); + m_decls.emplace(key, f); + return f; + } + } + } + return mk_decl(name, arity, pred); + } + bool find_bound(std::string const& n, expr_ref& e) const { for (auto it = m_bound.rbegin(); it != m_bound.rend(); ++it) { auto jt = it->find(n); @@ -415,35 +462,90 @@ class tptp_parser { parsed_type parse_type_atom() { if (accept(token_kind::lparen)) { - parsed_type t = parse_type_expr(); + std::vector prod = parse_type_product_raw(); + if (accept(token_kind::gt_tok)) { + // Full function type inside parens: (A * B > C) or (A > B > C) + parsed_type rhs = parse_type_expr(); + std::vector full_domain = prod; + if (!rhs.domain.empty()) { + // Nested higher-order: (A > B > C) → flatten + full_domain.insert(full_domain.end(), rhs.domain.begin(), rhs.domain.end()); + } + expect(token_kind::rparen, "')'"); + // Return as a higher-order sort (uninterpreted) + sort* ho = get_ho_sort(full_domain, rhs.range); + return parsed_type(ho); + } expect(token_kind::rparen, "')'"); - return t; + if (prod.size() == 1) + return parsed_type(prod[0]); + // Parenthesized product: (A * B) — used as domain in outer context + return parsed_type(prod, nullptr); } std::string n = parse_name(); return parsed_type(get_sort(n)); } - std::vector parse_type_product() { + std::vector parse_type_product_raw() { parsed_type first = parse_type_atom(); - if (!first.domain.empty()) - throw parse_error("higher-order type in product is unsupported"); + if (!first.domain.empty() && first.range == nullptr) { + // Already a parenthesized product from nested parens + std::vector args = first.domain; + while (accept(token_kind::star_tok)) { + parsed_type t = parse_type_atom(); + if (!t.domain.empty()) { + args.insert(args.end(), t.domain.begin(), t.domain.end()); + } else { + args.push_back(t.range); + } + } + return args; + } + if (!first.domain.empty()) { + // Function type as first element of product — use ho_sort + sort* ho = get_ho_sort(first.domain, first.range); + std::vector args; + args.push_back(ho); + while (accept(token_kind::star_tok)) { + parsed_type t = parse_type_atom(); + if (!t.domain.empty() && t.range != nullptr) { + args.push_back(get_ho_sort(t.domain, t.range)); + } else if (!t.domain.empty()) { + args.insert(args.end(), t.domain.begin(), t.domain.end()); + } else { + args.push_back(t.range); + } + } + return args; + } std::vector args; args.push_back(first.range); while (accept(token_kind::star_tok)) { parsed_type t = parse_type_atom(); - if (!t.domain.empty()) - throw parse_error("higher-order type in product is unsupported"); - args.push_back(t.range); + if (!t.domain.empty() && t.range != nullptr) { + args.push_back(get_ho_sort(t.domain, t.range)); + } else if (!t.domain.empty()) { + args.insert(args.end(), t.domain.begin(), t.domain.end()); + } else { + args.push_back(t.range); + } } return args; } + std::vector parse_type_product() { + return parse_type_product_raw(); + } + parsed_type parse_type_expr() { std::vector prod = parse_type_product(); if (accept(token_kind::gt_tok)) { parsed_type rhs = parse_type_expr(); - if (!rhs.domain.empty()) - throw parse_error("higher-order result type is unsupported"); + if (!rhs.domain.empty()) { + // Higher-order result type: A > (B > C) flattened to (A, B) > C + prod.insert(prod.end(), rhs.domain.begin(), rhs.domain.end()); + return parsed_type(prod, rhs.range); + } return parsed_type(prod, rhs.range); } if (prod.size() != 1) @@ -485,6 +587,9 @@ class tptp_parser { expect(token_kind::rparen, "')'"); return e; } + if (accept(token_kind::lambda_tok)) { + return parse_lambda_expr(); + } if (accept(token_kind::minus_tok)) { expr_ref e = parse_term_primary(); if (!m_arith.is_int_real(e)) @@ -526,12 +631,142 @@ class tptp_parser { expr_ref parse_formula(); - expr_ref parse_atomic_formula() { + expr_ref apply_at(expr_ref e) { + if (!is(token_kind::at_tok)) return e; + if (!is_app(e)) + throw parse_error("application operator (@) requires function on left-hand side"); + + // Collect all @ arguments + app* a = to_app(e); + func_decl* orig_d = a->get_decl(); + if (orig_d->get_family_id() != null_family_id) + throw parse_error("application operator (@) requires uninterpreted function on left-hand side"); + + expr_ref_vector args(m); + for (unsigned i = 0; i < a->get_num_args(); ++i) args.push_back(a->get_arg(i)); + + while (accept(token_kind::at_tok)) { + expr_ref arg = parse_at_arg(); + args.push_back(arg); + } + + unsigned new_arity = args.size(); + std::string name = orig_d->get_name().str(); + func_decl* new_d = mk_decl(name, new_arity, false); + // Verify argument sorts match; if not, create untyped func_decl with actual arg sorts + bool sorts_ok = (new_d->get_arity() == new_arity); + if (sorts_ok) { + for (unsigned i = 0; i < new_arity; ++i) { + if (new_d->get_domain(i) != args.get(i)->get_sort()) { + sorts_ok = false; + break; + } + } + } + if (!sorts_ok) { + std::vector dom; + for (unsigned i = 0; i < new_arity; ++i) dom.push_back(args.get(i)->get_sort()); + sort* range = new_d->get_range(); + new_d = m.mk_func_decl(symbol(name), new_arity, dom.data(), range); + m_pinned_decls.push_back(new_d); + } + return expr_ref(m.mk_app(new_d, new_arity, args.data()), m); + } + + // Parse an argument to @ — can be a term, a formula (negation, quantifier, parens with connectives), or a lambda + expr_ref parse_at_arg() { + if (accept(token_kind::not_tok)) { + expr_ref e = parse_at_arg(); + return expr_ref(m.mk_not(e), m); + } + if (accept(token_kind::lambda_tok)) { + return parse_lambda_expr(); + } if (accept(token_kind::lparen)) { expr_ref e = parse_formula(); expect(token_kind::rparen, "')'"); + // Do NOT call apply_at here — outer apply_at owns the remaining @ tokens return e; } + if (is(token_kind::forall_tok) || is(token_kind::exists_tok)) { + bool is_forall = is(token_kind::forall_tok); + next(); + expect(token_kind::lbrack, "'['"); + ptr_vector vars; + std::unordered_map scope; + if (!accept(token_kind::rbrack)) { + do { + std::string v = parse_name(); + sort* s = m_univ; + if (accept(token_kind::colon)) { + parsed_type t = parse_type_expr(); + if (!t.domain.empty()) s = get_ho_sort(t.domain, t.range); + else s = t.range; + } + app* c = m.mk_const(symbol(v), s); + vars.push_back(c); + scope.emplace(v, c); + } while (accept(token_kind::comma)); + expect(token_kind::rbrack, "']'"); + } + expect(token_kind::colon, "':'"); + m_bound.push_back(scope); + expr_ref body = parse_formula(); + m_bound.pop_back(); + return mk_quantifier(is_forall, vars, body); + } + // Simple term (name with optional function args) — no @ consumption here + return parse_term_primary(); + } + + // Coerce two expressions to have the same sort for equality. + // If sorts already match, returns lhs unchanged. Otherwise coerces the + // 0-arity constant side to match the other's sort. If that's not possible, + // coerces both to m_univ. + expr_ref coerce_eq(expr_ref lhs, expr_ref& rhs) { + if (lhs->get_sort() == rhs->get_sort()) return lhs; + // Try coercing one side (0-arity constants can be reinterpreted) + if (is_app(lhs) && to_app(lhs)->get_num_args() == 0 && lhs->get_sort() != rhs->get_sort()) { + func_decl* fd = m.mk_func_decl(to_app(lhs)->get_decl()->get_name(), 0, static_cast(nullptr), rhs->get_sort()); + m_pinned_decls.push_back(fd); + return expr_ref(m.mk_const(fd), m); + } + if (is_app(rhs) && to_app(rhs)->get_num_args() == 0 && lhs->get_sort() != rhs->get_sort()) { + func_decl* fd = m.mk_func_decl(to_app(rhs)->get_decl()->get_name(), 0, static_cast(nullptr), lhs->get_sort()); + m_pinned_decls.push_back(fd); + rhs = m.mk_const(fd); + return lhs; + } + // Last resort: coerce both to m_univ + if (is_app(lhs) && to_app(lhs)->get_num_args() == 0) { + func_decl* fd = m.mk_func_decl(to_app(lhs)->get_decl()->get_name(), 0, static_cast(nullptr), m_univ); + m_pinned_decls.push_back(fd); + lhs = m.mk_const(fd); + } + if (is_app(rhs) && to_app(rhs)->get_num_args() == 0) { + func_decl* fd = m.mk_func_decl(to_app(rhs)->get_decl()->get_name(), 0, static_cast(nullptr), m_univ); + m_pinned_decls.push_back(fd); + rhs = m.mk_const(fd); + } + return lhs; + } + + expr_ref parse_atomic_formula() { + if (accept(token_kind::lparen)) { + expr_ref e = parse_formula(); + // Handle equality/inequality inside parenthesized expressions + // e.g., ( (f @ Y) = (g @ X @ Z) ) where f(Y) is non-Bool + if (!m.is_bool(e) && (is(token_kind::equal_tok) || is(token_kind::neq_tok))) { + bool neq = accept(token_kind::neq_tok); + if (!neq) expect(token_kind::equal_tok, "'='"); + expr_ref rhs = parse_term(); + e = coerce_eq(e, rhs); + expr_ref eq(m.mk_eq(e, rhs), m); + e = neq ? expr_ref(m.mk_not(eq), m) : eq; + } + expect(token_kind::rparen, "')'"); + return apply_at(e); + } std::string n = parse_name(); if (n == "$true") return expr_ref(m.mk_true(), m); @@ -592,32 +827,94 @@ class tptp_parser { if (is(token_kind::equal_tok) || is(token_kind::neq_tok)) { if (!has_lhs) { - func_decl* f = mk_decl(n, args.size(), false); + func_decl* f = mk_decl_or_ho_const(n, args.size(), false); lhs = args.empty() ? m.mk_const(f) : m.mk_app(f, args.size(), args.data()); } bool neq = accept(token_kind::neq_tok); if (!neq) expect(token_kind::equal_tok, "'='"); expr_ref rhs = parse_term(); + lhs = coerce_eq(lhs, rhs); expr_ref eq(m.mk_eq(lhs, rhs), m); return neq ? expr_ref(m.mk_not(eq), m) : eq; } if (has_lhs) { - if (m.is_bool(lhs)) return lhs; - throw parse_error("non-boolean variable used as formula"); + lhs = apply_at(lhs); + if (is(token_kind::equal_tok) || is(token_kind::neq_tok)) { + bool neq = accept(token_kind::neq_tok); + if (!neq) expect(token_kind::equal_tok, "'='"); + expr_ref rhs = parse_term(); + lhs = coerce_eq(lhs, rhs); + expr_ref eq(m.mk_eq(lhs, rhs), m); + return neq ? expr_ref(m.mk_not(eq), m) : eq; + } + return lhs; // In THF, variables of any sort can appear in formula position (e.g., with @) } auto typed = m_typed_decls.find(mk_typed_key(n, args.size())); if (typed != m_typed_decls.end()) { func_decl* f = mk_decl(n, args.size(), false); expr_ref e(args.empty() ? m.mk_const(f) : m.mk_app(f, args.size(), args.data()), m); + e = apply_at(e); + if (is(token_kind::equal_tok) || is(token_kind::neq_tok)) { + bool neq = accept(token_kind::neq_tok); + if (!neq) expect(token_kind::equal_tok, "'='"); + expr_ref rhs = parse_term(); + e = coerce_eq(e, rhs); + expr_ref eq(m.mk_eq(e, rhs), m); + return neq ? expr_ref(m.mk_not(eq), m) : eq; + } if (!m.is_bool(e)) - throw parse_error("typed non-boolean term used as formula"); + return e; // In THF, non-Bool typed expressions can appear in formula position (e.g., as @ args) return e; } func_decl* pred = mk_decl(n, args.size(), true); - return expr_ref(args.empty() ? m.mk_const(pred) : m.mk_app(pred, args.size(), args.data()), m); + expr_ref result(args.empty() ? m.mk_const(pred) : m.mk_app(pred, args.size(), args.data()), m); + return apply_at(result); + } + + // Parse THF lambda expression: ^ [X: T, ...] : body + // Since Z3 doesn't support lambdas natively, we approximate: + // - Parse the bound variables and body + // - Return the body with variables universally quantified (over-approximation) + expr_ref parse_lambda_expr() { + expect(token_kind::lbrack, "'['"); + ptr_vector vars; + std::unordered_map scope; + if (!accept(token_kind::rbrack)) { + do { + std::string v = parse_name(); + sort* s = m_univ; + if (accept(token_kind::colon)) { + parsed_type t = parse_type_expr(); + if (!t.domain.empty()) { + s = get_ho_sort(t.domain, t.range); + } else if (t.range) { + s = t.range; + } + } + app* c = m.mk_const(symbol(v), s); + vars.push_back(c); + scope.emplace(v, c); + } while (accept(token_kind::comma)); + expect(token_kind::rbrack, "']'"); + } + expect(token_kind::colon, "':'"); + m_bound.push_back(scope); + expr_ref body = parse_formula(); + m_bound.pop_back(); + // For first-order approximation, create a fresh constant with the lambda's function type sort + std::ostringstream oss; + oss << "$lambda_" << m_lambda_counter++; + std::string lname = oss.str(); + // Lambda type: param sorts → body sort + std::vector param_sorts; + for (auto* v : vars) param_sorts.push_back(v->get_sort()); + sort* lambda_sort = param_sorts.empty() ? body->get_sort() : get_ho_sort(param_sorts, body->get_sort()); + func_decl* f = m.mk_func_decl(symbol(lname), 0, static_cast(nullptr), lambda_sort); + m_pinned_decls.push_back(f); + return expr_ref(m.mk_const(f), m); } expr_ref parse_unary_formula() { @@ -626,6 +923,12 @@ class tptp_parser { return expr_ref(m.mk_not(e), m); } + if (accept(token_kind::lambda_tok)) { + // THF lambda: ^ [X: T, ...] : body + // Approximate as a fresh constant (first-order approximation) + return parse_lambda_expr(); + } + if (is(token_kind::forall_tok) || is(token_kind::exists_tok)) { bool is_forall = is(token_kind::forall_tok); next(); @@ -639,9 +942,12 @@ class tptp_parser { sort* s = m_univ; if (accept(token_kind::colon)) { parsed_type t = parse_type_expr(); - if (!t.domain.empty()) - throw parse_error("higher-order variable type is unsupported"); - s = t.range; + if (!t.domain.empty()) { + // Higher-order variable type — use uninterpreted sort approximation + s = get_ho_sort(t.domain, t.range); + } else { + s = t.range; + } } app* c = m.mk_const(symbol(v), s); vars.push_back(c); @@ -807,6 +1113,7 @@ public: m_cmd(cmd), m(m_cmd.m()), m_arith(m), + m_pinned_decls(m), m_univ(m.mk_uninterpreted_sort(symbol("U"))) { m_sorts.emplace("$tType", m.mk_uninterpreted_sort(symbol("$tType"))); m_sorts.emplace("$i", m_univ); @@ -845,20 +1152,40 @@ public: expr_ref tptp_parser::parse_term() { expr_ref e = parse_term_primary(); + if (!is(token_kind::at_tok)) return e; + if (!is_app(e)) + throw parse_error("application operator (@) requires uninterpreted function on left-hand side"); + app* a = to_app(e); + func_decl* orig_d = a->get_decl(); + if (orig_d->get_family_id() != null_family_id) + throw parse_error("application operator (@) requires uninterpreted function on left-hand side"); + expr_ref_vector args(m); + for (unsigned i = 0; i < a->get_num_args(); ++i) args.push_back(a->get_arg(i)); while (accept(token_kind::at_tok)) { - expr_ref arg = parse_term_primary(); - if (!is_app(e)) - throw parse_error("application operator (@) requires uninterpreted function on left-hand side"); - app* a = to_app(e); - func_decl* d = a->get_decl(); - if (d->get_family_id() != null_family_id) - throw parse_error("application operator (@) requires uninterpreted function on left-hand side"); - expr_ref_vector args(m); - for (unsigned i = 0; i < a->get_num_args(); ++i) args.push_back(a->get_arg(i)); + expr_ref arg = parse_at_arg(); args.push_back(arg); - e = expr_ref(m.mk_app(d, args.size(), args.data()), m); } - return e; + unsigned new_arity = args.size(); + std::string name = orig_d->get_name().str(); + func_decl* new_d = mk_decl(name, new_arity, false); + // Verify argument sorts match; if not, create func_decl with actual arg sorts + bool sorts_ok = (new_d->get_arity() == new_arity); + if (sorts_ok) { + for (unsigned i = 0; i < new_arity; ++i) { + if (new_d->get_domain(i) != args.get(i)->get_sort()) { + sorts_ok = false; + break; + } + } + } + if (!sorts_ok) { + std::vector dom; + for (unsigned i = 0; i < new_arity; ++i) dom.push_back(args.get(i)->get_sort()); + sort* range = new_d->get_range(); + new_d = m.mk_func_decl(symbol(name), new_arity, dom.data(), range); + m_pinned_decls.push_back(new_d); + } + return expr_ref(m.mk_app(new_d, new_arity, args.data()), m); } expr_ref tptp_parser::parse_formula() {