@@ -86,6 +86,54 @@ def forward(self, x):
8686 expected_ops = {torch .ops .aten .index .Tensor , operator .getitem },
8787 )
8888
89+ def test_index_zero_index_one_SD (self ):
90+ class TestModule (nn .Module ):
91+ def forward (self , x ):
92+ index0 = torch .tensor ([0 , 0 , 1 , 1 , 2 , 2 , 3 , 3 , 4 , 4 , 5 , 5 , 6 , 6 , 7 , 7 ])
93+ index1 = torch .tensor ([0 , 0 , 1 , 1 , 2 , 2 , 3 , 3 , 4 , 4 , 5 , 5 , 6 , 6 , 7 , 7 ])
94+ indices = [None , index0 , index1 , None ]
95+ out = torch .ops .aten .index .Tensor (x , indices )
96+ return out
97+
98+ input = [torch .randn (2 , 1280 , 8 , 8 )]
99+ self .run_test (
100+ TestModule (),
101+ input ,
102+ expected_ops = {torch .ops .aten .index .Tensor , operator .getitem },
103+ )
104+
105+ def test_index_zero_index_one_SD_unsqueeze (self ):
106+ class TestModule (nn .Module ):
107+ def forward (self , x ):
108+ index0 = torch .tensor ([0 , 0 , 1 , 1 , 2 , 2 , 3 , 3 , 4 , 4 , 5 , 5 , 6 , 6 , 7 , 7 ])
109+ index1 = index0 .unsqueeze (0 ).T .long ()
110+ indices = [None , None , index1 , index1 ]
111+ out = torch .ops .aten .index .Tensor (x , indices )
112+ return out
113+
114+ input = [torch .randn (2 , 1280 , 8 , 8 )]
115+ self .run_test (
116+ TestModule (),
117+ input ,
118+ expected_ops = {torch .ops .aten .index .Tensor },
119+ )
120+
121+ def test_index_zero_index_one_index_two_SD_unsqueeze (self ):
122+ class TestModule (nn .Module ):
123+ def forward (self , x ):
124+ index0 = torch .tensor ([0 , 0 , 1 , 1 , 2 , 2 , 3 , 3 , 4 , 4 , 5 , 5 , 6 , 6 , 7 , 7 ])
125+ index1 = index0 .unsqueeze (0 ).T .long ()
126+ indices = [None , None , index0 , index1 ]
127+ out = torch .ops .aten .index .Tensor (x , indices )
128+ return out
129+
130+ input = [torch .randn (2 , 1280 , 8 , 8 )]
131+ self .run_test (
132+ TestModule (),
133+ input ,
134+ expected_ops = {torch .ops .aten .index .Tensor },
135+ )
136+
89137
90138if __name__ == "__main__" :
91139 run_tests ()
0 commit comments