@@ -22,6 +22,12 @@ static bool areAllConstantIntValue(ArrayRef<OpFoldResult> ofrs, int64_t value) {
22
22
ofrs, [&](OpFoldResult ofr) { return isConstantIntValue (ofr, value); });
23
23
}
24
24
25
+ // / Returns the number of shape sizes that is either dynamic or greater than 1.
26
+ static int64_t getNumGtOneDims (ArrayRef<int64_t > shape) {
27
+ return llvm::count_if (
28
+ shape, [](int64_t v) { return ShapedType::isDynamic (v) || v > 1 ; });
29
+ }
30
+
25
31
// / Packing one-dimensional tensor can be expressed as an expand shape op.
26
32
struct SimplifyPackToExpandShape : public OpRewritePattern <PackOp> {
27
33
using OpRewritePattern<PackOp>::OpRewritePattern;
@@ -34,26 +40,60 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
34
40
reassociation);
35
41
}
36
42
37
- LogicalResult matchAndRewrite (PackOp packOp,
38
- PatternRewriter &rewriter) const override {
39
- if (packOp.getPaddingValue ())
40
- return rewriter.notifyMatchFailure (packOp, " expects no padding value" );
41
-
43
+ // / Returns success() if it is only packing on the innermost dimension.
44
+ LogicalResult isPackOnInnerMostDim (RewriterBase &rewriter,
45
+ PackOp packOp) const {
42
46
auto outerDimsPerm = packOp.getOuterDimsPerm ();
43
47
if (!outerDimsPerm.empty () && !isIdentityPermutation (outerDimsPerm)) {
44
48
return rewriter.notifyMatchFailure (
45
49
packOp,
46
50
" expects outer_dims_perm is empty or an identity permutation" );
47
51
}
48
52
49
- RankedTensorType sourceType = packOp.getSourceType ();
50
- RankedTensorType destType = packOp.getDestType ();
53
+ int64_t srcRank = packOp.getSourceRank ();
51
54
ArrayRef<int64_t > dimsPos = packOp.getInnerDimsPos ();
52
- if (dimsPos.size () != 1 || (dimsPos[0 ] + 1 != sourceType. getRank () )) {
55
+ if (dimsPos.size () != 1 || (dimsPos[0 ] + 1 != srcRank )) {
53
56
return rewriter.notifyMatchFailure (
54
57
packOp, " expects packing at the innermost dimension" );
55
58
}
59
+ return success ();
60
+ }
61
+
62
+ // / Returns success() if there is only 1 dimension size in source being
63
+ // / greater than 1 and packing only happens on the dimension. It assumes that
64
+ // / the pack op does not have padding value.
65
+ LogicalResult isPack1DSrc (RewriterBase &rewriter, PackOp packOp) const {
66
+ assert (!packOp.getPaddingValue () &&
67
+ " expect the op does not have padding value." );
68
+ ArrayRef<int64_t > srcShape = packOp.getSourceType ().getShape ();
69
+ if (getNumGtOneDims (srcShape) > 1 ) {
70
+ return rewriter.notifyMatchFailure (
71
+ packOp, " expects source to have at most one non-unit dims" );
72
+ }
56
73
74
+ // The pack op does not have padding value. Non-unit inner tile size must be
75
+ // be used by the non-unit dimension.
76
+ SmallVector<int64_t > innerTiles = packOp.getStaticTiles ();
77
+ if (getNumGtOneDims (innerTiles) > 1 ) {
78
+ return rewriter.notifyMatchFailure (
79
+ packOp, " expects at most one non-unit inner tiles" );
80
+ }
81
+
82
+ return success ();
83
+ }
84
+
85
+ LogicalResult matchAndRewrite (PackOp packOp,
86
+ PatternRewriter &rewriter) const override {
87
+ if (packOp.getPaddingValue ())
88
+ return rewriter.notifyMatchFailure (packOp, " expects no padding value" );
89
+
90
+ if (failed (isPackOnInnerMostDim (rewriter, packOp)) &&
91
+ failed (isPack1DSrc (rewriter, packOp))) {
92
+ return failure ();
93
+ }
94
+
95
+ RankedTensorType sourceType = packOp.getSourceType ();
96
+ RankedTensorType destType = packOp.getDestType ();
57
97
auto reassociation =
58
98
getReassociationIndicesForReshape (sourceType, destType);
59
99
if (!reassociation)
0 commit comments