diff --git a/src/api/api_seq.cpp b/src/api/api_seq.cpp index 9dfac4a74..a79a63b4d 100644 --- a/src/api/api_seq.cpp +++ b/src/api/api_seq.cpp @@ -55,7 +55,6 @@ extern "C" { Z3_CATCH_RETURN(nullptr); } - Z3_ast Z3_API Z3_mk_lstring(Z3_context c, unsigned sz, Z3_string str) { Z3_TRY; LOG_Z3_mk_lstring(c, sz, str); @@ -80,6 +79,16 @@ extern "C" { Z3_CATCH_RETURN(nullptr); } + Z3_ast Z3_API Z3_mk_char(Z3_context c, unsigned ch) { + Z3_TRY; + LOG_Z3_mk_char(c, ch); + RESET_ERROR_CODE(); + app* a = mk_c(c)->sutil().str.mk_char(ch); + mk_c(c)->save_ast_trail(a); + RETURN_Z3(of_ast(a)); + Z3_CATCH_RETURN(nullptr); + } + Z3_sort Z3_API Z3_mk_string_sort(Z3_context c) { Z3_TRY; LOG_Z3_mk_string_sort(c); diff --git a/src/api/python/z3/z3.py b/src/api/python/z3/z3.py index 2b0286608..6cea34963 100644 --- a/src/api/python/z3/z3.py +++ b/src/api/python/z3/z3.py @@ -1150,6 +1150,8 @@ def _to_expr_ref(a, ctx): return FPRMRef(a, ctx) if sk == Z3_SEQ_SORT: return SeqRef(a, ctx) + if sk == Z3_CHAR_SORT: + return CharRef(a, ctx) if sk == Z3_RE_SORT: return ReRef(a, ctx) return ExprRef(a, ctx) @@ -10584,7 +10586,6 @@ class CharSortRef(SortRef): """Character sort.""" - def StringSort(ctx=None): """Create a string sort >>> s = StringSort() @@ -10650,18 +10651,67 @@ class SeqRef(ExprRef): return Z3_ast_to_string(self.ctx_ref(), self.as_ast()) def __le__(self, other): - return SeqRef(Z3_mk_str_le(self.ctx_ref(), self.as_ast(), other.as_ast()), self.ctx) + return _to_expr_ref(Z3_mk_str_le(self.ctx_ref(), self.as_ast(), other.as_ast()), self.ctx) def __lt__(self, other): - return SeqRef(Z3_mk_str_lt(self.ctx_ref(), self.as_ast(), other.as_ast()), self.ctx) + return _to_expr_ref(Z3_mk_str_lt(self.ctx_ref(), self.as_ast(), other.as_ast()), self.ctx) def __ge__(self, other): - return SeqRef(Z3_mk_str_le(self.ctx_ref(), other.as_ast(), self.as_ast()), self.ctx) + return _to_expr_ref(Z3_mk_str_le(self.ctx_ref(), other.as_ast(), self.as_ast()), self.ctx) def __gt__(self, other): - return SeqRef(Z3_mk_str_lt(self.ctx_ref(), other.as_ast(), self.as_ast()), self.ctx) + return _to_expr_ref(Z3_mk_str_lt(self.ctx_ref(), other.as_ast(), self.as_ast()), self.ctx) +def _coerce_char(ch, ctx=None): + if isinstance(ch, str): + ctx = _get_ctx(ctx) + ch = CharVal(ch, ctx) + if not is_expr(ch): + raise Z3Exception("Character expression expected") + return ch + +class CharRef(ExprRef): + """Character expression.""" + + def __le__(self, other): + other = _coerce_char(other, self.ctx) + return _to_expr_ref(Z3_mk_char_le(self.ctx_ref(), self.as_ast(), other.as_ast()), self.ctx) + + def to_int(self): + return _to_expr_ref(Z3_mk_char_to_int(self.ctx_ref(), self.as_ast()), self.ctx) + + def to_bv(self): + return _to_expr_ref(Z3_mk_char_to_bv(self.ctx_ref(), self.as_ast()), self.ctx) + + def is_digit(self): + return _to_expr_ref(Z3_mk_char_is_digit(self.ctx_ref(), self.as_ast()), self.ctx) + + +def CharVal(ch, ctx=None): + ctx = _get_ctx(ctx) + if isinstance(ch, str): + ch = ord(ch) + if not isinstance(ch, int): + raise Z3Exception("character value should be an ordinal") + return _to_expr_ref(Z3_mk_char(ctx.ref(), ch), ctx) + +def CharFromBv(ch, ctx=None): + ch = _coerce_char(ch, ctx) + return _to_expr_ref(Z3_mk_char_from_bv(ch.ctx_ref(), ch.as_ast()), ch.ctx) + +def CharToBv(ch, ctx=None): + ch = _coerce_char(ch, ctx) + return ch.to_bv() + +def CharToInt(ch, ctx=None): + ch = _coerce_char(ch, ctx) + return ch.to_int() + +def CharIsDigit(ch, ctx=None): + ch = _coerce_char(ch, ctx) + return ch.is_digit() + def _coerce_seq(s, ctx=None): if isinstance(s, str): ctx = _get_ctx(ctx) @@ -10778,6 +10828,7 @@ def Full(s): raise Z3Exception("Non-sequence, non-regular expression sort passed to Full") + def Unit(a): """Create a singleton sequence""" return SeqRef(Z3_mk_seq_unit(a.ctx_ref(), a.as_ast()), a.ctx) diff --git a/src/api/z3_api.h b/src/api/z3_api.h index 3f964adc6..16b22a991 100644 --- a/src/api/z3_api.h +++ b/src/api/z3_api.h @@ -3843,6 +3843,13 @@ extern "C" { */ Z3_ast Z3_API Z3_mk_re_full(Z3_context c, Z3_sort re); + + /** + \brief Create a character literal + def_API('Z3_mk_char', AST, (_in(CONTEXT), _in(UINT))) + */ + Z3_ast Z3_API Z3_mk_char(Z3_context c, unsigned ch); + /** \brief Create less than or equal to between two characters.