Skip to content

[mlir][Tensor] Generalize the pattern to swap tensor.collapse_shape -> tensor.expand_shape. #133819

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 91 additions & 21 deletions mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,56 +167,126 @@ struct BubbleUpExpandThroughParallelCollapse
return failure();
}

// Reshapes are parallel to each other if none of the reassociation indices
// have greater than 1 index for both reshapes.
// Reshapes are parallel to each other (by construction the number of
// reassociations specified in the collapse and expand are the same), if at
// any position
// 1. either the reassociation indices are of the same size, or
// 2. either the reassociation in the collapse or the expand is of size 1.
ArrayRef<int64_t> staticSourceSize = collapseOp.getSrcType().getShape();
ArrayRef<int64_t> staticResultSize = expandOp.getStaticOutputShape();
for (auto [expandReassociation, collapseReassociation] :
llvm::zip_equal(expandReInds, collapseReInds)) {
if (collapseReassociation.size() == expandReassociation.size()) {
// Even if the reassociations are the same, the collapse/expand should
// result in the same dimensions. i.e 4x8x2 into 64 should be expanded
// into 4x8x2 again. In presense of dynamic dimensions one can only
// verify "equality" when there is only one dynamic dimension present,
// and all other static dimensions are equal.
ArrayRef<int64_t> collapsedStaticShapes = staticSourceSize.slice(
collapseReassociation.front(), collapseReassociation.size());
int64_t numCollapsedDynamic =
llvm::count_if(collapsedStaticShapes,
[](int64_t d) { return ShapedType::isDynamic(d); });
ArrayRef<int64_t> expandedStaticShapes = staticResultSize.slice(
expandReassociation.front(), expandReassociation.size());
int64_t numExpandedDynamic =
llvm::count_if(expandedStaticShapes,
[](int64_t d) { return ShapedType::isDynamic(d); });
if (numCollapsedDynamic > 1 || numExpandedDynamic > 1 ||
collapsedStaticShapes != expandedStaticShapes) {
return failure();
}
continue;
}
// If the reassociations are not same, one or the other needs to be of
// size one.
if (collapseReassociation.size() != 1 && expandReassociation.size() != 1)
return failure();
}

// Compute new reassociation indices and expanded/collaped shapes.
SmallVector<ReassociationIndices> newExpandReInds, newCollapseReInds;
Location loc = expandOp->getLoc();
SmallVector<OpFoldResult> collapseSizes =
SmallVector<OpFoldResult> sourceSizes =
tensor::getMixedSizes(rewriter, loc, collapseOp.getSrc());
SmallVector<OpFoldResult> expandSizes(getMixedValues(
expandOp.getStaticOutputShape(), expandOp.getOutputShape(), rewriter));
SmallVector<OpFoldResult> resultSizes = expandOp.getMixedOutputShape();
SmallVector<OpFoldResult> newExpandSizes;
int64_t index = 0, expandIndex = 0, collapseIndex = 0;
for (auto [idx, collapseReassociation] : llvm::enumerate(collapseReInds)) {

int64_t newExpandIndex = 0, newCollapseIndex = 0, sourceSizeIndex = 0,
resultSizeIndex = 0;

for (size_t idx = 0, idxEnd = collapseReInds.size(); idx < idxEnd; idx++) {
auto &collapseReassociation = collapseReInds[idx];
auto &expandReassociation = expandReInds[idx];

// Case 1. The reassociations are same in the collapse producer
// and expand consumer. In the swapped expand, each of the final
// dimensions are kept as is in the expand and the collapse. So,
// for every element in the `ReassocationIndices` vector add a new
// `ReassociationIndices` vector for the swapped expand and collapse
// (of size 1).
if (collapseReassociation.size() == expandReassociation.size()) {
for (size_t i = 0; i < collapseReassociation.size(); ++i) {
newCollapseReInds.push_back({newCollapseIndex++});
newExpandReInds.push_back({newExpandIndex++});
newExpandSizes.push_back(resultSizes[resultSizeIndex++]);
sourceSizeIndex++;
}
continue;
}

// Case 2. The `ReassociationIndices` in the collapse is of size > 1 (and
// in the expand is of size == 1). In this case, the original dimensions
// are preserved on expansion and collapsed subsequently.
if (collapseReassociation.size() != 1) {
ReassociationIndices newCollapseReassociation;
for (size_t i = 0; i < collapseReassociation.size(); ++i) {
newCollapseReassociation.push_back(index);
newExpandReInds.push_back({index++});
newExpandSizes.push_back(collapseSizes[collapseIndex++]);
newCollapseReassociation.push_back(newCollapseIndex++);
newExpandReInds.push_back({newExpandIndex++});
newExpandSizes.push_back(sourceSizes[sourceSizeIndex++]);
}
resultSizeIndex++;
newCollapseReInds.push_back(newCollapseReassociation);
expandIndex++;
continue;
}

// Case 3. The `ReassociationIndices` in the expand is of size > 1 (and
// in the collapse is of size == 1). In this case, the expansion happens
// first and the expanded dimensions are preserved on collapse.
ReassociationIndices newExpandReassociation;
auto expandReassociation = expandReInds[idx];
for (size_t i = 0; i < expandReassociation.size(); ++i) {
newExpandReassociation.push_back(index);
newCollapseReInds.push_back({index++});
newExpandSizes.push_back(expandSizes[expandIndex++]);
newExpandReassociation.push_back(newExpandIndex++);
newCollapseReInds.push_back({newCollapseIndex++});
newExpandSizes.push_back(resultSizes[resultSizeIndex++]);
}
newExpandReInds.push_back(newExpandReassociation);
collapseIndex++;
sourceSizeIndex++;
}

// Swap reshape order.
SmallVector<Value> dynamicSizes;
SmallVector<int64_t> staticSizes;
dispatchIndexOpFoldResults(newExpandSizes, dynamicSizes, staticSizes);
auto expandResultType = expandOp.getResultType().clone(staticSizes);
auto newExpand = rewriter.create<tensor::ExpandShapeOp>(
loc, expandResultType, collapseOp.getSrc(), newExpandReInds,
newExpandSizes);
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
expandOp, newExpand.getResult(), newCollapseReInds);
Value newCollapseSrc = collapseOp.getSrc();
// If the number of reassociation indices in the new `expand_shape` op
// matches the number of dimensions of the result, then the expand_shape
// is a no-op.
if (newExpandReInds.size() != newExpandSizes.size()) {
newCollapseSrc = rewriter.create<tensor::ExpandShapeOp>(
loc, expandResultType, newCollapseSrc, newExpandReInds,
newExpandSizes);
}

// If the number of reassociation indices in the new `collapse_shape` op
// matches the number of dimensions of the source, then the collapse_shape
// is a no-op.
Value replacement = newCollapseSrc;
if (newCollapseReInds.size() != newExpandSizes.size()) {
replacement = rewriter.create<tensor::CollapseShapeOp>(
loc, newCollapseSrc, newCollapseReInds);
}
rewriter.replaceOp(expandOp, replacement);
return success();
}
};
Expand Down
63 changes: 58 additions & 5 deletions mlir/test/Dialect/Tensor/bubble-reshapes.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,67 @@ func.func @no_bubble_partial_intersecting_reshapes(%arg0: tensor<?x?x?x?xf32>, %

// -----

func.func @no_bubble_0d_tensor_reshapes(%arg0: tensor<?xf32>, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor<?x?x?x?xf32> {
%collapse = tensor.collapse_shape %arg0 [] : tensor<?xf32> into tensor<f32>
func.func @no_bubble_0d_tensor_reshapes(%arg0: tensor<1x1xf32>) -> tensor<1x1x1xf32> {
%collapse = tensor.collapse_shape %arg0 [] : tensor<1x1xf32> into tensor<f32>
%expand = tensor.expand_shape %collapse []
output_shape [%s0, %s1, %s2, %s3] : tensor<f32> into tensor<?x?x?x?xf32>
return %expand : tensor<?x?x?x?xf32>
output_shape [1, 1, 1] : tensor<f32> into tensor<1x1x1xf32>
return %expand : tensor<1x1x1xf32>
}
// CHECK: func @no_bubble_0d_tensor_reshapes
// CHECK-SAME: %[[ARG0:.+]]: tensor<?xf32>
// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1xf32>
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}]
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[COLLAPSE]] {{\[}}]
// CHECK: return %[[EXPAND]]

// -----

// Test the case where the reassocation indices in the collapse and expand
// are of same size.
func.func @bubble_expand_match_non_unit_size_reassocation(
%arg0 : tensor<4x?x4x32x4x?xf16>, %arg1 : index, %arg2 : index) -> tensor<4x?x4x128x?x32xf16> {
%collapsed = tensor.collapse_shape %arg0 [[0, 1, 2], [3, 4], [5]]
: tensor<4x?x4x32x4x?xf16> into tensor<?x128x?xf16>
%expanded = tensor.expand_shape %collapsed [[0, 1, 2], [3], [4, 5]] output_shape [4, %arg1, 4, 128, %arg2, 32]
: tensor<?x128x?xf16> into tensor<4x?x4x128x?x32xf16>
return %expanded : tensor<4x?x4x128x?x32xf16>
}
// CHECK: func @bubble_expand_match_non_unit_size_reassocation
// CHECK-SAME: %[[ARG0:.+]]: tensor<4x?x4x32x4x?xf16>
// CHECK-SAME: %[[ARG1:[a-zA-z0-9]+]]: index
// CHECK-SAME: %[[ARG2:[a-zA-z0-9]+]]: index
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]]
// CHECK-SAME: {{\[}}[0], [1], [2], [3], [4], [5, 6]{{\]}}
// CHECK-SAME: [4, %[[ARG1]], 4, 32, 4, %[[ARG2]], 32]
// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[EXPANDED]]
// CHECK-SAME: {{\[}}[0], [1], [2], [3, 4], [5], [6]{{\]}}
// CHECK: return %[[COLLAPSED]]

// -----

// Test the case where the trailing collapse isnt needed.
func.func @no_collapse_generated(
%arg0 : tensor<4x?x4x128x?xf16>, %arg1 : index, %arg2 : index) -> tensor<4x?x4x128x?x32xf16> {
%collapsed = tensor.collapse_shape %arg0 [[0, 1, 2], [3], [4]]
: tensor<4x?x4x128x?xf16> into tensor<?x128x?xf16>
%expanded = tensor.expand_shape %collapsed [[0, 1, 2], [3], [4, 5]] output_shape [4, %arg1, 4, 128, %arg2, 32]
: tensor<?x128x?xf16> into tensor<4x?x4x128x?x32xf16>
return %expanded : tensor<4x?x4x128x?x32xf16>
}
// CHECK: func @no_collapse_generated
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape
// CHECK: return %[[EXPANDED]]

// -----

// Test the case where the leading expand isnt needed.
func.func @no_expand_generated(
%arg0 : tensor<4x?x4x128x?x?x?xf16>, %arg1 : index, %arg2 : index, %arg3 : index) -> tensor<4x?x4x128x?x?xf16> {
%collapsed = tensor.collapse_shape %arg0 [[0, 1, 2], [3], [4], [5, 6]]
: tensor<4x?x4x128x?x?x?xf16> into tensor<?x128x?x?xf16>
%expanded = tensor.expand_shape %collapsed [[0, 1, 2], [3], [4], [5]] output_shape [4, %arg1, 4, 128, %arg2, %arg3]
: tensor<?x128x?x?xf16> into tensor<4x?x4x128x?x?xf16>
return %expanded : tensor<4x?x4x128x?x?xf16>
}
// CHECK: func @no_expand_generated
// CHECK: %[[EXPANDED:.+]] = tensor.collapse_shape
// CHECK: return %[[EXPANDED]]
Loading