From 13e48f3aee49f41df42192c1c2ee61d8c1a44ee5 Mon Sep 17 00:00:00 2001 From: MaheshRavishankar Date: Mon, 31 Mar 2025 15:50:39 -0500 Subject: [PATCH] [mlir][Tensor] Generalize the pattern to swap `tensor.collapse_shape` -> `tensor.expand_shape`. The current patterns compared the reassocation indices for the two ops and failed if neither of them were of size 1. This patch relaxes this restriction by handling a new case where the reassociation indices might be of the same size. Also generalizes to cases where when generating the swapped `tensor.expand_shape` -> `tensor.collapse_shape` if one of them is degenerate, those are not generated. Signed-off-by: MaheshRavishankar Signed-off-by: MaheshRavishankar --- .../Tensor/Transforms/ReshapePatterns.cpp | 112 ++++++++++++++---- mlir/test/Dialect/Tensor/bubble-reshapes.mlir | 63 +++++++++- 2 files changed, 149 insertions(+), 26 deletions(-) diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp index eed44e60d6591..a3de7f9b44ae6 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp @@ -167,10 +167,39 @@ struct BubbleUpExpandThroughParallelCollapse return failure(); } - // Reshapes are parallel to each other if none of the reassociation indices - // have greater than 1 index for both reshapes. + // Reshapes are parallel to each other (by construction the number of + // reassociations specified in the collapse and expand are the same), if at + // any position + // 1. either the reassociation indices are of the same size, or + // 2. either the reassociation in the collapse or the expand is of size 1. + ArrayRef staticSourceSize = collapseOp.getSrcType().getShape(); + ArrayRef staticResultSize = expandOp.getStaticOutputShape(); for (auto [expandReassociation, collapseReassociation] : llvm::zip_equal(expandReInds, collapseReInds)) { + if (collapseReassociation.size() == expandReassociation.size()) { + // Even if the reassociations are the same, the collapse/expand should + // result in the same dimensions. i.e 4x8x2 into 64 should be expanded + // into 4x8x2 again. In presense of dynamic dimensions one can only + // verify "equality" when there is only one dynamic dimension present, + // and all other static dimensions are equal. + ArrayRef collapsedStaticShapes = staticSourceSize.slice( + collapseReassociation.front(), collapseReassociation.size()); + int64_t numCollapsedDynamic = + llvm::count_if(collapsedStaticShapes, + [](int64_t d) { return ShapedType::isDynamic(d); }); + ArrayRef expandedStaticShapes = staticResultSize.slice( + expandReassociation.front(), expandReassociation.size()); + int64_t numExpandedDynamic = + llvm::count_if(expandedStaticShapes, + [](int64_t d) { return ShapedType::isDynamic(d); }); + if (numCollapsedDynamic > 1 || numExpandedDynamic > 1 || + collapsedStaticShapes != expandedStaticShapes) { + return failure(); + } + continue; + } + // If the reassociations are not same, one or the other needs to be of + // size one. if (collapseReassociation.size() != 1 && expandReassociation.size() != 1) return failure(); } @@ -178,33 +207,60 @@ struct BubbleUpExpandThroughParallelCollapse // Compute new reassociation indices and expanded/collaped shapes. SmallVector newExpandReInds, newCollapseReInds; Location loc = expandOp->getLoc(); - SmallVector collapseSizes = + SmallVector sourceSizes = tensor::getMixedSizes(rewriter, loc, collapseOp.getSrc()); - SmallVector expandSizes(getMixedValues( - expandOp.getStaticOutputShape(), expandOp.getOutputShape(), rewriter)); + SmallVector resultSizes = expandOp.getMixedOutputShape(); SmallVector newExpandSizes; - int64_t index = 0, expandIndex = 0, collapseIndex = 0; - for (auto [idx, collapseReassociation] : llvm::enumerate(collapseReInds)) { + + int64_t newExpandIndex = 0, newCollapseIndex = 0, sourceSizeIndex = 0, + resultSizeIndex = 0; + + for (size_t idx = 0, idxEnd = collapseReInds.size(); idx < idxEnd; idx++) { + auto &collapseReassociation = collapseReInds[idx]; + auto &expandReassociation = expandReInds[idx]; + + // Case 1. The reassociations are same in the collapse producer + // and expand consumer. In the swapped expand, each of the final + // dimensions are kept as is in the expand and the collapse. So, + // for every element in the `ReassocationIndices` vector add a new + // `ReassociationIndices` vector for the swapped expand and collapse + // (of size 1). + if (collapseReassociation.size() == expandReassociation.size()) { + for (size_t i = 0; i < collapseReassociation.size(); ++i) { + newCollapseReInds.push_back({newCollapseIndex++}); + newExpandReInds.push_back({newExpandIndex++}); + newExpandSizes.push_back(resultSizes[resultSizeIndex++]); + sourceSizeIndex++; + } + continue; + } + + // Case 2. The `ReassociationIndices` in the collapse is of size > 1 (and + // in the expand is of size == 1). In this case, the original dimensions + // are preserved on expansion and collapsed subsequently. if (collapseReassociation.size() != 1) { ReassociationIndices newCollapseReassociation; for (size_t i = 0; i < collapseReassociation.size(); ++i) { - newCollapseReassociation.push_back(index); - newExpandReInds.push_back({index++}); - newExpandSizes.push_back(collapseSizes[collapseIndex++]); + newCollapseReassociation.push_back(newCollapseIndex++); + newExpandReInds.push_back({newExpandIndex++}); + newExpandSizes.push_back(sourceSizes[sourceSizeIndex++]); } + resultSizeIndex++; newCollapseReInds.push_back(newCollapseReassociation); - expandIndex++; continue; } + + // Case 3. The `ReassociationIndices` in the expand is of size > 1 (and + // in the collapse is of size == 1). In this case, the expansion happens + // first and the expanded dimensions are preserved on collapse. ReassociationIndices newExpandReassociation; - auto expandReassociation = expandReInds[idx]; for (size_t i = 0; i < expandReassociation.size(); ++i) { - newExpandReassociation.push_back(index); - newCollapseReInds.push_back({index++}); - newExpandSizes.push_back(expandSizes[expandIndex++]); + newExpandReassociation.push_back(newExpandIndex++); + newCollapseReInds.push_back({newCollapseIndex++}); + newExpandSizes.push_back(resultSizes[resultSizeIndex++]); } newExpandReInds.push_back(newExpandReassociation); - collapseIndex++; + sourceSizeIndex++; } // Swap reshape order. @@ -212,11 +268,25 @@ struct BubbleUpExpandThroughParallelCollapse SmallVector staticSizes; dispatchIndexOpFoldResults(newExpandSizes, dynamicSizes, staticSizes); auto expandResultType = expandOp.getResultType().clone(staticSizes); - auto newExpand = rewriter.create( - loc, expandResultType, collapseOp.getSrc(), newExpandReInds, - newExpandSizes); - rewriter.replaceOpWithNewOp( - expandOp, newExpand.getResult(), newCollapseReInds); + Value newCollapseSrc = collapseOp.getSrc(); + // If the number of reassociation indices in the new `expand_shape` op + // matches the number of dimensions of the result, then the expand_shape + // is a no-op. + if (newExpandReInds.size() != newExpandSizes.size()) { + newCollapseSrc = rewriter.create( + loc, expandResultType, newCollapseSrc, newExpandReInds, + newExpandSizes); + } + + // If the number of reassociation indices in the new `collapse_shape` op + // matches the number of dimensions of the source, then the collapse_shape + // is a no-op. + Value replacement = newCollapseSrc; + if (newCollapseReInds.size() != newExpandSizes.size()) { + replacement = rewriter.create( + loc, newCollapseSrc, newCollapseReInds); + } + rewriter.replaceOp(expandOp, replacement); return success(); } }; diff --git a/mlir/test/Dialect/Tensor/bubble-reshapes.mlir b/mlir/test/Dialect/Tensor/bubble-reshapes.mlir index eeed794884942..81bf8e3f60e2c 100644 --- a/mlir/test/Dialect/Tensor/bubble-reshapes.mlir +++ b/mlir/test/Dialect/Tensor/bubble-reshapes.mlir @@ -48,14 +48,67 @@ func.func @no_bubble_partial_intersecting_reshapes(%arg0: tensor, % // ----- -func.func @no_bubble_0d_tensor_reshapes(%arg0: tensor, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor { - %collapse = tensor.collapse_shape %arg0 [] : tensor into tensor +func.func @no_bubble_0d_tensor_reshapes(%arg0: tensor<1x1xf32>) -> tensor<1x1x1xf32> { + %collapse = tensor.collapse_shape %arg0 [] : tensor<1x1xf32> into tensor %expand = tensor.expand_shape %collapse [] - output_shape [%s0, %s1, %s2, %s3] : tensor into tensor - return %expand : tensor + output_shape [1, 1, 1] : tensor into tensor<1x1x1xf32> + return %expand : tensor<1x1x1xf32> } // CHECK: func @no_bubble_0d_tensor_reshapes -// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1xf32> // CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}] // CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[COLLAPSE]] {{\[}}] // CHECK: return %[[EXPAND]] + +// ----- + +// Test the case where the reassocation indices in the collapse and expand +// are of same size. +func.func @bubble_expand_match_non_unit_size_reassocation( + %arg0 : tensor<4x?x4x32x4x?xf16>, %arg1 : index, %arg2 : index) -> tensor<4x?x4x128x?x32xf16> { + %collapsed = tensor.collapse_shape %arg0 [[0, 1, 2], [3, 4], [5]] + : tensor<4x?x4x32x4x?xf16> into tensor + %expanded = tensor.expand_shape %collapsed [[0, 1, 2], [3], [4, 5]] output_shape [4, %arg1, 4, 128, %arg2, 32] + : tensor into tensor<4x?x4x128x?x32xf16> + return %expanded : tensor<4x?x4x128x?x32xf16> +} +// CHECK: func @bubble_expand_match_non_unit_size_reassocation +// CHECK-SAME: %[[ARG0:.+]]: tensor<4x?x4x32x4x?xf16> +// CHECK-SAME: %[[ARG1:[a-zA-z0-9]+]]: index +// CHECK-SAME: %[[ARG2:[a-zA-z0-9]+]]: index +// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] +// CHECK-SAME: {{\[}}[0], [1], [2], [3], [4], [5, 6]{{\]}} +// CHECK-SAME: [4, %[[ARG1]], 4, 32, 4, %[[ARG2]], 32] +// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[EXPANDED]] +// CHECK-SAME: {{\[}}[0], [1], [2], [3, 4], [5], [6]{{\]}} +// CHECK: return %[[COLLAPSED]] + +// ----- + +// Test the case where the trailing collapse isnt needed. +func.func @no_collapse_generated( + %arg0 : tensor<4x?x4x128x?xf16>, %arg1 : index, %arg2 : index) -> tensor<4x?x4x128x?x32xf16> { + %collapsed = tensor.collapse_shape %arg0 [[0, 1, 2], [3], [4]] + : tensor<4x?x4x128x?xf16> into tensor + %expanded = tensor.expand_shape %collapsed [[0, 1, 2], [3], [4, 5]] output_shape [4, %arg1, 4, 128, %arg2, 32] + : tensor into tensor<4x?x4x128x?x32xf16> + return %expanded : tensor<4x?x4x128x?x32xf16> +} +// CHECK: func @no_collapse_generated +// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape +// CHECK: return %[[EXPANDED]] + +// ----- + +// Test the case where the leading expand isnt needed. +func.func @no_expand_generated( + %arg0 : tensor<4x?x4x128x?x?x?xf16>, %arg1 : index, %arg2 : index, %arg3 : index) -> tensor<4x?x4x128x?x?xf16> { + %collapsed = tensor.collapse_shape %arg0 [[0, 1, 2], [3], [4], [5, 6]] + : tensor<4x?x4x128x?x?x?xf16> into tensor + %expanded = tensor.expand_shape %collapsed [[0, 1, 2], [3], [4], [5]] output_shape [4, %arg1, 4, 128, %arg2, %arg3] + : tensor into tensor<4x?x4x128x?x?xf16> + return %expanded : tensor<4x?x4x128x?x?xf16> +} +// CHECK: func @no_expand_generated +// CHECK: %[[EXPANDED:.+]] = tensor.collapse_shape +// CHECK: return %[[EXPANDED]]