Skip to content

Commit b711077

Browse files
committed
Revert "[mlir][SCF] Unify tileUsingFor and tileReductionUsingFor implementation (llvm#120115)"
This reverts commit 4b56345.
1 parent d3821b5 commit b711077

File tree

4 files changed

+225
-305
lines changed

4 files changed

+225
-305
lines changed

mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h

Lines changed: 18 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -85,36 +85,6 @@ struct SCFTilingOptions {
8585
return *this;
8686
}
8787

88-
/// Specify how reduction dimensions should be tiled.
89-
///
90-
/// Tiling can be thought of as splitting a dimension into 2 and materializing
91-
/// the outer dimension as a loop:
92-
///
93-
/// op[original] -> op[original / x, x] -> loop[original] { op[x] }
94-
///
95-
/// For parallel dimensions, the split can only happen in one way, with both
96-
/// dimensions being parallel. For reduction dimensions however, there is a
97-
/// choice in how we split the reduction dimension. This enum exposes this
98-
/// choice.
99-
enum class ReductionTilingStrategy {
100-
// [reduction] -> [reduction1, reduction2]
101-
// -> loop[reduction1] { [reduction2] }
102-
FullReduction,
103-
// [reduction] -> [reduction1, parallel2]
104-
// -> loop[reduction1] { [parallel2] }; merge[reduction1]
105-
PartialReductionOuterReduction,
106-
// [reduction] -> [parallel1, reduction2]
107-
// -> loop[parallel1] { [reduction2] }; merge[parallel1]
108-
PartialReductionOuterParallel
109-
};
110-
ReductionTilingStrategy reductionStrategy =
111-
ReductionTilingStrategy::FullReduction;
112-
SCFTilingOptions &
113-
setReductionTilingStrategy(ReductionTilingStrategy strategy) {
114-
reductionStrategy = strategy;
115-
return *this;
116-
}
117-
11888
/// Specify mapping of loops to devices. This is only respected when the loop
11989
/// constructs support such a mapping (like `scf.forall`). Will be ignored
12090
/// when using loop constructs that dont support such a mapping (like
@@ -132,16 +102,11 @@ struct SCFTilingResult {
132102
/// matter except the last op. The replacements are expected to be the results
133103
/// of the last op.
134104
SmallVector<Operation *> tiledOps;
135-
/// The initial destination values passed to the tiled operations.
136-
SmallVector<Value> initialValues;
137105
/// The `scf.for` operations that iterate over the tiles.
138106
SmallVector<LoopLikeOpInterface> loops;
139-
/// The result generated by the loop nest in tiling, may hold partial results,
140-
/// which need to be merged to match the computation of the untiled operation.
141-
/// `mergeResult` contains the operations used to perform this merge from
142-
/// partial results and the values that can be used as replacements of
143-
/// the untiled operation.
144-
MergeResult mergeResult;
107+
/// Values to use as replacements for the untiled op. Is the same size as the
108+
/// number of results of the untiled op.
109+
SmallVector<Value> replacements;
145110
/// Slices generated after tiling that can be used for fusing with the tiled
146111
/// producer.
147112
SmallVector<Operation *> generatedSlices;
@@ -335,6 +300,20 @@ tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp);
335300
FailureOr<SmallVector<scf::ForOp>>
336301
lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op);
337302

303+
/// Transformation information returned after reduction tiling.
304+
struct SCFReductionTilingResult {
305+
/// The partial reduction tiled op generated.
306+
SmallVector<Operation *> parallelTiledOps;
307+
/// The final reduction operation merging all the partial reductions.
308+
SmallVector<Operation *> mergeOps;
309+
/// Initial values used for reduction.
310+
SmallVector<Value> initialValues;
311+
/// The loop operations that iterate over the tiles.
312+
SmallVector<LoopLikeOpInterface> loops;
313+
/// The replacements to use for the results of the tiled operation.
314+
SmallVector<Value> replacements;
315+
};
316+
338317
/// Method to tile a reduction and generate a parallel op within a serial loop.
339318
/// Each of the partial reductions are calculated in parallel. Then after the
340319
/// loop all the partial reduction are merged into a final reduction.
@@ -359,7 +338,7 @@ lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op);
359338
/// %6 = linalg.generic %1 ["parallel", "reduction"]
360339
/// : tensor<7x4xf32> -> tensor<7xf32>
361340
/// ```
362-
FailureOr<scf::SCFTilingResult>
341+
FailureOr<scf::SCFReductionTilingResult>
363342
tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op,
364343
ArrayRef<OpFoldResult> tileSize);
365344

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2223,7 +2223,7 @@ transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
22232223
return emitDefaultDefiniteFailure(target);
22242224

22252225
if (target->getNumResults())
2226-
rewriter.replaceOp(target, maybeTilingResult->mergeResult.replacements);
2226+
rewriter.replaceOp(target, maybeTilingResult->replacements);
22272227
else
22282228
rewriter.eraseOp(target);
22292229

@@ -2630,18 +2630,17 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
26302630
transform::ApplyToEachResultList &results,
26312631
transform::TransformState &state) {
26322632
rewriter.setInsertionPoint(target);
2633-
FailureOr<scf::SCFTilingResult> result = scf::tileReductionUsingScf(
2633+
FailureOr<scf::SCFReductionTilingResult> result = scf::tileReductionUsingScf(
26342634
rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
26352635
getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes())));
26362636

26372637
if (failed(result))
26382638
return emitDefaultSilenceableFailure(target);
2639-
rewriter.replaceOp(target, result->mergeResult.replacements);
26402639
for (Value initValue : result->initialValues)
26412640
results.push_back(initValue.getDefiningOp());
2642-
for (auto parallelTiledOp : result->tiledOps)
2641+
for (auto parallelTiledOp : result->parallelTiledOps)
26432642
results.push_back(parallelTiledOp);
2644-
for (auto mergeOp : result->mergeResult.mergeOps)
2643+
for (auto mergeOp : result->mergeOps)
26452644
results.push_back(mergeOp);
26462645
results.push_back(result->loops.front());
26472646
return DiagnosedSilenceableFailure::success();
@@ -3065,7 +3064,7 @@ transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter,
30653064
if (failed(maybeTilingResult))
30663065
return DiagnosedSilenceableFailure::definiteFailure();
30673066

3068-
rewriter.replaceOp(op, maybeTilingResult->mergeResult.replacements);
3067+
rewriter.replaceOp(op, maybeTilingResult->replacements);
30693068

30703069
tiled.append(maybeTilingResult->tiledOps);
30713070
for (const auto &en2 : llvm::enumerate(maybeTilingResult->loops))
@@ -3304,7 +3303,7 @@ DiagnosedSilenceableFailure transform::tileToForallOpImpl(
33043303
if (failed(maybeTilingResult))
33053304
return transformOp.emitDefaultSilenceableFailure(tileableOp);
33063305

3307-
rewriter.replaceOp(tileableOp, maybeTilingResult->mergeResult.replacements);
3306+
rewriter.replaceOp(tileableOp, maybeTilingResult->replacements);
33083307

33093308
tilingResult = *maybeTilingResult;
33103309

0 commit comments

Comments
 (0)