55import torch
66
77import helion
8+ from helion ._compat import get_tensor_descriptor_fn_name
89from helion ._compat import supports_tensor_descriptor
910from helion ._testing import DEVICE
1011from helion ._testing import TestCase
@@ -41,15 +42,15 @@ def kernel_with_permutation(x: torch.Tensor) -> torch.Tensor:
4142 kernel_with_permutation ,
4243 (x ,),
4344 indexing = "tensor_descriptor" ,
44- block_sizes = [4 , 8 ],
45+ block_sizes = [8 , 8 ],
4546 )
4647
4748 # Check that the result is correct
4849 expected = x + 1.0
4950 torch .testing .assert_close (result , expected )
5051
5152 # Check that the generated code contains permutation calls
52- self .assertIn ("tl.make_tensor_descriptor" , code )
53+ self .assertIn (get_tensor_descriptor_fn_name () , code )
5354 # The tensor descriptor should be created with permuted dimensions
5455 # (sizes and strides should be reordered so stride==1 dim is last)
5556
@@ -77,15 +78,15 @@ def kernel_no_permutation(x: torch.Tensor) -> torch.Tensor:
7778 kernel_no_permutation ,
7879 (x ,),
7980 indexing = "tensor_descriptor" ,
80- block_sizes = [4 , 8 ],
81+ block_sizes = [8 , 8 ],
8182 )
8283
8384 # Check that the result is correct
8485 expected = x * 2.0
8586 torch .testing .assert_close (result , expected )
8687
8788 # Check that the generated code contains tensor descriptor
88- self .assertIn ("tl.make_tensor_descriptor" , code )
89+ self .assertIn (get_tensor_descriptor_fn_name () , code )
8990 # Should not contain permute calls since no permutation needed
9091 self .assertNotIn ("tl.permute" , code )
9192
@@ -121,7 +122,7 @@ def kernel_3d_permutation(x: torch.Tensor) -> torch.Tensor:
121122 torch .testing .assert_close (result , expected )
122123
123124 # Should contain both tensor descriptor and permute operations
124- self .assertIn ("tl.make_tensor_descriptor" , code )
125+ self .assertIn (get_tensor_descriptor_fn_name () , code )
125126 self .assertIn ("tl.permute" , code )
126127
127128 @unittest .skipUnless (
@@ -149,15 +150,15 @@ def kernel_transpose_case(x: torch.Tensor) -> torch.Tensor:
149150 kernel_transpose_case ,
150151 (x ,),
151152 indexing = "tensor_descriptor" ,
152- block_sizes = [4 , 8 ],
153+ block_sizes = [8 , 8 ],
153154 )
154155
155156 # Check correctness
156157 expected = x * x
157158 torch .testing .assert_close (result , expected )
158159
159160 # Should handle the permutation properly
160- self .assertIn ("tl.make_tensor_descriptor" , code )
161+ self .assertIn (get_tensor_descriptor_fn_name () , code )
161162 self .assertIn ("tl.permute" , code )
162163
163164 @unittest .skipUnless (
@@ -183,14 +184,14 @@ def kernel_different_blocks(x: torch.Tensor) -> torch.Tensor:
183184 kernel_different_blocks ,
184185 (x ,),
185186 indexing = "tensor_descriptor" ,
186- block_sizes = [4 , 8 ],
187+ block_sizes = [8 , 8 ],
187188 )
188189
189190 expected = x + 5.0
190191 torch .testing .assert_close (result , expected )
191192
192193 # Should contain permutation and tensor descriptor
193- self .assertIn ("tl.make_tensor_descriptor" , code )
194+ self .assertIn (get_tensor_descriptor_fn_name () , code )
194195 self .assertIn ("tl.permute" , code )
195196
196197 # The block sizes should also be permuted in the tensor descriptor
@@ -223,14 +224,14 @@ def kernel_store_permutation(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
223224 kernel_store_permutation ,
224225 (x , y ),
225226 indexing = "tensor_descriptor" ,
226- block_sizes = [4 , 8 ],
227+ block_sizes = [8 , 8 ],
227228 )
228229
229230 expected = x * 3.0
230231 torch .testing .assert_close (result , expected )
231232
232233 # Should have permutation for both load and store
233- self .assertIn ("tl.make_tensor_descriptor" , code )
234+ self .assertIn (get_tensor_descriptor_fn_name () , code )
234235 self .assertIn ("tl.permute" , code )
235236
236237 @unittest .skipUnless (
@@ -301,7 +302,7 @@ def kernel_small_block(x: torch.Tensor) -> torch.Tensor:
301302
302303 # Should fall back to block_ptr or pointer indexing instead of tensor descriptor
303304 # If our fix works, this should NOT contain tensor descriptor
304- self .assertNotIn ("tl.make_tensor_descriptor" , code )
305+ self .assertNotIn (get_tensor_descriptor_fn_name () , code )
305306
306307 # But should still work correctly
307308 expected = x + 1.0
0 commit comments