@@ -56,6 +56,8 @@ class CommonSPIRTargetCodeGenInfo : public TargetCodeGenInfo {
56
56
llvm::Type *
57
57
getHLSLType (CodeGenModule &CGM, const Type *Ty,
58
58
const SmallVector<int32_t > *Packoffsets = nullptr ) const override ;
59
+ llvm::Type *getSampleType (clang::CodeGen::CodeGenModule &CGM,
60
+ clang::QualType Ty, llvm::LLVMContext &Ctx) const ;
59
61
llvm::Type *getSPIRVImageTypeFromHLSLResource (
60
62
const HLSLAttributedResourceType::Attributes &attributes,
61
63
llvm::Type *ElementType, llvm::LLVMContext &Ctx) const ;
@@ -483,12 +485,13 @@ llvm::Type *CommonSPIRTargetCodeGenInfo::getHLSLType(
483
485
assert (!ResAttrs.IsROV &&
484
486
" Rasterizer order views not implemented for SPIR-V yet" );
485
487
486
- llvm::Type *ElemType = CGM.getTypes ().ConvertType (ContainedTy);
487
488
if (!ResAttrs.RawBuffer ) {
488
489
// convert element type
490
+ llvm::Type *ElemType = getSampleType (CGM, ContainedTy, Ctx);
489
491
return getSPIRVImageTypeFromHLSLResource (ResAttrs, ElemType, Ctx);
490
492
}
491
493
494
+ llvm::Type *ElemType = CGM.getTypes ().ConvertType (ContainedTy);
492
495
llvm::ArrayType *RuntimeArrayType = llvm::ArrayType::get (ElemType, 0 );
493
496
uint32_t StorageClass = /* StorageBuffer storage class */ 12 ;
494
497
bool IsWritable = ResAttrs.ResourceClass == llvm::dxil::ResourceClass::UAV;
@@ -515,16 +518,41 @@ llvm::Type *CommonSPIRTargetCodeGenInfo::getHLSLType(
515
518
return nullptr ;
516
519
}
517
520
521
+ llvm::Type *CommonSPIRTargetCodeGenInfo::getSampleType (
522
+ clang::CodeGen::CodeGenModule &CGM,
523
+ clang::QualType Ty, llvm::LLVMContext &Ctx) const {
524
+ Ty = Ty->getCanonicalTypeUnqualified ();
525
+ if (const VectorType *V = dyn_cast<VectorType>(Ty))
526
+ Ty = V->getElementType ();
527
+ assert (!Ty->isVectorType () && " We still have a vector type." );
528
+
529
+ if (!Ty->isSignedIntegerType ())
530
+ return CGM.getTypes ().ConvertType (Ty);
531
+
532
+ // We need to maintain signed integers as sign integers for the sampled type.
533
+ // See https://docs.vulkan.org/spec/latest/appendices/spirvenv.html#spirvenv-image-signedness.
534
+
535
+ uint32_t OpTypeInt = 21 ;
536
+ uint32_t BitWidth = CGM.getContext ().getIntWidth (Ty);
537
+ uint32_t AlignmentInBytes = CGM.getContext ().getTypeAlign (Ty)/8 ;
538
+ llvm::Type* BitWidthLitType = getInlineSpirvConstant (CGM, nullptr , llvm::APInt (32 , BitWidth));
539
+ llvm::Type* SignLitType = getInlineSpirvConstant (CGM, nullptr , llvm::APInt (32 , 1 ));
540
+
541
+ return llvm::TargetExtType::get (Ctx, " spirv.Type" , {BitWidthLitType, SignLitType},
542
+ {OpTypeInt, BitWidth / 8 , AlignmentInBytes});
543
+ }
544
+
518
545
llvm::Type *CommonSPIRTargetCodeGenInfo::getSPIRVImageTypeFromHLSLResource (
519
546
const HLSLAttributedResourceType::Attributes &attributes,
520
- llvm::Type *ElementType, llvm::LLVMContext &Ctx) const {
521
-
522
- if (ElementType->isVectorTy ())
523
- ElementType = ElementType->getScalarType ();
547
+ llvm::Type *SampledType, llvm::LLVMContext &Ctx) const {
524
548
525
- assert ((ElementType->isIntegerTy () || ElementType->isFloatingPointTy ()) &&
526
- " The element type for a SPIR-V resource must be a scalar integer or "
527
- " floating point type." );
549
+ // assert((SampledType->isIntegerTy() || SampledType->isFloatingPointTy() ||
550
+ // (SampledType->isTargetExtTy() &&
551
+ // SampledType->getTargetExtName() == "spirv.Type" &&
552
+ // cast<llvm::TargetExtType>(SampledType)->getIntParameter(0) ==
553
+ // /* OpTypeInt */ 21)) &&
554
+ // "The element type for a SPIR-V resource must be a scalar integer or "
555
+ // "floating point type.");
528
556
529
557
// These parameters correspond to the operands to the OpTypeImage SPIR-V
530
558
// instruction. See
@@ -553,7 +581,7 @@ llvm::Type *CommonSPIRTargetCodeGenInfo::getSPIRVImageTypeFromHLSLResource(
553
581
// Setting to unknown for now.
554
582
IntParams[5 ] = 0 ;
555
583
556
- return llvm::TargetExtType::get (Ctx, " spirv.Image" , {ElementType }, IntParams);
584
+ return llvm::TargetExtType::get (Ctx, " spirv.Image" , {SampledType }, IntParams);
557
585
}
558
586
559
587
std::unique_ptr<TargetCodeGenInfo>
0 commit comments