diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h index e8a09c4741043..dd6b0e8682564 100644 --- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h @@ -59,8 +59,8 @@ void populateDropRedundantInsertSliceRankExpansionPatterns( /// `tensor.collapse_shape` into other ops. void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns); -/// Populates `patterns` with patterns that fold tensor.empty with -/// tensor.[extract_slice|expand_shape|collapse_shape]. +/// Populates `patterns` with patterns that fold tensor.empty with its +/// consumers. /// /// If `singleUseOnly` is set to "true", only tensor.empty ops with a single /// use are folded. diff --git a/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp index 7a707e749e69b..43ad0acaf7420 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp @@ -93,6 +93,49 @@ struct FoldEmptyTensorWithExtractSliceOp bool foldSingleUseOnly = false; }; +/// tensor.empty does not define any tensor contents, so an unpadded pack +/// can be folded away. +struct FoldEmptyTensorWithPackOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(PackOp packOp, + PatternRewriter &rewriter) const override { + // Check for tensor.empty source. + auto emptyOp = packOp.getSource().getDefiningOp(); + if (!emptyOp) + return failure(); + + // Check for padding. + // Packing with padding cannot be simply removed. + if (packOp.getPaddingValue()) + return rewriter.notifyMatchFailure(packOp, "expects no padding value"); + + // Replace the pack directly with its destination. + rewriter.replaceOp(packOp, packOp.getDest()); + + return success(); + } +}; + +/// tensor.empty does not define any tensor contents, so an unpack +/// can be folded away. +struct FoldEmptyTensorWithUnPackOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(UnPackOp unPackOp, + PatternRewriter &rewriter) const override { + // Check for tensor.empty source. + auto emptyOp = unPackOp.getSource().getDefiningOp(); + if (!emptyOp) + return failure(); + + // Replace the unpack directly with its destination. + rewriter.replaceOp(unPackOp, unPackOp.getDest()); + + return success(); + } +}; + } // namespace void mlir::tensor::populateFoldTensorEmptyPatterns(RewritePatternSet &patterns, @@ -101,4 +144,6 @@ void mlir::tensor::populateFoldTensorEmptyPatterns(RewritePatternSet &patterns, FoldEmptyTensorWithReshapeOp, FoldEmptyTensorWithReshapeOp>( patterns.getContext(), /*benefit=*/1, foldSingleUseOnly); + patterns.add( + patterns.getContext(), /*benefit=*/1); } diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index 914e5e8b8c4b8..f7fbd3834288b 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -2523,4 +2523,3 @@ func.func @dim_out_of_bounds() -> vector<7xi32> { %16 = affine.vector_load %alloc_21[%c1, %c1, %dim] : memref, vector<7xi32> return %16 : vector<7xi32> } - diff --git a/mlir/test/Dialect/Tensor/fold-empty-op.mlir b/mlir/test/Dialect/Tensor/fold-empty-op.mlir index e200a4f892613..e94f6ec7ec56e 100644 --- a/mlir/test/Dialect/Tensor/fold-empty-op.mlir +++ b/mlir/test/Dialect/Tensor/fold-empty-op.mlir @@ -64,6 +64,79 @@ func.func @rank_reducing_empty_tensor_extract(%sz : index, %idx : index) -> tens return %r: tensor<2xf32> } +func.func @pack_empty(%arg0: tensor<8x8x32x32xf32>) -> tensor<8x8x32x32xf32> { + %empty_unpacked = tensor.empty() : tensor<256x256xf32> + %packed = tensor.pack %empty_unpacked + inner_dims_pos = [0, 1] inner_tiles = [32, 32] + into %arg0 : tensor<256x256xf32> -> tensor<8x8x32x32xf32> + return %packed : tensor<8x8x32x32xf32> +} + +// CHECK-LABEL: func.func @pack_empty( +// CHECK-SAME: %[[T:.+]]: tensor<8x8x32x32xf32> +// CHECK-NOT: tensor.pack +// CHECK: return %[[T]] : tensor<8x8x32x32xf32> + +func.func @pack_empty_dynamic(%arg0: tensor, %dim0: index, %dim1: index) -> tensor { + %empty_unpacked = tensor.empty(%dim0, %dim1) : tensor + %packed = tensor.pack %empty_unpacked + inner_dims_pos = [0, 1] inner_tiles = [32, 32] + into %arg0 : tensor -> tensor + return %packed : tensor +} + +// CHECK-LABEL: func.func @pack_empty_dynamic( +// CHECK-SAME: %[[T:.+]]: tensor, +// CHECK-SAME: %[[DIM0:[a-zA-Z0-9_]+]]: index, +// CHECK-SAME: %[[DIM1:[a-zA-Z0-9_]+]]: index +// CHECK-NOT: tensor.pack +// CHECK: return %[[T]] : tensor + +func.func @unpack_empty(%arg0: tensor<256x256xf32>) -> tensor<256x256xf32> { + %empty_packed = tensor.empty() : tensor<8x8x32x32xf32> + %unpacked = tensor.unpack %empty_packed + inner_dims_pos = [0, 1] inner_tiles = [32, 32] + into %arg0 : tensor<8x8x32x32xf32> -> tensor<256x256xf32> + return %unpacked : tensor<256x256xf32> +} + +// CHECK-LABEL: func.func @unpack_empty( +// CHECK-SAME: %[[T:.+]]: tensor<256x256xf32> +// CHECK-NOT: tensor.unpack +// CHECK: return %[[T]] : tensor<256x256xf32> + +func.func @unpack_empty_dynamic(%arg0: tensor, %dim0: index, %dim1: index, %dim2: index, %dim3: index) -> tensor { + %empty_packed = tensor.empty(%dim0, %dim1, %dim2, %dim3) : tensor + %unpacked = tensor.unpack %empty_packed + inner_dims_pos = [0, 1] inner_tiles = [32, 32] + into %arg0 : tensor -> tensor + return %unpacked : tensor +} + +// CHECK-LABEL: func.func @unpack_empty_dynamic( +// CHECK-SAME: %[[T:.+]]: tensor, +// CHECK-SAME: %[[DIM0:[a-zA-Z0-9_]+]]: index, +// CHECK-SAME: %[[DIM1:[a-zA-Z0-9_]+]]: index, +// CHECK-SAME: %[[DIM2:[a-zA-Z0-9_]+]]: index, +// CHECK-SAME: %[[DIM3:[a-zA-Z0-9_]+]]: index +// CHECK-NOT: tensor.unpack +// CHECK: return %[[T]] : tensor + +func.func @pack_padded_empty(%arg0: tensor<8x8x32x32xf32>) -> tensor<8x8x32x32xf32> { + %pad = arith.constant 1.0 : f32 + %empty_unpacked = tensor.empty() : tensor<256x256xf32> + %packed = tensor.pack %empty_unpacked + padding_value(%pad : f32) + inner_dims_pos = [0, 1] inner_tiles = [32, 32] + into %arg0 : tensor<256x256xf32> -> tensor<8x8x32x32xf32> + return %packed : tensor<8x8x32x32xf32> +} + +// CHECK-LABEL: func.func @pack_padded_empty( +// CHECK-SAME: %[[T:.+]]: tensor<8x8x32x32xf32> +// CHECK: %[[PACK:.+]] = tensor.pack +// CHECK: return %[[PACK]] : tensor<8x8x32x32xf32> + // ----- module attributes {transform.with_named_sequence} {