Skip to content

Commit f220fa7

Browse files
authored
[CIR][CodeGen][LowerToLLVM] String literals for OpenCL (#1091)
This PR supports string literals in OpenCL end to end, making it possible to use `printf`. This involves two changes: * In CIRGen, ensure we create the global symbol for string literals with correct `constant` address space. * In LowerToLLVM, make the lowering of `GlobalViewAttr` aware of the upstream address space. Other proper refactors are also applied. Two test cases from OG CodeGen are reused. `str_literals.cl` is the primary test, while `printf.cl` is the bonus one.
1 parent 4463352 commit f220fa7

File tree

6 files changed

+119
-31
lines changed

6 files changed

+119
-31
lines changed

clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -375,13 +375,17 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
375375
return createAlloca(loc, addrType, type, name, alignmentIntAttr);
376376
}
377377

378-
mlir::Value createGetGlobal(cir::GlobalOp global, bool threadLocal = false) {
378+
mlir::Value createGetGlobal(mlir::Location loc, cir::GlobalOp global,
379+
bool threadLocal = false) {
379380
return create<cir::GetGlobalOp>(
380-
global.getLoc(),
381-
getPointerTo(global.getSymType(), global.getAddrSpaceAttr()),
381+
loc, getPointerTo(global.getSymType(), global.getAddrSpaceAttr()),
382382
global.getName(), threadLocal);
383383
}
384384

385+
mlir::Value createGetGlobal(cir::GlobalOp global, bool threadLocal = false) {
386+
return createGetGlobal(global.getLoc(), global, threadLocal);
387+
}
388+
385389
/// Create a copy with inferred length.
386390
cir::CopyOp createCopy(mlir::Value dst, mlir::Value src,
387391
bool isVolatile = false) {
@@ -547,8 +551,7 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
547551
});
548552

549553
if (last != block->rend())
550-
return OpBuilder::InsertPoint(block,
551-
++mlir::Block::iterator(&*last));
554+
return OpBuilder::InsertPoint(block, ++mlir::Block::iterator(&*last));
552555
return OpBuilder::InsertPoint(block, block->begin());
553556
};
554557

clang/lib/CIR/CodeGen/CIRGenExpr.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1840,12 +1840,9 @@ LValue CIRGenFunction::emitStringLiteralLValue(const StringLiteral *E) {
18401840
auto g = dyn_cast<cir::GlobalOp>(cstGlobal);
18411841
assert(g && "unaware of other symbol providers");
18421842

1843-
auto ptrTy =
1844-
cir::PointerType::get(CGM.getBuilder().getContext(), g.getSymType());
18451843
assert(g.getAlignment() && "expected alignment for string literal");
18461844
auto align = *g.getAlignment();
1847-
auto addr = builder.create<cir::GetGlobalOp>(getLoc(E->getSourceRange()),
1848-
ptrTy, g.getSymName());
1845+
auto addr = builder.createGetGlobal(getLoc(E->getSourceRange()), g);
18491846
return makeAddrLValue(
18501847
Address(addr, g.getSymType(), CharUnits::fromQuantity(align)),
18511848
E->getType(), AlignmentSource::Decl);

clang/lib/CIR/CodeGen/CIRGenModule.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1480,15 +1480,14 @@ static cir::GlobalOp
14801480
generateStringLiteral(mlir::Location loc, mlir::TypedAttr C,
14811481
cir::GlobalLinkageKind LT, CIRGenModule &CGM,
14821482
StringRef GlobalName, CharUnits Alignment) {
1483-
unsigned AddrSpace = CGM.getASTContext().getTargetAddressSpace(
1484-
CGM.getGlobalConstantAddressSpace());
1485-
assert((AddrSpace == 0 && !cir::MissingFeatures::addressSpaceInGlobalVar()) &&
1486-
"NYI");
1483+
cir::AddressSpaceAttr addrSpaceAttr =
1484+
CGM.getBuilder().getAddrSpaceAttr(CGM.getGlobalConstantAddressSpace());
14871485

14881486
// Create a global variable for this string
14891487
// FIXME(cir): check for insertion point in module level.
14901488
auto GV = CIRGenModule::createGlobalOp(CGM, loc, GlobalName, C.getType(),
1491-
!CGM.getLangOpts().WritableStrings);
1489+
!CGM.getLangOpts().WritableStrings,
1490+
addrSpaceAttr);
14921491

14931492
// Set up extra information and add to the module
14941493
GV.setAlignmentAttr(CGM.getSize(Alignment));
@@ -1559,7 +1558,8 @@ CIRGenModule::getAddrOfConstantStringFromLiteral(const StringLiteral *S,
15591558

15601559
auto ArrayTy = mlir::dyn_cast<cir::ArrayType>(GV.getSymType());
15611560
assert(ArrayTy && "String literal must be array");
1562-
auto PtrTy = cir::PointerType::get(&getMLIRContext(), ArrayTy.getEltType());
1561+
auto PtrTy =
1562+
getBuilder().getPointerTo(ArrayTy.getEltType(), GV.getAddrSpaceAttr());
15631563

15641564
return builder.getGlobalViewAttr(PtrTy, GV);
15651565
}

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,19 @@ void lowerAnnotationValue(
347347
}
348348
}
349349

350+
// Get addrspace by converting a pointer type.
351+
// TODO: The approach here is a little hacky. We should access the target info
352+
// directly to convert the address space of global op, similar to what we do
353+
// for type converter.
354+
unsigned getGlobalOpTargetAddrSpace(mlir::ConversionPatternRewriter &rewriter,
355+
const mlir::TypeConverter *converter,
356+
cir::GlobalOp op) {
357+
auto tempPtrTy = cir::PointerType::get(rewriter.getContext(), op.getSymType(),
358+
op.getAddrSpaceAttr());
359+
return cast<mlir::LLVM::LLVMPointerType>(converter->convertType(tempPtrTy))
360+
.getAddressSpace();
361+
}
362+
350363
} // namespace
351364

352365
//===----------------------------------------------------------------------===//
@@ -568,28 +581,36 @@ mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp,
568581
const mlir::TypeConverter *converter) {
569582
auto module = parentOp->getParentOfType<mlir::ModuleOp>();
570583
mlir::Type sourceType;
584+
unsigned sourceAddrSpace = 0;
571585
llvm::StringRef symName;
572586
auto *sourceSymbol =
573587
mlir::SymbolTable::lookupSymbolIn(module, globalAttr.getSymbol());
574588
if (auto llvmSymbol = dyn_cast<mlir::LLVM::GlobalOp>(sourceSymbol)) {
575589
sourceType = llvmSymbol.getType();
576590
symName = llvmSymbol.getSymName();
591+
sourceAddrSpace = llvmSymbol.getAddrSpace();
577592
} else if (auto cirSymbol = dyn_cast<cir::GlobalOp>(sourceSymbol)) {
578593
sourceType = converter->convertType(cirSymbol.getSymType());
579594
symName = cirSymbol.getSymName();
595+
sourceAddrSpace =
596+
getGlobalOpTargetAddrSpace(rewriter, converter, cirSymbol);
580597
} else if (auto llvmFun = dyn_cast<mlir::LLVM::LLVMFuncOp>(sourceSymbol)) {
581598
sourceType = llvmFun.getFunctionType();
582599
symName = llvmFun.getSymName();
600+
sourceAddrSpace = 0;
583601
} else if (auto fun = dyn_cast<cir::FuncOp>(sourceSymbol)) {
584602
sourceType = converter->convertType(fun.getFunctionType());
585603
symName = fun.getSymName();
604+
sourceAddrSpace = 0;
586605
} else {
587606
llvm_unreachable("Unexpected GlobalOp type");
588607
}
589608

590609
auto loc = parentOp->getLoc();
591610
mlir::Value addrOp = rewriter.create<mlir::LLVM::AddressOfOp>(
592-
loc, mlir::LLVM::LLVMPointerType::get(rewriter.getContext()), symName);
611+
loc,
612+
mlir::LLVM::LLVMPointerType::get(rewriter.getContext(), sourceAddrSpace),
613+
symName);
593614

594615
if (globalAttr.getIndices()) {
595616
llvm::SmallVector<mlir::LLVM::GEPArg> indices;
@@ -2322,18 +2343,6 @@ class CIRGlobalOpLowering : public mlir::OpConversionPattern<cir::GlobalOp> {
23222343
public:
23232344
using OpConversionPattern<cir::GlobalOp>::OpConversionPattern;
23242345

2325-
// Get addrspace by converting a pointer type.
2326-
// TODO: The approach here is a little hacky. We should access the target info
2327-
// directly to convert the address space of global op, similar to what we do
2328-
// for type converter.
2329-
unsigned getGlobalOpTargetAddrSpace(cir::GlobalOp op) const {
2330-
auto tempPtrTy = cir::PointerType::get(getContext(), op.getSymType(),
2331-
op.getAddrSpaceAttr());
2332-
return cast<mlir::LLVM::LLVMPointerType>(
2333-
typeConverter->convertType(tempPtrTy))
2334-
.getAddressSpace();
2335-
}
2336-
23372346
/// Replace CIR global with a region initialized LLVM global and update
23382347
/// insertion point to the end of the initializer block.
23392348
inline void setupRegionInitializedLLVMGlobalOp(
@@ -2344,7 +2353,7 @@ class CIRGlobalOpLowering : public mlir::OpConversionPattern<cir::GlobalOp> {
23442353
op, llvmType, op.getConstant(), convertLinkage(op.getLinkage()),
23452354
op.getSymName(), nullptr,
23462355
/*alignment*/ op.getAlignment().value_or(0),
2347-
/*addrSpace*/ getGlobalOpTargetAddrSpace(op),
2356+
/*addrSpace*/ getGlobalOpTargetAddrSpace(rewriter, typeConverter, op),
23482357
/*dsoLocal*/ false, /*threadLocal*/ (bool)op.getTlsModelAttr(),
23492358
/*comdat*/ mlir::SymbolRefAttr(), attributes);
23502359
newGlobalOp.getRegion().push_back(new mlir::Block());
@@ -2379,7 +2388,8 @@ class CIRGlobalOpLowering : public mlir::OpConversionPattern<cir::GlobalOp> {
23792388
if (!init.has_value()) {
23802389
rewriter.replaceOpWithNewOp<mlir::LLVM::GlobalOp>(
23812390
op, llvmType, isConst, linkage, symbol, mlir::Attribute(),
2382-
/*alignment*/ 0, /*addrSpace*/ getGlobalOpTargetAddrSpace(op),
2391+
/*alignment*/ 0,
2392+
/*addrSpace*/ getGlobalOpTargetAddrSpace(rewriter, typeConverter, op),
23832393
/*dsoLocal*/ isDsoLocal, /*threadLocal*/ (bool)op.getTlsModelAttr(),
23842394
/*comdat*/ mlir::SymbolRefAttr(), attributes);
23852395
return mlir::success();
@@ -2468,7 +2478,7 @@ class CIRGlobalOpLowering : public mlir::OpConversionPattern<cir::GlobalOp> {
24682478
auto llvmGlobalOp = rewriter.replaceOpWithNewOp<mlir::LLVM::GlobalOp>(
24692479
op, llvmType, isConst, linkage, symbol, init.value(),
24702480
/*alignment*/ op.getAlignment().value_or(0),
2471-
/*addrSpace*/ getGlobalOpTargetAddrSpace(op),
2481+
/*addrSpace*/ getGlobalOpTargetAddrSpace(rewriter, typeConverter, op),
24722482
/*dsoLocal*/ false, /*threadLocal*/ (bool)op.getTlsModelAttr(),
24732483
/*comdat*/ mlir::SymbolRefAttr(), attributes);
24742484

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
// RUN: %clang_cc1 -fclangir -no-enable-noundef-analysis -cl-std=CL1.2 -cl-ext=-+cl_khr_fp64 -triple spirv64-unknown-unknown -disable-llvm-passes -emit-cir -fno-clangir-call-conv-lowering -o %t.12fp64.cir %s
2+
// RUN: FileCheck -input-file=%t.12fp64.cir -check-prefixes=CIR-FP64,CIR-ALL %s
3+
// RUN: %clang_cc1 -fclangir -no-enable-noundef-analysis -cl-std=CL1.2 -cl-ext=-cl_khr_fp64 -triple spirv64-unknown-unknown -disable-llvm-passes -emit-cir -fno-clangir-call-conv-lowering -o %t.12nofp64.cir %s
4+
// RUN: FileCheck -input-file=%t.12nofp64.cir -check-prefixes=CIR-NOFP64,CIR-ALL %s
5+
// RUN: %clang_cc1 -fclangir -no-enable-noundef-analysis -cl-std=CL3.0 -cl-ext=+__opencl_c_fp64,+cl_khr_fp64 -triple spirv64-unknown-unknown -disable-llvm-passes -emit-cir -fno-clangir-call-conv-lowering -o %t.30fp64.cir %s
6+
// RUN: FileCheck -input-file=%t.30fp64.cir -check-prefixes=CIR-FP64,CIR-ALL %s
7+
// RUN: %clang_cc1 -fclangir -no-enable-noundef-analysis -cl-std=CL3.0 -cl-ext=-__opencl_c_fp64,-cl_khr_fp64 -triple spirv64-unknown-unknown -disable-llvm-passes -emit-cir -fno-clangir-call-conv-lowering -o %t.30nofp64.cir %s
8+
// RUN: FileCheck -input-file=%t.30nofp64.cir -check-prefixes=CIR-NOFP64,CIR-ALL %s
9+
// RUN: %clang_cc1 -fclangir -no-enable-noundef-analysis -cl-std=CL1.2 -cl-ext=-+cl_khr_fp64 -triple spirv64-unknown-unknown -disable-llvm-passes -emit-llvm -fno-clangir-call-conv-lowering -o %t.12fp64.ll %s
10+
// RUN: FileCheck -input-file=%t.12fp64.ll -check-prefixes=LLVM-FP64,LLVM-ALL %s
11+
// RUN: %clang_cc1 -fclangir -no-enable-noundef-analysis -cl-std=CL1.2 -cl-ext=-cl_khr_fp64 -triple spirv64-unknown-unknown -disable-llvm-passes -emit-llvm -fno-clangir-call-conv-lowering -o %t.12nofp64.ll %s
12+
// RUN: FileCheck -input-file=%t.12nofp64.ll -check-prefixes=LLVM-NOFP64,LLVM-ALL %s
13+
// RUN: %clang_cc1 -fclangir -no-enable-noundef-analysis -cl-std=CL3.0 -cl-ext=+__opencl_c_fp64,+cl_khr_fp64 -triple spirv64-unknown-unknown -disable-llvm-passes -emit-llvm -fno-clangir-call-conv-lowering -o %t.30fp64.ll %s
14+
// RUN: FileCheck -input-file=%t.30fp64.ll -check-prefixes=LLVM-FP64,LLVM-ALL %s
15+
// RUN: %clang_cc1 -fclangir -no-enable-noundef-analysis -cl-std=CL3.0 -cl-ext=-__opencl_c_fp64,-cl_khr_fp64 -triple spirv64-unknown-unknown -disable-llvm-passes -emit-llvm -fno-clangir-call-conv-lowering -o %t.30nofp64.ll %s
16+
// RUN: FileCheck -input-file=%t.30nofp64.ll -check-prefixes=LLVM-NOFP64,LLVM-ALL %s
17+
18+
typedef __attribute__((ext_vector_type(2))) float float2;
19+
typedef __attribute__((ext_vector_type(2))) half half2;
20+
21+
#if defined(cl_khr_fp64) || defined(__opencl_c_fp64)
22+
typedef __attribute__((ext_vector_type(2))) double double2;
23+
#endif
24+
25+
int printf(__constant const char* st, ...) __attribute__((format(printf, 1, 2)));
26+
27+
kernel void test_printf_float2(float2 arg) {
28+
printf("%v2hlf", arg);
29+
}
30+
// CIR-ALL-LABEL: @test_printf_float2(
31+
// CIR-FP64: %{{.+}} = cir.call @printf(%{{.+}}, %{{.+}}) : (!cir.ptr<!s8i, addrspace(offload_constant)>, !cir.vector<!cir.float x 2>) -> !s32i cc(spir_function)
32+
// CIR-NOFP64:%{{.+}} = cir.call @printf(%{{.+}}, %{{.+}}) : (!cir.ptr<!s8i, addrspace(offload_constant)>, !cir.vector<!cir.float x 2>) -> !s32i cc(spir_function)
33+
// LLVM-ALL-LABEL: @test_printf_float2(
34+
// LLVM-FP64: %{{.+}} = call spir_func i32 (ptr addrspace(2), ...) @{{.*}}printf{{.*}}(ptr addrspace(2) @.str, <2 x float> %{{.*}})
35+
// LLVM-NOFP64: call spir_func i32 (ptr addrspace(2), ...) @{{.*}}printf{{.*}}(ptr addrspace(2) @.str, <2 x float> %{{.*}})
36+
37+
kernel void test_printf_half2(half2 arg) {
38+
printf("%v2hf", arg);
39+
}
40+
// CIR-ALL-LABEL: @test_printf_half2(
41+
// CIR-FP64: %{{.+}} = cir.call @printf(%{{.+}}, %{{.+}}) : (!cir.ptr<!s8i, addrspace(offload_constant)>, !cir.vector<!cir.f16 x 2>) -> !s32i cc(spir_function)
42+
// CIR-NOFP64:%{{.+}} = cir.call @printf(%{{.+}}, %{{.+}}) : (!cir.ptr<!s8i, addrspace(offload_constant)>, !cir.vector<!cir.f16 x 2>) -> !s32i cc(spir_function)
43+
// LLVM-ALL-LABEL: @test_printf_half2(
44+
// LLVM-FP64: %{{.+}} = call spir_func i32 (ptr addrspace(2), ...) @{{.*}}printf{{.*}}(ptr addrspace(2) @.str.1, <2 x half> %{{.*}})
45+
// LLVM-NOFP64: %{{.+}} = call spir_func i32 (ptr addrspace(2), ...) @{{.*}}printf{{.*}}(ptr addrspace(2) @.str.1, <2 x half> %{{.*}})
46+
47+
#if defined(cl_khr_fp64) || defined(__opencl_c_fp64)
48+
kernel void test_printf_double2(double2 arg) {
49+
printf("%v2lf", arg);
50+
}
51+
// CIR-FP64-LABEL: @test_printf_double2(
52+
// CIR-FP64: %{{.+}} = cir.call @printf(%{{.+}}, %{{.+}}) : (!cir.ptr<!s8i, addrspace(offload_constant)>, !cir.vector<!cir.double x 2>) -> !s32i cc(spir_function)
53+
// LLVM-FP64-LABEL: @test_printf_double2(
54+
// LLVM-FP64: call spir_func i32 (ptr addrspace(2), ...) @{{.*}}printf{{.*}}(ptr addrspace(2) @.str.2, <2 x double> %{{.*}})
55+
#endif
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// RUN: %clang_cc1 %s -fclangir -triple=spirv64-unknown-unknown -cl-opt-disable -emit-cir -o %t.cir -ffake-address-space-map
2+
// RUN: FileCheck -input-file=%t.cir -check-prefix=CIR %s
3+
// RUN: %clang_cc1 %s -fclangir -triple=spirv64-unknown-unknown -cl-opt-disable -emit-llvm -o %t.ll -ffake-address-space-map
4+
// RUN: FileCheck -input-file=%t.ll -check-prefix=LLVM %s
5+
6+
__constant char *__constant x = "hello world";
7+
__constant char *__constant y = "hello world";
8+
9+
// CIR: cir.global{{.*}} constant {{.*}}addrspace(offload_constant) @".str" = #cir.const_array<"hello world\00" : !cir.array<!s8i x 12>> : !cir.array<!s8i x 12>
10+
// CIR: cir.global{{.*}} constant {{.*}}addrspace(offload_constant) @x = #cir.global_view<@".str"> : !cir.ptr<!s8i, addrspace(offload_constant)>
11+
// CIR: cir.global{{.*}} constant {{.*}}addrspace(offload_constant) @y = #cir.global_view<@".str"> : !cir.ptr<!s8i, addrspace(offload_constant)>
12+
// CIR: cir.global{{.*}} constant {{.*}}addrspace(offload_constant) @".str.1" = #cir.const_array<"f\00" : !cir.array<!s8i x 2>> : !cir.array<!s8i x 2>
13+
// LLVM: addrspace(2) constant{{.*}}"hello world\00"
14+
// LLVM-NOT: addrspace(2) constant
15+
// LLVM: @x = {{(dso_local )?}}addrspace(2) constant ptr addrspace(2)
16+
// LLVM: @y = {{(dso_local )?}}addrspace(2) constant ptr addrspace(2)
17+
// LLVM: addrspace(2) constant{{.*}}"f\00"
18+
19+
void f() {
20+
// CIR: cir.store %{{.*}}, %{{.*}} : !cir.ptr<!s8i, addrspace(offload_constant)>, !cir.ptr<!cir.ptr<!s8i, addrspace(offload_constant)>, addrspace(offload_private)>
21+
// LLVM: store ptr addrspace(2) {{.*}}, ptr
22+
constant const char *f3 = __func__;
23+
}

0 commit comments

Comments
 (0)