Skip to content

Commit fc74db4

Browse files
authored
[mlir][Linalg] Fix foldFillPackIntoFillOp to work for general cases (#74148)
1 parent 005c833 commit fc74db4

File tree

2 files changed

+21
-16
lines changed

2 files changed

+21
-16
lines changed

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -765,26 +765,12 @@ static FailureOr<FillOp> foldFillPackIntoFillOp(RewriterBase &rewriter,
765765
if (!isEqualConstantIntOrValue(paddingValue, fillOp.value()))
766766
return failure();
767767

768-
OpBuilder::InsertionGuard guard(rewriter);
769-
rewriter.setInsertionPoint(fillOp);
770-
771768
Value packOpDest = packOp.getDest();
772769
if (!packOpDest.hasOneUse())
773770
return failure();
774-
if (auto emptyOp = packOpDest.getDefiningOp<tensor::EmptyOp>()) {
775-
packOpDest = tensor::PackOp::createDestinationTensor(
776-
rewriter, fillOp.getLoc(), fillOp.getDpsInitOperand(0)->get(),
777-
packOp.getMixedTiles(), packOp.getInnerDimsPos(),
778-
packOp.getOuterDimsPerm());
779-
} else {
780-
DominanceInfo dom(fillOp);
781-
if (!dom.properlyDominates(packOpDest, fillOp))
782-
return failure();
783-
}
784771

785-
Value fillDest = packOpDest;
786-
return clone(rewriter, fillOp, packOpDest.getType(),
787-
{fillOp.value(), fillDest});
772+
return rewriter.create<linalg::FillOp>(packOp.getLoc(), fillOp.getInputs(),
773+
packOp.getDest());
788774
}
789775

790776
/// Wrapper pattern that applies foldFillPackIntoFillOp method.

mlir/test/Dialect/Linalg/canonicalize.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,25 @@ func.func @fill_pack() -> tensor<24x32x16x16xf32> {
368368

369369
// -----
370370

371+
func.func @fill_pack_general() -> tensor<1x1x8x4x4x8xi32>{
372+
%c0_i32 = arith.constant 0 : i32
373+
%alloc = memref.alloc() : memref<1x1x8x4x4x8xi32>
374+
%9 = tensor.empty() : tensor<1x1x16x64xi32>
375+
%extracted_slice_15 = tensor.extract_slice %9[0, 0, 0, 0] [1, 1, 16, 64] [1, 1, 1, 1] : tensor<1x1x16x64xi32> to tensor<1x1x16x64xi32>
376+
%16 = linalg.fill ins(%c0_i32 : i32) outs(%extracted_slice_15 : tensor<1x1x16x64xi32>) -> tensor<1x1x16x64xi32>
377+
%0 = bufferization.to_tensor %alloc restrict writable : memref<1x1x8x4x4x8xi32>
378+
%pack_18 = tensor.pack %16 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %0 : tensor<1x1x16x64xi32> -> tensor<1x1x8x4x4x8xi32>
379+
return %pack_18 : tensor<1x1x8x4x4x8xi32>
380+
}
381+
382+
// CHECK-LABEL: func.func @fill_pack_general
383+
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<1x1x8x4x4x8xi32>
384+
// CHECK: %[[TENSOR:.+]] = bufferization.to_tensor %[[ALLOC]]
385+
// CHECK: %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[TENSOR]]
386+
// CHECK: return %[[FILL]]
387+
388+
// -----
389+
371390
#map = affine_map<()[s0] -> (s0 ceildiv 16)>
372391
func.func @dynamic_fill_pack(%arg0: tensor<?x?xf32>) -> tensor<?x?x16x16xf32> {
373392
%cst = arith.constant 0.000000e+00 : f32

0 commit comments

Comments
 (0)