Skip to content

Commit b02dd86

Browse files
committed
[mlir][spirv] Add GpuToLLVM cconv suited to Vulkan, migrate last tests
This commit is a follow-up to 99a562b, which migrated some of the mlir-vulkan-runner tests to mlir-cpu-runner using a new pipeline and set of wrappers. That commit could not migrate all the tests, because the existing calling conventions/ABIs for kernel arguments generated by GPUToLLVMConversionPass were not a good fit for the Vulkan runtime. This commit fixes this and migrates the remaining tests. With this commit, mlir-vulkan-runner and many related components are now unused, and they will be removed in a later commit (see llvm#73457). The old calling conventions require both the caller (host LLVM code) and callee (device code) to have compile-time knowledge of the precise argument types. This works for CUDA, ROCm and SYCL, where there is a C-like calling convention agreed between the host and device code, and the runtime passes through arguments as raw data without comprehension. For Vulkan, however, the interface declared by the shader/kernel is in a more abstract form, so the device code has indirect access to the argument data, and the runtime must process the arguments to set up and bind appropriately-sized buffer descriptors. This commit introduces a new calling convention option to meet the Vulkan runtime's needs. It lowers memref arguments to {void*, size_t} pairs, which can be trivially interpreted by the runtime without it needing to know the original argument types. Unlike the stopgap measure in the previous commit, this system can support memrefs of various ranks and element types, which unblocked migrating the remaining tests.
1 parent fbea21a commit b02dd86

File tree

10 files changed

+113
-68
lines changed

10 files changed

+113
-68
lines changed

mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,10 @@ struct FunctionCallBuilder {
6262

6363
/// Collect a set of patterns to convert from the GPU dialect to LLVM and
6464
/// populate converter for gpu types.
65-
void populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter,
66-
RewritePatternSet &patterns,
67-
bool kernelBarePtrCallConv = false,
68-
bool typeCheckKernelArgs = false);
65+
void populateGpuToLLVMConversionPatterns(
66+
LLVMTypeConverter &converter, RewritePatternSet &patterns,
67+
bool kernelBarePtrCallConv = false,
68+
bool kernelIntersperseSizeCallConv = false);
6969

7070
/// A function that maps a MemorySpace enum to a target-specific integer value.
7171
using MemorySpaceMapping = std::function<unsigned(gpu::AddressSpace)>;

mlir/include/mlir/Conversion/Passes.td

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -518,11 +518,13 @@ def GpuToLLVMConversionPass : Pass<"gpu-to-llvm", "ModuleOp"> {
518518
"Use bare pointers to pass memref arguments to kernels. "
519519
"The kernel must use the same setting for this option."
520520
>,
521-
Option<"typeCheckKernelArgs", "type-check-kernel-args", "bool",
521+
Option<"kernelIntersperseSizeCallConv", "intersperse-sizes-for-kernels", "bool",
522522
/*default=*/"false",
523-
"Require all kernel arguments to be memrefs of rank 1 and with a "
524-
"32-bit element size. This is a temporary option that will be "
525-
"removed; TODO(https://github.com/llvm/llvm-project/issues/73457)."
523+
"Inserts a size_t argument following each memref argument, "
524+
"containing the static size in bytes of the buffer. Incompatible "
525+
"arguments are rejected. This is intended for use by the Vulkan "
526+
"runtime with the kernel bare pointer calling convention, to enable "
527+
"dynamic binding of buffers as arguments without static type info."
526528
>
527529
];
528530

mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp

Lines changed: 56 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -428,18 +428,18 @@ class LegalizeLaunchFuncOpPattern
428428
public:
429429
LegalizeLaunchFuncOpPattern(const LLVMTypeConverter &typeConverter,
430430
bool kernelBarePtrCallConv,
431-
bool typeCheckKernelArgs)
431+
bool kernelIntersperseSizeCallConv)
432432
: ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp>(typeConverter),
433433
kernelBarePtrCallConv(kernelBarePtrCallConv),
434-
typeCheckKernelArgs(typeCheckKernelArgs) {}
434+
kernelIntersperseSizeCallConv(kernelIntersperseSizeCallConv) {}
435435

436436
private:
437437
LogicalResult
438438
matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
439439
ConversionPatternRewriter &rewriter) const override;
440440

441441
bool kernelBarePtrCallConv;
442-
bool typeCheckKernelArgs;
442+
bool kernelIntersperseSizeCallConv;
443443
};
444444

445445
/// A rewrite pattern to convert gpu.memcpy operations into a GPU runtime
@@ -566,8 +566,9 @@ void GpuToLLVMConversionPass::runOnOperation() {
566566
populateFinalizeMemRefToLLVMConversionPatterns(converter, patterns);
567567
populateAsyncStructuralTypeConversionsAndLegality(converter, patterns,
568568
target);
569-
populateGpuToLLVMConversionPatterns(
570-
converter, patterns, kernelBarePtrCallConv, typeCheckKernelArgs);
569+
populateGpuToLLVMConversionPatterns(converter, patterns,
570+
kernelBarePtrCallConv,
571+
kernelIntersperseSizeCallConv);
571572

572573
if (failed(
573574
applyPartialConversion(getOperation(), target, std::move(patterns))))
@@ -970,33 +971,56 @@ LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite(
970971
else if (launchOp.getAsyncToken())
971972
stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
972973

973-
if (typeCheckKernelArgs) {
974-
// The current non-bare-pointer ABI is a bad fit for `mgpuLaunchKernel`,
975-
// which takes an untyped list of arguments. The type check here prevents
976-
// accidentally violating the assumption made in vulkan-runtime-wrappers.cpp
977-
// and creating a unchecked runtime ABI mismatch.
978-
// TODO(https://github.com/llvm/llvm-project/issues/73457): Change the ABI
979-
// here to remove the need for this type check.
980-
for (Value arg : launchOp.getKernelOperands()) {
981-
if (auto memrefTy = dyn_cast<MemRefType>(arg.getType())) {
982-
if (memrefTy.getRank() != 1 ||
983-
memrefTy.getElementTypeBitWidth() != 32) {
984-
return rewriter.notifyMatchFailure(
985-
launchOp, "Operand to launch op is not a rank-1 memref with "
986-
"32-bit element type.");
987-
}
988-
} else {
974+
// Lower the kernel operands to match kernel parameters.
975+
// Note: If `useBarePtrCallConv` is set in the type converter's options,
976+
// the value of `kernelBarePtrCallConv` will be ignored.
977+
OperandRange origArguments = launchOp.getKernelOperands();
978+
SmallVector<Value, 4> llvmArguments = getTypeConverter()->promoteOperands(
979+
loc, origArguments, adaptor.getKernelOperands(), rewriter,
980+
/*useBarePtrCallConv=*/kernelBarePtrCallConv);
981+
982+
// Intersperse size information if requested.
983+
if (kernelIntersperseSizeCallConv) {
984+
if (origArguments.size() != llvmArguments.size()) {
985+
// This shouldn't happen if the bare-pointer calling convention is used.
986+
return rewriter.notifyMatchFailure(
987+
launchOp,
988+
"Cannot add sizes to arguments with one-to-many LLVM IR expansion.");
989+
}
990+
991+
SmallVector<Value, 8> llvmArgumentsWithSizes;
992+
llvmArgumentsWithSizes.reserve(llvmArguments.size() * 2);
993+
for (auto [llvmArg, origArg] : zip_equal(llvmArguments, origArguments)) {
994+
auto memrefTy = dyn_cast<MemRefType>(origArg.getType());
995+
if (!memrefTy) {
989996
return rewriter.notifyMatchFailure(
990997
launchOp, "Operand to launch op is not a memref.");
991998
}
999+
1000+
if (!memrefTy.hasStaticShape() ||
1001+
!memrefTy.getElementType().isIntOrFloat()) {
1002+
return rewriter.notifyMatchFailure(
1003+
launchOp, "Operand to launch op is not a memref with a static "
1004+
"shape and an integer or float element type.");
1005+
}
1006+
1007+
unsigned bitwidth = memrefTy.getElementTypeBitWidth();
1008+
if (bitwidth % 8 != 0) {
1009+
return rewriter.notifyMatchFailure(
1010+
launchOp, "Operand to launch op is not a memref with a "
1011+
"byte-aligned element type.");
1012+
}
1013+
1014+
uint64_t staticSize =
1015+
uint64_t(bitwidth / 8) * uint64_t(memrefTy.getNumElements());
1016+
1017+
Value sizeArg = rewriter.create<LLVM::ConstantOp>(
1018+
loc, getIndexType(), rewriter.getIndexAttr(staticSize));
1019+
llvmArgumentsWithSizes.push_back(llvmArg); // presumably bare pointer
1020+
llvmArgumentsWithSizes.push_back(sizeArg);
9921021
}
1022+
llvmArguments = std::move(llvmArgumentsWithSizes);
9931023
}
994-
// Lower the kernel operands to match kernel parameters.
995-
// Note: If `useBarePtrCallConv` is set in the type converter's options,
996-
// the value of `kernelBarePtrCallConv` will be ignored.
997-
SmallVector<Value, 4> arguments = getTypeConverter()->promoteOperands(
998-
loc, launchOp.getKernelOperands(), adaptor.getKernelOperands(), rewriter,
999-
/*useBarePtrCallConv=*/kernelBarePtrCallConv);
10001024

10011025
std::optional<gpu::KernelDim3> clusterSize = std::nullopt;
10021026
if (launchOp.hasClusterSize()) {
@@ -1010,7 +1034,7 @@ LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite(
10101034
adaptor.getGridSizeZ()},
10111035
gpu::KernelDim3{adaptor.getBlockSizeX(), adaptor.getBlockSizeY(),
10121036
adaptor.getBlockSizeZ()},
1013-
adaptor.getDynamicSharedMemorySize(), arguments, stream, clusterSize);
1037+
adaptor.getDynamicSharedMemorySize(), llvmArguments, stream, clusterSize);
10141038
if (launchOp.getAsyncToken())
10151039
rewriter.replaceOp(launchOp, {stream});
10161040
else
@@ -1760,10 +1784,9 @@ LogicalResult ConvertCreateBsrOpToGpuRuntimeCallPattern::matchAndRewrite(
17601784
return success();
17611785
}
17621786

1763-
void mlir::populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter,
1764-
RewritePatternSet &patterns,
1765-
bool kernelBarePtrCallConv,
1766-
bool typeCheckKernelArgs) {
1787+
void mlir::populateGpuToLLVMConversionPatterns(
1788+
LLVMTypeConverter &converter, RewritePatternSet &patterns,
1789+
bool kernelBarePtrCallConv, bool kernelIntersperseSizeCallConv) {
17671790
addOpaquePointerConversion<gpu::AsyncTokenType>(converter);
17681791
addOpaquePointerConversion<gpu::SparseDnTensorHandleType>(converter);
17691792
addOpaquePointerConversion<gpu::SparseSpMatHandleType>(converter);
@@ -1801,7 +1824,7 @@ void mlir::populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter,
18011824
ConvertSpMatGetSizeOpToGpuRuntimeCallPattern,
18021825
ConvertSetCsrPointersOpToGpuRuntimeCallPattern>(converter);
18031826
patterns.add<LegalizeLaunchFuncOpPattern>(converter, kernelBarePtrCallConv,
1804-
typeCheckKernelArgs);
1827+
kernelIntersperseSizeCallConv);
18051828
}
18061829

18071830
//===----------------------------------------------------------------------===//
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// RUN: mlir-opt %s --gpu-to-llvm="use-bare-pointers-for-kernels=1 intersperse-sizes-for-kernels=1" -split-input-file | FileCheck %s
2+
3+
module attributes {gpu.container_module, spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>} {
4+
llvm.func @malloc(i64) -> !llvm.ptr
5+
gpu.binary @kernels [#gpu.object<#spirv.target_env<#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>, "">]
6+
func.func @main() attributes {llvm.emit_c_interface} {
7+
// CHECK: [[RANK1UMD:%.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
8+
%rank1UndefMemrefDescriptor = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
9+
// CHECK: [[RANK2UMD:%.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
10+
%rank2UndefMemrefDescriptor = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
11+
%c1 = arith.constant 1 : index
12+
// CHECK: [[PTR1:%.*]] = llvm.extractvalue [[RANK1UMD]][1]
13+
// CHECK: [[PTR2:%.*]] = llvm.extractvalue [[RANK2UMD]][1]
14+
// CHECK: [[PTR3:%.*]] = llvm.extractvalue [[RANK2UMD]][1]
15+
// CHECK: [[SIZE1:%.*]] = llvm.mlir.constant(32 : index) : i64
16+
// CHECK: [[SIZE2:%.*]] = llvm.mlir.constant(256 : index) : i64
17+
// CHECK: [[SIZE3:%.*]] = llvm.mlir.constant(48 : index) : i64
18+
%6 = builtin.unrealized_conversion_cast %rank1UndefMemrefDescriptor : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> to memref<8xf32>
19+
%10 = builtin.unrealized_conversion_cast %rank2UndefMemrefDescriptor : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<8x8xi32>
20+
%14 = builtin.unrealized_conversion_cast %rank2UndefMemrefDescriptor : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<4x12xi8>
21+
// CHECK: gpu.launch_func @kernels::@kernel_add blocks in ({{.*}}) threads in ({{.*}}) : i64 args([[PTR1]] : !llvm.ptr, [[SIZE1]] : i64, [[PTR2]] : !llvm.ptr, [[SIZE2]] : i64, [[PTR3]] : !llvm.ptr, [[SIZE3]] : i64)
22+
gpu.launch_func @kernels::@kernel_add blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%6 : memref<8xf32>, %10 : memref<8x8xi32>, %14 : memref<4x12xi8>)
23+
return
24+
}
25+
}

mlir/test/lib/Pass/TestVulkanRunnerPipeline.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,11 @@ void buildTestVulkanRunnerPipeline(OpPassManager &passManager,
6868
passManager.addPass(createFinalizeMemRefToLLVMConversionPass());
6969
passManager.nest<func::FuncOp>().addPass(
7070
LLVM::createRequestCWrappersPass());
71-
// vulkan-runtime-wrappers.cpp uses the non-bare-pointer calling convention,
72-
// and the type check is needed to prevent accidental ABI mismatches.
71+
// These calling convention options match vulkan-runtime-wrappers.cpp
7372
GpuToLLVMConversionPassOptions opt;
7473
opt.hostBarePtrCallConv = false;
75-
opt.kernelBarePtrCallConv = false;
76-
opt.typeCheckKernelArgs = true;
74+
opt.kernelBarePtrCallConv = true;
75+
opt.kernelIntersperseSizeCallConv = true;
7776
passManager.addPass(createGpuToLLVMConversionPass(opt));
7877
}
7978
}

mlir/test/mlir-vulkan-runner/addi.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
// RUN: mlir-opt %s -test-vulkan-runner-pipeline \
2-
// RUN: | mlir-vulkan-runner - --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils --entry-point-result=void | FileCheck %s
1+
// RUN: mlir-opt %s -test-vulkan-runner-pipeline=to-llvm \
2+
// RUN: | mlir-cpu-runner - --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils --entry-point-result=void | FileCheck %s
33

44
// CHECK-COUNT-64: [3, 3, 3, 3, 3, 3, 3, 3]
55
module attributes {

mlir/test/mlir-vulkan-runner/addi8.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
// RUN: mlir-opt %s -test-vulkan-runner-pipeline \
2-
// RUN: | mlir-vulkan-runner - --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils --entry-point-result=void | FileCheck %s
1+
// RUN: mlir-opt %s -test-vulkan-runner-pipeline=to-llvm \
2+
// RUN: | mlir-cpu-runner - --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils --entry-point-result=void | FileCheck %s
33

44
// CHECK-COUNT-64: [3, 3, 3, 3, 3, 3, 3, 3]
55
module attributes {

mlir/test/mlir-vulkan-runner/mulf.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
// RUN: mlir-opt %s -test-vulkan-runner-pipeline \
2-
// RUN: | mlir-vulkan-runner - --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils --entry-point-result=void | FileCheck %s
1+
// RUN: mlir-opt %s -test-vulkan-runner-pipeline=to-llvm \
2+
// RUN: | mlir-cpu-runner - --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils --entry-point-result=void | FileCheck %s
33

44
// CHECK-COUNT-4: [6, 6, 6, 6]
55
module attributes {

mlir/test/mlir-vulkan-runner/subf.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
// RUN: mlir-opt %s -test-vulkan-runner-pipeline \
2-
// RUN: | mlir-vulkan-runner - --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils --entry-point-result=void | FileCheck %s
1+
// RUN: mlir-opt %s -test-vulkan-runner-pipeline=to-llvm \
2+
// RUN: | mlir-cpu-runner - --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils --entry-point-result=void | FileCheck %s
33

44
// CHECK-COUNT-32: [2.2, 2.2, 2.2, 2.2]
55
module attributes {

mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -169,26 +169,22 @@ mgpuLaunchKernel(void *vkKernel, size_t gridX, size_t gridY, size_t gridZ,
169169
void ** /*extra*/, size_t paramsCount) {
170170
auto manager = static_cast<VulkanRuntimeManager *>(vkRuntimeManager);
171171

172-
// The non-bare-pointer memref ABI interacts badly with mgpuLaunchKernel's
173-
// signature:
174-
// - The memref descriptor struct gets split into several elements, each
175-
// passed as their own "param".
176-
// - No metadata is provided as to the rank or element type/size of a memref.
177-
// Here we assume that all MemRefs have rank 1 and an element size of
178-
// 4 bytes. This means each descriptor struct will have five members.
179-
// TODO(https://github.com/llvm/llvm-project/issues/73457): Refactor the
180-
// ABI/API of mgpuLaunchKernel to use a different ABI for memrefs, so
181-
// that other memref types can also be used. This will allow migrating
182-
// the remaining tests and removal of mlir-vulkan-runner.
183-
const size_t paramsPerMemRef = 5;
172+
// GpuToLLVMConversionPass with the kernelBarePtrCallConv and
173+
// kernelIntersperseSizeCallConv options will set up the params array like:
174+
// { &memref_ptr0, &memref_size0, &memref_ptr1, &memref_size1, ... }
175+
const size_t paramsPerMemRef = 2;
184176
if (paramsCount % paramsPerMemRef != 0) {
185-
abort();
177+
abort(); // This would indicate a serious calling convention mismatch.
186178
}
187179
const DescriptorSetIndex setIndex = 0;
188180
BindingIndex bindIndex = 0;
189181
for (size_t i = 0; i < paramsCount; i += paramsPerMemRef) {
190-
auto memref = static_cast<MemRefDescriptor<uint32_t, 1> *>(params[i]);
191-
bindMemRef<uint32_t, 1>(manager, setIndex, bindIndex, memref);
182+
void *memrefBufferBasePtr = *static_cast<void **>(params[i + 0]);
183+
size_t memrefBufferSize = *static_cast<size_t *>(params[i + 1]);
184+
VulkanHostMemoryBuffer memBuffer{memrefBufferBasePtr,
185+
static_cast<uint32_t>(memrefBufferSize)};
186+
reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)
187+
->setResourceData(setIndex, bindIndex, memBuffer);
192188
++bindIndex;
193189
}
194190

0 commit comments

Comments
 (0)