-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][spirv] Add GpuToLLVM cconv suited to Vulkan, migrate last tests #123384
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -428,18 +428,18 @@ class LegalizeLaunchFuncOpPattern | |||||||||
public: | ||||||||||
LegalizeLaunchFuncOpPattern(const LLVMTypeConverter &typeConverter, | ||||||||||
bool kernelBarePtrCallConv, | ||||||||||
bool typeCheckKernelArgs) | ||||||||||
bool kernelIntersperseSizeCallConv) | ||||||||||
: ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp>(typeConverter), | ||||||||||
kernelBarePtrCallConv(kernelBarePtrCallConv), | ||||||||||
typeCheckKernelArgs(typeCheckKernelArgs) {} | ||||||||||
kernelIntersperseSizeCallConv(kernelIntersperseSizeCallConv) {} | ||||||||||
|
||||||||||
private: | ||||||||||
LogicalResult | ||||||||||
matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor, | ||||||||||
ConversionPatternRewriter &rewriter) const override; | ||||||||||
|
||||||||||
bool kernelBarePtrCallConv; | ||||||||||
bool typeCheckKernelArgs; | ||||||||||
bool kernelIntersperseSizeCallConv; | ||||||||||
}; | ||||||||||
|
||||||||||
/// A rewrite pattern to convert gpu.memcpy operations into a GPU runtime | ||||||||||
|
@@ -566,8 +566,9 @@ void GpuToLLVMConversionPass::runOnOperation() { | |||||||||
populateFinalizeMemRefToLLVMConversionPatterns(converter, patterns); | ||||||||||
populateAsyncStructuralTypeConversionsAndLegality(converter, patterns, | ||||||||||
target); | ||||||||||
populateGpuToLLVMConversionPatterns( | ||||||||||
converter, patterns, kernelBarePtrCallConv, typeCheckKernelArgs); | ||||||||||
populateGpuToLLVMConversionPatterns(converter, patterns, | ||||||||||
kernelBarePtrCallConv, | ||||||||||
kernelIntersperseSizeCallConv); | ||||||||||
|
||||||||||
if (failed( | ||||||||||
applyPartialConversion(getOperation(), target, std::move(patterns)))) | ||||||||||
|
@@ -970,33 +971,55 @@ LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite( | |||||||||
else if (launchOp.getAsyncToken()) | ||||||||||
stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult(); | ||||||||||
|
||||||||||
if (typeCheckKernelArgs) { | ||||||||||
// The current non-bare-pointer ABI is a bad fit for `mgpuLaunchKernel`, | ||||||||||
// which takes an untyped list of arguments. The type check here prevents | ||||||||||
// accidentally violating the assumption made in vulkan-runtime-wrappers.cpp | ||||||||||
// and creating a unchecked runtime ABI mismatch. | ||||||||||
// TODO(https://github.com/llvm/llvm-project/issues/73457): Change the ABI | ||||||||||
// here to remove the need for this type check. | ||||||||||
for (Value arg : launchOp.getKernelOperands()) { | ||||||||||
if (auto memrefTy = dyn_cast<MemRefType>(arg.getType())) { | ||||||||||
if (memrefTy.getRank() != 1 || | ||||||||||
memrefTy.getElementTypeBitWidth() != 32) { | ||||||||||
return rewriter.notifyMatchFailure( | ||||||||||
launchOp, "Operand to launch op is not a rank-1 memref with " | ||||||||||
"32-bit element type."); | ||||||||||
} | ||||||||||
} else { | ||||||||||
// Lower the kernel operands to match kernel parameters. | ||||||||||
// Note: If `useBarePtrCallConv` is set in the type converter's options, | ||||||||||
// the value of `kernelBarePtrCallConv` will be ignored. | ||||||||||
OperandRange origArguments = launchOp.getKernelOperands(); | ||||||||||
SmallVector<Value, 8> llvmArguments = getTypeConverter()->promoteOperands( | ||||||||||
loc, origArguments, adaptor.getKernelOperands(), rewriter, | ||||||||||
/*useBarePtrCallConv=*/kernelBarePtrCallConv); | ||||||||||
SmallVector<Value, 8> llvmArgumentsWithSizes; | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: don't need to explicitly specify There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I prefer the explicit size in this case because |
||||||||||
|
||||||||||
// Intersperse size information if requested. | ||||||||||
if (kernelIntersperseSizeCallConv) { | ||||||||||
if (origArguments.size() != llvmArguments.size()) { | ||||||||||
// This shouldn't happen if the bare-pointer calling convention is used. | ||||||||||
return rewriter.notifyMatchFailure( | ||||||||||
launchOp, | ||||||||||
"Cannot add sizes to arguments with one-to-many LLVM IR expansion."); | ||||||||||
} | ||||||||||
|
||||||||||
llvmArgumentsWithSizes.reserve(llvmArguments.size() * 2); | ||||||||||
for (auto [llvmArg, origArg] : zip_equal(llvmArguments, origArguments)) { | ||||||||||
auto memrefTy = dyn_cast<MemRefType>(origArg.getType()); | ||||||||||
if (!memrefTy) { | ||||||||||
return rewriter.notifyMatchFailure( | ||||||||||
launchOp, "Operand to launch op is not a memref."); | ||||||||||
} | ||||||||||
|
||||||||||
if (!memrefTy.hasStaticShape() || | ||||||||||
!memrefTy.getElementType().isIntOrFloat()) { | ||||||||||
return rewriter.notifyMatchFailure( | ||||||||||
launchOp, "Operand to launch op is not a memref with a static " | ||||||||||
"shape and an integer or float element type."); | ||||||||||
} | ||||||||||
|
||||||||||
unsigned bitwidth = memrefTy.getElementTypeBitWidth(); | ||||||||||
if (bitwidth % 8 != 0) { | ||||||||||
return rewriter.notifyMatchFailure( | ||||||||||
launchOp, "Operand to launch op is not a memref with a " | ||||||||||
"byte-aligned element type."); | ||||||||||
} | ||||||||||
Comment on lines
+1000
to
+1012
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel like I can't be the first person to have wanted to do this, and I'm not sure I'm checking all the necessary conditions. Is there some handy function for this that I'm missing? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't recall seeing a helper for this either. Most code uses There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there some rule somewhere that guarantees elements smaller than 8 bits get padded to 8? I doubt the Vulkan target can actually do anything useful with non-multiple-of-8-bits types though, so maybe this simple check is fine. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was only able to find this: llvm-project/mlir/lib/IR/AttributeDetail.h Lines 113 to 116 in 8552c49
I think it's fine to reject such memrefs. If someone wants to support them in the future, we can work with them on adding that. |
||||||||||
|
||||||||||
uint64_t staticSize = static_cast<uint64_t>(bitwidth / 8) * | ||||||||||
static_cast<uint64_t>(memrefTy.getNumElements()); | ||||||||||
|
||||||||||
Value sizeArg = rewriter.create<LLVM::ConstantOp>( | ||||||||||
loc, getIndexType(), rewriter.getIndexAttr(staticSize)); | ||||||||||
llvmArgumentsWithSizes.push_back(llvmArg); // Presumably a bare pointer. | ||||||||||
llvmArgumentsWithSizes.push_back(sizeArg); | ||||||||||
} | ||||||||||
} | ||||||||||
// Lower the kernel operands to match kernel parameters. | ||||||||||
// Note: If `useBarePtrCallConv` is set in the type converter's options, | ||||||||||
// the value of `kernelBarePtrCallConv` will be ignored. | ||||||||||
SmallVector<Value, 4> arguments = getTypeConverter()->promoteOperands( | ||||||||||
loc, launchOp.getKernelOperands(), adaptor.getKernelOperands(), rewriter, | ||||||||||
/*useBarePtrCallConv=*/kernelBarePtrCallConv); | ||||||||||
|
||||||||||
std::optional<gpu::KernelDim3> clusterSize = std::nullopt; | ||||||||||
if (launchOp.hasClusterSize()) { | ||||||||||
|
@@ -1010,7 +1033,9 @@ LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite( | |||||||||
adaptor.getGridSizeZ()}, | ||||||||||
gpu::KernelDim3{adaptor.getBlockSizeX(), adaptor.getBlockSizeY(), | ||||||||||
adaptor.getBlockSizeZ()}, | ||||||||||
adaptor.getDynamicSharedMemorySize(), arguments, stream, clusterSize); | ||||||||||
adaptor.getDynamicSharedMemorySize(), | ||||||||||
llvmArgumentsWithSizes.empty() ? llvmArguments : llvmArgumentsWithSizes, | ||||||||||
stream, clusterSize); | ||||||||||
if (launchOp.getAsyncToken()) | ||||||||||
rewriter.replaceOp(launchOp, {stream}); | ||||||||||
else | ||||||||||
|
@@ -1760,10 +1785,9 @@ LogicalResult ConvertCreateBsrOpToGpuRuntimeCallPattern::matchAndRewrite( | |||||||||
return success(); | ||||||||||
} | ||||||||||
|
||||||||||
void mlir::populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter, | ||||||||||
RewritePatternSet &patterns, | ||||||||||
bool kernelBarePtrCallConv, | ||||||||||
bool typeCheckKernelArgs) { | ||||||||||
void mlir::populateGpuToLLVMConversionPatterns( | ||||||||||
LLVMTypeConverter &converter, RewritePatternSet &patterns, | ||||||||||
bool kernelBarePtrCallConv, bool kernelIntersperseSizeCallConv) { | ||||||||||
addOpaquePointerConversion<gpu::AsyncTokenType>(converter); | ||||||||||
addOpaquePointerConversion<gpu::SparseDnTensorHandleType>(converter); | ||||||||||
addOpaquePointerConversion<gpu::SparseSpMatHandleType>(converter); | ||||||||||
|
@@ -1801,7 +1825,7 @@ void mlir::populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter, | |||||||||
ConvertSpMatGetSizeOpToGpuRuntimeCallPattern, | ||||||||||
ConvertSetCsrPointersOpToGpuRuntimeCallPattern>(converter); | ||||||||||
patterns.add<LegalizeLaunchFuncOpPattern>(converter, kernelBarePtrCallConv, | ||||||||||
typeCheckKernelArgs); | ||||||||||
kernelIntersperseSizeCallConv); | ||||||||||
} | ||||||||||
|
||||||||||
//===----------------------------------------------------------------------===// | ||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
// RUN: mlir-opt %s --gpu-to-llvm="use-bare-pointers-for-kernels=1 intersperse-sizes-for-kernels=1" -split-input-file | FileCheck %s | ||
|
||
module attributes {gpu.container_module, spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>} { | ||
llvm.func @malloc(i64) -> !llvm.ptr | ||
gpu.binary @kernels [#gpu.object<#spirv.target_env<#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>, "">] | ||
func.func @main() attributes {llvm.emit_c_interface} { | ||
// CHECK: [[RANK1UMD:%.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> | ||
%rank1UndefMemrefDescriptor = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> | ||
// CHECK: [[RANK2UMD:%.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> | ||
%rank2UndefMemrefDescriptor = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> | ||
%c1 = arith.constant 1 : index | ||
// CHECK: [[PTR1:%.*]] = llvm.extractvalue [[RANK1UMD]][1] | ||
// CHECK: [[PTR2:%.*]] = llvm.extractvalue [[RANK2UMD]][1] | ||
// CHECK: [[PTR3:%.*]] = llvm.extractvalue [[RANK2UMD]][1] | ||
// CHECK: [[SIZE1:%.*]] = llvm.mlir.constant(32 : index) : i64 | ||
// CHECK: [[SIZE2:%.*]] = llvm.mlir.constant(256 : index) : i64 | ||
// CHECK: [[SIZE3:%.*]] = llvm.mlir.constant(48 : index) : i64 | ||
%6 = builtin.unrealized_conversion_cast %rank1UndefMemrefDescriptor : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> to memref<8xf32> | ||
%10 = builtin.unrealized_conversion_cast %rank2UndefMemrefDescriptor : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<8x8xi32> | ||
%14 = builtin.unrealized_conversion_cast %rank2UndefMemrefDescriptor : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<4x12xi8> | ||
// CHECK: gpu.launch_func @kernels::@kernel_add blocks in ({{.*}}) threads in ({{.*}}) : i64 args([[PTR1]] : !llvm.ptr, [[SIZE1]] : i64, [[PTR2]] : !llvm.ptr, [[SIZE2]] : i64, [[PTR3]] : !llvm.ptr, [[SIZE3]] : i64) | ||
gpu.launch_func @kernels::@kernel_add blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%6 : memref<8xf32>, %10 : memref<8x8xi32>, %14 : memref<4x12xi8>) | ||
return | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe better to make it enum instead of 2 bools?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That might not be a bad idea, but I am really unsure if it's possible for an option to be an enum, or what I have to do if I want that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Specifically the
Option
defined inPasses.td
. I know that it's somehow leveraging thecl
infrastructure so enums "should" be possible but I haven't managed to figure it out from staring at the headers and tablegen definitions, there's too many layers.)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is the example https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Conversion/Passes.td#L1145, but we can do it in separate PR