diff --git a/examples/fx/quantized_resnet_test.py b/examples/fx/quantized_resnet_test.py index c725c27aad..64d7579414 100644 --- a/examples/fx/quantized_resnet_test.py +++ b/examples/fx/quantized_resnet_test.py @@ -8,7 +8,7 @@ import torchvision.models as models from torch.ao.quantization.quantize_fx import ( convert_fx, - convert_to_reference, + convert_to_reference_fx, prepare_fx, ) from torch.fx.experimental.normalize import NormalizeArgs @@ -52,7 +52,7 @@ def build_int8_trt(rn18): prepared = prepare_fx(rn18, {"": qconfig}, data) for _ in range(10): prepared(data) - quantized_rn18 = convert_to_reference(prepared) + quantized_rn18 = convert_to_reference_fx(prepared) ref_res = quantized_rn18(data) print("quantized model:", quantized_rn18) diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index c697979a43..c6550ae7c7 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -117,7 +117,7 @@ def compile(module: Any, ir="default", inputs=[], enabled_precisions=set([_enums else: raise ValueError(f"Precision {enabled_precisions} not supported on FX") - return lower_to_trt(module, inputs, lower_precision=lower_precision, max_batch_size=inputs[0].size(0), explicit_batch_dimension=True) + return lower_to_trt(module, inputs, lower_precision=lower_precision, max_batch_size=inputs[0].size(0), explicit_batch_dimension=True, dynamic_batch=False) else: raise RuntimeError("Module is an unknown format or the ir requested is unknown") diff --git a/py/torch_tensorrt/fx/converters/acc_ops_converters.py b/py/torch_tensorrt/fx/converters/acc_ops_converters.py index 5558df28f5..95436e762e 100644 --- a/py/torch_tensorrt/fx/converters/acc_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/acc_ops_converters.py @@ -591,9 +591,33 @@ def acc_ops_batch_norm( ) power = np.ones_like(scale) + # For BatchNorm1d, reshape 1d to 2d + output_shape = input_val.shape + if not network.has_implicit_batch_dimension and len(input_val.shape) < 4: + assert ( + len(get_dynamic_dims(input_val.shape)) <= 1 + ), "BatchNorm1D with more than one dynamic dims is not currently supported." + reshape_layer = network.add_shuffle(input_val) + if len(input_val.shape) == 2: + reshape_layer.reshape_dims = (input_val.shape[0], input_val.shape[1], 1, 1) + else: # len(input_val.shape) == 3 + reshape_layer.reshape_dims = ( + input_val.shape[0], + input_val.shape[1], + input_val.shape[2], + 1, + ) + set_layer_name(reshape_layer, target, f"{name}_reshape_2d") + input_val = reshape_layer.get_output(0) layer = network.add_scale(input_val, trt.ScaleMode.CHANNEL, bias, scale, power) set_layer_name(layer, target, name) + # For BatchNorm1d, reshape output back to 1d + if not network.has_implicit_batch_dimension and len(output_shape) < 4: + reshape_output_layer = network.add_shuffle(layer.get_output(0)) + reshape_output_layer.reshape_dims = tuple(output_shape) + set_layer_name(reshape_output_layer, target, f"{name}_reshape_1d") + layer = reshape_output_layer return layer.get_output(0) @@ -614,7 +638,18 @@ def acc_ops_layer_norm(network, target, args, kwargs, name): eps_field = trt.PluginField( "eps", np.array([kwargs["eps"]], dtype=np.float32), trt.PluginFieldType.FLOAT32 ) - field_collection = trt.PluginFieldCollection([gamma_field, beta_field, eps_field]) + try: + normalized_shape = np.array(kwargs["normalized_shape"], dtype=np.int32) + except TypeError: + print("Unable to convert normalized_shape to a field, fall back to []") + normalized_shape = np.array([], dtype=np.int32) + + normalized_shape_filed = trt.PluginField( + "normalized_shape", normalized_shape, trt.PluginFieldType.INT32 + ) + field_collection = trt.PluginFieldCollection( + [gamma_field, beta_field, eps_field, normalized_shape_filed] + ) try: if network.has_implicit_batch_dimension: @@ -2838,11 +2873,7 @@ def num_slice_types(slices): """ Gather the number of slice in getitem slices. """ - num_slice = 0 - for s in slices: - if isinstance(s, slice) or isinstance(s, int): - num_slice += 1 - return num_slice + return sum(1 for s in slices if isinstance(s, slice) or isinstance(s, int)) def slice_to_trt_params(py_slice, dim_size): """ @@ -2878,9 +2909,9 @@ def slice_to_trt_params(py_slice, dim_size): new_slices = [] for s in slices: if s == Ellipsis: - while num_ellipsis > 0: + # pass explicit start to guard against negative num_ellipsis + for _ in range(0, num_ellipsis): new_slices.append(slice(None, None, None)) - num_ellipsis -= 1 else: new_slices.append(s) slices = new_slices diff --git a/py/torch_tensorrt/fx/converters/converter_utils.py b/py/torch_tensorrt/fx/converters/converter_utils.py index e54bd83efb..50c6f6fb03 100644 --- a/py/torch_tensorrt/fx/converters/converter_utils.py +++ b/py/torch_tensorrt/fx/converters/converter_utils.py @@ -41,6 +41,11 @@ def get_trt_plugin( Returns: A TensorRT plugin that can be added to TensorRT network as Plugin layer. """ + # print the registered plugins + # PLUGIN_CREATORS = trt.get_plugin_registry().plugin_creator_list + # for plugin_creator in PLUGIN_CREATORS: + # print(plugin_creator.name) + plugin_registry = trt.get_plugin_registry() plugin_creator = plugin_registry.get_plugin_creator( plugin_name, version, plugin_namespace @@ -214,7 +219,6 @@ def create_constant( if dtype: value = value.to(dtype) - constant = network.add_constant(value.shape, to_numpy(value)) constant.name = name return constant.get_output(0) diff --git a/py/torch_tensorrt/fx/input_tensor_spec.py b/py/torch_tensorrt/fx/input_tensor_spec.py index 9429a7661f..910bb8228e 100644 --- a/py/torch_tensorrt/fx/input_tensor_spec.py +++ b/py/torch_tensorrt/fx/input_tensor_spec.py @@ -6,15 +6,17 @@ from .utils import get_dynamic_dims -def generate_input_specs( - inputs, lower_setting, additional_inputs=None, fixed_shape=False -): +def generate_input_specs(inputs, lower_setting, additional_inputs=None): # AIT lower setting doesn't have explicit_batch_dimension field and # we just return None. if not hasattr(lower_setting, "explicit_batch_dimension"): return None - if not lower_setting.explicit_batch_dimension or fixed_shape: + # dynamic_batch is TRT only flag. It does not exist in AIT lower setting + if ( + not lower_setting.explicit_batch_dimension + or lower_setting.dynamic_batch is False + ): return InputTensorSpec.from_tensors(inputs) # If we don't have additional inputs, we assume the first dimension diff --git a/py/torch_tensorrt/fx/lower.py b/py/torch_tensorrt/fx/lower.py index 82791faf12..c99004585a 100644 --- a/py/torch_tensorrt/fx/lower.py +++ b/py/torch_tensorrt/fx/lower.py @@ -36,7 +36,7 @@ def lower_to_trt( timing_cache_prefix="", save_timing_cache=False, cuda_graph_batch_size=-1, - dynamic_batch=False, + dynamic_batch=True, ) -> nn.Module: """ Takes in original module, input and lowering setting, run lowering workflow to turn module diff --git a/py/torch_tensorrt/fx/lower_setting.py b/py/torch_tensorrt/fx/lower_setting.py index 50a0b5f32a..d3f2cc9a14 100644 --- a/py/torch_tensorrt/fx/lower_setting.py +++ b/py/torch_tensorrt/fx/lower_setting.py @@ -86,4 +86,4 @@ class LowerSetting(LowerSettingBasic): cuda_graph_batch_size: int = -1 preset_lowerer: str = "" opt_profile_replica: int = 1 - dynamic_batch: bool = False + dynamic_batch: bool = True diff --git a/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py b/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py index ef39fd3bf7..ee09da1ce5 100644 --- a/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py +++ b/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py @@ -1,3 +1,4 @@ +import datetime from functools import partial, wraps from typing import Any, Callable, Optional, Sequence @@ -142,6 +143,9 @@ def lower_func(split_result: SplitResult) -> nn.Module: # Only acc submodules will be lowered. if not submod_name.startswith(split_result.non_acc_submodule_prefix): + print("Now lowering submodule", submod_name) + lowering_start_time = datetime.datetime.now() + self.lower_setting.input_specs = generate_input_specs( submod_inputs, self.lower_setting, @@ -156,6 +160,10 @@ def lower_func(split_result: SplitResult) -> nn.Module: LOWER_SPLIT_POST_OBSERVER.observe( submod_name, lowered_module, submod_inputs ) + print( + f"Lowering submodule {submod_name} elapsed time", + datetime.datetime.now() - lowering_start_time, + ) return split_result.split_module diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_as_strided.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_as_strided.py index 9ba1f83474..003c8bd3e0 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_as_strided.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_as_strided.py @@ -3,7 +3,9 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase + +# from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestConverter(AccTestCase): @@ -30,7 +32,9 @@ def forward(self, x): test_implicit_batch_dim=False, ) - # Testing with shape (-1, -1, -1, -1) results into error: RuntimeError: setStorage: sizes [2, 3], strides [1, 2], storage offset 0, and itemsize 8 requiring a storage size of 48 are out of bounds for storage of size 1 + # Testing with shape (-1, 3) results into error: + # RuntimeError: setStorage: sizes [2, 3], strides [1, 2], storage offset 0, and itemsize 8 requiring a storage size of 48 are out of bounds for storage of size 16 + """ def test_as_strided_with_dynamic_shape_four_dimensions(self): class Stride(nn.Module): @@ -39,9 +43,9 @@ def forward(self, x): input_specs = [ InputTensorSpec( - shape=(-1, -1, -1, -1), + shape=(-1, 3), dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (5, 5, 5, 5))], + shape_ranges=[((1, 3), (2, 3), (2, 3))], ), ] diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_avgpool.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_avgpool.py index 91e9ca9c90..d5bf56e678 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_avgpool.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_avgpool.py @@ -39,6 +39,48 @@ def forward(self, x): inputs = [torch.randn(1, 3, 224)] self.run_test(TestModule(), inputs, expected_ops={acc_ops.avg_pool1d}) + @parameterized.expand( + [ + ("default", 1), + ("kernal_size", 3), + ("stride", 1, 2), + ("tuple_parameters", 2, (1,), (1,)), + param("padding", 2, padding=1), + param("ceil_mode", 1, ceil_mode=True), + param("include_pad", 2, padding=1, count_include_pad=False), + ] + ) + def test_avg_pool1d_with_dynamic_shape( + self, + test_name="default", + kernel_size=1, + stride=1, + padding=0, + ceil_mode=False, + count_include_pad=True, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.avg_pool = torch.nn.AvgPool1d( + kernel_size, stride, padding, ceil_mode, count_include_pad + ) + + def forward(self, x): + return self.avg_pool(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, 3, 3), + dtype=torch.float32, + shape_ranges=[((1, 3, 3), (3, 3, 3), (3, 3, 3))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={acc_ops.avg_pool1d} + ) + def test_avg_pool2d_with_dynamic_shape_four_dimensions( self, test_name="default", @@ -218,38 +260,6 @@ def forward(self, x): TestModule(), input_specs, expected_ops={acc_ops.avg_pool2d} ) - # Testing with (-1, -1, -1, -1) results in error: RuntimeError: ShapeProp error for: node=%avg_pool1d : [#users=1] = call_function[target=torch.avg_pool1d](args = (%x, (1,), (1,), (0,), False, True), kwargs = {}) with meta={} - """ - def test_avg_pool1d_with_dynamic_shape_four_dimensions( - self, - test_name="default", - kernel_size=1, - stride=1, - padding=0, - ceil_mode=False, - count_include_pad=True, - ): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.avg_pool = torch.nn.AvgPool1d( - kernel_size, stride, padding, ceil_mode, count_include_pad - ) - - def forward(self, x): - return self.avg_pool(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (5, 5, 5, 5))], - ), - ] - - self.run_test(TestModule(), input_specs, expected_ops={acc_ops.avg_pool1d}) - """ - if __name__ == "__main__": run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_batchnorm.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_batchnorm.py index 7b282f5bde..24f26d5480 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_batchnorm.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_batchnorm.py @@ -17,6 +17,27 @@ def forward(self, x): inputs = [torch.randn(1, 3, 224, 224)] self.run_test(TestModule(), inputs, expected_ops={acc_ops.batch_norm}) + def test_batchnorm1d_with_dynamic_shape(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.bn = torch.nn.BatchNorm1d(3) + + def forward(self, x): + return self.bn(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, 3, 5), + dtype=torch.float32, + shape_ranges=[((2, 3, 5), (6, 3, 5), (10, 3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={acc_ops.batch_norm} + ) + def test_batchnorm_with_dynamic_shape(self): class TestModule(torch.nn.Module): def __init__(self): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_binary_ops.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_binary_ops.py index 6da7a4e205..e122c2a414 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_binary_ops.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_binary_ops.py @@ -13,6 +13,8 @@ elementwise_ops = [ ((lambda x, y: x + y), acc_ops.add, NEED_TEST_BOTH_CONSTANTS_CASE), ((lambda x, y: x - y), acc_ops.sub, NEED_TEST_BOTH_CONSTANTS_CASE), + ((lambda x, y: torch.sub(x, y)), acc_ops.sub, False), + ((lambda x, y: x.sub(y)), acc_ops.sub, False), ((lambda x, y: x / y), acc_ops.div, NEED_TEST_BOTH_CONSTANTS_CASE), ((lambda x, y: x // y), acc_ops.floor_div, NEED_TEST_BOTH_CONSTANTS_CASE), ( diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_clamp.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_clamp.py index 2ba6273daa..5291331c67 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_clamp.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_clamp.py @@ -2,7 +2,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import param, parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestClampConverter(AccTestCase): @@ -27,8 +27,6 @@ def forward(self, x): inputs = [torch.randn(3, 4)] self.run_test(TestModule(), inputs, expected_ops={acc_ops.clamp}) - # Error: RuntimeError: ShapeProp error for: node=%clamp : [#users=1] = call_function[target=torch.clamp](args = (%x, 1, 0), kwargs = {}) with meta={} - """ @parameterized.expand( [ param("default", min=-1, max=0), @@ -55,8 +53,9 @@ def forward(self, x): ), ] - self.run_test(TestModule(), input_specs, expected_ops={acc_ops.clamp}) - """ + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={acc_ops.clamp} + ) if __name__ == "__main__": diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_convolution.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_convolution.py index c379da0217..e08484cd56 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_convolution.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_convolution.py @@ -44,6 +44,43 @@ def forward(self, x): test_explicit_precision=True, ) + @parameterized.expand( + [ + ("default", 1), + ] + ) + def test_conv1d_with_dynamic_shape( + self, + _, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv1d( + 3, 6, kernel_size, stride, padding, dilation, groups, bias + ) + + def forward(self, x): + return self.conv(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, 3, 3), + dtype=torch.float32, + shape_ranges=[((1, 3, 3), (3, 3, 3), (5, 3, 3))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={acc_ops.conv1d} + ) + @parameterized.expand( [ ("default", 1), @@ -77,6 +114,9 @@ def forward(self, x): inputs = [torch.randn(1, 3, 32, 32)] self.run_test(TestModule(), inputs, expected_ops={acc_ops.conv2d}) + # Testing with (-1, -1, -1, -1) results into Error: + # AssertionError: Channel dim can't be dynamic for convolution. + def test_conv2d_with_dynamic_shape(self): class TestModule(torch.nn.Module): def __init__(self): @@ -130,6 +170,9 @@ def forward(self, x): inputs = [torch.randn(1, 3, 32, 32, 32)] self.run_test(TestModule(), inputs, expected_ops={acc_ops.conv3d}) + # Testing with (-1, -1, -1, -1, -1) results into Error: + # AssertionError: Channel dim can't be dynamic for convolution. + def test_conv3d_with_dynamic_shape(self): class TestModule(torch.nn.Module): def __init__(self): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_expand.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_expand.py index 49a16d9e1d..af0dff943c 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_expand.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_expand.py @@ -27,6 +27,8 @@ def forward(self, x): expected_ops={acc_ops.expand}, ) + # Dynamic shape is not suitable for the expand operation. + if __name__ == "__main__": run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_gelu.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_gelu.py index 7fdd5da3c7..c33088a498 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_gelu.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_gelu.py @@ -1,3 +1,5 @@ +import unittest + import torch import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops @@ -5,6 +7,9 @@ from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec +@unittest.skip( + reason="Could not find CustomGeluPluginDynamic. Enable it once we upgrade TRT to 8.4" +) class TestGELU(AccTestCase): def test_gelu(self): class TestModule(nn.Module): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_interpolate.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_interpolate.py index b97cacf7d2..f0054e5cb7 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_interpolate.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_interpolate.py @@ -3,7 +3,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestInterpolateConverter(AccTestCase): @@ -106,6 +106,37 @@ def forward(self, x): expected_ops={acc_ops.interpolate}, ) + @parameterized.expand( + [ + # 4D + ("4d_dim_scale", (2, 3, 4, 5), (None), (2), ("nearest"), (None)), + ] + ) + def test_interpolate_with_dynamic_shape_four_dimensions( + self, _, init_size, size, scale_factor, mode, align_corners + ): + class Interpolate(nn.Module): + def forward(self, x): + return torch.nn.functional.interpolate( + x, + size=size, + scale_factor=scale_factor, + mode=mode, + align_corners=align_corners, + ) # only one of size or scale_factor should be defined + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (1, 2, 3, 3), (3, 3, 3, 3))], + ), + ] + + self.run_test_with_dynamic_shape( + Interpolate(), input_specs, expected_ops={acc_ops.interpolate} + ) + if __name__ == "__main__": run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_masked_fill.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_masked_fill.py index 337938fc5a..d04ed6ed33 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_masked_fill.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_masked_fill.py @@ -35,6 +35,9 @@ def forward(self, x): test_implicit_batch_dim=False, ) + # Testing with (-1, -1, -1, -1) results into following error: + # RuntimeError: Trying to create tensor with negative dimension -1: [-1, -1, -1, -1] + @parameterized.expand( [ ("same_dims", (2, 3), (2, 3), 5), diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_matmul.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_matmul.py index 9da0161f3f..50e1f5bfcd 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_matmul.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_matmul.py @@ -3,7 +3,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestMatMulConverter(AccTestCase): @@ -85,6 +85,33 @@ def forward(self, input, other): test_implicit_batch_dim=test_implicit_batch_dim, ) + def test_matmal_dynamic_shape( + self, + ): + class Matmul(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input, other): + return torch.matmul(input, other) + + input_specs = [ + InputTensorSpec( + shape=(-1, 1, 2, 3), + dtype=torch.float32, + shape_ranges=[((1, 1, 2, 3), (9, 1, 2, 3), (9, 1, 2, 3))], + ), + InputTensorSpec( + shape=(-1, -1, 3, 3), + dtype=torch.float32, + shape_ranges=[((1, 1, 3, 3), (9, 4, 3, 3), (9, 4, 3, 3))], + ), + ] + + self.run_test_with_dynamic_shape( + Matmul(), input_specs, expected_ops={acc_ops.matmul} + ) + if __name__ == "__main__": run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_max.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_max.py index 746e61cb30..1da3dd07fa 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_max.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_max.py @@ -79,6 +79,35 @@ def forward(self, input, other): class TestMaxConverterWithDynamicShape(AccTestCase): + @parameterized.expand( + [ + # keepdim can not be False for dynamic shape + ("dim0_keepdim", 0, True), + ("dim1_keepdim", 1, True), + ("dim2_keepdim", 2, True), + ("dim3_keepdim", 3, True), + ] + ) + def test_max_dim_reduce(self, _, dim, keepdim): + class MaxDimReduce(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.max(x, dim, keepdim=keepdim) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + MaxDimReduce(), input_specs, expected_ops={acc_ops.max_dim_reduce} + ) + def test_max_full_reduce( self, ): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_min.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_min.py index c88c08cfb1..33b2aa5671 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_min.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_min.py @@ -79,6 +79,34 @@ def forward(self, input, other): class TestMinConverterWithDynamicShape(AccTestCase): + @parameterized.expand( + [ + ("dim0_keepdim", 0, True), + ("dim1_keepdim", 1, True), + ("dim2_keepdim", 2, True), + ("dim3_keepdim", 3, True), + ] + ) + def test_min_dim_reduce(self, test_name, dim, keepdim): + class MinDimReduce(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.min(x, dim, keepdim) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + MinDimReduce(), input_specs, expected_ops={acc_ops.min_dim_reduce} + ) + def test_min_full_reduce( self, ): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_narrow.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_narrow.py index 15243cb259..9c2a4f34ab 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_narrow.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_narrow.py @@ -3,14 +3,13 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec -class TestNarrowConverter(AccTestCase): +class TestNarrowConverterWithDynamicShape(AccTestCase): @parameterized.expand( [ ("positive_dim", 1, 0, 1), - ("negative_dim", -1, 1, 2), ] ) def test_narrow(self, _, dim, start, length): @@ -18,19 +17,20 @@ class Narrow(nn.Module): def forward(self, x): return x.narrow(dim, start, length) - inputs = [torch.randn(1, 2, 3, 4)] - self.run_test( - Narrow(), - inputs, - expected_ops={acc_ops.slice_tensor}, - test_explicit_batch_dim=False, + input_specs = [ + InputTensorSpec( + shape=(-1, 3, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 3, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + Narrow(), input_specs, expected_ops={acc_ops.slice_tensor} ) -# Testing with (-1, -1, -1 , -1) results in following error: -# AssertionError: Can't chunk on dynamic shape dimension! -""" -class TestNarrowConverterWithDynamicShape(AccTestCase): +class TestNarrowConverter(AccTestCase): @parameterized.expand( [ ("positive_dim", 1, 0, 1), @@ -42,18 +42,14 @@ class Narrow(nn.Module): def forward(self, x): return x.narrow(dim, start, length) - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 5, 5), (2, 3, 5, 5), (2, 3, 5, 5))], - ), - ] - - self.run_test_with_dynamic_shape( - Narrow(), input_specs, expected_ops={acc_ops.slice_tensor} + inputs = [torch.randn(1, 2, 3, 4)] + self.run_test( + Narrow(), + inputs, + expected_ops={acc_ops.slice_tensor}, + test_explicit_batch_dim=False, ) -""" + if __name__ == "__main__": run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_pad.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_pad.py index 7a9e9544c3..c82eee79ee 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_pad.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_pad.py @@ -5,10 +5,12 @@ import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from parameterized import param, parameterized +from parameterized import parameterized from torch.testing._internal.common_utils import run_tests from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase +# from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec + class TestPadConverter(AccTestCase): @parameterized.expand( @@ -51,6 +53,27 @@ def forward(self, x): test_implicit_batch_dim=False, ) + # Testing with (-1, 3, 3, 3) results into following error: + # test_pad_with_dynamic_shape_four_dimensions_0_2d (deeplearning.trt.torch_tensorrt.py.torch_tensorrt.fx.test.converters.acc_op.test_pad.TestPadConverter) ... [07/15/2022-09:23:18] [TRT] [E] 2: [intInterval.cpp::max::26] Error Code 2: Internal Error (Assertion !empty() failed. ) + # Segmentation fault (core dumped) + + """ + def test_pad_with_dynamic_shape_four_dimensions(self): + class Pad(nn.Module): + def forward(self, x): + return torch.nn.functional.pad(x, (1, 1)) + + input_specs = [ + InputTensorSpec( + shape=(-1, 3, 3, 3), + dtype=torch.float32, + shape_ranges=[((1, 3, 3, 3), (2, 3, 3, 3), (2, 3, 3, 3))], + ), + ] + + self.run_test_with_dynamic_shape(Pad(), input_specs, expected_ops={acc_ops.pad}) + """ + @parameterized.expand( [ ("3d", (2, 2, 3, 1, 2, 2)), diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_reduce_ops.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_reduce_ops.py index ae9932fd61..879a0e0eb5 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_reduce_ops.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_reduce_ops.py @@ -2,7 +2,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec reduce_ops = [(torch.sum, acc_ops.sum), (torch.mean, acc_ops.mean)] @@ -76,6 +76,33 @@ def forward(self, x): test_implicit_batch_dim=False, ) + @parameterized.expand( + [ + (f"{acc_op.__name__}_no_dim_no_keepdim", op, acc_op) + for op, acc_op in reduce_ops + ] + ) + def test_reduce_all_dims_with_dynamic_shape_four_dimensions( + self, + test_name, + op, + expected_acc_op, + ): + class Reduce(torch.nn.Module): + def forward(self, x): + return op(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + Reduce(), input_specs, expected_ops={expected_acc_op} + ) + if __name__ == "__main__": run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_split.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_split.py index 811be2c05a..29d174d9fd 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_split.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_split.py @@ -76,6 +76,9 @@ def forward(self, x): }, ) + # Testing with (-1, -1, -1) results into following error: + # AssertionError: Can't chunk on dynamic shape dimension! + @parameterized.expand( [ ("split_with_size", [2, 3, 5], 1), diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_squeeze.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_squeeze.py index 3af02f73ab..d265def896 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_squeeze.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_squeeze.py @@ -14,6 +14,12 @@ def forward(self, x): inputs = [torch.randn(1, 2, 1)] self.run_test(Squeeze(), inputs, expected_ops={acc_ops.squeeze}) + # Testing with shape=(-1, -1, -1, -1) results in error: + # AssertionError: We don't support squeeze dynamic dim. + + # Testing with more than one dynamic dim results in error: + # AssertionError: Currently more than one dynamic dim for input to squeeze is not supported. + def test_squeeze_with_dynamic_shape(self): class Squeeze(nn.Module): def forward(self, x): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_std.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_std.py index 900a875e33..7bcfaa46f2 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_std.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_std.py @@ -2,7 +2,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestMinConverter(AccTestCase): @@ -29,6 +29,36 @@ def forward(self, x): expected_ops={acc_ops.mean, acc_ops.sub, acc_ops.pow, acc_ops.sqrt}, ) + @parameterized.expand( + [ + ("norm_1d", (-1), False), + ("norm_1d", (-1), True), + ("norm_2d", (2, 3), False), + ("norm_2d", (2, 3), True), + ] + ) + def test_std_with_dynamic_shape_four_dimensions(self, _, dim, unbiased): + class Std(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.std(x, dim, unbiased=unbiased, keepdim=True) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))], + ), + ] + + self.run_test_with_dynamic_shape( + Std(), + input_specs, + expected_ops={acc_ops.mean, acc_ops.sub, acc_ops.pow, acc_ops.sqrt}, + ) + @parameterized.expand( [ ("norm_1d", (-1), True), @@ -52,6 +82,36 @@ def forward(self, x): expected_ops={acc_ops.mean, acc_ops.sub, acc_ops.pow, acc_ops.sqrt}, ) + @parameterized.expand( + [ + ("norm_1d", (-1), True), + ("norm_1d", (-1), False), + ("norm_2d", (2, 3), True), + ("norm_2d", (2, 3), False), + ] + ) + def test_std_method_with_dynamic_shape_four_dimensions(self, _, dim, unbiased): + class Std(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x.std(dim, unbiased=unbiased, keepdim=True) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))], + ), + ] + + self.run_test_with_dynamic_shape( + Std(), + input_specs, + expected_ops={acc_ops.mean, acc_ops.sub, acc_ops.pow, acc_ops.sqrt}, + ) + if __name__ == "__main__": run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_tile.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_tile.py index d83bef5a67..cd8e6f97b5 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_tile.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_tile.py @@ -68,6 +68,32 @@ def forward(self, x): Tile(dims), input_specs, expected_ops={acc_ops.tile} ) + @parameterized.expand( + [ + ("all_dynamic_dim", (-1, -1), (1, 2, 2, 1)), + ] + ) + def test_tile_with_dynamic_shape_four_dimensions(self, _, shape, dims): + class Tile(nn.Module): + def __init__(self, dims): + super().__init__() + self.dims = dims + + def forward(self, x): + return torch.tile(x, self.dims) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 3), (3, 3, 3, 3), (3, 3, 3, 3))], + ), + ] + + self.run_test_with_dynamic_shape( + Tile(dims), input_specs, expected_ops={acc_ops.tile} + ) + def test_tile_non_int_dims(self): class Tile(nn.Module): def __init__(self): @@ -88,6 +114,32 @@ def forward(self, x, y): expected_ops={acc_ops.tile}, ) + def test_tile_non_int_dims_with_dynamic_shape_four_dimensions(self): + class Tile(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + y = y * 2 + return torch.tile(x, (1, y.shape[1], y.shape[1])) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 3), (3, 3, 3, 3), (3, 3, 3, 3))], + ), + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 3), (3, 3, 3, 3), (3, 3, 3, 3))], + ), + ] + + self.run_test_with_dynamic_shape( + Tile(), input_specs, expected_ops={acc_ops.tile} + ) + if __name__ == "__main__": run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_to_dtype.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_to_dtype.py index 63eb3345d9..67a07d83cf 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_to_dtype.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_to_dtype.py @@ -1,7 +1,7 @@ import torch import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec from torch_tensorrt.fx.utils import LowerPrecision @@ -23,6 +23,27 @@ def forward(self, x): precision=LowerPrecision.FP16, ) + # Testing with shape shape=(-1, -1, -1, -1) results into following error: + # Error: assert engine + """ + def test_fp16_with_dynamic_shape_four_dimension(self): + class To(torch.nn.Module): + def forward(self, x): + return x.to(torch.float16) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float16, + shape_ranges=[((1, 1, 3, 3), (3, 3, 3, 3), (3, 3, 3, 3))], + ).cuda(), + ] + + self.run_test_with_dynamic_shape( + To(), input_specs, expected_ops={acc_ops.to_dtype} + ) + """ + def test_fp32(self): class To(torch.nn.Module): def forward(self, x): @@ -72,6 +93,25 @@ def forward(self, x): precision=LowerPrecision.FP32, ) + def test_cuda_with_dynamic_shape_four_dimensions(self): + class To(torch.nn.Module): + def forward(self, x): + x = x.to(torch.device("cuda")) + # append extra layer since to(device) is skipped in TRT + return x + torch.randn(3, 3, 3, 3).cuda() + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float16, + shape_ranges=[((1, 1, 3, 3), (3, 3, 3, 3), (3, 3, 3, 3))], + ), + ] + + self.run_test_with_dynamic_shape( + To(), input_specs, expected_ops={acc_ops.to_dtype, acc_ops.add} + ) + def test_device(self): class To(torch.nn.Module): def __init__(self): @@ -95,6 +135,29 @@ def forward(self, x): precision=LowerPrecision.FP32, ) + def test_device_with_dynamic_shape_four_dimensions(self): + class To(torch.nn.Module): + def __init__(self): + super().__init__() + self.a = torch.randn(3, 3, 3, 3) + + def forward(self, x): + idevice = x.device + a = self.a.to(idevice) + return x + a + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float16, + shape_ranges=[((1, 1, 3, 3), (3, 3, 3, 3), (3, 3, 3, 3))], + ), + ] + + self.run_test_with_dynamic_shape( + To(), input_specs, expected_ops={acc_ops.to_dtype, acc_ops.add} + ) + def test_device_fp16(self): class To(torch.nn.Module): def __init__(self): @@ -121,24 +184,35 @@ def forward(self, x): precision=LowerPrecision.FP16, ) - def test_tensor(self): + # Testing with shape shape=(-1, -1, -1, -1) results into following error: + # Error: assert engine + """ + def test_device_fp16_with_dynamic_shape_four_dimensions(self): class To(torch.nn.Module): - def forward(self, x, y): - return y.to(x) + def __init__(self): + super().__init__() + self.a = torch.randn(2, 2) - input = torch.randn(2, 2).half().cuda() - other = torch.randn(2, 2) - inputs = [ - input, - other, + def forward(self, x): + idevice = x.device + idtype = x.dtype + a = self.a.to(idevice) + # fx tracer could not handle "to(idevice, torch.float16)" + # TypeError: to() received an invalid combination of arguments - got (Attribute, torch.dtype) + return a.to(idtype) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float16, + shape_ranges=[((2, 2, 2, 2), (4, 4, 4, 4), (4, 4, 4, 4))], + ), ] - self.run_test( - To(), - inputs, - expected_ops={acc_ops.to_dtype}, - test_implicit_batch_dim=False, - precision=LowerPrecision.FP16, + + self.run_test_with_dynamic_shape( + To(), input_specs, expected_ops={acc_ops.to_dtype} ) + """ # tensor.float() def test_float(self): @@ -158,6 +232,27 @@ def forward(self, x): precision=LowerPrecision.FP32, ) + # tensor.float() + def test_float_with_dynamic_shape_four_dimensions(self): + class To(torch.nn.Module): + def forward(self, x): + return x.float() + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))], + ), + ] + + self.run_test_with_dynamic_shape( + To(), input_specs, expected_ops={acc_ops.to_dtype} + ) + + # Half is not suitable for dynamic shape + # Error: assert engine + # tensor.half() def test_half(self): class To(torch.nn.Module): @@ -197,6 +292,27 @@ def forward(self, x): precision=LowerPrecision.FP32, ) + # tensor.int() + def test_int_with_dynamic_shape_four_dimensions(self): + class To(torch.nn.Module): + def forward(self, x): + x = x.int() + # we do not expect int to be output type, so add an extra layer + x = x.float() + return x + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.int, + shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))], + ), + ] + + self.run_test_with_dynamic_shape( + To(), input_specs, expected_ops={acc_ops.to_dtype} + ) + if __name__ == "__main__": run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_topk.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_topk.py index 53dfe63190..7ae93bf9bd 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_topk.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_topk.py @@ -3,7 +3,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestTopKConverter(AccTestCase): @@ -41,6 +41,44 @@ def forward(self, x): test_implicit_batch_dim=(dim != 0), ) + @parameterized.expand( + [ + ("top1", 1, -1), + ("top2", 2, -1), + ("none_dim", 1, None), + ("smallest", 1, -1, False), + ("top1_dim0", 1, 0, False), + ] + ) + def test_topk_with_dynamic_shape_four_dimensions(self, _, k, dim, largest=True): + class TopK(nn.Module): + def __init__(self, k, dim): + super().__init__() + self.k = k + self.dim = dim + self.largest = largest + + def forward(self, x): + if self.dim is not None: + out = torch.topk( + x, k=self.k, dim=self.dim, largest=self.largest, sorted=False + ) + else: + out = torch.topk(x, k=self.k, largest=self.largest, sorted=False) + return out[0], out[1] + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))], + ), + ] + + self.run_test_with_dynamic_shape( + TopK(k, dim), input_specs, expected_ops={acc_ops.topk} + ) + if __name__ == "__main__": run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_type_as.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_type_as.py index cf3eb972d5..839ff44566 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_type_as.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_type_as.py @@ -1,7 +1,7 @@ import torch import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec from torch_tensorrt.fx.utils import LowerPrecision @@ -103,6 +103,25 @@ def forward(self, input): precision=LowerPrecision.FP16, ) + def test_type_tensor_with_dynamic_shape_four_dimensions(self): + class Type_as(torch.nn.Module): + def forward(self, input): + return input.type(dtype=torch.float32) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.int, + shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))], + ), + ] + + self.run_test_with_dynamic_shape( + Type_as(), + input_specs, + expected_ops={acc_ops.to_dtype}, + ) + def test_type_tensor_ext(self): class Type_as(torch.nn.Module): def forward(self, input, other): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_unary_ops.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_unary_ops.py index c1299d809d..7fad26dc84 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_unary_ops.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_unary_ops.py @@ -6,7 +6,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec unary_ops = [ (torch.sin, acc_ops.sin), @@ -45,6 +45,30 @@ def forward(self, x): self.run_test(m, inputs, expected_ops={expected_op}) +class TestUnaryOpConvertersWithDynamicShapeFourDimensions(AccTestCase): + @parameterized.expand([(op[1].__name__, op[0], op[1]) for op in unary_ops]) + def test_unary_ops(self, name, orig_op: Callable, expected_op): + class TestModule(nn.Module): + def __init__(self, orig_op): + super().__init__() + self.orig_op = orig_op + + def forward(self, x): + return self.orig_op(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(orig_op), input_specs, expected_ops={expected_op} + ) + + class TestUnaryOpNotConverters(AccTestCase): @parameterized.expand( [ @@ -70,6 +94,37 @@ def forward(self, x): ) +class TestUnaryOpNotConvertersWithDynamicShapeFourDimensions(AccTestCase): + @parameterized.expand( + [ + ("not_bool", torch.logical_not, acc_ops.logical_not, torch.bool), + ("not_float", torch.logical_not, acc_ops.logical_not, torch.float), + ("not_int", torch.logical_not, acc_ops.logical_not, torch.int), + ] + ) + def test_unary_ops(self, name, orig_op: Callable, expected_op, input_dtype): + class TestModule(nn.Module): + def __init__(self, orig_op): + super().__init__() + self.orig_op = orig_op + + def forward(self, x): + x = self.orig_op(x) + return self.orig_op(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(orig_op), input_specs, expected_ops={expected_op} + ) + + class TestUnaryRSQRTConverters(AccTestCase): def test_unary_ops(self): class TestModule(nn.Module): @@ -81,5 +136,24 @@ def forward(self, x): self.run_test(m, inputs, expected_ops={acc_ops.sqrt, acc_ops.reciprocal}) +class TestUnaryRSQRTConvertersWithDynamicShapeFourDimensions(AccTestCase): + def test_unary_ops(self): + class TestModule(nn.Module): + def forward(self, x): + return torch.rsqrt(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={acc_ops.sqrt, acc_ops.reciprocal} + ) + + if __name__ == "__main__": run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_unsqueeze.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_unsqueeze.py index 097cf5435d..26d23e0e54 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_unsqueeze.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_unsqueeze.py @@ -26,6 +26,9 @@ def forward(self, x): inputs = [torch.randn(1, 2, 3)] self.run_test(Unsqueeze(dim), inputs, expected_ops={acc_ops.unsqueeze}) + # Testing with more than one dynamic dims results in following error: + # AssertionError: Currently we don't support unsqueeze with more than one dynamic dims. + @parameterized.expand( [ ("negative_dim_dynamic", -4), diff --git a/py/torch_tensorrt/fx/test/quant/test_quant_trt.py b/py/torch_tensorrt/fx/test/quant/test_quant_trt.py index 4fabc7f18d..9c16d60853 100644 --- a/py/torch_tensorrt/fx/test/quant/test_quant_trt.py +++ b/py/torch_tensorrt/fx/test/quant/test_quant_trt.py @@ -15,7 +15,7 @@ from torch.ao.quantization.backend_config.observation_type import ObservationType from torch.ao.quantization.fx.match_utils import MatchAllNode from torch.ao.quantization.quantize_fx import ( - convert_to_reference, + convert_to_reference_fx, get_tensorrt_backend_config_dict, prepare_fx, prepare_qat_fx, @@ -96,7 +96,9 @@ def forward(self, x): ) self.checkGraphModuleNodes(mp, expected_node_occurrence=prepare_count_check) mp(torch.randn(1, 1, 4, 4)) - mq = convert_to_reference(mp, backend_config_dict=self.trt_backend_config_dict) + mq = convert_to_reference_fx( + mp, backend_config_dict=self.trt_backend_config_dict + ) self.checkGraphModuleNodes(mq, expected_node_occurrence=convert_count_check) def test_quantized_input_quantized_output(self): @@ -258,7 +260,7 @@ def forward(self, x): ) # check converted/quantized model - m = convert_to_reference(m, backend_config_dict=backend_config_dict) + m = convert_to_reference_fx(m, backend_config_dict=backend_config_dict) self.checkGraphModuleNodes(m, expected_node_occurrence=convert_count_check) self.checkGraphModuleNodes( m.standalone, expected_node_occurrence=standalone_convert_count_check @@ -273,7 +275,7 @@ def forward(self, x): backend_config_dict=backend_config_dict, ) ref_m(data) - ref_m = convert_to_reference(ref_m, backend_config_dict=backend_config_dict) + ref_m = convert_to_reference_fx(ref_m, backend_config_dict=backend_config_dict) ref_res = ref_m(data) self.assertEqual(res, ref_res) @@ -433,7 +435,7 @@ def _test_module( self.checkGraphModuleNodes(prepared, expected_node_occurrence=no_prepare) # calibration prepared(*inputs) - quantized = convert_to_reference( + quantized = convert_to_reference_fx( prepared, backend_config_dict=self.trt_backend_config_dict, ) @@ -551,7 +553,7 @@ def forward(self, x): example_inputs, backend_config_dict=self.trt_backend_config_dict, ) - m = convert_to_reference(m, backend_config_dict=self.trt_backend_config_dict) + m = convert_to_reference_fx(m, backend_config_dict=self.trt_backend_config_dict) expected_occurrence = { ns.call_function(torch.quantize_per_tensor): 5, ns.call_method("dequantize"): 5, @@ -584,7 +586,7 @@ def forward(self, x): ) # calibration prepared(linear_module_input) - quantized = convert_to_reference( + quantized = convert_to_reference_fx( prepared, backend_config_dict=self.trt_backend_config_dict, ) @@ -614,7 +616,7 @@ def forward(self, x): backend_config_dict=self.trt_backend_config_dict, ) self.assertTrue(len(dict(prepared.named_children())) == 1) - quantized = convert_to_reference( + quantized = convert_to_reference_fx( prepared, backend_config_dict=self.trt_backend_config_dict, ) @@ -650,7 +652,7 @@ def forward(self, x): ns.call_module(torch.ao.quantization.HistogramObserver): 2, } self.checkGraphModuleNodes(prepared, expected_node_occurrence=node_occurrence) - quantized = convert_to_reference( + quantized = convert_to_reference_fx( prepared, backend_config_dict=self.trt_backend_config_dict, ) @@ -719,7 +721,7 @@ def conv_add_extra_inputs_getter(pattern): ns.call_module(torch.ao.quantization.HistogramObserver): 3, } self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) - m = convert_to_reference(m, backend_config_dict=modified_backend_config_dict) + m = convert_to_reference_fx(m, backend_config_dict=modified_backend_config_dict) node_occurrence = { ns.call_function(torch.quantize_per_tensor): 3, ns.call_method("dequantize"): 3, @@ -831,7 +833,7 @@ def forward(self, x): self.checkGraphModuleNodes( m.standalone, expected_node_occurrence=standalone_node_occurrence ) - m = convert_to_reference(m, backend_config_dict=backend_config_dict) + m = convert_to_reference_fx(m, backend_config_dict=backend_config_dict) node_occurrence = { # two inputs for standalone module ns.call_function(torch.quantize_per_tensor): 2, @@ -870,7 +872,7 @@ def forward(self, x): example_inputs, backend_config_dict=self.trt_backend_config_dict, ) - quantized = convert_to_reference( + quantized = convert_to_reference_fx( prepared, backend_config_dict=self.trt_backend_config_dict, ) diff --git a/py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py b/py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py index a78329c9ef..ab7f932acf 100644 --- a/py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py +++ b/py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py @@ -2577,5 +2577,6 @@ def test_all_acc_ops_registered(self): acc_ops.einsum, acc_ops.as_strided, acc_ops.var, + acc_ops.grid_sample, }, ) diff --git a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py index 1ce6fa39ea..8814bf4075 100644 --- a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py +++ b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py @@ -975,7 +975,9 @@ def rescale_quantize_per_channel(*, input, acc_out_ty=None): @register_acc_op_properties(AccOpProperty.pointwise) +@register_acc_op_mapping(op_and_target=("call_function", torch.sub)) @register_acc_op_mapping(op_and_target=("call_function", operator.sub)) +@register_acc_op_mapping(op_and_target=("call_method", "sub")) @register_acc_op def sub(*, input, other): return input - other @@ -2740,6 +2742,34 @@ def expand_as_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node: return expand_node +@register_acc_op_mapping( + op_and_target=("call_function", torch.nn.functional.grid_sample), + arg_replacement_tuples=[ + ("input", "input"), + ("grid", "grid"), + ("mode", "mode", this_arg_is_optional), + ("padding_mode", "padding_mode", this_arg_is_optional), + ("align_corners", "align_corners", this_arg_is_optional), + ], +) +@register_acc_op +def grid_sample( + *, + input, + grid, + mode="bilinear", + padding_mode="zeros", + align_corners=None, +): + return torch.nn.functional.grid_sample( + input=input, + grid=grid, + mode=mode, + padding_mode=padding_mode, + align_corners=align_corners, + ) + + @register_acc_op_mapping( op_and_target=("call_function", torch.nn.functional.interpolate), arg_replacement_tuples=[ diff --git a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py index 9f3f911261..c535b062ee 100644 --- a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py +++ b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py @@ -6,7 +6,7 @@ import textwrap import warnings from types import FunctionType -from typing import Any, Dict, Optional, Sequence, Set, Tuple, Type +from typing import Any, Dict, Iterable, Optional, Sequence, Set, Tuple, Type import torch import torch.jit as jit @@ -41,8 +41,11 @@ class Acc_Rewriter(ast.NodeTransformer): def __init__(self): super().__init__() self.exceptions_rewritten: Set[Type[Exception]] = set() + self.exceptions_bool_rewritten: Set[Type[Exception]] = set() - def rewrite(self, fn: FunctionType) -> Tuple[FunctionType, Set[Type[Exception]]]: + def rewrite( + self, fn: FunctionType + ) -> Tuple[FunctionType, Set[Type[Exception]], Set[Type[Exception]]]: # Normalize the source lines sourcelines, _ = inspect.getsourcelines(fn) @@ -65,7 +68,7 @@ def rewrite(self, fn: FunctionType) -> Tuple[FunctionType, Set[Type[Exception]]] # Return the correct FunctionType object and the Exceptions that were # rewritten during visit_If. - return fn_compiled, self.exceptions_rewritten + return fn_compiled, self.exceptions_rewritten, self.exceptions_bool_rewritten def visit_Assert(self, node: ast.Assert): """ @@ -161,7 +164,23 @@ def _reuse_loc(node): assert isinstance(exc_wrapper_node, ast.Expression) exc_wrapper_call_node = exc_wrapper_node.body assert isinstance(exc_wrapper_call_node, ast.Call) - exc_wrapper_call_node.args = [if_node.test, exc_msg] + if isinstance(if_node.test, ast.BoolOp) and isinstance( + if_node.test.op, ast.And + ): + self.exceptions_bool_rewritten.add(exc_type) + bool_wrapper_node = ast.parse( + f"self.{_get_exception_wrapper_attr_name(exc_type)}_bool()", mode="eval" + ) + assert isinstance(exc_wrapper_node, ast.Expression) + bool_wrapper_call_node = bool_wrapper_node.body + assert isinstance(exc_wrapper_call_node, ast.Call) + bool_wrapper_call_node.args = if_node.test.values + exc_wrapper_call_node.args = [ + _reuse_loc(bool_wrapper_call_node), + exc_msg, + ] + else: + exc_wrapper_call_node.args = [if_node.test, exc_msg] # Ensure that the new node conforms to the Python AST grammar expr_wrapper = _reuse_loc(ast.Expr(_reuse_loc(exc_wrapper_call_node))) @@ -206,6 +225,39 @@ def forward(self, cond: bool, msg: str): raise self.exc if msg is None else self.exc(msg) +class ConditionalExceptionBoolCondWrapper(nn.Module): + """ + This is a wrapper class to for boolean ops used inside conditionals + raising exceptions. + This currently only handles binary input cases for the `and` operator + at one level of depth + For example: + + .. code-block:: python + + if self.name != "x" and self.name != "y": + raise AssertionError(f"Name was not x: {self.name}") + + rewrites the `self.name != "x" and self.name != "y"` with + a `_conditional_exception_wrapper_AssertionError_bool` as follows: + + .. code-block:: python + + self._conditional_exception_wrapper_AssertionError( + self._conditional_exception_wrapper_AssertionError_bool(self.name != "x" and self.name != "y"), f"Name was not x: {self.name}" + ) + """ + + # Mark as impure so that calls to it will not be removed during DCE. + _is_impure = True + + def __init__(self, op): + super().__init__() + + def forward(self, *conds: Iterable): + return all(conds) + + # Custom tracer that traces to the functional level and rewrites asserts and # exceptions. class AccRewritingTracer(Tracer): @@ -217,6 +269,7 @@ class AccRewritingTracer(Tracer): # Note: Treat ConditionalExceptionWrapper as a leaf so that we don't # trace into it, because it contains control flow and raises an exception. DEFAULT_LEAF_MODULE_LIST = { + ConditionalExceptionBoolCondWrapper, ConditionalExceptionWrapper, torch.nn.quantized.Linear, torch.nn.quantized.Conv2d, @@ -318,6 +371,7 @@ def rewrite_module(m: nn.Module): # Acc_Rewriter calls into in this module so we can add them in init # below. all_added_wrappers: Set[Type[Exception]] = set() + all_added_bool_wrappers: Set[Type[Exception]] = set() # Note: Make this a subclass of our base class. class RewrittenModule(base_class): # type: ignore[valid-type, misc] @@ -349,8 +403,13 @@ class RewrittenModule(base_class): # type: ignore[valid-type, misc] if base_class not in allow_list: vars()[method_name] = method else: - vars()[method_name], added_wrappers = Acc_Rewriter().rewrite(method) + ( + vars()[method_name], + added_wrappers, + added_bool_wrappers, + ) = Acc_Rewriter().rewrite(method) all_added_wrappers.update(added_wrappers) + all_added_bool_wrappers.update(added_bool_wrappers) def __init__(self, orig): nn.Module.__init__(self) @@ -365,6 +424,16 @@ def __init__(self, orig): wrapper_name, ConditionalExceptionWrapper(exc_type), ) + + for exc_type in all_added_bool_wrappers: + wrapper_name = f"{_get_exception_wrapper_attr_name(exc_type)}_bool" + assert not hasattr(self, wrapper_name) + setattr( + self, + wrapper_name, + ConditionalExceptionBoolCondWrapper(exc_type), + ) + # Recursively rewrite and copy all module attrs of this module. for k, v in orig.__dict__.items(): if k == "_modules": @@ -403,9 +472,12 @@ def _remove_exceptions(gm: torch.fx.GraphModule) -> bool: found in GraphModule gm. Returns whether the graph is modified. """ changed = False - for node in gm.graph.nodes: - if node.op == "call_module" and isinstance( - gm.get_submodule(node.target), ConditionalExceptionWrapper + for node in reversed(gm.graph.nodes): + if node.op == "call_module" and ( + isinstance(gm.get_submodule(node.target), ConditionalExceptionWrapper) + or isinstance( + gm.get_submodule(node.target), ConditionalExceptionBoolCondWrapper + ) ): gm.graph.erase_node(node) changed = True diff --git a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_utils.py b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_utils.py index b4856e116f..425fafddac 100644 --- a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_utils.py +++ b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_utils.py @@ -171,7 +171,9 @@ def map_tensor_metadata(a: Any, fn: Callable): Map some `fn` to `a`, where `a` is either a TensorMetadata, or else a tuple/list recursively containing TensorMetadata. """ - if isinstance(a, TensorMetadata): + if isinstance(a, int): + return 1 + elif isinstance(a, TensorMetadata): return fn(a) elif isinstance(a, tuple): return tuple(map_tensor_metadata(elem, fn) for elem in a)