Skip to content

Commit 5c1752e

Browse files
authored
[MLIR][DLTI] Pretty parsing and printing for DLTI attrs (#113365)
Unifies parsing and printing for DLTI attributes. Introduces a format of `#dlti.attr<key1 = val1, ..., keyN = valN>` syntax for all queryable DLTI attributes similar to that of the DictionaryAttr, while retaining support for specifying key-value pairs with `#dlti.dl_entry` (whether to retain this is TBD). As the new format does away with most of the boilerplate, it is much easier to parse for humans. This makes an especially big difference for nested attributes. Updates the DLTI-using tests and includes fixes for misc error checking/ error messages.
1 parent 9cc2981 commit 5c1752e

File tree

14 files changed

+538
-413
lines changed

14 files changed

+538
-413
lines changed

flang/test/Fir/tco-default-datalayout.fir

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@ module {
77
// CHECK: module attributes {
88
// CHECK-SAME: dlti.dl_spec = #dlti.dl_spec<
99
// ...
10-
// CHECK-SAME: #dlti.dl_entry<i64, dense<[32, 64]> : vector<2xi64>>,
10+
// CHECK-SAME: i64 = dense<[32, 64]> : vector<2xi64>,
1111
// ...
1212
// CHECK-SAME: llvm.data_layout = ""

flang/test/Fir/tco-explicit-datalayout.fir

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,6 @@ module attributes {llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i6
88
// CHECK: module attributes {
99
// CHECK-SAME: dlti.dl_spec = #dlti.dl_spec<
1010
// ...
11-
// CHECK-SAME: #dlti.dl_entry<i64, dense<128> : vector<2xi64>>,
11+
// CHECK-SAME: i64 = dense<128> : vector<2xi64>,
1212
// ...
1313
// CHECK-SAME: llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:128-i128:128-f80:128-n8:16:32:64-S128"

mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td

+31-26
Original file line numberDiff line numberDiff line change
@@ -88,11 +88,15 @@ def DLTI_DataLayoutSpecAttr :
8888

8989
/// Returns the attribute associated with the key.
9090
FailureOr<Attribute> query(DataLayoutEntryKey key) {
91-
return llvm::cast<mlir::DataLayoutSpecInterface>(*this).queryHelper(key);
91+
return ::llvm::cast<mlir::DataLayoutSpecInterface>(*this).queryHelper(key);
9292
}
9393
}];
9494
}
9595

96+
//===----------------------------------------------------------------------===//
97+
// MapAttr
98+
//===----------------------------------------------------------------------===//
99+
96100
def DLTI_MapAttr : DLTIAttr<"Map", [DLTIQueryInterface]> {
97101
let summary = "A mapping of DLTI-information by way of key-value pairs";
98102
let description = [{
@@ -106,18 +110,16 @@ def DLTI_MapAttr : DLTIAttr<"Map", [DLTIQueryInterface]> {
106110

107111
Consider the following flat encoding of a single-key dictionary
108112
```
109-
#dlti.map<#dlti.dl_entry<"CPU::cache::L1::size_in_bytes", 65536 : i32>>
113+
#dlti.map<"CPU::cache::L1::size_in_bytes" = 65536 : i32>>
110114
```
111115
versus nested maps, which make it possible to obtain sub-dictionaries of
112116
related information (with the following example making use of other
113117
attributes that also implement the `DLTIQueryInterface`):
114118
```
115-
#dlti.target_system_spec<"CPU":
116-
#dlti.target_device_spec<#dlti.dl_entry<"cache",
117-
#dlti.map<#dlti.dl_entry<"L1",
118-
#dlti.map<#dlti.dl_entry<"size_in_bytes", 65536 : i32>>>,
119-
#dlti.dl_entry<"L1d",
120-
#dlti.map<#dlti.dl_entry<"size_in_bytes", 32768 : i32>>> >>>>
119+
#dlti.target_system_spec<"CPU" =
120+
#dlti.target_device_spec<"cache" =
121+
#dlti.map<"L1" = #dlti.map<"size_in_bytes" = 65536 : i32>,
122+
"L1d" = #dlti.map<"size_in_bytes" = 32768 : i32> >>>
121123
```
122124

123125
With the flat encoding, the implied structure of the key is ignored, that is
@@ -132,14 +134,13 @@ def DLTI_MapAttr : DLTIAttr<"Map", [DLTIQueryInterface]> {
132134
`transform.dlti.query ["CPU","cache","L1","size_in_bytes"] at %op` gives
133135
back the first leaf value contained. To access the other leaf, we need to do
134136
`transform.dlti.query ["CPU","cache","L1d","size_in_bytes"] at %op`.
135-
```
136137
}];
137138
let parameters = (ins
138139
ArrayRefParameter<"DataLayoutEntryInterface", "">:$entries
139140
);
140141
let mnemonic = "map";
141142
let genVerifyDecl = 1;
142-
let assemblyFormat = "`<` $entries `>`";
143+
let hasCustomAssemblyFormat = 1;
143144
let extraClassDeclaration = [{
144145
/// Returns the attribute associated with the key.
145146
FailureOr<Attribute> query(DataLayoutEntryKey key) {
@@ -167,20 +168,23 @@ def DLTI_TargetSystemSpecAttr :
167168
```
168169
dlti.target_system_spec =
169170
#dlti.target_system_spec<
170-
"CPU": #dlti.target_device_spec<
171-
#dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 4096: ui32>>,
172-
"GPU": #dlti.target_device_spec<
173-
#dlti.dl_entry<"dlti.max_vector_op_width", 64 : ui32>>,
174-
"XPU": #dlti.target_device_spec<
175-
#dlti.dl_entry<"dlti.max_vector_op_width", 4096 : ui32>>>
171+
"CPU" = #dlti.target_device_spec<
172+
"L1_cache_size_in_bytes" = 4096: ui32>,
173+
"GPU" = #dlti.target_device_spec<
174+
"max_vector_op_width" = 64 : ui32>,
175+
"XPU" = #dlti.target_device_spec<
176+
"max_vector_op_width" = 4096 : ui32>>
176177
```
178+
179+
The verifier checks that keys are strings and pointed to values implement
180+
DLTI's TargetDeviceSpecInterface.
177181
}];
178182
let parameters = (ins
179-
ArrayRefParameter<"DeviceIDTargetDeviceSpecPair", "">:$entries
183+
ArrayRefParameter<"DataLayoutEntryInterface">:$entries
180184
);
181185
let mnemonic = "target_system_spec";
182186
let genVerifyDecl = 1;
183-
let assemblyFormat = "`<` $entries `>`";
187+
let hasCustomAssemblyFormat = 1;
184188
let extraClassDeclaration = [{
185189
/// Return the device specification that matches the given device ID
186190
std::optional<TargetDeviceSpecInterface>
@@ -189,16 +193,18 @@ def DLTI_TargetSystemSpecAttr :
189193

190194
/// Returns the attribute associated with the key.
191195
FailureOr<Attribute> query(DataLayoutEntryKey key) const {
192-
return llvm::cast<mlir::TargetSystemSpecInterface>(*this).queryHelper(key);
196+
return ::llvm::cast<mlir::TargetSystemSpecInterface>(*this).queryHelper(key);
193197
}
194198
}];
195199
let extraClassDefinition = [{
196200
std::optional<TargetDeviceSpecInterface>
197201
$cppClass::getDeviceSpecForDeviceID(
198202
TargetSystemSpecInterface::DeviceID deviceID) {
199203
for (const auto& entry : getEntries()) {
200-
if (entry.first == deviceID)
201-
return entry.second;
204+
if (entry.getKey() == DataLayoutEntryKey(deviceID))
205+
if (auto deviceSpec =
206+
::llvm::dyn_cast<TargetDeviceSpecInterface>(entry.getValue()))
207+
return deviceSpec;
202208
}
203209
return std::nullopt;
204210
}
@@ -219,21 +225,20 @@ def DLTI_TargetDeviceSpecAttr :
219225

220226
Example:
221227
```
222-
#dlti.target_device_spec<
223-
#dlti.dl_entry<"dlti.max_vector_op_width", 64 : ui32>>
228+
#dlti.target_device_spec<"max_vector_op_width" = 64 : ui32>
224229
```
225230
}];
226231
let parameters = (ins
227-
ArrayRefParameter<"DataLayoutEntryInterface", "">:$entries
232+
ArrayRefParameter<"DataLayoutEntryInterface">:$entries
228233
);
229234
let mnemonic = "target_device_spec";
230235
let genVerifyDecl = 1;
231-
let assemblyFormat = "`<` $entries `>`";
236+
let hasCustomAssemblyFormat = 1;
232237

233238
let extraClassDeclaration = [{
234239
/// Returns the attribute associated with the key.
235240
FailureOr<Attribute> query(DataLayoutEntryKey key) const {
236-
return llvm::cast<mlir::TargetDeviceSpecInterface>(*this).queryHelper(key);
241+
return ::llvm::cast<mlir::TargetDeviceSpecInterface>(*this).queryHelper(key);
237242
}
238243
}];
239244
}

mlir/include/mlir/Interfaces/DataLayoutInterfaces.h

+2-4
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#ifndef MLIR_INTERFACES_DATALAYOUTINTERFACES_H
1616
#define MLIR_INTERFACES_DATALAYOUTINTERFACES_H
1717

18+
#include "mlir/IR/Attributes.h"
1819
#include "mlir/IR/DialectInterface.h"
1920
#include "mlir/IR/OpDefinition.h"
2021
#include "llvm/ADT/DenseMap.h"
@@ -32,10 +33,7 @@ using DataLayoutEntryKey = llvm::PointerUnion<Type, StringAttr>;
3233
using DataLayoutEntryList = llvm::SmallVector<DataLayoutEntryInterface, 4>;
3334
using DataLayoutEntryListRef = llvm::ArrayRef<DataLayoutEntryInterface>;
3435
using TargetDeviceSpecListRef = llvm::ArrayRef<TargetDeviceSpecInterface>;
35-
using DeviceIDTargetDeviceSpecPair =
36-
std::pair<StringAttr, TargetDeviceSpecInterface>;
37-
using DeviceIDTargetDeviceSpecPairListRef =
38-
llvm::ArrayRef<DeviceIDTargetDeviceSpecPair>;
36+
using TargetDeviceSpecEntry = std::pair<StringAttr, TargetDeviceSpecInterface>;
3937
class DataLayoutOpInterface;
4038
class DataLayoutSpecInterface;
4139
class ModuleOp;

mlir/include/mlir/Interfaces/DataLayoutInterfaces.td

+3-3
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def TargetDeviceSpecInterface : AttrInterface<"TargetDeviceSpecInterface", [DLTI
276276
/// Helper for default implementation of `DLTIQueryInterface`'s `query`.
277277
::mlir::FailureOr<::mlir::Attribute>
278278
queryHelper(::mlir::DataLayoutEntryKey key) const {
279-
if (auto strKey = llvm::dyn_cast<StringAttr>(key))
279+
if (auto strKey = ::llvm::dyn_cast<StringAttr>(key))
280280
if (DataLayoutEntryInterface spec = getSpecForIdentifier(strKey))
281281
return spec.getValue();
282282
return ::mlir::failure();
@@ -304,7 +304,7 @@ def TargetSystemSpecInterface : AttrInterface<"TargetSystemSpecInterface", [DLTI
304304
let methods = [
305305
InterfaceMethod<
306306
/*description=*/"Returns the list of layout entries.",
307-
/*retTy=*/"llvm::ArrayRef<DeviceIDTargetDeviceSpecPair>",
307+
/*retTy=*/"::llvm::ArrayRef<DataLayoutEntryInterface>",
308308
/*methodName=*/"getEntries",
309309
/*args=*/(ins)
310310
>,
@@ -334,7 +334,7 @@ def TargetSystemSpecInterface : AttrInterface<"TargetSystemSpecInterface", [DLTI
334334
/// Helper for default implementation of `DLTIQueryInterface`'s `query`.
335335
::mlir::FailureOr<::mlir::Attribute>
336336
queryHelper(::mlir::DataLayoutEntryKey key) const {
337-
if (auto strKey = llvm::dyn_cast<::mlir::StringAttr>(key))
337+
if (auto strKey = ::llvm::dyn_cast<::mlir::StringAttr>(key))
338338
if (auto deviceSpec = getDeviceSpecForDeviceID(strKey))
339339
return *deviceSpec;
340340
return ::mlir::failure();

0 commit comments

Comments
 (0)