diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp index e36e3951a31ec..74c169c9a7e76 100644 --- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp @@ -149,7 +149,7 @@ static Value optionallyTruncateOrExtend(Location loc, Value value, /// Broadcasts the value to vector with `numElements` number of elements. static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, - const LLVMTypeConverter &typeConverter, + const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter) { auto vectorType = VectorType::get(numElements, toBroadcast.getType()); auto llvmVectorType = typeConverter.convertType(vectorType); @@ -166,7 +166,7 @@ static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, /// Broadcasts the value. If `srcType` is a scalar, the value remains unchanged. static Value optionallyBroadcast(Location loc, Value value, Type srcType, - const LLVMTypeConverter &typeConverter, + const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter) { if (auto vectorType = dyn_cast(srcType)) { unsigned numElements = vectorType.getNumElements(); @@ -186,8 +186,7 @@ static Value optionallyBroadcast(Location loc, Value value, Type srcType, /// Then cast `Offset` and `Count` if their bit width is different /// from `Base` bit width. static Value processCountOrOffset(Location loc, Value value, Type srcType, - Type dstType, - const LLVMTypeConverter &converter, + Type dstType, const TypeConverter &converter, ConversionPatternRewriter &rewriter) { Value broadcasted = optionallyBroadcast(loc, value, srcType, converter, rewriter); @@ -197,7 +196,7 @@ static Value processCountOrOffset(Location loc, Value value, Type srcType, /// Converts SPIR-V struct with a regular (according to `VulkanLayoutUtils`) /// offset to LLVM struct. Otherwise, the conversion is not supported. static Type convertStructTypeWithOffset(spirv::StructType type, - const LLVMTypeConverter &converter) { + const TypeConverter &converter) { if (type != VulkanLayoutUtils::decorateType(type)) return nullptr; @@ -210,7 +209,7 @@ static Type convertStructTypeWithOffset(spirv::StructType type, /// Converts SPIR-V struct with no offset to packed LLVM struct. static Type convertStructTypePacked(spirv::StructType type, - const LLVMTypeConverter &converter) { + const TypeConverter &converter) { SmallVector elementsVector; if (failed(converter.convertTypes(type.getElementTypes(), elementsVector))) return nullptr; @@ -227,10 +226,11 @@ static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter, } /// Utility for `spirv.Load` and `spirv.Store` conversion. -static LogicalResult replaceWithLoadOrStore( - Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter, - const LLVMTypeConverter &typeConverter, unsigned alignment, bool isVolatile, - bool isNonTemporal) { +static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands, + ConversionPatternRewriter &rewriter, + const TypeConverter &typeConverter, + unsigned alignment, bool isVolatile, + bool isNonTemporal) { if (auto loadOp = dyn_cast(op)) { auto dstType = typeConverter.convertType(loadOp.getType()); if (!dstType) @@ -271,7 +271,7 @@ static std::optional convertArrayType(spirv::ArrayType type, /// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not /// modelled at the moment. static Type convertPointerType(spirv::PointerType type, - const LLVMTypeConverter &converter, + const TypeConverter &converter, spirv::ClientAPI clientAPI) { unsigned addressSpace = storageClassToAddressSpace(clientAPI, type.getStorageClass()); @@ -292,7 +292,7 @@ static std::optional convertRuntimeArrayType(spirv::RuntimeArrayType type, /// Converts SPIR-V struct to LLVM struct. There is no support of structs with /// member decorations. Also, only natural offset is supported. static Type convertStructType(spirv::StructType type, - const LLVMTypeConverter &converter) { + const TypeConverter &converter) { SmallVector memberDecorations; type.getMemberDecorations(memberDecorations); if (!memberDecorations.empty()) @@ -1378,9 +1378,10 @@ class FuncConversionPattern : public SPIRVToLLVMConversion { auto funcType = funcOp.getFunctionType(); TypeConverter::SignatureConversion signatureConverter( funcType.getNumInputs()); - auto llvmType = getTypeConverter()->convertFunctionSignature( - funcType, /*isVariadic=*/false, /*useBarePtrCallConv=*/false, - signatureConverter); + auto llvmType = static_cast(getTypeConverter()) + ->convertFunctionSignature( + funcType, /*isVariadic=*/false, + /*useBarePtrCallConv=*/false, signatureConverter); if (!llvmType) return failure();