Skip to content

Conversation

@s-perron
Copy link
Contributor

@s-perron s-perron commented Oct 24, 2025

The previous check for vector bitcasts in loadVectorFromVector only
compared the number of elements, which is insufficient when the element
types differ. This can lead to incorrect assumptions about the validity
of the cast.

This commit replaces the element count check with a comparison of the
total size of the vectors in bits. This ensures that the bitcast is
only performed between vectors of the same size, preventing potential
miscompilations.

Part of #153091

The previous check for vector bitcasts in `loadVectorFromVector` only
compared the number of elements, which is insufficient when the element
types differ. This can lead to incorrect assumptions about the validity
of the cast.

This commit replaces the element count check with a comparison of the
total size of the vectors in bits. This ensures that the bitcast is
only performed between vectors of the same size, preventing potential
miscompilations.
@llvmbot
Copy link
Member

llvmbot commented Oct 24, 2025

@llvm/pr-subscribers-backend-spir-v

Author: Steven Perron (s-perron)

Changes

The previous check for vector bitcasts in loadVectorFromVector only
compared the number of elements, which is insufficient when the element
types differ. This can lead to incorrect assumptions about the validity
of the cast.

This commit replaces the element count check with a comparison of the
total size of the vectors in bits. This ensures that the bitcast is
only performed between vectors of the same size, preventing potential
miscompilations.


Full diff: https://github.com/llvm/llvm-project/pull/164997.diff

3 Files Affected:

  • (modified) llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp (+8-1)
  • (modified) llvm/test/CodeGen/SPIRV/hlsl-resources/issue-146942-ptr-cast.ll (+1-3)
  • (modified) llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll (+22)
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
index 28a1690ef0be1..a692c24363310 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
@@ -73,16 +73,23 @@ class SPIRVLegalizePointerCast : public FunctionPass {
   // Returns the loaded value.
   Value *loadVectorFromVector(IRBuilder<> &B, FixedVectorType *SourceType,
                               FixedVectorType *TargetType, Value *Source) {
-    assert(TargetType->getNumElements() <= SourceType->getNumElements());
     LoadInst *NewLoad = B.CreateLoad(SourceType, Source);
     buildAssignType(B, SourceType, NewLoad);
     Value *AssignValue = NewLoad;
     if (TargetType->getElementType() != SourceType->getElementType()) {
+      const DataLayout &DL = B.GetInsertBlock()->getModule()->getDataLayout();
+      [[maybe_unused]] TypeSize TargetTypeSize =
+          DL.getTypeSizeInBits(TargetType);
+      [[maybe_unused]] TypeSize SourceTypeSize =
+          DL.getTypeSizeInBits(SourceType);
+      assert(TargetTypeSize == SourceTypeSize);
       AssignValue = B.CreateIntrinsic(Intrinsic::spv_bitcast,
                                       {TargetType, SourceType}, {NewLoad});
       buildAssignType(B, TargetType, AssignValue);
+      return AssignValue;
     }
 
+    assert(TargetType->getNumElements() < SourceType->getNumElements());
     SmallVector<int> Mask(/* Size= */ TargetType->getNumElements());
     for (unsigned I = 0; I < TargetType->getNumElements(); ++I)
       Mask[I] = I;
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-resources/issue-146942-ptr-cast.ll b/llvm/test/CodeGen/SPIRV/hlsl-resources/issue-146942-ptr-cast.ll
index ed67344842b11..4817e7450ac2e 100644
--- a/llvm/test/CodeGen/SPIRV/hlsl-resources/issue-146942-ptr-cast.ll
+++ b/llvm/test/CodeGen/SPIRV/hlsl-resources/issue-146942-ptr-cast.ll
@@ -16,7 +16,6 @@
 define void @case1() local_unnamed_addr {
   ; CHECK: %[[#BUFFER_LOAD:]] = OpLoad %[[#FLOAT4]] %{{[0-9]+}} Aligned 16
   ; CHECK: %[[#CAST_LOAD:]] = OpBitcast %[[#INT4]] %[[#BUFFER_LOAD]]
-  ; CHECK: %[[#VEC_SHUFFLE:]] = OpVectorShuffle %[[#INT4]] %[[#CAST_LOAD]] %[[#CAST_LOAD]] 0 1 2 3
   %1 = tail call target("spirv.VulkanBuffer", [0 x <4 x float>], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v4f32_12_0t(i32 0, i32 2, i32 1, i32 0, ptr nonnull @.str)
   %2 = tail call target("spirv.VulkanBuffer", [0 x <4 x i32>], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v4i32_12_1t(i32 0, i32 5, i32 1, i32 0, ptr nonnull @.str.2)
   %3 = tail call noundef align 16 dereferenceable(16) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v4f32_12_0t(target("spirv.VulkanBuffer", [0 x <4 x float>], 12, 0) %1, i32 0)
@@ -29,8 +28,7 @@ define void @case1() local_unnamed_addr {
 define void @case2() local_unnamed_addr {
   ; CHECK: %[[#BUFFER_LOAD:]] = OpLoad %[[#FLOAT4]] %{{[0-9]+}} Aligned 16
   ; CHECK: %[[#CAST_LOAD:]] = OpBitcast %[[#INT4]] %[[#BUFFER_LOAD]]
-  ; CHECK: %[[#VEC_SHUFFLE:]] = OpVectorShuffle %[[#INT4]] %[[#CAST_LOAD]] %[[#CAST_LOAD]] 0 1 2 3
-  ; CHECK: %[[#VEC_TRUNCATE:]] = OpVectorShuffle %[[#INT3]] %[[#VEC_SHUFFLE]] %[[#UNDEF_INT4]] 0 1 2
+  ; CHECK: %[[#VEC_TRUNCATE:]] = OpVectorShuffle %[[#INT3]] %[[#CAST_LOAD]] %[[#UNDEF_INT4]] 0 1 2
   %1 = tail call target("spirv.VulkanBuffer", [0 x <4 x float>], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v4f32_12_0t(i32 0, i32 2, i32 1, i32 0, ptr nonnull @.str)
   %2 = tail call target("spirv.VulkanBuffer", [0 x <3 x i32>], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v3i32_12_1t(i32 0, i32 5, i32 1, i32 0, ptr nonnull @.str.3)
   %3 = tail call noundef align 16 dereferenceable(16) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v4f32_12_0t(target("spirv.VulkanBuffer", [0 x <4 x float>], 12, 0) %1, i32 0)
diff --git a/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll b/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll
index 84913283f6868..a1ec2cd1cfdd2 100644
--- a/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll
+++ b/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll
@@ -26,3 +26,25 @@ entry:
   store <4 x i32> %6, ptr addrspace(11) %7, align 16
   ret void
 }
+
+; This tests a load from a pointer that has been bitcast between vector types
+; which share the same total bit-width but have different numbers of elements.
+; Tests that legalize-pointer-casts works correctly by moving the bitcast to
+; the element that was loaded.
+
+define void @main2() local_unnamed_addr #0 {
+entry:
+; CHECK:  %[[LOAD:[0-9]+]] = OpLoad %[[#v2_double]] {{.*}}
+; CHECK:  %[[BITCAST1:[0-9]+]] = OpBitcast %[[#v4_uint]] %[[LOAD]]
+; CHECK:  %[[BITCAST2:[0-9]+]] = OpBitcast %[[#v2_double]] %[[BITCAST1]]
+; CHECK: OpStore {{%[0-9]+}} %[[BITCAST2]] {{.*}}
+
+  %0 = tail call target("spirv.VulkanBuffer", [0 x <2 x double>], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v2f64_12_1t(i32 0, i32 2, i32 1, i32 0, ptr nonnull @.str.2)
+  %2 = tail call noundef align 16 dereferenceable(16) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v2f64_12_1t(target("spirv.VulkanBuffer", [0 x <2 x double>], 12, 1) %0, i32 0)
+  %3 = load <4 x i32>, ptr addrspace(11) %2
+  %4 = tail call noundef align 16 dereferenceable(16) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v2f64_12_1t(target("spirv.VulkanBuffer", [0 x <2 x double>], 12, 1) %0, i32 1)
+  store <4 x i32> %3, ptr addrspace(11) %4
+  ret void
+}
+
+attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }

@s-perron s-perron merged commit 9f72fab into llvm:main Oct 31, 2025
13 checks passed
@s-perron s-perron deleted the legalize_ptr_cast branch October 31, 2025 14:12
ckoparkar added a commit to ckoparkar/llvm-project that referenced this pull request Oct 31, 2025
* main:
  [SPIRV] Fix vector bitcast check in LegalizePointerCast (llvm#164997)
  [lldb][docs] Add troubleshooting section to scripting introduction
  [Sema] Fix parameter index checks on explicit object member functions (llvm#165586)
  To fix polymorphic pointer assignment in FORALL when LHS is unlimited polymorphic and RHS is intrinsic type target (llvm#164999)
  [CostModel][AArch64] Model cost of extract.last.active intrinsic (clastb) (llvm#165739)
  [MemProf] Select largest of matching contexts from profile (llvm#165338)
  [lldb][TypeSystem] Better support for _BitInt types (llvm#165689)
  [NVPTX] Move TMA G2S lowering to Tablegen (llvm#165710)
  [MLIR][NVVM] Extend NVVM mma ops to support fp64 (llvm#165380)
  [UTC] Support to test annotated IR (llvm#165419)
DEBADRIBASAK pushed a commit to DEBADRIBASAK/llvm-project that referenced this pull request Nov 3, 2025
The previous check for vector bitcasts in `loadVectorFromVector` only
compared the number of elements, which is insufficient when the element
types differ. This can lead to incorrect assumptions about the validity
of the cast.

This commit replaces the element count check with a comparison of the
total size of the vectors in bits. This ensures that the bitcast is
only performed between vectors of the same size, preventing potential
miscompilations.

Part of llvm#153091
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants