Skip to content

Commit 1775b98

Browse files
authored
[mlir][spirv] Add spirv-to-llvm conversion for OpControlBarrier (#111864)
The conversion is based on the expected llvm function from the LLVM/SPIRV translation tool.
1 parent 1bbf3a3 commit 1775b98

File tree

4 files changed

+94
-3
lines changed

4 files changed

+94
-3
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBarrierOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def SPIRV_ControlBarrierOp : SPIRV_Op<"ControlBarrier", []> {
5454
#### Example:
5555

5656
```mlir
57-
spirv.ControlBarrier "Workgroup", "Device", "Acquire|UniformMemory"
57+
spirv.ControlBarrier <Workgroup>, <Device>, <Acquire|UniformMemory>
5858
```
5959
}];
6060

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMiscOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===-- SPIRVBarrierOps.td - MLIR SPIR-V Barrier Ops -------*- tablegen -*-===//
1+
//===-- SPIRVMiscOps.td - MLIR SPIR-V Misc Ops -------------*- tablegen -*-===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.

mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1024,6 +1024,71 @@ class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> {
10241024
}
10251025
};
10261026

1027+
static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable,
1028+
StringRef name,
1029+
ArrayRef<Type> paramTypes,
1030+
Type resultType) {
1031+
auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>(
1032+
SymbolTable::lookupSymbolIn(symbolTable, name));
1033+
if (func)
1034+
return func;
1035+
1036+
OpBuilder b(symbolTable->getRegion(0));
1037+
func = b.create<LLVM::LLVMFuncOp>(
1038+
symbolTable->getLoc(), name,
1039+
LLVM::LLVMFunctionType::get(resultType, paramTypes));
1040+
func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
1041+
func.setConvergent(true);
1042+
func.setNoUnwind(true);
1043+
func.setWillReturn(true);
1044+
return func;
1045+
}
1046+
1047+
static LLVM::CallOp createSPIRVBuiltinCall(Location loc, OpBuilder &builder,
1048+
LLVM::LLVMFuncOp func,
1049+
ValueRange args) {
1050+
auto call = builder.create<LLVM::CallOp>(loc, func, args);
1051+
call.setCConv(func.getCConv());
1052+
call.setConvergentAttr(func.getConvergentAttr());
1053+
call.setNoUnwindAttr(func.getNoUnwindAttr());
1054+
call.setWillReturnAttr(func.getWillReturnAttr());
1055+
return call;
1056+
}
1057+
1058+
class ControlBarrierPattern
1059+
: public SPIRVToLLVMConversion<spirv::ControlBarrierOp> {
1060+
public:
1061+
using SPIRVToLLVMConversion<spirv::ControlBarrierOp>::SPIRVToLLVMConversion;
1062+
1063+
LogicalResult
1064+
matchAndRewrite(spirv::ControlBarrierOp controlBarrierOp, OpAdaptor adaptor,
1065+
ConversionPatternRewriter &rewriter) const override {
1066+
constexpr StringLiteral funcName = "_Z22__spirv_ControlBarrieriii";
1067+
Operation *symbolTable =
1068+
controlBarrierOp->getParentWithTrait<OpTrait::SymbolTable>();
1069+
1070+
Type i32 = rewriter.getI32Type();
1071+
1072+
Type voidTy = rewriter.getType<LLVM::LLVMVoidType>();
1073+
LLVM::LLVMFuncOp func =
1074+
lookupOrCreateSPIRVFn(symbolTable, funcName, {i32, i32, i32}, voidTy);
1075+
1076+
Location loc = controlBarrierOp->getLoc();
1077+
Value execution = rewriter.create<LLVM::ConstantOp>(
1078+
loc, i32, static_cast<int32_t>(adaptor.getExecutionScope()));
1079+
Value memory = rewriter.create<LLVM::ConstantOp>(
1080+
loc, i32, static_cast<int32_t>(adaptor.getMemoryScope()));
1081+
Value semantics = rewriter.create<LLVM::ConstantOp>(
1082+
loc, i32, static_cast<int32_t>(adaptor.getMemorySemantics()));
1083+
1084+
auto call = createSPIRVBuiltinCall(loc, rewriter, func,
1085+
{execution, memory, semantics});
1086+
1087+
rewriter.replaceOp(controlBarrierOp, call);
1088+
return success();
1089+
}
1090+
};
1091+
10271092
/// Converts `spirv.mlir.loop` to LLVM dialect. All blocks within selection
10281093
/// should be reachable for conversion to succeed. The structure of the loop in
10291094
/// LLVM dialect will be the following:
@@ -1648,7 +1713,10 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
16481713
ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,
16491714

16501715
// Return ops
1651-
ReturnPattern, ReturnValuePattern>(patterns.getContext(), typeConverter);
1716+
ReturnPattern, ReturnValuePattern,
1717+
1718+
// Barrier ops
1719+
ControlBarrierPattern>(patterns.getContext(), typeConverter);
16521720

16531721
patterns.add<GlobalVariablePattern>(clientAPI, patterns.getContext(),
16541722
typeConverter);
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// RUN: mlir-opt -convert-spirv-to-llvm %s | FileCheck %s
2+
3+
//===----------------------------------------------------------------------===//
4+
// spirv.ControlBarrierOp
5+
//===----------------------------------------------------------------------===//
6+
7+
// CHECK: llvm.func spir_funccc @_Z22__spirv_ControlBarrieriii(i32, i32, i32) attributes {convergent, no_unwind, will_return}
8+
9+
// CHECK-LABEL: @control_barrier
10+
spirv.func @control_barrier() "None" {
11+
// CHECK: [[EXECUTION:%.*]] = llvm.mlir.constant(2 : i32) : i32
12+
// CHECK: [[MEMORY:%.*]] = llvm.mlir.constant(2 : i32) : i32
13+
// CHECK: [[SEMANTICS:%.*]] = llvm.mlir.constant(768 : i32) : i32
14+
// CHECK: llvm.call spir_funccc @_Z22__spirv_ControlBarrieriii([[EXECUTION]], [[MEMORY]], [[SEMANTICS]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> ()
15+
spirv.ControlBarrier <Workgroup>, <Workgroup>, <CrossWorkgroupMemory|WorkgroupMemory>
16+
17+
// CHECK: [[EXECUTION:%.*]] = llvm.mlir.constant(2 : i32) : i32
18+
// CHECK: [[MEMORY:%.*]] = llvm.mlir.constant(2 : i32) : i32
19+
// CHECK: [[SEMANTICS:%.*]] = llvm.mlir.constant(256 : i32) : i32
20+
// CHECK: llvm.call spir_funccc @_Z22__spirv_ControlBarrieriii([[EXECUTION]], [[MEMORY]], [[SEMANTICS]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> ()
21+
spirv.ControlBarrier <Workgroup>, <Workgroup>, <WorkgroupMemory>
22+
spirv.Return
23+
}

0 commit comments

Comments
 (0)