-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][scf] Add reductions support to scf.parallel
fusion
#75955
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
2fd5a4d
d858a4a
47ec48e
65a3b05
083707e
8f7b4a4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -161,29 +161,85 @@ static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop, | |
} | ||
|
||
/// Prepends operations of firstPloop's body into secondPloop's body. | ||
static void fuseIfLegal(ParallelOp firstPloop, ParallelOp secondPloop, | ||
OpBuilder b, | ||
/// Updates secondPloop with new loop. | ||
static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop, | ||
OpBuilder builder, | ||
llvm::function_ref<bool(Value, Value)> mayAlias) { | ||
Block *block1 = firstPloop.getBody(); | ||
Block *block2 = secondPloop.getBody(); | ||
IRMapping firstToSecondPloopIndices; | ||
firstToSecondPloopIndices.map(firstPloop.getBody()->getArguments(), | ||
secondPloop.getBody()->getArguments()); | ||
firstToSecondPloopIndices.map(block1->getArguments(), block2->getArguments()); | ||
|
||
if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices, | ||
mayAlias)) | ||
return; | ||
|
||
b.setInsertionPointToStart(secondPloop.getBody()); | ||
for (auto &op : firstPloop.getBody()->without_terminator()) | ||
b.clone(op, firstToSecondPloopIndices); | ||
DominanceInfo dom; | ||
// We are fusing first loop into second, make sure there are no users of the | ||
// first loop results between loops. | ||
for (Operation *user : firstPloop->getUsers()) | ||
if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false)) | ||
return; | ||
|
||
ValueRange inits1 = firstPloop.getInitVals(); | ||
ValueRange inits2 = secondPloop.getInitVals(); | ||
|
||
SmallVector<Value> newInitVars(inits1.begin(), inits1.end()); | ||
newInitVars.append(inits2.begin(), inits2.end()); | ||
|
||
IRRewriter b(builder); | ||
b.setInsertionPoint(secondPloop); | ||
auto newSecondPloop = b.create<ParallelOp>( | ||
secondPloop.getLoc(), secondPloop.getLowerBound(), | ||
secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars); | ||
|
||
Block *newBlock = newSecondPloop.getBody(); | ||
auto term1 = cast<ReduceOp>(block1->getTerminator()); | ||
auto term2 = cast<ReduceOp>(block2->getTerminator()); | ||
|
||
b.inlineBlockBefore(block2, newBlock, newBlock->begin(), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's the difference between:
These seem the same to me? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need to change results count, when merging There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you not just re-generate the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Each reduction corresponds to the parent |
||
newBlock->getArguments()); | ||
b.inlineBlockBefore(block1, newBlock, newBlock->begin(), | ||
newBlock->getArguments()); | ||
|
||
ValueRange results = newSecondPloop.getResults(); | ||
if (!results.empty()) { | ||
b.setInsertionPointToEnd(newBlock); | ||
|
||
ValueRange reduceArgs1 = term1.getOperands(); | ||
ValueRange reduceArgs2 = term2.getOperands(); | ||
SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end()); | ||
newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end()); | ||
|
||
auto newReduceOp = b.create<scf::ReduceOp>(term2.getLoc(), newReduceArgs); | ||
|
||
for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>( | ||
term1.getReductions(), term2.getReductions()))) { | ||
Block &oldRedBlock = reg.front(); | ||
Block &newRedBlock = newReduceOp.getReductions()[i].front(); | ||
b.inlineBlockBefore(&oldRedBlock, &newRedBlock, newRedBlock.begin(), | ||
newRedBlock.getArguments()); | ||
} | ||
|
||
firstPloop.replaceAllUsesWith(results.take_front(inits1.size())); | ||
secondPloop.replaceAllUsesWith(results.take_back(inits2.size())); | ||
} | ||
term1->erase(); | ||
term2->erase(); | ||
firstPloop.erase(); | ||
secondPloop.erase(); | ||
secondPloop = newSecondPloop; | ||
} | ||
|
||
void mlir::scf::naivelyFuseParallelOps( | ||
Region ®ion, llvm::function_ref<bool(Value, Value)> mayAlias) { | ||
OpBuilder b(region); | ||
// Consider every single block and attempt to fuse adjacent loops. | ||
SmallVector<SmallVector<ParallelOp>, 1> ploopChains; | ||
for (auto &block : region) { | ||
SmallVector<SmallVector<ParallelOp, 8>, 1> ploopChains{{}}; | ||
ploopChains.clear(); | ||
rengolin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
ploopChains.push_back({}); | ||
|
||
// Not using `walk()` to traverse only top-level parallel loops and also | ||
// make sure that there are no side-effecting ops between the parallel | ||
// loops. | ||
|
@@ -201,7 +257,7 @@ void mlir::scf::naivelyFuseParallelOps( | |
// TODO: Handle region side effects properly. | ||
noSideEffects &= isMemoryEffectFree(&op) && op.getNumRegions() == 0; | ||
} | ||
for (ArrayRef<ParallelOp> ploops : ploopChains) { | ||
for (MutableArrayRef<ParallelOp> ploops : ploopChains) { | ||
for (int i = 0, e = ploops.size(); i + 1 < e; ++i) | ||
fuseIfLegal(ploops[i], ploops[i + 1], b, mayAlias); | ||
} | ||
|
Uh oh!
There was an error while loading. Please reload this page.