diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h index ed174699314e8..2a4327535c687 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h @@ -239,7 +239,7 @@ class LLVMTypeConverter : public TypeConverter { Type convertMemRefToBarePtr(BaseMemRefType type) const; /// Convert a 1D vector type into an LLVM vector type. - Type convertVectorType(VectorType type) const; + FailureOr convertVectorType(VectorType type) const; /// Options for customizing the llvm lowering. LowerToLLVMOptions options; diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp index a9e7ce9d42848..49e0513e629d9 100644 --- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp +++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp @@ -61,7 +61,12 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx, addConversion([&](MemRefType type) { return convertMemRefType(type); }); addConversion( [&](UnrankedMemRefType type) { return convertUnrankedMemRefType(type); }); - addConversion([&](VectorType type) { return convertVectorType(type); }); + addConversion([&](VectorType type) -> std::optional { + FailureOr llvmType = convertVectorType(type); + if (failed(llvmType)) + return std::nullopt; + return llvmType; + }); // LLVM-compatible types are legal, so add a pass-through conversion. Do this // before the conversions below since conversions are attempted in reverse @@ -490,10 +495,9 @@ Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) const { /// * 1-D `vector` remains as is while, /// * n>1 `vector` convert via an (n-1)-D array type to /// `!llvm.array>>`. -/// As LLVM does not support arrays of scalable vectors, it is assumed that -/// scalable vectors are always 1-D. This condition could be relaxed once the -/// missing functionality is added in LLVM -Type LLVMTypeConverter::convertVectorType(VectorType type) const { +/// Returns failure for n-D scalable vector types as LLVM does not support +/// arrays of scalable vectors. +FailureOr LLVMTypeConverter::convertVectorType(VectorType type) const { auto elementType = convertType(type.getElementType()); if (!elementType) return {}; @@ -503,9 +507,8 @@ Type LLVMTypeConverter::convertVectorType(VectorType type) const { type.getScalableDims().back()); assert(LLVM::isCompatibleVectorType(vectorType) && "expected vector type compatible with the LLVM dialect"); - assert( - (!type.isScalable() || (type.getRank() == 1)) && - "expected 1-D scalable vector (n-D scalable vectors are not supported)"); + if (type.isScalable() && (type.getRank() > 1)) + return failure(); auto shape = type.getShape(); for (int i = shape.size() - 2; i >= 0; --i) vectorType = LLVM::LLVMArrayType::get(vectorType, shape[i]);