Skip to content

Commit 0691d05

Browse files
committed
Adding the broadcast check and adding corresponding tests
1 parent c94114a commit 0691d05

File tree

2 files changed

+58
-0
lines changed

2 files changed

+58
-0
lines changed

py/torch_tensorrt/dynamo/conversion/impl/select.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import tensorrt as trt
66
from torch.fx.node import Target
77
from torch_tensorrt.dynamo._SourceIR import SourceIR
8+
from torch_tensorrt.dynamo.conversion.converter_utils import broadcastable
89
from torch_tensorrt.dynamo.conversion.impl.elementwise import convert_binary_elementwise
910
from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
1011
from torch_tensorrt.fx.converters.converter_utils import (
@@ -84,12 +85,21 @@ def index(
8485
# check if the input is dynamic
8586
dynamic_shape = has_dynamic_shape(input.shape)
8687

88+
# here we need to check if all the index are broadcastable
89+
# if no, then we need to broadcast
90+
91+
last_index = None
92+
broadcast_shape_len = 0
8793
for i, ind in enumerate(index):
8894
if ind is not None:
8995
_LOGGER.debug(f"Shape of {i} index is {ind.shape}")
9096
adv_indx_indices.append(i)
9197
# torch.nn.parameter.Parameter=> torch.Tensor
9298
ind = get_trt_tensor(network, ind, f"parameter_to_fp32_tensor_{i}")
99+
if last_index is not None:
100+
if not (broadcastable(ind, last_index)):
101+
assert "The indices should be broadcastable"
102+
last_index = ind
93103
tensor_indices.append(ind)
94104

95105
if not tensor_indices:

tests/py/dynamo/conversion/test_index_aten.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,54 @@ def forward(self, x):
8686
expected_ops={torch.ops.aten.index.Tensor, operator.getitem},
8787
)
8888

89+
def test_index_zero_index_one_SD(self):
90+
class TestModule(nn.Module):
91+
def forward(self, x):
92+
index0 = torch.tensor([0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7])
93+
index1 = torch.tensor([0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7])
94+
indices = [None, index0, index1, None]
95+
out = torch.ops.aten.index.Tensor(x, indices)
96+
return out
97+
98+
input = [torch.randn(2, 1280, 8, 8)]
99+
self.run_test(
100+
TestModule(),
101+
input,
102+
expected_ops={torch.ops.aten.index.Tensor, operator.getitem},
103+
)
104+
105+
def test_index_zero_index_one_SD_unsqueeze(self):
106+
class TestModule(nn.Module):
107+
def forward(self, x):
108+
index0 = torch.tensor([0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7])
109+
index1 = index0.unsqueeze(0).T.long()
110+
indices = [None, None, index1, index1]
111+
out = torch.ops.aten.index.Tensor(x, indices)
112+
return out
113+
114+
input = [torch.randn(2, 1280, 8, 8)]
115+
self.run_test(
116+
TestModule(),
117+
input,
118+
expected_ops={torch.ops.aten.index.Tensor},
119+
)
120+
121+
def test_index_zero_index_one_index_two_SD_unsqueeze(self):
122+
class TestModule(nn.Module):
123+
def forward(self, x):
124+
index0 = torch.tensor([0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7])
125+
index1 = index0.unsqueeze(0).T.long()
126+
indices = [None, None, index0, index1]
127+
out = torch.ops.aten.index.Tensor(x, indices)
128+
return out
129+
130+
input = [torch.randn(2, 1280, 8, 8)]
131+
self.run_test(
132+
TestModule(),
133+
input,
134+
expected_ops={torch.ops.aten.index.Tensor},
135+
)
136+
89137

90138
if __name__ == "__main__":
91139
run_tests()

0 commit comments

Comments
 (0)