|
1 | | -from typing import Optional |
| 1 | +from typing import Optional, Union |
2 | 2 |
|
3 | 3 | import tensorrt as trt |
4 | 4 | from torch.fx.node import Target |
5 | 5 | from torch_tensorrt.dynamo._SourceIR import SourceIR |
| 6 | +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext |
6 | 7 | from torch_tensorrt.dynamo.conversion.converter_utils import ( |
7 | 8 | cast_trt_tensor, |
| 9 | + flatten_dims, |
8 | 10 | get_axes_for_reduce_op, |
9 | 11 | ) |
10 | 12 | from torch_tensorrt.fx.converters.converter_utils import ( |
11 | 13 | get_positive_dim, |
12 | 14 | set_layer_name, |
13 | 15 | ) |
14 | | -from torch_tensorrt.fx.types import TRTNetwork, TRTTensor |
| 16 | +from torch_tensorrt.fx.types import TRTTensor |
15 | 17 |
|
16 | 18 | from . import squeeze |
17 | 19 |
|
18 | 20 |
|
19 | 21 | def argmax( |
20 | | - network: TRTNetwork, |
| 22 | + ctx: ConversionContext, |
21 | 23 | target: Target, |
22 | 24 | source_ir: Optional[SourceIR], |
23 | 25 | name: str, |
24 | 26 | input: TRTTensor, |
25 | | - dim: int = 0, |
| 27 | + dim: Union[int, None], |
26 | 28 | keep_dim: bool = False, |
27 | 29 | ) -> TRTTensor: |
28 | 30 | if not isinstance(input, TRTTensor): |
29 | 31 | raise RuntimeError( |
30 | 32 | f"argmax received input {input} that is not part " "of the TensorRT region!" |
31 | 33 | ) |
| 34 | + |
32 | 35 | if input.dtype == trt.int32: |
33 | | - input = cast_trt_tensor(network, input, trt.float32, name) |
34 | | - if dim < 0: |
| 36 | + input = cast_trt_tensor(ctx, input, trt.float32, name) |
| 37 | + |
| 38 | + # Three different cases here: |
| 39 | + # 1. dim == None, flatten input tensor first, keep_dim will be ignore and the output rank == input rank |
| 40 | + # 2. input rank == 1: TopK layer does not support 1 dimensional topk operation. Broadcast input to rank == 2 |
| 41 | + # 3. normal cases, no additional handlings |
| 42 | + out = input |
| 43 | + |
| 44 | + if dim is None: |
| 45 | + shuffle_layer = ctx.net.add_shuffle(input) |
| 46 | + shuffle_layer.reshape_dims = (*flatten_dims(input, 0, -1), 1) |
| 47 | + set_layer_name(shuffle_layer, target, name + "_flatten") |
| 48 | + out = shuffle_layer.get_output(0) |
| 49 | + elif len(input.shape) == 1: |
| 50 | + shuffle_layer = ctx.net.add_shuffle(input) |
| 51 | + shuffle_layer.reshape_dims = (*input.shape, 1) |
| 52 | + set_layer_name(shuffle_layer, target, name + "_broadcast") |
| 53 | + out = shuffle_layer.get_output(0) |
| 54 | + elif dim < 0: |
35 | 55 | dim = len(tuple(input.shape)) + dim |
36 | | - reduce_mask = get_axes_for_reduce_op(get_positive_dim(dim, len(input.shape))) |
37 | | - topk_layer = network.add_topk(input, trt.TopKOperation.MAX, 1, reduce_mask) |
| 56 | + |
| 57 | + reduce_mask = get_axes_for_reduce_op(0) |
| 58 | + if dim is not None: |
| 59 | + reduce_mask = get_axes_for_reduce_op(get_positive_dim(dim, len(out.shape))) |
| 60 | + |
| 61 | + topk_layer = ctx.net.add_topk(out, trt.TopKOperation.MAX, 1, reduce_mask) |
38 | 62 | set_layer_name(topk_layer, target, name) |
39 | 63 |
|
40 | 64 | out = topk_layer.get_output(1) |
41 | 65 |
|
42 | | - if not keep_dim: |
| 66 | + if dim is None: |
| 67 | + out_shuffle_layer = ctx.net.add_shuffle(out) |
| 68 | + out_shuffle_layer.reshape_dims = (1,) * len(input.shape) if keep_dim else () |
| 69 | + set_layer_name(out_shuffle_layer, target, name + "_broadcast") |
| 70 | + out = out_shuffle_layer.get_output(0) |
| 71 | + elif len(input.shape) == 1: |
43 | 72 | out = squeeze.squeeze( |
44 | | - network, target, SourceIR.ATEN, name + "_squeeze", out, dim |
| 73 | + ctx, |
| 74 | + target, |
| 75 | + SourceIR.ATEN, |
| 76 | + name + "_squeeze", |
| 77 | + out, |
| 78 | + 1 if keep_dim else [0, 1], |
45 | 79 | ) |
| 80 | + elif not keep_dim: |
| 81 | + out = squeeze.squeeze(ctx, target, SourceIR.ATEN, name + "_squeeze", out, dim) |
46 | 82 |
|
47 | 83 | return out |
0 commit comments