Skip to content

[mlir][SCF] Unify tileUsingFor and tileReductionUsingFor implementation #120115

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 39 additions & 18 deletions mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,36 @@ struct SCFTilingOptions {
return *this;
}

/// Specify how reduction dimensions should be tiled.
///
/// Tiling can be thought of as splitting a dimension into 2 and materializing
/// the outer dimension as a loop:
///
/// op[original] -> op[original / x, x] -> loop[original] { op[x] }
///
/// For parallel dimensions, the split can only happen in one way, with both
/// dimensions being parallel. For reduction dimensions however, there is a
/// choice in how we split the reduction dimension. This enum exposes this
/// choice.
enum class ReductionTilingStrategy {
// [reduction] -> [reduction1, reduction2]
// -> loop[reduction1] { [reduction2] }
FullReduction,
// [reduction] -> [reduction1, parallel2]
// -> loop[reduction1] { [parallel2] }; merge[reduction1]
PartialReductionOuterReduction,
// [reduction] -> [parallel1, reduction2]
// -> loop[parallel1] { [reduction2] }; merge[parallel1]
PartialReductionOuterParallel
};
ReductionTilingStrategy reductionStrategy =
ReductionTilingStrategy::FullReduction;
SCFTilingOptions &
setReductionTilingStrategy(ReductionTilingStrategy strategy) {
reductionStrategy = strategy;
return *this;
}

/// Specify mapping of loops to devices. This is only respected when the loop
/// constructs support such a mapping (like `scf.forall`). Will be ignored
/// when using loop constructs that dont support such a mapping (like
Expand All @@ -102,11 +132,16 @@ struct SCFTilingResult {
/// matter except the last op. The replacements are expected to be the results
/// of the last op.
SmallVector<Operation *> tiledOps;
/// The initial destination values passed to the tiled operations.
SmallVector<Value> initialValues;
/// The `scf.for` operations that iterate over the tiles.
SmallVector<LoopLikeOpInterface> loops;
/// Values to use as replacements for the untiled op. Is the same size as the
/// number of results of the untiled op.
SmallVector<Value> replacements;
/// The result generated by the loop nest in tiling, may hold partial results,
/// which need to be merged to match the computation of the untiled operation.
/// `mergeResult` contains the operations used to perform this merge from
/// partial results and the values that can be used as replacements of
/// the untiled operation.
MergeResult mergeResult;
/// Slices generated after tiling that can be used for fusing with the tiled
/// producer.
SmallVector<Operation *> generatedSlices;
Expand Down Expand Up @@ -300,20 +335,6 @@ tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp);
FailureOr<SmallVector<scf::ForOp>>
lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op);

/// Transformation information returned after reduction tiling.
struct SCFReductionTilingResult {
/// The partial reduction tiled op generated.
SmallVector<Operation *> parallelTiledOps;
/// The final reduction operation merging all the partial reductions.
SmallVector<Operation *> mergeOps;
/// Initial values used for reduction.
SmallVector<Value> initialValues;
/// The loop operations that iterate over the tiles.
SmallVector<LoopLikeOpInterface> loops;
/// The replacements to use for the results of the tiled operation.
SmallVector<Value> replacements;
};

/// Method to tile a reduction and generate a parallel op within a serial loop.
/// Each of the partial reductions are calculated in parallel. Then after the
/// loop all the partial reduction are merged into a final reduction.
Expand All @@ -338,7 +359,7 @@ struct SCFReductionTilingResult {
/// %6 = linalg.generic %1 ["parallel", "reduction"]
/// : tensor<7x4xf32> -> tensor<7xf32>
/// ```
FailureOr<scf::SCFReductionTilingResult>
FailureOr<scf::SCFTilingResult>
tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op,
ArrayRef<OpFoldResult> tileSize);

Expand Down
13 changes: 7 additions & 6 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2224,7 +2224,7 @@ transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
return emitDefaultDefiniteFailure(target);

if (target->getNumResults())
rewriter.replaceOp(target, maybeTilingResult->replacements);
rewriter.replaceOp(target, maybeTilingResult->mergeResult.replacements);
else
rewriter.eraseOp(target);

Expand Down Expand Up @@ -2631,17 +2631,18 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
rewriter.setInsertionPoint(target);
FailureOr<scf::SCFReductionTilingResult> result = scf::tileReductionUsingScf(
FailureOr<scf::SCFTilingResult> result = scf::tileReductionUsingScf(
rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes())));

if (failed(result))
return emitDefaultSilenceableFailure(target);
rewriter.replaceOp(target, result->mergeResult.replacements);
for (Value initValue : result->initialValues)
results.push_back(initValue.getDefiningOp());
for (auto parallelTiledOp : result->parallelTiledOps)
for (auto parallelTiledOp : result->tiledOps)
results.push_back(parallelTiledOp);
for (auto mergeOp : result->mergeOps)
for (auto mergeOp : result->mergeResult.mergeOps)
results.push_back(mergeOp);
results.push_back(result->loops.front());
return DiagnosedSilenceableFailure::success();
Expand Down Expand Up @@ -3065,7 +3066,7 @@ transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter,
if (failed(maybeTilingResult))
return DiagnosedSilenceableFailure::definiteFailure();

rewriter.replaceOp(op, maybeTilingResult->replacements);
rewriter.replaceOp(op, maybeTilingResult->mergeResult.replacements);

tiled.append(maybeTilingResult->tiledOps);
for (const auto &en2 : llvm::enumerate(maybeTilingResult->loops))
Expand Down Expand Up @@ -3304,7 +3305,7 @@ DiagnosedSilenceableFailure transform::tileToForallOpImpl(
if (failed(maybeTilingResult))
return transformOp.emitDefaultSilenceableFailure(tileableOp);

rewriter.replaceOp(tileableOp, maybeTilingResult->replacements);
rewriter.replaceOp(tileableOp, maybeTilingResult->mergeResult.replacements);

tilingResult = *maybeTilingResult;

Expand Down
Loading
Loading