Skip to content

Commit 7f16302

Browse files
committed
[mlir] Fix bug in UnPackOp tiling implementation causing infinite loop (llvm#113571)
This fixes a bug in the tiling implementation of tensor.unpack that was causing an infinite loop when certain unpack ops get tiled and fused as a producer. The tiled implementation of tensor.unpack sometimes needs to create an additional tensor.extract_slice on the result of the tiled unpack op, but this slice was getting added to the `generatedSlices` of the tiling result. The `generatedSlices` are used to find the next producers to fuse, so it caused an infinite loop of fusing the same unpack op after it was already in the loop. This fixes the bug by adding the slice of the source instead of the result. Signed-off-by: Max Dawkins <[email protected]>
1 parent 864902e commit 7f16302

File tree

2 files changed

+51
-4
lines changed

2 files changed

+51
-4
lines changed

mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -554,13 +554,14 @@ struct UnPackOpTiling
554554
sliceSrcIndices.append(numInnerTiles, zeroAttr);
555555
sliceSrcSizes.append(unpackOp.getMixedTiles());
556556
sliceSrcStrides.append(numInnerTiles, oneAttr);
557-
Value sliceSource =
557+
SmallVector<Operation *> generatedSlices;
558+
ExtractSliceOp sliceSource =
558559
b.create<ExtractSliceOp>(loc, unpackOp.getSource(), sliceSrcIndices,
559560
sliceSrcSizes, sliceSrcStrides);
561+
generatedSlices.push_back(sliceSource);
560562

561563
SmallVector<OpFoldResult> destStrides(destRank, oneAttr);
562564
Value sliceDest;
563-
SmallVector<Operation *> generatedSlices;
564565
if (isPerfectTilingCase) {
565566
auto destSliceOp = b.create<ExtractSliceOp>(loc, unpackOp.getDest(),
566567
offsets, sizes, destStrides);
@@ -571,7 +572,7 @@ struct UnPackOpTiling
571572
unpackOp.getDestType().getElementType());
572573
}
573574

574-
SmallVector<Value> tiledOperands = {sliceSource, sliceDest};
575+
SmallVector<Value> tiledOperands = {sliceSource.getResult(), sliceDest};
575576
for (auto tile : unpackOp.getInnerTiles())
576577
tiledOperands.push_back(tile);
577578

@@ -586,7 +587,6 @@ struct UnPackOpTiling
586587
auto extractSlice =
587588
b.create<ExtractSliceOp>(loc, tiledUnpackOp->getResult(0),
588589
resultOffsetsFromDest, sizes, destStrides);
589-
generatedSlices.push_back(extractSlice);
590590
return TilingResult{
591591
{tiledUnpackOp}, {extractSlice.getResult()}, generatedSlices};
592592
}

mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -587,3 +587,50 @@ module attributes {transform.with_named_sequence} {
587587
// CHECK: %[[INSERT_SLICE:.+]] = tensor.insert_slice %[[IF_RESULT]]
588588
// CHECK: scf.yield %[[INSERT_SLICE]]
589589
// CHECK: return %[[FOR_RESULT]]
590+
591+
// -----
592+
593+
func.func @imperfect_unpack_producer_fusion(%source: tensor<1x1x288x8x4xf32>, %dest: tensor<1x2x1152xf32>) -> tensor<1x2x1152xf32> {
594+
%0 = tensor.unpack %source
595+
outer_dims_perm = [0, 1, 2]
596+
inner_dims_pos = [1, 2]
597+
inner_tiles = [8, 4] into %dest
598+
: tensor<1x1x288x8x4xf32> -> tensor<1x2x1152xf32>
599+
%1 = tensor.empty() : tensor<1x2x1152xf32>
600+
%cst = arith.constant 1.0 : f32
601+
%2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
602+
affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
603+
iterator_types = ["parallel", "parallel", "parallel"]}
604+
ins(%0 : tensor<1x2x1152xf32>)
605+
outs(%1 : tensor<1x2x1152xf32>) {
606+
^bb0(%in: f32, %out: f32):
607+
%7 = arith.addf %in, %cst : f32
608+
linalg.yield %7 : f32
609+
} -> tensor<1x2x1152xf32>
610+
return %2 : tensor<1x2x1152xf32>
611+
}
612+
613+
module attributes {transform.with_named_sequence} {
614+
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
615+
%matmul = transform.structured.match ops{["linalg.generic"]} in %arg1
616+
: (!transform.any_op) -> !transform.any_op
617+
%a, %b = transform.structured.fuse %matmul [0, 1, 0]
618+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
619+
transform.yield
620+
}
621+
}
622+
623+
// CHECK-LABEL: func @imperfect_unpack_producer_fusion
624+
// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x288x8x4xf32>
625+
// CHECK-SAME: %[[ARG1:.+]]: tensor<1x2x1152xf32>
626+
// CHECK: %[[FOR_RESULT:.+]] = scf.for{{.*}}iter_args(%[[ITER_ARG:.+]] = {{.*}})
627+
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]]
628+
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[SLICE]]
629+
// CHECK-DAG: %[[UNPACK_SLICE:.+]] = tensor.extract_slice %[[UNPACK]]
630+
// CHECK-DAG: %[[INIT_SLICE:.+]] = tensor.extract_slice %[[ITER_ARG]]
631+
// CHECK: %[[GENERIC:.+]] = linalg.generic
632+
// CHECK-SAME: ins(%[[UNPACK_SLICE]]
633+
// CHECK-SAME: outs(%[[INIT_SLICE]]
634+
// CHECK: %[[INSERT_SLICE:.+]] = tensor.insert_slice %[[GENERIC]] into %[[ITER_ARG]]
635+
// CHECK: scf.yield %[[INSERT_SLICE]]
636+
// CHECK: return %[[FOR_RESULT]]

0 commit comments

Comments
 (0)