-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir] Fix bugs in expand_shape patterns after semantics changes #94631
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-tensor @llvm/pr-subscribers-mlir Author: None (Max191) ChangesAfter the Full diff: https://github.com/llvm/llvm-project/pull/94631.diff 2 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index e8f6edc3f133e..3b986f4a60064 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -85,21 +85,55 @@ bool isReassociationValid(ArrayRef<AffineMap> reassociation,
template <typename ReshapeOpTy, typename InverseReshapeOpTy>
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
ArrayRef<Attribute> operands) {
-
+ // Fold identity reshape.
if (reshapeOp.getSrcType() == reshapeOp.getType())
return reshapeOp.getSrc();
- // Fold producer-consumer reshape ops where the operand type of the
- // producer is same as the return type of the consumer.
- auto reshapeSrcOp =
- reshapeOp.getSrc().template getDefiningOp<InverseReshapeOpTy>();
- if (reshapeSrcOp && reshapeSrcOp.getSrcType() == reshapeOp.getResultType())
- return reshapeSrcOp.getSrc();
-
// Reshape of a constant can be replaced with a new constant.
if (auto elements = dyn_cast_or_null<DenseElementsAttr>(operands.front()))
return elements.reshape(cast<ShapedType>(reshapeOp.getResult().getType()));
+ // Fold if the producer reshape source has the same shape with at most 1
+ // dynamic dimension.
+ auto reshapeSrcOp =
+ reshapeOp.getSrc().template getDefiningOp<InverseReshapeOpTy>();
+ if (!reshapeSrcOp)
+ return nullptr;
+ auto srcType = reshapeSrcOp.getSrcType();
+ auto resultType = reshapeOp.getResultType();
+ if (srcType != resultType)
+ return nullptr;
+
+ // If the reshapes are expanding and then collapsing, the ops can be folded
+ // despite multiple dynamic dimensions.
+ if (srcType.getRank() < reshapeSrcOp.getResultType().getRank())
+ return reshapeSrcOp.getSrc();
+ // Otherwise, only 1 dynamic dimension is allowed.
+ if (srcType == resultType &&
+ llvm::count_if(srcType.getShape(), ShapedType::isDynamic) < 2) {
+ return reshapeSrcOp.getSrc();
+ }
+
+ // Fold producer-consumer reshape ops when they are perfect inverses of each
+ // other:
+ // 1) Reassociation indices are equivalent.
+ // 2) Boundary types are equivalent.
+ // 3) No reassociations have more than 1 dynamic dimension, and reassociated
+ // shapes are equal for each reassociation.
+ auto reassociations = reshapeOp.getReassociationIndices();
+ auto inverseReassociations = reshapeSrcOp.getReassociationIndices();
+ if (reassociations != inverseReassociations)
+ return nullptr;
+ ArrayRef<int64_t> expandedSrcShape = srcType.getShape();
+ ArrayRef<int64_t> expandedResultShape = resultType.getShape();
+ if (llvm::none_of(reassociations, [&](auto reInd) {
+ auto srcSlice = expandedSrcShape.slice(reInd.front(), reInd.size());
+ auto resSlice = expandedResultShape.slice(reInd.front(), reInd.size());
+ return srcSlice == resSlice &&
+ llvm::count_if(srcSlice, ShapedType::isDynamic) > 1;
+ })) {
+ return reshapeSrcOp.getSrc();
+ }
return nullptr;
}
@@ -360,10 +394,12 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
resultShape.slice(resultIndices.front(), resultIndices.size());
if (srcSubShape.size() == resultSubShape.size()) {
- if (srcSubShape == resultSubShape)
+ if (srcSubShape == resultSubShape &&
+ llvm::count_if(srcSubShape, ShapedType::isDynamic) < 2) {
composedReassociation.push_back(srcIndices);
- else
+ } else {
return std::nullopt;
+ }
}
// Find reassociation to collapse `srcSubShape` into `resultSubShape`.
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index f7fbd3834288b..4a04d37d4be29 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1139,7 +1139,7 @@ func.func @fold_collapse_of_expand(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32>
return %1 : tensor<12x4xf32>
}
// CHECK-LABEL: @fold_collapse_of_expand
-// CHECK-NOT: linalg.{{.*}}shape
+// CHECK-NOT: tensor.{{.*}}_shape
// -----
@@ -1152,7 +1152,60 @@ func.func @fold_collapse_of_expand_dynamic(%arg0 : tensor<?x?xf32>, %arg1: index
return %1 : tensor<?x?xf32>
}
// CHECK-LABEL: @fold_collapse_of_expand_dynamic
-// CHECK-NOT: linalg.{{.*}}_shape
+// CHECK-NOT: tensor.{{.*}}_shape
+
+// -----
+
+func.func @fold_collapse_of_expand_fully_dynamic(%arg0 : tensor<?x?xf32>, %arg1: index, %arg2: index, %arg3: index)
+ -> tensor<?x?xf32> {
+ %0 = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [%arg1, %arg2, %arg3]
+ : tensor<?x?xf32> into tensor<?x?x?xf32>
+ %1 = tensor.collapse_shape %0 [[0, 1], [2]]
+ : tensor<?x?x?xf32> into tensor<?x?xf32>
+ return %1 : tensor<?x?xf32>
+}
+// CHECK-LABEL: @fold_collapse_of_expand_fully_dynamic
+// CHECK-NOT: tensor.{{.*}}_shape
+
+// -----
+
+func.func @fold_expand_of_collapse(%arg0 : tensor<3x4x4xf32>) -> tensor<3x4x4xf32> {
+ %0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
+ : tensor<3x4x4xf32> into tensor<12x4xf32>
+ %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [3, 4, 4]
+ : tensor<12x4xf32> into tensor<3x4x4xf32>
+ return %1 : tensor<3x4x4xf32>
+}
+// CHECK-LABEL: @fold_expand_of_collapse
+// CHECK-NOT: tensor.{{.*}}_shape
+
+// -----
+
+func.func @fold_expand_of_collapse_dynamic(%arg0 : tensor<?x4x?xf32>, %arg1: index, %arg2: index)
+ -> tensor<?x4x?xf32> {
+ %0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
+ : tensor<?x4x?xf32> into tensor<?x?xf32>
+ %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, 4, %arg2]
+ : tensor<?x?xf32> into tensor<?x4x?xf32>
+ return %1 : tensor<?x4x?xf32>
+}
+// CHECK-LABEL: @fold_expand_of_collapse_dynamic
+// CHECK-NOT: tensor.{{.*}}_shape
+
+// -----
+
+func.func @no_fold_expand_of_collapse_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index)
+ -> tensor<?x?x?xf32> {
+ %0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
+ : tensor<?x?x?xf32> into tensor<?x?xf32>
+ %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, %arg2, %arg3]
+ : tensor<?x?xf32> into tensor<?x?x?xf32>
+ return %1 : tensor<?x?x?xf32>
+}
+// CHECK-LABEL: @no_fold_expand_of_collapse_dynamic
+// CHECK: tensor.collapse_shape
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape
+// CHECK: return %[[EXPAND]]
// -----
|
if (srcType.getRank() < reshapeSrcOp.getResultType().getRank()) | ||
return reshapeSrcOp.getSrc(); | ||
ArrayRef<int64_t> expandedSrcShape = srcType.getShape(); | ||
ArrayRef<int64_t> expandedResultShape = resultType.getShape(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: unused variable
After the
output_shape
field was added toexpand_shape
ops, dynamically sized expand shapes are now possible, but this was not accounted for in the folder. This PR tightens the constraints of the folder to fix this.