22
33import torch
44import torch .nn as nn
5+ from harness import DispatchTestCase
56from torch .testing ._internal .common_utils import run_tests
67from torch_tensorrt import Input
78
8- from .harness import DispatchTestCase
9-
109
1110class TestIndexConverter (DispatchTestCase ):
1211 def test_index_zero_two_dim (self ):
@@ -27,6 +26,21 @@ def forward(self, x):
2726 input ,
2827 )
2928
29+ def test_index_zero_two_dim_ITensor (self ):
30+ class TestModule (nn .Module ):
31+ def forward (self , x , index0 ):
32+ indices = [None , index0 ]
33+ out = torch .ops .aten .index .Tensor (x , indices )
34+ return out
35+
36+ input = torch .randn (2 , 2 )
37+ index0 = torch .randint (0 , 1 , (1 , 1 ))
38+ index0 = index0 .to (torch .int32 )
39+ self .run_test (
40+ TestModule (),
41+ [input , index0 ],
42+ )
43+
3044 def test_index_zero_index_three_dim (self ):
3145 class TestModule (nn .Module ):
3246 def __init__ (self ):
@@ -44,6 +58,18 @@ def forward(self, x):
4458 input ,
4559 )
4660
61+ def test_index_zero_index_three_dim_ITensor (self ):
62+ class TestModule (nn .Module ):
63+ def forward (self , x , index0 ):
64+ indices = [None , index0 , None ]
65+ out = torch .ops .aten .index .Tensor (x , indices )
66+ return out
67+
68+ input = torch .randn (2 , 2 , 2 )
69+ index0 = torch .randint (0 , 1 , (1 , 1 ))
70+ index0 = index0 .to (torch .int32 )
71+ self .run_test (TestModule (), [input , index0 ])
72+
4773 def test_index_zero_index_one_index_two_three_dim (self ):
4874 class TestModule (nn .Module ):
4975 def __init__ (self ):
0 commit comments