Skip to content

[FX] aten2trt and some pass fixes #1390

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions py/torch_tensorrt/fx/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,20 @@ FX2TRT is merged as FX module in Torch-TensorRT

- The user guide is in [link](../../../docsrc/tutorials/getting_started_with_fx_path.rst#installation)
- The examples are moved to [link](../../../examples/fx)

* Method 1. Follow the instrucions for Torch-TensorRT
* Method 2. To install FX path only (Python path) and avoid the C++ build for torchscript path
`
$ conda create --name python_env python=3.8
$ conda activate python_env
# Recommend to install PyTorch 1.12 and later
$ conda install pytorch torchvision torchtext cudatoolkit=11.3 -c pytorch-nightly
# Install TensorRT python package
$ pip3 install nvidia-pyindex
$ pip3 install nvidia-tensorrt==8.2.4.2
$ git clone https://github.com/pytorch/TensorRT.git
$ cd TensorRT/py && python setup.py install --fx-only && cd ..
$ pyton -c "import torch_tensorrt.fx"
# Test an example by
$ python py/torch_tensorrt/fx/example/lower_example.py
`
1 change: 1 addition & 0 deletions py/torch_tensorrt/fx/converters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .transformation import * # noqa: F401 F403
from .quantization import * # noqa: F401 F403
from .acc_ops_converters import * # noqa: F401 F403
from .aten_ops_converters import * # noqa: F401 F403

TRT_LOGGER = trt.Logger()
trt.init_libnvinfer_plugins(TRT_LOGGER, "")
104 changes: 87 additions & 17 deletions py/torch_tensorrt/fx/converters/acc_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,63 @@
from ..utils import get_dynamic_dims, torch_dtype_from_trt, torch_dtype_to_trt

from .converter_utils import * # noqa: F403

from torch_tensorrt.fx.passes.lower_basic_pass import (
trt_transposed_linear,
trt_transposed_matmul,
)

_LOGGER: logging.Logger = logging.getLogger(__name__)


@tensorrt_converter(trt_transposed_matmul)
def trt_transposed_matmul_converter(network, target, args, kwargs, name):
lhs, rhs, lhs_transposed, rhs_transposed = args

if isinstance(lhs, torch.nn.Parameter):
lhs = get_trt_tensor(network, lhs, f"{name}_lhs")
if isinstance(rhs, torch.nn.Parameter):
rhs = get_trt_tensor(network, rhs, f"{name}_rhs")
layer = network.add_matrix_multiply(
lhs,
trt.MatrixOperation.TRANSPOSE if lhs_transposed else trt.MatrixOperation.NONE,
rhs,
trt.MatrixOperation.TRANSPOSE if rhs_transposed else trt.MatrixOperation.NONE,
)
set_layer_name(layer, target, name)
return layer.get_output(0)


@tensorrt_converter(trt_transposed_linear)
def trt_transposed_linear_converter(network, target, args, kwargs, name):
input, weight, bias = args

weight = get_trt_tensor(network, weight.t(), f"{name}_weight")
bias = get_trt_tensor(network, bias.reshape(1, -1), f"{name}_bias")

input, weight = broadcast(
network,
input,
weight,
f"{input.name}_broadcast",
f"{weight.name}_broadcast",
)
layer = network.add_matrix_multiply(
input,
trt.MatrixOperation.TRANSPOSE,
weight,
trt.MatrixOperation.NONE,
)
set_layer_name(layer, target, f"{name}_mm")
return add_binary_elementwise_layer(
network,
layer.get_output(0),
bias,
trt.ElementWiseOperation.SUM,
target,
f"{name}_add",
)


@tensorrt_converter(acc_ops.conv1d)
def acc_ops_conv1d(
network: TRTNetwork,
Expand Down Expand Up @@ -1975,7 +2027,10 @@ def acc_ops_max_poolnd(
f"MaxPool2d received input {input_val} that is not part "
"of the TensorRT region!"
)
extend_len = 2 if target == acc_ops.max_pool2d else 3
if target not in (acc_ops.max_pool2d, acc_ops.max_pool3d):
extend_len = 2 if len(kwargs["kernel_size"]) == 2 else 3
else:
extend_len = 2 if target == acc_ops.max_pool2d else 3
kernel_size = extend_attr_to_tuple(kwargs["kernel_size"], extend_len)
stride = extend_attr_to_tuple(kwargs["stride"], extend_len)
padding = extend_attr_to_tuple(kwargs["padding"], extend_len)
Expand Down Expand Up @@ -2259,8 +2314,11 @@ def acc_ops_adaptive_avg_poolnd(
f"AdaptiveAvgPool2d received input {input_val} that is not part "
"of the TensorRT region!"
)
if target not in (acc_ops.adaptive_avg_pool3d, acc_ops.adaptive_avg_pool2d):
extend_len = 2 if len(kwargs["output_size"]) == 2 else 3
else:
extend_len = 2 if target == acc_ops.adaptive_avg_pool2d else 3

extend_len = 2 if target == acc_ops.adaptive_avg_pool2d else 3
assert all(
input_val.shape[-(i + 1)] != -1 for i in range(extend_len)
), "AdaptiveAvgPool2d and AdaptiveAvgPool3d currently doesn't support dynamic shapes for last two dims."
Expand Down Expand Up @@ -2747,7 +2805,10 @@ def acc_ops_linear(

if isinstance(kwargs["weight"], torch.Tensor):
weight = get_trt_tensor(network, kwargs["weight"].t(), f"{name}_weight")
weight_op = trt.MatrixOperation.NONE
if target is not acc_ops.linear:
weight_op = trt.MatrixOperation.TRANSPOSE
else:
weight_op = trt.MatrixOperation.NONE
else:
assert isinstance(
kwargs["weight"], TRTTensor
Expand Down Expand Up @@ -2782,17 +2843,26 @@ def acc_ops_linear(
return res


def add_clamp(network, input, val, op):
acc_ops_clamp_shape = (1,) * len(input.shape) # broadcast all dimensions
acc_ops_clamp_tensor = (
val
* torch.ones(acc_ops_clamp_shape, dtype=torch_dtype_from_trt(input.dtype))
.cpu()
.numpy()
)
acc_ops_clamp_trt = network.add_constant(acc_ops_clamp_shape, acc_ops_clamp_tensor)
layer = network.add_elementwise(input, acc_ops_clamp_trt.get_output(0), op)

def add_clamp(network, input, val, op, name):
if not len(input.shape):
# clamping scalar
acc_ops_clamp_trt = get_trt_tensor(
network,
squeeze_left(torch.tensor([val], dtype=torch_dtype_from_trt(input.dtype))),
f"{name}_clamp_{val}",
)
else:
acc_ops_clamp_shape = (1,) * len(input.shape) # broadcast all dimensions
acc_ops_clamp_tensor = (
val
* torch.ones(acc_ops_clamp_shape, dtype=torch_dtype_from_trt(input.dtype))
.cpu()
.numpy()
)
acc_ops_clamp_trt = network.add_constant(
acc_ops_clamp_shape, acc_ops_clamp_tensor
).get_output(0)
layer = network.add_elementwise(input, acc_ops_clamp_trt, op)
return layer


Expand All @@ -2816,13 +2886,13 @@ def acc_ops_clamp(

if min_val is not None:
clamp_min_layer = add_clamp(
network, input_val, min_val, trt.ElementWiseOperation.MAX
network, input_val, min_val, trt.ElementWiseOperation.MAX, name
)
set_layer_name(clamp_min_layer, target, f"{name}_clamp_min")
input_val = clamp_min_layer.get_output(0)
if max_val is not None:
clamp_max_layer = add_clamp(
network, input_val, max_val, trt.ElementWiseOperation.MIN
network, input_val, max_val, trt.ElementWiseOperation.MIN, name
)
set_layer_name(clamp_max_layer, target, f"{name}_clamp_max")
input_val = clamp_max_layer.get_output(0)
Expand Down
Loading