Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 79 additions & 11 deletions mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -518,10 +518,25 @@ LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
return failure();
llvm::TypeSwitch<Operation *, void>(loadOp)
.Case<affine::AffineLoadOp, memref::LoadOp>([&](auto op) {
rewriter.replaceOpWithNewOp<decltype(op)>(
.Case([&](affine::AffineLoadOp op) {
rewriter.replaceOpWithNewOp<affine::AffineLoadOp>(
loadOp, expandShapeOp.getViewSource(), sourceIndices);
})
.Case([&](memref::LoadOp op) {
rewriter.replaceOpWithNewOp<memref::LoadOp>(
loadOp, expandShapeOp.getViewSource(), sourceIndices,
op.getNontemporal());
})
.Case([&](vector::LoadOp op) {
rewriter.replaceOpWithNewOp<vector::LoadOp>(
op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
op.getNontemporal());
})
.Case([&](vector::MaskedLoadOp op) {
rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
op.getMask(), op.getPassThru());
})
.Default([](Operation *) { llvm_unreachable("unexpected operation."); });
return success();
}
Expand Down Expand Up @@ -551,10 +566,25 @@ LogicalResult LoadOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
loadOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
return failure();
llvm::TypeSwitch<Operation *, void>(loadOp)
.Case<affine::AffineLoadOp, memref::LoadOp>([&](auto op) {
rewriter.replaceOpWithNewOp<decltype(op)>(
.Case([&](affine::AffineLoadOp op) {
rewriter.replaceOpWithNewOp<affine::AffineLoadOp>(
loadOp, collapseShapeOp.getViewSource(), sourceIndices);
})
.Case([&](memref::LoadOp op) {
rewriter.replaceOpWithNewOp<memref::LoadOp>(
loadOp, collapseShapeOp.getViewSource(), sourceIndices,
op.getNontemporal());
})
.Case([&](vector::LoadOp op) {
rewriter.replaceOpWithNewOp<vector::LoadOp>(
op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
op.getNontemporal());
})
.Case([&](vector::MaskedLoadOp op) {
rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
op.getMask(), op.getPassThru());
})
.Default([](Operation *) { llvm_unreachable("unexpected operation."); });
return success();
}
Expand Down Expand Up @@ -651,10 +681,25 @@ LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
return failure();
llvm::TypeSwitch<Operation *, void>(storeOp)
.Case<affine::AffineStoreOp, memref::StoreOp>([&](auto op) {
rewriter.replaceOpWithNewOp<decltype(op)>(storeOp, storeOp.getValue(),
expandShapeOp.getViewSource(),
sourceIndices);
.Case([&](affine::AffineStoreOp op) {
rewriter.replaceOpWithNewOp<affine::AffineStoreOp>(
storeOp, op.getValueToStore(), expandShapeOp.getViewSource(),
sourceIndices);
})
.Case([&](memref::StoreOp op) {
rewriter.replaceOpWithNewOp<memref::StoreOp>(
storeOp, op.getValueToStore(), expandShapeOp.getViewSource(),
sourceIndices, op.getNontemporal());
})
.Case([&](vector::StoreOp op) {
rewriter.replaceOpWithNewOp<vector::StoreOp>(
op, op.getValueToStore(), expandShapeOp.getViewSource(),
sourceIndices, op.getNontemporal());
})
.Case([&](vector::MaskedStoreOp op) {
rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
op, expandShapeOp.getViewSource(), sourceIndices, op.getMask(),
op.getValueToStore());
})
.Default([](Operation *) { llvm_unreachable("unexpected operation."); });
return success();
Expand Down Expand Up @@ -685,11 +730,26 @@ LogicalResult StoreOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
storeOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
return failure();
llvm::TypeSwitch<Operation *, void>(storeOp)
.Case<affine::AffineStoreOp, memref::StoreOp>([&](auto op) {
rewriter.replaceOpWithNewOp<decltype(op)>(
storeOp, storeOp.getValue(), collapseShapeOp.getViewSource(),
.Case([&](affine::AffineStoreOp op) {
rewriter.replaceOpWithNewOp<affine::AffineStoreOp>(
storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(),
sourceIndices);
})
.Case([&](memref::StoreOp op) {
rewriter.replaceOpWithNewOp<memref::StoreOp>(
storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(),
sourceIndices, op.getNontemporal());
})
.Case([&](vector::StoreOp op) {
rewriter.replaceOpWithNewOp<vector::StoreOp>(
op, op.getValueToStore(), collapseShapeOp.getViewSource(),
sourceIndices, op.getNontemporal());
})
.Case([&](vector::MaskedStoreOp op) {
rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
op, collapseShapeOp.getViewSource(), sourceIndices, op.getMask(),
op.getValueToStore());
})
.Default([](Operation *) { llvm_unreachable("unexpected operation."); });
return success();
}
Expand Down Expand Up @@ -763,12 +823,20 @@ void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) {
StoreOpOfSubViewOpFolder<gpu::SubgroupMmaStoreMatrixOp>,
LoadOpOfExpandShapeOpFolder<affine::AffineLoadOp>,
LoadOpOfExpandShapeOpFolder<memref::LoadOp>,
LoadOpOfExpandShapeOpFolder<vector::LoadOp>,
LoadOpOfExpandShapeOpFolder<vector::MaskedLoadOp>,
StoreOpOfExpandShapeOpFolder<affine::AffineStoreOp>,
StoreOpOfExpandShapeOpFolder<memref::StoreOp>,
StoreOpOfExpandShapeOpFolder<vector::StoreOp>,
StoreOpOfExpandShapeOpFolder<vector::MaskedStoreOp>,
LoadOpOfCollapseShapeOpFolder<affine::AffineLoadOp>,
LoadOpOfCollapseShapeOpFolder<memref::LoadOp>,
LoadOpOfCollapseShapeOpFolder<vector::LoadOp>,
LoadOpOfCollapseShapeOpFolder<vector::MaskedLoadOp>,
StoreOpOfCollapseShapeOpFolder<affine::AffineStoreOp>,
StoreOpOfCollapseShapeOpFolder<memref::StoreOp>,
StoreOpOfCollapseShapeOpFolder<vector::StoreOp>,
StoreOpOfCollapseShapeOpFolder<vector::MaskedStoreOp>,
SubViewOfSubViewFolder, NVGPUAsyncCopyOpSubViewOpFolder>(
patterns.getContext());
}
Expand Down
172 changes: 160 additions & 12 deletions mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -473,10 +473,10 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_3d(%ar
func.func @fold_dynamic_subview_with_memref_load_expand_shape(%arg0 : memref<16x?xf32, strided<[16, 1]>>, %arg1 : index, %arg2 : index, %sz0: index) -> f32 {
%c0 = arith.constant 0 : index
%expand_shape = memref.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [1, 16, %sz0, 1] : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
%0 = memref.load %expand_shape[%c0, %arg1, %arg2, %c0] : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
%0 = memref.load %expand_shape[%c0, %arg1, %arg2, %c0] {nontemporal = true} : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
return %0 : f32
}
// CHECK-NEXT: %[[VAL1:.*]] = memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] : memref<16x?xf32, strided<[16, 1]>>
// CHECK-NEXT: %[[VAL1:.*]] = memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] {nontemporal = true} : memref<16x?xf32, strided<[16, 1]>>
// CHECK-NEXT: return %[[VAL1]] : f32

// -----
Expand All @@ -487,11 +487,11 @@ func.func @fold_dynamic_subview_with_memref_store_expand_shape(%arg0 : memref<16
%c0 = arith.constant 0 : index
%c1f32 = arith.constant 1.0 : f32
%expand_shape = memref.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [1, 16, %sz0, 1] : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
memref.store %c1f32, %expand_shape[%c0, %arg1, %arg2, %c0] : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
memref.store %c1f32, %expand_shape[%c0, %arg1, %arg2, %c0] {nontemporal = true} : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
return
}
// CHECK: %[[C1F32:.*]] = arith.constant 1.000000e+00 : f32
// CHECK-NEXT: memref.store %[[C1F32]], %[[ARG0]][%[[ARG1]], %[[ARG2]]] : memref<16x?xf32, strided<[16, 1]>>
// CHECK-NEXT: memref.store %[[C1F32]], %[[ARG0]][%[[ARG1]], %[[ARG2]]] {nontemporal = true} : memref<16x?xf32, strided<[16, 1]>>
// CHECK-NEXT: return

// -----
Expand Down Expand Up @@ -819,29 +819,29 @@ func.func @test_ldmatrix(%arg0: memref<4x32x32xf16, 3>, %arg1: index, %arg2: ind

// -----

func.func @fold_vector_load(
func.func @fold_vector_load_subview(
%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index) -> vector<12x32xf32> {
%0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
%1 = vector.load %0[] : memref<f32, strided<[], offset: ?>>, vector<12x32xf32>
return %1 : vector<12x32xf32>
}

// CHECK: func @fold_vector_load
// CHECK: func @fold_vector_load_subview
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
// CHECK: vector.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] : memref<12x32xf32>, vector<12x32xf32>

// -----

func.func @fold_vector_maskedload(
func.func @fold_vector_maskedload_subview(
%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3: vector<32xi1>, %arg4: vector<32xf32>) -> vector<32xf32> {
%0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
%1 = vector.maskedload %0[], %arg3, %arg4 : memref<f32, strided<[], offset: ?>>, vector<32xi1>, vector<32xf32> into vector<32xf32>
return %1 : vector<32xf32>
}

// CHECK: func @fold_vector_maskedload
// CHECK: func @fold_vector_maskedload_subview
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
Expand All @@ -851,14 +851,14 @@ func.func @fold_vector_maskedload(

// -----

func.func @fold_vector_store(
func.func @fold_vector_store_subview(
%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3: vector<2x32xf32>) -> () {
%0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
vector.store %arg3, %0[] : memref<f32, strided<[], offset: ?>>, vector<2x32xf32>
return
}

// CHECK: func @fold_vector_store
// CHECK: func @fold_vector_store_subview
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
Expand All @@ -868,18 +868,166 @@ func.func @fold_vector_store(

// -----

func.func @fold_vector_maskedstore(
func.func @fold_vector_maskedstore_subview(
%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3: vector<32xi1>, %arg4: vector<32xf32>) -> () {
%0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
vector.maskedstore %0[], %arg3, %arg4 : memref<f32, strided<[], offset: ?>>, vector<32xi1>, vector<32xf32>
return
}

// CHECK: func @fold_vector_maskedstore
// CHECK: func @fold_vector_maskedstore_subview
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: vector<32xi1>
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: vector<32xf32>
// CHECK: vector.maskedstore %[[ARG0]][%[[ARG1]], %[[ARG2]]], %[[ARG3]], %[[ARG4]] : memref<12x32xf32>, vector<32xi1>, vector<32xf32>
// CHECK: return

// -----

func.func @fold_vector_load_expand_shape(
%arg0 : memref<32xf32>, %arg1 : index) -> vector<8xf32> {
%c0 = arith.constant 0 : index
%0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
%1 = vector.load %0[%arg1, %c0] {nontemporal = true} : memref<4x8xf32>, vector<8xf32>
return %1 : vector<8xf32>
}

// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 8)>
// CHECK-LABEL: func @fold_vector_load_expand_shape
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
// CHECK: %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
// CHECK: vector.load %[[ARG0]][%[[IDX]]] {nontemporal = true}

// -----

func.func @fold_vector_maskedload_expand_shape(
%arg0 : memref<32xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) -> vector<8xf32> {
%c0 = arith.constant 0 : index
%0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
%1 = vector.maskedload %0[%arg1, %c0], %arg3, %arg4 : memref<4x8xf32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
return %1 : vector<8xf32>
}

// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 8)>
// CHECK-LABEL: func @fold_vector_maskedload_expand_shape
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: vector<8xi1>
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: vector<8xf32>
// CHECK: %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
// CHECK: vector.maskedload %[[ARG0]][%[[IDX]]], %[[ARG3]], %[[ARG4]]

// -----

func.func @fold_vector_store_expand_shape(
%arg0 : memref<32xf32>, %arg1 : index, %val : vector<8xf32>) {
%c0 = arith.constant 0 : index
%0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
vector.store %val, %0[%arg1, %c0] {nontemporal = true} : memref<4x8xf32>, vector<8xf32>
return
}

// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 8)>
// CHECK-LABEL: func @fold_vector_store_expand_shape
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
// CHECK: %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
// CHECK: vector.store %{{.*}}, %[[ARG0]][%[[IDX]]] {nontemporal = true}

// -----

func.func @fold_vector_maskedstore_expand_shape(
%arg0 : memref<32xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) {
%c0 = arith.constant 0 : index
%0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
vector.maskedstore %0[%arg1, %c0], %arg3, %arg4 : memref<4x8xf32>, vector<8xi1>, vector<8xf32>
return
}

// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 8)>
// CHECK-LABEL: func @fold_vector_maskedstore_expand_shape
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: vector<8xi1>
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: vector<8xf32>
// CHECK: %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
// CHECK: vector.maskedstore %[[ARG0]][%[[IDX]]], %[[ARG3]], %[[ARG4]]

// -----

func.func @fold_vector_load_collapse_shape(
%arg0 : memref<4x8xf32>, %arg1 : index) -> vector<8xf32> {
%0 = memref.collapse_shape %arg0 [[0, 1]] : memref<4x8xf32> into memref<32xf32>
%1 = vector.load %0[%arg1] {nontemporal = true} : memref<32xf32>, vector<8xf32>
return %1 : vector<8xf32>
}

// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 8)>
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 8)>
// CHECK-LABEL: func @fold_vector_load_collapse_shape
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<4x8xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
// CHECK: %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
// CHECK: %[[IDX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
// CHECK: vector.load %[[ARG0]][%[[IDX]], %[[IDX1]]] {nontemporal = true}

// -----

func.func @fold_vector_maskedload_collapse_shape(
%arg0 : memref<4x8xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) -> vector<8xf32> {
%0 = memref.collapse_shape %arg0 [[0, 1]] : memref<4x8xf32> into memref<32xf32>
%1 = vector.maskedload %0[%arg1], %arg3, %arg4 : memref<32xf32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
return %1 : vector<8xf32>
}

// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 8)>
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 8)>
// CHECK-LABEL: func @fold_vector_maskedload_collapse_shape
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<4x8xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: vector<8xi1>
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: vector<8xf32>
// CHECK: %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
// CHECK: %[[IDX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
// CHECK: vector.maskedload %[[ARG0]][%[[IDX]], %[[IDX1]]], %[[ARG3]], %[[ARG4]]

// -----

func.func @fold_vector_store_collapse_shape(
%arg0 : memref<4x8xf32>, %arg1 : index, %val : vector<8xf32>) {
%0 = memref.collapse_shape %arg0 [[0, 1]] : memref<4x8xf32> into memref<32xf32>
vector.store %val, %0[%arg1] {nontemporal = true} : memref<32xf32>, vector<8xf32>
return
}

// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 8)>
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 8)>
// CHECK-LABEL: func @fold_vector_store_collapse_shape
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<4x8xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
// CHECK: %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
// CHECK: %[[IDX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
// CHECK: vector.store %{{.*}}, %[[ARG0]][%[[IDX]], %[[IDX1]]] {nontemporal = true}

// -----

func.func @fold_vector_maskedstore_collapse_shape(
%arg0 : memref<4x8xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) {
%0 = memref.collapse_shape %arg0 [[0, 1]] : memref<4x8xf32> into memref<32xf32>
vector.maskedstore %0[%arg1], %arg3, %arg4 : memref<32xf32>, vector<8xi1>, vector<8xf32>
return
}

// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 8)>
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 8)>
// CHECK-LABEL: func @fold_vector_maskedstore_collapse_shape
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<4x8xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: vector<8xi1>
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: vector<8xf32>
// CHECK: %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
// CHECK: %[[IDX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
// CHECK: vector.maskedstore %[[ARG0]][%[[IDX]], %[[IDX1]]], %[[ARG3]], %[[ARG4]]