Skip to content

Commit 7448ef1

Browse files
committed
[MLIR][DLTI][Transform] Introduce transform.dlti.query
This transform op makes it possible to query attributes associated to IR by means of the DLTI dialect. The op takes both a `key` and a target `op` to perform the query at. Facility functions automatically find the closest ancestor op which defines the appropriate DLTI interface or has an attribute implementing a DLTI interface. By default the lookup uses the data layout interfaces of DLTI. If the optional `device` parameter is provided, the lookup happens with respect to the interfaces for TargetSystemSpec and TargetDeviceSpec. This op uses new free-standing functions in the `dlti` namespace to not only look up specifications via the DataLayoutSpecOpInterface and on ModuleOps but also on any ancestor op that has an appropriate DLTI attribute.
1 parent 3a8a0b8 commit 7448ef1

File tree

11 files changed

+556
-0
lines changed

11 files changed

+556
-0
lines changed

mlir/include/mlir/Dialect/DLTI/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
add_subdirectory(TransformOps)
2+
13
add_mlir_dialect(DLTI dlti)
24
add_mlir_doc(DLTIAttrs DLTIDialect Dialects/ -gen-dialect-doc)
35

mlir/include/mlir/Dialect/DLTI/DLTI.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@ namespace detail {
2222
class DataLayoutEntryAttrStorage;
2323
} // namespace detail
2424
} // namespace mlir
25+
namespace mlir {
26+
namespace dlti {
27+
DataLayoutSpecInterface getDataLayoutSpec(Operation *op);
28+
TargetSystemSpecInterface getTargetSystemSpec(Operation *op);
29+
} // namespace dlti
30+
} // namespace mlir
2531

2632
#define GET_ATTRDEF_CLASSES
2733
#include "mlir/Dialect/DLTI/DLTIAttrs.h.inc"
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
set(LLVM_TARGET_DEFINITIONS DLTITransformOps.td)
2+
mlir_tablegen(DLTITransformOps.h.inc -gen-op-decls)
3+
mlir_tablegen(DLTITransformOps.cpp.inc -gen-op-defs)
4+
add_public_tablegen_target(MLIRDLTITransformOpsIncGen)
5+
6+
add_mlir_doc(DLTITransformOps DLTITransformOps Dialects/ -gen-op-doc)
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
//===- DLTITransformOps.h - DLTI transform ops ------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_DLTI_TRANSFORMOPS_DLTITRANSFORMOPS_H
10+
#define MLIR_DIALECT_DLTI_TRANSFORMOPS_DLTITRANSFORMOPS_H
11+
12+
#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
13+
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
14+
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
15+
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
16+
17+
namespace mlir {
18+
namespace transform {
19+
class QueryOp;
20+
} // namespace transform
21+
} // namespace mlir
22+
23+
namespace mlir {
24+
class DialectRegistry;
25+
26+
namespace dlti {
27+
void registerTransformDialectExtension(DialectRegistry &registry);
28+
} // namespace dlti
29+
} // namespace mlir
30+
31+
////===----------------------------------------------------------------------===//
32+
//// DLTI Transform Operations
33+
////===----------------------------------------------------------------------===//
34+
35+
#define GET_OP_CLASSES
36+
#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h.inc"
37+
38+
#endif // MLIR_DIALECT_DLTI_TRANSFORMOPS_DLTITRANSFORMOPS_H
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
//===- DLTITransformOps.td - DLTI transform ops ------------*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef DLTI_TRANSFORM_OPS
10+
#define DLTI_TRANSFORM_OPS
11+
12+
include "mlir/Dialect/Transform/IR/TransformDialect.td"
13+
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
14+
include "mlir/Dialect/Transform/IR/TransformTypes.td"
15+
include "mlir/Interfaces/SideEffectInterfaces.td"
16+
include "mlir/IR/OpBase.td"
17+
18+
def QueryOp : Op<Transform_Dialect, "dlti.query", [
19+
TransformOpInterface, TransformEachOpTrait,
20+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
21+
]> {
22+
let summary = "Return attribute (as param) associated to key via DTLI";
23+
let description = [{
24+
This op queries data layout and target information associated to payload
25+
IR by way of the DLTI dialect. A lookup is performed for the given `key`
26+
at the `target` op, with the DLTI dialect determining which interfaces and
27+
attributes are consulted.
28+
29+
When only `key` is provided, the lookup occurs with respect to the data
30+
layout specification of DLTI. When `device` is provided, the lookup occurs
31+
with respect to DLTI's target device specifications associated to a DLTI
32+
system device specification.
33+
34+
#### Return modes
35+
36+
When succesfull, the result, `associated_attr`, associates one attribute
37+
as a param for each op in `target`'s payload.
38+
39+
If the lookup fails - as DLTI specifications or entries with the right
40+
names are missing (i.e. the values of `device` and `key`) - a definite
41+
failure is returned.
42+
}];
43+
44+
let arguments = (ins TransformHandleTypeInterface:$target,
45+
OptionalAttr<StrAttr>:$device,
46+
StrAttr:$key);
47+
let results = (outs TransformParamTypeInterface:$associated_attr);
48+
let assemblyFormat =
49+
"(`:``:` $device^ `:``:`)? $key `at` $target attr-dict `:`"
50+
"functional-type(operands, results)";
51+
52+
let extraClassDeclaration = [{
53+
::mlir::DiagnosedSilenceableFailure applyToOne(
54+
::mlir::transform::TransformRewriter &rewriter,
55+
::mlir::Operation *target,
56+
::mlir::transform::ApplyToEachResultList &results,
57+
TransformState &state);
58+
}];
59+
}
60+
61+
#endif // DLTI_TRANSFORM_OPS

mlir/include/mlir/InitAllExtensions.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
2626
#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
2727
#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
28+
#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h"
2829
#include "mlir/Dialect/Func/Extensions/AllExtensions.h"
2930
#include "mlir/Dialect/Func/TransformOps/FuncTransformOps.h"
3031
#include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h"
@@ -69,6 +70,7 @@ inline void registerAllExtensions(DialectRegistry &registry) {
6970
// Register all transform dialect extensions.
7071
affine::registerTransformDialectExtension(registry);
7172
bufferization::registerTransformDialectExtension(registry);
73+
dlti::registerTransformDialectExtension(registry);
7274
func::registerTransformDialectExtension(registry);
7375
gpu::registerTransformDialectExtension(registry);
7476
linalg::registerTransformDialectExtension(registry);

mlir/lib/Dialect/DLTI/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
add_subdirectory(TransformOps)
12
add_mlir_dialect_library(MLIRDLTIDialect
23
DLTI.cpp
34
Traits.cpp

mlir/lib/Dialect/DLTI/DLTI.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,41 @@ TargetSystemSpecAttr::verify(function_ref<InFlightDiagnostic()> emitError,
389389
return success();
390390
}
391391

392+
DataLayoutSpecInterface dlti::getDataLayoutSpec(Operation *op) {
393+
DataLayoutSpecInterface dlSpec = nullptr;
394+
395+
for (Operation *cur = op; cur && !dlSpec; cur = cur->getParentOp()) {
396+
if (auto dataLayoutOp = dyn_cast<DataLayoutOpInterface>(cur))
397+
dlSpec = dataLayoutOp.getDataLayoutSpec();
398+
else if (auto moduleOp = dyn_cast<ModuleOp>(cur))
399+
dlSpec = moduleOp.getDataLayoutSpec();
400+
else
401+
for (NamedAttribute attr : cur->getAttrs())
402+
if ((dlSpec = llvm::dyn_cast<DataLayoutSpecInterface>(attr.getValue())))
403+
break;
404+
}
405+
406+
return dlSpec;
407+
}
408+
409+
TargetSystemSpecInterface dlti::getTargetSystemSpec(Operation *op) {
410+
TargetSystemSpecInterface sysSpec = nullptr;
411+
412+
for (Operation *cur = op; cur && !sysSpec; cur = cur->getParentOp()) {
413+
if (auto dataLayoutOp = dyn_cast<DataLayoutOpInterface>(cur))
414+
sysSpec = dataLayoutOp.getTargetSystemSpec();
415+
else if (auto moduleOp = dyn_cast<ModuleOp>(cur))
416+
sysSpec = moduleOp.getTargetSystemSpec();
417+
else
418+
for (NamedAttribute attr : cur->getAttrs())
419+
if ((sysSpec =
420+
llvm::dyn_cast<TargetSystemSpecInterface>(attr.getValue())))
421+
break;
422+
}
423+
424+
return sysSpec;
425+
}
426+
392427
//===----------------------------------------------------------------------===//
393428
// DLTIDialect
394429
//===----------------------------------------------------------------------===//
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
add_mlir_dialect_library(MLIRDLTITransformOps
2+
DLTITransformOps.cpp
3+
4+
ADDITIONAL_HEADER_DIRS
5+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/DLTI/TransformOps
6+
7+
DEPENDS
8+
MLIRDLTITransformOpsIncGen
9+
MLIRDLTIDialect
10+
11+
LINK_LIBS PUBLIC
12+
MLIRDLTIDialect
13+
MLIRSideEffectInterfaces
14+
MLIRTransformDialect
15+
)
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
2+
//===- DLTITransformOps.cpp - Implementation of DLTI transform ops --------===//
3+
//
4+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
10+
#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h"
11+
12+
#include "mlir/Dialect/DLTI/DLTI.h"
13+
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
14+
#include "mlir/Dialect/Transform/Utils/Utils.h"
15+
#include "mlir/Interfaces/DataLayoutInterfaces.h"
16+
17+
using namespace mlir;
18+
using namespace mlir::transform;
19+
20+
#define DEBUG_TYPE "dlti-transforms"
21+
22+
//===----------------------------------------------------------------------===//
23+
// FuseOp
24+
//===----------------------------------------------------------------------===//
25+
26+
void transform::QueryOp::getEffects(
27+
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
28+
onlyReadsHandle(getTargetMutable(), effects);
29+
producesHandle(getOperation()->getOpResults(), effects);
30+
onlyReadsPayload(effects);
31+
}
32+
33+
DiagnosedSilenceableFailure transform::QueryOp::applyToOne(
34+
transform::TransformRewriter &rewriter, Operation *target,
35+
transform::ApplyToEachResultList &results, TransformState &state) {
36+
StringAttr deviceId = getDeviceAttr();
37+
StringAttr key = getKeyAttr();
38+
39+
DataLayoutEntryInterface entry;
40+
if (deviceId) {
41+
TargetSystemSpecInterface sysSpec = dlti::getTargetSystemSpec(target);
42+
if (!sysSpec)
43+
return mlir::emitDefiniteFailure(target->getLoc())
44+
<< "no target system spec associated to: " << target;
45+
46+
if (auto targetSpec = sysSpec.getDeviceSpecForDeviceID(deviceId))
47+
entry = targetSpec->getSpecForIdentifier(key);
48+
else
49+
return mlir::emitDefiniteFailure(target->getLoc())
50+
<< "no " << deviceId << " target device spec found";
51+
} else {
52+
DataLayoutSpecInterface dlSpec = dlti::getDataLayoutSpec(target);
53+
if (!dlSpec)
54+
return mlir::emitDefiniteFailure(target->getLoc())
55+
<< "no data layout spec associated to: " << target;
56+
57+
entry = dlSpec.getSpecForIdentifier(key);
58+
}
59+
60+
if (!entry)
61+
return mlir::emitDefiniteFailure(target->getLoc())
62+
<< "no DLTI entry for key: " << key;
63+
64+
results.push_back(entry.getValue());
65+
66+
return DiagnosedSilenceableFailure::success();
67+
}
68+
69+
//===----------------------------------------------------------------------===//
70+
// Transform op registration
71+
//===----------------------------------------------------------------------===//
72+
73+
namespace {
74+
class DLTITransformDialectExtension
75+
: public transform::TransformDialectExtension<
76+
DLTITransformDialectExtension> {
77+
public:
78+
using Base::Base;
79+
80+
void init() {
81+
registerTransformOps<
82+
#define GET_OP_LIST
83+
#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.cpp.inc"
84+
>();
85+
}
86+
};
87+
} // namespace
88+
89+
#define GET_OP_CLASSES
90+
#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.cpp.inc"
91+
92+
void mlir::dlti::registerTransformDialectExtension(DialectRegistry &registry) {
93+
registry.addExtensions<DLTITransformDialectExtension>();
94+
}

0 commit comments

Comments
 (0)