Skip to content

Commit 0cf7aaf

Browse files
authored
[MLIR][Vector] Update Transfer{Read|Write}DropUnitDimsPattern patterns (#112394)
Updates `TransferWriteDropUnitDimsPattern` and `TransferReadDropUnitDimsPattern` to inherit from `MaskableOpRewritePattern` so that masked versions of xfer_read/xfer_write Ops are also supported: ```mlir %v = vector.mask %mask { vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, vector<3x2xi8> } : vector<3x2xi1> -> vector<3x2xi8> ```
1 parent 4102625 commit 0cf7aaf

File tree

3 files changed

+139
-18
lines changed

3 files changed

+139
-18
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp

Lines changed: 50 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -354,11 +354,13 @@ namespace {
354354
/// inserting a memref.subview dropping those unit dims. The vector shapes are
355355
/// also reduced accordingly.
356356
class TransferReadDropUnitDimsPattern
357-
: public OpRewritePattern<vector::TransferReadOp> {
358-
using OpRewritePattern::OpRewritePattern;
357+
: public vector::MaskableOpRewritePattern<vector::TransferReadOp> {
358+
using MaskableOpRewritePattern::MaskableOpRewritePattern;
359359

360-
LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
361-
PatternRewriter &rewriter) const override {
360+
FailureOr<Value>
361+
matchAndRewriteMaskableOp(vector::TransferReadOp transferReadOp,
362+
vector::MaskingOpInterface maskingOp,
363+
PatternRewriter &rewriter) const override {
362364
auto loc = transferReadOp.getLoc();
363365
Value vector = transferReadOp.getVector();
364366
VectorType vectorType = cast<VectorType>(vector.getType());
@@ -376,6 +378,10 @@ class TransferReadDropUnitDimsPattern
376378
int reducedRank = getReducedRank(sourceType.getShape());
377379
if (reducedRank == sourceType.getRank())
378380
return failure();
381+
// TODO: Extend vector.mask to support 0-d vectors. In the meantime, bail
382+
// out.
383+
if (reducedRank == 0 && maskingOp)
384+
return failure();
379385
// Check if the reduced vector shape matches the reduced source shape.
380386
// Otherwise, this case is not supported yet.
381387
VectorType reducedVectorType = trimNonScalableUnitDims(vectorType);
@@ -406,27 +412,37 @@ class TransferReadDropUnitDimsPattern
406412
SmallVector<Value> zeros(reducedRank, c0);
407413
auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
408414
SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
409-
auto newTransferReadOp = rewriter.create<vector::TransferReadOp>(
415+
Operation *newTransferReadOp = rewriter.create<vector::TransferReadOp>(
410416
loc, reducedVectorType, reducedShapeSource, zeros, identityMap,
411417
transferReadOp.getPadding(), maskOp,
412418
rewriter.getBoolArrayAttr(inBounds));
419+
420+
if (maskingOp) {
421+
auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>(
422+
loc, reducedVectorType.cloneWith(std::nullopt, rewriter.getI1Type()),
423+
maskingOp.getMask());
424+
newTransferReadOp = mlir::vector::maskOperation(
425+
rewriter, newTransferReadOp, shapeCastMask);
426+
}
427+
413428
auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
414-
loc, vectorType, newTransferReadOp);
415-
rewriter.replaceOp(transferReadOp, shapeCast);
429+
loc, vectorType, newTransferReadOp->getResults()[0]);
416430

417-
return success();
431+
return shapeCast;
418432
}
419433
};
420434

421435
/// Rewrites `vector.transfer_write` ops where the "source" (i.e. destination)
422436
/// has unit dims, by inserting a `memref.subview` dropping those unit dims. The
423437
/// vector shapes are also reduced accordingly.
424438
class TransferWriteDropUnitDimsPattern
425-
: public OpRewritePattern<vector::TransferWriteOp> {
426-
using OpRewritePattern::OpRewritePattern;
439+
: public vector::MaskableOpRewritePattern<vector::TransferWriteOp> {
440+
using MaskableOpRewritePattern::MaskableOpRewritePattern;
427441

428-
LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
429-
PatternRewriter &rewriter) const override {
442+
FailureOr<Value>
443+
matchAndRewriteMaskableOp(vector::TransferWriteOp transferWriteOp,
444+
vector::MaskingOpInterface maskingOp,
445+
PatternRewriter &rewriter) const override {
430446
auto loc = transferWriteOp.getLoc();
431447
Value vector = transferWriteOp.getVector();
432448
VectorType vectorType = cast<VectorType>(vector.getType());
@@ -444,6 +460,10 @@ class TransferWriteDropUnitDimsPattern
444460
int reducedRank = getReducedRank(sourceType.getShape());
445461
if (reducedRank == sourceType.getRank())
446462
return failure();
463+
// TODO: Extend vector.mask to support 0-d vectors. In the meantime, bail
464+
// out.
465+
if (reducedRank == 0 && maskingOp)
466+
return failure();
447467
// Check if the reduced vector shape matches the reduced destination shape.
448468
// Otherwise, this case is not supported yet.
449469
VectorType reducedVectorType = trimNonScalableUnitDims(vectorType);
@@ -474,13 +494,26 @@ class TransferWriteDropUnitDimsPattern
474494
SmallVector<Value> zeros(reducedRank, c0);
475495
auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
476496
SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
477-
auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
497+
auto shapeCastSrc = rewriter.createOrFold<vector::ShapeCastOp>(
478498
loc, reducedVectorType, vector);
479-
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
480-
transferWriteOp, Type(), shapeCast, reducedShapeSource, zeros,
481-
identityMap, maskOp, rewriter.getBoolArrayAttr(inBounds));
499+
Operation *newXferWrite = rewriter.create<vector::TransferWriteOp>(
500+
loc, Type(), shapeCastSrc, reducedShapeSource, zeros, identityMap,
501+
maskOp, rewriter.getBoolArrayAttr(inBounds));
502+
503+
if (maskingOp) {
504+
auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>(
505+
loc, reducedVectorType.cloneWith(std::nullopt, rewriter.getI1Type()),
506+
maskingOp.getMask());
507+
newXferWrite =
508+
mlir::vector::maskOperation(rewriter, newXferWrite, shapeCastMask);
509+
}
482510

483-
return success();
511+
if (transferWriteOp.hasPureTensorSemantics())
512+
return newXferWrite->getResults()[0];
513+
514+
// With Memref semantics, there's no return value. Use empty value to signal
515+
// success.
516+
return Value();
484517
}
485518
};
486519

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1717,6 +1717,15 @@ func.func @vector_mask_shape_mismatch(%a: vector<8xi32>, %m0: vector<16xi1>) ->
17171717

17181718
// -----
17191719

1720+
func.func @vector_mask_passthru_type_mismatch(%t0: tensor<f32>, %m0: vector<i1>) -> vector<f32> {
1721+
%ft0 = arith.constant 0.0 : f32
1722+
// expected-error@+1 {{'vector.mask' op operand #0 must be vector of 1-bit signless integer values, but got 'vector<i1>'}}
1723+
%0 = vector.mask %m0 { vector.transfer_read %t0[], %ft0 : tensor<f32>, vector<f32> } : vector<i1> -> vector<f32>
1724+
return %0 : vector<f32>
1725+
}
1726+
1727+
// -----
1728+
17201729
// expected-note@+1 {{prior use here}}
17211730
func.func @vector_mask_passthru_type_mismatch(%t0: tensor<?xf32>, %idx: index, %m0: vector<16xi1>, %pt0: vector<16xi32>) -> vector<16xf32> {
17221731
%ft0 = arith.constant 0.0 : f32

mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
22

3+
//-----------------------------------------------------------------------------
4+
// [Patterns: TransferWriteDropUnitDimsPattern, TransferReadeDropUnitDimsPattern]
5+
//-----------------------------------------------------------------------------
6+
37
func.func @transfer_read_rank_reducing(
48
%arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>) -> vector<3x2xi8> {
59
%c0 = arith.constant 0 : index
@@ -14,7 +18,29 @@ func.func @transfer_read_rank_reducing(
1418
// CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
1519
// CHECK: vector.transfer_read %[[SUBVIEW]]
1620

17-
func.func @transfer_write_rank_reducing(%arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, %vec : vector<3x2xi8>) {
21+
func.func @transfer_read_rank_reducing_masked(
22+
%arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>,
23+
%mask: vector<3x2xi1>) -> vector<3x2xi8> {
24+
%c0 = arith.constant 0 : index
25+
%cst = arith.constant 0 : i8
26+
%v = vector.mask %mask {
27+
vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
28+
memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, vector<3x2xi8>
29+
} : vector<3x2xi1> -> vector<3x2xi8>
30+
return %v : vector<3x2xi8>
31+
}
32+
// CHECK-LABEL: func @transfer_read_rank_reducing_masked
33+
// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x2xi8
34+
// CHECK-SAME: %[[MASK:.+]]: vector<3x2xi1>
35+
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0] [1, 1, 3, 2] [1, 1, 1, 1]
36+
// CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
37+
// CHECK: vector.mask %[[MASK]]
38+
// CHECK-SAME: vector.transfer_read %[[SUBVIEW]]
39+
40+
func.func @transfer_write_rank_reducing(
41+
%arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>,
42+
%vec : vector<3x2xi8>) {
43+
1844
%c0 = arith.constant 0 : index
1945
vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
2046
vector<3x2xi8>, memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>
@@ -26,6 +52,26 @@ func.func @transfer_write_rank_reducing(%arg : memref<1x1x3x2xi8, strided<[6, 6,
2652
// CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
2753
// CHECK: vector.transfer_write %{{.*}}, %[[SUBVIEW]]
2854

55+
func.func @transfer_write_rank_reducing_masked(
56+
%arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>,
57+
%vec : vector<3x2xi8>,
58+
%mask: vector<3x2xi1>) {
59+
%c0 = arith.constant 0 : index
60+
vector.mask %mask {
61+
vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
62+
vector<3x2xi8>, memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>
63+
} : vector<3x2xi1>
64+
return
65+
}
66+
// CHECK-LABEL: func @transfer_write_rank_reducing_masked
67+
// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x2xi8
68+
// CHECK-SAME: %[[VEC:.+]]: vector<3x2xi8>
69+
// CHECK-SAME: %[[MASK:.+]]: vector<3x2xi1>
70+
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0] [1, 1, 3, 2] [1, 1, 1, 1]
71+
// CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
72+
// CHECK: vector.mask %[[MASK]]
73+
// CHECK-SAME: vector.transfer_write %{{.*}}, %[[SUBVIEW]]
74+
2975
func.func @transfer_read_and_vector_rank_reducing(
3076
%arg : memref<1x1x3x2x1xf32>) -> vector<3x2x1xf32> {
3177
%c0 = arith.constant 0 : index
@@ -68,6 +114,22 @@ func.func @transfer_read_and_vector_rank_reducing_to_0d(
68114
// CHECK: %[[READ:.+]] = vector.transfer_read %[[SUBVIEW]]{{.*}} : memref<f32>, vector<f32>
69115
// CHECK: vector.shape_cast %[[READ]] : vector<f32> to vector<1x1x1xf32>
70116

117+
func.func @transfer_read_and_vector_rank_reducing_to_0d_masked(
118+
%arg : memref<1x1x1x1x1xf32>,
119+
%mask: vector<1x1x1xi1>) -> vector<1x1x1xf32> {
120+
121+
%c0 = arith.constant 0 : index
122+
%cst = arith.constant 0.0 : f32
123+
%v = vector.mask %mask {
124+
vector.transfer_read %arg[%c0, %c0, %c0, %c0, %c0], %cst
125+
: memref<1x1x1x1x1xf32>, vector<1x1x1xf32>
126+
} : vector<1x1x1xi1> -> vector<1x1x1xf32>
127+
return %v : vector<1x1x1xf32>
128+
}
129+
// CHECK-LABEL: func @transfer_read_and_vector_rank_reducing_to_0d_masked
130+
// CHECK-NOT: vector.shape_cast
131+
// CHECK-NOT: memref.subview
132+
71133
func.func @transfer_write_and_vector_rank_reducing_to_0d(
72134
%arg : memref<1x1x1x1x1xf32>,
73135
%vec : vector<1x1x1xf32>) {
@@ -82,6 +144,23 @@ func.func @transfer_write_and_vector_rank_reducing_to_0d(
82144
// CHECK: %[[SHCAST:.+]] = vector.shape_cast %[[VECTOR]] : vector<1x1x1xf32> to vector<f32>
83145
// CHECK: vector.transfer_write %[[SHCAST]], %[[SUBVIEW]]{{.*}} : vector<f32>, memref<f32>
84146

147+
func.func @transfer_write_and_vector_rank_reducing_to_0d_masked(
148+
%arg : memref<1x1x1x1x1xf32>,
149+
%vec : vector<1x1x1xf32>,
150+
%mask: vector<1x1x1xi1>) {
151+
152+
%c0 = arith.constant 0 : index
153+
%cst = arith.constant 0.0 : f32
154+
vector.mask %mask {
155+
vector.transfer_write %vec, %arg[%c0, %c0, %c0, %c0, %c0] :
156+
vector<1x1x1xf32>, memref<1x1x1x1x1xf32>
157+
} : vector<1x1x1xi1>
158+
return
159+
}
160+
// CHECK-LABEL: func @transfer_write_and_vector_rank_reducing_to_0d_masked
161+
// CHECK-NOT: vector.shape_cast
162+
// CHECK-NOT: memref.subview
163+
85164
func.func @transfer_read_dynamic_rank_reducing(
86165
%arg : memref<?x1xi8, strided<[?, ?], offset: ?>>) -> vector<[16]x1xi8> {
87166
%c0 = arith.constant 0 : index

0 commit comments

Comments
 (0)