Skip to content

Commit e556cbf

Browse files
committed
add more memref tests
1 parent ede397e commit e556cbf

File tree

3 files changed

+59
-19
lines changed

3 files changed

+59
-19
lines changed

examples/flash_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def manual_attn(q, k, v):
104104
return y
105105

106106

107-
rank_reduce = memref.MemRef.rank_reduce
107+
rank_reduce = memref.rank_reduce
108108

109109

110110
# https://github.com/tspeterkim/flash-attention-minimal/blob/main/flash.cu

mlir/extras/dialects/ext/memref.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,9 @@ def store(
170170
return get_op_result_or_op_results(StoreOp(value, memref, indices, loc=loc, ip=ip))
171171

172172

173+
rank_reduce = object()
174+
175+
173176
@register_value_caster(MemRefType.static_typeid)
174177
@ShapedValue
175178
class MemRef(Value):
@@ -179,8 +182,6 @@ def __str__(self):
179182
def __repr__(self):
180183
return str(self)
181184

182-
rank_reduce = object()
183-
184185
def __getitem__(self, idx: tuple) -> "MemRef":
185186
loc = get_user_code_loc()
186187

@@ -195,9 +196,9 @@ def __getitem__(self, idx: tuple) -> "MemRef":
195196
return expand_shape(self, (0,), loc=loc)
196197

197198
idx = list((idx,) if isinstance(idx, (int, Scalar, slice)) else idx)
198-
rank_reduce = MemRef.rank_reduce in idx
199-
if rank_reduce:
200-
idx.remove(MemRef.rank_reduce)
199+
should_rank_reduce = rank_reduce in idx
200+
if should_rank_reduce:
201+
idx.remove(rank_reduce)
201202

202203
for i, d in enumerate(idx):
203204
# TODO(max): rethink this since subview and etc probably take constant attributes?
@@ -207,7 +208,7 @@ def __getitem__(self, idx: tuple) -> "MemRef":
207208
if all(isinstance(d, Scalar) for d in idx) and len(idx) == len(self.shape):
208209
return load(self, idx, loc=loc)
209210
else:
210-
return _subview(self, tuple(idx), rank_reduce=rank_reduce, loc=loc)
211+
return _subview(self, tuple(idx), rank_reduce=should_rank_reduce, loc=loc)
211212

212213
def __setitem__(self, idx, val):
213214
loc = get_user_code_loc()

tests/test_memref.py

Lines changed: 51 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
alloca_scope,
1818
alloca_scope_return,
1919
global_,
20+
rank_reduce,
2021
)
2122
from mlir.extras.dialects.ext.scf import (
2223
range_,
@@ -41,6 +42,10 @@ def test_simple_literal_indexing(ctx: MLIRContext):
4142
w = mem[2, 4, 6, 8]
4243
assert isinstance(w, Scalar)
4344

45+
two = constant(1) * 2
46+
w = mem[two, 4, 6, 8]
47+
mem[two, 4, 6, 8] = w
48+
4449
correct = dedent(
4550
"""\
4651
module {
@@ -50,6 +55,19 @@ def test_simple_literal_indexing(ctx: MLIRContext):
5055
%c6 = arith.constant 6 : index
5156
%c8 = arith.constant 8 : index
5257
%0 = memref.load %alloc[%c2, %c4, %c6, %c8] : memref<10x22x333x4444xi32>
58+
%c1_i32 = arith.constant 1 : i32
59+
%c2_i32 = arith.constant 2 : i32
60+
%1 = arith.muli %c1_i32, %c2_i32 : i32
61+
%c4_0 = arith.constant 4 : index
62+
%c6_1 = arith.constant 6 : index
63+
%c8_2 = arith.constant 8 : index
64+
%2 = arith.index_cast %1 : i32 to index
65+
%3 = memref.load %alloc[%2, %c4_0, %c6_1, %c8_2] : memref<10x22x333x4444xi32>
66+
%c4_3 = arith.constant 4 : index
67+
%c6_4 = arith.constant 6 : index
68+
%c8_5 = arith.constant 8 : index
69+
%4 = arith.index_cast %1 : i32 to index
70+
memref.store %3, %alloc[%4, %c4_3, %c6_4, %c8_5] : memref<10x22x333x4444xi32>
5371
}
5472
"""
5573
)
@@ -62,8 +80,8 @@ def test_simple_slicing(ctx: MLIRContext):
6280
w = mem[5:]
6381
w = mem[:5]
6482

65-
one = constant(1, index=True) * 2
66-
w = mem[one:]
83+
two = constant(1, index=True) * 2
84+
w = mem[two:]
6785

6886
correct = dedent(
6987
"""\
@@ -153,9 +171,9 @@ def test_ellipsis_and_full_slice_plus_coordinate_1(ctx: MLIRContext):
153171
w = mem[1, :, ...]
154172
w = mem[1, :, :, ...]
155173

156-
one = constant(1, index=True) * 2
157-
w = mem[one, :, :, ...]
158-
w = mem[one:, :, :, ...]
174+
two = constant(1, index=True) * 2
175+
w = mem[two, :, :, ...]
176+
w = mem[two:, :, :, ...]
159177

160178
correct = dedent(
161179
f"""\
@@ -183,36 +201,43 @@ def test_ellipsis_and_full_slice_plus_coordinate_1(ctx: MLIRContext):
183201
w = mem[1, :, :, :, :]
184202
except IndexError as e:
185203
assert (
186-
str(e)
187-
== "Too many indices for shaped type with rank: 5 non-None/Ellipsis indices for dim 4."
204+
str(e)
205+
== "Too many indices for shaped type with rank: 5 non-None/Ellipsis indices for dim 4."
188206
)
189207

190208

191-
192209
def test_ellipsis_and_full_slice_plus_coordinate_2(ctx: MLIRContext):
193210
sizes = (10, 22, 333, 4444)
194211
dtype_size_in_bytes = np.int32().dtype.itemsize
195212
golden_mem = np.zeros(sizes, dtype=np.int32)
196213
golden_w_1 = golden_mem[1:2, :]
214+
golden_w_1_rank_reduce = golden_mem[1, :]
197215
golden_w_2 = golden_mem[1:2, :, :]
198216
golden_w_3 = golden_mem[1:2, :, :, :]
199217
golden_w_4 = golden_mem[:, 1:2]
200218
golden_w_5 = golden_mem[:, :, 1:2]
201219

202220
golden_w_1_strides = (np.array(golden_w_1.strides) // dtype_size_in_bytes).tolist()
221+
golden_w_1_rank_reduce_strides = (
222+
np.array(golden_w_1_rank_reduce.strides) // dtype_size_in_bytes
223+
).tolist()
203224
golden_w_2_strides = (np.array(golden_w_2.strides) // dtype_size_in_bytes).tolist()
204225
golden_w_3_strides = (np.array(golden_w_3.strides) // dtype_size_in_bytes).tolist()
205226
golden_w_4_strides = (np.array(golden_w_4.strides) // dtype_size_in_bytes).tolist()
206227
golden_w_5_strides = (np.array(golden_w_5.strides) // dtype_size_in_bytes).tolist()
207228

208229
golden_w_1_offset = get_np_view_offset(golden_w_1) // dtype_size_in_bytes
230+
golden_w_1_rank_reduce_offset = (
231+
get_np_view_offset(golden_w_1_rank_reduce) // dtype_size_in_bytes
232+
)
209233
golden_w_2_offset = get_np_view_offset(golden_w_2) // dtype_size_in_bytes
210234
golden_w_3_offset = get_np_view_offset(golden_w_3) // dtype_size_in_bytes
211235
golden_w_4_offset = get_np_view_offset(golden_w_4) // dtype_size_in_bytes
212236
golden_w_5_offset = get_np_view_offset(golden_w_5) // dtype_size_in_bytes
213237

214238
mem = alloc(sizes, T.i32())
215239
w = mem[1, :]
240+
w = mem[1, :, rank_reduce]
216241
w = mem[1, :, :]
217242
w = mem[1, :, :, :]
218243
w = mem[:, 1]
@@ -224,13 +249,15 @@ def test_ellipsis_and_full_slice_plus_coordinate_2(ctx: MLIRContext):
224249
%c1 = arith.constant 1 : index
225250
%subview = memref.subview %alloc[1, 0, 0, 0] [1, 22, 333, 4444] [1, 1, 1, 1] : memref<10x22x333x4444xi32> to memref<1x22x333x4444xi32, strided<{golden_w_1_strides}, offset: {golden_w_1_offset}>>
226251
%c1_0 = arith.constant 1 : index
227-
%subview_2 = memref.subview %alloc[1, 0, 0, 0] [1, 22, 333, 4444] [1, 1, 1, 1] : memref<10x22x333x4444xi32> to memref<1x22x333x4444xi32, strided<{golden_w_2_strides}, offset: {golden_w_2_offset}>>
252+
%subview_1 = memref.subview %alloc[1, 0, 0, 0] [1, 22, 333, 4444] [1, 1, 1, 1] : memref<10x22x333x4444xi32> to memref<22x333x4444xi32, strided<{golden_w_1_rank_reduce_strides}, offset: {golden_w_1_rank_reduce_offset}>>
228253
%c1_2 = arith.constant 1 : index
229-
%subview_3 = memref.subview %alloc[1, 0, 0, 0] [1, 22, 333, 4444] [1, 1, 1, 1] : memref<10x22x333x4444xi32> to memref<1x22x333x4444xi32, strided<{golden_w_3_strides}, offset: {golden_w_3_offset}>>
254+
%subview_3 = memref.subview %alloc[1, 0, 0, 0] [1, 22, 333, 4444] [1, 1, 1, 1] : memref<10x22x333x4444xi32> to memref<1x22x333x4444xi32, strided<{golden_w_2_strides}, offset: {golden_w_2_offset}>>
230255
%c1_4 = arith.constant 1 : index
231-
%subview_5 = memref.subview %alloc[0, 1, 0, 0] [10, 1, 333, 4444] [1, 1, 1, 1] : memref<10x22x333x4444xi32> to memref<10x1x333x4444xi32, strided<{golden_w_4_strides}, offset: {golden_w_4_offset}>>
256+
%subview_5 = memref.subview %alloc[1, 0, 0, 0] [1, 22, 333, 4444] [1, 1, 1, 1] : memref<10x22x333x4444xi32> to memref<1x22x333x4444xi32, strided<{golden_w_3_strides}, offset: {golden_w_3_offset}>>
232257
%c1_6 = arith.constant 1 : index
233-
%subview_7 = memref.subview %alloc[0, 0, 1, 0] [10, 22, 1, 4444] [1, 1, 1, 1] : memref<10x22x333x4444xi32> to memref<10x22x1x4444xi32, strided<{golden_w_5_strides}, offset: {golden_w_5_offset}>>
258+
%subview_7 = memref.subview %alloc[0, 1, 0, 0] [10, 1, 333, 4444] [1, 1, 1, 1] : memref<10x22x333x4444xi32> to memref<10x1x333x4444xi32, strided<{golden_w_4_strides}, offset: {golden_w_4_offset}>>
259+
%c1_8 = arith.constant 1 : index
260+
%subview_9 = memref.subview %alloc[0, 0, 1, 0] [10, 22, 1, 4444] [1, 1, 1, 1] : memref<10x22x333x4444xi32> to memref<10x22x1x4444xi32, strided<{golden_w_5_strides}, offset: {golden_w_5_offset}>>
234261
}}
235262
"""
236263
)
@@ -688,6 +715,9 @@ def test_memref_view(ctx: MLIRContext):
688715
ab_buffer = alloc(((m * k + k * n) * byte_width_dtype,), T.i8())
689716
a_buffer = memref.view(ab_buffer, (m, k), dtype=dtype)
690717
b_buffer = memref.view(ab_buffer, (k, n), dtype=dtype, shift=m * k)
718+
two = constant(1) * 2
719+
# TODO(max): should the type here also contain the offset...?
720+
c_buffer = memref.view(ab_buffer, (k, n), dtype=dtype, shift=m * k + two)
691721

692722
correct = dedent(
693723
"""\
@@ -697,6 +727,15 @@ def test_memref_view(ctx: MLIRContext):
697727
%view = memref.view %alloc[%c0][] : memref<2048xi8> to memref<16x16xf32>
698728
%c1024 = arith.constant 1024 : index
699729
%view_0 = memref.view %alloc[%c1024][] : memref<2048xi8> to memref<16x16xf32>
730+
%c1_i32 = arith.constant 1 : i32
731+
%c2_i32 = arith.constant 2 : i32
732+
%0 = arith.muli %c1_i32, %c2_i32 : i32
733+
%c256_i32 = arith.constant 256 : i32
734+
%1 = arith.addi %c256_i32, %0 : i32
735+
%c4_i32 = arith.constant 4 : i32
736+
%2 = arith.muli %1, %c4_i32 : i32
737+
%3 = arith.index_cast %2 : i32 to index
738+
%view_1 = memref.view %alloc[%3][] : memref<2048xi8> to memref<16x16xf32>
700739
}
701740
"""
702741
)

0 commit comments

Comments
 (0)