Skip to content

Commit f59eef6

Browse files
authored
[mlir][tensor] Enhance SimplifyPackToExpandShape for unit dim cases. (#79247)
Progress on iree-org/iree#16181
1 parent 816cc9d commit f59eef6

File tree

2 files changed

+99
-8
lines changed

2 files changed

+99
-8
lines changed

mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@ static bool areAllConstantIntValue(ArrayRef<OpFoldResult> ofrs, int64_t value) {
2222
ofrs, [&](OpFoldResult ofr) { return isConstantIntValue(ofr, value); });
2323
}
2424

25+
/// Returns the number of shape sizes that is either dynamic or greater than 1.
26+
static int64_t getNumGtOneDims(ArrayRef<int64_t> shape) {
27+
return llvm::count_if(
28+
shape, [](int64_t v) { return ShapedType::isDynamic(v) || v > 1; });
29+
}
30+
2531
/// Packing one-dimensional tensor can be expressed as an expand shape op.
2632
struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
2733
using OpRewritePattern<PackOp>::OpRewritePattern;
@@ -34,26 +40,60 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
3440
reassociation);
3541
}
3642

37-
LogicalResult matchAndRewrite(PackOp packOp,
38-
PatternRewriter &rewriter) const override {
39-
if (packOp.getPaddingValue())
40-
return rewriter.notifyMatchFailure(packOp, "expects no padding value");
41-
43+
/// Returns success() if it is only packing on the innermost dimension.
44+
LogicalResult isPackOnInnerMostDim(RewriterBase &rewriter,
45+
PackOp packOp) const {
4246
auto outerDimsPerm = packOp.getOuterDimsPerm();
4347
if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) {
4448
return rewriter.notifyMatchFailure(
4549
packOp,
4650
"expects outer_dims_perm is empty or an identity permutation");
4751
}
4852

49-
RankedTensorType sourceType = packOp.getSourceType();
50-
RankedTensorType destType = packOp.getDestType();
53+
int64_t srcRank = packOp.getSourceRank();
5154
ArrayRef<int64_t> dimsPos = packOp.getInnerDimsPos();
52-
if (dimsPos.size() != 1 || (dimsPos[0] + 1 != sourceType.getRank())) {
55+
if (dimsPos.size() != 1 || (dimsPos[0] + 1 != srcRank)) {
5356
return rewriter.notifyMatchFailure(
5457
packOp, "expects packing at the innermost dimension");
5558
}
59+
return success();
60+
}
61+
62+
/// Returns success() if there is only 1 dimension size in source being
63+
/// greater than 1 and packing only happens on the dimension. It assumes that
64+
/// the pack op does not have padding value.
65+
LogicalResult isPack1DSrc(RewriterBase &rewriter, PackOp packOp) const {
66+
assert(!packOp.getPaddingValue() &&
67+
"expect the op does not have padding value.");
68+
ArrayRef<int64_t> srcShape = packOp.getSourceType().getShape();
69+
if (getNumGtOneDims(srcShape) > 1) {
70+
return rewriter.notifyMatchFailure(
71+
packOp, "expects source to have at most one non-unit dims");
72+
}
5673

74+
// The pack op does not have padding value. Non-unit inner tile size must be
75+
// be used by the non-unit dimension.
76+
SmallVector<int64_t> innerTiles = packOp.getStaticTiles();
77+
if (getNumGtOneDims(innerTiles) > 1) {
78+
return rewriter.notifyMatchFailure(
79+
packOp, "expects at most one non-unit inner tiles");
80+
}
81+
82+
return success();
83+
}
84+
85+
LogicalResult matchAndRewrite(PackOp packOp,
86+
PatternRewriter &rewriter) const override {
87+
if (packOp.getPaddingValue())
88+
return rewriter.notifyMatchFailure(packOp, "expects no padding value");
89+
90+
if (failed(isPackOnInnerMostDim(rewriter, packOp)) &&
91+
failed(isPack1DSrc(rewriter, packOp))) {
92+
return failure();
93+
}
94+
95+
RankedTensorType sourceType = packOp.getSourceType();
96+
RankedTensorType destType = packOp.getDestType();
5797
auto reassociation =
5898
getReassociationIndicesForReshape(sourceType, destType);
5999
if (!reassociation)

mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,57 @@ func.func @single_first_inner_dim_packing(%arg0: tensor<256x5xf32>) -> tensor<8x
8383

8484
// -----
8585

86+
// CHECK-LABEL: func.func @pack_1x32_to_1x32x1x1
87+
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]
88+
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2, 3]]
89+
// CHECK: return %[[EXPANDED]]
90+
func.func @pack_1x32_to_1x32x1x1(%arg0 : tensor<1x32xf32>) -> tensor<1x32x1x1xf32> {
91+
%empty = tensor.empty() : tensor<1x32x1x1xf32>
92+
%pack = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [1, 1] into %empty
93+
: tensor<1x32xf32> -> tensor<1x32x1x1xf32>
94+
return %pack : tensor<1x32x1x1xf32>
95+
}
96+
97+
// -----
98+
99+
// CHECK-LABEL: func.func @pack_1x32_to_1x16x1x2
100+
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]
101+
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2, 3]]
102+
// CHECK: return %[[EXPANDED]]
103+
func.func @pack_1x32_to_1x16x1x2(%arg0 : tensor<1x32xf32>) -> tensor<1x16x1x2xf32> {
104+
%empty = tensor.empty() : tensor<1x16x1x2xf32>
105+
%pack = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [1, 2] into %empty
106+
: tensor<1x32xf32> -> tensor<1x16x1x2xf32>
107+
return %pack : tensor<1x16x1x2xf32>
108+
}
109+
110+
// -----
111+
112+
// CHECK-LABEL: func.func @pack_32x1_to_16x1x2x1
113+
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]
114+
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]]
115+
// CHECK: return %[[EXPANDED]]
116+
func.func @pack_32x1_to_16x1x2x1(%arg0 : tensor<32x1xf32>) -> tensor<1x16x2x1xf32> {
117+
%empty = tensor.empty() : tensor<1x16x2x1xf32>
118+
%pack = tensor.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [2, 1] into %empty
119+
: tensor<32x1xf32> -> tensor<1x16x2x1xf32>
120+
return %pack : tensor<1x16x2x1xf32>
121+
}
122+
123+
// -----
124+
125+
// CHECK-LABEL: func.func @pack_32x1_to_16x1x1x2
126+
// CHECK-NOT: tensor.expand_shape
127+
// CHECK: tensor.pack
128+
func.func @pack_32x1_to_16x1x1x2(%arg0 : tensor<32x1xf32>) -> tensor<16x1x1x2xf32> {
129+
%empty = tensor.empty() : tensor<16x1x1x2xf32>
130+
%pack = tensor.pack %arg0 inner_dims_pos = [1, 0] inner_tiles = [1, 2] into %empty
131+
: tensor<32x1xf32> -> tensor<16x1x1x2xf32>
132+
return %pack : tensor<16x1x1x2xf32>
133+
}
134+
135+
// -----
136+
86137
// CHECK-LABEL: func.func @unpack_1d_to_collapse
87138
// CHECK-SAME: %[[ARG0:.+]]: tensor<8x32xf32>)
88139
// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]] : tensor<8x32xf32> into tensor<256xf32>

0 commit comments

Comments
 (0)