diff --git a/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h b/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h index cf0c96f0eba00..312f1fc5b20c9 100644 --- a/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h +++ b/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h @@ -62,10 +62,10 @@ struct FunctionCallBuilder { /// Collect a set of patterns to convert from the GPU dialect to LLVM and /// populate converter for gpu types. -void populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter, - RewritePatternSet &patterns, - bool kernelBarePtrCallConv = false, - bool typeCheckKernelArgs = false); +void populateGpuToLLVMConversionPatterns( + LLVMTypeConverter &converter, RewritePatternSet &patterns, + bool kernelBarePtrCallConv = false, + bool kernelIntersperseSizeCallConv = false); /// A function that maps a MemorySpace enum to a target-specific integer value. using MemorySpaceMapping = std::function; diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 0f42ffb3a8026..34f7ab46b6298 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -518,11 +518,13 @@ def GpuToLLVMConversionPass : Pass<"gpu-to-llvm", "ModuleOp"> { "Use bare pointers to pass memref arguments to kernels. " "The kernel must use the same setting for this option." >, - Option<"typeCheckKernelArgs", "type-check-kernel-args", "bool", + Option<"kernelIntersperseSizeCallConv", "intersperse-sizes-for-kernels", "bool", /*default=*/"false", - "Require all kernel arguments to be memrefs of rank 1 and with a " - "32-bit element size. This is a temporary option that will be " - "removed; TODO(https://github.com/llvm/llvm-project/issues/73457)." + "Inserts a size_t argument following each memref argument, " + "containing the static size in bytes of the buffer. Incompatible " + "arguments are rejected. This is intended for use by the Vulkan " + "runtime with the kernel bare pointer calling convention, to enable " + "dynamic binding of buffers as arguments without static type info." > ]; diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp index ca9883a79dc16..8017eb6bb383b 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -428,10 +428,10 @@ class LegalizeLaunchFuncOpPattern public: LegalizeLaunchFuncOpPattern(const LLVMTypeConverter &typeConverter, bool kernelBarePtrCallConv, - bool typeCheckKernelArgs) + bool kernelIntersperseSizeCallConv) : ConvertOpToGpuRuntimeCallPattern(typeConverter), kernelBarePtrCallConv(kernelBarePtrCallConv), - typeCheckKernelArgs(typeCheckKernelArgs) {} + kernelIntersperseSizeCallConv(kernelIntersperseSizeCallConv) {} private: LogicalResult @@ -439,7 +439,7 @@ class LegalizeLaunchFuncOpPattern ConversionPatternRewriter &rewriter) const override; bool kernelBarePtrCallConv; - bool typeCheckKernelArgs; + bool kernelIntersperseSizeCallConv; }; /// A rewrite pattern to convert gpu.memcpy operations into a GPU runtime @@ -566,8 +566,9 @@ void GpuToLLVMConversionPass::runOnOperation() { populateFinalizeMemRefToLLVMConversionPatterns(converter, patterns); populateAsyncStructuralTypeConversionsAndLegality(converter, patterns, target); - populateGpuToLLVMConversionPatterns( - converter, patterns, kernelBarePtrCallConv, typeCheckKernelArgs); + populateGpuToLLVMConversionPatterns(converter, patterns, + kernelBarePtrCallConv, + kernelIntersperseSizeCallConv); if (failed( applyPartialConversion(getOperation(), target, std::move(patterns)))) @@ -970,33 +971,55 @@ LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite( else if (launchOp.getAsyncToken()) stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult(); - if (typeCheckKernelArgs) { - // The current non-bare-pointer ABI is a bad fit for `mgpuLaunchKernel`, - // which takes an untyped list of arguments. The type check here prevents - // accidentally violating the assumption made in vulkan-runtime-wrappers.cpp - // and creating a unchecked runtime ABI mismatch. - // TODO(https://github.com/llvm/llvm-project/issues/73457): Change the ABI - // here to remove the need for this type check. - for (Value arg : launchOp.getKernelOperands()) { - if (auto memrefTy = dyn_cast(arg.getType())) { - if (memrefTy.getRank() != 1 || - memrefTy.getElementTypeBitWidth() != 32) { - return rewriter.notifyMatchFailure( - launchOp, "Operand to launch op is not a rank-1 memref with " - "32-bit element type."); - } - } else { + // Lower the kernel operands to match kernel parameters. + // Note: If `useBarePtrCallConv` is set in the type converter's options, + // the value of `kernelBarePtrCallConv` will be ignored. + OperandRange origArguments = launchOp.getKernelOperands(); + SmallVector llvmArguments = getTypeConverter()->promoteOperands( + loc, origArguments, adaptor.getKernelOperands(), rewriter, + /*useBarePtrCallConv=*/kernelBarePtrCallConv); + SmallVector llvmArgumentsWithSizes; + + // Intersperse size information if requested. + if (kernelIntersperseSizeCallConv) { + if (origArguments.size() != llvmArguments.size()) { + // This shouldn't happen if the bare-pointer calling convention is used. + return rewriter.notifyMatchFailure( + launchOp, + "Cannot add sizes to arguments with one-to-many LLVM IR expansion."); + } + + llvmArgumentsWithSizes.reserve(llvmArguments.size() * 2); + for (auto [llvmArg, origArg] : zip_equal(llvmArguments, origArguments)) { + auto memrefTy = dyn_cast(origArg.getType()); + if (!memrefTy) { return rewriter.notifyMatchFailure( launchOp, "Operand to launch op is not a memref."); } + + if (!memrefTy.hasStaticShape() || + !memrefTy.getElementType().isIntOrFloat()) { + return rewriter.notifyMatchFailure( + launchOp, "Operand to launch op is not a memref with a static " + "shape and an integer or float element type."); + } + + unsigned bitwidth = memrefTy.getElementTypeBitWidth(); + if (bitwidth % 8 != 0) { + return rewriter.notifyMatchFailure( + launchOp, "Operand to launch op is not a memref with a " + "byte-aligned element type."); + } + + uint64_t staticSize = static_cast(bitwidth / 8) * + static_cast(memrefTy.getNumElements()); + + Value sizeArg = rewriter.create( + loc, getIndexType(), rewriter.getIndexAttr(staticSize)); + llvmArgumentsWithSizes.push_back(llvmArg); // Presumably a bare pointer. + llvmArgumentsWithSizes.push_back(sizeArg); } } - // Lower the kernel operands to match kernel parameters. - // Note: If `useBarePtrCallConv` is set in the type converter's options, - // the value of `kernelBarePtrCallConv` will be ignored. - SmallVector arguments = getTypeConverter()->promoteOperands( - loc, launchOp.getKernelOperands(), adaptor.getKernelOperands(), rewriter, - /*useBarePtrCallConv=*/kernelBarePtrCallConv); std::optional clusterSize = std::nullopt; if (launchOp.hasClusterSize()) { @@ -1010,7 +1033,9 @@ LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite( adaptor.getGridSizeZ()}, gpu::KernelDim3{adaptor.getBlockSizeX(), adaptor.getBlockSizeY(), adaptor.getBlockSizeZ()}, - adaptor.getDynamicSharedMemorySize(), arguments, stream, clusterSize); + adaptor.getDynamicSharedMemorySize(), + llvmArgumentsWithSizes.empty() ? llvmArguments : llvmArgumentsWithSizes, + stream, clusterSize); if (launchOp.getAsyncToken()) rewriter.replaceOp(launchOp, {stream}); else @@ -1760,10 +1785,9 @@ LogicalResult ConvertCreateBsrOpToGpuRuntimeCallPattern::matchAndRewrite( return success(); } -void mlir::populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter, - RewritePatternSet &patterns, - bool kernelBarePtrCallConv, - bool typeCheckKernelArgs) { +void mlir::populateGpuToLLVMConversionPatterns( + LLVMTypeConverter &converter, RewritePatternSet &patterns, + bool kernelBarePtrCallConv, bool kernelIntersperseSizeCallConv) { addOpaquePointerConversion(converter); addOpaquePointerConversion(converter); addOpaquePointerConversion(converter); @@ -1801,7 +1825,7 @@ void mlir::populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter, ConvertSpMatGetSizeOpToGpuRuntimeCallPattern, ConvertSetCsrPointersOpToGpuRuntimeCallPattern>(converter); patterns.add(converter, kernelBarePtrCallConv, - typeCheckKernelArgs); + kernelIntersperseSizeCallConv); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/GPUCommon/lower-launch-func-bare-ptr-intersperse-size.mlir b/mlir/test/Conversion/GPUCommon/lower-launch-func-bare-ptr-intersperse-size.mlir new file mode 100644 index 0000000000000..171b13da22713 --- /dev/null +++ b/mlir/test/Conversion/GPUCommon/lower-launch-func-bare-ptr-intersperse-size.mlir @@ -0,0 +1,25 @@ +// RUN: mlir-opt %s --gpu-to-llvm="use-bare-pointers-for-kernels=1 intersperse-sizes-for-kernels=1" -split-input-file | FileCheck %s + +module attributes {gpu.container_module, spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>>} { + llvm.func @malloc(i64) -> !llvm.ptr + gpu.binary @kernels [#gpu.object<#spirv.target_env<#spirv.vce, #spirv.resource_limits<>>, "">] + func.func @main() attributes {llvm.emit_c_interface} { + // CHECK: [[RANK1UMD:%.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %rank1UndefMemrefDescriptor = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: [[RANK2UMD:%.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + %rank2UndefMemrefDescriptor = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + %c1 = arith.constant 1 : index + // CHECK: [[PTR1:%.*]] = llvm.extractvalue [[RANK1UMD]][1] + // CHECK: [[PTR2:%.*]] = llvm.extractvalue [[RANK2UMD]][1] + // CHECK: [[PTR3:%.*]] = llvm.extractvalue [[RANK2UMD]][1] + // CHECK: [[SIZE1:%.*]] = llvm.mlir.constant(32 : index) : i64 + // CHECK: [[SIZE2:%.*]] = llvm.mlir.constant(256 : index) : i64 + // CHECK: [[SIZE3:%.*]] = llvm.mlir.constant(48 : index) : i64 + %6 = builtin.unrealized_conversion_cast %rank1UndefMemrefDescriptor : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> to memref<8xf32> + %10 = builtin.unrealized_conversion_cast %rank2UndefMemrefDescriptor : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<8x8xi32> + %14 = builtin.unrealized_conversion_cast %rank2UndefMemrefDescriptor : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<4x12xi8> + // CHECK: gpu.launch_func @kernels::@kernel_add blocks in ({{.*}}) threads in ({{.*}}) : i64 args([[PTR1]] : !llvm.ptr, [[SIZE1]] : i64, [[PTR2]] : !llvm.ptr, [[SIZE2]] : i64, [[PTR3]] : !llvm.ptr, [[SIZE3]] : i64) + gpu.launch_func @kernels::@kernel_add blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%6 : memref<8xf32>, %10 : memref<8x8xi32>, %14 : memref<4x12xi8>) + return + } +} diff --git a/mlir/test/lib/Pass/TestVulkanRunnerPipeline.cpp b/mlir/test/lib/Pass/TestVulkanRunnerPipeline.cpp index a3624eb31e26e..fa6a5a7140d8c 100644 --- a/mlir/test/lib/Pass/TestVulkanRunnerPipeline.cpp +++ b/mlir/test/lib/Pass/TestVulkanRunnerPipeline.cpp @@ -68,12 +68,11 @@ void buildTestVulkanRunnerPipeline(OpPassManager &passManager, passManager.addPass(createFinalizeMemRefToLLVMConversionPass()); passManager.nest().addPass( LLVM::createRequestCWrappersPass()); - // vulkan-runtime-wrappers.cpp uses the non-bare-pointer calling convention, - // and the type check is needed to prevent accidental ABI mismatches. + // vulkan-runtime-wrappers.cpp requires these calling convention options. GpuToLLVMConversionPassOptions opt; opt.hostBarePtrCallConv = false; - opt.kernelBarePtrCallConv = false; - opt.typeCheckKernelArgs = true; + opt.kernelBarePtrCallConv = true; + opt.kernelIntersperseSizeCallConv = true; passManager.addPass(createGpuToLLVMConversionPass(opt)); } } diff --git a/mlir/test/mlir-vulkan-runner/addi.mlir b/mlir/test/mlir-vulkan-runner/addi.mlir index 7e212a4fb179c..e60eb7696770a 100644 --- a/mlir/test/mlir-vulkan-runner/addi.mlir +++ b/mlir/test/mlir-vulkan-runner/addi.mlir @@ -1,5 +1,5 @@ -// RUN: mlir-opt %s -test-vulkan-runner-pipeline \ -// RUN: | mlir-vulkan-runner - --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils --entry-point-result=void | FileCheck %s +// RUN: mlir-opt %s -test-vulkan-runner-pipeline=to-llvm \ +// RUN: | mlir-cpu-runner - --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils --entry-point-result=void | FileCheck %s // CHECK-COUNT-64: [3, 3, 3, 3, 3, 3, 3, 3] module attributes { diff --git a/mlir/test/mlir-vulkan-runner/addi8.mlir b/mlir/test/mlir-vulkan-runner/addi8.mlir index e0b1a8e8875c0..6374f29fd0ed2 100644 --- a/mlir/test/mlir-vulkan-runner/addi8.mlir +++ b/mlir/test/mlir-vulkan-runner/addi8.mlir @@ -1,5 +1,5 @@ -// RUN: mlir-opt %s -test-vulkan-runner-pipeline \ -// RUN: | mlir-vulkan-runner - --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils --entry-point-result=void | FileCheck %s +// RUN: mlir-opt %s -test-vulkan-runner-pipeline=to-llvm \ +// RUN: | mlir-cpu-runner - --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils --entry-point-result=void | FileCheck %s // CHECK-COUNT-64: [3, 3, 3, 3, 3, 3, 3, 3] module attributes { diff --git a/mlir/test/mlir-vulkan-runner/mulf.mlir b/mlir/test/mlir-vulkan-runner/mulf.mlir index 22fa034a9b455..bd1b8d9abf4c9 100644 --- a/mlir/test/mlir-vulkan-runner/mulf.mlir +++ b/mlir/test/mlir-vulkan-runner/mulf.mlir @@ -1,5 +1,5 @@ -// RUN: mlir-opt %s -test-vulkan-runner-pipeline \ -// RUN: | mlir-vulkan-runner - --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils --entry-point-result=void | FileCheck %s +// RUN: mlir-opt %s -test-vulkan-runner-pipeline=to-llvm \ +// RUN: | mlir-cpu-runner - --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils --entry-point-result=void | FileCheck %s // CHECK-COUNT-4: [6, 6, 6, 6] module attributes { diff --git a/mlir/test/mlir-vulkan-runner/subf.mlir b/mlir/test/mlir-vulkan-runner/subf.mlir index 23496ef3abc00..e8cee0e021a27 100644 --- a/mlir/test/mlir-vulkan-runner/subf.mlir +++ b/mlir/test/mlir-vulkan-runner/subf.mlir @@ -1,5 +1,5 @@ -// RUN: mlir-opt %s -test-vulkan-runner-pipeline \ -// RUN: | mlir-vulkan-runner - --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils --entry-point-result=void | FileCheck %s +// RUN: mlir-opt %s -test-vulkan-runner-pipeline=to-llvm \ +// RUN: | mlir-cpu-runner - --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils --entry-point-result=void | FileCheck %s // CHECK-COUNT-32: [2.2, 2.2, 2.2, 2.2] module attributes { diff --git a/mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp b/mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp index ffd1114cec6aa..8d1bac3b6f286 100644 --- a/mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp +++ b/mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp @@ -169,26 +169,21 @@ mgpuLaunchKernel(void *vkKernel, size_t gridX, size_t gridY, size_t gridZ, void ** /*extra*/, size_t paramsCount) { auto manager = static_cast(vkRuntimeManager); - // The non-bare-pointer memref ABI interacts badly with mgpuLaunchKernel's - // signature: - // - The memref descriptor struct gets split into several elements, each - // passed as their own "param". - // - No metadata is provided as to the rank or element type/size of a memref. - // Here we assume that all MemRefs have rank 1 and an element size of - // 4 bytes. This means each descriptor struct will have five members. - // TODO(https://github.com/llvm/llvm-project/issues/73457): Refactor the - // ABI/API of mgpuLaunchKernel to use a different ABI for memrefs, so - // that other memref types can also be used. This will allow migrating - // the remaining tests and removal of mlir-vulkan-runner. - const size_t paramsPerMemRef = 5; + // GpuToLLVMConversionPass with the kernelBarePtrCallConv and + // kernelIntersperseSizeCallConv options will set up the params array like: + // { &memref_ptr0, &memref_size0, &memref_ptr1, &memref_size1, ... } + const size_t paramsPerMemRef = 2; if (paramsCount % paramsPerMemRef != 0) { - abort(); + abort(); // This would indicate a serious calling convention mismatch. } const DescriptorSetIndex setIndex = 0; BindingIndex bindIndex = 0; for (size_t i = 0; i < paramsCount; i += paramsPerMemRef) { - auto memref = static_cast *>(params[i]); - bindMemRef(manager, setIndex, bindIndex, memref); + void *memrefBufferBasePtr = *static_cast(params[i + 0]); + size_t memrefBufferSize = *static_cast(params[i + 1]); + VulkanHostMemoryBuffer memBuffer{memrefBufferBasePtr, + static_cast(memrefBufferSize)}; + manager->setResourceData(setIndex, bindIndex, memBuffer); ++bindIndex; }