55from torch .testing ._internal .common_utils import run_tests
66from torch_tensorrt import Input
77
8- from . harness import DispatchTestCase
8+ from harness import DispatchTestCase
99
1010
1111class TestIndexConverter (DispatchTestCase ):
@@ -26,6 +26,21 @@ def forward(self, x):
2626 TestModule (),
2727 input ,
2828 )
29+
30+ def test_index_zero_two_dim_ITensor (self ):
31+ class TestModule (nn .Module ):
32+ def forward (self , x , index0 ):
33+ indices = [None , index0 ]
34+ out = torch .ops .aten .index .Tensor (x , indices )
35+ return out
36+
37+ input = torch .randn (2 , 2 )
38+ index0 = torch .randint (0 , 1 , (1 , 1 ))
39+ index0 = index0 .to (torch .int32 )
40+ self .run_test (
41+ TestModule (),
42+ [input , index0 ],
43+ )
2944
3045 def test_index_zero_index_three_dim (self ):
3146 class TestModule (nn .Module ):
@@ -43,6 +58,21 @@ def forward(self, x):
4358 TestModule (),
4459 input ,
4560 )
61+
62+ def test_index_zero_index_three_dim_ITensor (self ):
63+ class TestModule (nn .Module ):
64+ def forward (self , x , index0 ):
65+ indices = [None , index0 , None ]
66+ out = torch .ops .aten .index .Tensor (x , indices )
67+ return out
68+
69+ input = torch .randn (2 , 2 , 2 )
70+ index0 = torch .randint (0 , 1 , (1 , 1 ))
71+ index0 = index0 .to (torch .int32 )
72+ self .run_test (
73+ TestModule (),
74+ [input , index0 ]
75+ )
4676
4777 def test_index_zero_index_one_index_two_three_dim (self ):
4878 class TestModule (nn .Module ):
0 commit comments