diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td index 6975b18ab7f81..74bcc02ff1314 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td @@ -351,27 +351,6 @@ def LLVM_DICompileUnitAttr : LLVM_Attr<"DICompileUnit", "di_compile_unit", let assemblyFormat = "`<` struct(params) `>`"; } -//===----------------------------------------------------------------------===// -// DICompositeTypeAttr -//===----------------------------------------------------------------------===// - -def LLVM_DICompositeTypeAttr : LLVM_Attr<"DICompositeType", "di_composite_type", - /*traits=*/[], "DITypeAttr"> { - let parameters = (ins - LLVM_DITagParameter:$tag, - OptionalParameter<"StringAttr">:$name, - OptionalParameter<"DIFileAttr">:$file, - OptionalParameter<"uint32_t">:$line, - OptionalParameter<"DIScopeAttr">:$scope, - OptionalParameter<"DITypeAttr">:$baseType, - OptionalParameter<"DIFlags", "DIFlags::Zero">:$flags, - OptionalParameter<"uint64_t">:$sizeInBits, - OptionalParameter<"uint64_t">:$alignInBits, - OptionalArrayRefParameter<"DINodeAttr">:$elements - ); - let assemblyFormat = "`<` struct(params) `>`"; -} - //===----------------------------------------------------------------------===// // DIDerivedTypeAttr //===----------------------------------------------------------------------===// @@ -684,6 +663,61 @@ def LLVM_AliasScopeDomainAttr : LLVM_Attr<"AliasScopeDomain", let assemblyFormat = "`<` struct(params) `>`"; } +//===----------------------------------------------------------------------===// +// DICompositeTypeAttr +//===----------------------------------------------------------------------===// + +def LLVM_DICompositeTypeAttr : LLVM_Attr<"DICompositeType", "di_composite_type", + /*traits=*/[NativeTypeTrait<"IsMutable">], "DITypeAttr"> { + let parameters = (ins + OptionalParameter<"unsigned">:$tag, + OptionalParameter<"StringAttr">:$name, + OptionalParameter<"DIFileAttr">:$file, + OptionalParameter<"uint32_t">:$line, + OptionalParameter<"DIScopeAttr">:$scope, + OptionalParameter<"DITypeAttr">:$baseType, + OptionalParameter<"DIFlags", "DIFlags::Zero">:$flags, + OptionalParameter<"uint64_t">:$sizeInBits, + OptionalParameter<"uint64_t">:$alignInBits, + OptionalArrayRefParameter<"DINodeAttr">:$elements, + OptionalParameter<"StringAttr">:$identifier + ); + let hasCustomAssemblyFormat = 1; + let genStorageClass = 0; + let storageClass = "DICompositeTypeAttrStorage"; + let builders = [ + AttrBuilder<(ins + "unsigned":$tag, + "StringAttr":$name, + "DIFileAttr":$file, + "uint32_t":$line, + "DIScopeAttr":$scope, + "DITypeAttr":$baseType, + "DIFlags":$flags, + "uint64_t":$sizeInBits, + "uint64_t":$alignInBits, + "::llvm::ArrayRef":$elements + )>, + AttrBuilder<(ins + "StringAttr":$identifier, + "unsigned":$tag, + "StringAttr":$name, + "DIFileAttr":$file, + "uint32_t":$line, + "DIScopeAttr":$scope, + "DITypeAttr":$baseType, + "DIFlags":$flags, + "uint64_t":$sizeInBits, + "uint64_t":$alignInBits, + CArg<"::llvm::ArrayRef", "{}">:$elements + )> + ]; + let extraClassDeclaration = [{ + static DICompositeTypeAttr getIdentified(MLIRContext *context, StringAttr identifier); + void replaceElements(const ArrayRef& elements); + }]; +} + //===----------------------------------------------------------------------===// // AliasScopeAttr //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h index c370bfa2b733d..c38bf1c66bba3 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h @@ -23,6 +23,10 @@ namespace mlir { namespace LLVM { +namespace detail { + struct DICompositeTypeAttrStorage; +} // namespace detail + /// This class represents the base attribute for all debug info attributes. class DINodeAttr : public Attribute { public: diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h index 5dfc15afb7593..8a9932ad36b67 100644 --- a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h +++ b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h @@ -63,7 +63,7 @@ registerAllGPUToLLVMIRTranslations(DialectRegistry ®istry) { registerLLVMDialectTranslation(registry); registerNVVMDialectTranslation(registry); registerROCDLDialectTranslation(registry); - registerSPIRVDialectTranslation(registry); + //registerSPIRVDialectTranslation(registry); // Extension required for translating GPU offloading Ops. gpu::registerOffloadingLLVMTranslationInterfaceExternalModels(registry); diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp index d085fb6af6bc1..de120bfd7a4f9 100644 --- a/mlir/lib/AsmParser/AttributeParser.cpp +++ b/mlir/lib/AsmParser/AttributeParser.cpp @@ -248,6 +248,7 @@ Attribute Parser::parseAttribute(Type type) { OptionalParseResult Parser::parseOptionalAttribute(Attribute &attribute, Type type) { switch (getToken().getKind()) { + case Token::kw_distinct: case Token::at_identifier: case Token::floatliteral: case Token::integer: diff --git a/mlir/lib/Dialect/LLVMIR/IR/AttrDetail.h b/mlir/lib/Dialect/LLVMIR/IR/AttrDetail.h new file mode 100644 index 0000000000000..ccecd540ab046 --- /dev/null +++ b/mlir/lib/Dialect/LLVMIR/IR/AttrDetail.h @@ -0,0 +1,136 @@ +//===- AttrDetail.h - Details of MLIR LLVM dialect attributes --------*- C++ +//-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains implementation details, such as storage structures, of +// MLIR LLVM dialect attributes. +// +//===----------------------------------------------------------------------===// +#ifndef DIALECT_LLVMIR_IR_ATTRDETAIL_H +#define DIALECT_LLVMIR_IR_ATTRDETAIL_H + +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" +#include "mlir/IR/Types.h" + +namespace mlir { +namespace LLVM { +namespace detail { + +//===----------------------------------------------------------------------===// +// DICompositeTypeAttrStorage +//===----------------------------------------------------------------------===// + +struct DICompositeTypeAttrStorage : public ::mlir::AttributeStorage { + using KeyTy = std::tuple, StringAttr>; + + DICompositeTypeAttrStorage(unsigned tag, StringAttr name, DIFileAttr file, + uint32_t line, DIScopeAttr scope, + DITypeAttr baseType, DIFlags flags, + uint64_t sizeInBits, uint64_t alignInBits, + ArrayRef elements, + StringAttr identifier = StringAttr()) + : tag(tag), name(name), file(file), line(line), scope(scope), + baseType(baseType), flags(flags), sizeInBits(sizeInBits), + alignInBits(alignInBits), elements(elements), identifier(identifier) {} + + unsigned getTag() const { return tag; } + StringAttr getName() const { return name; } + DIFileAttr getFile() const { return file; } + uint32_t getLine() const { return line; } + DIScopeAttr getScope() const { return scope; } + DITypeAttr getBaseType() const { return baseType; } + DIFlags getFlags() const { return flags; } + uint64_t getSizeInBits() const { return sizeInBits; } + uint64_t getAlignInBits() const { return alignInBits; } + ArrayRef getElements() const { return elements; } + StringAttr getIdentifier() const { return identifier; } + + /// Returns true if this attribute is identified. + bool isIdentified() const { + return !(!identifier); + } + + /// Returns the respective key for this attribute. + KeyTy getAsKey() const { + if (isIdentified()) + return KeyTy(tag, name, file, line, scope, baseType, flags, sizeInBits, + alignInBits, elements, identifier); + + return KeyTy(tag, name, file, line, scope, baseType, flags, sizeInBits, + alignInBits, elements, StringAttr()); + } + + /// Compares two keys. + bool operator==(const KeyTy &other) const { + if (isIdentified()) + // Just compare against the identifier. + return identifier == std::get<10>(other); + + // Otherwise, compare the entire tuple. + return other == getAsKey(); + } + + /// Returns the hash value of the key. + static llvm::hash_code hashKey(const KeyTy &key) { + const auto &[tag, name, file, line, scope, baseType, flags, sizeInBits, + alignInBits, elements, identifier] = key; + + if (identifier) + // Only the identifier participates in the hash id. + return hash_value(identifier); + + // Otherwise, everything else is included in the hash. + return hash_combine(tag, name, file, line, scope, baseType, flags, + sizeInBits, alignInBits, elements); + } + + /// Constructs new storage for an attribute. + static DICompositeTypeAttrStorage * + construct(AttributeStorageAllocator &allocator, const KeyTy &key) { + auto [tag, name, file, line, scope, baseType, flags, sizeInBits, + alignInBits, elements, identifier] = key; + elements = allocator.copyInto(elements); + if (identifier) { + return new (allocator.allocate()) + DICompositeTypeAttrStorage(tag, name, file, line, scope, baseType, + flags, sizeInBits, alignInBits, elements, + identifier); + } + return new (allocator.allocate()) + DICompositeTypeAttrStorage(tag, name, file, line, scope, baseType, + flags, sizeInBits, alignInBits, elements); + } + + LogicalResult mutate(AttributeStorageAllocator &allocator, + const ArrayRef& elements) { + // Replace the elements. + this->elements = allocator.copyInto(elements); + return success(); + } + +private: + unsigned tag; + StringAttr name; + DIFileAttr file; + uint32_t line; + DIScopeAttr scope; + DITypeAttr baseType; + DIFlags flags; + uint64_t sizeInBits; + uint64_t alignInBits; + ArrayRef elements; + StringAttr identifier; +}; + +} // namespace detail +} // namespace LLVM +} // namespace mlir + +#endif // DIALECT_LLVMIR_IR_ATTRDETAIL_H diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp index 645a45dd96bef..c11ed72fa3557 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp @@ -10,6 +10,8 @@ // //===----------------------------------------------------------------------===// +#include "AttrDetail.h" + #include "mlir/Dialect/LLVMIR/LLVMAttrs.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/Builders.h" @@ -47,6 +49,7 @@ void LLVMDialect::registerAttributes() { addAttributes< #define GET_ATTRDEF_LIST #include "mlir/Dialect/LLVMIR/LLVMOpsAttrDefs.cpp.inc" + >(); } @@ -124,7 +127,7 @@ bool MemoryEffectsAttr::isReadWrite() { } //===----------------------------------------------------------------------===// -// DIExpression +// DIExpressionAttr //===----------------------------------------------------------------------===// DIExpressionAttr DIExpressionAttr::get(MLIRContext *context) { @@ -248,3 +251,342 @@ TargetFeaturesAttr TargetFeaturesAttr::featuresAt(Operation *op) { return parentFunction.getOperation()->getAttrOfType( getAttributeName()); } + +//===----------------------------------------------------------------------===// +// DICompositeTypeAttr +//===----------------------------------------------------------------------===// + +DICompositeTypeAttr +DICompositeTypeAttr::get(MLIRContext *context, unsigned tag, StringAttr name, + DIFileAttr file, uint32_t line, DIScopeAttr scope, + DITypeAttr baseType, DIFlags flags, + uint64_t sizeInBits, uint64_t alignInBits, + ::llvm::ArrayRef elements) { + return Base::get(context, tag, name, file, line, scope, baseType, flags, + sizeInBits, alignInBits, elements, StringAttr()); +} + +DICompositeTypeAttr DICompositeTypeAttr::get( + MLIRContext *context, StringAttr identifier, unsigned tag, StringAttr name, + DIFileAttr file, uint32_t line, DIScopeAttr scope, DITypeAttr baseType, + DIFlags flags, uint64_t sizeInBits, uint64_t alignInBits, + ::llvm::ArrayRef elements) { + return Base::get(context, tag, name, file, line, scope, baseType, flags, + sizeInBits, alignInBits, elements, identifier); +} + +unsigned DICompositeTypeAttr::getTag() const { return getImpl()->getTag(); } + +StringAttr DICompositeTypeAttr::getName() const { return getImpl()->getName(); } + +DIFileAttr DICompositeTypeAttr::getFile() const { return getImpl()->getFile(); } + +uint32_t DICompositeTypeAttr::getLine() const { return getImpl()->getLine(); } + +DIScopeAttr DICompositeTypeAttr::getScope() const { + return getImpl()->getScope(); +} + +DITypeAttr DICompositeTypeAttr::getBaseType() const { + return getImpl()->getBaseType(); +} + +DIFlags DICompositeTypeAttr::getFlags() const { return getImpl()->getFlags(); } + +uint64_t DICompositeTypeAttr::getSizeInBits() const { + return getImpl()->getSizeInBits(); +} + +uint64_t DICompositeTypeAttr::getAlignInBits() const { + return getImpl()->getAlignInBits(); +} + +::llvm::ArrayRef DICompositeTypeAttr::getElements() const { + return getImpl()->getElements(); +} + +StringAttr DICompositeTypeAttr::getIdentifier() const { + return getImpl()->getIdentifier(); +} + +Attribute DICompositeTypeAttr::parse(AsmParser &parser, Type type) { + FailureOr cyclicParse; + FailureOr tag; + FailureOr name; + FailureOr file; + FailureOr line; + FailureOr scope; + FailureOr baseType; + FailureOr flags; + FailureOr sizeInBits; + FailureOr alignInBits; + SmallVector elements; + StringAttr identifier; + const Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation()); + + auto paramParser = [&]() -> LogicalResult { + StringRef paramKey; + if (parser.parseKeyword(¶mKey)) { + return parser.emitError(parser.getCurrentLocation(), + "expected parameter name."); + } + + if (parser.parseEqual()) { + return parser.emitError(parser.getCurrentLocation(), + "expected `=` following parameter name."); + } + + if (failed(tag) && paramKey == "tag") { + tag = [&]() -> FailureOr { + StringRef nameKeyword; + if (parser.parseKeyword(&nameKeyword)) + return failure(); + if (const unsigned value = llvm::dwarf::getTag(nameKeyword)) + return value; + return parser.emitError(parser.getCurrentLocation()) + << "invalid debug info debug info tag name: " << nameKeyword; + }(); + } else if (failed(name) && paramKey == "name") { + name = FieldParser::parse(parser); + if (failed(name)) { + return parser.emitError(parser.getCurrentLocation(), + "failed to parse parameter 'name'"); + } + } else if (failed(file) && paramKey == "file") { + file = FieldParser::parse(parser); + if (failed(file)) { + return parser.emitError(parser.getCurrentLocation(), + "failed to parse parameter 'file'"); + } + } else if (failed(line) && paramKey == "line") { + line = FieldParser::parse(parser); + if (failed(line)) { + return parser.emitError(parser.getCurrentLocation(), + "failed to parse parameter 'line'"); + } + } else if (failed(scope) && paramKey == "scope") { + scope = FieldParser::parse(parser); + if (failed(scope)) { + return parser.emitError(parser.getCurrentLocation(), + "failed to parse parameter 'scope'"); + } + } else if (failed(baseType) && paramKey == "baseType") { + baseType = FieldParser::parse(parser); + if (failed(baseType)) { + return parser.emitError(parser.getCurrentLocation(), + "failed to parse parameter 'baseType'"); + } + } else if (failed(flags) && paramKey == "flags") { + flags = FieldParser::parse(parser); + if (failed(flags)) { + return parser.emitError(parser.getCurrentLocation(), + "failed to parse parameter 'flags'"); + } + } else if (failed(sizeInBits) && paramKey == "sizeInBits") { + sizeInBits = FieldParser::parse(parser); + if (failed(sizeInBits)) { + return parser.emitError(parser.getCurrentLocation(), + "failed to parse parameter 'sizeInBits'"); + } + } else if (failed(alignInBits) && paramKey == "alignInBits") { + alignInBits = FieldParser::parse(parser); + if (failed(alignInBits)) { + return parser.emitError(parser.getCurrentLocation(), + "failed to parse parameter 'alignInBits'"); + } + } else { + return parser.emitError(parser.getCurrentLocation(), + "unknown parameter '") + << paramKey << "'"; + } + return success(); + }; + + // Begin parsing. + if (parser.parseLess()) { + parser.emitError(parser.getCurrentLocation(), "expected `<`"); + return {}; + } + + // First, attempt to parse the identifier attribute. + const OptionalParseResult idResult = + parser.parseOptionalAttribute(identifier); + if (idResult.has_value() && succeeded(*idResult)) { + if (succeeded(parser.parseOptionalGreater())) { + DICompositeTypeAttr result = + getIdentified(parser.getContext(), identifier); + // Cyclic parsing should not initiate with only the identifier. Only + // nested instances should terminate early. + if (succeeded(parser.tryStartCyclicParse(result))) { + parser.emitError(parser.getCurrentLocation(), + "Expected identified attribute to contain at least " + "one other parameter"); + return {}; + } + return result; + } + + if (parser.parseComma()) { + parser.emitError(parser.getCurrentLocation(), "Expected `,`"); + } + } + + // Parse immutable parameters. + if (parser.parseCommaSeparatedList(paramParser)) { + return {}; + } + + if (identifier) { + // Create the identified attribute. + DICompositeTypeAttr result = + get(parser.getContext(), identifier, tag.value_or(0), + name.value_or(StringAttr()), file.value_or(DIFileAttr()), + line.value_or(0), scope.value_or(DIScopeAttr()), + baseType.value_or(DITypeAttr()), flags.value_or(DIFlags::Zero), + sizeInBits.value_or(0), alignInBits.value_or(0)); + + // Initiate cyclic parsing. + if (cyclicParse = parser.tryStartCyclicParse(result); failed(cyclicParse)) { + return {}; + } + } + + // Parse the elements now. + if (succeeded(parser.parseOptionalLParen())) { + if (parser.parseCommaSeparatedList([&]() -> LogicalResult { + Attribute attr; + if (parser.parseAttribute(attr)) { + return parser.emitError(parser.getCurrentLocation(), + "expected attribute"); + } + elements.push_back(mlir::cast(attr)); + return success(); + })) { + return {}; + } + + if (parser.parseRParen()) { + parser.emitError(parser.getCurrentLocation(), "expected `)"); + return {}; + } + } + + // Expect the attribute to terminate. + if (parser.parseGreater()) { + parser.emitError(parser.getCurrentLocation(), "expected `>`"); + return {}; + } + + if (!identifier) + return get(loc.getContext(), tag.value_or(0), name.value_or(StringAttr()), + file.value_or(DIFileAttr()), line.value_or(0), + scope.value_or(DIScopeAttr()), baseType.value_or(DITypeAttr()), + flags.value_or(DIFlags::Zero), sizeInBits.value_or(0), + alignInBits.value_or(0), elements); + + // Replace the elements if the attribute is identified. + DICompositeTypeAttr result = getIdentified(parser.getContext(), identifier); + result.replaceElements(elements); + return result; +} + +void DICompositeTypeAttr::print(AsmPrinter &printer) const { + FailureOr cyclicPrint; + SmallVector> valuePrinters; + printer << "<"; + if (getImpl()->isIdentified()) { + cyclicPrint = printer.tryStartCyclicPrint(*this); + if (failed(cyclicPrint)) { + printer << getIdentifier() << ">"; + return; + } + valuePrinters.push_back([&]() { printer << getIdentifier(); }); + } + + if (getTag() > 0) { + valuePrinters.push_back( + [&]() { + printer << "tag = " << llvm::dwarf::TagString(getTag()); + }); + } + + if (getName()) { + valuePrinters.push_back([&]() { + + + + printer << "name = "; + printer.printStrippedAttrOrType(getName()); + }); + } + + if (getFile()) { + valuePrinters.push_back([&]() { + printer << "file = "; + printer.printStrippedAttrOrType(getFile()); + }); + } + + if (getLine() > 0) { + valuePrinters.push_back([&]() { + printer << "line = "; + printer.printStrippedAttrOrType(getLine()); + }); + } + + if (getScope()) { + valuePrinters.push_back([&]() { + printer << "scope = "; + printer.printStrippedAttrOrType(getScope()); + }); + } + + if (getBaseType()) { + valuePrinters.push_back([&]() { + printer << "baseType = "; + printer.printStrippedAttrOrType(getBaseType()); + }); + } + + if (getFlags() != DIFlags::Zero) { + valuePrinters.push_back([&]() { + printer << "flags = "; + printer.printStrippedAttrOrType(getFlags()); + }); + } + + if (getSizeInBits() > 0) { + valuePrinters.push_back([&]() { + printer << "sizeInBits = "; + printer.printStrippedAttrOrType(getSizeInBits()); + }); + } + + if (getAlignInBits() > 0) { + valuePrinters.push_back([&]() { + printer << "alignInBits = "; + printer.printStrippedAttrOrType(getAlignInBits()); + }); + } + interleaveComma(valuePrinters, printer, + [&](const std::function &fn) { fn(); }); + + if (!getElements().empty()) { + printer << " ("; + printer.printStrippedAttrOrType(getElements()); + printer << ")"; + } + printer << ">"; +} + +DICompositeTypeAttr DICompositeTypeAttr::getIdentified(MLIRContext *context, + StringAttr identifier) { + return Base::get(context, 0, StringAttr(), DIFileAttr(), 0, DIScopeAttr(), + DITypeAttr(), DIFlags::Zero, 0, 0, ArrayRef(), + identifier); +} + +void DICompositeTypeAttr::replaceElements( + const ArrayRef &elements) { + (void)Base::mutate(elements); +} diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 53e1088f620d7..70e8a6ce20858 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -2933,14 +2933,14 @@ struct LLVMOpAsmDialectInterface : public OpAsmDialectInterface { return TypeSwitch(attr) .Case([&](auto attr) { os << decltype(attr)::getMnemonic(); return AliasResult::OverridableAlias; diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 1f7cbf349255d..220fd1fbdd558 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -744,6 +744,7 @@ class DummyAliasOperationPrinter : private OpAsmPrinter { printAttribute(attr); } LogicalResult printAlias(Attribute attr) override { + initializer.visit(attr); return success(); } diff --git a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp index 1178417fd2a6c..ce7f6371cb150 100644 --- a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp +++ b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp @@ -35,9 +35,22 @@ using namespace mlir; /// Returns the bitwidth of the index type if specified in the param list. /// Assumes 64-bit index otherwise. static uint64_t getIndexBitwidth(DataLayoutEntryListRef params) { - if (params.empty()) + DataLayoutEntryInterface entry; + + // Look up the bitwidth param in the list. + for (DataLayoutEntryInterface param : params) { + if (param.getKey().is() && + mlir::isa(param.getKey().get())) + entry = param; + } + + // No corresponding entry was found, so assume the bitwidth is 64-bit. + if (!entry) return 64; - auto attr = cast(params.front().getValue()); + + // The expected attribute is a IntegerAttr. Cast to it and retreive the + // bitwidth value. + auto attr = cast(entry.getValue()); return attr.getValue().getZExtValue(); } @@ -86,16 +99,20 @@ mlir::detail::getDefaultTypeSizeInBits(Type type, const DataLayout &dataLayout, reportMissingDataLayout(type); } +template static DataLayoutEntryInterface -findEntryForIntegerType(IntegerType intType, +findEntryForType(T type, ArrayRef params) { assert(!params.empty() && "expected non-empty parameter list"); std::map sortedParams; for (DataLayoutEntryInterface entry : params) { - sortedParams.insert(std::make_pair( - entry.getKey().get().getIntOrFloatBitWidth(), entry)); + // Filter the params by integer type. + if (entry.getKey().is() && + mlir::isa(entry.getKey().get())) + sortedParams.insert(std::make_pair( + entry.getKey().get().getIntOrFloatBitWidth(), entry)); } - auto iter = sortedParams.lower_bound(intType.getWidth()); + auto iter = sortedParams.lower_bound(type.getWidth()); if (iter == sortedParams.end()) iter = std::prev(iter); @@ -122,17 +139,15 @@ getIntegerTypeABIAlignment(IntegerType intType, : kDefaultSmallIntAlignment; } - return extractABIAlignment(findEntryForIntegerType(intType, params)); + return extractABIAlignment(findEntryForType(intType, params)); } static uint64_t getFloatTypeABIAlignment(FloatType fltType, const DataLayout &dataLayout, ArrayRef params) { - assert(params.size() <= 1 && "at most one data layout entry is expected for " - "the singleton floating-point type"); if (params.empty()) return llvm::PowerOf2Ceil(dataLayout.getTypeSize(fltType).getFixedValue()); - return extractABIAlignment(params[0]); + return extractABIAlignment(findEntryForType(fltType, params)); } uint64_t mlir::detail::getDefaultABIAlignment( @@ -175,17 +190,15 @@ getIntegerTypePreferredAlignment(IntegerType intType, if (params.empty()) return llvm::PowerOf2Ceil(dataLayout.getTypeSize(intType).getFixedValue()); - return extractPreferredAlignment(findEntryForIntegerType(intType, params)); + return extractPreferredAlignment(findEntryForType(intType, params)); } static uint64_t getFloatTypePreferredAlignment(FloatType fltType, const DataLayout &dataLayout, ArrayRef params) { - assert(params.size() <= 1 && "at most one data layout entry is expected for " - "the singleton floating-point type"); if (params.empty()) return dataLayout.getTypeABIAlignment(fltType); - return extractPreferredAlignment(params[0]); + return extractPreferredAlignment(findEntryForType(fltType, params)); } uint64_t mlir::detail::getDefaultPreferredAlignment( diff --git a/mlir/lib/Target/LLVMIR/DebugTranslation.cpp b/mlir/lib/Target/LLVMIR/DebugTranslation.cpp index 16918aab54978..1f73d16b8ead4 100644 --- a/mlir/lib/Target/LLVMIR/DebugTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/DebugTranslation.cpp @@ -118,10 +118,6 @@ static DINodeT *getDistinctOrUnique(bool isDistinct, Ts &&...args) { llvm::DICompositeType * DebugTranslation::translateImpl(DICompositeTypeAttr attr) { - SmallVector elements; - for (auto member : attr.getElements()) - elements.push_back(translate(member)); - // TODO: Use distinct attributes to model this, once they have landed. // Depending on the tag, composite types must be distinct. bool isDistinct = false; @@ -133,15 +129,31 @@ DebugTranslation::translateImpl(DICompositeTypeAttr attr) { isDistinct = true; } - return getDistinctOrUnique( + // Create the composite type metadata first with an empty set of elements. + llvm::DICompositeType *result = getDistinctOrUnique( isDistinct, llvmCtx, attr.getTag(), getMDStringOrNull(attr.getName()), translate(attr.getFile()), attr.getLine(), translate(attr.getScope()), translate(attr.getBaseType()), attr.getSizeInBits(), attr.getAlignInBits(), /*OffsetInBits=*/0, /*Flags=*/static_cast(attr.getFlags()), - llvm::MDNode::get(llvmCtx, elements), + llvm::MDNode::get(llvmCtx, {}), /*RuntimeLang=*/0, /*VTableHolder=*/nullptr); + + // Short-circuit the mapping for this attribute to prevent infinite recursion + // if this composite type is encountered while translating the elements. + attrToNode[attr] = result; + + // Translate the elements. + SmallVector elements; + for (const DINodeAttr member : attr.getElements()) + elements.push_back(translate(member)); + + // Replace the elements in the resulting metadata. + result->replaceElements(llvm::MDTuple::get(llvmCtx, elements)); + + // Return the composite type. + return result; } llvm::DIDerivedType *DebugTranslation::translateImpl(DIDerivedTypeAttr attr) { diff --git a/mlir/test/Dialect/LLVMIR/debuginfo.mlir b/mlir/test/Dialect/LLVMIR/debuginfo.mlir index 53c38b4797031..c9724db7196ba 100644 --- a/mlir/test/Dialect/LLVMIR/debuginfo.mlir +++ b/mlir/test/Dialect/LLVMIR/debuginfo.mlir @@ -36,6 +36,11 @@ tag = DW_TAG_pointer_type, name = "ptr1" > +// CHECK-DAG: #[[COMP3:.*]] = #llvm.di_composite_type<"mystruct", tag = DW_TAG_structure_type, name = "array1", file = #[[FILE]], scope = #[[FILE]] (#llvm.di_composite_type<"mystruct">, #int0, #int1)> +#comp3 = #llvm.di_composite_type<"mystruct", tag = DW_TAG_structure_type, name = "struct1", + file = #file, scope = #file (#llvm.di_composite_type<"mystruct">, #int0, #int1) +> + // CHECK-DAG: #[[COMP0:.*]] = #llvm.di_composite_type #comp0 = #llvm.di_composite_type< tag = DW_TAG_array_type, name = "array0", @@ -45,9 +50,9 @@ // CHECK-DAG: #[[COMP1:.*]] = #llvm.di_composite_type> #comp1 = #llvm.di_composite_type< tag = DW_TAG_array_type, name = "array1", file = #file, - scope = #file, baseType = #int0, + scope = #file, baseType = #int0 // Specify the subrange count. - elements = #llvm.di_subrange + (#llvm.di_subrange) > // CHECK-DAG: #[[TOPLEVEL:.*]] = #llvm.di_namespace @@ -74,7 +79,7 @@ // CHECK-DAG: #[[SPTYPE0:.*]] = #llvm.di_subroutine_type #spType0 = #llvm.di_subroutine_type< - callingConvention = DW_CC_normal, types = #null, #int0, #ptr0, #ptr1, #comp0, #comp1, #comp2 + callingConvention = DW_CC_normal, types = #null, #int0, #ptr0, #ptr1, #comp0, #comp1, #comp2, #comp3 > // CHECK-DAG: #[[SPTYPE1:.*]] = #llvm.di_subroutine_type