diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp index e05c801121ffc..3a30382114c8d 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -354,11 +354,13 @@ namespace { /// inserting a memref.subview dropping those unit dims. The vector shapes are /// also reduced accordingly. class TransferReadDropUnitDimsPattern - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + : public vector::MaskableOpRewritePattern { + using MaskableOpRewritePattern::MaskableOpRewritePattern; - LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp, - PatternRewriter &rewriter) const override { + FailureOr + matchAndRewriteMaskableOp(vector::TransferReadOp transferReadOp, + vector::MaskingOpInterface maskingOp, + PatternRewriter &rewriter) const override { auto loc = transferReadOp.getLoc(); Value vector = transferReadOp.getVector(); VectorType vectorType = cast(vector.getType()); @@ -376,6 +378,10 @@ class TransferReadDropUnitDimsPattern int reducedRank = getReducedRank(sourceType.getShape()); if (reducedRank == sourceType.getRank()) return failure(); + // TODO: Extend vector.mask to support 0-d vectors. In the meantime, bail + // out. + if (reducedRank == 0 && maskingOp) + return failure(); // Check if the reduced vector shape matches the reduced source shape. // Otherwise, this case is not supported yet. VectorType reducedVectorType = trimNonScalableUnitDims(vectorType); @@ -406,15 +412,23 @@ class TransferReadDropUnitDimsPattern SmallVector zeros(reducedRank, c0); auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank); SmallVector inBounds(reducedVectorType.getRank(), true); - auto newTransferReadOp = rewriter.create( + Operation *newTransferReadOp = rewriter.create( loc, reducedVectorType, reducedShapeSource, zeros, identityMap, transferReadOp.getPadding(), maskOp, rewriter.getBoolArrayAttr(inBounds)); + + if (maskingOp) { + auto shapeCastMask = rewriter.createOrFold( + loc, reducedVectorType.cloneWith(std::nullopt, rewriter.getI1Type()), + maskingOp.getMask()); + newTransferReadOp = mlir::vector::maskOperation( + rewriter, newTransferReadOp, shapeCastMask); + } + auto shapeCast = rewriter.createOrFold( - loc, vectorType, newTransferReadOp); - rewriter.replaceOp(transferReadOp, shapeCast); + loc, vectorType, newTransferReadOp->getResults()[0]); - return success(); + return shapeCast; } }; @@ -422,11 +436,13 @@ class TransferReadDropUnitDimsPattern /// has unit dims, by inserting a `memref.subview` dropping those unit dims. The /// vector shapes are also reduced accordingly. class TransferWriteDropUnitDimsPattern - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + : public vector::MaskableOpRewritePattern { + using MaskableOpRewritePattern::MaskableOpRewritePattern; - LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp, - PatternRewriter &rewriter) const override { + FailureOr + matchAndRewriteMaskableOp(vector::TransferWriteOp transferWriteOp, + vector::MaskingOpInterface maskingOp, + PatternRewriter &rewriter) const override { auto loc = transferWriteOp.getLoc(); Value vector = transferWriteOp.getVector(); VectorType vectorType = cast(vector.getType()); @@ -444,6 +460,10 @@ class TransferWriteDropUnitDimsPattern int reducedRank = getReducedRank(sourceType.getShape()); if (reducedRank == sourceType.getRank()) return failure(); + // TODO: Extend vector.mask to support 0-d vectors. In the meantime, bail + // out. + if (reducedRank == 0 && maskingOp) + return failure(); // Check if the reduced vector shape matches the reduced destination shape. // Otherwise, this case is not supported yet. VectorType reducedVectorType = trimNonScalableUnitDims(vectorType); @@ -474,13 +494,26 @@ class TransferWriteDropUnitDimsPattern SmallVector zeros(reducedRank, c0); auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank); SmallVector inBounds(reducedVectorType.getRank(), true); - auto shapeCast = rewriter.createOrFold( + auto shapeCastSrc = rewriter.createOrFold( loc, reducedVectorType, vector); - rewriter.replaceOpWithNewOp( - transferWriteOp, Type(), shapeCast, reducedShapeSource, zeros, - identityMap, maskOp, rewriter.getBoolArrayAttr(inBounds)); + Operation *newXferWrite = rewriter.create( + loc, Type(), shapeCastSrc, reducedShapeSource, zeros, identityMap, + maskOp, rewriter.getBoolArrayAttr(inBounds)); + + if (maskingOp) { + auto shapeCastMask = rewriter.createOrFold( + loc, reducedVectorType.cloneWith(std::nullopt, rewriter.getI1Type()), + maskingOp.getMask()); + newXferWrite = + mlir::vector::maskOperation(rewriter, newXferWrite, shapeCastMask); + } - return success(); + if (transferWriteOp.hasPureTensorSemantics()) + return newXferWrite->getResults()[0]; + + // With Memref semantics, there's no return value. Use empty value to signal + // success. + return Value(); } }; diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index 36d04bb77e3b9..bebe47ba2db9a 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1717,6 +1717,15 @@ func.func @vector_mask_shape_mismatch(%a: vector<8xi32>, %m0: vector<16xi1>) -> // ----- +func.func @vector_mask_passthru_type_mismatch(%t0: tensor, %m0: vector) -> vector { + %ft0 = arith.constant 0.0 : f32 + // expected-error@+1 {{'vector.mask' op operand #0 must be vector of 1-bit signless integer values, but got 'vector'}} + %0 = vector.mask %m0 { vector.transfer_read %t0[], %ft0 : tensor, vector } : vector -> vector + return %0 : vector +} + +// ----- + // expected-note@+1 {{prior use here}} func.func @vector_mask_passthru_type_mismatch(%t0: tensor, %idx: index, %m0: vector<16xi1>, %pt0: vector<16xi32>) -> vector<16xf32> { %ft0 = arith.constant 0.0 : f32 diff --git a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir index e9d12b044e2c7..8234351302f6b 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir @@ -1,5 +1,9 @@ // RUN: mlir-opt %s --transform-interpreter | FileCheck %s +//----------------------------------------------------------------------------- +// [Patterns: TransferWriteDropUnitDimsPattern, TransferReadeDropUnitDimsPattern] +//----------------------------------------------------------------------------- + func.func @transfer_read_rank_reducing( %arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>) -> vector<3x2xi8> { %c0 = arith.constant 0 : index @@ -14,7 +18,29 @@ func.func @transfer_read_rank_reducing( // CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}> // CHECK: vector.transfer_read %[[SUBVIEW]] -func.func @transfer_write_rank_reducing(%arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, %vec : vector<3x2xi8>) { +func.func @transfer_read_rank_reducing_masked( + %arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, + %mask: vector<3x2xi1>) -> vector<3x2xi8> { + %c0 = arith.constant 0 : index + %cst = arith.constant 0 : i8 + %v = vector.mask %mask { + vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : + memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, vector<3x2xi8> + } : vector<3x2xi1> -> vector<3x2xi8> + return %v : vector<3x2xi8> +} +// CHECK-LABEL: func @transfer_read_rank_reducing_masked +// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x2xi8 +// CHECK-SAME: %[[MASK:.+]]: vector<3x2xi1> +// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0] [1, 1, 3, 2] [1, 1, 1, 1] +// CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}> +// CHECK: vector.mask %[[MASK]] +// CHECK-SAME: vector.transfer_read %[[SUBVIEW]] + +func.func @transfer_write_rank_reducing( + %arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, + %vec : vector<3x2xi8>) { + %c0 = arith.constant 0 : index vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] : vector<3x2xi8>, memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>> @@ -26,6 +52,26 @@ func.func @transfer_write_rank_reducing(%arg : memref<1x1x3x2xi8, strided<[6, 6, // CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}> // CHECK: vector.transfer_write %{{.*}}, %[[SUBVIEW]] +func.func @transfer_write_rank_reducing_masked( + %arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, + %vec : vector<3x2xi8>, + %mask: vector<3x2xi1>) { + %c0 = arith.constant 0 : index + vector.mask %mask { + vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] : + vector<3x2xi8>, memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>> + } : vector<3x2xi1> + return +} +// CHECK-LABEL: func @transfer_write_rank_reducing_masked +// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x2xi8 +// CHECK-SAME: %[[VEC:.+]]: vector<3x2xi8> +// CHECK-SAME: %[[MASK:.+]]: vector<3x2xi1> +// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0] [1, 1, 3, 2] [1, 1, 1, 1] +// CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}> +// CHECK: vector.mask %[[MASK]] +// CHECK-SAME: vector.transfer_write %{{.*}}, %[[SUBVIEW]] + func.func @transfer_read_and_vector_rank_reducing( %arg : memref<1x1x3x2x1xf32>) -> vector<3x2x1xf32> { %c0 = arith.constant 0 : index @@ -68,6 +114,22 @@ func.func @transfer_read_and_vector_rank_reducing_to_0d( // CHECK: %[[READ:.+]] = vector.transfer_read %[[SUBVIEW]]{{.*}} : memref, vector // CHECK: vector.shape_cast %[[READ]] : vector to vector<1x1x1xf32> +func.func @transfer_read_and_vector_rank_reducing_to_0d_masked( + %arg : memref<1x1x1x1x1xf32>, + %mask: vector<1x1x1xi1>) -> vector<1x1x1xf32> { + + %c0 = arith.constant 0 : index + %cst = arith.constant 0.0 : f32 + %v = vector.mask %mask { + vector.transfer_read %arg[%c0, %c0, %c0, %c0, %c0], %cst + : memref<1x1x1x1x1xf32>, vector<1x1x1xf32> + } : vector<1x1x1xi1> -> vector<1x1x1xf32> + return %v : vector<1x1x1xf32> +} +// CHECK-LABEL: func @transfer_read_and_vector_rank_reducing_to_0d_masked +// CHECK-NOT: vector.shape_cast +// CHECK-NOT: memref.subview + func.func @transfer_write_and_vector_rank_reducing_to_0d( %arg : memref<1x1x1x1x1xf32>, %vec : vector<1x1x1xf32>) { @@ -82,6 +144,23 @@ func.func @transfer_write_and_vector_rank_reducing_to_0d( // CHECK: %[[SHCAST:.+]] = vector.shape_cast %[[VECTOR]] : vector<1x1x1xf32> to vector // CHECK: vector.transfer_write %[[SHCAST]], %[[SUBVIEW]]{{.*}} : vector, memref +func.func @transfer_write_and_vector_rank_reducing_to_0d_masked( + %arg : memref<1x1x1x1x1xf32>, + %vec : vector<1x1x1xf32>, + %mask: vector<1x1x1xi1>) { + + %c0 = arith.constant 0 : index + %cst = arith.constant 0.0 : f32 + vector.mask %mask { + vector.transfer_write %vec, %arg[%c0, %c0, %c0, %c0, %c0] : + vector<1x1x1xf32>, memref<1x1x1x1x1xf32> + } : vector<1x1x1xi1> + return +} +// CHECK-LABEL: func @transfer_write_and_vector_rank_reducing_to_0d_masked +// CHECK-NOT: vector.shape_cast +// CHECK-NOT: memref.subview + func.func @transfer_read_dynamic_rank_reducing( %arg : memref>) -> vector<[16]x1xi8> { %c0 = arith.constant 0 : index