Skip to content

Commit f506d72

Browse files
committed
[mlir] Extract forall_to_for logic into reusable function and add pass
1 parent 330d898 commit f506d72

File tree

7 files changed

+171
-30
lines changed

7 files changed

+171
-30
lines changed

mlir/include/mlir/Dialect/SCF/Transforms/Passes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ createParallelLoopTilingPass(llvm::ArrayRef<int64_t> tileSize = {},
5959
/// loop range.
6060
std::unique_ptr<Pass> createForLoopRangeFoldingPass();
6161

62+
/// Creates a pass that converts SCF forall loops to SCF for loops.
63+
std::unique_ptr<Pass> createForallToForLoopPass();
64+
6265
// Creates a pass which lowers for loops into while loops.
6366
std::unique_ptr<Pass> createForToWhileLoopPass();
6467

mlir/include/mlir/Dialect/SCF/Transforms/Passes.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,11 @@ def SCFForLoopRangeFolding : Pass<"scf-for-loop-range-folding"> {
120120
let constructor = "mlir::createForLoopRangeFoldingPass()";
121121
}
122122

123+
def SCFForallToForLoop : Pass<"scf-forall-to-for"> {
124+
let summary = "Convert SCF forall loops to SCF for loops";
125+
let constructor = "mlir::createForallToForLoopPass()";
126+
}
127+
123128
def SCFForToWhileLoop : Pass<"scf-for-to-while"> {
124129
let summary = "Convert SCF for loops to SCF while loops";
125130
let constructor = "mlir::createForToWhileLoopPass()";

mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,15 @@ class Value;
2828
namespace scf {
2929

3030
class IfOp;
31+
class ForallOp;
3132
class ForOp;
3233
class ParallelOp;
3334
class WhileOp;
3435

36+
/// Try converting scf.forall into a set of nested scf.for loops.
37+
LogicalResult forallToForLoop(RewriterBase &rewriter, ForallOp forallOp,
38+
SmallVector<Operation *> *results);
39+
3540
/// Fuses all adjacent scf.parallel operations with identical bounds and step
3641
/// into one scf.parallel operations. Uses a naive aliasing and dependency
3742
/// analysis.

mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp

Lines changed: 8 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -69,16 +69,7 @@ transform::ForallToForOp::apply(transform::TransformRewriter &rewriter,
6969
return diag;
7070
}
7171

72-
rewriter.setInsertionPoint(target);
73-
74-
if (!target.getOutputs().empty()) {
75-
return emitSilenceableError()
76-
<< "unsupported shared outputs (didn't bufferize?)";
77-
}
78-
7972
SmallVector<OpFoldResult> lbs = target.getMixedLowerBound();
80-
SmallVector<OpFoldResult> ubs = target.getMixedUpperBound();
81-
SmallVector<OpFoldResult> steps = target.getMixedStep();
8273

8374
if (getNumResults() != lbs.size()) {
8475
DiagnosedSilenceableFailure diag =
@@ -89,28 +80,15 @@ transform::ForallToForOp::apply(transform::TransformRewriter &rewriter,
8980
return diag;
9081
}
9182

92-
auto loc = target.getLoc();
93-
SmallVector<Value> ivs;
94-
for (auto &&[lb, ub, step] : llvm::zip(lbs, ubs, steps)) {
95-
Value lbValue = getValueOrCreateConstantIndexOp(rewriter, loc, lb);
96-
Value ubValue = getValueOrCreateConstantIndexOp(rewriter, loc, ub);
97-
Value stepValue = getValueOrCreateConstantIndexOp(rewriter, loc, step);
98-
auto loop = rewriter.create<scf::ForOp>(
99-
loc, lbValue, ubValue, stepValue, ValueRange(),
100-
[](OpBuilder &, Location, Value, ValueRange) {});
101-
ivs.push_back(loop.getInductionVar());
102-
rewriter.setInsertionPointToStart(loop.getBody());
103-
rewriter.create<scf::YieldOp>(loc);
104-
rewriter.setInsertionPointToStart(loop.getBody());
83+
SmallVector<Operation *> opResults;
84+
if (failed(scf::forallToForLoop(rewriter, target, &opResults))) {
85+
DiagnosedSilenceableFailure diag = emitSilenceableError()
86+
<< "failed to convert forall into for";
87+
return diag;
10588
}
106-
rewriter.eraseOp(target.getBody()->getTerminator());
107-
rewriter.inlineBlockBefore(target.getBody(), &*rewriter.getInsertionPoint(),
108-
ivs);
109-
rewriter.eraseOp(target);
110-
111-
for (auto &&[i, iv] : llvm::enumerate(ivs)) {
112-
results.set(cast<OpResult>(getTransformed()[i]),
113-
{iv.getParentBlock()->getParentOp()});
89+
90+
for (auto [i, res] : llvm::enumerate(opResults)) {
91+
results.set(cast<OpResult>(getTransformed()[i]), {res});
11492
}
11593
return DiagnosedSilenceableFailure::success();
11694
}

mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
22
BufferDeallocationOpInterfaceImpl.cpp
33
BufferizableOpInterfaceImpl.cpp
44
Bufferize.cpp
5+
ForallToFor.cpp
56
ForToWhile.cpp
67
LoopCanonicalization.cpp
78
LoopPipelining.cpp
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
//===- ForallToFor.cpp - scf.forall to scf.for loop conversion ------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Transforms SCF.ForallOp's into SCF.ForOp's.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Dialect/SCF/Transforms/Passes.h"
14+
15+
#include "mlir/Dialect/SCF/IR/SCF.h"
16+
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
17+
#include "mlir/IR/PatternMatch.h"
18+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19+
20+
namespace mlir {
21+
#define GEN_PASS_DEF_SCFFORALLTOFORLOOP
22+
#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
23+
} // namespace mlir
24+
25+
using namespace llvm;
26+
using namespace mlir;
27+
using scf::ForallOp;
28+
using scf::ForOp;
29+
30+
LogicalResult
31+
mlir::scf::forallToForLoop(RewriterBase &rewriter, scf::ForallOp forallOp,
32+
SmallVector<Operation *> *results = nullptr) {
33+
rewriter.setInsertionPoint(forallOp);
34+
35+
if (!forallOp.getOutputs().empty()) {
36+
return forallOp.emitOpError()
37+
<< "unsupported shared outputs (didn't bufferize?)";
38+
}
39+
40+
SmallVector<OpFoldResult> lbs = forallOp.getMixedLowerBound();
41+
SmallVector<OpFoldResult> ubs = forallOp.getMixedUpperBound();
42+
SmallVector<OpFoldResult> steps = forallOp.getMixedStep();
43+
44+
auto loc = forallOp.getLoc();
45+
SmallVector<Value> ivs;
46+
for (auto &&[lb, ub, step] : llvm::zip(lbs, ubs, steps)) {
47+
Value lbValue = getValueOrCreateConstantIndexOp(rewriter, loc, lb);
48+
Value ubValue = getValueOrCreateConstantIndexOp(rewriter, loc, ub);
49+
Value stepValue = getValueOrCreateConstantIndexOp(rewriter, loc, step);
50+
auto loop =
51+
rewriter.create<ForOp>(loc, lbValue, ubValue, stepValue, ValueRange(),
52+
[](OpBuilder &, Location, Value, ValueRange) {});
53+
if (results)
54+
results->push_back(loop);
55+
ivs.push_back(loop.getInductionVar());
56+
rewriter.setInsertionPointToStart(loop.getBody());
57+
rewriter.create<scf::YieldOp>(loc);
58+
rewriter.setInsertionPointToStart(loop.getBody());
59+
}
60+
rewriter.eraseOp(forallOp.getBody()->getTerminator());
61+
rewriter.inlineBlockBefore(forallOp.getBody(), &*rewriter.getInsertionPoint(),
62+
ivs);
63+
rewriter.eraseOp(forallOp);
64+
return success();
65+
}
66+
67+
namespace {
68+
struct ForallToForLoopLoweringPattern : public OpRewritePattern<ForallOp> {
69+
using OpRewritePattern<ForallOp>::OpRewritePattern;
70+
71+
LogicalResult matchAndRewrite(ForallOp forallOp,
72+
PatternRewriter &rewriter) const override {
73+
if (failed(scf::forallToForLoop(rewriter, forallOp)))
74+
return failure();
75+
return success();
76+
}
77+
};
78+
79+
struct ForallToForLoop : public impl::SCFForallToForLoopBase<ForallToForLoop> {
80+
void runOnOperation() override {
81+
auto *parentOp = getOperation();
82+
MLIRContext *ctx = parentOp->getContext();
83+
RewritePatternSet patterns(ctx);
84+
patterns.add<ForallToForLoopLoweringPattern>(ctx);
85+
(void)applyPatternsAndFoldGreedily(parentOp, std::move(patterns));
86+
}
87+
};
88+
} // namespace
89+
90+
std::unique_ptr<Pass> mlir::createForallToForLoopPass() {
91+
return std::make_unique<ForallToForLoop>();
92+
}
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(scf-forall-to-for))' -split-input-file | FileCheck %s
2+
3+
func.func private @callee(%i: index, %j: index)
4+
5+
// CHECK-LABEL: @two_iters
6+
// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index
7+
func.func @two_iters(%ub1: index, %ub2: index) {
8+
scf.forall (%i, %j) in (%ub1, %ub2) {
9+
func.call @callee(%i, %j) : (index, index) -> ()
10+
}
11+
// CHECK: scf.for %[[IV1:.+]] = %{{.*}} to %[[UB1]]
12+
// CHECK: scf.for %[[IV2:.+]] = %{{.*}} to %[[UB2]]
13+
// CHECK: func.call @callee(%[[IV1]], %[[IV2]])
14+
return
15+
}
16+
17+
// -----
18+
19+
func.func private @callee(%i: index, %j: index)
20+
21+
// CHECK-LABEL: @repeated
22+
// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index
23+
func.func @repeated(%ub1: index, %ub2: index) {
24+
scf.forall (%i, %j) in (%ub1, %ub2) {
25+
func.call @callee(%i, %j) : (index, index) -> ()
26+
}
27+
// CHECK: scf.for %[[IV1:.+]] = %{{.*}} to %[[UB1]]
28+
// CHECK: scf.for %[[IV2:.+]] = %{{.*}} to %[[UB2]]
29+
// CHECK: func.call @callee(%[[IV1]], %[[IV2]])
30+
scf.forall (%i, %j) in (%ub1, %ub2) {
31+
func.call @callee(%i, %j) : (index, index) -> ()
32+
}
33+
// CHECK: scf.for %[[IV1:.+]] = %{{.*}} to %[[UB1]]
34+
// CHECK: scf.for %[[IV2:.+]] = %{{.*}} to %[[UB2]]
35+
// CHECK: func.call @callee(%[[IV1]], %[[IV2]])
36+
return
37+
}
38+
39+
// -----
40+
41+
func.func private @callee(%i: index, %j: index, %k: index, %l: index)
42+
43+
// CHECK-LABEL: @nested
44+
// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index, %[[UB3:.+]]: index, %[[UB4:.+]]: index
45+
func.func @nested(%ub1: index, %ub2: index, %ub3: index, %ub4: index) {
46+
// CHECK: scf.for %[[IV1:.+]] = %{{.*}} to %[[UB1]]
47+
// CHECK: scf.for %[[IV2:.+]] = %{{.*}} to %[[UB2]]
48+
// CHECK: scf.for %[[IV3:.+]] = %{{.*}} to %[[UB3]]
49+
// CHECK: scf.for %[[IV4:.+]] = %{{.*}} to %[[UB4]]
50+
// CHECK: func.call @callee(%[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]])
51+
scf.forall (%i, %j) in (%ub1, %ub2) {
52+
scf.forall (%k, %l) in (%ub3, %ub4) {
53+
func.call @callee(%i, %j, %k, %l) : (index, index, index, index) -> ()
54+
}
55+
}
56+
return
57+
}

0 commit comments

Comments
 (0)