@@ -46,6 +46,21 @@ struct GpuTilingAndFusion final
46
46
void runOnOperation () override {
47
47
IRRewriter rewriter (&getContext ());
48
48
scf::SCFTileAndFuseOptions opts;
49
+ opts.setFusionControlFn (
50
+ [&](tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer,
51
+ bool isDestinationOperand)
52
+ -> std::optional<scf::SCFTileAndFuseOptions::ControlFnResult> {
53
+ Operation *op = originalProducer.getOwner ();
54
+ if (!op) {
55
+ return std::nullopt;
56
+ }
57
+ if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
58
+ if (!linalgOp.hasOnlyProjectedPermutations ()) {
59
+ return std::nullopt;
60
+ }
61
+ }
62
+ return scf::SCFTileAndFuseOptions::ControlFnResult{};
63
+ });
49
64
opts.tilingOptions .setLoopType (scf::SCFTilingOptions::LoopType::ForallOp);
50
65
// The outer loop is converted to a GPU kernel and the tile sizes are mapped
51
66
// to the grid sizes.
@@ -77,13 +92,15 @@ struct GpuTilingAndFusion final
77
92
assert (itTypes.size () == itDomains.size ());
78
93
79
94
// TODO: Add a parameter to the options?
80
- size_t totalSize = calcOperandsSize (op) * euThreads ;
95
+ size_t totalSize = calcOperandsSize (op);
81
96
unsigned loopCount = 0 ;
97
+ SmallVector<int64_t > sizes;
82
98
83
99
for (auto [t, r] : zip (itTypes, itDomains)) {
84
100
if (t == utils::IteratorType::parallel) {
85
101
if (auto v = getConstantIntValue (r.size )) {
86
102
loopCount++;
103
+ sizes.emplace_back (*v);
87
104
totalSize *= *v;
88
105
} else {
89
106
return calcDynamicSizes (builder, ti, euMem, euThreads);
@@ -95,19 +112,25 @@ struct GpuTilingAndFusion final
95
112
return {};
96
113
}
97
114
98
- // TODO: In case of different sizes, calculate the ratio for each loop
99
- double ratio = std::pow (static_cast <double >(totalSize) /
100
- static_cast <double >(euMem),
101
- 1.0 / loopCount);
102
- ratio = std::max (1.0 , ratio);
115
+ auto outerTileSize = static_cast <size_t >(
116
+ std::ceil (static_cast <double >(euMem) /
117
+ static_cast <double >(calcOperandsSize (op))));
118
+ SmallVector<int64_t > outerTiles;
119
+ SmallVector<int64_t > innerTiles;
120
+ normaliseTiles (outerTileSize, sizes, outerTiles);
121
+ normaliseTiles (euThreads, sizes, innerTiles);
122
+
123
+ unsigned counter = 0 ;
103
124
SmallVector<OpFoldResult> tiles;
104
125
tiles.reserve (itDomains.size ());
105
126
106
127
for (auto [t, r] : zip (itTypes, itDomains)) {
107
128
if (t != utils::IteratorType::parallel) {
108
129
tiles.emplace_back (builder.getIndexAttr (1 ));
109
130
} else if (auto v = getConstantIntValue (r.size )) {
110
- tiles.emplace_back (ceil (builder, *v, ratio));
131
+ tiles.emplace_back (
132
+ ceil (builder, outerTiles[counter], innerTiles[counter]));
133
+ counter++;
111
134
} else {
112
135
abort (); // Must never get here
113
136
}
@@ -174,7 +197,8 @@ struct GpuTilingAndFusion final
174
197
static std::optional<TilingInterface> findTi (Operation *op) {
175
198
std::optional<TilingInterface> last;
176
199
op->walk <WalkOrder::PreOrder>([&](linalg::LinalgOp linalgOp) {
177
- if (!linalgOp->getParentOfType <scf::ForallOp>()) {
200
+ if (linalgOp.hasOnlyProjectedPermutations () &&
201
+ !linalgOp->getParentOfType <scf::ForallOp>()) {
178
202
if (auto ti = dyn_cast<TilingInterface>(linalgOp.getOperation ())) {
179
203
last = ti;
180
204
}
0 commit comments