Skip to content

Commit 75b1a2a

Browse files
apbosenarendasan
authored andcommitted
Converter reorg and softmax operation
softmax linting error fix
1 parent 546f975 commit 75b1a2a

File tree

5 files changed

+117
-31
lines changed

5 files changed

+117
-31
lines changed

py/torch_tensorrt/fx/converters/acc_ops_converters.py

Lines changed: 2 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from torch_tensorrt.fx.tracer.acc_tracer.acc_ops import contiguous
2929
from torch_tensorrt.fx.converters.impl import activation
3030
from torch_tensorrt.fx.converters.impl.elementwise import trunc_div
31+
from torch_tensorrt.fx.converters.impl.normalization import softmax
3132
from torch_tensorrt.fx.converters.impl.unary import sign
3233
from torch_tensorrt.fx.converters.impl.elementwise.base import (
3334
convert_binary_elementwise,
@@ -861,37 +862,7 @@ def acc_ops_softmax(
861862
kwargs: Dict[str, Argument],
862863
name: str,
863864
) -> Union[TRTTensor, Sequence[TRTTensor]]:
864-
input_val = kwargs["input"]
865-
input_ranks = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0) # type: ignore[union-attr]
866-
867-
if not isinstance(input_val, TRTTensor):
868-
raise RuntimeError(
869-
f"softmax received input {input_val} that is not part "
870-
"of the TensorRT region!"
871-
)
872-
873-
# Used to get dim when dim is None. Copied from PyTorch softmax implementation.
874-
def get_softmax_dim(ndim: int) -> int:
875-
if ndim == 0 or ndim == 1 or ndim == 3:
876-
ret = 0
877-
else:
878-
ret = 1
879-
return ret
880-
881-
if kwargs["dim"] is None:
882-
dim = get_softmax_dim(input_ranks)
883-
else:
884-
dim = cast(int, kwargs["dim"])
885-
886-
dim = get_positive_dim(dim, input_ranks)
887-
if network.has_implicit_batch_dimension:
888-
assert dim != 0, "Can't apply softmax on batch dimension when it's implicit."
889-
dim -= 1
890-
891-
layer = network.add_softmax(input_val)
892-
layer.axes = 1 << dim
893-
set_layer_name(layer, target, name)
894-
return layer.get_output(0)
865+
return softmax(network, target, SourceIR.ACC, name, kwargs["input"], kwargs["dim"])
895866

896867

897868
@tensorrt_converter(acc_ops.tile)

py/torch_tensorrt/fx/converters/aten_ops_converters.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from torch_tensorrt.fx.converters.impl import activation
2626
from torch_tensorrt.fx.converters.impl.elementwise import trunc_div
2727
from torch_tensorrt.fx.converters.impl.elementwise import rsqrt
28+
from torch_tensorrt.fx.converters.impl.normalization import softmax
2829

2930
_LOGGER: logging.Logger = logging.getLogger(__name__)
3031

@@ -389,6 +390,17 @@ def aten_ops_reshape(
389390
return layer.get_output(0)
390391

391392

393+
@tensorrt_converter(torch.ops.aten._softmax.default)
394+
def aten_ops_softmax(
395+
network: TRTNetwork,
396+
target: Target,
397+
args: Tuple[Argument, ...],
398+
kwargs: Dict[str, Argument],
399+
name: str,
400+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
401+
return softmax(network, target, SourceIR.ATEN, name, args[0], args[1])
402+
403+
392404
@tensorrt_converter(torch.ops.aten.cat.default)
393405
def aten_ops_cat(
394406
network: TRTNetwork,
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .ops import *
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import operator
2+
import warnings
3+
from typing import cast, Union, Callable, Any, Optional, Sequence
4+
5+
import numpy as np
6+
7+
# @manual=//deeplearning/trt/python:py_tensorrt
8+
import tensorrt as trt
9+
import torch
10+
from torch.fx.node import Target
11+
12+
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
13+
14+
from torch_tensorrt.fx.converters.converter_utils import (
15+
SourceIR,
16+
set_layer_name,
17+
get_positive_dim,
18+
)
19+
20+
21+
def softmax(
22+
network: TRTNetwork,
23+
target: Target,
24+
source_ir: Optional[SourceIR],
25+
name: str,
26+
input: TRTTensor,
27+
dim: Optional[Any] = None,
28+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
29+
input_ranks = len(input.shape) + (1 if network.has_implicit_batch_dimension else 0) # type: ignore[union-attr]
30+
31+
if not isinstance(input, TRTTensor):
32+
raise RuntimeError(
33+
f"softmax received input {input} that is not part "
34+
"of the TensorRT region!"
35+
)
36+
37+
# Used to get dim when dim is None. Copied from PyTorch softmax implementation.
38+
def get_softmax_dim(ndim: int) -> int:
39+
if ndim == 0 or ndim == 1 or ndim == 3:
40+
ret = 0
41+
else:
42+
ret = 1
43+
return ret
44+
45+
if dim is None:
46+
dim = get_softmax_dim(input_ranks)
47+
else:
48+
dim = cast(int, dim)
49+
50+
dim = get_positive_dim(dim, input_ranks)
51+
if network.has_implicit_batch_dimension:
52+
assert dim != 0, "Can't apply softmax on batch dimension when it's implicit."
53+
dim -= 1
54+
55+
layer = network.add_softmax(input)
56+
layer.axes = 1 << dim
57+
set_layer_name(layer, target, name)
58+
return layer.get_output(0)
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import torch
2+
from torch.testing._internal.common_utils import run_tests
3+
from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec
4+
5+
6+
class TestSoftMaxConverter(DispatchTestCase):
7+
def test_softmax(self):
8+
class TestModule(torch.nn.Module):
9+
def __init__(self):
10+
super().__init__()
11+
self.softmax = torch.nn.Softmax(1)
12+
13+
def forward(self, x):
14+
return self.softmax(x)
15+
16+
inputs = [torch.randn(1, 3, 224, 224)]
17+
self.run_test(
18+
TestModule(), inputs, expected_ops={torch.ops.aten._softmax.default}
19+
)
20+
21+
def test_softmax_with_dynamic_shape(self):
22+
class TestModule(torch.nn.Module):
23+
def __init__(self):
24+
super().__init__()
25+
self.softmax = torch.nn.Softmax(2)
26+
27+
def forward(self, x):
28+
return self.softmax(x)
29+
30+
input_specs = [
31+
InputTensorSpec(
32+
shape=(-1, 3, -1, -1),
33+
dtype=torch.float32,
34+
shape_ranges=[((1, 3, 1, 1), (1, 3, 5, 5), (2, 3, 10, 10))],
35+
),
36+
]
37+
38+
self.run_test_with_dynamic_shape(
39+
TestModule(), input_specs, expected_ops={torch.ops.aten._softmax.default}
40+
)
41+
42+
43+
if __name__ == "__main__":
44+
run_tests()

0 commit comments

Comments
 (0)