From 0e7c32325a873871f1337ade892f5e726e8db41e Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Tue, 5 Dec 2023 16:18:00 +0900 Subject: [PATCH] [mlir][linalg] Fix invalid IR in Linalg op fusion Linalg op fusion (`Linalg/Transforms/Fusion.cpp`) used to generate invalid fused producer ops: ``` error: 'linalg.conv_2d_nhwc_hwcf' op expected type of operand #2 ('tensor<1x8x16x4xf32>') to match type of corresponding result ('tensor') note: see current operation: %24 = "linalg.conv_2d_nhwc_hwcf"(%21, %22, %23) <{dilations = dense<1> : tensor<2xi64>, operandSegmentSizes = array, strides = dense<2> : tensor<2xi64>}> ({ ^bb0(%arg9: f32, %arg10: f32, %arg11: f32): %28 = "arith.mulf"(%arg9, %arg10) <{fastmath = #arith.fastmath}> : (f32, f32) -> f32 %29 = "arith.addf"(%arg11, %28) <{fastmath = #arith.fastmath}> : (f32, f32) -> f32 "linalg.yield"(%29) : (f32) -> () }) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 * 2 + d4, d2 * 2 + d5, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>]} : (tensor<1x?x?x3xf32>, tensor<3x3x3x4xf32>, tensor<1x8x16x4xf32>) -> tensor ``` This is a problem because the input IR to greedy pattern rewriter during `-test-linalg-greedy-fusion` is invalid. This commit fixes tests such as `mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir` when verifying the IR after each pattern application (#74270). --- mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp | 24 ++++++------------- 1 file changed, 7 insertions(+), 17 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp index 11bd886c36e53..e48188fe516d3 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -144,27 +144,17 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer, b, loc, producer, getTiledOperands(producer), ivs, tileSizes, sizeBounds, /**omitPartialTileCheck=*/false)); - // Iterate over the results in order. - // Extract the subtensor type from the linearized range. - // Since we do not enforce any canonicalizations on the fly, this is always - // fully dynamic at construction time. + // Take result types from the tiled init operands. + MutableOperandRange producerDpsInits = producer.getDpsInitsMutable(); SmallVector resultTypes; resultTypes.reserve(producer->getNumResults()); - for (Value operand : producer.getDpsInits()) { - auto tensorType = dyn_cast(operand.getType()); - if (!tensorType) - continue; - unsigned rank = tensorType.getRank(); - SmallVector staticOffsetsVector( - rank, ShapedType::kDynamic); - SmallVector staticSizesVector(rank, ShapedType::kDynamic); - SmallVector staticStridesVector( - rank, ShapedType::kDynamic); - resultTypes.push_back(tensor::ExtractSliceOp::inferResultType( - tensorType, staticOffsetsVector, staticSizesVector, - staticStridesVector)); + int64_t firstInitOperandIdx = + static_cast(producerDpsInits).getBeginOperandIndex(); + for (int64_t i = 0, e = producer->getNumResults(); i < e; ++i) { + resultTypes.push_back(clonedShapes[firstInitOperandIdx + i].getType()); } + // Clone the producer with new operands and result types. LinalgOp clonedOp = clone(b, producer, resultTypes, clonedShapes); // Shift all IndexOp results by the tile offset.