@@ -60,23 +60,30 @@ void ExternalNameConversionPass::runOnOperation() {
6060
6161 llvm::DenseMap<mlir::StringAttr, mlir::FlatSymbolRefAttr> remappings;
6262
63+ auto processFctOrGlobal = [&](mlir::Operation &funcOrGlobal) {
64+ auto symName = funcOrGlobal.getAttrOfType <mlir::StringAttr>(
65+ mlir::SymbolTable::getSymbolAttrName ());
66+ auto deconstructedName = fir::NameUniquer::deconstruct (symName);
67+ if (fir::NameUniquer::isExternalFacingUniquedName (deconstructedName)) {
68+ auto newName = mangleExternalName (deconstructedName, appendUnderscoreOpt);
69+ auto newAttr = mlir::StringAttr::get (context, newName);
70+ mlir::SymbolTable::setSymbolName (&funcOrGlobal, newAttr);
71+ auto newSymRef = mlir::FlatSymbolRefAttr::get (newAttr);
72+ remappings.try_emplace (symName, newSymRef);
73+ if (llvm::isa<mlir::func::FuncOp>(funcOrGlobal))
74+ funcOrGlobal.setAttr (fir::getInternalFuncNameAttrName (), symName);
75+ }
76+ };
77+
6378 auto renameFuncOrGlobalInModule = [&](mlir::Operation *module ) {
64- for (auto &funcOrGlobal : module ->getRegion (0 ).front ()) {
65- if (llvm::isa<mlir::func::FuncOp>(funcOrGlobal) ||
66- llvm::isa<fir::GlobalOp>(funcOrGlobal)) {
67- auto symName = funcOrGlobal.getAttrOfType <mlir::StringAttr>(
68- mlir::SymbolTable::getSymbolAttrName ());
69- auto deconstructedName = fir::NameUniquer::deconstruct (symName);
70- if (fir::NameUniquer::isExternalFacingUniquedName (deconstructedName)) {
71- auto newName =
72- mangleExternalName (deconstructedName, appendUnderscoreOpt);
73- auto newAttr = mlir::StringAttr::get (context, newName);
74- mlir::SymbolTable::setSymbolName (&funcOrGlobal, newAttr);
75- auto newSymRef = mlir::FlatSymbolRefAttr::get (newAttr);
76- remappings.try_emplace (symName, newSymRef);
77- if (llvm::isa<mlir::func::FuncOp>(funcOrGlobal))
78- funcOrGlobal.setAttr (fir::getInternalFuncNameAttrName (), symName);
79- }
79+ for (auto &op : module ->getRegion (0 ).front ()) {
80+ if (mlir::isa<mlir::func::FuncOp, fir::GlobalOp>(op)) {
81+ processFctOrGlobal (op);
82+ } else if (auto gpuMod = mlir::dyn_cast<mlir::gpu::GPUModuleOp>(op)) {
83+ for (auto &gpuOp : gpuMod.getBodyRegion ().front ())
84+ if (mlir::isa<mlir::func::FuncOp, fir::GlobalOp,
85+ mlir::gpu::GPUFuncOp>(gpuOp))
86+ processFctOrGlobal (gpuOp);
8087 }
8188 }
8289 };
@@ -85,23 +92,25 @@ void ExternalNameConversionPass::runOnOperation() {
8592 // globals.
8693 renameFuncOrGlobalInModule (op);
8794
88- // Do the same in GPU modules.
89- if (auto mod = mlir::dyn_cast_or_null<mlir::ModuleOp>(*op))
90- for (auto gpuMod : mod.getOps <mlir::gpu::GPUModuleOp>())
91- renameFuncOrGlobalInModule (gpuMod);
92-
9395 if (remappings.empty ())
9496 return ;
9597
9698 // Update all uses of the functions and globals that have been renamed.
9799 op.walk ([&remappings](mlir::Operation *nestedOp) {
98100 llvm::SmallVector<std::pair<mlir::StringAttr, mlir::SymbolRefAttr>> updates;
99101 for (const mlir::NamedAttribute &attr : nestedOp->getAttrDictionary ())
100- if (auto symRef = llvm::dyn_cast<mlir::SymbolRefAttr>(attr.getValue ()))
101- if (auto remap = remappings.find (symRef.getRootReference ());
102- remap != remappings.end ())
102+ if (auto symRef = llvm::dyn_cast<mlir::SymbolRefAttr>(attr.getValue ())) {
103+ if (auto remap = remappings.find (symRef.getLeafReference ());
104+ remap != remappings.end ()) {
105+ mlir::SymbolRefAttr symAttr = mlir::FlatSymbolRefAttr (remap->second );
106+ if (mlir::isa<mlir::gpu::LaunchFuncOp>(nestedOp))
107+ symAttr = mlir::SymbolRefAttr::get (
108+ symRef.getRootReference (),
109+ {mlir::FlatSymbolRefAttr (remap->second )});
103110 updates.emplace_back (std::pair<mlir::StringAttr, mlir::SymbolRefAttr>{
104- attr.getName (), mlir::SymbolRefAttr (remap->second )});
111+ attr.getName (), symAttr});
112+ }
113+ }
105114 for (auto update : updates)
106115 nestedOp->setAttr (update.first , update.second );
107116 });
0 commit comments