Skip to content

Commit 81d6bcc

Browse files
authored
Merge pull request #2189 from pytorch/py38_compatibility
Py38 compatibility
2 parents b3089bf + f53a823 commit 81d6bcc

File tree

15 files changed

+47
-15
lines changed

15 files changed

+47
-15
lines changed

py/torch_tensorrt/_Input.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
from enum import Enum
24
from typing import Any, Dict, List, Optional, Sequence, Tuple
35

@@ -32,11 +34,11 @@ class _ShapeMode(Enum):
3234
shape: Optional[
3335
Tuple[int, ...] | Dict[str, Tuple[int, ...]]
3436
] = None #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
35-
dtype: _enums.dtype = ( # type: ignore[name-defined]
37+
dtype: _enums.dtype = (
3638
_enums.dtype.unknown
3739
) #: The expected data type of the input tensor (default: torch_tensorrt.dtype.float32)
3840
_explicit_set_dtype: bool = False
39-
format: _enums.TensorFormat = ( # type: ignore[name-defined]
41+
format: _enums.TensorFormat = (
4042
_enums.TensorFormat.contiguous
4143
) #: The expected format of the input tensor (default: torch_tensorrt.TensorFormat.NCHW)
4244

@@ -208,7 +210,7 @@ def _supported_input_size_type(input_size: Any) -> bool:
208210
return False
209211

210212
@staticmethod
211-
def _parse_dtype(dtype: Any) -> _enums.dtype: # type: ignore[name-defined]
213+
def _parse_dtype(dtype: Any) -> _enums.dtype:
212214
if isinstance(dtype, torch.dtype):
213215
if dtype == torch.long:
214216
return _enums.dtype.long
@@ -236,7 +238,7 @@ def _parse_dtype(dtype: Any) -> _enums.dtype: # type: ignore[name-defined]
236238
)
237239

238240
@staticmethod
239-
def _to_torch_dtype(dtype: _enums.dtype) -> torch.dtype: # type: ignore[name-defined]
241+
def _to_torch_dtype(dtype: _enums.dtype) -> torch.dtype:
240242
if dtype == _enums.dtype.long:
241243
return torch.long
242244
elif dtype == _enums.dtype.int32:
@@ -255,7 +257,7 @@ def is_trt_dtype(self) -> bool:
255257
return bool(self.dtype != _enums.dtype.long)
256258

257259
@staticmethod
258-
def _parse_format(format: Any) -> _enums.TensorFormat: # type: ignore[name-defined]
260+
def _parse_format(format: Any) -> _enums.TensorFormat:
259261
if isinstance(format, torch.memory_format):
260262
if format == torch.contiguous_format:
261263
return _enums.TensorFormat.contiguous

py/torch_tensorrt/_compile.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
from __future__ import annotations
2+
13
from enum import Enum
2-
from typing import Any, Callable, List, Optional, Sequence, Set, TypeGuard
4+
from typing import Any, Callable, List, Optional, Sequence, Set
35

46
import torch
57
import torch.fx
@@ -12,6 +14,7 @@
1214
from torch_tensorrt.fx.lower import compile as fx_compile
1315
from torch_tensorrt.fx.utils import LowerPrecision
1416
from torch_tensorrt.ts._compiler import compile as torchscript_compile
17+
from typing_extensions import TypeGuard
1518

1619

1720
def _non_fx_input_interface(

py/torch_tensorrt/dynamo/aten_tracer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
from __future__ import annotations
2+
13
import copy
24
import sys
35
from contextlib import contextmanager
4-
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple
6+
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
57

68
import torch
79
import torch._dynamo as torchdynamo
@@ -22,7 +24,7 @@
2224
)
2325
from typing_extensions import TypeAlias
2426

25-
Value: TypeAlias = Tuple["Value", ...] | List["Value"] | Dict[str, "Value"]
27+
Value: TypeAlias = Union[Tuple["Value", ...], List["Value"], Dict[str, "Value"]]
2628

2729

2830
class DynamoConfig:

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import logging
24
from functools import partial
35
from typing import Any, Callable, Sequence

py/torch_tensorrt/dynamo/compile.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import collections.abc
24
import logging
35
from typing import Any, List, Optional, Set, Tuple

py/torch_tensorrt/dynamo/conversion/conversion.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import io
24
from typing import Sequence
35

py/torch_tensorrt/dynamo/conversion/converter_registry.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import logging
24
from dataclasses import dataclass, field
35
from enum import Enum, auto
@@ -28,7 +30,7 @@
2830
Dict[str, Argument],
2931
str,
3032
],
31-
TRTTensor | Sequence[TRTTensor],
33+
Union[TRTTensor, Sequence[TRTTensor]],
3234
]
3335

3436

py/torch_tensorrt/dynamo/conversion/impl/shape.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
from typing import List, Optional, Tuple
24

35
import numpy as np

py/torch_tensorrt/dynamo/conversion/truncate_long_and_double.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
from typing import Optional, Sequence, Set
24

35
import torch

py/torch_tensorrt/dynamo/lowering/_pre_aot_lowering.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
from __future__ import annotations
2+
13
import logging
24
from dataclasses import dataclass
3-
from typing import Any, Callable, Dict, Optional, Type, TypeAlias
5+
from typing import Any, Callable, Dict, Optional, Type
46

57
import torch
68
from torch._ops import OpOverload
79
from torch.fx import GraphModule, Node
10+
from typing_extensions import TypeAlias
811

912
logger = logging.getLogger(__name__)
1013

0 commit comments

Comments
 (0)