-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][linalg] Support lowering unpack with outer_dims_perm #94477
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Author: Ryan Holt (ryan-holt-1) ChangesThis commit adds support for lowering Full diff: https://github.com/llvm/llvm-project/pull/94477.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 91dfac802ad67..f18cfdea2faac 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -356,13 +356,6 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
tensor::UnPackOp unPackOp) {
- // 1. Filter out NYI cases.
- if (!unPackOp.getOuterDimsPerm().empty() &&
- !isIdentityPermutation(unPackOp.getOuterDimsPerm())) {
- return rewriter.notifyMatchFailure(unPackOp,
- "non-identity outer dims perm NYI");
- }
-
Location loc = unPackOp->getLoc();
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(unPackOp);
@@ -391,20 +384,17 @@ FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
return LowerUnPackOpResult{/*emptyOp=*/nullptr, /*transposeOp=*/nullptr,
/*reshapeOp=*/nullptr, extractSliceOp};
}
- // 2. Compute the permutation vector to move the last `numPackedDims` into
- // the `innerPosDims` of a shape of rank `packedRank`.
- int64_t numPackedDims = unPackOp.getInnerDimsPos().size();
- auto lastDims = llvm::to_vector(
- llvm::seq<int64_t>(packedRank - numPackedDims, packedRank));
- PackingMetadata packingMetadata =
- computePackingMetadata(packedRank, unPackOp.getInnerDimsPos());
- SmallVector<int64_t> lastDimsToInsertPositionsPerm = computePermutationVector(
- packedRank, lastDims, packingMetadata.insertPositions);
+
+ // 2. Compute the permutation vector to shuffle packed shape into the shape
+ // before any outer or inner permutations have been applied.
+ PackingMetadata packingMetadata;
+ SmallVector<int64_t> packedToStripMinedShapePerm =
+ tensor::getUnPackInverseSrcPerm(unPackOp, packingMetadata);
// 3. Compute the stripMinedShape: this is the packed shape without outer and
// inner permutations.
SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
- applyPermutationToVector(stripMinedShape, lastDimsToInsertPositionsPerm);
+ applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm);
// 4. Transpose packedShape to stripMinedShape.
RankedTensorType stripMinedTensorType =
@@ -412,15 +402,15 @@ FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
stripMinedTensorType, packingMetadata.reassociations);
- // Get dynamic dims from input tensor based on lastDimsToInsertPositionsPerm
+ // Get dynamic dims from input tensor based on packedToStripMinedShapePerm
// permutation.
SmallVector<OpFoldResult, 4> dims =
tensor::getMixedSizes(rewriter, loc, unPackOp.getSource());
- applyPermutationToVector(dims, lastDimsToInsertPositionsPerm);
+ applyPermutationToVector(dims, packedToStripMinedShapePerm);
auto emptyOp = rewriter.create<tensor::EmptyOp>(
loc, dims, stripMinedTensorType.getElementType());
auto transposeOp = rewriter.create<linalg::TransposeOp>(
- loc, unPackOp.getSource(), emptyOp, lastDimsToInsertPositionsPerm);
+ loc, unPackOp.getSource(), emptyOp, packedToStripMinedShapePerm);
LLVM_DEBUG(
DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions,
@@ -428,8 +418,8 @@ FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(),
DBGS() << "packedShape: ");
DBGSNL();
- llvm::interleaveComma(lastDimsToInsertPositionsPerm,
- DBGS() << "lastDimsToInsertPositionsPerm: ");
+ llvm::interleaveComma(packedToStripMinedShapePerm,
+ DBGS() << "packedToStripMinedShapePerm: ");
DBGSNL(); llvm::interleaveComma(
packingMetadata.reassociations, DBGS() << "reassociations: ",
[&](ReassociationIndices ri) {
diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index 926969bfc7388..f34ef4f961483 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -622,9 +622,20 @@ module attributes {transform.with_named_sequence} {
// -----
-// At the moment, we cannot lower tensor.unpack with outer_dims_perm.
-func.func @diagnostic_unpack(%arg0: tensor<32x64xf32>, %arg1: tensor<2x4x32x8xf32>) -> tensor<32x64xf32> {
- // expected-note @below {{target payload op}}
+// CHECK-LABEL: @unpack_with_outer_dims_perm
+// CHECK-SAME: %[[ARG0:.*]]: tensor<32x64xf32>, %[[ARG1:.*]]: tensor<2x4x32x8xf32>
+// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<4x8x2x32xf32>
+// CHECK: %[[TRAN:.*]] = linalg.transpose
+// CHECK-SAME: ins(%[[ARG1]] : tensor<2x4x32x8xf32>)
+// CHECK-SAME: outs(%[[EMPTY]] : tensor<4x8x2x32xf32>)
+// CHECK-SAME: permutation = [1, 3, 0, 2]
+// CHECK: %[[CLP:.*]] = tensor.collapse_shape %[[TRAN]] {{\[}}[0, 1], [2, 3]]
+// CHECK-SAME: : tensor<4x8x2x32xf32> into tensor<32x64xf32>
+// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[CLP]][0, 0] [32, 64] [1, 1]
+// CHECK-SAME: : tensor<32x64xf32> to tensor<32x64xf32>
+// CHECK: linalg.copy ins(%[[SLICE]]
+// CHECK-SAME: : tensor<32x64xf32>) outs(%[[ARG0]] : tensor<32x64xf32>) -> tensor<32x64xf32>
+func.func @unpack_with_outer_dims_perm(%arg0: tensor<32x64xf32>, %arg1: tensor<2x4x32x8xf32>) -> tensor<32x64xf32> {
%unpack = tensor.unpack %arg1 outer_dims_perm = [1, 0]
inner_dims_pos = [1, 0] inner_tiles = [32, 8] into %arg0 : tensor<2x4x32x8xf32> -> tensor<32x64xf32>
return %unpack : tensor<32x64xf32>
@@ -634,7 +645,6 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op
: (!transform.any_op) -> !transform.op<"tensor.unpack">
- // expected-error @below {{cannot lower to transpose + collapse + extract}}
transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">)
-> (!transform.op<"tensor.empty">,
!transform.op<"linalg.transpose">,
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
ea402fc
to
59f2f0f
Compare
This commit adds support for lowering `tensor.unpack` with a non-identity `outer_dims_perm`. This was previously left as a not-yet-implemented case.
59f2f0f
to
82383c7
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This commit adds support for lowering
tensor.unpack
with a non-identityouter_dims_perm
. This was previously left as a not-yet-implemented case.