Skip to content

Commit b586149

Browse files
authored
[mlir][tensor] Fold pack and unpack of empty input tensor (#92247)
Extends `tensor.empty` folding patterns with pack and unpack consumers to fold away the operations when their source is empty.
1 parent dbfedc6 commit b586149

File tree

4 files changed

+120
-3
lines changed

4 files changed

+120
-3
lines changed

mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ void populateDropRedundantInsertSliceRankExpansionPatterns(
5959
/// `tensor.collapse_shape` into other ops.
6060
void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns);
6161

62-
/// Populates `patterns` with patterns that fold tensor.empty with
63-
/// tensor.[extract_slice|expand_shape|collapse_shape].
62+
/// Populates `patterns` with patterns that fold tensor.empty with its
63+
/// consumers.
6464
///
6565
/// If `singleUseOnly` is set to "true", only tensor.empty ops with a single
6666
/// use are folded.

mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,49 @@ struct FoldEmptyTensorWithExtractSliceOp
9393
bool foldSingleUseOnly = false;
9494
};
9595

96+
/// tensor.empty does not define any tensor contents, so an unpadded pack
97+
/// can be folded away.
98+
struct FoldEmptyTensorWithPackOp : public OpRewritePattern<PackOp> {
99+
using OpRewritePattern<PackOp>::OpRewritePattern;
100+
101+
LogicalResult matchAndRewrite(PackOp packOp,
102+
PatternRewriter &rewriter) const override {
103+
// Check for tensor.empty source.
104+
auto emptyOp = packOp.getSource().getDefiningOp<EmptyOp>();
105+
if (!emptyOp)
106+
return failure();
107+
108+
// Check for padding.
109+
// Packing with padding cannot be simply removed.
110+
if (packOp.getPaddingValue())
111+
return rewriter.notifyMatchFailure(packOp, "expects no padding value");
112+
113+
// Replace the pack directly with its destination.
114+
rewriter.replaceOp(packOp, packOp.getDest());
115+
116+
return success();
117+
}
118+
};
119+
120+
/// tensor.empty does not define any tensor contents, so an unpack
121+
/// can be folded away.
122+
struct FoldEmptyTensorWithUnPackOp : public OpRewritePattern<UnPackOp> {
123+
using OpRewritePattern<UnPackOp>::OpRewritePattern;
124+
125+
LogicalResult matchAndRewrite(UnPackOp unPackOp,
126+
PatternRewriter &rewriter) const override {
127+
// Check for tensor.empty source.
128+
auto emptyOp = unPackOp.getSource().getDefiningOp<EmptyOp>();
129+
if (!emptyOp)
130+
return failure();
131+
132+
// Replace the unpack directly with its destination.
133+
rewriter.replaceOp(unPackOp, unPackOp.getDest());
134+
135+
return success();
136+
}
137+
};
138+
96139
} // namespace
97140

98141
void mlir::tensor::populateFoldTensorEmptyPatterns(RewritePatternSet &patterns,
@@ -101,4 +144,6 @@ void mlir::tensor::populateFoldTensorEmptyPatterns(RewritePatternSet &patterns,
101144
FoldEmptyTensorWithReshapeOp<tensor::ExpandShapeOp>,
102145
FoldEmptyTensorWithReshapeOp<tensor::CollapseShapeOp>>(
103146
patterns.getContext(), /*benefit=*/1, foldSingleUseOnly);
147+
patterns.add<FoldEmptyTensorWithPackOp, FoldEmptyTensorWithUnPackOp>(
148+
patterns.getContext(), /*benefit=*/1);
104149
}

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2523,4 +2523,3 @@ func.func @dim_out_of_bounds() -> vector<7xi32> {
25232523
%16 = affine.vector_load %alloc_21[%c1, %c1, %dim] : memref<?x26x2xi32>, vector<7xi32>
25242524
return %16 : vector<7xi32>
25252525
}
2526-

mlir/test/Dialect/Tensor/fold-empty-op.mlir

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,79 @@ func.func @rank_reducing_empty_tensor_extract(%sz : index, %idx : index) -> tens
6464
return %r: tensor<2xf32>
6565
}
6666

67+
func.func @pack_empty(%arg0: tensor<8x8x32x32xf32>) -> tensor<8x8x32x32xf32> {
68+
%empty_unpacked = tensor.empty() : tensor<256x256xf32>
69+
%packed = tensor.pack %empty_unpacked
70+
inner_dims_pos = [0, 1] inner_tiles = [32, 32]
71+
into %arg0 : tensor<256x256xf32> -> tensor<8x8x32x32xf32>
72+
return %packed : tensor<8x8x32x32xf32>
73+
}
74+
75+
// CHECK-LABEL: func.func @pack_empty(
76+
// CHECK-SAME: %[[T:.+]]: tensor<8x8x32x32xf32>
77+
// CHECK-NOT: tensor.pack
78+
// CHECK: return %[[T]] : tensor<8x8x32x32xf32>
79+
80+
func.func @pack_empty_dynamic(%arg0: tensor<?x?x?x?xf32>, %dim0: index, %dim1: index) -> tensor<?x?x?x?xf32> {
81+
%empty_unpacked = tensor.empty(%dim0, %dim1) : tensor<?x?xf32>
82+
%packed = tensor.pack %empty_unpacked
83+
inner_dims_pos = [0, 1] inner_tiles = [32, 32]
84+
into %arg0 : tensor<?x?xf32> -> tensor<?x?x?x?xf32>
85+
return %packed : tensor<?x?x?x?xf32>
86+
}
87+
88+
// CHECK-LABEL: func.func @pack_empty_dynamic(
89+
// CHECK-SAME: %[[T:.+]]: tensor<?x?x?x?xf32>,
90+
// CHECK-SAME: %[[DIM0:[a-zA-Z0-9_]+]]: index,
91+
// CHECK-SAME: %[[DIM1:[a-zA-Z0-9_]+]]: index
92+
// CHECK-NOT: tensor.pack
93+
// CHECK: return %[[T]] : tensor<?x?x?x?xf32>
94+
95+
func.func @unpack_empty(%arg0: tensor<256x256xf32>) -> tensor<256x256xf32> {
96+
%empty_packed = tensor.empty() : tensor<8x8x32x32xf32>
97+
%unpacked = tensor.unpack %empty_packed
98+
inner_dims_pos = [0, 1] inner_tiles = [32, 32]
99+
into %arg0 : tensor<8x8x32x32xf32> -> tensor<256x256xf32>
100+
return %unpacked : tensor<256x256xf32>
101+
}
102+
103+
// CHECK-LABEL: func.func @unpack_empty(
104+
// CHECK-SAME: %[[T:.+]]: tensor<256x256xf32>
105+
// CHECK-NOT: tensor.unpack
106+
// CHECK: return %[[T]] : tensor<256x256xf32>
107+
108+
func.func @unpack_empty_dynamic(%arg0: tensor<?x?xf32>, %dim0: index, %dim1: index, %dim2: index, %dim3: index) -> tensor<?x?xf32> {
109+
%empty_packed = tensor.empty(%dim0, %dim1, %dim2, %dim3) : tensor<?x?x?x?xf32>
110+
%unpacked = tensor.unpack %empty_packed
111+
inner_dims_pos = [0, 1] inner_tiles = [32, 32]
112+
into %arg0 : tensor<?x?x?x?xf32> -> tensor<?x?xf32>
113+
return %unpacked : tensor<?x?xf32>
114+
}
115+
116+
// CHECK-LABEL: func.func @unpack_empty_dynamic(
117+
// CHECK-SAME: %[[T:.+]]: tensor<?x?xf32>,
118+
// CHECK-SAME: %[[DIM0:[a-zA-Z0-9_]+]]: index,
119+
// CHECK-SAME: %[[DIM1:[a-zA-Z0-9_]+]]: index,
120+
// CHECK-SAME: %[[DIM2:[a-zA-Z0-9_]+]]: index,
121+
// CHECK-SAME: %[[DIM3:[a-zA-Z0-9_]+]]: index
122+
// CHECK-NOT: tensor.unpack
123+
// CHECK: return %[[T]] : tensor<?x?xf32>
124+
125+
func.func @pack_padded_empty(%arg0: tensor<8x8x32x32xf32>) -> tensor<8x8x32x32xf32> {
126+
%pad = arith.constant 1.0 : f32
127+
%empty_unpacked = tensor.empty() : tensor<256x256xf32>
128+
%packed = tensor.pack %empty_unpacked
129+
padding_value(%pad : f32)
130+
inner_dims_pos = [0, 1] inner_tiles = [32, 32]
131+
into %arg0 : tensor<256x256xf32> -> tensor<8x8x32x32xf32>
132+
return %packed : tensor<8x8x32x32xf32>
133+
}
134+
135+
// CHECK-LABEL: func.func @pack_padded_empty(
136+
// CHECK-SAME: %[[T:.+]]: tensor<8x8x32x32xf32>
137+
// CHECK: %[[PACK:.+]] = tensor.pack
138+
// CHECK: return %[[PACK]] : tensor<8x8x32x32xf32>
139+
67140
// -----
68141

69142
module attributes {transform.with_named_sequence} {

0 commit comments

Comments
 (0)