@@ -25,29 +25,80 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
2525 Location loc = gpuFuncOp.getLoc ();
2626
2727 SmallVector<LLVM::GlobalOp, 3 > workgroupBuffers;
28- workgroupBuffers.reserve (gpuFuncOp.getNumWorkgroupAttributions ());
29- for (const auto [idx, attribution] :
30- llvm::enumerate (gpuFuncOp.getWorkgroupAttributions ())) {
31- auto type = dyn_cast<MemRefType>(attribution.getType ());
32- assert (type && type.hasStaticShape () && " unexpected type in attribution" );
33-
34- uint64_t numElements = type.getNumElements ();
35-
36- auto elementType =
37- cast<Type>(typeConverter->convertType (type.getElementType ()));
38- auto arrayType = LLVM::LLVMArrayType::get (elementType, numElements);
39- std::string name =
40- std::string (llvm::formatv (" __wg_{0}_{1}" , gpuFuncOp.getName (), idx));
41- uint64_t alignment = 0 ;
42- if (auto alignAttr =
43- dyn_cast_or_null<IntegerAttr>(gpuFuncOp.getWorkgroupAttributionAttr (
44- idx, LLVM::LLVMDialect::getAlignAttrName ())))
45- alignment = alignAttr.getInt ();
46- auto globalOp = rewriter.create <LLVM::GlobalOp>(
47- gpuFuncOp.getLoc (), arrayType, /* isConstant=*/ false ,
48- LLVM::Linkage::Internal, name, /* value=*/ Attribute (), alignment,
49- workgroupAddrSpace);
50- workgroupBuffers.push_back (globalOp);
28+ if (encodeWorkgroupAttributionsAsArguments) {
29+ // Append an `llvm.ptr` argument to the function signature to encode
30+ // workgroup attributions.
31+
32+ ArrayRef<BlockArgument> workgroupAttributions =
33+ gpuFuncOp.getWorkgroupAttributions ();
34+ size_t numAttributions = workgroupAttributions.size ();
35+
36+ // Insert all arguments at the end.
37+ unsigned index = gpuFuncOp.getNumArguments ();
38+ SmallVector<unsigned > argIndices (numAttributions, index);
39+
40+ // New arguments will simply be `llvm.ptr` with the correct address space
41+ Type workgroupPtrType =
42+ rewriter.getType <LLVM::LLVMPointerType>(workgroupAddrSpace);
43+ SmallVector<Type> argTypes (numAttributions, workgroupPtrType);
44+
45+ // Attributes: noalias, llvm.mlir.workgroup_attribution(<size>, <type>)
46+ std::array attrs{
47+ rewriter.getNamedAttr (LLVM::LLVMDialect::getNoAliasAttrName (),
48+ rewriter.getUnitAttr ()),
49+ rewriter.getNamedAttr (
50+ getDialect ().getWorkgroupAttributionAttrHelper ().getName (),
51+ rewriter.getUnitAttr ()),
52+ };
53+ SmallVector<DictionaryAttr> argAttrs;
54+ for (BlockArgument attribution : workgroupAttributions) {
55+ auto attributionType = cast<MemRefType>(attribution.getType ());
56+ IntegerAttr numElements =
57+ rewriter.getI64IntegerAttr (attributionType.getNumElements ());
58+ Type llvmElementType =
59+ getTypeConverter ()->convertType (attributionType.getElementType ());
60+ if (!llvmElementType)
61+ return failure ();
62+ TypeAttr type = TypeAttr::get (llvmElementType);
63+ attrs.back ().setValue (
64+ rewriter.getAttr <LLVM::WorkgroupAttributionAttr>(numElements, type));
65+ argAttrs.push_back (rewriter.getDictionaryAttr (attrs));
66+ }
67+
68+ // Location match function location
69+ SmallVector<Location> argLocs (numAttributions, gpuFuncOp.getLoc ());
70+
71+ // Perform signature modification
72+ rewriter.modifyOpInPlace (
73+ gpuFuncOp, [gpuFuncOp, &argIndices, &argTypes, &argAttrs, &argLocs]() {
74+ static_cast <FunctionOpInterface>(gpuFuncOp).insertArguments (
75+ argIndices, argTypes, argAttrs, argLocs);
76+ });
77+ } else {
78+ workgroupBuffers.reserve (gpuFuncOp.getNumWorkgroupAttributions ());
79+ for (auto [idx, attribution] :
80+ llvm::enumerate (gpuFuncOp.getWorkgroupAttributions ())) {
81+ auto type = dyn_cast<MemRefType>(attribution.getType ());
82+ assert (type && type.hasStaticShape () && " unexpected type in attribution" );
83+
84+ uint64_t numElements = type.getNumElements ();
85+
86+ auto elementType =
87+ cast<Type>(typeConverter->convertType (type.getElementType ()));
88+ auto arrayType = LLVM::LLVMArrayType::get (elementType, numElements);
89+ std::string name =
90+ std::string (llvm::formatv (" __wg_{0}_{1}" , gpuFuncOp.getName (), idx));
91+ uint64_t alignment = 0 ;
92+ if (auto alignAttr = dyn_cast_or_null<IntegerAttr>(
93+ gpuFuncOp.getWorkgroupAttributionAttr (
94+ idx, LLVM::LLVMDialect::getAlignAttrName ())))
95+ alignment = alignAttr.getInt ();
96+ auto globalOp = rewriter.create <LLVM::GlobalOp>(
97+ gpuFuncOp.getLoc (), arrayType, /* isConstant=*/ false ,
98+ LLVM::Linkage::Internal, name, /* value=*/ Attribute (), alignment,
99+ workgroupAddrSpace);
100+ workgroupBuffers.push_back (globalOp);
101+ }
51102 }
52103
53104 // Remap proper input types.
@@ -101,16 +152,19 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
101152 // attribute. The former is necessary for further translation while the
102153 // latter is expected by gpu.launch_func.
103154 if (gpuFuncOp.isKernel ()) {
104- attributes.emplace_back (kernelAttributeName, rewriter.getUnitAttr ());
155+ if (kernelAttributeName)
156+ attributes.emplace_back (kernelAttributeName, rewriter.getUnitAttr ());
105157 // Set the dialect-specific block size attribute if there is one.
106- if (kernelBlockSizeAttributeName.has_value () && knownBlockSize) {
107- attributes.emplace_back (kernelBlockSizeAttributeName.value (),
108- knownBlockSize);
158+ if (kernelBlockSizeAttributeName && knownBlockSize) {
159+ attributes.emplace_back (kernelBlockSizeAttributeName, knownBlockSize);
109160 }
110161 }
162+ LLVM::CConv callingConvention = gpuFuncOp.isKernel ()
163+ ? kernelCallingConvention
164+ : nonKernelCallingConvention;
111165 auto llvmFuncOp = rewriter.create <LLVM::LLVMFuncOp>(
112166 gpuFuncOp.getLoc (), gpuFuncOp.getName (), funcType,
113- LLVM::Linkage::External, /* dsoLocal=*/ false , /* cconv= */ LLVM::CConv::C ,
167+ LLVM::Linkage::External, /* dsoLocal=*/ false , callingConvention ,
114168 /* comdat=*/ nullptr , attributes);
115169
116170 {
@@ -125,24 +179,51 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
125179 rewriter.setInsertionPointToStart (&gpuFuncOp.front ());
126180 unsigned numProperArguments = gpuFuncOp.getNumArguments ();
127181
128- for (const auto [idx, global] : llvm::enumerate (workgroupBuffers)) {
129- auto ptrType = LLVM::LLVMPointerType::get (rewriter.getContext (),
130- global.getAddrSpace ());
131- Value address = rewriter.create <LLVM::AddressOfOp>(
132- loc, ptrType, global.getSymNameAttr ());
133- Value memory =
134- rewriter.create <LLVM::GEPOp>(loc, ptrType, global.getType (), address,
135- ArrayRef<LLVM::GEPArg>{0 , 0 });
136-
137- // Build a memref descriptor pointing to the buffer to plug with the
138- // existing memref infrastructure. This may use more registers than
139- // otherwise necessary given that memref sizes are fixed, but we can try
140- // and canonicalize that away later.
141- Value attribution = gpuFuncOp.getWorkgroupAttributions ()[idx];
142- auto type = cast<MemRefType>(attribution.getType ());
143- auto descr = MemRefDescriptor::fromStaticShape (
144- rewriter, loc, *getTypeConverter (), type, memory);
145- signatureConversion.remapInput (numProperArguments + idx, descr);
182+ if (encodeWorkgroupAttributionsAsArguments) {
183+ // Build a MemRefDescriptor with each of the arguments added above.
184+
185+ unsigned numAttributions = gpuFuncOp.getNumWorkgroupAttributions ();
186+ assert (numProperArguments >= numAttributions &&
187+ " Expecting attributions to be encoded as arguments already" );
188+
189+ // Arguments encoding workgroup attributions will be in positions
190+ // [numProperArguments, numProperArguments+numAttributions)
191+ ArrayRef<BlockArgument> attributionArguments =
192+ gpuFuncOp.getArguments ().slice (numProperArguments - numAttributions,
193+ numAttributions);
194+ for (auto [idx, vals] : llvm::enumerate (llvm::zip_equal (
195+ gpuFuncOp.getWorkgroupAttributions (), attributionArguments))) {
196+ auto [attribution, arg] = vals;
197+ auto type = cast<MemRefType>(attribution.getType ());
198+
199+ // Arguments are of llvm.ptr type and attributions are of memref type:
200+ // we need to wrap them in memref descriptors.
201+ Value descr = MemRefDescriptor::fromStaticShape (
202+ rewriter, loc, *getTypeConverter (), type, arg);
203+
204+ // And remap the arguments
205+ signatureConversion.remapInput (numProperArguments + idx, descr);
206+ }
207+ } else {
208+ for (const auto [idx, global] : llvm::enumerate (workgroupBuffers)) {
209+ auto ptrType = LLVM::LLVMPointerType::get (rewriter.getContext (),
210+ global.getAddrSpace ());
211+ Value address = rewriter.create <LLVM::AddressOfOp>(
212+ loc, ptrType, global.getSymNameAttr ());
213+ Value memory =
214+ rewriter.create <LLVM::GEPOp>(loc, ptrType, global.getType (),
215+ address, ArrayRef<LLVM::GEPArg>{0 , 0 });
216+
217+ // Build a memref descriptor pointing to the buffer to plug with the
218+ // existing memref infrastructure. This may use more registers than
219+ // otherwise necessary given that memref sizes are fixed, but we can try
220+ // and canonicalize that away later.
221+ Value attribution = gpuFuncOp.getWorkgroupAttributions ()[idx];
222+ auto type = cast<MemRefType>(attribution.getType ());
223+ auto descr = MemRefDescriptor::fromStaticShape (
224+ rewriter, loc, *getTypeConverter (), type, memory);
225+ signatureConversion.remapInput (numProperArguments + idx, descr);
226+ }
146227 }
147228
148229 // Rewrite private memory attributions to alloca'ed buffers.
@@ -239,6 +320,8 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
239320 copyPointerAttribute (LLVM::LLVMDialect::getDereferenceableAttrName ());
240321 copyPointerAttribute (
241322 LLVM::LLVMDialect::getDereferenceableOrNullAttrName ());
323+ copyPointerAttribute (
324+ LLVM::LLVMDialect::WorkgroupAttributionAttrHelper::getNameStr ());
242325 }
243326 }
244327 rewriter.eraseOp (gpuFuncOp);
0 commit comments