-
Notifications
You must be signed in to change notification settings - Fork 15k
[mlir][gpu] GPUToROCDL/NVVM: use generic llvm conversion interface instead of hardcoded conversions. #124439
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Why is this a draft? Is more work needed? |
depends on #121440 |
No need to keep it as draft - these are basically stacked PRs. Add this info to the summary and a note that only the top commit should be reviewed. |
2497c43
to
1b3f7e4
Compare
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-gpu Author: Ivan Butygin (Hardcode84) ChangesUsing Add Full diff: https://github.com/llvm/llvm-project/pull/124439.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index ff79a1226c047bc..1873d95eed88f2a 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -550,14 +550,16 @@ def ConvertGpuOpsToNVVMOps : Pass<"convert-gpu-to-nvvm", "gpu::GPUModuleOp"> {
];
let options = [
Option<"indexBitwidth", "index-bitwidth", "unsigned",
- /*default=kDeriveIndexBitwidthFromDataLayout*/"0",
+ /*default=kDeriveIndexBitwidthFromDataLayout*/ "0",
"Bitwidth of the index type, 0 to use size of machine word">,
Option<"hasRedux", "has-redux", "bool", /*default=*/"false",
"Target gpu supports redux">,
Option<"useBarePtrCallConv", "use-bare-ptr-memref-call-conv", "bool",
/*default=*/"false",
"Replace memref arguments in GPU functions with bare pointers. "
- "All memrefs must have static shape.">
+ "All memrefs must have static shape.">,
+ ListOption<"filterDialects", "filter-dialects", "std::string",
+ "Run conversion patterns of only the specified dialects">,
];
}
@@ -578,20 +580,24 @@ def ConvertGpuOpsToROCDLOps : Pass<"convert-gpu-to-rocdl", "gpu::GPUModuleOp"> {
/*default=*/"\"gfx000\"",
"Chipset that these operations will run on">,
Option<"indexBitwidth", "index-bitwidth", "unsigned",
- /*default=kDeriveIndexBitwidthFromDataLayout*/"0",
+ /*default=kDeriveIndexBitwidthFromDataLayout*/ "0",
"Bitwidth of the index type, 0 to use size of machine word">,
Option<"useBarePtrCallConv", "use-bare-ptr-memref-call-conv", "bool",
/*default=*/"false",
"Replace memref arguments in GPU functions with bare pointers."
"All memrefs must have static shape">,
Option<"runtime", "runtime", "::mlir::gpu::amd::Runtime",
- "::mlir::gpu::amd::Runtime::Unknown",
- "Runtime code will be run on (default is Unknown, can also use HIP or OpenCl)",
- [{::llvm::cl::values(
- clEnumValN(::mlir::gpu::amd::Runtime::Unknown, "unknown", "Unknown (default)"),
- clEnumValN(::mlir::gpu::amd::Runtime::HIP, "HIP", "HIP"),
- clEnumValN(::mlir::gpu::amd::Runtime::OpenCL, "OpenCL", "OpenCL")
- )}]>
+ "::mlir::gpu::amd::Runtime::Unknown",
+ "Runtime code will be run on (default is Unknown, can also use HIP "
+ "or OpenCl)",
+ [{::llvm::cl::values(
+ clEnumValN(::mlir::gpu::amd::Runtime::Unknown, "unknown",
+ "Unknown (default)"),
+ clEnumValN(::mlir::gpu::amd::Runtime::HIP, "HIP", "HIP"),
+ clEnumValN(::mlir::gpu::amd::Runtime::OpenCL, "OpenCL",
+ "OpenCL"))}]>,
+ ListOption<"filterDialects", "filter-dialects", "std::string",
+ "Run conversion patterns of only the specified dialects">,
];
}
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 11363a0d60ebfa1..e03335a9f696c5a 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -11,19 +11,14 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
-
-#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
-#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
-#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
#include "mlir/Conversion/GPUToNVVM/GPUToNVVM.h"
+#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
-#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
-#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -346,6 +341,11 @@ struct LowerGpuOpsToNVVMOpsPass
: public impl::ConvertGpuOpsToNVVMOpsBase<LowerGpuOpsToNVVMOpsPass> {
using Base::Base;
+ void getDependentDialects(DialectRegistry ®istry) const override final {
+ Base::getDependentDialects(registry);
+ registerConvertToLLVMDependentDialectLoading(registry);
+ }
+
void runOnOperation() override {
gpu::GPUModuleOp m = getOperation();
@@ -376,17 +376,44 @@ struct LowerGpuOpsToNVVMOpsPass
LLVMTypeConverter converter(m.getContext(), options);
configureGpuToNVVMTypeConverter(converter);
RewritePatternSet llvmPatterns(m.getContext());
+ LLVMConversionTarget target(getContext());
+
+ if (!filterDialects.empty()) {
+ for (StringRef dialectName : filterDialects) {
+ Dialect *dialect = getContext().getLoadedDialect(dialectName);
+ // Dialect may not be loaded if it wasn't used in source module, ignore.
+ if (!dialect)
+ continue;
+
+ auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
+ if (!iface) {
+ m.emitError()
+ << "dialect does not implement ConvertToLLVMPatternInterface: "
+ << dialectName << "\n";
+ return signalPassFailure();
+ }
+
+ iface->populateConvertToLLVMConversionPatterns(target, converter,
+ llvmPatterns);
+ }
+ } else {
+ for (Dialect *dialect : getContext().getLoadedDialects()) {
+ if (isa<math::MathDialect>(dialect)) // Need custom math lowering
+ continue;
+
+ auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
+ if (!iface)
+ continue;
+
+ iface->populateConvertToLLVMConversionPatterns(target, converter,
+ llvmPatterns);
+ }
+ }
- arith::populateArithToLLVMConversionPatterns(converter, llvmPatterns);
- cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns);
- populateFuncToLLVMConversionPatterns(converter, llvmPatterns);
- populateFinalizeMemRefToLLVMConversionPatterns(converter, llvmPatterns);
populateGpuToNVVMConversionPatterns(converter, llvmPatterns);
populateGpuWMMAToNVVMConversionPatterns(converter, llvmPatterns);
- populateVectorToLLVMConversionPatterns(converter, llvmPatterns);
if (this->hasRedux)
populateGpuSubgroupReduceOpLoweringPattern(converter, llvmPatterns);
- LLVMConversionTarget target(getContext());
configureGpuToNVVMConversionLegality(target);
if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
signalPassFailure();
@@ -397,6 +424,7 @@ struct LowerGpuOpsToNVVMOpsPass
void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
target.addIllegalOp<func::FuncOp>();
+ target.addIllegalOp<cf::AssertOp>();
target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
target.addIllegalDialect<gpu::GPUDialect>();
@@ -472,8 +500,10 @@ void mlir::populateGpuToNVVMConversionPatterns(
using gpu::index_lowering::IndexKind;
using gpu::index_lowering::IntrType;
populateWithGenerated(patterns);
+
+ // Set higher benefit, so patterns will run before generic LLVM lowering.
patterns.add<GPUPrintfOpToVPrintfLowering, AssertOpToAssertfailLowering>(
- converter);
+ converter, /*benefit*/ 10);
patterns.add<
gpu::index_lowering::OpLowering<gpu::ThreadIdOp, NVVM::ThreadIdXOp,
NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>>(
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index afebded1c3ea401..48f24b3fb95494d 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -11,7 +11,6 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Pass/Pass.h"
@@ -19,8 +18,8 @@
#include "mlir/Transforms/Passes.h"
#include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
-#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
-#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
@@ -28,8 +27,6 @@
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
-#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
-#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
@@ -218,6 +215,11 @@ struct LowerGpuOpsToROCDLOpsPass
this->runtime = runtime;
}
+ void getDependentDialects(DialectRegistry ®istry) const override final {
+ Base::getDependentDialects(registry);
+ registerConvertToLLVMDependentDialectLoading(registry);
+ }
+
void runOnOperation() override {
gpu::GPUModuleOp m = getOperation();
MLIRContext *ctx = m.getContext();
@@ -289,18 +291,40 @@ struct LowerGpuOpsToROCDLOpsPass
});
RewritePatternSet llvmPatterns(ctx);
+ LLVMConversionTarget target(getContext());
+
+ if (!filterDialects.empty()) {
+ for (StringRef dialectName : filterDialects) {
+ Dialect *dialect = ctx->getLoadedDialect(dialectName);
+ // Dialect may not be loaded if it wasn't used in source module, ignore.
+ if (!dialect)
+ continue;
+
+ auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
+ if (!iface) {
+ m.emitError()
+ << "dialect does not implement ConvertToLLVMPatternInterface: "
+ << dialectName << "\n";
+ return signalPassFailure();
+ }
+
+ iface->populateConvertToLLVMConversionPatterns(target, converter,
+ llvmPatterns);
+ }
+ } else {
+ for (Dialect *dialect : ctx->getLoadedDialects()) {
+ auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
+ if (!iface)
+ continue;
+
+ iface->populateConvertToLLVMConversionPatterns(target, converter,
+ llvmPatterns);
+ }
+ }
- mlir::arith::populateArithToLLVMConversionPatterns(converter, llvmPatterns);
populateAMDGPUToROCDLConversionPatterns(converter, llvmPatterns,
*maybeChipset);
- populateVectorToLLVMConversionPatterns(converter, llvmPatterns);
- populateMathToLLVMConversionPatterns(converter, llvmPatterns);
- cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns);
- cf::populateAssertToLLVMConversionPattern(converter, llvmPatterns);
- populateFuncToLLVMConversionPatterns(converter, llvmPatterns);
- populateFinalizeMemRefToLLVMConversionPatterns(converter, llvmPatterns);
populateGpuToROCDLConversionPatterns(converter, llvmPatterns, runtime);
- LLVMConversionTarget target(getContext());
configureGpuToROCDLConversionLegality(target);
if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
signalPassFailure();
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
index de2a4ff2079e2d7..e917ae46dfc2420 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -1,4 +1,5 @@
// RUN: mlir-opt %s -convert-gpu-to-nvvm='has-redux=1' -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -convert-gpu-to-nvvm='has-redux=1 filter-dialects=func,arith,cf' -split-input-file | FileCheck %s
// RUN: mlir-opt %s -convert-gpu-to-nvvm='has-redux=1 use-bare-ptr-memref-call-conv=1' -split-input-file | FileCheck %s --check-prefix=CHECK-BARE
// RUN: mlir-opt %s -transform-interpreter | FileCheck %s
diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
index 11b9fa5e33f10e6..4e59578d078a98c 100644
--- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
@@ -1,4 +1,5 @@
// RUN: mlir-opt %s -convert-gpu-to-rocdl -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -convert-gpu-to-rocdl='filter-dialects=func,arith,math' -split-input-file | FileCheck %s
// RUN: mlir-opt %s -convert-gpu-to-rocdl='index-bitwidth=32' -split-input-file | FileCheck --check-prefix=CHECK32 %s
// CHECK-LABEL: @test_module
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good overall
Nice! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for the changes.
Just one thing: the flag name in the PR description is out of sync with the code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for implementing that
f1b1edc
to
bcd2d22
Compare
…stead of hardcoded conversions. (llvm#124439) Using `ConvertToLLVMPatternInterface` allows to unhardcode specific dialect conversions from passes and, more importantly, allows downstream projects to inject their ops/types translation here by registering corresponding interface. Add `allowed-dialects` option so user can control which dialects can be used to populate conversions.
…stead of hardcoded conversions. (llvm#124439) Using `ConvertToLLVMPatternInterface` allows to unhardcode specific dialect conversions from passes and, more importantly, allows downstream projects to inject their ops/types translation here by registering corresponding interface. Add `allowed-dialects` option so user can control which dialects can be used to populate conversions.
Using
ConvertToLLVMPatternInterface
allows to unhardcode specific dialect conversions from passes and, more importantly, allows downstream projects to inject their ops/types translation here by registering corresponding interface.Add
allowed-dialects
option so user can control which dialects can be used to populate conversions.