Skip to content

Commit f2780fe

Browse files
committed
[SPIR-V] Add legalize-addrspace-cast pass
This commit adds a new pass in the backend which propagates the addrspace of the pointers down to the last use, making sure the addrspace remains consistent, and thus stripping any addrspacecast. This is required to lower LLVM-IR to logical SPIR-V, which does not support generic pointers. This is now required as HLSL emits several address spaces, and thus addrspacecasts in some cases: Example 1: resource access ```llvm %handle = tail call target("spirv.VulkanBuffer", ...) %rptr = @llvm.spv.resource.getpointer(%handle, ...); %cptr = addrspacecast ptr addrspace(11) %rptr to ptr %fptr = load i32, ptr %cptr ``` Example 2: object methods ```llvm define void @objectmethod(ptr %this) { } define void @foo(ptr addrspace(11) %object) { call void @objectmethod(ptr addrspacecast(addrspace(11) %object to ptr)); } ```
1 parent 4135a63 commit f2780fe

File tree

6 files changed

+236
-0
lines changed

6 files changed

+236
-0
lines changed

llvm/lib/Target/SPIRV/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ add_llvm_target(SPIRVCodeGen
2828
SPIRVInstructionSelector.cpp
2929
SPIRVStripConvergentIntrinsics.cpp
3030
SPIRVLegalizePointerCast.cpp
31+
SPIRVLegalizeAddrspaceCast.cpp
3132
SPIRVMergeRegionExitTargets.cpp
3233
SPIRVISelLowering.cpp
3334
SPIRVLegalizerInfo.cpp

llvm/lib/Target/SPIRV/SPIRV.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ FunctionPass *createSPIRVStructurizerPass();
2424
FunctionPass *createSPIRVMergeRegionExitTargetsPass();
2525
FunctionPass *createSPIRVStripConvergenceIntrinsicsPass();
2626
FunctionPass *createSPIRVLegalizePointerCastPass(SPIRVTargetMachine *TM);
27+
FunctionPass *createSPIRVLegalizeAddrspaceCastPass(SPIRVTargetMachine *TM);
2728
FunctionPass *createSPIRVRegularizerPass();
2829
FunctionPass *createSPIRVPreLegalizerCombiner();
2930
FunctionPass *createSPIRVPreLegalizerPass();
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
//===-- SPIRVLegalizeAddrspaceCast.cpp ----------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "SPIRV.h"
10+
#include "SPIRVSubtarget.h"
11+
#include "SPIRVTargetMachine.h"
12+
#include "SPIRVUtils.h"
13+
#include "llvm/CodeGen/IntrinsicLowering.h"
14+
#include "llvm/IR/IRBuilder.h"
15+
#include "llvm/IR/IntrinsicInst.h"
16+
#include "llvm/IR/Intrinsics.h"
17+
#include "llvm/IR/IntrinsicsSPIRV.h"
18+
#include "llvm/Transforms/Utils/Cloning.h"
19+
#include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
20+
21+
using namespace llvm;
22+
23+
namespace llvm {
24+
void initializeSPIRVLegalizeAddrspaceCastPass(PassRegistry &);
25+
}
26+
27+
class SPIRVLegalizeAddrspaceCast : public FunctionPass {
28+
29+
public:
30+
SPIRVLegalizeAddrspaceCast(SPIRVTargetMachine *TM)
31+
: FunctionPass(ID), TM(TM) {
32+
initializeSPIRVLegalizeAddrspaceCastPass(*PassRegistry::getPassRegistry());
33+
};
34+
35+
void gatherAddrspaceCast(Function &F) {
36+
WorkList.clear();
37+
std::vector<User *> ToVisit;
38+
for (auto &BB : F)
39+
for (auto &I : BB)
40+
ToVisit.push_back(&I);
41+
42+
std::unordered_set<User *> Visited;
43+
while (ToVisit.size() > 0) {
44+
User *I = ToVisit.back();
45+
ToVisit.pop_back();
46+
if (Visited.count(I) != 0)
47+
continue;
48+
Visited.insert(I);
49+
50+
if (AddrSpaceCastInst *AI = dyn_cast<AddrSpaceCastInst>(I))
51+
WorkList.insert(AI);
52+
else if (auto *AO = dyn_cast<AddrSpaceCastOperator>(I))
53+
WorkList.insert(AO);
54+
55+
for (auto &O : I->operands())
56+
if (User *U = dyn_cast<User>(&O))
57+
ToVisit.push_back(U);
58+
}
59+
}
60+
61+
void propagateAddrspace(User *U) {
62+
if (!U->getType()->isPointerTy())
63+
return;
64+
65+
if (AddrSpaceCastOperator *AO = dyn_cast<AddrSpaceCastOperator>(U)) {
66+
for (auto &Use : AO->uses())
67+
WorkList.insert(Use.getUser());
68+
69+
AO->mutateType(AO->getPointerOperand()->getType());
70+
AO->replaceAllUsesWith(AO->getPointerOperand());
71+
DeadUsers.insert(AO);
72+
return;
73+
}
74+
75+
if (AddrSpaceCastInst *AC = dyn_cast<AddrSpaceCastInst>(U)) {
76+
for (auto &Use : AC->uses())
77+
WorkList.insert(Use.getUser());
78+
79+
AC->mutateType(AC->getPointerOperand()->getType());
80+
AC->replaceAllUsesWith(AC->getPointerOperand());
81+
return;
82+
}
83+
84+
PointerType *NewType = nullptr;
85+
for (Use &U : U->operands()) {
86+
PointerType *PT = dyn_cast<PointerType>(U.get()->getType());
87+
if (!PT)
88+
continue;
89+
90+
if (NewType == nullptr)
91+
NewType = PT;
92+
else {
93+
// We could imagine a function calls taking 2 pointers to distinct
94+
// address spaces which returns a pointer. But we want to run this
95+
// pass after inlining, so we'll assume this doesn't happen.
96+
assert(NewType->getAddressSpace() == PT->getAddressSpace());
97+
}
98+
}
99+
100+
assert(NewType != nullptr);
101+
U->mutateType(NewType);
102+
}
103+
104+
virtual bool runOnFunction(Function &F) override {
105+
const SPIRVSubtarget &ST = TM->getSubtarget<SPIRVSubtarget>(F);
106+
GR = ST.getSPIRVGlobalRegistry();
107+
108+
DeadUsers.clear();
109+
gatherAddrspaceCast(F);
110+
111+
while (WorkList.size() > 0) {
112+
User *U = *WorkList.begin();
113+
WorkList.erase(U);
114+
propagateAddrspace(U);
115+
}
116+
117+
for (User *U : DeadUsers) {
118+
if (Instruction *I = dyn_cast<Instruction>(U))
119+
I->eraseFromParent();
120+
}
121+
return DeadUsers.size() != 0;
122+
}
123+
124+
private:
125+
SPIRVTargetMachine *TM = nullptr;
126+
SPIRVGlobalRegistry *GR = nullptr;
127+
std::unordered_set<User *> WorkList;
128+
std::unordered_set<User *> DeadUsers;
129+
130+
public:
131+
static char ID;
132+
};
133+
134+
char SPIRVLegalizeAddrspaceCast::ID = 0;
135+
INITIALIZE_PASS(SPIRVLegalizeAddrspaceCast, "spirv-legalize-addrspacecast",
136+
"SPIRV legalize addrspacecast", false, false)
137+
138+
FunctionPass *
139+
llvm::createSPIRVLegalizeAddrspaceCastPass(SPIRVTargetMachine *TM) {
140+
return new SPIRVLegalizeAddrspaceCast(TM);
141+
}

llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,8 @@ void SPIRVPassConfig::addIRPasses() {
190190
TargetPassConfig::addIRPasses();
191191

192192
if (TM.getSubtargetImpl()->isVulkanEnv()) {
193+
addPass(createSPIRVLegalizeAddrspaceCastPass(&getTM<SPIRVTargetMachine>()));
194+
193195
// 1. Simplify loop for subsequent transformations. After this steps, loops
194196
// have the following properties:
195197
// - loops have a single entry edge (pre-header to loop header).
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
; RUN: llc -verify-machineinstrs -O3 -mtriple=spirv-unknown-vulkan1.3-compute %s -o - | FileCheck %s --match-full-lines
2+
; RUN: %if spirv-tools %{ llc -O3 -mtriple=spirv-unknown-vulkan1.3-compute %s -o - -filetype=obj | spirv-val %}
3+
4+
; FIXME(134119): enable-this once Offset decoration are added.
5+
; XFAIL: spirv-tools
6+
7+
%S2 = type { { [10 x { i32, i32 } ] }, i32 }
8+
9+
; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0
10+
; CHECK-DAG: %[[#uint_0:]] = OpConstant %[[#uint]] 0
11+
; CHECK-DAG: %[[#uint_1:]] = OpConstant %[[#uint]] 1
12+
; CHECK-DAG: %[[#uint_3:]] = OpConstant %[[#uint]] 3
13+
; CHECK-DAG: %[[#uint_10:]] = OpConstant %[[#uint]] 10
14+
; CHECK-DAG: %[[#uint_11:]] = OpConstant %[[#uint]] 11
15+
; CHECK-DAG: %[[#ptr_StorageBuffer_uint:]] = OpTypePointer StorageBuffer %[[#uint]]
16+
17+
; CHECK-DAG: %[[#t_s2_s_a_s:]] = OpTypeStruct %[[#uint]] %[[#uint]]
18+
; CHECK-DAG: %[[#t_s2_s_a:]] = OpTypeArray %[[#t_s2_s_a_s]] %[[#uint_10]]
19+
; CHECK-DAG: %[[#t_s2_s:]] = OpTypeStruct %[[#t_s2_s_a]]
20+
; CHECK-DAG: %[[#t_s2:]] = OpTypeStruct %[[#t_s2_s]] %[[#uint]]
21+
22+
; CHECK-DAG: %[[#ptr_StorageBuffer_struct:]] = OpTypePointer StorageBuffer %[[#t_s2]]
23+
; CHECK-DAG: %[[#rarr:]] = OpTypeRuntimeArray %[[#t_s2]]
24+
; CHECK-DAG: %[[#rarr_struct:]] = OpTypeStruct %[[#rarr]]
25+
; CHECK-DAG: %[[#spirv_VulkanBuffer:]] = OpTypePointer StorageBuffer %[[#rarr_struct]]
26+
27+
declare target("spirv.VulkanBuffer", [0 x %S2], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0s_Ss_12_1t(i32, i32, i32, i32, i1)
28+
29+
define void @main() "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" {
30+
entry:
31+
%handle = tail call target("spirv.VulkanBuffer", [0 x %S2], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0s_Ss_12_1t(i32 0, i32 0, i32 1, i32 0, i1 false)
32+
; CHECK: %[[#resource:]] = OpVariable %[[#spirv_VulkanBuffer]] StorageBuffer
33+
34+
%ptr = tail call noundef align 4 dereferenceable(4) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0s_Ss_12_1t(target("spirv.VulkanBuffer", [0 x %S2], 12, 1) %handle, i32 0)
35+
; CHECK: %[[#a:]] = OpCopyObject %[[#spirv_VulkanBuffer]] %[[#resource]]
36+
; CHECK: %[[#b:]] = OpAccessChain %[[#ptr_StorageBuffer_struct]] %[[#a:]] %[[#uint_0]] %[[#uint_0]]
37+
%casted = addrspacecast ptr addrspace(11) %ptr to ptr
38+
39+
; CHECK: %[[#ptr2:]] = OpInBoundsAccessChain %[[#ptr_StorageBuffer_uint]] %[[#b:]] %[[#uint_0]] %[[#uint_0]] %[[#uint_3]] %[[#uint_1]]
40+
%ptr2 = getelementptr inbounds %S2, ptr %casted, i64 0, i32 0, i32 0, i32 3, i32 1
41+
42+
; CHECK: OpStore %[[#ptr2]] %[[#uint_10]] Aligned 4
43+
store i32 10, ptr %ptr2, align 4
44+
45+
; Another store, but this time using LLVM's ability to load the first element
46+
; without an explicit GEP. The backend has to determine the ptr type and
47+
; generate the appropriate access chain.
48+
; CHECK: %[[#ptr3:]] = OpInBoundsAccessChain %[[#ptr_StorageBuffer_uint]] %[[#b:]] %[[#uint_0]] %[[#uint_0]] %[[#uint_0]] %[[#uint_0]]
49+
; CHECK: OpStore %[[#ptr3]] %[[#uint_11]] Aligned 4
50+
store i32 11, ptr %casted, align 4
51+
ret void
52+
}
53+
54+
declare ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0s_S2s_12_1t(target("spirv.VulkanBuffer", [0 x %S2], 12, 1), i32)
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
; RUN: llc -verify-machineinstrs -O3 -mtriple=spirv-unknown-vulkan1.3-compute %s -o - | FileCheck %s
2+
; RUN: %if spirv-tools %{ llc -O3 -mtriple=spirv-unknown-vulkan1.3-compute %s -o - -filetype=obj | spirv-val %}
3+
4+
; FIXME(134119): enable-this once Offset decoration are added.
5+
; XFAIL: spirv-tools
6+
7+
%struct.S = type { i32 }
8+
9+
; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0
10+
; CHECK-DAG: %[[#uint_0:]] = OpConstant %[[#uint]] 0
11+
; CHECK-DAG: %[[#uint_10:]] = OpConstant %[[#uint]] 10
12+
; CHECK-DAG: %[[#ptr_StorageBuffer_uint:]] = OpTypePointer StorageBuffer %[[#uint]]
13+
; CHECK-DAG: %[[#struct:]] = OpTypeStruct %[[#uint]]
14+
; CHECK-DAG: %[[#ptr_StorageBuffer_struct:]] = OpTypePointer StorageBuffer %[[#struct]]
15+
; CHECK-DAG: %[[#rarr:]] = OpTypeRuntimeArray %[[#struct]]
16+
; CHECK-DAG: %[[#rarr_struct:]] = OpTypeStruct %[[#rarr]]
17+
; CHECK-DAG: %[[#spirv_VulkanBuffer:]] = OpTypePointer StorageBuffer %[[#rarr_struct]]
18+
19+
declare target("spirv.VulkanBuffer", [0 x %struct.S], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0s_struct.Ss_12_1t(i32, i32, i32, i32, i1)
20+
21+
define void @main() "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" {
22+
entry:
23+
%handle = tail call target("spirv.VulkanBuffer", [0 x %struct.S], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0s_struct.Ss_12_1t(i32 0, i32 0, i32 1, i32 0, i1 false)
24+
; CHECK: %[[#resource:]] = OpVariable %[[#spirv_VulkanBuffer]] StorageBuffer
25+
26+
%ptr = tail call noundef align 4 dereferenceable(4) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0s_struct.Ss_12_1t(target("spirv.VulkanBuffer", [0 x %struct.S], 12, 1) %handle, i32 0)
27+
; CHECK: %[[#a:]] = OpCopyObject %[[#spirv_VulkanBuffer]] %[[#resource]]
28+
; CHECK: %[[#b:]] = OpAccessChain %[[#ptr_StorageBuffer_struct]] %[[#a:]] %[[#uint_0]] %[[#uint_0]]
29+
; CHECK: %[[#c:]] = OpInBoundsAccessChain %[[#ptr_StorageBuffer_uint]] %[[#b:]] %[[#uint_0]]
30+
%casted = addrspacecast ptr addrspace(11) %ptr to ptr
31+
32+
; CHECK: OpStore %[[#c]] %[[#uint_10]] Aligned 4
33+
store i32 10, ptr %casted, align 4
34+
ret void
35+
}
36+
37+
declare ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0s_struct.Ss_12_1t(target("spirv.VulkanBuffer", [0 x %struct.S], 12, 1), i32)

0 commit comments

Comments
 (0)