Skip to content

Commit ea402fc

Browse files
committed
[mlir][linalg] Support lowering unpack with outer_dims_perm
This commit adds support for lowering `tensor.unpack` with a non-identity `outer_dims_perm`. This was previously left as a not-yet-implemented case.
1 parent 45964eb commit ea402fc

File tree

2 files changed

+26
-26
lines changed

2 files changed

+26
-26
lines changed

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

+12-22
Original file line numberDiff line numberDiff line change
@@ -356,13 +356,6 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
356356

357357
FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
358358
tensor::UnPackOp unPackOp) {
359-
// 1. Filter out NYI cases.
360-
if (!unPackOp.getOuterDimsPerm().empty() &&
361-
!isIdentityPermutation(unPackOp.getOuterDimsPerm())) {
362-
return rewriter.notifyMatchFailure(unPackOp,
363-
"non-identity outer dims perm NYI");
364-
}
365-
366359
Location loc = unPackOp->getLoc();
367360
OpBuilder::InsertionGuard g(rewriter);
368361
rewriter.setInsertionPoint(unPackOp);
@@ -391,45 +384,42 @@ FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
391384
return LowerUnPackOpResult{/*emptyOp=*/nullptr, /*transposeOp=*/nullptr,
392385
/*reshapeOp=*/nullptr, extractSliceOp};
393386
}
394-
// 2. Compute the permutation vector to move the last `numPackedDims` into
395-
// the `innerPosDims` of a shape of rank `packedRank`.
396-
int64_t numPackedDims = unPackOp.getInnerDimsPos().size();
397-
auto lastDims = llvm::to_vector(
398-
llvm::seq<int64_t>(packedRank - numPackedDims, packedRank));
399-
PackingMetadata packingMetadata =
400-
computePackingMetadata(packedRank, unPackOp.getInnerDimsPos());
401-
SmallVector<int64_t> lastDimsToInsertPositionsPerm = computePermutationVector(
402-
packedRank, lastDims, packingMetadata.insertPositions);
387+
388+
// 2. Compute the permutation vector to shuffle packed shape into the shape
389+
// before any outer or inner permutations have been applied.
390+
PackingMetadata packingMetadata;
391+
SmallVector<int64_t> packedToStripMinedShapePerm =
392+
tensor::getUnPackInverseSrcPerm(unPackOp, packingMetadata);
403393

404394
// 3. Compute the stripMinedShape: this is the packed shape without outer and
405395
// inner permutations.
406396
SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
407-
applyPermutationToVector(stripMinedShape, lastDimsToInsertPositionsPerm);
397+
applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm);
408398

409399
// 4. Transpose packedShape to stripMinedShape.
410400
RankedTensorType stripMinedTensorType =
411401
RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
412402
RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
413403
stripMinedTensorType, packingMetadata.reassociations);
414404

415-
// Get dynamic dims from input tensor based on lastDimsToInsertPositionsPerm
405+
// Get dynamic dims from input tensor based on packedToStripMinedShapePerm
416406
// permutation.
417407
SmallVector<OpFoldResult, 4> dims =
418408
tensor::getMixedSizes(rewriter, loc, unPackOp.getSource());
419-
applyPermutationToVector(dims, lastDimsToInsertPositionsPerm);
409+
applyPermutationToVector(dims, packedToStripMinedShapePerm);
420410
auto emptyOp = rewriter.create<tensor::EmptyOp>(
421411
loc, dims, stripMinedTensorType.getElementType());
422412
auto transposeOp = rewriter.create<linalg::TransposeOp>(
423-
loc, unPackOp.getSource(), emptyOp, lastDimsToInsertPositionsPerm);
413+
loc, unPackOp.getSource(), emptyOp, packedToStripMinedShapePerm);
424414

425415
LLVM_DEBUG(
426416
DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions,
427417
DBGS() << "insertPositions: ");
428418
DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(),
429419
DBGS() << "packedShape: ");
430420
DBGSNL();
431-
llvm::interleaveComma(lastDimsToInsertPositionsPerm,
432-
DBGS() << "lastDimsToInsertPositionsPerm: ");
421+
llvm::interleaveComma(packedToStripMinedShapePerm,
422+
DBGS() << "packedToStripMinedShapePerm: ");
433423
DBGSNL(); llvm::interleaveComma(
434424
packingMetadata.reassociations, DBGS() << "reassociations: ",
435425
[&](ReassociationIndices ri) {

mlir/test/Dialect/Linalg/transform-lower-pack.mlir

+14-4
Original file line numberDiff line numberDiff line change
@@ -622,9 +622,20 @@ module attributes {transform.with_named_sequence} {
622622

623623
// -----
624624

625-
// At the moment, we cannot lower tensor.unpack with outer_dims_perm.
626-
func.func @diagnostic_unpack(%arg0: tensor<32x64xf32>, %arg1: tensor<2x4x32x8xf32>) -> tensor<32x64xf32> {
627-
// expected-note @below {{target payload op}}
625+
// CHECK-LABEL: @unpack_with_outer_dims_perm
626+
// CHECK-SAME: %[[ARG0:.*]]: tensor<32x64xf32>, %[[ARG1:.*]]: tensor<2x4x32x8xf32>
627+
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<4x8x2x32xf32>
628+
// CHECK: %[[TRAN:.*]] = linalg.transpose
629+
// CHECK-SAME: ins(%[[ARG1]] : tensor<2x4x32x8xf32>)
630+
// CHECK-SAME: outs(%[[EMPTY]] : tensor<4x8x2x32xf32>)
631+
// CHECK-SAME: permutation = [1, 3, 0, 2]
632+
// CHECK: %[[CLP:.*]] = tensor.collapse_shape %[[TRAN]] {{\[}}[0, 1], [2, 3]]
633+
// CHECK-SAME: : tensor<4x8x2x32xf32> into tensor<32x64xf32>
634+
// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[CLP]][0, 0] [32, 64] [1, 1]
635+
// CHECK-SAME: : tensor<32x64xf32> to tensor<32x64xf32>
636+
// CHECK: linalg.copy ins(%[[SLICE]]
637+
// CHECK-SAME: : tensor<32x64xf32>) outs(%[[ARG0]] : tensor<32x64xf32>) -> tensor<32x64xf32>
638+
func.func @unpack_with_outer_dims_perm(%arg0: tensor<32x64xf32>, %arg1: tensor<2x4x32x8xf32>) -> tensor<32x64xf32> {
628639
%unpack = tensor.unpack %arg1 outer_dims_perm = [1, 0]
629640
inner_dims_pos = [1, 0] inner_tiles = [32, 8] into %arg0 : tensor<2x4x32x8xf32> -> tensor<32x64xf32>
630641
return %unpack : tensor<32x64xf32>
@@ -634,7 +645,6 @@ module attributes {transform.with_named_sequence} {
634645
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
635646
%unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op
636647
: (!transform.any_op) -> !transform.op<"tensor.unpack">
637-
// expected-error @below {{cannot lower to transpose + collapse + extract}}
638648
transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">)
639649
-> (!transform.op<"tensor.empty">,
640650
!transform.op<"linalg.transpose">,

0 commit comments

Comments
 (0)