diff --git a/mlir/docs/DefiningDialects/Operations.md b/mlir/docs/DefiningDialects/Operations.md index 528070cd3ebff..fafda816a3881 100644 --- a/mlir/docs/DefiningDialects/Operations.md +++ b/mlir/docs/DefiningDialects/Operations.md @@ -1498,22 +1498,17 @@ optionality, default values, etc.: * `AllAttrOf`: adapts an attribute with [multiple constraints](#combining-constraints). -### Enum attributes +## Enum definition -Some attributes can only take values from a predefined enum, e.g., the -comparison kind of a comparison op. To define such attributes, ODS provides -several mechanisms: `IntEnumAttr`, and `BitEnumAttr`. +MLIR is capabable of generating C++ enums, both those that represent a set +of values drawn from a list or that can hold a combination of flags +using the `IntEnum` and `BitEnum` classes, respectively. -* `IntEnumAttr`: each enum case is an integer, the attribute is stored as a - [`IntegerAttr`][IntegerAttr] in the op. -* `BitEnumAttr`: each enum case is a either the empty case, a single bit, - or a group of single bits, and the attribute is stored as a - [`IntegerAttr`][IntegerAttr] in the op. - -All these `*EnumAttr` attributes require fully specifying all of the allowed -cases via their corresponding `*EnumAttrCase`. With this, ODS is able to +All these `IntEnum` and `BitEnum` classes require fully specifying all of the allowed +cases via a `EnumCase` or `BitEnumCase` subclass, respectively. With this, ODS is able to generate additional verification to only accept allowed cases. To facilitate the -interaction between `*EnumAttr`s and their C++ consumers, the +interaction between tablegen enums and the attributes or properties that wrap them and +to make them easier to use in C++, the [`EnumsGen`][EnumsGen] TableGen backend can generate a few common utilities: a C++ enum class, `llvm::DenseMapInfo` for the enum class, conversion functions from/to strings. This is controlled via the `-gen-enum-decls` and @@ -1522,10 +1517,10 @@ from/to strings. This is controlled via the `-gen-enum-decls` and For example, given the following `EnumAttr`: ```tablegen -def Case15: I32EnumAttrCase<"Case15", 15>; -def Case20: I32EnumAttrCase<"Case20", 20>; +def Case15: I32EnumCase<"Case15", 15>; +def Case20: I32EnumCase<"Case20", 20>; -def MyIntEnum: I32EnumAttr<"MyIntEnum", "An example int enum", +def MyIntEnum: I32Enum<"MyIntEnum", "An example int enum", [Case15, Case20]> { let cppNamespace = "Outer::Inner"; let stringToSymbolFnName = "ConvertToEnum"; @@ -1611,14 +1606,17 @@ std::optional symbolizeMyIntEnum(uint32_t value) { Similarly for the following `BitEnumAttr` definition: ```tablegen -def None: I32BitEnumAttrCaseNone<"None">; -def Bit0: I32BitEnumAttrCaseBit<"Bit0", 0, "tagged">; -def Bit1: I32BitEnumAttrCaseBit<"Bit1", 1>; -def Bit2: I32BitEnumAttrCaseBit<"Bit2", 2>; -def Bit3: I32BitEnumAttrCaseBit<"Bit3", 3>; - -def MyBitEnum: BitEnumAttr<"MyBitEnum", "An example bit enum", - [None, Bit0, Bit1, Bit2, Bit3]>; +def None: I32BitEnumCaseNone<"None">; +def Bit0: I32BitEnumCaseBit<"Bit0", 0, "tagged">; +def Bit1: I32BitEnumCaseBit<"Bit1", 1>; +def Bit2: I32BitEnumCaseBit<"Bit2", 2>; +def Bit3: I32BitEnumCaseBit<"Bit3", 3>; + +def MyBitEnum: I32BitEnum<"MyBitEnum", "An example bit enum", + [None, Bit0, Bit1, Bit2, Bit3]> { + // Note: this is the default value, and is listed for illustrative purposes. + let separator = "|"; +} ``` We can have: @@ -1738,6 +1736,26 @@ std::optional symbolizeMyBitEnum(uint32_t value) { } ``` +### Wrapping enums in attributes + +There are several mechanisms for creating an `Attribute` whose values are +taken from a `*Enum`. + +The most common of these is to use the `EnumAttr` class, which takes +an `EnumInfo` (either a `IntEnum` or `BitEnum`) as a parameter and constructs +an attribute that holds one argument - value of the enum. This attribute +is defined within a dialect and can have its assembly format customized to, +for example, print angle brackets around the enum value or assign a mnemonic. + +An older form involves using the `*IntEnumAttr` and `*BitEnumATtr` classes +and their corresponding `*EnumAttrCase` classes (which can be used +anywhere a `*EnumCase` is needed). These classes store their values +as a `SignlessIntegerAttr` of their bitwidth, imposing the constraint on it +that it has a value within the valid range of the enum. If their +`genSpecializedAttr` parameter is set, they will also generate a +wrapper attribute instead of using a bare signless integer attribute +for storage. + ## Debugging Tips ### Run `mlir-tblgen` to see the generated content diff --git a/mlir/include/mlir/IR/EnumAttr.td b/mlir/include/mlir/IR/EnumAttr.td index 9fec28f03ec28..e5406546b1950 100644 --- a/mlir/include/mlir/IR/EnumAttr.td +++ b/mlir/include/mlir/IR/EnumAttr.td @@ -14,8 +14,8 @@ include "mlir/IR/AttrTypeBase.td" //===----------------------------------------------------------------------===// // Enum attribute kinds -// Additional information for an enum attribute case. -class EnumAttrCaseInfo { +// Additional information for an enum case. +class EnumCase { // The C++ enumerant symbol. string symbol = sym; @@ -26,29 +26,56 @@ class EnumAttrCaseInfo { // The string representation of the enumerant. May be the same as symbol. string str = strVal; + + // The bitwidth of the enum. + int width = widthVal; } // An enum attribute case stored with IntegerAttr, which has an integer value, // its representation as a string and a C++ symbol name which may be different. +// Not needed when using the newer `EnumCase` form for defining enum cases. class IntEnumAttrCaseBase : - EnumAttrCaseInfo, + EnumCase, SignlessIntegerAttrBase { let predicate = CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getInt() == " # intVal>; } -// Cases of integer enum attributes with a specific type. By default, the string +// Cases of integer enums with a specific type. By default, the string // representation is the same as the C++ symbol name. +class I32EnumCase + : EnumCase; +class I64EnumCase + : EnumCase; + +// Cases of integer enum attributes with a specific type. By default, the string +// representation is the same as the C++ symbol name. These forms +// are not needed when using the newer `EnumCase` form. class I32EnumAttrCase : IntEnumAttrCaseBase; class I64EnumAttrCase : IntEnumAttrCaseBase; -// A bit enum case stored with an IntegerAttr. `val` here is *not* the ordinal -// number of a bit that is set. It is an integer value with bits set to match -// the case. +// A bit enum case. `val` here is *not* the ordinal number of a bit +// that is set. It is an integer value with bits set to match the case. +class BitEnumCaseBase : + EnumCase; +// Bit enum attr cases. The string representation is the same as the C++ symbol +// name unless otherwise specified. +class I8BitEnumCase + : BitEnumCaseBase; +class I16BitEnumCase + : BitEnumCaseBase; +class I32BitEnumCase + : BitEnumCaseBase; +class I64BitEnumCase + : BitEnumCaseBase; + +// A form of `BitEnumCaseBase` that also inherits from `Attr` and encodes +// the width of the enum, which was defined when enums were always +// stored in attributes. class BitEnumAttrCaseBase : - EnumAttrCaseInfo, + BitEnumCaseBase, SignlessIntegerAttrBase; class I8BitEnumAttrCase @@ -61,6 +88,19 @@ class I64BitEnumAttrCase : BitEnumAttrCaseBase; // The special bit enum case with no bits set (i.e. value = 0). +class BitEnumCaseNone + : BitEnumCaseBase; + +class I8BitEnumCaseNone + : BitEnumCaseNone; +class I16BitEnumCaseNone + : BitEnumCaseNone; +class I32BitEnumCaseNone + : BitEnumCaseNone; +class I64BitEnumCaseNone + : BitEnumCaseNone; + +// Older forms, used when enums were necessarily attributes. class I8BitEnumAttrCaseNone : I8BitEnumAttrCase; class I16BitEnumAttrCaseNone @@ -70,6 +110,24 @@ class I32BitEnumAttrCaseNone class I64BitEnumAttrCaseNone : I64BitEnumAttrCase; +// A bit enum case for a single bit, specified by a bit position `pos`. +// The `pos` argument refers to the index of the bit, and is limited +// to be in the range [0, width). +class BitEnumCaseBit + : BitEnumCaseBase { + assert !and(!ge(pos, 0), !lt(pos, width)), + "bit position larger than underlying storage"; +} + +class I8BitEnumCaseBit + : BitEnumCaseBit; +class I16BitEnumCaseBit + : BitEnumCaseBit; +class I32BitEnumCaseBit + : BitEnumCaseBit; +class I64BitEnumCaseBit + : BitEnumCaseBit; + // A bit enum case for a single bit, specified by a bit position. // The pos argument refers to the index of the bit, and is limited // to be in the range [0, bitwidth). @@ -90,12 +148,17 @@ class I64BitEnumAttrCaseBit // A bit enum case for a group/list of previously declared cases, providing // a convenient alias for that group. +class BitEnumCaseGroup cases, string str = sym> + : BitEnumCaseBase; + +// The attribute-only form of `BitEnumCaseGroup`. class BitEnumAttrCaseGroup cases, string str = sym> + list cases, string str = sym> : BitEnumAttrCaseBase; - class I8BitEnumAttrCaseGroup cases, string str = sym> : BitEnumAttrCaseGroup; @@ -109,29 +172,36 @@ class I64BitEnumAttrCaseGroup cases, string str = sym> : BitEnumAttrCaseGroup; -// Additional information for an enum attribute. -class EnumAttrInfo< - string name, list cases, Attr baseClass> : - Attr { - +// Information describing an enum and the functions that should be generated for it. +class EnumInfo cases, int width> { + string summary = summaryValue; // Generate a description of this enums members for the MLIR docs. - let description = + string description = "Enum cases:\n" # !interleave( !foreach(case, cases, "* " # case.str # " (`" # case.symbol # "`)"), "\n"); + // The C++ namespace for this enum + string cppNamespace = ""; + // The C++ enum class name string className = name; + // C++ type wrapped by attribute + string cppType = cppNamespace # "::" # className; + // List of all accepted cases - list enumerants = cases; + list enumerants = cases; // The following fields are only used by the EnumsGen backend to generate // an enum class definition and conversion utility functions. + // The bitwidth underlying the class + int bitwidth = width; + // The underlying type for the C++ enum class. An empty string mean the // underlying type is not explicitly specified. - string underlyingType = ""; + string underlyingType = "uint" # width # "_t"; // The name of the utility function that converts a value of the underlying // type to the corresponding symbol. It will have the following signature: @@ -165,6 +235,15 @@ class EnumAttrInfo< // static constexpr unsigned (); // ``` string maxEnumValFnName = "getMaxEnumValFor" # name; +} + +// A wrapper around `EnumInfo` that also makes the Enum an attribute +// if `genSeecializedAttr` is 1 (though `EnumAttr` is the preferred means +// to accomplish this) or declares that the enum will be stored in an attribute. +class EnumAttrInfo< + string name, list cases, SignlessIntegerAttrBase baseClass> : + EnumInfo(baseClass.valueType).bitwidth>, + Attr { // Generate specialized Attribute class bit genSpecializedAttr = 1; @@ -188,15 +267,25 @@ class EnumAttrInfo< baseAttrClass.constBuilderCall); let valueType = baseAttrClass.valueType; - // C++ type wrapped by attribute - string cppType = cppNamespace # "::" # className; - // Parser and printer code used by the EnumParameter class, to be provided by // derived classes string parameterParser = ?; string parameterPrinter = ?; } +// An attribute holding a single integer value. +class IntEnum cases, int width> + : EnumInfo; + +class I32Enum cases> + : IntEnum; +class I64Enum cases> + : IntEnum; + // An enum attribute backed by IntegerAttr. // // Op attributes of this kind are stored as IntegerAttr. Extra verification will @@ -245,13 +334,73 @@ class I64EnumAttr cases> : let underlyingType = "uint64_t"; } +// The base mixin for bit enums that are stored as an integer. +// This is used by both BitEnum and BitEnumAttr, which need to have a set of +// extra properties that bit enums have which normal enums don't. However, +// we can't just use BitEnum as a base class of BitEnumAttr, since BitEnumAttr +// also inherits from EnumAttrInfo, causing double inheritance of EnumInfo. +class BitEnumBase cases> { + // Determine "valid" bits from enum cases for error checking + int validBits = !foldl(0, cases, value, bitcase, !or(value, bitcase.value)); + + // The delimiter used to separate bit enum cases in strings. Only "|" and + // "," (along with optional spaces) are supported due to the use of the + // parseSeparatorFn in parameterParser below. + // Spaces in the separator string are used for printing, but will be optional + // for parsing. + string separator = "|"; + assert !or(!ge(!find(separator, "|"), 0), !ge(!find(separator, ","), 0)), + "separator must contain '|' or ',' for parameter parsing"; + + // Print the "primary group" only for bits that are members of case groups + // that have all bits present. When the value is 0, printing will display both + // both individual bit case names AND the names for all groups that the bit is + // contained in. When the value is 1, for each bit that is set AND is a member + // of a group with all bits set, only the "primary group" (i.e. the first + // group with all bits set in reverse declaration order) will be printed (for + // conciseness). + bit printBitEnumPrimaryGroups = 0; + + // 1 if the operator<< for this enum should put quotes around values with + // multiple entries. Off by default in the general case but on for BitEnumAttrs + // since that was the original behavior. + bit printBitEnumQuoted = 0; +} + +// A bit enum stored as an integer. +// +// Enums of these kind are staored as an integer. Attributes or properties deriving +// from this enum will have additional verification generated on them to make sure +// only allowed bits are set. Helper methods are generated to parse a sring of enum +// values generated by the specified separator to a symbol and vice versa. +class BitEnum cases, int width> + : EnumInfo, BitEnumBase { + // We need to return a string because we may concatenate symbols for multiple + // bits together. + let symbolToStringFnRetType = "std::string"; +} + +class I8BitEnum cases> + : BitEnum; +class I16BitEnum cases> + : BitEnum; +class I32BitEnum cases> + : BitEnum; + +class I64BitEnum cases> + : BitEnum; + // A bit enum stored with an IntegerAttr. // // Op attributes of this kind are stored as IntegerAttr. Extra verification will // be generated on the integer to make sure only allowed bits are set. Besides, // helper methods are generated to parse a string separated with a specified // delimiter to a symbol and vice versa. -class BitEnumAttrBase cases, +class BitEnumAttrBase cases, string summary> : SignlessIntegerAttrBase { let predicate = And<[ @@ -264,24 +413,13 @@ class BitEnumAttrBase cases, } class BitEnumAttr cases> - : EnumAttrInfo> { - // Determine "valid" bits from enum cases for error checking - int validBits = !foldl(0, cases, value, bitcase, !or(value, bitcase.value)); - + list cases> + : EnumAttrInfo>, + BitEnumBase { // We need to return a string because we may concatenate symbols for multiple // bits together. let symbolToStringFnRetType = "std::string"; - // The delimiter used to separate bit enum cases in strings. Only "|" and - // "," (along with optional spaces) are supported due to the use of the - // parseSeparatorFn in parameterParser below. - // Spaces in the separator string are used for printing, but will be optional - // for parsing. - string separator = "|"; - assert !or(!ge(!find(separator, "|"), 0), !ge(!find(separator, ","), 0)), - "separator must contain '|' or ',' for parameter parsing"; - // Parsing function that corresponds to the enum separator. Only // "," and "|" are supported by this definition. string parseSeparatorFn = !if(!ge(!find(separator, "|"), 0), @@ -312,36 +450,30 @@ class BitEnumAttr cases> + list cases> : BitEnumAttr { let underlyingType = "uint8_t"; } class I16BitEnumAttr cases> + list cases> : BitEnumAttr { let underlyingType = "uint16_t"; } class I32BitEnumAttr cases> + list cases> : BitEnumAttr { let underlyingType = "uint32_t"; } class I64BitEnumAttr cases> + list cases> : BitEnumAttr { let underlyingType = "uint64_t"; } @@ -349,11 +481,13 @@ class I64BitEnumAttr +class EnumParameter : AttrParameter { - let parser = enumInfo.parameterParser; - let printer = enumInfo.parameterPrinter; + let parser = !if(!isa(enumInfo), + !cast(enumInfo).parameterParser, ?); + let printer = !if(!isa(enumInfo), + !cast(enumInfo).parameterPrinter, ?); } // An attribute backed by a C++ enum. The attribute contains a single @@ -384,14 +518,14 @@ class EnumParameter // The op will appear in the IR as `my_dialect.my_op first`. However, the // generic format of the attribute will be `#my_dialect<"enum first">`. Override // the attribute's assembly format as required. -class EnumAttr traits = []> : AttrDef { let summary = enumInfo.summary; let description = enumInfo.description; // The backing enumeration. - EnumAttrInfo enum = enumInfo; + EnumInfo enum = enumInfo; // Inherit the C++ namespace from the enum. let cppNamespace = enumInfo.cppNamespace; @@ -417,41 +551,42 @@ class EnumAttr { +class _symbolToValue { defvar cases = - !filter(iter, enumAttrInfo.enumerants, !eq(iter.str, case)); + !filter(iter, enumInfo.enumerants, !eq(iter.str, case)); assert !not(!empty(cases)), "failed to find enum-case '" # case # "'"; // `!empty` check to not cause an error if the cases are empty. // The assertion catches the issue later and emits a proper error message. - string value = enumAttrInfo.cppType # "::" + string value = enumInfo.cppType # "::" # !if(!empty(cases), "", !head(cases).symbol); } -class _bitSymbolsToValue { +class _bitSymbolsToValue { + assert !isa(bitEnum), "_bitSymbolsToValue not given a bit enum"; defvar pos = !find(case, "|"); // Recursive instantiation looking up the symbol before the `|` in // enum cases. string value = !if( - !eq(pos, -1), /*baseCase=*/_symbolToValue.value, - /*rec=*/_symbolToValue.value # "|" - # _bitSymbolsToValue.value + !eq(pos, -1), /*baseCase=*/_symbolToValue.value, + /*rec=*/_symbolToValue.value # "|" + # _bitSymbolsToValue.value ); } class ConstantEnumCaseBase + EnumInfo enumInfo, string case> : ConstantAttr(enumAttrInfo), - _bitSymbolsToValue(enumAttrInfo), case>.value, - _symbolToValue.value + !if(!isa(enumInfo), + _bitSymbolsToValue.value, + _symbolToValue.value ) >; /// Attribute constraint matching a constant enum case. `attribute` should be -/// one of `EnumAttrInfo` or `EnumAttr` and `symbol` the string representation +/// one of `EnumInfo` or `EnumAttr` and `symbol` the string representation /// of an enum case. Multiple enum values of a bit-enum can be combined using /// `|` as a separator. Note that there mustn't be any whitespace around the /// separator. @@ -463,10 +598,10 @@ class ConstantEnumCaseBase class ConstantEnumCase : ConstantEnumCaseBase(attribute), !cast(attribute), + !if(!isa(attribute), !cast(attribute), !cast(attribute).enum), case> { - assert !or(!isa(attribute), !isa(attribute)), - "attribute must be one of 'EnumAttr' or 'EnumAttrInfo'"; + assert !or(!isa(attribute), !isa(attribute)), + "attribute must be one of 'EnumAttr' or 'EnumInfo'"; } #endif // ENUMATTR_TD diff --git a/mlir/include/mlir/TableGen/EnumInfo.h b/mlir/include/mlir/TableGen/EnumInfo.h index 5bc7ffb6a8a35..aea76c01e7e7a 100644 --- a/mlir/include/mlir/TableGen/EnumInfo.h +++ b/mlir/include/mlir/TableGen/EnumInfo.h @@ -84,6 +84,9 @@ class EnumInfo { // Returns the description of the enum. StringRef getDescription() const; + // Returns the bitwidth of the enum. + int64_t getBitwidth() const; + // Returns the underlying type. StringRef getUnderlyingType() const; @@ -119,6 +122,7 @@ class EnumInfo { // Only applicable for bit enums. bool printBitEnumPrimaryGroups() const; + bool printBitEnumQuoted() const; // Returns the TableGen definition this EnumAttrCase was constructed from. const llvm::Record &getDef() const; diff --git a/mlir/lib/TableGen/EnumInfo.cpp b/mlir/lib/TableGen/EnumInfo.cpp index 9f491d30f0e7f..6128c53557cc4 100644 --- a/mlir/lib/TableGen/EnumInfo.cpp +++ b/mlir/lib/TableGen/EnumInfo.cpp @@ -18,8 +18,8 @@ using llvm::Init; using llvm::Record; EnumCase::EnumCase(const Record *record) : def(record) { - assert(def->isSubClassOf("EnumAttrCaseInfo") && - "must be subclass of TableGen 'EnumAttrCaseInfo' class"); + assert(def->isSubClassOf("EnumCase") && + "must be subclass of TableGen 'EnumCase' class"); } EnumCase::EnumCase(const DefInit *init) : EnumCase(init->getDef()) {} @@ -35,8 +35,8 @@ int64_t EnumCase::getValue() const { return def->getValueAsInt("value"); } const Record &EnumCase::getDef() const { return *def; } EnumInfo::EnumInfo(const Record *record) : def(record) { - assert(isSubClassOf("EnumAttrInfo") && - "must be subclass of TableGen 'EnumAttrInfo' class"); + assert(isSubClassOf("EnumInfo") && + "must be subclass of TableGen 'EnumInfo' class"); } EnumInfo::EnumInfo(const Record &record) : EnumInfo(&record) {} @@ -55,7 +55,7 @@ std::optional EnumInfo::asEnumAttr() const { return std::nullopt; } -bool EnumInfo::isBitEnum() const { return isSubClassOf("BitEnumAttr"); } +bool EnumInfo::isBitEnum() const { return isSubClassOf("BitEnumBase"); } StringRef EnumInfo::getEnumClassName() const { return def->getValueAsString("className"); @@ -73,6 +73,8 @@ StringRef EnumInfo::getCppNamespace() const { return def->getValueAsString("cppNamespace"); } +int64_t EnumInfo::getBitwidth() const { return def->getValueAsInt("bitwidth"); } + StringRef EnumInfo::getUnderlyingType() const { return def->getValueAsString("underlyingType"); } @@ -127,4 +129,8 @@ bool EnumInfo::printBitEnumPrimaryGroups() const { return def->getValueAsBit("printBitEnumPrimaryGroups"); } +bool EnumInfo::printBitEnumQuoted() const { + return def->getValueAsBit("printBitEnumQuoted"); +} + const Record &EnumInfo::getDef() const { return *def; } diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp index 73e2803c21dae..d83df3e415c36 100644 --- a/mlir/lib/TableGen/Pattern.cpp +++ b/mlir/lib/TableGen/Pattern.cpp @@ -57,7 +57,7 @@ bool DagLeaf::isNativeCodeCall() const { bool DagLeaf::isConstantAttr() const { return isSubClassOf("ConstantAttr"); } -bool DagLeaf::isEnumCase() const { return isSubClassOf("EnumAttrCaseInfo"); } +bool DagLeaf::isEnumCase() const { return isSubClassOf("EnumCase"); } bool DagLeaf::isStringAttr() const { return isa(def); } diff --git a/mlir/test/Dialect/LLVMIR/func.mlir b/mlir/test/Dialect/LLVMIR/func.mlir index 74dd862ce8fb2..7caea3920255a 100644 --- a/mlir/test/Dialect/LLVMIR/func.mlir +++ b/mlir/test/Dialect/LLVMIR/func.mlir @@ -428,7 +428,7 @@ module { module { "llvm.func"() ({ - // expected-error @below {{invalid Calling Conventions specification: cc_12}} + // expected-error @below {{expected one of [ccc, fastcc, coldcc, cc_10, cc_11, anyregcc, preserve_mostcc, preserve_allcc, swiftcc, cxx_fast_tlscc, tailcc, cfguard_checkcc, swifttailcc, x86_stdcallcc, x86_fastcallcc, arm_apcscc, arm_aapcscc, arm_aapcs_vfpcc, msp430_intrcc, x86_thiscallcc, ptx_kernelcc, ptx_devicecc, spir_funccc, spir_kernelcc, intel_ocl_bicc, x86_64_sysvcc, win64cc, x86_vectorcallcc, hhvmcc, hhvm_ccc, x86_intrcc, avr_intrcc, avr_builtincc, amdgpu_vscc, amdgpu_gscc, amdgpu_cscc, amdgpu_kernelcc, x86_regcallcc, amdgpu_hscc, msp430_builtincc, amdgpu_lscc, amdgpu_escc, aarch64_vectorcallcc, aarch64_sve_vectorcallcc, wasm_emscripten_invokecc, amdgpu_gfxcc, m68k_intrcc] for Calling Conventions, got: cc_12}} // expected-error @below {{failed to parse CConvAttr parameter 'CallingConv' which is to be a `CConv`}} }) {sym_name = "generic_unknown_calling_convention", CConv = #llvm.cconv, function_type = !llvm.func} : () -> () } diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir index 5a005a393d8ac..4f280bde1aecc 100644 --- a/mlir/test/IR/attribute.mlir +++ b/mlir/test/IR/attribute.mlir @@ -535,7 +535,7 @@ func.func @allowed_cases_pass() { // ----- func.func @disallowed_case_sticky_fail() { - // expected-error@+2 {{expected test::TestBitEnum to be one of: read, write, execute}} + // expected-error@+2 {{expected one of [read, write, execute] for a test bit enum, got: sticky}} // expected-error@+1 {{failed to parse TestBitEnumAttr}} "test.op_with_bit_enum"() {value = #test.bit_enum} : () -> () } diff --git a/mlir/test/lib/Dialect/Test/TestEnumDefs.td b/mlir/test/lib/Dialect/Test/TestEnumDefs.td index 1ddfca0b22315..7441ea5a9726b 100644 --- a/mlir/test/lib/Dialect/Test/TestEnumDefs.td +++ b/mlir/test/lib/Dialect/Test/TestEnumDefs.td @@ -42,11 +42,10 @@ def TestEnum let cppNamespace = "test"; } -def TestSimpleEnum : I32EnumAttr<"SimpleEnum", "", [ - I32EnumAttrCase<"a", 0>, - I32EnumAttrCase<"b", 1> +def TestSimpleEnum : I32Enum<"SimpleEnum", "", [ + I32EnumCase<"a", 0>, + I32EnumCase<"b", 1> ]> { - let genSpecializedAttr = 0; let cppNamespace = "::test"; } @@ -56,24 +55,22 @@ def TestSimpleEnum : I32EnumAttr<"SimpleEnum", "", [ // Define the C++ enum. def TestBitEnum - : I32BitEnumAttr<"TestBitEnum", "a test bit enum", [ - I32BitEnumAttrCaseBit<"Read", 0, "read">, - I32BitEnumAttrCaseBit<"Write", 1, "write">, - I32BitEnumAttrCaseBit<"Execute", 2, "execute">, + : I32BitEnum<"TestBitEnum", "a test bit enum", [ + I32BitEnumCaseBit<"Read", 0, "read">, + I32BitEnumCaseBit<"Write", 1, "write">, + I32BitEnumCaseBit<"Execute", 2, "execute">, ]> { - let genSpecializedAttr = 0; let cppNamespace = "test"; let separator = ", "; } // Define an enum with a different separator def TestBitEnumVerticalBar - : I32BitEnumAttr<"TestBitEnumVerticalBar", "another test bit enum", [ - I32BitEnumAttrCaseBit<"User", 0, "user">, - I32BitEnumAttrCaseBit<"Group", 1, "group">, - I32BitEnumAttrCaseBit<"Other", 2, "other">, + : I32BitEnum<"TestBitEnumVerticalBar", "another test bit enum", [ + I32BitEnumCaseBit<"User", 0, "user">, + I32BitEnumCaseBit<"Group", 1, "group">, + I32BitEnumCaseBit<"Other", 2, "other">, ]> { - let genSpecializedAttr = 0; let cppNamespace = "test"; let separator = " | "; } diff --git a/mlir/test/mlir-tblgen/enums-gen.td b/mlir/test/mlir-tblgen/enums-gen.td index c3a768e42236c..8489cff7c429d 100644 --- a/mlir/test/mlir-tblgen/enums-gen.td +++ b/mlir/test/mlir-tblgen/enums-gen.td @@ -5,12 +5,12 @@ include "mlir/IR/EnumAttr.td" include "mlir/IR/OpBase.td" // Test bit enums -def None: I32BitEnumAttrCaseNone<"None", "none">; -def Bit0: I32BitEnumAttrCaseBit<"Bit0", 0, "tagged">; -def Bit1: I32BitEnumAttrCaseBit<"Bit1", 1>; -def Bit2: I32BitEnumAttrCaseBit<"Bit2", 2>; +def None: I32BitEnumCaseNone<"None", "none">; +def Bit0: I32BitEnumCaseBit<"Bit0", 0, "tagged">; +def Bit1: I32BitEnumCaseBit<"Bit1", 1>; +def Bit2: I32BitEnumCaseBit<"Bit2", 2>; def Bit3: I32BitEnumAttrCaseBit<"Bit3", 3>; -def BitGroup: I32BitEnumAttrCaseGroup<"BitGroup", [ +def BitGroup: BitEnumCaseGroup<"BitGroup", [ Bit0, Bit1 ]>; @@ -42,7 +42,7 @@ def MyBitEnum: I32BitEnumAttr<"MyBitEnum", "An example bit enum", // DECL: // Symbolize the keyword. // DECL: if (::std::optional<::MyBitEnum> attr = ::symbolizeEnum<::MyBitEnum>(enumKeyword)) // DECL: return *attr; -// DECL: return parser.emitError(loc, "invalid An example bit enum specification: ") << enumKeyword; +// DECL: return parser.emitError(loc, "expected one of [none, tagged, Bit1, Bit2, Bit3, BitGroup] for An example bit enum, got: ") << enumKeyword; // DECL: } // DECL: inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, ::MyBitEnum value) { @@ -73,7 +73,7 @@ def MyBitEnum: I32BitEnumAttr<"MyBitEnum", "An example bit enum", // Test enum printer generation for non non-keyword enums. -def NonKeywordBit: I32BitEnumAttrCaseBit<"Bit0", 0, "tag-ged">; +def NonKeywordBit: I32BitEnumCaseBit<"Bit0", 0, "tag-ged">; def MyMixedNonKeywordBitEnum: I32BitEnumAttr<"MyMixedNonKeywordBitEnum", "An example bit enum", [ NonKeywordBit, Bit1 @@ -101,3 +101,32 @@ def MyNonKeywordBitEnum: I32BitEnumAttr<"MyNonKeywordBitEnum", "An example bit e // DECL: inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, ::MyNonKeywordBitEnum value) { // DECL: auto valueStr = stringifyEnum(value); // DECL: return p << '"' << valueStr << '"'; + +def MyNonQuotedPrintBitEnum + : I32BitEnum<"MyNonQuotedPrintBitEnum", "Example new-style bit enum", + [None, Bit0, Bit1, Bit2, Bit3, BitGroup]>; + +// DECL: struct FieldParser<::MyNonQuotedPrintBitEnum, ::MyNonQuotedPrintBitEnum> { +// DECL: template +// DECL: static FailureOr<::MyNonQuotedPrintBitEnum> parse(ParserT &parser) { +// DECL: ::MyNonQuotedPrintBitEnum flags = {}; +// DECL: do { + // DECL: // Parse the keyword containing a part of the enum. +// DECL: ::llvm::StringRef enumKeyword; +// DECL: auto loc = parser.getCurrentLocation(); +// DECL: if (failed(parser.parseOptionalKeyword(&enumKeyword))) { +// DECL: return parser.emitError(loc, "expected keyword for Example new-style bit enum"); +// DECL: } +// DECL: // Symbolize the keyword. +// DECL: if (::std::optional<::MyNonQuotedPrintBitEnum> flag = ::symbolizeEnum<::MyNonQuotedPrintBitEnum>(enumKeyword)) +// DECL: flags = flags | *flag; +// DECL: } else { +// DECL: return parser.emitError(loc, "expected one of [none, tagged, Bit1, Bit2, Bit3, BitGroup] for Example new-style bit enum, got: ") << enumKeyword; +// DECL: } +// DECL: } while (::mlir::succeeded(parser.parseOptionalVerticalBar())); +// DECL: return flags; +// DECL: } + +// DECL: inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, ::MyNonQuotedPrintBitEnum value) { +// DECL: auto valueStr = stringifyEnum(value); +// DECL-NEXT: return p << valueStr; diff --git a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp index 5d4d9e90fff67..8e2d6114e48eb 100644 --- a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp @@ -85,17 +85,6 @@ static void emitEnumClass(EnumInfo enumInfo, raw_ostream &os) { os << "\n"; } -/// Attempts to extract the bitwidth B from string "uintB_t" describing the -/// type. This bitwidth information is not readily available in ODS. Returns -/// `false` on success, `true` on failure. -static bool extractUIntBitwidth(StringRef uintType, int64_t &bitwidth) { - if (!uintType.consume_front("uint")) - return true; - if (!uintType.consume_back("_t")) - return true; - return uintType.getAsInteger(/*Radix=*/10, bitwidth); -} - /// Emits an attribute builder for the given enum attribute to support automatic /// conversion between enum values and attributes in Python. Returns /// `false` on success, `true` on failure. @@ -104,12 +93,7 @@ static bool emitAttributeBuilder(const EnumInfo &enumInfo, raw_ostream &os) { if (!enumAttrInfo) return false; - int64_t bitwidth; - if (extractUIntBitwidth(enumInfo.getUnderlyingType(), bitwidth)) { - llvm::errs() << "failed to identify bitwidth of " - << enumInfo.getUnderlyingType(); - return true; - } + int64_t bitwidth = enumInfo.getBitwidth(); os << formatv("@register_attribute_builder(\"{0}\")\n", enumAttrInfo->getAttrDefName()); os << formatv("def _{0}(x, context):\n", @@ -140,7 +124,7 @@ static bool emitDialectEnumAttributeBuilder(StringRef attrDefName, static bool emitPythonEnums(const RecordKeeper &records, raw_ostream &os) { os << fileHeader; for (const Record *it : - records.getAllDerivedDefinitionsIfDefined("EnumAttrInfo")) { + records.getAllDerivedDefinitionsIfDefined("EnumInfo")) { EnumInfo enumInfo(*it); emitEnumClass(enumInfo, os); emitAttributeBuilder(enumInfo, os); diff --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp index fa6fad156b747..9941a203bc5cb 100644 --- a/mlir/tools/mlir-tblgen/EnumsGen.cpp +++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp @@ -77,11 +77,22 @@ static void emitParserPrinter(const EnumInfo &enumInfo, StringRef qualName, // Check which cases shouldn't be printed using a keyword. llvm::BitVector nonKeywordCases(cases.size()); - for (auto [index, caseVal] : llvm::enumerate(cases)) - if (!mlir::tblgen::canFormatStringAsKeyword(caseVal.getStr())) - nonKeywordCases.set(index); - - // Generate the parser and the start of the printer for the enum. + std::string casesList; + llvm::raw_string_ostream caseListOs(casesList); + caseListOs << "["; + llvm::interleaveComma(llvm::enumerate(cases), caseListOs, + [&](auto enumerant) { + StringRef name = enumerant.value().getStr(); + if (!mlir::tblgen::canFormatStringAsKeyword(name)) { + nonKeywordCases.set(enumerant.index()); + caseListOs << "\\\"" << name << "\\\""; + } + caseListOs << name; + }); + caseListOs << "]"; + + // Generate the parser and the start of the printer for the enum, excluding + // non-quoted bit enums. const char *parsedAndPrinterStart = R"( namespace mlir { template @@ -100,7 +111,7 @@ struct FieldParser<{0}, {0}> {{ // Symbolize the keyword. if (::std::optional<{0}> attr = {1}::symbolizeEnum<{0}>(enumKeyword)) return *attr; - return parser.emitError(loc, "invalid {2} specification: ") << enumKeyword; + return parser.emitError(loc, "expected one of {3} for {2}, got: ") << enumKeyword; } }; @@ -121,7 +132,7 @@ struct FieldParser, std::optional<{0}>> {{ // Symbolize the keyword. if (::std::optional<{0}> attr = {1}::symbolizeEnum<{0}>(enumKeyword)) return attr; - return parser.emitError(loc, "invalid {2} specification: ") << enumKeyword; + return parser.emitError(loc, "expected one of {3} for {2}, got: ") << enumKeyword; } }; } // namespace mlir @@ -131,8 +142,94 @@ inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, {0} value) {{ auto valueStr = stringifyEnum(value); )"; - os << formatv(parsedAndPrinterStart, qualName, cppNamespace, - enumInfo.getSummary()); + const char *parsedAndPrinterStartUnquotedBitEnum = R"( + namespace mlir { + template + struct FieldParser; + + template<> + struct FieldParser<{0}, {0}> {{ + template + static FailureOr<{0}> parse(ParserT &parser) {{ + {0} flags = {{}; + do {{ + // Parse the keyword containing a part of the enum. + ::llvm::StringRef enumKeyword; + auto loc = parser.getCurrentLocation(); + if (failed(parser.parseOptionalKeyword(&enumKeyword))) {{ + return parser.emitError(loc, "expected keyword for {2}"); + } + + // Symbolize the keyword. + if (::std::optional<{0}> flag = {1}::symbolizeEnum<{0}>(enumKeyword)) {{ + flags = flags | *flag; + } else {{ + return parser.emitError(loc, "expected one of {3} for {2}, got: ") << enumKeyword; + } + } while (::mlir::succeeded(parser.{5}())); + return flags; + } + }; + + /// Support for std::optional, useful in attribute/type definition where the enum is + /// used as: + /// + /// let parameters = (ins OptionalParameter<"std::optional">:$value); + template<> + struct FieldParser, std::optional<{0}>> {{ + template + static FailureOr> parse(ParserT &parser) {{ + {0} flags = {{}; + bool firstIter = true; + do {{ + // Parse the keyword containing a part of the enum. + ::llvm::StringRef enumKeyword; + auto loc = parser.getCurrentLocation(); + if (failed(parser.parseOptionalKeyword(&enumKeyword))) {{ + if (firstIter) + return std::optional<{0}>{{}; + return parser.emitError(loc, "expected keyword for {2} after '{4}'"); + } + firstIter = false; + + // Symbolize the keyword. + if (::std::optional<{0}> flag = {1}::symbolizeEnum<{0}>(enumKeyword)) {{ + flags = flags | *flag; + } else {{ + return parser.emitError(loc, "expected one of {3} for {2}, got: ") << enumKeyword; + } + } while(::mlir::succeeded(parser.{5}())); + return std::optional<{0}>{{flags}; + } + }; + } // namespace mlir + + namespace llvm { + inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, {0} value) {{ + auto valueStr = stringifyEnum(value); + )"; + + bool isNewStyleBitEnum = + enumInfo.isBitEnum() && !enumInfo.printBitEnumQuoted(); + + if (isNewStyleBitEnum) { + if (nonKeywordCases.any()) + return PrintFatalError( + "bit enum " + qualName + + " cannot be printed unquoted with cases that cannot be keywords"); + StringRef separator = enumInfo.getDef().getValueAsString("separator"); + StringRef parseSeparatorFn = + llvm::StringSwitch(separator.trim()) + .Case("|", "parseOptionalVerticalBar") + .Case(",", "parseOptionalComma") + .Default("error, enum seperator must be '|' or ','"); + os << formatv(parsedAndPrinterStartUnquotedBitEnum, qualName, cppNamespace, + enumInfo.getSummary(), casesList, separator, + parseSeparatorFn); + } else { + os << formatv(parsedAndPrinterStart, qualName, cppNamespace, + enumInfo.getSummary(), casesList); + } // If all cases require a string, always wrap. if (nonKeywordCases.all()) { @@ -160,7 +257,10 @@ inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, {0} value) {{ // If this is a bit enum, conservatively print the string form if the value // is not a power of two (i.e. not a single bit case) and not a known case. - } else if (enumInfo.isBitEnum()) { + // Only do this if we're using the old-style parser that parses the enum as + // one keyword, as opposed to the new form, where we can print the value + // as-is. + } else if (enumInfo.isBitEnum() && !isNewStyleBitEnum) { // Process the known multi-bit cases that use valid keywords. SmallVector validMultiBitCases; for (auto [index, caseVal] : llvm::enumerate(cases)) { @@ -670,7 +770,7 @@ static bool emitEnumDecls(const RecordKeeper &records, raw_ostream &os) { llvm::emitSourceFileHeader("Enum Utility Declarations", os, records); for (const Record *def : - records.getAllDerivedDefinitionsIfDefined("EnumAttrInfo")) + records.getAllDerivedDefinitionsIfDefined("EnumInfo")) emitEnumDecl(*def, os); return false; @@ -708,7 +808,7 @@ static bool emitEnumDefs(const RecordKeeper &records, raw_ostream &os) { llvm::emitSourceFileHeader("Enum Utility Definitions", os, records); for (const Record *def : - records.getAllDerivedDefinitionsIfDefined("EnumAttrInfo")) + records.getAllDerivedDefinitionsIfDefined("EnumInfo")) emitEnumDef(*def, os); return false; diff --git a/mlir/tools/mlir-tblgen/OpDocGen.cpp b/mlir/tools/mlir-tblgen/OpDocGen.cpp index f53aebb302dc9..077f9d1ea2b13 100644 --- a/mlir/tools/mlir-tblgen/OpDocGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDocGen.cpp @@ -406,7 +406,7 @@ static void emitEnumDoc(const EnumInfo &def, raw_ostream &os) { static void emitEnumDoc(const RecordKeeper &records, raw_ostream &os) { os << "\n"; - for (const Record *def : records.getAllDerivedDefinitions("EnumAttrInfo")) + for (const Record *def : records.getAllDerivedDefinitions("EnumInfo")) emitEnumDoc(EnumInfo(def), os); } @@ -526,7 +526,7 @@ static bool emitDialectDoc(const RecordKeeper &records, raw_ostream &os) { auto typeDefs = records.getAllDerivedDefinitionsIfDefined("DialectType"); auto typeDefDefs = records.getAllDerivedDefinitionsIfDefined("TypeDef"); auto attrDefDefs = records.getAllDerivedDefinitionsIfDefined("AttrDef"); - auto enumDefs = records.getAllDerivedDefinitionsIfDefined("EnumAttrInfo"); + auto enumDefs = records.getAllDerivedDefinitionsIfDefined("EnumInfo"); std::vector dialectAttrs; std::vector dialectAttrDefs; diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp index 7a6189c09f426..f94ed17aeb4e0 100644 --- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp +++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp @@ -455,7 +455,7 @@ static bool emitEnumDecls(const RecordKeeper &records, raw_ostream &os) { llvm::emitSourceFileHeader("SPIR-V Enum Availability Declarations", os, records); - auto defs = records.getAllDerivedDefinitions("EnumAttrInfo"); + auto defs = records.getAllDerivedDefinitions("EnumInfo"); for (const auto *def : defs) emitEnumDecl(*def, os); @@ -487,7 +487,7 @@ static bool emitEnumDefs(const RecordKeeper &records, raw_ostream &os) { llvm::emitSourceFileHeader("SPIR-V Enum Availability Definitions", os, records); - auto defs = records.getAllDerivedDefinitions("EnumAttrInfo"); + auto defs = records.getAllDerivedDefinitions("EnumInfo"); for (const auto *def : defs) emitEnumDef(*def, os); @@ -1262,7 +1262,7 @@ static void emitEnumGetAttrNameFnDefn(const EnumInfo &enumInfo, static bool emitAttrUtils(const RecordKeeper &records, raw_ostream &os) { llvm::emitSourceFileHeader("SPIR-V Attribute Utilities", os, records); - auto defs = records.getAllDerivedDefinitions("EnumAttrInfo"); + auto defs = records.getAllDerivedDefinitions("EnumInfo"); os << "#ifndef MLIR_DIALECT_SPIRV_IR_ATTR_UTILS_H_\n"; os << "#define MLIR_DIALECT_SPIRV_IR_ATTR_UTILS_H_\n"; emitEnumGetAttrNameFnDecl(os); diff --git a/mlir/utils/spirv/gen_spirv_dialect.py b/mlir/utils/spirv/gen_spirv_dialect.py index 99ed3489b4cbd..d2d0b410f52df 100755 --- a/mlir/utils/spirv/gen_spirv_dialect.py +++ b/mlir/utils/spirv/gen_spirv_dialect.py @@ -288,11 +288,11 @@ def get_availability_spec(enum_case, for_op, for_cap): def gen_operand_kind_enum_attr(operand_kind): - """Generates the TableGen EnumAttr definition for the given operand kind. + """Generates the TableGen EnumInfo definition for the given operand kind. Returns: - The operand kind's name - - A string containing the TableGen EnumAttr definition + - A string containing the TableGen EnumInfo definition """ if "enumerants" not in operand_kind: return "", ""