diff --git a/src/cmd_context/tptp_frontend.cpp b/src/cmd_context/tptp_frontend.cpp index 7ab534ad5..fc03218aa 100644 --- a/src/cmd_context/tptp_frontend.cpp +++ b/src/cmd_context/tptp_frontend.cpp @@ -10,6 +10,7 @@ #include #include "ast/arith_decl_plugin.h" +#include "ast/array_decl_plugin.h" #include "ast/expr_abstract.h" #include "ast/ast_util.h" #include "cmd_context/cmd_context.h" @@ -47,6 +48,8 @@ enum class token_kind { not_tok, forall_tok, exists_tok, + type_forall_tok, // !> + type_exists_tok, // ?* equal_tok, neq_tok, iff_tok, @@ -226,8 +229,12 @@ public: case '&': t.kind = token_kind::and_tok; return t; case '|': t.kind = token_kind::or_tok; return t; case '~': t.kind = token_kind::not_tok; return t; - case '!': t.kind = token_kind::forall_tok; return t; - case '?': t.kind = token_kind::exists_tok; return t; + case '!': + if (peek() == '>') { get(); t.kind = token_kind::type_forall_tok; return t; } + t.kind = token_kind::forall_tok; return t; + case '?': + if (peek() == '*') { get(); t.kind = token_kind::type_exists_tok; return t; } + t.kind = token_kind::exists_tok; return t; case '=': t.kind = token_kind::equal_tok; return t; case '>': t.kind = token_kind::gt_tok; return t; case '*': t.kind = token_kind::star_tok; return t; @@ -264,13 +271,12 @@ class tptp_parser { cmd_context& m_cmd; ast_manager& m; arith_util m_arith; + array_util m_array; sort* m_univ; bool m_has_conjecture = false; bool m_last_name_quoted = false; - unsigned m_lambda_counter = 0; std::unordered_map m_sorts; sort_ref_vector m_pinned_sorts; // prevents cached sorts from being freed - std::unordered_map, sort*>> m_ho_sort_info; // ho_sort → (domain, range) std::unordered_map m_decls; func_decl_ref_vector m_pinned_decls; // prevents cached func_decls from being freed expr_ref_vector m_pinned_exprs; // prevents bound variable apps from being freed @@ -283,10 +289,70 @@ class tptp_parser { implicit_var_scope* m_implicit_scope = nullptr; std::unordered_set m_seen_files; + // Table-driven operator dispatch + using op_builder = std::function; + struct op_entry { + bool is_infix; + unsigned precedence; // only meaningful for infix; higher = tighter binding + bool right_assoc; + op_builder builder; + }; + std::unordered_map m_ops; + + // Infix precedence levels: + static constexpr unsigned PREC_IFF = 1; // <=> <~> + static constexpr unsigned PREC_IMPLIES = 2; // => <= + static constexpr unsigned PREC_OR = 3; // | ~| + static constexpr unsigned PREC_AND = 4; // & ~& + static constexpr unsigned PREC_EQ = 5; // = != + std::string m_input; std::unique_ptr m_lex; token m_curr; + // Helper: check arity for arithmetic operators + void check_arith_arity(expr_ref_vector const& args, unsigned expected, char const* name) { + if (args.size() != expected) { + std::ostringstream out; + out << "'" << name << "' expects arity " << expected; + throw parse_error(out.str()); + } + } + + // Helper: coerce two arithmetic args to same sort (promote int to real if needed) + std::pair coerce_arith2(expr_ref_vector const& args) { + expr_ref a(args[0], m), b(args[1], m); + if (m_arith.is_real(a) || m_arith.is_real(b)) { + if (m_arith.is_int(a)) a = expr_ref(m_arith.mk_to_real(a), m); + if (m_arith.is_int(b)) b = expr_ref(m_arith.mk_to_real(b), m); + } + return { a, b }; + } + + // Helper: quotient dispatch (integer division for int/int, real division otherwise) + expr_ref mk_quotient(expr_ref_vector const& args) { + expr_ref a(args[0], m), b(args[1], m); + if (m_arith.is_int(a) && m_arith.is_int(b)) + return expr_ref(m_arith.mk_idiv(a, b), m); + if (m_arith.is_int(a)) a = expr_ref(m_arith.mk_to_real(a), m); + if (m_arith.is_int(b)) b = expr_ref(m_arith.mk_to_real(b), m); + return expr_ref(m_arith.mk_div(a, b), m); + } + + // Map infix token to operator name (returns nullptr if not an infix op token) + char const* token_to_op_name() const { + switch (m_curr.kind) { + case token_kind::iff_tok: return "<=>"; + case token_kind::xor_tok: return "<~>"; + case token_kind::implies_tok: return "=>"; + case token_kind::implied_tok: return "<="; + case token_kind::or_tok: return "|"; + case token_kind::nor_tok: return "~|"; + case token_kind::and_tok: return "&"; + case token_kind::nand_tok: return "~&"; + default: return nullptr; + } + } static std::string to_lower(std::string s) { for (char& c : s) c = static_cast(std::tolower(static_cast(c))); return s; @@ -350,22 +416,12 @@ class tptp_parser { } // For higher-order types like ($i > $o), create an uninterpreted sort - // that represents the function type. This is a first-order approximation. + // Function type A > B is represented as Array(A, B). + // Multi-argument A * B > C is represented as Array(A, Array(B, C)) (curried). sort* get_ho_sort(std::vector const& domain, sort* range) { - std::ostringstream oss; - oss << "("; - for (size_t i = 0; i < domain.size(); ++i) { - if (i > 0) oss << " * "; - oss << domain[i]->get_name(); - } - oss << " > " << range->get_name() << ")"; - std::string key = oss.str(); - auto it = m_sorts.find(key); - if (it != m_sorts.end()) return it->second; - sort* s = m.mk_uninterpreted_sort(symbol(key)); - m_sorts.emplace(key, s); - m_pinned_sorts.push_back(s); - m_ho_sort_info.emplace(s, std::make_pair(domain, range)); + sort* s = range; + for (int i = (int)domain.size() - 1; i >= 0; --i) + s = m_array.mk_array_sort(domain[i], s); return s; } @@ -463,6 +519,42 @@ class tptp_parser { return mk_decl(name, arity, pred); } + // Coerce an expression to a target sort using boxing/unboxing functions + expr_ref coerce_arg(expr_ref const& e, sort* target) { + sort* actual = e->get_sort(); + if (actual == target) return e; + // Create a boxing function from actual sort to target sort + std::string box_name = std::string("$box_") + actual->get_name().str() + "_to_" + target->get_name().str(); + std::string key = mk_decl_key(box_name, 1, 'f'); + auto it = m_decls.find(key); + func_decl* f; + if (it != m_decls.end()) { + f = it->second; + } else { + f = m.mk_func_decl(symbol(box_name), 1, &actual, target); + m_pinned_decls.push_back(f); + m_decls.emplace(key, f); + } + return expr_ref(m.mk_app(f, e.get()), m); + } + + // Coerce arguments of a function application to match declared sorts + void coerce_args(func_decl* f, expr_ref_vector& args) { + for (unsigned i = 0; i < args.size() && i < f->get_arity(); ++i) { + sort* expected = f->get_domain(i); + sort* actual = args.get(i)->get_sort(); + if (expected != actual) { + args[i] = coerce_arg(expr_ref(args.get(i), m), expected); + } + } + } + + // Coerce result to expected sort if needed + expr_ref coerce_result(expr_ref const& e, sort* expected) { + if (!expected || e->get_sort() == expected) return e; + return coerce_arg(e, expected); + } + bool find_bound(std::string const& n, expr_ref& e) const { for (auto it = m_bound.rbegin(); it != m_bound.rend(); ++it) { auto jt = it->find(n); @@ -524,6 +616,25 @@ class tptp_parser { return is_forall ? ::mk_forall(m, bound.size(), bound.data(), body.get()) : ::mk_exists(m, bound.size(), bound.data(), body.get()); } + // $is_rat(x) ≡ exists a:Int, b:Int. b != 0 && x = a/b + expr_ref mk_is_rat(expr_ref const& x) { + sort* int_sort = m_arith.mk_int(); + app* a = m.mk_fresh_const("a", int_sort); + app* b = m.mk_fresh_const("b", int_sort); + expr_ref ar(m_arith.mk_to_real(a), m); + expr_ref br(m_arith.mk_to_real(b), m); + expr_ref xr(x); + if (m_arith.is_int(x)) + xr = expr_ref(m_arith.mk_to_real(x), m); + expr_ref b_ne_zero(m.mk_not(m.mk_eq(b, m_arith.mk_int(0))), m); + expr_ref x_eq_div(m.mk_eq(xr, m_arith.mk_div(ar, br)), m); + expr_ref body(m.mk_and(b_ne_zero, x_eq_div), m); + ptr_vector bound; + bound.push_back(a); + bound.push_back(b); + return expr_ref(::mk_exists(m, bound.size(), bound.data(), body.get()), m); + } + parsed_type parse_type_atom() { if (accept(token_kind::lparen)) { std::vector prod = parse_type_product_raw(); @@ -536,9 +647,8 @@ class tptp_parser { full_domain.insert(full_domain.end(), rhs.domain.begin(), rhs.domain.end()); } expect(token_kind::rparen, "')'"); - // Return as a higher-order sort (uninterpreted) - sort* ho = get_ho_sort(full_domain, rhs.range); - return parsed_type(ho); + // Return with domain/range preserved for proper flattening + return parsed_type(full_domain, rhs.range); } expect(token_kind::rparen, "')'"); if (prod.size() == 1) @@ -547,6 +657,17 @@ class tptp_parser { return parsed_type(prod, nullptr); } std::string n = parse_name(); + // Handle parameterized type constructors: fun(A, B), product_prod(A, B), etc. + if (accept(token_kind::lparen)) { + // Consume type arguments — for monomorphization, we ignore them + // and return the base sort (or m_univ if the constructor result is $tType) + if (!accept(token_kind::rparen)) { + do { parse_type_expr(); } while (accept(token_kind::comma)); + expect(token_kind::rparen, "')'"); + } + // Return m_univ as the monomorphized result of any type constructor application + return parsed_type(m_univ); + } return parsed_type(get_sort(n)); } @@ -597,24 +718,93 @@ class tptp_parser { return args; } - std::vector parse_type_product() { - return parse_type_product_raw(); + parsed_type parse_type_product() { + parsed_type first = parse_type_atom(); + // If atom returned a function type and no '*' follows, return it directly + if (!first.domain.empty() && first.range != nullptr && !is(token_kind::star_tok)) { + return first; + } + // Build product vector + std::vector args; + if (!first.domain.empty() && first.range != nullptr) { + // Function type used as element in a product + args.push_back(get_ho_sort(first.domain, first.range)); + } else if (!first.domain.empty() && first.range == nullptr) { + // Parenthesized product: flatten + args = first.domain; + } else { + args.push_back(first.range); + } + while (accept(token_kind::star_tok)) { + parsed_type t = parse_type_atom(); + if (!t.domain.empty() && t.range != nullptr) { + args.push_back(get_ho_sort(t.domain, t.range)); + } else if (!t.domain.empty()) { + args.insert(args.end(), t.domain.begin(), t.domain.end()); + } else { + args.push_back(t.range); + } + } + return parsed_type(args, nullptr); } parsed_type parse_type_expr() { - std::vector prod = parse_type_product(); + // Handle type quantification at the expression level for proper domain/range preservation + if (is(token_kind::type_forall_tok) || is(token_kind::type_exists_tok)) { + next(); + expect(token_kind::lbrack, "'['"); + std::vector type_params; + if (!accept(token_kind::rbrack)) { + do { + std::string tv = parse_name(); + if (accept(token_kind::colon)) + parse_type_expr(); // consume $tType annotation + m_sorts.insert_or_assign(tv, m_univ); + type_params.push_back(m_univ); + } while (accept(token_kind::comma)); + expect(token_kind::rbrack, "']'"); + } + expect(token_kind::colon, "':'"); + parsed_type inner = parse_type_expr(); + // Prepend type params to domain + if (!type_params.empty()) { + std::vector full_domain = type_params; + full_domain.insert(full_domain.end(), inner.domain.begin(), inner.domain.end()); + return parsed_type(full_domain, inner.range); + } + return inner; + } + parsed_type prod = parse_type_product(); if (accept(token_kind::gt_tok)) { parsed_type rhs = parse_type_expr(); + // prod is either a product (domain non-empty, range==nullptr) or a single sort (domain empty) + std::vector domain; + if (!prod.domain.empty() && prod.range == nullptr) { + domain = prod.domain; + } else if (!prod.domain.empty() && prod.range != nullptr) { + // A function type as domain element — wrap it + domain.push_back(get_ho_sort(prod.domain, prod.range)); + } else { + domain.push_back(prod.range); + } if (!rhs.domain.empty()) { // Higher-order result type: A > (B > C) flattened to (A, B) > C - prod.insert(prod.end(), rhs.domain.begin(), rhs.domain.end()); - return parsed_type(prod, rhs.range); + domain.insert(domain.end(), rhs.domain.begin(), rhs.domain.end()); + return parsed_type(domain, rhs.range); } - return parsed_type(prod, rhs.range); + return parsed_type(domain, rhs.range); } - if (prod.size() != 1) - throw parse_error("type product must be followed by '>'"); - return parsed_type(prod[0]); + // No '>' follows — must be a single type or a function type from parens + if (!prod.domain.empty() && prod.range != nullptr) { + // Function type from parenthesized expression + return prod; + } + if (!prod.domain.empty() && prod.range == nullptr) { + if (prod.domain.size() != 1) + throw parse_error("type product must be followed by '>'"); + return parsed_type(prod.domain[0]); + } + return parsed_type(prod.range); } void skip_annotations_until_rparen() { @@ -647,7 +837,7 @@ class tptp_parser { expr_ref parse_term_primary() { if (accept(token_kind::lparen)) { - expr_ref e = parse_term(); + expr_ref e = parse_formula(); expect(token_kind::rparen, "')'"); return e; } @@ -675,90 +865,31 @@ class tptp_parser { return expr_ref(get_or_create_implicit_var(n), m); expr_ref_vector args(m); - if (accept(token_kind::lparen)) { + // $ite needs special parsing: first arg is formula, rest are formulas (branches can be equalities) + if (n == "$ite") { + expect(token_kind::lparen, "'('"); + args.push_back(parse_formula()); + expect(token_kind::comma, "','"); + args.push_back(parse_formula()); + expect(token_kind::comma, "','"); + args.push_back(parse_formula()); + expect(token_kind::rparen, "')'"); + } + else if (accept(token_kind::lparen)) { if (!accept(token_kind::rparen)) { do { args.push_back(parse_term()); } while (accept(token_kind::comma)); expect(token_kind::rparen, "')'"); } } - if (n == "$uminus") { - if (args.size() != 1) - throw parse_error("arithmetic function '$uminus' expects arity 1"); - expr_ref a(args.get(0), m); - if (!m_arith.is_int_real(a)) - throw parse_error("arithmetic function '$uminus' expects arithmetic argument"); - return expr_ref(m_arith.mk_uminus(a), m); + // Table-driven prefix operator dispatch + auto op_it = m_ops.find(n); + if (op_it != m_ops.end() && !op_it->second.is_infix) { + return op_it->second.builder(args); } - if (n == "$sum" || n == "$difference" || n == "$product") { - if (args.size() != 2) - throw parse_error("arithmetic function expects arity 2"); - expr_ref a(args.get(0), m), b(args.get(1), m); - if (!m_arith.is_int_real(a) || !m_arith.is_int_real(b)) - throw parse_error("arithmetic function expects arithmetic arguments"); - bool use_real = m_arith.is_real(a) || m_arith.is_real(b); - if (use_real) { - if (m_arith.is_int(a)) a = expr_ref(m_arith.mk_to_real(a), m); - if (m_arith.is_int(b)) b = expr_ref(m_arith.mk_to_real(b), m); - } - if (n == "$sum") return expr_ref(m_arith.mk_add(a, b), m); - if (n == "$difference") return expr_ref(m_arith.mk_sub(a, b), m); - /* $product */ return expr_ref(m_arith.mk_mul(a, b), m); - } - - if (n == "$quotient_e" || n == "$quotient_t" || n == "$quotient_f" || n == "$quotient") { - if (args.size() != 2) - throw parse_error("arithmetic function expects arity 2"); - expr_ref a(args.get(0), m), b(args.get(1), m); - if (!m_arith.is_int_real(a) || !m_arith.is_int_real(b)) - throw parse_error("arithmetic function expects arithmetic arguments"); - if (m_arith.is_int(a) && m_arith.is_int(b)) - return expr_ref(m_arith.mk_idiv(a, b), m); - if (m_arith.is_int(a)) a = expr_ref(m_arith.mk_to_real(a), m); - if (m_arith.is_int(b)) b = expr_ref(m_arith.mk_to_real(b), m); - return expr_ref(m_arith.mk_div(a, b), m); - } - - if (n == "$remainder_e" || n == "$remainder_t" || n == "$remainder_f") { - if (args.size() != 2) - throw parse_error("arithmetic function expects arity 2"); - expr_ref a(args.get(0), m), b(args.get(1), m); - if (!m_arith.is_int_real(a) || !m_arith.is_int_real(b)) - throw parse_error("arithmetic function expects arithmetic arguments"); - return expr_ref(m_arith.mk_mod(a, b), m); - } - - if (n == "$floor" || n == "$ceiling" || n == "$truncate" || n == "$round" || n == "$to_int") { - if (args.size() != 1) - throw parse_error("arithmetic function expects arity 1"); - expr_ref a(args.get(0), m); - if (!m_arith.is_int_real(a)) - throw parse_error("arithmetic function expects arithmetic argument"); - if (m_arith.is_int(a)) return a; - return expr_ref(m_arith.mk_to_int(a), m); - } - - if (n == "$to_rat" || n == "$to_real") { - if (args.size() != 1) - throw parse_error("arithmetic function expects arity 1"); - expr_ref a(args.get(0), m); - if (!m_arith.is_int_real(a)) - throw parse_error("arithmetic function expects arithmetic argument"); - if (m_arith.is_real(a)) return a; - return expr_ref(m_arith.mk_to_real(a), m); - } - - if (n == "$is_int") { - if (args.size() != 1) - throw parse_error("arithmetic predicate '$is_int' expects arity 1"); - expr_ref a(args.get(0), m); - if (!m_arith.is_int_real(a)) - throw parse_error("arithmetic predicate '$is_int' expects arithmetic argument"); - return expr_ref(m_arith.mk_is_int(a), m); - } - - func_decl* f = mk_decl(n, args.size(), false); + func_decl* f = mk_decl_or_ho_const(n, args.size(), false); + if (!args.empty()) coerce_args(f, args); return expr_ref(args.empty() ? m.mk_const(f) : m.mk_app(f, args.size(), args.data()), m); } @@ -766,91 +897,19 @@ class tptp_parser { expr_ref apply_at(expr_ref e) { if (!is(token_kind::at_tok)) return e; - if (!is_app(e)) - throw parse_error("application operator (@) requires function on left-hand side"); - - // Collect all @ arguments - app* a = to_app(e); - func_decl* orig_d = a->get_decl(); - if (orig_d->get_family_id() != null_family_id) - throw parse_error("application operator (@) requires uninterpreted function on left-hand side"); - - expr_ref_vector args(m); - for (unsigned i = 0; i < a->get_num_args(); ++i) args.push_back(a->get_arg(i)); + // @ corresponds to array select (function application) while (accept(token_kind::at_tok)) { expr_ref arg = parse_at_arg(); - args.push_back(arg); - } - - unsigned new_arity = args.size(); - std::string name = orig_d->get_name().str(); - - // Check if the LHS has a ho_sort (function type as value). - // If so, use its type info to determine domain/range for the application. - sort* lhs_sort = orig_d->get_range(); // sort of the 0-arity constant - auto ho_it = m_ho_sort_info.find(lhs_sort); - if (ho_it != m_ho_sort_info.end() && a->get_num_args() == 0) { - auto& [ho_domain, ho_range] = ho_it->second; - sort* range = ho_range; - // If we're partially applying (fewer args than domain), result is a ho_sort - if (new_arity < ho_domain.size()) { - std::vector remaining(ho_domain.begin() + new_arity, ho_domain.end()); - range = get_ho_sort(remaining, ho_range); + sort* e_sort = e->get_sort(); + if (!m_array.is_array(e_sort)) { + std::ostringstream out; + out << "application operator (@) requires array/function type on left-hand side, got sort " << e_sort->get_name(); + throw parse_error(out.str()); } - // For bound variables, use $apply so that the variable remains in the - // expression body and mk_forall can properly abstract it. - if (is_bound_var(a)) { - std::vector apply_dom; - apply_dom.push_back(lhs_sort); // first arg is the function variable - for (unsigned i = 0; i < new_arity; ++i) { - if (i < ho_domain.size()) - apply_dom.push_back(ho_domain[i]); - else - apply_dom.push_back(args.get(i)->get_sort()); - } - // Create unique apply function per signature - std::ostringstream apn; - apn << "$apply_" << lhs_sort->get_name() << "_" << new_arity; - func_decl* apply_d = m.mk_func_decl(symbol(apn.str()), apply_dom.size(), apply_dom.data(), range); - m_pinned_decls.push_back(apply_d); - expr_ref_vector all_args(m); - all_args.push_back(e); // the bound variable itself - for (unsigned i = 0; i < new_arity; ++i) all_args.push_back(args.get(i)); - return expr_ref(m.mk_app(apply_d, all_args.size(), all_args.data()), m); - } - // For non-bound ho-typed constants, use the name directly - std::vector dom; - for (unsigned i = 0; i < new_arity; ++i) { - if (i < ho_domain.size()) - dom.push_back(ho_domain[i]); - else - dom.push_back(args.get(i)->get_sort()); - } - func_decl* new_d = m.mk_func_decl(symbol(name), new_arity, dom.data(), range); - m_pinned_decls.push_back(new_d); - return expr_ref(m.mk_app(new_d, new_arity, args.data()), m); + e = expr_ref(m_array.mk_select(e, arg), m); } - - func_decl* new_d = mk_decl(name, new_arity, false); - // Verify argument sorts match; if not, create untyped func_decl with actual arg sorts - bool sorts_ok = (new_d->get_arity() == new_arity); - if (sorts_ok) { - for (unsigned i = 0; i < new_arity; ++i) { - if (new_d->get_domain(i) != args.get(i)->get_sort()) { - sorts_ok = false; - break; - } - } - } - if (!sorts_ok) { - std::vector dom; - for (unsigned i = 0; i < new_arity; ++i) dom.push_back(args.get(i)->get_sort()); - sort* range = new_d->get_range(); - new_d = m.mk_func_decl(symbol(name), new_arity, dom.data(), range); - m_pinned_decls.push_back(new_d); - } - return expr_ref(m.mk_app(new_d, new_arity, args.data()), m); + return e; } // Parse an argument to @ — can be a term, a formula (negation, quantifier, parens with connectives), or a lambda @@ -901,53 +960,6 @@ class tptp_parser { } // Build an arithmetic expression from a TPTP function name and arguments - expr_ref mk_arith_expr(std::string const& n, expr_ref_vector const& args) { - if (n == "$uminus") { - if (args.size() != 1) throw parse_error("$uminus expects arity 1"); - expr_ref a(args.get(0), m); - return expr_ref(m_arith.mk_uminus(a), m); - } - if (n == "$sum" || n == "$difference" || n == "$product") { - if (args.size() != 2) throw parse_error("arithmetic binary function expects arity 2"); - expr_ref a(args.get(0), m), b(args.get(1), m); - bool use_real = m_arith.is_real(a) || m_arith.is_real(b); - if (use_real) { - if (m_arith.is_int(a)) a = expr_ref(m_arith.mk_to_real(a), m); - if (m_arith.is_int(b)) b = expr_ref(m_arith.mk_to_real(b), m); - } - if (n == "$sum") return expr_ref(m_arith.mk_add(a, b), m); - if (n == "$difference") return expr_ref(m_arith.mk_sub(a, b), m); - return expr_ref(m_arith.mk_mul(a, b), m); - } - if (n == "$quotient" || n == "$quotient_e" || n == "$quotient_t" || n == "$quotient_f") { - if (args.size() != 2) throw parse_error("quotient expects arity 2"); - expr_ref a(args.get(0), m), b(args.get(1), m); - if (m_arith.is_int(a) && m_arith.is_int(b)) - return expr_ref(m_arith.mk_idiv(a, b), m); - if (m_arith.is_int(a)) a = expr_ref(m_arith.mk_to_real(a), m); - if (m_arith.is_int(b)) b = expr_ref(m_arith.mk_to_real(b), m); - return expr_ref(m_arith.mk_div(a, b), m); - } - if (n == "$remainder_e" || n == "$remainder_t" || n == "$remainder_f") { - if (args.size() != 2) throw parse_error("remainder expects arity 2"); - expr_ref a(args.get(0), m), b(args.get(1), m); - return expr_ref(m_arith.mk_mod(a, b), m); - } - if (n == "$floor" || n == "$ceiling" || n == "$truncate" || n == "$round" || n == "$to_int") { - if (args.size() != 1) throw parse_error("$to_int expects arity 1"); - expr_ref a(args.get(0), m); - if (m_arith.is_int(a)) return a; - return expr_ref(m_arith.mk_to_int(a), m); - } - if (n == "$to_rat" || n == "$to_real") { - if (args.size() != 1) throw parse_error("$to_real expects arity 1"); - expr_ref a(args.get(0), m); - if (m_arith.is_real(a)) return a; - return expr_ref(m_arith.mk_to_real(a), m); - } - throw parse_error("unknown arithmetic function: " + n); - } - // Coerce two expressions to have the same sort for equality. // If sorts already match, returns lhs unchanged. Otherwise coerces the // 0-arity constant side to match the other's sort. If that's not possible, @@ -997,6 +1009,21 @@ class tptp_parser { return apply_at(e); } + // Handle negative numerals in formula position: -2 = $uminus(2) + if (accept(token_kind::minus_tok)) { + expr_ref t = parse_term(); + expr_ref lhs(m_arith.mk_uminus(t), m); + if (is(token_kind::equal_tok) || is(token_kind::neq_tok)) { + bool neq = accept(token_kind::neq_tok); + if (!neq) expect(token_kind::equal_tok, "'='"); + expr_ref rhs = parse_term(); + lhs = coerce_eq(lhs, rhs); + expr_ref eq(m.mk_eq(lhs, rhs), m); + return neq ? expr_ref(m.mk_not(eq), m) : eq; + } + return lhs; + } + std::string n = parse_name(); if (n == "$true") return expr_ref(m.mk_true(), m); if (n == "$false") return expr_ref(m.mk_false(), m); @@ -1014,53 +1041,39 @@ class tptp_parser { } expr_ref_vector args(m); - if (accept(token_kind::lparen)) { + // $ite needs special parsing: first arg is formula, rest are formulas (branches can be equalities) + if (n == "$ite") { + expect(token_kind::lparen, "'('"); + args.push_back(parse_formula()); + expect(token_kind::comma, "','"); + args.push_back(parse_formula()); + expect(token_kind::comma, "','"); + args.push_back(parse_formula()); + expect(token_kind::rparen, "')'"); + } + else if (accept(token_kind::lparen)) { if (!accept(token_kind::rparen)) { do { args.push_back(parse_term()); } while (accept(token_kind::comma)); expect(token_kind::rparen, "')'"); } } - if (n == "$less" || n == "$lesseq" || n == "$greater" || n == "$greatereq") { - if (args.size() != 2) { - std::ostringstream out; - out << "arithmetic predicate '" << n << "' expects arity 2"; - throw parse_error(out.str()); + // Table-driven prefix operator dispatch + auto op_it = m_ops.find(n); + if (op_it != m_ops.end() && !op_it->second.is_infix) { + expr_ref result = op_it->second.builder(args); + // If result is non-Bool, check for trailing = or != + if (!m.is_bool(result)) { + if (is(token_kind::equal_tok) || is(token_kind::neq_tok)) { + bool neq = accept(token_kind::neq_tok); + if (!neq) expect(token_kind::equal_tok, "'='"); + expr_ref rhs = parse_term(); + result = coerce_eq(result, rhs); + expr_ref eq(m.mk_eq(result, rhs), m); + return neq ? expr_ref(m.mk_not(eq), m) : eq; + } } - expr_ref lhs(args.get(0), m), rhs(args.get(1), m); - if (!m_arith.is_int_real(lhs) || !m_arith.is_int_real(rhs)) { - std::ostringstream out; - out << "arithmetic predicate '" << n << "' expects arithmetic arguments"; - throw parse_error(out.str()); - } - bool use_real = m_arith.is_real(lhs) || m_arith.is_real(rhs); - if (use_real) { - if (m_arith.is_int(lhs)) lhs = expr_ref(m_arith.mk_to_real(lhs), m); - if (m_arith.is_int(rhs)) rhs = expr_ref(m_arith.mk_to_real(rhs), m); - } - if (n == "$less") return expr_ref(m_arith.mk_lt(lhs, rhs), m); - if (n == "$lesseq") return expr_ref(m_arith.mk_le(lhs, rhs), m); - if (n == "$greater") return expr_ref(m_arith.mk_gt(lhs, rhs), m); - /* n == "$greatereq"*/ return expr_ref(m_arith.mk_ge(lhs, rhs), m); - } - - // Arithmetic terms that may appear before = or != in formula position - if (n == "$sum" || n == "$difference" || n == "$product" || - n == "$quotient" || n == "$quotient_e" || n == "$quotient_t" || n == "$quotient_f" || - n == "$remainder_e" || n == "$remainder_t" || n == "$remainder_f" || - n == "$uminus" || n == "$floor" || n == "$ceiling" || n == "$truncate" || - n == "$round" || n == "$to_int" || n == "$to_rat" || n == "$to_real") { - // Re-parse via the term path by constructing the arithmetic expression - expr_ref arith_expr = mk_arith_expr(n, args); - if (is(token_kind::equal_tok) || is(token_kind::neq_tok)) { - bool neq = accept(token_kind::neq_tok); - if (!neq) expect(token_kind::equal_tok, "'='"); - expr_ref rhs = parse_term(); - arith_expr = coerce_eq(arith_expr, rhs); - expr_ref eq(m.mk_eq(arith_expr, rhs), m); - return neq ? expr_ref(m.mk_not(eq), m) : eq; - } - return arith_expr; + return result; } expr_ref lhs(m); @@ -1080,6 +1093,7 @@ class tptp_parser { if (is(token_kind::equal_tok) || is(token_kind::neq_tok)) { if (!has_lhs) { func_decl* f = mk_decl_or_ho_const(n, args.size(), false); + if (!args.empty()) coerce_args(f, args); lhs = args.empty() ? m.mk_const(f) : m.mk_app(f, args.size(), args.data()); } bool neq = accept(token_kind::neq_tok); @@ -1105,7 +1119,8 @@ class tptp_parser { auto typed = m_typed_decls.find(mk_typed_key(n, args.size())); if (typed != m_typed_decls.end()) { - func_decl* f = mk_decl(n, args.size(), false); + func_decl* f = args.empty() ? mk_decl_or_ho_const(n, 0, false) : mk_decl(n, args.size(), false); + if (!args.empty()) coerce_args(f, args); expr_ref e(args.empty() ? m.mk_const(f) : m.mk_app(f, args.size(), args.data()), m); e = apply_at(e); if (is(token_kind::equal_tok) || is(token_kind::neq_tok)) { @@ -1121,15 +1136,14 @@ class tptp_parser { return e; } - func_decl* pred = mk_decl(n, args.size(), true); + func_decl* pred = mk_decl_or_ho_const(n, args.size(), true); + if (!args.empty()) coerce_args(pred, args); expr_ref result(args.empty() ? m.mk_const(pred) : m.mk_app(pred, args.size(), args.data()), m); return apply_at(result); } // Parse THF lambda expression: ^ [X: T, ...] : body - // Since Z3 doesn't support lambdas natively, we approximate: - // - Parse the bound variables and body - // - Return the body with variables universally quantified (over-approximation) + // Uses Z3's native lambda construct, which produces array terms. expr_ref parse_lambda_expr() { expect(token_kind::lbrack, "'['"); ptr_vector vars; @@ -1157,33 +1171,19 @@ class tptp_parser { m_bound.push_back(scope); expr_ref body = parse_formula(); m_bound.pop_back(); - // For first-order approximation, create a fresh constant with the lambda's function type sort - std::ostringstream oss; - oss << "$lambda_" << m_lambda_counter++; - std::string lname = oss.str(); - // Lambda type: param sorts → body sort - // If the body sort is itself a ho_sort, flatten it into the lambda's type - // e.g., ^ [X: mu] : (body with sort (U > Bool)) → sort (mu * U > Bool) - std::vector param_sorts; - for (auto* v : vars) param_sorts.push_back(v->get_sort()); - sort* body_sort = body->get_sort(); - sort* lambda_sort; - if (param_sorts.empty()) { - lambda_sort = body_sort; - } else { - auto body_ho = m_ho_sort_info.find(body_sort); - if (body_ho != m_ho_sort_info.end()) { - // Flatten: params + body's domain → body's range - std::vector full_domain = param_sorts; - full_domain.insert(full_domain.end(), body_ho->second.first.begin(), body_ho->second.first.end()); - lambda_sort = get_ho_sort(full_domain, body_ho->second.second); - } else { - lambda_sort = get_ho_sort(param_sorts, body_sort); - } + if (vars.empty()) + return body; + // Use expr_abstract to replace named constants with de Bruijn indices, + // then wrap with mk_lambda. + expr_ref abs_body(m); + expr_abstract(m, 0, vars.size(), (expr* const*)vars.data(), body, abs_body); + ptr_vector sorts; + svector names; + for (unsigned i = 0; i < vars.size(); ++i) { + sorts.push_back(vars[i]->get_sort()); + names.push_back(vars[i]->get_decl()->get_name()); } - func_decl* f = m.mk_func_decl(symbol(lname), 0, static_cast(nullptr), lambda_sort); - m_pinned_decls.push_back(f); - return expr_ref(m.mk_const(f), m); + return expr_ref(m.mk_lambda(sorts.size(), sorts.data(), names.data(), abs_body), m); } expr_ref parse_unary_formula() { @@ -1218,6 +1218,12 @@ class tptp_parser { s = t.range; } } + // Monomorphize: $tType-sorted variables become U-sorted + // and register them as sorts for subsequent type references + if (is_ttype(s)) { + s = m_univ; + m_sorts.insert_or_assign(v, m_univ); + } app* c = m.mk_const(symbol(v), s); m_pinned_exprs.push_back(c); vars.push_back(c); @@ -1233,42 +1239,42 @@ class tptp_parser { return mk_quantifier(is_forall, vars, body); } + // Type quantification in formula context: !>[A: $tType, ...] : body + // Erase type variables and parse body as formula + if (is(token_kind::type_forall_tok) || is(token_kind::type_exists_tok)) { + next(); + expect(token_kind::lbrack, "'['"); + if (!accept(token_kind::rbrack)) { + do { + std::string tv = parse_name(); + if (accept(token_kind::colon)) + parse_type_expr(); // consume $tType annotation + m_sorts.insert_or_assign(tv, m_univ); + } while (accept(token_kind::comma)); + expect(token_kind::rbrack, "']'"); + } + expect(token_kind::colon, "':'"); + return parse_formula(); + } + return parse_atomic_formula(); } - expr_ref parse_and_formula() { + expr_ref parse_expr(unsigned min_prec) { expr_ref e = parse_unary_formula(); - while (is(token_kind::and_tok) || is(token_kind::nand_tok)) { - bool is_nand = accept(token_kind::nand_tok); - if (!is_nand) expect(token_kind::and_tok, "'&'"); - expr_ref rhs = parse_unary_formula(); - expr_ref conj(::mk_and(m, e, rhs), m); - e = is_nand ? expr_ref(m.mk_not(conj), m) : conj; - } - return e; - } - - expr_ref parse_or_formula() { - expr_ref e = parse_and_formula(); - while (is(token_kind::or_tok) || is(token_kind::nor_tok)) { - bool is_nor = accept(token_kind::nor_tok); - if (!is_nor) expect(token_kind::or_tok, "'|'"); - expr_ref rhs = parse_and_formula(); - expr_ref disj(::mk_or(m, e, rhs), m); - e = is_nor ? expr_ref(m.mk_not(disj), m) : disj; - } - return e; - } - - expr_ref parse_implies_formula() { - expr_ref e = parse_or_formula(); - if (accept(token_kind::implies_tok)) { - expr_ref rhs = parse_implies_formula(); - return expr_ref(m.mk_implies(e, rhs), m); - } - if (accept(token_kind::implied_tok)) { - expr_ref rhs = parse_implies_formula(); - return expr_ref(m.mk_implies(rhs, e), m); + for (;;) { + char const* op_name = token_to_op_name(); + if (!op_name) break; + auto it = m_ops.find(op_name); + if (it == m_ops.end() || !it->second.is_infix) break; + if (it->second.precedence < min_prec) break; + next(); // consume the operator token + unsigned next_prec = it->second.right_assoc ? it->second.precedence : it->second.precedence + 1; + expr_ref rhs = parse_expr(next_prec); + expr_ref_vector args(m); + args.push_back(e); + args.push_back(rhs); + e = it->second.builder(args); } return e; } @@ -1283,10 +1289,17 @@ class tptp_parser { expect(token_kind::rparen, "')'"); if (t.domain.empty() && is_ttype(t.range)) { - m_sorts.insert_or_assign(name, m.mk_uninterpreted_sort(symbol(name))); + // Sort declaration: monomorphize to m_univ + m_sorts.insert_or_assign(name, m_univ); return; } + // Monomorphize: replace $tType in domain/range with m_univ + for (auto& s : t.domain) { + if (is_ttype(s)) s = m_univ; + } + if (t.range && is_ttype(t.range)) t.range = m_univ; + m_typed_decls.insert_or_assign(mk_typed_key(name, t.domain.size()), std::make_pair(t.domain, t.range)); } @@ -1405,6 +1418,7 @@ public: m_cmd(cmd), m(m_cmd.m()), m_arith(m), + m_array(m), m_pinned_sorts(m), m_pinned_decls(m), m_pinned_exprs(m), @@ -1418,6 +1432,196 @@ public: m_sorts.emplace("$int", m_arith.mk_int()); m_sorts.emplace("$rat", m_arith.mk_real()); m_sorts.emplace("$real", m_arith.mk_real()); + init_op_table(); + } + + void init_op_table() { + // Prefix arithmetic predicates (is_infix=false, precedence=0) + m_ops["$less"] = { false, 0, false, [&](expr_ref_vector const& args) -> expr_ref { + check_arith_arity(args, 2, "$less"); + auto [a, b] = coerce_arith2(args); + return expr_ref(m_arith.mk_lt(a, b), m); + }}; + m_ops["$lesseq"] = { false, 0, false, [&](expr_ref_vector const& args) -> expr_ref { + check_arith_arity(args, 2, "$lesseq"); + auto [a, b] = coerce_arith2(args); + return expr_ref(m_arith.mk_le(a, b), m); + }}; + m_ops["$greater"] = { false, 0, false, [&](expr_ref_vector const& args) -> expr_ref { + check_arith_arity(args, 2, "$greater"); + auto [a, b] = coerce_arith2(args); + return expr_ref(m_arith.mk_gt(a, b), m); + }}; + m_ops["$greatereq"] = { false, 0, false, [&](expr_ref_vector const& args) -> expr_ref { + check_arith_arity(args, 2, "$greatereq"); + auto [a, b] = coerce_arith2(args); + return expr_ref(m_arith.mk_ge(a, b), m); + }}; + m_ops["$uminus"] = { false, 0, false, [&](expr_ref_vector const& args) -> expr_ref { + check_arith_arity(args, 1, "$uminus"); + return expr_ref(m_arith.mk_uminus(args[0]), m); + }}; + m_ops["$sum"] = { false, 0, false, [&](expr_ref_vector const& args) -> expr_ref { + check_arith_arity(args, 2, "$sum"); + auto [a, b] = coerce_arith2(args); + return expr_ref(m_arith.mk_add(a, b), m); + }}; + m_ops["$plus"] = { false, 0, false, [&](expr_ref_vector const& args) -> expr_ref { + check_arith_arity(args, 2, "$plus"); + auto [a, b] = coerce_arith2(args); + return expr_ref(m_arith.mk_add(a, b), m); + }}; + m_ops["$difference"] = { false, 0, false, [&](expr_ref_vector const& args) -> expr_ref { + check_arith_arity(args, 2, "$difference"); + auto [a, b] = coerce_arith2(args); + return expr_ref(m_arith.mk_sub(a, b), m); + }}; + m_ops["$product"] = { false, 0, false, [&](expr_ref_vector const& args) -> expr_ref { + check_arith_arity(args, 2, "$product"); + auto [a, b] = coerce_arith2(args); + return expr_ref(m_arith.mk_mul(a, b), m); + }}; + m_ops["$quotient"] = { false, 0, false, [&](expr_ref_vector const& args) -> expr_ref { + check_arith_arity(args, 2, "$quotient"); + return mk_quotient(args); + }}; + m_ops["$quotient_e"] = { false, 0, false, [&](expr_ref_vector const& args) -> expr_ref { + check_arith_arity(args, 2, "$quotient_e"); + return mk_quotient(args); + }}; + m_ops["$quotient_t"] = { false, 0, false, [&](expr_ref_vector const& args) -> expr_ref { + check_arith_arity(args, 2, "$quotient_t"); + return mk_quotient(args); + }}; + m_ops["$quotient_f"] = { false, 0, false, [&](expr_ref_vector const& args) -> expr_ref { + check_arith_arity(args, 2, "$quotient_f"); + return mk_quotient(args); + }}; + m_ops["$remainder_e"] = { false, 0, false, [&](expr_ref_vector const& args) -> expr_ref { + check_arith_arity(args, 2, "$remainder_e"); + return expr_ref(m_arith.mk_mod(args[0], args[1]), m); + }}; + m_ops["$remainder_t"] = { false, 0, false, [&](expr_ref_vector const& args) -> expr_ref { + check_arith_arity(args, 2, "$remainder_t"); + return expr_ref(m_arith.mk_mod(args[0], args[1]), m); + }}; + m_ops["$remainder_f"] = { false, 0, false, [&](expr_ref_vector const& args) -> expr_ref { + check_arith_arity(args, 2, "$remainder_f"); + return expr_ref(m_arith.mk_mod(args[0], args[1]), m); + }}; + m_ops["$floor"] = { false, 0, false, [&](expr_ref_vector const& args) -> expr_ref { + check_arith_arity(args, 1, "$floor"); + expr_ref a(args[0], m); + if (m_arith.is_int(a)) return a; + return expr_ref(m_arith.mk_to_int(a), m); + }}; + m_ops["$ceiling"] = { false, 0, false, [&](expr_ref_vector const& args) -> expr_ref { + check_arith_arity(args, 1, "$ceiling"); + expr_ref a(args[0], m); + if (m_arith.is_int(a)) return a; + // ceiling(x) = -floor(-x) + return expr_ref(m_arith.mk_uminus(m_arith.mk_to_int(m_arith.mk_uminus(a))), m); + }}; + m_ops["$truncate"] = { false, 0, false, [&](expr_ref_vector const& args) -> expr_ref { + check_arith_arity(args, 1, "$truncate"); + expr_ref a(args[0], m); + if (m_arith.is_int(a)) return a; + // truncate(x) = if x >= 0 then floor(x) else ceiling(x) + expr_ref zero(m_arith.mk_real(0), m); + expr_ref fl(m_arith.mk_to_int(a), m); + expr_ref neg_fl(m_arith.mk_uminus(m_arith.mk_to_int(m_arith.mk_uminus(a))), m); + return expr_ref(m.mk_ite(m_arith.mk_ge(a, zero), fl, neg_fl), m); + }}; + m_ops["$round"] = { false, 0, false, [&](expr_ref_vector const& args) -> expr_ref { + check_arith_arity(args, 1, "$round"); + expr_ref a(args[0], m); + if (m_arith.is_int(a)) return a; + // round to nearest even + expr_ref i(m_arith.mk_to_int(a), m); + expr_ref half(m_arith.mk_add(m_arith.mk_to_real(i), m_arith.mk_numeral(rational(1, 2), false)), m); + expr_ref i1(m_arith.mk_add(i, m_arith.mk_int(1)), m); + expr_ref is_even(m.mk_eq(m_arith.mk_mod(i, m_arith.mk_int(2)), m_arith.mk_int(0)), m); + return expr_ref(m.mk_ite(m_arith.mk_gt(a, half), i1, + m.mk_ite(m.mk_eq(a, half), m.mk_ite(is_even, i, i1), i)), m); + }}; + m_ops["$to_int"] = { false, 0, false, [&](expr_ref_vector const& args) -> expr_ref { + check_arith_arity(args, 1, "$to_int"); + expr_ref a(args[0], m); + if (m_arith.is_int(a)) return a; + return expr_ref(m_arith.mk_to_int(a), m); + }}; + m_ops["$to_real"] = { false, 0, false, [&](expr_ref_vector const& args) -> expr_ref { + check_arith_arity(args, 1, "$to_real"); + expr_ref a(args[0], m); + if (m_arith.is_real(a)) return a; + return expr_ref(m_arith.mk_to_real(a), m); + }}; + m_ops["$to_rat"] = { false, 0, false, [&](expr_ref_vector const& args) -> expr_ref { + check_arith_arity(args, 1, "$to_rat"); + expr_ref a(args[0], m); + if (m_arith.is_real(a)) return a; + return expr_ref(m_arith.mk_to_real(a), m); + }}; + m_ops["$is_int"] = { false, 0, false, [&](expr_ref_vector const& args) -> expr_ref { + check_arith_arity(args, 1, "$is_int"); + return expr_ref(m_arith.mk_is_int(args[0]), m); + }}; + m_ops["$is_rat"] = { false, 0, false, [&](expr_ref_vector const& args) -> expr_ref { + check_arith_arity(args, 1, "$is_rat"); + expr_ref a(args[0], m); + return mk_is_rat(a); + }}; + m_ops["$distinct"] = { false, 0, false, [&](expr_ref_vector const& args) -> expr_ref { + if (args.size() == 2) return expr_ref(m.mk_not(m.mk_eq(args[0], args[1])), m); + return expr_ref(m.mk_distinct(args.size(), args.data()), m); + }}; + m_ops["$ite"] = { false, 0, false, [&](expr_ref_vector const& args) -> expr_ref { + check_arith_arity(args, 3, "$ite"); + expr_ref cond(args[0], m), t(args[1], m), f(args[2], m); + if (!m.is_bool(cond)) + throw parse_error("$ite expects Bool condition as first argument"); + return expr_ref(m.mk_ite(cond, t, f), m); + }}; + m_ops["$abs"] = { false, 0, false, [&](expr_ref_vector const& args) -> expr_ref { + check_arith_arity(args, 1, "$abs"); + expr_ref a(args[0], m); + if (!m_arith.is_int_real(a)) + throw parse_error("$abs expects arithmetic argument"); + expr_ref zero(m_arith.is_int(a) ? m_arith.mk_int(0) : m_arith.mk_numeral(rational(0), false), m); + return expr_ref(m.mk_ite(m_arith.mk_ge(a, zero), a, expr_ref(m_arith.mk_uminus(a), m)), m); + }}; + m_ops["$true"] = { false, 0, false, [&](expr_ref_vector const&) -> expr_ref { + return expr_ref(m.mk_true(), m); + }}; + m_ops["$false"] = { false, 0, false, [&](expr_ref_vector const&) -> expr_ref { + return expr_ref(m.mk_false(), m); + }}; + + // Infix logical operators (token-based, matched by token_to_op_name) + m_ops["<=>"] = { true, PREC_IFF, false, [&](expr_ref_vector const& args) -> expr_ref { + return expr_ref(m.mk_iff(args[0], args[1]), m); + }}; + m_ops["<~>"] = { true, PREC_IFF, false, [&](expr_ref_vector const& args) -> expr_ref { + return expr_ref(m.mk_not(m.mk_iff(args[0], args[1])), m); + }}; + m_ops["=>"] = { true, PREC_IMPLIES, true, [&](expr_ref_vector const& args) -> expr_ref { + return expr_ref(m.mk_implies(args[0], args[1]), m); + }}; + m_ops["<="] = { true, PREC_IMPLIES, false, [&](expr_ref_vector const& args) -> expr_ref { + return expr_ref(m.mk_implies(args[1], args[0]), m); + }}; + m_ops["|"] = { true, PREC_OR, false, [&](expr_ref_vector const& args) -> expr_ref { + return expr_ref(m.mk_or(args[0], args[1]), m); + }}; + m_ops["~|"] = { true, PREC_OR, false, [&](expr_ref_vector const& args) -> expr_ref { + return expr_ref(m.mk_not(m.mk_or(args[0], args[1])), m); + }}; + m_ops["&"] = { true, PREC_AND, false, [&](expr_ref_vector const& args) -> expr_ref { + return expr_ref(m.mk_and(args[0], args[1]), m); + }}; + m_ops["~&"] = { true, PREC_AND, false, [&](expr_ref_vector const& args) -> expr_ref { + return expr_ref(m.mk_not(m.mk_and(args[0], args[1])), m); + }}; } void parse_input(std::istream& in, std::string const& current_file) { @@ -1460,51 +1664,22 @@ public: expr_ref tptp_parser::parse_term() { expr_ref e = parse_term_primary(); if (!is(token_kind::at_tok)) return e; - if (!is_app(e)) - throw parse_error("application operator (@) requires uninterpreted function on left-hand side"); - app* a = to_app(e); - func_decl* orig_d = a->get_decl(); - if (orig_d->get_family_id() != null_family_id) - throw parse_error("application operator (@) requires uninterpreted function on left-hand side"); - expr_ref_vector args(m); - for (unsigned i = 0; i < a->get_num_args(); ++i) args.push_back(a->get_arg(i)); + // @ corresponds to array select (function application) while (accept(token_kind::at_tok)) { expr_ref arg = parse_at_arg(); - args.push_back(arg); - } - unsigned new_arity = args.size(); - std::string name = orig_d->get_name().str(); - func_decl* new_d = mk_decl(name, new_arity, false); - // Verify argument sorts match; if not, create func_decl with actual arg sorts - bool sorts_ok = (new_d->get_arity() == new_arity); - if (sorts_ok) { - for (unsigned i = 0; i < new_arity; ++i) { - if (new_d->get_domain(i) != args.get(i)->get_sort()) { - sorts_ok = false; - break; - } + sort* e_sort = e->get_sort(); + if (!m_array.is_array(e_sort)) { + std::ostringstream out; + out << "application operator (@) requires array/function type on left-hand side, got sort " << e_sort->get_name(); + throw parse_error(out.str()); } + e = expr_ref(m_array.mk_select(e, arg), m); } - if (!sorts_ok) { - std::vector dom; - for (unsigned i = 0; i < new_arity; ++i) dom.push_back(args.get(i)->get_sort()); - sort* range = new_d->get_range(); - new_d = m.mk_func_decl(symbol(name), new_arity, dom.data(), range); - m_pinned_decls.push_back(new_d); - } - return expr_ref(m.mk_app(new_d, new_arity, args.data()), m); + return e; } expr_ref tptp_parser::parse_formula() { - expr_ref e = parse_implies_formula(); - while (is(token_kind::iff_tok) || is(token_kind::xor_tok)) { - bool is_xor = accept(token_kind::xor_tok); - if (!is_xor) expect(token_kind::iff_tok, "'<=>'"); - expr_ref rhs = parse_implies_formula(); - expr_ref iff(m.mk_iff(e, rhs), m); - e = is_xor ? expr_ref(m.mk_not(iff), m) : iff; - } - return e; + return parse_expr(PREC_IFF); } } diff --git a/src/shell/main.cpp b/src/shell/main.cpp index 9a4cd6f27..0703f731f 100644 --- a/src/shell/main.cpp +++ b/src/shell/main.cpp @@ -337,10 +337,35 @@ static void parse_cmd_line_args(std::string& input_file, int argc, char ** argv) } } else if (argv[i][0] != '"' && (eq_pos = strchr(argv[i], '='))) { - char * key = argv[i]; - *eq_pos = 0; - char * value = eq_pos+1; - gparams::set(key, value); + // If the argument looks like a file path (contains path separators + // or has a file extension), treat it as a filename rather than + // a parameter assignment. This handles files with '=' in their names. + bool is_filepath = strchr(argv[i], '/') || strchr(argv[i], '\\'); + if (!is_filepath) { + char const * ext = get_extension(argv[i]); + if (ext && (strcmp(ext, "smt2") == 0 || strcmp(ext, "smt") == 0 || + strcmp(ext, "dimacs") == 0 || strcmp(ext, "cnf") == 0 || + strcmp(ext, "wcnf") == 0 || strcmp(ext, "opb") == 0 || + strcmp(ext, "lp") == 0 || strcmp(ext, "log") == 0 || + strcmp(ext, "drat") == 0 || strcmp(ext, "p") == 0)) + is_filepath = true; + } + if (is_filepath) { + if (get_extension(arg) && strcmp(get_extension(arg), "drat") == 0) { + g_input_kind = IN_DRAT; + g_drat_input_file = arg; + } + else if (g_input_file) + warning_msg("input file was already specified."); + else + g_input_file = arg; + } + else { + char * key = argv[i]; + *eq_pos = 0; + char * value = eq_pos+1; + gparams::set(key, value); + } } else { if (get_extension(arg) && strcmp(get_extension(arg), "drat") == 0) { diff --git a/src/solver/smt_logics.cpp b/src/solver/smt_logics.cpp index a02b90880..b47669c6e 100644 --- a/src/solver/smt_logics.cpp +++ b/src/solver/smt_logics.cpp @@ -50,10 +50,7 @@ bool smt_logics::logic_has_arith(symbol const & s) { str.find("IDL") != std::string::npos || str.find("RDL") != std::string::npos || str == "QF_BVRE" || - str == "QF_FP" || - str == "FP" || - str == "QF_FPBV" || - str == "QF_BVFP" || + logic_has_fpa(s) || str == "QF_S" || logic_is_all(s) || str == "QF_FD" || @@ -102,11 +99,7 @@ bool smt_logics::logic_has_str(symbol const & s) { bool smt_logics::logic_has_fpa(symbol const & s) { auto str = s.str(); - return str == "FP" || - str == "QF_FP" || - str == "QF_FPBV" || - str == "QF_BVFP" || - str == "QF_FPLRA" || + return str.find("FP") != std::string::npos || logic_is_all(s); }