diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index 5573bb8d28..f5ddce8924 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -227,3 +227,30 @@ jobs: ${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_dynamo_core_partitioning_test_results.xml partitioning/ ${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_dynamo_core_lowering_test_results.xml lowering/ popd + + tests-py-core: + name: Test core [Python] + needs: [generate-matrix, build] + strategy: + fail-fast: false + matrix: + include: + - repository: pytorch/tensorrt + package-name: torch_tensorrt + pre-script: packaging/pre_build_script.sh + uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main + with: + job-name: tests-py-core + repository: "pytorch/tensorrt" + ref: "" + test-infra-repository: pytorch/test-infra + test-infra-ref: main + build-matrix: ${{ needs.generate-matrix.outputs.matrix }} + pre-script: ${{ matrix.pre-script }} + script: | + export USE_HOST_DEPS=1 + pushd . + cd tests/py/core + ${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver + ${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_core_test_results.xml . + popd diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f4ac2ab9e2..61d97503a2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,7 @@ exclude: ^.github/actions/assigner/dist repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v4.5.0 hooks: - id: check-yaml - id: trailing-whitespace @@ -16,38 +16,38 @@ repos: - --fix=lf exclude: ^docs - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v16.0.6 + rev: v18.1.1 hooks: - id: clang-format types_or: [c++, c, cuda] - repo: https://github.com/keith/pre-commit-buildifier - rev: 6.1.0.2 + rev: 6.4.0 hooks: - id: buildifier args: - --warnings=all - id: buildifier-lint - repo: https://github.com/abravalheri/validate-pyproject - rev: v0.13 + rev: v0.16 hooks: - id: validate-pyproject - repo: https://github.com/pycqa/isort - rev: 5.12.0 + rev: 5.13.2 hooks: - id: isort name: isort (python) - repo: https://github.com/pre-commit/mirrors-mypy - rev: 'v1.4.1' + rev: 'v1.9.0' hooks: - id: mypy exclude: "^py/torch_tensorrt/fx|^examples|^tests|^py/torch_tensorrt/dynamo/_experimental|^tools|^docs|noxfile.py|setup.py|versions.py" - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.0.278 + rev: v0.3.3 hooks: - id: ruff - repo: https://github.com/psf/black - rev: 24.1.1 + rev: 24.3.0 hooks: - id: black exclude: ^examples/custom_converters/elu_converter/setup.py|^docs diff --git a/BUILD b/BUILD index c40d52e0f9..3138a5d021 100644 --- a/BUILD +++ b/BUILD @@ -33,6 +33,14 @@ pkg_tar( ], ) +pkg_tar( + name = "include_rt", + package_dir = "include/torch_tensorrt", + deps = [ + "//core/runtime:include", + ], +) + pkg_tar( name = "include", srcs = [ @@ -55,6 +63,18 @@ pkg_tar( package_dir = "lib/", ) +pkg_tar( + name = "lib_rt", + srcs = select({ + ":windows": ["//cpp/lib:torch_tensorrt_runtime.dll"], + "//conditions:default": [ + "//cpp/lib:libtorchtrt_runtime.so", + ], + }), + mode = "0755", + package_dir = "lib/", +) + pkg_tar( name = "bin", srcs = [ @@ -82,3 +102,18 @@ pkg_tar( "//conditions:default": [":bin"], }), ) + +pkg_tar( + name = "libtorchtrt_runtime", + srcs = [ + "//:LICENSE", + "//bzl_def:BUILD", + "//bzl_def:WORKSPACE", + ], + extension = "tar.gz", + package_dir = "torch_tensorrt_runtime", + deps = [ + ":include_rt", + ":lib_rt", + ], +) diff --git a/core/runtime/RTDevice.cpp b/core/runtime/RTDevice.cpp index 34ecc22e97..f78ce306ad 100644 --- a/core/runtime/RTDevice.cpp +++ b/core/runtime/RTDevice.cpp @@ -7,8 +7,6 @@ namespace torch_tensorrt { namespace core { namespace runtime { -const std::string DEVICE_INFO_DELIM = "%"; - typedef enum { ID_IDX = 0, SM_MAJOR_IDX, SM_MINOR_IDX, DEVICE_TYPE_IDX, DEVICE_NAME_IDX } SerializedDeviceInfoIndex; RTDevice::RTDevice() : id{-1}, major{-1}, minor{-1}, device_type{nvinfer1::DeviceType::kGPU} {} diff --git a/core/runtime/RTDevice.h b/core/runtime/RTDevice.h index bd1484d4b0..60963c36e1 100644 --- a/core/runtime/RTDevice.h +++ b/core/runtime/RTDevice.h @@ -6,6 +6,8 @@ namespace torch_tensorrt { namespace core { namespace runtime { +const std::string DEVICE_INFO_DELIM = "%"; + struct RTDevice { int64_t id; // CUDA device id int64_t major; // CUDA compute major version diff --git a/core/runtime/register_jit_hooks.cpp b/core/runtime/register_jit_hooks.cpp index 4ae4f92337..c2a344a307 100644 --- a/core/runtime/register_jit_hooks.cpp +++ b/core/runtime/register_jit_hooks.cpp @@ -116,6 +116,7 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion = TORCH_LIBRARY(tensorrt, m) { m.def("execute_engine", execute_engine); m.def("SERIALIZED_ENGINE_BINDING_DELIM", []() -> std::string { return std::string(1, TRTEngine::BINDING_DELIM); }); + m.def("SERIALIZED_RT_DEVICE_DELIM", []() -> std::string { return DEVICE_INFO_DELIM; }); m.def("ABI_VERSION", []() -> std::string { return ABI_VERSION; }); m.def("get_multi_device_safe_mode", []() -> bool { return MULTI_DEVICE_SAFE_MODE; }); m.def("set_multi_device_safe_mode", [](bool multi_device_safe_mode) -> void { diff --git a/docsrc/py_api/torch_tensorrt.rst b/docsrc/py_api/torch_tensorrt.rst index 22fda13ba2..eb8285e103 100644 --- a/docsrc/py_api/torch_tensorrt.rst +++ b/docsrc/py_api/torch_tensorrt.rst @@ -37,10 +37,6 @@ Classes :members: :special-members: __init__ -.. autoclass:: TRTModuleNext - :members: - :special-members: __init__ - Enums ------- @@ -50,7 +46,7 @@ Enums .. autoclass:: EngineCapability -.. autoclass:: TensorFormat +.. autoclass:: memory_format Submodules ---------- diff --git a/examples/dynamo/torch_compile_stable_diffusion.py b/examples/dynamo/torch_compile_stable_diffusion.py index 0511e5a363..a0b725572b 100644 --- a/examples/dynamo/torch_compile_stable_diffusion.py +++ b/examples/dynamo/torch_compile_stable_diffusion.py @@ -18,9 +18,8 @@ # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ import torch -from diffusers import DiffusionPipeline - import torch_tensorrt +from diffusers import DiffusionPipeline model_id = "CompVis/stable-diffusion-v1-4" device = "cuda:0" @@ -39,7 +38,7 @@ backend=backend, options={ "truncate_long_and_double": True, - "precision": torch.float16, + "enabled_precisions": {torch.float32, torch.float16}, }, dynamic=False, ) diff --git a/py/torch_tensorrt/_Device.py b/py/torch_tensorrt/_Device.py index 6f20b6c84c..4c8c855943 100644 --- a/py/torch_tensorrt/_Device.py +++ b/py/torch_tensorrt/_Device.py @@ -1,3 +1,6 @@ +from __future__ import annotations + +import logging import sys from typing import Any, Optional, Tuple @@ -6,19 +9,11 @@ else: from typing_extensions import Self -import warnings - -# from torch_tensorrt import _enums -import tensorrt as trt import torch -from torch_tensorrt import logging +from torch_tensorrt._enums import DeviceType +from torch_tensorrt._features import ENABLED_FEATURES -try: - from torch_tensorrt import _C -except ImportError: - warnings.warn( - "Unable to import torchscript frontend core and torch-tensorrt runtime. Some dependent features may be unavailable." - ) +import tensorrt as trt class Device(object): @@ -32,9 +27,9 @@ class Device(object): allow_gpu_fallback (bool): Whether falling back to GPU if DLA cannot support an op should be allowed """ - device_type: Optional[trt.DeviceType] = ( - None #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified. - ) + device_type: DeviceType = ( + DeviceType.UNKNOWN + ) #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified. gpu_id: int = -1 #: Device ID for target GPU dla_core: int = -1 #: Core ID for target DLA core allow_gpu_fallback: bool = ( @@ -69,32 +64,31 @@ def __init__(self, *args: Any, **kwargs: Any): ) else: (self.device_type, id) = Device._parse_device_str(args[0]) - if self.device_type == trt.DeviceType.GPU: - self.gpu_id = id - else: + if self.device_type == DeviceType.DLA: self.dla_core = id self.gpu_id = 0 - logging.log( - logging.Level.Warning, - "Setting GPU id to 0 for device because device 0 manages DLA on Xavier", + logging.warning( + "Setting GPU id to 0 for device because device 0 manages DLA on AGX Devices", ) + else: + self.gpu_id = id elif len(args) == 0: if "gpu_id" in kwargs or "dla_core" in kwargs: if "dla_core" in kwargs: - self.device_type = trt.DeviceType.DLA self.dla_core = kwargs["dla_core"] - if "gpu_id" in kwargs: - self.gpu_id = kwargs["gpu_id"] - else: + if "gpu_id" in kwargs: + self.gpu_id = kwargs["gpu_id"] + + if self.dla_core >= 0: + self.device_type = DeviceType.DLA + if self.gpu_id != 0: self.gpu_id = 0 - logging.log( - logging.Level.Warning, - "Setting GPU id to 0 for device because device 0 manages DLA on Xavier", + logging.warning( + "Setting GPU id to 0 for device because device 0 manages DLA on AGX Platforms", ) else: - self.gpu_id = kwargs["gpu_id"] - self.device_type = trt.DeviceType.GPU + self.device_type = DeviceType.GPU else: raise ValueError( "Either gpu_id or dla_core or both must be defined if no string with device specs is provided as an arg" @@ -102,9 +96,7 @@ def __init__(self, *args: Any, **kwargs: Any): else: raise ValueError( - "Unexpected number of positional arguments for class Device \n Found {} arguments, expected either zero or a single positional arguments".format( - len(args) - ) + f"Unexpected number of positional arguments for class Device \n Found {len(args)} arguments, expected either zero or a single positional arguments" ) if "allow_gpu_fallback" in kwargs: @@ -112,58 +104,85 @@ def __init__(self, *args: Any, **kwargs: Any): raise TypeError("allow_gpu_fallback must be a bool") self.allow_gpu_fallback = kwargs["allow_gpu_fallback"] + if "device_type" in kwargs: + if isinstance(kwargs["device_type"], trt.DeviceType): + self.device_type = DeviceType._from(kwargs["device_type"]) + def __str__(self) -> str: - return ( - "Device(type={}, gpu_id={}".format(self.device_type, self.gpu_id) + ")" - if self.device_type == trt.DeviceType.GPU - else ", dla_core={}, allow_gpu_fallback={}".format( - self.dla_core, self.allow_gpu_fallback - ) + suffix = ( + ")" + if self.device_type == DeviceType.GPU + else f", dla_core={self.dla_core}, allow_gpu_fallback={self.allow_gpu_fallback})" ) + dev_str: str = f"Device(type={self.device_type}, gpu_id={self.gpu_id}{suffix}" + return dev_str def __repr__(self) -> str: return self.__str__() - def _to_internal(self) -> _C.Device: - internal_dev = _C.Device() - if self.device_type == trt.DeviceType.GPU: - internal_dev.device_type = _C.DeviceType.GPU - elif self.device_type == trt.DeviceType.DLA: - internal_dev.device_type = _C.DeviceType.DLA - else: - raise ValueError( - "Invalid DeviceType detected while parsing the Device class" - ) + @classmethod + def _from(cls, d: Optional[Self | torch.device | str]) -> Device: + """Cast a device-type to torch_tensorrt.Device + + Returns the corresponding torch_tensorrt.Device + """ + if isinstance(d, Device): + return d - internal_dev.gpu_id = self.gpu_id - internal_dev.dla_core = self.dla_core - internal_dev.allow_gpu_fallback = self.allow_gpu_fallback - return internal_dev + elif isinstance(d, torch.device): + if d.type != "cuda": + raise ValueError('Torch Device specs must have type "cuda"') + return cls(gpu_id=d.index) - def _to_serialized_rt_device(self) -> str: - internal_dev = self._to_internal() - serialized_rt_device: str = internal_dev._to_serialized_rt_device() - return serialized_rt_device + elif d is None: + return cls(gpu_id=torch.cuda.current_device()) + + else: + return cls(d) @classmethod - def _from_torch_device(cls, torch_dev: torch.device) -> Self: - if torch_dev.type != "cuda": - raise ValueError('Torch Device specs must have type "cuda"') - gpu_id = torch_dev.index - return cls(gpu_id=gpu_id) + def _from_torch_device(cls, torch_dev: torch.device) -> Device: + return cls._from(torch_dev) @classmethod - def _current_device(cls) -> Self: - dev = _C._get_current_device() - return cls(gpu_id=dev.gpu_id) + def _current_device(cls) -> Device: + dev_id = torch.cuda.current_device() + return cls(gpu_id=dev_id) @staticmethod def _parse_device_str(s: str) -> Tuple[trt.DeviceType, int]: s = s.lower() spec = s.split(":") if spec[0] == "gpu" or spec[0] == "cuda": - return (trt.DeviceType.GPU, int(spec[1])) + return (DeviceType.GPU, int(spec[1])) elif spec[0] == "dla": - return (trt.DeviceType.DLA, int(spec[1])) + return (DeviceType.DLA, int(spec[1])) else: raise ValueError(f"Unknown device type {spec[0]}") + + def to(self, t: type) -> torch.device: + if t == torch.device: + if self.gpu_id != -1: + return torch.device(self.gpu_id) + else: + raise ValueError("Invalid GPU ID provided for the CUDA device provided") + else: + raise TypeError("Unsupported target type for device conversion") + + def _to_serialized_rt_device(self) -> str: + if not ENABLED_FEATURES.torch_tensorrt_runtime: + raise NotImplementedError("Torch-TensorRT runtime is not available") + + delim = torch.ops.tensorrt.SERIALIZED_RT_DEVICE_DELIM()[0] + dev_info = torch.cuda.get_device_properties(self.gpu_id) + rt_info = [ + self.gpu_id, + dev_info.major, + dev_info.minor, + int(self.device_type.to(trt.DeviceType)), # type: ignore[arg-type] + dev_info.name, + ] + rt_info = [str(i) for i in rt_info] + packed_rt_info: str = delim.join(rt_info) + logging.debug(f"Serialized Device Info: {packed_rt_info}") + return packed_rt_info diff --git a/py/torch_tensorrt/_Input.py b/py/torch_tensorrt/_Input.py index db36678d17..32f19ce1f0 100644 --- a/py/torch_tensorrt/_Input.py +++ b/py/torch_tensorrt/_Input.py @@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple import torch -from torch_tensorrt import _enums +from torch_tensorrt._enums import dtype, memory_format class Input(object): @@ -34,18 +34,17 @@ class _ShapeMode(Enum): shape: Optional[Tuple[int, ...] | Dict[str, Tuple[int, ...]]] = ( None #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }`` ) - dtype: _enums.dtype = ( - _enums.dtype.unknown + dtype: dtype = ( + dtype.unknown ) #: The expected data type of the input tensor (default: torch_tensorrt.dtype.float32) _explicit_set_dtype: bool = False - format: _enums.TensorFormat = ( - _enums.TensorFormat.contiguous - ) #: The expected format of the input tensor (default: torch_tensorrt.TensorFormat.NCHW) + format: memory_format = ( + memory_format.linear + ) #: The expected format of the input tensor (default: torch_tensorrt.memory_format.linear) DOMAIN_OFFSET: float = 2.0 low_tensor_domain_incl: float = 0.0 high_tensor_domain_excl: float = low_tensor_domain_incl + DOMAIN_OFFSET - torch_dtype: torch.dtype = torch.float32 torch_tensor: torch.Tensor = None name: str = "" @@ -151,21 +150,19 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: else: raise ValueError( - "Unexpected number of positional arguments for class Input \n Found {} arguments, expected either zero or a single positional arguments".format( - len(args) - ) + f"Unexpected number of positional arguments for class Input \n Found {len(args)} arguments, expected either zero or a single positional arguments" ) if "dtype" in kwargs: - if isinstance(kwargs["dtype"], torch.dtype): - self.torch_dtype = kwargs["dtype"] + self.dtype = dtype._from(kwargs["dtype"]) - self.dtype = Input._parse_dtype(kwargs["dtype"]) - self.torch_dtype = Input._to_torch_dtype(self.dtype) + if self.dtype != dtype.unknown: self._explicit_set_dtype = True + else: + self._explicit_set_dtype = False if "format" in kwargs: - self.format = Input._parse_format(kwargs["format"]) + self.format = memory_format._from(kwargs["format"]) if "tensor_domain" in kwargs: domain = kwargs["tensor_domain"] @@ -212,6 +209,9 @@ def __str__(self) -> str: else: raise RuntimeError("Unknown input shape mode") + def __repr__(self) -> str: + return self.__str__() + @staticmethod def _supported_input_size_type(input_size: Any) -> bool: if isinstance(input_size, torch.Size): @@ -223,77 +223,6 @@ def _supported_input_size_type(input_size: Any) -> bool: else: return False - @staticmethod - def _parse_dtype(dtype: Any) -> _enums.dtype: - if isinstance(dtype, torch.dtype): - if dtype == torch.long: - return _enums.dtype.long - elif dtype == torch.int32: - return _enums.dtype.int32 - elif dtype == torch.half: - return _enums.dtype.half - elif dtype == torch.float: - return _enums.dtype.float - elif dtype == torch.float64: - return _enums.dtype.double - elif dtype == torch.bool: - return _enums.dtype.bool - else: - raise TypeError( - "Provided an unsupported data type as an input data type (support: bool, int32, long, half, float), got: " - + str(dtype) - ) - - elif isinstance(dtype, _enums.dtype): - return dtype - - else: - raise TypeError( - "Input data type needs to be specified with a torch.dtype or a torch_tensorrt.dtype, got: " - + str(type(dtype)) - ) - - @staticmethod - def _to_torch_dtype(dtype: _enums.dtype) -> torch.dtype: - if dtype == _enums.dtype.long: - return torch.long - elif dtype == _enums.dtype.int32: - return torch.int32 - elif dtype == _enums.dtype.half: - return torch.half - elif dtype == _enums.dtype.float: - return torch.float - elif dtype == _enums.dtype.bool: - return torch.bool - elif dtype == _enums.dtype.double: - return torch.float64 - else: - # Default torch_dtype used in FX path - return torch.float32 - - def is_trt_dtype(self) -> bool: - return bool(self.dtype != _enums.dtype.long) - - @staticmethod - def _parse_format(format: Any) -> _enums.TensorFormat: - if isinstance(format, torch.memory_format): - if format == torch.contiguous_format: - return _enums.TensorFormat.contiguous - elif format == torch.channels_last: - return _enums.TensorFormat.channels_last - else: - raise ValueError( - "Provided an unsupported tensor format (support: NCHW/contiguous_format, NHWC/channel_last)" - ) - - elif isinstance(format, _enums.TensorFormat): - return format - - else: - raise TypeError( - "Tensor format needs to be specified with either torch.memory_format or torch_tensorrt.TensorFormat" - ) - @staticmethod def _parse_tensor_domain( domain: Optional[Tuple[float, float]] @@ -415,7 +344,9 @@ def example_tensor( ) else: if isinstance(self.shape, tuple): - return torch.rand(self.shape).to(dtype=self.torch_dtype) + return torch.rand(self.shape).to( + dtype=self.dtype.to(torch.dtype, use_default=True) + ) else: RuntimeError( f"Input shape is dynamic but shapes are not provided as sequence (found: {self.shape})" @@ -434,7 +365,7 @@ def example_tensor( if isinstance(self.shape, dict): return torch.rand(self.shape[optimization_profile_field]).to( - dtype=self.torch_dtype + dtype=self.dtype.to(torch.dtype, use_default=True) ) else: raise RuntimeError( diff --git a/py/torch_tensorrt/__init__.py b/py/torch_tensorrt/__init__.py index b9d2af39c5..f95f33bc74 100644 --- a/py/torch_tensorrt/__init__.py +++ b/py/torch_tensorrt/__init__.py @@ -80,26 +80,46 @@ def _find_lib(name: str, paths: List[str]) -> str: for lib in LINUX_LIBS: ctypes.CDLL(_find_lib(lib, LINUX_PATHS)) -import torch -from torch_tensorrt._compile import * # noqa: F403 -from torch_tensorrt._Device import Device # noqa: F401 -from torch_tensorrt._enums import * # noqa: F403 -from torch_tensorrt._Input import Input # noqa: F401 -from torch_tensorrt._utils import * # noqa: F403 -from torch_tensorrt._utils import sanitized_torch_version -from torch_tensorrt.logging import * -from torch_tensorrt.ptq import * -from torch_tensorrt.runtime import * # noqa: F403 +import logging -if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"): - from torch_tensorrt.dynamo import backend # noqa: F401 +import torch +from torch_tensorrt._features import ENABLED_FEATURES, _enabled_features_str - from torch_tensorrt import dynamo # noqa: F401 +_LOGGER = logging.getLogger(__name__) +_LOGGER.debug(_enabled_features_str()) def _register_with_torch() -> None: trtorch_dir = os.path.dirname(__file__) - torch.ops.load_library(trtorch_dir + "/lib/libtorchtrt.so") + if os.path.isfile(trtorch_dir + "/lib/libtorchtrt.so"): + assert ENABLED_FEATURES.torchscript_frontend + assert ENABLED_FEATURES.torch_tensorrt_runtime + torch.ops.load_library(trtorch_dir + "/lib/libtorchtrt.so") + elif os.path.isfile(trtorch_dir + "/lib/libtorchtrt_runtime.so"): + assert ENABLED_FEATURES.torch_tensorrt_runtime + torch.ops.load_library(trtorch_dir + "/lib/libtorchtrt_runtime.so") _register_with_torch() + +from torch_tensorrt._Device import Device # noqa: F401 +from torch_tensorrt._enums import ( # noqa: F401 + DeviceType, + EngineCapability, + dtype, + memory_format, +) +from torch_tensorrt._Input import Input # noqa: F401 +from torch_tensorrt.runtime import * # noqa: F403 + +if ENABLED_FEATURES.torchscript_frontend: + from torch_tensorrt import ts + +if ENABLED_FEATURES.fx_frontend: + from torch_tensorrt import fx + +if ENABLED_FEATURES.dynamo_frontend: + from torch_tensorrt.dynamo import backend # noqa: F401 + from torch_tensorrt import dynamo # noqa: F401 + +from torch_tensorrt._compile import * # noqa: F403 diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index 9dd816e633..1381971047 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -6,23 +6,29 @@ import torch import torch.fx -import torch_tensorrt.ts from torch_tensorrt._enums import dtype +from torch_tensorrt._features import ENABLED_FEATURES from torch_tensorrt._Input import Input -from torch_tensorrt._utils import sanitized_torch_version +from torch_tensorrt.dynamo import _defaults from torch_tensorrt.fx import InputTensorSpec from torch_tensorrt.fx.lower import compile as fx_compile from torch_tensorrt.fx.utils import LowerPrecision -from torch_tensorrt.ts._compiler import compile as torchscript_compile from typing_extensions import TypeGuard -from packaging import version - -DYNAMO_ENABLED = version.parse(sanitized_torch_version()) >= version.parse("2.1.dev") +if ENABLED_FEATURES.torchscript_frontend: + import torch_tensorrt.ts + from torch_tensorrt.ts._compiler import compile as torchscript_compile + from torch_tensorrt.ts._compiler import ( + convert_method_to_trt_engine as ts_convert_method_to_trt_engine, + ) -if DYNAMO_ENABLED: +if ENABLED_FEATURES.dynamo_frontend: from torch._export import ExportedProgram from torch_tensorrt.dynamo._compiler import compile as dynamo_compile + from torch_tensorrt.dynamo._compiler import ( + convert_module_to_trt_engine as dynamo_convert_module_to_trt_engine, + ) + from torch_tensorrt.dynamo._tracer import trace as dynamo_trace logger = logging.getLogger(__name__) @@ -71,7 +77,7 @@ def _parse_module_type(module: Any) -> _ModuleType: return _ModuleType.ts elif isinstance(module, torch.fx.GraphModule): return _ModuleType.fx - elif DYNAMO_ENABLED and isinstance(module, ExportedProgram): + elif isinstance(module, ExportedProgram): return _ModuleType.ep elif isinstance(module, torch.nn.Module): return _ModuleType.nn @@ -79,7 +85,7 @@ def _parse_module_type(module: Any) -> _ModuleType: raise RuntimeError("Module is an unknown format") -def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType: +def _get_target_fe(module_type: _ModuleType, ir: str) -> _IRType: module_is_tsable = any(module_type == t for t in [_ModuleType.nn, _ModuleType.ts]) module_is_fxable = any(module_type == t for t in [_ModuleType.nn, _ModuleType.fx]) module_is_exportable = module_type == _ModuleType.ep @@ -90,35 +96,52 @@ def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType: ir_targets_torch_compile = ir == "torch_compile" if module_is_tsable and ir_targets_torchscript: - return _IRType.ts + if ENABLED_FEATURES.torchscript_frontend: + return _IRType.ts + else: + raise ValueError( + "Requested using the TS frontend but the TS frontend is not available in this build of Torch-TensorRT" + ) elif module_is_fxable and ir_targets_fx: - return _IRType.fx - elif module_is_fxable and ir_targets_dynamo: - return _IRType.dynamo + if ENABLED_FEATURES.fx_frontend: + return _IRType.fx + else: + raise ValueError( + "Requested using the FX frontend but the FX frontend is not available in this build of Torch-TensorRT" + ) + elif (module_is_fxable or module_is_exportable) and ir_targets_dynamo: + if ENABLED_FEATURES.dynamo_frontend: + return _IRType.dynamo + else: + raise ValueError( + "Requested using the Dynamo frontend but the Dynamo frontend is not available in this build of Torch-TensorRT" + ) elif module_is_fxable and ir_targets_torch_compile: - return _IRType.torch_compile + if ENABLED_FEATURES.dynamo_frontend: + return _IRType.torch_compile + else: + raise ValueError( + "Requested using the Torch-TensorRT torch.compile backend but the Torch-TensorRT torch.compile backend is not available in this build of Torch-TensorRT" + ) else: if ir == "default": # Options are listed in order of preference - if DYNAMO_ENABLED and module_is_fxable: - logger.info("ir was set to default, using dynamo as ir") + if ENABLED_FEATURES.dynamo_frontend and module_is_fxable: + logger.info("ir was set to default, using dynamo frontend") return _IRType.dynamo - elif module_is_tsable: - if DYNAMO_ENABLED: + elif ENABLED_FEATURES.torchscript_frontend and module_is_tsable: + if ENABLED_FEATURES.dynamo_frontend: logger.warning( - "Input graph is a Torchscript module but the ir provided is default (dynamo). Please set ir=torchscript to suppress the warning. Compiling the module with ir=torchscript" + "Input is a torchscript module but the ir was not specified (default=dynamo), please set ir=torchscript to suppress the warning." ) return _IRType.ts - elif module_is_exportable: + elif ENABLED_FEATURES.dynamo_frontend and module_is_exportable: + logger.info("ir was set to default, using dynamo frontend") + return _IRType.dynamo + else: raise ValueError( - "Input graph is an ExportedProgram which is not currently supported. Please provide torch.nn.Module or torch.fx.GraphModule as input." + f"Module was provided in an unsupported format\nInstalled frontends:\n\tDynamo - {ENABLED_FEATURES.dynamo_frontend}\n\tTorchScript - {ENABLED_FEATURES.torchscript_frontend}\n\tFX - {ENABLED_FEATURES.fx_frontend})" ) - else: - raise ValueError("Module was provided in an unsupported format") - elif ir == "exported_program": - raise ValueError( - "ir=exported_program is not currently supported. Supported ir options : ts|fx|dynamo" - ) else: raise ValueError("Unknown ir was requested") @@ -168,12 +191,14 @@ def compile( torch.nn.Module: Compiled Module, when run it will execute via TensorRT """ input_list = inputs if inputs is not None else [] - enabled_precisions_set = ( - enabled_precisions if enabled_precisions is not None else {torch.float} + enabled_precisions_set: Set[dtype | torch.dtype] = ( + enabled_precisions + if enabled_precisions is not None + else _defaults.ENABLED_PRECISIONS ) module_type = _parse_module_type(module) - target_ir = _get_target_ir(module_type, ir) + target_ir = _get_target_fe(module_type, ir) if target_ir == _IRType.ts: ts_mod = module if module_type == _ModuleType.nn: @@ -224,7 +249,7 @@ def compile( # Export the module torchtrt_inputs = prepare_inputs(input_list) - exp_program = torch_tensorrt.dynamo.trace(module, torchtrt_inputs, **kwargs) + exp_program = dynamo_trace(module, torchtrt_inputs, **kwargs) trt_graph_module = dynamo_compile( exp_program, inputs=torchtrt_inputs, @@ -299,7 +324,7 @@ def convert_method_to_trt_engine( ) module_type = _parse_module_type(module) - target_ir = _get_target_ir(module_type, ir) + target_ir = _get_target_fe(module_type, ir) if target_ir == _IRType.ts: ts_mod = module if module_type == _ModuleType.nn: @@ -307,19 +332,20 @@ def convert_method_to_trt_engine( "Module was provided as a torch.nn.Module, trying to script the module with torch.jit.script. In the event of a failure please preconvert your module to TorchScript" ) ts_mod = torch.jit.script(module) - return torch_tensorrt.ts.convert_method_to_trt_engine( # type: ignore[no-any-return] + serialized_engine: bytes = ts_convert_method_to_trt_engine( ts_mod, inputs=inputs, method_name=method_name, enabled_precisions=enabled_precisions_set, **kwargs, ) + return serialized_engine elif target_ir == _IRType.fx: raise RuntimeError( "convert_method_to_trt_engine call is not supported for ir=fx" ) elif target_ir == _IRType.dynamo: - return torch_tensorrt.dynamo.convert_module_to_trt_engine( # type: ignore[no-any-return] + return dynamo_convert_module_to_trt_engine( # type: ignore[no-any-return] module, inputs=inputs, method_name=method_name, diff --git a/py/torch_tensorrt/_enums.py b/py/torch_tensorrt/_enums.py index 44cb772dc3..350d8a299e 100644 --- a/py/torch_tensorrt/_enums.py +++ b/py/torch_tensorrt/_enums.py @@ -1,3 +1,729 @@ -from torch_tensorrt._C import EngineCapability, TensorFormat, dtype # noqa: F401 +from __future__ import annotations -from tensorrt import DeviceType # noqa: F401 +import logging +from enum import Enum, auto +from typing import Any, Optional, Type, Union + +import numpy as np +import torch +from torch_tensorrt._features import ENABLED_FEATURES + +import tensorrt as trt + + +class dtype(Enum): + """Enum to set supported dtypes in the compiler""" + + # Supported types in Torch-TensorRT + unknown = auto() + u8 = auto() + i8 = auto() + i32 = auto() + i64 = auto() + f16 = auto() + f32 = auto() + f64 = auto() + b = auto() + # TODO: Enable FP8 and BF16 + # f8 = auto() + # bf16 = auto() + + uint8 = u8 + int8 = i8 + + int32 = i32 + + long = i64 + int64 = i64 + + half = f16 + fp16 = f16 + float16 = f16 + + float = f32 + fp32 = f32 + float32 = f32 + + double = f64 + fp64 = f64 + float64 = f64 + + # TODO: Enable when FP8 is enabled + # float8 = f8 + # fp8 = f8 + + # TODO: Enable when BF16 is enabled + # bfloat16 = bf16 + + @staticmethod + def _is_np_obj(t: Any) -> bool: + if isinstance(t, np.dtype): + return True + elif isinstance(t, type): + if issubclass(t, np.generic): + return True + return False + + @classmethod + def _from( + cls, + t: Union[torch.dtype, trt.DataType, np.dtype, dtype, type], + use_default: bool = False, + ) -> dtype: + # TODO: Ideally implemented with match statement but need to wait for Py39 EoL + if isinstance(t, torch.dtype): + if t == torch.uint8: + return dtype.u8 + elif t == torch.int8: + return dtype.i8 + elif t == torch.long: + return dtype.i64 + elif t == torch.int32: + return dtype.i32 + elif t == torch.half: + return dtype.f16 + elif t == torch.float: + return dtype.f32 + elif t == torch.float64: + return dtype.f64 + elif t == torch.bool: + return dtype.b + elif use_default: + logging.warning( + f"Given dtype that does not have direct mapping to Torch-TensorRT supported types ({t}), defaulting to torch_tensorrt.dtype.float" + ) + return dtype.float + else: + raise TypeError( + f"Provided an unsupported data type as an input data type (support: bool, int32, long, half, float), got: {t}" + ) + elif isinstance(t, trt.DataType): + if t == trt.uint8: + return dtype.u8 + elif t == trt.int8: + return dtype.i8 + elif t == trt.int32: + return dtype.i32 + elif t == trt.float16: + return dtype.f16 + elif t == trt.float32: + return dtype.f32 + elif trt.__version__ >= "7.0" and t == trt.bool: + return dtype.b + else: + raise TypeError( + f"Provided an unsupported data type as an input data type (support: bool, int32, half, float), got: {t}" + ) + + elif dtype._is_np_obj(t): + if t == np.uint8: + return dtype.u8 + elif t == np.int8: + return dtype.i8 + elif t == np.int32: + return dtype.i32 + elif t == np.int64: + return dtype.i64 + elif t == np.float16: + return dtype.f16 + elif t == np.float32: + return dtype.f32 + elif t == np.float64: + return dtype.f64 + elif t == np.bool: + return dtype.b + elif use_default: + logging.warning( + f"Given dtype that does not have direct mapping to Torch-TensorRT supported types ({t}), defaulting to torch_tensorrt.dtype.float" + ) + return dtype.float + else: + raise TypeError( + "Provided an unsupported data type as an input data type (support: bool, int32, long, half, float), got: " + + str(t) + ) + + elif isinstance(t, dtype): + return t + + elif ENABLED_FEATURES.torchscript_frontend: + from torch_tensorrt import _C + + if isinstance(t, _C.dtype): + if t == _C.dtype.long: + return dtype.i64 + elif t == _C.dtype.int32: + return dtype.i32 + elif t == _C.dtype.int8: + return dtype.i8 + elif t == _C.dtype.half: + return dtype.f16 + elif t == _C.dtype.float: + return dtype.f32 + elif t == _C.dtype.double: + return dtype.f64 + elif t == _C.dtype.bool: + return dtype.b + elif t == _C.dtype.unknown: + return dtype.unknown + else: + raise TypeError( + f"Provided an unsupported data type as an input data type (support: bool, int32, long, half, float), got: {t}" + ) + # else: # commented out for mypy + raise TypeError( + f"Provided unsupported source type for dtype conversion (got: {t})" + ) + + @classmethod + def try_from( + cls, + t: Union[torch.dtype, trt.DataType, np.dtype, dtype], + use_default: bool = False, + ) -> Optional[dtype]: + try: + casted_format = dtype._from(t, use_default=use_default) + return casted_format + except (ValueError, TypeError) as e: + logging.debug( + f"Conversion from {t} to torch_tensorrt.dtype failed", exc_info=True + ) + return None + + def to( + self, + t: Union[Type[torch.dtype], Type[trt.DataType], Type[np.dtype], Type[dtype]], + use_default: bool = False, + ) -> Union[torch.dtype, trt.DataType, np.dtype, dtype]: + # TODO: Ideally implemented with match statement but need to wait for Py39 EoL + if t == torch.dtype: + if self == dtype.u8: + return torch.uint8 + elif self == dtype.i8: + return torch.int8 + elif self == dtype.i32: + return torch.int + elif self == dtype.i64: + return torch.long + elif self == dtype.f16: + return torch.half + elif self == dtype.f32: + return torch.float + elif self == dtype.f64: + return torch.double + elif self == dtype.b: + return torch.bool + elif use_default: + logging.warning( + f"Given dtype that does not have direct mapping to torch ({self}), defaulting to torch.float" + ) + return torch.float + else: + raise TypeError(f"Unsupported torch dtype (had: {self})") + + elif t == trt.DataType: + if self == dtype.u8: + return trt.DataType.UINT8 + if self == dtype.i8: + return trt.DataType.INT8 + elif self == dtype.i32: + return trt.DataType.INT32 + elif self == dtype.f16: + return trt.DataType.HALF + elif self == dtype.f32: + return trt.DataType.FLOAT + elif self == dtype.b: + return trt.DataType.BOOL + elif use_default: + return trt.DataType.FLOAT + else: + raise TypeError("Unsupported tensorrt dtype") + + elif t == np.dtype: + if self == dtype.u8: + return np.uint8 + elif self == dtype.i8: + return np.int8 + elif self == dtype.i32: + return np.int32 + elif self == dtype.i64: + return np.int64 + elif self == dtype.f16: + return np.float16 + elif self == dtype.f32: + return np.float32 + elif self == dtype.f64: + return np.float64 + elif self == dtype.b: + return np.bool_ + elif use_default: + return np.float32 + else: + raise TypeError("Unspported numpy dtype") + + elif t == dtype: + return self + + elif ENABLED_FEATURES.torchscript_frontend: + from torch_tensorrt import _C + + if t == _C.dtype: + if self == dtype.i64: + return _C.dtype.long + elif self == dtype.i8: + return _C.dtype.int8 + elif self == dtype.i32: + return _C.dtype.int32 + elif self == dtype.f16: + return _C.dtype.half + elif self == dtype.f32: + return _C.dtype.float + elif self == dtype.f64: + return _C.dtype.double + elif self == dtype.b: + return _C.dtype.bool + elif self == dtype.unknown: + return _C.dtype.unknown + else: + raise TypeError( + f"Provided an unsupported data type as an input data type (support: bool, int32, long, half, float), got: {self}" + ) + # else: # commented out for mypy + raise TypeError( + f"Provided unsupported destination type for dtype conversion {t}" + ) + + def try_to( + self, + t: Union[Type[torch.dtype], Type[trt.DataType], Type[np.dtype], Type[dtype]], + use_default: bool, + ) -> Optional[Union[torch.dtype, trt.DataType, np.dtype, dtype]]: + try: + casted_format = self.to(t, use_default) + return casted_format + except (ValueError, TypeError) as e: + logging.debug( + f"torch_tensorrt.dtype conversion to target type {t} failed", + exc_info=True, + ) + return None + + def __eq__(self, other: Union[torch.dtype, trt.DataType, np.dtype, dtype]) -> bool: + other_ = dtype._from(other) + return bool(self.value == other_.value) + + def __hash__(self) -> int: + return hash(self.value) + + # Putting aliases here that mess with mypy + bool = b + int = i32 + + +class memory_format(Enum): + + # TensorRT supported memory layouts + linear = auto() + chw2 = auto() + hwc8 = auto() + chw4 = auto() + chw16 = auto() + chw32 = auto() + dhwc8 = auto() + cdhw32 = auto() + hwc = auto() + dla_linear = auto() + dla_hwc4 = auto() + hwc16 = auto() + dhwc = auto() + + # PyTorch aliases for TRT layouts + contiguous = linear + channels_last = hwc + channels_last_3d = dhwc + + @classmethod + def _from( + cls, f: Union[torch.memory_format, trt.TensorFormat, memory_format] + ) -> memory_format: + # TODO: Ideally implemented with match statement but need to wait for Py39 EoL + if isinstance(f, torch.memory_format): + if f == torch.contiguous_format: + return memory_format.contiguous + elif f == torch.channels_last: + return memory_format.channels_last + elif f == torch.channels_last_3d: + return memory_format.channels_last_3d + else: + raise TypeError( + f"Provided an unsupported memory format for tensor, got: {dtype}" + ) + + elif isinstance(f, trt.DataType): + if f == trt.TensorFormat.LINEAR: + return memory_format.linear + elif f == trt.TensorFormat.CHW2: + return memory_format.chw2 + elif f == trt.TensorFormat.HWC8: + return memory_format.hwc8 + elif f == trt.TensorFormat.CHW4: + return memory_format.chw4 + elif f == trt.TensorFormat.CHW16: + return memory_format.chw16 + elif f == trt.TensorFormat.CHW32: + return memory_format.chw32 + elif f == trt.TensorFormat.DHWC8: + return memory_format.dhwc8 + elif f == trt.TensorFormat.CDHW32: + return memory_format.cdhw32 + elif f == trt.TensorFormat.HWC: + return memory_format.hwc + elif f == trt.TensorFormat.DLA_LINEAR: + return memory_format.dla_linear + elif f == trt.TensorFormat.DLA_HWC4: + return memory_format.dla_hwc4 + elif f == trt.TensorFormat.HWC16: + return memory_format.hwc16 + elif f == trt.TensorFormat.DHWC: + return memory_format.dhwc + else: + raise TypeError( + f"Provided an unsupported tensor format for tensor, got: {dtype}" + ) + + elif isinstance(f, memory_format): + return f + + elif ENABLED_FEATURES.torchscript_frontend: + from torch_tensorrt import _C + + if isinstance(f, _C.TensorFormat): + if f == _C.TensorFormat.contiguous: + return memory_format.contiguous + elif f == _C.TensorFormat.channels_last: + return memory_format.channels_last + else: + raise ValueError( + "Provided an unsupported tensor format (support: NCHW/contiguous_format, NHWC/channel_last)" + ) + # else: # commented out for mypy + raise TypeError("Provided unsupported source type for memory_format conversion") + + @classmethod + def try_from( + cls, f: Union[torch.memory_format, trt.TensorFormat, memory_format] + ) -> Optional[memory_format]: + try: + casted_format = memory_format._from(f) + return casted_format + except (ValueError, TypeError) as e: + logging.debug( + f"Conversion from {f} to torch_tensorrt.memory_format failed", + exc_info=True, + ) + return None + + def to( + self, + t: Union[ + Type[torch.memory_format], Type[trt.TensorFormat], Type[memory_format] + ], + ) -> Union[torch.memory_format, trt.TensorFormat, memory_format]: + if t == torch.memory_format: + if self == memory_format.contiguous: + return torch.contiguous_format + elif self == memory_format.channels_last: + return torch.channels_last + elif self == memory_format.channels_last_3d: + return torch.channels_last_3d + else: + raise TypeError("Unsupported torch dtype") + + elif t == trt.TensorFormat: + if self == memory_format.linear: + return trt.TensorFormat.LINEAR + elif self == memory_format.chw2: + return trt.TensorFormat.CHW2 + elif self == memory_format.hwc8: + return trt.TensorFormat.HWC8 + elif self == memory_format.chw4: + return trt.TensorFormat.CHW4 + elif self == memory_format.chw16: + return trt.TensorFormat.CHW16 + elif self == memory_format.chw32: + return trt.TensorFormat.CHW32 + elif self == memory_format.dhwc8: + return trt.TensorFormat.DHWC8 + elif self == memory_format.cdhw32: + return trt.TensorFormat.CDHW32 + elif self == memory_format.hwc: + return trt.TensorFormat.HWC + elif self == memory_format.dla_linear: + return trt.TensorFormat.DLA_LINEAR + elif self == memory_format.dla_hwc4: + return trt.TensorFormat.DLA_HWC4 + elif self == memory_format.hwc16: + return trt.TensorFormat.HWC16 + elif self == memory_format.dhwc: + return trt.TensorFormat.DHWC + else: + raise TypeError("Unsupported tensorrt memory format") + + elif t == memory_format: + return self + + elif ENABLED_FEATURES.torchscript_frontend: + from torch_tensorrt import _C + + if t == _C.TensorFormat: + if self == memory_format.contiguous: + return _C.TensorFormat.contiguous + elif self == memory_format.channels_last: + return _C.TensorFormat.channels_last + else: + raise ValueError( + "Provided an unsupported tensor format (support: NCHW/contiguous_format, NHWC/channel_last)" + ) + # else: # commented out for mypy + raise TypeError( + "Provided unsupported destination type for memory format conversion" + ) + + def try_to( + self, + t: Union[ + Type[torch.memory_format], Type[trt.TensorFormat], Type[memory_format] + ], + ) -> Optional[Union[torch.memory_format, trt.TensorFormat, memory_format]]: + try: + casted_format = self.to(t) + return casted_format + except (ValueError, TypeError) as e: + logging.debug( + f"torch_tensorrt.memory_format conversion to target type {t} failed", + exc_info=True, + ) + return None + + def __eq__( + self, other: Union[torch.memory_format, trt.TensorFormat, memory_format] + ) -> bool: + other_ = memory_format._from(other) + return self.value == other_.value + + def __hash__(self) -> int: + return hash(self.value) + + +class DeviceType(Enum): + UNKNOWN = auto() + GPU = auto() + DLA = auto() + + @classmethod + def _from(cls, d: Union[trt.DeviceType, DeviceType]) -> DeviceType: + if isinstance(d, trt.DeviceType): + if d == trt.DeviceType.GPU: + return DeviceType.GPU + elif d == trt.DeviceType.DLA: + return DeviceType.DLA + else: + raise ValueError( + "Provided an unsupported device type (support: GPU/DLA)" + ) + + elif isinstance(d, DeviceType): + return d + + elif ENABLED_FEATURES.torchscript_frontend: + from torch_tensorrt import _C + + if isinstance(d, _C.DeviceType): + if d == _C.DeviceType.GPU: + return DeviceType.GPU + elif d == _C.DeviceType.DLA: + return DeviceType.DLA + else: + raise ValueError( + "Provided an unsupported device type (support: GPU/DLA)" + ) + # else: # commented out for mypy + raise TypeError("Provided unsupported source type for DeviceType conversion") + + @classmethod + def try_from(cls, d: Union[trt.DeviceType, DeviceType]) -> Optional[DeviceType]: + try: + casted_format = DeviceType._from(d) + return casted_format + except (ValueError, TypeError) as e: + logging.debug( + f"Conversion from {d} to torch_tensorrt.DeviceType failed", + exc_info=True, + ) + return None + + def to( + self, + t: Union[Type[trt.DeviceType], Type[DeviceType]], + use_default: bool = False, + ) -> Union[trt.DeviceType, DeviceType]: + if t == trt.DeviceType: + if self == DeviceType.GPU: + return trt.DeviceType.GPU + elif self == DeviceType.DLA: + return trt.DeviceType.DLA + elif use_default: + return trt.DeviceType.GPU + else: + raise ValueError( + "Provided an unsupported device type (support: GPU/DLA)" + ) + + elif t == DeviceType: + return self + + elif ENABLED_FEATURES.torchscript_frontend: + from torch_tensorrt import _C + + if t == _C.DeviceType: + if self == DeviceType.GPU: + return _C.DeviceType.GPU + elif self == DeviceType.DLA: + return _C.DeviceType.DLA + else: + raise ValueError( + "Provided an unsupported device type (support: GPU/DLA)" + ) + # else: # commented out for mypy + raise TypeError( + "Provided unsupported destination type for device type conversion" + ) + + def try_to( + self, + t: Union[Type[trt.DeviceType], Type[DeviceType]], + use_default: bool = False, + ) -> Optional[Union[trt.DeviceType, DeviceType]]: + try: + casted_format = self.to(t, use_default=use_default) + return casted_format + except (ValueError, TypeError) as e: + logging.debug( + f"torch_tensorrt.DeviceType conversion to target type {t} failed", + exc_info=True, + ) + return None + + def __eq__(self, other: Union[trt.DeviceType, DeviceType]) -> bool: + other_ = DeviceType._from(other) + return bool(self.value == other_.value) + + def __hash__(self) -> int: + return hash(self.value) + + +class EngineCapability(Enum): + STANDARD = auto() + SAFETY = auto() + DLA_STANDALONE = auto() + + @classmethod + def _from( + cls, c: Union[trt.EngineCapability, EngineCapability] + ) -> EngineCapability: + if isinstance(c, trt.EngineCapability): + if c == trt.EngineCapability.STANDARD: + return EngineCapability.STANDARD + elif c == trt.EngineCapability.SAFETY: + return EngineCapability.SAFETY + elif c == trt.EngineCapability.DLA_STANDALONE: + return EngineCapability.DLA_STANDALONE + else: + raise ValueError("Provided an unsupported engine capability") + + elif isinstance(c, EngineCapability): + return c + + elif ENABLED_FEATURES.torchscript_frontend: + from torch_tensorrt import _C + + if isinstance(c, _C.EngineCapability): + if c == _C.EngineCapability.STANDARD: + return EngineCapability.STANDARD + elif c == _C.EngineCapability.SAFETY: + return EngineCapability.SAFETY + elif c == _C.EngineCapability.DLA_STANDALONE: + return EngineCapability.DLA_STANDALONE + else: + raise ValueError("Provided an unsupported engine capability") + # else: # commented out for mypy + raise TypeError( + "Provided unsupported source type for EngineCapability conversion" + ) + + @classmethod + def try_from( + c: Union[trt.EngineCapability, EngineCapability] + ) -> Optional[EngineCapability]: + try: + casted_format = EngineCapability._from(c) + return casted_format + except (ValueError, TypeError) as e: + logging.debug( + f"Conversion from {c} to torch_tensorrt.EngineCapablity failed", + exc_info=True, + ) + return None + + def to( + self, t: Union[Type[trt.EngineCapability], Type[EngineCapability]] + ) -> Union[trt.EngineCapability, EngineCapability]: + if t == trt.EngineCapability: + if self == EngineCapability.STANDARD: + return trt.EngineCapability.STANDARD + elif self == EngineCapability.SAFETY: + return trt.EngineCapability.SAFETY + elif self == EngineCapability.DLA_STANDALONE: + return trt.EngineCapability.DLA_STANDALONE + else: + raise ValueError("Provided an unsupported engine capability") + + elif t == EngineCapability: + return self + + elif ENABLED_FEATURES.torchscript_frontend: + from torch_tensorrt import _C + + if t == _C.EngineCapability: + if self == EngineCapability.STANDARD: + return _C.EngineCapability.STANDARD + elif self == EngineCapability.SAFETY: + return _C.EngineCapability.SAFETY + elif self == EngineCapability.DLA_STANDALONE: + return _C.EngineCapability.DLA_STANDALONE + else: + raise ValueError("Provided an unsupported engine capability") + # else: # commented out for mypy + raise TypeError( + "Provided unsupported destination type for engine capablity type conversion" + ) + + def try_to( + self, t: Union[Type[trt.EngineCapability], Type[EngineCapability]] + ) -> Optional[Union[trt.EngineCapability, EngineCapability]]: + try: + casted_format = self.to(t) + return casted_format + except (ValueError, TypeError) as e: + logging.debug( + f"torch_tensorrt.EngineCapablity conversion to target type {t} failed", + exc_info=True, + ) + return None + + def __eq__(self, other: Union[trt.EngineCapability, EngineCapability]) -> bool: + other_ = EngineCapability._from(other) + return bool(self.value == other_.value) + + def __hash__(self) -> int: + return hash(self.value) diff --git a/py/torch_tensorrt/_features.py b/py/torch_tensorrt/_features.py new file mode 100644 index 0000000000..dde99cbaf6 --- /dev/null +++ b/py/torch_tensorrt/_features.py @@ -0,0 +1,35 @@ +import os +from collections import namedtuple + +from torch_tensorrt._utils import sanitized_torch_version + +from packaging import version + +FeatureSet = namedtuple( + "FeatureSet", + [ + "torchscript_frontend", + "torch_tensorrt_runtime", + "dynamo_frontend", + "fx_frontend", + ], +) + +_TS_FE_AVAIL = os.path.isfile(os.path.dirname(__file__) + "/lib/libtorchtrt.so") +_TORCHTRT_RT_AVAIL = _TS_FE_AVAIL or os.path.isfile( + os.path.dirname(__file__) + "/lib/libtorchtrt_runtime.so" +) +_DYNAMO_FE_AVAIL = version.parse(sanitized_torch_version()) >= version.parse("2.1.dev") +_FX_FE_AVAIL = True + +ENABLED_FEATURES = FeatureSet( + _TS_FE_AVAIL, _TORCHTRT_RT_AVAIL, _DYNAMO_FE_AVAIL, _FX_FE_AVAIL +) + + +def _enabled_features_str() -> str: + enabled = lambda x: "ENABLED" if x else "DISABLED" + out_str: str = ( + f"Enabled Features:\n - Dynamo Frontend: {enabled(_DYNAMO_FE_AVAIL)}\n - Torch-TensorRT Runtime: {enabled(_TORCHTRT_RT_AVAIL)}\n - FX Frontend: {enabled(_FX_FE_AVAIL)}\n - TorchScript Frontend: {enabled(_TS_FE_AVAIL)}\n" # type: ignore[no-untyped-call] + ) + return out_str diff --git a/py/torch_tensorrt/_utils.py b/py/torch_tensorrt/_utils.py index b21696427b..3d5f98b5e5 100644 --- a/py/torch_tensorrt/_utils.py +++ b/py/torch_tensorrt/_utils.py @@ -1,36 +1,6 @@ from typing import Any import torch -from torch_tensorrt import _C -from torch_tensorrt._version import __version__ - - -def dump_build_info() -> None: - """Prints build information about the torch_tensorrt distribution to stdout""" - print(get_build_info()) - - -def get_build_info() -> str: - """Returns a string containing the build information of torch_tensorrt distribution - - Returns: - str: String containing the build information for torch_tensorrt distribution - """ - core_build_info = _C.get_build_info() - build_info = str( - "Torch-TensorRT Version: " - + str(__version__) - + "\n" - + "Using PyTorch Version: " - + str(torch.__version__) - + "\n" - + core_build_info - ) - return build_info - - -def set_device(gpu_id: int) -> None: - _C.set_device(gpu_id) def sanitized_torch_version() -> Any: diff --git a/py/torch_tensorrt/csrc/tensorrt_classes.cpp b/py/torch_tensorrt/csrc/tensorrt_classes.cpp index 4794a679eb..bd3aa6b305 100644 --- a/py/torch_tensorrt/csrc/tensorrt_classes.cpp +++ b/py/torch_tensorrt/csrc/tensorrt_classes.cpp @@ -235,23 +235,23 @@ std::string Device::to_str() { std::string to_str(EngineCapability value) { switch (value) { - case EngineCapability::kSAFE_GPU: - return "Safe GPU"; - case EngineCapability::kSAFE_DLA: - return "Safe DLA"; - case EngineCapability::kDEFAULT: + case EngineCapability::kDLA_STANDALONE: + return "DLA Standalone"; + case EngineCapability::kSAFETY: + return "Safety"; + case EngineCapability::kSTANDARD: default: - return "Default"; + return "Standard"; } } nvinfer1::EngineCapability toTRTEngineCapability(EngineCapability value) { switch (value) { - case EngineCapability::kSAFE_DLA: + case EngineCapability::kDLA_STANDALONE: return TRT_ENGINE_CAPABILITY_DLA_STANDALONE; - case EngineCapability::kSAFE_GPU: + case EngineCapability::kSAFETY: return TRT_ENGINE_CAPABILITY_SAFETY; - case EngineCapability::kDEFAULT: + case EngineCapability::kSTANDARD: default: return TRT_ENGINE_CAPABILITY_STANDARD; } diff --git a/py/torch_tensorrt/csrc/tensorrt_classes.h b/py/torch_tensorrt/csrc/tensorrt_classes.h index 9bdd00b7e0..89c5c8661e 100644 --- a/py/torch_tensorrt/csrc/tensorrt_classes.h +++ b/py/torch_tensorrt/csrc/tensorrt_classes.h @@ -114,9 +114,9 @@ struct TorchFallback : torch::CustomClassHolder { }; enum class EngineCapability : int8_t { - kDEFAULT, - kSAFE_GPU, - kSAFE_DLA, + kSTANDARD, + kSAFETY, + kDLA_STANDALONE, }; std::string to_str(EngineCapability value); @@ -160,7 +160,7 @@ struct CompileSpec : torch::CustomClassHolder { ADD_FIELD_GET_SET(sparse_weights, bool); ADD_FIELD_GET_SET(refit, bool); ADD_FIELD_GET_SET(debug, bool); - ADD_ENUM_GET_SET(capability, EngineCapability, static_cast(EngineCapability::kSAFE_DLA)); + ADD_ENUM_GET_SET(capability, EngineCapability, static_cast(EngineCapability::kSTANDARD)); ADD_FIELD_GET_SET(num_avg_timing_iters, int64_t); ADD_FIELD_GET_SET(workspace_size, int64_t); ADD_FIELD_GET_SET(dla_sram_size, int64_t); @@ -184,7 +184,7 @@ struct CompileSpec : torch::CustomClassHolder { bool allow_shape_tensors = false; Device device; TorchFallback torch_fallback; - EngineCapability capability = EngineCapability::kDEFAULT; + EngineCapability capability = EngineCapability::kSTANDARD; int64_t num_avg_timing_iters = 1; int64_t workspace_size = 0; int64_t dla_sram_size = 1048576; diff --git a/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp b/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp index 33c7e27398..e4d88088e4 100644 --- a/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp +++ b/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp @@ -261,9 +261,9 @@ PYBIND11_MODULE(_C, m) { m, "EngineCapability", "Enum to specify engine capability settings (selections of kernels to meet safety requirements)") - .value("safe_gpu", EngineCapability::kSAFE_GPU, "Use safety GPU kernels only") - .value("safe_dla", EngineCapability::kSAFE_DLA, "Use safety DLA kernels only") - .value("default", EngineCapability::kDEFAULT, "Use default behavior"); + .value("SAFETY", EngineCapability::kSAFETY, "Use safe kernels only") + .value("DLA_STANDALONE", EngineCapability::kDLA_STANDALONE, "Use DLA kernels only") + .value("STANDARD", EngineCapability::kSTANDARD, "Use default behavior"); py::enum_(m, "TensorFormat", "Enum to specifiy the memory layout of tensors") .value("contiguous", TensorFormat::kContiguous, "Contiguous memory layout (NCHW / Linear)") diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 6312532f1c..ed9a0bb7ae 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -5,43 +5,12 @@ from typing import Any, Collection, List, Optional, Sequence, Set, Tuple, Union import torch -import torch_tensorrt from torch.export import ExportedProgram from torch.fx.node import Target -from torch_tensorrt import _enums from torch_tensorrt._Device import Device -from torch_tensorrt._enums import ( # TODO: Should probabably be the TRT EngineCapability Enum - EngineCapability, -) +from torch_tensorrt._enums import EngineCapability, dtype from torch_tensorrt._Input import Input -from torch_tensorrt.dynamo import partitioning -from torch_tensorrt.dynamo._defaults import ( - DEBUG, - DEVICE, - DISABLE_TF32, - DLA_GLOBAL_DRAM_SIZE, - DLA_LOCAL_DRAM_SIZE, - DLA_SRAM_SIZE, - DRYRUN, - ENABLE_EXPERIMENTAL_DECOMPOSITIONS, - ENGINE_CAPABILITY, - HARDWARE_COMPATIBLE, - MAX_AUX_STREAMS, - MIN_BLOCK_SIZE, - NUM_AVG_TIMING_ITERS, - OPTIMIZATION_LEVEL, - OUTPUT_FORMAT, - PASS_THROUGH_BUILD_FAILURES, - PRECISION, - REFIT, - REQUIRE_FULL_COMPILATION, - SPARSE_WEIGHTS, - TRUNCATE_LONG_AND_DOUBLE, - USE_FAST_PARTITIONER, - USE_PYTHON_RUNTIME, - VERSION_COMPATIBLE, - WORKSPACE_SIZE, -) +from torch_tensorrt.dynamo import _defaults, partitioning from torch_tensorrt.dynamo._DryRunTracker import ( DryRunTracker, PerSubgraphData, @@ -76,33 +45,35 @@ def compile( exported_program: ExportedProgram, inputs: Tuple[Any, ...], *, - device: Optional[Union[Device, torch.device, str]] = DEVICE, - disable_tf32: bool = DISABLE_TF32, - sparse_weights: bool = SPARSE_WEIGHTS, - enabled_precisions: Set[torch.dtype] | Tuple[torch.dtype] = (torch.float32,), - engine_capability: EngineCapability = ENGINE_CAPABILITY, - refit: bool = REFIT, - debug: bool = DEBUG, - num_avg_timing_iters: int = NUM_AVG_TIMING_ITERS, - workspace_size: int = WORKSPACE_SIZE, - dla_sram_size: int = DLA_SRAM_SIZE, - dla_local_dram_size: int = DLA_LOCAL_DRAM_SIZE, - dla_global_dram_size: int = DLA_GLOBAL_DRAM_SIZE, - truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE, - require_full_compilation: bool = REQUIRE_FULL_COMPILATION, - min_block_size: int = MIN_BLOCK_SIZE, + device: Optional[Union[Device, torch.device, str]] = _defaults.DEVICE, + disable_tf32: bool = _defaults.DISABLE_TF32, + sparse_weights: bool = _defaults.SPARSE_WEIGHTS, + enabled_precisions: ( + Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype] + ) = _defaults.ENABLED_PRECISIONS, + engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY, + refit: bool = _defaults.REFIT, + debug: bool = _defaults.DEBUG, + num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS, + workspace_size: int = _defaults.WORKSPACE_SIZE, + dla_sram_size: int = _defaults.DLA_SRAM_SIZE, + dla_local_dram_size: int = _defaults.DLA_LOCAL_DRAM_SIZE, + dla_global_dram_size: int = _defaults.DLA_GLOBAL_DRAM_SIZE, + truncate_long_and_double: bool = _defaults.TRUNCATE_LONG_AND_DOUBLE, + require_full_compilation: bool = _defaults.REQUIRE_FULL_COMPILATION, + min_block_size: int = _defaults.MIN_BLOCK_SIZE, torch_executed_ops: Optional[Collection[Target]] = None, torch_executed_modules: Optional[List[str]] = None, - pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES, - max_aux_streams: Optional[int] = MAX_AUX_STREAMS, - version_compatible: bool = VERSION_COMPATIBLE, - optimization_level: Optional[int] = OPTIMIZATION_LEVEL, - use_python_runtime: bool = USE_PYTHON_RUNTIME, - use_fast_partitioner: bool = USE_FAST_PARTITIONER, - enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS, - dryrun: bool = DRYRUN, - hardware_compatible: bool = HARDWARE_COMPATIBLE, - output_format: str = OUTPUT_FORMAT, + pass_through_build_failures: bool = _defaults.PASS_THROUGH_BUILD_FAILURES, + max_aux_streams: Optional[int] = _defaults.MAX_AUX_STREAMS, + version_compatible: bool = _defaults.VERSION_COMPATIBLE, + optimization_level: Optional[int] = _defaults.OPTIMIZATION_LEVEL, + use_python_runtime: bool = _defaults.USE_PYTHON_RUNTIME, + use_fast_partitioner: bool = _defaults.USE_FAST_PARTITIONER, + enable_experimental_decompositions: bool = _defaults.ENABLE_EXPERIMENTAL_DECOMPOSITIONS, + dryrun: bool = _defaults.DRYRUN, + hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE, + output_format: str = _defaults.OUTPUT_FORMAT, **kwargs: Any, ) -> Union[ExportedProgram, torch.jit.ScriptModule, torch.fx.GraphModule]: """Compile a TorchScript module for NVIDIA GPUs using TensorRT @@ -170,6 +141,8 @@ def compile( if debug: set_log_level(logger.parent, logging.DEBUG) + engine_capability = EngineCapability._from(engine_capability) + if torch_executed_modules is not None and torch_executed_modules: logger.warning( f"Detected torch_executed_modules was non-empty: {torch_executed_modules}" @@ -182,6 +155,7 @@ def compile( # Prepare torch_trt inputs inputs = prepare_inputs(inputs) device = to_torch_tensorrt_device(device) + enabled_precisions = {dtype._from(p) for p in enabled_precisions} if not isinstance(exported_program, ExportedProgram): raise AssertionError( @@ -198,28 +172,10 @@ def compile( gm = apply_lowering_passes(gm, torch_inputs) logger.debug("Lowered Input graph: " + str(gm.graph)) - enabled_precisions = set(enabled_precisions) - - if ( - torch.float16 in enabled_precisions - or torch_tensorrt.dtype.half in enabled_precisions - ): - precision = torch.float16 - elif ( - torch.float32 in enabled_precisions - or torch_tensorrt.dtype.float in enabled_precisions - ): - precision = torch.float32 - elif len(enabled_precisions) == 0: - logger.info(f"No precision specified, defaulting to {PRECISION}") - precision = PRECISION - else: - raise ValueError( - f"Precision {enabled_precisions} not supported in the Dynamo Path" - ) - compilation_options = { - "precision": precision, + "enabled_precisions": ( + enabled_precisions if enabled_precisions else _defaults.ENABLED_PRECISIONS + ), "debug": debug, "device": device, "workspace_size": workspace_size, @@ -288,7 +244,7 @@ def compile_module( sample_inputs, "shape", lambda x: dict(x) if isinstance(x, dict) else tuple(x) ) dryrun_tracker.graph_input_dtypes = parse_complex_tensor_structs( - sample_inputs, "torch_dtype" + sample_inputs, "dtype", lambda t: t.to(torch.dtype, use_default=True) ) dryrun_tracker.compilation_settings = settings @@ -402,7 +358,7 @@ def compile_module( lambda x: dict(x) if isinstance(x, dict) else tuple(x), ) subgraph_data.subgraph_input_dtypes = parse_complex_tensor_structs( - submodule_inputs, "torch_dtype" + submodule_inputs, "dtype", lambda t: t.to(torch.dtype) ) submodule_outputs = submodule( @@ -463,29 +419,31 @@ def convert_module_to_trt_engine( module: torch.fx.GraphModule, method_name: str = "forward", inputs: Optional[Sequence[Input | torch.Tensor]] = None, - enabled_precisions: Optional[Set[torch.dtype | _enums.dtype]] = None, - debug: bool = DEBUG, - workspace_size: int = WORKSPACE_SIZE, - min_block_size: int = MIN_BLOCK_SIZE, - torch_executed_ops: Set[str] = set(), - pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES, - max_aux_streams: Optional[int] = MAX_AUX_STREAMS, - version_compatible: bool = VERSION_COMPATIBLE, - optimization_level: Optional[int] = OPTIMIZATION_LEVEL, - use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME, - truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE, - use_fast_partitioner: bool = USE_FAST_PARTITIONER, - enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS, + enabled_precisions: ( + Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype] + ) = _defaults.ENABLED_PRECISIONS, + debug: bool = _defaults.DEBUG, + workspace_size: int = _defaults.WORKSPACE_SIZE, + min_block_size: int = _defaults.MIN_BLOCK_SIZE, + torch_executed_ops: Optional[Set[str]] = None, + pass_through_build_failures: bool = _defaults.PASS_THROUGH_BUILD_FAILURES, + max_aux_streams: Optional[int] = _defaults.MAX_AUX_STREAMS, + version_compatible: bool = _defaults.VERSION_COMPATIBLE, + optimization_level: Optional[int] = _defaults.OPTIMIZATION_LEVEL, + use_python_runtime: Optional[bool] = _defaults.USE_PYTHON_RUNTIME, + truncate_long_and_double: bool = _defaults.TRUNCATE_LONG_AND_DOUBLE, + use_fast_partitioner: bool = _defaults.USE_FAST_PARTITIONER, + enable_experimental_decompositions: bool = _defaults.ENABLE_EXPERIMENTAL_DECOMPOSITIONS, device: Device = Device._current_device(), - require_full_compilation: bool = REQUIRE_FULL_COMPILATION, - disable_tf32: bool = DISABLE_TF32, - sparse_weights: bool = SPARSE_WEIGHTS, - refit: bool = REFIT, - engine_capability: EngineCapability = ENGINE_CAPABILITY, - num_avg_timing_iters: int = NUM_AVG_TIMING_ITERS, - dla_sram_size: int = DLA_SRAM_SIZE, - dla_local_dram_size: int = DLA_LOCAL_DRAM_SIZE, - dla_global_dram_size: int = DLA_GLOBAL_DRAM_SIZE, + require_full_compilation: bool = _defaults.REQUIRE_FULL_COMPILATION, + disable_tf32: bool = _defaults.DISABLE_TF32, + sparse_weights: bool = _defaults.SPARSE_WEIGHTS, + refit: bool = _defaults.REFIT, + engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY, + num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS, + dla_sram_size: int = _defaults.DLA_SRAM_SIZE, + dla_local_dram_size: int = _defaults.DLA_LOCAL_DRAM_SIZE, + dla_global_dram_size: int = _defaults.DLA_GLOBAL_DRAM_SIZE, calibrator: object = None, allow_shape_tensors: bool = False, ) -> bytes: @@ -569,34 +527,15 @@ def convert_module_to_trt_engine( set_log_level(logger.parent, logging.DEBUG) input_list = list(inputs) if inputs is not None else [] + torch_executed_ops = torch_executed_ops if torch_executed_ops is not None else set() # Prepare torch_trt inputs input_list = prepare_inputs(input_list) device = to_torch_tensorrt_device(device) - enabled_precisions = ( - enabled_precisions if enabled_precisions is not None else {torch.float} - ) - - if ( - torch.float16 in enabled_precisions - or torch_tensorrt.dtype.half in enabled_precisions - ): - precision = torch.float16 - elif ( - torch.float32 in enabled_precisions - or torch_tensorrt.dtype.float in enabled_precisions - ): - precision = torch.float32 - elif len(enabled_precisions) == 0: - logger.info(f"No precision specified, defaulting to {PRECISION}") - precision = PRECISION - else: - raise ValueError( - f"Precision {enabled_precisions} not supported in the Dynamo Path" - ) + enabled_precisions = {dtype._from(e) for e in enabled_precisions} compilation_options = { - "precision": precision, + "enabled_precisions": enabled_precisions, "debug": debug, "workspace_size": workspace_size, "min_block_size": min_block_size, diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index ec038c0dba..c43cc78d76 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -1,8 +1,8 @@ import torch -from tensorrt import EngineCapability from torch_tensorrt._Device import Device +from torch_tensorrt._enums import EngineCapability, dtype -PRECISION = torch.float32 +ENABLED_PRECISIONS = {dtype.f32} DEBUG = False DEVICE = None DISABLE_TF32 = False @@ -27,6 +27,7 @@ DRYRUN = False HARDWARE_COMPATIBLE = False OUTPUT_FORMAT = "exported_program" +SUPPORTED_KERNEL_PRECISIONS = {dtype.f32, dtype.f16, dtype.i8} def default_device() -> Device: diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index c00b049f45..c35156cdb4 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -1,10 +1,9 @@ from dataclasses import dataclass, field from typing import Collection, Optional, Union -import torch -from tensorrt import EngineCapability from torch.fx.node import Target from torch_tensorrt._Device import Device +from torch_tensorrt._enums import EngineCapability, dtype from torch_tensorrt.dynamo._defaults import ( DEBUG, DISABLE_TF32, @@ -13,6 +12,7 @@ DLA_SRAM_SIZE, DRYRUN, ENABLE_EXPERIMENTAL_DECOMPOSITIONS, + ENABLED_PRECISIONS, ENGINE_CAPABILITY, HARDWARE_COMPATIBLE, MAX_AUX_STREAMS, @@ -21,7 +21,6 @@ OPTIMIZATION_LEVEL, OUTPUT_FORMAT, PASS_THROUGH_BUILD_FAILURES, - PRECISION, REFIT, REQUIRE_FULL_COMPILATION, SPARSE_WEIGHTS, @@ -39,7 +38,7 @@ class CompilationSettings: """Compilation settings for Torch-TensorRT Dynamo Paths Args: - precision (torch.dtype): Model Layer precision + enabled_precisions (Set[dtype]): Available kernel dtype precisions debug (bool): Whether to print out verbose debugging information workspace_size (int): Workspace TRT is allowed to use for the module (0 is default) min_block_size (int): Minimum number of operators per TRT-Engine Block @@ -74,7 +73,7 @@ class CompilationSettings: output_format (str): Output format of the result of TRT compilation. Options include "exported_program" (or) "ep" | "torchscript" (or) "ts" | "graph_module" (or) "fx". Default is "exported_program" """ - precision: torch.dtype = PRECISION + enabled_precisions: dtype = field(default_factory=lambda: ENABLED_PRECISIONS) debug: bool = DEBUG workspace_size: int = WORKSPACE_SIZE min_block_size: int = MIN_BLOCK_SIZE @@ -92,7 +91,9 @@ class CompilationSettings: disable_tf32: bool = DISABLE_TF32 sparse_weights: bool = SPARSE_WEIGHTS refit: bool = REFIT - engine_capability: EngineCapability = ENGINE_CAPABILITY + engine_capability: EngineCapability = field( + default_factory=lambda: ENGINE_CAPABILITY + ) num_avg_timing_iters: int = NUM_AVG_TIMING_ITERS dla_sram_size: int = DLA_SRAM_SIZE dla_local_dram_size: int = DLA_LOCAL_DRAM_SIZE diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 06ae596ed0..6362e4253b 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -4,13 +4,14 @@ from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set import numpy as np -import tensorrt as trt import torch import torch.fx from torch.fx.node import _get_qualified_name from torch.fx.passes.shape_prop import TensorMetadata from torch.utils._python_dispatch import _disable_current_modes +from torch_tensorrt._enums import dtype from torch_tensorrt._Input import Input +from torch_tensorrt.dynamo import _defaults from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( @@ -22,8 +23,9 @@ get_trt_tensor, ) from torch_tensorrt.fx.observer import Observer -from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter +from torch_tensorrt.logging import TRT_LOGGER +import tensorrt as trt from packaging import version _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -50,13 +52,12 @@ def __init__( module: torch.fx.GraphModule, input_specs: Sequence[Input], logger_level: trt.ILogger.Severity = trt.ILogger.Severity.WARNING, - output_dtypes: Optional[Sequence[torch.dtype]] = None, + output_dtypes: Optional[Sequence[dtype]] = None, compilation_settings: CompilationSettings = CompilationSettings(), ): super().__init__(module) - # TODO: @narendasan replace with Torch-TensorRT Logger - self.logger = trt.Logger(logger_level) + self.logger = TRT_LOGGER self.builder = trt.Builder(self.logger) flag = 0 @@ -69,9 +70,11 @@ def __init__( self.builder.create_network(flag), compilation_settings ) + assert TRTInterpreter._all_precisions_supported( + compilation_settings.enabled_precisions + ), f"Attempted to enable kernel precisions that are not supported (got: {compilation_settings.enabled_precisions}, support: {_defaults.SUPPORTED_KERNEL_PRECISIONS})" missing_ops = self.validate_conversion() if missing_ops: - # TODO: @narendasan make sure to set logging.captureWarnings(True) warnings.warn( "Interpretation will fail due to missing operations \n" + "\n".join(f"{i}" for i in missing_ops) @@ -98,7 +101,11 @@ def __init__( self.compilation_settings = compilation_settings # Data types for TRT Module output Tensors - self.output_dtypes = output_dtypes + self.output_dtypes = ( + [dtype._from(o) for o in output_dtypes] if output_dtypes else None + ) + + _LOGGER.debug(f"Graph to be compiled to TensorRT: {self.module.graph}") def validate_conversion(self) -> Set[str]: missing_converters: Set[str] = set() @@ -116,60 +123,58 @@ def validate_conversion(self) -> Set[str]: return missing_converters - def run( - self, - force_fp32_output: bool = False, - strict_type_constraints: bool = False, - algorithm_selector: Optional[trt.IAlgorithmSelector] = None, - timing_cache: Optional[trt.ITimingCache] = None, - tactic_sources: Optional[int] = None, - ) -> TRTInterpreterResult: - """ - Build TensorRT engine with some configs. - Args: - force_fp32_output: force output to be fp32 - strict_type_constraints: Usually we should set it to False unless we want to control the precision of certain layer for numeric reasons. - algorithm_selector: set up algorithm selection for certain layer - timing_cache: enable timing cache for TensorRT - Return: - TRTInterpreterResult - """ - TRT_INTERPRETER_CALL_PRE_OBSERVER.observe(self.module) + @staticmethod + def _args_str(args: List[Any]) -> str: + def clean_repr(x: Any, depth: int = 0) -> Any: + if isinstance(x, trt.ITensor): + return f"{x.name} " + elif isinstance(x, torch.Tensor): + return f"" + elif isinstance(x, np.ndarray): + return ( + f"" + ) + elif isinstance(x, Sequence) and not isinstance(x, str): + if depth < 3: + return type(x)([clean_repr(i, depth=depth + 1) for i in x]) # type: ignore[call-arg] + else: + return "(...)" + else: + return x + + str_args = [clean_repr(a) for a in args] + return repr(tuple(str_args)) - precision = self.compilation_settings.precision - # For float outputs, we set their dtype to fp16 only if precision == torch.float16 and - # force_fp32_output=False. Overriden by specifying output_dtypes - self.output_fp16 = not force_fp32_output and precision == torch.float16 + @staticmethod + def _all_precisions_supported(enabled_precisions: Set[dtype]) -> bool: + return enabled_precisions.issubset(_defaults.SUPPORTED_KERNEL_PRECISIONS) - if precision == torch.int8 and not self.builder.platform_has_fast_int8: + def validate_compile_settings(self) -> None: + if ( + dtype.i8 in self.compilation_settings.enabled_precisions + and not self.builder.platform_has_fast_int8 + ): raise RuntimeError("Current platform doesn't support fast native int8!") - if precision == torch.float16 and not self.builder.platform_has_fast_fp16: + if ( + dtype.f16 in self.compilation_settings.enabled_precisions + and not self.builder.platform_has_fast_fp16 + ): warnings.warn("Current platform doesn't support fast native fp16!") - self.input_specs_iter = 0 - run_module_start_time = datetime.now() - super().run() - _LOGGER.info( - f"TRT INetwork construction elapsed time: {datetime.now() - run_module_start_time}" - ) - build_engine_start_time = datetime.now() + def _populate_trt_builder_config( + self, + strict_type_constraints: bool = False, + algorithm_selector: Optional[trt.IAlgorithmSelector] = None, + tactic_sources: Optional[int] = None, + ) -> trt.IBuilderConfig: builder_config = self.builder.create_builder_config() - if self.compilation_settings.workspace_size != 0: builder_config.set_memory_pool_limit( trt.MemoryPoolType.WORKSPACE, self.compilation_settings.workspace_size ) - cache = None - if timing_cache: - cache_file = np.array(timing_cache) - cache = builder_config.create_timing_cache(cache_file.tobytes()) - else: - cache = builder_config.create_timing_cache(b"") - builder_config.set_timing_cache(cache, False) - if version.parse(trt.__version__) >= version.parse("8.2"): builder_config.profiling_verbosity = ( trt.ProfilingVerbosity.VERBOSE @@ -201,12 +206,20 @@ def run( self.compilation_settings.optimization_level ) - builder_config.engine_capability = self.compilation_settings.engine_capability + builder_config.engine_capability = ( + self.compilation_settings.engine_capability.to(trt.EngineCapability) + ) builder_config.avg_timing_iterations = ( self.compilation_settings.num_avg_timing_iters ) if self.compilation_settings.device.device_type == trt.DeviceType.DLA: + device_info = torch.cuda.get_device_properties( + self.compilation_settings.device.gpu_id + ) + assert (device_info.major == 8 and device_info.minor == 7) or ( + device_info.major == 7 and device_info.minor == 2 + ), "DLA is not available on non AGX systems" builder_config.DLA_core = self.compilation_settings.device.dla_core _LOGGER.info(f"Using DLA core {self.compilation_settings.device.dla_core}") builder_config.set_memory_pool_limit( @@ -222,10 +235,10 @@ def run( self.compilation_settings.dla_global_dram_size, ) - if precision == torch.float16: + if dtype.float16 in self.compilation_settings.enabled_precisions: builder_config.set_flag(trt.BuilderFlag.FP16) - if precision == torch.int8: + if dtype.int8 in self.compilation_settings.enabled_precisions: builder_config.set_flag(trt.BuilderFlag.INT8) if self.compilation_settings.sparse_weights: @@ -252,11 +265,58 @@ def run( if tactic_sources is not None: builder_config.set_tactic_sources(tactic_sources=tactic_sources) + return builder_config + + def _create_timing_cache( + self, + builder_config: trt.IBuilderConfig, + existing_cache: Optional[trt.ITimingCache] = None, + ) -> trt.ITimingCache: + cache = None + if existing_cache: + cache_file = np.array(existing_cache) + cache = builder_config.create_timing_cache(cache_file.tobytes()) + else: + cache = builder_config.create_timing_cache(b"") + builder_config.set_timing_cache(cache, False) + return cache + + def run( + self, + strict_type_constraints: bool = False, + algorithm_selector: Optional[trt.IAlgorithmSelector] = None, + existing_cache: Optional[trt.ITimingCache] = None, + tactic_sources: Optional[int] = None, + ) -> TRTInterpreterResult: + """ + Build TensorRT engine with some configs. + Args: + strict_type_constraints: Usually we should set it to False unless we want to control the precision of certain layer for numeric reasons. + algorithm_selector: set up algorithm selection for certain layer + existing_cache: enable timing cache for TensorRT + Return: + TRTInterpreterResult + """ + TRT_INTERPRETER_CALL_PRE_OBSERVER.observe(self.module) + + self.input_specs_iter = 0 + run_module_start_time = datetime.now() + super().run() + _LOGGER.info( + f"TRT INetwork construction elapsed time: {datetime.now() - run_module_start_time}" + ) + build_engine_start_time = datetime.now() + + builder_config = self._populate_trt_builder_config( + strict_type_constraints, algorithm_selector, tactic_sources + ) + timing_cache = self._create_timing_cache(builder_config, existing_cache) + engine = self.builder.build_engine(self.ctx.net, builder_config) assert engine serialized_cache = ( - bytearray(cache.serialize()) + bytearray(timing_cache.serialize()) if builder_config.get_timing_cache() else bytearray() ) @@ -285,7 +345,7 @@ def run_node(self, n: torch.fx.Node) -> torch.fx.Node: del kwargs["_itensor_to_tensor_meta"] n.kwargs = kwargs - if isinstance(trt_node, trt.tensorrt.ITensor): + if isinstance(trt_node, trt.ITensor): self._itensor_to_tensor_meta[trt_node] = n.meta.get("tensor_meta") return trt_node @@ -323,10 +383,14 @@ def placeholder(self, target: str, args: Any, kwargs: Any) -> trt.ITensor: f"Unable to access shape spec for input: {target} (got: {current_input})" ) + trt_input_dtype = current_input.dtype.to(trt.DataType, use_default=True) + _LOGGER.debug( + f"Adding input to in-progress INetwork: {target} [shape={shape}, dtype={trt_input_dtype}]" + ) return self.ctx.net.add_input( name=target, shape=tuple(shape), - dtype=unified_dtype_converter(current_input.torch_dtype, Frameworks.TRT), + dtype=trt_input_dtype, ) def call_module( @@ -345,6 +409,9 @@ def call_module( converter, calling_convention = converter_packet assert self._cur_node_name is not None + _LOGGER.debug( + f"Converting node {self._cur_node_name} (kind: {target}, args: {TRTInterpreter._args_str(args)})" + ) if calling_convention is CallingConvention.LEGACY: return converter(self.ctx.net, submod, args, kwargs, self._cur_node_name) else: @@ -361,6 +428,9 @@ def call_function(self, target: str, args: Any, kwargs: Any) -> Any: converter, calling_convention = converter_packet assert self._cur_node_name is not None + _LOGGER.debug( + f"Converting node {self._cur_node_name} (kind: {target}, args: {TRTInterpreter._args_str(args)})" + ) if calling_convention is CallingConvention.LEGACY: return converter(self.ctx.net, target, args, kwargs, self._cur_node_name) else: @@ -392,6 +462,9 @@ def call_method(self, target: str, args: Any, kwargs: Any) -> Any: converter, calling_convention = converter_packet assert self._cur_node_name is not None + _LOGGER.debug( + f"Converting node {self._cur_node_name} (kind: {target}, args: {TRTInterpreter._args_str(args)})" + ) if calling_convention is CallingConvention.LEGACY: return converter(self.ctx.net, target, args, kwargs, self._cur_node_name) else: @@ -409,13 +482,13 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]: for output_idx in range(len(outputs)): output = outputs[output_idx] - if not isinstance(output, trt.tensorrt.ITensor): + if not isinstance(output, trt.ITensor): new_output = get_trt_tensor(self.ctx, output, target) outputs = ( outputs[:output_idx] + (new_output,) + outputs[output_idx + 1 :] ) - if not all(isinstance(output, trt.tensorrt.ITensor) for output in outputs): + if not all(isinstance(output, trt.ITensor) for output in outputs): raise RuntimeError("TensorRT requires all outputs to be Tensor!") if self.output_dtypes is not None and len(self.output_dtypes) != len(outputs): @@ -436,6 +509,7 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]: "not", "ne", "isinf", + "isnan", "any", ) ): @@ -446,13 +520,13 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]: output.name = name self.ctx.net.mark_output(output) if output_bool: - output.dtype = trt.bool + output.dtype = trt.DataType.BOOL elif self.output_dtypes is not None: - output.dtype = unified_dtype_converter( - self.output_dtypes[i], Frameworks.TRT - ) - elif self.output_fp16 and output.dtype == trt.float32: - output.dtype = trt.float16 + output.dtype = self.output_dtypes[i].to(trt.DataType) + self._output_names.append(name) + _LOGGER.debug( + f"Marking output {name} [shape={output.shape}, dtype={output.dtype}]" + ) return list(outputs) diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index d7cfe15694..6a2530a956 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -1,10 +1,13 @@ from __future__ import annotations import io -from typing import Sequence +import logging +from typing import List, Sequence -import tensorrt as trt import torch +from torch_tensorrt._Device import Device +from torch_tensorrt._enums import dtype +from torch_tensorrt._features import ENABLED_FEATURES from torch_tensorrt._Input import Input from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo.conversion._TRTInterpreter import ( @@ -14,6 +17,37 @@ from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule from torch_tensorrt.dynamo.utils import get_torch_inputs +import tensorrt as trt + +logger = logging.getLogger(__name__) + + +def infer_module_output_dtypes( + module: torch.fx.GraphModule, + inputs: Sequence[Input], + device: Device, + truncate_long_and_double: bool = False, +) -> List[dtype]: + torch_inputs = get_torch_inputs(inputs, device) + module = module.to(device.to(torch.device)) + module_outputs = module(*torch_inputs) + + if not isinstance(module_outputs, (list, tuple)): + module_outputs = [module_outputs] + + # Int64 outputs can sometimes be generated from within other operators + # such as aten.sum - such outputs can be truncated + output_dtypes = [] + for output in module_outputs: + if truncate_long_and_double and output.dtype == dtype.float64: + output_dtypes.append(dtype.float32) + elif truncate_long_and_double and output.dtype == dtype.int64: + output_dtypes.append(dtype.int32) + else: + output_dtypes.append(dtype._from(output.dtype)) + + return output_dtypes + def interpret_module_to_result( module: torch.fx.GraphModule, @@ -28,22 +62,12 @@ def interpret_module_to_result( Returns: TRTInterpreterResult """ - torch_inputs = get_torch_inputs(inputs, settings.device) - module_outputs = module(*torch_inputs) - - if not isinstance(module_outputs, (list, tuple)): - module_outputs = [module_outputs] - - # Int64 outputs can sometimes be generated from within other operators - # such as aten.sum - such outputs can be truncated - output_dtypes = [] - for output in module_outputs: - if settings.truncate_long_and_double and output.dtype == torch.float64: - output_dtypes.append(torch.float32) - elif settings.truncate_long_and_double and output.dtype == torch.int64: - output_dtypes.append(torch.int32) - else: - output_dtypes.append(output.dtype) + output_dtypes = infer_module_output_dtypes( + module, + inputs, + settings.device, + truncate_long_and_double=settings.truncate_long_and_double, + ) interpreter = TRTInterpreter( module, @@ -73,7 +97,11 @@ def convert_module( """ interpreter_result = interpret_module_to_result(module, inputs, settings) - if settings.use_python_runtime: + if settings.use_python_runtime or not ENABLED_FEATURES.torch_tensorrt_runtime: + if not settings.use_python_runtime: + logger.info( + "Since Torch-TensorRT runtime is not available, using Python Runtime, some features may not be available" + ) return PythonTorchTensorRTModule( engine=interpreter_result.engine, input_names=list(interpreter_result.input_names), diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index f9d14917f1..04f048c5f3 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -4,23 +4,21 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, overload import numpy as np -import tensorrt as trt import torch from torch import SymBool, SymFloat, SymInt from torch.fx.node import Argument, Target +from torch_tensorrt import _enums from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( ConverterRegistry, DynamoConverterImplSignature, ) -from torch_tensorrt.fx.converters.converter_utils import ( - Frameworks, - get_axes_for_reduce_op, - unified_dtype_converter, -) +from torch_tensorrt.fx.converters.converter_utils import get_axes_for_reduce_op from torch_tensorrt.fx.types import TRTDataType, TRTTensor +import tensorrt as trt + _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -44,7 +42,6 @@ def get_node_name(node: torch.fx.Node) -> str: # like the node.meta['source_fn'] attr pass - _LOGGER.debug(f"Node meta name {node_name}") return node_name @@ -121,7 +118,7 @@ def _is_subnode_dynamic(subnode: torch.fx.Node) -> bool: def cast_trt_tensor( ctx: ConversionContext, input_val: TRTTensor, - dtype: TRTDataType, + dtype: Union[TRTDataType, torch.dtype, np.dtype, _enums.dtype], name: str, target: Target = "", source_ir: Optional[SourceIR] = None, @@ -142,7 +139,7 @@ def cast_trt_tensor( Returns: A TensorRT ITensor which has been casted to the specified dtype """ - trt_dtype = unified_dtype_converter(dtype, Frameworks.TRT) + trt_dtype = _enums.dtype._from(dtype).to(trt.DataType) if input_val.dtype != trt_dtype: source_ir = source_ir if source_ir is not None else SourceIR.UNKNOWN @@ -253,7 +250,7 @@ def create_constant( ctx: ConversionContext, value: Union[int, float, bool, np.ndarray, torch.Tensor], name: str, - dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType]], + dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType, _enums.dtype]], ) -> TRTTensor: """ Add a TensorRT constant layer whose value is `value` to `ctx.net`. @@ -268,7 +265,9 @@ def create_constant( Returns: A TensorRT ITensor that represents the given value. """ - numpy_value = to_numpy(value, dtype) + numpy_value = to_numpy( + value, _enums.dtype._from(dtype).to(np.dtype) if dtype is not None else None + ) constant = ctx.net.add_constant( (1,) if isinstance(value, (int, float, bool)) else value.shape, numpy_value.copy() if isinstance(numpy_value, np.ndarray) else numpy_value, @@ -281,7 +280,7 @@ def get_trt_tensor( ctx: ConversionContext, input_val: Any, name: str, - dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType]] = None, + dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType, _enums.dtype]] = None, ) -> TRTTensor: """ Given a value of random type, we try to convert it to a TensorRT ITensor. @@ -466,7 +465,7 @@ def convert_with_type_enforcement( def to_numpy( value: Optional[Union[torch.Tensor, np.ndarray, int, float, bool]], - dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType]] = None, + dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType, _enums.dtype]] = None, ) -> Optional[np.ndarray]: """ Convert a PyTorch Tensor, Numpy array, or scalar to a Numpy Array. If the tensor is @@ -503,7 +502,7 @@ def to_numpy( return ( output if (dtype is None or output is None) - else output.astype(unified_dtype_converter(dtype, Frameworks.NUMPY)) + else output.astype(_enums.dtype._from(dtype).to(np.dtype)) ) else: raise AssertionError( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/cast.py b/py/torch_tensorrt/dynamo/conversion/impl/cast.py index bc6af1a32d..b6d024eb08 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/cast.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/cast.py @@ -1,17 +1,18 @@ import logging -from typing import Optional +from typing import Optional, Union +import numpy as np +import torch from torch.fx.node import Target +from torch_tensorrt import _enums from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion._ConverterRegistry import ConverterRegistry from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor -from torch_tensorrt.fx.converters.converter_utils import ( - Frameworks, - unified_dtype_converter, -) from torch_tensorrt.fx.types import TRTDataType, TRTTensor +import tensorrt as trt + LOGGER: logging.Logger = logging.getLogger(__name__) @@ -21,7 +22,7 @@ def to_copy( source_ir: Optional[SourceIR], name: str, input: TRTTensor, - dtype: TRTDataType, + dtype: Union[TRTDataType, torch.dtype, np.dtype, _enums.dtype], force_layer: bool = False, ) -> TRTTensor: if not isinstance(input, TRTTensor): @@ -32,7 +33,7 @@ def to_copy( # If cast is forced, insert identity layer regardless of whether the dtype # doesn't change if force_layer: - trt_dtype = unified_dtype_converter(dtype, Frameworks.TRT) + trt_dtype = _enums.dtype._from(dtype).to(trt.DataType) source_ir = source_ir if source_ir is not None else SourceIR.UNKNOWN target_str = ConverterRegistry.qualified_name_or_str(target) target_name = f"{source_ir}_ops{('.' + target_str) if target_str else ''}" diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py index 8282ee8698..ce4d70cef5 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py @@ -3,9 +3,9 @@ from typing import Any, Callable, Optional, Union import numpy as np -import tensorrt as trt import torch from torch.fx.node import Target +from torch_tensorrt import _enums from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import ( @@ -14,7 +14,8 @@ ) from torch_tensorrt.fx.converters.converter_utils import broadcast, set_layer_name from torch_tensorrt.fx.types import TRTElementWiseOp, TRTTensor -from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter + +import tensorrt as trt def get_python_op_from_trt_elementwise_op( @@ -121,22 +122,20 @@ def convert_binary_elementwise( # dtype but we don't have a way to detect whether it makes sense for the # scalar to be float or half. Hence we go with the lhs dtype. if is_lhs_trt_tensor and isinstance(rhs_val, (float, int, bool)): - rhs_val = np.array( - [rhs_val], dtype=unified_dtype_converter(lhs_dtype, Frameworks.NUMPY) - ) + rhs_val = np.array([rhs_val], dtype=_enums.dtype._from(lhs_dtype).to(np.dtype)) if is_rhs_trt_tensor and isinstance(lhs_val, (float, int, bool)): - lhs_val = np.array( - [lhs_val], dtype=unified_dtype_converter(rhs_dtype, Frameworks.NUMPY) - ) + lhs_val = np.array([lhs_val], dtype=_enums.dtype._from(rhs_dtype).to(np.dtype)) lhs_val = get_trt_tensor(ctx, lhs_val, f"{name}_lhs", lhs_dtype) rhs_val = get_trt_tensor(ctx, rhs_val, f"{name}_rhs", rhs_dtype) - promoted_type = torch.promote_types( - unified_dtype_converter(lhs_val.dtype, Frameworks.TORCH), - unified_dtype_converter(rhs_val.dtype, Frameworks.TORCH), + promoted_type = _enums.dtype._from( + torch.promote_types( + _enums.dtype._from(lhs_val.dtype).to(torch.dtype), + _enums.dtype._from(rhs_val.dtype).to(torch.dtype), + ) ) - trt_promoted_type = unified_dtype_converter(promoted_type, Frameworks.TRT) + trt_promoted_type = promoted_type.to(trt.DataType) if trt_promoted_type != lhs_val.dtype: lhs_val = cast_trt_tensor( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py index bece21c033..30a5203eed 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py @@ -1,10 +1,10 @@ from typing import Optional, Union import numpy as np -import tensorrt as trt import torch import torch_tensorrt.dynamo.conversion.impl as impl from torch.fx.node import Target +from torch_tensorrt import _enums from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import ( @@ -20,7 +20,8 @@ from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary from torch_tensorrt.fx.converters.converter_utils import broadcast from torch_tensorrt.fx.types import TRTTensor -from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter + +import tensorrt as trt def trunc_div( @@ -73,7 +74,7 @@ def trunc_div( ctx, other, f"{name}_other", - dtype=unified_dtype_converter(input.dtype, Frameworks.TORCH), + dtype=_enums.dtype._from(input.dtype).to(torch.dtype), ) abs_input_output = convert_unary( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/grid.py b/py/torch_tensorrt/dynamo/conversion/impl/grid.py index 672fc97351..8981eca73c 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/grid.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/grid.py @@ -1,13 +1,12 @@ -from typing import Optional, Sequence +from typing import Optional -import tensorrt as trt -import torch from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext -from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor from torch_tensorrt.fx.converters.converter_utils import set_layer_name -from torch_tensorrt.fx.types import TRTNetwork, TRTTensor +from torch_tensorrt.fx.types import TRTTensor + +import tensorrt as trt # nearest, linear, cubic GridSamplerInterpolationMode = { diff --git a/py/torch_tensorrt/dynamo/conversion/impl/matmul.py b/py/torch_tensorrt/dynamo/conversion/impl/matmul.py index a50ec3c434..5ea29622c8 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/matmul.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/matmul.py @@ -1,13 +1,15 @@ from typing import Optional -import tensorrt as trt +import torch from torch.fx.node import Target +from torch_tensorrt import _enums from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor from torch_tensorrt.fx.converters.converter_utils import broadcast, set_layer_name from torch_tensorrt.fx.types import TRTTensor -from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter + +import tensorrt as trt def matrix_multiply( @@ -20,14 +22,14 @@ def matrix_multiply( input_matrix_op: trt.MatrixOperation = trt.MatrixOperation.NONE, other_matrix_op: trt.MatrixOperation = trt.MatrixOperation.NONE, ) -> TRTTensor: - if not isinstance(input, trt.tensorrt.ITensor): + if not isinstance(input, trt.ITensor): input = get_trt_tensor(ctx, input, f"{name}_input") - if not isinstance(other, trt.tensorrt.ITensor): + if not isinstance(other, trt.ITensor): other = get_trt_tensor( ctx, other, f"{name}_other", - dtype=unified_dtype_converter(input.dtype, Frameworks.TORCH), + dtype=_enums.dtype._from(input.dtype).to(torch.dtype), ) preset_diff = 0 diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index 470abb8f48..a4507ece3e 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -2,7 +2,6 @@ from typing import Optional, Sequence, Union, cast import numpy as np -import tensorrt as trt import torch from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR @@ -21,6 +20,8 @@ ) from torch_tensorrt.fx.types import Shape, TRTTensor +import tensorrt as trt + _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -81,7 +82,7 @@ def index( source_ir: Optional[SourceIR], name: str, input: TRTTensor, - index: Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]], + indices: Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]], ) -> TRTTensor: adv_indx_indices = [] tensor_indices = [] @@ -93,12 +94,14 @@ def index( "Determining whether aten.index constant-index optimization can be invoked" ) is_numpy = all( - isinstance(ind, (torch.Tensor, np.ndarray)) for ind in index if ind is not None + isinstance(ind, (torch.Tensor, np.ndarray)) + for ind in indices + if ind is not None ) # here we need to check if all the index are broadcastable # if no, then we need to broadcast last_index = None - for i, ind in enumerate(index): + for i, ind in enumerate(indices): if ind is not None: _LOGGER.debug(f"Shape of {i} index is {ind.shape}") adv_indx_indices.append(i) @@ -369,7 +372,7 @@ def index( ) reshape_output = reshape_layer.get_output(0) - return reshape_output + return reshape_output def index_select( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/upsample.py b/py/torch_tensorrt/dynamo/conversion/impl/upsample.py index 3313730ec3..e2e8481e24 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/upsample.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/upsample.py @@ -1,12 +1,13 @@ from typing import Optional, Sequence -import tensorrt as trt from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.fx.converters.converter_utils import set_layer_name from torch_tensorrt.fx.types import TRTTensor +import tensorrt as trt + def upsample( ctx: ConversionContext, @@ -29,7 +30,7 @@ def upsample( resize_layer.scales = [1.0, 1.0] + list(scale_factors) else: raise RuntimeError( - f"At least one of out_shape and scale_factors should be specified." + "At least one of out_shape and scale_factors should be specified." ) # interpolate mode diff --git a/py/torch_tensorrt/dynamo/conversion/truncate_long_and_double.py b/py/torch_tensorrt/dynamo/conversion/truncate_long_and_double.py index 9390bc3bde..d5670be1db 100644 --- a/py/torch_tensorrt/dynamo/conversion/truncate_long_and_double.py +++ b/py/torch_tensorrt/dynamo/conversion/truncate_long_and_double.py @@ -4,6 +4,7 @@ import torch from torch.fx.node import _get_qualified_name +from torch_tensorrt._enums import dtype from torch_tensorrt._Input import Input from torch_tensorrt.dynamo.utils import get_torch_inputs @@ -217,6 +218,8 @@ def repair_long_or_double_inputs( # Set the 32bit inputs and their types to the submodule Inputs for idx in range(len(submodule_inputs)): submodule_inputs[idx].torch_tensor = submodule_torch_inputs[idx] - submodule_inputs[idx].torch_dtype = submodule_torch_inputs[idx].dtype + submodule_inputs[idx].dtype = dtype._from( + submodule_torch_inputs[idx].dtype + ) return submodule_inputs diff --git a/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py index e263b51bb2..c00d92577c 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py @@ -110,6 +110,7 @@ def __init__( allowed_single_node_partition_ops: Optional[Collection[str]] = None, min_block_size: int = MIN_BLOCK_SIZE, require_full_compilation: bool = REQUIRE_FULL_COMPILATION, + return_tuple: bool = False, ): """ Preprocesses graph before splitting: @@ -149,6 +150,7 @@ def __init__( self.num_trt_accelerated_subgraphs: Optional[int] = None self.allowed_single_node_partition_ops = allowed_single_node_partition_ops self.require_full_compilation = require_full_compilation + self._return_tuple = return_tuple def remove_small_acc_subgraphs(self, subgraphs: List[Subgraph]) -> List[Subgraph]: """ diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 3a66ed3716..20762731b0 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -4,17 +4,18 @@ from contextlib import nullcontext from typing import Any, Dict, List, Optional, Sequence, Tuple -import tensorrt as trt import torch import torch_tensorrt from torch.nn import Module from torch_tensorrt._Device import Device +from torch_tensorrt._enums import dtype from torch_tensorrt.dynamo.runtime.tools import ( _is_switch_required, _select_rt_device, multi_gpu_device_check, ) -from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter + +import tensorrt as trt logger = logging.getLogger(__name__) @@ -84,9 +85,7 @@ def _initialize(self) -> None: ) self.input_dtypes = [ - unified_dtype_converter( - self.engine.get_binding_dtype(idx), Frameworks.TORCH - ) + dtype._from(self.engine.get_binding_dtype(idx)) for idx in self.input_binding_indices_in_order ] self.input_shapes: Sequence[Sequence[int]] = [ @@ -94,9 +93,7 @@ def _initialize(self) -> None: for idx in self.input_binding_indices_in_order ] self.output_dtypes = [ - unified_dtype_converter( - self.engine.get_binding_dtype(idx), Frameworks.TORCH - ) + dtype._from(self.engine.get_binding_dtype(idx)) for idx in self.output_binding_indices_in_order ] self.output_shapes = [ @@ -108,9 +105,7 @@ def _initialize(self) -> None: for idx in self.output_binding_indices_in_order ] self.hidden_output_dtypes = [ - unified_dtype_converter( - self.engine.get_binding_dtype(idx), Frameworks.TORCH - ) + dtype._from(self.engine.get_binding_dtype(idx)) for idx in self.hidden_output_binding_indices_in_order ] self.hidden_output_shapes = [ @@ -263,7 +258,7 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . output = torch.empty( size=shape, - dtype=self.output_dtypes[i], + dtype=self.output_dtypes[i].to(torch.dtype), device=torch.cuda.current_device(), ) outputs.append(output) @@ -274,7 +269,7 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . output = torch.empty( size=shape, - dtype=self.hidden_output_dtypes[i], + dtype=self.hidden_output_dtypes[i].to(torch.dtype), device=torch.cuda.current_device(), ) bindings[idx] = output.data_ptr() diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 22590fe73d..29b01990ce 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -6,11 +6,11 @@ import torch from torch_tensorrt._Device import Device +from torch_tensorrt._enums import dtype from torch_tensorrt._Input import Input -from torch_tensorrt.dynamo._defaults import PRECISION +from torch_tensorrt.dynamo import _defaults from torch_tensorrt.dynamo._settings import CompilationSettings -import torch_tensorrt from packaging import version logger = logging.getLogger(__name__) @@ -196,10 +196,7 @@ def to_torch_device(device: Optional[Union[Device, torch.device, str]]) -> torch Returns the corresponding torch.device """ if isinstance(device, Device): - if device.gpu_id != -1: - return torch.device(device.gpu_id) - else: - raise ValueError("Invalid GPU ID provided for the CUDA device provided") + return device.to(torch.device) elif isinstance(device, torch.device): return device @@ -218,17 +215,7 @@ def to_torch_tensorrt_device( Returns the corresponding torch_tensorrt.Device """ - if isinstance(device, Device): - return device - - elif isinstance(device, torch.device): - return Device(gpu_id=device.index) - - elif device is None: - return Device(gpu_id=torch.cuda.current_device()) - - else: - return Device(device) + return Device._from(device) def parse_dynamo_kwargs(kwargs: Any) -> CompilationSettings: @@ -257,25 +244,15 @@ def parse_dynamo_kwargs(kwargs: Any) -> CompilationSettings: # TODO: Remove once Dynamo precisions refactoring is complete if "enabled_precisions" in kwargs: - enabled_precisions = kwargs["enabled_precisions"] - - if ( - torch.float16 in enabled_precisions - or torch_tensorrt.dtype.half in enabled_precisions - ): - settings.precision = torch.float16 - elif ( - torch.float32 in enabled_precisions - or torch_tensorrt.dtype.float in enabled_precisions - ): - settings.precision = torch.float32 - elif len(enabled_precisions) == 0: - logger.info(f"No precision specified, defaulting to {PRECISION}") - settings.precision = PRECISION - else: - raise ValueError( - f"Precision {enabled_precisions} not supported in the Dynamo Path" + enabled_precisions = {dtype._from(e) for e in kwargs["enabled_precisions"]} + + if len(enabled_precisions) == 0: + logger.info( + f"No precision specified, defaulting to {_defaults.ENABLED_PRECISION}" ) + enabled_precisions = _defaults.ENABLED_PRECISIONS + + settings.enabled_precisions = enabled_precisions # Parse input runtime specification settings.use_python_runtime = use_python_runtime_parser(settings.use_python_runtime) diff --git a/py/torch_tensorrt/logging.py b/py/torch_tensorrt/logging.py index e48e3c6317..4cbb686b0d 100644 --- a/py/torch_tensorrt/logging.py +++ b/py/torch_tensorrt/logging.py @@ -1,118 +1,35 @@ -from enum import Enum +import logging from typing import Any -from torch_tensorrt._C import ( - LogLevel, - _get_is_colored_output_on, - _get_logging_prefix, - _get_reportable_log_level, - _log, - _set_is_colored_output_on, - _set_logging_prefix, - _set_reportable_log_level, -) - - -class Level(Enum): - """Enum to set the minimum required logging level to print a message to stdout""" - - InternalError = LogLevel.INTERNAL_ERROR - Error = LogLevel.ERROR - Warning = LogLevel.WARNING - Info = LogLevel.INFO - Debug = LogLevel.DEBUG - Graph = LogLevel.GRAPH - - @staticmethod - def _to_internal_level(external: "Level") -> LogLevel: - if external == Level.InternalError: - return LogLevel.INTERNAL_ERROR - elif external == Level.Error: - return LogLevel.ERROR - elif external == Level.Warning: - return LogLevel.WARNING - elif external == Level.Info: - return LogLevel.INFO - elif external == Level.Debug: - return LogLevel.DEBUG - elif external == Level.Graph: - return LogLevel.GRAPH - else: - raise ValueError("Unknown log severity") - - -def get_logging_prefix() -> str: - """Get the prefix set for logging messages - - Returns: - str: Prefix used for logger - """ - return str(_get_logging_prefix()) - - -def set_logging_prefix(prefix: str) -> None: - """Set the prefix used when logging messages - - Args: - prefix (str): Prefix to use for logging messages - """ - _set_logging_prefix(prefix) - - -def get_reportable_log_level() -> Level: - """Get the level required for a message to be printed in the log - - Returns: - torch_tensorrt.logging.Level: The enum representing the level required to print - """ - return Level(_get_reportable_log_level()) - - -def set_reportable_log_level(level: Level) -> None: - """Set the level required for a message to be printed to the log +from torch_tensorrt._features import ENABLED_FEATURES - Args: - level (torch_tensorrt.logging.Level): The enum representing the level required to print - """ - _set_reportable_log_level(Level._to_internal_level(level)) - - -def get_is_colored_output_on() -> bool: - """Get if colored output is enabled for logging +import tensorrt as trt - Returns: - bool: If colored output is one - """ - return bool(_get_is_colored_output_on()) +logging.captureWarnings(True) +_LOGGER = logging.getLogger("torch_tensorrt [TensorRT Conversion Context]") -def set_is_colored_output_on(colored_output_on: bool) -> None: - """Enable or disable color in the log output +class _TRTLogger(trt.ILogger): # type: ignore[misc] - Args: - colored_output_on (bool): If colored output should be enabled or not - """ - _set_is_colored_output_on(colored_output_on) + def __init__(self) -> None: + trt.ILogger.__init__(self) + def log(self, severity: trt.ILogger.Severity, msg: str) -> None: + # TODO: Move to match once py39 reaches EoL + if severity == trt.ILogger.Severity.INTERNAL_ERROR: + _LOGGER.critical(msg) + raise RuntimeError(msg) + elif severity == trt.ILogger.Severity.ERROR: + _LOGGER.error(msg) + elif severity == trt.ILogger.Severity.WARNING: + _LOGGER.warning(msg) + elif severity == trt.ILogger.Severity.INFO: + _LOGGER.info(msg) + elif severity == trt.ILogger.Severity.VERBOSE: + _LOGGER.debug(msg) -def log(level: Level, msg: str) -> None: - """Add a new message to the log - Adds a new message to the log at a specified level. The message - will only get printed out if Level > reportable_log_level - - Args: - level (torch_tensorrt.logging.Level): Severity of the message - msg (str): Actual message text - """ - _log(Level._to_internal_level(level), msg) - - InternalError = LogLevel.INTERNAL_ERROR - Error = LogLevel.ERROR - Warning = LogLevel.WARNING - Info = LogLevel.INFO - Debug = LogLevel.DEBUG - Graph = LogLevel.GRAPH +TRT_LOGGER = _TRTLogger() class internal_errors: @@ -125,11 +42,22 @@ class internal_errors: """ def __enter__(self) -> None: - self.external_lvl = get_reportable_log_level() - set_reportable_log_level(Level.InternalError) + self.external_lvl = _LOGGER.getEffectiveLevel() + _LOGGER.setLevel(logging.CRITICAL) + + if ENABLED_FEATURES.torchscript_frontend: + from torch_tensorrt.ts import logging as ts_logging + + self.ts_level = ts_logging.get_reportable_log_level() + ts_logging.set_reportable_log_level(ts_logging.Level.InternalError) def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: - set_reportable_log_level(self.external_lvl) + _LOGGER.setLevel(self.external_lvl) + + if ENABLED_FEATURES.torchscript_frontend: + from torch_tensorrt.ts import logging as ts_logging + + ts_logging.set_reportable_log_level(self.ts_level) class errors: @@ -142,11 +70,22 @@ class errors: """ def __enter__(self) -> None: - self.external_lvl = get_reportable_log_level() - set_reportable_log_level(Level.Error) + self.external_lvl = _LOGGER.getEffectiveLevel() + _LOGGER.setLevel(logging.ERROR) + + if ENABLED_FEATURES.torchscript_frontend: + from torch_tensorrt.ts import logging as ts_logging + + self.ts_level = ts_logging.get_reportable_log_level() + ts_logging.set_reportable_log_level(ts_logging.Level.Error) def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: - set_reportable_log_level(self.external_lvl) + _LOGGER.setLevel(self.external_lvl) + + if ENABLED_FEATURES.torchscript_frontend: + from torch_tensorrt.ts import logging as ts_logging + + ts_logging.set_reportable_log_level(self.ts_level) class warnings: @@ -159,11 +98,22 @@ class warnings: """ def __enter__(self) -> None: - self.external_lvl = get_reportable_log_level() - set_reportable_log_level(Level.Warning) + self.external_lvl = _LOGGER.getEffectiveLevel() + _LOGGER.setLevel(logging.WARNING) + + if ENABLED_FEATURES.torchscript_frontend: + from torch_tensorrt.ts import logging as ts_logging + + self.ts_level = ts_logging.get_reportable_log_level() + ts_logging.set_reportable_log_level(ts_logging.Level.Warning) def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: - set_reportable_log_level(self.external_lvl) + _LOGGER.setLevel(self.external_lvl) + + if ENABLED_FEATURES.torchscript_frontend: + from torch_tensorrt.ts import logging as ts_logging + + ts_logging.set_reportable_log_level(self.ts_level) class info: @@ -176,11 +126,22 @@ class info: """ def __enter__(self) -> None: - self.external_lvl = get_reportable_log_level() - set_reportable_log_level(Level.Info) + self.external_lvl = _LOGGER.getEffectiveLevel() + _LOGGER.setLevel(logging.INFO) + + if ENABLED_FEATURES.torchscript_frontend: + from torch_tensorrt.ts import logging as ts_logging + + self.ts_level = ts_logging.get_reportable_log_level() + ts_logging.set_reportable_log_level(ts_logging.Level.Info) def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: - set_reportable_log_level(self.external_lvl) + _LOGGER.setLevel(self.external_lvl) + + if ENABLED_FEATURES.torchscript_frontend: + from torch_tensorrt.ts import logging as ts_logging + + ts_logging.set_reportable_log_level(self.ts_level) class debug: @@ -193,11 +154,22 @@ class debug: """ def __enter__(self) -> None: - self.external_lvl = get_reportable_log_level() - set_reportable_log_level(Level.Debug) + self.external_lvl = _LOGGER.getEffectiveLevel() + _LOGGER.setLevel(logging.DEBUG) + + if ENABLED_FEATURES.torchscript_frontend: + from torch_tensorrt.ts import logging as ts_logging + + self.ts_level = ts_logging.get_reportable_log_level() + ts_logging.set_reportable_log_level(ts_logging.Level.Debug) def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: - set_reportable_log_level(self.external_lvl) + _LOGGER.setLevel(self.external_lvl) + + if ENABLED_FEATURES.torchscript_frontend: + from torch_tensorrt.ts import logging as ts_logging + + ts_logging.set_reportable_log_level(self.ts_level) class graphs: @@ -211,8 +183,19 @@ class graphs: """ def __enter__(self) -> None: - self.external_lvl = get_reportable_log_level() - set_reportable_log_level(Level.Graph) + self.external_lvl = _LOGGER.getEffectiveLevel() + _LOGGER.setLevel(logging.NOTSET) + + if ENABLED_FEATURES.torchscript_frontend: + from torch_tensorrt.ts import logging as ts_logging + + self.ts_level = ts_logging.get_reportable_log_level() + ts_logging.set_reportable_log_level(ts_logging.Level.Graph) def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: - set_reportable_log_level(self.external_lvl) + _LOGGER.setLevel(self.external_lvl) + + if ENABLED_FEATURES.torchscript_frontend: + from torch_tensorrt.ts import logging as ts_logging + + ts_logging.set_reportable_log_level(self.ts_level) diff --git a/py/torch_tensorrt/ts/_Device.py b/py/torch_tensorrt/ts/_Device.py new file mode 100644 index 0000000000..3ae10a9c4d --- /dev/null +++ b/py/torch_tensorrt/ts/_Device.py @@ -0,0 +1,69 @@ +import sys +from typing import Any + +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self + +import warnings + +from torch_tensorrt._Device import Device + +try: + from torch_tensorrt import _C +except ImportError: + warnings.warn( + "Unable to import torchscript frontend core and torch-tensorrt runtime. Some dependent features may be unavailable." + ) + + +class TorchScriptDevice(Device): + """ + Defines a device that can be used to specify target devices for engines + + Attributes: + device_type (torch_tensorrt.DeviceType): Target device type (GPU or DLA). Set implicitly based on if dla_core is specified. + gpu_id (int): Device ID for target GPU + dla_core (int): Core ID for target DLA core + allow_gpu_fallback (bool): Whether falling back to GPU if DLA cannot support an op should be allowed + """ + + def __init__(self, *args: Any, **kwargs: Any): + """__init__ Method for torch_tensorrt.Device + + Device accepts one of a few construction patterns + + Args: + spec (str): String with device spec e.g. "dla:0" for dla, core_id 0 + + Keyword Arguments: + gpu_id (int): ID of target GPU (will get overrided if dla_core is specified to the GPU managing DLA). If specified, no positional arguments should be provided + dla_core (int): ID of target DLA core. If specified, no positional arguments should be provided. + allow_gpu_fallback (bool): Allow TensorRT to schedule operations on GPU if they are not supported on DLA (ignored if device type is not DLA) + + Examples: + - Device("gpu:1") + - Device("cuda:1") + - Device("dla:0", allow_gpu_fallback=True) + - Device(gpu_id=0, dla_core=0, allow_gpu_fallback=True) + - Device(dla_core=0, allow_gpu_fallback=True) + - Device(gpu_id=1) + """ + super().__init__(*args, **kwargs) + + def _to_internal(self) -> _C.Device: + internal_dev = _C.Device() + internal_dev.device_type = self.device_type.to(_C.DeviceType) + internal_dev.gpu_id = self.gpu_id + internal_dev.dla_core = self.dla_core + internal_dev.allow_gpu_fallback = self.allow_gpu_fallback + return internal_dev + + @classmethod + def _from(cls, d: object) -> Self: + return cls( + gpu_id=d.gpu_id, + dla_core=d.dla_core, + allow_gpu_fallback=d.allow_gpu_fallback, + ) diff --git a/py/torch_tensorrt/ts/_Input.py b/py/torch_tensorrt/ts/_Input.py index f9cbf2c333..6099efbcd2 100644 --- a/py/torch_tensorrt/ts/_Input.py +++ b/py/torch_tensorrt/ts/_Input.py @@ -1,6 +1,7 @@ from typing import Any -from torch_tensorrt import _C, _enums +from torch_tensorrt import _C +from torch_tensorrt._enums import dtype from torch_tensorrt._Input import Input @@ -49,11 +50,16 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: """ super().__init__(*args, **kwargs) + def is_trt_dtype(self) -> bool: + return bool(self.dtype != dtype.long) + def _to_internal(self) -> _C.Input: internal_in = _C.Input() if self.shape_mode == Input._ShapeMode.DYNAMIC: if isinstance(self.shape, dict): - if not Input._supported_input_size_type(self.shape["min_shape"]): + if not TorchScriptInput._supported_input_size_type( + self.shape["min_shape"] + ): raise TypeError( "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: " + str(type(self.shape["min_shape"])) @@ -62,7 +68,9 @@ def _to_internal(self) -> _C.Input: else: internal_in.min = self.shape["min_shape"] - if not Input._supported_input_size_type(self.shape["opt_shape"]): + if not TorchScriptInput._supported_input_size_type( + self.shape["opt_shape"] + ): raise TypeError( "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: " + str(type(self.shape["opt_shape"])) @@ -71,7 +79,9 @@ def _to_internal(self) -> _C.Input: else: internal_in.opt = self.shape["opt_shape"] - if not Input._supported_input_size_type(self.shape["max_shape"]): + if not TorchScriptInput._supported_input_size_type( + self.shape["max_shape"] + ): raise TypeError( "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: " + str(type(self.shape["max_shape"])) @@ -81,7 +91,7 @@ def _to_internal(self) -> _C.Input: internal_in.max = self.shape["max_shape"] internal_in.input_is_dynamic = True else: - if not Input._supported_input_size_type(self.shape): + if not TorchScriptInput._supported_input_size_type(self.shape): raise TypeError( "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: " + str(type(self.shape)) @@ -91,14 +101,11 @@ def _to_internal(self) -> _C.Input: internal_in.opt = self.shape internal_in.input_is_dynamic = False - if self.dtype != _enums.dtype.unknown: - self._explicit_set_dtype = True - else: - self._explicit_set_dtype = False - - internal_in.dtype = Input._parse_dtype(self.dtype) + internal_in.dtype = self.dtype.to(_C.dtype) internal_in._explicit_set_dtype = self._explicit_set_dtype - internal_in.format = Input._parse_format(self.format) + internal_in.format = self.format.to(_C.TensorFormat) - internal_in.tensor_domain = Input._parse_tensor_domain(self.tensor_domain) + internal_in.tensor_domain = TorchScriptInput._parse_tensor_domain( + self.tensor_domain + ) return internal_in diff --git a/py/torch_tensorrt/ts/__init__.py b/py/torch_tensorrt/ts/__init__.py index 5cb45cba5c..d11db42c68 100644 --- a/py/torch_tensorrt/ts/__init__.py +++ b/py/torch_tensorrt/ts/__init__.py @@ -1,3 +1,5 @@ from torch_tensorrt.ts._compile_spec import TensorRTCompileSpec # noqa: F401 from torch_tensorrt.ts._compiler import * # noqa: F403 +from torch_tensorrt.ts._Device import TorchScriptDevice # noqa: F401 +from torch_tensorrt.ts._enums import * # noqa: F403 from torch_tensorrt.ts._Input import TorchScriptInput # noqa: F401 diff --git a/py/torch_tensorrt/ts/_compile_spec.py b/py/torch_tensorrt/ts/_compile_spec.py index 37f5fb79e3..1574de02f3 100644 --- a/py/torch_tensorrt/ts/_compile_spec.py +++ b/py/torch_tensorrt/ts/_compile_spec.py @@ -3,14 +3,17 @@ from copy import deepcopy from typing import Any, Dict, List, Optional, Set -import tensorrt as trt import torch import torch_tensorrt._C.ts as _ts_C -from torch_tensorrt import _C, _enums +from torch_tensorrt import _C from torch_tensorrt._Device import Device +from torch_tensorrt._enums import DeviceType, EngineCapability, dtype from torch_tensorrt._Input import Input -from torch_tensorrt.logging import Level, log +from torch_tensorrt.ts._Device import TorchScriptDevice from torch_tensorrt.ts._Input import TorchScriptInput +from torch_tensorrt.ts.logging import Level, log + +import tensorrt as trt def _internal_input_to_torch_class_input(i: _C.Input) -> torch.classes.tensorrt._Input: @@ -40,31 +43,11 @@ def _supported_input_size_type(input_size: Any) -> bool: ) -def _parse_op_precision(precision: Any) -> _enums.dtype: - if isinstance(precision, torch.dtype): - if precision == torch.int8: - return _enums.dtype.int8 - elif precision == torch.half: - return _enums.dtype.half - elif precision == torch.float: - return _enums.dtype.float - else: - raise TypeError( - "Provided an unsupported dtype as operating precision (support: int8, half, float), got: " - + str(precision) - ) - - elif isinstance(precision, _enums.dtype): - return precision - - else: - raise TypeError( - "Op precision type needs to be specified with a torch.dtype or a torch_tensorrt.dtype, got: " - + str(type(precision)) - ) +def _parse_op_precision(precision: Any) -> _C.dtype: + return dtype._from(precision).to(_C.dtype) -def _parse_enabled_precisions(precisions: Any) -> Set[_enums.dtype]: +def _parse_enabled_precisions(precisions: Any) -> Set[_C.dtype]: parsed_precisions = set() if any(isinstance(precisions, type) for type in [list, tuple, set]): for p in precisions: @@ -74,36 +57,8 @@ def _parse_enabled_precisions(precisions: Any) -> Set[_enums.dtype]: return parsed_precisions -def _parse_device_type(device: Any) -> _enums.DeviceType: - if isinstance(device, torch.device): - if device.type == "cuda": - return _C.DeviceType.gpu - else: - ValueError( - "Got a device type other than GPU or DLA (type: " - + str(device.type) - + ")" - ) - elif isinstance(device, _C.DeviceType): - return device - elif isinstance(device, trt.DeviceType): - if device == trt.DeviceType.DLA: - return _C.DeviceType.DLA - return _C.DeviceType.GPU - elif isinstance(device, str): - if device == "gpu" or device == "GPU": - return _C.DeviceType.GPU - elif device == "dla" or device == "DLA": - return _C.DeviceType.DLA - else: - ValueError( - "Got a device type other than GPU or DLA (type: " + str(device) + ")" - ) - else: - raise TypeError( - "Device specification must be of type torch.device, string or torch_tensorrt.DeviceType, but got: " - + str(type(device)) - ) +def _parse_device_type(device: Any) -> _C.DeviceType: + return DeviceType._from(device).to(_C.DeviceType) def _parse_device(device_info: Any) -> _C.Device: @@ -128,9 +83,11 @@ def _parse_device(device_info: Any) -> _C.Device: return info elif isinstance(device_info, Device): + return TorchScriptDevice._from(device_info)._to_internal() + elif isinstance(device_info, TorchScriptDevice): return device_info._to_internal() elif isinstance(device_info, torch.device): - return (Device._from_torch_device(device_info))._to_internal() + return TorchScriptDevice._from(device_info)._to_internal() else: raise ValueError( "Unsupported data for device specification. Expected either a dict, torch_tensorrt.Device or torch.Device" @@ -184,9 +141,11 @@ def _parse_input_signature(input_signature: Any, depth: int = 0) -> Any: else input_signature ) - if not i.is_trt_dtype(): + if not i.dtype.try_to(trt.DataType, use_default=True): raise TypeError( - "Using non-TRT input types with input_signature is not currently " + "Using non-TRT input types ({}) with input_signature is not currently ".format( + i.dtype + ) + "supported. Please specify inputs individually to use " + "non-TRT types." ) @@ -246,7 +205,9 @@ def _parse_compile_spec(compile_spec_: Dict[str, Any]) -> _ts_C.CompileSpec: if i.shape_mode == Input._ShapeMode.STATIC: ts_inputs.append( TorchScriptInput( - shape=i.shape, dtype=i.dtype, format=i.format + shape=i.shape, + dtype=i.dtype.to(_C.dtype), + format=i.format.to(_C.TensorFormat), )._to_internal() ) elif i.shape_mode == Input._ShapeMode.DYNAMIC: @@ -255,8 +216,8 @@ def _parse_compile_spec(compile_spec_: Dict[str, Any]) -> _ts_C.CompileSpec: min_shape=i.shape["min_shape"], opt_shape=i.shape["opt_shape"], max_shape=i.shape["max_shape"], - dtype=i.dtype, - format=i.format, + dtype=i.dtype.to(_C.dtype), + format=i.format.to(_C.TensorFormat), )._to_internal() ) info.inputs = ts_inputs @@ -306,8 +267,10 @@ def _parse_compile_spec(compile_spec_: Dict[str, Any]) -> _ts_C.CompileSpec: info.device = _parse_device(compile_spec["device"]) if "capability" in compile_spec: - assert isinstance(compile_spec["capability"], _enums.EngineCapability) - info.capability = compile_spec["capability"] + capability = EngineCapability._from(compile_spec["capability"]).to( + _C.EngineCapability + ) + info.capability = capability if "num_avg_timing_iters" in compile_spec: assert type(compile_spec["num_avg_timing_iters"]) is int @@ -347,10 +310,10 @@ def TensorRTCompileSpec( device: torch.device | Device = Device._current_device(), disable_tf32: bool = False, sparse_weights: bool = False, - enabled_precisions: Optional[Set[torch.dtype | _enums.dtype]] = None, + enabled_precisions: Optional[Set[torch.dtype | dtype]] = None, refit: bool = False, debug: bool = False, - capability: _enums.EngineCapability = _enums.EngineCapability.default, + capability: EngineCapability = EngineCapability.STANDARD, num_avg_timing_iters: int = 1, workspace_size: int = 0, dla_sram_size: int = 1048576, diff --git a/py/torch_tensorrt/ts/_compiler.py b/py/torch_tensorrt/ts/_compiler.py index e101ebe25d..3be9b7a4c2 100644 --- a/py/torch_tensorrt/ts/_compiler.py +++ b/py/torch_tensorrt/ts/_compiler.py @@ -5,6 +5,7 @@ import torch import torch_tensorrt._C.ts as _C from torch_tensorrt._Device import Device +from torch_tensorrt._enums import EngineCapability, dtype from torch_tensorrt._Input import Input from torch_tensorrt.ts._compile_spec import _parse_compile_spec, _parse_device @@ -18,10 +19,10 @@ def compile( device: Device = Device._current_device(), disable_tf32: bool = False, sparse_weights: bool = False, - enabled_precisions: Optional[Set[torch.dtype | _enums.dtype]] = None, + enabled_precisions: Optional[Set[torch.dtype | dtype]] = None, refit: bool = False, debug: bool = False, - capability: _enums.EngineCapability = _enums.EngineCapability.default, + capability: EngineCapability = EngineCapability.STANDARD, num_avg_timing_iters: int = 1, workspace_size: int = 0, dla_sram_size: int = 1048576, @@ -166,10 +167,10 @@ def convert_method_to_trt_engine( device: Device = Device._current_device(), disable_tf32: bool = False, sparse_weights: bool = False, - enabled_precisions: Optional[Set[torch.dtype | _enums.dtype]] = None, + enabled_precisions: Optional[Set[torch.dtype | dtype]] = None, refit: bool = False, debug: bool = False, - capability: _enums.EngineCapability = _enums.EngineCapability.default, + capability: EngineCapability = EngineCapability.STANDARD, num_avg_timing_iters: int = 1, workspace_size: int = 0, dla_sram_size: int = 1048576, diff --git a/py/torch_tensorrt/ts/_enums.py b/py/torch_tensorrt/ts/_enums.py new file mode 100644 index 0000000000..44cb772dc3 --- /dev/null +++ b/py/torch_tensorrt/ts/_enums.py @@ -0,0 +1,3 @@ +from torch_tensorrt._C import EngineCapability, TensorFormat, dtype # noqa: F401 + +from tensorrt import DeviceType # noqa: F401 diff --git a/py/torch_tensorrt/ts/_utils.py b/py/torch_tensorrt/ts/_utils.py new file mode 100644 index 0000000000..89625e5e86 --- /dev/null +++ b/py/torch_tensorrt/ts/_utils.py @@ -0,0 +1,31 @@ +import torch +from torch_tensorrt import _C +from torch_tensorrt._version import __version__ + + +def dump_build_info() -> None: + """Prints build information about the torch_tensorrt distribution to stdout""" + print(get_build_info()) + + +def get_build_info() -> str: + """Returns a string containing the build information of torch_tensorrt distribution + + Returns: + str: String containing the build information for torch_tensorrt distribution + """ + core_build_info = _C.get_build_info() + build_info = str( + "Torch-TensorRT Version: " + + str(__version__) + + "\n" + + "Using PyTorch Version: " + + str(torch.__version__) + + "\n" + + core_build_info + ) + return build_info + + +def set_device(gpu_id: int) -> None: + _C.set_device(gpu_id) diff --git a/py/torch_tensorrt/ts/logging.py b/py/torch_tensorrt/ts/logging.py new file mode 100644 index 0000000000..4220df7f19 --- /dev/null +++ b/py/torch_tensorrt/ts/logging.py @@ -0,0 +1,115 @@ +from enum import Enum + +from torch_tensorrt._C import ( + LogLevel, + _get_is_colored_output_on, + _get_logging_prefix, + _get_reportable_log_level, + _log, + _set_is_colored_output_on, + _set_logging_prefix, + _set_reportable_log_level, +) + + +class Level(Enum): + """Enum to set the minimum required logging level to print a message to stdout""" + + InternalError = LogLevel.INTERNAL_ERROR + Error = LogLevel.ERROR + Warning = LogLevel.WARNING + Info = LogLevel.INFO + Debug = LogLevel.DEBUG + Graph = LogLevel.GRAPH + + @staticmethod + def _to_internal_level(external: "Level") -> LogLevel: + if external == Level.InternalError: + return LogLevel.INTERNAL_ERROR + elif external == Level.Error: + return LogLevel.ERROR + elif external == Level.Warning: + return LogLevel.WARNING + elif external == Level.Info: + return LogLevel.INFO + elif external == Level.Debug: + return LogLevel.DEBUG + elif external == Level.Graph: + return LogLevel.GRAPH + else: + print(external) + raise ValueError("Unknown log severity") + + +def get_logging_prefix() -> str: + """Get the prefix set for logging messages + + Returns: + str: Prefix used for logger + """ + return str(_get_logging_prefix()) + + +def set_logging_prefix(prefix: str) -> None: + """Set the prefix used when logging messages + + Args: + prefix (str): Prefix to use for logging messages + """ + _set_logging_prefix(prefix) + + +def get_reportable_log_level() -> Level: + """Get the level required for a message to be printed in the log + + Returns: + torch_tensorrt.logging.Level: The enum representing the level required to print + """ + return Level(_get_reportable_log_level()) + + +def set_reportable_log_level(level: Level) -> None: + """Set the level required for a message to be printed to the log + + Args: + level (torch_tensorrt.logging.Level): The enum representing the level required to print + """ + _set_reportable_log_level(Level._to_internal_level(level)) + + +def get_is_colored_output_on() -> bool: + """Get if colored output is enabled for logging + + Returns: + bool: If colored output is one + """ + return bool(_get_is_colored_output_on()) + + +def set_is_colored_output_on(colored_output_on: bool) -> None: + """Enable or disable color in the log output + + Args: + colored_output_on (bool): If colored output should be enabled or not + """ + _set_is_colored_output_on(colored_output_on) + + +def log(level: Level, msg: str) -> None: + """Add a new message to the log + + Adds a new message to the log at a specified level. The message + will only get printed out if Level > reportable_log_level + + Args: + level (torch_tensorrt.logging.Level): Severity of the message + msg (str): Actual message text + """ + _log(Level._to_internal_level(level), msg) + + InternalError = LogLevel.INTERNAL_ERROR + Error = LogLevel.ERROR + Warning = LogLevel.WARNING + Info = LogLevel.INFO + Debug = LogLevel.DEBUG + Graph = LogLevel.GRAPH diff --git a/py/torch_tensorrt/ptq.py b/py/torch_tensorrt/ts/ptq.py similarity index 99% rename from py/torch_tensorrt/ptq.py rename to py/torch_tensorrt/ts/ptq.py index 5d13ab9108..ec86c620ea 100644 --- a/py/torch_tensorrt/ptq.py +++ b/py/torch_tensorrt/ts/ptq.py @@ -11,7 +11,7 @@ import torch from torch_tensorrt import _C -from torch_tensorrt.logging import Level, log +from py.torch_tensorrt.ts.logging import Level, log class CalibrationAlgo(Enum): diff --git a/pyproject.toml b/pyproject.toml index fb4b5478ae..5f681f5a15 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,7 +64,7 @@ include-package-data = false [tool.ruff] # NOTE: Synchoronize the ignores with .flake8 -ignore = [ +lint.ignore = [ # these ignores are from flake8-bugbear; please fix! "B007", "B008", "B017", "B018", # Useless expression @@ -97,7 +97,7 @@ ignore = [ "SIM118", ] #line-length = 120 -select = [ +lint.select = [ "B", "C4", "G", @@ -112,11 +112,11 @@ select = [ ] # Allow unused variables when underscore-prefixed. -dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" +lint.dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" target-version = "py311" # Allow autofix for all enabled rules (when `--fix`) is provided. -fixable = [ +lint.fixable = [ "A","B","C","D","E","F","G", "I","N","Q","S","T","W", "ANN", "ARG", "BLE", "COM", "DJ", @@ -125,7 +125,7 @@ fixable = [ "PGH", "PIE", "PL", "PT", "PTH", "PYI", "RET", "RSE", "RUF", "SIM", "SLF", "TCH", "TID", "TRY", "UP", "YTT"] -unfixable = [] +lint.unfixable = [] # Exclude a variety of commonly ignored directories. exclude = [ @@ -164,7 +164,7 @@ exclude = [ "__init__.py" ] -[tool.ruff.mccabe] +[tool.ruff.lint.mccabe] # Unlike Flake8, default to a complexity level of 10. max-complexity = 10 diff --git a/requirements-dev.txt b/requirements-dev.txt index 052751afec..7f97a8e276 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -9,3 +9,4 @@ transformers timm parameterized expecttest==0.1.6 +pyyaml diff --git a/setup.py b/setup.py index 38d2121461..494eaa7ee1 100644 --- a/setup.py +++ b/setup.py @@ -1,3 +1,5 @@ +# type: ignore + import glob import os import platform @@ -80,15 +82,24 @@ def load_dep_info(): CXX11_ABI = False JETPACK_VERSION = None -FX_ONLY = False +PY_ONLY = False +NO_TS = False LEGACY = False RELEASE = False CI_BUILD = False if "--fx-only" in sys.argv: - FX_ONLY = True + PY_ONLY = True sys.argv.remove("--fx-only") +if "--py-only" in sys.argv: + PY_ONLY = True + sys.argv.remove("--py-only") + +if "--no-ts" in sys.argv: + NO_TS = True + sys.argv.remove("--no-ts") + if "--legacy" in sys.argv: LEGACY = True sys.argv.remove("--legacy") @@ -97,6 +108,14 @@ def load_dep_info(): RELEASE = True sys.argv.remove("--release") +if (no_ts_env_var := os.environ.get("NO_TORCHSCRIPT")) is not None: + if no_ts_env_var == "1": + NO_TS = True + +if (py_only_env_var := os.environ.get("PYTHON_ONLY")) is not None: + if py_only_env_var == "1": + PY_ONLY = True + if (release_env_var := os.environ.get("RELEASE")) is not None: if release_env_var == "1": RELEASE = True @@ -168,7 +187,7 @@ def is_exe(fpath): BAZEL_EXE = None -if not FX_ONLY: +if not PY_ONLY: BAZEL_EXE = which("bazelisk") if BAZEL_EXE is None: @@ -177,9 +196,15 @@ def is_exe(fpath): sys.exit("Could not find bazel in PATH") -def build_libtorchtrt_pre_cxx11_abi(develop=True, use_dist_dir=True, cxx11_abi=False): +def build_libtorchtrt_pre_cxx11_abi( + develop=True, use_dist_dir=True, cxx11_abi=False, rt_only=False +): cmd = [BAZEL_EXE, "build"] - cmd.append("//:libtorchtrt") + if rt_only: + cmd.append("//:libtorchtrt_runtime") + else: + cmd.append("//:libtorchtrt") + if develop: cmd.append("--compilation_mode=dbg") else: @@ -224,7 +249,7 @@ def gen_version_file(): f.write('__tensorrt_version__ = "' + __tensorrt_version__ + '"\n') -def copy_libtorchtrt(multilinux=False): +def copy_libtorchtrt(multilinux=False, rt_only=False): if not os.path.exists(dir_path + "/torch_tensorrt/lib"): os.makedirs(dir_path + "/torch_tensorrt/lib") @@ -234,6 +259,14 @@ def copy_libtorchtrt(multilinux=False): dir_path + "/build/libtrtorch_build/libtrtorch.so", dir_path + "/trtorch/lib/libtrtorch.so", ) + elif rt_only: + os.system( + "tar -xzf " + + dir_path + + "/../bazel-bin/libtorchtrt_runtime.tar.gz --strip-components=1 -C " + + dir_path + + "/torch_tensorrt" + ) else: os.system( "tar -xzf " @@ -252,17 +285,20 @@ def initialize_options(self): def finalize_options(self): develop.finalize_options(self) + if NO_TS or PY_ONLY: + self.root_is_pure = False def run(self): - if FX_ONLY: - gen_version_file() - develop.run(self) - else: + + if not PY_ONLY: global CXX11_ABI - build_libtorchtrt_pre_cxx11_abi(develop=True, cxx11_abi=CXX11_ABI) - gen_version_file() - copy_libtorchtrt() - develop.run(self) + build_libtorchtrt_pre_cxx11_abi( + develop=True, cxx11_abi=CXX11_ABI, rt_only=NO_TS + ) + copy_libtorchtrt(rt_only=NO_TS) + + gen_version_file() + develop.run(self) class InstallCommand(install): @@ -273,17 +309,20 @@ def initialize_options(self): def finalize_options(self): install.finalize_options(self) + if NO_TS or PY_ONLY: + self.root_is_pure = False def run(self): - if FX_ONLY: - gen_version_file() - install.run(self) - else: + + if not PY_ONLY: global CXX11_ABI - build_libtorchtrt_pre_cxx11_abi(develop=False, cxx11_abi=CXX11_ABI) - gen_version_file() - copy_libtorchtrt() - install.run(self) + build_libtorchtrt_pre_cxx11_abi( + develop=False, cxx11_abi=CXX11_ABI, rt_only=NO_TS + ) + copy_libtorchtrt(rt_only=NO_TS) + + gen_version_file() + install.run(self) class BdistCommand(bdist_wheel): @@ -294,12 +333,18 @@ def initialize_options(self): def finalize_options(self): bdist_wheel.finalize_options(self) + if NO_TS or PY_ONLY: + self.root_is_pure = False def run(self): - global CXX11_ABI - build_libtorchtrt_pre_cxx11_abi(develop=False, cxx11_abi=CXX11_ABI) + if not PY_ONLY: + global CXX11_ABI + build_libtorchtrt_pre_cxx11_abi( + develop=False, cxx11_abi=CXX11_ABI, rt_only=NO_TS + ) + copy_libtorchtrt(rt_only=NO_TS) + gen_version_file() - copy_libtorchtrt() bdist_wheel.run(self) @@ -311,16 +356,20 @@ def initialize_options(self): def finalize_options(self): editable_wheel.finalize_options(self) + if NO_TS or PY_ONLY: + self.root_is_pure = False def run(self): - if FX_ONLY: + if PY_ONLY: gen_version_file() editable_wheel.run(self) else: global CXX11_ABI - build_libtorchtrt_pre_cxx11_abi(develop=True, cxx11_abi=CXX11_ABI) + build_libtorchtrt_pre_cxx11_abi( + develop=True, cxx11_abi=CXX11_ABI, rt_only=NO_TS + ) gen_version_file() - copy_libtorchtrt() + copy_libtorchtrt(rt_only=NO_TS) editable_wheel.run(self) @@ -436,7 +485,7 @@ def run(self): package_data = {} -if not FX_ONLY: +if not (PY_ONLY or NO_TS): ext_modules += [ cpp_extension.CUDAExtension( "torch_tensorrt._C", @@ -536,6 +585,19 @@ def run(self): ] } ) +elif NO_TS: + package_data.update( + { + "torch_tensorrt": [ + "BUILD", + "WORKSPACE", + "include/torch_tensorrt/*.h", + "include/torch_tensorrt/core/*.h", + "include/torch_tensorrt/core/runtime/*.h", + "lib/*", + ] + } + ) with open("README.md", "r", encoding="utf-8") as fh: long_description = fh.read() diff --git a/tests/py/core/test_classes.py b/tests/py/core/test_classes.py new file mode 100644 index 0000000000..171fb305ad --- /dev/null +++ b/tests/py/core/test_classes.py @@ -0,0 +1,58 @@ +import copy +import unittest +from typing import Dict + +import torch +import torch_tensorrt as torchtrt +import torchvision.models as models +from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import TorchTensorRTModule + +import tensorrt as trt + + +class TestDevice(unittest.TestCase): + def test_from_string_constructor(self): + device = torchtrt.Device("cuda:0") + self.assertEqual(device.device_type, torchtrt.DeviceType.GPU) + self.assertEqual(device.gpu_id, 0) + + device = torchtrt.Device("gpu:1") + self.assertEqual(device.device_type, torchtrt.DeviceType.GPU) + self.assertEqual(device.gpu_id, 1) + + def test_from_string_constructor_dla(self): + device = torchtrt.Device("dla:0") + self.assertEqual(device.device_type, torchtrt.DeviceType.DLA) + self.assertEqual(device.gpu_id, 0) + self.assertEqual(device.dla_core, 0) + + device = torchtrt.Device("dla:1", allow_gpu_fallback=True) + self.assertEqual(device.device_type, torchtrt.DeviceType.DLA) + self.assertEqual(device.gpu_id, 0) + self.assertEqual(device.dla_core, 1) + self.assertEqual(device.allow_gpu_fallback, True) + + def test_kwargs_gpu(self): + device = torchtrt.Device(gpu_id=0) + self.assertEqual(device.device_type, torchtrt.DeviceType.GPU) + self.assertEqual(device.gpu_id, 0) + + def test_kwargs_dla_and_settings(self): + device = torchtrt.Device(dla_core=1, allow_gpu_fallback=False) + self.assertEqual(device.device_type, torchtrt.DeviceType.DLA) + self.assertEqual(device.gpu_id, 0) + self.assertEqual(device.dla_core, 1) + self.assertEqual(device.allow_gpu_fallback, False) + + device = torchtrt.Device(gpu_id=1, dla_core=0, allow_gpu_fallback=True) + self.assertEqual(device.device_type, torchtrt.DeviceType.DLA) + self.assertEqual( + device.gpu_id, 0 + ) # Override since AGX platforms use iGPU to manage DLA + self.assertEqual(device.dla_core, 0) + self.assertEqual(device.allow_gpu_fallback, True) + + def test_from_torch(self): + device = torchtrt.Device._from_torch_device(torch.device("cuda:0")) + self.assertEqual(device.device_type, torchtrt.DeviceType.GPU) + self.assertEqual(device.gpu_id, 0) diff --git a/tests/py/dynamo/backend/test_backend_compiler.py b/tests/py/dynamo/backend/test_backend_compiler.py index a958d03120..506c9a1959 100644 --- a/tests/py/dynamo/backend/test_backend_compiler.py +++ b/tests/py/dynamo/backend/test_backend_compiler.py @@ -1,11 +1,11 @@ +# type: ignore from copy import deepcopy import torch +import torch_tensorrt from torch.testing._internal.common_utils import TestCase, run_tests from torch_tensorrt.dynamo.partitioning import fast_partition -import torch_tensorrt - from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing diff --git a/tests/py/dynamo/backend/test_specialized_models.py b/tests/py/dynamo/backend/test_specialized_models.py index 3885627b5f..613fc167bb 100644 --- a/tests/py/dynamo/backend/test_specialized_models.py +++ b/tests/py/dynamo/backend/test_specialized_models.py @@ -1,3 +1,4 @@ +# type: ignore import torch import torch_tensorrt from torch.testing._internal.common_utils import TestCase, run_tests diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index 404f50a187..ef034c914f 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -1,3 +1,5 @@ +# type: ignore + import logging import time import unittest @@ -6,10 +8,12 @@ import torch from torch.testing._internal.common_utils import TestCase from torch_tensorrt import Input +from torch_tensorrt._enums import dtype from torch_tensorrt.dynamo._settings import CompilationSettings # Use interpreter, input spec, and test case from fx_ts_compat to test Dynamo Converter Registry from torch_tensorrt.dynamo.conversion import TRTInterpreter +from torch_tensorrt.dynamo.conversion._conversion import infer_module_output_dtypes from torch_tensorrt.dynamo.lowering import apply_lowering_passes from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule @@ -68,7 +72,8 @@ def run_test( interpreter_result.output_names, ) - ref_outputs = mod(*inputs) + mod = mod.cuda() + ref_outputs = mod(*cuda_inputs) torch.cuda.synchronize() start_event = torch.cuda.Event(enable_timing=True) @@ -144,7 +149,7 @@ def run_test_custom_compare_results( interpreter_result.output_names, ) res_trt = trt_mod(*cuda_inputs).cpu() - res_cpu = mod(*inputs) + res_cpu = mod(*cuda_inputs).cpu() assert len(res_trt) == len(res_cpu) assert len(res_cpu) == len(comparators) for output_trt, output_cpu, comparator in zip( @@ -208,7 +213,6 @@ def generate_graph( fx_module = torch.fx.symbolic_trace(mod) if enable_passes: fx_module = apply_lowering_passes(fx_module, original_inputs) - _LOGGER.info(f"FX graph= {fx_module.graph}") return fx_module def run_test( @@ -217,9 +221,8 @@ def run_test( inputs, rtol=1e-03, atol=1e-03, - precision=torch.float, + precision=dtype.f32, check_dtype=True, - output_dtypes=None, use_dynamo_tracer=False, enable_passes=False, ): @@ -234,12 +237,25 @@ def run_test( # Previous instance of the interpreter auto-casted 64-bit inputs # We replicate this behavior here compilation_settings = CompilationSettings( - precision=precision, truncate_long_and_double=True + enabled_precisions={dtype._from(precision)}, + truncate_long_and_double=True, + debug=True, ) + input_specs = [Input.from_tensor(i) for i in inputs] + + output_dtypes = None + if check_dtype: + output_dtypes = infer_module_output_dtypes( + mod, + input_specs, + compilation_settings.device, + truncate_long_and_double=compilation_settings.truncate_long_and_double, + ) + interp = TRTInterpreter( mod, - Input.from_tensors(inputs), + input_specs, output_dtypes=output_dtypes, compilation_settings=compilation_settings, ) diff --git a/tests/py/dynamo/conversion/test_abs_aten.py b/tests/py/dynamo/conversion/test_abs_aten.py index 13beeb3bfa..5778110106 100644 --- a/tests/py/dynamo/conversion/test_abs_aten.py +++ b/tests/py/dynamo/conversion/test_abs_aten.py @@ -42,7 +42,6 @@ def forward(self, input): self.run_test( abs(), inputs, - output_dtypes=[torch.int], ) diff --git a/tests/py/dynamo/conversion/test_any.py b/tests/py/dynamo/conversion/test_any.py index f82e2465be..29522145da 100644 --- a/tests/py/dynamo/conversion/test_any.py +++ b/tests/py/dynamo/conversion/test_any.py @@ -26,7 +26,7 @@ def forward(self, x): return torch.ops.aten.any.default(x) inputs = [torch.randn(*input_shape)] - self.run_test(Any(), inputs, output_dtypes=[torch.bool]) + self.run_test(Any(), inputs) @parameterized.expand( [ @@ -43,7 +43,7 @@ def forward(self, x): return torch.ops.aten.any.dim(x, dim, keep_dims) inputs = [torch.randn(*input_shape)] - self.run_test(AnyDim(), inputs, output_dtypes=[torch.bool]) + self.run_test(AnyDim(), inputs) @parameterized.expand( [ @@ -59,7 +59,7 @@ def forward(self, x): return torch.ops.aten.any.dims(x, dims, keep_dims) inputs = [torch.randn(*input_shape)] - self.run_test(AnyDims(), inputs, output_dtypes=[torch.bool]) + self.run_test(AnyDims(), inputs) @parameterized.expand( [ @@ -79,7 +79,6 @@ def forward(self, x): self.run_test( Any(), inputs, - output_dtypes=[torch.bool], ) @parameterized.expand( @@ -100,7 +99,6 @@ def forward(self, x): self.run_test( AnyDim(), inputs, - output_dtypes=[torch.bool], ) @parameterized.expand( @@ -123,7 +121,6 @@ def forward(self, x): self.run_test( AnyDims(), inputs, - output_dtypes=[torch.bool], ) @parameterized.expand( @@ -142,7 +139,6 @@ def forward(self, x): self.run_test( Any(), inputs, - output_dtypes=[torch.bool], ) @parameterized.expand( @@ -163,7 +159,6 @@ def forward(self, x): self.run_test( AnyDim(), inputs, - output_dtypes=[torch.bool], ) @parameterized.expand( @@ -186,7 +181,6 @@ def forward(self, x): self.run_test( AnyDims(), inputs, - output_dtypes=[torch.bool], ) diff --git a/tests/py/dynamo/conversion/test_bitwise_not_aten.py b/tests/py/dynamo/conversion/test_bitwise_not_aten.py index 6dd512ef16..b811f1e51a 100644 --- a/tests/py/dynamo/conversion/test_bitwise_not_aten.py +++ b/tests/py/dynamo/conversion/test_bitwise_not_aten.py @@ -25,7 +25,6 @@ def forward(self, val): bitwise_not(), inputs, enable_passes=True, - output_dtypes=[torch.bool], ) diff --git a/tests/py/dynamo/conversion/test_casts.py b/tests/py/dynamo/conversion/test_casts.py index c067a0b9ad..84234db857 100644 --- a/tests/py/dynamo/conversion/test_casts.py +++ b/tests/py/dynamo/conversion/test_casts.py @@ -1,6 +1,9 @@ +# type: ignore + import torch import torch.nn as nn from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import dtype from torch_tensorrt.dynamo.conversion import UnsupportedOperatorException from .harness import DispatchTestCase diff --git a/tests/py/dynamo/conversion/test_eq_aten.py b/tests/py/dynamo/conversion/test_eq_aten.py index 17a372182c..3adc6774d6 100644 --- a/tests/py/dynamo/conversion/test_eq_aten.py +++ b/tests/py/dynamo/conversion/test_eq_aten.py @@ -25,7 +25,6 @@ def forward(self, lhs_val, rhs_val): self.run_test( eq(), inputs, - output_dtypes=[torch.bool], ) @parameterized.expand( @@ -43,7 +42,6 @@ def forward(self, lhs_val): self.run_test( eq(), inputs, - output_dtypes=[torch.bool], ) @parameterized.expand( @@ -61,7 +59,6 @@ def forward(self, lhs_val): self.run_test( eq(), inputs, - output_dtypes=[torch.bool], ) diff --git a/tests/py/dynamo/conversion/test_ge_aten.py b/tests/py/dynamo/conversion/test_ge_aten.py index 6b1ee6d440..bacfedafc8 100644 --- a/tests/py/dynamo/conversion/test_ge_aten.py +++ b/tests/py/dynamo/conversion/test_ge_aten.py @@ -25,7 +25,6 @@ def forward(self, lhs_val, rhs_val): self.run_test( ge(), inputs, - output_dtypes=[torch.bool], ) @parameterized.expand( @@ -43,7 +42,6 @@ def forward(self, lhs_val): self.run_test( ge(), inputs, - output_dtypes=[torch.bool], ) @parameterized.expand( @@ -61,7 +59,6 @@ def forward(self, lhs_val): self.run_test( ge(), inputs, - output_dtypes=[torch.bool], ) diff --git a/tests/py/dynamo/conversion/test_gt_aten.py b/tests/py/dynamo/conversion/test_gt_aten.py index 8d9ae24f80..0eab7c84ff 100644 --- a/tests/py/dynamo/conversion/test_gt_aten.py +++ b/tests/py/dynamo/conversion/test_gt_aten.py @@ -22,7 +22,6 @@ def forward(self, lhs_val, rhs_val): self.run_test( gt(), inputs, - output_dtypes=[torch.bool], ) @parameterized.expand( @@ -40,7 +39,6 @@ def forward(self, lhs_val): self.run_test( gt(), inputs, - output_dtypes=[torch.bool], ) @parameterized.expand( @@ -58,7 +56,6 @@ def forward(self, lhs_val): self.run_test( gt(), inputs, - output_dtypes=[torch.bool], ) diff --git a/tests/py/dynamo/conversion/test_isinf_aten.py b/tests/py/dynamo/conversion/test_isinf_aten.py index 78695dbe21..d0dce59a60 100644 --- a/tests/py/dynamo/conversion/test_isinf_aten.py +++ b/tests/py/dynamo/conversion/test_isinf_aten.py @@ -35,7 +35,6 @@ def forward(self, input): self.run_test( isinf(), inputs, - output_dtypes=[torch.bool], ) @parameterized.expand( @@ -54,7 +53,6 @@ def forward(self, input): self.run_test( isinf(), inputs, - output_dtypes=[torch.bool], ) diff --git a/tests/py/dynamo/conversion/test_isnan_aten.py b/tests/py/dynamo/conversion/test_isnan_aten.py index 5651b0ca25..a1e897a664 100644 --- a/tests/py/dynamo/conversion/test_isnan_aten.py +++ b/tests/py/dynamo/conversion/test_isnan_aten.py @@ -1,3 +1,4 @@ +# type: ignore import torch import torch.nn as nn from parameterized import parameterized @@ -36,7 +37,6 @@ def forward(self, input): self.run_test( isnan(), inputs, - output_dtypes=[torch.bool], ) @parameterized.expand( @@ -55,7 +55,6 @@ def forward(self, input): self.run_test( isnan(), inputs, - output_dtypes=[torch.bool], ) @parameterized.expand( @@ -74,7 +73,6 @@ def forward(self, input): self.run_test( isnan(), inputs, - output_dtypes=[torch.bool], ) diff --git a/tests/py/dynamo/conversion/test_le_aten.py b/tests/py/dynamo/conversion/test_le_aten.py index 373384c6f9..5b725213a3 100644 --- a/tests/py/dynamo/conversion/test_le_aten.py +++ b/tests/py/dynamo/conversion/test_le_aten.py @@ -25,7 +25,6 @@ def forward(self, lhs_val, rhs_val): self.run_test( le(), inputs, - output_dtypes=[torch.bool], ) @parameterized.expand( @@ -43,7 +42,6 @@ def forward(self, lhs_val): self.run_test( le(), inputs, - output_dtypes=[torch.bool], ) @parameterized.expand( @@ -61,7 +59,6 @@ def forward(self, lhs_val): self.run_test( le(), inputs, - output_dtypes=[torch.bool], ) diff --git a/tests/py/dynamo/conversion/test_logical_not_aten.py b/tests/py/dynamo/conversion/test_logical_not_aten.py index a36a8dbf72..b03fbc777e 100644 --- a/tests/py/dynamo/conversion/test_logical_not_aten.py +++ b/tests/py/dynamo/conversion/test_logical_not_aten.py @@ -22,7 +22,6 @@ def forward(self, input): self.run_test( logical_not(), inputs, - output_dtypes=[torch.bool], ) @parameterized.expand( @@ -41,7 +40,6 @@ def forward(self, input): self.run_test( logical_not(), inputs, - output_dtypes=[torch.bool], ) @parameterized.expand( @@ -60,7 +58,6 @@ def forward(self, input): self.run_test( logical_not(), inputs, - output_dtypes=[torch.bool], ) diff --git a/tests/py/dynamo/conversion/test_lt_aten.py b/tests/py/dynamo/conversion/test_lt_aten.py index 89cb7f42c5..bd4b8f1b21 100644 --- a/tests/py/dynamo/conversion/test_lt_aten.py +++ b/tests/py/dynamo/conversion/test_lt_aten.py @@ -22,7 +22,6 @@ def forward(self, lhs_val, rhs_val): self.run_test( lt(), inputs, - output_dtypes=[torch.bool], ) @parameterized.expand( @@ -40,7 +39,6 @@ def forward(self, lhs_val): self.run_test( lt(), inputs, - output_dtypes=[torch.bool], ) @parameterized.expand( @@ -58,7 +56,6 @@ def forward(self, lhs_val): self.run_test( lt(), inputs, - output_dtypes=[torch.bool], ) diff --git a/tests/py/dynamo/conversion/test_ne_aten.py b/tests/py/dynamo/conversion/test_ne_aten.py index 2450ac0945..d2f7421848 100644 --- a/tests/py/dynamo/conversion/test_ne_aten.py +++ b/tests/py/dynamo/conversion/test_ne_aten.py @@ -25,7 +25,6 @@ def forward(self, lhs_val, rhs_val): self.run_test( ne(), inputs, - output_dtypes=[torch.bool], ) @parameterized.expand( @@ -43,7 +42,6 @@ def forward(self, lhs_val): self.run_test( ne(), inputs, - output_dtypes=[torch.bool], ) @parameterized.expand( @@ -61,7 +59,6 @@ def forward(self, lhs_val): self.run_test( ne(), inputs, - output_dtypes=[torch.bool], ) diff --git a/tests/py/dynamo/conversion/test_pad_aten.py b/tests/py/dynamo/conversion/test_pad_aten.py index 2803736ad0..ca0b01b5d2 100644 --- a/tests/py/dynamo/conversion/test_pad_aten.py +++ b/tests/py/dynamo/conversion/test_pad_aten.py @@ -1,3 +1,4 @@ +# type: ignore import torch from parameterized import parameterized from torch.testing._internal.common_utils import run_tests diff --git a/tests/py/dynamo/conversion/test_scalar_tensor_aten.py b/tests/py/dynamo/conversion/test_scalar_tensor_aten.py index 28c3d7f481..d0146ed720 100644 --- a/tests/py/dynamo/conversion/test_scalar_tensor_aten.py +++ b/tests/py/dynamo/conversion/test_scalar_tensor_aten.py @@ -87,7 +87,6 @@ def forward(self): self.run_test( ScalarTensor(), inputs, - output_dtypes=None if dtype is None else [dtype], ) diff --git a/tests/py/dynamo/conversion/test_sum_aten.py b/tests/py/dynamo/conversion/test_sum_aten.py index 999b8b4997..bac8c7edf1 100644 --- a/tests/py/dynamo/conversion/test_sum_aten.py +++ b/tests/py/dynamo/conversion/test_sum_aten.py @@ -85,7 +85,6 @@ def forward(self, x): self.run_test( Sum(), inputs, - output_dtypes=[torch.int32], ) @parameterized.expand( @@ -108,7 +107,6 @@ def forward(self, x): self.run_test( Sum(), inputs, - output_dtypes=[torch.int32], ) diff --git a/tests/py/dynamo/lowering/test_aten_lowering_passes.py b/tests/py/dynamo/lowering/test_aten_lowering_passes.py index bc75a8aa3d..b8b3b3e249 100644 --- a/tests/py/dynamo/lowering/test_aten_lowering_passes.py +++ b/tests/py/dynamo/lowering/test_aten_lowering_passes.py @@ -1,7 +1,8 @@ -import torch -from torch.testing._internal.common_utils import TestCase, run_tests +import unittest +import torch import torch_tensorrt +from torch.testing._internal.common_utils import TestCase, run_tests from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing @@ -268,6 +269,10 @@ def forward(self, q, k, v): torch._dynamo.reset() +@unittest.skipIf( + torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8, + "GPU compute capability is too low to run flash attention, need Ampere (8.0) or greater", +) class TestLowerFlashAttention(TestCase): def test_lower_flash_attention(self): class FlashAttention(torch.nn.Module): diff --git a/tests/py/dynamo/models/test_models_export.py b/tests/py/dynamo/models/test_models_export.py index fd7b40592a..3da1a0976f 100644 --- a/tests/py/dynamo/models/test_models_export.py +++ b/tests/py/dynamo/models/test_models_export.py @@ -1,3 +1,4 @@ +# type: ignore import unittest import pytest diff --git a/tests/py/dynamo/runtime/test_hw_compat.py b/tests/py/dynamo/runtime/test_hw_compat.py index 4218cc7de0..29bd17cfde 100644 --- a/tests/py/dynamo/runtime/test_hw_compat.py +++ b/tests/py/dynamo/runtime/test_hw_compat.py @@ -7,6 +7,14 @@ class TestHardwareCompatibility(TestCase): + @unittest.skipIf( + not torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime, + "Torch-TensorRT Runtime is not available", + ) + @unittest.skipIf( + not torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8, + "HW Compatibility is not supported on cards older than Ampere", + ) def test_hw_compat_enabled(self): class SampleModel(torch.nn.Module): def forward(self, x): @@ -58,6 +66,14 @@ def forward(self, x): torch.ops.tensorrt.ABI_VERSION() != "5", "Detected incorrect ABI version, please update this test case", ) + @unittest.skipIf( + not torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime, + "Torch-TensorRT runtime is not available", + ) + @unittest.skipIf( + not torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8, + "HW Compatibility is not supported on cards older than Ampere", + ) def test_hw_compat_3080_build(self): inputs = [torch.randn(5, 7).cuda()] diff --git a/tests/py/dynamo/runtime/test_safe_mode.py b/tests/py/dynamo/runtime/test_safe_mode.py index bd196b12f0..5842b3ddc5 100644 --- a/tests/py/dynamo/runtime/test_safe_mode.py +++ b/tests/py/dynamo/runtime/test_safe_mode.py @@ -1,11 +1,16 @@ -import torch -from torch.testing._internal.common_utils import TestCase, run_tests +import unittest +import torch import torch_tensorrt +from torch.testing._internal.common_utils import TestCase, run_tests from ..testing_utilities import DECIMALS_OF_AGREEMENT +@unittest.skipIf( + not torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime, + "Torch-TensorRT runtime is not available", +) class TestSafeMode(TestCase): def test_multi_device_safe_mode_on(self): torch_tensorrt.runtime.set_multi_device_safe_mode(True) diff --git a/tests/py/ts/api/test_classes.py b/tests/py/ts/api/test_classes.py index 01c805d9a1..2a152cdec7 100644 --- a/tests/py/ts/api/test_classes.py +++ b/tests/py/ts/api/test_classes.py @@ -1,58 +1,17 @@ -import unittest -import torch_tensorrt as torchtrt -from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import TorchTensorRTModule -import torch -import torchvision.models as models import copy +import unittest from typing import Dict - -class TestDevice(unittest.TestCase): - def test_from_string_constructor(self): - device = torchtrt.Device("cuda:0") - self.assertEqual(device.device_type, torchtrt.DeviceType.GPU) - self.assertEqual(device.gpu_id, 0) - - device = torchtrt.Device("gpu:1") - self.assertEqual(device.device_type, torchtrt.DeviceType.GPU) - self.assertEqual(device.gpu_id, 1) - - def test_from_string_constructor_dla(self): - device = torchtrt.Device("dla:0") - self.assertEqual(device.device_type, torchtrt.DeviceType.DLA) - self.assertEqual(device.gpu_id, 0) - self.assertEqual(device.dla_core, 0) - - device = torchtrt.Device("dla:1", allow_gpu_fallback=True) - self.assertEqual(device.device_type, torchtrt.DeviceType.DLA) - self.assertEqual(device.gpu_id, 0) - self.assertEqual(device.dla_core, 1) - self.assertEqual(device.allow_gpu_fallback, True) - - def test_kwargs_gpu(self): - device = torchtrt.Device(gpu_id=0) - self.assertEqual(device.device_type, torchtrt.DeviceType.GPU) - self.assertEqual(device.gpu_id, 0) - - def test_kwargs_dla_and_settings(self): - device = torchtrt.Device(dla_core=1, allow_gpu_fallback=False) - self.assertEqual(device.device_type, torchtrt.DeviceType.DLA) - self.assertEqual(device.gpu_id, 0) - self.assertEqual(device.dla_core, 1) - self.assertEqual(device.allow_gpu_fallback, False) - - device = torchtrt.Device(gpu_id=1, dla_core=0, allow_gpu_fallback=True) - self.assertEqual(device.device_type, torchtrt.DeviceType.DLA) - self.assertEqual(device.gpu_id, 1) - self.assertEqual(device.dla_core, 0) - self.assertEqual(device.allow_gpu_fallback, True) - - def test_from_torch(self): - device = torchtrt.Device._from_torch_device(torch.device("cuda:0")) - self.assertEqual(device.device_type, torchtrt.DeviceType.GPU) - self.assertEqual(device.gpu_id, 0) +import torch +import torch_tensorrt as torchtrt +import torchvision.models as models +from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import TorchTensorRTModule +@unittest.skipIf( + not torchtrt.ENABLED_FEATURES.torchscript_frontend, + "TorchScript Frontend is not available", +) class TestInput(unittest.TestCase): def _verify_correctness(self, struct: torchtrt.Input, target: Dict) -> bool: internal = struct._to_internal() @@ -80,10 +39,16 @@ def field_is_correct(field, equal_fn, a1, a2): target["explicit_set_dtype"], ) dtype_ = field_is_correct( - "dtype", eq, int(internal.dtype), int(target["dtype"]) + "dtype", + eq, + torchtrt.dtype._from(internal.dtype), + torchtrt.dtype._from(target["dtype"]), ) format_ = field_is_correct( - "format", eq, int(internal.format), int(target["format"]) + "format", + eq, + torchtrt.memory_format._from(internal.format), + torchtrt.memory_format._from(target["format"]), ) return all( @@ -98,7 +63,7 @@ def test_infer_from_example_tensor(self): "max": shape, "input_is_dynamic": False, "dtype": torchtrt.dtype.half, - "format": torchtrt.TensorFormat.contiguous, + "format": torchtrt.memory_format.contiguous, "explicit_set_dtype": True, } @@ -117,7 +82,7 @@ def test_static_shape(self): "max": shape, "input_is_dynamic": False, "dtype": torchtrt.dtype.unknown, - "format": torchtrt.TensorFormat.contiguous, + "format": torchtrt.memory_format.contiguous, "explicit_set_dtype": False, } @@ -165,7 +130,7 @@ def test_data_type(self): "max": shape, "input_is_dynamic": False, "dtype": torchtrt.dtype.half, - "format": torchtrt.TensorFormat.contiguous, + "format": torchtrt.memory_format.contiguous, "explicit_set_dtype": True, } @@ -189,11 +154,11 @@ def test_tensor_format(self): "max": shape, "input_is_dynamic": False, "dtype": torchtrt.dtype.unknown, - "format": torchtrt.TensorFormat.channels_last, + "format": torchtrt.memory_format.channels_last, "explicit_set_dtype": False, } - i = torchtrt.Input(shape, format=torchtrt.TensorFormat.channels_last) + i = torchtrt.Input(shape, format=torchtrt.memory_format.channels_last) ts_i = torchtrt.ts.TorchScriptInput( shape=i.shape, dtype=i.dtype, format=i.format ) @@ -215,7 +180,7 @@ def test_dynamic_shape(self): "max": max_shape, "input_is_dynamic": True, "dtype": torchtrt.dtype.unknown, - "format": torchtrt.TensorFormat.contiguous, + "format": torchtrt.memory_format.contiguous, "explicit_set_dtype": False, } @@ -261,6 +226,10 @@ def test_dynamic_shape(self): self.assertTrue(self._verify_correctness(ts_i, target)) +@unittest.skipIf( + not torchtrt.ENABLED_FEATURES.torchscript_frontend, + "TorchScript Frontend is not available", +) class TestTorchTensorRTModule(unittest.TestCase): @staticmethod def _get_trt_mod(): diff --git a/tests/py/ts/api/test_collections.py b/tests/py/ts/api/test_collections.py index eab67679ed..7dc79b09b4 100644 --- a/tests/py/ts/api/test_collections.py +++ b/tests/py/ts/api/test_collections.py @@ -1,9 +1,12 @@ +# type: ignore + +import os import unittest -import torch_tensorrt as torchtrt + import torch +import torch_tensorrt as torchtrt import torchvision.models as models -import os -from utils import cosine_similarity, COSINE_THRESHOLD +from utils import COSINE_THRESHOLD, cosine_similarity def find_repo_root(max_depth=10): @@ -21,6 +24,10 @@ def find_repo_root(max_depth=10): MODULE_DIR = find_repo_root() + "/tests/modules" +@unittest.skipIf( + not torchtrt.ENABLED_FEATURES.torchscript_frontend, + "TorchScript Frontend is not available", +) class TestStandardTensorInput(unittest.TestCase): def test_compile(self): self.input = torch.randn((1, 3, 224, 224)).to("cuda") @@ -49,9 +56,14 @@ def test_compile(self): ) +@unittest.skipIf( + not torchtrt.ENABLED_FEATURES.torchscript_frontend, + "TorchScript Frontend is not available", +) +@unittest.skip("TODO: @bowang007, Invalid test case, needs fixing") class TestStandardTensorInputLong(unittest.TestCase): def test_compile(self): - self.input = torch.randn((1, 3, 224, 224)).to("cuda") + self.input = torch.randn((1, 3, 224, 224)).to("cuda").to(torch.int32) self.model = ( torch.jit.load(MODULE_DIR + "/standard_tensor_input_scripted.jit.pt") .eval() @@ -66,6 +78,7 @@ def test_compile(self): "device": torchtrt.Device("gpu:0"), "enabled_precisions": {torch.float}, "truncate_long_and_double": True, + "require_full_compilation": True, } trt_mod = torchtrt.ts.compile(self.model, **compile_spec) @@ -78,6 +91,10 @@ def test_compile(self): ) +@unittest.skipIf( + not torchtrt.ENABLED_FEATURES.torchscript_frontend, + "TorchScript Frontend is not available", +) class TestStandardTensorInputDomain(unittest.TestCase): def test_compile(self): self.input = torch.randn((1, 3, 224, 224)).to("cuda") @@ -106,6 +123,10 @@ def test_compile(self): ) +@unittest.skipIf( + not torchtrt.ENABLED_FEATURES.torchscript_frontend, + "TorchScript Frontend is not available", +) class TestTupleInput(unittest.TestCase): def test_compile(self): self.input = torch.randn((1, 3, 224, 224)).to("cuda") @@ -134,6 +155,10 @@ def test_compile(self): ) +@unittest.skipIf( + not torchtrt.ENABLED_FEATURES.torchscript_frontend, + "TorchScript Frontend is not available", +) class TestListInput(unittest.TestCase): def test_compile(self): self.input = torch.randn((1, 3, 224, 224)).to("cuda") @@ -160,6 +185,10 @@ def test_compile(self): ) +@unittest.skipIf( + not torchtrt.ENABLED_FEATURES.torchscript_frontend, + "TorchScript Frontend is not available", +) class TestTupleInputOutput(unittest.TestCase): def test_compile(self): self.input = torch.randn((1, 3, 224, 224)).to("cuda") @@ -217,6 +246,10 @@ def test_compile_full_compilation(self): ) +@unittest.skipIf( + not torchtrt.ENABLED_FEATURES.torchscript_frontend, + "TorchScript Frontend is not available", +) class TestListInputOutput(unittest.TestCase): def test_compile(self): self.input = torch.randn((1, 3, 224, 224)).to("cuda") @@ -276,6 +309,10 @@ def test_compile_full_compilation(self): ) +@unittest.skipIf( + not torchtrt.ENABLED_FEATURES.torchscript_frontend, + "TorchScript Frontend is not available", +) class TestListInputTupleOutput(unittest.TestCase): def test_compile(self): self.input = torch.randn((1, 3, 224, 224)).to("cuda") diff --git a/tests/py/ts/api/test_e2e_behavior.py b/tests/py/ts/api/test_e2e_behavior.py index 499106e9ca..7e1f3dd538 100644 --- a/tests/py/ts/api/test_e2e_behavior.py +++ b/tests/py/ts/api/test_e2e_behavior.py @@ -1,9 +1,10 @@ +import copy import unittest -import torch_tensorrt as torchtrt +from typing import Dict + import torch +import torch_tensorrt as torchtrt import torchvision.models as models -import copy -from typing import Dict from utils import same_output_format @@ -39,7 +40,7 @@ def test_input_respect_user_setting_fp32_weights_fp16_in_non_constructor(self): ts_model = torch.jit.script(self.model) input_spec = torchtrt.Input(self.input.shape) - input_spec.dtype = torch.half + input_spec.dtype = torchtrt.dtype.half trt_mod = torchtrt.ts.compile( ts_model, @@ -100,7 +101,7 @@ def test_input_respect_user_setting_fp16_weights_fp32_in_non_constuctor(self): half_mod.half() input_spec = torchtrt.Input(self.input.shape) - input_spec.dtype = torch.float + input_spec.dtype = torchtrt.dtype.float trt_mod = torchtrt.ts.compile( half_mod, diff --git a/tests/py/ts/api/test_logging.py b/tests/py/ts/api/test_logging.py index cc10fa9cc9..b07be3bff5 100644 --- a/tests/py/ts/api/test_logging.py +++ b/tests/py/ts/api/test_logging.py @@ -1,71 +1,76 @@ +import copy import unittest -import torch_tensorrt as torchtrt +from typing import Dict + import torch +import torch_tensorrt as torchtrt import torchvision.models as models -import copy -from typing import Dict +@unittest.skipIf( + not torchtrt.ENABLED_FEATURES.torchscript_frontend, + "TorchScript Frontend is not available", +) class TestLoggingAPIs(unittest.TestCase): def test_logging_prefix(self): new_prefix = "Python API Test: " - torchtrt.logging.set_logging_prefix(new_prefix) - logging_prefix = torchtrt.logging.get_logging_prefix() + torchtrt.ts.logging.set_logging_prefix(new_prefix) + logging_prefix = torchtrt.ts.logging.get_logging_prefix() self.assertEqual(new_prefix, logging_prefix) def test_reportable_log_level(self): - new_level = torchtrt.logging.Level.Error - torchtrt.logging.set_reportable_log_level(new_level) - level = torchtrt.logging.get_reportable_log_level() + new_level = torchtrt.ts.logging.Level.Error + torchtrt.ts.logging.set_reportable_log_level(new_level) + level = torchtrt.ts.logging.get_reportable_log_level() self.assertEqual(new_level, level) def test_is_colored_output_on(self): - torchtrt.logging.set_is_colored_output_on(True) - color = torchtrt.logging.get_is_colored_output_on() + torchtrt.ts.logging.set_is_colored_output_on(True) + color = torchtrt.ts.logging.get_is_colored_output_on() self.assertTrue(color) def test_context_managers(self): - base_lvl = torchtrt.logging.get_reportable_log_level() + base_lvl = torchtrt.ts.logging.get_reportable_log_level() with torchtrt.logging.internal_errors(): - lvl = torchtrt.logging.get_reportable_log_level() - self.assertEqual(torchtrt.logging.Level.InternalError, lvl) + lvl = torchtrt.ts.logging.get_reportable_log_level() + self.assertEqual(torchtrt.ts.logging.Level.InternalError, lvl) - lvl = torchtrt.logging.get_reportable_log_level() + lvl = torchtrt.ts.logging.get_reportable_log_level() self.assertEqual(base_lvl, lvl) with torchtrt.logging.errors(): - lvl = torchtrt.logging.get_reportable_log_level() - self.assertEqual(torchtrt.logging.Level.Error, lvl) + lvl = torchtrt.ts.logging.get_reportable_log_level() + self.assertEqual(torchtrt.ts.logging.Level.Error, lvl) - lvl = torchtrt.logging.get_reportable_log_level() + lvl = torchtrt.ts.logging.get_reportable_log_level() self.assertEqual(base_lvl, lvl) with torchtrt.logging.warnings(): - lvl = torchtrt.logging.get_reportable_log_level() - self.assertEqual(torchtrt.logging.Level.Warning, lvl) + lvl = torchtrt.ts.logging.get_reportable_log_level() + self.assertEqual(torchtrt.ts.logging.Level.Warning, lvl) - lvl = torchtrt.logging.get_reportable_log_level() + lvl = torchtrt.ts.logging.get_reportable_log_level() self.assertEqual(base_lvl, lvl) with torchtrt.logging.info(): - lvl = torchtrt.logging.get_reportable_log_level() - self.assertEqual(torchtrt.logging.Level.Info, lvl) + lvl = torchtrt.ts.logging.get_reportable_log_level() + self.assertEqual(torchtrt.ts.logging.Level.Info, lvl) - lvl = torchtrt.logging.get_reportable_log_level() + lvl = torchtrt.ts.logging.get_reportable_log_level() self.assertEqual(base_lvl, lvl) with torchtrt.logging.debug(): - lvl = torchtrt.logging.get_reportable_log_level() - self.assertEqual(torchtrt.logging.Level.Debug, lvl) + lvl = torchtrt.ts.logging.get_reportable_log_level() + self.assertEqual(torchtrt.ts.logging.Level.Debug, lvl) - lvl = torchtrt.logging.get_reportable_log_level() + lvl = torchtrt.ts.logging.get_reportable_log_level() self.assertEqual(base_lvl, lvl) with torchtrt.logging.graphs(): - lvl = torchtrt.logging.get_reportable_log_level() - self.assertEqual(torchtrt.logging.Level.Graph, lvl) + lvl = torchtrt.ts.logging.get_reportable_log_level() + self.assertEqual(torchtrt.ts.logging.Level.Graph, lvl) - lvl = torchtrt.logging.get_reportable_log_level() + lvl = torchtrt.ts.logging.get_reportable_log_level() self.assertEqual(base_lvl, lvl) diff --git a/tests/py/ts/hw/test_api_dla.py b/tests/py/ts/hw/test_api_dla.py index 5328b92233..0bf3b74010 100644 --- a/tests/py/ts/hw/test_api_dla.py +++ b/tests/py/ts/hw/test_api_dla.py @@ -1,10 +1,15 @@ import unittest -import torch_tensorrt as torchtrt + import torch +import torch_tensorrt as torchtrt import torchvision.models as models -from utils import cosine_similarity, COSINE_THRESHOLD +from utils import COSINE_THRESHOLD, cosine_similarity +@unittest.skipIf( + not torchtrt.ENABLED_FEATURES.torchscript_frontend, + "TorchScript Frontend is not available", +) class ModelTestCaseOnDLA(unittest.TestCase): def __init__(self, methodName="runTest", model=None): super(ModelTestCaseOnDLA, self).__init__(methodName) @@ -21,6 +26,10 @@ def parametrize(testcase_class, model=None): return suite +@unittest.skipIf( + not torchtrt.ENABLED_FEATURES.torchscript_frontend, + "TorchScript Frontend is not available", +) class TestCompile(ModelTestCaseOnDLA): def setUp(self): self.input = torch.randn((1, 3, 224, 224)).to("cuda").half() diff --git a/tests/py/ts/hw/test_multi_gpu.py b/tests/py/ts/hw/test_multi_gpu.py index b6fa3f220b..4fc2bd9223 100644 --- a/tests/py/ts/hw/test_multi_gpu.py +++ b/tests/py/ts/hw/test_multi_gpu.py @@ -1,11 +1,15 @@ import unittest -import torch_tensorrt as torchtrt + import torch +import torch_tensorrt as torchtrt import torchvision.models as models - from model_test_case import ModelTestCase +@unittest.skipIf( + not torchtrt.ENABLED_FEATURES.torchscript_frontend, + "TorchScript Frontend is not available", +) class TestMultiGpuSwitching(ModelTestCase): def setUp(self): if torch.cuda.device_count() < 2: @@ -65,6 +69,10 @@ def test_compile_script(self): ) +@unittest.skipIf( + not torchtrt.ENABLED_FEATURES.torchscript_frontend, + "TorchScript Frontend is not available", +) class TestMultiGpuSerializeDeserializeSwitching(ModelTestCase): def setUp(self): if torch.cuda.device_count() < 2: diff --git a/tests/py/ts/integrations/test_to_backend_api.py b/tests/py/ts/integrations/test_to_backend_api.py index 0f74a3af15..6e974ba2c8 100644 --- a/tests/py/ts/integrations/test_to_backend_api.py +++ b/tests/py/ts/integrations/test_to_backend_api.py @@ -1,10 +1,16 @@ +# type: ignore import unittest -import torch_tensorrt as torchtrt + import torch +import torch_tensorrt as torchtrt import torchvision.models as models -from utils import cosine_similarity, COSINE_THRESHOLD +from utils import COSINE_THRESHOLD, cosine_similarity +@unittest.skipIf( + not torchtrt.ENABLED_FEATURES.torchscript_frontend, + "TorchScript Frontend is not available", +) class TestToBackendLowering(unittest.TestCase): def setUp(self): self.input = torch.randn((1, 3, 300, 300)).to("cuda") @@ -23,7 +29,9 @@ def setUp(self): "dla_core": 0, "allow_gpu_fallback": True, }, - "capability": torchtrt.EngineCapability.default, + "capability": torchtrt.EngineCapability.STANDARD.to( + torchtrt._C.EngineCapability + ), "num_avg_timing_iters": 1, "disable_tf32": False, } diff --git a/tests/py/ts/integrations/test_trt_intercompatibility.py b/tests/py/ts/integrations/test_trt_intercompatibility.py index b938e4a1ac..ed3d906386 100644 --- a/tests/py/ts/integrations/test_trt_intercompatibility.py +++ b/tests/py/ts/integrations/test_trt_intercompatibility.py @@ -1,11 +1,17 @@ import unittest -import torch_tensorrt as torchtrt + import torch +import torch_tensorrt as torchtrt import torchvision.models as models +from utils import COSINE_THRESHOLD, cosine_similarity + import tensorrt as trt -from utils import cosine_similarity, COSINE_THRESHOLD +@unittest.skipIf( + not torchtrt.ENABLED_FEATURES.torchscript_frontend, + "TorchScript Frontend is not available", +) class TestPyTorchToTRTEngine(unittest.TestCase): def test_pt_to_trt(self): self.model = models.resnet18(pretrained=True).eval().to("cuda:0") diff --git a/tests/py/ts/models/test_models.py b/tests/py/ts/models/test_models.py index 5678e8f648..1d5c3bae3b 100644 --- a/tests/py/ts/models/test_models.py +++ b/tests/py/ts/models/test_models.py @@ -10,6 +10,10 @@ from utils import COSINE_THRESHOLD, cosine_similarity +@unittest.skipIf( + not torchtrt.ENABLED_FEATURES.torchscript_frontend, + "TorchScript Frontend is not available", +) class TestModels(unittest.TestCase): def test_resnet18(self): self.model = models.resnet18(pretrained=True).eval().to("cuda") diff --git a/tests/py/ts/models/test_multiple_registered_engines.py b/tests/py/ts/models/test_multiple_registered_engines.py index e8c1f95433..407502f04a 100644 --- a/tests/py/ts/models/test_multiple_registered_engines.py +++ b/tests/py/ts/models/test_multiple_registered_engines.py @@ -1,14 +1,19 @@ +import copy import unittest -import torch_tensorrt as torchtrt +from typing import Dict + +import custom_models as cm +import timm import torch +import torch_tensorrt as torchtrt import torchvision.models as models -import copy -import timm -import custom_models as cm -from typing import Dict -from utils import cosine_similarity, COSINE_THRESHOLD +from utils import COSINE_THRESHOLD, cosine_similarity +@unittest.skipIf( + not torchtrt.ENABLED_FEATURES.torchscript_frontend, + "TorchScript Frontend is not available", +) class TestModelToEngineToModel(unittest.TestCase): def test_multiple_engines(self): self.resnet18 = models.resnet18(pretrained=True).eval().to("cuda") diff --git a/tests/py/ts/ptq/test_ptq_dataloader_calibrator.py b/tests/py/ts/ptq/test_ptq_dataloader_calibrator.py index c5a84f301d..2fac02f542 100644 --- a/tests/py/ts/ptq/test_ptq_dataloader_calibrator.py +++ b/tests/py/ts/ptq/test_ptq_dataloader_calibrator.py @@ -1,13 +1,13 @@ +import os import unittest -import torch_tensorrt as torchtrt -from torch_tensorrt.logging import * + import torch import torch.nn as nn -from torch.nn import functional as F +import torch_tensorrt as torchtrt import torchvision import torchvision.transforms as transforms - -import os +from torch.nn import functional as F +from torch_tensorrt.ts.logging import * def find_repo_root(max_depth=10): @@ -49,6 +49,10 @@ def compute_accuracy(testing_dataloader, model): return correct / total +@unittest.skipIf( + not torchtrt.ENABLED_FEATURES.torchscript_frontend, + "TorchScript Frontend is not available", +) class TestAccuracy(unittest.TestCase): def test_compile_script(self): self.model = ( diff --git a/tests/py/ts/ptq/test_ptq_to_backend.py b/tests/py/ts/ptq/test_ptq_to_backend.py index 3a0a5bf336..d016dedb15 100644 --- a/tests/py/ts/ptq/test_ptq_to_backend.py +++ b/tests/py/ts/ptq/test_ptq_to_backend.py @@ -1,12 +1,13 @@ +import os import unittest -import torch_tensorrt as torchtrt -from torch_tensorrt.logging import * + import torch import torch.nn as nn -from torch.nn import functional as F +import torch_tensorrt as torchtrt import torchvision import torchvision.transforms as transforms -import os +from torch.nn import functional as F +from torch_tensorrt.ts.logging import * def find_repo_root(max_depth=10): @@ -48,6 +49,10 @@ def compute_accuracy(testing_dataloader, model): return correct / total +@unittest.skipIf( + not torchtrt.ENABLED_FEATURES.torchscript_frontend, + "TorchScript Frontend is not available", +) class TestAccuracy(unittest.TestCase): def test_compile_script(self): self.model = ( diff --git a/tests/py/ts/ptq/test_ptq_trt_calibrator.py b/tests/py/ts/ptq/test_ptq_trt_calibrator.py index 93596c895d..5ebd47e807 100644 --- a/tests/py/ts/ptq/test_ptq_trt_calibrator.py +++ b/tests/py/ts/ptq/test_ptq_trt_calibrator.py @@ -1,13 +1,15 @@ -import unittest import os -import torch_tensorrt as torchtrt -from torch_tensorrt.logging import * +import unittest + import torch -import tensorrt as trt import torch.nn as nn -from torch.nn import functional as F +import torch_tensorrt as torchtrt import torchvision import torchvision.transforms as transforms +from torch.nn import functional as F +from torch_tensorrt.ts.logging import * + +import tensorrt as trt def find_repo_root(max_depth=10): @@ -49,6 +51,10 @@ def compute_accuracy(testing_dataloader, model): return correct / total +@unittest.skipIf( + not torchtrt.ENABLED_FEATURES.torchscript_frontend, + "TorchScript Frontend is not available", +) class TRTEntropyCalibrator(trt.IInt8EntropyCalibrator2): def __init__(self, dataloader, **kwargs): trt.IInt8EntropyCalibrator2.__init__(self) @@ -94,6 +100,10 @@ def write_calibration_cache(self, cache): f.write(cache) +@unittest.skipIf( + not torchtrt.ENABLED_FEATURES.torchscript_frontend, + "TorchScript Frontend is not available", +) class TestAccuracy(unittest.TestCase): def test_compile_script(self): self.model = ( diff --git a/tests/py/ts/qat/test_qat_trt_accuracy.py b/tests/py/ts/qat/test_qat_trt_accuracy.py index ce574c57fe..ade2cfc865 100644 --- a/tests/py/ts/qat/test_qat_trt_accuracy.py +++ b/tests/py/ts/qat/test_qat_trt_accuracy.py @@ -1,13 +1,14 @@ +import os +import sys import unittest -import torch_tensorrt as torchtrt -from torch_tensorrt.logging import * + import torch import torch.nn as nn -from torch.nn import functional as F +import torch_tensorrt as torchtrt import torchvision import torchvision.transforms as transforms -import os -import sys +from torch.nn import functional as F +from torch_tensorrt.ts.logging import * def find_repo_root(max_depth=10): @@ -51,6 +52,10 @@ def compute_accuracy(testing_dataloader, model): return correct / total +@unittest.skipIf( + not torchtrt.ENABLED_FEATURES.torchscript_frontend, + "TorchScript Frontend is not available", +) class TestAccuracy(unittest.TestCase): def test_compile_script(self): self.model = ( diff --git a/toolchains/legacy/pyproject.toml b/toolchains/legacy/pyproject.toml index ce9e6423cb..90b2d4f2ec 100644 --- a/toolchains/legacy/pyproject.toml +++ b/toolchains/legacy/pyproject.toml @@ -64,7 +64,7 @@ include-package-data = false [tool.ruff] # NOTE: Synchoronize the ignores with .flake8 -ignore = [ +lint.ignore = [ # these ignores are from flake8-bugbear; please fix! "B007", "B008", "B017", "B018", # Useless expression @@ -97,7 +97,7 @@ ignore = [ "SIM118", ] #line-length = 120 -select = [ +lint.select = [ "B", "C4", "G", @@ -112,11 +112,11 @@ select = [ ] # Allow unused variables when underscore-prefixed. -dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" +lint.dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" target-version = "py311" # Allow autofix for all enabled rules (when `--fix`) is provided. -fixable = [ +lint.fixable = [ "A","B","C","D","E","F","G", "I","N","Q","S","T","W", "ANN", "ARG", "BLE", "COM", "DJ", @@ -125,10 +125,10 @@ fixable = [ "PGH", "PIE", "PL", "PT", "PTH", "PYI", "RET", "RSE", "RUF", "SIM", "SLF", "TCH", "TID", "TRY", "UP", "YTT"] -unfixable = [] +lint.unfixable = [] # Exclude a variety of commonly ignored directories. -exclude = [ +lint.exclude = [ ".bzr", ".direnv", ".eggs", @@ -164,7 +164,7 @@ exclude = [ "__init__.py" ] -[tool.ruff.mccabe] +[tool.ruff.lint.mccabe] # Unlike Flake8, default to a complexity level of 10. max-complexity = 10 diff --git a/versions.py b/versions.py index 772737aab7..db418a06d2 100644 --- a/versions.py +++ b/versions.py @@ -1,11 +1,10 @@ -import yaml -import re import os +import re import subprocess - from datetime import datetime from pathlib import Path -from typing import List + +import yaml __version__ = "0.0.0" __cuda_version__ = "0.0"