diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h index ed174699314e8..43db7987e650a 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h @@ -238,14 +238,15 @@ class LLVMTypeConverter : public TypeConverter { /// Convert a memref type to a bare pointer to the memref element type. Type convertMemRefToBarePtr(BaseMemRefType type) const; - /// Convert a 1D vector type into an LLVM vector type. - Type convertVectorType(VectorType type) const; - /// Options for customizing the llvm lowering. LowerToLLVMOptions options; /// Data layout analysis mapping scopes to layouts active in them. const DataLayoutAnalysis *dataLayoutAnalysis; + +protected: + /// Convert a 1D vector type into an LLVM vector type. + Type convertVectorType(VectorType type) const; }; /// Callback to convert function argument types. It converts a MemRef function diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h index ab5c179f2dd77..ad3c010816fa3 100644 --- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h @@ -43,6 +43,19 @@ std::unique_ptr createTileAllocationPass(); class ArmSMETypeConverter : public LLVMTypeConverter { public: ArmSMETypeConverter(MLIRContext *ctx, const LowerToLLVMOptions &options); + +protected: + /// Convert an n-D vector type to an LLVM vector type. + /// + /// Disables type conversion of legal 2-D scalable vector types such as + /// `vector<[16]x[16]xi8>` for ArmSME, since LLVM does not support arrays of + /// scalable vectors and the LLVM type converter asserts on such types to + /// prevent generation of illegal LLVM IR. When lowering to ArmSME these types + /// should be eliminated before lowering to LLVM. + /// + /// Types unrelated to ArmSME are converted by + /// `LLVMTypeConverter::convertVectorType`. + Type convertVectorType(VectorType type) const; }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp index 04570a750822a..c534ef6e408b8 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -83,21 +83,26 @@ void LowerVectorToLLVMPass::runOnOperation() { // Convert to the LLVM IR dialect. LowerToLLVMOptions options(&getContext()); options.useOpaquePointers = useOpaquePointers; - LLVMTypeConverter converter(&getContext(), options); + + LLVMTypeConverter *converter; + if (armSME) + converter = new arm_sme::ArmSMETypeConverter(&getContext(), options); + else + converter = new LLVMTypeConverter(&getContext(), options); + RewritePatternSet patterns(&getContext()); populateVectorMaskMaterializationPatterns(patterns, force32BitVectorIndices); populateVectorTransferLoweringPatterns(patterns); - populateVectorToLLVMMatrixConversionPatterns(converter, patterns); + populateVectorToLLVMMatrixConversionPatterns(*converter, patterns); populateVectorToLLVMConversionPatterns( - converter, patterns, reassociateFPReductions, force32BitVectorIndices); - populateVectorToLLVMMatrixConversionPatterns(converter, patterns); + *converter, patterns, reassociateFPReductions, force32BitVectorIndices); + populateVectorToLLVMMatrixConversionPatterns(*converter, patterns); // Architecture specific augmentations. LLVMConversionTarget target(getContext()); target.addLegalDialect(); target.addLegalDialect(); target.addLegalOp(); - arm_sme::ArmSMETypeConverter armSMEConverter(&getContext(), options); if (armNeon) { // TODO: we may or may not want to include in-dialect lowering to @@ -107,19 +112,19 @@ void LowerVectorToLLVMPass::runOnOperation() { } if (armSVE) { configureArmSVELegalizeForExportTarget(target); - populateArmSVELegalizeForLLVMExportPatterns(converter, patterns); + populateArmSVELegalizeForLLVMExportPatterns(*converter, patterns); } if (armSME) { configureArmSMELegalizeForExportTarget(target); - populateArmSMELegalizeForLLVMExportPatterns(armSMEConverter, patterns); + populateArmSMELegalizeForLLVMExportPatterns(*converter, patterns); } if (amx) { configureAMXLegalizeForExportTarget(target); - populateAMXLegalizeForLLVMExportPatterns(converter, patterns); + populateAMXLegalizeForLLVMExportPatterns(*converter, patterns); } if (x86Vector) { configureX86VectorLegalizeForExportTarget(target); - populateX86VectorLegalizeForLLVMExportPatterns(converter, patterns); + populateX86VectorLegalizeForLLVMExportPatterns(*converter, patterns); } if (failed( diff --git a/mlir/lib/Dialect/ArmSME/Transforms/ArmSMETypeConverter.cpp b/mlir/lib/Dialect/ArmSME/Transforms/ArmSMETypeConverter.cpp index 1cefc220ecf10..65da2a7a75d29 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/ArmSMETypeConverter.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/ArmSMETypeConverter.cpp @@ -7,16 +7,17 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/ArmSME/Transforms/Passes.h" +#include "mlir/Dialect/ArmSME/Utils/Utils.h" using namespace mlir; arm_sme::ArmSMETypeConverter::ArmSMETypeConverter( MLIRContext *ctx, const LowerToLLVMOptions &options) : LLVMTypeConverter(ctx, options) { - // Disable LLVM type conversion for vectors. This is to prevent 2-d scalable - // vectors (common in the context of ArmSME), e.g. - // `vector<[16]x[16]xi8>`, - // entering the LLVM Type converter. LLVM does not support arrays of scalable - // vectors, but in the case of SME such types are effectively eliminated when - // emitting ArmSME LLVM IR intrinsics. - addConversion([&](VectorType type) { return type; }); + addConversion([&](VectorType type) { return convertVectorType(type); }); +} + +Type arm_sme::ArmSMETypeConverter::convertVectorType(VectorType type) const { + if (arm_sme::isValidSMETileVectorType(type)) + return type; + return LLVMTypeConverter::convertVectorType(type); } diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index 514594240d22a..3f897fbf01b7b 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt %s -convert-vector-to-llvm='use-opaque-pointers=1' -split-input-file | FileCheck %s +// RUN: mlir-opt %s -convert-vector-to-llvm='use-opaque-pointers=1 enable-arm-sme' -split-input-file | FileCheck %s func.func @bitcast_f32_to_i32_vector_0d(%input: vector) -> vector {