diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp index 28a1690ef0be1..a692c24363310 100644 --- a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp @@ -73,16 +73,23 @@ class SPIRVLegalizePointerCast : public FunctionPass { // Returns the loaded value. Value *loadVectorFromVector(IRBuilder<> &B, FixedVectorType *SourceType, FixedVectorType *TargetType, Value *Source) { - assert(TargetType->getNumElements() <= SourceType->getNumElements()); LoadInst *NewLoad = B.CreateLoad(SourceType, Source); buildAssignType(B, SourceType, NewLoad); Value *AssignValue = NewLoad; if (TargetType->getElementType() != SourceType->getElementType()) { + const DataLayout &DL = B.GetInsertBlock()->getModule()->getDataLayout(); + [[maybe_unused]] TypeSize TargetTypeSize = + DL.getTypeSizeInBits(TargetType); + [[maybe_unused]] TypeSize SourceTypeSize = + DL.getTypeSizeInBits(SourceType); + assert(TargetTypeSize == SourceTypeSize); AssignValue = B.CreateIntrinsic(Intrinsic::spv_bitcast, {TargetType, SourceType}, {NewLoad}); buildAssignType(B, TargetType, AssignValue); + return AssignValue; } + assert(TargetType->getNumElements() < SourceType->getNumElements()); SmallVector Mask(/* Size= */ TargetType->getNumElements()); for (unsigned I = 0; I < TargetType->getNumElements(); ++I) Mask[I] = I; diff --git a/llvm/test/CodeGen/SPIRV/hlsl-resources/issue-146942-ptr-cast.ll b/llvm/test/CodeGen/SPIRV/hlsl-resources/issue-146942-ptr-cast.ll index ed67344842b11..4817e7450ac2e 100644 --- a/llvm/test/CodeGen/SPIRV/hlsl-resources/issue-146942-ptr-cast.ll +++ b/llvm/test/CodeGen/SPIRV/hlsl-resources/issue-146942-ptr-cast.ll @@ -16,7 +16,6 @@ define void @case1() local_unnamed_addr { ; CHECK: %[[#BUFFER_LOAD:]] = OpLoad %[[#FLOAT4]] %{{[0-9]+}} Aligned 16 ; CHECK: %[[#CAST_LOAD:]] = OpBitcast %[[#INT4]] %[[#BUFFER_LOAD]] - ; CHECK: %[[#VEC_SHUFFLE:]] = OpVectorShuffle %[[#INT4]] %[[#CAST_LOAD]] %[[#CAST_LOAD]] 0 1 2 3 %1 = tail call target("spirv.VulkanBuffer", [0 x <4 x float>], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v4f32_12_0t(i32 0, i32 2, i32 1, i32 0, ptr nonnull @.str) %2 = tail call target("spirv.VulkanBuffer", [0 x <4 x i32>], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v4i32_12_1t(i32 0, i32 5, i32 1, i32 0, ptr nonnull @.str.2) %3 = tail call noundef align 16 dereferenceable(16) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v4f32_12_0t(target("spirv.VulkanBuffer", [0 x <4 x float>], 12, 0) %1, i32 0) @@ -29,8 +28,7 @@ define void @case1() local_unnamed_addr { define void @case2() local_unnamed_addr { ; CHECK: %[[#BUFFER_LOAD:]] = OpLoad %[[#FLOAT4]] %{{[0-9]+}} Aligned 16 ; CHECK: %[[#CAST_LOAD:]] = OpBitcast %[[#INT4]] %[[#BUFFER_LOAD]] - ; CHECK: %[[#VEC_SHUFFLE:]] = OpVectorShuffle %[[#INT4]] %[[#CAST_LOAD]] %[[#CAST_LOAD]] 0 1 2 3 - ; CHECK: %[[#VEC_TRUNCATE:]] = OpVectorShuffle %[[#INT3]] %[[#VEC_SHUFFLE]] %[[#UNDEF_INT4]] 0 1 2 + ; CHECK: %[[#VEC_TRUNCATE:]] = OpVectorShuffle %[[#INT3]] %[[#CAST_LOAD]] %[[#UNDEF_INT4]] 0 1 2 %1 = tail call target("spirv.VulkanBuffer", [0 x <4 x float>], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v4f32_12_0t(i32 0, i32 2, i32 1, i32 0, ptr nonnull @.str) %2 = tail call target("spirv.VulkanBuffer", [0 x <3 x i32>], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v3i32_12_1t(i32 0, i32 5, i32 1, i32 0, ptr nonnull @.str.3) %3 = tail call noundef align 16 dereferenceable(16) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v4f32_12_0t(target("spirv.VulkanBuffer", [0 x <4 x float>], 12, 0) %1, i32 0) diff --git a/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll b/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll index 84913283f6868..a1ec2cd1cfdd2 100644 --- a/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll +++ b/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll @@ -26,3 +26,25 @@ entry: store <4 x i32> %6, ptr addrspace(11) %7, align 16 ret void } + +; This tests a load from a pointer that has been bitcast between vector types +; which share the same total bit-width but have different numbers of elements. +; Tests that legalize-pointer-casts works correctly by moving the bitcast to +; the element that was loaded. + +define void @main2() local_unnamed_addr #0 { +entry: +; CHECK: %[[LOAD:[0-9]+]] = OpLoad %[[#v2_double]] {{.*}} +; CHECK: %[[BITCAST1:[0-9]+]] = OpBitcast %[[#v4_uint]] %[[LOAD]] +; CHECK: %[[BITCAST2:[0-9]+]] = OpBitcast %[[#v2_double]] %[[BITCAST1]] +; CHECK: OpStore {{%[0-9]+}} %[[BITCAST2]] {{.*}} + + %0 = tail call target("spirv.VulkanBuffer", [0 x <2 x double>], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v2f64_12_1t(i32 0, i32 2, i32 1, i32 0, ptr nonnull @.str.2) + %2 = tail call noundef align 16 dereferenceable(16) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v2f64_12_1t(target("spirv.VulkanBuffer", [0 x <2 x double>], 12, 1) %0, i32 0) + %3 = load <4 x i32>, ptr addrspace(11) %2 + %4 = tail call noundef align 16 dereferenceable(16) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v2f64_12_1t(target("spirv.VulkanBuffer", [0 x <2 x double>], 12, 1) %0, i32 1) + store <4 x i32> %3, ptr addrspace(11) %4 + ret void +} + +attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }