Skip to content

[mlir][Tensor] Generalize the pattern to swap tensor.collapse_shape -> tensor.expand_shape. #133819

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

MaheshRavishankar
Copy link
Contributor

The current patterns compared the reassocation indices for the two ops and failed if neither of them were of size 1. This patch relaxes this restriction by handling a new case where the reassociation indices might be of the same size.

Also generalizes to cases where when generating the swapped tensor.expand_shape -> tensor.collapse_shape if one of them is degenerate, those are not generated.

@llvmbot
Copy link
Member

llvmbot commented Mar 31, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tensor

Author: None (MaheshRavishankar)

Changes

The current patterns compared the reassocation indices for the two ops and failed if neither of them were of size 1. This patch relaxes this restriction by handling a new case where the reassociation indices might be of the same size.

Also generalizes to cases where when generating the swapped tensor.expand_shape -> tensor.collapse_shape if one of them is degenerate, those are not generated.


Full diff: https://github.com/llvm/llvm-project/pull/133819.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp (+92-21)
  • (modified) mlir/test/Dialect/Tensor/bubble-reshapes.mlir (+58-5)
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
index acedf51d0e240..2542039267f01 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
@@ -166,10 +166,39 @@ struct BubbleUpExpandThroughParallelCollapse
       return failure();
     }
 
-    // Reshapes are parallel to each other if none of the reassociation indices
-    // have greater than 1 index for both reshapes.
+    // Reshapes are parallel to each other (by construction the number of
+    // reassociations specified in the collapse and expand are the same), if at
+    // any position
+    // 1. either the reassociation indices are of the same size, or
+    // 2. either the reassociation in the collapse or the expand is of size 1.
+    ArrayRef<int64_t> staticSourceSize = collapseOp.getSrcType().getShape();
+    ArrayRef<int64_t> staticResultSize = expandOp.getStaticOutputShape();
     for (auto [expandReassociation, collapseReassociation] :
          llvm::zip_equal(expandReInds, collapseReInds)) {
+      if (collapseReassociation.size() == expandReassociation.size()) {
+        // Even if the reassociations are the same, the collapse/expand should
+        // result in the same dimensions. i.e  4x8x2 into 64 should be expanded
+        // into 4x8x2 again. In presense of dynamic dimensions one can only
+        // verify "equality" when there is only one dynamic dimension present,
+        // and all other static dimensions are equal.
+        ArrayRef<int64_t> collapsedStaticShapes = staticSourceSize.slice(
+            collapseReassociation.front(), collapseReassociation.size());
+        int64_t numCollapsedDynamic =
+            llvm::count_if(collapsedStaticShapes,
+                           [](int64_t d) { return ShapedType::isDynamic(d); });
+        ArrayRef<int64_t> expandedStaticShapes = staticResultSize.slice(
+            expandReassociation.front(), expandReassociation.size());
+        int64_t numExpandedDynamic =
+            llvm::count_if(expandedStaticShapes,
+                           [](int64_t d) { return ShapedType::isDynamic(d); });
+        if (numCollapsedDynamic > 1 || numExpandedDynamic > 1 ||
+            collapsedStaticShapes != expandedStaticShapes) {
+          return failure();
+        }
+        continue;
+      }
+      // If the reassociations are not same, one or the other needs to be of
+      // size one.
       if (collapseReassociation.size() != 1 && expandReassociation.size() != 1)
         return failure();
     }
@@ -177,33 +206,61 @@ struct BubbleUpExpandThroughParallelCollapse
     // Compute new reassociation indices and expanded/collaped shapes.
     SmallVector<ReassociationIndices> newExpandReInds, newCollapseReInds;
     Location loc = expandOp->getLoc();
-    SmallVector<OpFoldResult> collapseSizes =
+    SmallVector<OpFoldResult> sourceSizes =
         tensor::getMixedSizes(rewriter, loc, collapseOp.getSrc());
-    SmallVector<OpFoldResult> expandSizes(getMixedValues(
-        expandOp.getStaticOutputShape(), expandOp.getOutputShape(), rewriter));
+    SmallVector<OpFoldResult> resultSizes = expandOp.getMixedOutputShape();
     SmallVector<OpFoldResult> newExpandSizes;
-    int64_t index = 0, expandIndex = 0, collapseIndex = 0;
-    for (auto [idx, collapseReassociation] : llvm::enumerate(collapseReInds)) {
+
+    int64_t newExpandIndex = 0, newCollapseIndex = 0, sourceSizeIndex = 0,
+            resultSizeIndex = 0;
+
+    for (size_t idx = 0, idx_end = collapseReInds.size(); idx < idx_end;
+         idx++) {
+      auto collapseReassociation = collapseReInds[idx];
+      auto expandReassociation = expandReInds[idx];
+
+      // Case 1. The reassociations are same in the collapse producer
+      // and expand consumer. In the swapped expand, each of the final
+      // dimensions are kept as is in the expand and the collapse. So,
+      // for every element in the `ReassocationIndices` vector add a new
+      // `ReassociationIndices` vector for the swapped expand and collapse
+      // (of size 1).
+      if (collapseReassociation.size() == expandReassociation.size()) {
+        for (size_t i = 0; i < collapseReassociation.size(); ++i) {
+          newCollapseReInds.push_back({newCollapseIndex++});
+          newExpandReInds.push_back({newExpandIndex++});
+          newExpandSizes.push_back(resultSizes[resultSizeIndex++]);
+          sourceSizeIndex++;
+        }
+        continue;
+      }
+
+      // Case 2. The `ReassociationIndices` in the collapse is of size > 1 (and
+      // in the expand is of size == 1). In this case, the original dimensions
+      // are preserved on expansion and collapsed subsequently.
       if (collapseReassociation.size() != 1) {
         ReassociationIndices newCollapseReassociation;
         for (size_t i = 0; i < collapseReassociation.size(); ++i) {
-          newCollapseReassociation.push_back(index);
-          newExpandReInds.push_back({index++});
-          newExpandSizes.push_back(collapseSizes[collapseIndex++]);
+          newCollapseReassociation.push_back(newCollapseIndex++);
+          newExpandReInds.push_back({newExpandIndex++});
+          newExpandSizes.push_back(sourceSizes[sourceSizeIndex++]);
         }
+        resultSizeIndex++;
         newCollapseReInds.push_back(newCollapseReassociation);
-        expandIndex++;
         continue;
       }
+
+      // Case 3. The `ReassociationIndices` in the expand is of size > 1 (and
+      // in the collapse is of size == 1). In this case, the expansion happens
+      // first and the expanded dimensions are preserved on collapse.
       ReassociationIndices newExpandReassociation;
-      auto expandReassociation = expandReInds[idx];
       for (size_t i = 0; i < expandReassociation.size(); ++i) {
-        newExpandReassociation.push_back(index);
-        newCollapseReInds.push_back({index++});
-        newExpandSizes.push_back(expandSizes[expandIndex++]);
+        newExpandReassociation.push_back(newExpandIndex++);
+        newCollapseReInds.push_back({newCollapseIndex++});
+        newExpandSizes.push_back(resultSizes[resultSizeIndex++]);
       }
       newExpandReInds.push_back(newExpandReassociation);
-      collapseIndex++;
+      sourceSizeIndex++;
     }
 
     // Swap reshape order.
@@ -211,11 +268,25 @@ struct BubbleUpExpandThroughParallelCollapse
     SmallVector<int64_t> staticSizes;
     dispatchIndexOpFoldResults(newExpandSizes, dynamicSizes, staticSizes);
     auto expandResultType = expandOp.getResultType().clone(staticSizes);
-    auto newExpand = rewriter.create<tensor::ExpandShapeOp>(
-        loc, expandResultType, collapseOp.getSrc(), newExpandReInds,
-        newExpandSizes);
-    rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
-        expandOp, newExpand.getResult(), newCollapseReInds);
+    Value newCollapseSrc = collapseOp.getSrc();
+    // If the number of reassociation indices in the new `expand_shape` op
+    // matches the number of dimensions of the result, then the expand_shape
+    // is a no-op.
+    if (newExpandReInds.size() != newExpandSizes.size()) {
+      newCollapseSrc = rewriter.create<tensor::ExpandShapeOp>(
+          loc, expandResultType, newCollapseSrc, newExpandReInds,
+          newExpandSizes);
+    }
+
+    // If the number of reassociation indices in the new `collapse_shape` op
+    // matches the number of dimensions of the source, then the collapse_shape
+    // is a no-op.
+    Value replacement = newCollapseSrc;
+    if (newCollapseReInds.size() != newExpandSizes.size()) {
+      replacement = rewriter.create<tensor::CollapseShapeOp>(
+          loc, newCollapseSrc, newCollapseReInds);
+    }
+    rewriter.replaceOp(expandOp, replacement);
     return success();
   }
 };
diff --git a/mlir/test/Dialect/Tensor/bubble-reshapes.mlir b/mlir/test/Dialect/Tensor/bubble-reshapes.mlir
index eeed794884942..1a277af96c6f3 100644
--- a/mlir/test/Dialect/Tensor/bubble-reshapes.mlir
+++ b/mlir/test/Dialect/Tensor/bubble-reshapes.mlir
@@ -48,14 +48,67 @@ func.func @no_bubble_partial_intersecting_reshapes(%arg0: tensor<?x?x?x?xf32>, %
 
 // -----
 
-func.func @no_bubble_0d_tensor_reshapes(%arg0: tensor<?xf32>, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor<?x?x?x?xf32> {
-  %collapse = tensor.collapse_shape %arg0 [] : tensor<?xf32> into tensor<f32>
+func.func @no_bubble_0d_tensor_reshapes(%arg0: tensor<1x1xf32>) -> tensor<1x1x1xf32> {
+  %collapse = tensor.collapse_shape %arg0 [] : tensor<1x1xf32> into tensor<f32>
   %expand = tensor.expand_shape %collapse []
-              output_shape [%s0, %s1, %s2, %s3] : tensor<f32> into tensor<?x?x?x?xf32>
-  return %expand : tensor<?x?x?x?xf32>
+              output_shape [1, 1, 1] : tensor<f32> into tensor<1x1x1xf32>
+  return %expand : tensor<1x1x1xf32>
 }
 //      CHECK: func @no_bubble_0d_tensor_reshapes
-// CHECK-SAME:   %[[ARG0:.+]]: tensor<?xf32>
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<1x1xf32>
 //      CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}]
 //      CHECK:   %[[EXPAND:.+]] = tensor.expand_shape %[[COLLAPSE]] {{\[}}]
 //      CHECK:   return %[[EXPAND]]
+
+// -----
+
+// Test the case where the reassocation indices in the collapse and expand
+// are of same size.
+func.func @bubble_expand_match_non_unit_size_reassocation(
+      %arg0 : tensor<4x?x4x32x4x?xf16>, %arg1 : index, %arg2 : index) -> tensor<4x?x4x128x?x32xf16> {
+  %collapsed = tensor.collapse_shape %arg0 [[0, 1, 2], [3, 4], [5]]
+      : tensor<4x?x4x32x4x?xf16> into tensor<?x128x?xf16>
+  %expanded = tensor.expand_shape %collapsed [[0, 1, 2], [3], [4, 5]] output_shape [4, %arg1, 4, 128, %arg2, 32]
+      : tensor<?x128x?xf16> into tensor<4x?x4x128x?x32xf16>
+  return %expanded : tensor<4x?x4x128x?x32xf16>
+}
+//      CHECK: func @bubble_expand_match_non_unit_size_reassocation
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<4x?x4x32x4x?xf16>
+// CHECK-SAME:     %[[ARG1:[a-zA-z0-9]+]]: index
+// CHECK-SAME:     %[[ARG2:[a-zA-z0-9]+]]: index
+//      CHECK:   %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]]
+// CHECK-SAME:       {{\[}}[0], [1], [2], [3], [4], [5, 6]{{\]}}
+// CHECK-SAME:       [4, %[[ARG1]], 4, 32, 4, %[[ARG2]], 32]
+//      CHECK:   %[[COLLAPSED:.+]] = tensor.collapse_shape %[[EXPANDED]]
+// CHECK-SAME:       {{\[}}[0], [1], [2], [3, 4], [5], [6]{{\]}}
+//      CHECK:   return %[[COLLAPSED]]
+
+// -----
+
+// Test the case where the trailing collapse isnt needed.
+func.func @no_collapse_generated(
+      %arg0 : tensor<4x?x4x128x?xf16>, %arg1 : index, %arg2 : index) -> tensor<4x?x4x128x?x32xf16> {
+  %collapsed = tensor.collapse_shape %arg0 [[0, 1, 2], [3], [4]]
+      : tensor<4x?x4x128x?xf16> into tensor<?x128x?xf16>
+  %expanded = tensor.expand_shape %collapsed [[0, 1, 2], [3], [4, 5]] output_shape [4, %arg1, 4, 128, %arg2, 32]
+      : tensor<?x128x?xf16> into tensor<4x?x4x128x?x32xf16>
+  return %expanded : tensor<4x?x4x128x?x32xf16>
+}
+//      CHECK: func @no_collapse_generated
+//      CHECK:   %[[EXPANDED:.+]] = tensor.expand_shape 
+//      CHECK:   return %[[EXPANDED]]
+
+// -----
+
+// Test the case where the trailing collapse isnt needed.
+func.func @no_expand_generated(
+      %arg0 : tensor<4x?x4x128x?x?x?xf16>, %arg1 : index, %arg2 : index, %arg3 : index) -> tensor<4x?x4x128x?x?xf16> {
+  %collapsed = tensor.collapse_shape %arg0 [[0, 1, 2], [3], [4], [5, 6]]
+      : tensor<4x?x4x128x?x?x?xf16> into tensor<?x128x?x?xf16>
+  %expanded = tensor.expand_shape %collapsed [[0, 1, 2], [3], [4], [5]] output_shape [4, %arg1, 4, 128, %arg2, %arg3]
+      : tensor<?x128x?x?xf16> into tensor<4x?x4x128x?x?xf16>
+  return %expanded : tensor<4x?x4x128x?x?xf16>
+}
+//      CHECK: func @no_expand_generated
+//      CHECK:   %[[EXPANDED:.+]] = tensor.collapse_shape
+//      CHECK:   return %[[EXPANDED]]

int64_t newExpandIndex = 0, newCollapseIndex = 0, sourceSizeIndex = 0,
resultSizeIndex = 0;

for (size_t idx = 0, idx_end = collapseReInds.size(); idx < idx_end;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason you are creating a new variable idx_end? (also should be camelBack)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed to camelCase. I tried to do

for (auto &[collapsedReassocation, expandReassocation] : llvm::zip_equal(collapseReInds, expandReInds))

and it didnt compile. Dont know why.

MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Apr 3, 2025
Signed-off-by: MaheshRavishankar <[email protected]>
@MaheshRavishankar MaheshRavishankar force-pushed the users/MaheshRavishankar/generalizeExpandCollapseSwap branch 2 times, most recently from 9c9da8b to 44631b4 Compare April 15, 2025 21:00
Copy link

github-actions bot commented Apr 15, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

… -> `tensor.expand_shape`.

The current patterns compared the reassocation indices for the two ops
and failed if neither of them were of size 1. This patch relaxes this
restriction by handling a new case where the reassociation indices
might be of the same size.

Also generalizes to cases where when generating the swapped
`tensor.expand_shape` -> `tensor.collapse_shape` if one of them is
degenerate, those are not generated.

Signed-off-by: MaheshRavishankar <[email protected]>
Signed-off-by: MaheshRavishankar <[email protected]>
@MaheshRavishankar MaheshRavishankar force-pushed the users/MaheshRavishankar/generalizeExpandCollapseSwap branch from 44631b4 to 13e48f3 Compare April 15, 2025 21:08
@MaheshRavishankar MaheshRavishankar merged commit 0f3e460 into llvm:main Apr 15, 2025
6 of 9 checks passed
var-const pushed a commit to ldionne/llvm-project that referenced this pull request Apr 17, 2025
… -> `tensor.expand_shape`. (llvm#133819)

The current patterns compared the reassocation indices for the two ops
and failed if neither of them were of size 1. This patch relaxes this
restriction by handling a new case where the reassociation indices might
be of the same size.

Also generalizes to cases where when generating the swapped
`tensor.expand_shape` -> `tensor.collapse_shape` if one of them is
degenerate, those are not generated.

Signed-off-by: MaheshRavishankar <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants