Skip to content

[FX] Changes done internally at Facebook #1194

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/fx/quantized_resnet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torchvision.models as models
from torch.ao.quantization.quantize_fx import (
convert_fx,
convert_to_reference,
convert_to_reference_fx,
prepare_fx,
)
from torch.fx.experimental.normalize import NormalizeArgs
Expand Down Expand Up @@ -52,7 +52,7 @@ def build_int8_trt(rn18):
prepared = prepare_fx(rn18, {"": qconfig}, data)
for _ in range(10):
prepared(data)
quantized_rn18 = convert_to_reference(prepared)
quantized_rn18 = convert_to_reference_fx(prepared)
ref_res = quantized_rn18(data)
print("quantized model:", quantized_rn18)

Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def compile(module: Any, ir="default", inputs=[], enabled_precisions=set([_enums
else:
raise ValueError(f"Precision {enabled_precisions} not supported on FX")

return lower_to_trt(module, inputs, lower_precision=lower_precision, max_batch_size=inputs[0].size(0), explicit_batch_dimension=True)
return lower_to_trt(module, inputs, lower_precision=lower_precision, max_batch_size=inputs[0].size(0), explicit_batch_dimension=True, dynamic_batch=False)
else:
raise RuntimeError("Module is an unknown format or the ir requested is unknown")

Expand Down
47 changes: 39 additions & 8 deletions py/torch_tensorrt/fx/converters/acc_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,9 +591,33 @@ def acc_ops_batch_norm(
)
power = np.ones_like(scale)

# For BatchNorm1d, reshape 1d to 2d
output_shape = input_val.shape
if not network.has_implicit_batch_dimension and len(input_val.shape) < 4:
assert (
len(get_dynamic_dims(input_val.shape)) <= 1
), "BatchNorm1D with more than one dynamic dims is not currently supported."
reshape_layer = network.add_shuffle(input_val)
if len(input_val.shape) == 2:
reshape_layer.reshape_dims = (input_val.shape[0], input_val.shape[1], 1, 1)
else: # len(input_val.shape) == 3
reshape_layer.reshape_dims = (
input_val.shape[0],
input_val.shape[1],
input_val.shape[2],
1,
)
set_layer_name(reshape_layer, target, f"{name}_reshape_2d")
input_val = reshape_layer.get_output(0)
layer = network.add_scale(input_val, trt.ScaleMode.CHANNEL, bias, scale, power)
set_layer_name(layer, target, name)

# For BatchNorm1d, reshape output back to 1d
if not network.has_implicit_batch_dimension and len(output_shape) < 4:
reshape_output_layer = network.add_shuffle(layer.get_output(0))
reshape_output_layer.reshape_dims = tuple(output_shape)
set_layer_name(reshape_output_layer, target, f"{name}_reshape_1d")
layer = reshape_output_layer
return layer.get_output(0)


Expand All @@ -614,7 +638,18 @@ def acc_ops_layer_norm(network, target, args, kwargs, name):
eps_field = trt.PluginField(
"eps", np.array([kwargs["eps"]], dtype=np.float32), trt.PluginFieldType.FLOAT32
)
field_collection = trt.PluginFieldCollection([gamma_field, beta_field, eps_field])
try:
normalized_shape = np.array(kwargs["normalized_shape"], dtype=np.int32)
except TypeError:
print("Unable to convert normalized_shape to a field, fall back to []")
normalized_shape = np.array([], dtype=np.int32)

normalized_shape_filed = trt.PluginField(
"normalized_shape", normalized_shape, trt.PluginFieldType.INT32
)
field_collection = trt.PluginFieldCollection(
[gamma_field, beta_field, eps_field, normalized_shape_filed]
)

try:
if network.has_implicit_batch_dimension:
Expand Down Expand Up @@ -2838,11 +2873,7 @@ def num_slice_types(slices):
"""
Gather the number of slice in getitem slices.
"""
num_slice = 0
for s in slices:
if isinstance(s, slice) or isinstance(s, int):
num_slice += 1
return num_slice
return sum(1 for s in slices if isinstance(s, slice) or isinstance(s, int))

def slice_to_trt_params(py_slice, dim_size):
"""
Expand Down Expand Up @@ -2878,9 +2909,9 @@ def slice_to_trt_params(py_slice, dim_size):
new_slices = []
for s in slices:
if s == Ellipsis:
while num_ellipsis > 0:
# pass explicit start to guard against negative num_ellipsis
for _ in range(0, num_ellipsis):
new_slices.append(slice(None, None, None))
num_ellipsis -= 1
else:
new_slices.append(s)
slices = new_slices
Expand Down
6 changes: 5 additions & 1 deletion py/torch_tensorrt/fx/converters/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ def get_trt_plugin(
Returns:
A TensorRT plugin that can be added to TensorRT network as Plugin layer.
"""
# print the registered plugins
# PLUGIN_CREATORS = trt.get_plugin_registry().plugin_creator_list
# for plugin_creator in PLUGIN_CREATORS:
# print(plugin_creator.name)

plugin_registry = trt.get_plugin_registry()
plugin_creator = plugin_registry.get_plugin_creator(
plugin_name, version, plugin_namespace
Expand Down Expand Up @@ -214,7 +219,6 @@ def create_constant(

if dtype:
value = value.to(dtype)

constant = network.add_constant(value.shape, to_numpy(value))
constant.name = name
return constant.get_output(0)
Expand Down
10 changes: 6 additions & 4 deletions py/torch_tensorrt/fx/input_tensor_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,17 @@
from .utils import get_dynamic_dims


def generate_input_specs(
inputs, lower_setting, additional_inputs=None, fixed_shape=False
):
def generate_input_specs(inputs, lower_setting, additional_inputs=None):
# AIT lower setting doesn't have explicit_batch_dimension field and
# we just return None.
if not hasattr(lower_setting, "explicit_batch_dimension"):
return None

if not lower_setting.explicit_batch_dimension or fixed_shape:
# dynamic_batch is TRT only flag. It does not exist in AIT lower setting
if (
not lower_setting.explicit_batch_dimension
or lower_setting.dynamic_batch is False
):
return InputTensorSpec.from_tensors(inputs)

# If we don't have additional inputs, we assume the first dimension
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/fx/lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def lower_to_trt(
timing_cache_prefix="",
save_timing_cache=False,
cuda_graph_batch_size=-1,
dynamic_batch=False,
dynamic_batch=True,
) -> nn.Module:
"""
Takes in original module, input and lowering setting, run lowering workflow to turn module
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/fx/lower_setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,4 @@ class LowerSetting(LowerSettingBasic):
cuda_graph_batch_size: int = -1
preset_lowerer: str = ""
opt_profile_replica: int = 1
dynamic_batch: bool = False
dynamic_batch: bool = True
8 changes: 8 additions & 0 deletions py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime
from functools import partial, wraps
from typing import Any, Callable, Optional, Sequence

Expand Down Expand Up @@ -142,6 +143,9 @@ def lower_func(split_result: SplitResult) -> nn.Module:

# Only acc submodules will be lowered.
if not submod_name.startswith(split_result.non_acc_submodule_prefix):
print("Now lowering submodule", submod_name)
lowering_start_time = datetime.datetime.now()

self.lower_setting.input_specs = generate_input_specs(
submod_inputs,
self.lower_setting,
Expand All @@ -156,6 +160,10 @@ def lower_func(split_result: SplitResult) -> nn.Module:
LOWER_SPLIT_POST_OBSERVER.observe(
submod_name, lowered_module, submod_inputs
)
print(
f"Lowering submodule {submod_name} elapsed time",
datetime.datetime.now() - lowering_start_time,
)

return split_result.split_module

Expand Down
12 changes: 8 additions & 4 deletions py/torch_tensorrt/fx/test/converters/acc_op/test_as_strided.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec
from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase

# from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec


class TestConverter(AccTestCase):
Expand All @@ -30,7 +32,9 @@ def forward(self, x):
test_implicit_batch_dim=False,
)

# Testing with shape (-1, -1, -1, -1) results into error: RuntimeError: setStorage: sizes [2, 3], strides [1, 2], storage offset 0, and itemsize 8 requiring a storage size of 48 are out of bounds for storage of size 1
# Testing with shape (-1, 3) results into error:
# RuntimeError: setStorage: sizes [2, 3], strides [1, 2], storage offset 0, and itemsize 8 requiring a storage size of 48 are out of bounds for storage of size 16

"""
def test_as_strided_with_dynamic_shape_four_dimensions(self):
class Stride(nn.Module):
Expand All @@ -39,9 +43,9 @@ def forward(self, x):
input_specs = [
InputTensorSpec(
shape=(-1, -1, -1, -1),
shape=(-1, 3),
dtype=torch.float32,
shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (5, 5, 5, 5))],
shape_ranges=[((1, 3), (2, 3), (2, 3))],
),
]
Expand Down
74 changes: 42 additions & 32 deletions py/torch_tensorrt/fx/test/converters/acc_op/test_avgpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,48 @@ def forward(self, x):
inputs = [torch.randn(1, 3, 224)]
self.run_test(TestModule(), inputs, expected_ops={acc_ops.avg_pool1d})

@parameterized.expand(
[
("default", 1),
("kernal_size", 3),
("stride", 1, 2),
("tuple_parameters", 2, (1,), (1,)),
param("padding", 2, padding=1),
param("ceil_mode", 1, ceil_mode=True),
param("include_pad", 2, padding=1, count_include_pad=False),
]
)
def test_avg_pool1d_with_dynamic_shape(
self,
test_name="default",
kernel_size=1,
stride=1,
padding=0,
ceil_mode=False,
count_include_pad=True,
):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.avg_pool = torch.nn.AvgPool1d(
kernel_size, stride, padding, ceil_mode, count_include_pad
)

def forward(self, x):
return self.avg_pool(x)

input_specs = [
InputTensorSpec(
shape=(-1, 3, 3),
dtype=torch.float32,
shape_ranges=[((1, 3, 3), (3, 3, 3), (3, 3, 3))],
),
]

self.run_test_with_dynamic_shape(
TestModule(), input_specs, expected_ops={acc_ops.avg_pool1d}
)

def test_avg_pool2d_with_dynamic_shape_four_dimensions(
self,
test_name="default",
Expand Down Expand Up @@ -218,38 +260,6 @@ def forward(self, x):
TestModule(), input_specs, expected_ops={acc_ops.avg_pool2d}
)

# Testing with (-1, -1, -1, -1) results in error: RuntimeError: ShapeProp error for: node=%avg_pool1d : [#users=1] = call_function[target=torch.avg_pool1d](args = (%x, (1,), (1,), (0,), False, True), kwargs = {}) with meta={}
"""
def test_avg_pool1d_with_dynamic_shape_four_dimensions(
self,
test_name="default",
kernel_size=1,
stride=1,
padding=0,
ceil_mode=False,
count_include_pad=True,
):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.avg_pool = torch.nn.AvgPool1d(
kernel_size, stride, padding, ceil_mode, count_include_pad
)
def forward(self, x):
return self.avg_pool(x)
input_specs = [
InputTensorSpec(
shape=(-1, -1, -1, -1),
dtype=torch.float32,
shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (5, 5, 5, 5))],
),
]
self.run_test(TestModule(), input_specs, expected_ops={acc_ops.avg_pool1d})
"""


if __name__ == "__main__":
run_tests()
21 changes: 21 additions & 0 deletions py/torch_tensorrt/fx/test/converters/acc_op/test_batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,27 @@ def forward(self, x):
inputs = [torch.randn(1, 3, 224, 224)]
self.run_test(TestModule(), inputs, expected_ops={acc_ops.batch_norm})

def test_batchnorm1d_with_dynamic_shape(self):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.bn = torch.nn.BatchNorm1d(3)

def forward(self, x):
return self.bn(x)

input_specs = [
InputTensorSpec(
shape=(-1, 3, 5),
dtype=torch.float32,
shape_ranges=[((2, 3, 5), (6, 3, 5), (10, 3, 5))],
),
]

self.run_test_with_dynamic_shape(
TestModule(), input_specs, expected_ops={acc_ops.batch_norm}
)

def test_batchnorm_with_dynamic_shape(self):
class TestModule(torch.nn.Module):
def __init__(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
elementwise_ops = [
((lambda x, y: x + y), acc_ops.add, NEED_TEST_BOTH_CONSTANTS_CASE),
((lambda x, y: x - y), acc_ops.sub, NEED_TEST_BOTH_CONSTANTS_CASE),
((lambda x, y: torch.sub(x, y)), acc_ops.sub, False),
((lambda x, y: x.sub(y)), acc_ops.sub, False),
((lambda x, y: x / y), acc_ops.div, NEED_TEST_BOTH_CONSTANTS_CASE),
((lambda x, y: x // y), acc_ops.floor_div, NEED_TEST_BOTH_CONSTANTS_CASE),
(
Expand Down
9 changes: 4 additions & 5 deletions py/torch_tensorrt/fx/test/converters/acc_op/test_clamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops
from parameterized import param, parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase
from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec


class TestClampConverter(AccTestCase):
Expand All @@ -27,8 +27,6 @@ def forward(self, x):
inputs = [torch.randn(3, 4)]
self.run_test(TestModule(), inputs, expected_ops={acc_ops.clamp})

# Error: RuntimeError: ShapeProp error for: node=%clamp : [#users=1] = call_function[target=torch.clamp](args = (%x, 1, 0), kwargs = {}) with meta={}
"""
@parameterized.expand(
[
param("default", min=-1, max=0),
Expand All @@ -55,8 +53,9 @@ def forward(self, x):
),
]

self.run_test(TestModule(), input_specs, expected_ops={acc_ops.clamp})
"""
self.run_test_with_dynamic_shape(
TestModule(), input_specs, expected_ops={acc_ops.clamp}
)


if __name__ == "__main__":
Expand Down
Loading