@@ -85,36 +85,6 @@ struct SCFTilingOptions {
85
85
return *this ;
86
86
}
87
87
88
- // / Specify how reduction dimensions should be tiled.
89
- // /
90
- // / Tiling can be thought of as splitting a dimension into 2 and materializing
91
- // / the outer dimension as a loop:
92
- // /
93
- // / op[original] -> op[original / x, x] -> loop[original] { op[x] }
94
- // /
95
- // / For parallel dimensions, the split can only happen in one way, with both
96
- // / dimensions being parallel. For reduction dimensions however, there is a
97
- // / choice in how we split the reduction dimension. This enum exposes this
98
- // / choice.
99
- enum class ReductionTilingStrategy {
100
- // [reduction] -> [reduction1, reduction2]
101
- // -> loop[reduction1] { [reduction2] }
102
- FullReduction,
103
- // [reduction] -> [reduction1, parallel2]
104
- // -> loop[reduction1] { [parallel2] }; merge[reduction1]
105
- PartialReductionOuterReduction,
106
- // [reduction] -> [parallel1, reduction2]
107
- // -> loop[parallel1] { [reduction2] }; merge[parallel1]
108
- PartialReductionOuterParallel
109
- };
110
- ReductionTilingStrategy reductionStrategy =
111
- ReductionTilingStrategy::FullReduction;
112
- SCFTilingOptions &
113
- setReductionTilingStrategy (ReductionTilingStrategy strategy) {
114
- reductionStrategy = strategy;
115
- return *this ;
116
- }
117
-
118
88
// / Specify mapping of loops to devices. This is only respected when the loop
119
89
// / constructs support such a mapping (like `scf.forall`). Will be ignored
120
90
// / when using loop constructs that dont support such a mapping (like
@@ -132,16 +102,11 @@ struct SCFTilingResult {
132
102
// / matter except the last op. The replacements are expected to be the results
133
103
// / of the last op.
134
104
SmallVector<Operation *> tiledOps;
135
- // / The initial destination values passed to the tiled operations.
136
- SmallVector<Value> initialValues;
137
105
// / The `scf.for` operations that iterate over the tiles.
138
106
SmallVector<LoopLikeOpInterface> loops;
139
- // / The result generated by the loop nest in tiling, may hold partial results,
140
- // / which need to be merged to match the computation of the untiled operation.
141
- // / `mergeResult` contains the operations used to perform this merge from
142
- // / partial results and the values that can be used as replacements of
143
- // / the untiled operation.
144
- MergeResult mergeResult;
107
+ // / Values to use as replacements for the untiled op. Is the same size as the
108
+ // / number of results of the untiled op.
109
+ SmallVector<Value> replacements;
145
110
// / Slices generated after tiling that can be used for fusing with the tiled
146
111
// / producer.
147
112
SmallVector<Operation *> generatedSlices;
@@ -335,6 +300,20 @@ tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp);
335
300
FailureOr<SmallVector<scf::ForOp>>
336
301
lowerToLoopsUsingSCFForOp (RewriterBase &rewriter, TilingInterface op);
337
302
303
+ // / Transformation information returned after reduction tiling.
304
+ struct SCFReductionTilingResult {
305
+ // / The partial reduction tiled op generated.
306
+ SmallVector<Operation *> parallelTiledOps;
307
+ // / The final reduction operation merging all the partial reductions.
308
+ SmallVector<Operation *> mergeOps;
309
+ // / Initial values used for reduction.
310
+ SmallVector<Value> initialValues;
311
+ // / The loop operations that iterate over the tiles.
312
+ SmallVector<LoopLikeOpInterface> loops;
313
+ // / The replacements to use for the results of the tiled operation.
314
+ SmallVector<Value> replacements;
315
+ };
316
+
338
317
// / Method to tile a reduction and generate a parallel op within a serial loop.
339
318
// / Each of the partial reductions are calculated in parallel. Then after the
340
319
// / loop all the partial reduction are merged into a final reduction.
@@ -359,7 +338,7 @@ lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op);
359
338
// / %6 = linalg.generic %1 ["parallel", "reduction"]
360
339
// / : tensor<7x4xf32> -> tensor<7xf32>
361
340
// / ```
362
- FailureOr<scf::SCFTilingResult >
341
+ FailureOr<scf::SCFReductionTilingResult >
363
342
tileReductionUsingScf (RewriterBase &b, PartialReductionOpInterface op,
364
343
ArrayRef<OpFoldResult> tileSize);
365
344
0 commit comments