diff --git a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h index e0bf560dbd98b..a1f64be57fa69 100644 --- a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h +++ b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h @@ -20,6 +20,7 @@ class IRBuilderBase; } namespace mlir { +class SymbolTable; namespace LLVM { class ModuleTranslation; } @@ -55,11 +56,13 @@ class TargetOptions { } CompilationTarget; /// Constructor initializing the toolkit path, the list of files to link to, - /// extra command line options & the compilation target. The default - /// compilation target is `binary`. + /// extra command line options, the compilation target and a callback for + /// obtaining the parent symbol table. The default compilation target is + /// `binOrFatbin`. TargetOptions(StringRef toolkitPath = {}, ArrayRef linkFiles = {}, StringRef cmdOptions = {}, - CompilationTarget compilationTarget = binOrFatbin); + CompilationTarget compilationTarget = binOrFatbin, + function_ref getSymbolTableCallback = {}); /// Returns the typeID. TypeID getTypeID() const; @@ -80,12 +83,20 @@ class TargetOptions { /// Returns the compilation target. CompilationTarget getCompilationTarget() const; + /// Returns the result of the `getSymbolTableCallback` callback or a nullptr + /// if no callback was provided. + /// Note: The callback itself can return nullptr. It is up to the target how + /// to react to getting a nullptr, e.g., emitting an error or constructing the + /// table. + SymbolTable *getSymbolTable() const; + protected: /// Derived classes must use this constructor to initialize `typeID` to the /// appropiate value: ie. `TargetOptions(TypeID::get())`. TargetOptions(TypeID typeID, StringRef toolkitPath = {}, ArrayRef linkFiles = {}, StringRef cmdOptions = {}, - CompilationTarget compilationTarget = binOrFatbin); + CompilationTarget compilationTarget = binOrFatbin, + function_ref getSymbolTableCallback = {}); /// Path to the target toolkit. std::string toolkitPath; @@ -100,6 +111,10 @@ class TargetOptions { /// Compilation process target representation. CompilationTarget compilationTarget; + /// Callback for obtaining the parent symbol table of all the GPU modules + /// being serialized. + function_ref getSymbolTableCallback; + private: TypeID typeID; }; diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td index ba8a6266604e4..0bfb275099205 100644 --- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td @@ -64,9 +64,11 @@ def GpuModuleToBinaryPass with an object for every target. The `format` argument can have the following values: - 1. `offloading`, `llvm`: producing an offloading representation. - 2. `assembly`, `isa`: producing assembly code. - 3. `binary`, `bin`: producing binaries. + 1. `offloading`, `llvm`: produces an offloading representation. + 2. `assembly`, `isa`: produces assembly code. + 3. `binary`, `bin`: produces binaries. + 4. `fatbinary`, `fatbin`: produces fatbinaries. + 5. `binOrFatbin`: produces bins or fatbins, the target decides which. }]; let options = [ Option<"offloadingHandler", "handler", "Attribute", "nullptr", diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index f417a083337fc..46fb1766bc405 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -1979,20 +1979,20 @@ gpu::SelectObjectAttr::verify(function_ref emitError, // GPU target options //===----------------------------------------------------------------------===// -TargetOptions::TargetOptions(StringRef toolkitPath, - ArrayRef linkFiles, - StringRef cmdOptions, - CompilationTarget compilationTarget) +TargetOptions::TargetOptions( + StringRef toolkitPath, ArrayRef linkFiles, + StringRef cmdOptions, CompilationTarget compilationTarget, + function_ref getSymbolTableCallback) : TargetOptions(TypeID::get(), toolkitPath, linkFiles, - cmdOptions, compilationTarget) {} + cmdOptions, compilationTarget, getSymbolTableCallback) {} -TargetOptions::TargetOptions(TypeID typeID, StringRef toolkitPath, - ArrayRef linkFiles, - StringRef cmdOptions, - CompilationTarget compilationTarget) +TargetOptions::TargetOptions( + TypeID typeID, StringRef toolkitPath, ArrayRef linkFiles, + StringRef cmdOptions, CompilationTarget compilationTarget, + function_ref getSymbolTableCallback) : toolkitPath(toolkitPath.str()), linkFiles(linkFiles), cmdOptions(cmdOptions.str()), compilationTarget(compilationTarget), - typeID(typeID) {} + getSymbolTableCallback(getSymbolTableCallback), typeID(typeID) {} TypeID TargetOptions::getTypeID() const { return typeID; } @@ -2002,6 +2002,10 @@ ArrayRef TargetOptions::getLinkFiles() const { return linkFiles; } StringRef TargetOptions::getCmdOptions() const { return cmdOptions; } +SymbolTable *TargetOptions::getSymbolTable() const { + return getSymbolTableCallback ? getSymbolTableCallback() : nullptr; +} + std::pair> TargetOptions::tokenizeCmdOptions() const { std::pair> options; diff --git a/mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp b/mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp index 06b7dee6941e1..e29a1f0c3248d 100644 --- a/mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp @@ -66,9 +66,26 @@ void GpuModuleToBinaryPass::runOnOperation() { .Default(-1); if (targetFormat == -1) getOperation()->emitError() << "Invalid format specified."; + + // Lazy symbol table builder callback. + std::optional parentTable; + auto lazyTableBuilder = [&]() -> SymbolTable * { + // Build the table if it has not been built. + if (!parentTable) { + Operation *table = SymbolTable::getNearestSymbolTable(getOperation()); + // It's up to the target attribute to determine if failing to find a + // symbol table is an error. + if (!table) + return nullptr; + parentTable = SymbolTable(table); + } + return &parentTable.value(); + }; + TargetOptions targetOptions( toolkitPath, linkFiles, cmdOptions, - static_cast(targetFormat)); + static_cast(targetFormat), + lazyTableBuilder); if (failed(transformGpuModulesToBinaries( getOperation(), offloadingHandler ? dyn_cast(