-
Notifications
You must be signed in to change notification settings - Fork 17
Lower SLM to XeGPU #409
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Lower SLM to XeGPU #409
Conversation
Signed-off-by: dchigarev <[email protected]>
Signed-off-by: dchigarev <[email protected]>
fbbe969
to
f63e8f6
Compare
Signed-off-by: dchigarev <[email protected]>
Signed-off-by: dchigarev <[email protected]>
Signed-off-by: dchigarev <[email protected]>
pm.addNestedPass<func::FuncOp>(createLinalgToXeGPU( | ||
{/*kTile=*/16, /*stages=*/1, /*dpasTiles=*/{8, 16, 16}})); | ||
pm.addPass(createCSEPass()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added CSE
pass to minimize impact of the insert/extract manipulations with vectors
|
||
imex::InsertGPUAllocsOptions insertGPUAllocsOption{ | ||
/*clientAPI*/ "opencl", /*inRegions*/ false, | ||
/*isUsmArgs*/ pipelineOpts.isUsmArgs}; | ||
pm.addNestedPass<func::FuncOp>( | ||
imex::createInsertGPUAllocsPass(insertGPUAllocsOption)); | ||
pm.addPass(createGpuKernelOutliningPass()); | ||
pm.addPass(createCanonicalizerPass()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Canonicalizer converts vector.from_elements [%val, %val, ... %val]
into vector.splat %val
that causes the imex::ConvertGPUXToSPIRVPass
to fail (it seems it doesn't support vector.splat
). So removed canonicalizer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, some questions and comments inlined
if (!sgMap) { | ||
// Assuming default tensor descriptor type (blocked & in global memory). | ||
return xegpu::TensorDescType::get(shape, elementType, /*array_length=*/1, | ||
/*boundary_check=*/true); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sgmap shouldn't have anything to do with the type of the descriptor type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there are two types of tensor attributes that are called sg_map
in the implementation of xeGPU dialect:
- ScatterTensorDescAttr - for scatter descriptors
- BlockTensorDescAttr - for block descriptors
They describe two kinds (the type is the same indeed) of descriptors and the kind depends on sg_map
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sg_map has nothing to do with the tensor descriptor attribute (they are not called sg_map), it is a separate attribute that describes data chunks access by individual threads withing a subgroup
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, okay :)
renamed sgMap -> descAttr
@@ -150,5 +151,47 @@ std::pair<Value, Value> getPtrAndOffset(OpBuilder &builder, Value operand) { | |||
return std::make_pair(alignedPointer, offset); | |||
} | |||
|
|||
Value flattenMemref(PatternRewriter &rewriter, Location loc, Value srcMemref) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I saw something very similar in LowerQuantOps.cpp. Maybe reuse is possible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I saw something very similar in LowerQuantOps.cpp
We don't have this file in our project. What you're referring to?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, okay, found it in LLVM.
They flatten tensors there, not memrefs
assert(llvm::all_of(storeTiles, | ||
[&](Value tile) { return tile.getType() == tileType; }) && | ||
"All load tiles must have the same type."); | ||
assert(tileType.getShape().size() == 1 && "Scatter tiles must be 1D"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this also coming from lowering restrictions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would say it's xegpu limitation. SLM for f16 can only be loaded/stored via 1D scatter descriptors
// Do we need those for SLM? | ||
/*l1_hint=*/hint, /*l2_hint=*/hint, /*l3_hint=*/hint); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure, will need to double-check
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
well, if nothing crashes with them, i think we can keep them :D
// The shape to be loaded is split into the largest 2D loads supported | ||
// by the hardware. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what happens to, say, 1d tensors?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know. I would assume it will crash in the exact same way as the current linalg-to-xegpu lowering does
An attempt to use linalg-to-xegpu pass with 1D tensors/memrefs on the current main branch
gc-opt: /home/jovyan/llvm/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp:83: static void mlir::xegpu::CreateNdDescOp::build(mlir::OpBuilder&, mlir::OperationState&, mlir::Type, mlir::TypedValue<mlir::MemRefType>, llvm::ArrayRef<mlir::OpFoldResult>): Assertion `ty.hasStaticShape() && offsets.size() == (size_t)ty.getRank()' failed.
PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
Stack dump:
0. Program arguments: ./bin/gc-opt /home/jovyan/graph-compiler/test/mlir/test/gc/Transforms/GPU/linalg-to-xegpu1d.mlir "-linalg-to-xegpu=dpas-tile=8,16,16 k-tile=16" -canonicalize -split-input-file
#0 0x00005571ec59bb30 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) (./bin/gc-opt+0x589bb30)
#1 0x00005571ec598f3f llvm::sys::RunSignalHandlers() (./bin/gc-opt+0x5898f3f)
#2 0x00005571ec599095 SignalHandler(int) Signals.cpp:0:0
#3 0x00007fd97d43f520 (/lib/x86_64-linux-gnu/libc.so.6+0x42520)
#4 0x00007fd97d4939fc __pthread_kill_implementation ./nptl/pthread_kill.c:44:76
#5 0x00007fd97d4939fc __pthread_kill_internal ./nptl/pthread_kill.c:78:10
#6 0x00007fd97d4939fc pthread_kill ./nptl/pthread_kill.c:89:10
#7 0x00007fd97d43f476 gsignal ./signal/../sysdeps/posix/raise.c:27:6
#8 0x00007fd97d4257f3 abort ./stdlib/abort.c:81:7
#9 0x00007fd97d42571b _nl_load_domain ./intl/loadmsgcat.c:1177:9
#10 0x00007fd97d436e96 (/lib/x86_64-linux-gnu/libc.so.6+0x39e96)
#11 0x00005571e931f149 mlir::xegpu::CreateNdDescOp::build(mlir::OpBuilder&, mlir::OperationState&, mlir::Type, mlir::detail::TypedValue<mlir::MemRefType>, llvm::ArrayRef<mlir::OpFoldResult>) (./bin/gc-opt+0x261f149)
#12 0x00005571e9a4de94 mlir::xegpu::CreateNdDescOp mlir::OpBuilder::create<mlir::xegpu::CreateNdDescOp, mlir::xegpu::TensorDescType&, mlir::detail::TypedValue<mlir::MemRefType>, llvm::SmallVector<mlir::OpFoldResult, 6u>&>(mlir::Location, mlir::xegpu::TensorDescType&, mlir::detail::TypedValue<mlir::MemRefType>&&, llvm::SmallVector<mlir::OpFoldResult, 6u>&) /home/jovyan/llvm/llvm-install-imex-17_oct/include/mlir/IR/Builders.h:517:22
#13 0x00005571e9a2af4e (anonymous namespace)::createDescriptorTiles(mlir::PatternRewriter&, mlir::Location, mlir::Value, llvm::ArrayRef<long>, llvm::ArrayRef<long>, llvm::ArrayRef<long>, int, bool) /home/jovyan/graph-compiler/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp:661:0
#14 0x00005571e9a2b585 (anonymous namespace)::createCoarseDscTiles(mlir::PatternRewriter&, mlir::Location, mlir::Value, llvm::ArrayRef<long>, bool, bool) /home/jovyan/graph-compiler/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp:735:0
#15 0x00005571e9a2fa31 (anonymous namespace)::createEltwiseKernel(mlir::linalg::LinalgOp, mlir::PatternRewriter&) /home/jovyan/graph-compiler/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp:1321:0
#16 0x00005571e9a42f80 (anonymous namespace)::ConvertNamedEltwiseToXeGPU<mlir::linalg::AddOp>::matchAndRewrite(mlir::linalg::AddOp, mlir::PatternRewriter&) const /home/jovyan/graph-compiler/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp:1494:0
#17 0x00005571e9a6590c mlir::detail::OpOrInterfaceRewritePatternBase<mlir::linalg::AddOp>::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&) const /home/jovyan/llvm/llvm-install-imex-17_oct/include/mlir/IR/PatternMatch.h:332:3
#18 0x00005571ec0c6bc8 mlir::PatternApplicator::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&, llvm::function_ref<bool (mlir::Pattern const&)>, llvm::function_ref<void (mlir::Pattern const&)>, llvm::function_ref<llvm::LogicalResult (mlir::Pattern const&)>) (./bin/gc-opt+0x53c6bc8)
#19 0x00005571ec08f3de (anonymous namespace)::GreedyPatternRewriteDriver::processWorklist() GreedyPatternRewriteDriver.cpp:0:0
#20 0x00005571ec091be5 mlir::applyPatternsAndFoldGreedily(mlir::Region&, mlir::FrozenRewritePatternSet const&, mlir::GreedyRewriteConfig, bool*) (./bin/gc-opt+0x5391be5)
#21 0x00005571e9915394 mlir::applyPatternsAndFoldGreedily(mlir::Operation*, mlir::FrozenRewritePatternSet const&, mlir::GreedyRewriteConfig, bool*) /home/jovyan/llvm/llvm-install-imex-17_oct/include/mlir/Transforms/GreedyPatternRewriteDriver.h:159:37
#22 0x00005571e9a30f6e (anonymous namespace)::LinalgToXeGPU::runOnOperation() /home/jovyan/graph-compiler/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp:1649:0
#23 0x00005571ec1c3479 mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) (./bin/gc-opt+0x54c3479)
#24 0x00005571ec1c3931 mlir::detail::OpToOpPassAdaptor::runPipeline(mlir::OpPassManager&, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) (./bin/gc-opt+0x54c3931)
#25 0x00005571ec1c3cd6 mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::'lambda'(mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo&)::operator()(mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo&) const Pass.cpp:0:0
#26 0x00005571ec1c29a5 mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool) (./bin/gc-opt+0x54c29a5)
#27 0x00005571ec1c3280 mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) (./bin/gc-opt+0x54c3280)
#28 0x00005571ec1c3931 mlir::detail::OpToOpPassAdaptor::runPipeline(mlir::OpPassManager&, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) (./bin/gc-opt+0x54c3931)
#29 0x00005571ec1c4995 mlir::PassManager::run(mlir::Operation*) (./bin/gc-opt+0x54c4995)
#30 0x00005571e98c5217 performActions(llvm::raw_ostream&, std::shared_ptr<llvm::SourceMgr> const&, mlir::MLIRContext*, mlir::MlirOptMainConfig const&) MlirOptMain.cpp:0:0
#31 0x00005571e98c5c2c processBuffer(llvm::raw_ostream&, std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, mlir::MlirOptMainConfig const&, mlir::DialectRegistry&, llvm::ThreadPoolInterface*) MlirOptMain.cpp:0:0
#32 0x00005571e98c5d8d llvm::LogicalResult llvm::function_ref<llvm::LogicalResult (std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&)>::callback_fn<mlir::MlirOptMain(llvm::raw_ostream&, std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&)::'lambda'(std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&)>(long, std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&) MlirOptMain.cpp:0:0
#33 0x00005571ec467b1f mlir::splitAndProcessBuffer(std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, llvm::function_ref<llvm::LogicalResult (std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&)>, llvm::raw_ostream&, llvm::StringRef, llvm::StringRef)::'lambda'(llvm::StringRef)::operator()(llvm::StringRef) const ToolUtilities.cpp:0:0
#34 0x00005571ec468472 mlir::splitAndProcessBuffer(std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, llvm::function_ref<llvm::LogicalResult (std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&)>, llvm::raw_ostream&, llvm::StringRef, llvm::StringRef) (./bin/gc-opt+0x5768472)
#35 0x00005571e98bd56c mlir::MlirOptMain(llvm::raw_ostream&, std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&) (./bin/gc-opt+0x2bbd56c)
#36 0x00005571e98c5ef0 mlir::MlirOptMain(int, char**, llvm::StringRef, llvm::StringRef, mlir::DialectRegistry&) (./bin/gc-opt+0x2bc5ef0)
#37 0x00005571e98c6417 mlir::MlirOptMain(int, char**, llvm::StringRef, mlir::DialectRegistry&) (./bin/gc-opt+0x2bc6417)
#38 0x00005571e70d410c main /home/jovyan/graph-compiler/src/gc-opt/gc-opt.cpp:75:0
#39 0x00007fd97d426d90 __libc_start_call_main ./csu/../sysdeps/nptl/libc_start_call_main.h:58:16
#40 0x00007fd97d426e40 call_init ./csu/../csu/libc-start.c:128:20
#41 0x00007fd97d426e40 __libc_start_main ./csu/../csu/libc-start.c:379:5
#42 0x00005571e70d3cf5 _start (./bin/gc-opt+0x3d3cf5)
Aborted (core dumped)
(1D is not supported by linalg-to-xegpu)
Signed-off-by: dchigarev <[email protected]>
Signed-off-by: dchigarev <[email protected]>
Fixes #394. The PR adds support for SLM to
linalg-to-xegpu
pass.SLM requires special handling in XeGPU, we can only load it via scatter descriptors and only 32 elements per load (more info in the issue).
The flow of working with SLM in XeGPU is the following:
Creating descriptors
memref.reinterpret_cast
since scatter descriptors only work with 1D memrefsimex::ConvertGPUXToSPIRV
pass doesn't allowmemref.subview
s inside a gpu kernel, we have to mergesubview
offsets with the offsets for the root xegpu.descriptor. So the step 2 is basically to compute offsets for the beginning of the SLM block for this thread.Do we merge `subview` offsets for block-descriptors as well?
Yes. There's a separate pass in upstream (XeGPUFoldAliasOps) that does it. It only works with blocked descriptors though, meaning that for scattered ones we have to implement the logic on our own.MLIR example
Loading data
MLIR example
Storing data
MLIR example
As you can notice there's a lot of efforts required to load/store tiles from SLM. Even loading/storing a single 16x16 block requires 8 loads + 8 vector.insert ops + 8 stores + 8 vector.extract_strided_slice ops. It seems that it won't perform very well and that we should avoid using SLM where possible (through ops-fusion for example)