From d8fc2be7009c88b835f643e2583709221e7a78ca Mon Sep 17 00:00:00 2001 From: Andrey Pavlenko Date: Tue, 10 Sep 2024 11:53:47 +0200 Subject: [PATCH 1/8] Convert a subset of GPU dialect ops to the GPU OpenCL runtime calls Added a new path, that converts the following GPU dialect ops to the corresponding callsof the GPU OpenCL runtime functions (to be implemented later): - gpu.alloc, gpu.dealloc, gpu.memcpy and gpu.launch The first argument of each runtime's function is a pointer to the context structure. This is not a cl_context, this is an execution context, i.e. a single execution of the module's main function. It contains the queue, wait list (in case of out-of-order mode) and someother data, required for the module ops execution. It's expected, that the pointer to the context is passed to the module's main function as the last argument of type memref with zero dims. For each gpu.launch operation, 2 additional functions are created: - getXXXKernel(): returns the kernel pointer, stored in a global variable. If it's NULL, calls createXXXKernel(). - createXXXKernel(): Calls the runtime's function, that creates a kernel. SPIRV, kernel name, and sizes are passed to the function. The returned pointer is saved in the global var using `llvm.cmpxchg`, to make sure it doesn't overwrite a kernel, created by another thread. Finally, a destructor function is created, that calls the corresponding runtime's kernel destroy function and passes the pointers, stored in the global vars. This function must be called by themodule owner, when destroying the module. The kernel is not a cl_kernel, but a runtime's internal structure, that contains a compiledcl_program, preconfigured cl_kernel and other data, required for the kernel execution. The runtime's launch function clones the preconfigured kernel, sets the arguments and enqueues a command to execute the kernel. --- .../GPURuntime/GpuOclRuntime.h | 27 + include/gc/Transforms/Passes.td | 7 + lib/gc/Transforms/GPU/CMakeLists.txt | 1 + lib/gc/Transforms/GPU/GpuToGpuOcl.cpp | 531 ++++++++++++++++++ 4 files changed, 566 insertions(+) create mode 100644 include/gc/ExecutionEngine/GPURuntime/GpuOclRuntime.h create mode 100644 lib/gc/Transforms/GPU/GpuToGpuOcl.cpp diff --git a/include/gc/ExecutionEngine/GPURuntime/GpuOclRuntime.h b/include/gc/ExecutionEngine/GPURuntime/GpuOclRuntime.h new file mode 100644 index 000000000..fa576c4e4 --- /dev/null +++ b/include/gc/ExecutionEngine/GPURuntime/GpuOclRuntime.h @@ -0,0 +1,27 @@ +//===-- GpuOclRuntime.h - GPU OpenCL runtime --------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef GC_GPUOCLMODULE_H +#define GC_GPUOCLMODULE_H + +#define GC_GPU_OCL_MALLOC "gcGpuOclMaloc" +#define GC_GPU_OCL_DEALLOC "gcGpuOclDealloc" +#define GC_GPU_OCL_MEMCPY "gcGpuOclMemcpy" +#define GC_GPU_OCL_KERNEL_CREATE "gcGpuOclKernelCreate" +#define GC_GPU_OCL_KERNEL_DESTROY "gcGpuOclKernelDestroy" +#define GC_GPU_OCL_KERNEL_LAUNCH "gcGpuOclKernelLaunch" +#define GC_GPU_OCL_MOD_DESTRUCTOR "gcGpuOclModuleDestructor" + +#ifndef GC_GPU_OCL_DEF_ONLY + +// TBD + +#else +#undef GC_GPU_OCL_DEF_ONLY +#endif +#endif diff --git a/include/gc/Transforms/Passes.td b/include/gc/Transforms/Passes.td index 5151a0335..6ba973361 100644 --- a/include/gc/Transforms/Passes.td +++ b/include/gc/Transforms/Passes.td @@ -93,6 +93,13 @@ def LinalgToXeGPU : Pass<"linalg-to-xegpu", "func::FuncOp"> { "DPAS register block sizes MxNxK">, ]; } + +def GpuToGpuOcl : Pass<"gpu-to-gpuocl", "ModuleOp"> { + let summary = "Convert the GPU operations to GpuOclRuntime calls."; + let description = [{ + Convert the gpu alloc, dealloc, memcpy and launch operations to GpuOclRuntime calls. + }]; +} #endif // GC_USE_IMEX def IterativeTilingAndFusion : Pass<"iterative-tiling-and-fusion", diff --git a/lib/gc/Transforms/GPU/CMakeLists.txt b/lib/gc/Transforms/GPU/CMakeLists.txt index 13f9c2981..5fd06e8db 100644 --- a/lib/gc/Transforms/GPU/CMakeLists.txt +++ b/lib/gc/Transforms/GPU/CMakeLists.txt @@ -1,4 +1,5 @@ gc_add_mlir_library(GcGpuPasses + GpuToGpuOcl.cpp LinalgToXeGPU.cpp Pipeline.cpp diff --git a/lib/gc/Transforms/GPU/GpuToGpuOcl.cpp b/lib/gc/Transforms/GPU/GpuToGpuOcl.cpp new file mode 100644 index 000000000..aa42115c9 --- /dev/null +++ b/lib/gc/Transforms/GPU/GpuToGpuOcl.cpp @@ -0,0 +1,531 @@ +//===-- GpuToGpuOcl.cpp - GpuToGpuOcl path ----------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#include + +#define GC_GPU_OCL_DEF_ONLY +#include "gc/ExecutionEngine/GPURuntime/GpuOclRuntime.h" + +#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/Passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/Transforms/Passes.h" + +using namespace mlir; + +namespace mlir { +namespace gc { +#define GEN_PASS_DECL_GPUTOGPUOCL +#define GEN_PASS_DEF_GPUTOGPUOCL +#include "gc/Transforms/Passes.h.inc" +} // namespace gc +} // namespace mlir + +namespace { + +LLVM::CallOp funcCall(OpBuilder &builder, const StringRef name, + const Type returnType, const ArrayRef argTypes, + const Location loc, const ArrayRef arguments, + bool isVarArg = false) { + auto module = builder.getBlock()->getParent()->getParentOfType(); + auto function = module.lookupSymbol(name); + if (!function) { + auto type = LLVM::LLVMFunctionType::get(returnType, argTypes, isVarArg); + function = OpBuilder::atBlockEnd(module.getBody()) + .create(loc, name, type); + } + return builder.create(loc, function, arguments); +} + +// Assuming that the pointer to GcGpuOclContext is passed as the last +// memref with zero dims argument of the current function. +Value getCtxPtr(const OpBuilder &rewriter) { + auto func = + rewriter.getBlock()->getParent()->getParentOfType(); + return func.getArgument(func.getNumArguments() - 3); +} + +struct Helper final { + LLVMTypeConverter &converter; + Type voidType; + Type ptrType; + Type idxType; + mutable std::set> kernelNames; + + explicit Helper(MLIRContext *ctx, LLVMTypeConverter &converter) + : converter(converter), voidType(LLVM::LLVMVoidType::get(ctx)), + ptrType(LLVM::LLVMPointerType::get(ctx)), + idxType(IntegerType::get(ctx, converter.getPointerBitwidth())) {} + + Value idxConstant(OpBuilder &rewriter, const Location loc, + size_t value) const { + return rewriter.create( + loc, idxType, + rewriter.getIntegerAttr(idxType, static_cast(value))); + } + + void destroyKernels(OpBuilder &rewriter, Location loc, + ArrayRef kernelPtrs) const { + auto size = idxConstant(rewriter, loc, kernelPtrs.size()); + auto kernelPtrsArray = + rewriter.create(loc, ptrType, ptrType, size); + for (size_t i = 0, n = kernelPtrs.size(); i < n; i++) { + auto elementPtr = + rewriter.create(loc, ptrType, ptrType, kernelPtrsArray, + idxConstant(rewriter, loc, i)); + rewriter.create(loc, kernelPtrs[i], elementPtr); + } + + funcCall(rewriter, GC_GPU_OCL_KERNEL_DESTROY, voidType, {idxType, ptrType}, + loc, {size, kernelPtrsArray}); + } +}; + +template +struct ConvertOpPattern : ConvertOpToLLVMPattern { + const Helper &helper; + + explicit ConvertOpPattern(const Helper &helper) + : ConvertOpToLLVMPattern(helper.converter), helper(helper) {} +}; + +struct ConvertAlloc final : ConvertOpPattern { + explicit ConvertAlloc(const Helper &helper) : ConvertOpPattern(helper) {} + + LogicalResult + matchAndRewrite(gpu::AllocOp allocOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = allocOp.getLoc(); + MemRefType type = allocOp.getType(); + auto shape = type.getShape(); + auto dynamics = adaptor.getDynamicSizes(); + + if (shape.empty() || dynamics.empty()) { + int64_t staticSize; + if (shape.empty()) { + staticSize = 0; + } else { + staticSize = type.getElementType().getIntOrFloatBitWidth() / 8; + for (auto dim : shape) { + assert(dim != ShapedType::kDynamic); + staticSize *= dim; + } + } + auto size = helper.idxConstant(rewriter, loc, staticSize); + auto ptr = funcCall(rewriter, GC_GPU_OCL_MALLOC, helper.ptrType, + {helper.ptrType, helper.idxType}, loc, + {getCtxPtr(rewriter), size}) + .getResult(); + Value replacement = MemRefDescriptor::fromStaticShape( + rewriter, loc, helper.converter, type, ptr, ptr); + rewriter.replaceOp(allocOp, replacement); + return success(); + } + + auto ndims = shape.size(); + SmallVector newShape; + SmallVector newStrides(ndims); + auto staticSize = type.getElementType().getIntOrFloatBitWidth() / 8; + auto size = dynamics[0]; + + auto idxMul = [&](Value x, Value y) -> Value { + if (auto xConst = getConstantIntValue(x)) { + if (auto yConst = getConstantIntValue(y)) { + return helper.idxConstant(rewriter, loc, + xConst.value() * yConst.value()); + } + } + return rewriter.create(loc, x, y); + }; + + for (size_t i = 0, j = 0; i < ndims; i++) { + auto dim = shape[i]; + if (dim == ShapedType::kDynamic) { + auto dynSize = dynamics[j++]; + newShape.emplace_back(dynSize); + if (j != 1) { + size = idxMul(size, dynSize); + } + } else { + staticSize *= dim; + newShape.emplace_back(helper.idxConstant(rewriter, loc, dim)); + } + } + + size = idxMul(size, helper.idxConstant(rewriter, loc, staticSize)); + auto ptr = funcCall(rewriter, GC_GPU_OCL_MALLOC, helper.ptrType, + {helper.ptrType, helper.idxType}, loc, + {getCtxPtr(rewriter), size}) + .getResult(); + + newStrides[ndims - 1] = helper.idxConstant(rewriter, loc, 1); + for (int i = static_cast(ndims) - 2; i >= 0; i--) { + newStrides[i] = idxMul(newStrides[i + 1], newShape[i]); + ; + } + + auto dsc = MemRefDescriptor::undef(rewriter, loc, + helper.converter.convertType(type)); + dsc.setAllocatedPtr(rewriter, loc, ptr); + dsc.setAlignedPtr(rewriter, loc, ptr); + dsc.setOffset(rewriter, loc, helper.idxConstant(rewriter, loc, 0)); + + for (unsigned i = 0, n = static_cast(ndims); i < n; i++) { + dsc.setSize(rewriter, loc, i, newShape[i]); + dsc.setStride(rewriter, loc, i, newStrides[i]); + } + + rewriter.replaceOp(allocOp, static_cast(dsc)); + return success(); + } +}; + +struct ConvertDealloc final : ConvertOpPattern { + explicit ConvertDealloc(const Helper &helper) : ConvertOpPattern(helper) {} + + LogicalResult + matchAndRewrite(gpu::DeallocOp gpuDealloc, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = gpuDealloc.getLoc(); + MemRefDescriptor dsc(adaptor.getMemref()); + auto ptr = dsc.allocatedPtr(rewriter, loc); + auto oclDealloc = funcCall(rewriter, GC_GPU_OCL_DEALLOC, helper.voidType, + {helper.ptrType, helper.ptrType}, loc, + {getCtxPtr(rewriter), ptr}); + rewriter.replaceOp(gpuDealloc, oclDealloc); + return success(); + } +}; + +struct ConvertMemcpy final : ConvertOpPattern { + explicit ConvertMemcpy(const Helper &helper) : ConvertOpPattern(helper) {} + + LogicalResult + matchAndRewrite(gpu::MemcpyOp gpuMemcpy, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = gpuMemcpy.getLoc(); + auto srcType = gpuMemcpy.getSrc().getType(); + auto elementSize = srcType.getElementType().getIntOrFloatBitWidth() / 8; + uint64_t numElements = 0; + for (auto dim : srcType.getShape()) { + if (dim == ShapedType::kDynamic) { + gpuMemcpy.emitOpError() + << "dynamic shapes are not currently not supported"; + return failure(); + } + numElements = numElements ? numElements * dim : dim; + } + + MemRefDescriptor srcDsc(adaptor.getSrc()); + MemRefDescriptor dstDsc(adaptor.getDst()); + auto srcPtr = srcDsc.alignedPtr(rewriter, loc); + auto dstPtr = dstDsc.alignedPtr(rewriter, loc); + auto size = helper.idxConstant(rewriter, loc, elementSize * numElements); + auto oclMemcpy = funcCall( + rewriter, GC_GPU_OCL_MEMCPY, helper.voidType, + {helper.ptrType, helper.ptrType, helper.ptrType, helper.idxType}, loc, + {getCtxPtr(rewriter), srcPtr, dstPtr, size}); + rewriter.replaceOp(gpuMemcpy, oclMemcpy); + return success(); + } +}; + +struct ConvertLaunch final : ConvertOpPattern { + + explicit ConvertLaunch(const Helper &helper) : ConvertOpPattern(helper) {} + + LogicalResult + matchAndRewrite(gpu::LaunchFuncOp gpuLaunch, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto kernelPtr = getKernel(gpuLaunch, adaptor, rewriter); + if (!kernelPtr) { + return failure(); + } + + const Location loc = gpuLaunch.getLoc(); + auto kernelArgs = adaptor.getKernelOperands(); + std::vector args; + args.reserve(kernelArgs.size() + 2); + args.emplace_back(getCtxPtr(rewriter)); + args.emplace_back(kernelPtr.value()); + + int i = 0; + for (auto arg : kernelArgs) { + if (isa(gpuLaunch.getKernelOperand(i++).getType())) { + MemRefDescriptor desc(arg); + args.emplace_back(desc.alignedPtr(rewriter, loc)); + } else { + args.emplace_back(arg); + } + } + + const auto gpuOclLaunch = + funcCall(rewriter, GC_GPU_OCL_KERNEL_LAUNCH, helper.voidType, + {helper.ptrType, helper.ptrType}, loc, args, true); + rewriter.replaceOp(gpuLaunch, gpuOclLaunch); + return success(); + } + +private: + // Returns the kernel pointer stored in the global var ...name_Ptr. + // If it's NULL, calls the createKernel() function. + std::optional getKernel(gpu::LaunchFuncOp &gpuLaunch, + OpAdaptor &adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = gpuLaunch.getLoc(); + auto ctx = getCtxPtr(rewriter); + auto mod = rewriter.getBlock()->getParent()->getParentOfType(); + auto kernelModName = gpuLaunch.getKernelModuleName(); + SmallString<128> getFuncName("getGcGpuOclKernel_"); + getFuncName.append(kernelModName); + + if (helper.kernelNames.insert(SmallString<32>(kernelModName)).second) { + auto insPoint = rewriter.saveInsertionPoint(); + SmallString<128> strBuf("gcGpuOclKernel_"); + strBuf.append(kernelModName); + strBuf.append("_"); + auto strBufStart = strBuf.size(); + auto str = [&strBuf, + strBufStart](const char *chars) -> SmallString<128> & { + strBuf.truncate(strBufStart); + strBuf.append(chars); + return strBuf; + }; + + SmallString<128> createFuncName("createGcGpuOclKernel_"); + createFuncName.append(kernelModName); + if (!createKernel(gpuLaunch, adaptor, rewriter, loc, mod, createFuncName, + str)) { + return std::nullopt; + } + + auto function = rewriter.create( + loc, getFuncName, + LLVM::LLVMFunctionType::get(helper.ptrType, {helper.ptrType}), + LLVM::Linkage::Internal); + rewriter.setInsertionPointToStart(function.addEntryBlock(rewriter)); + + auto ptr = mod.lookupSymbol(str("Ptr")); + assert(ptr); + auto null = rewriter.create(loc, helper.ptrType); + auto ptrPtr = rewriter.create(loc, ptr); + auto ptrVal = rewriter.create(loc, helper.ptrType, ptrPtr); + auto cmp = rewriter.create(loc, LLVM::ICmpPredicate::eq, + ptrVal, null); + + auto body = &function.getBody(); + auto thenBlock = rewriter.createBlock(body); + auto elseBlock = rewriter.createBlock(body); + rewriter.setInsertionPointToEnd(&body->front()); + rewriter.create(loc, cmp, thenBlock, elseBlock); + + // Then block + rewriter.setInsertionPointToStart(thenBlock); + auto result = funcCall(rewriter, createFuncName, helper.ptrType, + {helper.ptrType}, loc, {function.getArgument(0)}); + rewriter.create(loc, result.getResult()); + + // Else block + rewriter.setInsertionPointToStart(elseBlock); + rewriter.create(loc, ptrVal); + + rewriter.restoreInsertionPoint(insPoint); + } + + auto kernelFunc = mod.lookupSymbol(getFuncName); + if (!kernelFunc) { + gpuLaunch.emitOpError() << "Function " << getFuncName << " not found!"; + return std::nullopt; + } + return rewriter.create(loc, kernelFunc, ValueRange(ctx)) + .getResult(); + } + + // Create a new kernel and save the pointer to the global variable + // ...name_Ptr. + bool createKernel( + gpu::LaunchFuncOp &gpuLaunch, OpAdaptor &adaptor, + ConversionPatternRewriter &rewriter, Location &loc, ModuleOp &mod, + StringRef funcName, + const std::function &(const char *chars)> &str) const { + auto kernelModName = gpuLaunch.getKernelModuleName(); + auto kernelMod = SymbolTable::lookupNearestSymbolFrom( + gpuLaunch, kernelModName); + if (!kernelMod) { + gpuLaunch.emitOpError() << "Module " << kernelModName << " not found!"; + return false; + } + const auto binaryAttr = kernelMod->getAttrOfType("gpu.binary"); + if (!binaryAttr) { + kernelMod.emitOpError() << "missing 'gpu.binary' attribute"; + return false; + } + + rewriter.setInsertionPointToStart(mod.getBody()); + // The kernel pointer is stored here + rewriter.create(loc, helper.ptrType, /*isConstant=*/false, + LLVM::Linkage::Internal, str("Ptr"), + rewriter.getZeroAttr(helper.ptrType)); + rewriter.eraseOp(kernelMod); + + auto function = rewriter.create( + loc, funcName, + LLVM::LLVMFunctionType::get(helper.ptrType, {helper.ptrType}), + LLVM::Linkage::Internal); + rewriter.setInsertionPointToStart(function.addEntryBlock(rewriter)); + + auto ptr = mod.lookupSymbol(str("Ptr")); + assert(ptr); + SmallVector nameChars(kernelModName.getValue().begin(), + kernelModName.getValue().end()); + nameChars.emplace_back('\0'); + // Kernel name and SPIRV are stored as global strings + auto name = LLVM::createGlobalString( + loc, rewriter, str("Name"), + StringRef(nameChars.data(), nameChars.size()), LLVM::Linkage::Internal); + auto spirv = LLVM::createGlobalString(loc, rewriter, str("SPIRV"), + binaryAttr.getValue(), + LLVM::Linkage::Internal); + auto spirvSize = rewriter.create( + loc, helper.idxType, + mlir::IntegerAttr::get(helper.idxType, + static_cast(binaryAttr.size()))); + + SmallVector globalSize; + SmallVector localSize; + SmallVector argSize; + kernelMod->walk([&](gpu::GPUFuncOp func) { + if (func.getName() == gpuLaunch.getKernelName()) { + for (auto s : func.getKnownGridSize().value()) { + globalSize.emplace_back(s); + } + for (auto s : func.getKnownBlockSize().value()) { + localSize.emplace_back(s); + } + } + }); + assert(globalSize.size() == 3 && localSize.size() == 3); + globalSize = {globalSize[0] * localSize[0], globalSize[1] * localSize[1], + globalSize[2] * localSize[2]}; + for (auto arg : adaptor.getKernelOperands()) { + auto type = arg.getType(); + auto size = type.isIntOrFloat() ? type.getIntOrFloatBitWidth() / 8 : 0; + argSize.emplace_back(size); + } + + auto array = [&](SmallVector &values) { + auto size = helper.idxConstant(rewriter, loc, values.size()); + auto arrayPtr = rewriter.create(loc, helper.ptrType, + helper.idxType, size); + for (size_t i = 0, n = values.size(); i < n; i++) { + auto elementPtr = rewriter.create( + loc, helper.ptrType, helper.idxType, arrayPtr, + helper.idxConstant(rewriter, loc, i)); + rewriter.create( + loc, helper.idxConstant(rewriter, loc, values[i]), elementPtr); + } + return arrayPtr.getResult(); + }; + + auto ctx = function.getArgument(0); + auto argNum = + helper.idxConstant(rewriter, loc, adaptor.getKernelOperands().size()); + auto createKernelCall = funcCall( + rewriter, GC_GPU_OCL_KERNEL_CREATE, helper.ptrType, + {helper.ptrType, helper.idxType, helper.ptrType, helper.ptrType, + helper.ptrType, helper.ptrType, helper.idxType, helper.ptrType}, + loc, + {ctx, spirvSize, spirv, name, array(globalSize), array(localSize), + argNum, array(argSize)}); + auto result = createKernelCall.getResult(); + + // Save the kernel pointer to the global var using CAS + auto null = rewriter.create(loc, helper.ptrType); + auto ptrPtr = rewriter.create(loc, ptr); + auto casResult = rewriter.create( + loc, ptrPtr, null, result, LLVM::AtomicOrdering::acq_rel, + LLVM::AtomicOrdering::monotonic); + auto casFlag = rewriter.create( + loc, rewriter.getI1Type(), casResult, 1); + + auto body = &function.getBody(); + auto thenBlock = rewriter.createBlock(body); + auto elseBlock = rewriter.createBlock(body); + rewriter.setInsertionPointToEnd(&body->front()); + rewriter.create(loc, casFlag, thenBlock, elseBlock); + + // Then block + rewriter.setInsertionPointToStart(thenBlock); + rewriter.create(loc, result); + + // Else block + // The kernel has already been created by another thread, destroying this + // one. + rewriter.setInsertionPointToStart(elseBlock); + helper.destroyKernels(rewriter, loc, result); + result = rewriter.create(loc, helper.ptrType, + casResult, 0); + rewriter.create(loc, result); + + rewriter.setInsertionPointAfter(function); + return true; + } +}; + +struct GpuToGpuOcl final : gc::impl::GpuToGpuOclBase { + + void runOnOperation() override { + const auto ctx = &getContext(); + const LLVMConversionTarget target(getContext()); + LLVMTypeConverter converter(ctx); + Helper helper(ctx, converter); + RewritePatternSet patterns(ctx); + + populateGpuToLLVMConversionPatterns(converter, patterns); + patterns.insert( + helper); + + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) { + signalPassFailure(); + return; + } + + // Add gpuOclDestructor() function that destroys all the kernels + auto mod = llvm::dyn_cast(getOperation()); + assert(mod); + OpBuilder rewriter(mod.getBody(), mod.getBody()->end()); + auto destruct = rewriter.create( + mod.getLoc(), GC_GPU_OCL_MOD_DESTRUCTOR, + LLVM::LLVMFunctionType::get(helper.voidType, {}), + LLVM::Linkage::External); + auto loc = destruct.getLoc(); + rewriter.setInsertionPointToStart(destruct.addEntryBlock(rewriter)); + // Add memory fence + rewriter.create(loc, LLVM::AtomicOrdering::acquire); + + SmallVector kernelPtrs; + SmallString<128> strBuf("gcGpuOclKernel_"); + auto strBufStart = strBuf.size(); + kernelPtrs.reserve(helper.kernelNames.size()); + for (auto &name : helper.kernelNames) { + strBuf.truncate(strBufStart); + strBuf.append(name); + strBuf.append("_Ptr"); + auto ptr = mod.lookupSymbol(strBuf); + assert(ptr); + auto ptrVal = rewriter.create( + loc, helper.ptrType, rewriter.create(loc, ptr)); + kernelPtrs.emplace_back(ptrVal); + } + + helper.destroyKernels(rewriter, loc, kernelPtrs); + rewriter.create(loc, ValueRange{}); + } +}; +} // namespace \ No newline at end of file From cc85da7402437181d4b1ff553da831c2d7cdb607 Mon Sep 17 00:00:00 2001 From: Andrey Pavlenko Date: Fri, 13 Sep 2024 00:38:54 +0200 Subject: [PATCH 2/8] Use constexpr instead of #define --- .../GPURuntime/GpuOclRuntime.h | 24 +++++----- lib/gc/Transforms/GPU/GpuToGpuOcl.cpp | 48 ++++++++++--------- 2 files changed, 38 insertions(+), 34 deletions(-) diff --git a/include/gc/ExecutionEngine/GPURuntime/GpuOclRuntime.h b/include/gc/ExecutionEngine/GPURuntime/GpuOclRuntime.h index fa576c4e4..b01b9f2c6 100644 --- a/include/gc/ExecutionEngine/GPURuntime/GpuOclRuntime.h +++ b/include/gc/ExecutionEngine/GPURuntime/GpuOclRuntime.h @@ -6,22 +6,24 @@ // //===----------------------------------------------------------------------===// -#ifndef GC_GPUOCLMODULE_H -#define GC_GPUOCLMODULE_H +#ifndef GC_GPUOCLRUNTIME_H +#define GC_GPUOCLRUNTIME_H -#define GC_GPU_OCL_MALLOC "gcGpuOclMaloc" -#define GC_GPU_OCL_DEALLOC "gcGpuOclDealloc" -#define GC_GPU_OCL_MEMCPY "gcGpuOclMemcpy" -#define GC_GPU_OCL_KERNEL_CREATE "gcGpuOclKernelCreate" -#define GC_GPU_OCL_KERNEL_DESTROY "gcGpuOclKernelDestroy" -#define GC_GPU_OCL_KERNEL_LAUNCH "gcGpuOclKernelLaunch" -#define GC_GPU_OCL_MOD_DESTRUCTOR "gcGpuOclModuleDestructor" +namespace mlir::gc::gpu { +constexpr char GPU_OCL_MALLOC[] = "gcGpuOclMalloc"; +constexpr char GPU_OCL_DEALLOC[] = "gcGpuOclDealloc"; +constexpr char GPU_OCL_MEMCPY[] = "gcGpuOclMemcpy"; +constexpr char GPU_OCL_KERNEL_CREATE[] = "gcGpuOclKernelCreate"; +constexpr char GPU_OCL_KERNEL_DESTROY[] = "gcGpuOclKernelDestroy"; +constexpr char GPU_OCL_KERNEL_LAUNCH[] = "gcGpuOclKernelLaunch"; +constexpr char GPU_OCL_MOD_DESTRUCTOR[] = "gcGpuOclModuleDestructor"; +} // namespace mlir::gc::gpu -#ifndef GC_GPU_OCL_DEF_ONLY +#ifndef GC_GPU_OCL_CONST_ONLY // TBD #else -#undef GC_GPU_OCL_DEF_ONLY +#undef GC_GPU_OCL_CONST_ONLY #endif #endif diff --git a/lib/gc/Transforms/GPU/GpuToGpuOcl.cpp b/lib/gc/Transforms/GPU/GpuToGpuOcl.cpp index aa42115c9..838930534 100644 --- a/lib/gc/Transforms/GPU/GpuToGpuOcl.cpp +++ b/lib/gc/Transforms/GPU/GpuToGpuOcl.cpp @@ -5,9 +5,9 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// -#include +#include -#define GC_GPU_OCL_DEF_ONLY +#define GC_GPU_OCL_CONST_ONLY #include "gc/ExecutionEngine/GPURuntime/GpuOclRuntime.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" @@ -17,17 +17,15 @@ #include "mlir/Dialect/GPU/Transforms/Passes.h" using namespace mlir; +using namespace mlir::gc::gpu; -namespace mlir { -namespace gc { +namespace mlir::gc { #define GEN_PASS_DECL_GPUTOGPUOCL #define GEN_PASS_DEF_GPUTOGPUOCL #include "gc/Transforms/Passes.h.inc" -} // namespace gc -} // namespace mlir +} // namespace mlir::gc namespace { - LLVM::CallOp funcCall(OpBuilder &builder, const StringRef name, const Type returnType, const ArrayRef argTypes, const Location loc, const ArrayRef arguments, @@ -42,8 +40,10 @@ LLVM::CallOp funcCall(OpBuilder &builder, const StringRef name, return builder.create(loc, function, arguments); } -// Assuming that the pointer to GcGpuOclContext is passed as the last -// memref with zero dims argument of the current function. +// Assuming that the pointer to the context is passed as the last argument +// of the current function of type memref with zero dims. When lowering +// to LLVM, the memref arg is replaced with 3 args of types ptr, ptr, i64. +// Returning the first one. Value getCtxPtr(const OpBuilder &rewriter) { auto func = rewriter.getBlock()->getParent()->getParentOfType(); @@ -55,7 +55,7 @@ struct Helper final { Type voidType; Type ptrType; Type idxType; - mutable std::set> kernelNames; + mutable std::unordered_set kernelNames; explicit Helper(MLIRContext *ctx, LLVMTypeConverter &converter) : converter(converter), voidType(LLVM::LLVMVoidType::get(ctx)), @@ -81,7 +81,7 @@ struct Helper final { rewriter.create(loc, kernelPtrs[i], elementPtr); } - funcCall(rewriter, GC_GPU_OCL_KERNEL_DESTROY, voidType, {idxType, ptrType}, + funcCall(rewriter, GPU_OCL_KERNEL_DESTROY, voidType, {idxType, ptrType}, loc, {size, kernelPtrsArray}); } }; @@ -117,7 +117,7 @@ struct ConvertAlloc final : ConvertOpPattern { } } auto size = helper.idxConstant(rewriter, loc, staticSize); - auto ptr = funcCall(rewriter, GC_GPU_OCL_MALLOC, helper.ptrType, + auto ptr = funcCall(rewriter, GPU_OCL_MALLOC, helper.ptrType, {helper.ptrType, helper.idxType}, loc, {getCtxPtr(rewriter), size}) .getResult(); @@ -158,7 +158,7 @@ struct ConvertAlloc final : ConvertOpPattern { } size = idxMul(size, helper.idxConstant(rewriter, loc, staticSize)); - auto ptr = funcCall(rewriter, GC_GPU_OCL_MALLOC, helper.ptrType, + auto ptr = funcCall(rewriter, GPU_OCL_MALLOC, helper.ptrType, {helper.ptrType, helper.idxType}, loc, {getCtxPtr(rewriter), size}) .getResult(); @@ -194,7 +194,7 @@ struct ConvertDealloc final : ConvertOpPattern { auto loc = gpuDealloc.getLoc(); MemRefDescriptor dsc(adaptor.getMemref()); auto ptr = dsc.allocatedPtr(rewriter, loc); - auto oclDealloc = funcCall(rewriter, GC_GPU_OCL_DEALLOC, helper.voidType, + auto oclDealloc = funcCall(rewriter, GPU_OCL_DEALLOC, helper.voidType, {helper.ptrType, helper.ptrType}, loc, {getCtxPtr(rewriter), ptr}); rewriter.replaceOp(gpuDealloc, oclDealloc); @@ -227,7 +227,7 @@ struct ConvertMemcpy final : ConvertOpPattern { auto dstPtr = dstDsc.alignedPtr(rewriter, loc); auto size = helper.idxConstant(rewriter, loc, elementSize * numElements); auto oclMemcpy = funcCall( - rewriter, GC_GPU_OCL_MEMCPY, helper.voidType, + rewriter, GPU_OCL_MEMCPY, helper.voidType, {helper.ptrType, helper.ptrType, helper.ptrType, helper.idxType}, loc, {getCtxPtr(rewriter), srcPtr, dstPtr, size}); rewriter.replaceOp(gpuMemcpy, oclMemcpy); @@ -249,7 +249,7 @@ struct ConvertLaunch final : ConvertOpPattern { const Location loc = gpuLaunch.getLoc(); auto kernelArgs = adaptor.getKernelOperands(); - std::vector args; + SmallVector args; args.reserve(kernelArgs.size() + 2); args.emplace_back(getCtxPtr(rewriter)); args.emplace_back(kernelPtr.value()); @@ -265,7 +265,7 @@ struct ConvertLaunch final : ConvertOpPattern { } const auto gpuOclLaunch = - funcCall(rewriter, GC_GPU_OCL_KERNEL_LAUNCH, helper.voidType, + funcCall(rewriter, GPU_OCL_KERNEL_LAUNCH, helper.voidType, {helper.ptrType, helper.ptrType}, loc, args, true); rewriter.replaceOp(gpuLaunch, gpuOclLaunch); return success(); @@ -284,7 +284,9 @@ struct ConvertLaunch final : ConvertOpPattern { SmallString<128> getFuncName("getGcGpuOclKernel_"); getFuncName.append(kernelModName); - if (helper.kernelNames.insert(SmallString<32>(kernelModName)).second) { + if (helper.kernelNames + .insert(std::string(kernelModName.begin(), kernelModName.end())) + .second) { auto insPoint = rewriter.saveInsertionPoint(); SmallString<128> strBuf("gcGpuOclKernel_"); strBuf.append(kernelModName); @@ -391,10 +393,10 @@ struct ConvertLaunch final : ConvertOpPattern { auto spirv = LLVM::createGlobalString(loc, rewriter, str("SPIRV"), binaryAttr.getValue(), LLVM::Linkage::Internal); - auto spirvSize = rewriter.create( + auto spirvSize = rewriter.create( loc, helper.idxType, - mlir::IntegerAttr::get(helper.idxType, - static_cast(binaryAttr.size()))); + IntegerAttr::get(helper.idxType, + static_cast(binaryAttr.size()))); SmallVector globalSize; SmallVector localSize; @@ -436,7 +438,7 @@ struct ConvertLaunch final : ConvertOpPattern { auto argNum = helper.idxConstant(rewriter, loc, adaptor.getKernelOperands().size()); auto createKernelCall = funcCall( - rewriter, GC_GPU_OCL_KERNEL_CREATE, helper.ptrType, + rewriter, GPU_OCL_KERNEL_CREATE, helper.ptrType, {helper.ptrType, helper.idxType, helper.ptrType, helper.ptrType, helper.ptrType, helper.ptrType, helper.idxType, helper.ptrType}, loc, @@ -501,7 +503,7 @@ struct GpuToGpuOcl final : gc::impl::GpuToGpuOclBase { assert(mod); OpBuilder rewriter(mod.getBody(), mod.getBody()->end()); auto destruct = rewriter.create( - mod.getLoc(), GC_GPU_OCL_MOD_DESTRUCTOR, + mod.getLoc(), GPU_OCL_MOD_DESTRUCTOR, LLVM::LLVMFunctionType::get(helper.voidType, {}), LLVM::Linkage::External); auto loc = destruct.getLoc(); From 3eda0d7be20d1eb7a1b34a37a1f6a5417eea8196 Mon Sep 17 00:00:00 2001 From: Andrey Pavlenko Date: Sat, 14 Sep 2024 21:51:44 +0200 Subject: [PATCH 3/8] Get grid/block sizes from gpu.launch_func --- lib/gc/Transforms/GPU/GpuToGpuOcl.cpp | 43 +++++++++++++-------------- 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/lib/gc/Transforms/GPU/GpuToGpuOcl.cpp b/lib/gc/Transforms/GPU/GpuToGpuOcl.cpp index 838930534..6a725a434 100644 --- a/lib/gc/Transforms/GPU/GpuToGpuOcl.cpp +++ b/lib/gc/Transforms/GPU/GpuToGpuOcl.cpp @@ -398,29 +398,23 @@ struct ConvertLaunch final : ConvertOpPattern { IntegerAttr::get(helper.idxType, static_cast(binaryAttr.size()))); - SmallVector globalSize; - SmallVector localSize; - SmallVector argSize; - kernelMod->walk([&](gpu::GPUFuncOp func) { - if (func.getName() == gpuLaunch.getKernelName()) { - for (auto s : func.getKnownGridSize().value()) { - globalSize.emplace_back(s); - } - for (auto s : func.getKnownBlockSize().value()) { - localSize.emplace_back(s); - } - } - }); - assert(globalSize.size() == 3 && localSize.size() == 3); - globalSize = {globalSize[0] * localSize[0], globalSize[1] * localSize[1], - globalSize[2] * localSize[2]}; + SmallVector gridSize; + SmallVector blockSize; + SmallVector argSize; + gridSize.emplace_back(gpuLaunch.getGridSizeX()); + gridSize.emplace_back(gpuLaunch.getGridSizeY()); + gridSize.emplace_back(gpuLaunch.getGridSizeZ()); + blockSize.emplace_back(gpuLaunch.getBlockSizeX()); + blockSize.emplace_back(gpuLaunch.getBlockSizeY()); + blockSize.emplace_back(gpuLaunch.getBlockSizeZ()); + for (auto arg : adaptor.getKernelOperands()) { auto type = arg.getType(); auto size = type.isIntOrFloat() ? type.getIntOrFloatBitWidth() / 8 : 0; - argSize.emplace_back(size); + argSize.emplace_back(helper.idxConstant(rewriter, loc, size)); } - auto array = [&](SmallVector &values) { + auto array = [&](SmallVector &values) { auto size = helper.idxConstant(rewriter, loc, values.size()); auto arrayPtr = rewriter.create(loc, helper.ptrType, helper.idxType, size); @@ -428,8 +422,13 @@ struct ConvertLaunch final : ConvertOpPattern { auto elementPtr = rewriter.create( loc, helper.ptrType, helper.idxType, arrayPtr, helper.idxConstant(rewriter, loc, i)); - rewriter.create( - loc, helper.idxConstant(rewriter, loc, values[i]), elementPtr); + auto value = values[i]; + if (auto cast = value.getDefiningOp()) { + assert(getConstantIntValue(cast.getOperand(0))); + value = helper.idxConstant( + rewriter, loc, getConstantIntValue(cast.getOperand(0)).value()); + } + rewriter.create(loc, value, elementPtr); } return arrayPtr.getResult(); }; @@ -442,8 +441,8 @@ struct ConvertLaunch final : ConvertOpPattern { {helper.ptrType, helper.idxType, helper.ptrType, helper.ptrType, helper.ptrType, helper.ptrType, helper.idxType, helper.ptrType}, loc, - {ctx, spirvSize, spirv, name, array(globalSize), array(localSize), - argNum, array(argSize)}); + {ctx, spirvSize, spirv, name, array(gridSize), array(blockSize), argNum, + array(argSize)}); auto result = createKernelCall.getResult(); // Save the kernel pointer to the global var using CAS From 39350f3d6abbe8a2e985ccf3840793a0972a0a65 Mon Sep 17 00:00:00 2001 From: Andrey Pavlenko Date: Mon, 16 Sep 2024 18:14:32 +0200 Subject: [PATCH 4/8] Pass value pointer but not the values to kernelLaunch() --- lib/gc/Transforms/GPU/GpuToGpuOcl.cpp | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/lib/gc/Transforms/GPU/GpuToGpuOcl.cpp b/lib/gc/Transforms/GPU/GpuToGpuOcl.cpp index 6a725a434..e48025b0d 100644 --- a/lib/gc/Transforms/GPU/GpuToGpuOcl.cpp +++ b/lib/gc/Transforms/GPU/GpuToGpuOcl.cpp @@ -256,11 +256,17 @@ struct ConvertLaunch final : ConvertOpPattern { int i = 0; for (auto arg : kernelArgs) { - if (isa(gpuLaunch.getKernelOperand(i++).getType())) { + if (auto type = gpuLaunch.getKernelOperand(i++).getType(); + isa(type)) { MemRefDescriptor desc(arg); args.emplace_back(desc.alignedPtr(rewriter, loc)); } else { - args.emplace_back(arg); + // Store the arg on the stack and pass the pointer + auto ptr = rewriter.create( + loc, helper.ptrType, typeConverter->convertType(type), + helper.idxConstant(rewriter, loc, 1)); + rewriter.create(loc, arg, ptr); + args.emplace_back(ptr); } } @@ -352,7 +358,7 @@ struct ConvertLaunch final : ConvertOpPattern { // ...name_Ptr. bool createKernel( gpu::LaunchFuncOp &gpuLaunch, OpAdaptor &adaptor, - ConversionPatternRewriter &rewriter, Location &loc, ModuleOp &mod, + ConversionPatternRewriter &rewriter, const Location &loc, ModuleOp &mod, StringRef funcName, const std::function &(const char *chars)> &str) const { auto kernelModName = gpuLaunch.getKernelModuleName(); @@ -410,6 +416,8 @@ struct ConvertLaunch final : ConvertOpPattern { for (auto arg : adaptor.getKernelOperands()) { auto type = arg.getType(); + // Assuming, that the value is either an integer or a float or a pointer. + // In the latter case, the size is 0 bytes. auto size = type.isIntOrFloat() ? type.getIntOrFloatBitWidth() / 8 : 0; argSize.emplace_back(helper.idxConstant(rewriter, loc, size)); } From 8b1c013d79cc7531b4cc0e6774674abec9e46cf4 Mon Sep 17 00:00:00 2001 From: Andrey Pavlenko Date: Thu, 19 Sep 2024 22:58:11 +0200 Subject: [PATCH 5/8] Added test, fixes --- include/gc/Transforms/Passes.td | 7 +++ lib/gc/Transforms/GPU/AddContextArg.cpp | 45 +++++++++++++ lib/gc/Transforms/GPU/CMakeLists.txt | 1 + lib/gc/Transforms/GPU/GpuToGpuOcl.cpp | 1 + lib/gc/Transforms/GPU/Pipeline.cpp | 9 ++- .../Transforms/IterativeTilingAndFusion.cpp | 2 +- .../test/gc/gpu-runner/XeGPU/lit.local.cfg | 2 + .../test/gc/gpu-runner/gpu-to-gpuocl.mlir | 63 +++++++++++++++++++ test/mlir/test/gc/gpu-runner/lit.local.cfg | 5 +- 9 files changed, 130 insertions(+), 5 deletions(-) create mode 100644 lib/gc/Transforms/GPU/AddContextArg.cpp create mode 100644 test/mlir/test/gc/gpu-runner/XeGPU/lit.local.cfg create mode 100644 test/mlir/test/gc/gpu-runner/gpu-to-gpuocl.mlir diff --git a/include/gc/Transforms/Passes.td b/include/gc/Transforms/Passes.td index 6ba973361..2ddf0a06e 100644 --- a/include/gc/Transforms/Passes.td +++ b/include/gc/Transforms/Passes.td @@ -94,6 +94,13 @@ def LinalgToXeGPU : Pass<"linalg-to-xegpu", "func::FuncOp"> { ]; } +def AddContextArg : Pass<"add-ctx-arg", "func::FuncOp"> { + let summary = "Add a context argument."; + let description = [{ + Add a new memref argument to the function, that could be used to pass some context. + }]; +} + def GpuToGpuOcl : Pass<"gpu-to-gpuocl", "ModuleOp"> { let summary = "Convert the GPU operations to GpuOclRuntime calls."; let description = [{ diff --git a/lib/gc/Transforms/GPU/AddContextArg.cpp b/lib/gc/Transforms/GPU/AddContextArg.cpp new file mode 100644 index 000000000..d731fbb62 --- /dev/null +++ b/lib/gc/Transforms/GPU/AddContextArg.cpp @@ -0,0 +1,45 @@ +//===-- AddContextArg.cpp - Add context argument ----------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#include "mlir/Conversion/Passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" + +namespace mlir::gc { +#define GEN_PASS_DECL_ADDCONTEXTARG +#define GEN_PASS_DEF_ADDCONTEXTARG +#include "gc/Transforms/Passes.h.inc" +} // namespace mlir::gc + +using namespace mlir; + +namespace { +struct AddContextArg final : gc::impl::AddContextArgBase { + void runOnOperation() override { + auto func = getOperation(); + auto funcType = func.getFunctionType(); + auto argTypes = llvm::to_vector<8>(funcType.getInputs()); + auto resultTypes = llvm::to_vector<1>(funcType.getResults()); + auto ctx = func->getContext(); + auto newArgType = MemRefType::get({}, IntegerType::get(ctx, 8)); + argTypes.emplace_back(newArgType); + auto newFuncType = FunctionType::get(ctx, argTypes, resultTypes); + func.setType(newFuncType); + + if (func.getBody().hasOneBlock()) { + func.getBody().front().addArgument(newArgType, func.getLoc()); + } + + // Find all function calls and append the last argument of the current + // function to the call. + func.walk([&](func::CallOp call) { + auto args = llvm::to_vector<8>(call.getOperands()); + args.emplace_back(func.getArgument(func.getNumArguments() - 1)); + call->setOperands(args); + }); + } +}; +} // namespace diff --git a/lib/gc/Transforms/GPU/CMakeLists.txt b/lib/gc/Transforms/GPU/CMakeLists.txt index 5fd06e8db..3909681e3 100644 --- a/lib/gc/Transforms/GPU/CMakeLists.txt +++ b/lib/gc/Transforms/GPU/CMakeLists.txt @@ -1,4 +1,5 @@ gc_add_mlir_library(GcGpuPasses + AddContextArg.cpp GpuToGpuOcl.cpp LinalgToXeGPU.cpp Pipeline.cpp diff --git a/lib/gc/Transforms/GPU/GpuToGpuOcl.cpp b/lib/gc/Transforms/GPU/GpuToGpuOcl.cpp index e48025b0d..dfcd1daba 100644 --- a/lib/gc/Transforms/GPU/GpuToGpuOcl.cpp +++ b/lib/gc/Transforms/GPU/GpuToGpuOcl.cpp @@ -316,6 +316,7 @@ struct ConvertLaunch final : ConvertOpPattern { loc, getFuncName, LLVM::LLVMFunctionType::get(helper.ptrType, {helper.ptrType}), LLVM::Linkage::Internal); + function.setAlwaysInline(true); rewriter.setInsertionPointToStart(function.addEntryBlock(rewriter)); auto ptr = mod.lookupSymbol(str("Ptr")); diff --git a/lib/gc/Transforms/GPU/Pipeline.cpp b/lib/gc/Transforms/GPU/Pipeline.cpp index d7bc69e13..a507e2c63 100644 --- a/lib/gc/Transforms/GPU/Pipeline.cpp +++ b/lib/gc/Transforms/GPU/Pipeline.cpp @@ -47,8 +47,11 @@ struct GPUPipelineOption : PassPipelineOptions { llvm::cl::init(true)}; }; -void populateGPUPipeline(mlir::OpPassManager &pm, +void populateGPUPipeline(OpPassManager &pm, const GPUPipelineOption &pipelineOption) { + // Add an argument for the GPU context + pm.addNestedPass(createAddContextArg()); + pm.addNestedPass(createIterativeTilingAndFusion()); pm.addPass(bufferization::createEmptyTensorEliminationPass()); @@ -91,6 +94,7 @@ void populateGPUPipeline(mlir::OpPassManager &pm, /*isUsmArgs*/ pipelineOption.isUsmArgs.getValue()}; pm.addNestedPass( imex::createInsertGPUAllocsPass(insertGPUAllocsOption)); + pm.addPass(createGpuKernelOutliningPass()); pm.addPass(createCanonicalizerPass()); pm.addPass(imex::createSetSPIRVCapabilitiesPass()); @@ -109,7 +113,6 @@ void populateGPUPipeline(mlir::OpPassManager &pm, pm.addNestedPass(LLVM::createRequestCWrappersPass()); pm.addPass(imex::createSerializeSPIRVPass()); pm.addPass(createConvertVectorToSCFPass()); - pm.addPass(imex::createConvertGPUToGPUXPass()); pm.addPass(createConvertSCFToCFPass()); pm.addPass(createConvertControlFlowToLLVMPass()); pm.addPass(createConvertVectorToLLVMPass()); @@ -117,7 +120,7 @@ void populateGPUPipeline(mlir::OpPassManager &pm, pm.addPass(createArithToLLVMConversionPass()); pm.addPass(createConvertFuncToLLVMPass()); pm.addPass(createConvertMathToLLVMPass()); - pm.addPass(imex::createConvertGPUXToLLVMPass()); + pm.addPass(createGpuToGpuOcl()); pm.addPass(createConvertIndexToLLVMPass()); pm.addPass(memref::createExpandStridedMetadataPass()); pm.addPass(createLowerAffinePass()); diff --git a/lib/gc/Transforms/IterativeTilingAndFusion.cpp b/lib/gc/Transforms/IterativeTilingAndFusion.cpp index a486c29b0..d94db20c3 100644 --- a/lib/gc/Transforms/IterativeTilingAndFusion.cpp +++ b/lib/gc/Transforms/IterativeTilingAndFusion.cpp @@ -680,7 +680,7 @@ defaultTilingOfType(RewriterBase &rewriter, Operation *op, } else { defaultTileSize.resize(iteratorTypes.size(), rewriter.getIndexAttr(0)); // Try tileSize from `32` to `16`. - SmallVector tsOrder = {32, 16}; + SmallVector tsOrder = {16, 32}; // Record how many dims have been tiled, including fully tiled, i.e. // tileSize == dimSize. unsigned nonOneTileDims = diff --git a/test/mlir/test/gc/gpu-runner/XeGPU/lit.local.cfg b/test/mlir/test/gc/gpu-runner/XeGPU/lit.local.cfg new file mode 100644 index 000000000..152c26255 --- /dev/null +++ b/test/mlir/test/gc/gpu-runner/XeGPU/lit.local.cfg @@ -0,0 +1,2 @@ +# GPUX is currently disabled +config.unsupported = True diff --git a/test/mlir/test/gc/gpu-runner/gpu-to-gpuocl.mlir b/test/mlir/test/gc/gpu-runner/gpu-to-gpuocl.mlir new file mode 100644 index 000000000..7742b8d19 --- /dev/null +++ b/test/mlir/test/gc/gpu-runner/gpu-to-gpuocl.mlir @@ -0,0 +1,63 @@ +// RUN: gc-opt %s --gc-gpu-pipeline | FileCheck %s + +module @test { + func.func @entry(%arg0: memref<32x32xf32>, %arg1: memref<32x32xf32>, %arg2: memref<32x32xf32>) { + %0 = bufferization.to_tensor %arg0 restrict : memref<32x32xf32> + %1 = bufferization.to_tensor %arg1 restrict : memref<32x32xf32> + %2 = tensor.empty() : tensor<32x32xf32> + %3 = linalg.add ins(%1, %0 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%2 : tensor<32x32xf32>) -> tensor<32x32xf32> + bufferization.materialize_in_destination %3 in restrict writable %arg2 : (tensor<32x32xf32>, memref<32x32xf32>) -> () + return + } +} + +// CHECK: llvm.mlir.global internal constant @gcGpuOclKernel_entry_kernel_SPIRV +// CHECK: llvm.mlir.global internal constant @gcGpuOclKernel_entry_kernel_Name +// CHECK: llvm.mlir.global internal @gcGpuOclKernel_entry_kernel_Ptr + +// CHECK: llvm.func internal @createGcGpuOclKernel_entry_kernel([[CTX:%.+]]: !llvm.ptr) -> !llvm.ptr +// CHECK: [[PTR_ADDR:%.+]] = llvm.mlir.addressof @gcGpuOclKernel_entry_kernel_Ptr +// CHECK: [[ZERO:%.+]] = llvm.mlir.zero +// CHECK: [[ONE:%.+]] = llvm.mlir.constant(1 : i64) : i64 +// CHECK: [[NEW_PTR:%.+]] = llvm.call @gcGpuOclKernelCreate([[CTX]] +// CHECK: [[CMPXCHG:%.+]] = llvm.cmpxchg [[PTR_ADDR]], [[ZERO]], [[NEW_PTR]] +// CHECK: [[FLAG:%.+]] = llvm.extractvalue [[CMPXCHG]][1] +// CHECK: llvm.cond_br [[FLAG]], [[BB1:\^.+]], [[BB2:\^.+]] +// CHECK: [[BB1]]: +// CHECK: llvm.return [[NEW_PTR]] +// CHECK: [[BB2]]: +// CHECK: [[ARRAY:%.+]] = llvm.alloca [[ONE]] +// CHECK: llvm.store [[NEW_PTR]], [[ARRAY]] +// CHECK: llvm.call @gcGpuOclKernelDestroy([[ONE]], [[ARRAY]]) +// CHECK: [[OLD_PTR:%.+]] = llvm.extractvalue [[CMPXCHG]][0] +// CHECK: llvm.return [[OLD_PTR]] + +// CHECK: llvm.func internal @getGcGpuOclKernel_entry_kernel([[CTX:%.+]]: !llvm.ptr) -> !llvm.ptr attributes {always_inline} +// CHECK: [[ZERO:%.+]] = llvm.mlir.zero +// CHECK: [[PTR_ADDR:%.+]] = llvm.mlir.addressof @gcGpuOclKernel_entry_kernel_Ptr +// CHECK: [[PTR:%.+]] = llvm.load [[PTR_ADDR]] +// CHECK: [[ICMP:%.+]] = llvm.icmp "eq" [[PTR]], [[ZERO]] +// CHECK: llvm.cond_br [[ICMP]], [[BB1:\^.+]], [[BB2:\^.+]] +// CHECK: [[BB1]]: +// CHECK: [[NEW_PTR:%.+]] = llvm.call @createGcGpuOclKernel_entry_kernel([[CTX]]) +// CHECK: llvm.return [[NEW_PTR]] +// CHECK: [[BB2]]: +// CHECK: llvm.return [[PTR]] + +// CHECK: llvm.func @entry +// CHECK: [[KERNEL:%.+]] = llvm.call @getGcGpuOclKernel_entry_kernel([[CTX:%.+]]) : (!llvm.ptr) -> !llvm.ptr +// CHECK: llvm.call @gcGpuOclKernelLaunch([[CTX]], [[KERNEL]], + +// CHECK: llvm.func @gcGpuOclKernelCreate +// CHECK: llvm.func @gcGpuOclKernelDestroy +// CHECK: llvm.func @gcGpuOclKernelLaunch + + +// CHECK: llvm.func @gcGpuOclModuleDestructor() +// CHECK: [[ONE:%.+]] = llvm.mlir.constant(1 : i64) : i64 +// CHECK: [[PTR_ADDR:%.+]] = llvm.mlir.addressof @gcGpuOclKernel_entry_kernel_Ptr +// CHECK: llvm.fence acquire +// CHECK: [[PTR:%.+]] = llvm.load [[PTR_ADDR]] +// CHECK: [[ARRAY:%.+]] = llvm.alloca [[ONE]] +// CHECK: llvm.store [[PTR]], [[ARRAY]] +// CHECK: llvm.call @gcGpuOclKernelDestroy([[ONE]], [[ARRAY]]) diff --git a/test/mlir/test/gc/gpu-runner/lit.local.cfg b/test/mlir/test/gc/gpu-runner/lit.local.cfg index f180dd41b..5ed13b0d2 100644 --- a/test/mlir/test/gc/gpu-runner/lit.local.cfg +++ b/test/mlir/test/gc/gpu-runner/lit.local.cfg @@ -1,2 +1,5 @@ if not config.gc_use_imex: - config.unsupported = True \ No newline at end of file + config.unsupported = True +else: + # FIXME: Enable when the GPU runner is implemented. + config.excludes = ['mlp.mlir'] From 501d434f75f4b4c5c425dd61e3fa05a5b2e21238 Mon Sep 17 00:00:00 2001 From: Andrey Pavlenko Date: Fri, 20 Sep 2024 17:18:20 +0200 Subject: [PATCH 6/8] Added support for vector types --- lib/gc/Transforms/GPU/GpuToGpuOcl.cpp | 27 +++++++++++++++++++++++---- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/lib/gc/Transforms/GPU/GpuToGpuOcl.cpp b/lib/gc/Transforms/GPU/GpuToGpuOcl.cpp index dfcd1daba..e8fa2440e 100644 --- a/lib/gc/Transforms/GPU/GpuToGpuOcl.cpp +++ b/lib/gc/Transforms/GPU/GpuToGpuOcl.cpp @@ -415,11 +415,30 @@ struct ConvertLaunch final : ConvertOpPattern { blockSize.emplace_back(gpuLaunch.getBlockSizeY()); blockSize.emplace_back(gpuLaunch.getBlockSizeZ()); - for (auto arg : adaptor.getKernelOperands()) { + for (auto arg : gpuLaunch.getKernelOperands()) { auto type = arg.getType(); - // Assuming, that the value is either an integer or a float or a pointer. - // In the latter case, the size is 0 bytes. - auto size = type.isIntOrFloat() ? type.getIntOrFloatBitWidth() / 8 : 0; + size_t size; + if (isa(type)) { + size = 0; // A special case for pointers + } else if (type.isIndex()) { + size = helper.idxType.getIntOrFloatBitWidth() / 8; + } else if (type.isIntOrFloat()) { + size = type.getIntOrFloatBitWidth() / 8; + } else if (auto vectorType = dyn_cast(type)) { + type = vectorType.getElementType(); + if (type.isIntOrFloat()) { + size = type.getIntOrFloatBitWidth(); + } else if (type.isIndex()) { + size = helper.idxType.getIntOrFloatBitWidth(); + } else { + llvm::errs() << "Unsupported vector element type: " << type << "\n"; + return false; + } + size *= vectorType.getNumElements() / 8; + } else { + llvm::errs() << "Unsupported type: " << type << "\n"; + return false; + } argSize.emplace_back(helper.idxConstant(rewriter, loc, size)); } From 443ff0e62f673f0ca475cebb2d01870300d5f182 Mon Sep 17 00:00:00 2001 From: Andrey Pavlenko Date: Tue, 24 Sep 2024 00:24:22 +0200 Subject: [PATCH 7/8] Apply changes from code review --- lib/gc/Transforms/GPU/AddContextArg.cpp | 17 ++- lib/gc/Transforms/GPU/GpuToGpuOcl.cpp | 119 ++++++++---------- lib/gc/Transforms/GPU/Pipeline.cpp | 40 +++--- .../Transforms/IterativeTilingAndFusion.cpp | 2 +- .../test/gc/Transforms/GPU/gpu-to-gpuocl.mlir | 98 +++++++++++++++ .../test/gc/gpu-runner/gpu-to-gpuocl.mlir | 63 ---------- 6 files changed, 180 insertions(+), 159 deletions(-) create mode 100644 test/mlir/test/gc/Transforms/GPU/gpu-to-gpuocl.mlir delete mode 100644 test/mlir/test/gc/gpu-runner/gpu-to-gpuocl.mlir diff --git a/lib/gc/Transforms/GPU/AddContextArg.cpp b/lib/gc/Transforms/GPU/AddContextArg.cpp index d731fbb62..037f33d77 100644 --- a/lib/gc/Transforms/GPU/AddContextArg.cpp +++ b/lib/gc/Transforms/GPU/AddContextArg.cpp @@ -20,6 +20,10 @@ namespace { struct AddContextArg final : gc::impl::AddContextArgBase { void runOnOperation() override { auto func = getOperation(); + if (func.isExternal()) { + return; + } + auto funcType = func.getFunctionType(); auto argTypes = llvm::to_vector<8>(funcType.getInputs()); auto resultTypes = llvm::to_vector<1>(funcType.getResults()); @@ -28,14 +32,19 @@ struct AddContextArg final : gc::impl::AddContextArgBase { argTypes.emplace_back(newArgType); auto newFuncType = FunctionType::get(ctx, argTypes, resultTypes); func.setType(newFuncType); - - if (func.getBody().hasOneBlock()) { - func.getBody().front().addArgument(newArgType, func.getLoc()); - } + func.getBody().front().addArgument(newArgType, func.getLoc()); // Find all function calls and append the last argument of the current // function to the call. + auto module = func->getParentOfType(); func.walk([&](func::CallOp call) { + // If the function to be called is defined in the current module, then the + // context arg will be added to this function signature either and, thus, + // wee need add the context arg to the function call. + if (auto callee = module.lookupSymbol(call.getCallee()); + !callee || callee.isExternal()) { + return; + } auto args = llvm::to_vector<8>(call.getOperands()); args.emplace_back(func.getArgument(func.getNumArguments() - 1)); call->setOperands(args); diff --git a/lib/gc/Transforms/GPU/GpuToGpuOcl.cpp b/lib/gc/Transforms/GPU/GpuToGpuOcl.cpp index e8fa2440e..f765d29cf 100644 --- a/lib/gc/Transforms/GPU/GpuToGpuOcl.cpp +++ b/lib/gc/Transforms/GPU/GpuToGpuOcl.cpp @@ -69,6 +69,30 @@ struct Helper final { rewriter.getIntegerAttr(idxType, static_cast(value))); } + Value calculateStaticSize(OpBuilder &rewriter, const Location loc, + const MemRefType type) const { + if (type.getRank() == 0) { + return idxConstant(rewriter, loc, 0); + } + + auto elementType = type.getElementType(); + if (!elementType.isIntOrIndexOrFloat()) { + return nullptr; + } + + int64_t numElements = 1; + for (auto dim : type.getShape()) { + if (dim == ShapedType::kDynamic) { + return nullptr; + } + numElements = numElements * dim; + } + auto elementSize = elementType.isIndex() + ? idxType.getIntOrFloatBitWidth() + : elementType.getIntOrFloatBitWidth(); + return idxConstant(rewriter, loc, elementSize * numElements / 8); + } + void destroyKernels(OpBuilder &rewriter, Location loc, ArrayRef kernelPtrs) const { auto size = idxConstant(rewriter, loc, kernelPtrs.size()); @@ -102,24 +126,11 @@ struct ConvertAlloc final : ConvertOpPattern { ConversionPatternRewriter &rewriter) const override { auto loc = allocOp.getLoc(); MemRefType type = allocOp.getType(); - auto shape = type.getShape(); - auto dynamics = adaptor.getDynamicSizes(); - if (shape.empty() || dynamics.empty()) { - int64_t staticSize; - if (shape.empty()) { - staticSize = 0; - } else { - staticSize = type.getElementType().getIntOrFloatBitWidth() / 8; - for (auto dim : shape) { - assert(dim != ShapedType::kDynamic); - staticSize *= dim; - } - } - auto size = helper.idxConstant(rewriter, loc, staticSize); + if (auto staticSize = helper.calculateStaticSize(rewriter, loc, type)) { auto ptr = funcCall(rewriter, GPU_OCL_MALLOC, helper.ptrType, {helper.ptrType, helper.idxType}, loc, - {getCtxPtr(rewriter), size}) + {getCtxPtr(rewriter), staticSize}) .getResult(); Value replacement = MemRefDescriptor::fromStaticShape( rewriter, loc, helper.converter, type, ptr, ptr); @@ -127,57 +138,32 @@ struct ConvertAlloc final : ConvertOpPattern { return success(); } - auto ndims = shape.size(); - SmallVector newShape; - SmallVector newStrides(ndims); - auto staticSize = type.getElementType().getIntOrFloatBitWidth() / 8; - auto size = dynamics[0]; - - auto idxMul = [&](Value x, Value y) -> Value { - if (auto xConst = getConstantIntValue(x)) { - if (auto yConst = getConstantIntValue(y)) { - return helper.idxConstant(rewriter, loc, - xConst.value() * yConst.value()); - } - } - return rewriter.create(loc, x, y); - }; - - for (size_t i = 0, j = 0; i < ndims; i++) { - auto dim = shape[i]; - if (dim == ShapedType::kDynamic) { - auto dynSize = dynamics[j++]; - newShape.emplace_back(dynSize); - if (j != 1) { - size = idxMul(size, dynSize); - } - } else { - staticSize *= dim; - newShape.emplace_back(helper.idxConstant(rewriter, loc, dim)); - } + auto dstType = helper.converter.convertType(type); + if (!dstType) { + allocOp.emitError() << "Failed to convert the MemRefType"; + return failure(); } - size = idxMul(size, helper.idxConstant(rewriter, loc, staticSize)); + SmallVector shape; + SmallVector strides; + Value size; + getMemRefDescriptorSizes(loc, type, adaptor.getDynamicSizes(), rewriter, + shape, strides, size); + assert(shape.size() == strides.size()); + auto ptr = funcCall(rewriter, GPU_OCL_MALLOC, helper.ptrType, {helper.ptrType, helper.idxType}, loc, {getCtxPtr(rewriter), size}) .getResult(); - newStrides[ndims - 1] = helper.idxConstant(rewriter, loc, 1); - for (int i = static_cast(ndims) - 2; i >= 0; i--) { - newStrides[i] = idxMul(newStrides[i + 1], newShape[i]); - ; - } - - auto dsc = MemRefDescriptor::undef(rewriter, loc, - helper.converter.convertType(type)); + auto dsc = MemRefDescriptor::undef(rewriter, loc, dstType); dsc.setAllocatedPtr(rewriter, loc, ptr); dsc.setAlignedPtr(rewriter, loc, ptr); dsc.setOffset(rewriter, loc, helper.idxConstant(rewriter, loc, 0)); - for (unsigned i = 0, n = static_cast(ndims); i < n; i++) { - dsc.setSize(rewriter, loc, i, newShape[i]); - dsc.setStride(rewriter, loc, i, newStrides[i]); + for (unsigned i = 0, n = static_cast(shape.size()); i < n; i++) { + dsc.setSize(rewriter, loc, i, shape[i]); + dsc.setStride(rewriter, loc, i, strides[i]); } rewriter.replaceOp(allocOp, static_cast(dsc)); @@ -209,23 +195,24 @@ struct ConvertMemcpy final : ConvertOpPattern { matchAndRewrite(gpu::MemcpyOp gpuMemcpy, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = gpuMemcpy.getLoc(); + MemRefDescriptor srcDsc(adaptor.getSrc()); + MemRefDescriptor dstDsc(adaptor.getDst()); auto srcType = gpuMemcpy.getSrc().getType(); - auto elementSize = srcType.getElementType().getIntOrFloatBitWidth() / 8; - uint64_t numElements = 0; - for (auto dim : srcType.getShape()) { - if (dim == ShapedType::kDynamic) { - gpuMemcpy.emitOpError() - << "dynamic shapes are not currently not supported"; - return failure(); + Value size = helper.calculateStaticSize(rewriter, loc, srcType); + + if (!size) { + auto numElements = helper.idxConstant(rewriter, loc, 1); + for (unsigned i = 0, n = srcType.getRank(); i < n; i++) { + numElements = rewriter.create( + loc, numElements, srcDsc.size(rewriter, loc, i)); } - numElements = numElements ? numElements * dim : dim; + size = rewriter.create( + loc, numElements, + getSizeInBytes(loc, srcType.getElementType(), rewriter)); } - MemRefDescriptor srcDsc(adaptor.getSrc()); - MemRefDescriptor dstDsc(adaptor.getDst()); auto srcPtr = srcDsc.alignedPtr(rewriter, loc); auto dstPtr = dstDsc.alignedPtr(rewriter, loc); - auto size = helper.idxConstant(rewriter, loc, elementSize * numElements); auto oclMemcpy = funcCall( rewriter, GPU_OCL_MEMCPY, helper.voidType, {helper.ptrType, helper.ptrType, helper.ptrType, helper.idxType}, loc, diff --git a/lib/gc/Transforms/GPU/Pipeline.cpp b/lib/gc/Transforms/GPU/Pipeline.cpp index a507e2c63..3d07d6ba8 100644 --- a/lib/gc/Transforms/GPU/Pipeline.cpp +++ b/lib/gc/Transforms/GPU/Pipeline.cpp @@ -6,45 +6,35 @@ // //===----------------------------------------------------------------------===// +#include + +#include "gc/Transforms/Passes.h" + +#include "imex/Conversion/Passes.h" +#include "imex/Transforms/Passes.h" + #include "mlir/Conversion/Passes.h" -#include "mlir/Dialect/Arith/Transforms/Passes.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" #include "mlir/Dialect/Bufferization/Transforms/Passes.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/GPU/Transforms/Passes.h" #include "mlir/Dialect/LLVMIR/Transforms/Passes.h" #include "mlir/Dialect/Linalg/Passes.h" -#include "mlir/Dialect/Math/Transforms/Passes.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/DialectRegistry.h" +#include "mlir/Dialect/SPIRV/Transforms/Passes.h" #include "mlir/InitAllPasses.h" #include "mlir/Pass/PassManager.h" -#include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/Passes.h" -#include - -#include "mlir/Dialect/GPU/IR/GPUDialect.h" -#include "mlir/Dialect/GPU/Transforms/Passes.h" -#include "mlir/Dialect/SPIRV/Transforms/Passes.h" - -#include -#include - -#include - -#include "gc/Transforms/Passes.h" namespace mlir::gc { struct GPUPipelineOption : PassPipelineOptions { - PassOptions::Option isUsmArgs{ + Option isUsmArgs{ *this, "is-usm-args", - llvm::cl::desc("Whether to use USM(unified shared memory) func args, in " - "which the host and device could access the same buffer " - "and there is no need to add memcpy explicitly"), - llvm::cl::init(true)}; + desc("Whether to use USM(unified shared memory) func args, in " + "which the host and device could access the same buffer " + "and there is no need to add memcpy explicitly"), + init(true)}; }; void populateGPUPipeline(OpPassManager &pm, diff --git a/lib/gc/Transforms/IterativeTilingAndFusion.cpp b/lib/gc/Transforms/IterativeTilingAndFusion.cpp index d94db20c3..a486c29b0 100644 --- a/lib/gc/Transforms/IterativeTilingAndFusion.cpp +++ b/lib/gc/Transforms/IterativeTilingAndFusion.cpp @@ -680,7 +680,7 @@ defaultTilingOfType(RewriterBase &rewriter, Operation *op, } else { defaultTileSize.resize(iteratorTypes.size(), rewriter.getIndexAttr(0)); // Try tileSize from `32` to `16`. - SmallVector tsOrder = {16, 32}; + SmallVector tsOrder = {32, 16}; // Record how many dims have been tiled, including fully tiled, i.e. // tileSize == dimSize. unsigned nonOneTileDims = diff --git a/test/mlir/test/gc/Transforms/GPU/gpu-to-gpuocl.mlir b/test/mlir/test/gc/Transforms/GPU/gpu-to-gpuocl.mlir new file mode 100644 index 000000000..8a5571571 --- /dev/null +++ b/test/mlir/test/gc/Transforms/GPU/gpu-to-gpuocl.mlir @@ -0,0 +1,98 @@ +// RUN: gc-opt %s --gpu-to-gpuocl | FileCheck %s + +module @test attributes {gpu.container_module} { + llvm.func @entry(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64, %arg7: !llvm.ptr, %arg8: !llvm.ptr, %arg9: i64) attributes {llvm.emit_c_interface} { + %0 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + %1 = llvm.insertvalue %arg0, %0[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + %2 = llvm.insertvalue %arg1, %1[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + %3 = llvm.insertvalue %arg2, %2[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + %4 = llvm.insertvalue %arg3, %3[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + %5 = llvm.insertvalue %arg4, %4[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + %6 = llvm.insertvalue %arg5, %5[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + %7 = llvm.insertvalue %arg6, %6[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + %8 = builtin.unrealized_conversion_cast %7 : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<64x64xf32> + %gpu_mem = gpu.alloc host_shared () : memref<64x64xf32> + gpu.memcpy %gpu_mem, %8 : memref<64x64xf32>, memref<64x64xf32> + %9 = llvm.mlir.constant(32 : index) : i64 + %10 = builtin.unrealized_conversion_cast %9 : i64 to index + %11 = llvm.mlir.constant(2 : index) : i64 + %12 = builtin.unrealized_conversion_cast %11 : i64 to index + %13 = llvm.mlir.constant(1 : index) : i64 + %14 = builtin.unrealized_conversion_cast %13 : i64 to index + gpu.launch_func @entry_kernel::@entry_kernel blocks in (%12, %12, %14) threads in (%14, %14, %14) args(%10 : index, %gpu_mem : memref<64x64xf32>) + gpu.memcpy %8, %gpu_mem : memref<64x64xf32>, memref<64x64xf32> + gpu.dealloc %gpu_mem : memref<64x64xf32> + llvm.return + } + + gpu.module @entry_kernel attributes {gpu.binary = "Some SPIRV here \00"} { + gpu.func @entry_kernel(%arg0: index, %arg1: memref<64x64xf32>) kernel attributes {} { + gpu.return + } + } +} + +// CHECK: llvm.mlir.global internal constant @gcGpuOclKernel_entry_kernel_SPIRV +// CHECK: llvm.mlir.global internal constant @gcGpuOclKernel_entry_kernel_Name +// CHECK: llvm.mlir.global internal @gcGpuOclKernel_entry_kernel_Ptr + +// CHECK: llvm.func internal @createGcGpuOclKernel_entry_kernel([[CTX:%.+]]: !llvm.ptr) -> !llvm.ptr +// CHECK: [[NEW_PTR:%.+]] = llvm.call @gcGpuOclKernelCreate([[CTX]] +// CHECK: [[ZERO:%.+]] = llvm.mlir.zero +// CHECK: [[PTR_ADDR:%.+]] = llvm.mlir.addressof @gcGpuOclKernel_entry_kernel_Ptr +// CHECK: [[CMPXCHG:%.+]] = llvm.cmpxchg [[PTR_ADDR]], [[ZERO]], [[NEW_PTR]] +// CHECK: [[FLAG:%.+]] = llvm.extractvalue [[CMPXCHG]][1] +// CHECK: llvm.cond_br [[FLAG]], [[BB1:\^.+]], [[BB2:\^.+]] +// CHECK: [[BB1]]: +// CHECK: llvm.return [[NEW_PTR]] +// CHECK: [[BB2]]: +// CHECK: [[ONE:%.+]] = llvm.mlir.constant(1 : i64) : i64 +// CHECK: [[ARRAY:%.+]] = llvm.alloca [[ONE]] +// CHECK: [[ADDR:%.+]] = llvm.getelementptr [[ARRAY]] +// CHECK: llvm.store [[NEW_PTR]], [[ADDR]] +// CHECK: llvm.call @gcGpuOclKernelDestroy([[ONE]], [[ARRAY]]) +// CHECK: [[OLD_PTR:%.+]] = llvm.extractvalue [[CMPXCHG]][0] +// CHECK: llvm.return [[OLD_PTR]] + +// CHECK: llvm.func internal @getGcGpuOclKernel_entry_kernel([[CTX:%.+]]: !llvm.ptr) -> !llvm.ptr attributes {always_inline} +// CHECK: [[ZERO:%.+]] = llvm.mlir.zero +// CHECK: [[PTR_ADDR:%.+]] = llvm.mlir.addressof @gcGpuOclKernel_entry_kernel_Ptr +// CHECK: [[PTR:%.+]] = llvm.load [[PTR_ADDR]] +// CHECK: [[ICMP:%.+]] = llvm.icmp "eq" [[PTR]], [[ZERO]] +// CHECK: llvm.cond_br [[ICMP]], [[BB1:\^.+]], [[BB2:\^.+]] +// CHECK: [[BB1]]: +// CHECK: [[NEW_PTR:%.+]] = llvm.call @createGcGpuOclKernel_entry_kernel([[CTX]]) +// CHECK: llvm.return [[NEW_PTR]] +// CHECK: [[BB2]]: +// CHECK: llvm.return [[PTR]] + +// CHECK: llvm.func @entry(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64, [[CTX:%.+]]: !llvm.ptr, %arg8: !llvm.ptr, %arg9: i64) +// CHECK: [[SIZE:%.+]] = llvm.mlir.constant(16384 : i64) : i64 +// CHECK: llvm.call @gcGpuOclMalloc([[CTX]], [[SIZE]]) +// CHECK: [[SIZE:%.+]] = llvm.mlir.constant(16384 : i64) : i64 +// CHECK: [[SRC:%.+]] = llvm.extractvalue +// CHECK: [[DST:%.+]] = llvm.extractvalue [[GPU_MEMREF:%.+]][1] +// CHECK: llvm.call @gcGpuOclMemcpy([[CTX]], [[SRC]], [[DST]], [[SIZE]]) +// CHECK: [[KERNEL:%.+]] = llvm.call @getGcGpuOclKernel_entry_kernel([[CTX:%.+]]) : (!llvm.ptr) -> !llvm.ptr +// CHECK: llvm.call @gcGpuOclKernelLaunch([[CTX]], [[KERNEL]], +// CHECK: [[SIZE:%.+]] = llvm.mlir.constant(16384 : i64) : i64 +// CHECK: [[SRC:%.+]] = llvm.extractvalue [[GPU_MEMREF:%.+]][1] +// CHECK: [[DST:%.+]] = llvm.extractvalue +// CHECK: llvm.call @gcGpuOclMemcpy([[CTX]], [[SRC]], [[DST]], [[SIZE]]) +// CHECK: [[GPU_PTR:%.+]] = llvm.extractvalue [[GPU_MEMREF:%.+]][0] +// CHECK: llvm.call @gcGpuOclDealloc([[CTX]], [[GPU_PTR]]) + +// CHECK: llvm.func @gcGpuOclKernelCreate +// CHECK: llvm.func @gcGpuOclKernelDestroy +// CHECK: llvm.func @gcGpuOclKernelLaunch + + +// CHECK: llvm.func @gcGpuOclModuleDestructor() +// CHECK: llvm.fence acquire +// CHECK: [[PTR_ADDR:%.+]] = llvm.mlir.addressof @gcGpuOclKernel_entry_kernel_Ptr +// CHECK: [[PTR:%.+]] = llvm.load [[PTR_ADDR]] +// CHECK: [[ONE:%.+]] = llvm.mlir.constant(1 : i64) : i64 +// CHECK: [[ARRAY:%.+]] = llvm.alloca [[ONE]] +// CHECK: [[ADDR:%.+]] = llvm.getelementptr [[ARRAY]] +// CHECK: llvm.store [[PTR]], [[ADDR]] +// CHECK: llvm.call @gcGpuOclKernelDestroy([[ONE]], [[ARRAY]]) diff --git a/test/mlir/test/gc/gpu-runner/gpu-to-gpuocl.mlir b/test/mlir/test/gc/gpu-runner/gpu-to-gpuocl.mlir deleted file mode 100644 index 7742b8d19..000000000 --- a/test/mlir/test/gc/gpu-runner/gpu-to-gpuocl.mlir +++ /dev/null @@ -1,63 +0,0 @@ -// RUN: gc-opt %s --gc-gpu-pipeline | FileCheck %s - -module @test { - func.func @entry(%arg0: memref<32x32xf32>, %arg1: memref<32x32xf32>, %arg2: memref<32x32xf32>) { - %0 = bufferization.to_tensor %arg0 restrict : memref<32x32xf32> - %1 = bufferization.to_tensor %arg1 restrict : memref<32x32xf32> - %2 = tensor.empty() : tensor<32x32xf32> - %3 = linalg.add ins(%1, %0 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%2 : tensor<32x32xf32>) -> tensor<32x32xf32> - bufferization.materialize_in_destination %3 in restrict writable %arg2 : (tensor<32x32xf32>, memref<32x32xf32>) -> () - return - } -} - -// CHECK: llvm.mlir.global internal constant @gcGpuOclKernel_entry_kernel_SPIRV -// CHECK: llvm.mlir.global internal constant @gcGpuOclKernel_entry_kernel_Name -// CHECK: llvm.mlir.global internal @gcGpuOclKernel_entry_kernel_Ptr - -// CHECK: llvm.func internal @createGcGpuOclKernel_entry_kernel([[CTX:%.+]]: !llvm.ptr) -> !llvm.ptr -// CHECK: [[PTR_ADDR:%.+]] = llvm.mlir.addressof @gcGpuOclKernel_entry_kernel_Ptr -// CHECK: [[ZERO:%.+]] = llvm.mlir.zero -// CHECK: [[ONE:%.+]] = llvm.mlir.constant(1 : i64) : i64 -// CHECK: [[NEW_PTR:%.+]] = llvm.call @gcGpuOclKernelCreate([[CTX]] -// CHECK: [[CMPXCHG:%.+]] = llvm.cmpxchg [[PTR_ADDR]], [[ZERO]], [[NEW_PTR]] -// CHECK: [[FLAG:%.+]] = llvm.extractvalue [[CMPXCHG]][1] -// CHECK: llvm.cond_br [[FLAG]], [[BB1:\^.+]], [[BB2:\^.+]] -// CHECK: [[BB1]]: -// CHECK: llvm.return [[NEW_PTR]] -// CHECK: [[BB2]]: -// CHECK: [[ARRAY:%.+]] = llvm.alloca [[ONE]] -// CHECK: llvm.store [[NEW_PTR]], [[ARRAY]] -// CHECK: llvm.call @gcGpuOclKernelDestroy([[ONE]], [[ARRAY]]) -// CHECK: [[OLD_PTR:%.+]] = llvm.extractvalue [[CMPXCHG]][0] -// CHECK: llvm.return [[OLD_PTR]] - -// CHECK: llvm.func internal @getGcGpuOclKernel_entry_kernel([[CTX:%.+]]: !llvm.ptr) -> !llvm.ptr attributes {always_inline} -// CHECK: [[ZERO:%.+]] = llvm.mlir.zero -// CHECK: [[PTR_ADDR:%.+]] = llvm.mlir.addressof @gcGpuOclKernel_entry_kernel_Ptr -// CHECK: [[PTR:%.+]] = llvm.load [[PTR_ADDR]] -// CHECK: [[ICMP:%.+]] = llvm.icmp "eq" [[PTR]], [[ZERO]] -// CHECK: llvm.cond_br [[ICMP]], [[BB1:\^.+]], [[BB2:\^.+]] -// CHECK: [[BB1]]: -// CHECK: [[NEW_PTR:%.+]] = llvm.call @createGcGpuOclKernel_entry_kernel([[CTX]]) -// CHECK: llvm.return [[NEW_PTR]] -// CHECK: [[BB2]]: -// CHECK: llvm.return [[PTR]] - -// CHECK: llvm.func @entry -// CHECK: [[KERNEL:%.+]] = llvm.call @getGcGpuOclKernel_entry_kernel([[CTX:%.+]]) : (!llvm.ptr) -> !llvm.ptr -// CHECK: llvm.call @gcGpuOclKernelLaunch([[CTX]], [[KERNEL]], - -// CHECK: llvm.func @gcGpuOclKernelCreate -// CHECK: llvm.func @gcGpuOclKernelDestroy -// CHECK: llvm.func @gcGpuOclKernelLaunch - - -// CHECK: llvm.func @gcGpuOclModuleDestructor() -// CHECK: [[ONE:%.+]] = llvm.mlir.constant(1 : i64) : i64 -// CHECK: [[PTR_ADDR:%.+]] = llvm.mlir.addressof @gcGpuOclKernel_entry_kernel_Ptr -// CHECK: llvm.fence acquire -// CHECK: [[PTR:%.+]] = llvm.load [[PTR_ADDR]] -// CHECK: [[ARRAY:%.+]] = llvm.alloca [[ONE]] -// CHECK: llvm.store [[PTR]], [[ARRAY]] -// CHECK: llvm.call @gcGpuOclKernelDestroy([[ONE]], [[ARRAY]]) From b31e64d4973d48412022488f2597cc98b4ce10df Mon Sep 17 00:00:00 2001 From: Andrey Pavlenko Date: Tue, 24 Sep 2024 16:59:31 +0200 Subject: [PATCH 8/8] Removed the path from pipeline --- lib/gc/Transforms/GPU/Pipeline.cpp | 49 +++++++++++-------- .../test/gc/gpu-runner/XeGPU/lit.local.cfg | 2 - test/mlir/test/gc/gpu-runner/lit.local.cfg | 5 +- 3 files changed, 29 insertions(+), 27 deletions(-) delete mode 100644 test/mlir/test/gc/gpu-runner/XeGPU/lit.local.cfg diff --git a/lib/gc/Transforms/GPU/Pipeline.cpp b/lib/gc/Transforms/GPU/Pipeline.cpp index 3d07d6ba8..d7bc69e13 100644 --- a/lib/gc/Transforms/GPU/Pipeline.cpp +++ b/lib/gc/Transforms/GPU/Pipeline.cpp @@ -6,42 +6,49 @@ // //===----------------------------------------------------------------------===// -#include - -#include "gc/Transforms/Passes.h" - -#include "imex/Conversion/Passes.h" -#include "imex/Transforms/Passes.h" - #include "mlir/Conversion/Passes.h" +#include "mlir/Dialect/Arith/Transforms/Passes.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" #include "mlir/Dialect/Bufferization/Transforms/Passes.h" -#include "mlir/Dialect/GPU/IR/GPUDialect.h" -#include "mlir/Dialect/GPU/Transforms/Passes.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/Transforms/Passes.h" #include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Math/Transforms/Passes.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" -#include "mlir/Dialect/SPIRV/Transforms/Passes.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/DialectRegistry.h" #include "mlir/InitAllPasses.h" #include "mlir/Pass/PassManager.h" +#include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/Passes.h" +#include + +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/GPU/Transforms/Passes.h" +#include "mlir/Dialect/SPIRV/Transforms/Passes.h" + +#include +#include + +#include + +#include "gc/Transforms/Passes.h" namespace mlir::gc { struct GPUPipelineOption : PassPipelineOptions { - Option isUsmArgs{ + PassOptions::Option isUsmArgs{ *this, "is-usm-args", - desc("Whether to use USM(unified shared memory) func args, in " - "which the host and device could access the same buffer " - "and there is no need to add memcpy explicitly"), - init(true)}; + llvm::cl::desc("Whether to use USM(unified shared memory) func args, in " + "which the host and device could access the same buffer " + "and there is no need to add memcpy explicitly"), + llvm::cl::init(true)}; }; -void populateGPUPipeline(OpPassManager &pm, +void populateGPUPipeline(mlir::OpPassManager &pm, const GPUPipelineOption &pipelineOption) { - // Add an argument for the GPU context - pm.addNestedPass(createAddContextArg()); - pm.addNestedPass(createIterativeTilingAndFusion()); pm.addPass(bufferization::createEmptyTensorEliminationPass()); @@ -84,7 +91,6 @@ void populateGPUPipeline(OpPassManager &pm, /*isUsmArgs*/ pipelineOption.isUsmArgs.getValue()}; pm.addNestedPass( imex::createInsertGPUAllocsPass(insertGPUAllocsOption)); - pm.addPass(createGpuKernelOutliningPass()); pm.addPass(createCanonicalizerPass()); pm.addPass(imex::createSetSPIRVCapabilitiesPass()); @@ -103,6 +109,7 @@ void populateGPUPipeline(OpPassManager &pm, pm.addNestedPass(LLVM::createRequestCWrappersPass()); pm.addPass(imex::createSerializeSPIRVPass()); pm.addPass(createConvertVectorToSCFPass()); + pm.addPass(imex::createConvertGPUToGPUXPass()); pm.addPass(createConvertSCFToCFPass()); pm.addPass(createConvertControlFlowToLLVMPass()); pm.addPass(createConvertVectorToLLVMPass()); @@ -110,7 +117,7 @@ void populateGPUPipeline(OpPassManager &pm, pm.addPass(createArithToLLVMConversionPass()); pm.addPass(createConvertFuncToLLVMPass()); pm.addPass(createConvertMathToLLVMPass()); - pm.addPass(createGpuToGpuOcl()); + pm.addPass(imex::createConvertGPUXToLLVMPass()); pm.addPass(createConvertIndexToLLVMPass()); pm.addPass(memref::createExpandStridedMetadataPass()); pm.addPass(createLowerAffinePass()); diff --git a/test/mlir/test/gc/gpu-runner/XeGPU/lit.local.cfg b/test/mlir/test/gc/gpu-runner/XeGPU/lit.local.cfg deleted file mode 100644 index 152c26255..000000000 --- a/test/mlir/test/gc/gpu-runner/XeGPU/lit.local.cfg +++ /dev/null @@ -1,2 +0,0 @@ -# GPUX is currently disabled -config.unsupported = True diff --git a/test/mlir/test/gc/gpu-runner/lit.local.cfg b/test/mlir/test/gc/gpu-runner/lit.local.cfg index 5ed13b0d2..f180dd41b 100644 --- a/test/mlir/test/gc/gpu-runner/lit.local.cfg +++ b/test/mlir/test/gc/gpu-runner/lit.local.cfg @@ -1,5 +1,2 @@ if not config.gc_use_imex: - config.unsupported = True -else: - # FIXME: Enable when the GPU runner is implemented. - config.excludes = ['mlp.mlir'] + config.unsupported = True \ No newline at end of file