Skip to content

Commit cbd4750

Browse files
committed
[mlir][mlprogram] Add mlprogram-pipeline-globals optimization pass
Added pass optimizes MLProgram global operations by reducing to only the minimal load/store operations for global tensors. This avoids unnecessary global operations throughout a program and potentially improves operation gusion. Reviewed By: jpienaar Differential Revision: https://reviews.llvm.org/D159228
1 parent b2ef297 commit cbd4750

File tree

10 files changed

+582
-4
lines changed

10 files changed

+582
-4
lines changed
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
add_subdirectory(IR)
2+
add_subdirectory(Transforms)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
set(LLVM_TARGET_DEFINITIONS Passes.td)
2+
mlir_tablegen(Passes.h.inc -gen-pass-decls -name MLProgram)
3+
add_public_tablegen_target(MLIRMLProgramPassIncGen)
4+
add_dependencies(mlir-headers MLIRMLProgramPassIncGen)
5+
6+
add_mlir_doc(Passes MLProgramPasses ./ -gen-pass-doc)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_MLPROGRAM_TRANSFORMS_PASSES_H_
10+
#define MLIR_DIALECT_MLPROGRAM_TRANSFORMS_PASSES_H_
11+
12+
#include "mlir/Dialect/Func/IR/FuncOps.h"
13+
#include "mlir/IR/BuiltinOps.h"
14+
#include "mlir/Pass/Pass.h"
15+
16+
namespace mlir {
17+
namespace ml_program {
18+
19+
#define GEN_PASS_DECL
20+
#include "mlir/Dialect/MLProgram/Transforms/Passes.h.inc"
21+
22+
//===----------------------------------------------------------------------===//
23+
// Registration
24+
//===----------------------------------------------------------------------===//
25+
26+
std::unique_ptr<OperationPass<ModuleOp>> createMLProgramPipelineGlobalsPass();
27+
28+
/// Generate the code for registering passes.
29+
#define GEN_PASS_REGISTRATION
30+
#include "mlir/Dialect/MLProgram/Transforms/Passes.h.inc"
31+
32+
} // namespace ml_program
33+
} // namespace mlir
34+
35+
#endif // MLIR_DIALECT_MLPROGRAM_TRANSFORMS_PASSES_H_
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
//===-- Passes.td - pass definition file -------------------*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_MLPROGRAM_TRANSFORMS_PASSES
10+
#define MLIR_DIALECT_MLPROGRAM_TRANSFORMS_PASSES
11+
12+
include "mlir/Pass/PassBase.td"
13+
14+
def MLProgramPipelineGlobals : Pass<"mlprogram-pipeline-globals", "ModuleOp"> {
15+
let summary = "Optimize `ml_program` global operations for read and store";
16+
let description = [{
17+
`ml_program`'s load and store operations can be optimized for
18+
write-write or write-read sets of operations. This allows known
19+
tensors to not be re-read when the value is already known in IR.
20+
21+
The pass is designed to handle both nested regions and function calls
22+
safely.
23+
}];
24+
let constructor = "mlir::ml_program::createMLProgramPipelineGlobalsPass()";
25+
}
26+
27+
#endif // MLIR_DIALECT_MLPROGRAM_TRANSFORMS_PASSES

mlir/include/mlir/InitAllPasses.h

+2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "mlir/Dialect/GPU/Transforms/Passes.h"
2727
#include "mlir/Dialect/LLVMIR/Transforms/Passes.h"
2828
#include "mlir/Dialect/Linalg/Passes.h"
29+
#include "mlir/Dialect/MLProgram/Transforms/Passes.h"
2930
#include "mlir/Dialect/Math/Transforms/Passes.h"
3031
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
3132
#include "mlir/Dialect/NVGPU/Transforms/Passes.h"
@@ -72,6 +73,7 @@ inline void registerAllPasses() {
7273
LLVM::registerLLVMPasses();
7374
math::registerMathPasses();
7475
memref::registerMemRefPasses();
76+
ml_program::registerMLProgramPasses();
7577
registerSCFPasses();
7678
registerShapePasses();
7779
spirv::registerSPIRVPasses();
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
add_subdirectory(IR)
2+
add_subdirectory(Transforms)

mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp

+16-4
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,14 @@ LogicalResult GlobalOp::verify() {
178178
//===----------------------------------------------------------------------===//
179179

180180
GlobalOp GlobalLoadOp::getGlobalOp(SymbolTableCollection &symbolTable) {
181-
return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
182-
getOperation()->getParentOp(), getGlobalAttr());
181+
for (auto parent = getOperation()->getParentOp(); parent;
182+
parent = parent->getParentOp()) {
183+
if (auto nearest = symbolTable.lookupNearestSymbolFrom<GlobalOp>(
184+
parent, getGlobalAttr())) {
185+
return nearest;
186+
}
187+
}
188+
return {};
183189
}
184190

185191
LogicalResult
@@ -253,8 +259,14 @@ GlobalLoadGraphOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
253259
//===----------------------------------------------------------------------===//
254260

255261
GlobalOp GlobalStoreOp::getGlobalOp(SymbolTableCollection &symbolTable) {
256-
return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
257-
getOperation()->getParentOp(), getGlobalAttr());
262+
for (auto parent = getOperation()->getParentOp(); parent;) {
263+
if (auto nearest = symbolTable.lookupNearestSymbolFrom<GlobalOp>(
264+
parent, getGlobalAttr())) {
265+
return nearest;
266+
}
267+
parent = parent->getParentOp();
268+
}
269+
return {};
258270
}
259271

260272
LogicalResult
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
add_mlir_dialect_library(MLIRMLProgramTransforms
2+
PipelineGlobalOps.cpp
3+
4+
ADDITIONAL_HEADER_DIRS
5+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/MLProgram/Transforms
6+
7+
DEPENDS
8+
MLIRMLProgramPassIncGen
9+
10+
LINK_LIBS PUBLIC
11+
MLIRIR
12+
MLIRMLProgramDialect
13+
MLIRPass
14+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
//===- PipelineGlobalOpsPass.cpp - Pipeline Global Ops Pass ---------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/MLProgram/Transforms/Passes.h"
10+
11+
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
12+
#include "mlir/Dialect/MLProgram/Transforms/Passes.h"
13+
#include "mlir/IR/BuiltinOps.h"
14+
#include "mlir/Pass/Pass.h"
15+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
16+
17+
namespace mlir {
18+
namespace ml_program {
19+
#define GEN_PASS_DEF_MLPROGRAMPIPELINEGLOBALS
20+
#include "mlir/Dialect/MLProgram/Transforms/Passes.h.inc"
21+
22+
namespace {
23+
24+
class MLProgramPipelineGlobals
25+
: public impl::MLProgramPipelineGlobalsBase<MLProgramPipelineGlobals> {
26+
public:
27+
void runOnOperation() override;
28+
29+
private:
30+
LogicalResult buildGlobalMap(ModuleOp op);
31+
32+
void ProcessBlock(Block &block, llvm::DenseSet<SymbolRefAttr> &symbolLoad,
33+
llvm::DenseSet<SymbolRefAttr> &symbolStore);
34+
35+
llvm::DenseMap<SymbolRefAttr, llvm::DenseSet<SymbolRefAttr>> loadSymbolsMap;
36+
llvm::DenseMap<SymbolRefAttr, llvm::DenseSet<SymbolRefAttr>> storeSymbolsMap;
37+
};
38+
39+
// Traverses upwards searchign for the operation mapped by the symbol.
40+
static Operation *getFromSymbol(Operation *baseOp, SymbolRefAttr symbol) {
41+
for (auto op = baseOp; op; op = op->getParentOp()) {
42+
auto lookup = SymbolTable::lookupNearestSymbolFrom(op, symbol);
43+
if (lookup)
44+
return lookup;
45+
}
46+
return nullptr;
47+
}
48+
49+
// Builds map from a symbol to MLProgram global symbols loaded or stored
50+
// during processing.
51+
LogicalResult MLProgramPipelineGlobals::buildGlobalMap(ModuleOp module) {
52+
llvm::DenseMap<SymbolRefAttr, Operation *> callableMap;
53+
auto res = module->walk([&](Operation *op) {
54+
if (auto caller = mlir::dyn_cast<CallOpInterface>(op)) {
55+
auto callable = caller.getCallableForCallee();
56+
// For now we do not know how to handle Value based tracing, so fail.
57+
if (mlir::isa<Value>(callable)) {
58+
return WalkResult::interrupt();
59+
}
60+
61+
auto symbol = mlir::dyn_cast<SymbolRefAttr>(callable);
62+
auto func = getFromSymbol(op, symbol);
63+
callableMap[symbol] = func;
64+
}
65+
return WalkResult::advance();
66+
});
67+
68+
if (res.wasInterrupted()) {
69+
return failure();
70+
}
71+
72+
// First grab all symbols loaded or stored by each function. This
73+
// will not handle calls initially.
74+
llvm::DenseMap<SymbolRefAttr, llvm::DenseSet<SymbolRefAttr>> opLoadSymbols;
75+
llvm::DenseMap<SymbolRefAttr, llvm::DenseSet<SymbolRefAttr>> opStoreSymbols;
76+
for (auto callable : callableMap) {
77+
llvm::DenseSet<SymbolRefAttr> loadSymbols;
78+
llvm::DenseSet<SymbolRefAttr> storeSymbols;
79+
80+
callable.getSecond()->walk(
81+
[&](GlobalLoadOp op) { loadSymbols.insert(op.getGlobal()); });
82+
83+
callable.getSecond()->walk(
84+
[&](GlobalStoreOp op) { storeSymbols.insert(op.getGlobal()); });
85+
86+
opLoadSymbols[callable.getFirst()] = std::move(loadSymbols);
87+
opStoreSymbols[callable.getFirst()] = std::move(storeSymbols);
88+
}
89+
90+
// For each callable function we find each global loaded/stored within the
91+
// function or a nested called function. This includes recursion checking to
92+
// avoid infinitely recursing.
93+
for (auto callable : callableMap) {
94+
SymbolRefAttr thisSymbol = llvm::dyn_cast<SymbolRefAttr>(callable.first);
95+
llvm::SmallVector<SymbolRefAttr> work = {thisSymbol};
96+
llvm::DenseSet<SymbolRefAttr> visited = {thisSymbol};
97+
llvm::DenseSet<SymbolRefAttr> loadSymbols;
98+
llvm::DenseSet<SymbolRefAttr> storeSymbols;
99+
100+
for (size_t i = 0; i < work.size(); ++i) {
101+
callableMap[work[i]]->walk([&](CallOpInterface call) {
102+
auto symbol = dyn_cast<SymbolRefAttr>(call.getCallableForCallee());
103+
if (!visited.contains(symbol)) {
104+
visited.insert(symbol);
105+
work.push_back(symbol);
106+
}
107+
});
108+
109+
for (auto load : opLoadSymbols[work[i]])
110+
loadSymbols.insert(load);
111+
112+
for (auto store : opStoreSymbols[work[i]])
113+
storeSymbols.insert(store);
114+
}
115+
116+
loadSymbolsMap[thisSymbol] = std::move(loadSymbols);
117+
storeSymbolsMap[thisSymbol] = std::move(storeSymbols);
118+
}
119+
120+
return success();
121+
}
122+
123+
// Process each operation in the block deleting unneeded loads / stores,
124+
// recursing on subblocks and checking function calls.
125+
void MLProgramPipelineGlobals::ProcessBlock(
126+
Block &block, llvm::DenseSet<SymbolRefAttr> &symbolLoad,
127+
llvm::DenseSet<SymbolRefAttr> &symbolStore) {
128+
129+
llvm::DenseMap<SymbolRefAttr, Value> previousLoads;
130+
llvm::DenseMap<SymbolRefAttr, Operation *> previousStores;
131+
llvm::SmallVector<Operation *> toDelete;
132+
for (auto &op : block) {
133+
// If this is a global load, remap to a previous value if known
134+
// and delete this load. Remember that this value is the currently
135+
// known load.
136+
if (auto load = mlir::dyn_cast<GlobalLoadOp>(op)) {
137+
auto ref = load.getGlobal();
138+
symbolLoad.insert(ref);
139+
if (previousLoads.contains(ref)) {
140+
toDelete.push_back(&op);
141+
load.getResult().replaceAllUsesWith(previousLoads[ref]);
142+
} else {
143+
previousLoads[ref] = load.getResult();
144+
}
145+
continue;
146+
}
147+
148+
// Delete a previous store if it exists and is not needed, update
149+
// the most recent known value for this global ref.
150+
if (auto store = mlir::dyn_cast<GlobalStoreOp>(op)) {
151+
auto ref = store.getGlobal();
152+
symbolStore.insert(ref);
153+
if (previousStores.contains(ref)) {
154+
toDelete.push_back(previousStores.find(ref)->getSecond());
155+
}
156+
157+
previousLoads[ref] = store.getValue();
158+
previousStores[ref] = &op;
159+
continue;
160+
}
161+
162+
// If a function is called, clear known values for loads/stores used by
163+
// the function or its sub-functions.
164+
if (auto call = mlir::dyn_cast<CallOpInterface>(op)) {
165+
auto loadSymbols =
166+
loadSymbolsMap[dyn_cast<SymbolRefAttr>(call.getCallableForCallee())];
167+
auto storeSymbols =
168+
storeSymbolsMap[dyn_cast<SymbolRefAttr>(call.getCallableForCallee())];
169+
170+
for (auto sym : loadSymbols) {
171+
previousStores.erase(sym);
172+
}
173+
174+
for (auto sym : storeSymbols) {
175+
previousLoads.erase(sym);
176+
previousStores.erase(sym);
177+
}
178+
continue;
179+
}
180+
181+
// If the op has sub-regions, recurse inside. We make no guarantees whether
182+
// the recursion occurs.
183+
llvm::DenseSet<SymbolRefAttr> opSymbolLoad;
184+
llvm::DenseSet<SymbolRefAttr> opSymbolStore;
185+
for (auto &region : op.getRegions()) {
186+
for (auto &block : region) {
187+
ProcessBlock(block, opSymbolLoad, opSymbolStore);
188+
}
189+
}
190+
191+
// Update current state from the subblock.
192+
for (auto change : opSymbolLoad) {
193+
symbolLoad.insert(change);
194+
previousStores.erase(change);
195+
}
196+
197+
for (auto change : opSymbolStore) {
198+
symbolStore.insert(change);
199+
previousLoads.erase(change);
200+
previousStores.erase(change);
201+
}
202+
}
203+
204+
for (auto op : toDelete) {
205+
op->erase();
206+
}
207+
}
208+
209+
void MLProgramPipelineGlobals::runOnOperation() {
210+
auto targetOp = getOperation();
211+
if (failed(buildGlobalMap(targetOp))) {
212+
return;
213+
}
214+
215+
for (auto &funcOp : *targetOp.getBody()) {
216+
for (auto &region : funcOp.getRegions()) {
217+
for (auto &block : region.getBlocks()) {
218+
llvm::DenseSet<SymbolRefAttr> symbolsLoaded;
219+
llvm::DenseSet<SymbolRefAttr> symbolsStored;
220+
ProcessBlock(block, symbolsLoaded, symbolsStored);
221+
}
222+
}
223+
}
224+
}
225+
226+
} // namespace
227+
228+
std::unique_ptr<OperationPass<mlir::ModuleOp>>
229+
createMLProgramPipelineGlobalsPass() {
230+
return std::make_unique<MLProgramPipelineGlobals>();
231+
}
232+
233+
} // namespace ml_program
234+
} // namespace mlir

0 commit comments

Comments
 (0)