diff --git a/mlir/docs/DefiningDialects/Operations.md b/mlir/docs/DefiningDialects/Operations.md index fafda816a3881..88d58e0a1efbf 100644 --- a/mlir/docs/DefiningDialects/Operations.md +++ b/mlir/docs/DefiningDialects/Operations.md @@ -1756,6 +1756,23 @@ that it has a value within the valid range of the enum. If their wrapper attribute instead of using a bare signless integer attribute for storage. +### Enum properties + +Enums can be wrapped in properties so that they can be stored inline. +This causes a value of the enum's C++ class to become a member of the operation's +property struct and for the operation's verifier to check that the enum's value +is a valid value for the enum. + +The basic wrapper is `EnumProp`, which simply takes an `EnumInfo`. + +A less ambiguous syntax, namely putting a mnemonic and `<>`s surrounding +the enum is generated with `NamedEnumProp`, which takes a `*EnumInfo` +and a mnemonic string, which becomes part of the property's syntax. + +Both of these `EnumProp` types have a `*EnumPropWithAttrForm`, which allows for +transparently upgrading from `EnumAttr`s and optionally retaining those +attributes in the generic form. + ## Debugging Tips ### Run `mlir-tblgen` to see the generated content diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td index a9de787806452..34a30a00790ea 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td @@ -485,17 +485,16 @@ def DISubprogramFlags : I32BitEnumAttr< // IntegerOverflowFlags //===----------------------------------------------------------------------===// -def IOFnone : I32BitEnumAttrCaseNone<"none">; -def IOFnsw : I32BitEnumAttrCaseBit<"nsw", 0>; -def IOFnuw : I32BitEnumAttrCaseBit<"nuw", 1>; +def IOFnone : I32BitEnumCaseNone<"none">; +def IOFnsw : I32BitEnumCaseBit<"nsw", 0>; +def IOFnuw : I32BitEnumCaseBit<"nuw", 1>; -def IntegerOverflowFlags : I32BitEnumAttr< +def IntegerOverflowFlags : I32BitEnum< "IntegerOverflowFlags", "LLVM integer overflow flags", [IOFnone, IOFnsw, IOFnuw]> { let separator = ", "; let cppNamespace = "::mlir::LLVM"; - let genSpecializedAttr = 0; let printBitEnumPrimaryGroups = 1; } @@ -504,6 +503,11 @@ def LLVM_IntegerOverflowFlagsAttr : let assemblyFormat = "`<` $value `>`"; } +def LLVM_IntegerOverflowFlagsProp : + NamedEnumPropWithAttrForm { + let defaultValue = enum.cppType # "::" # "none"; +} + //===----------------------------------------------------------------------===// // FastmathFlags //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index b107b64e55b46..73ccef94f122f 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -60,7 +60,7 @@ class LLVM_IntArithmeticOpWithOverflowFlag traits = []> : LLVM_ArithmeticOpBase], traits)> { - dag iofArg = (ins EnumProp<"IntegerOverflowFlags", "", "IntegerOverflowFlags::none">:$overflowFlags); + dag iofArg = (ins LLVM_IntegerOverflowFlagsProp:$overflowFlags); let arguments = !con(commonArgs, iofArg); string mlirBuilder = [{ @@ -69,7 +69,7 @@ class LLVM_IntArithmeticOpWithOverflowFlag($overflowFlags) attr-dict `:` type($res) + $lhs `,` $rhs ($overflowFlags^)? attr-dict `:` type($res) }]; string llvmBuilder = "$res = builder.Create" # instName # @@ -563,10 +563,10 @@ class LLVM_CastOpWithOverflowFlag traits = []> : LLVM_Op], traits)>, LLVM_Builder<"$res = builder.Create" # instName # "($arg, $_resultType, /*Name=*/\"\", op.hasNoUnsignedWrap(), op.hasNoSignedWrap());"> { - let arguments = (ins type:$arg, EnumProp<"IntegerOverflowFlags", "", "IntegerOverflowFlags::none">:$overflowFlags); + let arguments = (ins type:$arg, LLVM_IntegerOverflowFlagsProp:$overflowFlags); let results = (outs resultType:$res); let builders = [LLVM_OneResultOpBuilder]; - let assemblyFormat = "$arg `` custom($overflowFlags) attr-dict `:` type($arg) `to` type($res)"; + let assemblyFormat = "$arg ($overflowFlags^)? attr-dict `:` type($arg) `to` type($res)"; string llvmInstName = instName; string mlirBuilder = [{ auto op = $_builder.create<$_qualCppClassName>( diff --git a/mlir/include/mlir/IR/EnumAttr.td b/mlir/include/mlir/IR/EnumAttr.td index 931126a155fbb..3f7f747ac20d3 100644 --- a/mlir/include/mlir/IR/EnumAttr.td +++ b/mlir/include/mlir/IR/EnumAttr.td @@ -10,6 +10,7 @@ #define ENUMATTR_TD include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/Properties.td" //===----------------------------------------------------------------------===// // Enum attribute kinds @@ -552,6 +553,141 @@ class EnumAttr : Property { + EnumInfo enum = enumInfo; + + let description = enum.description; + let predicate = !if( + !isa(enum), + CPred<"(static_cast<" # enum.underlyingType # ">($_self) & ~" # !cast(enum).validBits # ") == 0">, + Or)>); + + let convertFromAttribute = [{ + auto intAttr = ::mlir::dyn_cast_if_present<::mlir::IntegerAttr>($_attr); + if (!intAttr) { + return $_diag() << "expected IntegerAttr storage for }] # + enum.cppType # [{"; + } + $_storage = static_cast<}] # enum.cppType # [{>(intAttr.getValue().getZExtValue()); + return ::mlir::success(); + }]; + + let convertToAttribute = [{ + return ::mlir::IntegerAttr::get(::mlir::IntegerType::get($_ctxt, }] # enum.bitwidth + # [{), static_cast<}] # enum.underlyingType #[{>($_storage)); + }]; + + let writeToMlirBytecode = [{ + $_writer.writeVarInt(static_cast($_storage)); + }]; + + let readFromMlirBytecode = [{ + uint64_t rawValue; + if (::mlir::failed($_reader.readVarInt(rawValue))) + return ::mlir::failure(); + if (rawValue > std::numeric_limits<}] # enum.underlyingType # [{>::max()) + return ::mlir::failure(); + $_storage = static_cast<}] # enum.cppType # [{>(rawValue); + }]; + + let optionalParser = [{ + auto value = ::mlir::FieldParser>::parse($_parser); + if (::mlir::failed(value)) + return ::mlir::failure(); + if (!(value->has_value())) + return std::nullopt; + $_storage = std::move(**value); + }]; +} + +// Enum property that can have been (or, if `storeInCustomAttribute` is true, will also +// be stored as) an attribute, in addition to being stored as an integer attribute. +class EnumPropWithAttrForm + : EnumProp { + Attr attrForm = attributeForm; + bit storeInCustomAttribute = 0; + + let convertFromAttribute = [{ + auto customAttr = ::mlir::dyn_cast_if_present<}] + # attrForm.storageType # [{>($_attr); + if (customAttr) { + $_storage = customAttr.getValue(); + return ::mlir::success(); + } + auto intAttr = ::mlir::dyn_cast_if_present<::mlir::IntegerAttr>($_attr); + if (!intAttr) { + return $_diag() << "expected }] # attrForm.storageType + # [{ or IntegerAttr storage for }] # enum.cppType # [{"; + } + $_storage = static_cast<}] # enum.cppType # [{>(intAttr.getValue().getZExtValue()); + return ::mlir::success(); + }]; + + let convertToAttribute = !if(storeInCustomAttribute, [{ + return }] # attrForm.storageType # [{::get($_ctxt, $_storage); + }], [{ + return ::mlir::IntegerAttr::get(::mlir::IntegerType::get($_ctxt, }] # enumInfo.bitwidth + # [{), static_cast<}] # enum.underlyingType #[{>($_storage)); + }]); +} + +class _namedEnumPropFields { + code parser = [{ + if ($_parser.parseKeyword("}] # mnemonic # [{") + || $_parser.parseLess()) { + return ::mlir::failure(); + } + auto parseRes = ::mlir::FieldParser<}] # cppType # [{>::parse($_parser); + if (::mlir::failed(parseRes) || + ::mlir::failed($_parser.parseGreater())) { + return ::mlir::failure(); + } + $_storage = *parseRes; + }]; + + code optionalParser = [{ + if ($_parser.parseOptionalKeyword("}] # mnemonic # [{")) { + return std::nullopt; + } + if ($_parser.parseLess()) { + return ::mlir::failure(); + } + auto parseRes = ::mlir::FieldParser<}] # cppType # [{>::parse($_parser); + if (::mlir::failed(parseRes) || + ::mlir::failed($_parser.parseGreater())) { + return ::mlir::failure(); + } + $_storage = *parseRes; + }]; + + code printer = [{ + $_printer << "}] # mnemonic # [{<" << $_storage << ">"; + }]; +} + +// An EnumProp which, when printed, is surrounded by mnemonic<>. +// For example, if the enum can be a, b, or c, and the mnemonic is foo, +// the format of this property will be "foo", "foo", or "foo". +class NamedEnumProp + : EnumProp { + string mnemonic = name; + let parser = _namedEnumPropFields.parser; + let optionalParser = _namedEnumPropFields.optionalParser; + let printer = _namedEnumPropFields.printer; +} + +// A `NamedEnumProp` with an attribute form as in `EnumPropWithAttrForm`. +class NamedEnumPropWithAttrForm + : EnumPropWithAttrForm { + string mnemonic = name; + let parser = _namedEnumPropFields.parser; + let optionalParser = _namedEnumPropFields.optionalParser; + let printer = _namedEnumPropFields.printer; +} + class _symbolToValue { defvar cases = !filter(iter, enumInfo.enumerants, !eq(iter.str, case)); diff --git a/mlir/include/mlir/IR/Properties.td b/mlir/include/mlir/IR/Properties.td index 8bd8343790402..739df03c7ef2e 100644 --- a/mlir/include/mlir/IR/Properties.td +++ b/mlir/include/mlir/IR/Properties.td @@ -239,25 +239,6 @@ def I64Prop : IntProp<"int64_t">; def I32Property : IntProp<"int32_t">, Deprecated<"moved to shorter name I32Prop">; def I64Property : IntProp<"int64_t">, Deprecated<"moved to shorter name I64Prop">; -class EnumProp : - Property { - // TODO: implement predicate for enum validity. - let writeToMlirBytecode = [{ - $_writer.writeVarInt(static_cast($_storage)); - }]; - let readFromMlirBytecode = [{ - uint64_t val; - if (failed($_reader.readVarInt(val))) - return ::mlir::failure(); - $_storage = static_cast<}] # storageTypeParam # [{>(val); - }]; - let defaultValue = default; -} - -class EnumProperty : - EnumProp, - Deprecated<"moved to shorter name EnumProp">; - // Note: only a class so we can deprecate the old name class _cls_StringProp : Property<"std::string", "string"> { let interfaceType = "::llvm::StringRef"; diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 78eb4c9b3481f..c42f906369a13 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -49,71 +49,6 @@ using mlir::LLVM::tailcallkind::getMaxEnumValForTailCallKind; #include "mlir/Dialect/LLVMIR/LLVMOpsDialect.cpp.inc" -//===----------------------------------------------------------------------===// -// Property Helpers -//===----------------------------------------------------------------------===// - -//===----------------------------------------------------------------------===// -// IntegerOverflowFlags -//===----------------------------------------------------------------------===// - -namespace mlir { -static Attribute convertToAttribute(MLIRContext *ctx, - IntegerOverflowFlags flags) { - return IntegerOverflowFlagsAttr::get(ctx, flags); -} - -static LogicalResult -convertFromAttribute(IntegerOverflowFlags &flags, Attribute attr, - function_ref emitError) { - auto flagsAttr = dyn_cast(attr); - if (!flagsAttr) { - return emitError() << "expected 'overflowFlags' attribute to be an " - "IntegerOverflowFlagsAttr, but got " - << attr; - } - flags = flagsAttr.getValue(); - return success(); -} -} // namespace mlir - -static ParseResult parseOverflowFlags(AsmParser &p, - IntegerOverflowFlags &flags) { - if (failed(p.parseOptionalKeyword("overflow"))) { - flags = IntegerOverflowFlags::none; - return success(); - } - if (p.parseLess()) - return failure(); - do { - StringRef kw; - SMLoc loc = p.getCurrentLocation(); - if (p.parseKeyword(&kw)) - return failure(); - std::optional flag = - symbolizeIntegerOverflowFlags(kw); - if (!flag) - return p.emitError(loc, - "invalid overflow flag: expected nsw, nuw, or none"); - flags = flags | *flag; - } while (succeeded(p.parseOptionalComma())); - return p.parseGreater(); -} - -static void printOverflowFlags(AsmPrinter &p, Operation *op, - IntegerOverflowFlags flags) { - if (flags == IntegerOverflowFlags::none) - return; - p << " overflow<"; - SmallVector strs; - if (bitEnumContainsAny(flags, IntegerOverflowFlags::nsw)) - strs.push_back("nsw"); - if (bitEnumContainsAny(flags, IntegerOverflowFlags::nuw)) - strs.push_back("nuw"); - llvm::interleaveComma(strs, p); - p << ">"; -} - //===----------------------------------------------------------------------===// // Attribute Helpers //===----------------------------------------------------------------------===// diff --git a/mlir/test/IR/enum-attr-invalid.mlir b/mlir/test/IR/enum-attr-invalid.mlir index 923736f28dadb..2f240a56c9874 100644 --- a/mlir/test/IR/enum-attr-invalid.mlir +++ b/mlir/test/IR/enum-attr-invalid.mlir @@ -28,3 +28,78 @@ func.func @test_parse_invalid_attr() -> () { // expected-error@+1 {{failed to parse TestEnumAttr parameter 'value'}} test.op_with_enum 1 : index } + +// ----- + +func.func @test_non_keyword_prop_enum() -> () { + // expected-error@+2 {{expected keyword for a test enum}} + // expected-error@+1 {{invalid value for property value, expected a test enum}} + test.op_with_enum_prop 0 + return +} + +// ----- + +func.func @test_wrong_keyword_prop_enum() -> () { + // expected-error@+2 {{expected one of [first, second, third] for a test enum, got: fourth}} + // expected-error@+1 {{invalid value for property value, expected a test enum}} + test.op_with_enum_prop fourth +} + +// ----- + +func.func @test_bad_integer() -> () { + // expected-error@+1 {{op property 'value' failed to satisfy constraint: a test enum}} + "test.op_with_enum_prop"() <{value = 4 : i32}> {} : () -> () +} + +// ----- + +func.func @test_bit_enum_prop_not_keyword() -> () { + // expected-error@+2 {{expected keyword for a test bit enum}} + // expected-error@+1 {{invalid value for property value1, expected a test bit enum}} + test.op_with_bit_enum_prop 0 + return +} + +// ----- + +func.func @test_bit_enum_prop_wrong_keyword() -> () { + // expected-error@+2 {{expected one of [read, write, execute] for a test bit enum, got: chroot}} + // expected-error@+1 {{invalid value for property value1, expected a test bit enum}} + test.op_with_bit_enum_prop read, chroot : () + return +} + +// ----- + +func.func @test_bit_enum_prop_bad_value() -> () { + // expected-error@+1 {{op property 'value2' failed to satisfy constraint: a test bit enum}} + "test.op_with_bit_enum_prop"() <{value1 = 7 : i32, value2 = 8 : i32}> {} : () -> () + return +} + +// ----- + +func.func @test_bit_enum_prop_named_wrong_keyword() -> () { + // expected-error@+2 {{expected 'bit_enum'}} + // expected-error@+1 {{invalid value for property value1, expected a test bit enum}} + test.op_with_bit_enum_prop_named foo + return +} + +// ----- + +func.func @test_bit_enum_prop_named_not_open() -> () { + // expected-error@+2 {{expected '<'}} + // expected-error@+1 {{invalid value for property value1, expected a test bit enum}} + test.op_with_bit_enum_prop_named bit_enum read, execute> +} + +// ----- + +func.func @test_bit_enum_prop_named_not_closed() -> () { + // expected-error@+2 {{expected '>'}} + // expected-error@+1 {{invalid value for property value1, expected a test bit enum}} + test.op_with_bit_enum_prop_named bit_enum () { test.op_with_bit_enum tag 0 : i32 return } + +// CHECK-LABEL: @test_enum_prop +func.func @test_enum_prop() -> () { + // CHECK: test.op_with_enum_prop first + test.op_with_enum_prop first + + // CHECK: test.op_with_enum_prop first + "test.op_with_enum_prop"() <{value = 0 : i32}> {} : () -> () + + // CHECK: test.op_with_enum_prop_attr_form <{value = 0 : i32}> + test.op_with_enum_prop_attr_form <{value = 0 : i32}> + // CHECK: test.op_with_enum_prop_attr_form <{value = 1 : i32}> + test.op_with_enum_prop_attr_form <{value = #test}> + + // CHECK: test.op_with_enum_prop_attr_form_always <{value = #test}> + test.op_with_enum_prop_attr_form_always <{value = #test}> + // CHECK: test.op_with_enum_prop_attr_form_always <{value = #test} + test.op_with_enum_prop_attr_form_always <{value = #test}> + + return +} + +// CHECK-LABEL @test_bit_enum_prop() +func.func @test_bit_enum_prop() -> () { + // CHECK: test.op_with_bit_enum_prop read : () + test.op_with_bit_enum_prop read read : () + + // CHECK: test.op_with_bit_enum_prop read, write write, execute + test.op_with_bit_enum_prop read, write write, execute : () + + // CHECK: test.op_with_bit_enum_prop read, execute write + "test.op_with_bit_enum_prop"() <{value1 = 5 : i32, value2 = 2 : i32}> {} : () -> () + + // CHECK: test.op_with_bit_enum_prop read, write, execute + test.op_with_bit_enum_prop read, write, execute : () + + // CHECK: test.op_with_bit_enum_prop_named bit_enum{{$}} + test.op_with_bit_enum_prop_named bit_enum bit_enum + // CHECK: test.op_with_bit_enum_prop_named bit_enum bit_enum + test.op_with_bit_enum_prop_named bit_enum bit_enum + // CHECK: test.op_with_bit_enum_prop_named bit_enum + test.op_with_bit_enum_prop_named bit_enum + + return +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 31be00ace1384..85a49e05d4c73 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -423,6 +423,52 @@ def : Pat<(OpWithEnum ConstantEnumCase:$value, (OpWithEnum ConstantEnumCase, ConstantAttr)>; +//===----------------------------------------------------------------------===// +// Test Enum Properties +//===----------------------------------------------------------------------===// + +// Define the enum property. +def TestEnumProp : EnumProp; +// Define an op that contains the enum property. +def OpWithEnumProp : TEST_Op<"op_with_enum_prop"> { + let arguments = (ins TestEnumProp:$value); + let assemblyFormat = "$value attr-dict"; +} + +def TestEnumPropAttrForm : EnumPropWithAttrForm; +def OpWithEnumPropAttrForm : TEST_Op<"op_with_enum_prop_attr_form"> { + let arguments = (ins TestEnumPropAttrForm:$value); + let assemblyFormat = "prop-dict attr-dict"; +} + +def TestEnumPropAttrFormAlways : EnumPropWithAttrForm { + let storeInCustomAttribute = 1; +} +def OpWithEnumPropAttrFormAlways : TEST_Op<"op_with_enum_prop_attr_form_always"> { + let arguments = (ins TestEnumPropAttrFormAlways:$value); + let assemblyFormat = "prop-dict attr-dict"; +} + +def TestBitEnumProp : EnumProp { + let defaultValue = TestBitEnum.cppType # "::Read"; +} +def OpWithTestBitEnum : TEST_Op<"op_with_bit_enum_prop"> { + let arguments = (ins + TestBitEnumProp:$value1, + TestBitEnumProp:$value2); + let assemblyFormat = "$value1 ($value2^)? attr-dict `:` `(``)`"; +} + +def TestBitEnumPropNamed : NamedEnumProp { + let defaultValue = TestBitEnum.cppType # "::Read"; +} +def OpWithBitEnumPropNamed : TEST_Op<"op_with_bit_enum_prop_named"> { + let arguments = (ins + TestBitEnumPropNamed:$value1, + TestBitEnumPropNamed:$value2); + let assemblyFormat = "$value1 ($value2^)? attr-dict"; +} + //===----------------------------------------------------------------------===// // Test Bit Enum Attributes //===----------------------------------------------------------------------===//