Skip to content

Commit 48f980c

Browse files
authored
[mlir][memref] Add memref alias folding for masked transfers (#71476)
The contents of a mask on a masked transfer are unaffected by the particular region of memory being read/stored to, so just forward the mask in subview folding patterns.
1 parent 0c6a77b commit 48f980c

File tree

2 files changed

+123
-4
lines changed

2 files changed

+123
-4
lines changed

mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -346,8 +346,6 @@ preconditionsFoldSubViewOpImpl(RewriterBase &rewriter, XferOp xferOp,
346346
"must be a vector transfer op");
347347
if (xferOp.hasOutOfBoundsDim())
348348
return rewriter.notifyMatchFailure(xferOp, "out of bounds transfer dim");
349-
if (xferOp.getMask())
350-
return rewriter.notifyMatchFailure(xferOp, "masked transfer");
351349
if (!subviewOp.hasUnitStride()) {
352350
return rewriter.notifyMatchFailure(
353351
xferOp, "non-1 stride subview, need to track strides in folded memref");
@@ -428,7 +426,7 @@ LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
428426
AffineMapAttr::get(expandDimsToRank(
429427
op.getPermutationMap(), subViewOp.getSourceType().getRank(),
430428
subViewOp.getDroppedDims())),
431-
op.getPadding(), /*mask=*/Value(), op.getInBoundsAttr());
429+
op.getPadding(), op.getMask(), op.getInBoundsAttr());
432430
})
433431
.Case([&](gpu::SubgroupMmaLoadMatrixOp op) {
434432
rewriter.replaceOpWithNewOp<gpu::SubgroupMmaLoadMatrixOp>(
@@ -557,7 +555,7 @@ LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
557555
AffineMapAttr::get(expandDimsToRank(
558556
op.getPermutationMap(), subViewOp.getSourceType().getRank(),
559557
subViewOp.getDroppedDims())),
560-
op.getInBoundsAttr());
558+
op.getMask(), op.getInBoundsAttr());
561559
})
562560
.Case([&](gpu::SubgroupMmaStoreMatrixOp op) {
563561
rewriter.replaceOpWithNewOp<gpu::SubgroupMmaStoreMatrixOp>(

mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,127 @@ func.func @fold_vector_transfer_write_with_inner_rank_reduced_subview(
266266

267267
// -----
268268

269+
func.func @fold_masked_vector_transfer_read_with_subview(
270+
%arg0 : memref<?x?xf32, strided<[?, ?], offset: ?>>,
271+
%arg1: index, %arg2 : index, %arg3 : index, %arg4: index, %arg5 : index,
272+
%arg6 : index, %mask : vector<4xi1>) -> vector<4xf32> {
273+
%cst = arith.constant 0.0 : f32
274+
%0 = memref.subview %arg0[%arg1, %arg2] [%arg3, %arg4] [1, 1]
275+
: memref<?x?xf32, strided<[?, ?], offset: ?>> to
276+
memref<?x?xf32, strided<[?, ?], offset: ?>>
277+
%1 = vector.transfer_read %0[%arg5, %arg6], %cst, %mask {in_bounds = [true]}
278+
: memref<?x?xf32, strided<[?, ?], offset: ?>>, vector<4xf32>
279+
return %1 : vector<4xf32>
280+
}
281+
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
282+
// CHECK: func @fold_masked_vector_transfer_read_with_subview
283+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?xf32, strided<[?, ?], offset: ?>>
284+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
285+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
286+
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
287+
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
288+
// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index
289+
// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index
290+
// CHECK-SAME: %[[MASK:[a-zA-Z0-9]+]]: vector<4xi1>
291+
// CHECK-DAG: %[[IDX0:.+]] = affine.apply #[[MAP1]]()[%[[ARG1]], %[[ARG5]]]
292+
// CHECK-DAG: %[[IDX1:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG6]]]
293+
// CHECK: vector.transfer_read %[[ARG0]][%[[IDX0]], %[[IDX1]]], %{{.*}}, %[[MASK]] {{.*}} : memref<?x?xf32
294+
295+
// -----
296+
297+
func.func @fold_masked_vector_transfer_read_with_rank_reducing_subview(
298+
%arg0 : memref<?x?x?x?xf32, strided<[?, ?, ?, ?], offset: ?>>,
299+
%arg1: index, %arg2 : index, %arg3 : index, %arg4: index, %arg5 : index,
300+
%arg6 : index, %mask : vector<4x3xi1>) -> vector<3x4xf32> {
301+
%cst = arith.constant 0.0 : f32
302+
%0 = memref.subview %arg0[0, %arg1, 0, %arg2] [1, %arg3, 1, %arg4] [1, 1, 1, 1]
303+
: memref<?x?x?x?xf32, strided<[?, ?, ?, ?], offset: ?>> to
304+
memref<?x?xf32, strided<[?, ?], offset: ?>>
305+
%1 = vector.transfer_read %0[%arg5, %arg6], %cst, %mask {
306+
permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]}
307+
: memref<?x?xf32, strided<[?, ?], offset: ?>>, vector<3x4xf32>
308+
return %1 : vector<3x4xf32>
309+
}
310+
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
311+
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d1)>
312+
// CHECK: func @fold_masked_vector_transfer_read_with_rank_reducing_subview
313+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32, strided<[?, ?, ?, ?], offset: ?>>
314+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
315+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
316+
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
317+
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
318+
// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index
319+
// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index
320+
// CHECK-SAME: %[[MASK:[a-zA-Z0-9]+]]: vector<4x3xi1>
321+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
322+
// CHECK-DAG: %[[PAD:.+]] = arith.constant 0.000000e+00 : f32
323+
// CHECK-DAG: %[[IDX0:.+]] = affine.apply #[[MAP0]]()[%[[ARG1]], %[[ARG5]]]
324+
// CHECK-DAG: %[[IDX1:.+]] = affine.apply #[[MAP0]]()[%[[ARG2]], %[[ARG6]]]
325+
// CHECK: vector.transfer_read %[[ARG0]][%[[C0]], %[[IDX0]], %[[C0]], %[[IDX1]]], %[[PAD]], %[[MASK]] {{.*}} permutation_map = #[[MAP1]]} : memref<?x?x?x?xf32
326+
327+
// -----
328+
329+
func.func @fold_masked_vector_transfer_write_with_subview(
330+
%arg0 : memref<?x?xf32, strided<[?, ?], offset: ?>>,
331+
%arg1 : vector<4xf32>, %arg2: index, %arg3 : index, %arg4 : index,
332+
%arg5: index, %arg6 : index, %arg7 : index, %mask : vector<4xi1>) {
333+
%cst = arith.constant 0.0 : f32
334+
%0 = memref.subview %arg0[%arg2, %arg3] [%arg4, %arg5] [1, 1]
335+
: memref<?x?xf32, strided<[?, ?], offset: ?>> to
336+
memref<?x?xf32, strided<[?, ?], offset: ?>>
337+
vector.transfer_write %arg1, %0[%arg6, %arg7], %mask {in_bounds = [true]}
338+
: vector<4xf32>, memref<?x?xf32, strided<[?, ?], offset: ?>>
339+
return
340+
}
341+
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
342+
// CHECK: func @fold_masked_vector_transfer_write_with_subview
343+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?xf32, strided<[?, ?], offset: ?>>
344+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: vector<4xf32>
345+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
346+
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
347+
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
348+
// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index
349+
// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index
350+
// CHECK-SAME: %[[ARG7:[a-zA-Z0-9]+]]: index
351+
// CHECK-SAME: %[[MASK:[a-zA-Z0-9]+]]: vector<4xi1>
352+
// CHECK-DAG: %[[IDX0:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG6]]]
353+
// CHECK-DAG: %[[IDX1:.+]] = affine.apply #[[MAP1]]()[%[[ARG3]], %[[ARG7]]]
354+
// CHECK-DAG: vector.transfer_write %[[ARG1]], %[[ARG0]][%[[IDX0]], %[[IDX1]]], %[[MASK]] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32
355+
356+
// -----
357+
358+
func.func @fold_masked_vector_transfer_write_with_rank_reducing_subview(
359+
%arg0 : memref<?x?x?x?xf32, strided<[?, ?, ?, ?], offset: ?>>,
360+
%arg1 : vector<3x4xf32>, %arg2: index, %arg3 : index, %arg4 : index,
361+
%arg5: index, %arg6 : index, %arg7 : index, %mask : vector<4x3xi1>) {
362+
%cst = arith.constant 0.0 : f32
363+
%0 = memref.subview %arg0[0, %arg2, 0, %arg3] [1, %arg4, 1, %arg5] [1, 1, 1, 1]
364+
: memref<?x?x?x?xf32, strided<[?, ?, ?, ?], offset: ?>> to
365+
memref<?x?xf32, strided<[?, ?], offset: ?>>
366+
vector.transfer_write %arg1, %0[%arg6, %arg7], %mask {
367+
permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]}
368+
: vector<3x4xf32>, memref<?x?xf32, strided<[?, ?], offset: ?>>
369+
return
370+
}
371+
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
372+
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d1)>
373+
// CHECK: func @fold_masked_vector_transfer_write_with_rank_reducing_subview
374+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32, strided<[?, ?, ?, ?], offset: ?>>
375+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: vector<3x4xf32>
376+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
377+
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
378+
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
379+
// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index
380+
// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index
381+
// CHECK-SAME: %[[ARG7:[a-zA-Z0-9]+]]: index
382+
// CHECK-SAME: %[[MASK:[a-zA-Z0-9]+]]: vector<4x3xi1>
383+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
384+
// CHECK-DAG: %[[IDX0:.+]] = affine.apply #[[MAP0]]()[%[[ARG2]], %[[ARG6]]]
385+
// CHECK-DAG: %[[IDX1:.+]] = affine.apply #[[MAP0]]()[%[[ARG3]], %[[ARG7]]]
386+
// CHECK-DAG: vector.transfer_write %[[ARG1]], %[[ARG0]][%[[C0]], %[[IDX0]], %[[C0]], %[[IDX1]]], %[[ARG8]] {in_bounds = [true, true], permutation_map = #[[MAP1]]} : vector<3x4xf32>, memref<?x?x?x?xf32
387+
388+
// -----
389+
269390
// Test with affine.load/store ops. We only do a basic test here since the
270391
// logic is identical to that with memref.load/store ops. The same affine.apply
271392
// ops would be generated.

0 commit comments

Comments
 (0)