diff --git a/mlir/docs/LangRef.md b/mlir/docs/LangRef.md index 0cfe845638c3c..a00f9f76c4253 100644 --- a/mlir/docs/LangRef.md +++ b/mlir/docs/LangRef.md @@ -183,12 +183,14 @@ starting with a `//` and going until the end of the line. ``` // Top level production -toplevel := (operation | attribute-alias-def | type-alias-def)* +toplevel := (operation | alias-block-def)* +alias-block-def := (attribute-alias-def | type-alias-def)* ``` The production `toplevel` is the top level production that is parsed by any parsing -consuming the MLIR syntax. [Operations](#operations), -[Attribute aliases](#attribute-value-aliases), and [Type aliases](#type-aliases) +consuming the MLIR syntax. [Operations](#operations) and +[Alias Blocks](#alias-block-definitions) consisting of +[Attribute aliases](#attribute-value-aliases) and [Type aliases](#type-aliases) can be declared on the toplevel. ### Identifiers and keywords @@ -880,3 +882,26 @@ version using readAttribute and readType methods. There is no restriction on what kind of information a dialect is allowed to encode to model its versioning. Currently, versioning is supported only for bytecode formats. + +## Alias Block Definitions + +An alias block is a list of subsequent attribute or type alias definitions that +are conceptually parsed as one unit. +This allows any alias definition within the block to reference any other alias +definition within the block, regardless if defined lexically later or earlier in +the block. + +```mlir +// Alias block consisting of #array, !integer_type and #integer_attr. +#array = [#integer_attr, !integer_type] +!integer_type = i32 +#integer_attr = 8 : !integer_type + +// Illegal. !other_type is not part of this alias block and defined later +// in the file. +!tuple = tuple + +func.func @foo() { ... } + +!other_type = f32 +``` diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp index 3437ac9addc5f..3453049a09bee 100644 --- a/mlir/lib/AsmParser/AttributeParser.cpp +++ b/mlir/lib/AsmParser/AttributeParser.cpp @@ -145,6 +145,10 @@ Attribute Parser::parseAttribute(Type type) { parseLocationInstance(locAttr) || parseToken(Token::r_paren, "expected ')' in inline location")) return Attribute(); + + if (syntaxOnly()) + return state.syntaxOnlyAttr; + return locAttr; } @@ -430,6 +434,9 @@ Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) { return FloatAttr::get(floatType, *result); } + if (syntaxOnly()) + return state.syntaxOnlyAttr; + if (!isa(type)) return emitError(loc, "integer literal not valid for specified type"), nullptr; @@ -1003,7 +1010,9 @@ Attribute Parser::parseDenseElementsAttr(Type attrType) { auto type = parseElementsLiteralType(attrType); if (!type) return nullptr; - return literalParser.getAttr(loc, type); + if (syntaxOnly()) + return state.syntaxOnlyAttr; + return literalParser.getAttr(loc, cast(type)); } Attribute Parser::parseDenseResourceElementsAttr(Type attrType) { @@ -1030,6 +1039,9 @@ Attribute Parser::parseDenseResourceElementsAttr(Type attrType) { return nullptr; } + if (syntaxOnly()) + return state.syntaxOnlyAttr; + ShapedType shapedType = dyn_cast(attrType); if (!shapedType) { emitError(typeLoc, "`dense_resource` expected a shaped type"); @@ -1044,7 +1056,7 @@ Attribute Parser::parseDenseResourceElementsAttr(Type attrType) { /// elements-literal-type ::= vector-type | ranked-tensor-type /// /// This method also checks the type has static shape. -ShapedType Parser::parseElementsLiteralType(Type type) { +Type Parser::parseElementsLiteralType(Type type) { // If the user didn't provide a type, parse the colon type for the literal. if (!type) { if (parseToken(Token::colon, "expected ':'")) @@ -1053,6 +1065,9 @@ ShapedType Parser::parseElementsLiteralType(Type type) { return nullptr; } + if (syntaxOnly()) + return state.syntaxOnlyType; + auto sType = dyn_cast(type); if (!sType) { emitError("elements literal must be a shaped type"); @@ -1077,17 +1092,23 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) { // of the type. Type indiceEltType = builder.getIntegerType(64); if (consumeIf(Token::greater)) { - ShapedType type = parseElementsLiteralType(attrType); + Type type = parseElementsLiteralType(attrType); if (!type) return nullptr; + if (syntaxOnly()) + return state.syntaxOnlyAttr; + // Construct the sparse elements attr using zero element indice/value // attributes. + ShapedType shapedType = cast(type); ShapedType indicesType = - RankedTensorType::get({0, type.getRank()}, indiceEltType); - ShapedType valuesType = RankedTensorType::get({0}, type.getElementType()); + RankedTensorType::get({0, shapedType.getRank()}, indiceEltType); + ShapedType valuesType = + RankedTensorType::get({0}, shapedType.getElementType()); return getChecked( - loc, type, DenseElementsAttr::get(indicesType, ArrayRef()), + loc, shapedType, + DenseElementsAttr::get(indicesType, ArrayRef()), DenseElementsAttr::get(valuesType, ArrayRef())); } @@ -1114,6 +1135,11 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) { if (!type) return nullptr; + if (syntaxOnly()) + return state.syntaxOnlyAttr; + + ShapedType shapedType = cast(type); + // If the indices are a splat, i.e. the literal parser parsed an element and // not a list, we set the shape explicitly. The indices are represented by a // 2-dimensional shape where the second dimension is the rank of the type. @@ -1121,7 +1147,8 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) { // indice and thus one for the first dimension. ShapedType indicesType; if (indiceParser.getShape().empty()) { - indicesType = RankedTensorType::get({1, type.getRank()}, indiceEltType); + indicesType = + RankedTensorType::get({1, shapedType.getRank()}, indiceEltType); } else { // Otherwise, set the shape to the one parsed by the literal parser. indicesType = RankedTensorType::get(indiceParser.getShape(), indiceEltType); @@ -1131,7 +1158,7 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) { // If the values are a splat, set the shape explicitly based on the number of // indices. The number of indices is encoded in the first dimension of the // indice shape type. - auto valuesEltType = type.getElementType(); + auto valuesEltType = shapedType.getElementType(); ShapedType valuesType = valuesParser.getShape().empty() ? RankedTensorType::get({indicesType.getDimSize(0)}, valuesEltType) @@ -1139,7 +1166,7 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) { auto values = valuesParser.getAttr(valuesLoc, valuesType); // Build the sparse elements attribute by the indices and values. - return getChecked(loc, type, indices, values); + return getChecked(loc, shapedType, indices, values); } Attribute Parser::parseStridedLayoutAttr() { @@ -1260,6 +1287,9 @@ Attribute Parser::parseDistinctAttr(Type type) { return {}; } + if (syntaxOnly()) + return state.syntaxOnlyAttr; + // Add the distinct attribute to the parser state, if it has not been parsed // before. Otherwise, check if the parsed reference attribute matches the one // found in the parser state. diff --git a/mlir/lib/AsmParser/DialectSymbolParser.cpp b/mlir/lib/AsmParser/DialectSymbolParser.cpp index 2b1b114b90e86..5330fbc9996ff 100644 --- a/mlir/lib/AsmParser/DialectSymbolParser.cpp +++ b/mlir/lib/AsmParser/DialectSymbolParser.cpp @@ -15,6 +15,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/DialectImplementation.h" +#include "llvm/ADT/ScopeExit.h" #include "llvm/Support/SourceMgr.h" using namespace mlir; @@ -156,9 +157,11 @@ ParseResult Parser::parseDialectSymbolBody(StringRef &body, } /// Parse an extended dialect symbol. -template +template static Symbol parseExtendedSymbol(Parser &p, AsmParserState *asmState, SymbolAliasMap &aliases, + ParseAliasFn &parseAliasFn, CreateFn &&createSymbol) { Token tok = p.getToken(); @@ -185,12 +188,32 @@ static Symbol parseExtendedSymbol(Parser &p, AsmParserState *asmState, // If there is no '<' token following this, and if the typename contains no // dot, then we are parsing a symbol alias. if (!hasTrailingData && !isPrettyName) { + + // Don't check the validity of alias reference in syntax-only mode. + if (p.syntaxOnly()) { + if constexpr (std::is_same_v) + return p.getState().syntaxOnlyType; + else + return p.getState().syntaxOnlyAttr; + } + // Check for an alias for this type. auto aliasIt = aliases.find(identifier); - if (aliasIt == aliases.end()) - return (p.emitWrongTokenError("undefined symbol alias id '" + identifier + - "'"), - nullptr); + if (aliasIt == aliases.end()) { + FailureOr symbol = failure(); + // Try the parse alias function if set. + if (parseAliasFn) + symbol = parseAliasFn(identifier); + + if (failed(symbol)) { + p.emitWrongTokenError("undefined symbol alias id '" + identifier + "'"); + return nullptr; + } + if (!*symbol) + return nullptr; + + aliasIt = aliases.insert({identifier, *symbol}).first; + } if (asmState) { if constexpr (std::is_same_v) asmState->addTypeAliasUses(identifier, range); @@ -241,12 +264,16 @@ Attribute Parser::parseExtendedAttr(Type type) { MLIRContext *ctx = getContext(); Attribute attr = parseExtendedSymbol( *this, state.asmState, state.symbols.attributeAliasDefinitions, + state.symbols.parseUnknownAttributeAlias, [&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Attribute { // Parse an optional trailing colon type. Type attrType = type; if (consumeIf(Token::colon) && !(attrType = parseType())) return Attribute(); + if (syntaxOnly()) + return state.syntaxOnlyAttr; + // If we found a registered dialect, then ask it to parse the attribute. if (Dialect *dialect = builder.getContext()->getOrLoadDialect(dialectName)) { @@ -288,7 +315,11 @@ Type Parser::parseExtendedType() { MLIRContext *ctx = getContext(); return parseExtendedSymbol( *this, state.asmState, state.symbols.typeAliasDefinitions, + state.symbols.parseUnknownTypeAlias, [&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Type { + if (syntaxOnly()) + return state.syntaxOnlyType; + // If we found a registered dialect, then ask it to parse the type. if (auto *dialect = ctx->getOrLoadDialect(dialectName)) { // Temporarily reset the lexer to let the dialect parse the type. diff --git a/mlir/lib/AsmParser/LocationParser.cpp b/mlir/lib/AsmParser/LocationParser.cpp index 61b20179800c6..8139f188c32a7 100644 --- a/mlir/lib/AsmParser/LocationParser.cpp +++ b/mlir/lib/AsmParser/LocationParser.cpp @@ -53,6 +53,9 @@ ParseResult Parser::parseCallSiteLocation(LocationAttr &loc) { if (parseToken(Token::r_paren, "expected ')' in callsite location")) return failure(); + if (syntaxOnly()) + return success(); + // Return the callsite location. loc = CallSiteLoc::get(calleeLoc, callerLoc); return success(); @@ -79,6 +82,9 @@ ParseResult Parser::parseFusedLocation(LocationAttr &loc) { LocationAttr newLoc; if (parseLocationInstance(newLoc)) return failure(); + if (syntaxOnly()) + return success(); + locations.push_back(newLoc); return success(); }; @@ -135,12 +141,15 @@ ParseResult Parser::parseNameOrFileLineColLocation(LocationAttr &loc) { if (parseLocationInstance(childLoc)) return failure(); - loc = NameLoc::get(StringAttr::get(ctx, str), childLoc); - // Parse the closing ')'. if (parseToken(Token::r_paren, "expected ')' after child location of NameLoc")) return failure(); + + if (syntaxOnly()) + return success(); + + loc = NameLoc::get(StringAttr::get(ctx, str), childLoc); } else { loc = NameLoc::get(StringAttr::get(ctx, str)); } @@ -154,6 +163,10 @@ ParseResult Parser::parseLocationInstance(LocationAttr &loc) { Attribute locAttr = parseExtendedAttr(Type()); if (!locAttr) return failure(); + + if (syntaxOnly()) + return success(); + if (!(loc = dyn_cast(locAttr))) return emitError("expected location attribute, but got") << locAttr; return success(); diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp index 02bf9a4180639..738485fb45372 100644 --- a/mlir/lib/AsmParser/Parser.cpp +++ b/mlir/lib/AsmParser/Parser.cpp @@ -27,6 +27,7 @@ #include "llvm/Support/Endian.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/PrettyStackTrace.h" +#include "llvm/Support/SaveAndRestore.h" #include "llvm/Support/SourceMgr.h" #include #include @@ -2403,17 +2404,12 @@ class TopLevelOperationParser : public Parser { ParseResult parse(Block *topLevelBlock, Location parserLoc); private: - /// Parse an attribute alias declaration. + /// Parse an alias block definition. /// /// attribute-alias-def ::= '#' alias-name `=` attribute-value - /// - ParseResult parseAttributeAliasDef(); - - /// Parse a type alias declaration. - /// /// type-alias-def ::= '!' alias-name `=` type - /// - ParseResult parseTypeAliasDef(); + /// alias-block-def ::= (type-alias-def | attribute-alias-def)+ + ParseResult parseAliasBlockDef(); /// Parse a top-level file metadata dictionary. /// @@ -2514,69 +2510,184 @@ class ParsedResourceEntry : public AsmParsedResourceEntry { Token value; Parser &p; }; + +/// Convenient subclass of `ParserState` which configures the parser for +/// syntax-only parsing. This only copies config and other required state for +/// parsing but does not copy side-effecting state such as the code completion +/// context. +struct SyntaxParserState : ParserState { + explicit SyntaxParserState(ParserState &state) + : ParserState(state.lex.getSourceMgr(), state.config, state.symbols, + /*asmState=*/nullptr, + /*codeCompleteContext=*/nullptr) { + syntaxOnly = true; + } +}; + } // namespace -ParseResult TopLevelOperationParser::parseAttributeAliasDef() { - assert(getToken().is(Token::hash_identifier)); - StringRef aliasName = getTokenSpelling().drop_front(); +ParseResult TopLevelOperationParser::parseAliasBlockDef() { - // Check for redefinitions. - if (state.symbols.attributeAliasDefinitions.count(aliasName) > 0) - return emitError("redefinition of attribute alias id '" + aliasName + "'"); + struct UnparsedData { + SMRange location; + StringRef text; + }; - // Make sure this isn't invading the dialect attribute namespace. - if (aliasName.contains('.')) - return emitError("attribute names with a '.' are reserved for " - "dialect-defined names"); + // Use a map vector as StringMap has non-deterministic iteration order. + using StringMapVector = + llvm::MapVector>; - SMRange location = getToken().getLocRange(); - consumeToken(Token::hash_identifier); + StringMapVector unparsedAttributeAliases; + StringMapVector unparsedTypeAliases; - // Parse the '='. - if (parseToken(Token::equal, "expected '=' in attribute alias definition")) - return failure(); + // Returns true if this alias has already been defined, either in this block + // or a previous one. + auto isRedefinition = [&](bool isType, StringRef aliasName) { + if (isType) + return state.symbols.typeAliasDefinitions.contains(aliasName) || + unparsedTypeAliases.contains(aliasName); - // Parse the attribute value. - Attribute attr = parseAttribute(); - if (!attr) - return failure(); + return state.symbols.attributeAliasDefinitions.contains(aliasName) || + unparsedAttributeAliases.contains(aliasName); + }; - // Register this alias with the parser state. - if (state.asmState) - state.asmState->addAttrAliasDefinition(aliasName, location, attr); - state.symbols.attributeAliasDefinitions[aliasName] = attr; - return success(); -} + // Collect all attribute or type alias definitions in unparsed form first. + while ( + getToken().isAny(Token::exclamation_identifier, Token::hash_identifier)) { + StringRef aliasName = getTokenSpelling().drop_front(); + + bool isType = getToken().is(mlir::Token::exclamation_identifier); + StringRef kind = isType ? "type" : "attribute"; + + // Check for redefinitions. + if (isRedefinition(isType, aliasName)) + return emitError("redefinition of ") + << kind << " alias id '" << aliasName << "'"; + + // Make sure this isn't invading the dialect namespace. + if (aliasName.contains('.')) + return emitError(kind) << " names with a '.' are reserved for " + "dialect-defined names"; + + SMRange location = getToken().getLocRange(); + consumeToken(); + + // Parse the '='. + if (parseToken(Token::equal, + "expected '=' in " + kind + " alias definition")) + return failure(); -ParseResult TopLevelOperationParser::parseTypeAliasDef() { - assert(getToken().is(Token::exclamation_identifier)); - StringRef aliasName = getTokenSpelling().drop_front(); + SyntaxParserState skippingParserState(state); + Parser syntaxOnlyParser(skippingParserState); + const char *start = getToken().getLoc().getPointer(); + syntaxOnlyParser.resetToken(start); - // Check for redefinitions. - if (state.symbols.typeAliasDefinitions.count(aliasName) > 0) - return emitError("redefinition of type alias id '" + aliasName + "'"); + // Parse just the syntax of the value, moving the lexer past the definition. + if (isType ? !syntaxOnlyParser.parseType() + : !syntaxOnlyParser.parseAttribute()) + return failure(); + + // Get the location from the lexers new position. + const char *end = syntaxOnlyParser.getToken().getLoc().getPointer(); + size_t length = end - start; - // Make sure this isn't invading the dialect type namespace. - if (aliasName.contains('.')) - return emitError("type names with a '.' are reserved for " - "dialect-defined names"); + StringMapVector &unparsedMap = + isType ? unparsedTypeAliases : unparsedAttributeAliases; - SMRange location = getToken().getLocRange(); - consumeToken(Token::exclamation_identifier); + unparsedMap[aliasName] = + UnparsedData{location, StringRef(start, length).rtrim()}; + + // Move the top-level parser past the alias definition. + resetToken(end); + } + + auto parseAttributeAlias = [&](StringRef aliasName, + const UnparsedData &unparsedData) { + llvm::SaveAndRestore> cyclicStack( + getState().cyclicParsingStack, {}); + auto exit = saveAndResetToken(unparsedData.text.data()); + Attribute attribute = parseAttribute(); + if (!attribute) + return attribute; + + // Register this alias with the parser state. + if (state.asmState) + state.asmState->addAttrAliasDefinition(aliasName, unparsedData.location, + attribute); - // Parse the '='. - if (parseToken(Token::equal, "expected '=' in type alias definition")) + return attribute; + }; + + auto parseTypeAlias = [&](StringRef aliasName, + const UnparsedData &unparsedData) { + llvm::SaveAndRestore> cyclicStack( + getState().cyclicParsingStack, {}); + auto exit = saveAndResetToken(unparsedData.text.data()); + Type type = parseType(); + if (!type) + return type; + + // Register this alias with the parser state. + if (state.asmState) + state.asmState->addTypeAliasDefinition(aliasName, unparsedData.location, + type); + + return type; + }; + + // Set the callbacks for the lazy parsing of alias definitions in the parser. + state.symbols.parseUnknownAttributeAlias = + [&](StringRef aliasName) -> FailureOr { + auto *iter = unparsedAttributeAliases.find(aliasName); + if (iter == unparsedAttributeAliases.end()) + return failure(); + + return parseAttributeAlias(aliasName, iter->second); + }; + state.symbols.parseUnknownTypeAlias = + [&](StringRef aliasName) -> FailureOr { + auto *iter = unparsedTypeAliases.find(aliasName); + if (iter == unparsedTypeAliases.end()) + return failure(); + + return parseTypeAlias(aliasName, iter->second); + }; + + // Reset them to nullptr at the end. Keeping them around would lead to the + // access of local variables captured in this scope after we've returned from + // this function. + auto exit = llvm::make_scope_exit([&] { + state.symbols.parseUnknownTypeAlias = nullptr; + state.symbols.parseUnknownAttributeAlias = nullptr; + }); + + // Now go through all the unparsed definitions in the block and parse them. + // The order here is not significant for correctness, but should be + // deterministic. The order can also have an impact on the maximum stack usage + // during parsing. This can be improved in the future. + auto parse = [](auto &unparsed, auto &definitions, auto &parseFn) { + for (auto &&[aliasName, unparsedData] : unparsed) { + // Avoid parsing twice. + if (definitions.contains(aliasName)) + continue; + + auto symbol = parseFn(aliasName, unparsedData); + if (!symbol) + return failure(); + definitions[aliasName] = symbol; + } + return success(); + }; + + if (failed(parse(unparsedAttributeAliases, + state.symbols.attributeAliasDefinitions, + parseAttributeAlias))) return failure(); - // Parse the type. - Type aliasedType = parseType(); - if (!aliasedType) + if (failed(parse(unparsedTypeAliases, state.symbols.typeAliasDefinitions, + parseTypeAlias))) return failure(); - // Register this alias with the parser state. - if (state.asmState) - state.asmState->addTypeAliasDefinition(aliasName, location, aliasedType); - state.symbols.typeAliasDefinitions.try_emplace(aliasName, aliasedType); return success(); } @@ -2715,15 +2826,10 @@ ParseResult TopLevelOperationParser::parse(Block *topLevelBlock, case Token::error: return failure(); - // Parse an attribute alias. + // Parse an alias def block. case Token::hash_identifier: - if (parseAttributeAliasDef()) - return failure(); - break; - - // Parse a type alias. case Token::exclamation_identifier: - if (parseTypeAliasDef()) + if (parseAliasBlockDef()) return failure(); break; diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h index 01c55f97a08c2..282c0b13aebd9 100644 --- a/mlir/lib/AsmParser/Parser.h +++ b/mlir/lib/AsmParser/Parser.h @@ -12,6 +12,7 @@ #include "ParserState.h" #include "mlir/IR/Builders.h" #include "mlir/IR/OpImplementation.h" +#include "llvm/ADT/ScopeExit.h" #include namespace mlir { @@ -132,6 +133,22 @@ class Parser { state.curToken = state.lex.lexToken(); } + /// Temporarily resets the parser to the given lexer position. The previous + /// lexer position is saved and restored on destruction of the returned + /// object. + [[nodiscard]] auto saveAndResetToken(const char *tokPos) { + const char *previous = getToken().getLoc().getPointer(); + resetToken(tokPos); + return llvm::make_scope_exit([this, previous] { resetToken(previous); }); + } + + /// Returns true if the parser is in syntax-only mode. In this mode, the + /// parser only checks the syntactic validity of the parsed elements but does + /// not verify the correctness of the parsed data. Syntax-only mode is + /// currently only supported for attribute and type parsing and skips parsing + /// dialect attributes and types entirely. + bool syntaxOnly() { return state.syntaxOnly; } + /// Consume the specified token if present and return success. On failure, /// output a diagnostic and return failure. ParseResult parseToken(Token::Kind expectedToken, const Twine &message); @@ -209,7 +226,7 @@ class Parser { Type parseTupleType(); /// Parse a vector type. - VectorType parseVectorType(); + Type parseVectorType(); ParseResult parseVectorDimensionList(SmallVectorImpl &dimensions, SmallVectorImpl &scalableDims); ParseResult parseDimensionListRanked(SmallVectorImpl &dimensions, @@ -265,7 +282,7 @@ class Parser { /// Parse a dense elements attribute. Attribute parseDenseElementsAttr(Type attrType); - ShapedType parseElementsLiteralType(Type type); + Type parseElementsLiteralType(Type type); /// Parse a dense resource elements attribute. Attribute parseDenseResourceElementsAttr(Type attrType); diff --git a/mlir/lib/AsmParser/ParserState.h b/mlir/lib/AsmParser/ParserState.h index 1428ea3a82cee..17602f84a9083 100644 --- a/mlir/lib/AsmParser/ParserState.h +++ b/mlir/lib/AsmParser/ParserState.h @@ -32,6 +32,15 @@ struct SymbolState { /// A map from type alias identifier to Type. llvm::StringMap typeAliasDefinitions; + /// Parser functions set during the parsing of alias-block-defs to parse an + /// unknown attribute or type alias. The parameter is the name of the alias. + /// The function should return failure if no such alias could be found. + /// If any errors occurred during parsing, a null attribute or type should + /// be returned. + llvm::unique_function(StringRef)> + parseUnknownAttributeAlias; + llvm::unique_function(StringRef)> parseUnknownTypeAlias; + /// A map of dialect resource keys to the resolved resource name and handle /// to use during parsing. DenseMap defaultDialectStack{"builtin"}; + + /// Controls whether the parser is in syntax-only mode. + bool syntaxOnly = false; + + /// Attribute and type returned by `parseType`, `parseAttribute` and the more + /// specific parsing function to signal syntactic correctness if an attribute + /// or type cannot be created without verifying the parsed data as well. + /// Callers of such function should only check for null or not null return + /// values for error signaling. + Type syntaxOnlyType = NoneType::get(config.getContext()); + Attribute syntaxOnlyAttr = UnitAttr::get(config.getContext()); }; } // namespace detail diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp index 306e850af27bc..77e87cf9b0bef 100644 --- a/mlir/lib/AsmParser/TypeParser.cpp +++ b/mlir/lib/AsmParser/TypeParser.cpp @@ -130,6 +130,10 @@ Type Parser::parseComplexType() { if (!elementType || parseToken(Token::greater, "expected '>' in complex type")) return nullptr; + + if (syntaxOnly()) + return state.syntaxOnlyType; + if (!isa(elementType) && !isa(elementType)) return emitError(elementTypeLoc, "invalid element type for complex"), nullptr; @@ -150,6 +154,9 @@ Type Parser::parseFunctionType() { parseFunctionResultTypes(results)) return nullptr; + if (syntaxOnly()) + return state.syntaxOnlyType; + return builder.getFunctionType(arguments, results); } @@ -195,9 +202,10 @@ Type Parser::parseMemRefType() { if (!elementType) return nullptr; - // Check that memref is formed from allowed types. - if (!BaseMemRefType::isValidElementType(elementType)) - return emitError(typeLoc, "invalid memref element type"), nullptr; + if (!syntaxOnly()) { // Check that memref is formed from allowed types. + if (!BaseMemRefType::isValidElementType(elementType)) + return emitError(typeLoc, "invalid memref element type"), nullptr; + } MemRefLayoutAttrInterface layout; Attribute memorySpace; @@ -208,6 +216,9 @@ Type Parser::parseMemRefType() { if (!attr) return failure(); + if (syntaxOnly()) + return success(); + if (isa(attr)) { layout = cast(attr); } else if (memorySpace) { @@ -235,6 +246,9 @@ Type Parser::parseMemRefType() { } } + if (syntaxOnly()) + return state.syntaxOnlyType; + if (isUnranked) return getChecked(loc, elementType, memorySpace); @@ -437,7 +451,7 @@ Type Parser::parseTupleType() { /// vector-dim-list := (static-dim-list `x`)? (`[` static-dim-list `]` `x`)? /// static-dim-list ::= decimal-literal (`x` decimal-literal)* /// -VectorType Parser::parseVectorType() { +Type Parser::parseVectorType() { consumeToken(Token::kw_vector); if (parseToken(Token::less, "expected '<' in vector type")) @@ -458,6 +472,9 @@ VectorType Parser::parseVectorType() { if (!elementType || parseToken(Token::greater, "expected '>' in vector type")) return nullptr; + if (syntaxOnly()) + return state.syntaxOnlyType; + if (!VectorType::isValidElementType(elementType)) return emitError(typeLoc, "vector elements must be int/index/float type"), nullptr; diff --git a/mlir/test/IR/alias-def-groups.mlir b/mlir/test/IR/alias-def-groups.mlir new file mode 100644 index 0000000000000..71d09371fae67 --- /dev/null +++ b/mlir/test/IR/alias-def-groups.mlir @@ -0,0 +1,25 @@ +// RUN: mlir-opt -allow-unregistered-dialect -verify-diagnostics -split-input-file %s | FileCheck %s + +#array = [#integer_attr, !integer_type] +!integer_type = i32 +#integer_attr = 8 : !integer_type + +// CHECK-LABEL: func @foo() +func.func @foo() { + // CHECK-NEXT: value = [8 : i32, i32] + "foo.attr"() { value = #array} : () -> () +} + +// ----- + +// Check that only groups may reference later defined aliases. + +// expected-error@below {{undefined symbol alias id 'integer_attr'}} +#array = [!integer_type, #integer_attr] +!integer_type = i32 + +func.func @foo() { + %0 = "foo.attr"() { value = #array} +} + +#integer_attr = 8 : !integer_type