diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h index d79b90f840ce8..38b5e492a8ed8 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h @@ -161,6 +161,41 @@ class LLVMTypeConverter : public TypeConverter { /// Check if a memref type can be converted to a bare pointer. static bool canConvertToBarePtr(BaseMemRefType type); + /// Convert a memref type into a list of LLVM IR types that will form the + /// memref descriptor. If `unpackAggregates` is true the `sizes` and `strides` + /// arrays in the descriptors are unpacked to individual index-typed elements, + /// else they are kept as rank-sized arrays of index type. In particular, + /// the list will contain: + /// - two pointers to the memref element type, followed by + /// - an index-typed offset, followed by + /// - (if unpackAggregates = true) + /// - one index-typed size per dimension of the memref, followed by + /// - one index-typed stride per dimension of the memref. + /// - (if unpackArrregates = false) + /// - one rank-sized array of index-type for the size of each dimension + /// - one rank-sized array of index-type for the stride of each dimension + /// + /// For example, memref is converted to the following list: + /// - `!llvm<"float*">` (allocated pointer), + /// - `!llvm<"float*">` (aligned pointer), + /// - `i64` (offset), + /// - `i64`, `i64` (sizes), + /// - `i64`, `i64` (strides). + /// These types can be recomposed to a memref descriptor struct. + SmallVector getMemRefDescriptorFields(MemRefType type, + bool unpackAggregates) const; + + /// Convert an unranked memref type into a list of non-aggregate LLVM IR types + /// that will form the unranked memref descriptor. In particular, this list + /// contains: + /// - an integer rank, followed by + /// - a pointer to the memref descriptor struct. + /// For example, memref<*xf32> is converted to the following list: + /// i64 (rank) + /// !llvm<"i8*"> (type-erased pointer). + /// These types can be recomposed to a unranked memref descriptor struct. + SmallVector getUnrankedMemRefDescriptorFields() const; + protected: /// Pointer to the LLVM dialect. LLVM::LLVMDialect *llvmDialect; @@ -213,41 +248,6 @@ class LLVMTypeConverter : public TypeConverter { /// Convert a memref type into an LLVM type that captures the relevant data. Type convertMemRefType(MemRefType type) const; - /// Convert a memref type into a list of LLVM IR types that will form the - /// memref descriptor. If `unpackAggregates` is true the `sizes` and `strides` - /// arrays in the descriptors are unpacked to individual index-typed elements, - /// else they are kept as rank-sized arrays of index type. In particular, - /// the list will contain: - /// - two pointers to the memref element type, followed by - /// - an index-typed offset, followed by - /// - (if unpackAggregates = true) - /// - one index-typed size per dimension of the memref, followed by - /// - one index-typed stride per dimension of the memref. - /// - (if unpackArrregates = false) - /// - one rank-sized array of index-type for the size of each dimension - /// - one rank-sized array of index-type for the stride of each dimension - /// - /// For example, memref is converted to the following list: - /// - `!llvm<"float*">` (allocated pointer), - /// - `!llvm<"float*">` (aligned pointer), - /// - `i64` (offset), - /// - `i64`, `i64` (sizes), - /// - `i64`, `i64` (strides). - /// These types can be recomposed to a memref descriptor struct. - SmallVector getMemRefDescriptorFields(MemRefType type, - bool unpackAggregates) const; - - /// Convert an unranked memref type into a list of non-aggregate LLVM IR types - /// that will form the unranked memref descriptor. In particular, this list - /// contains: - /// - an integer rank, followed by - /// - a pointer to the memref descriptor struct. - /// For example, memref<*xf32> is converted to the following list: - /// i64 (rank) - /// !llvm<"i8*"> (type-erased pointer). - /// These types can be recomposed to a unranked memref descriptor struct. - SmallVector getUnrankedMemRefDescriptorFields() const; - /// Convert an unranked memref type to an LLVM type that captures the /// runtime rank and a pointer to the static ranked memref desc Type convertUnrankedMemRefType(UnrankedMemRefType type) const; diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp index e2ab0ed6f66cc..1a7951282d3f7 100644 --- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp +++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp @@ -44,6 +44,74 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx, const DataLayoutAnalysis *analysis) : LLVMTypeConverter(ctx, LowerToLLVMOptions(ctx), analysis) {} +/// Helper function that checks if the given value range is a bare pointer. +static bool isBarePointer(ValueRange values) { + return values.size() == 1 && + isa(values.front().getType()); +}; + +/// Pack SSA values into an unranked memref descriptor struct. +static Value packUnrankedMemRefDesc(OpBuilder &builder, + UnrankedMemRefType resultType, + ValueRange inputs, Location loc, + const LLVMTypeConverter &converter) { + // Note: Bare pointers are not supported for unranked memrefs because a + // memref descriptor cannot be built just from a bare pointer. + if (TypeRange(inputs) != converter.getUnrankedMemRefDescriptorFields()) + return Value(); + return UnrankedMemRefDescriptor::pack(builder, loc, converter, resultType, + inputs); +} + +/// Pack SSA values into a ranked memref descriptor struct. +static Value packRankedMemRefDesc(OpBuilder &builder, MemRefType resultType, + ValueRange inputs, Location loc, + const LLVMTypeConverter &converter) { + assert(resultType && "expected non-null result type"); + if (isBarePointer(inputs)) + return MemRefDescriptor::fromStaticShape(builder, loc, converter, + resultType, inputs[0]); + if (TypeRange(inputs) == + converter.getMemRefDescriptorFields(resultType, + /*unpackAggregates=*/true)) + return MemRefDescriptor::pack(builder, loc, converter, resultType, inputs); + // The inputs are neither a bare pointer nor an unpacked memref descriptor. + // This materialization function cannot be used. + return Value(); +} + +/// MemRef descriptor elements -> UnrankedMemRefType +static Value unrankedMemRefMaterialization(OpBuilder &builder, + UnrankedMemRefType resultType, + ValueRange inputs, Location loc, + const LLVMTypeConverter &converter) { + // An argument materialization must return a value of type + // `resultType`, so insert a cast from the memref descriptor type + // (!llvm.struct) to the original memref type. + Value packed = + packUnrankedMemRefDesc(builder, resultType, inputs, loc, converter); + if (!packed) + return Value(); + return builder.create(loc, resultType, packed) + .getResult(0); +}; + +/// MemRef descriptor elements -> MemRefType +static Value rankedMemRefMaterialization(OpBuilder &builder, + MemRefType resultType, + ValueRange inputs, Location loc, + const LLVMTypeConverter &converter) { + // An argument materialization must return a value of type `resultType`, + // so insert a cast from the memref descriptor type (!llvm.struct) to the + // original memref type. + Value packed = + packRankedMemRefDesc(builder, resultType, inputs, loc, converter); + if (!packed) + return Value(); + return builder.create(loc, resultType, packed) + .getResult(0); +} + /// Create an LLVMTypeConverter using custom LowerToLLVMOptions. LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx, const LowerToLLVMOptions &options, @@ -166,81 +234,29 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx, .getResult(0); }); - // Helper function that checks if the given value range is a bare pointer. - auto isBarePointer = [](ValueRange values) { - return values.size() == 1 && - isa(values.front().getType()); - }; - - // TODO: For some reason, `this` is nullptr in here, so the LLVMTypeConverter - // must be passed explicitly. - auto packUnrankedMemRefDesc = - [&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs, - Location loc, LLVMTypeConverter &converter) -> Value { - // Note: Bare pointers are not supported for unranked memrefs because a - // memref descriptor cannot be built just from a bare pointer. - if (TypeRange(inputs) != converter.getUnrankedMemRefDescriptorFields()) - return Value(); - return UnrankedMemRefDescriptor::pack(builder, loc, converter, resultType, - inputs); - }; - - // MemRef descriptor elements -> UnrankedMemRefType - auto unrakedMemRefMaterialization = [&](OpBuilder &builder, - UnrankedMemRefType resultType, - ValueRange inputs, Location loc) { - // An argument materialization must return a value of type - // `resultType`, so insert a cast from the memref descriptor type - // (!llvm.struct) to the original memref type. - Value packed = - packUnrankedMemRefDesc(builder, resultType, inputs, loc, *this); - if (!packed) - return Value(); - return builder.create(loc, resultType, packed) - .getResult(0); - }; - - // TODO: For some reason, `this` is nullptr in here, so the LLVMTypeConverter - // must be passed explicitly. - auto packRankedMemRefDesc = [&](OpBuilder &builder, MemRefType resultType, - ValueRange inputs, Location loc, - LLVMTypeConverter &converter) -> Value { - assert(resultType && "expected non-null result type"); - if (isBarePointer(inputs)) - return MemRefDescriptor::fromStaticShape(builder, loc, converter, - resultType, inputs[0]); - if (TypeRange(inputs) == - converter.getMemRefDescriptorFields(resultType, - /*unpackAggregates=*/true)) - return MemRefDescriptor::pack(builder, loc, converter, resultType, - inputs); - // The inputs are neither a bare pointer nor an unpacked memref descriptor. - // This materialization function cannot be used. - return Value(); - }; - - // MemRef descriptor elements -> MemRefType - auto rankedMemRefMaterialization = [&](OpBuilder &builder, - MemRefType resultType, - ValueRange inputs, Location loc) { - // An argument materialization must return a value of type `resultType`, - // so insert a cast from the memref descriptor type (!llvm.struct) to the - // original memref type. - Value packed = - packRankedMemRefDesc(builder, resultType, inputs, loc, *this); - if (!packed) - return Value(); - return builder.create(loc, resultType, packed) - .getResult(0); - }; - // Argument materializations convert from the new block argument types // (multiple SSA values that make up a memref descriptor) back to the // original block argument type. - addArgumentMaterialization(unrakedMemRefMaterialization); - addArgumentMaterialization(rankedMemRefMaterialization); - addSourceMaterialization(unrakedMemRefMaterialization); - addSourceMaterialization(rankedMemRefMaterialization); + addArgumentMaterialization([&](OpBuilder &builder, + UnrankedMemRefType resultType, + ValueRange inputs, Location loc) { + return unrankedMemRefMaterialization(builder, resultType, inputs, loc, + *this); + }); + addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType, + ValueRange inputs, Location loc) { + return rankedMemRefMaterialization(builder, resultType, inputs, loc, *this); + }); + addSourceMaterialization([&](OpBuilder &builder, + UnrankedMemRefType resultType, ValueRange inputs, + Location loc) { + return unrankedMemRefMaterialization(builder, resultType, inputs, loc, + *this); + }); + addSourceMaterialization([&](OpBuilder &builder, MemRefType resultType, + ValueRange inputs, Location loc) { + return rankedMemRefMaterialization(builder, resultType, inputs, loc, *this); + }); // Bare pointer -> Packed MemRef descriptor addTargetMaterialization([&](OpBuilder &builder, Type resultType,