From 0623d1cb33902cd8dc45f3527acce267932e7980 Mon Sep 17 00:00:00 2001 From: Amy Wang Date: Fri, 8 Sep 2023 09:50:46 -0400 Subject: [PATCH] [MLIR][Tensor] Add Destination style RewritePattern for DimOp. Fold dim of a destination passing op with dim of the corresponding init. This enables canonicalization to fold away unnecessary tensor.dim ops which in turn enables folding away of other operations, as can be seen in conv_tensors_dynamic where affine.min operations were folded away. --- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 23 +++++++++++++++++- mlir/test/Dialect/Linalg/canonicalize.mlir | 24 +++++++++++++++++-- .../Dialect/Linalg/tile-and-fuse-tensors.mlir | 12 +++------- .../Linalg/transform-tile-reduction.mlir | 14 +++++------ 4 files changed, 54 insertions(+), 19 deletions(-) diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 42d89cd5a7620..25ddf2fc48d6f 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -579,11 +579,32 @@ struct DimOfCastOp : public OpRewritePattern { return success(); } }; + +/// Fold dim of a destination passing style op into the dim of the corresponding +/// init. +struct DimOfDestStyleOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DimOp dimOp, + PatternRewriter &rewriter) const override { + auto source = dimOp.getSource(); + auto destOp = source.getDefiningOp(); + if (!destOp) + return failure(); + + auto resultIndex = source.cast().getResultNumber(); + auto initOperand = destOp.getDpsInitOperand(resultIndex); + + rewriter.updateRootInPlace( + dimOp, [&]() { dimOp.getSourceMutable().assign(initOperand->get()); }); + return success(); + } +}; } // namespace void DimOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir index 783660727ce16..297b5c4e332c8 100644 --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -397,9 +397,8 @@ func.func @fold_static_pad_fill() -> tensor<412x276xf32> { // CHECK-DAG: %[[I1:.+]] = arith.constant 1 : index // CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[OF:.+]] = linalg.fill ins(%[[F0]] : f32) outs(%[[SRC]] : tensor<8x?x16x32xf32>) // CHECK: %[[S0:.+]] = affine.apply #[[MAP0]]()[%[[LOW0]]] -// CHECK: %[[DIM1:.+]] = tensor.dim %[[OF]], %[[I1]] : tensor<8x?x16x32xf32> +// CHECK: %[[DIM1:.+]] = tensor.dim %[[SRC]], %[[I1]] : tensor<8x?x16x32xf32> // CHECK: %[[S1:.+]] = affine.apply #[[MAP1]]()[%[[DIM1]]] // CHECK: %[[S2:.+]] = affine.apply #[[MAP2]]()[%[[HIGH2]]] // CHECK: %[[S3:.+]] = affine.apply #[[MAP3]]()[%[[LOW3]], %[[HIGH3]]] @@ -908,3 +907,24 @@ func.func @dead_softmax(%arg0: tensor<16x64x256xf32>) -> tensor<16x64x256xf32> { ins(%arg0 : tensor<16x64x256xf32>) outs(%0 : tensor<16x64x256xf32>) -> tensor<16x64x256xf32> return %arg0 : tensor<16x64x256xf32> } + +// ----- + +// CHECK-LABEL: func @canonicalize_dim_of_dest_style_op +// CHECK: tensor.dim +// CHECK: tensor.dim +// CHECK-NOT: tensor.dim +// CHECK: return +func.func @canonicalize_dim_of_dest_style_op(%arg0 : tensor) -> tensor { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %dim0_0 = tensor.dim %arg0, %c0 : tensor + %dim1_0 = tensor.dim %arg0, %c1 : tensor + %0 = tensor.empty(%dim0_0, %dim1_0) : tensor + %1 = linalg.copy ins(%arg0 : tensor) outs(%0 : tensor) -> tensor + %dim0_1 = tensor.dim %1, %c0 : tensor + %dim1_1 = tensor.dim %1, %c1 : tensor + %2 = tensor.empty(%dim0_1, %dim1_1) : tensor + %3 = linalg.copy ins(%1 : tensor) outs(%2 : tensor) -> tensor + return %3: tensor +} diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir index 6f21e1e20c3d4..0f27a92c119cf 100644 --- a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir +++ b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir @@ -197,10 +197,8 @@ func.func @conv_tensors_dynamic(%input: tensor, %filter: tensor (-d0 + s0, 16)> // CHECK: #[[X2_MAP:.+]] = affine_map<(d0) -> (d0 * 2)> // CHECK: #[[INPUT_BOUND:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * -2 + s0 * 2 + s1 - 2, d1 * 2 + s1 - 2)> -// CHECK: #[[BOUND16_MAP_2:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s1, -d0 + s0, 16)> // CHECK: #[[BOUND4_MAP:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 4)> // CHECK: #[[BOUND2_MAP:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 2)> -// CHECK: #[[BOUND4_MAP_2:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s1, -d0 + s0, 4)> // CHECK: #[[BOUND2_MAP_2:.+]] = affine_map<(d0, d1)[s0, s1] -> (-d0 + s0, -d1 + s1, 2)> // CHECK: func @conv_tensors_dynamic @@ -225,8 +223,6 @@ func.func @conv_tensors_dynamic(%input: tensor, %filter: tensor // CHECK-DAG: %[[INPUT_N:.+]] = tensor.dim %[[INPUT]], %[[C0]] : tensor // CHECK-DAG: %[[INPUT_C:.+]] = tensor.dim %[[INPUT]], %[[C3]] : tensor -// CHECK-DAG: %[[FILL_H:.+]] = tensor.dim %[[FILL]], %[[C1]] : tensor -// CHECK-DAG: %[[FILL_W:.+]] = tensor.dim %[[FILL]], %[[C2]] : tensor // CHECK: scf.for %[[IV0:.+]] = %{{.+}} to %[[ELEM_N]] step %{{.+}} iter_args(%{{.+}} = %[[FILL]]) // CHECK-NEXT: %[[SIZE_ELEM_N:.+]] = affine.min #[[BOUND8_MAP]](%[[IV0]])[%[[ELEM_N]]] @@ -234,14 +230,12 @@ func.func @conv_tensors_dynamic(%input: tensor, %filter: tensor, %filter: tensor, tensor) // CHECK-SAME: outs(%[[ST_FILL]] : tensor) -> tensor diff --git a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir index 70e535b74f055..934be889cecb2 100644 --- a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir +++ b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir @@ -43,9 +43,7 @@ transform.sequence failures(propagate) { // CHECK: arith.addf // CHECK: linalg.yield // CHECK: } -> tensor -// CHECK: %[[D3:.*]] = tensor.dim %[[PR]], %[[C0]] : tensor -// CHECK: %[[D4:.*]] = tensor.dim %[[PR]], %[[C1]] : tensor -// CHECK: %[[INS:.*]] = tensor.insert_slice %[[PR]] into %[[ARG3]][0, 0] [%[[D3]], %[[D4]]] [1, 1] : tensor into tensor +// CHECK: %[[INS:.*]] = tensor.insert_slice %[[PR]] into %[[ARG3]][0, 0] [%[[D0]], %[[PS]]] [1, 1] : tensor into tensor // CHECK: scf.yield %[[INS]] : tensor // CHECK: } // CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP0]], #[[MAP1]]], iterator_types = ["parallel", "reduction"]} ins(%[[L]] : tensor) outs(%[[ARG1]] : tensor) { @@ -76,14 +74,16 @@ transform.sequence failures(propagate) { by tile_sizes = [5, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) } +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 5)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> (d1)> // CHECK: func @reduction_tile_transpose // CHECK: tensor.empty(%{{.*}}) : tensor<5x?xf32> // CHECK: linalg.fill {{.*}} : tensor<5x?xf32>) -> tensor<5x?xf32> // CHECK: scf.for -// CHECK: linalg.generic -// CHECK: %[[D3:.*]] = tensor.dim %{{.*}}, %[[C0]] : tensor -// CHECK: %[[D4:.*]] = tensor.dim %{{.*}}, %[[C1]] : tensor -// CHECK: %[[INS:.*]] = tensor.insert_slice %[[PR]] into %[[ARG3]][0, 0] [%[[D3]], %[[D4]]] [1, 1] : tensor into tensor<5x?xf32> +// CHECK: %[[EXT:.*]] = tensor.extract_slice %[[ARG3:.*]][0, 0] [%[[D0:.*]], %[[D1:.*]]] [1, 1] : tensor<5x?xf32> to tensor +// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[L:.*]] : tensor) outs(%[[EXT]] : tensor) +// CHECK: %[[INS:.*]] = tensor.insert_slice %[[R]] into %[[ARG3]][0, 0] [%[[D0]], %[[D1]]] [1, 1] : tensor into tensor<5x?xf32> // CHECK: scf.yield {{.*}} : tensor<5x?xf32> // CHECK: } // CHECK: linalg.generic