3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2026-05-15 22:55:33 +00:00
Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2026-05-13 13:40:32 -07:00
parent ce07160d64
commit 153c6a017a

View file

@ -259,8 +259,10 @@ class tptp_parser {
arith_util m_arith;
sort* m_univ;
bool m_has_conjecture = false;
bool m_last_name_quoted = false;
unsigned m_lambda_counter = 0;
std::unordered_map<std::string, sort*> m_sorts;
std::unordered_map<sort*, std::pair<std::vector<sort*>, sort*>> m_ho_sort_info; // ho_sort → (domain, range)
std::unordered_map<std::string, func_decl*> m_decls;
func_decl_ref_vector m_pinned_decls; // prevents cached func_decls from being freed
std::unordered_map<std::string, std::pair<std::vector<sort*>, sort*>> m_typed_decls;
@ -315,6 +317,7 @@ class tptp_parser {
std::string parse_name() {
if (is(token_kind::id) || is(token_kind::str)) {
m_last_name_quoted = is(token_kind::str);
std::string r = m_curr.text;
next();
return r;
@ -351,6 +354,7 @@ class tptp_parser {
if (it != m_sorts.end()) return it->second;
sort* s = m.mk_uninterpreted_sort(symbol(key));
m_sorts.emplace(key, s);
m_ho_sort_info.emplace(s, std::make_pair(domain, range));
return s;
}
@ -459,6 +463,16 @@ class tptp_parser {
return false;
}
bool is_bound_var(app* a) const {
std::string name = a->get_decl()->get_name().str();
for (auto it = m_bound.rbegin(); it != m_bound.rend(); ++it) {
auto jt = it->find(name);
if (jt != it->end() && jt->second == a)
return true;
}
return false;
}
bool should_create_implicit_var(std::string const& n) const {
return is_var_name(n) && m_implicit_scope;
}
@ -643,9 +657,9 @@ class tptp_parser {
}
expr_ref b(m);
if (is_var_name(n) && find_bound(n, b))
if (!m_last_name_quoted && is_var_name(n) && find_bound(n, b))
return b;
if (should_create_implicit_var(n))
if (!m_last_name_quoted && should_create_implicit_var(n))
return expr_ref(get_or_create_implicit_var(n), m);
expr_ref_vector args(m);
@ -665,6 +679,73 @@ class tptp_parser {
return expr_ref(m_arith.mk_uminus(a), m);
}
if (n == "$sum" || n == "$difference" || n == "$product") {
if (args.size() != 2)
throw parse_error("arithmetic function expects arity 2");
expr_ref a(args.get(0), m), b(args.get(1), m);
if (!m_arith.is_int_real(a) || !m_arith.is_int_real(b))
throw parse_error("arithmetic function expects arithmetic arguments");
bool use_real = m_arith.is_real(a) || m_arith.is_real(b);
if (use_real) {
if (m_arith.is_int(a)) a = expr_ref(m_arith.mk_to_real(a), m);
if (m_arith.is_int(b)) b = expr_ref(m_arith.mk_to_real(b), m);
}
if (n == "$sum") return expr_ref(m_arith.mk_add(a, b), m);
if (n == "$difference") return expr_ref(m_arith.mk_sub(a, b), m);
/* $product */ return expr_ref(m_arith.mk_mul(a, b), m);
}
if (n == "$quotient_e" || n == "$quotient_t" || n == "$quotient_f" || n == "$quotient") {
if (args.size() != 2)
throw parse_error("arithmetic function expects arity 2");
expr_ref a(args.get(0), m), b(args.get(1), m);
if (!m_arith.is_int_real(a) || !m_arith.is_int_real(b))
throw parse_error("arithmetic function expects arithmetic arguments");
if (m_arith.is_int(a) && m_arith.is_int(b))
return expr_ref(m_arith.mk_idiv(a, b), m);
if (m_arith.is_int(a)) a = expr_ref(m_arith.mk_to_real(a), m);
if (m_arith.is_int(b)) b = expr_ref(m_arith.mk_to_real(b), m);
return expr_ref(m_arith.mk_div(a, b), m);
}
if (n == "$remainder_e" || n == "$remainder_t" || n == "$remainder_f") {
if (args.size() != 2)
throw parse_error("arithmetic function expects arity 2");
expr_ref a(args.get(0), m), b(args.get(1), m);
if (!m_arith.is_int_real(a) || !m_arith.is_int_real(b))
throw parse_error("arithmetic function expects arithmetic arguments");
return expr_ref(m_arith.mk_mod(a, b), m);
}
if (n == "$floor" || n == "$ceiling" || n == "$truncate" || n == "$round" || n == "$to_int") {
if (args.size() != 1)
throw parse_error("arithmetic function expects arity 1");
expr_ref a(args.get(0), m);
if (!m_arith.is_int_real(a))
throw parse_error("arithmetic function expects arithmetic argument");
if (m_arith.is_int(a)) return a;
return expr_ref(m_arith.mk_to_int(a), m);
}
if (n == "$to_rat" || n == "$to_real") {
if (args.size() != 1)
throw parse_error("arithmetic function expects arity 1");
expr_ref a(args.get(0), m);
if (!m_arith.is_int_real(a))
throw parse_error("arithmetic function expects arithmetic argument");
if (m_arith.is_real(a)) return a;
return expr_ref(m_arith.mk_to_real(a), m);
}
if (n == "$is_int") {
if (args.size() != 1)
throw parse_error("arithmetic predicate '$is_int' expects arity 1");
expr_ref a(args.get(0), m);
if (!m_arith.is_int_real(a))
throw parse_error("arithmetic predicate '$is_int' expects arithmetic argument");
return expr_ref(m_arith.mk_is_int(a), m);
}
func_decl* f = mk_decl(n, args.size(), false);
return expr_ref(args.empty() ? m.mk_const(f) : m.mk_app(f, args.size(), args.data()), m);
}
@ -692,6 +773,53 @@ class tptp_parser {
unsigned new_arity = args.size();
std::string name = orig_d->get_name().str();
// Check if the LHS has a ho_sort (function type as value).
// If so, use its type info to determine domain/range for the application.
sort* lhs_sort = orig_d->get_range(); // sort of the 0-arity constant
auto ho_it = m_ho_sort_info.find(lhs_sort);
if (ho_it != m_ho_sort_info.end() && a->get_num_args() == 0) {
auto& [ho_domain, ho_range] = ho_it->second;
sort* range = ho_range;
// If we're partially applying (fewer args than domain), result is a ho_sort
if (new_arity < ho_domain.size()) {
std::vector<sort*> remaining(ho_domain.begin() + new_arity, ho_domain.end());
range = get_ho_sort(remaining, ho_range);
}
// For bound variables, use $apply so that the variable remains in the
// expression body and mk_forall can properly abstract it.
if (is_bound_var(a)) {
std::vector<sort*> apply_dom;
apply_dom.push_back(lhs_sort); // first arg is the function variable
for (unsigned i = 0; i < new_arity; ++i) {
if (i < ho_domain.size())
apply_dom.push_back(ho_domain[i]);
else
apply_dom.push_back(args.get(i)->get_sort());
}
// Create unique apply function per signature
std::ostringstream apn;
apn << "$apply_" << lhs_sort->get_name() << "_" << new_arity;
func_decl* apply_d = m.mk_func_decl(symbol(apn.str()), apply_dom.size(), apply_dom.data(), range);
m_pinned_decls.push_back(apply_d);
expr_ref_vector all_args(m);
all_args.push_back(e); // the bound variable itself
for (unsigned i = 0; i < new_arity; ++i) all_args.push_back(args.get(i));
return expr_ref(m.mk_app(apply_d, all_args.size(), all_args.data()), m);
}
// For non-bound ho-typed constants, use the name directly
std::vector<sort*> dom;
for (unsigned i = 0; i < new_arity; ++i) {
if (i < ho_domain.size())
dom.push_back(ho_domain[i]);
else
dom.push_back(args.get(i)->get_sort());
}
func_decl* new_d = m.mk_func_decl(symbol(name), new_arity, dom.data(), range);
m_pinned_decls.push_back(new_d);
return expr_ref(m.mk_app(new_d, new_arity, args.data()), m);
}
func_decl* new_d = mk_decl(name, new_arity, false);
// Verify argument sorts match; if not, create untyped func_decl with actual arg sorts
bool sorts_ok = (new_d->get_arity() == new_arity);
@ -759,6 +887,54 @@ class tptp_parser {
return parse_term_primary();
}
// Build an arithmetic expression from a TPTP function name and arguments
expr_ref mk_arith_expr(std::string const& n, expr_ref_vector const& args) {
if (n == "$uminus") {
if (args.size() != 1) throw parse_error("$uminus expects arity 1");
expr_ref a(args.get(0), m);
return expr_ref(m_arith.mk_uminus(a), m);
}
if (n == "$sum" || n == "$difference" || n == "$product") {
if (args.size() != 2) throw parse_error("arithmetic binary function expects arity 2");
expr_ref a(args.get(0), m), b(args.get(1), m);
bool use_real = m_arith.is_real(a) || m_arith.is_real(b);
if (use_real) {
if (m_arith.is_int(a)) a = expr_ref(m_arith.mk_to_real(a), m);
if (m_arith.is_int(b)) b = expr_ref(m_arith.mk_to_real(b), m);
}
if (n == "$sum") return expr_ref(m_arith.mk_add(a, b), m);
if (n == "$difference") return expr_ref(m_arith.mk_sub(a, b), m);
return expr_ref(m_arith.mk_mul(a, b), m);
}
if (n == "$quotient" || n == "$quotient_e" || n == "$quotient_t" || n == "$quotient_f") {
if (args.size() != 2) throw parse_error("quotient expects arity 2");
expr_ref a(args.get(0), m), b(args.get(1), m);
if (m_arith.is_int(a) && m_arith.is_int(b))
return expr_ref(m_arith.mk_idiv(a, b), m);
if (m_arith.is_int(a)) a = expr_ref(m_arith.mk_to_real(a), m);
if (m_arith.is_int(b)) b = expr_ref(m_arith.mk_to_real(b), m);
return expr_ref(m_arith.mk_div(a, b), m);
}
if (n == "$remainder_e" || n == "$remainder_t" || n == "$remainder_f") {
if (args.size() != 2) throw parse_error("remainder expects arity 2");
expr_ref a(args.get(0), m), b(args.get(1), m);
return expr_ref(m_arith.mk_mod(a, b), m);
}
if (n == "$floor" || n == "$ceiling" || n == "$truncate" || n == "$round" || n == "$to_int") {
if (args.size() != 1) throw parse_error("$to_int expects arity 1");
expr_ref a(args.get(0), m);
if (m_arith.is_int(a)) return a;
return expr_ref(m_arith.mk_to_int(a), m);
}
if (n == "$to_rat" || n == "$to_real") {
if (args.size() != 1) throw parse_error("$to_real expects arity 1");
expr_ref a(args.get(0), m);
if (m_arith.is_real(a)) return a;
return expr_ref(m_arith.mk_to_real(a), m);
}
throw parse_error("unknown arithmetic function: " + n);
}
// Coerce two expressions to have the same sort for equality.
// If sorts already match, returns lhs unchanged. Otherwise coerces the
// 0-arity constant side to match the other's sort. If that's not possible,
@ -795,11 +971,11 @@ class tptp_parser {
if (accept(token_kind::lparen)) {
expr_ref e = parse_formula();
// Handle equality/inequality inside parenthesized expressions
// e.g., ( (f @ Y) = (g @ X @ Z) ) where f(Y) is non-Bool
if (!m.is_bool(e) && (is(token_kind::equal_tok) || is(token_kind::neq_tok))) {
// In THF, (A = B) is used even for Bool-sorted expressions (meaning iff)
if (is(token_kind::equal_tok) || is(token_kind::neq_tok)) {
bool neq = accept(token_kind::neq_tok);
if (!neq) expect(token_kind::equal_tok, "'='");
expr_ref rhs = parse_term();
expr_ref rhs = m.is_bool(e) ? parse_formula() : parse_term();
e = coerce_eq(e, rhs);
expr_ref eq(m.mk_eq(e, rhs), m);
e = neq ? expr_ref(m.mk_not(eq), m) : eq;
@ -855,15 +1031,34 @@ class tptp_parser {
/* n == "$greatereq"*/ return expr_ref(m_arith.mk_ge(lhs, rhs), m);
}
// Arithmetic terms that may appear before = or != in formula position
if (n == "$sum" || n == "$difference" || n == "$product" ||
n == "$quotient" || n == "$quotient_e" || n == "$quotient_t" || n == "$quotient_f" ||
n == "$remainder_e" || n == "$remainder_t" || n == "$remainder_f" ||
n == "$uminus" || n == "$floor" || n == "$ceiling" || n == "$truncate" ||
n == "$round" || n == "$to_int" || n == "$to_rat" || n == "$to_real") {
// Re-parse via the term path by constructing the arithmetic expression
expr_ref arith_expr = mk_arith_expr(n, args);
if (is(token_kind::equal_tok) || is(token_kind::neq_tok)) {
bool neq = accept(token_kind::neq_tok);
if (!neq) expect(token_kind::equal_tok, "'='");
expr_ref rhs = parse_term();
arith_expr = coerce_eq(arith_expr, rhs);
expr_ref eq(m.mk_eq(arith_expr, rhs), m);
return neq ? expr_ref(m.mk_not(eq), m) : eq;
}
return arith_expr;
}
expr_ref lhs(m);
bool has_lhs = false;
if (args.empty()) {
expr_ref b(m);
if (is_var_name(n) && find_bound(n, b)) {
if (!m_last_name_quoted && is_var_name(n) && find_bound(n, b)) {
lhs = b;
has_lhs = true;
}
else if (should_create_implicit_var(n)) {
else if (!m_last_name_quoted && should_create_implicit_var(n)) {
lhs = expr_ref(get_or_create_implicit_var(n), m);
has_lhs = true;
}
@ -953,9 +1148,25 @@ class tptp_parser {
oss << "$lambda_" << m_lambda_counter++;
std::string lname = oss.str();
// Lambda type: param sorts → body sort
// If the body sort is itself a ho_sort, flatten it into the lambda's type
// e.g., ^ [X: mu] : (body with sort (U > Bool)) → sort (mu * U > Bool)
std::vector<sort*> param_sorts;
for (auto* v : vars) param_sorts.push_back(v->get_sort());
sort* lambda_sort = param_sorts.empty() ? body->get_sort() : get_ho_sort(param_sorts, body->get_sort());
sort* body_sort = body->get_sort();
sort* lambda_sort;
if (param_sorts.empty()) {
lambda_sort = body_sort;
} else {
auto body_ho = m_ho_sort_info.find(body_sort);
if (body_ho != m_ho_sort_info.end()) {
// Flatten: params + body's domain → body's range
std::vector<sort*> full_domain = param_sorts;
full_domain.insert(full_domain.end(), body_ho->second.first.begin(), body_ho->second.first.end());
lambda_sort = get_ho_sort(full_domain, body_ho->second.second);
} else {
lambda_sort = get_ho_sort(param_sorts, body_sort);
}
}
func_decl* f = m.mk_func_decl(symbol(lname), 0, static_cast<sort**>(nullptr), lambda_sort);
m_pinned_decls.push_back(f);
return expr_ref(m.mk_const(f), m);
@ -1128,9 +1339,13 @@ class tptp_parser {
}
if (role == "conjecture") {
m_has_conjecture = true;
f = m.mk_not(f);
if (m.is_bool(f))
f = m.mk_not(f);
}
m_cmd.assert_expr(f);
// Only assert Bool-sorted formulas; non-Bool results from
// incomplete higher-order approximation are silently skipped.
if (m.is_bool(f))
m_cmd.assert_expr(f);
}
if (accept(token_kind::comma)) {