Skip to content

[mlir] Improve EnumProp, making it take an EnumInfo #132349

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions mlir/docs/DefiningDialects/Operations.md
Original file line number Diff line number Diff line change
Expand Up @@ -1756,6 +1756,23 @@ that it has a value within the valid range of the enum. If their
wrapper attribute instead of using a bare signless integer attribute
for storage.

### Enum properties

Enums can be wrapped in properties so that they can be stored inline.
This causes a value of the enum's C++ class to become a member of the operation's
property struct and for the operation's verifier to check that the enum's value
is a valid value for the enum.

The basic wrapper is `EnumProp`, which simply takes an `EnumInfo`.

A less ambiguous syntax, namely putting a mnemonic and `<>`s surrounding
the enum is generated with `NamedEnumProp`, which takes a `*EnumInfo`
and a mnemonic string, which becomes part of the property's syntax.

Both of these `EnumProp` types have a `*EnumPropWithAttrForm`, which allows for
transparently upgrading from `EnumAttr`s and optionally retaining those
attributes in the generic form.

## Debugging Tips

### Run `mlir-tblgen` to see the generated content
Expand Down
14 changes: 9 additions & 5 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
Original file line number Diff line number Diff line change
Expand Up @@ -485,17 +485,16 @@ def DISubprogramFlags : I32BitEnumAttr<
// IntegerOverflowFlags
//===----------------------------------------------------------------------===//

def IOFnone : I32BitEnumAttrCaseNone<"none">;
def IOFnsw : I32BitEnumAttrCaseBit<"nsw", 0>;
def IOFnuw : I32BitEnumAttrCaseBit<"nuw", 1>;
def IOFnone : I32BitEnumCaseNone<"none">;
def IOFnsw : I32BitEnumCaseBit<"nsw", 0>;
def IOFnuw : I32BitEnumCaseBit<"nuw", 1>;

def IntegerOverflowFlags : I32BitEnumAttr<
def IntegerOverflowFlags : I32BitEnum<
"IntegerOverflowFlags",
"LLVM integer overflow flags",
[IOFnone, IOFnsw, IOFnuw]> {
let separator = ", ";
let cppNamespace = "::mlir::LLVM";
let genSpecializedAttr = 0;
let printBitEnumPrimaryGroups = 1;
}

Expand All @@ -504,6 +503,11 @@ def LLVM_IntegerOverflowFlagsAttr :
let assemblyFormat = "`<` $value `>`";
}

def LLVM_IntegerOverflowFlagsProp :
NamedEnumPropWithAttrForm<IntegerOverflowFlags, "overflow", LLVM_IntegerOverflowFlagsAttr> {
let defaultValue = enum.cppType # "::" # "none";
}

//===----------------------------------------------------------------------===//
// FastmathFlags
//===----------------------------------------------------------------------===//
Expand Down
8 changes: 4 additions & 4 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class LLVM_IntArithmeticOpWithOverflowFlag<string mnemonic, string instName,
list<Trait> traits = []> :
LLVM_ArithmeticOpBase<AnySignlessInteger, mnemonic, instName,
!listconcat([DeclareOpInterfaceMethods<IntegerOverflowFlagsInterface>], traits)> {
dag iofArg = (ins EnumProp<"IntegerOverflowFlags", "", "IntegerOverflowFlags::none">:$overflowFlags);
dag iofArg = (ins LLVM_IntegerOverflowFlagsProp:$overflowFlags);
let arguments = !con(commonArgs, iofArg);

string mlirBuilder = [{
Expand All @@ -69,7 +69,7 @@ class LLVM_IntArithmeticOpWithOverflowFlag<string mnemonic, string instName,
$res = op;
}];
let assemblyFormat = [{
$lhs `,` $rhs `` custom<OverflowFlags>($overflowFlags) attr-dict `:` type($res)
$lhs `,` $rhs ($overflowFlags^)? attr-dict `:` type($res)
}];
string llvmBuilder =
"$res = builder.Create" # instName #
Expand Down Expand Up @@ -563,10 +563,10 @@ class LLVM_CastOpWithOverflowFlag<string mnemonic, string instName, Type type,
Type resultType, list<Trait> traits = []> :
LLVM_Op<mnemonic, !listconcat([Pure], [DeclareOpInterfaceMethods<IntegerOverflowFlagsInterface>], traits)>,
LLVM_Builder<"$res = builder.Create" # instName # "($arg, $_resultType, /*Name=*/\"\", op.hasNoUnsignedWrap(), op.hasNoSignedWrap());"> {
let arguments = (ins type:$arg, EnumProp<"IntegerOverflowFlags", "", "IntegerOverflowFlags::none">:$overflowFlags);
let arguments = (ins type:$arg, LLVM_IntegerOverflowFlagsProp:$overflowFlags);
let results = (outs resultType:$res);
let builders = [LLVM_OneResultOpBuilder];
let assemblyFormat = "$arg `` custom<OverflowFlags>($overflowFlags) attr-dict `:` type($arg) `to` type($res)";
let assemblyFormat = "$arg ($overflowFlags^)? attr-dict `:` type($arg) `to` type($res)";
string llvmInstName = instName;
string mlirBuilder = [{
auto op = $_builder.create<$_qualCppClassName>(
Expand Down
136 changes: 136 additions & 0 deletions mlir/include/mlir/IR/EnumAttr.td
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#define ENUMATTR_TD

include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/Properties.td"

//===----------------------------------------------------------------------===//
// Enum attribute kinds
Expand Down Expand Up @@ -552,6 +553,141 @@ class EnumAttr<Dialect dialect, EnumInfo enumInfo, string name = "",
let assemblyFormat = "$value";
}

// A property wrapping by a C++ enum. This class will automatically create bytecode
// serialization logic for the given enum, as well as arranging for parser and
// printer calls.
class EnumProp<EnumInfo enumInfo> : Property<enumInfo.cppType, enumInfo.summary> {
EnumInfo enum = enumInfo;

let description = enum.description;
let predicate = !if(
!isa<BitEnumBase>(enum),
CPred<"(static_cast<" # enum.underlyingType # ">($_self) & ~" # !cast<BitEnumBase>(enum).validBits # ") == 0">,
Or<!foreach(case, enum.enumerants, CPred<"$_self == " # enum.cppType # "::" # case.symbol>)>);

let convertFromAttribute = [{
auto intAttr = ::mlir::dyn_cast_if_present<::mlir::IntegerAttr>($_attr);
if (!intAttr) {
return $_diag() << "expected IntegerAttr storage for }] #
enum.cppType # [{";
}
$_storage = static_cast<}] # enum.cppType # [{>(intAttr.getValue().getZExtValue());
return ::mlir::success();
}];

let convertToAttribute = [{
return ::mlir::IntegerAttr::get(::mlir::IntegerType::get($_ctxt, }] # enum.bitwidth
# [{), static_cast<}] # enum.underlyingType #[{>($_storage));
}];

let writeToMlirBytecode = [{
$_writer.writeVarInt(static_cast<uint64_t>($_storage));
}];

let readFromMlirBytecode = [{
uint64_t rawValue;
if (::mlir::failed($_reader.readVarInt(rawValue)))
return ::mlir::failure();
if (rawValue > std::numeric_limits<}] # enum.underlyingType # [{>::max())
return ::mlir::failure();
$_storage = static_cast<}] # enum.cppType # [{>(rawValue);
Copy link
Collaborator

@joker-eph joker-eph Apr 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we check that the read bits don't overflow the storage?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The static cast will truncate if needed, no?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My point is that this should be a failure instead of silently discarding some bytecode inputs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I can see that - will do

}];

let optionalParser = [{
auto value = ::mlir::FieldParser<std::optional<}] # enum.cppType # [{>>::parse($_parser);
if (::mlir::failed(value))
return ::mlir::failure();
if (!(value->has_value()))
return std::nullopt;
$_storage = std::move(**value);
}];
}

// Enum property that can have been (or, if `storeInCustomAttribute` is true, will also
// be stored as) an attribute, in addition to being stored as an integer attribute.
class EnumPropWithAttrForm<EnumInfo enumInfo, Attr attributeForm>
: EnumProp<enumInfo> {
Attr attrForm = attributeForm;
bit storeInCustomAttribute = 0;

let convertFromAttribute = [{
auto customAttr = ::mlir::dyn_cast_if_present<}]
# attrForm.storageType # [{>($_attr);
if (customAttr) {
$_storage = customAttr.getValue();
return ::mlir::success();
}
auto intAttr = ::mlir::dyn_cast_if_present<::mlir::IntegerAttr>($_attr);
if (!intAttr) {
return $_diag() << "expected }] # attrForm.storageType
# [{ or IntegerAttr storage for }] # enum.cppType # [{";
}
$_storage = static_cast<}] # enum.cppType # [{>(intAttr.getValue().getZExtValue());
return ::mlir::success();
}];

let convertToAttribute = !if(storeInCustomAttribute, [{
return }] # attrForm.storageType # [{::get($_ctxt, $_storage);
}], [{
return ::mlir::IntegerAttr::get(::mlir::IntegerType::get($_ctxt, }] # enumInfo.bitwidth
# [{), static_cast<}] # enum.underlyingType #[{>($_storage));
}]);
}

class _namedEnumPropFields<string cppType, string mnemonic> {
code parser = [{
if ($_parser.parseKeyword("}] # mnemonic # [{")
|| $_parser.parseLess()) {
return ::mlir::failure();
}
auto parseRes = ::mlir::FieldParser<}] # cppType # [{>::parse($_parser);
if (::mlir::failed(parseRes) ||
::mlir::failed($_parser.parseGreater())) {
return ::mlir::failure();
}
$_storage = *parseRes;
}];

code optionalParser = [{
if ($_parser.parseOptionalKeyword("}] # mnemonic # [{")) {
return std::nullopt;
}
if ($_parser.parseLess()) {
return ::mlir::failure();
}
auto parseRes = ::mlir::FieldParser<}] # cppType # [{>::parse($_parser);
if (::mlir::failed(parseRes) ||
::mlir::failed($_parser.parseGreater())) {
return ::mlir::failure();
}
$_storage = *parseRes;
}];

code printer = [{
$_printer << "}] # mnemonic # [{<" << $_storage << ">";
}];
}

// An EnumProp which, when printed, is surrounded by mnemonic<>.
// For example, if the enum can be a, b, or c, and the mnemonic is foo,
// the format of this property will be "foo<a>", "foo<b>", or "foo<c>".
class NamedEnumProp<EnumInfo enumInfo, string name>
: EnumProp<enumInfo> {
string mnemonic = name;
let parser = _namedEnumPropFields<enum.cppType, mnemonic>.parser;
let optionalParser = _namedEnumPropFields<enum.cppType, mnemonic>.optionalParser;
let printer = _namedEnumPropFields<enum.cppType, mnemonic>.printer;
}

// A `NamedEnumProp` with an attribute form as in `EnumPropWithAttrForm`.
class NamedEnumPropWithAttrForm<EnumInfo enumInfo, string name, Attr attributeForm>
: EnumPropWithAttrForm<enumInfo, attributeForm> {
string mnemonic = name;
let parser = _namedEnumPropFields<enum.cppType, mnemonic>.parser;
let optionalParser = _namedEnumPropFields<enum.cppType, mnemonic>.optionalParser;
let printer = _namedEnumPropFields<enumInfo.cppType, mnemonic>.printer;
}

class _symbolToValue<EnumInfo enumInfo, string case> {
defvar cases =
!filter(iter, enumInfo.enumerants, !eq(iter.str, case));
Expand Down
19 changes: 0 additions & 19 deletions mlir/include/mlir/IR/Properties.td
Original file line number Diff line number Diff line change
Expand Up @@ -239,25 +239,6 @@ def I64Prop : IntProp<"int64_t">;
def I32Property : IntProp<"int32_t">, Deprecated<"moved to shorter name I32Prop">;
def I64Property : IntProp<"int64_t">, Deprecated<"moved to shorter name I64Prop">;

class EnumProp<string storageTypeParam, string desc = "", string default = ""> :
Property<storageTypeParam, desc> {
// TODO: implement predicate for enum validity.
let writeToMlirBytecode = [{
$_writer.writeVarInt(static_cast<uint64_t>($_storage));
}];
let readFromMlirBytecode = [{
uint64_t val;
if (failed($_reader.readVarInt(val)))
return ::mlir::failure();
$_storage = static_cast<}] # storageTypeParam # [{>(val);
}];
let defaultValue = default;
}

class EnumProperty<string storageTypeParam, string desc = "", string default = ""> :
EnumProp<storageTypeParam, desc, default>,
Deprecated<"moved to shorter name EnumProp">;

// Note: only a class so we can deprecate the old name
class _cls_StringProp : Property<"std::string", "string"> {
let interfaceType = "::llvm::StringRef";
Expand Down
65 changes: 0 additions & 65 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,71 +49,6 @@ using mlir::LLVM::tailcallkind::getMaxEnumValForTailCallKind;

#include "mlir/Dialect/LLVMIR/LLVMOpsDialect.cpp.inc"

//===----------------------------------------------------------------------===//
// Property Helpers
//===----------------------------------------------------------------------===//

//===----------------------------------------------------------------------===//
// IntegerOverflowFlags
//===----------------------------------------------------------------------===//

namespace mlir {
static Attribute convertToAttribute(MLIRContext *ctx,
IntegerOverflowFlags flags) {
return IntegerOverflowFlagsAttr::get(ctx, flags);
}

static LogicalResult
convertFromAttribute(IntegerOverflowFlags &flags, Attribute attr,
function_ref<InFlightDiagnostic()> emitError) {
auto flagsAttr = dyn_cast<IntegerOverflowFlagsAttr>(attr);
if (!flagsAttr) {
return emitError() << "expected 'overflowFlags' attribute to be an "
"IntegerOverflowFlagsAttr, but got "
<< attr;
}
flags = flagsAttr.getValue();
return success();
}
} // namespace mlir

static ParseResult parseOverflowFlags(AsmParser &p,
IntegerOverflowFlags &flags) {
if (failed(p.parseOptionalKeyword("overflow"))) {
flags = IntegerOverflowFlags::none;
return success();
}
if (p.parseLess())
return failure();
do {
StringRef kw;
SMLoc loc = p.getCurrentLocation();
if (p.parseKeyword(&kw))
return failure();
std::optional<IntegerOverflowFlags> flag =
symbolizeIntegerOverflowFlags(kw);
if (!flag)
return p.emitError(loc,
"invalid overflow flag: expected nsw, nuw, or none");
flags = flags | *flag;
} while (succeeded(p.parseOptionalComma()));
return p.parseGreater();
}

static void printOverflowFlags(AsmPrinter &p, Operation *op,
IntegerOverflowFlags flags) {
if (flags == IntegerOverflowFlags::none)
return;
p << " overflow<";
SmallVector<StringRef, 2> strs;
if (bitEnumContainsAny(flags, IntegerOverflowFlags::nsw))
strs.push_back("nsw");
if (bitEnumContainsAny(flags, IntegerOverflowFlags::nuw))
strs.push_back("nuw");
llvm::interleaveComma(strs, p);
p << ">";
}

//===----------------------------------------------------------------------===//
// Attribute Helpers
//===----------------------------------------------------------------------===//
Expand Down
Loading
Loading