-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][Tensor] Generalize the pattern to swap tensor.collapse_shape
-> tensor.expand_shape
.
#133819
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Tensor] Generalize the pattern to swap tensor.collapse_shape
-> tensor.expand_shape
.
#133819
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-tensor Author: None (MaheshRavishankar) ChangesThe current patterns compared the reassocation indices for the two ops and failed if neither of them were of size 1. This patch relaxes this restriction by handling a new case where the reassociation indices might be of the same size. Also generalizes to cases where when generating the swapped Full diff: https://github.com/llvm/llvm-project/pull/133819.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
index acedf51d0e240..2542039267f01 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
@@ -166,10 +166,39 @@ struct BubbleUpExpandThroughParallelCollapse
return failure();
}
- // Reshapes are parallel to each other if none of the reassociation indices
- // have greater than 1 index for both reshapes.
+ // Reshapes are parallel to each other (by construction the number of
+ // reassociations specified in the collapse and expand are the same), if at
+ // any position
+ // 1. either the reassociation indices are of the same size, or
+ // 2. either the reassociation in the collapse or the expand is of size 1.
+ ArrayRef<int64_t> staticSourceSize = collapseOp.getSrcType().getShape();
+ ArrayRef<int64_t> staticResultSize = expandOp.getStaticOutputShape();
for (auto [expandReassociation, collapseReassociation] :
llvm::zip_equal(expandReInds, collapseReInds)) {
+ if (collapseReassociation.size() == expandReassociation.size()) {
+ // Even if the reassociations are the same, the collapse/expand should
+ // result in the same dimensions. i.e 4x8x2 into 64 should be expanded
+ // into 4x8x2 again. In presense of dynamic dimensions one can only
+ // verify "equality" when there is only one dynamic dimension present,
+ // and all other static dimensions are equal.
+ ArrayRef<int64_t> collapsedStaticShapes = staticSourceSize.slice(
+ collapseReassociation.front(), collapseReassociation.size());
+ int64_t numCollapsedDynamic =
+ llvm::count_if(collapsedStaticShapes,
+ [](int64_t d) { return ShapedType::isDynamic(d); });
+ ArrayRef<int64_t> expandedStaticShapes = staticResultSize.slice(
+ expandReassociation.front(), expandReassociation.size());
+ int64_t numExpandedDynamic =
+ llvm::count_if(expandedStaticShapes,
+ [](int64_t d) { return ShapedType::isDynamic(d); });
+ if (numCollapsedDynamic > 1 || numExpandedDynamic > 1 ||
+ collapsedStaticShapes != expandedStaticShapes) {
+ return failure();
+ }
+ continue;
+ }
+ // If the reassociations are not same, one or the other needs to be of
+ // size one.
if (collapseReassociation.size() != 1 && expandReassociation.size() != 1)
return failure();
}
@@ -177,33 +206,61 @@ struct BubbleUpExpandThroughParallelCollapse
// Compute new reassociation indices and expanded/collaped shapes.
SmallVector<ReassociationIndices> newExpandReInds, newCollapseReInds;
Location loc = expandOp->getLoc();
- SmallVector<OpFoldResult> collapseSizes =
+ SmallVector<OpFoldResult> sourceSizes =
tensor::getMixedSizes(rewriter, loc, collapseOp.getSrc());
- SmallVector<OpFoldResult> expandSizes(getMixedValues(
- expandOp.getStaticOutputShape(), expandOp.getOutputShape(), rewriter));
+ SmallVector<OpFoldResult> resultSizes = expandOp.getMixedOutputShape();
SmallVector<OpFoldResult> newExpandSizes;
- int64_t index = 0, expandIndex = 0, collapseIndex = 0;
- for (auto [idx, collapseReassociation] : llvm::enumerate(collapseReInds)) {
+
+ int64_t newExpandIndex = 0, newCollapseIndex = 0, sourceSizeIndex = 0,
+ resultSizeIndex = 0;
+
+ for (size_t idx = 0, idx_end = collapseReInds.size(); idx < idx_end;
+ idx++) {
+ auto collapseReassociation = collapseReInds[idx];
+ auto expandReassociation = expandReInds[idx];
+
+ // Case 1. The reassociations are same in the collapse producer
+ // and expand consumer. In the swapped expand, each of the final
+ // dimensions are kept as is in the expand and the collapse. So,
+ // for every element in the `ReassocationIndices` vector add a new
+ // `ReassociationIndices` vector for the swapped expand and collapse
+ // (of size 1).
+ if (collapseReassociation.size() == expandReassociation.size()) {
+ for (size_t i = 0; i < collapseReassociation.size(); ++i) {
+ newCollapseReInds.push_back({newCollapseIndex++});
+ newExpandReInds.push_back({newExpandIndex++});
+ newExpandSizes.push_back(resultSizes[resultSizeIndex++]);
+ sourceSizeIndex++;
+ }
+ continue;
+ }
+
+ // Case 2. The `ReassociationIndices` in the collapse is of size > 1 (and
+ // in the expand is of size == 1). In this case, the original dimensions
+ // are preserved on expansion and collapsed subsequently.
if (collapseReassociation.size() != 1) {
ReassociationIndices newCollapseReassociation;
for (size_t i = 0; i < collapseReassociation.size(); ++i) {
- newCollapseReassociation.push_back(index);
- newExpandReInds.push_back({index++});
- newExpandSizes.push_back(collapseSizes[collapseIndex++]);
+ newCollapseReassociation.push_back(newCollapseIndex++);
+ newExpandReInds.push_back({newExpandIndex++});
+ newExpandSizes.push_back(sourceSizes[sourceSizeIndex++]);
}
+ resultSizeIndex++;
newCollapseReInds.push_back(newCollapseReassociation);
- expandIndex++;
continue;
}
+
+ // Case 3. The `ReassociationIndices` in the expand is of size > 1 (and
+ // in the collapse is of size == 1). In this case, the expansion happens
+ // first and the expanded dimensions are preserved on collapse.
ReassociationIndices newExpandReassociation;
- auto expandReassociation = expandReInds[idx];
for (size_t i = 0; i < expandReassociation.size(); ++i) {
- newExpandReassociation.push_back(index);
- newCollapseReInds.push_back({index++});
- newExpandSizes.push_back(expandSizes[expandIndex++]);
+ newExpandReassociation.push_back(newExpandIndex++);
+ newCollapseReInds.push_back({newCollapseIndex++});
+ newExpandSizes.push_back(resultSizes[resultSizeIndex++]);
}
newExpandReInds.push_back(newExpandReassociation);
- collapseIndex++;
+ sourceSizeIndex++;
}
// Swap reshape order.
@@ -211,11 +268,25 @@ struct BubbleUpExpandThroughParallelCollapse
SmallVector<int64_t> staticSizes;
dispatchIndexOpFoldResults(newExpandSizes, dynamicSizes, staticSizes);
auto expandResultType = expandOp.getResultType().clone(staticSizes);
- auto newExpand = rewriter.create<tensor::ExpandShapeOp>(
- loc, expandResultType, collapseOp.getSrc(), newExpandReInds,
- newExpandSizes);
- rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
- expandOp, newExpand.getResult(), newCollapseReInds);
+ Value newCollapseSrc = collapseOp.getSrc();
+ // If the number of reassociation indices in the new `expand_shape` op
+ // matches the number of dimensions of the result, then the expand_shape
+ // is a no-op.
+ if (newExpandReInds.size() != newExpandSizes.size()) {
+ newCollapseSrc = rewriter.create<tensor::ExpandShapeOp>(
+ loc, expandResultType, newCollapseSrc, newExpandReInds,
+ newExpandSizes);
+ }
+
+ // If the number of reassociation indices in the new `collapse_shape` op
+ // matches the number of dimensions of the source, then the collapse_shape
+ // is a no-op.
+ Value replacement = newCollapseSrc;
+ if (newCollapseReInds.size() != newExpandSizes.size()) {
+ replacement = rewriter.create<tensor::CollapseShapeOp>(
+ loc, newCollapseSrc, newCollapseReInds);
+ }
+ rewriter.replaceOp(expandOp, replacement);
return success();
}
};
diff --git a/mlir/test/Dialect/Tensor/bubble-reshapes.mlir b/mlir/test/Dialect/Tensor/bubble-reshapes.mlir
index eeed794884942..1a277af96c6f3 100644
--- a/mlir/test/Dialect/Tensor/bubble-reshapes.mlir
+++ b/mlir/test/Dialect/Tensor/bubble-reshapes.mlir
@@ -48,14 +48,67 @@ func.func @no_bubble_partial_intersecting_reshapes(%arg0: tensor<?x?x?x?xf32>, %
// -----
-func.func @no_bubble_0d_tensor_reshapes(%arg0: tensor<?xf32>, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor<?x?x?x?xf32> {
- %collapse = tensor.collapse_shape %arg0 [] : tensor<?xf32> into tensor<f32>
+func.func @no_bubble_0d_tensor_reshapes(%arg0: tensor<1x1xf32>) -> tensor<1x1x1xf32> {
+ %collapse = tensor.collapse_shape %arg0 [] : tensor<1x1xf32> into tensor<f32>
%expand = tensor.expand_shape %collapse []
- output_shape [%s0, %s1, %s2, %s3] : tensor<f32> into tensor<?x?x?x?xf32>
- return %expand : tensor<?x?x?x?xf32>
+ output_shape [1, 1, 1] : tensor<f32> into tensor<1x1x1xf32>
+ return %expand : tensor<1x1x1xf32>
}
// CHECK: func @no_bubble_0d_tensor_reshapes
-// CHECK-SAME: %[[ARG0:.+]]: tensor<?xf32>
+// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1xf32>
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}]
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[COLLAPSE]] {{\[}}]
// CHECK: return %[[EXPAND]]
+
+// -----
+
+// Test the case where the reassocation indices in the collapse and expand
+// are of same size.
+func.func @bubble_expand_match_non_unit_size_reassocation(
+ %arg0 : tensor<4x?x4x32x4x?xf16>, %arg1 : index, %arg2 : index) -> tensor<4x?x4x128x?x32xf16> {
+ %collapsed = tensor.collapse_shape %arg0 [[0, 1, 2], [3, 4], [5]]
+ : tensor<4x?x4x32x4x?xf16> into tensor<?x128x?xf16>
+ %expanded = tensor.expand_shape %collapsed [[0, 1, 2], [3], [4, 5]] output_shape [4, %arg1, 4, 128, %arg2, 32]
+ : tensor<?x128x?xf16> into tensor<4x?x4x128x?x32xf16>
+ return %expanded : tensor<4x?x4x128x?x32xf16>
+}
+// CHECK: func @bubble_expand_match_non_unit_size_reassocation
+// CHECK-SAME: %[[ARG0:.+]]: tensor<4x?x4x32x4x?xf16>
+// CHECK-SAME: %[[ARG1:[a-zA-z0-9]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-z0-9]+]]: index
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]]
+// CHECK-SAME: {{\[}}[0], [1], [2], [3], [4], [5, 6]{{\]}}
+// CHECK-SAME: [4, %[[ARG1]], 4, 32, 4, %[[ARG2]], 32]
+// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[EXPANDED]]
+// CHECK-SAME: {{\[}}[0], [1], [2], [3, 4], [5], [6]{{\]}}
+// CHECK: return %[[COLLAPSED]]
+
+// -----
+
+// Test the case where the trailing collapse isnt needed.
+func.func @no_collapse_generated(
+ %arg0 : tensor<4x?x4x128x?xf16>, %arg1 : index, %arg2 : index) -> tensor<4x?x4x128x?x32xf16> {
+ %collapsed = tensor.collapse_shape %arg0 [[0, 1, 2], [3], [4]]
+ : tensor<4x?x4x128x?xf16> into tensor<?x128x?xf16>
+ %expanded = tensor.expand_shape %collapsed [[0, 1, 2], [3], [4, 5]] output_shape [4, %arg1, 4, 128, %arg2, 32]
+ : tensor<?x128x?xf16> into tensor<4x?x4x128x?x32xf16>
+ return %expanded : tensor<4x?x4x128x?x32xf16>
+}
+// CHECK: func @no_collapse_generated
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape
+// CHECK: return %[[EXPANDED]]
+
+// -----
+
+// Test the case where the trailing collapse isnt needed.
+func.func @no_expand_generated(
+ %arg0 : tensor<4x?x4x128x?x?x?xf16>, %arg1 : index, %arg2 : index, %arg3 : index) -> tensor<4x?x4x128x?x?xf16> {
+ %collapsed = tensor.collapse_shape %arg0 [[0, 1, 2], [3], [4], [5, 6]]
+ : tensor<4x?x4x128x?x?x?xf16> into tensor<?x128x?x?xf16>
+ %expanded = tensor.expand_shape %collapsed [[0, 1, 2], [3], [4], [5]] output_shape [4, %arg1, 4, 128, %arg2, %arg3]
+ : tensor<?x128x?x?xf16> into tensor<4x?x4x128x?x?xf16>
+ return %expanded : tensor<4x?x4x128x?x?xf16>
+}
+// CHECK: func @no_expand_generated
+// CHECK: %[[EXPANDED:.+]] = tensor.collapse_shape
+// CHECK: return %[[EXPANDED]]
|
int64_t newExpandIndex = 0, newCollapseIndex = 0, sourceSizeIndex = 0, | ||
resultSizeIndex = 0; | ||
|
||
for (size_t idx = 0, idx_end = collapseReInds.size(); idx < idx_end; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason you are creating a new variable idx_end
? (also should be camelBack)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed to camelCase. I tried to do
for (auto &[collapsedReassocation, expandReassocation] : llvm::zip_equal(collapseReInds, expandReInds))
and it didnt compile. Dont know why.
Signed-off-by: MaheshRavishankar <[email protected]>
9c9da8b
to
44631b4
Compare
✅ With the latest revision this PR passed the C/C++ code formatter. |
… -> `tensor.expand_shape`. The current patterns compared the reassocation indices for the two ops and failed if neither of them were of size 1. This patch relaxes this restriction by handling a new case where the reassociation indices might be of the same size. Also generalizes to cases where when generating the swapped `tensor.expand_shape` -> `tensor.collapse_shape` if one of them is degenerate, those are not generated. Signed-off-by: MaheshRavishankar <[email protected]> Signed-off-by: MaheshRavishankar <[email protected]>
44631b4
to
13e48f3
Compare
… -> `tensor.expand_shape`. (llvm#133819) The current patterns compared the reassocation indices for the two ops and failed if neither of them were of size 1. This patch relaxes this restriction by handling a new case where the reassociation indices might be of the same size. Also generalizes to cases where when generating the swapped `tensor.expand_shape` -> `tensor.collapse_shape` if one of them is degenerate, those are not generated. Signed-off-by: MaheshRavishankar <[email protected]>
The current patterns compared the reassocation indices for the two ops and failed if neither of them were of size 1. This patch relaxes this restriction by handling a new case where the reassociation indices might be of the same size.
Also generalizes to cases where when generating the swapped
tensor.expand_shape
->tensor.collapse_shape
if one of them is degenerate, those are not generated.