Skip to content

[mlir][spirv] Add GpuToLLVM cconv suited to Vulkan, migrate last tests #123384

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

andfau-amd
Copy link
Contributor

This commit is a follow-up to 99a562b, which migrated some of the mlir-vulkan-runner tests to mlir-cpu-runner using a new pipeline and set of wrappers. That commit could not migrate all the tests, because the existing calling conventions/ABIs for kernel arguments generated by GPUToLLVMConversionPass were not a good fit for the Vulkan runtime. This commit fixes this and migrates the remaining tests. With this commit, mlir-vulkan-runner and many related components are now unused, and they will be removed in a later commit (see #73457).

The old calling conventions require both the caller (host LLVM code) and callee (device code) to have compile-time knowledge of the precise argument types. This works for CUDA, ROCm and SYCL, where there is a C-like calling convention agreed between the host and device code, and the runtime passes through arguments as raw data without comprehension. For Vulkan, however, the interface declared by the shader/kernel is in a more abstract form, so the device code has indirect access to the argument data, and the runtime must process the arguments to set up and bind appropriately-sized buffer descriptors.

This commit introduces a new calling convention option to meet the Vulkan runtime's needs. It lowers memref arguments to {void*, size_t} pairs, which can be trivially interpreted by the runtime without it needing to know the original argument types. Unlike the stopgap measure in the previous commit, this system can support memrefs of various ranks and element types, which unblocked migrating the remaining tests.

@llvmbot
Copy link
Member

llvmbot commented Jan 17, 2025

@llvm/pr-subscribers-mlir-gpu
@llvm/pr-subscribers-mlir-spirv
@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir

Author: Andrea Faulds (andfau-amd)

Changes

This commit is a follow-up to 99a562b, which migrated some of the mlir-vulkan-runner tests to mlir-cpu-runner using a new pipeline and set of wrappers. That commit could not migrate all the tests, because the existing calling conventions/ABIs for kernel arguments generated by GPUToLLVMConversionPass were not a good fit for the Vulkan runtime. This commit fixes this and migrates the remaining tests. With this commit, mlir-vulkan-runner and many related components are now unused, and they will be removed in a later commit (see #73457).

The old calling conventions require both the caller (host LLVM code) and callee (device code) to have compile-time knowledge of the precise argument types. This works for CUDA, ROCm and SYCL, where there is a C-like calling convention agreed between the host and device code, and the runtime passes through arguments as raw data without comprehension. For Vulkan, however, the interface declared by the shader/kernel is in a more abstract form, so the device code has indirect access to the argument data, and the runtime must process the arguments to set up and bind appropriately-sized buffer descriptors.

This commit introduces a new calling convention option to meet the Vulkan runtime's needs. It lowers memref arguments to {void*, size_t} pairs, which can be trivially interpreted by the runtime without it needing to know the original argument types. Unlike the stopgap measure in the previous commit, this system can support memrefs of various ranks and element types, which unblocked migrating the remaining tests.


Full diff: https://github.com/llvm/llvm-project/pull/123384.diff

10 Files Affected:

  • (modified) mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h (+1-1)
  • (modified) mlir/include/mlir/Conversion/Passes.td (+6-4)
  • (modified) mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp (+56-33)
  • (added) mlir/test/Conversion/GPUCommon/lower-launch-func-bare-ptr-intersperse-size.mlir (+25)
  • (modified) mlir/test/lib/Pass/TestVulkanRunnerPipeline.cpp (+3-4)
  • (modified) mlir/test/mlir-vulkan-runner/addi.mlir (+2-2)
  • (modified) mlir/test/mlir-vulkan-runner/addi8.mlir (+2-2)
  • (modified) mlir/test/mlir-vulkan-runner/mulf.mlir (+2-2)
  • (modified) mlir/test/mlir-vulkan-runner/subf.mlir (+2-2)
  • (modified) mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp (+11-15)
diff --git a/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h b/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h
index cf0c96f0eba000c..e682fb1ff49d1de 100644
--- a/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h
+++ b/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h
@@ -65,7 +65,7 @@ struct FunctionCallBuilder {
 void populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter,
                                          RewritePatternSet &patterns,
                                          bool kernelBarePtrCallConv = false,
-                                         bool typeCheckKernelArgs = false);
+                                         bool kernelIntersperseSizeCallConv = false);
 
 /// A function that maps a MemorySpace enum to a target-specific integer value.
 using MemorySpaceMapping = std::function<unsigned(gpu::AddressSpace)>;
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 0f42ffb3a802667..34f7ab46b629820 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -518,11 +518,13 @@ def GpuToLLVMConversionPass : Pass<"gpu-to-llvm", "ModuleOp"> {
              "Use bare pointers to pass memref arguments to kernels. "
              "The kernel must use the same setting for this option."
           >,
-    Option<"typeCheckKernelArgs", "type-check-kernel-args", "bool",
+    Option<"kernelIntersperseSizeCallConv", "intersperse-sizes-for-kernels", "bool",
            /*default=*/"false",
-             "Require all kernel arguments to be memrefs of rank 1 and with a "
-             "32-bit element size. This is a temporary option that will be "
-             "removed; TODO(https://github.com/llvm/llvm-project/issues/73457)."
+           "Inserts a size_t argument following each memref argument, "
+           "containing the static size in bytes of the buffer. Incompatible "
+           "arguments are rejected. This is intended for use by the Vulkan "
+           "runtime with the kernel bare pointer calling convention, to enable "
+           "dynamic binding of buffers as arguments without static type info."
           >
   ];
 
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index ca9883a79dc168f..9d61fe9c7ac1fb9 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -428,10 +428,10 @@ class LegalizeLaunchFuncOpPattern
 public:
   LegalizeLaunchFuncOpPattern(const LLVMTypeConverter &typeConverter,
                               bool kernelBarePtrCallConv,
-                              bool typeCheckKernelArgs)
+                              bool kernelIntersperseSizeCallConv)
       : ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp>(typeConverter),
         kernelBarePtrCallConv(kernelBarePtrCallConv),
-        typeCheckKernelArgs(typeCheckKernelArgs) {}
+        kernelIntersperseSizeCallConv(kernelIntersperseSizeCallConv) {}
 
 private:
   LogicalResult
@@ -439,7 +439,7 @@ class LegalizeLaunchFuncOpPattern
                   ConversionPatternRewriter &rewriter) const override;
 
   bool kernelBarePtrCallConv;
-  bool typeCheckKernelArgs;
+  bool kernelIntersperseSizeCallConv;
 };
 
 /// A rewrite pattern to convert gpu.memcpy operations into a GPU runtime
@@ -566,8 +566,9 @@ void GpuToLLVMConversionPass::runOnOperation() {
   populateFinalizeMemRefToLLVMConversionPatterns(converter, patterns);
   populateAsyncStructuralTypeConversionsAndLegality(converter, patterns,
                                                     target);
-  populateGpuToLLVMConversionPatterns(
-      converter, patterns, kernelBarePtrCallConv, typeCheckKernelArgs);
+  populateGpuToLLVMConversionPatterns(converter, patterns,
+                                      kernelBarePtrCallConv,
+                                      kernelIntersperseSizeCallConv);
 
   if (failed(
           applyPartialConversion(getOperation(), target, std::move(patterns))))
@@ -970,33 +971,56 @@ LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite(
   else if (launchOp.getAsyncToken())
     stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
 
-  if (typeCheckKernelArgs) {
-    // The current non-bare-pointer ABI is a bad fit for `mgpuLaunchKernel`,
-    // which takes an untyped list of arguments. The type check here prevents
-    // accidentally violating the assumption made in vulkan-runtime-wrappers.cpp
-    // and creating a unchecked runtime ABI mismatch.
-    // TODO(https://github.com/llvm/llvm-project/issues/73457): Change the ABI
-    // here to remove the need for this type check.
-    for (Value arg : launchOp.getKernelOperands()) {
-      if (auto memrefTy = dyn_cast<MemRefType>(arg.getType())) {
-        if (memrefTy.getRank() != 1 ||
-            memrefTy.getElementTypeBitWidth() != 32) {
-          return rewriter.notifyMatchFailure(
-              launchOp, "Operand to launch op is not a rank-1 memref with "
-                        "32-bit element type.");
-        }
-      } else {
+  // Lower the kernel operands to match kernel parameters.
+  // Note: If `useBarePtrCallConv` is set in the type converter's options,
+  // the value of `kernelBarePtrCallConv` will be ignored.
+  OperandRange origArguments = launchOp.getKernelOperands();
+  SmallVector<Value, 4> llvmArguments = getTypeConverter()->promoteOperands(
+      loc, origArguments, adaptor.getKernelOperands(), rewriter,
+      /*useBarePtrCallConv=*/kernelBarePtrCallConv);
+
+  // Intersperse size information if requested.
+  if (kernelIntersperseSizeCallConv) {
+    if (origArguments.size() != llvmArguments.size()) {
+      // This shouldn't happen if the bare-pointer calling convention is used.
+      return rewriter.notifyMatchFailure(
+          launchOp,
+          "Cannot add sizes to arguments with one-to-many LLVM IR expansion.");
+    }
+
+    SmallVector<Value, 8> llvmArgumentsWithSizes;
+    llvmArgumentsWithSizes.reserve(llvmArguments.size() * 2);
+    for (auto [llvmArg, origArg] : zip_equal(llvmArguments, origArguments)) {
+      auto memrefTy = dyn_cast<MemRefType>(origArg.getType());
+      if (!memrefTy) {
         return rewriter.notifyMatchFailure(
             launchOp, "Operand to launch op is not a memref.");
       }
+
+      if (!memrefTy.hasStaticShape() ||
+          !memrefTy.getElementType().isIntOrFloat()) {
+        return rewriter.notifyMatchFailure(
+            launchOp, "Operand to launch op is not a memref with a static "
+                      "shape and an integer or float element type.");
+      }
+
+      unsigned bitwidth = memrefTy.getElementTypeBitWidth();
+      if (bitwidth % 8 != 0) {
+        return rewriter.notifyMatchFailure(
+            launchOp, "Operand to launch op is not a memref with a "
+                      "byte-aligned element type.");
+      }
+
+      uint64_t staticSize =
+          uint64_t(bitwidth / 8) * uint64_t(memrefTy.getNumElements());
+
+      Value sizeArg = rewriter.create<LLVM::ConstantOp>(
+          loc, getIndexType(), rewriter.getIndexAttr(staticSize));
+      llvmArgumentsWithSizes.push_back(llvmArg); // presumably bare pointer
+      llvmArgumentsWithSizes.push_back(sizeArg);
     }
+    llvmArguments = std::move(llvmArgumentsWithSizes);
   }
-  // Lower the kernel operands to match kernel parameters.
-  // Note: If `useBarePtrCallConv` is set in the type converter's options,
-  // the value of `kernelBarePtrCallConv` will be ignored.
-  SmallVector<Value, 4> arguments = getTypeConverter()->promoteOperands(
-      loc, launchOp.getKernelOperands(), adaptor.getKernelOperands(), rewriter,
-      /*useBarePtrCallConv=*/kernelBarePtrCallConv);
 
   std::optional<gpu::KernelDim3> clusterSize = std::nullopt;
   if (launchOp.hasClusterSize()) {
@@ -1010,7 +1034,7 @@ LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite(
                       adaptor.getGridSizeZ()},
       gpu::KernelDim3{adaptor.getBlockSizeX(), adaptor.getBlockSizeY(),
                       adaptor.getBlockSizeZ()},
-      adaptor.getDynamicSharedMemorySize(), arguments, stream, clusterSize);
+      adaptor.getDynamicSharedMemorySize(), llvmArguments, stream, clusterSize);
   if (launchOp.getAsyncToken())
     rewriter.replaceOp(launchOp, {stream});
   else
@@ -1760,10 +1784,9 @@ LogicalResult ConvertCreateBsrOpToGpuRuntimeCallPattern::matchAndRewrite(
   return success();
 }
 
-void mlir::populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter,
-                                               RewritePatternSet &patterns,
-                                               bool kernelBarePtrCallConv,
-                                               bool typeCheckKernelArgs) {
+void mlir::populateGpuToLLVMConversionPatterns(
+    LLVMTypeConverter &converter, RewritePatternSet &patterns,
+    bool kernelBarePtrCallConv, bool kernelIntersperseSizeCallConv) {
   addOpaquePointerConversion<gpu::AsyncTokenType>(converter);
   addOpaquePointerConversion<gpu::SparseDnTensorHandleType>(converter);
   addOpaquePointerConversion<gpu::SparseSpMatHandleType>(converter);
@@ -1801,7 +1824,7 @@ void mlir::populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter,
                ConvertSpMatGetSizeOpToGpuRuntimeCallPattern,
                ConvertSetCsrPointersOpToGpuRuntimeCallPattern>(converter);
   patterns.add<LegalizeLaunchFuncOpPattern>(converter, kernelBarePtrCallConv,
-                                            typeCheckKernelArgs);
+                                            kernelIntersperseSizeCallConv);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/GPUCommon/lower-launch-func-bare-ptr-intersperse-size.mlir b/mlir/test/Conversion/GPUCommon/lower-launch-func-bare-ptr-intersperse-size.mlir
new file mode 100644
index 000000000000000..171b13da227136c
--- /dev/null
+++ b/mlir/test/Conversion/GPUCommon/lower-launch-func-bare-ptr-intersperse-size.mlir
@@ -0,0 +1,25 @@
+// RUN: mlir-opt %s --gpu-to-llvm="use-bare-pointers-for-kernels=1 intersperse-sizes-for-kernels=1" -split-input-file | FileCheck %s
+
+module attributes {gpu.container_module, spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>} {
+  llvm.func @malloc(i64) -> !llvm.ptr
+  gpu.binary @kernels  [#gpu.object<#spirv.target_env<#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>, "">]
+  func.func @main() attributes {llvm.emit_c_interface} {
+    // CHECK: [[RANK1UMD:%.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    %rank1UndefMemrefDescriptor = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK: [[RANK2UMD:%.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+    %rank2UndefMemrefDescriptor = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+    %c1 = arith.constant 1 : index
+    // CHECK: [[PTR1:%.*]] = llvm.extractvalue [[RANK1UMD]][1]
+    // CHECK: [[PTR2:%.*]] = llvm.extractvalue [[RANK2UMD]][1]
+    // CHECK: [[PTR3:%.*]] = llvm.extractvalue [[RANK2UMD]][1]
+    // CHECK: [[SIZE1:%.*]] = llvm.mlir.constant(32 : index) : i64
+    // CHECK: [[SIZE2:%.*]] = llvm.mlir.constant(256 : index) : i64
+    // CHECK: [[SIZE3:%.*]] = llvm.mlir.constant(48 : index) : i64
+    %6 = builtin.unrealized_conversion_cast %rank1UndefMemrefDescriptor : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> to memref<8xf32>
+    %10 = builtin.unrealized_conversion_cast %rank2UndefMemrefDescriptor : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<8x8xi32>
+    %14 = builtin.unrealized_conversion_cast %rank2UndefMemrefDescriptor : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<4x12xi8>
+    // CHECK: gpu.launch_func  @kernels::@kernel_add blocks in ({{.*}}) threads in ({{.*}}) : i64 args([[PTR1]] : !llvm.ptr, [[SIZE1]] : i64, [[PTR2]] : !llvm.ptr, [[SIZE2]] : i64, [[PTR3]] : !llvm.ptr, [[SIZE3]] : i64)
+    gpu.launch_func  @kernels::@kernel_add blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1)  args(%6 : memref<8xf32>, %10 : memref<8x8xi32>, %14 : memref<4x12xi8>)
+    return
+  }
+}
diff --git a/mlir/test/lib/Pass/TestVulkanRunnerPipeline.cpp b/mlir/test/lib/Pass/TestVulkanRunnerPipeline.cpp
index a3624eb31e26e57..691053b694a2344 100644
--- a/mlir/test/lib/Pass/TestVulkanRunnerPipeline.cpp
+++ b/mlir/test/lib/Pass/TestVulkanRunnerPipeline.cpp
@@ -68,12 +68,11 @@ void buildTestVulkanRunnerPipeline(OpPassManager &passManager,
     passManager.addPass(createFinalizeMemRefToLLVMConversionPass());
     passManager.nest<func::FuncOp>().addPass(
         LLVM::createRequestCWrappersPass());
-    // vulkan-runtime-wrappers.cpp uses the non-bare-pointer calling convention,
-    // and the type check is needed to prevent accidental ABI mismatches.
+    // These calling convention options match vulkan-runtime-wrappers.cpp
     GpuToLLVMConversionPassOptions opt;
     opt.hostBarePtrCallConv = false;
-    opt.kernelBarePtrCallConv = false;
-    opt.typeCheckKernelArgs = true;
+    opt.kernelBarePtrCallConv = true;
+    opt.kernelIntersperseSizeCallConv = true;
     passManager.addPass(createGpuToLLVMConversionPass(opt));
   }
 }
diff --git a/mlir/test/mlir-vulkan-runner/addi.mlir b/mlir/test/mlir-vulkan-runner/addi.mlir
index 7e212a4fb179c29..e60eb7696770a89 100644
--- a/mlir/test/mlir-vulkan-runner/addi.mlir
+++ b/mlir/test/mlir-vulkan-runner/addi.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s -test-vulkan-runner-pipeline \
-// RUN:   | mlir-vulkan-runner - --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils --entry-point-result=void | FileCheck %s
+// RUN: mlir-opt %s -test-vulkan-runner-pipeline=to-llvm \
+// RUN:   | mlir-cpu-runner - --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils --entry-point-result=void | FileCheck %s
 
 // CHECK-COUNT-64: [3, 3, 3, 3, 3, 3, 3, 3]
 module attributes {
diff --git a/mlir/test/mlir-vulkan-runner/addi8.mlir b/mlir/test/mlir-vulkan-runner/addi8.mlir
index e0b1a8e8875c020..6374f29fd0ed2d5 100644
--- a/mlir/test/mlir-vulkan-runner/addi8.mlir
+++ b/mlir/test/mlir-vulkan-runner/addi8.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s -test-vulkan-runner-pipeline \
-// RUN:   | mlir-vulkan-runner - --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils --entry-point-result=void | FileCheck %s
+// RUN: mlir-opt %s -test-vulkan-runner-pipeline=to-llvm \
+// RUN:   | mlir-cpu-runner - --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils --entry-point-result=void | FileCheck %s
 
 // CHECK-COUNT-64: [3, 3, 3, 3, 3, 3, 3, 3]
 module attributes {
diff --git a/mlir/test/mlir-vulkan-runner/mulf.mlir b/mlir/test/mlir-vulkan-runner/mulf.mlir
index 22fa034a9b455fe..bd1b8d9abf4c9f2 100644
--- a/mlir/test/mlir-vulkan-runner/mulf.mlir
+++ b/mlir/test/mlir-vulkan-runner/mulf.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s -test-vulkan-runner-pipeline \
-// RUN:   | mlir-vulkan-runner - --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils --entry-point-result=void | FileCheck %s
+// RUN: mlir-opt %s -test-vulkan-runner-pipeline=to-llvm \
+// RUN:   | mlir-cpu-runner - --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils --entry-point-result=void | FileCheck %s
 
 // CHECK-COUNT-4: [6, 6, 6, 6]
 module attributes {
diff --git a/mlir/test/mlir-vulkan-runner/subf.mlir b/mlir/test/mlir-vulkan-runner/subf.mlir
index 23496ef3abc00da..e8cee0e021a2770 100644
--- a/mlir/test/mlir-vulkan-runner/subf.mlir
+++ b/mlir/test/mlir-vulkan-runner/subf.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s -test-vulkan-runner-pipeline \
-// RUN:   | mlir-vulkan-runner - --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils --entry-point-result=void | FileCheck %s
+// RUN: mlir-opt %s -test-vulkan-runner-pipeline=to-llvm \
+// RUN:   | mlir-cpu-runner - --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils --entry-point-result=void | FileCheck %s
 
 // CHECK-COUNT-32: [2.2, 2.2, 2.2, 2.2]
 module attributes {
diff --git a/mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp b/mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp
index ffd1114cec6aa37..f1c86cb23928c52 100644
--- a/mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp
+++ b/mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp
@@ -169,26 +169,22 @@ mgpuLaunchKernel(void *vkKernel, size_t gridX, size_t gridY, size_t gridZ,
                  void ** /*extra*/, size_t paramsCount) {
   auto manager = static_cast<VulkanRuntimeManager *>(vkRuntimeManager);
 
-  // The non-bare-pointer memref ABI interacts badly with mgpuLaunchKernel's
-  // signature:
-  // - The memref descriptor struct gets split into several elements, each
-  //   passed as their own "param".
-  // - No metadata is provided as to the rank or element type/size of a memref.
-  //   Here we assume that all MemRefs have rank 1 and an element size of
-  //   4 bytes. This means each descriptor struct will have five members.
-  // TODO(https://github.com/llvm/llvm-project/issues/73457): Refactor the
-  //       ABI/API of mgpuLaunchKernel to use a different ABI for memrefs, so
-  //       that other memref types can also be used. This will allow migrating
-  //       the remaining tests and removal of mlir-vulkan-runner.
-  const size_t paramsPerMemRef = 5;
+  // GpuToLLVMConversionPass with the kernelBarePtrCallConv and
+  // kernelIntersperseSizeCallConv options will set up the params array like:
+  // { &memref_ptr0, &memref_size0, &memref_ptr1, &memref_size1, ... }
+  const size_t paramsPerMemRef = 2;
   if (paramsCount % paramsPerMemRef != 0) {
-    abort();
+    abort(); // This would indicate a serious calling convention mismatch.
   }
   const DescriptorSetIndex setIndex = 0;
   BindingIndex bindIndex = 0;
   for (size_t i = 0; i < paramsCount; i += paramsPerMemRef) {
-    auto memref = static_cast<MemRefDescriptor<uint32_t, 1> *>(params[i]);
-    bindMemRef<uint32_t, 1>(manager, setIndex, bindIndex, memref);
+    void *memrefBufferBasePtr = *static_cast<void **>(params[i + 0]);
+    size_t memrefBufferSize = *static_cast<size_t *>(params[i + 1]);
+    VulkanHostMemoryBuffer memBuffer{memrefBufferBasePtr,
+                                     static_cast<uint32_t>(memrefBufferSize)};
+    reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)
+        ->setResourceData(setIndex, bindIndex, memBuffer);
     ++bindIndex;
   }
 

Comment on lines +1000 to +1012
if (!memrefTy.hasStaticShape() ||
!memrefTy.getElementType().isIntOrFloat()) {
return rewriter.notifyMatchFailure(
launchOp, "Operand to launch op is not a memref with a static "
"shape and an integer or float element type.");
}

unsigned bitwidth = memrefTy.getElementTypeBitWidth();
if (bitwidth % 8 != 0) {
return rewriter.notifyMatchFailure(
launchOp, "Operand to launch op is not a memref with a "
"byte-aligned element type.");
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like I can't be the first person to have wanted to do this, and I'm not sure I'm checking all the necessary conditions. Is there some handy function for this that I'm missing?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't recall seeing a helper for this either. Most code uses llvm::ceilDiv for storage size though instead of rejecting element sizes that are not multiples of 8 bits.

Copy link
Contributor Author

@andfau-amd andfau-amd Jan 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there some rule somewhere that guarantees elements smaller than 8 bits get padded to 8?

I doubt the Vulkan target can actually do anything useful with non-multiple-of-8-bits types though, so maybe this simple check is fine.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was only able to find this:

// Non 1-bit dense elements are padded to 8-bits.
size_t storageSize = llvm::divideCeil(elementWidth, CHAR_BIT);
assert(((data.size() / storageSize) == numElements) &&
"data does not hold expected number of elements");
but it doesn't give us a definite answer. It could be that memref data layout is not universally specified on purpose. cc: @Hardcode84

I think it's fine to reject such memrefs. If someone wants to support them in the future, we can work with them on adding that.

}
llvmArguments = std::move(llvmArgumentsWithSizes);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is pretty ugly but I don't have a better idea. Inserting into an existing list would be ugly and slow (in a very theoretical sense).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I decided I disliked this enough to change it to having two arrays on the stack and a conditional selection. The advantage of this setup is there's no move from a SmallVector<T, 8> into a SmallVector<T, 4>, which had been making me wince a bit.

Copy link

github-actions bot commented Jan 17, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@andfau-amd andfau-amd force-pushed the 73457-runner-migration-vulkan-pipeline-with-serialization-wrappers-new-abi branch from c71170f to b02dd86 Compare January 17, 2025 18:30
bool typeCheckKernelArgs = false);
void populateGpuToLLVMConversionPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns,
bool kernelBarePtrCallConv = false,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe better to make it enum instead of 2 bools?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That might not be a bad idea, but I am really unsure if it's possible for an option to be an enum, or what I have to do if I want that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Specifically the Option defined in Passes.td. I know that it's somehow leveraging the cl infrastructure so enums "should" be possible but I haven't managed to figure it out from staring at the headers and tablegen definitions, there's too many layers.)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall, just some nits

Comment on lines +1000 to +1012
if (!memrefTy.hasStaticShape() ||
!memrefTy.getElementType().isIntOrFloat()) {
return rewriter.notifyMatchFailure(
launchOp, "Operand to launch op is not a memref with a static "
"shape and an integer or float element type.");
}

unsigned bitwidth = memrefTy.getElementTypeBitWidth();
if (bitwidth % 8 != 0) {
return rewriter.notifyMatchFailure(
launchOp, "Operand to launch op is not a memref with a "
"byte-aligned element type.");
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't recall seeing a helper for this either. Most code uses llvm::ceilDiv for storage size though instead of rejecting element sizes that are not multiples of 8 bits.

This commit is a follow-up to 99a562b,
which migrated some of the mlir-vulkan-runner tests to mlir-cpu-runner
using a new pipeline and set of wrappers. That commit could not migrate
all the tests, because the existing calling conventions/ABIs for kernel
arguments generated by GPUToLLVMConversionPass were not a good fit for
the Vulkan runtime. This commit fixes this and migrates the remaining
tests. With this commit, mlir-vulkan-runner and many related components
are now unused, and they will be removed in a later commit (see llvm#73457).

The old calling conventions require both the caller (host LLVM code)
and callee (device code) to have compile-time knowledge of the precise
argument types. This works for CUDA, ROCm and SYCL, where there is a
C-like calling convention agreed between the host and device code, and
the runtime passes through arguments as raw data without comprehension.
For Vulkan, however, the interface declared by the shader/kernel is in a
more abstract form, so the device code has indirect access to the
argument data, and the runtime must process the arguments to set up and
bind appropriately-sized buffer descriptors.

This commit introduces a new calling convention option to meet the
Vulkan runtime's needs. It lowers memref arguments to {void*, size_t}
pairs, which can be trivially interpreted by the runtime without it
needing to know the original argument types. Unlike the stopgap measure
in the previous commit, this system can support memrefs of various ranks
and element types, which unblocked migrating the remaining tests.
@andfau-amd andfau-amd force-pushed the 73457-runner-migration-vulkan-pipeline-with-serialization-wrappers-new-abi branch from b02dd86 to 8fb279b Compare January 20, 2025 17:59
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

bool typeCheckKernelArgs = false);
void populateGpuToLLVMConversionPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns,
bool kernelBarePtrCallConv = false,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SmallVector<Value, 8> llvmArguments = getTypeConverter()->promoteOperands(
loc, origArguments, adaptor.getKernelOperands(), rewriter,
/*useBarePtrCallConv=*/kernelBarePtrCallConv);
SmallVector<Value, 8> llvmArgumentsWithSizes;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: don't need to explicitly specify 8 unless this is some special case (0/1)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer the explicit size in this case because promoteOperands returns a SmallVector<Value, 4>, which isn't going to be large enough once the sizes are added (three operands becomes six). But I guess it doesn't matter too much...

@andfau-amd andfau-amd merged commit 733be4e into llvm:main Jan 21, 2025
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants