diff --git a/mypy/typeshed/stdlib/typing.pyi b/mypy/typeshed/stdlib/typing.pyi index aafb1fbdf1b3..0995c9412726 100644 --- a/mypy/typeshed/stdlib/typing.pyi +++ b/mypy/typeshed/stdlib/typing.pyi @@ -2,6 +2,7 @@ import collections # Needed by aliases like DefaultDict, see mypy issue 2986 import sys from abc import ABCMeta, abstractmethod from types import BuiltinFunctionType, CodeType, FrameType, FunctionType, MethodType, ModuleType, TracebackType +from typing_extensions import Literal as _Literal if sys.version_info >= (3, 7): from types import MethodDescriptorType, MethodWrapperType, WrapperDescriptorType @@ -563,19 +564,35 @@ class Match(Generic[AnyStr]): # this match instance. re: Pattern[AnyStr] def expand(self, template: AnyStr) -> AnyStr: ... - # TODO: The return for a group may be None, except if __group is 0 or not given. + # group() returns "AnyStr" or "AnyStr | None", depending on the pattern. @overload - def group(self, __group: Union[str, int] = ...) -> AnyStr: ... + def group(self, __group: _Literal[0] = ...) -> AnyStr: ... @overload - def group(self, __group1: Union[str, int], __group2: Union[str, int], *groups: Union[str, int]) -> Tuple[AnyStr, ...]: ... - def groups(self, default: AnyStr = ...) -> Sequence[AnyStr]: ... - def groupdict(self, default: AnyStr = ...) -> dict[str, AnyStr]: ... + def group(self, __group: str | int) -> AnyStr | Any: ... + @overload + def group(self, __group1: str | int, __group2: str | int, *groups: str | int) -> Tuple[AnyStr | Any, ...]: ... + # Each item of groups()'s return tuple is either "AnyStr" or + # "AnyStr | None", depending on the pattern. + @overload + def groups(self) -> Tuple[AnyStr | Any, ...]: ... + @overload + def groups(self, default: _T) -> Tuple[AnyStr | _T, ...]: ... + # Each value in groupdict()'s return dict is either "AnyStr" or + # "AnyStr | None", depending on the pattern. + @overload + def groupdict(self) -> dict[str, AnyStr | Any]: ... + @overload + def groupdict(self, default: _T) -> dict[str, AnyStr | _T]: ... def start(self, __group: Union[int, str] = ...) -> int: ... def end(self, __group: Union[int, str] = ...) -> int: ... def span(self, __group: Union[int, str] = ...) -> Tuple[int, int]: ... @property def regs(self) -> Tuple[Tuple[int, int], ...]: ... # undocumented - def __getitem__(self, g: Union[int, str]) -> AnyStr: ... + # __getitem__() returns "AnyStr" or "AnyStr | None", depending on the pattern. + @overload + def __getitem__(self, __key: _Literal[0]) -> AnyStr: ... + @overload + def __getitem__(self, __key: int | str) -> AnyStr | Any: ... if sys.version_info >= (3, 9): def __class_getitem__(cls, item: Any) -> GenericAlias: ... diff --git a/mypyc/codegen/emit.py b/mypyc/codegen/emit.py index 9053fac1acff..0ef6dd57e441 100644 --- a/mypyc/codegen/emit.py +++ b/mypyc/codegen/emit.py @@ -15,7 +15,7 @@ is_list_rprimitive, is_dict_rprimitive, is_set_rprimitive, is_tuple_rprimitive, is_none_rprimitive, is_object_rprimitive, object_rprimitive, is_str_rprimitive, int_rprimitive, is_optional_type, optional_value_type, is_int32_rprimitive, - is_int64_rprimitive, is_bit_rprimitive, is_range_rprimitive + is_int64_rprimitive, is_bit_rprimitive, is_range_rprimitive, is_bytes_rprimitive ) from mypyc.ir.func_ir import FuncDecl from mypyc.ir.class_ir import ClassIR, all_concrete_classes @@ -451,8 +451,8 @@ def emit_cast(self, # TODO: Verify refcount handling. if (is_list_rprimitive(typ) or is_dict_rprimitive(typ) or is_set_rprimitive(typ) - or is_str_rprimitive(typ) or is_range_rprimitive(typ) or is_float_rprimitive(typ) - or is_int_rprimitive(typ) or is_bool_rprimitive(typ)): + or is_str_rprimitive(typ) or is_bytes_rprimitive(typ) or is_range_rprimitive(typ) + or is_float_rprimitive(typ) or is_int_rprimitive(typ) or is_bool_rprimitive(typ)): if declare_dest: self.emit_line('PyObject *{};'.format(dest)) if is_list_rprimitive(typ): @@ -463,6 +463,8 @@ def emit_cast(self, prefix = 'PySet' elif is_str_rprimitive(typ): prefix = 'PyUnicode' + elif is_bytes_rprimitive(typ): + prefix = 'PyBytes' elif is_range_rprimitive(typ): prefix = 'PyRange' elif is_float_rprimitive(typ): diff --git a/mypyc/ir/rtypes.py b/mypyc/ir/rtypes.py index 0aaa03c1e2d2..87f1bd1f2be3 100644 --- a/mypyc/ir/rtypes.py +++ b/mypyc/ir/rtypes.py @@ -328,6 +328,9 @@ def __hash__(self) -> int: # (PyUnicode). str_rprimitive: Final = RPrimitive("builtins.str", is_unboxed=False, is_refcounted=True) +# Python bytes object. +bytes_rprimitive: Final = RPrimitive('builtins.bytes', is_unboxed=False, is_refcounted=True) + # Tuple of an arbitrary length (corresponds to Tuple[t, ...], with # explicit '...'). tuple_rprimitive: Final = RPrimitive("builtins.tuple", is_unboxed=False, is_refcounted=True) @@ -410,6 +413,10 @@ def is_str_rprimitive(rtype: RType) -> bool: return isinstance(rtype, RPrimitive) and rtype.name == 'builtins.str' +def is_bytes_rprimitive(rtype: RType) -> bool: + return isinstance(rtype, RPrimitive) and rtype.name == 'builtins.bytes' + + def is_tuple_rprimitive(rtype: RType) -> bool: return isinstance(rtype, RPrimitive) and rtype.name == 'builtins.tuple' diff --git a/mypyc/irbuild/ll_builder.py b/mypyc/irbuild/ll_builder.py index 45a1ca647f76..e5d46bf5edd4 100644 --- a/mypyc/irbuild/ll_builder.py +++ b/mypyc/irbuild/ll_builder.py @@ -32,8 +32,8 @@ c_pyssize_t_rprimitive, is_short_int_rprimitive, is_tagged, PyVarObject, short_int_rprimitive, is_list_rprimitive, is_tuple_rprimitive, is_dict_rprimitive, is_set_rprimitive, PySetObject, none_rprimitive, RTuple, is_bool_rprimitive, is_str_rprimitive, c_int_rprimitive, - pointer_rprimitive, PyObject, bit_rprimitive, is_bit_rprimitive, - object_pointer_rprimitive, c_size_t_rprimitive, dict_rprimitive, PyListObject + pointer_rprimitive, PyObject, PyListObject, bit_rprimitive, is_bit_rprimitive, + object_pointer_rprimitive, c_size_t_rprimitive, dict_rprimitive, bytes_rprimitive ) from mypyc.ir.func_ir import FuncDecl, FuncSignature from mypyc.ir.class_ir import ClassIR, all_concrete_classes @@ -802,7 +802,7 @@ def load_str(self, value: str) -> Value: def load_bytes(self, value: bytes) -> Value: """Load a bytes literal value.""" - return self.add(LoadLiteral(value, object_rprimitive)) + return self.add(LoadLiteral(value, bytes_rprimitive)) def load_complex(self, value: complex) -> Value: """Load a complex literal value.""" diff --git a/mypyc/irbuild/mapper.py b/mypyc/irbuild/mapper.py index 37b3861643a2..901ea49fc2fa 100644 --- a/mypyc/irbuild/mapper.py +++ b/mypyc/irbuild/mapper.py @@ -12,7 +12,7 @@ from mypyc.ir.rtypes import ( RType, RUnion, RTuple, RInstance, object_rprimitive, dict_rprimitive, tuple_rprimitive, none_rprimitive, int_rprimitive, float_rprimitive, str_rprimitive, bool_rprimitive, - list_rprimitive, set_rprimitive, range_rprimitive + list_rprimitive, set_rprimitive, range_rprimitive, bytes_rprimitive ) from mypyc.ir.func_ir import FuncSignature, FuncDecl, RuntimeArg from mypyc.ir.class_ir import ClassIR @@ -43,10 +43,12 @@ def type_to_rtype(self, typ: Optional[Type]) -> RType: return int_rprimitive elif typ.type.fullname == 'builtins.float': return float_rprimitive - elif typ.type.fullname == 'builtins.str': - return str_rprimitive elif typ.type.fullname == 'builtins.bool': return bool_rprimitive + elif typ.type.fullname == 'builtins.str': + return str_rprimitive + elif typ.type.fullname == 'builtins.bytes': + return bytes_rprimitive elif typ.type.fullname == 'builtins.list': return list_rprimitive # Dict subclasses are at least somewhat common and we diff --git a/mypyc/primitives/bytes_ops.py b/mypyc/primitives/bytes_ops.py new file mode 100644 index 000000000000..5981284c9484 --- /dev/null +++ b/mypyc/primitives/bytes_ops.py @@ -0,0 +1,11 @@ +"""Primitive bytes ops.""" + +from mypyc.ir.rtypes import object_rprimitive +from mypyc.primitives.registry import load_address_op + + +# Get the 'bytes' type object. +load_address_op( + name='builtins.bytes', + type=object_rprimitive, + src='PyBytes_Type') diff --git a/mypyc/primitives/registry.py b/mypyc/primitives/registry.py index 4be29885903a..5ed910549f5a 100644 --- a/mypyc/primitives/registry.py +++ b/mypyc/primitives/registry.py @@ -239,6 +239,7 @@ def load_address_op(name: str, # Import various modules that set up global state. import mypyc.primitives.int_ops # noqa import mypyc.primitives.str_ops # noqa +import mypyc.primitives.bytes_ops # noqa import mypyc.primitives.list_ops # noqa import mypyc.primitives.dict_ops # noqa import mypyc.primitives.tuple_ops # noqa diff --git a/mypyc/test-data/driver/driver.py b/mypyc/test-data/driver/driver.py index 957b25a01987..6717f402f72d 100644 --- a/mypyc/test-data/driver/driver.py +++ b/mypyc/test-data/driver/driver.py @@ -26,7 +26,7 @@ def extract_line(tb): formatted = '\n'.join(format_tb(tb)) - m = re.search('File "native.py", line ([0-9]+), in test_', formatted) + m = re.search('File "(native|driver).py", line ([0-9]+), in (test_|)', formatted) if m is None: return "0" return m.group(1) diff --git a/mypyc/test-data/fixtures/ir.py b/mypyc/test-data/fixtures/ir.py index ad5c2b4aa87b..9e32a70dfd79 100644 --- a/mypyc/test-data/fixtures/ir.py +++ b/mypyc/test-data/fixtures/ir.py @@ -96,10 +96,14 @@ def __truediv__(self, n: complex) -> complex: pass def __neg__(self) -> complex: pass class bytes: + @overload + def __init__(self) -> None: pass + @overload def __init__(self, x: object) -> None: pass - def __add__(self, x: object) -> bytes: pass - def __eq__(self, x:object) -> bool:pass + def __add__(self, x: bytes) -> bytes: pass + def __eq__(self, x: object) -> bool: pass def __ne__(self, x: object) -> bool: pass + def __getitem__(self, i: int) -> int: pass def join(self, x: Iterable[object]) -> bytes: pass class bool(int): diff --git a/mypyc/test-data/irbuild-basic.test b/mypyc/test-data/irbuild-basic.test index 5b102ba2631f..070bf7e333c8 100644 --- a/mypyc/test-data/irbuild-basic.test +++ b/mypyc/test-data/irbuild-basic.test @@ -761,7 +761,7 @@ def f() -> bytes: return b'1234' [out] def f(): - r0, x, r1 :: object + r0, x, r1 :: bytes L0: r0 = b'\xf0' x = r0 diff --git a/mypyc/test-data/run-bytes.test b/mypyc/test-data/run-bytes.test new file mode 100644 index 000000000000..b88201fc896a --- /dev/null +++ b/mypyc/test-data/run-bytes.test @@ -0,0 +1,55 @@ +# Bytes test cases (compile and run) + +[case testBytesBasics] +# Note: Add tests for additional operations to testBytesOps or in a new test case + +def f(x: bytes) -> bytes: + return x + +def eq(a: bytes, b: bytes) -> bool: + return a == b + +def neq(a: bytes, b: bytes) -> bool: + return a != b +[file driver.py] +from native import f, eq, neq +assert f(b'123') == b'123' +assert f(b'\x07 \x0b " \t \x7f \xf0') == b'\x07 \x0b " \t \x7f \xf0' +assert eq(b'123', b'123') +assert not eq(b'123', b'1234') +assert neq(b'123', b'1234') +try: + f('x') + assert False +except TypeError: + pass + +[case testBytesOps] +def test_indexing() -> None: + # Use bytes() to avoid constant folding + b = b'asdf' + bytes() + assert b[0] == 97 + assert b[1] == 115 + assert b[3] == 102 + assert b[-1] == 102 + b = b'\xfe\x15' + bytes() + assert b[0] == 254 + assert b[1] == 21 + +def test_concat() -> None: + b1 = b'123' + bytes() + b2 = b'456' + bytes() + assert b1 + b2 == b'123456' + +def test_join() -> None: + seq = (b'1', b'"', b'\xf0') + assert b'\x07'.join(seq) == b'1\x07"\x07\xf0' + assert b', '.join(()) == b'' + assert b', '.join([bytes() + b'ab']) == b'ab' + assert b', '.join([bytes() + b'ab', b'cd']) == b'ab, cd' + +def test_len() -> None: + # Use bytes() to avoid constant folding + b = b'foo' + bytes() + assert len(b) == 3 + assert len(bytes()) == 0 \ No newline at end of file diff --git a/mypyc/test-data/run-primitives.test b/mypyc/test-data/run-primitives.test index efef3d86edaf..b95f742977be 100644 --- a/mypyc/test-data/run-primitives.test +++ b/mypyc/test-data/run-primitives.test @@ -255,32 +255,6 @@ assert str(to_int(3.14)) == '3' assert str(to_int(3)) == '3' assert get_complex() == 3.5 + 6.2j -[case testBytes] -def f(x: bytes) -> bytes: - return x - -def concat(a: bytes, b: bytes) -> bytes: - return a + b - -def eq(a: bytes, b: bytes) -> bool: - return a == b - -def neq(a: bytes, b: bytes) -> bool: - return a != b - -def join() -> bytes: - seq = (b'1', b'"', b'\xf0') - return b'\x07'.join(seq) -[file driver.py] -from native import f, concat, eq, neq, join -assert f(b'123') == b'123' -assert f(b'\x07 \x0b " \t \x7f \xf0') == b'\x07 \x0b " \t \x7f \xf0' -assert concat(b'123', b'456') == b'123456' -assert eq(b'123', b'123') -assert not eq(b'123', b'1234') -assert neq(b'123', b'1234') -assert join() == b'1\x07"\x07\xf0' - [case testDel] from typing import List from testutil import assertRaises diff --git a/mypyc/test/test_run.py b/mypyc/test/test_run.py index df0228c68d3b..8ed1ac3c3dc6 100644 --- a/mypyc/test/test_run.py +++ b/mypyc/test/test_run.py @@ -36,6 +36,7 @@ 'run-floats.test', 'run-bools.test', 'run-strings.test', + 'run-bytes.test', 'run-tuples.test', 'run-lists.test', 'run-dicts.test',