Skip to content

[mlir][SPIRV] Fix build (2) #111265

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 5, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 16 additions & 15 deletions mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ static Value optionallyTruncateOrExtend(Location loc, Value value,

/// Broadcasts the value to vector with `numElements` number of elements.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements,
const LLVMTypeConverter &typeConverter,
const TypeConverter &typeConverter,
ConversionPatternRewriter &rewriter) {
auto vectorType = VectorType::get(numElements, toBroadcast.getType());
auto llvmVectorType = typeConverter.convertType(vectorType);
Expand All @@ -166,7 +166,7 @@ static Value broadcast(Location loc, Value toBroadcast, unsigned numElements,

/// Broadcasts the value. If `srcType` is a scalar, the value remains unchanged.
static Value optionallyBroadcast(Location loc, Value value, Type srcType,
const LLVMTypeConverter &typeConverter,
const TypeConverter &typeConverter,
ConversionPatternRewriter &rewriter) {
if (auto vectorType = dyn_cast<VectorType>(srcType)) {
unsigned numElements = vectorType.getNumElements();
Expand All @@ -186,8 +186,7 @@ static Value optionallyBroadcast(Location loc, Value value, Type srcType,
/// Then cast `Offset` and `Count` if their bit width is different
/// from `Base` bit width.
static Value processCountOrOffset(Location loc, Value value, Type srcType,
Type dstType,
const LLVMTypeConverter &converter,
Type dstType, const TypeConverter &converter,
ConversionPatternRewriter &rewriter) {
Value broadcasted =
optionallyBroadcast(loc, value, srcType, converter, rewriter);
Expand All @@ -197,7 +196,7 @@ static Value processCountOrOffset(Location loc, Value value, Type srcType,
/// Converts SPIR-V struct with a regular (according to `VulkanLayoutUtils`)
/// offset to LLVM struct. Otherwise, the conversion is not supported.
static Type convertStructTypeWithOffset(spirv::StructType type,
const LLVMTypeConverter &converter) {
const TypeConverter &converter) {
if (type != VulkanLayoutUtils::decorateType(type))
return nullptr;

Expand All @@ -210,7 +209,7 @@ static Type convertStructTypeWithOffset(spirv::StructType type,

/// Converts SPIR-V struct with no offset to packed LLVM struct.
static Type convertStructTypePacked(spirv::StructType type,
const LLVMTypeConverter &converter) {
const TypeConverter &converter) {
SmallVector<Type> elementsVector;
if (failed(converter.convertTypes(type.getElementTypes(), elementsVector)))
return nullptr;
Expand All @@ -227,10 +226,11 @@ static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter,
}

/// Utility for `spirv.Load` and `spirv.Store` conversion.
static LogicalResult replaceWithLoadOrStore(
Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter,
const LLVMTypeConverter &typeConverter, unsigned alignment, bool isVolatile,
bool isNonTemporal) {
static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands,
ConversionPatternRewriter &rewriter,
const TypeConverter &typeConverter,
unsigned alignment, bool isVolatile,
bool isNonTemporal) {
if (auto loadOp = dyn_cast<spirv::LoadOp>(op)) {
auto dstType = typeConverter.convertType(loadOp.getType());
if (!dstType)
Expand Down Expand Up @@ -271,7 +271,7 @@ static std::optional<Type> convertArrayType(spirv::ArrayType type,
/// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not
/// modelled at the moment.
static Type convertPointerType(spirv::PointerType type,
const LLVMTypeConverter &converter,
const TypeConverter &converter,
spirv::ClientAPI clientAPI) {
unsigned addressSpace =
storageClassToAddressSpace(clientAPI, type.getStorageClass());
Expand All @@ -292,7 +292,7 @@ static std::optional<Type> convertRuntimeArrayType(spirv::RuntimeArrayType type,
/// Converts SPIR-V struct to LLVM struct. There is no support of structs with
/// member decorations. Also, only natural offset is supported.
static Type convertStructType(spirv::StructType type,
const LLVMTypeConverter &converter) {
const TypeConverter &converter) {
SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
type.getMemberDecorations(memberDecorations);
if (!memberDecorations.empty())
Expand Down Expand Up @@ -1378,9 +1378,10 @@ class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> {
auto funcType = funcOp.getFunctionType();
TypeConverter::SignatureConversion signatureConverter(
funcType.getNumInputs());
auto llvmType = getTypeConverter()->convertFunctionSignature(
funcType, /*isVariadic=*/false, /*useBarePtrCallConv=*/false,
signatureConverter);
auto llvmType = static_cast<const LLVMTypeConverter *>(getTypeConverter())
->convertFunctionSignature(
funcType, /*isVariadic=*/false,
/*useBarePtrCallConv=*/false, signatureConverter);
if (!llvmType)
return failure();

Expand Down
Loading