From d18c013f03b9f774604530f29c08608780c82470 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Mon, 17 Mar 2025 18:33:18 -0700 Subject: [PATCH 01/11] fix: Fix BF16 compilation issues --- py/torch_tensorrt/_enums.py | 7 +- .../dynamo/conversion/_TRTInterpreter.py | 7 +- .../dynamo/conversion/converter_utils.py | 57 ++++++- .../dynamo/conversion/impl/conv.py | 14 +- .../dynamo/conversion/impl/deconv.py | 14 +- .../conversion/test_convolution_aten.py | 160 +++++++++--------- tests/py/dynamo/models/test_models.py | 44 +++++ 7 files changed, 196 insertions(+), 107 deletions(-) diff --git a/py/torch_tensorrt/_enums.py b/py/torch_tensorrt/_enums.py index c706c345d6..9b73e6a67e 100644 --- a/py/torch_tensorrt/_enums.py +++ b/py/torch_tensorrt/_enums.py @@ -4,6 +4,7 @@ from enum import Enum, auto from typing import Any, Optional, Type, Union +import ml_dtypes import numpy as np import tensorrt as trt import torch @@ -416,10 +417,8 @@ def to( return np.float64 elif self == dtype.b: return np.bool_ - # TODO: Consider using ml_dtypes when issues like this are resolved: - # https://github.com/pytorch/pytorch/issues/109873 - # elif self == dtype.bf16: - # return ml_dtypes.bfloat16 + elif self == dtype.bf16: + return ml_dtypes.bfloat16 elif use_default: return np.float32 else: diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 7f26a7c3e6..ddc1d828f8 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -41,6 +41,7 @@ get_node_io, get_node_name, get_trt_tensor, + to_torch, ) from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, get_model_device, to_torch_device from torch_tensorrt.fx.observer import Observer @@ -869,8 +870,6 @@ def call_function(self, target: str, args: Any, kwargs: Any) -> Any: def get_attr(self, target: str, args: Any, kwargs: Any) -> np.ndarray: with _disable_current_modes(): - from torch_tensorrt.dynamo.conversion.converter_utils import to_numpy - frozen_attr = self.fetch_attr(target) if isinstance(frozen_attr, torch.nn.Parameter): @@ -878,9 +877,7 @@ def get_attr(self, target: str, args: Any, kwargs: Any) -> np.ndarray: else: constant_tensor = frozen_attr - network_constant = to_numpy(constant_tensor) - - return network_constant + return to_torch(constant_tensor) def call_method(self, target: str, args: Any, kwargs: Any) -> Any: assert isinstance(target, str) diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index 62526080c4..cb1ca6550e 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -344,10 +344,13 @@ def create_constant( # Rank 0 constant is required in IFillLayer inputs. if min_rank == 0: shape = trt.Dims() - numpy_value = to_numpy(value, dtype) + + torch_value = to_torch(value, dtype) + trt_dtype = _enums.dtype._from(torch_value.dtype).to(trt.DataType, use_default=True) + weights = trt.Weights(trt_dtype, torch_value.data_ptr(), torch_value.numel()) constant = ctx.net.add_constant( - shape if isinstance(value, (int, float, bool)) else value.shape, - numpy_value.copy() if isinstance(numpy_value, np.ndarray) else numpy_value, + shape if isinstance(value, (int, float, bool)) else list(torch_value.shape), + weights, ) constant.name = name return constant.get_output(0) @@ -564,6 +567,9 @@ def to_numpy( value = value.dequantize() elif value.dtype == torch.bfloat16: # TODO: Remove when numpy has a BF16 type + _LOGGER.warning( + "Requested a conversion of bfloat16 tensor from torch to numpy which isn't supported. Casting this tensor to FP32 precision currently. Please use to_torch() API for better data representation", + ) value = value.to(torch.float) output = value.cpu().detach().contiguous().numpy() @@ -589,6 +595,51 @@ def to_numpy( ) +def to_torch( + value: Optional[Union[torch.Tensor, np.ndarray, int, float, bool]], + dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType, _enums.dtype]] = None, +) -> Optional[np.ndarray]: + """ + Convert a Numpy array, or scalar to a PyTorch tensor and move it to CPU + Args: + value (Optional[Union[torch.Tensor, np.ndarray, int, float, bool]]): + A PyTorch tensor, Numpy array, int, float, or bool + dtype (Optional[Union[torch.dtype, np.dtype, TRTDataType]]): + If a dtype is given, we will convert the type of the given `value` to this dtype. + Returns: + A Numpy array or None, if the input was None. + """ + + cpu_device = torch.device("cpu") + if value is None: + return None + + elif isinstance(value, torch.Tensor): + return value.to(cpu_device) + + elif isinstance(value, np.ndarray): + output = torch.from_numpy(value).to(cpu_device) + return ( + output.to(_enums.dtype._from(dtype).to(torch.dtype, use_default=True)) + if dtype + else output + ) + + elif isinstance(value, int): + return torch.tensor([value], device=cpu_device, dtype=torch.int32) + + elif isinstance(value, float): + return torch.tensor([value], device=cpu_device, dtype=torch.float32) + + elif isinstance(value, bool): + return torch.tensor([value], device=cpu_device, dtype=torch.bool) + + else: + raise AssertionError( + f"to_torch can only be called on None, bool, int, float, np.ndarray, or torch.Tensor, got an object of type: {type(value)}" + ) + + def flatten_dims( input: Sequence[Union[TRTTensor, torch.Tensor, np.ndarray]], start_dim: int, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/conv.py b/py/torch_tensorrt/dynamo/conversion/impl/conv.py index 25419d7f60..f27fb13e97 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/conv.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/conv.py @@ -13,7 +13,7 @@ cast_trt_tensor, extend_attr_to_tuple, get_trt_tensor, - to_numpy, + to_torch, ) from torch_tensorrt.fx.converters.converter_utils import ( get_dyn_range, @@ -45,7 +45,6 @@ def convNd( assert input.shape[1] != -1, "Channel dim can't be dynamic for convolution." num_dims = len(input.shape) - 2 - if is_conv1d: # Apply an unsqueeze operation to transform the conv1d problem into conv2d input = impl.unsqueeze.unsqueeze( @@ -54,8 +53,8 @@ def convNd( # Process bias terms if isinstance(bias, (torch.Tensor, np.ndarray)): - # Transform the bias constant into a Numpy array - bias = to_numpy(bias, dtype=input.dtype) + bias = to_torch(bias, dtype=input.dtype) + bias = get_trt_tensor(ctx, bias, f"{name}_bias") elif isinstance(bias, TRTTensor): bias = get_trt_tensor(ctx, bias, f"{name}_bias") @@ -74,12 +73,11 @@ def convNd( ctx, target, source_ir, weight.name + "_unsqueeze_conv1d", weight, -1 ) elif isinstance(weight, (torch.Tensor, np.ndarray)): - # Transform the weight constant into a Numpy array - weight = to_numpy(weight, dtype=input.dtype) - + weight = to_torch(weight, dtype=input.dtype) # Append new dimension (unsqueeze) if the convolution is 1d if is_conv1d: - weight = np.expand_dims(weight, -1) + weight = torch.unsqueeze(weight, -1) + weight = get_trt_tensor(ctx, weight, f"{name}_weight") else: raise RuntimeError( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/deconv.py b/py/torch_tensorrt/dynamo/conversion/impl/deconv.py index d19a92e646..629cecf5db 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/deconv.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/deconv.py @@ -6,13 +6,12 @@ import tensorrt as trt import torch from torch.fx.node import Target - from torch_tensorrt.dynamo.conversion import impl from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import ( extend_attr_to_tuple, get_trt_tensor, - to_numpy, + to_torch, ) from torch_tensorrt.fx.converters.converter_utils import ( SourceIR, @@ -53,7 +52,8 @@ def deconvNd( # Process bias terms if isinstance(bias, (torch.Tensor, np.ndarray)): # Transform the bias constant into a Numpy array - bias = to_numpy(bias) + bias = to_torch(bias, dtype=input.dtype) + bias = get_trt_tensor(ctx, bias, f"{name}_bias") elif isinstance(bias, TRTTensor): bias = get_trt_tensor(ctx, bias, f"{name}_bias") @@ -73,12 +73,12 @@ def deconvNd( ) elif isinstance(weight, (torch.Tensor, np.ndarray)): - # Transform the weight constant into a Numpy array - weight = to_numpy(weight) - + weight = to_torch(weight, dtype=input.dtype) # Append new dimension (unsqueeze) if the deconvolution is 1d if is_deconv1d: - weight = np.expand_dims(weight, axis=-1) + weight = torch.unsqueeze(weight, -1) + + weight = get_trt_tensor(ctx, weight, f"{name}_weight") else: raise RuntimeError( diff --git a/tests/py/dynamo/conversion/test_convolution_aten.py b/tests/py/dynamo/conversion/test_convolution_aten.py index 78d7fc4cca..a81310ee99 100644 --- a/tests/py/dynamo/conversion/test_convolution_aten.py +++ b/tests/py/dynamo/conversion/test_convolution_aten.py @@ -10,11 +10,11 @@ class TestConvolutionConverter(DispatchTestCase): @parameterized.expand( [ ("default", 1), - param("no_bias", 1, bias=False), - ("tuple_parameters", 1, (1), (1)), - param("non_zero_padding", 1, padding=1), - param("dilation", 1, dilation=2), - param("groups", 1, groups=3), + # param("no_bias", 1, bias=False), + # ("tuple_parameters", 1, (1), (1)), + # param("non_zero_padding", 1, padding=1), + # param("dilation", 1, dilation=2), + # param("groups", 1, groups=3), ] ) def test_conv1d( @@ -45,87 +45,87 @@ def forward(self, x): enable_passes=True, ) - @parameterized.expand( - [ - ("default", 1), - param("no_bias", 1, bias=False), - ("tuple_parameters", 1, (1), (1)), - param("non_zero_padding", 1, padding=1), - param("dilation", 1, dilation=2), - ] - ) - def test_conv1d_TRTTensor_weight( - self, - _, - kernel_size, - stride=1, - padding=0, - dilation=1, - groups=1, - bias=True, - ): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() + # @parameterized.expand( + # [ + # ("default", 1), + # param("no_bias", 1, bias=False), + # ("tuple_parameters", 1, (1), (1)), + # param("non_zero_padding", 1, padding=1), + # param("dilation", 1, dilation=2), + # ] + # ) + # def test_conv1d_TRTTensor_weight( + # self, + # _, + # kernel_size, + # stride=1, + # padding=0, + # dilation=1, + # groups=1, + # bias=True, + # ): + # class TestModule(torch.nn.Module): + # def __init__(self): + # super().__init__() - def forward(self, x, w): - return torch.ops.aten.convolution.default( - x, - w, - None, - (stride,) if isinstance(stride, int) else stride, - (padding,) if isinstance(padding, int) else padding, - (dilation,) if isinstance(dilation, int) else dilation, - False, - (0,), - groups, - ) + # def forward(self, x, w): + # return torch.ops.aten.convolution.default( + # x, + # w, + # None, + # (stride,) if isinstance(stride, int) else stride, + # (padding,) if isinstance(padding, int) else padding, + # (dilation,) if isinstance(dilation, int) else dilation, + # False, + # (0,), + # groups, + # ) - inputs = [ - torch.randn(1, 3, 32), - torch.randn( - 6, 3, 1 - ), # Conv1d weight shape: (out_channels, in_channels, kernel_size) - ] - self.run_test( - TestModule(), - inputs, - use_dynamo_tracer=True, - ) + # inputs = [ + # torch.randn(1, 3, 32), + # torch.randn( + # 6, 3, 1 + # ), # Conv1d weight shape: (out_channels, in_channels, kernel_size) + # ] + # self.run_test( + # TestModule(), + # inputs, + # use_dynamo_tracer=True, + # ) - def test_conv1d_with_dynamic_shape( - self, - kernel_size=1, - stride=1, - padding=0, - dilation=1, - groups=1, - bias=True, - ): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv1d( - 3, 6, kernel_size, stride, padding, dilation, groups, bias - ) + # def test_conv1d_with_dynamic_shape( + # self, + # kernel_size=1, + # stride=1, + # padding=0, + # dilation=1, + # groups=1, + # bias=True, + # ): + # class TestModule(torch.nn.Module): + # def __init__(self): + # super().__init__() + # self.conv = torch.nn.Conv1d( + # 3, 6, kernel_size, stride, padding, dilation, groups, bias + # ) - def forward(self, x): - return self.conv(x) + # def forward(self, x): + # return self.conv(x) - input_specs = [ - Input( - shape=(-1, 3, 3), - dtype=torch.float32, - shape_ranges=[((1, 3, 3), (3, 3, 3), (5, 3, 3))], - ), - ] + # input_specs = [ + # Input( + # shape=(-1, 3, 3), + # dtype=torch.float32, + # shape_ranges=[((1, 3, 3), (3, 3, 3), (5, 3, 3))], + # ), + # ] - self.run_test_with_dynamic_shape( - TestModule(), - input_specs, - use_dynamo_tracer=True, - enable_passes=True, - ) + # self.run_test_with_dynamic_shape( + # TestModule(), + # input_specs, + # use_dynamo_tracer=True, + # enable_passes=True, + # ) @parameterized.expand( [ diff --git a/tests/py/dynamo/models/test_models.py b/tests/py/dynamo/models/test_models.py index b6f986711a..acf59a0d5a 100644 --- a/tests/py/dynamo/models/test_models.py +++ b/tests/py/dynamo/models/test_models.py @@ -182,3 +182,47 @@ def test_resnet18_half(ir): # Clean up model env torch._dynamo.reset() + + +@pytest.mark.unit +def test_bf16_model(ir): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True) + self.relu = torch.nn.ReLU() + + def forward(self, x): + out = self.conv(x) + out = self.relu(out) + return out + + model = MyModule().eval().cuda().to(torch.bfloat16) + input = torch.randn((1, 3, 224, 224)).to("cuda").to(torch.bfloat16) + + compile_spec = { + "inputs": [ + torchtrt.Input( + input.shape, dtype=torch.bfloat16, format=torch.contiguous_format + ) + ], + "device": torchtrt.Device("cuda:0"), + "enabled_precisions": {torch.float32}, + "ir": ir, + "pass_through_build_failures": True, + "min_block_size": 1, + "cache_built_engines": False, + "reuse_cached_engines": False, + "use_explicit_typing": True, + } + + trt_mod = torchtrt.compile(model, **compile_spec) + cos_sim = cosine_similarity(model(input), trt_mod(input)) + breakpoint() + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"BF16 model TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + # Clean up model env + torch._dynamo.reset() From daa97a8a7df688cc46f5c0df7b26e80a02d45083 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Tue, 18 Mar 2025 10:02:04 -0700 Subject: [PATCH 02/11] chore: minor fixes --- .../dynamo/conversion/converter_utils.py | 38 +++-- .../conversion/test_convolution_aten.py | 160 +++++++++--------- 2 files changed, 108 insertions(+), 90 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index cb1ca6550e..e53718cb06 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -346,14 +346,32 @@ def create_constant( shape = trt.Dims() torch_value = to_torch(value, dtype) - trt_dtype = _enums.dtype._from(torch_value.dtype).to(trt.DataType, use_default=True) - weights = trt.Weights(trt_dtype, torch_value.data_ptr(), torch_value.numel()) - constant = ctx.net.add_constant( - shape if isinstance(value, (int, float, bool)) else list(torch_value.shape), - weights, - ) - constant.name = name - return constant.get_output(0) + if torch_value: + if torch_value.dtype == torch.bfloat16: + torch_value_fp32 = torch_value.to(torch.float32) + numpy_value = torch_value_fp32.numpy() + else: + numpy_value = torch_value.numpy() + + constant = ctx.net.add_constant( + shape if isinstance(value, (int, float, bool)) else list(torch_value.shape), + numpy_value, + ) + constant.name = name + + if torch_value.dtype == torch.bfloat16: + return cast_trt_tensor( + ctx, + constant.get_output(0), + trt.DataType.BF16, + name + "_bf16_cast", + ) + + return constant.get_output(0) + else: + raise ValueError( + f"Cannot convert tensor '{name}' to a TensorRT constant because its value is None." + ) def get_trt_tensor( @@ -615,10 +633,10 @@ def to_torch( return None elif isinstance(value, torch.Tensor): - return value.to(cpu_device) + return value.to(cpu_device).contiguous() elif isinstance(value, np.ndarray): - output = torch.from_numpy(value).to(cpu_device) + output = torch.from_numpy(value).to(cpu_device).contiguous() return ( output.to(_enums.dtype._from(dtype).to(torch.dtype, use_default=True)) if dtype diff --git a/tests/py/dynamo/conversion/test_convolution_aten.py b/tests/py/dynamo/conversion/test_convolution_aten.py index a81310ee99..78d7fc4cca 100644 --- a/tests/py/dynamo/conversion/test_convolution_aten.py +++ b/tests/py/dynamo/conversion/test_convolution_aten.py @@ -10,11 +10,11 @@ class TestConvolutionConverter(DispatchTestCase): @parameterized.expand( [ ("default", 1), - # param("no_bias", 1, bias=False), - # ("tuple_parameters", 1, (1), (1)), - # param("non_zero_padding", 1, padding=1), - # param("dilation", 1, dilation=2), - # param("groups", 1, groups=3), + param("no_bias", 1, bias=False), + ("tuple_parameters", 1, (1), (1)), + param("non_zero_padding", 1, padding=1), + param("dilation", 1, dilation=2), + param("groups", 1, groups=3), ] ) def test_conv1d( @@ -45,87 +45,87 @@ def forward(self, x): enable_passes=True, ) - # @parameterized.expand( - # [ - # ("default", 1), - # param("no_bias", 1, bias=False), - # ("tuple_parameters", 1, (1), (1)), - # param("non_zero_padding", 1, padding=1), - # param("dilation", 1, dilation=2), - # ] - # ) - # def test_conv1d_TRTTensor_weight( - # self, - # _, - # kernel_size, - # stride=1, - # padding=0, - # dilation=1, - # groups=1, - # bias=True, - # ): - # class TestModule(torch.nn.Module): - # def __init__(self): - # super().__init__() + @parameterized.expand( + [ + ("default", 1), + param("no_bias", 1, bias=False), + ("tuple_parameters", 1, (1), (1)), + param("non_zero_padding", 1, padding=1), + param("dilation", 1, dilation=2), + ] + ) + def test_conv1d_TRTTensor_weight( + self, + _, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() - # def forward(self, x, w): - # return torch.ops.aten.convolution.default( - # x, - # w, - # None, - # (stride,) if isinstance(stride, int) else stride, - # (padding,) if isinstance(padding, int) else padding, - # (dilation,) if isinstance(dilation, int) else dilation, - # False, - # (0,), - # groups, - # ) + def forward(self, x, w): + return torch.ops.aten.convolution.default( + x, + w, + None, + (stride,) if isinstance(stride, int) else stride, + (padding,) if isinstance(padding, int) else padding, + (dilation,) if isinstance(dilation, int) else dilation, + False, + (0,), + groups, + ) - # inputs = [ - # torch.randn(1, 3, 32), - # torch.randn( - # 6, 3, 1 - # ), # Conv1d weight shape: (out_channels, in_channels, kernel_size) - # ] - # self.run_test( - # TestModule(), - # inputs, - # use_dynamo_tracer=True, - # ) + inputs = [ + torch.randn(1, 3, 32), + torch.randn( + 6, 3, 1 + ), # Conv1d weight shape: (out_channels, in_channels, kernel_size) + ] + self.run_test( + TestModule(), + inputs, + use_dynamo_tracer=True, + ) - # def test_conv1d_with_dynamic_shape( - # self, - # kernel_size=1, - # stride=1, - # padding=0, - # dilation=1, - # groups=1, - # bias=True, - # ): - # class TestModule(torch.nn.Module): - # def __init__(self): - # super().__init__() - # self.conv = torch.nn.Conv1d( - # 3, 6, kernel_size, stride, padding, dilation, groups, bias - # ) + def test_conv1d_with_dynamic_shape( + self, + kernel_size=1, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv1d( + 3, 6, kernel_size, stride, padding, dilation, groups, bias + ) - # def forward(self, x): - # return self.conv(x) + def forward(self, x): + return self.conv(x) - # input_specs = [ - # Input( - # shape=(-1, 3, 3), - # dtype=torch.float32, - # shape_ranges=[((1, 3, 3), (3, 3, 3), (5, 3, 3))], - # ), - # ] + input_specs = [ + Input( + shape=(-1, 3, 3), + dtype=torch.float32, + shape_ranges=[((1, 3, 3), (3, 3, 3), (5, 3, 3))], + ), + ] - # self.run_test_with_dynamic_shape( - # TestModule(), - # input_specs, - # use_dynamo_tracer=True, - # enable_passes=True, - # ) + self.run_test_with_dynamic_shape( + TestModule(), + input_specs, + use_dynamo_tracer=True, + enable_passes=True, + ) @parameterized.expand( [ From 008f3d4403a1a34e46191ef76ebf83471a75b178 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Tue, 18 Mar 2025 10:22:48 -0700 Subject: [PATCH 03/11] chore: minor fix --- .../dynamo/conversion/converter_utils.py | 2 +- tests/py/dynamo/models/test_models.py | 49 ++++++++++++++++++- 2 files changed, 49 insertions(+), 2 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index e53718cb06..e412bd626c 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -346,7 +346,7 @@ def create_constant( shape = trt.Dims() torch_value = to_torch(value, dtype) - if torch_value: + if torch_value is not None: if torch_value.dtype == torch.bfloat16: torch_value_fp32 = torch_value.to(torch.float32) numpy_value = torch_value_fp32.numpy() diff --git a/tests/py/dynamo/models/test_models.py b/tests/py/dynamo/models/test_models.py index acf59a0d5a..6314baa5ec 100644 --- a/tests/py/dynamo/models/test_models.py +++ b/tests/py/dynamo/models/test_models.py @@ -218,7 +218,7 @@ def forward(self, x): trt_mod = torchtrt.compile(model, **compile_spec) cos_sim = cosine_similarity(model(input), trt_mod(input)) - breakpoint() + assertions.assertTrue( cos_sim > COSINE_THRESHOLD, msg=f"BF16 model TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", @@ -226,3 +226,50 @@ def forward(self, x): # Clean up model env torch._dynamo.reset() + + +@pytest.mark.unit +def test_bf16_fallback_model(ir): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 16, 3, padding=1, stride=1, bias=True) + self.relu = torch.nn.ReLU() + self.conv2 = torch.nn.Conv2d(16, 16, 3, padding=1, stride=1, bias=True) + + def forward(self, x): + out = self.conv(x) + out = self.relu(out) + out = self.conv2(out) + return out + + model = MyModule().eval().cuda().to(torch.bfloat16) + input = torch.randn((1, 3, 224, 224)).to("cuda").to(torch.bfloat16) + + compile_spec = { + "inputs": [ + torchtrt.Input( + input.shape, dtype=torch.bfloat16, format=torch.contiguous_format + ) + ], + "device": torchtrt.Device("cuda:0"), + "enabled_precisions": {torch.float32}, + "ir": ir, + "pass_through_build_failures": True, + "min_block_size": 1, + "cache_built_engines": False, + "reuse_cached_engines": False, + "use_explicit_typing": True, + "torch_executed_ops": {"torch.ops.aten.relu.default"}, + } + + trt_mod = torchtrt.compile(model, **compile_spec) + cos_sim = cosine_similarity(model(input), trt_mod(input)) + + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"BF16 fallback model TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + # Clean up model env + torch._dynamo.reset() From a7b63042ac61a8b1861ffa6d15f367017f5e7e90 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Tue, 18 Mar 2025 10:33:54 -0700 Subject: [PATCH 04/11] chore: revert bf16 enum fix --- py/torch_tensorrt/_enums.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/py/torch_tensorrt/_enums.py b/py/torch_tensorrt/_enums.py index 9b73e6a67e..c706c345d6 100644 --- a/py/torch_tensorrt/_enums.py +++ b/py/torch_tensorrt/_enums.py @@ -4,7 +4,6 @@ from enum import Enum, auto from typing import Any, Optional, Type, Union -import ml_dtypes import numpy as np import tensorrt as trt import torch @@ -417,8 +416,10 @@ def to( return np.float64 elif self == dtype.b: return np.bool_ - elif self == dtype.bf16: - return ml_dtypes.bfloat16 + # TODO: Consider using ml_dtypes when issues like this are resolved: + # https://github.com/pytorch/pytorch/issues/109873 + # elif self == dtype.bf16: + # return ml_dtypes.bfloat16 elif use_default: return np.float32 else: From 062a94be3cbda86df6da64a4b3036755620efd9b Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Wed, 19 Mar 2025 01:40:47 -0700 Subject: [PATCH 05/11] chore: fix CI failures --- .../dynamo/conversion/converter_utils.py | 47 ++++++++++--------- 1 file changed, 26 insertions(+), 21 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index e412bd626c..ddfb7e3fd3 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -9,6 +9,7 @@ import tensorrt as trt import torch import torch_tensorrt.dynamo.conversion.impl as impl +from torch.fx.experimental.proxy_tensor import unset_fake_temporarily from torch.fx.node import Argument, Target from torch.fx.passes.shape_prop import TensorMetadata from torch_tensorrt import _enums @@ -629,33 +630,37 @@ def to_torch( """ cpu_device = torch.device("cpu") - if value is None: - return None + torch_dtype = ( + _enums.dtype._from(dtype).to(torch.dtype, use_default=True) if dtype else None + ) - elif isinstance(value, torch.Tensor): - return value.to(cpu_device).contiguous() + with unset_fake_temporarily(): + if value is None: + return None - elif isinstance(value, np.ndarray): - output = torch.from_numpy(value).to(cpu_device).contiguous() - return ( - output.to(_enums.dtype._from(dtype).to(torch.dtype, use_default=True)) - if dtype - else output - ) + elif isinstance(value, torch.Tensor): + output = torch.atleast_1d(value).to(cpu_device).contiguous() - elif isinstance(value, int): - return torch.tensor([value], device=cpu_device, dtype=torch.int32) + elif isinstance(value, np.ndarray): + output = ( + torch.atleast_1d(torch.from_numpy(value)).to(cpu_device).contiguous() + ) - elif isinstance(value, float): - return torch.tensor([value], device=cpu_device, dtype=torch.float32) + elif isinstance(value, int): + output = torch.tensor([value], device=cpu_device, dtype=torch.int32) - elif isinstance(value, bool): - return torch.tensor([value], device=cpu_device, dtype=torch.bool) + elif isinstance(value, float): + output = torch.tensor([value], device=cpu_device, dtype=torch.float32) - else: - raise AssertionError( - f"to_torch can only be called on None, bool, int, float, np.ndarray, or torch.Tensor, got an object of type: {type(value)}" - ) + elif isinstance(value, bool): + output = torch.tensor([value], device=cpu_device, dtype=torch.bool) + + else: + raise AssertionError( + f"to_torch can only be called on None, bool, int, float, np.ndarray, or torch.Tensor, got an object of type: {type(value)}" + ) + + return output.to(torch_dtype) if torch_dtype else output def flatten_dims( From 7105474b08b4aa3fa1bf4bedd8bcdeae400c55f8 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Wed, 19 Mar 2025 17:17:13 -0700 Subject: [PATCH 06/11] chore: bug fix --- py/torch_tensorrt/dynamo/conversion/converter_utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index ddfb7e3fd3..6dc0309e13 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -355,7 +355,11 @@ def create_constant( numpy_value = torch_value.numpy() constant = ctx.net.add_constant( - shape if isinstance(value, (int, float, bool)) else list(torch_value.shape), + ( + shape + if isinstance(value, (int, float, bool)) or min_rank == 0 + else list(torch_value.shape) + ), numpy_value, ) constant.name = name From 7e9d388a09cfd50ba789d6102f1dae0eae54b980 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Fri, 21 Mar 2025 01:15:58 -0700 Subject: [PATCH 07/11] chore: fix CI test failures --- .../dynamo/conversion/converter_utils.py | 78 +++++++++---------- 1 file changed, 39 insertions(+), 39 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index 6dc0309e13..cb6dc2e010 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -341,42 +341,44 @@ def create_constant( Returns: A TensorRT ITensor that represents the given value. """ - shape = (1,) - # Rank 0 constant is required in IFillLayer inputs. - if min_rank == 0: - shape = trt.Dims() - - torch_value = to_torch(value, dtype) - if torch_value is not None: - if torch_value.dtype == torch.bfloat16: - torch_value_fp32 = torch_value.to(torch.float32) - numpy_value = torch_value_fp32.numpy() - else: - numpy_value = torch_value.numpy() - - constant = ctx.net.add_constant( - ( - shape - if isinstance(value, (int, float, bool)) or min_rank == 0 - else list(torch_value.shape) - ), - numpy_value, - ) - constant.name = name + with unset_fake_temporarily(): + shape = (1,) - if torch_value.dtype == torch.bfloat16: - return cast_trt_tensor( - ctx, - constant.get_output(0), - trt.DataType.BF16, - name + "_bf16_cast", + # Rank 0 constant is required in IFillLayer inputs. + if min_rank == 0: + shape = trt.Dims() + + torch_value = to_torch(value, dtype) + if torch_value is not None: + if torch_value.dtype == torch.bfloat16: + torch_value_fp32 = torch_value.to(torch.float32) + numpy_value = torch_value_fp32.numpy() + else: + numpy_value = torch_value.numpy() + + constant = ctx.net.add_constant( + ( + shape + if isinstance(value, (int, float, bool)) or min_rank == 0 + else list(torch_value.shape) + ), + numpy_value, ) + constant.name = name - return constant.get_output(0) - else: - raise ValueError( - f"Cannot convert tensor '{name}' to a TensorRT constant because its value is None." - ) + if torch_value.dtype == torch.bfloat16: + return cast_trt_tensor( + ctx, + constant.get_output(0), + trt.DataType.BF16, + name + "_bf16_cast", + ) + + return constant.get_output(0) + else: + raise ValueError( + f"Cannot convert tensor '{name}' to a TensorRT constant because its value is None." + ) def get_trt_tensor( @@ -621,7 +623,7 @@ def to_numpy( def to_torch( value: Optional[Union[torch.Tensor, np.ndarray, int, float, bool]], dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType, _enums.dtype]] = None, -) -> Optional[np.ndarray]: +) -> Optional[torch.Tensor]: """ Convert a Numpy array, or scalar to a PyTorch tensor and move it to CPU Args: @@ -630,7 +632,7 @@ def to_torch( dtype (Optional[Union[torch.dtype, np.dtype, TRTDataType]]): If a dtype is given, we will convert the type of the given `value` to this dtype. Returns: - A Numpy array or None, if the input was None. + A PyTorch tensor or None, if the input was None. """ cpu_device = torch.device("cpu") @@ -643,12 +645,10 @@ def to_torch( return None elif isinstance(value, torch.Tensor): - output = torch.atleast_1d(value).to(cpu_device).contiguous() + output = value.to(cpu_device).contiguous() elif isinstance(value, np.ndarray): - output = ( - torch.atleast_1d(torch.from_numpy(value)).to(cpu_device).contiguous() - ) + output = torch.from_numpy(value).to(cpu_device).contiguous() elif isinstance(value, int): output = torch.tensor([value], device=cpu_device, dtype=torch.int32) From 2c61ccbc9a7750db0e6b2a2bf601c65be5090312 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Tue, 25 Mar 2025 17:15:47 -0700 Subject: [PATCH 08/11] chore: additional CI test failure fixes --- .../dynamo/conversion/_TRTInterpreter.py | 31 ++++++++++--------- .../dynamo/conversion/converter_utils.py | 21 +++++++------ tests/py/dynamo/models/test_models_export.py | 4 +-- 3 files changed, 31 insertions(+), 25 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 3fc6c518e6..2a31924df5 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -21,6 +21,7 @@ import tensorrt as trt import torch import torch.fx +from torch.fx.experimental.proxy_tensor import unset_fake_temporarily from torch.fx.node import _get_qualified_name from torch.fx.passes.shape_prop import TensorMetadata from torch.utils._python_dispatch import _disable_current_modes @@ -409,12 +410,13 @@ def find_weight( np_map: the map from weight name to np values in INetworkDefinition state_dict: state of the graph module """ - network_weight = torch.from_numpy(np_map[weight_name]).to(device) - for sd_w_name, sd_weight in state_dict.items(): - if TRTInterpreter.check_weight_equal(sd_weight, network_weight, device): - del state_dict[sd_w_name] - return sd_w_name - return "" + with unset_fake_temporarily(): + network_weight = torch.from_numpy(np_map[weight_name]).to(device) + for sd_w_name, sd_weight in state_dict.items(): + if TRTInterpreter.check_weight_equal(sd_weight, network_weight, device): + del state_dict[sd_w_name] + return sd_w_name + return "" @staticmethod def check_weight_equal( @@ -422,14 +424,15 @@ def check_weight_equal( network_weight: Union[torch.Tensor, np.ndarray], device: torch.device, ) -> Any: - if not isinstance(network_weight, torch.Tensor): - network_weight = torch.from_numpy(network_weight).to(device) - try: - return sd_weight.shape == network_weight.shape and torch.all( - torch.abs(sd_weight - network_weight) < 0.01 - ) - except Exception: - return torch.all(sd_weight == network_weight) + with unset_fake_temporarily(): + if not isinstance(network_weight, torch.Tensor): + network_weight = torch.from_numpy(network_weight).to(device) + try: + return sd_weight.shape == network_weight.shape and torch.all( + torch.abs(sd_weight - network_weight) < 0.01 + ) + except Exception: + return torch.all(sd_weight == network_weight) def _save_weight_mapping(self) -> None: """ diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index cb6dc2e010..4813058e5b 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -342,13 +342,20 @@ def create_constant( A TensorRT ITensor that represents the given value. """ with unset_fake_temporarily(): - shape = (1,) + torch_value = to_torch(value, dtype) + if torch_value.dtype == torch.float64: + raise ValueError( + "TensorRT does not support float64 (double) precision. To resolve this, please set truncate_double=True in your compilation settings and re-run the model." + ) # Rank 0 constant is required in IFillLayer inputs. - if min_rank == 0: + if min_rank == 0 and isinstance(value, (int, float, bool)): shape = trt.Dims() - - torch_value = to_torch(value, dtype) + elif list(torch_value.shape) == []: + shape = (1,) + else: + shape = list(torch_value.shape) + # breakpoint() if torch_value is not None: if torch_value.dtype == torch.bfloat16: torch_value_fp32 = torch_value.to(torch.float32) @@ -357,11 +364,7 @@ def create_constant( numpy_value = torch_value.numpy() constant = ctx.net.add_constant( - ( - shape - if isinstance(value, (int, float, bool)) or min_rank == 0 - else list(torch_value.shape) - ), + shape, numpy_value, ) constant.name = name diff --git a/tests/py/dynamo/models/test_models_export.py b/tests/py/dynamo/models/test_models_export.py index 469ed569d1..f5230f3ace 100644 --- a/tests/py/dynamo/models/test_models_export.py +++ b/tests/py/dynamo/models/test_models_export.py @@ -257,7 +257,6 @@ def calibrate_loop(model): def test_base_int8(ir): import modelopt.torch.quantization as mtq from modelopt.torch.quantization.utils import export_torch_mode - from torch.export._trace import _export class SimpleNetwork(torch.nn.Module): def __init__(self): @@ -285,7 +284,7 @@ def calibrate_loop(model): with torch.no_grad(): with export_torch_mode(): - exp_program = _export(model, (input_tensor,)) + exp_program = torch.export.export(model, (input_tensor,)) trt_model = torchtrt.dynamo.compile( exp_program, inputs=[input_tensor], @@ -294,6 +293,7 @@ def calibrate_loop(model): debug=True, cache_built_engines=False, reuse_cached_engines=False, + truncate_double=True, ) outputs_trt = trt_model(input_tensor) assert torch.allclose(output_pyt, outputs_trt, rtol=5e-3, atol=1e-2) From 7bc6eaf49814fa8294ab061a677fb7448e59154e Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Thu, 27 Mar 2025 13:15:45 -0700 Subject: [PATCH 09/11] chore: updates --- py/torch_tensorrt/dynamo/_refit.py | 122 +++++++++--------- .../dynamo/conversion/_TRTInterpreter.py | 2 +- .../dynamo/conversion/converter_utils.py | 2 +- .../dynamo/conversion/impl/quantize.py | 82 ++++++------ 4 files changed, 108 insertions(+), 100 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 96fc6daad2..c128e9cc82 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -9,6 +9,7 @@ import tensorrt as trt import torch from torch.export import ExportedProgram +from torch.fx.experimental.proxy_tensor import unset_fake_temporarily from torch_tensorrt._enums import dtype from torch_tensorrt._Input import Input from torch_tensorrt.dynamo import partitioning @@ -144,71 +145,72 @@ def _refit_single_trt_engine_with_gm( Refit a TensorRT Engine in place """ - refitted = set() - torch_device = get_model_device(new_gm) - refitter = trt.Refitter(old_engine, TRT_LOGGER) - weight_list = refitter.get_all_weights() - - if weight_name_map: - # Get the refitting mapping - trt_wt_location = ( - trt.TensorLocation.DEVICE - if torch_device.type == "cuda" - else trt.TensorLocation.HOST - ) + with unset_fake_temporarily(): + refitted = set() + torch_device = get_model_device(new_gm) + refitter = trt.Refitter(old_engine, TRT_LOGGER) + weight_list = refitter.get_all_weights() + + if weight_name_map: + # Get the refitting mapping + trt_wt_location = ( + trt.TensorLocation.DEVICE + if torch_device.type == "cuda" + else trt.TensorLocation.HOST + ) - constant_mapping: dict[str, Any] = weight_name_map.pop( - "constant_mapping", {} - ) # type: ignore - mapping = construct_refit_mapping_from_weight_name_map( - weight_name_map, new_gm.state_dict() - ) - constant_mapping_with_type = {} - - for constant_name, val in constant_mapping.items(): - np_weight_type = val.dtype - val_tensor = torch.from_numpy(val).cuda() - trt_dtype = dtype.try_from(np_weight_type).to(trt.DataType) - torch_dtype = dtype.try_from(np_weight_type).to(torch.dtype) - constant_mapping_with_type[constant_name] = ( - val_tensor.clone().reshape(-1).contiguous().to(torch_dtype), - trt_dtype, + constant_mapping: dict[str, Any] = weight_name_map.pop( + "constant_mapping", {} + ) # type: ignore + mapping = construct_refit_mapping_from_weight_name_map( + weight_name_map, new_gm.state_dict() ) + constant_mapping_with_type = {} + + for constant_name, val in constant_mapping.items(): + np_weight_type = val.dtype + val_tensor = torch.from_numpy(val).cuda() + trt_dtype = dtype.try_from(np_weight_type).to(trt.DataType) + torch_dtype = dtype.try_from(np_weight_type).to(torch.dtype) + constant_mapping_with_type[constant_name] = ( + val_tensor.clone().reshape(-1).contiguous().to(torch_dtype), + trt_dtype, + ) - mapping.update(constant_mapping_with_type) + mapping.update(constant_mapping_with_type) - for layer_name in weight_list: - if layer_name not in mapping: - logger.warning(f"{layer_name} is not found in weight mapping.") - continue - # Use Numpy to create weights - weight, weight_dtype = mapping[layer_name] - trt_wt_tensor = trt.Weights( - weight_dtype, weight.data_ptr(), torch.numel(weight) - ) - refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location) - assert ( - len(refitter.get_missing_weights()) == 0 - ), "Fast refitting failed due to incomplete mapping" + for layer_name in weight_list: + if layer_name not in mapping: + logger.warning(f"{layer_name} is not found in weight mapping.") + continue + # Use Numpy to create weights + weight, weight_dtype = mapping[layer_name] + trt_wt_tensor = trt.Weights( + weight_dtype, weight.data_ptr(), torch.numel(weight) + ) + refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location) + assert ( + len(refitter.get_missing_weights()) == 0 + ), "Fast refitting failed due to incomplete mapping" - else: - mapping = construct_refit_mapping(new_gm, input_list, settings) - trt_wt_location = trt.TensorLocation.HOST - for layer_name in weight_list: - if layer_name not in mapping: - raise AssertionError(f"{layer_name} is not found in weight mapping") - # Use Numpy to create weights - weight, datatype = mapping[layer_name] - trt_wt_tensor = trt.Weights(datatype, weight.ctypes.data, weight.size) - refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location) - refitted.add(layer_name) - - if len(refitted) != len(weight_list): - logger.warning("Not all weights have been refitted!!!") - - if not refitter.refit_cuda_engine(): - logger.error("Error: failed to refit new weights.") - raise AssertionError("Refitting failed.") + else: + mapping = construct_refit_mapping(new_gm, input_list, settings) + trt_wt_location = trt.TensorLocation.HOST + for layer_name in weight_list: + if layer_name not in mapping: + raise AssertionError(f"{layer_name} is not found in weight mapping") + # Use Numpy to create weights + weight, datatype = mapping[layer_name] + trt_wt_tensor = trt.Weights(datatype, weight.ctypes.data, weight.size) + refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location) + refitted.add(layer_name) + + if len(refitted) != len(weight_list): + logger.warning("Not all weights have been refitted!!!") + + if not refitter.refit_cuda_engine(): + logger.error("Error: failed to refit new weights.") + raise AssertionError("Refitting failed.") def refit_module_weights( diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 2a31924df5..17f2fccbff 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -891,7 +891,7 @@ def call_function(self, target: str, args: Any, kwargs: Any) -> Any: return converter(self.ctx, target, args, kwargs, self._cur_node_name) def get_attr(self, target: str, args: Any, kwargs: Any) -> np.ndarray: - with _disable_current_modes(): + with _disable_current_modes(), unset_fake_temporarily(): frozen_attr = self.fetch_attr(target) if isinstance(frozen_attr, torch.nn.Parameter): diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index 4813058e5b..8f000ac94d 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -352,7 +352,7 @@ def create_constant( if min_rank == 0 and isinstance(value, (int, float, bool)): shape = trt.Dims() elif list(torch_value.shape) == []: - shape = (1,) + shape = trt.Dims() else: shape = list(torch_value.shape) # breakpoint() diff --git a/py/torch_tensorrt/dynamo/conversion/impl/quantize.py b/py/torch_tensorrt/dynamo/conversion/impl/quantize.py index b97840cd09..e472ed3092 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/quantize.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/quantize.py @@ -1,11 +1,13 @@ -from typing import Optional +from typing import Optional, Union import numpy as np import tensorrt as trt +import torch +from torch.fx.experimental.proxy_tensor import unset_fake_temporarily from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext -from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor +from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor, to_torch from torch_tensorrt.fx.converters.converter_utils import set_layer_name from torch_tensorrt.fx.types import TRTTensor @@ -16,7 +18,7 @@ def quantize( source_ir: Optional[SourceIR], name: str, input_tensor: TRTTensor, - amax: np.ndarray, + amax: Union[np.ndarray, torch.Tensor], num_bits: int, exponent_bits: int, ) -> TRTTensor: @@ -24,40 +26,44 @@ def quantize( Adds quantize and dequantize ops (QDQ) which quantize to INT8 or FP8 based on the output_type set and dequantizes them back. """ - if isinstance(input_tensor, TRTTensor) and input_tensor.dtype not in ( - trt.float32, - trt.float16, - ): - raise ValueError( - f"quantize converter received an input of {input_tensor.dtype} type. Supported types: float32 | float16" - ) - if num_bits != 8 or exponent_bits not in (0, 4): - raise ValueError( - f"quantize converter currently only accept INT8 or FP8 based quantize, got {num_bits=}, {exponent_bits=}" - ) - if num_bits == 8 and exponent_bits == 0: - max_bound = 127 - elif num_bits == 8 and exponent_bits == 4: - max_bound = 448 - scale = np.divide(amax, max_bound) - scale = get_trt_tensor(ctx, scale, name + "_scale") - # Add Q node - quantize_layer = ctx.net.add_quantize(input_tensor, scale) - if num_bits == 8 and exponent_bits == 0: - quantize_layer.set_output_type(0, trt.DataType.INT8) - elif num_bits == 8 and exponent_bits == 4: - quantize_layer.set_output_type(0, trt.DataType.FP8) - set_layer_name(quantize_layer, target, name + "_quantize", source_ir) - q_output = quantize_layer.get_output(0) - # Add DQ node - dequantize_layer = ctx.net.add_dequantize(q_output, scale) - set_layer_name(dequantize_layer, target, name + "_dequantize", source_ir) - if num_bits == 8 and exponent_bits == 0: - dequantize_layer.precision = trt.DataType.INT8 - elif num_bits == 8 and exponent_bits == 4: - # Set DQ layer precision to FP8 - dequantize_layer.precision = trt.DataType.FP8 - dq_output = dequantize_layer.get_output(0) + with unset_fake_temporarily(): + if isinstance(input_tensor, TRTTensor) and input_tensor.dtype not in ( + trt.float32, + trt.float16, + ): + raise ValueError( + f"quantize converter received an input of {input_tensor.dtype} type. Supported types: float32 | float16" + ) + if num_bits != 8 or exponent_bits not in (0, 4): + raise ValueError( + f"quantize converter currently only accept INT8 or FP8 based quantize, got {num_bits=}, {exponent_bits=}" + ) + if num_bits == 8 and exponent_bits == 0: + max_bound = 127 + elif num_bits == 8 and exponent_bits == 4: + max_bound = 448 - return dq_output + amax = to_torch(amax, None) + scale = torch.divide(amax, max_bound) + scale = get_trt_tensor(ctx, scale, name + "_scale") + # Add Q node + quantize_layer = ctx.net.add_quantize(input_tensor, scale) + if num_bits == 8 and exponent_bits == 0: + quantize_layer.set_output_type(0, trt.DataType.INT8) + elif num_bits == 8 and exponent_bits == 4: + quantize_layer.set_output_type(0, trt.DataType.FP8) + + set_layer_name(quantize_layer, target, name + "_quantize", source_ir) + q_output = quantize_layer.get_output(0) + # Add DQ node + dequantize_layer = ctx.net.add_dequantize(q_output, scale) + set_layer_name(dequantize_layer, target, name + "_dequantize", source_ir) + if num_bits == 8 and exponent_bits == 0: + dequantize_layer.precision = trt.DataType.INT8 + elif num_bits == 8 and exponent_bits == 4: + # Set DQ layer precision to FP8 + dequantize_layer.precision = trt.DataType.FP8 + dq_output = dequantize_layer.get_output(0) + + return dq_output From 0d5e91fc840df3fc7e77f8423394a935267bbdcd Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Thu, 27 Mar 2025 13:16:49 -0700 Subject: [PATCH 10/11] chore: updates --- py/torch_tensorrt/dynamo/conversion/converter_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index 8f000ac94d..bcb8495c67 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -355,7 +355,7 @@ def create_constant( shape = trt.Dims() else: shape = list(torch_value.shape) - # breakpoint() + if torch_value is not None: if torch_value.dtype == torch.bfloat16: torch_value_fp32 = torch_value.to(torch.float32) From c748fac000182f5e53aff0c21e3fbe2b9abbcc71 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Mon, 31 Mar 2025 21:35:07 -0700 Subject: [PATCH 11/11] chore: updates --- tests/py/dynamo/models/test_models_export.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/py/dynamo/models/test_models_export.py b/tests/py/dynamo/models/test_models_export.py index f5230f3ace..6f96e259b0 100644 --- a/tests/py/dynamo/models/test_models_export.py +++ b/tests/py/dynamo/models/test_models_export.py @@ -249,6 +249,7 @@ def calibrate_loop(model): @unittest.skipIf( platform.system() != "Linux" + or torch.cuda.get_device_capability() < (8, 9) or not importlib.util.find_spec("modelopt") or Version(metadata.version("nvidia-modelopt")) < Version("0.17.0"), "modelopt 0.17.0 or later is required, Int8 quantization is supported in modelopt since 0.17.0 or later for linux",