diff --git a/mlir/docs/LangRef.md b/mlir/docs/LangRef.md index 0cfe845638c3c..a00f9f76c4253 100644 --- a/mlir/docs/LangRef.md +++ b/mlir/docs/LangRef.md @@ -183,12 +183,14 @@ starting with a `//` and going until the end of the line. ``` // Top level production -toplevel := (operation | attribute-alias-def | type-alias-def)* +toplevel := (operation | alias-block-def)* +alias-block-def := (attribute-alias-def | type-alias-def)* ``` The production `toplevel` is the top level production that is parsed by any parsing -consuming the MLIR syntax. [Operations](#operations), -[Attribute aliases](#attribute-value-aliases), and [Type aliases](#type-aliases) +consuming the MLIR syntax. [Operations](#operations) and +[Alias Blocks](#alias-block-definitions) consisting of +[Attribute aliases](#attribute-value-aliases) and [Type aliases](#type-aliases) can be declared on the toplevel. ### Identifiers and keywords @@ -880,3 +882,26 @@ version using readAttribute and readType methods. There is no restriction on what kind of information a dialect is allowed to encode to model its versioning. Currently, versioning is supported only for bytecode formats. + +## Alias Block Definitions + +An alias block is a list of subsequent attribute or type alias definitions that +are conceptually parsed as one unit. +This allows any alias definition within the block to reference any other alias +definition within the block, regardless if defined lexically later or earlier in +the block. + +```mlir +// Alias block consisting of #array, !integer_type and #integer_attr. +#array = [#integer_attr, !integer_type] +!integer_type = i32 +#integer_attr = 8 : !integer_type + +// Illegal. !other_type is not part of this alias block and defined later +// in the file. +!tuple = tuple + +func.func @foo() { ... } + +!other_type = f32 +``` diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index 8864ef02cd3cb..0893744c6a119 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -1365,7 +1365,7 @@ class AsmParser { AttrOrTypeT> || std::is_base_of_v, AttrOrTypeT>, "Only mutable attributes or types can be cyclic"); - if (failed(pushCyclicParsing(attrOrType.getAsOpaquePointer()))) + if (failed(pushCyclicParsing(attrOrType))) return failure(); return CyclicParseReset(this); @@ -1377,11 +1377,11 @@ class AsmParser { virtual FailureOr parseResourceHandle(Dialect *dialect) = 0; - /// Pushes a new attribute or type in the form of a type erased pointer - /// into an internal set. + /// Pushes a new attribute or type into an internal set. /// Returns success if the type or attribute was inserted in the set or /// failure if it was already contained. - virtual LogicalResult pushCyclicParsing(const void *opaquePointer) = 0; + virtual LogicalResult + pushCyclicParsing(PointerUnion attrOrType) = 0; /// Removes the element that was last inserted with a successful call to /// `pushCyclicParsing`. There must be exactly one `popCyclicParsing` call diff --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h index 30c0079cda086..1b88ca240c0eb 100644 --- a/mlir/lib/AsmParser/AsmParserImpl.h +++ b/mlir/lib/AsmParser/AsmParserImpl.h @@ -570,8 +570,9 @@ class AsmParserImpl : public BaseT { return parser.parseXInDimensionList(); } - LogicalResult pushCyclicParsing(const void *opaquePointer) override { - return success(parser.getState().cyclicParsingStack.insert(opaquePointer)); + LogicalResult + pushCyclicParsing(PointerUnion attrOrType) override { + return success(parser.getState().cyclicParsingStack.insert(attrOrType)); } void popCyclicParsing() override { diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp index 3437ac9addc5f..0ed385de8cf3b 100644 --- a/mlir/lib/AsmParser/AttributeParser.cpp +++ b/mlir/lib/AsmParser/AttributeParser.cpp @@ -49,7 +49,7 @@ using namespace mlir::detail; /// | distinct-attribute /// | extended-attribute /// -Attribute Parser::parseAttribute(Type type) { +Attribute Parser::parseAttribute(Type type, StringRef aliasDefName) { switch (getToken().getKind()) { // Parse an AffineMap or IntegerSet attribute. case Token::kw_affine_map: { @@ -117,7 +117,7 @@ Attribute Parser::parseAttribute(Type type) { // Parse an extended attribute, i.e. alias or dialect attribute. case Token::hash_identifier: - return parseExtendedAttr(type); + return parseExtendedAttr(type, aliasDefName); // Parse floating point and integer attributes. case Token::floatliteral: @@ -145,6 +145,10 @@ Attribute Parser::parseAttribute(Type type) { parseLocationInstance(locAttr) || parseToken(Token::r_paren, "expected ')' in inline location")) return Attribute(); + + if (syntaxOnly()) + return state.syntaxOnlyAttr; + return locAttr; } @@ -430,6 +434,9 @@ Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) { return FloatAttr::get(floatType, *result); } + if (syntaxOnly()) + return state.syntaxOnlyAttr; + if (!isa(type)) return emitError(loc, "integer literal not valid for specified type"), nullptr; @@ -1003,7 +1010,9 @@ Attribute Parser::parseDenseElementsAttr(Type attrType) { auto type = parseElementsLiteralType(attrType); if (!type) return nullptr; - return literalParser.getAttr(loc, type); + if (syntaxOnly()) + return state.syntaxOnlyAttr; + return literalParser.getAttr(loc, cast(type)); } Attribute Parser::parseDenseResourceElementsAttr(Type attrType) { @@ -1030,6 +1039,9 @@ Attribute Parser::parseDenseResourceElementsAttr(Type attrType) { return nullptr; } + if (syntaxOnly()) + return state.syntaxOnlyAttr; + ShapedType shapedType = dyn_cast(attrType); if (!shapedType) { emitError(typeLoc, "`dense_resource` expected a shaped type"); @@ -1044,7 +1056,7 @@ Attribute Parser::parseDenseResourceElementsAttr(Type attrType) { /// elements-literal-type ::= vector-type | ranked-tensor-type /// /// This method also checks the type has static shape. -ShapedType Parser::parseElementsLiteralType(Type type) { +Type Parser::parseElementsLiteralType(Type type) { // If the user didn't provide a type, parse the colon type for the literal. if (!type) { if (parseToken(Token::colon, "expected ':'")) @@ -1053,6 +1065,9 @@ ShapedType Parser::parseElementsLiteralType(Type type) { return nullptr; } + if (syntaxOnly()) + return state.syntaxOnlyType; + auto sType = dyn_cast(type); if (!sType) { emitError("elements literal must be a shaped type"); @@ -1077,17 +1092,23 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) { // of the type. Type indiceEltType = builder.getIntegerType(64); if (consumeIf(Token::greater)) { - ShapedType type = parseElementsLiteralType(attrType); + Type type = parseElementsLiteralType(attrType); if (!type) return nullptr; + if (syntaxOnly()) + return state.syntaxOnlyAttr; + // Construct the sparse elements attr using zero element indice/value // attributes. + ShapedType shapedType = cast(type); ShapedType indicesType = - RankedTensorType::get({0, type.getRank()}, indiceEltType); - ShapedType valuesType = RankedTensorType::get({0}, type.getElementType()); + RankedTensorType::get({0, shapedType.getRank()}, indiceEltType); + ShapedType valuesType = + RankedTensorType::get({0}, shapedType.getElementType()); return getChecked( - loc, type, DenseElementsAttr::get(indicesType, ArrayRef()), + loc, shapedType, + DenseElementsAttr::get(indicesType, ArrayRef()), DenseElementsAttr::get(valuesType, ArrayRef())); } @@ -1114,6 +1135,11 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) { if (!type) return nullptr; + if (syntaxOnly()) + return state.syntaxOnlyAttr; + + ShapedType shapedType = cast(type); + // If the indices are a splat, i.e. the literal parser parsed an element and // not a list, we set the shape explicitly. The indices are represented by a // 2-dimensional shape where the second dimension is the rank of the type. @@ -1121,7 +1147,8 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) { // indice and thus one for the first dimension. ShapedType indicesType; if (indiceParser.getShape().empty()) { - indicesType = RankedTensorType::get({1, type.getRank()}, indiceEltType); + indicesType = + RankedTensorType::get({1, shapedType.getRank()}, indiceEltType); } else { // Otherwise, set the shape to the one parsed by the literal parser. indicesType = RankedTensorType::get(indiceParser.getShape(), indiceEltType); @@ -1131,7 +1158,7 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) { // If the values are a splat, set the shape explicitly based on the number of // indices. The number of indices is encoded in the first dimension of the // indice shape type. - auto valuesEltType = type.getElementType(); + auto valuesEltType = shapedType.getElementType(); ShapedType valuesType = valuesParser.getShape().empty() ? RankedTensorType::get({indicesType.getDimSize(0)}, valuesEltType) @@ -1139,7 +1166,7 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) { auto values = valuesParser.getAttr(valuesLoc, valuesType); // Build the sparse elements attribute by the indices and values. - return getChecked(loc, type, indices, values); + return getChecked(loc, shapedType, indices, values); } Attribute Parser::parseStridedLayoutAttr() { @@ -1260,6 +1287,9 @@ Attribute Parser::parseDistinctAttr(Type type) { return {}; } + if (syntaxOnly()) + return state.syntaxOnlyAttr; + // Add the distinct attribute to the parser state, if it has not been parsed // before. Otherwise, check if the parsed reference attribute matches the one // found in the parser state. diff --git a/mlir/lib/AsmParser/DialectSymbolParser.cpp b/mlir/lib/AsmParser/DialectSymbolParser.cpp index 2b1b114b90e86..2d78db35433ee 100644 --- a/mlir/lib/AsmParser/DialectSymbolParser.cpp +++ b/mlir/lib/AsmParser/DialectSymbolParser.cpp @@ -15,6 +15,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/DialectImplementation.h" +#include "llvm/ADT/ScopeExit.h" #include "llvm/Support/SourceMgr.h" using namespace mlir; @@ -28,18 +29,37 @@ namespace { /// hooking into the main MLIR parsing logic. class CustomDialectAsmParser : public AsmParserImpl { public: - CustomDialectAsmParser(StringRef fullSpec, Parser &parser) + CustomDialectAsmParser(StringRef fullSpec, Parser &parser, + StringRef aliasDefName) : AsmParserImpl(parser.getToken().getLoc(), parser), - fullSpec(fullSpec) {} + fullSpec(fullSpec), aliasDefName(aliasDefName) {} ~CustomDialectAsmParser() override = default; /// Returns the full specification of the symbol being parsed. This allows /// for using a separate parser if necessary. StringRef getFullSymbolSpec() const override { return fullSpec; } + LogicalResult + pushCyclicParsing(PointerUnion attrOrType) override { + // If this is an alias definition, register the mutable attribute or type. + if (!aliasDefName.empty()) { + if (auto attr = dyn_cast(attrOrType)) + parser.getState().symbols.attributeAliasDefinitions[aliasDefName] = + attr; + else + parser.getState().symbols.typeAliasDefinitions[aliasDefName] = + cast(attrOrType); + } + return AsmParserImpl::pushCyclicParsing(attrOrType); + } + private: /// The full symbol specification. StringRef fullSpec; + + /// If this parser is used to parse an alias definition, the name of the alias + /// definition. Empty otherwise. + StringRef aliasDefName; }; } // namespace @@ -156,9 +176,11 @@ ParseResult Parser::parseDialectSymbolBody(StringRef &body, } /// Parse an extended dialect symbol. -template +template static Symbol parseExtendedSymbol(Parser &p, AsmParserState *asmState, SymbolAliasMap &aliases, + ParseAliasFn &parseAliasFn, CreateFn &&createSymbol) { Token tok = p.getToken(); @@ -185,12 +207,32 @@ static Symbol parseExtendedSymbol(Parser &p, AsmParserState *asmState, // If there is no '<' token following this, and if the typename contains no // dot, then we are parsing a symbol alias. if (!hasTrailingData && !isPrettyName) { + + // Don't check the validity of alias reference in syntax-only mode. + if (p.syntaxOnly()) { + if constexpr (std::is_same_v) + return p.getState().syntaxOnlyType; + else + return p.getState().syntaxOnlyAttr; + } + // Check for an alias for this type. auto aliasIt = aliases.find(identifier); - if (aliasIt == aliases.end()) - return (p.emitWrongTokenError("undefined symbol alias id '" + identifier + - "'"), - nullptr); + if (aliasIt == aliases.end()) { + FailureOr symbol = failure(); + // Try the parse alias function if set. + if (parseAliasFn) + symbol = parseAliasFn(identifier); + + if (failed(symbol)) { + p.emitWrongTokenError("undefined symbol alias id '" + identifier + "'"); + return nullptr; + } + if (!*symbol) + return nullptr; + + aliasIt = aliases.insert({identifier, *symbol}).first; + } if (asmState) { if constexpr (std::is_same_v) asmState->addTypeAliasUses(identifier, range); @@ -237,16 +279,20 @@ static Symbol parseExtendedSymbol(Parser &p, AsmParserState *asmState, /// | `#` alias-name pretty-dialect-sym-body? (`:` type)? /// attribute-alias ::= `#` alias-name /// -Attribute Parser::parseExtendedAttr(Type type) { +Attribute Parser::parseExtendedAttr(Type type, StringRef aliasDefName) { MLIRContext *ctx = getContext(); Attribute attr = parseExtendedSymbol( *this, state.asmState, state.symbols.attributeAliasDefinitions, + state.symbols.parseUnknownAttributeAlias, [&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Attribute { // Parse an optional trailing colon type. Type attrType = type; if (consumeIf(Token::colon) && !(attrType = parseType())) return Attribute(); + if (syntaxOnly()) + return state.syntaxOnlyAttr; + // If we found a registered dialect, then ask it to parse the attribute. if (Dialect *dialect = builder.getContext()->getOrLoadDialect(dialectName)) { @@ -255,7 +301,7 @@ Attribute Parser::parseExtendedAttr(Type type) { resetToken(symbolData.data()); // Parse the attribute. - CustomDialectAsmParser customParser(symbolData, *this); + CustomDialectAsmParser customParser(symbolData, *this, aliasDefName); Attribute attr = dialect->parseAttribute(customParser, attrType); resetToken(curLexerPos); return attr; @@ -284,11 +330,15 @@ Attribute Parser::parseExtendedAttr(Type type) { /// dialect-type ::= `!` alias-name pretty-dialect-attribute-body? /// type-alias ::= `!` alias-name /// -Type Parser::parseExtendedType() { +Type Parser::parseExtendedType(StringRef aliasDefName) { MLIRContext *ctx = getContext(); return parseExtendedSymbol( *this, state.asmState, state.symbols.typeAliasDefinitions, + state.symbols.parseUnknownTypeAlias, [&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Type { + if (syntaxOnly()) + return state.syntaxOnlyType; + // If we found a registered dialect, then ask it to parse the type. if (auto *dialect = ctx->getOrLoadDialect(dialectName)) { // Temporarily reset the lexer to let the dialect parse the type. @@ -296,7 +346,7 @@ Type Parser::parseExtendedType() { resetToken(symbolData.data()); // Parse the type. - CustomDialectAsmParser customParser(symbolData, *this); + CustomDialectAsmParser customParser(symbolData, *this, aliasDefName); Type type = dialect->parseType(customParser); resetToken(curLexerPos); return type; diff --git a/mlir/lib/AsmParser/LocationParser.cpp b/mlir/lib/AsmParser/LocationParser.cpp index 61b20179800c6..8139f188c32a7 100644 --- a/mlir/lib/AsmParser/LocationParser.cpp +++ b/mlir/lib/AsmParser/LocationParser.cpp @@ -53,6 +53,9 @@ ParseResult Parser::parseCallSiteLocation(LocationAttr &loc) { if (parseToken(Token::r_paren, "expected ')' in callsite location")) return failure(); + if (syntaxOnly()) + return success(); + // Return the callsite location. loc = CallSiteLoc::get(calleeLoc, callerLoc); return success(); @@ -79,6 +82,9 @@ ParseResult Parser::parseFusedLocation(LocationAttr &loc) { LocationAttr newLoc; if (parseLocationInstance(newLoc)) return failure(); + if (syntaxOnly()) + return success(); + locations.push_back(newLoc); return success(); }; @@ -135,12 +141,15 @@ ParseResult Parser::parseNameOrFileLineColLocation(LocationAttr &loc) { if (parseLocationInstance(childLoc)) return failure(); - loc = NameLoc::get(StringAttr::get(ctx, str), childLoc); - // Parse the closing ')'. if (parseToken(Token::r_paren, "expected ')' after child location of NameLoc")) return failure(); + + if (syntaxOnly()) + return success(); + + loc = NameLoc::get(StringAttr::get(ctx, str), childLoc); } else { loc = NameLoc::get(StringAttr::get(ctx, str)); } @@ -154,6 +163,10 @@ ParseResult Parser::parseLocationInstance(LocationAttr &loc) { Attribute locAttr = parseExtendedAttr(Type()); if (!locAttr) return failure(); + + if (syntaxOnly()) + return success(); + if (!(loc = dyn_cast(locAttr))) return emitError("expected location attribute, but got") << locAttr; return success(); diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp index 84f44dba806df..4bae575c1c82a 100644 --- a/mlir/lib/AsmParser/Parser.cpp +++ b/mlir/lib/AsmParser/Parser.cpp @@ -29,6 +29,7 @@ #include "llvm/Support/Endian.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/PrettyStackTrace.h" +#include "llvm/Support/SaveAndRestore.h" #include "llvm/Support/SourceMgr.h" #include #include @@ -2417,17 +2418,12 @@ class TopLevelOperationParser : public Parser { ParseResult parse(Block *topLevelBlock, Location parserLoc); private: - /// Parse an attribute alias declaration. + /// Parse an alias block definition. /// /// attribute-alias-def ::= '#' alias-name `=` attribute-value - /// - ParseResult parseAttributeAliasDef(); - - /// Parse a type alias declaration. - /// /// type-alias-def ::= '!' alias-name `=` type - /// - ParseResult parseTypeAliasDef(); + /// alias-block-def ::= (type-alias-def | attribute-alias-def)+ + ParseResult parseAliasBlockDef(); /// Parse a top-level file metadata dictionary. /// @@ -2528,69 +2524,184 @@ class ParsedResourceEntry : public AsmParsedResourceEntry { Token value; Parser &p; }; + +/// Convenient subclass of `ParserState` which configures the parser for +/// syntax-only parsing. This only copies config and other required state for +/// parsing but does not copy side-effecting state such as the code completion +/// context. +struct SyntaxParserState : ParserState { + explicit SyntaxParserState(ParserState &state) + : ParserState(state.lex.getSourceMgr(), state.config, state.symbols, + /*asmState=*/nullptr, + /*codeCompleteContext=*/nullptr) { + syntaxOnly = true; + } +}; + } // namespace -ParseResult TopLevelOperationParser::parseAttributeAliasDef() { - assert(getToken().is(Token::hash_identifier)); - StringRef aliasName = getTokenSpelling().drop_front(); +ParseResult TopLevelOperationParser::parseAliasBlockDef() { - // Check for redefinitions. - if (state.symbols.attributeAliasDefinitions.count(aliasName) > 0) - return emitError("redefinition of attribute alias id '" + aliasName + "'"); + struct UnparsedData { + SMRange location; + StringRef text; + }; - // Make sure this isn't invading the dialect attribute namespace. - if (aliasName.contains('.')) - return emitError("attribute names with a '.' are reserved for " - "dialect-defined names"); + // Use a map vector as StringMap has non-deterministic iteration order. + using StringMapVector = + llvm::MapVector>; - SMRange location = getToken().getLocRange(); - consumeToken(Token::hash_identifier); + StringMapVector unparsedAttributeAliases; + StringMapVector unparsedTypeAliases; - // Parse the '='. - if (parseToken(Token::equal, "expected '=' in attribute alias definition")) - return failure(); + // Returns true if this alias has already been defined, either in this block + // or a previous one. + auto isRedefinition = [&](bool isType, StringRef aliasName) { + if (isType) + return state.symbols.typeAliasDefinitions.contains(aliasName) || + unparsedTypeAliases.contains(aliasName); - // Parse the attribute value. - Attribute attr = parseAttribute(); - if (!attr) - return failure(); + return state.symbols.attributeAliasDefinitions.contains(aliasName) || + unparsedAttributeAliases.contains(aliasName); + }; - // Register this alias with the parser state. - if (state.asmState) - state.asmState->addAttrAliasDefinition(aliasName, location, attr); - state.symbols.attributeAliasDefinitions[aliasName] = attr; - return success(); -} + // Collect all attribute or type alias definitions in unparsed form first. + while ( + getToken().isAny(Token::exclamation_identifier, Token::hash_identifier)) { + StringRef aliasName = getTokenSpelling().drop_front(); + + bool isType = getToken().is(mlir::Token::exclamation_identifier); + StringRef kind = isType ? "type" : "attribute"; + + // Check for redefinitions. + if (isRedefinition(isType, aliasName)) + return emitError("redefinition of ") + << kind << " alias id '" << aliasName << "'"; + + // Make sure this isn't invading the dialect namespace. + if (aliasName.contains('.')) + return emitError(kind) << " names with a '.' are reserved for " + "dialect-defined names"; + + SMRange location = getToken().getLocRange(); + consumeToken(); + + // Parse the '='. + if (parseToken(Token::equal, + "expected '=' in " + kind + " alias definition")) + return failure(); -ParseResult TopLevelOperationParser::parseTypeAliasDef() { - assert(getToken().is(Token::exclamation_identifier)); - StringRef aliasName = getTokenSpelling().drop_front(); + SyntaxParserState skippingParserState(state); + Parser syntaxOnlyParser(skippingParserState); + const char *start = getToken().getLoc().getPointer(); + syntaxOnlyParser.resetToken(start); - // Check for redefinitions. - if (state.symbols.typeAliasDefinitions.count(aliasName) > 0) - return emitError("redefinition of type alias id '" + aliasName + "'"); + // Parse just the syntax of the value, moving the lexer past the definition. + if (isType ? !syntaxOnlyParser.parseType() + : !syntaxOnlyParser.parseAttribute()) + return failure(); + + // Get the location from the lexers new position. + const char *end = syntaxOnlyParser.getToken().getLoc().getPointer(); + size_t length = end - start; - // Make sure this isn't invading the dialect type namespace. - if (aliasName.contains('.')) - return emitError("type names with a '.' are reserved for " - "dialect-defined names"); + StringMapVector &unparsedMap = + isType ? unparsedTypeAliases : unparsedAttributeAliases; - SMRange location = getToken().getLocRange(); - consumeToken(Token::exclamation_identifier); + unparsedMap[aliasName] = + UnparsedData{location, StringRef(start, length).rtrim()}; + + // Move the top-level parser past the alias definition. + resetToken(end); + } + + auto parseAttributeAlias = [&](StringRef aliasName, + const UnparsedData &unparsedData) { + llvm::SaveAndRestore>> cyclicStack( + getState().cyclicParsingStack, {}); + auto exit = saveAndResetToken(unparsedData.text.data()); + Attribute attribute = parseAttribute(Type(), aliasName); + if (!attribute) + return attribute; + + // Register this alias with the parser state. + if (state.asmState) + state.asmState->addAttrAliasDefinition(aliasName, unparsedData.location, + attribute); - // Parse the '='. - if (parseToken(Token::equal, "expected '=' in type alias definition")) + return attribute; + }; + + auto parseTypeAlias = [&](StringRef aliasName, + const UnparsedData &unparsedData) { + llvm::SaveAndRestore>> cyclicStack( + getState().cyclicParsingStack, {}); + auto exit = saveAndResetToken(unparsedData.text.data()); + Type type = parseType(aliasName); + if (!type) + return type; + + // Register this alias with the parser state. + if (state.asmState) + state.asmState->addTypeAliasDefinition(aliasName, unparsedData.location, + type); + + return type; + }; + + // Set the callbacks for the lazy parsing of alias definitions in the parser. + state.symbols.parseUnknownAttributeAlias = + [&](StringRef aliasName) -> FailureOr { + auto *iter = unparsedAttributeAliases.find(aliasName); + if (iter == unparsedAttributeAliases.end()) + return failure(); + + return parseAttributeAlias(aliasName, iter->second); + }; + state.symbols.parseUnknownTypeAlias = + [&](StringRef aliasName) -> FailureOr { + auto *iter = unparsedTypeAliases.find(aliasName); + if (iter == unparsedTypeAliases.end()) + return failure(); + + return parseTypeAlias(aliasName, iter->second); + }; + + // Reset them to nullptr at the end. Keeping them around would lead to the + // access of local variables captured in this scope after we've returned from + // this function. + auto exit = llvm::make_scope_exit([&] { + state.symbols.parseUnknownTypeAlias = nullptr; + state.symbols.parseUnknownAttributeAlias = nullptr; + }); + + // Now go through all the unparsed definitions in the block and parse them. + // The order here is not significant for correctness, but should be + // deterministic. The order can also have an impact on the maximum stack usage + // during parsing. This can be improved in the future. + auto parse = [](auto &unparsed, auto &definitions, auto &parseFn) { + for (auto &&[aliasName, unparsedData] : unparsed) { + // Avoid parsing twice. + if (definitions.contains(aliasName)) + continue; + + auto symbol = parseFn(aliasName, unparsedData); + if (!symbol) + return failure(); + definitions[aliasName] = symbol; + } + return success(); + }; + + if (failed(parse(unparsedAttributeAliases, + state.symbols.attributeAliasDefinitions, + parseAttributeAlias))) return failure(); - // Parse the type. - Type aliasedType = parseType(); - if (!aliasedType) + if (failed(parse(unparsedTypeAliases, state.symbols.typeAliasDefinitions, + parseTypeAlias))) return failure(); - // Register this alias with the parser state. - if (state.asmState) - state.asmState->addTypeAliasDefinition(aliasName, location, aliasedType); - state.symbols.typeAliasDefinitions.try_emplace(aliasName, aliasedType); return success(); } @@ -2729,15 +2840,10 @@ ParseResult TopLevelOperationParser::parse(Block *topLevelBlock, case Token::error: return failure(); - // Parse an attribute alias. + // Parse an alias def block. case Token::hash_identifier: - if (parseAttributeAliasDef()) - return failure(); - break; - - // Parse a type alias. case Token::exclamation_identifier: - if (parseTypeAliasDef()) + if (parseAliasBlockDef()) return failure(); break; diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h index 01c55f97a08c2..0515d9bf95626 100644 --- a/mlir/lib/AsmParser/Parser.h +++ b/mlir/lib/AsmParser/Parser.h @@ -12,6 +12,7 @@ #include "ParserState.h" #include "mlir/IR/Builders.h" #include "mlir/IR/OpImplementation.h" +#include "llvm/ADT/ScopeExit.h" #include namespace mlir { @@ -132,6 +133,22 @@ class Parser { state.curToken = state.lex.lexToken(); } + /// Temporarily resets the parser to the given lexer position. The previous + /// lexer position is saved and restored on destruction of the returned + /// object. + [[nodiscard]] auto saveAndResetToken(const char *tokPos) { + const char *previous = getToken().getLoc().getPointer(); + resetToken(tokPos); + return llvm::make_scope_exit([this, previous] { resetToken(previous); }); + } + + /// Returns true if the parser is in syntax-only mode. In this mode, the + /// parser only checks the syntactic validity of the parsed elements but does + /// not verify the correctness of the parsed data. Syntax-only mode is + /// currently only supported for attribute and type parsing and skips parsing + /// dialect attributes and types entirely. + bool syntaxOnly() { return state.syntaxOnly; } + /// Consume the specified token if present and return success. On failure, /// output a diagnostic and return failure. ParseResult parseToken(Token::Kind expectedToken, const Twine &message); @@ -185,13 +202,13 @@ class Parser { OptionalParseResult parseOptionalType(Type &type); /// Parse an arbitrary type. - Type parseType(); + Type parseType(StringRef aliasDefName = ""); /// Parse a complex type. Type parseComplexType(); /// Parse an extended type. - Type parseExtendedType(); + Type parseExtendedType(StringRef aliasDefName = ""); /// Parse a function type. Type parseFunctionType(); @@ -200,7 +217,7 @@ class Parser { Type parseMemRefType(); /// Parse a non function type. - Type parseNonFunctionType(); + Type parseNonFunctionType(StringRef aliasDefName = ""); /// Parse a tensor type. Type parseTensorType(); @@ -209,7 +226,7 @@ class Parser { Type parseTupleType(); /// Parse a vector type. - VectorType parseVectorType(); + Type parseVectorType(); ParseResult parseVectorDimensionList(SmallVectorImpl &dimensions, SmallVectorImpl &scalableDims); ParseResult parseDimensionListRanked(SmallVectorImpl &dimensions, @@ -223,7 +240,7 @@ class Parser { //===--------------------------------------------------------------------===// /// Parse an arbitrary attribute with an optional type. - Attribute parseAttribute(Type type = {}); + Attribute parseAttribute(Type type = {}, StringRef aliasDefName = ""); /// Parse an optional attribute with the provided type. OptionalParseResult parseOptionalAttribute(Attribute &attribute, @@ -254,7 +271,7 @@ class Parser { Attribute parseDistinctAttr(Type type); /// Parse an extended attribute. - Attribute parseExtendedAttr(Type type); + Attribute parseExtendedAttr(Type type, StringRef aliasDefName = ""); /// Parse a float attribute. Attribute parseFloatAttr(Type type, bool isNegative); @@ -265,7 +282,7 @@ class Parser { /// Parse a dense elements attribute. Attribute parseDenseElementsAttr(Type attrType); - ShapedType parseElementsLiteralType(Type type); + Type parseElementsLiteralType(Type type); /// Parse a dense resource elements attribute. Attribute parseDenseResourceElementsAttr(Type attrType); diff --git a/mlir/lib/AsmParser/ParserState.h b/mlir/lib/AsmParser/ParserState.h index 1428ea3a82cee..0166116ff0ba3 100644 --- a/mlir/lib/AsmParser/ParserState.h +++ b/mlir/lib/AsmParser/ParserState.h @@ -32,6 +32,15 @@ struct SymbolState { /// A map from type alias identifier to Type. llvm::StringMap typeAliasDefinitions; + /// Parser functions set during the parsing of alias-block-defs to parse an + /// unknown attribute or type alias. The parameter is the name of the alias. + /// The function should return failure if no such alias could be found. + /// If any errors occurred during parsing, a null attribute or type should + /// be returned. + llvm::unique_function(StringRef)> + parseUnknownAttributeAlias; + llvm::unique_function(StringRef)> parseUnknownTypeAlias; + /// A map of dialect resource keys to the resolved resource name and handle /// to use during parsing. DenseMap cyclicParsingStack; + SetVector> cyclicParsingStack; /// An optional pointer to a struct containing high level parser state to be /// populated during parsing. @@ -88,6 +97,17 @@ struct ParserState { // popped when done. At the top-level we start with "builtin" as the // default, so that the top-level `module` operation parses as-is. SmallVector defaultDialectStack{"builtin"}; + + /// Controls whether the parser is in syntax-only mode. + bool syntaxOnly = false; + + /// Attribute and type returned by `parseType`, `parseAttribute` and the more + /// specific parsing function to signal syntactic correctness if an attribute + /// or type cannot be created without verifying the parsed data as well. + /// Callers of such function should only check for null or not null return + /// values for error signaling. + Type syntaxOnlyType = NoneType::get(config.getContext()); + Attribute syntaxOnlyAttr = UnitAttr::get(config.getContext()); }; } // namespace detail diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp index 306e850af27bc..b24cf7b7021f0 100644 --- a/mlir/lib/AsmParser/TypeParser.cpp +++ b/mlir/lib/AsmParser/TypeParser.cpp @@ -58,10 +58,10 @@ OptionalParseResult Parser::parseOptionalType(Type &type) { /// type ::= function-type /// | non-function-type /// -Type Parser::parseType() { +Type Parser::parseType(StringRef aliasDefName) { if (getToken().is(Token::l_paren)) return parseFunctionType(); - return parseNonFunctionType(); + return parseNonFunctionType(aliasDefName); } /// Parse a function result type. @@ -130,6 +130,10 @@ Type Parser::parseComplexType() { if (!elementType || parseToken(Token::greater, "expected '>' in complex type")) return nullptr; + + if (syntaxOnly()) + return state.syntaxOnlyType; + if (!isa(elementType) && !isa(elementType)) return emitError(elementTypeLoc, "invalid element type for complex"), nullptr; @@ -150,6 +154,9 @@ Type Parser::parseFunctionType() { parseFunctionResultTypes(results)) return nullptr; + if (syntaxOnly()) + return state.syntaxOnlyType; + return builder.getFunctionType(arguments, results); } @@ -195,9 +202,10 @@ Type Parser::parseMemRefType() { if (!elementType) return nullptr; - // Check that memref is formed from allowed types. - if (!BaseMemRefType::isValidElementType(elementType)) - return emitError(typeLoc, "invalid memref element type"), nullptr; + if (!syntaxOnly()) { // Check that memref is formed from allowed types. + if (!BaseMemRefType::isValidElementType(elementType)) + return emitError(typeLoc, "invalid memref element type"), nullptr; + } MemRefLayoutAttrInterface layout; Attribute memorySpace; @@ -208,6 +216,9 @@ Type Parser::parseMemRefType() { if (!attr) return failure(); + if (syntaxOnly()) + return success(); + if (isa(attr)) { layout = cast(attr); } else if (memorySpace) { @@ -235,6 +246,9 @@ Type Parser::parseMemRefType() { } } + if (syntaxOnly()) + return state.syntaxOnlyType; + if (isUnranked) return getChecked(loc, elementType, memorySpace); @@ -259,7 +273,7 @@ Type Parser::parseMemRefType() { /// float-type ::= `f16` | `bf16` | `f32` | `f64` | `f80` | `f128` /// none-type ::= `none` /// -Type Parser::parseNonFunctionType() { +Type Parser::parseNonFunctionType(StringRef aliasDefName) { switch (getToken().getKind()) { default: return (emitWrongTokenError("expected non-function type"), nullptr); @@ -342,7 +356,7 @@ Type Parser::parseNonFunctionType() { // extended type case Token::exclamation_identifier: - return parseExtendedType(); + return parseExtendedType(aliasDefName); // Handle completion of a dialect type. case Token::code_complete: @@ -437,7 +451,7 @@ Type Parser::parseTupleType() { /// vector-dim-list := (static-dim-list `x`)? (`[` static-dim-list `]` `x`)? /// static-dim-list ::= decimal-literal (`x` decimal-literal)* /// -VectorType Parser::parseVectorType() { +Type Parser::parseVectorType() { consumeToken(Token::kw_vector); if (parseToken(Token::less, "expected '<' in vector type")) @@ -458,6 +472,9 @@ VectorType Parser::parseVectorType() { if (!elementType || parseToken(Token::greater, "expected '>' in vector type")) return nullptr; + if (syntaxOnly()) + return state.syntaxOnlyType; + if (!VectorType::isValidElementType(elementType)) return emitError(typeLoc, "vector elements must be int/index/float type"), nullptr; diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 7b0da30541b16..6b0a08494ddf7 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1192,21 +1192,10 @@ void AliasState::printAliases(AsmPrinter::Impl &p, NewLineCounter &newLine, alias.print(p.getStream()); p.getStream() << " = "; - if (alias.isTypeAlias()) { - // TODO: Support nested aliases in mutable types. - Type type = Type::getFromOpaquePointer(opaqueSymbol); - if (type.hasTrait()) - p.getStream() << type; - else - p.printTypeImpl(type); - } else { - // TODO: Support nested aliases in mutable attributes. - Attribute attr = Attribute::getFromOpaquePointer(opaqueSymbol); - if (attr.hasTrait()) - p.getStream() << attr; - else - p.printAttributeImpl(attr); - } + if (alias.isTypeAlias()) + p.printTypeImpl(Type::getFromOpaquePointer(opaqueSymbol)); + else + p.printAttributeImpl(Attribute::getFromOpaquePointer(opaqueSymbol)); p.getStream() << newLine; } diff --git a/mlir/test/Dialect/LLVMIR/types-typed-pointers.mlir b/mlir/test/Dialect/LLVMIR/types-typed-pointers.mlir index 2d63f379c2ee7..ac112f7745e51 100644 --- a/mlir/test/Dialect/LLVMIR/types-typed-pointers.mlir +++ b/mlir/test/Dialect/LLVMIR/types-typed-pointers.mlir @@ -106,7 +106,7 @@ func.func @ptr_elem_interface(%arg0: !llvm.ptr) { !baz = i64 !qux = !llvm.struct<(!baz)> -!rec = !llvm.struct<"a", (ptr>)> +!rec = !llvm.struct<"a", (ptr)> // CHECK: aliases llvm.func @aliases() { diff --git a/mlir/test/Dialect/SPIRV/IR/types.mlir b/mlir/test/Dialect/SPIRV/IR/types.mlir index e10a6fc77e856..ae50517eadd42 100644 --- a/mlir/test/Dialect/SPIRV/IR/types.mlir +++ b/mlir/test/Dialect/SPIRV/IR/types.mlir @@ -425,6 +425,11 @@ func.func private @id_struct_redefinition(!spirv.struct, Uniform>)>, Uniform>)>) func.func private @id_struct_recursive(!spirv.struct, Uniform>)>, Uniform>)>) -> () +!a = !spirv.struct)> +!b = !spirv.struct)> +// CHECK: func private @id_struct_recursive2(!spirv.struct, Uniform>)>, Uniform>)>) +func.func private @id_struct_recursive2(!a) -> () + // ----- // Equivalent to: @@ -433,6 +438,11 @@ func.func private @id_struct_recursive(!spirv.struct, Uniform>, !spirv.ptr, Uniform>)>, Uniform>)>) func.func private @id_struct_recursive(!spirv.struct, Uniform>, !spirv.ptr, Uniform>)>, Uniform>)>) -> () +!a = !spirv.struct)> +!b = !spirv.struct, !spirv.ptr)> +// CHECK: func private @id_struct_recursive2(!spirv.struct, Uniform>, !spirv.ptr, Uniform>)>, Uniform>)>) +func.func private @id_struct_recursive2(!a) -> () + // ----- //===----------------------------------------------------------------------===// diff --git a/mlir/test/IR/alias-def-groups.mlir b/mlir/test/IR/alias-def-groups.mlir new file mode 100644 index 0000000000000..71d09371fae67 --- /dev/null +++ b/mlir/test/IR/alias-def-groups.mlir @@ -0,0 +1,25 @@ +// RUN: mlir-opt -allow-unregistered-dialect -verify-diagnostics -split-input-file %s | FileCheck %s + +#array = [#integer_attr, !integer_type] +!integer_type = i32 +#integer_attr = 8 : !integer_type + +// CHECK-LABEL: func @foo() +func.func @foo() { + // CHECK-NEXT: value = [8 : i32, i32] + "foo.attr"() { value = #array} : () -> () +} + +// ----- + +// Check that only groups may reference later defined aliases. + +// expected-error@below {{undefined symbol alias id 'integer_attr'}} +#array = [!integer_type, #integer_attr] +!integer_type = i32 + +func.func @foo() { + %0 = "foo.attr"() { value = #array} +} + +#integer_attr = 8 : !integer_type diff --git a/mlir/test/IR/recursive-type.mlir b/mlir/test/IR/recursive-type-and-attr.mlir similarity index 58% rename from mlir/test/IR/recursive-type.mlir rename to mlir/test/IR/recursive-type-and-attr.mlir index 121ba095573ba..29a577bef4d05 100644 --- a/mlir/test/IR/recursive-type.mlir +++ b/mlir/test/IR/recursive-type-and-attr.mlir @@ -1,8 +1,18 @@ // RUN: mlir-opt %s -test-recursive-types | FileCheck %s -// CHECK: !testrec = !test.test_rec> -// CHECK: ![[$NAME:.*]] = !test.test_rec_alias> -// CHECK: ![[$NAME2:.*]] = !test.test_rec_alias, i32>> +// CHECK-DAG: !testrec = !test.test_rec> +// CHECK-DAG: ![[$NAME:.*]] = !test.test_rec_alias +// CHECK-DAG: ![[$NAME2:.*]] = !test.test_rec_alias> +// CHECK-DAG: #[[$ATTR:.*]] = #test.test_rec_alias +// CHECK-DAG: #[[$ATTR2:.*]] = #test.test_rec_alias + + +!name = !test.test_rec_alias +!name2 = !test.test_rec_alias> + +#attr = #test.test_rec_alias +#array = [#attr2, 5] +#attr2 = #test.test_rec_alias // CHECK-LABEL: @roundtrip func.func @roundtrip() { @@ -24,6 +34,14 @@ func.func @roundtrip() { // CHECK: () -> ![[$NAME2]] "test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias, i32>> "test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias, i32>> + + // Check that we can use these aliases, not just print them. + // CHECK: value = #[[$ATTR]] + // CHECK-SAME: () -> ![[$NAME]] + // CHECK-NEXT: value = #[[$ATTR2]] + // CHECK-SAME: () -> ![[$NAME2]] + "test.dummy_op_for_roundtrip"() { value = #attr } : () -> !name + "test.dummy_op_for_roundtrip"() { value = #attr2 } : () -> !name2 return } diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td index ec0a5548a1603..26d99218286b7 100644 --- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td +++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td @@ -323,5 +323,22 @@ def Test_IteratorTypeArrayAttr : TypedArrayAttrBase; +def TestRecursiveAliasAttr + : Test_Attr<"TestRecursiveAlias", [NativeAttrTrait<"IsMutable">]> { + let mnemonic = "test_rec_alias"; + let storageClass = "TestRecursiveAttrStorage"; + let storageNamespace = "test"; + let genStorageClass = 0; + + let parameters = (ins "llvm::StringRef":$name); + + let hasCustomAssemblyFormat = 1; + + let extraClassDeclaration = [{ + Attribute getBody() const; + + void setBody(Attribute attribute); + }]; +} #endif // TEST_ATTRDEFS diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp index 7fc2e6ab3ec0a..e0e11fe7478f9 100644 --- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp +++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp @@ -175,6 +175,62 @@ static void printTrueFalse(AsmPrinter &p, std::optional result) { p << (*result ? "true" : "false"); } +//===----------------------------------------------------------------------===// +// TestRecursiveAttr +//===----------------------------------------------------------------------===// + +Attribute TestRecursiveAliasAttr::getBody() const { return getImpl()->body; } + +void TestRecursiveAliasAttr::setBody(Attribute attribute) { + (void)Base::mutate(attribute); +} + +StringRef TestRecursiveAliasAttr::getName() const { return getImpl()->name; } + +Attribute TestRecursiveAliasAttr::parse(AsmParser &parser, Type type) { + StringRef name; + if (parser.parseLess() || parser.parseKeyword(&name)) + return nullptr; + auto rec = TestRecursiveAliasAttr::get(parser.getContext(), name); + + FailureOr cyclicParse = + parser.tryStartCyclicParse(rec); + + // If this type already has been parsed above in the stack, expect just the + // name. + if (failed(cyclicParse)) { + if (failed(parser.parseGreater())) + return nullptr; + return rec; + } + + // Otherwise, parse the body and update the type. + if (failed(parser.parseComma())) + return nullptr; + Attribute subAttr; + if (parser.parseAttribute(subAttr)) + return nullptr; + if (!subAttr || failed(parser.parseGreater())) + return nullptr; + + rec.setBody(subAttr); + + return rec; +} + +void TestRecursiveAliasAttr::print(AsmPrinter &printer) const { + + FailureOr cyclicPrint = + printer.tryStartCyclicPrint(*this); + + printer << "<" << getName(); + if (succeeded(cyclicPrint)) { + printer << ", "; + printer << getBody(); + } + printer << ">"; +} + //===----------------------------------------------------------------------===// // Tablegen Generated Definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.h b/mlir/test/lib/Dialect/Test/TestAttributes.h index cc73e078bf7e2..d0f24f2738a4c 100644 --- a/mlir/test/lib/Dialect/Test/TestAttributes.h +++ b/mlir/test/lib/Dialect/Test/TestAttributes.h @@ -33,6 +33,37 @@ class TestDialect; /// A handle used to reference external elements instances. using TestDialectResourceBlobHandle = mlir::DialectResourceBlobHandle; + +/// Storage for simple named recursive attribute, where the attribute is +/// identified by its name and can "contain" another attribute, including +/// itself. +struct TestRecursiveAttrStorage : public ::mlir::AttributeStorage { + using KeyTy = ::llvm::StringRef; + + explicit TestRecursiveAttrStorage(::llvm::StringRef key) : name(key) {} + + bool operator==(const KeyTy &other) const { return name == other; } + + static TestRecursiveAttrStorage * + construct(::mlir::AttributeStorageAllocator &allocator, const KeyTy &key) { + return new (allocator.allocate()) + TestRecursiveAttrStorage(allocator.copyInto(key)); + } + + ::mlir::LogicalResult mutate(::mlir::AttributeStorageAllocator &allocator, + ::mlir::Attribute newBody) { + // Cannot set a different body than before. + if (body && body != newBody) + return ::mlir::failure(); + + body = newBody; + return ::mlir::success(); + } + + ::llvm::StringRef name; + ::mlir::Attribute body; +}; + } // namespace test #define GET_ATTRDEF_CLASSES diff --git a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp index 950af85007475..c34135e7f2792 100644 --- a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp @@ -169,6 +169,11 @@ struct TestOpAsmInterface : public OpAsmDialectInterface { //===------------------------------------------------------------------===// AliasResult getAlias(Attribute attr, raw_ostream &os) const final { + if (auto recAliasAttr = dyn_cast(attr)) { + os << recAliasAttr.getName(); + return AliasResult::FinalAlias; + } + StringAttr strAttr = dyn_cast(attr); if (!strAttr) return AliasResult::NoAlias; diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td index 2a8bdad8fb25d..9b685d30302ff 100644 --- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td +++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td @@ -373,7 +373,7 @@ def TestI32 : Test_Type<"TestI32"> { let mnemonic = "i32"; } -def TestRecursiveAlias +def TestRecursiveAliasType : Test_Type<"TestRecursiveAlias", [NativeTypeTrait<"IsMutable">]> { let mnemonic = "test_rec_alias"; let storageClass = "TestRecursiveTypeStorage";