Skip to content

Commit 377547e

Browse files
author
Wei
authored
Changes done internally at Facebook (#1194)
6703b98dff0695d91026f057b951dba1355825fa Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_ops.prod c822345d6d673e1653c2208435e34ab400bada3d Jason Park <[email protected]> Add support for generic torch ops to be used in training. e5758602a0592d6c2b71d6d66a0398c4dd9b5e20 Shreyansh Prajapati <[email protected]> Test dynamic shape support for repeat interleave c13c633f04df162500eed477c0569eb2b81eb070 Shreyansh Prajapati <[email protected]> Test dynamic shape support for reduce ops 863476cf43b210922b88585b8f196dd84fbebb56 Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_op.convolution 68dff39793e5c30c20010919a855bb3d984015d7 Ruichao Xiao <[email protected]> [fbcode][GPU][DHEN]fuse split squeeze cat as reshape f8b920769507ebd2ff02419b4aece25451298a95 Ruichao Xiao <[email protected]> [fbcode][DHEN][GPU] reorder and merge cats whose input is a sublist of another cat 5b6a8d2d6be979983a52ac96225fefb510c3817c Andrew Or <[email protected]> [Quant][fx] Rename convert_to_reference to convert_to_reference_fx 996a0e080b8a8bc0b292a7c2ac92f41f6db33a2e Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_op.expand 084631fe74b304fbb9481ca15fd452a3714fb1b8 Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_op.to_dtype b3195e76329ccddbb5c4640cfa884d0e457d2d34 Shreyansh Prajapati <[email protected]> Test dynamic shape support for std a5d964e62bdf769cf8c2e67321138b33e1f524a7 Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_op.tile 3d33d45b2fc7f10f25c22946ba474b227e4b6529 Shreyansh Prajapati <[email protected]> Test dynamic shape support for squeeze 09085abf63d7e7732e2cd66e600e8afc6d58964f Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_op.topk 65edc7ea12899e9bd2af42c890a64de853d9b7fe Huamin Li <[email protected]> temporarily skip gelu tests d11e521f9b90554ca86912a49920afa4406bb40d Shirong Wu <[email protected]> Suppress accuracy check for remove_reshape_with_batch_size_change 6d948298b2327d229e010a34f1c221b11d2eb504 Ankur Singla <[email protected]> [GPULowering] Suppress accuracy check for fuse_unsqueeze_cat_sum e780b647fc9571b77d9f41c963041a6ac3d66f33 Janet Yang <[email protected]> Lower xrayvideo2022 to fx2trt 433c7207fef16b1fdff985546ea969c39fa83e7c generatedunixname89002005287564 <[email protected]> [Codemod][Remove @noautodeps and @autodeps-skip tags] deeplearning/trt 1/2 66fdb65cffa925660c77b4758388399db3cbfe48 Scott Wolchok <[email protected]> [fx2ait] Minor Python cleanup in acc_ops_getitem 188132ecb2c19bcbf83cb2dc381f6e3798629f87 generatedunixname89002005324833 <[email protected]> [AutoAccept][Codemod][FBSourceBuckFormatLinter] Daily `arc lint --take BUCKFORMAT` 4536bae4686dd01f2149541ea7fb330e178a4969 Wei Wei <[email protected]> [fx2trt] support sub 064602e666f86c110d931cd90a8536112a19b4ad Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_ops.interpolate 9dfd0ee0cecb1975e3f53c44de237d67ca443ec5 Shreyansh Prajapati <[email protected]> Test dynamic shape support for unary_ops 39b9efad8d5d82463a2016d135c0cf277de1c3c6 Shreyansh Prajapati <[email protected]> Test dynamic shape support for unsqueeze 2bb17667d1dabc95391950426fc1f921eb3d0959 Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_ops.split 64dfb7b096686cb2fd33197340dc72f30d525456 Shirong Wu <[email protected]> Group LN trt plugin 438f670e28df59b0734baa092a514fba3d75eb4f Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_ops.avgpool df0fe32dae4343827bd9b37b72daae761b02f228 Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_ops masked fill 44fe735d3493ea2d05a56b49093e4a23dd63a98e Shreyansh Prajapati <[email protected]> Test dynamic shaope support for acc_ops.pad 4f931acca706d8ce79045ceafef2ea0486609149 Wei Wei <[email protected]> [fx2trt] torch.max dynamic shape test bf6f6cbe217d26a95ca9122574adf7de3966db9e Shreyansh Prajapati <[email protected]> Change the name of the test from full_reduce to dim_reduce 1c5680ed107d9206f3514eff4069a3f6c870ba8c Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_ops.type_as 33e4c175a4f5fec78ac0b1c8eb262ca777c7aaba Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_ops.min f37be34bcef9716080b8bafbd1f4ad72e412c44c Wei Wei <[email protected]> [fx2trt] plugin for grid_sample 57b5cc6a0f4839686ae360361a3a13b424794ee7 generatedunixname89002005367269 <[email protected]> [AutoAccept][Codemod][FBSourceBlackLinter] Daily `arc lint --take BLACK` eb741cc5e5a7babdc94e72d411670905f54da3e0 Shreyansh Prajapati <[email protected]> Updated the dynamic shape support for narrow op 521c36b96a14741ae89d7af6cbb658120bcec2ea Shreyansh Prajapati <[email protected]> Removing the comment for 4 dims dynamic shape support after analysis e947343375967fe9efb0a16fdb9f63bff1449328 Shreyansh Prajapati <[email protected]> Updated the pad test for dynamic batch for analysis 3d64087014e91bc301a315eae43683b1aa2b66bc Oleg Khabinov <[email protected]> [trt_bc] Some improvements dfd937a56fa01aca88a89b46176befdac4c202c4 Shreyansh Prajapati <[email protected]> Updated the test for as_strided op for analysis 11d76d0420dcaa4bb8890dcdeb86b6e534af831c Bangsheng Tang <[email protected]> [gpu][infer] replace fx2trt_layer_norm with fbgemm layer_norm 932046ff6ea6dead114c0222b23ca3854690cffa Wei Wei <[email protected]> [fx2trt] bridge the dynamic batch and fixed shape
1 parent 2b224b2 commit 377547e

38 files changed

+900
-127
lines changed

examples/fx/quantized_resnet_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import torchvision.models as models
99
from torch.ao.quantization.quantize_fx import (
1010
convert_fx,
11-
convert_to_reference,
11+
convert_to_reference_fx,
1212
prepare_fx,
1313
)
1414
from torch.fx.experimental.normalize import NormalizeArgs
@@ -52,7 +52,7 @@ def build_int8_trt(rn18):
5252
prepared = prepare_fx(rn18, {"": qconfig}, data)
5353
for _ in range(10):
5454
prepared(data)
55-
quantized_rn18 = convert_to_reference(prepared)
55+
quantized_rn18 = convert_to_reference_fx(prepared)
5656
ref_res = quantized_rn18(data)
5757
print("quantized model:", quantized_rn18)
5858

py/torch_tensorrt/_compile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def compile(module: Any, ir="default", inputs=[], enabled_precisions=set([_enums
117117
else:
118118
raise ValueError(f"Precision {enabled_precisions} not supported on FX")
119119

120-
return lower_to_trt(module, inputs, lower_precision=lower_precision, max_batch_size=inputs[0].size(0), explicit_batch_dimension=True)
120+
return lower_to_trt(module, inputs, lower_precision=lower_precision, max_batch_size=inputs[0].size(0), explicit_batch_dimension=True, dynamic_batch=False)
121121
else:
122122
raise RuntimeError("Module is an unknown format or the ir requested is unknown")
123123

py/torch_tensorrt/fx/converters/acc_ops_converters.py

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -591,9 +591,33 @@ def acc_ops_batch_norm(
591591
)
592592
power = np.ones_like(scale)
593593

594+
# For BatchNorm1d, reshape 1d to 2d
595+
output_shape = input_val.shape
596+
if not network.has_implicit_batch_dimension and len(input_val.shape) < 4:
597+
assert (
598+
len(get_dynamic_dims(input_val.shape)) <= 1
599+
), "BatchNorm1D with more than one dynamic dims is not currently supported."
600+
reshape_layer = network.add_shuffle(input_val)
601+
if len(input_val.shape) == 2:
602+
reshape_layer.reshape_dims = (input_val.shape[0], input_val.shape[1], 1, 1)
603+
else: # len(input_val.shape) == 3
604+
reshape_layer.reshape_dims = (
605+
input_val.shape[0],
606+
input_val.shape[1],
607+
input_val.shape[2],
608+
1,
609+
)
610+
set_layer_name(reshape_layer, target, f"{name}_reshape_2d")
611+
input_val = reshape_layer.get_output(0)
594612
layer = network.add_scale(input_val, trt.ScaleMode.CHANNEL, bias, scale, power)
595613
set_layer_name(layer, target, name)
596614

615+
# For BatchNorm1d, reshape output back to 1d
616+
if not network.has_implicit_batch_dimension and len(output_shape) < 4:
617+
reshape_output_layer = network.add_shuffle(layer.get_output(0))
618+
reshape_output_layer.reshape_dims = tuple(output_shape)
619+
set_layer_name(reshape_output_layer, target, f"{name}_reshape_1d")
620+
layer = reshape_output_layer
597621
return layer.get_output(0)
598622

599623

@@ -614,7 +638,18 @@ def acc_ops_layer_norm(network, target, args, kwargs, name):
614638
eps_field = trt.PluginField(
615639
"eps", np.array([kwargs["eps"]], dtype=np.float32), trt.PluginFieldType.FLOAT32
616640
)
617-
field_collection = trt.PluginFieldCollection([gamma_field, beta_field, eps_field])
641+
try:
642+
normalized_shape = np.array(kwargs["normalized_shape"], dtype=np.int32)
643+
except TypeError:
644+
print("Unable to convert normalized_shape to a field, fall back to []")
645+
normalized_shape = np.array([], dtype=np.int32)
646+
647+
normalized_shape_filed = trt.PluginField(
648+
"normalized_shape", normalized_shape, trt.PluginFieldType.INT32
649+
)
650+
field_collection = trt.PluginFieldCollection(
651+
[gamma_field, beta_field, eps_field, normalized_shape_filed]
652+
)
618653

619654
try:
620655
if network.has_implicit_batch_dimension:
@@ -2838,11 +2873,7 @@ def num_slice_types(slices):
28382873
"""
28392874
Gather the number of slice in getitem slices.
28402875
"""
2841-
num_slice = 0
2842-
for s in slices:
2843-
if isinstance(s, slice) or isinstance(s, int):
2844-
num_slice += 1
2845-
return num_slice
2876+
return sum(1 for s in slices if isinstance(s, slice) or isinstance(s, int))
28462877

28472878
def slice_to_trt_params(py_slice, dim_size):
28482879
"""
@@ -2878,9 +2909,9 @@ def slice_to_trt_params(py_slice, dim_size):
28782909
new_slices = []
28792910
for s in slices:
28802911
if s == Ellipsis:
2881-
while num_ellipsis > 0:
2912+
# pass explicit start to guard against negative num_ellipsis
2913+
for _ in range(0, num_ellipsis):
28822914
new_slices.append(slice(None, None, None))
2883-
num_ellipsis -= 1
28842915
else:
28852916
new_slices.append(s)
28862917
slices = new_slices

py/torch_tensorrt/fx/converters/converter_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ def get_trt_plugin(
4141
Returns:
4242
A TensorRT plugin that can be added to TensorRT network as Plugin layer.
4343
"""
44+
# print the registered plugins
45+
# PLUGIN_CREATORS = trt.get_plugin_registry().plugin_creator_list
46+
# for plugin_creator in PLUGIN_CREATORS:
47+
# print(plugin_creator.name)
48+
4449
plugin_registry = trt.get_plugin_registry()
4550
plugin_creator = plugin_registry.get_plugin_creator(
4651
plugin_name, version, plugin_namespace
@@ -214,7 +219,6 @@ def create_constant(
214219

215220
if dtype:
216221
value = value.to(dtype)
217-
218222
constant = network.add_constant(value.shape, to_numpy(value))
219223
constant.name = name
220224
return constant.get_output(0)

py/torch_tensorrt/fx/input_tensor_spec.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,17 @@
66
from .utils import get_dynamic_dims
77

88

9-
def generate_input_specs(
10-
inputs, lower_setting, additional_inputs=None, fixed_shape=False
11-
):
9+
def generate_input_specs(inputs, lower_setting, additional_inputs=None):
1210
# AIT lower setting doesn't have explicit_batch_dimension field and
1311
# we just return None.
1412
if not hasattr(lower_setting, "explicit_batch_dimension"):
1513
return None
1614

17-
if not lower_setting.explicit_batch_dimension or fixed_shape:
15+
# dynamic_batch is TRT only flag. It does not exist in AIT lower setting
16+
if (
17+
not lower_setting.explicit_batch_dimension
18+
or lower_setting.dynamic_batch is False
19+
):
1820
return InputTensorSpec.from_tensors(inputs)
1921

2022
# If we don't have additional inputs, we assume the first dimension

py/torch_tensorrt/fx/lower.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def lower_to_trt(
3636
timing_cache_prefix="",
3737
save_timing_cache=False,
3838
cuda_graph_batch_size=-1,
39-
dynamic_batch=False,
39+
dynamic_batch=True,
4040
) -> nn.Module:
4141
"""
4242
Takes in original module, input and lowering setting, run lowering workflow to turn module

py/torch_tensorrt/fx/lower_setting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,4 +86,4 @@ class LowerSetting(LowerSettingBasic):
8686
cuda_graph_batch_size: int = -1
8787
preset_lowerer: str = ""
8888
opt_profile_replica: int = 1
89-
dynamic_batch: bool = False
89+
dynamic_batch: bool = True

py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import datetime
12
from functools import partial, wraps
23
from typing import Any, Callable, Optional, Sequence
34

@@ -142,6 +143,9 @@ def lower_func(split_result: SplitResult) -> nn.Module:
142143

143144
# Only acc submodules will be lowered.
144145
if not submod_name.startswith(split_result.non_acc_submodule_prefix):
146+
print("Now lowering submodule", submod_name)
147+
lowering_start_time = datetime.datetime.now()
148+
145149
self.lower_setting.input_specs = generate_input_specs(
146150
submod_inputs,
147151
self.lower_setting,
@@ -156,6 +160,10 @@ def lower_func(split_result: SplitResult) -> nn.Module:
156160
LOWER_SPLIT_POST_OBSERVER.observe(
157161
submod_name, lowered_module, submod_inputs
158162
)
163+
print(
164+
f"Lowering submodule {submod_name} elapsed time",
165+
datetime.datetime.now() - lowering_start_time,
166+
)
159167

160168
return split_result.split_module
161169

py/torch_tensorrt/fx/test/converters/acc_op/test_as_strided.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops
44
from parameterized import parameterized
55
from torch.testing._internal.common_utils import run_tests
6-
from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec
6+
from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase
7+
8+
# from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec
79

810

911
class TestConverter(AccTestCase):
@@ -30,7 +32,9 @@ def forward(self, x):
3032
test_implicit_batch_dim=False,
3133
)
3234

33-
# Testing with shape (-1, -1, -1, -1) results into error: RuntimeError: setStorage: sizes [2, 3], strides [1, 2], storage offset 0, and itemsize 8 requiring a storage size of 48 are out of bounds for storage of size 1
35+
# Testing with shape (-1, 3) results into error:
36+
# RuntimeError: setStorage: sizes [2, 3], strides [1, 2], storage offset 0, and itemsize 8 requiring a storage size of 48 are out of bounds for storage of size 16
37+
3438
"""
3539
def test_as_strided_with_dynamic_shape_four_dimensions(self):
3640
class Stride(nn.Module):
@@ -39,9 +43,9 @@ def forward(self, x):
3943
4044
input_specs = [
4145
InputTensorSpec(
42-
shape=(-1, -1, -1, -1),
46+
shape=(-1, 3),
4347
dtype=torch.float32,
44-
shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (5, 5, 5, 5))],
48+
shape_ranges=[((1, 3), (2, 3), (2, 3))],
4549
),
4650
]
4751

py/torch_tensorrt/fx/test/converters/acc_op/test_avgpool.py

Lines changed: 42 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,48 @@ def forward(self, x):
3939
inputs = [torch.randn(1, 3, 224)]
4040
self.run_test(TestModule(), inputs, expected_ops={acc_ops.avg_pool1d})
4141

42+
@parameterized.expand(
43+
[
44+
("default", 1),
45+
("kernal_size", 3),
46+
("stride", 1, 2),
47+
("tuple_parameters", 2, (1,), (1,)),
48+
param("padding", 2, padding=1),
49+
param("ceil_mode", 1, ceil_mode=True),
50+
param("include_pad", 2, padding=1, count_include_pad=False),
51+
]
52+
)
53+
def test_avg_pool1d_with_dynamic_shape(
54+
self,
55+
test_name="default",
56+
kernel_size=1,
57+
stride=1,
58+
padding=0,
59+
ceil_mode=False,
60+
count_include_pad=True,
61+
):
62+
class TestModule(torch.nn.Module):
63+
def __init__(self):
64+
super().__init__()
65+
self.avg_pool = torch.nn.AvgPool1d(
66+
kernel_size, stride, padding, ceil_mode, count_include_pad
67+
)
68+
69+
def forward(self, x):
70+
return self.avg_pool(x)
71+
72+
input_specs = [
73+
InputTensorSpec(
74+
shape=(-1, 3, 3),
75+
dtype=torch.float32,
76+
shape_ranges=[((1, 3, 3), (3, 3, 3), (3, 3, 3))],
77+
),
78+
]
79+
80+
self.run_test_with_dynamic_shape(
81+
TestModule(), input_specs, expected_ops={acc_ops.avg_pool1d}
82+
)
83+
4284
def test_avg_pool2d_with_dynamic_shape_four_dimensions(
4385
self,
4486
test_name="default",
@@ -218,38 +260,6 @@ def forward(self, x):
218260
TestModule(), input_specs, expected_ops={acc_ops.avg_pool2d}
219261
)
220262

221-
# Testing with (-1, -1, -1, -1) results in error: RuntimeError: ShapeProp error for: node=%avg_pool1d : [#users=1] = call_function[target=torch.avg_pool1d](args = (%x, (1,), (1,), (0,), False, True), kwargs = {}) with meta={}
222-
"""
223-
def test_avg_pool1d_with_dynamic_shape_four_dimensions(
224-
self,
225-
test_name="default",
226-
kernel_size=1,
227-
stride=1,
228-
padding=0,
229-
ceil_mode=False,
230-
count_include_pad=True,
231-
):
232-
class TestModule(torch.nn.Module):
233-
def __init__(self):
234-
super().__init__()
235-
self.avg_pool = torch.nn.AvgPool1d(
236-
kernel_size, stride, padding, ceil_mode, count_include_pad
237-
)
238-
239-
def forward(self, x):
240-
return self.avg_pool(x)
241-
242-
input_specs = [
243-
InputTensorSpec(
244-
shape=(-1, -1, -1, -1),
245-
dtype=torch.float32,
246-
shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (5, 5, 5, 5))],
247-
),
248-
]
249-
250-
self.run_test(TestModule(), input_specs, expected_ops={acc_ops.avg_pool1d})
251-
"""
252-
253263

254264
if __name__ == "__main__":
255265
run_tests()

py/torch_tensorrt/fx/test/converters/acc_op/test_batchnorm.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,27 @@ def forward(self, x):
1717
inputs = [torch.randn(1, 3, 224, 224)]
1818
self.run_test(TestModule(), inputs, expected_ops={acc_ops.batch_norm})
1919

20+
def test_batchnorm1d_with_dynamic_shape(self):
21+
class TestModule(torch.nn.Module):
22+
def __init__(self):
23+
super().__init__()
24+
self.bn = torch.nn.BatchNorm1d(3)
25+
26+
def forward(self, x):
27+
return self.bn(x)
28+
29+
input_specs = [
30+
InputTensorSpec(
31+
shape=(-1, 3, 5),
32+
dtype=torch.float32,
33+
shape_ranges=[((2, 3, 5), (6, 3, 5), (10, 3, 5))],
34+
),
35+
]
36+
37+
self.run_test_with_dynamic_shape(
38+
TestModule(), input_specs, expected_ops={acc_ops.batch_norm}
39+
)
40+
2041
def test_batchnorm_with_dynamic_shape(self):
2142
class TestModule(torch.nn.Module):
2243
def __init__(self):

py/torch_tensorrt/fx/test/converters/acc_op/test_binary_ops.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
elementwise_ops = [
1414
((lambda x, y: x + y), acc_ops.add, NEED_TEST_BOTH_CONSTANTS_CASE),
1515
((lambda x, y: x - y), acc_ops.sub, NEED_TEST_BOTH_CONSTANTS_CASE),
16+
((lambda x, y: torch.sub(x, y)), acc_ops.sub, False),
17+
((lambda x, y: x.sub(y)), acc_ops.sub, False),
1618
((lambda x, y: x / y), acc_ops.div, NEED_TEST_BOTH_CONSTANTS_CASE),
1719
((lambda x, y: x // y), acc_ops.floor_div, NEED_TEST_BOTH_CONSTANTS_CASE),
1820
(

py/torch_tensorrt/fx/test/converters/acc_op/test_clamp.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops
33
from parameterized import param, parameterized
44
from torch.testing._internal.common_utils import run_tests
5-
from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase
5+
from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec
66

77

88
class TestClampConverter(AccTestCase):
@@ -27,8 +27,6 @@ def forward(self, x):
2727
inputs = [torch.randn(3, 4)]
2828
self.run_test(TestModule(), inputs, expected_ops={acc_ops.clamp})
2929

30-
# Error: RuntimeError: ShapeProp error for: node=%clamp : [#users=1] = call_function[target=torch.clamp](args = (%x, 1, 0), kwargs = {}) with meta={}
31-
"""
3230
@parameterized.expand(
3331
[
3432
param("default", min=-1, max=0),
@@ -55,8 +53,9 @@ def forward(self, x):
5553
),
5654
]
5755

58-
self.run_test(TestModule(), input_specs, expected_ops={acc_ops.clamp})
59-
"""
56+
self.run_test_with_dynamic_shape(
57+
TestModule(), input_specs, expected_ops={acc_ops.clamp}
58+
)
6059

6160

6261
if __name__ == "__main__":

0 commit comments

Comments
 (0)