@@ -1024,6 +1024,71 @@ class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> {
1024
1024
}
1025
1025
};
1026
1026
1027
+ static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn (Operation *symbolTable,
1028
+ StringRef name,
1029
+ ArrayRef<Type> paramTypes,
1030
+ Type resultType) {
1031
+ auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>(
1032
+ SymbolTable::lookupSymbolIn (symbolTable, name));
1033
+ if (func)
1034
+ return func;
1035
+
1036
+ OpBuilder b (symbolTable->getRegion (0 ));
1037
+ func = b.create <LLVM::LLVMFuncOp>(
1038
+ symbolTable->getLoc (), name,
1039
+ LLVM::LLVMFunctionType::get (resultType, paramTypes));
1040
+ func.setCConv (LLVM::cconv::CConv::SPIR_FUNC);
1041
+ func.setConvergent (true );
1042
+ func.setNoUnwind (true );
1043
+ func.setWillReturn (true );
1044
+ return func;
1045
+ }
1046
+
1047
+ static LLVM::CallOp createSPIRVBuiltinCall (Location loc, OpBuilder &builder,
1048
+ LLVM::LLVMFuncOp func,
1049
+ ValueRange args) {
1050
+ auto call = builder.create <LLVM::CallOp>(loc, func, args);
1051
+ call.setCConv (func.getCConv ());
1052
+ call.setConvergentAttr (func.getConvergentAttr ());
1053
+ call.setNoUnwindAttr (func.getNoUnwindAttr ());
1054
+ call.setWillReturnAttr (func.getWillReturnAttr ());
1055
+ return call;
1056
+ }
1057
+
1058
+ class ControlBarrierPattern
1059
+ : public SPIRVToLLVMConversion<spirv::ControlBarrierOp> {
1060
+ public:
1061
+ using SPIRVToLLVMConversion<spirv::ControlBarrierOp>::SPIRVToLLVMConversion;
1062
+
1063
+ LogicalResult
1064
+ matchAndRewrite (spirv::ControlBarrierOp controlBarrierOp, OpAdaptor adaptor,
1065
+ ConversionPatternRewriter &rewriter) const override {
1066
+ constexpr StringLiteral funcName = " _Z22__spirv_ControlBarrieriii" ;
1067
+ Operation *symbolTable =
1068
+ controlBarrierOp->getParentWithTrait <OpTrait::SymbolTable>();
1069
+
1070
+ Type i32 = rewriter.getI32Type ();
1071
+
1072
+ Type voidTy = rewriter.getType <LLVM::LLVMVoidType>();
1073
+ LLVM::LLVMFuncOp func =
1074
+ lookupOrCreateSPIRVFn (symbolTable, funcName, {i32, i32, i32}, voidTy);
1075
+
1076
+ Location loc = controlBarrierOp->getLoc ();
1077
+ Value execution = rewriter.create <LLVM::ConstantOp>(
1078
+ loc, i32, static_cast <int32_t >(adaptor.getExecutionScope ()));
1079
+ Value memory = rewriter.create <LLVM::ConstantOp>(
1080
+ loc, i32, static_cast <int32_t >(adaptor.getMemoryScope ()));
1081
+ Value semantics = rewriter.create <LLVM::ConstantOp>(
1082
+ loc, i32, static_cast <int32_t >(adaptor.getMemorySemantics ()));
1083
+
1084
+ auto call = createSPIRVBuiltinCall (loc, rewriter, func,
1085
+ {execution, memory, semantics});
1086
+
1087
+ rewriter.replaceOp (controlBarrierOp, call);
1088
+ return success ();
1089
+ }
1090
+ };
1091
+
1027
1092
// / Converts `spirv.mlir.loop` to LLVM dialect. All blocks within selection
1028
1093
// / should be reachable for conversion to succeed. The structure of the loop in
1029
1094
// / LLVM dialect will be the following:
@@ -1648,7 +1713,10 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
1648
1713
ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,
1649
1714
1650
1715
// Return ops
1651
- ReturnPattern, ReturnValuePattern>(patterns.getContext (), typeConverter);
1716
+ ReturnPattern, ReturnValuePattern,
1717
+
1718
+ // Barrier ops
1719
+ ControlBarrierPattern>(patterns.getContext (), typeConverter);
1652
1720
1653
1721
patterns.add <GlobalVariablePattern>(clientAPI, patterns.getContext (),
1654
1722
typeConverter);
0 commit comments