Skip to content

Commit 1be9366

Browse files
authored
[LLVM->SPIRV] Cast the GEP base pointer to source type upon mismatch (#3255)
The source element type used in a GEP may differ from the actual type of the pointer operand (e.g., ptr i8 vs. ptr [N x T]). This mismatch can lead to incorrect address computations during translation to SPIR-V of GEP used in constexpr context, which requires that pointer types match the type of the object being accessed. This patch inserts an explicit bitcast to convert the GEP pointer operand to the expected type, derived from the GEP’s source element type, before emitting an PtrAccessChain. This ensures the resulting SPIR-V instruction has a correctly typed base pointer and produces valid indexing behavior. For example: Before this change, the following GEP was translated incorrectly: getelementptr(i8, ptr addrspace(1) @a_var, i64 2) Whereas this nearly equivalent GEP was handled correctly: getelementptr inbounds ([2 x i8], ptr @a_var, i64 0, i64 1) Previously, the first form was incorrectly interpreted as: getelementptr inbounds ([2 x i8], ptr @a_var, i64 0, i64 2)
1 parent 58512e7 commit 1be9366

File tree

2 files changed

+44
-0
lines changed

2 files changed

+44
-0
lines changed

lib/SPIRV/SPIRVWriter.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1474,6 +1474,18 @@ SPIRVValue *LLVMToSPIRVBase::transConstant(Value *V) {
14741474
if (auto *ConstUE = dyn_cast<ConstantExpr>(V)) {
14751475
if (auto *GEP = dyn_cast<GEPOperator>(ConstUE)) {
14761476
auto *TransPointerOperand = transValue(GEP->getPointerOperand(), nullptr);
1477+
// Determine the expected pointer type from the GEP source element type.
1478+
Type *GepSourceElemTy = GEP->getSourceElementType();
1479+
SPIRVType *ExpectedPtrTy =
1480+
transPointerType(GepSourceElemTy, GEP->getPointerAddressSpace());
1481+
1482+
// Ensure the base pointer's type matches the GEP's effective source
1483+
// element type.
1484+
if (TransPointerOperand->getType() != ExpectedPtrTy) {
1485+
TransPointerOperand = BM->addUnaryInst(OpBitcast, ExpectedPtrTy,
1486+
TransPointerOperand, nullptr);
1487+
}
1488+
14771489
std::vector<SPIRVWord> Ops = {TransPointerOperand->getId()};
14781490
for (unsigned I = 0, E = GEP->getNumIndices(); I != E; ++I)
14791491
Ops.push_back(transValue(GEP->getOperand(I + 1), nullptr)->getId());

test/gep-operand-source-mismatch.ll

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
; RUN: llvm-as %s -o %t.bc
2+
; RUN: llvm-spirv %t.bc -o %t.spv
3+
; RUN: spirv-val %t.spv
4+
; RUN: llvm-spirv %t.spv -o %t.spt --to-text
5+
; RUN: llvm-spirv -r %t.spv -o %t.rev.bc
6+
; RUN: llvm-dis %t.rev.bc
7+
; RUN: FileCheck %s --input-file %t.spt -check-prefix=CHECK-SPIRV
8+
; RUN: FileCheck %s --input-file %t.rev.ll -check-prefix=CHECK-LLVM
9+
10+
; Make sure that when the GEP operand type doesn't match the source element type (here operand a_var is [2 x i16], but the source element is i8),
11+
; we cast the operand to the source element pointer type (a_var to i8*).
12+
13+
target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
14+
target triple = "spir-unknown-unknown"
15+
16+
; CHECK-SPIRV-DAG: Name [[A_VAR:[0-9]+]] "a_var"
17+
; CHECK-SPIRV-DAG: Name [[GLOBAL_PTR:[0-9]+]] "global_ptr"
18+
19+
; CHECK-SPIRV-DAG: TypeArray [[ARRAY_TYPE:[0-9]+]] [[USHORT_TYPE:[0-9]+]] [[CONST_2:[0-9]+]]
20+
; CHECK-SPIRV-DAG: TypePointer [[ARRAY_PTR_TYPE:[0-9]+]] 5 [[ARRAY_TYPE]]
21+
22+
; CHECK-SPIRV-DAG: Variable [[ARRAY_PTR_TYPE]] [[A_VAR]] 5 [[INIT_ID:[0-9]+]]
23+
; CHECK-SPIRV-DAG: SpecConstantOp [[I8PTR:[0-9]+]] [[BITCAST:[0-9]+]] 124 [[A_VAR]]
24+
; CHECK-SPIRV-DAG: SpecConstantOp [[I8PTR]] [[PTRCHAIN:[0-9]+]] 67 [[BITCAST]] [[INDEX_ID:[0-9]+]]
25+
; CHECK-SPIRV-DAG: TypePointer [[PTR_PTR_TYPE:[0-9]+]] 5 [[I8PTR]]
26+
; CHECK-SPIRV-DAG: Variable [[PTR_PTR_TYPE]] [[GLOBAL_PTR]] 5 [[PTRCHAIN]]
27+
28+
; CHECK-LLVM: @global_ptr = addrspace(1) global ptr addrspace(1) getelementptr (i8, ptr addrspace(1) @a_var, i64 2), align 8
29+
; CHECK-LLVM-NOT: @global_ptr = addrspace(1) global ptr addrspace(1) getelementptr ([2 x i16], ptr addrspace(1) @a_var, i64 2), align 8
30+
31+
@a_var = dso_local addrspace(1) global [2 x i16] [i16 4, i16 5], align 2
32+
@global_ptr = dso_local addrspace(1) global ptr addrspace(1) getelementptr (i8, ptr addrspace(1) @a_var, i64 2), align 8

0 commit comments

Comments
 (0)