-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][tensor] Fold pack and unpack of empty input tensor #92247
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Adds canonicalization to pack and unpack to fold away operations when their source is a `tensor.empty`.
@llvm/pr-subscribers-mlir-tensor @llvm/pr-subscribers-mlir Author: Adam Siemieniuk (adam-smnk) ChangesAdds canonicalization to pack and unpack to fold away operations when their source is a Full diff: https://github.com/llvm/llvm-project/pull/92247.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 414bd7459af8f..428bf61e2fe5a 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4200,6 +4200,12 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
return success();
}
+ // Fold away packing an empty source tensor.
+ if (auto emptyTensor = packOp.getSource().getDefiningOp<tensor::EmptyOp>()) {
+ rewriter.replaceOp(packOp, packOp.getDest());
+ return success();
+ }
+
// Insert tensor.cast ops if static shape inference is available..
SmallVector<int64_t> srcShape, destShape;
if (inferStaticShape(packOp, srcShape, destShape)) {
@@ -4435,6 +4441,13 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
return success();
}
+ // Fold away unpacking an empty source tensor.
+ if (auto emptyTensor =
+ unPackOp.getSource().getDefiningOp<tensor::EmptyOp>()) {
+ rewriter.replaceOp(unPackOp, unPackOp.getDest());
+ return success();
+ }
+
// Insert tensor.cast ops if static shape inference is available..
SmallVector<int64_t> srcShape, destShape;
if (inferStaticShape(unPackOp, srcShape, destShape)) {
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 8036d996d2324..4922251363950 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2486,3 +2486,24 @@ func.func @dim_out_of_bounds() -> vector<7xi32> {
return %16 : vector<7xi32>
}
+// -----
+
+// CHECK: func.func @pack_empty(
+// CHECK-SAME: %[[T:.+]]: tensor<8x8x32x32xf32>
+// CHECK: return %[[T]] : tensor<8x8x32x32xf32>
+func.func @pack_empty(%arg0: tensor<8x8x32x32xf32>) -> tensor<8x8x32x32xf32> {
+ %empty_unpacked = tensor.empty() : tensor<256x256xf32>
+ %packed = tensor.pack %empty_unpacked inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %arg0 : tensor<256x256xf32> -> tensor<8x8x32x32xf32>
+ return %packed : tensor<8x8x32x32xf32>
+}
+
+// -----
+
+// CHECK: func.func @unpack_empty(
+// CHECK-SAME: %[[T:.+]]: tensor<256x256xf32>
+// CHECK: return %[[T]] : tensor<256x256xf32>
+func.func @unpack_empty(%arg0: tensor<256x256xf32>) -> tensor<256x256xf32> {
+ %empty_packed = tensor.empty() : tensor<8x8x32x32xf32>
+ %unpacked = tensor.unpack %empty_packed inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %arg0 : tensor<8x8x32x32xf32> -> tensor<256x256xf32>
+ return %unpacked : tensor<256x256xf32>
+}
|
It also could be an optional pattern maybe in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding these. I don't think these should be folders or canonicalizations. There are patterns that do folding of tensor.empty and it's consumers, these are better suited there.
Also packs have padding semantics. Those need to be accounted for. I think if there is padding you can't do this transformation
Thanks for the feedback. I'll move it to the folder patterns.
At first I treated packing on empty as mostly UB but I suppose you could materialize it as a valid initialization of a buffer. |
This is where you should add this https://github.com/shark-infra/llvm-project/blob/f44eaa13d0c56630791e6484fa3d86a813fa716c/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp#L98 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the ping @hanhanW.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Just a small comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have a few comments on the lit tests. Looks good otherwise. Please address before landing.
Extends
tensor.empty
folding patterns with pack and unpack consumers to fold away the operations when their source is empty.