Skip to content

Commit 4fbb5f9

Browse files
authored
[mlir] introduce transform.loop.forall_to_for (#65474)
Add a straightforward sequentialization transform from `scf.forall` to a nest of `scf.for` in absence of results and expose it as a transform op. This is helpful in combination with other transform ops, particularly fusion, that work best on parallel-by-construction `scf.forall` but later need to target sequential `for` loops.
1 parent f557986 commit 4fbb5f9

File tree

4 files changed

+174
-0
lines changed

4 files changed

+174
-0
lines changed

mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ namespace func {
2020
class FuncOp;
2121
} // namespace func
2222
namespace scf {
23+
class ForallOp;
2324
class ForOp;
2425
class IfOp;
2526
} // namespace scf

mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,34 @@ def ApplySCFStructuralConversionPatternsOp : Op<Transform_Dialect,
4040

4141
def Transform_ScfForOp : Transform_ConcreteOpType<"scf.for">;
4242

43+
def ForallToForOp : Op<Transform_Dialect, "loop.forall_to_for",
44+
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
45+
DeclareOpInterfaceMethods<TransformOpInterface>]> {
46+
let summary = "Converts scf.forall into a nest of scf.for operations";
47+
let description = [{
48+
Converts the `scf.forall` operation pointed to by the given handle into a
49+
set of nested `scf.for` operations. Each new operation corresponds to one
50+
induction variable of the original "multifor" loop.
51+
52+
The operand handle must be associated with exactly one payload operation.
53+
54+
Loops with shared outputs are currently not supported.
55+
56+
#### Return Modes
57+
58+
Consumes the operand handle. Produces a silenceable failure if the operand
59+
is not associated with a single `scf.forall` payload operation.
60+
Returns as many handles as the given `forall` op has induction variables
61+
that are associated with the generated `scf.for` loops.
62+
Produces a silenceable failure if another number of resulting handles is
63+
requested.
64+
}];
65+
let arguments = (ins TransformHandleTypeInterface:$target);
66+
let results = (outs Variadic<TransformHandleTypeInterface>:$transformed);
67+
68+
let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
69+
}
70+
4371
def GetParentForOp : Op<Transform_Dialect, "loop.get_parent_for",
4472
[NavigationTransformOpTrait, MemoryEffectsOpInterface,
4573
DeclareOpInterfaceMethods<TransformOpInterface>]> {

mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
1010
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1111
#include "mlir/Dialect/Affine/LoopUtils.h"
12+
#include "mlir/Dialect/Arith/IR/Arith.h"
13+
#include "mlir/Dialect/Arith/Utils/Utils.h"
1214
#include "mlir/Dialect/Func/IR/FuncOps.h"
1315
#include "mlir/Dialect/SCF/IR/SCF.h"
1416
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
@@ -17,8 +19,11 @@
1719
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
1820
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
1921
#include "mlir/Dialect/Transform/IR/TransformOps.h"
22+
#include "mlir/Dialect/Utils/StaticValueUtils.h"
2023
#include "mlir/Dialect/Vector/IR/VectorOps.h"
24+
#include "mlir/IR/BuiltinAttributes.h"
2125
#include "mlir/IR/Dominance.h"
26+
#include "mlir/IR/OpDefinition.h"
2227

2328
using namespace mlir;
2429
using namespace mlir::affine;
@@ -47,6 +52,7 @@ void transform::ApplySCFStructuralConversionPatternsOp::
4752
//===----------------------------------------------------------------------===//
4853
// GetParentForOp
4954
//===----------------------------------------------------------------------===//
55+
5056
DiagnosedSilenceableFailure
5157
transform::GetParentForOp::apply(transform::TransformRewriter &rewriter,
5258
transform::TransformResults &results,
@@ -76,6 +82,72 @@ transform::GetParentForOp::apply(transform::TransformRewriter &rewriter,
7682
return DiagnosedSilenceableFailure::success();
7783
}
7884

85+
//===----------------------------------------------------------------------===//
86+
// ForallToForOp
87+
//===----------------------------------------------------------------------===//
88+
89+
DiagnosedSilenceableFailure
90+
transform::ForallToForOp::apply(transform::TransformRewriter &rewriter,
91+
transform::TransformResults &results,
92+
transform::TransformState &state) {
93+
auto payload = state.getPayloadOps(getTarget());
94+
if (!llvm::hasSingleElement(payload))
95+
return emitSilenceableError() << "expected a single payload op";
96+
97+
auto target = dyn_cast<scf::ForallOp>(*payload.begin());
98+
if (!target) {
99+
DiagnosedSilenceableFailure diag =
100+
emitSilenceableError() << "expected the payload to be scf.forall";
101+
diag.attachNote((*payload.begin())->getLoc()) << "payload op";
102+
return diag;
103+
}
104+
105+
rewriter.setInsertionPoint(target);
106+
107+
if (!target.getOutputs().empty()) {
108+
return emitSilenceableError()
109+
<< "unsupported shared outputs (didn't bufferize?)";
110+
}
111+
112+
SmallVector<OpFoldResult> lbs = target.getMixedLowerBound();
113+
SmallVector<OpFoldResult> ubs = target.getMixedUpperBound();
114+
SmallVector<OpFoldResult> steps = target.getMixedStep();
115+
116+
if (getNumResults() != lbs.size()) {
117+
DiagnosedSilenceableFailure diag =
118+
emitSilenceableError()
119+
<< "op expects as many results (" << getNumResults()
120+
<< ") as payload has induction variables (" << lbs.size() << ")";
121+
diag.attachNote(target.getLoc()) << "payload op";
122+
return diag;
123+
}
124+
125+
auto loc = target.getLoc();
126+
SmallVector<Value> ivs;
127+
for (auto &&[lb, ub, step] : llvm::zip(lbs, ubs, steps)) {
128+
Value lbValue = getValueOrCreateConstantIndexOp(rewriter, loc, lb);
129+
Value ubValue = getValueOrCreateConstantIndexOp(rewriter, loc, ub);
130+
Value stepValue = getValueOrCreateConstantIndexOp(rewriter, loc, step);
131+
auto loop = rewriter.create<scf::ForOp>(
132+
loc, lbValue, ubValue, stepValue, ValueRange(),
133+
[](OpBuilder &, Location, Value, ValueRange) {});
134+
ivs.push_back(loop.getInductionVar());
135+
rewriter.setInsertionPointToStart(loop.getBody());
136+
rewriter.create<scf::YieldOp>(loc);
137+
rewriter.setInsertionPointToStart(loop.getBody());
138+
}
139+
rewriter.eraseOp(target.getBody()->getTerminator());
140+
rewriter.inlineBlockBefore(target.getBody(), &*rewriter.getInsertionPoint(),
141+
ivs);
142+
rewriter.eraseOp(target);
143+
144+
for (auto &&[i, iv] : llvm::enumerate(ivs)) {
145+
results.set(cast<OpResult>(getTransformed()[i]),
146+
{iv.getParentBlock()->getParentOp()});
147+
}
148+
return DiagnosedSilenceableFailure::success();
149+
}
150+
79151
//===----------------------------------------------------------------------===//
80152
// LoopOutlineOp
81153
//===----------------------------------------------------------------------===//
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file --verify-diagnostics | FileCheck %s
2+
3+
func.func private @callee(%i: index, %j: index)
4+
5+
// CHECK-LABEL: @two_iters
6+
// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index
7+
func.func @two_iters(%ub1: index, %ub2: index) {
8+
scf.forall (%i, %j) in (%ub1, %ub2) {
9+
func.call @callee(%i, %j) : (index, index) -> ()
10+
}
11+
// CHECK: scf.for %[[IV1:.+]] = %{{.*}} to %[[UB1]]
12+
// CHECK: scf.for %[[IV2:.+]] = %{{.*}} to %[[UB2]]
13+
// CHECK: func.call @callee(%[[IV1]], %[[IV2]])
14+
return
15+
}
16+
17+
transform.sequence failures(propagate) {
18+
^bb0(%arg0: !transform.any_op):
19+
%0 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
20+
transform.loop.forall_to_for %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
21+
}
22+
23+
// -----
24+
25+
func.func private @callee(%i: index, %j: index)
26+
27+
func.func @repeated(%ub1: index, %ub2: index) {
28+
scf.forall (%i, %j) in (%ub1, %ub2) {
29+
func.call @callee(%i, %j) : (index, index) -> ()
30+
}
31+
scf.forall (%i, %j) in (%ub1, %ub2) {
32+
func.call @callee(%i, %j) : (index, index) -> ()
33+
}
34+
return
35+
}
36+
37+
transform.sequence failures(propagate) {
38+
^bb0(%arg0: !transform.any_op):
39+
%0 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
40+
// expected-error @below {{expected a single payload op}}
41+
transform.loop.forall_to_for %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
42+
}
43+
44+
// -----
45+
46+
func.func private @callee(%i: index, %j: index)
47+
48+
func.func @repeated(%ub1: index, %ub2: index) {
49+
// expected-note @below {{payload op}}
50+
scf.forall (%i, %j) in (%ub1, %ub2) {
51+
func.call @callee(%i, %j) : (index, index) -> ()
52+
}
53+
return
54+
}
55+
56+
transform.sequence failures(propagate) {
57+
^bb0(%arg0: !transform.any_op):
58+
%0 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
59+
// expected-error @below {{op expects as many results (1) as payload has induction variables (2)}}
60+
transform.loop.forall_to_for %0 : (!transform.any_op) -> !transform.any_op
61+
}
62+
63+
// -----
64+
65+
// expected-note @below {{payload op}}
66+
func.func private @callee(%i: index, %j: index)
67+
68+
transform.sequence failures(propagate) {
69+
^bb0(%arg0: !transform.any_op):
70+
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
71+
// expected-error @below {{expected the payload to be scf.forall}}
72+
transform.loop.forall_to_for %0 : (!transform.any_op) -> !transform.any_op
73+
}

0 commit comments

Comments
 (0)