Skip to content

Commit 38eb55a

Browse files
authored
[mlir][llvm] Return failure from type converter for n-D scalable vectors (#65450)
This patch changes vector type conversion to return failure on n-D scalable vector types instead of asserting. This is an alternative approach to #65261 that aims to enable lowering of Vector ops directly to ArmSME intrinsics where possible, and seems more consistent with other type conversions. It's trivial to hit the assert at the moment and it could be interpreted as n-D scalable vector types being a bug, when they're valid types in the Vector dialect. By returning failure it will generally fail more gracefully, particularly for release builds or other builds where assertions are disabled.
1 parent 13e5faf commit 38eb55a

File tree

2 files changed

+12
-9
lines changed

2 files changed

+12
-9
lines changed

mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ class LLVMTypeConverter : public TypeConverter {
239239
Type convertMemRefToBarePtr(BaseMemRefType type) const;
240240

241241
/// Convert a 1D vector type into an LLVM vector type.
242-
Type convertVectorType(VectorType type) const;
242+
FailureOr<Type> convertVectorType(VectorType type) const;
243243

244244
/// Options for customizing the llvm lowering.
245245
LowerToLLVMOptions options;

mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,12 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
6161
addConversion([&](MemRefType type) { return convertMemRefType(type); });
6262
addConversion(
6363
[&](UnrankedMemRefType type) { return convertUnrankedMemRefType(type); });
64-
addConversion([&](VectorType type) { return convertVectorType(type); });
64+
addConversion([&](VectorType type) -> std::optional<Type> {
65+
FailureOr<Type> llvmType = convertVectorType(type);
66+
if (failed(llvmType))
67+
return std::nullopt;
68+
return llvmType;
69+
});
6570

6671
// LLVM-compatible types are legal, so add a pass-through conversion. Do this
6772
// before the conversions below since conversions are attempted in reverse
@@ -490,10 +495,9 @@ Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) const {
490495
/// * 1-D `vector<axT>` remains as is while,
491496
/// * n>1 `vector<ax...xkxT>` convert via an (n-1)-D array type to
492497
/// `!llvm.array<ax...array<jxvector<kxT>>>`.
493-
/// As LLVM does not support arrays of scalable vectors, it is assumed that
494-
/// scalable vectors are always 1-D. This condition could be relaxed once the
495-
/// missing functionality is added in LLVM
496-
Type LLVMTypeConverter::convertVectorType(VectorType type) const {
498+
/// Returns failure for n-D scalable vector types as LLVM does not support
499+
/// arrays of scalable vectors.
500+
FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
497501
auto elementType = convertType(type.getElementType());
498502
if (!elementType)
499503
return {};
@@ -503,9 +507,8 @@ Type LLVMTypeConverter::convertVectorType(VectorType type) const {
503507
type.getScalableDims().back());
504508
assert(LLVM::isCompatibleVectorType(vectorType) &&
505509
"expected vector type compatible with the LLVM dialect");
506-
assert(
507-
(!type.isScalable() || (type.getRank() == 1)) &&
508-
"expected 1-D scalable vector (n-D scalable vectors are not supported)");
510+
if (type.isScalable() && (type.getRank() > 1))
511+
return failure();
509512
auto shape = type.getShape();
510513
for (int i = shape.size() - 2; i >= 0; --i)
511514
vectorType = LLVM::LLVMArrayType::get(vectorType, shape[i]);

0 commit comments

Comments
 (0)