Skip to content

Commit ed2a128

Browse files
committed
[SYCL] Handle address space casts in the LowerWGScope
1 parent 4109213 commit ed2a128

File tree

2 files changed

+94
-2
lines changed

2 files changed

+94
-2
lines changed

llvm/lib/SYCLLowerIR/LowerWGScope.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,8 @@ static bool mayHaveSideEffects(const Instruction *I) {
237237
case Instruction::Call:
238238
assert(!isPFWICall(I) && "pfwi must have been handled separately");
239239
return true;
240+
case Instruction::AddrSpaceCast:
241+
return false;
240242
default:
241243
return true;
242244
}
@@ -630,6 +632,11 @@ static void fixupPrivateMemoryPFWILambdaCaptures(CallInst *PFWICall) {
630632
// whether it is an alloca with "work_item_scope"
631633
SmallVector<CaptureDesc, 4> PrivMemCaptures;
632634

635+
// Look through cast
636+
auto *Cast = dyn_cast<AddrSpaceCastInst>(LambdaObj);
637+
if (Cast)
638+
LambdaObj = Cast->getOperand(0);
639+
633640
for (auto *U : LambdaObj->users()) {
634641
GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U);
635642

@@ -779,13 +786,15 @@ PreservedAnalyses SYCLLowerWGScopePass::run(Function &F, const llvm::Triple &TT,
779786
// globals.
780787
Instruction *I = BB.getFirstNonPHI();
781788

782-
for (; I->getOpcode() == Instruction::Alloca; I = I->getNextNode()) {
789+
for (; I->getOpcode() == Instruction::Alloca ||
790+
I->getOpcode() == Instruction::AddrSpaceCast;
791+
I = I->getNextNode()) {
783792
auto *AllocaI = dyn_cast<AllocaInst>(I);
784793
// Allocas marked with "work_item_scope" are those originating from
785794
// cl::sycl::private_memory<T> variables, which must be in private memory.
786795
// No shadows/materialization is needed for them because they can be
787796
// updated only within PFWIs
788-
if (!AllocaI->getMetadata(WI_SCOPE_MD))
797+
if (AllocaI && !AllocaI->getMetadata(WI_SCOPE_MD))
789798
Allocas.insert(AllocaI);
790799
}
791800
for (; I && (I != BB.getTerminator()); I = I->getNextNode()) {
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
2+
; RUN: opt < %s -LowerWGScope -S | FileCheck %s
3+
4+
%struct.ham = type { i64, i64, i32, i32 }
5+
%struct.bar = type { i64 }
6+
%struct.spam = type { i64, i64, i64, i64, i32 }
7+
8+
; CHECK: @[[SHADOW4:.*]] = internal unnamed_addr addrspace(3) global %struct.ham addrspace(4)*
9+
; CHECK: @[[SHADOW3:.*]] = internal unnamed_addr addrspace(3) global %struct.spam
10+
; CHECK: @[[SHADOW2:.*]] = internal unnamed_addr addrspace(3) global %struct.ham
11+
; CHECK: @[[SHADOW1:.*]] = internal unnamed_addr addrspace(3) global %struct.bar
12+
13+
define linkonce_odr dso_local spir_func void @foo(%struct.ham addrspace(4)* dereferenceable_or_null(56) %arg, %struct.bar* byval(%struct.bar) align 8 %arg1) !work_group_scope !0 {
14+
; CHECK-LABEL: @foo(
15+
; CHECK-NEXT: bb:
16+
; CHECK-NEXT: [[TMP:%.*]] = alloca [[STRUCT_HAM:%.*]] addrspace(4)*, align 8
17+
; CHECK-NEXT: [[TMP0:%.*]] = load i64, i64 addrspace(1)* @__spirv_BuiltInLocalInvocationIndex, align 4
18+
; CHECK-NEXT: call void @_Z22__spirv_ControlBarrierjjj(i32 2, i32 2, i32 272) [[ATTR0:#.*]]
19+
; CHECK-NEXT: [[CMPZ3:%.*]] = icmp eq i64 [[TMP0]], 0
20+
; CHECK-NEXT: br i1 [[CMPZ3]], label [[LEADER:%.*]], label [[MERGE:%.*]]
21+
; CHECK: leader:
22+
; CHECK-NEXT: [[TMP1:%.*]] = bitcast %struct.bar* [[ARG1:%.*]] to i8*
23+
; CHECK-NEXT: call void @llvm.memcpy.p3i8.p0i8.i64(i8 addrspace(3)* align 8 bitcast (%struct.bar addrspace(3)* @[[SHADOW1]] to i8 addrspace(3)*), i8* align 8 [[TMP1]], i64 8, i1 false)
24+
; CHECK-NEXT: [[TMP2:%.*]] = bitcast [[STRUCT_HAM]] addrspace(4)* [[ARG:%.*]] to i8 addrspace(4)*
25+
; CHECK-NEXT: call void @llvm.memcpy.p3i8.p4i8.i64(i8 addrspace(3)* align 16 bitcast (%struct.ham addrspace(3)* @[[SHADOW2]] to i8 addrspace(3)*), i8 addrspace(4)* align 8 [[TMP2]], i64 24, i1 false)
26+
; CHECK-NEXT: br label [[MERGE]]
27+
; CHECK: merge:
28+
; CHECK-NEXT: call void @_Z22__spirv_ControlBarrierjjj(i32 2, i32 2, i32 272) [[ATTR0]]
29+
; CHECK-NEXT: [[TMP3:%.*]] = bitcast %struct.bar* [[ARG1]] to i8*
30+
; CHECK-NEXT: call void @llvm.memcpy.p0i8.p3i8.i64(i8* align 8 [[TMP3]], i8 addrspace(3)* align 8 bitcast (%struct.bar addrspace(3)* @[[SHADOW1]] to i8 addrspace(3)*), i64 8, i1 false)
31+
; CHECK-NEXT: [[TMP4:%.*]] = bitcast [[STRUCT_HAM]] addrspace(4)* [[ARG]] to i8 addrspace(4)*
32+
; CHECK-NEXT: call void @llvm.memcpy.p4i8.p3i8.i64(i8 addrspace(4)* align 8 [[TMP4]], i8 addrspace(3)* align 16 bitcast (%struct.ham addrspace(3)* @[[SHADOW2]] to i8 addrspace(3)*), i64 24, i1 false)
33+
; CHECK-NEXT: [[TMP2:%.*]] = addrspacecast [[STRUCT_HAM]] addrspace(4)** [[TMP]] to [[STRUCT_HAM]] addrspace(4)* addrspace(4)*
34+
; CHECK-NEXT: [[TMP3:%.*]] = alloca [[STRUCT_SPAM:%.*]], align 8
35+
; CHECK-NEXT: [[TMP4:%.*]] = addrspacecast %struct.spam* [[TMP3]] to [[STRUCT_SPAM]] addrspace(4)*
36+
; CHECK-NEXT: [[TMP5:%.*]] = load i64, i64 addrspace(1)* @__spirv_BuiltInLocalInvocationIndex, align 4
37+
; CHECK-NEXT: call void @_Z22__spirv_ControlBarrierjjj(i32 2, i32 2, i32 272) [[ATTR0]]
38+
; CHECK-NEXT: [[CMPZ:%.*]] = icmp eq i64 [[TMP5]], 0
39+
; CHECK-NEXT: br i1 [[CMPZ]], label [[WG_LEADER:%.*]], label [[WG_CF:%.*]]
40+
; CHECK: wg_leader:
41+
; CHECK-NEXT: store [[STRUCT_HAM]] addrspace(4)* [[ARG]], [[STRUCT_HAM]] addrspace(4)* addrspace(4)* [[TMP2]], align 8
42+
; CHECK-NEXT: br label [[WG_CF]]
43+
; CHECK: wg_cf:
44+
; CHECK-NEXT: [[TMP6:%.*]] = load i64, i64 addrspace(1)* @__spirv_BuiltInLocalInvocationIndex, align 4
45+
; CHECK-NEXT: call void @_Z22__spirv_ControlBarrierjjj(i32 2, i32 2, i32 272) [[ATTR0]]
46+
; CHECK-NEXT: [[CMPZ2:%.*]] = icmp eq i64 [[TMP6]], 0
47+
; CHECK-NEXT: br i1 [[CMPZ2]], label [[TESTMAT:%.*]], label [[LEADERMAT:%.*]]
48+
; CHECK: TestMat:
49+
; CHECK-NEXT: [[TMP7:%.*]] = bitcast %struct.spam* [[TMP3]] to i8*
50+
; CHECK-NEXT: call void @llvm.memcpy.p3i8.p0i8.i64(i8 addrspace(3)* align 16 bitcast (%struct.spam addrspace(3)* @[[SHADOW3]] to i8 addrspace(3)*), i8* align 8 [[TMP7]], i64 36, i1 false)
51+
; CHECK-NEXT: [[MAT_LD:%.*]] = load [[STRUCT_HAM]] addrspace(4)*, [[STRUCT_HAM]] addrspace(4)** [[TMP]], align 8
52+
; CHECK-NEXT: store [[STRUCT_HAM]] addrspace(4)* [[MAT_LD]], [[STRUCT_HAM]] addrspace(4)* addrspace(3)* @[[SHADOW4]], align 8
53+
; CHECK-NEXT: br label [[LEADERMAT]]
54+
; CHECK: LeaderMat:
55+
; CHECK-NEXT: call void @_Z22__spirv_ControlBarrierjjj(i32 2, i32 2, i32 272) [[ATTR0]]
56+
; CHECK-NEXT: [[MAT_LD1:%.*]] = load [[STRUCT_HAM]] addrspace(4)*, [[STRUCT_HAM]] addrspace(4)* addrspace(3)* @[[SHADOW4]], align 8
57+
; CHECK-NEXT: store [[STRUCT_HAM]] addrspace(4)* [[MAT_LD1]], [[STRUCT_HAM]] addrspace(4)** [[TMP]], align 8
58+
; CHECK-NEXT: [[TMP8:%.*]] = bitcast %struct.spam* [[TMP3]] to i8*
59+
; CHECK-NEXT: call void @llvm.memcpy.p0i8.p3i8.i64(i8* align 8 [[TMP8]], i8 addrspace(3)* align 16 bitcast (%struct.spam addrspace(3)* @[[SHADOW3]] to i8 addrspace(3)*), i64 36, i1 false)
60+
; CHECK-NEXT: call void @_Z22__spirv_ControlBarrierjjj(i32 2, i32 2, i32 272) [[ATTR0]]
61+
; CHECK-NEXT: [[TMP5:%.*]] = addrspacecast %struct.bar* [[ARG1]] to [[STRUCT_BAR:%.*]] addrspace(4)*
62+
; CHECK-NEXT: [[TMP6:%.*]] = addrspacecast [[STRUCT_SPAM]] addrspace(4)* [[TMP4]] to %struct.spam*
63+
; CHECK-NEXT: call spir_func void @widget(%struct.bar addrspace(4)* dereferenceable_or_null(32) [[TMP5]], %struct.spam* byval(%struct.spam) align 8 [[TMP6]])
64+
; CHECK-NEXT: ret void
65+
;
66+
bb:
67+
%tmp = alloca %struct.ham addrspace(4)*, align 8
68+
%tmp2 = addrspacecast %struct.ham addrspace(4)** %tmp to %struct.ham addrspace(4)* addrspace(4)*
69+
%tmp3 = alloca %struct.spam, align 8
70+
%tmp4 = addrspacecast %struct.spam* %tmp3 to %struct.spam addrspace(4)*
71+
store %struct.ham addrspace(4)* %arg, %struct.ham addrspace(4)* addrspace(4)* %tmp2, align 8
72+
%tmp5 = addrspacecast %struct.bar* %arg1 to %struct.bar addrspace(4)*
73+
%tmp6 = addrspacecast %struct.spam addrspace(4)* %tmp4 to %struct.spam*
74+
call spir_func void @widget(%struct.bar addrspace(4)* dereferenceable_or_null(32) %tmp5, %struct.spam* byval(%struct.spam) align 8 %tmp6)
75+
ret void
76+
}
77+
78+
define linkonce_odr dso_local spir_func void @widget(%struct.bar addrspace(4)* dereferenceable_or_null(32) %arg, %struct.spam* byval(%struct.spam) align 8 %arg1) !work_item_scope !0 !parallel_for_work_item !0 {
79+
bb:
80+
ret void
81+
}
82+
83+
!0 = !{}

0 commit comments

Comments
 (0)