Skip to content

Commit 75a73d8

Browse files
committed
[MLIR][DLTI] Pretty parsing and printing for DLTI attrs
Unifies parsing and printing for DLTI attributes. Introduces syntax of `#dlti.attr<key1 = val1, ..., keyN = valN>` for all queryable DLTI attributes, while retaining support for specifying key-value entry pairs with `#dlti.dl_entry` (whether to retain this is TBD). As the new format does away with much of the boilerplate, it is much easier on the eye. This makes an especially big difference for nested attributes. Updates the DLTI tests and includes fixes for misc error checking/ error messages.
1 parent 0764e55 commit 75a73d8

File tree

12 files changed

+510
-400
lines changed

12 files changed

+510
-400
lines changed

mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td

+28-22
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,10 @@ def DLTI_DataLayoutSpecAttr :
9393
}];
9494
}
9595

96+
//===----------------------------------------------------------------------===//
97+
// MapAttr
98+
//===----------------------------------------------------------------------===//
99+
96100
def DLTI_MapAttr : DLTIAttr<"Map", [DLTIQueryInterface]> {
97101
let summary = "A mapping of DLTI-information by way of key-value pairs";
98102
let description = [{
@@ -106,18 +110,16 @@ def DLTI_MapAttr : DLTIAttr<"Map", [DLTIQueryInterface]> {
106110

107111
Consider the following flat encoding of a single-key dictionary
108112
```
109-
#dlti.map<#dlti.dl_entry<"CPU::cache::L1::size_in_bytes", 65536 : i32>>
113+
#dlti.map<"CPU::cache::L1::size_in_bytes" = 65536 : i32>>
110114
```
111115
versus nested maps, which make it possible to obtain sub-dictionaries of
112116
related information (with the following example making use of other
113117
attributes that also implement the `DLTIQueryInterface`):
114118
```
115-
#dlti.target_system_spec<"CPU":
116-
#dlti.target_device_spec<#dlti.dl_entry<"cache",
117-
#dlti.map<#dlti.dl_entry<"L1",
118-
#dlti.map<#dlti.dl_entry<"size_in_bytes", 65536 : i32>>>,
119-
#dlti.dl_entry<"L1d",
120-
#dlti.map<#dlti.dl_entry<"size_in_bytes", 32768 : i32>>> >>>>
119+
#dlti.target_system_spec<"CPU" =
120+
#dlti.target_device_spec<"cache" =
121+
#dlti.map<"L1" = #dlti.map<"size_in_bytes" = 65536 : i32>,
122+
"L1d" = #dlti.map<"size_in_bytes" = 32768 : i32> >>>
121123
```
122124

123125
With the flat encoding, the implied structure of the key is ignored, that is
@@ -139,7 +141,7 @@ def DLTI_MapAttr : DLTIAttr<"Map", [DLTIQueryInterface]> {
139141
);
140142
let mnemonic = "map";
141143
let genVerifyDecl = 1;
142-
let assemblyFormat = "`<` $entries `>`";
144+
let hasCustomAssemblyFormat = 1;
143145
let extraClassDeclaration = [{
144146
/// Returns the attribute associated with the key.
145147
FailureOr<Attribute> query(DataLayoutEntryKey key) {
@@ -167,20 +169,23 @@ def DLTI_TargetSystemSpecAttr :
167169
```
168170
dlti.target_system_spec =
169171
#dlti.target_system_spec<
170-
"CPU": #dlti.target_device_spec<
171-
#dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 4096: ui32>>,
172-
"GPU": #dlti.target_device_spec<
173-
#dlti.dl_entry<"dlti.max_vector_op_width", 64 : ui32>>,
174-
"XPU": #dlti.target_device_spec<
175-
#dlti.dl_entry<"dlti.max_vector_op_width", 4096 : ui32>>>
172+
"CPU" = #dlti.target_device_spec<
173+
"L1_cache_size_in_bytes" = 4096: ui32>,
174+
"GPU" = #dlti.target_device_spec<
175+
"max_vector_op_width" = 64 : ui32>,
176+
"XPU" = #dlti.target_device_spec<
177+
"max_vector_op_width" = 4096 : ui32>>
176178
```
179+
180+
The verifier checks that keys are strings and pointed to values implement
181+
DLTI's TargetDeviceSpecInterface.
177182
}];
178183
let parameters = (ins
179-
ArrayRefParameter<"DeviceIDTargetDeviceSpecPair", "">:$entries
184+
ArrayRefParameter<"DataLayoutEntryInterface">:$entries
180185
);
181186
let mnemonic = "target_system_spec";
182187
let genVerifyDecl = 1;
183-
let assemblyFormat = "`<` $entries `>`";
188+
let hasCustomAssemblyFormat = 1;
184189
let extraClassDeclaration = [{
185190
/// Return the device specification that matches the given device ID
186191
std::optional<TargetDeviceSpecInterface>
@@ -197,8 +202,10 @@ def DLTI_TargetSystemSpecAttr :
197202
$cppClass::getDeviceSpecForDeviceID(
198203
TargetSystemSpecInterface::DeviceID deviceID) {
199204
for (const auto& entry : getEntries()) {
200-
if (entry.first == deviceID)
201-
return entry.second;
205+
if (entry.getKey() == DataLayoutEntryKey(deviceID))
206+
if (auto deviceSpec =
207+
llvm::dyn_cast<TargetDeviceSpecInterface>(entry.getValue()))
208+
return deviceSpec;
202209
}
203210
return std::nullopt;
204211
}
@@ -219,16 +226,15 @@ def DLTI_TargetDeviceSpecAttr :
219226

220227
Example:
221228
```
222-
#dlti.target_device_spec<
223-
#dlti.dl_entry<"dlti.max_vector_op_width", 64 : ui32>>
229+
#dlti.target_device_spec<"max_vector_op_width" = 64 : ui32>
224230
```
225231
}];
226232
let parameters = (ins
227-
ArrayRefParameter<"DataLayoutEntryInterface", "">:$entries
233+
ArrayRefParameter<"DataLayoutEntryInterface">:$entries
228234
);
229235
let mnemonic = "target_device_spec";
230236
let genVerifyDecl = 1;
231-
let assemblyFormat = "`<` $entries `>`";
237+
let hasCustomAssemblyFormat = 1;
232238

233239
let extraClassDeclaration = [{
234240
/// Returns the attribute associated with the key.

mlir/include/mlir/Interfaces/DataLayoutInterfaces.h

+2-4
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#ifndef MLIR_INTERFACES_DATALAYOUTINTERFACES_H
1616
#define MLIR_INTERFACES_DATALAYOUTINTERFACES_H
1717

18+
#include "mlir/IR/Attributes.h"
1819
#include "mlir/IR/DialectInterface.h"
1920
#include "mlir/IR/OpDefinition.h"
2021
#include "llvm/ADT/DenseMap.h"
@@ -32,10 +33,7 @@ using DataLayoutEntryKey = llvm::PointerUnion<Type, StringAttr>;
3233
using DataLayoutEntryList = llvm::SmallVector<DataLayoutEntryInterface, 4>;
3334
using DataLayoutEntryListRef = llvm::ArrayRef<DataLayoutEntryInterface>;
3435
using TargetDeviceSpecListRef = llvm::ArrayRef<TargetDeviceSpecInterface>;
35-
using DeviceIDTargetDeviceSpecPair =
36-
std::pair<StringAttr, TargetDeviceSpecInterface>;
37-
using DeviceIDTargetDeviceSpecPairListRef =
38-
llvm::ArrayRef<DeviceIDTargetDeviceSpecPair>;
36+
using TargetDeviceSpecEntry = std::pair<StringAttr, TargetDeviceSpecInterface>;
3937
class DataLayoutOpInterface;
4038
class DataLayoutSpecInterface;
4139
class ModuleOp;

mlir/include/mlir/Interfaces/DataLayoutInterfaces.td

+1-1
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ def TargetSystemSpecInterface : AttrInterface<"TargetSystemSpecInterface", [DLTI
304304
let methods = [
305305
InterfaceMethod<
306306
/*description=*/"Returns the list of layout entries.",
307-
/*retTy=*/"llvm::ArrayRef<DeviceIDTargetDeviceSpecPair>",
307+
/*retTy=*/"llvm::ArrayRef<DataLayoutEntryInterface>",
308308
/*methodName=*/"getEntries",
309309
/*args=*/(ins)
310310
>,

0 commit comments

Comments
 (0)