Skip to content

Commit 509fe55

Browse files
committed
[mlir] Decouple enum generation from attributes, adding EnumInfo and EnumCase
This commit pulls apart the inherent attribute dependence of classes like EnumAttrInfo and EnumAttrCase, factoring them out into simpler EnumCase and EnumInfo variants. This allows specifying the cases of an enum without needing to make the cases, or the EnumInfo itself, a subclass of SignlessIntegerAttrBase. The existing classes are retained as subclasses of the new ones, both for backwards compatibility and to allow attribute-specific information. In addition, the new BitEnum class changes its default printer/parser behavior: cases when multiple keywords appear, like having both nuw and nsw in overflow flags, will no longer be quoted by the operator<<, and the FieldParser instance will now expect multiple keywords. All instances of BitEnumAttr retain the old behavior.
1 parent d130635 commit 509fe55

File tree

13 files changed

+389
-134
lines changed

13 files changed

+389
-134
lines changed

mlir/include/mlir/IR/EnumAttr.td

Lines changed: 203 additions & 68 deletions
Large diffs are not rendered by default.

mlir/include/mlir/TableGen/EnumInfo.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@ class EnumInfo {
8585
// Returns the description of the enum.
8686
StringRef getDescription() const;
8787

88+
// Returns the bitwidth of the enum.
89+
int64_t getBitwidth() const;
90+
8891
// Returns the underlying type.
8992
StringRef getUnderlyingType() const;
9093

@@ -120,6 +123,7 @@ class EnumInfo {
120123
// Only applicable for bit enums.
121124

122125
bool printBitEnumPrimaryGroups() const;
126+
bool printBitEnumQuoted() const;
123127

124128
// Returns the TableGen definition this EnumAttrCase was constructed from.
125129
const llvm::Record &getDef() const;

mlir/lib/TableGen/EnumInfo.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ using llvm::Init;
1818
using llvm::Record;
1919

2020
EnumCase::EnumCase(const Record *record) : def(record) {
21-
assert(def->isSubClassOf("EnumAttrCaseInfo") &&
22-
"must be subclass of TableGen 'EnumAttrCaseInfo' class");
21+
assert(def->isSubClassOf("EnumCase") &&
22+
"must be subclass of TableGen 'EnumCase' class");
2323
}
2424

2525
EnumCase::EnumCase(const DefInit *init) : EnumCase(init->getDef()) {}
@@ -35,8 +35,8 @@ int64_t EnumCase::getValue() const { return def->getValueAsInt("value"); }
3535
const Record &EnumCase::getDef() const { return *def; }
3636

3737
EnumInfo::EnumInfo(const Record *record) : def(record) {
38-
assert(isSubClassOf("EnumAttrInfo") &&
39-
"must be subclass of TableGen 'EnumAttrInfo' class");
38+
assert(isSubClassOf("EnumInfo") &&
39+
"must be subclass of TableGen 'EnumInfo' class");
4040
}
4141

4242
EnumInfo::EnumInfo(const Record &record) : EnumInfo(&record) {}
@@ -55,7 +55,7 @@ std::optional<Attribute> EnumInfo::asEnumAttr() const {
5555
return std::nullopt;
5656
}
5757

58-
bool EnumInfo::isBitEnum() const { return isSubClassOf("BitEnumAttr"); }
58+
bool EnumInfo::isBitEnum() const { return isSubClassOf("BitEnumBase"); }
5959

6060
StringRef EnumInfo::getEnumClassName() const {
6161
return def->getValueAsString("className");
@@ -73,6 +73,8 @@ StringRef EnumInfo::getCppNamespace() const {
7373
return def->getValueAsString("cppNamespace");
7474
}
7575

76+
int64_t EnumInfo::getBitwidth() const { return def->getValueAsInt("bitwidth"); }
77+
7678
StringRef EnumInfo::getUnderlyingType() const {
7779
return def->getValueAsString("underlyingType");
7880
}
@@ -127,4 +129,8 @@ bool EnumInfo::printBitEnumPrimaryGroups() const {
127129
return def->getValueAsBit("printBitEnumPrimaryGroups");
128130
}
129131

132+
bool EnumInfo::printBitEnumQuoted() const {
133+
return def->getValueAsBit("printBitEnumQuoted");
134+
}
135+
130136
const Record &EnumInfo::getDef() const { return *def; }

mlir/lib/TableGen/Pattern.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ bool DagLeaf::isNativeCodeCall() const {
5757

5858
bool DagLeaf::isConstantAttr() const { return isSubClassOf("ConstantAttr"); }
5959

60-
bool DagLeaf::isEnumCase() const { return isSubClassOf("EnumAttrCaseInfo"); }
60+
bool DagLeaf::isEnumCase() const { return isSubClassOf("EnumCase"); }
6161

6262
bool DagLeaf::isStringAttr() const { return isa<llvm::StringInit>(def); }
6363

mlir/test/Dialect/LLVMIR/func.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,7 @@ module {
428428

429429
module {
430430
"llvm.func"() ({
431-
// expected-error @below {{invalid Calling Conventions specification: cc_12}}
431+
// expected-error @below {{expected one of [ccc, fastcc, coldcc, cc_10, cc_11, anyregcc, preserve_mostcc, preserve_allcc, swiftcc, cxx_fast_tlscc, tailcc, cfguard_checkcc, swifttailcc, x86_stdcallcc, x86_fastcallcc, arm_apcscc, arm_aapcscc, arm_aapcs_vfpcc, msp430_intrcc, x86_thiscallcc, ptx_kernelcc, ptx_devicecc, spir_funccc, spir_kernelcc, intel_ocl_bicc, x86_64_sysvcc, win64cc, x86_vectorcallcc, hhvmcc, hhvm_ccc, x86_intrcc, avr_intrcc, avr_builtincc, amdgpu_vscc, amdgpu_gscc, amdgpu_cscc, amdgpu_kernelcc, x86_regcallcc, amdgpu_hscc, msp430_builtincc, amdgpu_lscc, amdgpu_escc, aarch64_vectorcallcc, aarch64_sve_vectorcallcc, wasm_emscripten_invokecc, amdgpu_gfxcc, m68k_intrcc] for Calling Conventions, got: cc_12}}
432432
// expected-error @below {{failed to parse CConvAttr parameter 'CallingConv' which is to be a `CConv`}}
433433
}) {sym_name = "generic_unknown_calling_convention", CConv = #llvm.cconv<cc_12>, function_type = !llvm.func<i64 (i64, i64)>} : () -> ()
434434
}

mlir/test/IR/attribute.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -535,7 +535,7 @@ func.func @allowed_cases_pass() {
535535
// -----
536536

537537
func.func @disallowed_case_sticky_fail() {
538-
// expected-error@+2 {{expected test::TestBitEnum to be one of: read, write, execute}}
538+
// expected-error@+2 {{expected one of [read, write, execute] for a test bit enum, got: sticky}}
539539
// expected-error@+1 {{failed to parse TestBitEnumAttr}}
540540
"test.op_with_bit_enum"() {value = #test.bit_enum<sticky>} : () -> ()
541541
}

mlir/test/lib/Dialect/Test/TestEnumDefs.td

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,10 @@ def TestEnum
4242
let cppNamespace = "test";
4343
}
4444

45-
def TestSimpleEnum : I32EnumAttr<"SimpleEnum", "", [
46-
I32EnumAttrCase<"a", 0>,
47-
I32EnumAttrCase<"b", 1>
45+
def TestSimpleEnum : I32Enum<"SimpleEnum", "", [
46+
I32EnumCase<"a", 0>,
47+
I32EnumCase<"b", 1>
4848
]> {
49-
let genSpecializedAttr = 0;
5049
let cppNamespace = "::test";
5150
}
5251

@@ -56,24 +55,22 @@ def TestSimpleEnum : I32EnumAttr<"SimpleEnum", "", [
5655

5756
// Define the C++ enum.
5857
def TestBitEnum
59-
: I32BitEnumAttr<"TestBitEnum", "a test bit enum", [
60-
I32BitEnumAttrCaseBit<"Read", 0, "read">,
61-
I32BitEnumAttrCaseBit<"Write", 1, "write">,
62-
I32BitEnumAttrCaseBit<"Execute", 2, "execute">,
58+
: I32BitEnum<"TestBitEnum", "a test bit enum", [
59+
I32BitEnumCaseBit<"Read", 0, "read">,
60+
I32BitEnumCaseBit<"Write", 1, "write">,
61+
I32BitEnumCaseBit<"Execute", 2, "execute">,
6362
]> {
64-
let genSpecializedAttr = 0;
6563
let cppNamespace = "test";
6664
let separator = ", ";
6765
}
6866

6967
// Define an enum with a different separator
7068
def TestBitEnumVerticalBar
71-
: I32BitEnumAttr<"TestBitEnumVerticalBar", "another test bit enum", [
72-
I32BitEnumAttrCaseBit<"User", 0, "user">,
73-
I32BitEnumAttrCaseBit<"Group", 1, "group">,
74-
I32BitEnumAttrCaseBit<"Other", 2, "other">,
69+
: I32BitEnum<"TestBitEnumVerticalBar", "another test bit enum", [
70+
I32BitEnumCaseBit<"User", 0, "user">,
71+
I32BitEnumCaseBit<"Group", 1, "group">,
72+
I32BitEnumCaseBit<"Other", 2, "other">,
7573
]> {
76-
let genSpecializedAttr = 0;
7774
let cppNamespace = "test";
7875
let separator = " | ";
7976
}

mlir/test/mlir-tblgen/enums-gen.td

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@ include "mlir/IR/EnumAttr.td"
55
include "mlir/IR/OpBase.td"
66

77
// Test bit enums
8-
def None: I32BitEnumAttrCaseNone<"None", "none">;
9-
def Bit0: I32BitEnumAttrCaseBit<"Bit0", 0, "tagged">;
10-
def Bit1: I32BitEnumAttrCaseBit<"Bit1", 1>;
11-
def Bit2: I32BitEnumAttrCaseBit<"Bit2", 2>;
8+
def None: I32BitEnumCaseNone<"None", "none">;
9+
def Bit0: I32BitEnumCaseBit<"Bit0", 0, "tagged">;
10+
def Bit1: I32BitEnumCaseBit<"Bit1", 1>;
11+
def Bit2: I32BitEnumCaseBit<"Bit2", 2>;
1212
def Bit3: I32BitEnumAttrCaseBit<"Bit3", 3>;
13-
def BitGroup: I32BitEnumAttrCaseGroup<"BitGroup", [
13+
def BitGroup: BitEnumCaseGroup<"BitGroup", [
1414
Bit0, Bit1
1515
]>;
1616

@@ -42,7 +42,7 @@ def MyBitEnum: I32BitEnumAttr<"MyBitEnum", "An example bit enum",
4242
// DECL: // Symbolize the keyword.
4343
// DECL: if (::std::optional<::MyBitEnum> attr = ::symbolizeEnum<::MyBitEnum>(enumKeyword))
4444
// DECL: return *attr;
45-
// DECL: return parser.emitError(loc, "invalid An example bit enum specification: ") << enumKeyword;
45+
// DECL: return parser.emitError(loc, "expected one of [none, tagged, Bit1, Bit2, Bit3, BitGroup] for An example bit enum, got: ") << enumKeyword;
4646
// DECL: }
4747

4848
// DECL: inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, ::MyBitEnum value) {
@@ -73,7 +73,7 @@ def MyBitEnum: I32BitEnumAttr<"MyBitEnum", "An example bit enum",
7373

7474
// Test enum printer generation for non non-keyword enums.
7575

76-
def NonKeywordBit: I32BitEnumAttrCaseBit<"Bit0", 0, "tag-ged">;
76+
def NonKeywordBit: I32BitEnumCaseBit<"Bit0", 0, "tag-ged">;
7777
def MyMixedNonKeywordBitEnum: I32BitEnumAttr<"MyMixedNonKeywordBitEnum", "An example bit enum", [
7878
NonKeywordBit,
7979
Bit1
@@ -101,3 +101,32 @@ def MyNonKeywordBitEnum: I32BitEnumAttr<"MyNonKeywordBitEnum", "An example bit e
101101
// DECL: inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, ::MyNonKeywordBitEnum value) {
102102
// DECL: auto valueStr = stringifyEnum(value);
103103
// DECL: return p << '"' << valueStr << '"';
104+
105+
def MyNonQuotedPrintBitEnum
106+
: I32BitEnum<"MyNonQuotedPrintBitEnum", "Example new-style bit enum",
107+
[None, Bit0, Bit1, Bit2, Bit3, BitGroup]>;
108+
109+
// DECL: struct FieldParser<::MyNonQuotedPrintBitEnum, ::MyNonQuotedPrintBitEnum> {
110+
// DECL: template <typename ParserT>
111+
// DECL: static FailureOr<::MyNonQuotedPrintBitEnum> parse(ParserT &parser) {
112+
// DECL: ::MyNonQuotedPrintBitEnum flags = {};
113+
// DECL: do {
114+
// DECL: // Parse the keyword containing a part of the enum.
115+
// DECL: ::llvm::StringRef enumKeyword;
116+
// DECL: auto loc = parser.getCurrentLocation();
117+
// DECL: if (failed(parser.parseOptionalKeyword(&enumKeyword))) {
118+
// DECL: return parser.emitError(loc, "expected keyword for Example new-style bit enum");
119+
// DECL: }
120+
// DECL: // Symbolize the keyword.
121+
// DECL: if (::std::optional<::MyNonQuotedPrintBitEnum> flag = ::symbolizeEnum<::MyNonQuotedPrintBitEnum>(enumKeyword))
122+
// DECL: flags = flags | *flag;
123+
// DECL: } else {
124+
// DECL: return parser.emitError(loc, "expected one of [none, tagged, Bit1, Bit2, Bit3, BitGroup] for Example new-style bit enum, got: ") << enumKeyword;
125+
// DECL: }
126+
// DECL: } while (::mlir::succeeded(parser.parseOptionalVerticalBar()));
127+
// DECL: return flags;
128+
// DECL: }
129+
130+
// DECL: inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, ::MyNonQuotedPrintBitEnum value) {
131+
// DECL: auto valueStr = stringifyEnum(value);
132+
// DECL-NEXT: return p << valueStr;

mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -85,17 +85,6 @@ static void emitEnumClass(EnumInfo enumInfo, raw_ostream &os) {
8585
os << "\n";
8686
}
8787

88-
/// Attempts to extract the bitwidth B from string "uintB_t" describing the
89-
/// type. This bitwidth information is not readily available in ODS. Returns
90-
/// `false` on success, `true` on failure.
91-
static bool extractUIntBitwidth(StringRef uintType, int64_t &bitwidth) {
92-
if (!uintType.consume_front("uint"))
93-
return true;
94-
if (!uintType.consume_back("_t"))
95-
return true;
96-
return uintType.getAsInteger(/*Radix=*/10, bitwidth);
97-
}
98-
9988
/// Emits an attribute builder for the given enum attribute to support automatic
10089
/// conversion between enum values and attributes in Python. Returns
10190
/// `false` on success, `true` on failure.
@@ -104,12 +93,7 @@ static bool emitAttributeBuilder(const EnumInfo &enumInfo, raw_ostream &os) {
10493
if (!enumAttrInfo)
10594
return false;
10695

107-
int64_t bitwidth;
108-
if (extractUIntBitwidth(enumInfo.getUnderlyingType(), bitwidth)) {
109-
llvm::errs() << "failed to identify bitwidth of "
110-
<< enumInfo.getUnderlyingType();
111-
return true;
112-
}
96+
int64_t bitwidth = enumInfo.getBitwidth();
11397
os << formatv("@register_attribute_builder(\"{0}\")\n",
11498
enumAttrInfo->getAttrDefName());
11599
os << formatv("def _{0}(x, context):\n",
@@ -140,7 +124,7 @@ static bool emitDialectEnumAttributeBuilder(StringRef attrDefName,
140124
static bool emitPythonEnums(const RecordKeeper &records, raw_ostream &os) {
141125
os << fileHeader;
142126
for (const Record *it :
143-
records.getAllDerivedDefinitionsIfDefined("EnumAttrInfo")) {
127+
records.getAllDerivedDefinitionsIfDefined("EnumInfo")) {
144128
EnumInfo enumInfo(*it);
145129
emitEnumClass(enumInfo, os);
146130
emitAttributeBuilder(enumInfo, os);

0 commit comments

Comments
 (0)