Skip to content

Commit 1f48bd5

Browse files
committed
add test for mem[v:]
1 parent a7c3755 commit 1f48bd5

File tree

4 files changed

+117
-66
lines changed

4 files changed

+117
-66
lines changed

mlir/extras/dialects/ext/_shaped_value.py

Lines changed: 46 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -163,14 +163,13 @@ def _indices_to_indexer(
163163
elif isinstance(idx_e, slice):
164164
# Normalize the slice to use None when possible
165165
start, stop, step = idx_e.start, idx_e.stop, idx_e.step
166-
if step is None or isinstance(step, int) and step == 1:
166+
if isinstance(step, int) and step == 1:
167167
step = None
168168
if step is None:
169169
if start is None or isinstance(start, int) and start == 0:
170170
start = None
171171
if (
172-
stop is None
173-
or isinstance(stop, int)
172+
isinstance(stop, int)
174173
and in_shape[in_axis] != ShapedType.get_dynamic_size()
175174
and stop >= in_shape[in_axis]
176175
):
@@ -205,6 +204,7 @@ def _indices_to_indexer(
205204
step = 1
206205
if stop is None:
207206
stop = in_shape[in_axis]
207+
208208
indices[in_axis] = slice(start, stop, step)
209209

210210
out_axis += 1
@@ -307,20 +307,55 @@ def _has_index_type(e: Any) -> bool:
307307

308308
def _is_constant_index(e: Any) -> bool:
309309
return (
310-
isinstance(e, Scalar)
311-
and e.is_constant()
310+
(isinstance(e, Scalar) and e.is_constant())
312311
or isinstance(e, (int, float, bool))
313-
or isinstance(e, slice)
314-
and _is_constant_scalar(e.start)
315-
and _is_constant_scalar(e.stop)
316-
and _is_constant_scalar(e.step)
312+
or (
313+
isinstance(e, slice)
314+
and _is_constant_scalar(e.start)
315+
and _is_constant_scalar(e.stop)
316+
and _is_constant_scalar(e.step)
317+
)
317318
)
318319

319320

320321
def _is_constant_scalar(e: Any) -> bool:
321322
return (
322-
isinstance(e, Scalar)
323-
and e.is_constant()
323+
(isinstance(e, Scalar) and e.is_constant())
324324
or (isinstance(e, (int, float, bool)) and e != ShapedType.get_dynamic_size())
325325
or e is None
326326
)
327+
328+
329+
def _maybe_compute_size(start, stop, step):
330+
from ....dialects import arith
331+
332+
# TODO(max): figure out how to use actual canonicalizers
333+
if (
334+
isinstance(start, Value)
335+
and isinstance(stop, Value)
336+
and stop.owner.operands[0]._eq(start)
337+
and stop.owner.operands[1].is_constant()
338+
):
339+
return stop.owner.operands[1].literal_value
340+
elif (
341+
isinstance(start, Value)
342+
and isinstance(start.owner.opview, arith.MulIOp)
343+
and isinstance(stop, Value)
344+
and isinstance(stop.owner.opview, arith.MulIOp)
345+
and isinstance(stop.owner.operands[0].owner.opview, arith.AddIOp)
346+
and start.owner.operands[0] == stop.owner.operands[0].owner.operands[0]
347+
and stop.owner.operands[1].is_constant()
348+
and isinstance(step, int)
349+
or (isinstance(step, Scalar) and step.is_constant())
350+
):
351+
# looks like this
352+
# l = lambda l: l * D
353+
# r = lambda r: (r + 1) * D
354+
# a, b, c = (
355+
# A[l(i) : r(i), l(j) : r(j)],
356+
# B[l(i) : r(i), l(j) : r(j)],
357+
# C[l(i) : r(i), l(j) : r(j)],
358+
# )
359+
return stop.owner.operands[1]
360+
else:
361+
return stop - start

mlir/extras/dialects/ext/memref.py

Lines changed: 48 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import numpy as np
77

8-
from ._shaped_value import ShapedValue, _indices_to_indexer
8+
from ._shaped_value import ShapedValue, _indices_to_indexer, _maybe_compute_size
99
from .arith import Scalar, constant, index_cast
1010
from .tensor import compute_result_shape_reassoc_list
1111
from .vector import Vector
@@ -26,8 +26,8 @@
2626
)
2727
from ....dialects.memref import (
2828
_is_static_int_like,
29-
_infer_memref_subview_result_type,
3029
_generated_subview,
30+
_is_constant_int_like,
3131
)
3232
from ....dialects.memref import *
3333
from ....ir import (
@@ -283,37 +283,44 @@ def expand_shape(
283283
)
284284

285285

286-
def _maybe_compute_size(start, stop, step):
287-
# TODO(max): figure out how to use actual canonicalizers
288-
if (
289-
isinstance(start, Value)
290-
and isinstance(stop, Value)
291-
and stop.owner.operands[0]._eq(start)
292-
and stop.owner.operands[1].is_constant()
293-
):
294-
return stop.owner.operands[1].literal_value
295-
elif (
296-
isinstance(start, Value)
297-
and isinstance(start.owner.opview, arith.MulIOp)
298-
and isinstance(stop, Value)
299-
and isinstance(stop.owner.opview, arith.MulIOp)
300-
and isinstance(stop.owner.operands[0].owner.opview, arith.AddIOp)
301-
and start.owner.operands[0] == stop.owner.operands[0].owner.operands[0]
302-
and stop.owner.operands[1].is_constant()
303-
and isinstance(step, int)
304-
or (isinstance(step, Scalar) and step.is_constant())
305-
):
306-
# looks like this
307-
# l = lambda l: l * D
308-
# r = lambda r: (r + 1) * D
309-
# a, b, c = (
310-
# A[l(i) : r(i), l(j) : r(j)],
311-
# B[l(i) : r(i), l(j) : r(j)],
312-
# C[l(i) : r(i), l(j) : r(j)],
313-
# )
314-
return stop.owner.operands[1]
286+
def _infer_memref_subview_result_type(source_memref_type, offsets, sizes, strides):
287+
source_strides, source_offset = source_memref_type.get_strides_and_offset()
288+
# "canonicalize" from tuple|list -> list
289+
offsets, sizes, strides, source_strides = map(
290+
list, (offsets, sizes, strides, source_strides)
291+
)
292+
293+
if any(not _is_static_int_like(i) for i in offsets + [source_offset]):
294+
target_offset = ShapedType.get_dynamic_size()
295+
else:
296+
target_offset = source_offset
297+
for offset, target_stride in zip(offsets, source_strides):
298+
target_offset += offset * target_stride
299+
300+
target_strides = []
301+
for source_stride, static_stride in zip(source_strides, strides):
302+
target_strides.append(source_stride * static_stride)
303+
304+
if all(isinstance(s, int) for s in sizes):
305+
# If default striding then no need to complicate things for downstream ops (e.g., expand_shape).
306+
default_strides = list(accumulate(sizes[1:][::-1], operator.mul))[::-1] + [1]
315307
else:
316-
return stop - start
308+
default_strides = None
309+
sizes = [
310+
s if isinstance(s, int) else ShapedType.get_dynamic_size() for s in sizes
311+
]
312+
313+
if target_strides == default_strides and target_offset == 0:
314+
layout = None
315+
else:
316+
layout = StridedLayoutAttr.get(target_offset, target_strides)
317+
318+
return MemRefType.get(
319+
sizes,
320+
source_memref_type.element_type,
321+
layout,
322+
source_memref_type.memory_space,
323+
)
317324

318325

319326
def subview(
@@ -333,21 +340,17 @@ def subview(
333340
sizes = []
334341
if strides is None:
335342
strides = []
336-
source_strides, source_offset = source.type.get_strides_and_offset()
337-
if result_type is None and all(
338-
all(_is_static_int_like(i) for i in s) for s in [sizes, strides, source_strides]
339-
):
343+
344+
for s in [offsets, sizes, strides]:
345+
for idx, i in enumerate(s):
346+
if _is_constant_int_like(i):
347+
s[idx] = i.owner.opview.literal_value
348+
349+
if result_type is None:
340350
# If any are arith.constant results then this will canonicalize to python int
341351
# (which can then be used to fully specify the subview).
342-
(
343-
offsets,
344-
sizes,
345-
strides,
346-
result_type,
347-
) = _infer_memref_subview_result_type(source.type, offsets, sizes, strides)
348-
elif result_type is None:
349-
raise ValueError(
350-
"mixed static/dynamic offset/sizes/strides requires explicit result type."
352+
result_type = _infer_memref_subview_result_type(
353+
source.type, offsets, sizes, strides
351354
)
352355

353356
offsets, _packed_offsets, static_offsets = _dispatch_mixed_values(offsets)

tests/test_gpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1088,7 +1088,7 @@ def test_amdgpu_bank_conflicts(ctx: MLIRContext):
10881088

10891089
set_container_module(ctx.module)
10901090

1091-
M = 1024
1091+
M = 128
10921092

10931093
@gpu_func
10941094
def no_bank_conflicts(A: T.memref(M, M, T.f32()), B: T.memref(M, M, T.f32())):

tests/test_memref.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,21 @@ def test_simple_slicing(ctx: MLIRContext):
6262
w = mem[5:]
6363
w = mem[:5]
6464

65+
one = constant(1, index=True) * 2
66+
w = mem[one:]
67+
6568
correct = dedent(
6669
"""\
6770
module {
6871
%alloc = memref.alloc() : memref<10xi32>
6972
%subview = memref.subview %alloc[5] [5] [1] : memref<10xi32> to memref<5xi32, strided<[1], offset: 5>>
7073
%subview_0 = memref.subview %alloc[0] [5] [1] : memref<10xi32> to memref<5xi32>
74+
%c1 = arith.constant 1 : index
75+
%c2 = arith.constant 2 : index
76+
%0 = arith.muli %c1, %c2 : index
77+
%c10 = arith.constant 10 : index
78+
%1 = arith.subi %c10, %0 : index
79+
%subview_1 = memref.subview %alloc[%0] [%1] [1] : memref<10xi32> to memref<?xi32, strided<[1], offset: ?>>
7180
}
7281
"""
7382
)
@@ -146,14 +155,7 @@ def test_ellipsis_and_full_slice_plus_coordinate_1(ctx: MLIRContext):
146155

147156
one = constant(1, index=True) * 2
148157
w = mem[one, :, :, ...]
149-
150-
try:
151-
w = mem[1, :, :, :, :]
152-
except IndexError as e:
153-
assert (
154-
str(e)
155-
== "Too many indices for shaped type with rank: 5 non-None/Ellipsis indices for dim 4."
156-
)
158+
w = mem[one:, :, :, ...]
157159

158160
correct = dedent(
159161
f"""\
@@ -169,12 +171,23 @@ def test_ellipsis_and_full_slice_plus_coordinate_1(ctx: MLIRContext):
169171
%c2 = arith.constant 2 : index
170172
%0 = arith.muli %c1_4, %c2 : index
171173
%subview_5 = memref.subview %alloc[%0, 0, 0, 0] [1, 22, 333, 4444] [1, 1, 1, 1] : memref<10x22x333x4444xi32> to memref<1x22x333x4444xi32, strided<{golden_w_3_strides}, offset: ?>>
172-
%c1_6 = arith.constant 1 : index
174+
%c10 = arith.constant 10 : index
175+
%1 = arith.subi %c10, %0 : index
176+
%subview_6 = memref.subview %alloc[%0, 0, 0, 0] [%1, 22, 333, 4444] [1, 1, 1, 1] : memref<10x22x333x4444xi32> to memref<?x22x333x4444xi32, strided<{golden_w_3_strides}, offset: ?>>
173177
}}
174178
"""
175179
)
176180
filecheck(correct, ctx.module)
177181

182+
try:
183+
w = mem[1, :, :, :, :]
184+
except IndexError as e:
185+
assert (
186+
str(e)
187+
== "Too many indices for shaped type with rank: 5 non-None/Ellipsis indices for dim 4."
188+
)
189+
190+
178191

179192
def test_ellipsis_and_full_slice_plus_coordinate_2(ctx: MLIRContext):
180193
sizes = (10, 22, 333, 4444)

0 commit comments

Comments
 (0)