@@ -348,6 +348,19 @@ void lowerAnnotationValue(
348
348
}
349
349
}
350
350
351
+ // Get addrspace by converting a pointer type.
352
+ // TODO: The approach here is a little hacky. We should access the target info
353
+ // directly to convert the address space of global op, similar to what we do
354
+ // for type converter.
355
+ unsigned getGlobalOpTargetAddrSpace (mlir::ConversionPatternRewriter &rewriter,
356
+ const mlir::TypeConverter *converter,
357
+ mlir::cir::GlobalOp op) {
358
+ auto tempPtrTy = mlir::cir::PointerType::get (
359
+ rewriter.getContext (), op.getSymType (), op.getAddrSpaceAttr ());
360
+ return cast<mlir::LLVM::LLVMPointerType>(converter->convertType (tempPtrTy))
361
+ .getAddressSpace ();
362
+ }
363
+
351
364
} // namespace
352
365
353
366
// ===----------------------------------------------------------------------===//
@@ -571,28 +584,36 @@ mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp,
571
584
const mlir::TypeConverter *converter) {
572
585
auto module = parentOp->getParentOfType <mlir::ModuleOp>();
573
586
mlir::Type sourceType;
587
+ unsigned sourceAddrSpace = 0 ;
574
588
llvm::StringRef symName;
575
589
auto *sourceSymbol =
576
590
mlir::SymbolTable::lookupSymbolIn (module , globalAttr.getSymbol ());
577
591
if (auto llvmSymbol = dyn_cast<mlir::LLVM::GlobalOp>(sourceSymbol)) {
578
592
sourceType = llvmSymbol.getType ();
579
593
symName = llvmSymbol.getSymName ();
594
+ sourceAddrSpace = llvmSymbol.getAddrSpace ();
580
595
} else if (auto cirSymbol = dyn_cast<mlir::cir::GlobalOp>(sourceSymbol)) {
581
596
sourceType = converter->convertType (cirSymbol.getSymType ());
582
597
symName = cirSymbol.getSymName ();
598
+ sourceAddrSpace =
599
+ getGlobalOpTargetAddrSpace (rewriter, converter, cirSymbol);
583
600
} else if (auto llvmFun = dyn_cast<mlir::LLVM::LLVMFuncOp>(sourceSymbol)) {
584
601
sourceType = llvmFun.getFunctionType ();
585
602
symName = llvmFun.getSymName ();
603
+ sourceAddrSpace = 0 ;
586
604
} else if (auto fun = dyn_cast<mlir::cir::FuncOp>(sourceSymbol)) {
587
605
sourceType = converter->convertType (fun.getFunctionType ());
588
606
symName = fun.getSymName ();
607
+ sourceAddrSpace = 0 ;
589
608
} else {
590
609
llvm_unreachable (" Unexpected GlobalOp type" );
591
610
}
592
611
593
612
auto loc = parentOp->getLoc ();
594
613
mlir::Value addrOp = rewriter.create <mlir::LLVM::AddressOfOp>(
595
- loc, mlir::LLVM::LLVMPointerType::get (rewriter.getContext ()), symName);
614
+ loc,
615
+ mlir::LLVM::LLVMPointerType::get (rewriter.getContext (), sourceAddrSpace),
616
+ symName);
596
617
597
618
if (globalAttr.getIndices ()) {
598
619
llvm::SmallVector<mlir::LLVM::GEPArg> indices;
@@ -2349,18 +2370,6 @@ class CIRGlobalOpLowering
2349
2370
public:
2350
2371
using OpConversionPattern<mlir::cir::GlobalOp>::OpConversionPattern;
2351
2372
2352
- // Get addrspace by converting a pointer type.
2353
- // TODO: The approach here is a little hacky. We should access the target info
2354
- // directly to convert the address space of global op, similar to what we do
2355
- // for type converter.
2356
- unsigned getGlobalOpTargetAddrSpace (mlir::cir::GlobalOp op) const {
2357
- auto tempPtrTy = mlir::cir::PointerType::get (getContext (), op.getSymType (),
2358
- op.getAddrSpaceAttr ());
2359
- return cast<mlir::LLVM::LLVMPointerType>(
2360
- typeConverter->convertType (tempPtrTy))
2361
- .getAddressSpace ();
2362
- }
2363
-
2364
2373
// / Replace CIR global with a region initialized LLVM global and update
2365
2374
// / insertion point to the end of the initializer block.
2366
2375
inline void setupRegionInitializedLLVMGlobalOp (
@@ -2371,7 +2380,7 @@ class CIRGlobalOpLowering
2371
2380
op, llvmType, op.getConstant (), convertLinkage (op.getLinkage ()),
2372
2381
op.getSymName (), nullptr ,
2373
2382
/* alignment*/ op.getAlignment ().value_or (0 ),
2374
- /* addrSpace*/ getGlobalOpTargetAddrSpace (op),
2383
+ /* addrSpace*/ getGlobalOpTargetAddrSpace (rewriter, typeConverter, op),
2375
2384
/* dsoLocal*/ false , /* threadLocal*/ (bool )op.getTlsModelAttr (),
2376
2385
/* comdat*/ mlir::SymbolRefAttr (), attributes);
2377
2386
newGlobalOp.getRegion ().push_back (new mlir::Block ());
@@ -2406,7 +2415,8 @@ class CIRGlobalOpLowering
2406
2415
if (!init.has_value ()) {
2407
2416
rewriter.replaceOpWithNewOp <mlir::LLVM::GlobalOp>(
2408
2417
op, llvmType, isConst, linkage, symbol, mlir::Attribute (),
2409
- /* alignment*/ 0 , /* addrSpace*/ getGlobalOpTargetAddrSpace (op),
2418
+ /* alignment*/ 0 ,
2419
+ /* addrSpace*/ getGlobalOpTargetAddrSpace (rewriter, typeConverter, op),
2410
2420
/* dsoLocal*/ isDsoLocal, /* threadLocal*/ (bool )op.getTlsModelAttr (),
2411
2421
/* comdat*/ mlir::SymbolRefAttr (), attributes);
2412
2422
return mlir::success ();
@@ -2498,7 +2508,7 @@ class CIRGlobalOpLowering
2498
2508
auto llvmGlobalOp = rewriter.replaceOpWithNewOp <mlir::LLVM::GlobalOp>(
2499
2509
op, llvmType, isConst, linkage, symbol, init.value (),
2500
2510
/* alignment*/ op.getAlignment ().value_or (0 ),
2501
- /* addrSpace*/ getGlobalOpTargetAddrSpace (op),
2511
+ /* addrSpace*/ getGlobalOpTargetAddrSpace (rewriter, typeConverter, op),
2502
2512
/* dsoLocal*/ false , /* threadLocal*/ (bool )op.getTlsModelAttr (),
2503
2513
/* comdat*/ mlir::SymbolRefAttr (), attributes);
2504
2514
0 commit comments