1717 alloca_scope ,
1818 alloca_scope_return ,
1919 global_ ,
20+ rank_reduce ,
2021)
2122from mlir .extras .dialects .ext .scf import (
2223 range_ ,
@@ -41,6 +42,10 @@ def test_simple_literal_indexing(ctx: MLIRContext):
4142 w = mem [2 , 4 , 6 , 8 ]
4243 assert isinstance (w , Scalar )
4344
45+ two = constant (1 ) * 2
46+ w = mem [two , 4 , 6 , 8 ]
47+ mem [two , 4 , 6 , 8 ] = w
48+
4449 correct = dedent (
4550 """\
4651 module {
@@ -50,6 +55,19 @@ def test_simple_literal_indexing(ctx: MLIRContext):
5055 %c6 = arith.constant 6 : index
5156 %c8 = arith.constant 8 : index
5257 %0 = memref.load %alloc[%c2, %c4, %c6, %c8] : memref<10x22x333x4444xi32>
58+ %c1_i32 = arith.constant 1 : i32
59+ %c2_i32 = arith.constant 2 : i32
60+ %1 = arith.muli %c1_i32, %c2_i32 : i32
61+ %c4_0 = arith.constant 4 : index
62+ %c6_1 = arith.constant 6 : index
63+ %c8_2 = arith.constant 8 : index
64+ %2 = arith.index_cast %1 : i32 to index
65+ %3 = memref.load %alloc[%2, %c4_0, %c6_1, %c8_2] : memref<10x22x333x4444xi32>
66+ %c4_3 = arith.constant 4 : index
67+ %c6_4 = arith.constant 6 : index
68+ %c8_5 = arith.constant 8 : index
69+ %4 = arith.index_cast %1 : i32 to index
70+ memref.store %3, %alloc[%4, %c4_3, %c6_4, %c8_5] : memref<10x22x333x4444xi32>
5371 }
5472 """
5573 )
@@ -62,8 +80,8 @@ def test_simple_slicing(ctx: MLIRContext):
6280 w = mem [5 :]
6381 w = mem [:5 ]
6482
65- one = constant (1 , index = True ) * 2
66- w = mem [one :]
83+ two = constant (1 , index = True ) * 2
84+ w = mem [two :]
6785
6886 correct = dedent (
6987 """\
@@ -153,9 +171,9 @@ def test_ellipsis_and_full_slice_plus_coordinate_1(ctx: MLIRContext):
153171 w = mem [1 , :, ...]
154172 w = mem [1 , :, :, ...]
155173
156- one = constant (1 , index = True ) * 2
157- w = mem [one , :, :, ...]
158- w = mem [one :, :, :, ...]
174+ two = constant (1 , index = True ) * 2
175+ w = mem [two , :, :, ...]
176+ w = mem [two :, :, :, ...]
159177
160178 correct = dedent (
161179 f"""\
@@ -183,36 +201,43 @@ def test_ellipsis_and_full_slice_plus_coordinate_1(ctx: MLIRContext):
183201 w = mem [1 , :, :, :, :]
184202 except IndexError as e :
185203 assert (
186- str (e )
187- == "Too many indices for shaped type with rank: 5 non-None/Ellipsis indices for dim 4."
204+ str (e )
205+ == "Too many indices for shaped type with rank: 5 non-None/Ellipsis indices for dim 4."
188206 )
189207
190208
191-
192209def test_ellipsis_and_full_slice_plus_coordinate_2 (ctx : MLIRContext ):
193210 sizes = (10 , 22 , 333 , 4444 )
194211 dtype_size_in_bytes = np .int32 ().dtype .itemsize
195212 golden_mem = np .zeros (sizes , dtype = np .int32 )
196213 golden_w_1 = golden_mem [1 :2 , :]
214+ golden_w_1_rank_reduce = golden_mem [1 , :]
197215 golden_w_2 = golden_mem [1 :2 , :, :]
198216 golden_w_3 = golden_mem [1 :2 , :, :, :]
199217 golden_w_4 = golden_mem [:, 1 :2 ]
200218 golden_w_5 = golden_mem [:, :, 1 :2 ]
201219
202220 golden_w_1_strides = (np .array (golden_w_1 .strides ) // dtype_size_in_bytes ).tolist ()
221+ golden_w_1_rank_reduce_strides = (
222+ np .array (golden_w_1_rank_reduce .strides ) // dtype_size_in_bytes
223+ ).tolist ()
203224 golden_w_2_strides = (np .array (golden_w_2 .strides ) // dtype_size_in_bytes ).tolist ()
204225 golden_w_3_strides = (np .array (golden_w_3 .strides ) // dtype_size_in_bytes ).tolist ()
205226 golden_w_4_strides = (np .array (golden_w_4 .strides ) // dtype_size_in_bytes ).tolist ()
206227 golden_w_5_strides = (np .array (golden_w_5 .strides ) // dtype_size_in_bytes ).tolist ()
207228
208229 golden_w_1_offset = get_np_view_offset (golden_w_1 ) // dtype_size_in_bytes
230+ golden_w_1_rank_reduce_offset = (
231+ get_np_view_offset (golden_w_1_rank_reduce ) // dtype_size_in_bytes
232+ )
209233 golden_w_2_offset = get_np_view_offset (golden_w_2 ) // dtype_size_in_bytes
210234 golden_w_3_offset = get_np_view_offset (golden_w_3 ) // dtype_size_in_bytes
211235 golden_w_4_offset = get_np_view_offset (golden_w_4 ) // dtype_size_in_bytes
212236 golden_w_5_offset = get_np_view_offset (golden_w_5 ) // dtype_size_in_bytes
213237
214238 mem = alloc (sizes , T .i32 ())
215239 w = mem [1 , :]
240+ w = mem [1 , :, rank_reduce ]
216241 w = mem [1 , :, :]
217242 w = mem [1 , :, :, :]
218243 w = mem [:, 1 ]
@@ -224,13 +249,15 @@ def test_ellipsis_and_full_slice_plus_coordinate_2(ctx: MLIRContext):
224249 %c1 = arith.constant 1 : index
225250 %subview = memref.subview %alloc[1, 0, 0, 0] [1, 22, 333, 4444] [1, 1, 1, 1] : memref<10x22x333x4444xi32> to memref<1x22x333x4444xi32, strided<{ golden_w_1_strides } , offset: { golden_w_1_offset } >>
226251 %c1_0 = arith.constant 1 : index
227- %subview_2 = memref.subview %alloc[1, 0, 0, 0] [1, 22, 333, 4444] [1, 1, 1, 1] : memref<10x22x333x4444xi32> to memref<1x22x333x4444xi32 , strided<{ golden_w_2_strides } , offset: { golden_w_2_offset } >>
252+ %subview_1 = memref.subview %alloc[1, 0, 0, 0] [1, 22, 333, 4444] [1, 1, 1, 1] : memref<10x22x333x4444xi32> to memref<22x333x4444xi32 , strided<{ golden_w_1_rank_reduce_strides } , offset: { golden_w_1_rank_reduce_offset } >>
228253 %c1_2 = arith.constant 1 : index
229- %subview_3 = memref.subview %alloc[1, 0, 0, 0] [1, 22, 333, 4444] [1, 1, 1, 1] : memref<10x22x333x4444xi32> to memref<1x22x333x4444xi32, strided<{ golden_w_3_strides } , offset: { golden_w_3_offset } >>
254+ %subview_3 = memref.subview %alloc[1, 0, 0, 0] [1, 22, 333, 4444] [1, 1, 1, 1] : memref<10x22x333x4444xi32> to memref<1x22x333x4444xi32, strided<{ golden_w_2_strides } , offset: { golden_w_2_offset } >>
230255 %c1_4 = arith.constant 1 : index
231- %subview_5 = memref.subview %alloc[0, 1 , 0, 0] [10, 1 , 333, 4444] [1, 1, 1, 1] : memref<10x22x333x4444xi32> to memref<10x1x333x4444xi32 , strided<{ golden_w_4_strides } , offset: { golden_w_4_offset } >>
256+ %subview_5 = memref.subview %alloc[1, 0 , 0, 0] [1, 22 , 333, 4444] [1, 1, 1, 1] : memref<10x22x333x4444xi32> to memref<1x22x333x4444xi32 , strided<{ golden_w_3_strides } , offset: { golden_w_3_offset } >>
232257 %c1_6 = arith.constant 1 : index
233- %subview_7 = memref.subview %alloc[0, 0, 1, 0] [10, 22, 1, 4444] [1, 1, 1, 1] : memref<10x22x333x4444xi32> to memref<10x22x1x4444xi32, strided<{ golden_w_5_strides } , offset: { golden_w_5_offset } >>
258+ %subview_7 = memref.subview %alloc[0, 1, 0, 0] [10, 1, 333, 4444] [1, 1, 1, 1] : memref<10x22x333x4444xi32> to memref<10x1x333x4444xi32, strided<{ golden_w_4_strides } , offset: { golden_w_4_offset } >>
259+ %c1_8 = arith.constant 1 : index
260+ %subview_9 = memref.subview %alloc[0, 0, 1, 0] [10, 22, 1, 4444] [1, 1, 1, 1] : memref<10x22x333x4444xi32> to memref<10x22x1x4444xi32, strided<{ golden_w_5_strides } , offset: { golden_w_5_offset } >>
234261 }}
235262 """
236263 )
@@ -688,6 +715,9 @@ def test_memref_view(ctx: MLIRContext):
688715 ab_buffer = alloc (((m * k + k * n ) * byte_width_dtype ,), T .i8 ())
689716 a_buffer = memref .view (ab_buffer , (m , k ), dtype = dtype )
690717 b_buffer = memref .view (ab_buffer , (k , n ), dtype = dtype , shift = m * k )
718+ two = constant (1 ) * 2
719+ # TODO(max): should the type here also contain the offset...?
720+ c_buffer = memref .view (ab_buffer , (k , n ), dtype = dtype , shift = m * k + two )
691721
692722 correct = dedent (
693723 """\
@@ -697,6 +727,15 @@ def test_memref_view(ctx: MLIRContext):
697727 %view = memref.view %alloc[%c0][] : memref<2048xi8> to memref<16x16xf32>
698728 %c1024 = arith.constant 1024 : index
699729 %view_0 = memref.view %alloc[%c1024][] : memref<2048xi8> to memref<16x16xf32>
730+ %c1_i32 = arith.constant 1 : i32
731+ %c2_i32 = arith.constant 2 : i32
732+ %0 = arith.muli %c1_i32, %c2_i32 : i32
733+ %c256_i32 = arith.constant 256 : i32
734+ %1 = arith.addi %c256_i32, %0 : i32
735+ %c4_i32 = arith.constant 4 : i32
736+ %2 = arith.muli %1, %c4_i32 : i32
737+ %3 = arith.index_cast %2 : i32 to index
738+ %view_1 = memref.view %alloc[%3][] : memref<2048xi8> to memref<16x16xf32>
700739 }
701740 """
702741 )
0 commit comments