-
Notifications
You must be signed in to change notification settings - Fork 14.6k
[MLIR][GPU-LLVM] Add in-pass signature update option for opencl kernels #105664
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
2e8465c
9238460
7212344
1961dff
53baba2
4b18490
dfbcd23
b39e055
d5519a5
598f9b1
3fde4ae
71006c6
a1666d6
c2b22c7
a0a5a00
abb6ef9
0f79242
69ed255
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,12 +9,14 @@ | |
#include "mlir/Conversion/GPUToLLVMSPV/GPUToLLVMSPVPass.h" | ||
|
||
#include "../GPUCommon/GPUOpsLowering.h" | ||
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" | ||
#include "mlir/Conversion/GPUCommon/AttrToSPIRVConverter.h" | ||
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h" | ||
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" | ||
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h" | ||
#include "mlir/Conversion/LLVMCommon/Pattern.h" | ||
#include "mlir/Conversion/LLVMCommon/TypeConverter.h" | ||
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" | ||
#include "mlir/Conversion/SPIRVCommon/AttrToLLVMConverter.h" | ||
#include "mlir/Dialect/GPU/IR/GPUDialect.h" | ||
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" | ||
|
@@ -34,6 +36,8 @@ | |
#include "llvm/ADT/TypeSwitch.h" | ||
#include "llvm/Support/FormatVariadic.h" | ||
|
||
#define DEBUG_TYPE "gpu-to-llvm-spv" | ||
|
||
using namespace mlir; | ||
|
||
namespace mlir { | ||
|
@@ -306,6 +310,36 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> { | |
} | ||
}; | ||
|
||
class MemorySpaceToOpenCLMemorySpaceConverter final : public TypeConverter { | ||
public: | ||
MemorySpaceToOpenCLMemorySpaceConverter() { | ||
addConversion([](Type t) { return t; }); | ||
addConversion([this](BaseMemRefType memRefType) -> std::optional<Type> { | ||
kurapov-peter marked this conversation as resolved.
Show resolved
Hide resolved
|
||
// Attach global addr space attribute to memrefs with no addr space attr | ||
Attribute memSpaceAttr = memRefType.getMemorySpace(); | ||
if (memSpaceAttr) | ||
return std::nullopt; | ||
|
||
auto addrSpaceAttr = gpu::AddressSpaceAttr::get( | ||
memRefType.getContext(), gpu::AddressSpace::Global); | ||
if (auto rankedType = dyn_cast<MemRefType>(memRefType)) { | ||
return MemRefType::get(memRefType.getShape(), | ||
memRefType.getElementType(), | ||
rankedType.getLayout(), addrSpaceAttr); | ||
} | ||
return UnrankedMemRefType::get(memRefType.getElementType(), | ||
addrSpaceAttr); | ||
}); | ||
addConversion([this](FunctionType type) { | ||
auto inputs = llvm::map_to_vector( | ||
type.getInputs(), [this](Type ty) { return convertType(ty); }); | ||
auto results = llvm::map_to_vector( | ||
type.getResults(), [this](Type ty) { return convertType(ty); }); | ||
return FunctionType::get(type.getContext(), inputs, results); | ||
}); | ||
} | ||
}; | ||
victor-eds marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
//===----------------------------------------------------------------------===// | ||
// GPU To LLVM-SPV Pass. | ||
//===----------------------------------------------------------------------===// | ||
|
@@ -325,16 +359,45 @@ struct GPUToLLVMSPVConversionPass final | |
LLVMTypeConverter converter(context, options); | ||
LLVMConversionTarget target(*context); | ||
|
||
if (forceOpenclAddressSpaces) { | ||
MemorySpaceToOpenCLMemorySpaceConverter converter; | ||
AttrTypeReplacer replacer; | ||
replacer.addReplacement([&converter](BaseMemRefType origType) | ||
-> std::optional<BaseMemRefType> { | ||
return converter.convertType<BaseMemRefType>(origType); | ||
}); | ||
|
||
replacer.recursivelyReplaceElementsIn(getOperation(), | ||
/*replaceAttrs=*/true, | ||
/*replaceLocs=*/false, | ||
/*replaceTypes=*/true); | ||
} | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks like an overkill, isn't there another way? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tried a natural way - putting additional conversion to the llvm converter but it didn't work out. First, it expects a legal value as the output, so I can't just add a memref. I tried to run So instead, I added this. It doesn't have to think about legality and all since it's not an llvm converter. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I took a look into this once more. It seems like if I were to reuse There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. True, conversion to LLVM assumes memspace |
||
target.addIllegalOp<gpu::BarrierOp, gpu::BlockDimOp, gpu::BlockIdOp, | ||
gpu::GPUFuncOp, gpu::GlobalIdOp, gpu::GridDimOp, | ||
gpu::ReturnOp, gpu::ShuffleOp, gpu::ThreadIdOp>(); | ||
|
||
populateGpuToLLVMSPVConversionPatterns(converter, patterns); | ||
populateFuncToLLVMConversionPatterns(converter, patterns); | ||
populateFinalizeMemRefToLLVMConversionPatterns(converter, patterns); | ||
kurapov-peter marked this conversation as resolved.
Show resolved
Hide resolved
|
||
populateGpuMemorySpaceAttributeConversions(converter); | ||
|
||
if (failed(applyPartialConversion(getOperation(), target, | ||
std::move(patterns)))) | ||
signalPassFailure(); | ||
|
||
// `func.func`s are not handled by the lowering, so need a proper calling | ||
// convention set separately. | ||
getOperation().walk([](LLVM::LLVMFuncOp f) { | ||
if (f.getCConv() == LLVM::CConv::C) { | ||
f.setCConv(LLVM::CConv::SPIR_FUNC); | ||
} | ||
}); | ||
getOperation().walk([](LLVM::CallOp c) { | ||
if (c.getCConv() == LLVM::CConv::C) { | ||
c.setCConv(LLVM::CConv::SPIR_FUNC); | ||
} | ||
}); | ||
kurapov-peter marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
}; | ||
} // namespace | ||
|
Uh oh!
There was an error while loading. Please reload this page.