@@ -347,6 +347,19 @@ void lowerAnnotationValue(
347
347
}
348
348
}
349
349
350
+ // Get addrspace by converting a pointer type.
351
+ // TODO: The approach here is a little hacky. We should access the target info
352
+ // directly to convert the address space of global op, similar to what we do
353
+ // for type converter.
354
+ unsigned getGlobalOpTargetAddrSpace (mlir::ConversionPatternRewriter &rewriter,
355
+ const mlir::TypeConverter *converter,
356
+ cir::GlobalOp op) {
357
+ auto tempPtrTy = cir::PointerType::get (rewriter.getContext (), op.getSymType (),
358
+ op.getAddrSpaceAttr ());
359
+ return cast<mlir::LLVM::LLVMPointerType>(converter->convertType (tempPtrTy))
360
+ .getAddressSpace ();
361
+ }
362
+
350
363
} // namespace
351
364
352
365
// ===----------------------------------------------------------------------===//
@@ -568,28 +581,36 @@ mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp,
568
581
const mlir::TypeConverter *converter) {
569
582
auto module = parentOp->getParentOfType <mlir::ModuleOp>();
570
583
mlir::Type sourceType;
584
+ unsigned sourceAddrSpace = 0 ;
571
585
llvm::StringRef symName;
572
586
auto *sourceSymbol =
573
587
mlir::SymbolTable::lookupSymbolIn (module , globalAttr.getSymbol ());
574
588
if (auto llvmSymbol = dyn_cast<mlir::LLVM::GlobalOp>(sourceSymbol)) {
575
589
sourceType = llvmSymbol.getType ();
576
590
symName = llvmSymbol.getSymName ();
591
+ sourceAddrSpace = llvmSymbol.getAddrSpace ();
577
592
} else if (auto cirSymbol = dyn_cast<cir::GlobalOp>(sourceSymbol)) {
578
593
sourceType = converter->convertType (cirSymbol.getSymType ());
579
594
symName = cirSymbol.getSymName ();
595
+ sourceAddrSpace =
596
+ getGlobalOpTargetAddrSpace (rewriter, converter, cirSymbol);
580
597
} else if (auto llvmFun = dyn_cast<mlir::LLVM::LLVMFuncOp>(sourceSymbol)) {
581
598
sourceType = llvmFun.getFunctionType ();
582
599
symName = llvmFun.getSymName ();
600
+ sourceAddrSpace = 0 ;
583
601
} else if (auto fun = dyn_cast<cir::FuncOp>(sourceSymbol)) {
584
602
sourceType = converter->convertType (fun.getFunctionType ());
585
603
symName = fun.getSymName ();
604
+ sourceAddrSpace = 0 ;
586
605
} else {
587
606
llvm_unreachable (" Unexpected GlobalOp type" );
588
607
}
589
608
590
609
auto loc = parentOp->getLoc ();
591
610
mlir::Value addrOp = rewriter.create <mlir::LLVM::AddressOfOp>(
592
- loc, mlir::LLVM::LLVMPointerType::get (rewriter.getContext ()), symName);
611
+ loc,
612
+ mlir::LLVM::LLVMPointerType::get (rewriter.getContext (), sourceAddrSpace),
613
+ symName);
593
614
594
615
if (globalAttr.getIndices ()) {
595
616
llvm::SmallVector<mlir::LLVM::GEPArg> indices;
@@ -2322,18 +2343,6 @@ class CIRGlobalOpLowering : public mlir::OpConversionPattern<cir::GlobalOp> {
2322
2343
public:
2323
2344
using OpConversionPattern<cir::GlobalOp>::OpConversionPattern;
2324
2345
2325
- // Get addrspace by converting a pointer type.
2326
- // TODO: The approach here is a little hacky. We should access the target info
2327
- // directly to convert the address space of global op, similar to what we do
2328
- // for type converter.
2329
- unsigned getGlobalOpTargetAddrSpace (cir::GlobalOp op) const {
2330
- auto tempPtrTy = cir::PointerType::get (getContext (), op.getSymType (),
2331
- op.getAddrSpaceAttr ());
2332
- return cast<mlir::LLVM::LLVMPointerType>(
2333
- typeConverter->convertType (tempPtrTy))
2334
- .getAddressSpace ();
2335
- }
2336
-
2337
2346
// / Replace CIR global with a region initialized LLVM global and update
2338
2347
// / insertion point to the end of the initializer block.
2339
2348
inline void setupRegionInitializedLLVMGlobalOp (
@@ -2344,7 +2353,7 @@ class CIRGlobalOpLowering : public mlir::OpConversionPattern<cir::GlobalOp> {
2344
2353
op, llvmType, op.getConstant (), convertLinkage (op.getLinkage ()),
2345
2354
op.getSymName (), nullptr ,
2346
2355
/* alignment*/ op.getAlignment ().value_or (0 ),
2347
- /* addrSpace*/ getGlobalOpTargetAddrSpace (op),
2356
+ /* addrSpace*/ getGlobalOpTargetAddrSpace (rewriter, typeConverter, op),
2348
2357
/* dsoLocal*/ false , /* threadLocal*/ (bool )op.getTlsModelAttr (),
2349
2358
/* comdat*/ mlir::SymbolRefAttr (), attributes);
2350
2359
newGlobalOp.getRegion ().push_back (new mlir::Block ());
@@ -2379,7 +2388,8 @@ class CIRGlobalOpLowering : public mlir::OpConversionPattern<cir::GlobalOp> {
2379
2388
if (!init.has_value ()) {
2380
2389
rewriter.replaceOpWithNewOp <mlir::LLVM::GlobalOp>(
2381
2390
op, llvmType, isConst, linkage, symbol, mlir::Attribute (),
2382
- /* alignment*/ 0 , /* addrSpace*/ getGlobalOpTargetAddrSpace (op),
2391
+ /* alignment*/ 0 ,
2392
+ /* addrSpace*/ getGlobalOpTargetAddrSpace (rewriter, typeConverter, op),
2383
2393
/* dsoLocal*/ isDsoLocal, /* threadLocal*/ (bool )op.getTlsModelAttr (),
2384
2394
/* comdat*/ mlir::SymbolRefAttr (), attributes);
2385
2395
return mlir::success ();
@@ -2468,7 +2478,7 @@ class CIRGlobalOpLowering : public mlir::OpConversionPattern<cir::GlobalOp> {
2468
2478
auto llvmGlobalOp = rewriter.replaceOpWithNewOp <mlir::LLVM::GlobalOp>(
2469
2479
op, llvmType, isConst, linkage, symbol, init.value (),
2470
2480
/* alignment*/ op.getAlignment ().value_or (0 ),
2471
- /* addrSpace*/ getGlobalOpTargetAddrSpace (op),
2481
+ /* addrSpace*/ getGlobalOpTargetAddrSpace (rewriter, typeConverter, op),
2472
2482
/* dsoLocal*/ false , /* threadLocal*/ (bool )op.getTlsModelAttr (),
2473
2483
/* comdat*/ mlir::SymbolRefAttr (), attributes);
2474
2484
0 commit comments