From c1f4264a71d6d80350056d0d9ca86a0ac2c1e04f Mon Sep 17 00:00:00 2001 From: James Newling Date: Tue, 15 Apr 2025 13:09:43 -0700 Subject: [PATCH 1/6] fix edge case where n=k (rank-preserving shape_cast) --- .../mlir/Dialect/Vector/IR/VectorOps.td | 17 ++--- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 63 ++++++++++--------- mlir/test/Dialect/Vector/invalid.mlir | 15 +++-- mlir/test/Dialect/Vector/ops.mlir | 8 +++ 4 files changed, 61 insertions(+), 42 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 7fc56b1aa4e7e..a9e25f23ef90f 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -2244,18 +2244,19 @@ def Vector_ShapeCastOp : Results<(outs AnyVectorOfAnyRank:$result)> { let summary = "shape_cast casts between vector shapes"; let description = [{ - The shape_cast operation casts between an n-D source vector shape and - a k-D result vector shape (the element type remains the same). + The shape_cast operation casts from an n-D source vector to a k-D result + vector. The element type remains the same, as does the number of elements + (product of dimensions). + + If reducing or preserving rank (n >= k), all result dimension sizes must be + products of contiguous source dimension sizes. If expanding rank (n < k), + source dimensions must all factor into contiguous sequences of destination + dimension sizes. - If reducing rank (n > k), result dimension sizes must be a product - of contiguous source dimension sizes. - If expanding rank (n < k), source dimensions must factor into a - contiguous sequence of destination dimension sizes. Each source dim is expanded (or contiguous sequence of source dims combined) in source dimension list order (i.e. 0 <= i < n), to produce a contiguous sequence of result dims (or a single result dim), in result dimension list - order (i.e. 0 <= j < k). The product of all source dimension sizes and all - result dimension sizes must match. + order (i.e. 0 <= j < k). It is currently assumed that this operation does not require moving data, and that it will be folded away before lowering vector operations. diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index bee5c1fd6ed58..554dbba081898 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5534,10 +5534,10 @@ void ShapeCastOp::inferResultRanges(ArrayRef argRanges, /// Returns true if each element of 'a' is equal to the product of a contiguous /// sequence of the elements of 'b'. Returns false otherwise. -static bool isValidShapeCast(ArrayRef a, ArrayRef b) { +static bool isValidExpandingShapeCast(ArrayRef a, ArrayRef b) { unsigned rankA = a.size(); unsigned rankB = b.size(); - assert(rankA < rankB); + assert(rankA <= rankB); auto isOne = [](int64_t v) { return v == 1; }; @@ -5573,34 +5573,36 @@ static LogicalResult verifyVectorShapeCast(Operation *op, VectorType resultVectorType) { // Check that element type is the same. if (sourceVectorType.getElementType() != resultVectorType.getElementType()) - return op->emitOpError("source/result vectors must have same element type"); - auto sourceShape = sourceVectorType.getShape(); - auto resultShape = resultVectorType.getShape(); + return op->emitOpError("has different source and result element types"); + ArrayRef lowRankShape = sourceVectorType.getShape(); + ArrayRef highRankShape = resultVectorType.getShape(); + if (lowRankShape.size() > highRankShape.size()) + std::swap(lowRankShape, highRankShape); // Check that product of source dim sizes matches product of result dim sizes. - int64_t sourceDimProduct = std::accumulate( - sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies{}); - int64_t resultDimProduct = std::accumulate( - resultShape.begin(), resultShape.end(), 1LL, std::multiplies{}); - if (sourceDimProduct != resultDimProduct) - return op->emitOpError("source/result number of elements must match"); - - // Check that expanding/contracting rank cases. - unsigned sourceRank = sourceVectorType.getRank(); - unsigned resultRank = resultVectorType.getRank(); - if (sourceRank < resultRank) { - if (!isValidShapeCast(sourceShape, resultShape)) - return op->emitOpError("invalid shape cast"); - } else if (sourceRank > resultRank) { - if (!isValidShapeCast(resultShape, sourceShape)) - return op->emitOpError("invalid shape cast"); + int64_t nLowRankElms = + std::accumulate(lowRankShape.begin(), lowRankShape.end(), 1LL, + std::multiplies{}); + int64_t nHighRankElms = + std::accumulate(highRankShape.begin(), highRankShape.end(), 1LL, + std::multiplies{}); + + if (nLowRankElms != nHighRankElms) { + return op->emitOpError( + "has a different number of source and result elements"); + } + + if (!isValidExpandingShapeCast(lowRankShape, highRankShape)) { + return op->emitOpError( + "is invalid (does not uniformly collapse or expand)"); } // Check that (non-)scalability is preserved int64_t sourceNScalableDims = sourceVectorType.getNumScalableDims(); int64_t resultNScalableDims = resultVectorType.getNumScalableDims(); if (sourceNScalableDims != resultNScalableDims) - return op->emitOpError("different number of scalable dims at source (") + return op->emitOpError( + "has a different number of scalable dims at source (") << sourceNScalableDims << ") and result (" << resultNScalableDims << ")"; sourceVectorType.getNumDynamicDims(); @@ -5634,17 +5636,18 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) { // Only allows valid transitive folding (expand/collapse dimensions). VectorType srcType = otherOp.getSource().getType(); + if (resultType == srcType) return otherOp.getSource(); - if (srcType.getRank() < resultType.getRank()) { - if (!isValidShapeCast(srcType.getShape(), resultType.getShape())) - return {}; - } else if (srcType.getRank() > resultType.getRank()) { - if (!isValidShapeCast(resultType.getShape(), srcType.getShape())) - return {}; - } else { + + ArrayRef lowRankShape = srcType.getShape(); + ArrayRef highRankShape = resultType.getShape(); + if (lowRankShape.size() > highRankShape.size()) + std::swap(lowRankShape, highRankShape); + + if (!isValidExpandingShapeCast(lowRankShape, highRankShape)) return {}; - } + setOperand(otherOp.getSource()); return getResult(); } diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index dbf829e014b8d..9f94fb0574504 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1132,28 +1132,35 @@ func.func @cannot_print_string_with_source_set(%vec: vector<[4]xf32>) { // ----- func.func @shape_cast_wrong_element_type(%arg0 : vector<5x1x3x2xf32>) { - // expected-error@+1 {{op source/result vectors must have same element type}} + // expected-error@+1 {{op has different source and result element types}} %0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<15x2xi32> } // ----- func.func @shape_cast_wrong_num_elements(%arg0 : vector<5x1x3x2xf32>) { - // expected-error@+1 {{op source/result number of elements must match}} + // expected-error@+1 {{op has a different number of source and result elements}} %0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<10x2xf32> } // ----- +func.func @shape_cast_invalid_rank_preservating(%arg0 : vector<3x2xf32>) { + // expected-error@+1 {{op is invalid (does not uniformly collapse or expand)}} + %0 = vector.shape_cast %arg0 : vector<3x2xf32> to vector<2x3xf32> +} + +// ----- + func.func @shape_cast_invalid_rank_reduction(%arg0 : vector<5x1x3x2xf32>) { - // expected-error@+1 {{invalid shape cast}} + // expected-error@+1 {{op is invalid (does not uniformly collapse or expand)}} %0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<2x15xf32> } // ----- func.func @shape_cast_invalid_rank_expansion(%arg0 : vector<15x2xf32>) { - // expected-error@+1 {{invalid shape cast}} + // expected-error@+1 {{op is invalid (does not uniformly collapse or expand)}} %0 = vector.shape_cast %arg0 : vector<15x2xf32> to vector<5x2x3x1xf32> } diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir index 8ae1e9f9d0c64..527bccf8383ca 100644 --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -576,6 +576,14 @@ func.func @shape_cast_0d(%arg0 : vector<1x1x1x1xf32>) -> (vector<1x1x1x1xf32>) { return %1 : vector<1x1x1x1xf32> } +// CHECK-LABEL: @shape_cast_rank_preserving +func.func @shape_cast_rank_preserving(%arg0 : vector<1x4xf32>) -> vector<4x1xf32> { + + // CHECK: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4x1xf32> + %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4x1xf32> + return %0 : vector<4x1xf32> +} + // CHECK-LABEL: @bitcast func.func @bitcast(%arg0 : vector<5x1x3x2xf32>, %arg1 : vector<8x1xi32>, From b4914638b9426472c780cc94ac29224353d1b9d4 Mon Sep 17 00:00:00 2001 From: James Newling Date: Tue, 15 Apr 2025 15:10:54 -0700 Subject: [PATCH 2/6] clang-format --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 554dbba081898..120dd57659e6b 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5534,7 +5534,8 @@ void ShapeCastOp::inferResultRanges(ArrayRef argRanges, /// Returns true if each element of 'a' is equal to the product of a contiguous /// sequence of the elements of 'b'. Returns false otherwise. -static bool isValidExpandingShapeCast(ArrayRef a, ArrayRef b) { +static bool isValidExpandingShapeCast(ArrayRef a, + ArrayRef b) { unsigned rankA = a.size(); unsigned rankB = b.size(); assert(rankA <= rankB); From 7bfc219d9d3e4881a70d6866ec59b077da22e71b Mon Sep 17 00:00:00 2001 From: James Newling Date: Tue, 15 Apr 2025 16:14:40 -0700 Subject: [PATCH 3/6] update tests --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 41 ++++++------- mlir/test/Dialect/Vector/canonicalize.mlir | 10 ++-- ...-shape-cast-lowering-scalable-vectors.mlir | 58 +++++++++---------- ...vector-shape-cast-lowering-transforms.mlir | 21 ------- 4 files changed, 52 insertions(+), 78 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 120dd57659e6b..07b6baf961a3e 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5534,11 +5534,12 @@ void ShapeCastOp::inferResultRanges(ArrayRef argRanges, /// Returns true if each element of 'a' is equal to the product of a contiguous /// sequence of the elements of 'b'. Returns false otherwise. -static bool isValidExpandingShapeCast(ArrayRef a, - ArrayRef b) { +static bool isExpandingShapeCast(ArrayRef a, ArrayRef b) { unsigned rankA = a.size(); unsigned rankB = b.size(); - assert(rankA <= rankB); + if (rankA > rankB) { + return false; + } auto isOne = [](int64_t v) { return v == 1; }; @@ -5565,35 +5566,34 @@ static bool isValidExpandingShapeCast(ArrayRef a, if (j < rankB && llvm::all_of(b.slice(j), isOne)) j = rankB; } - return i == rankA && j == rankB; } +static bool isValidShapeCast(ArrayRef a, ArrayRef b) { + return isExpandingShapeCast(a, b) || isExpandingShapeCast(b, a); +} + static LogicalResult verifyVectorShapeCast(Operation *op, VectorType sourceVectorType, VectorType resultVectorType) { // Check that element type is the same. if (sourceVectorType.getElementType() != resultVectorType.getElementType()) return op->emitOpError("has different source and result element types"); - ArrayRef lowRankShape = sourceVectorType.getShape(); - ArrayRef highRankShape = resultVectorType.getShape(); - if (lowRankShape.size() > highRankShape.size()) - std::swap(lowRankShape, highRankShape); + ArrayRef inShape = sourceVectorType.getShape(); + ArrayRef outShape = resultVectorType.getShape(); // Check that product of source dim sizes matches product of result dim sizes. - int64_t nLowRankElms = - std::accumulate(lowRankShape.begin(), lowRankShape.end(), 1LL, - std::multiplies{}); - int64_t nHighRankElms = - std::accumulate(highRankShape.begin(), highRankShape.end(), 1LL, - std::multiplies{}); - - if (nLowRankElms != nHighRankElms) { + int64_t nInElms = std::accumulate(inShape.begin(), inShape.end(), 1LL, + std::multiplies{}); + int64_t nOutElms = std::accumulate(outShape.begin(), outShape.end(), 1LL, + std::multiplies{}); + + if (nInElms != nOutElms) { return op->emitOpError( "has a different number of source and result elements"); } - if (!isValidExpandingShapeCast(lowRankShape, highRankShape)) { + if (!isValidShapeCast(inShape, outShape)) { return op->emitOpError( "is invalid (does not uniformly collapse or expand)"); } @@ -5641,12 +5641,7 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) { if (resultType == srcType) return otherOp.getSource(); - ArrayRef lowRankShape = srcType.getShape(); - ArrayRef highRankShape = resultType.getShape(); - if (lowRankShape.size() > highRankShape.size()) - std::swap(lowRankShape, highRankShape); - - if (!isValidExpandingShapeCast(lowRankShape, highRankShape)) + if (!isValidShapeCast(srcType.getShape(), resultType.getShape())) return {}; setOperand(otherOp.getSource()); diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 78b0ea78849e8..8d24e1bf2ba94 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1290,12 +1290,12 @@ func.func @extract_strided_broadcast4(%arg0: f32) -> vector<1x4xf32> { // ----- // CHECK-LABEL: consecutive_shape_cast -// CHECK: %[[C:.*]] = vector.shape_cast %{{.*}} : vector<16xf16> to vector<4x4xf16> -// CHECK-NEXT: return %[[C]] : vector<4x4xf16> -func.func @consecutive_shape_cast(%arg0: vector<16xf16>) -> vector<4x4xf16> { +// CHECK: %[[C:.*]] = vector.shape_cast %{{.*}} : vector<16xf16> to vector<2x2x4xf16> +// CHECK-NEXT: return %[[C]] : vector<2x2x4xf16> +func.func @consecutive_shape_cast(%arg0: vector<16xf16>) -> vector<2x2x4xf16> { %0 = vector.shape_cast %arg0 : vector<16xf16> to vector<2x8xf16> - %1 = vector.shape_cast %0 : vector<2x8xf16> to vector<4x4xf16> - return %1 : vector<4x4xf16> + %1 = vector.shape_cast %0 : vector<2x8xf16> to vector<2x2x4xf16> + return %1 : vector<2x2x4xf16> } // ----- diff --git a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-scalable-vectors.mlir b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-scalable-vectors.mlir index f4becad3c79c1..2faa47c1b08a8 100644 --- a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-scalable-vectors.mlir +++ b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-scalable-vectors.mlir @@ -74,23 +74,23 @@ func.func @i8_1d_to_2d_last_dim_scalable(%arg0: vector<[32]xi8>) -> vector<4x[8] // CHECK-LABEL: f32_permute_leading_non_scalable_dims // CHECK-SAME: %[[arg0:.*]]: vector<2x3x[4]xf32> -func.func @f32_permute_leading_non_scalable_dims(%arg0: vector<2x3x[4]xf32>) -> vector<3x2x[4]xf32> { - // CHECK-NEXT: %[[ub:.*]] = ub.poison : vector<3x2x[4]xf32> +func.func @f32_permute_leading_non_scalable_dims(%arg0: vector<2x3x[4]xf32>) -> vector<1x6x[4]xf32> { + // CHECK-NEXT: %[[ub:.*]] = ub.poison : vector<1x6x[4]xf32> // CHECK-NEXT: %[[subvec0:.*]] = vector.extract %[[arg0]][0, 0] : vector<[4]xf32> from vector<2x3x[4]xf32> - // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[ub]] [0, 0] : vector<[4]xf32> into vector<3x2x[4]xf32> + // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[ub]] [0, 0] : vector<[4]xf32> into vector<1x6x[4]xf32> // CHECK-NEXT: %[[subvec1:.*]] = vector.extract %[[arg0]][0, 1] : vector<[4]xf32> from vector<2x3x[4]xf32> - // CHECK-NEXT: %[[res1:.*]] = vector.insert %[[subvec1]], %[[res0]] [0, 1] : vector<[4]xf32> into vector<3x2x[4]xf32> + // CHECK-NEXT: %[[res1:.*]] = vector.insert %[[subvec1]], %[[res0]] [0, 1] : vector<[4]xf32> into vector<1x6x[4]xf32> // CHECK-NEXT: %[[subvec2:.*]] = vector.extract %[[arg0]][0, 2] : vector<[4]xf32> from vector<2x3x[4]xf32> - // CHECK-NEXT: %[[res2:.*]] = vector.insert %[[subvec2]], %[[res1]] [1, 0] : vector<[4]xf32> into vector<3x2x[4]xf32> + // CHECK-NEXT: %[[res2:.*]] = vector.insert %[[subvec2]], %[[res1]] [0, 2] : vector<[4]xf32> into vector<1x6x[4]xf32> // CHECK-NEXT: %[[subvec3:.*]] = vector.extract %[[arg0]][1, 0] : vector<[4]xf32> from vector<2x3x[4]xf32> - // CHECK-NEXT: %[[res3:.*]] = vector.insert %[[subvec3]], %[[res2]] [1, 1] : vector<[4]xf32> into vector<3x2x[4]xf32> + // CHECK-NEXT: %[[res3:.*]] = vector.insert %[[subvec3]], %[[res2]] [0, 3] : vector<[4]xf32> into vector<1x6x[4]xf32> // CHECK-NEXT: %[[subvec4:.*]] = vector.extract %[[arg0]][1, 1] : vector<[4]xf32> from vector<2x3x[4]xf32> - // CHECK-NEXT: %[[res4:.*]] = vector.insert %[[subvec4]], %[[res3]] [2, 0] : vector<[4]xf32> into vector<3x2x[4]xf32> + // CHECK-NEXT: %[[res4:.*]] = vector.insert %[[subvec4]], %[[res3]] [0, 4] : vector<[4]xf32> into vector<1x6x[4]xf32> // CHECK-NEXT: %[[subvec5:.*]] = vector.extract %[[arg0]][1, 2] : vector<[4]xf32> from vector<2x3x[4]xf32> - // CHECK-NEXT: %[[res5:.*]] = vector.insert %[[subvec5]], %[[res4]] [2, 1] : vector<[4]xf32> into vector<3x2x[4]xf32> - %res = vector.shape_cast %arg0: vector<2x3x[4]xf32> to vector<3x2x[4]xf32> - // CHECK-NEXT: return %[[res5]] : vector<3x2x[4]xf32> - return %res : vector<3x2x[4]xf32> + // CHECK-NEXT: %[[res5:.*]] = vector.insert %[[subvec5]], %[[res4]] [0, 5] : vector<[4]xf32> into vector<1x6x[4]xf32> + %res = vector.shape_cast %arg0: vector<2x3x[4]xf32> to vector<1x6x[4]xf32> + // CHECK-NEXT: return %[[res5]] : vector<1x6x[4]xf32> + return %res : vector<1x6x[4]xf32> } // ----- @@ -117,48 +117,48 @@ func.func @f64_flatten_leading_non_scalable_dims(%arg0: vector<2x2x[2]xf64>) -> // CHECK-LABEL: f32_reduce_trailing_scalable_dim // CHECK-SAME: %[[arg0:.*]]: vector<3x[4]xf32> -func.func @f32_reduce_trailing_scalable_dim(%arg0: vector<3x[4]xf32>) -> vector<6x[2]xf32> +func.func @f32_reduce_trailing_scalable_dim(%arg0: vector<3x[4]xf32>) -> vector<3x2x[2]xf32> { - // CHECK-NEXT: %[[ub:.*]] = ub.poison : vector<6x[2]xf32> + // CHECK-NEXT: %[[ub:.*]] = ub.poison : vector<3x2x[2]xf32> // CHECK-NEXT: %[[srcvec0:.*]] = vector.extract %[[arg0]][0] : vector<[4]xf32> from vector<3x[4]xf32> // CHECK-NEXT: %[[subvec0:.*]] = vector.scalable.extract %[[srcvec0]][0] : vector<[2]xf32> from vector<[4]xf32> - // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[ub]] [0] : vector<[2]xf32> into vector<6x[2]xf32> + // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[ub]] [0, 0] : vector<[2]xf32> into vector<3x2x[2]xf32> // CHECK-NEXT: %[[subvec1:.*]] = vector.scalable.extract %[[srcvec0]][2] : vector<[2]xf32> from vector<[4]xf32> - // CHECK-NEXT: %[[res1:.*]] = vector.insert %[[subvec1]], %[[res0]] [1] : vector<[2]xf32> into vector<6x[2]xf32> + // CHECK-NEXT: %[[res1:.*]] = vector.insert %[[subvec1]], %[[res0]] [0, 1] : vector<[2]xf32> into vector<3x2x[2]xf32> // CHECK-NEXT: %[[srcvec1:.*]] = vector.extract %[[arg0]][1] : vector<[4]xf32> from vector<3x[4]xf32> // CHECK-NEXT: %[[subvec2:.*]] = vector.scalable.extract %[[srcvec1]][0] : vector<[2]xf32> from vector<[4]xf32> - // CHECK-NEXT: %[[res2:.*]] = vector.insert %[[subvec2]], %[[res1]] [2] : vector<[2]xf32> into vector<6x[2]xf32> + // CHECK-NEXT: %[[res2:.*]] = vector.insert %[[subvec2]], %[[res1]] [1, 0] : vector<[2]xf32> into vector<3x2x[2]xf32> // CHECK-NEXT: %[[subvec3:.*]] = vector.scalable.extract %[[srcvec1]][2] : vector<[2]xf32> from vector<[4]xf32> - // CHECK-NEXT: %[[res3:.*]] = vector.insert %[[subvec3]], %[[res2]] [3] : vector<[2]xf32> into vector<6x[2]xf32> + // CHECK-NEXT: %[[res3:.*]] = vector.insert %[[subvec3]], %[[res2]] [1, 1] : vector<[2]xf32> into vector<3x2x[2]xf32> // CHECK-NEXT: %[[srcvec2:.*]] = vector.extract %[[arg0]][2] : vector<[4]xf32> from vector<3x[4]xf32> // CHECK-NEXT: %[[subvec4:.*]] = vector.scalable.extract %[[srcvec2]][0] : vector<[2]xf32> from vector<[4]xf32> - // CHECK-NEXT: %[[res4:.*]] = vector.insert %[[subvec4]], %[[res3]] [4] : vector<[2]xf32> into vector<6x[2]xf32> + // CHECK-NEXT: %[[res4:.*]] = vector.insert %[[subvec4]], %[[res3]] [2, 0] : vector<[2]xf32> into vector<3x2x[2]xf32> // CHECK-NEXT: %[[subvec5:.*]] = vector.scalable.extract %[[srcvec2]][2] : vector<[2]xf32> from vector<[4]xf32> - // CHECK-NEXT: %[[res5:.*]] = vector.insert %[[subvec5]], %[[res4]] [5] : vector<[2]xf32> into vector<6x[2]xf32> - %res = vector.shape_cast %arg0: vector<3x[4]xf32> to vector<6x[2]xf32> - // CHECK-NEXT: return %[[res5]] : vector<6x[2]xf32> - return %res: vector<6x[2]xf32> + // CHECK-NEXT: %[[res5:.*]] = vector.insert %[[subvec5]], %[[res4]] [2, 1] : vector<[2]xf32> into vector<3x2x[2]xf32> + %res = vector.shape_cast %arg0: vector<3x[4]xf32> to vector<3x2x[2]xf32> + // CHECK-NEXT: return %[[res5]] : vector<3x2x[2]xf32> + return %res: vector<3x2x[2]xf32> } // ----- // CHECK-LABEL: f32_increase_trailing_scalable_dim -// CHECK-SAME: %[[arg0:.*]]: vector<4x[2]xf32> -func.func @f32_increase_trailing_scalable_dim(%arg0: vector<4x[2]xf32>) -> vector<2x[4]xf32> +// CHECK-SAME: %[[arg0:.*]]: vector<2x2x[2]xf32> +func.func @f32_increase_trailing_scalable_dim(%arg0: vector<2x2x[2]xf32>) -> vector<2x[4]xf32> { // CHECK-DAG: %[[ub0:.*]] = ub.poison : vector<2x[4]xf32> // CHECK-DAG: %[[ub1:.*]] = ub.poison : vector<[4]xf32> - // CHECK-NEXT: %[[subvec0:.*]] = vector.extract %[[arg0]][0] : vector<[2]xf32> from vector<4x[2]xf32> + // CHECK-NEXT: %[[subvec0:.*]] = vector.extract %[[arg0]][0, 0] : vector<[2]xf32> from vector<2x2x[2]xf32> // CHECK-NEXT: %[[resvec1:.*]] = vector.scalable.insert %[[subvec0]], %[[ub1]][0] : vector<[2]xf32> into vector<[4]xf32> - // CHECK-NEXT: %[[subvec1:.*]] = vector.extract %[[arg0]][1] : vector<[2]xf32> from vector<4x[2]xf32> + // CHECK-NEXT: %[[subvec1:.*]] = vector.extract %[[arg0]][0, 1] : vector<[2]xf32> from vector<2x2x[2]xf32> // CHECK-NEXT: %[[resvec2:.*]] = vector.scalable.insert %[[subvec1]], %[[resvec1]][2] : vector<[2]xf32> into vector<[4]xf32> // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[resvec2]], %[[ub0]] [0] : vector<[4]xf32> into vector<2x[4]xf32> - // CHECK-NEXT: %[[subvec3:.*]] = vector.extract %[[arg0]][2] : vector<[2]xf32> from vector<4x[2]xf32> + // CHECK-NEXT: %[[subvec3:.*]] = vector.extract %[[arg0]][1, 0] : vector<[2]xf32> from vector<2x2x[2]xf32> // CHECK-NEXT: %[[resvec4:.*]] = vector.scalable.insert %[[subvec3]], %[[ub1]][0] : vector<[2]xf32> into vector<[4]xf32> - // CHECK-NEXT: %[[subvec4:.*]] = vector.extract %[[arg0]][3] : vector<[2]xf32> from vector<4x[2]xf32> + // CHECK-NEXT: %[[subvec4:.*]] = vector.extract %[[arg0]][1, 1] : vector<[2]xf32> from vector<2x2x[2]xf32> // CHECK-NEXT: %[[resvec5:.*]] = vector.scalable.insert %[[subvec4]], %[[resvec4]][2] : vector<[2]xf32> into vector<[4]xf32> // CHECK-NEXT: %[[res1:.*]] = vector.insert %[[resvec5]], %[[res0]] [1] : vector<[4]xf32> into vector<2x[4]xf32> - %res = vector.shape_cast %arg0: vector<4x[2]xf32> to vector<2x[4]xf32> + %res = vector.shape_cast %arg0: vector<2x2x[2]xf32> to vector<2x[4]xf32> // CHECK-NEXT: return %[[res1]] : vector<2x[4]xf32> return %res: vector<2x[4]xf32> } diff --git a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir index ef32f8c6a1cdb..fbfe3789b871b 100644 --- a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir @@ -57,27 +57,6 @@ func.func @shape_casts(%a: vector<2x2xf32>) -> (vector<4xf32>, vector<2x2xf32>) return %r0, %1 : vector<4xf32>, vector<2x2xf32> } -// CHECK-LABEL: func @shape_cast_2d2d -// CHECK-SAME: %[[A:.*]]: vector<3x2xf32> -// CHECK: %[[UB:.*]] = ub.poison : vector<2x3xf32> -// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : f32 from vector<3x2xf32> -// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[UB]] [0, 0] : f32 into vector<2x3xf32> -// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 1] : f32 from vector<3x2xf32> -// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 1] : f32 into vector<2x3xf32> -// CHECK: %[[T4:.*]] = vector.extract %[[A]][1, 0] : f32 from vector<3x2xf32> -// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [0, 2] : f32 into vector<2x3xf32> -// CHECK: %[[T6:.*]] = vector.extract %[[A]][1, 1] : f32 from vector<3x2xf32> -// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 0] : f32 into vector<2x3xf32> -// CHECK: %[[T8:.*]] = vector.extract %[[A]][2, 0] : f32 from vector<3x2xf32> -// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [1, 1] : f32 into vector<2x3xf32> -// CHECK: %[[T10:.*]] = vector.extract %[[A]][2, 1] : f32 from vector<3x2xf32> -// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [1, 2] : f32 into vector<2x3xf32> -// CHECK: return %[[T11]] : vector<2x3xf32> - -func.func @shape_cast_2d2d(%arg0 : vector<3x2xf32>) -> vector<2x3xf32> { - %s = vector.shape_cast %arg0: vector<3x2xf32> to vector<2x3xf32> - return %s : vector<2x3xf32> -} // CHECK-LABEL: func @shape_cast_3d1d // CHECK-SAME: %[[A:.*]]: vector<1x3x2xf32> From db1b71738607d73eb15c9fd14f94961dd6d0dc20 Mon Sep 17 00:00:00 2001 From: James Newling Date: Tue, 15 Apr 2025 18:56:31 -0700 Subject: [PATCH 4/6] tighten --- .../mlir/Dialect/Vector/IR/VectorOps.td | 9 +- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 95 +++++++++---------- mlir/test/Dialect/Vector/canonicalize.mlir | 16 ++-- mlir/test/Dialect/Vector/ops.mlir | 20 ++++ 4 files changed, 81 insertions(+), 59 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index a9e25f23ef90f..7d5b5048131d8 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -2248,10 +2248,11 @@ def Vector_ShapeCastOp : vector. The element type remains the same, as does the number of elements (product of dimensions). - If reducing or preserving rank (n >= k), all result dimension sizes must be - products of contiguous source dimension sizes. If expanding rank (n < k), - source dimensions must all factor into contiguous sequences of destination - dimension sizes. + A shape_cast must be either collapsing or expanding. Collapsing means all + result dimension sizes are products of contiguous source dimension sizes. + Expanding means source dimensions all factor into contiguous sequences of + destination dimension sizes. Size 1 dimensions in source and destination + are ignored. Each source dim is expanded (or contiguous sequence of source dims combined) in source dimension list order (i.e. 0 <= i < n), to produce a contiguous diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 07b6baf961a3e..d0797db00d528 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5532,41 +5532,34 @@ void ShapeCastOp::inferResultRanges(ArrayRef argRanges, setResultRanges(getResult(), argRanges.front()); } -/// Returns true if each element of 'a' is equal to the product of a contiguous -/// sequence of the elements of 'b'. Returns false otherwise. +/// Returns true if each element of 'a' is either 1 or equal to the product of a +/// contiguous sequence of the elements of 'b'. Returns false otherwise. +/// +/// This function assumes that the product of elements in a and b are the same. static bool isExpandingShapeCast(ArrayRef a, ArrayRef b) { - unsigned rankA = a.size(); - unsigned rankB = b.size(); - if (rankA > rankB) { - return false; - } - - auto isOne = [](int64_t v) { return v == 1; }; - - // Special-case for n-D to 0-d shape cast. 'b' must be all ones to be shape - // casted to a 0-d vector. - if (rankA == 0 && llvm::all_of(b, isOne)) - return true; + unsigned rankA = a.size(); unsigned i = 0; unsigned j = 0; - while (i < rankA && j < rankB) { + while (i < rankA) { + if (a[i] == 1) { + ++i; + continue; + } + int64_t dimA = a[i]; int64_t dimB = 1; - while (dimB < dimA && j < rankB) + + while (dimB < dimA) { dimB *= b[j++]; - if (dimA != dimB) - break; - ++i; + } - // Handle the case when trailing dimensions are of size 1. - // Include them into the contiguous sequence. - if (i < rankA && llvm::all_of(a.slice(i), isOne)) - i = rankA; - if (j < rankB && llvm::all_of(b.slice(j), isOne)) - j = rankB; + if (dimA != dimB) { + return false; + } + ++i; } - return i == rankA && j == rankB; + return true; } static bool isValidShapeCast(ArrayRef a, ArrayRef b) { @@ -5582,7 +5575,8 @@ static LogicalResult verifyVectorShapeCast(Operation *op, ArrayRef inShape = sourceVectorType.getShape(); ArrayRef outShape = resultVectorType.getShape(); - // Check that product of source dim sizes matches product of result dim sizes. + // Check that product of source dim sizes matches product of result dim + // sizes. int64_t nInElms = std::accumulate(inShape.begin(), inShape.end(), 1LL, std::multiplies{}); int64_t nOutElms = std::accumulate(outShape.begin(), outShape.end(), 1LL, @@ -5702,8 +5696,8 @@ static VectorType trimTrailingOneDims(VectorType oldType) { /// /// Looks at `vector.shape_cast` Ops that simply "drop" the trailing unit /// dimension. If the input vector comes from `vector.create_mask` for which -/// the corresponding mask input value is 1 (e.g. `%c1` below), then it is safe -/// to fold shape_cast into create_mask. +/// the corresponding mask input value is 1 (e.g. `%c1` below), then it is +/// safe to fold shape_cast into create_mask. /// /// BEFORE: /// %1 = vector.create_mask %c1, %dim, %c1, %c1 : vector<1x[4]x1x1xi1> @@ -5970,8 +5964,8 @@ LogicalResult TypeCastOp::verify() { auto resultType = getResultMemRefType(); if (getElementTypeOrSelf(getElementTypeOrSelf(sourceType)) != getElementTypeOrSelf(getElementTypeOrSelf(resultType))) - return emitOpError( - "expects result and operand with same underlying scalar type: ") + return emitOpError("expects result and operand with same underlying " + "scalar type: ") << resultType; if (extractShape(sourceType) != extractShape(resultType)) return emitOpError( @@ -6009,7 +6003,8 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) { return attr.reshape(getResultVectorType()); // Eliminate identity transpose ops. This happens when the dimensions of the - // input vector remain in their original order after the transpose operation. + // input vector remain in their original order after the transpose + // operation. ArrayRef perm = getPermutation(); // Check if the permutation of the dimensions contains sequential values: @@ -6068,7 +6063,8 @@ class TransposeFolder final : public OpRewritePattern { return result; }; - // Return if the input of 'transposeOp' is not defined by another transpose. + // Return if the input of 'transposeOp' is not defined by another + // transpose. vector::TransposeOp parentTransposeOp = transposeOp.getVector().getDefiningOp(); if (!parentTransposeOp) @@ -6212,8 +6208,9 @@ LogicalResult ConstantMaskOp::verify() { return emitOpError( "only supports 'none set' or 'all set' scalable dimensions"); } - // Verify that if one mask dim size is zero, they all should be zero (because - // the mask region is a conjunction of each mask dimension interval). + // Verify that if one mask dim size is zero, they all should be zero + // (because the mask region is a conjunction of each mask dimension + // interval). bool anyZeros = llvm::is_contained(maskDimSizes, 0); bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) { return s == 0; }); if (anyZeros && !allZeros) @@ -6251,7 +6248,8 @@ void CreateMaskOp::build(OpBuilder &builder, OperationState &result, LogicalResult CreateMaskOp::verify() { auto vectorType = llvm::cast(getResult().getType()); - // Verify that an operand was specified for each result vector each dimension. + // Verify that an operand was specified for each result vector each + // dimension. if (vectorType.getRank() == 0) { if (getNumOperands() != 1) return emitOpError( @@ -6458,8 +6456,8 @@ void mlir::vector::MaskOp::print(OpAsmPrinter &p) { void MaskOp::ensureTerminator(Region ®ion, Builder &builder, Location loc) { OpTrait::SingleBlockImplicitTerminator::Impl< MaskOp>::ensureTerminator(region, builder, loc); - // Keep the default yield terminator if the number of masked operations is not - // the expected. This case will trigger a verification failure. + // Keep the default yield terminator if the number of masked operations is + // not the expected. This case will trigger a verification failure. Block &block = region.front(); if (block.getOperations().size() != 2) return; @@ -6563,9 +6561,9 @@ LogicalResult MaskOp::fold(FoldAdaptor adaptor, return success(); } -// Elides empty vector.mask operations with or without return values. Propagates -// the yielded values by the vector.yield terminator, if any, or erases the op, -// otherwise. +// Elides empty vector.mask operations with or without return values. +// Propagates the yielded values by the vector.yield terminator, if any, or +// erases the op, otherwise. class ElideEmptyMaskOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -6668,7 +6666,8 @@ OpFoldResult SplatOp::fold(FoldAdaptor adaptor) { if (!isa_and_nonnull(constOperand)) return {}; - // SplatElementsAttr::get treats single value for second arg as being a splat. + // SplatElementsAttr::get treats single value for second arg as being a + // splat. return SplatElementsAttr::get(getType(), {constOperand}); } @@ -6790,12 +6789,12 @@ Operation *mlir::vector::maskOperation(OpBuilder &builder, } /// Creates a vector select operation that picks values from `newValue` or -/// `passthru` for each result vector lane based on `mask`. This utility is used -/// to propagate the pass-thru value of vector.mask or for cases where only the -/// pass-thru value propagation is needed. VP intrinsics do not support -/// pass-thru values and every mask-out lane is set to poison. LLVM backends are -/// usually able to match op + select patterns and fold them into a native -/// target instructions. +/// `passthru` for each result vector lane based on `mask`. This utility is +/// used to propagate the pass-thru value of vector.mask or for cases where +/// only the pass-thru value propagation is needed. VP intrinsics do not +/// support pass-thru values and every mask-out lane is set to poison. LLVM +/// backends are usually able to match op + select patterns and fold them into +/// a native target instructions. Value mlir::vector::selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru) { if (!mask) diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 8d24e1bf2ba94..986e11d948052 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -950,14 +950,16 @@ func.func @insert_no_fold_scalar_to_0d(%v: vector) -> vector { // ----- +// The definition of shape_cast stipulates that it must be either expanding or collapsing, +// it cannot be a mixture of both. // CHECK-LABEL: dont_fold_expand_collapse -// CHECK: %[[A:.*]] = vector.shape_cast %{{.*}} : vector<1x1x64xf32> to vector<1x1x8x8xf32> -// CHECK: %[[B:.*]] = vector.shape_cast %{{.*}} : vector<1x1x8x8xf32> to vector<8x8xf32> -// CHECK: return %[[B]] : vector<8x8xf32> -func.func @dont_fold_expand_collapse(%arg0: vector<1x1x64xf32>) -> vector<8x8xf32> { - %0 = vector.shape_cast %arg0 : vector<1x1x64xf32> to vector<1x1x8x8xf32> - %1 = vector.shape_cast %0 : vector<1x1x8x8xf32> to vector<8x8xf32> - return %1 : vector<8x8xf32> +// CHECK: %[[A:.*]] = vector.shape_cast %{{.*}} : vector<2x2x9xf32> to vector<2x2x3x3xf32> +// CHECK: %[[B:.*]] = vector.shape_cast %{{.*}} : vector<2x2x3x3xf32> to vector<4x3x3xf32> +// CHECK: return %[[B]] : vector<4x3x3xf32> +func.func @dont_fold_expand_collapse(%arg0: vector<2x2x9xf32>) -> vector<4x3x3xf32> { + %0 = vector.shape_cast %arg0 : vector<2x2x9xf32> to vector<2x2x3x3xf32> + %1 = vector.shape_cast %0 : vector<2x2x3x3xf32> to vector<4x3x3xf32> + return %1 : vector<4x3x3xf32> } // ----- diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir index 527bccf8383ca..504c6c300e9f0 100644 --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -581,9 +581,29 @@ func.func @shape_cast_rank_preserving(%arg0 : vector<1x4xf32>) -> vector<4x1xf32 // CHECK: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4x1xf32> %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4x1xf32> + return %0 : vector<4x1xf32> } + +// CHECK-LABEL: @collapse_but_increase_rank +func.func @collapse_but_increase_rank(%arg0 : vector<2x3x5x7xf32>) -> vector<1x6x1x35x1xf32> { + + // CHECK: vector.shape_cast %{{.*}} : vector<2x3x5x7xf32> to vector<1x6x1x35x1xf32> + %0 = vector.shape_cast %arg0 : vector<2x3x5x7xf32> to vector<1x6x1x35x1xf32> + + return %0 : vector<1x6x1x35x1xf32> +} + +// CHECK-LABEL: @expand_but_decrease_rank +func.func @expand_but_decrease_rank(%arg0 : vector<1x1x6xi8>) -> vector<2x3xi8> { + + // CHECK: vector.shape_cast %{{.*}} : vector<1x1x6xi8> to vector<2x3xi8> + %0 = vector.shape_cast %arg0 : vector<1x1x6xi8> to vector<2x3xi8> + + return %0 : vector<2x3xi8> +} + // CHECK-LABEL: @bitcast func.func @bitcast(%arg0 : vector<5x1x3x2xf32>, %arg1 : vector<8x1xi32>, From ed2b46ad841b06b5788e8ea7653d701d55c1b761 Mon Sep 17 00:00:00 2001 From: James Newling Date: Tue, 15 Apr 2025 18:58:23 -0700 Subject: [PATCH 5/6] fix name --- mlir/test/Dialect/Vector/invalid.mlir | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index 9f94fb0574504..45b7b44d47039 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1145,7 +1145,7 @@ func.func @shape_cast_wrong_num_elements(%arg0 : vector<5x1x3x2xf32>) { // ----- -func.func @shape_cast_invalid_rank_preservating(%arg0 : vector<3x2xf32>) { +func.func @shape_cast_invalid_rank_preserving(%arg0 : vector<3x2xf32>) { // expected-error@+1 {{op is invalid (does not uniformly collapse or expand)}} %0 = vector.shape_cast %arg0 : vector<3x2xf32> to vector<2x3xf32> } From 19cac4f70c0501b860342e4eb3b7c13c40a3959f Mon Sep 17 00:00:00 2001 From: James Newling Date: Tue, 15 Apr 2025 19:05:48 -0700 Subject: [PATCH 6/6] undo formatting noise --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 45 +++++++++++------------- 1 file changed, 20 insertions(+), 25 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index d0797db00d528..20162e93c88e8 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5696,8 +5696,8 @@ static VectorType trimTrailingOneDims(VectorType oldType) { /// /// Looks at `vector.shape_cast` Ops that simply "drop" the trailing unit /// dimension. If the input vector comes from `vector.create_mask` for which -/// the corresponding mask input value is 1 (e.g. `%c1` below), then it is -/// safe to fold shape_cast into create_mask. +/// the corresponding mask input value is 1 (e.g. `%c1` below), then it is safe +/// to fold shape_cast into create_mask. /// /// BEFORE: /// %1 = vector.create_mask %c1, %dim, %c1, %c1 : vector<1x[4]x1x1xi1> @@ -5964,8 +5964,8 @@ LogicalResult TypeCastOp::verify() { auto resultType = getResultMemRefType(); if (getElementTypeOrSelf(getElementTypeOrSelf(sourceType)) != getElementTypeOrSelf(getElementTypeOrSelf(resultType))) - return emitOpError("expects result and operand with same underlying " - "scalar type: ") + return emitOpError( + "expects result and operand with same underlying scalar type: ") << resultType; if (extractShape(sourceType) != extractShape(resultType)) return emitOpError( @@ -6003,8 +6003,7 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) { return attr.reshape(getResultVectorType()); // Eliminate identity transpose ops. This happens when the dimensions of the - // input vector remain in their original order after the transpose - // operation. + // input vector remain in their original order after the transpose operation. ArrayRef perm = getPermutation(); // Check if the permutation of the dimensions contains sequential values: @@ -6063,8 +6062,7 @@ class TransposeFolder final : public OpRewritePattern { return result; }; - // Return if the input of 'transposeOp' is not defined by another - // transpose. + // Return if the input of 'transposeOp' is not defined by another transpose. vector::TransposeOp parentTransposeOp = transposeOp.getVector().getDefiningOp(); if (!parentTransposeOp) @@ -6208,9 +6206,8 @@ LogicalResult ConstantMaskOp::verify() { return emitOpError( "only supports 'none set' or 'all set' scalable dimensions"); } - // Verify that if one mask dim size is zero, they all should be zero - // (because the mask region is a conjunction of each mask dimension - // interval). + // Verify that if one mask dim size is zero, they all should be zero (because + // the mask region is a conjunction of each mask dimension interval). bool anyZeros = llvm::is_contained(maskDimSizes, 0); bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) { return s == 0; }); if (anyZeros && !allZeros) @@ -6248,8 +6245,7 @@ void CreateMaskOp::build(OpBuilder &builder, OperationState &result, LogicalResult CreateMaskOp::verify() { auto vectorType = llvm::cast(getResult().getType()); - // Verify that an operand was specified for each result vector each - // dimension. + // Verify that an operand was specified for each result vector each dimension. if (vectorType.getRank() == 0) { if (getNumOperands() != 1) return emitOpError( @@ -6457,7 +6453,7 @@ void MaskOp::ensureTerminator(Region ®ion, Builder &builder, Location loc) { OpTrait::SingleBlockImplicitTerminator::Impl< MaskOp>::ensureTerminator(region, builder, loc); // Keep the default yield terminator if the number of masked operations is - // not the expected. This case will trigger a verification failure. + // not as expected. This case will trigger a verification failure. Block &block = region.front(); if (block.getOperations().size() != 2) return; @@ -6561,9 +6557,9 @@ LogicalResult MaskOp::fold(FoldAdaptor adaptor, return success(); } -// Elides empty vector.mask operations with or without return values. -// Propagates the yielded values by the vector.yield terminator, if any, or -// erases the op, otherwise. +// Elides empty vector.mask operations with or without return values. Propagates +// the yielded values by the vector.yield terminator, if any, or erases the op, +// otherwise. class ElideEmptyMaskOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -6666,8 +6662,7 @@ OpFoldResult SplatOp::fold(FoldAdaptor adaptor) { if (!isa_and_nonnull(constOperand)) return {}; - // SplatElementsAttr::get treats single value for second arg as being a - // splat. + // SplatElementsAttr::get treats single value for second arg as being a splat. return SplatElementsAttr::get(getType(), {constOperand}); } @@ -6789,12 +6784,12 @@ Operation *mlir::vector::maskOperation(OpBuilder &builder, } /// Creates a vector select operation that picks values from `newValue` or -/// `passthru` for each result vector lane based on `mask`. This utility is -/// used to propagate the pass-thru value of vector.mask or for cases where -/// only the pass-thru value propagation is needed. VP intrinsics do not -/// support pass-thru values and every mask-out lane is set to poison. LLVM -/// backends are usually able to match op + select patterns and fold them into -/// a native target instructions. +/// `passthru` for each result vector lane based on `mask`. This utility is used +/// to propagate the pass-thru value of vector.mask or for cases where only the +/// pass-thru value propagation is needed. VP intrinsics do not support +/// pass-thru values and every mask-out lane is set to poison. LLVM backends are +/// usually able to match op + select patterns and fold them into a native +/// target instructions. Value mlir::vector::selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru) { if (!mask)