diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 58af9995548e9..9a4d5e8845b21 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -765,26 +765,12 @@ static FailureOr foldFillPackIntoFillOp(RewriterBase &rewriter, if (!isEqualConstantIntOrValue(paddingValue, fillOp.value())) return failure(); - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPoint(fillOp); - Value packOpDest = packOp.getDest(); if (!packOpDest.hasOneUse()) return failure(); - if (auto emptyOp = packOpDest.getDefiningOp()) { - packOpDest = tensor::PackOp::createDestinationTensor( - rewriter, fillOp.getLoc(), fillOp.getDpsInitOperand(0)->get(), - packOp.getMixedTiles(), packOp.getInnerDimsPos(), - packOp.getOuterDimsPerm()); - } else { - DominanceInfo dom(fillOp); - if (!dom.properlyDominates(packOpDest, fillOp)) - return failure(); - } - Value fillDest = packOpDest; - return clone(rewriter, fillOp, packOpDest.getType(), - {fillOp.value(), fillDest}); + return rewriter.create(packOp.getLoc(), fillOp.getInputs(), + packOp.getDest()); } /// Wrapper pattern that applies foldFillPackIntoFillOp method. diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir index e875bae473094..052dc367ca677 100644 --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -368,6 +368,25 @@ func.func @fill_pack() -> tensor<24x32x16x16xf32> { // ----- +func.func @fill_pack_general() -> tensor<1x1x8x4x4x8xi32>{ + %c0_i32 = arith.constant 0 : i32 + %alloc = memref.alloc() : memref<1x1x8x4x4x8xi32> + %9 = tensor.empty() : tensor<1x1x16x64xi32> + %extracted_slice_15 = tensor.extract_slice %9[0, 0, 0, 0] [1, 1, 16, 64] [1, 1, 1, 1] : tensor<1x1x16x64xi32> to tensor<1x1x16x64xi32> + %16 = linalg.fill ins(%c0_i32 : i32) outs(%extracted_slice_15 : tensor<1x1x16x64xi32>) -> tensor<1x1x16x64xi32> + %0 = bufferization.to_tensor %alloc restrict writable : memref<1x1x8x4x4x8xi32> + %pack_18 = tensor.pack %16 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %0 : tensor<1x1x16x64xi32> -> tensor<1x1x8x4x4x8xi32> + return %pack_18 : tensor<1x1x8x4x4x8xi32> +} + +// CHECK-LABEL: func.func @fill_pack_general +// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<1x1x8x4x4x8xi32> +// CHECK: %[[TENSOR:.+]] = bufferization.to_tensor %[[ALLOC]] +// CHECK: %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[TENSOR]] +// CHECK: return %[[FILL]] + +// ----- + #map = affine_map<()[s0] -> (s0 ceildiv 16)> func.func @dynamic_fill_pack(%arg0: tensor) -> tensor { %cst = arith.constant 0.000000e+00 : f32