Skip to content

Conversation

Hardcode84
Copy link
Contributor

@Hardcode84 Hardcode84 commented Jan 26, 2025

Using ConvertToLLVMPatternInterface allows to unhardcode specific dialect conversions from passes and, more importantly, allows downstream projects to inject their ops/types translation here by registering corresponding interface.

Add allowed-dialects option so user can control which dialects can be used to populate conversions.

@banach-space
Copy link
Contributor

Why is this a draft? Is more work needed?

@Hardcode84
Copy link
Contributor Author

depends on #121440

@banach-space
Copy link
Contributor

depends on #121440

No need to keep it as draft - these are basically stacked PRs. Add this info to the summary and a note that only the top commit should be reviewed.

@Hardcode84 Hardcode84 force-pushed the to-gpu-llvm-translate branch from 2497c43 to 1b3f7e4 Compare February 8, 2025 17:22
@Hardcode84 Hardcode84 marked this pull request as ready for review February 8, 2025 17:27
@llvmbot
Copy link
Member

llvmbot commented Feb 8, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-gpu

Author: Ivan Butygin (Hardcode84)

Changes

Using ConvertToLLVMPatternInterface allows to unhardcode specific dialect conversions from passes and, more importantly, allows downstream projects to inject their ops/types translation here by registering corresponding interface.

Add filter-dialects option so user can control which dialects can be used to populate conversions.


Full diff: https://github.com/llvm/llvm-project/pull/124439.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Conversion/Passes.td (+16-10)
  • (modified) mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp (+44-14)
  • (modified) mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp (+37-13)
  • (modified) mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir (+1)
  • (modified) mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir (+1)
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index ff79a1226c047bc..1873d95eed88f2a 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -550,14 +550,16 @@ def ConvertGpuOpsToNVVMOps : Pass<"convert-gpu-to-nvvm", "gpu::GPUModuleOp"> {
   ];
   let options = [
     Option<"indexBitwidth", "index-bitwidth", "unsigned",
-           /*default=kDeriveIndexBitwidthFromDataLayout*/"0",
+           /*default=kDeriveIndexBitwidthFromDataLayout*/ "0",
            "Bitwidth of the index type, 0 to use size of machine word">,
     Option<"hasRedux", "has-redux", "bool", /*default=*/"false",
            "Target gpu supports redux">,
     Option<"useBarePtrCallConv", "use-bare-ptr-memref-call-conv", "bool",
            /*default=*/"false",
            "Replace memref arguments in GPU functions with bare pointers. "
-           "All memrefs must have static shape.">
+           "All memrefs must have static shape.">,
+    ListOption<"filterDialects", "filter-dialects", "std::string",
+               "Run conversion patterns of only the specified dialects">,
   ];
 }
 
@@ -578,20 +580,24 @@ def ConvertGpuOpsToROCDLOps : Pass<"convert-gpu-to-rocdl", "gpu::GPUModuleOp"> {
            /*default=*/"\"gfx000\"",
            "Chipset that these operations will run on">,
     Option<"indexBitwidth", "index-bitwidth", "unsigned",
-           /*default=kDeriveIndexBitwidthFromDataLayout*/"0",
+           /*default=kDeriveIndexBitwidthFromDataLayout*/ "0",
            "Bitwidth of the index type, 0 to use size of machine word">,
     Option<"useBarePtrCallConv", "use-bare-ptr-memref-call-conv", "bool",
            /*default=*/"false",
            "Replace memref arguments in GPU functions with bare pointers."
            "All memrefs must have static shape">,
     Option<"runtime", "runtime", "::mlir::gpu::amd::Runtime",
-          "::mlir::gpu::amd::Runtime::Unknown",
-          "Runtime code will be run on (default is Unknown, can also use HIP or OpenCl)",
-          [{::llvm::cl::values(
-            clEnumValN(::mlir::gpu::amd::Runtime::Unknown, "unknown", "Unknown (default)"),
-            clEnumValN(::mlir::gpu::amd::Runtime::HIP, "HIP", "HIP"),
-            clEnumValN(::mlir::gpu::amd::Runtime::OpenCL, "OpenCL", "OpenCL")
-          )}]>
+           "::mlir::gpu::amd::Runtime::Unknown",
+           "Runtime code will be run on (default is Unknown, can also use HIP "
+           "or OpenCl)",
+           [{::llvm::cl::values(
+               clEnumValN(::mlir::gpu::amd::Runtime::Unknown, "unknown",
+                          "Unknown (default)"),
+               clEnumValN(::mlir::gpu::amd::Runtime::HIP, "HIP", "HIP"),
+               clEnumValN(::mlir::gpu::amd::Runtime::OpenCL, "OpenCL",
+                          "OpenCL"))}]>,
+    ListOption<"filterDialects", "filter-dialects", "std::string",
+               "Run conversion patterns of only the specified dialects">,
   ];
 }
 
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 11363a0d60ebfa1..e03335a9f696c5a 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -11,19 +11,14 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
-
-#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
-#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
-#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
 #include "mlir/Conversion/GPUToNVVM/GPUToNVVM.h"
+#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
 #include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
-#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
-#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -346,6 +341,11 @@ struct LowerGpuOpsToNVVMOpsPass
     : public impl::ConvertGpuOpsToNVVMOpsBase<LowerGpuOpsToNVVMOpsPass> {
   using Base::Base;
 
+  void getDependentDialects(DialectRegistry &registry) const override final {
+    Base::getDependentDialects(registry);
+    registerConvertToLLVMDependentDialectLoading(registry);
+  }
+
   void runOnOperation() override {
     gpu::GPUModuleOp m = getOperation();
 
@@ -376,17 +376,44 @@ struct LowerGpuOpsToNVVMOpsPass
     LLVMTypeConverter converter(m.getContext(), options);
     configureGpuToNVVMTypeConverter(converter);
     RewritePatternSet llvmPatterns(m.getContext());
+    LLVMConversionTarget target(getContext());
+
+    if (!filterDialects.empty()) {
+      for (StringRef dialectName : filterDialects) {
+        Dialect *dialect = getContext().getLoadedDialect(dialectName);
+        // Dialect may not be loaded if it wasn't used in source module, ignore.
+        if (!dialect)
+          continue;
+
+        auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
+        if (!iface) {
+          m.emitError()
+              << "dialect does not implement ConvertToLLVMPatternInterface: "
+              << dialectName << "\n";
+          return signalPassFailure();
+        }
+
+        iface->populateConvertToLLVMConversionPatterns(target, converter,
+                                                       llvmPatterns);
+      }
+    } else {
+      for (Dialect *dialect : getContext().getLoadedDialects()) {
+        if (isa<math::MathDialect>(dialect)) // Need custom math lowering
+          continue;
+
+        auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
+        if (!iface)
+          continue;
+
+        iface->populateConvertToLLVMConversionPatterns(target, converter,
+                                                       llvmPatterns);
+      }
+    }
 
-    arith::populateArithToLLVMConversionPatterns(converter, llvmPatterns);
-    cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns);
-    populateFuncToLLVMConversionPatterns(converter, llvmPatterns);
-    populateFinalizeMemRefToLLVMConversionPatterns(converter, llvmPatterns);
     populateGpuToNVVMConversionPatterns(converter, llvmPatterns);
     populateGpuWMMAToNVVMConversionPatterns(converter, llvmPatterns);
-    populateVectorToLLVMConversionPatterns(converter, llvmPatterns);
     if (this->hasRedux)
       populateGpuSubgroupReduceOpLoweringPattern(converter, llvmPatterns);
-    LLVMConversionTarget target(getContext());
     configureGpuToNVVMConversionLegality(target);
     if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
       signalPassFailure();
@@ -397,6 +424,7 @@ struct LowerGpuOpsToNVVMOpsPass
 
 void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
   target.addIllegalOp<func::FuncOp>();
+  target.addIllegalOp<cf::AssertOp>();
   target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
   target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
   target.addIllegalDialect<gpu::GPUDialect>();
@@ -472,8 +500,10 @@ void mlir::populateGpuToNVVMConversionPatterns(
   using gpu::index_lowering::IndexKind;
   using gpu::index_lowering::IntrType;
   populateWithGenerated(patterns);
+
+  // Set higher benefit, so patterns will run before generic LLVM lowering.
   patterns.add<GPUPrintfOpToVPrintfLowering, AssertOpToAssertfailLowering>(
-      converter);
+      converter, /*benefit*/ 10);
   patterns.add<
       gpu::index_lowering::OpLowering<gpu::ThreadIdOp, NVVM::ThreadIdXOp,
                                       NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>>(
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index afebded1c3ea401..48f24b3fb95494d 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -11,7 +11,6 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
 #include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
 #include "mlir/Dialect/Arith/Transforms/Passes.h"
 #include "mlir/Pass/Pass.h"
@@ -19,8 +18,8 @@
 #include "mlir/Transforms/Passes.h"
 
 #include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
-#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
-#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
 #include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
@@ -28,8 +27,6 @@
 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
 #include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
-#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
-#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
@@ -218,6 +215,11 @@ struct LowerGpuOpsToROCDLOpsPass
       this->runtime = runtime;
   }
 
+  void getDependentDialects(DialectRegistry &registry) const override final {
+    Base::getDependentDialects(registry);
+    registerConvertToLLVMDependentDialectLoading(registry);
+  }
+
   void runOnOperation() override {
     gpu::GPUModuleOp m = getOperation();
     MLIRContext *ctx = m.getContext();
@@ -289,18 +291,40 @@ struct LowerGpuOpsToROCDLOpsPass
         });
 
     RewritePatternSet llvmPatterns(ctx);
+    LLVMConversionTarget target(getContext());
+
+    if (!filterDialects.empty()) {
+      for (StringRef dialectName : filterDialects) {
+        Dialect *dialect = ctx->getLoadedDialect(dialectName);
+        // Dialect may not be loaded if it wasn't used in source module, ignore.
+        if (!dialect)
+          continue;
+
+        auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
+        if (!iface) {
+          m.emitError()
+              << "dialect does not implement ConvertToLLVMPatternInterface: "
+              << dialectName << "\n";
+          return signalPassFailure();
+        }
+
+        iface->populateConvertToLLVMConversionPatterns(target, converter,
+                                                       llvmPatterns);
+      }
+    } else {
+      for (Dialect *dialect : ctx->getLoadedDialects()) {
+        auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
+        if (!iface)
+          continue;
+
+        iface->populateConvertToLLVMConversionPatterns(target, converter,
+                                                       llvmPatterns);
+      }
+    }
 
-    mlir::arith::populateArithToLLVMConversionPatterns(converter, llvmPatterns);
     populateAMDGPUToROCDLConversionPatterns(converter, llvmPatterns,
                                             *maybeChipset);
-    populateVectorToLLVMConversionPatterns(converter, llvmPatterns);
-    populateMathToLLVMConversionPatterns(converter, llvmPatterns);
-    cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns);
-    cf::populateAssertToLLVMConversionPattern(converter, llvmPatterns);
-    populateFuncToLLVMConversionPatterns(converter, llvmPatterns);
-    populateFinalizeMemRefToLLVMConversionPatterns(converter, llvmPatterns);
     populateGpuToROCDLConversionPatterns(converter, llvmPatterns, runtime);
-    LLVMConversionTarget target(getContext());
     configureGpuToROCDLConversionLegality(target);
     if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
       signalPassFailure();
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
index de2a4ff2079e2d7..e917ae46dfc2420 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -1,4 +1,5 @@
 // RUN: mlir-opt %s -convert-gpu-to-nvvm='has-redux=1' -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -convert-gpu-to-nvvm='has-redux=1 filter-dialects=func,arith,cf' -split-input-file | FileCheck %s
 // RUN: mlir-opt %s -convert-gpu-to-nvvm='has-redux=1 use-bare-ptr-memref-call-conv=1' -split-input-file | FileCheck %s --check-prefix=CHECK-BARE
 // RUN: mlir-opt %s -transform-interpreter | FileCheck %s
 
diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
index 11b9fa5e33f10e6..4e59578d078a98c 100644
--- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
@@ -1,4 +1,5 @@
 // RUN: mlir-opt %s -convert-gpu-to-rocdl -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -convert-gpu-to-rocdl='filter-dialects=func,arith,math' -split-input-file | FileCheck %s
 // RUN: mlir-opt %s -convert-gpu-to-rocdl='index-bitwidth=32' -split-input-file | FileCheck --check-prefix=CHECK32 %s
 
 // CHECK-LABEL: @test_module

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall

@joker-eph
Copy link
Collaborator

Nice!

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the changes.

Just one thing: the flag name in the PR description is out of sync with the code.

Copy link
Member

@grypp grypp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for implementing that

@Hardcode84 Hardcode84 force-pushed the to-gpu-llvm-translate branch from f1b1edc to bcd2d22 Compare February 13, 2025 11:51
@Hardcode84 Hardcode84 merged commit aecb764 into llvm:main Feb 13, 2025
8 checks passed
@Hardcode84 Hardcode84 deleted the to-gpu-llvm-translate branch February 13, 2025 14:53
joaosaffran pushed a commit to joaosaffran/llvm-project that referenced this pull request Feb 14, 2025
…stead of hardcoded conversions. (llvm#124439)

Using `ConvertToLLVMPatternInterface` allows to unhardcode specific
dialect conversions from passes and, more importantly, allows downstream
projects to inject their ops/types translation here by registering
corresponding interface.

Add `allowed-dialects` option so user can control which dialects can be
used to populate conversions.
sivan-shani pushed a commit to sivan-shani/llvm-project that referenced this pull request Feb 24, 2025
…stead of hardcoded conversions. (llvm#124439)

Using `ConvertToLLVMPatternInterface` allows to unhardcode specific
dialect conversions from passes and, more importantly, allows downstream
projects to inject their ops/types translation here by registering
corresponding interface.

Add `allowed-dialects` option so user can control which dialects can be
used to populate conversions.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants