Skip to content

Commit f423be1

Browse files
apbosenarendasan
authored andcommitted
Converter reorg and softmax operation
softmax linting error fix
1 parent 9611d67 commit f423be1

File tree

5 files changed

+117
-31
lines changed

5 files changed

+117
-31
lines changed

py/torch_tensorrt/fx/converters/acc_ops_converters.py

Lines changed: 2 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from torch_tensorrt.fx.tracer.acc_tracer.acc_ops import contiguous
2929
from torch_tensorrt.fx.converters.impl import activation
3030
from torch_tensorrt.fx.converters.impl.elementwise import trunc_div
31+
from torch_tensorrt.fx.converters.impl.normalization import softmax
3132
from torch_tensorrt.fx.converters.impl.unary import sign
3233
from torch_tensorrt.fx.converters.impl.elementwise.base import (
3334
convert_binary_elementwise,
@@ -858,37 +859,7 @@ def acc_ops_softmax(
858859
kwargs: Dict[str, Argument],
859860
name: str,
860861
) -> Union[TRTTensor, Sequence[TRTTensor]]:
861-
input_val = kwargs["input"]
862-
input_ranks = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0) # type: ignore[union-attr]
863-
864-
if not isinstance(input_val, TRTTensor):
865-
raise RuntimeError(
866-
f"softmax received input {input_val} that is not part "
867-
"of the TensorRT region!"
868-
)
869-
870-
# Used to get dim when dim is None. Copied from PyTorch softmax implementation.
871-
def get_softmax_dim(ndim: int) -> int:
872-
if ndim == 0 or ndim == 1 or ndim == 3:
873-
ret = 0
874-
else:
875-
ret = 1
876-
return ret
877-
878-
if kwargs["dim"] is None:
879-
dim = get_softmax_dim(input_ranks)
880-
else:
881-
dim = cast(int, kwargs["dim"])
882-
883-
dim = get_positive_dim(dim, input_ranks)
884-
if network.has_implicit_batch_dimension:
885-
assert dim != 0, "Can't apply softmax on batch dimension when it's implicit."
886-
dim -= 1
887-
888-
layer = network.add_softmax(input_val)
889-
layer.axes = 1 << dim
890-
set_layer_name(layer, target, name)
891-
return layer.get_output(0)
862+
return softmax(network, target, SourceIR.ACC, name, kwargs["input"], kwargs["dim"])
892863

893864

894865
@tensorrt_converter(acc_ops.tile)

py/torch_tensorrt/fx/converters/aten_ops_converters.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils
2525
from torch_tensorrt.fx.converters.impl import activation
2626
from torch_tensorrt.fx.converters.impl.elementwise import trunc_div
27+
from torch_tensorrt.fx.converters.impl.normalization import softmax
2728

2829
_LOGGER: logging.Logger = logging.getLogger(__name__)
2930

@@ -352,6 +353,17 @@ def aten_ops_reshape(
352353
return layer.get_output(0)
353354

354355

356+
@tensorrt_converter(torch.ops.aten._softmax.default)
357+
def aten_ops_softmax(
358+
network: TRTNetwork,
359+
target: Target,
360+
args: Tuple[Argument, ...],
361+
kwargs: Dict[str, Argument],
362+
name: str,
363+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
364+
return softmax(network, target, SourceIR.ATEN, name, args[0], args[1])
365+
366+
355367
@tensorrt_converter(torch.ops.aten.cat.default)
356368
def aten_ops_cat(
357369
network: TRTNetwork,
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .ops import *
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import operator
2+
import warnings
3+
from typing import cast, Union, Callable, Any, Optional, Sequence
4+
5+
import numpy as np
6+
7+
# @manual=//deeplearning/trt/python:py_tensorrt
8+
import tensorrt as trt
9+
import torch
10+
from torch.fx.node import Target
11+
12+
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
13+
14+
from torch_tensorrt.fx.converters.converter_utils import (
15+
SourceIR,
16+
set_layer_name,
17+
get_positive_dim,
18+
)
19+
20+
21+
def softmax(
22+
network: TRTNetwork,
23+
target: Target,
24+
source_ir: Optional[SourceIR],
25+
name: str,
26+
input: TRTTensor,
27+
dim: Optional[Any] = None,
28+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
29+
input_ranks = len(input.shape) + (1 if network.has_implicit_batch_dimension else 0) # type: ignore[union-attr]
30+
31+
if not isinstance(input, TRTTensor):
32+
raise RuntimeError(
33+
f"softmax received input {input} that is not part "
34+
"of the TensorRT region!"
35+
)
36+
37+
# Used to get dim when dim is None. Copied from PyTorch softmax implementation.
38+
def get_softmax_dim(ndim: int) -> int:
39+
if ndim == 0 or ndim == 1 or ndim == 3:
40+
ret = 0
41+
else:
42+
ret = 1
43+
return ret
44+
45+
if dim is None:
46+
dim = get_softmax_dim(input_ranks)
47+
else:
48+
dim = cast(int, dim)
49+
50+
dim = get_positive_dim(dim, input_ranks)
51+
if network.has_implicit_batch_dimension:
52+
assert dim != 0, "Can't apply softmax on batch dimension when it's implicit."
53+
dim -= 1
54+
55+
layer = network.add_softmax(input)
56+
layer.axes = 1 << dim
57+
set_layer_name(layer, target, name)
58+
return layer.get_output(0)
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import torch
2+
from torch.testing._internal.common_utils import run_tests
3+
from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec
4+
5+
6+
class TestSoftMaxConverter(DispatchTestCase):
7+
def test_softmax(self):
8+
class TestModule(torch.nn.Module):
9+
def __init__(self):
10+
super().__init__()
11+
self.softmax = torch.nn.Softmax(1)
12+
13+
def forward(self, x):
14+
return self.softmax(x)
15+
16+
inputs = [torch.randn(1, 3, 224, 224)]
17+
self.run_test(
18+
TestModule(), inputs, expected_ops={torch.ops.aten._softmax.default}
19+
)
20+
21+
def test_softmax_with_dynamic_shape(self):
22+
class TestModule(torch.nn.Module):
23+
def __init__(self):
24+
super().__init__()
25+
self.softmax = torch.nn.Softmax(2)
26+
27+
def forward(self, x):
28+
return self.softmax(x)
29+
30+
input_specs = [
31+
InputTensorSpec(
32+
shape=(-1, 3, -1, -1),
33+
dtype=torch.float32,
34+
shape_ranges=[((1, 3, 1, 1), (1, 3, 5, 5), (2, 3, 10, 10))],
35+
),
36+
]
37+
38+
self.run_test_with_dynamic_shape(
39+
TestModule(), input_specs, expected_ops={torch.ops.aten._softmax.default}
40+
)
41+
42+
43+
if __name__ == "__main__":
44+
run_tests()

0 commit comments

Comments
 (0)