Skip to content

[MLIR][DLTI][Transform] Introduce transform.dlti.query #101561

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/DLTI/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
add_subdirectory(TransformOps)

add_mlir_dialect(DLTI dlti)
add_mlir_doc(DLTIAttrs DLTIDialect Dialects/ -gen-dialect-doc)

Expand Down
13 changes: 13 additions & 0 deletions mlir/include/mlir/Dialect/DLTI/DLTI.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,19 @@ namespace detail {
class DataLayoutEntryAttrStorage;
} // namespace detail
} // namespace mlir
namespace mlir {
namespace dlti {
/// Find the first DataLayoutSpec associated to `op`, via either the
/// DataLayoutOpInterface, a method on ModuleOp, or an attribute implementing
/// the interface, on `op` and else on `op`'s ancestors in turn.
DataLayoutSpecInterface getDataLayoutSpec(Operation *op);

/// Find the first TargetSystemSpec associated to `op`, via either the
/// DataLayoutOpInterface, a method on ModuleOp, or an attribute implementing
/// the interface, on `op` and else on `op`'s ancestors in turn.
TargetSystemSpecInterface getTargetSystemSpec(Operation *op);
} // namespace dlti
} // namespace mlir

#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/DLTI/DLTIAttrs.h.inc"
Expand Down
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/DLTI/TransformOps/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
set(LLVM_TARGET_DEFINITIONS DLTITransformOps.td)
mlir_tablegen(DLTITransformOps.h.inc -gen-op-decls)
mlir_tablegen(DLTITransformOps.cpp.inc -gen-op-defs)
add_public_tablegen_target(MLIRDLTITransformOpsIncGen)

add_mlir_doc(DLTITransformOps DLTITransformOps Dialects/ -gen-op-doc)
38 changes: 38 additions & 0 deletions mlir/include/mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
//===- DLTITransformOps.h - DLTI transform ops ------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_DLTI_TRANSFORMOPS_DLTITRANSFORMOPS_H
#define MLIR_DIALECT_DLTI_TRANSFORMOPS_DLTITRANSFORMOPS_H

#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"

namespace mlir {
namespace transform {
class QueryOp;
} // namespace transform
} // namespace mlir

namespace mlir {
class DialectRegistry;

namespace dlti {
void registerTransformDialectExtension(DialectRegistry &registry);
} // namespace dlti
} // namespace mlir

////===----------------------------------------------------------------------===//
//// DLTI Transform Operations
////===----------------------------------------------------------------------===//

#define GET_OP_CLASSES
#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h.inc"

#endif // MLIR_DIALECT_DLTI_TRANSFORMOPS_DLTITRANSFORMOPS_H
61 changes: 61 additions & 0 deletions mlir/include/mlir/Dialect/DLTI/TransformOps/DLTITransformOps.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
//===- DLTITransformOps.td - DLTI transform ops ------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef DLTI_TRANSFORM_OPS
#define DLTI_TRANSFORM_OPS

include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
include "mlir/Dialect/Transform/IR/TransformTypes.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpBase.td"

def QueryOp : Op<Transform_Dialect, "dlti.query", [
TransformOpInterface, TransformEachOpTrait,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
]> {
let summary = "Return attribute (as param) associated to key via DTLI";
let description = [{
This op queries data layout and target information associated to payload
IR by way of the DLTI dialect. A lookup is performed for the given `key`
at the `target` op, with the DLTI dialect determining which interfaces and
attributes are consulted - first checking `target` and then its ancestors.

When only `key` is provided, the lookup occurs with respect to the data
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a currently arbitrary semantics, but I can't see a better way to encode this.

@ftynse we could probably add a similar root entry to DL like we have for TI. Then there would be no ambiguity.

Not that I like the root entry itself, but it does simplify things a lot.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the current state of DLTI, Data Layout info is always keyed by a single name. For querying Target Information, you must always specify a device ID in addition to a key to look up w.r.t. that device. So currently, whether a second key, deviceId, is provided is enough to indicate which type of DLTI info the user wants to query. (If this were no longer to be the case going forward, I would advocate for having both types of DLTI info implement a DLTIQueryInterface, or the like, probably supporting any number of keys so that nested dictionary lookups become possible.)

As for the ::$deviceId::$key syntax: I too would prefer for the leading :: to go. It's currently there to help the parser. With a custom parser, I am sure $deviceId::$key syntax is possible as well. As more changes for DLTI are planned, I was hoping to address this in later PRs (e.g. when would allow nested dictionary lookups).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't mind there being a root entry.

layout specification of DLTI. When `device` is provided, the lookup occurs
with respect to DLTI's target device specifications associated to a DLTI
system device specification.

#### Return modes

When succesful, the result, `associated_attr`, associates one attribute as a
param for each op in `target`'s payload.

If the lookup fails - as DLTI specifications or entries with the right
names are missing (i.e. the values of `device` and `key`) - a definite
failure is returned.
}];

let arguments = (ins TransformHandleTypeInterface:$target,
OptionalAttr<StrAttr>:$device,
StrAttr:$key);
let results = (outs TransformParamTypeInterface:$associated_attr);
let assemblyFormat =
"(`:``:` $device^ `:``:`)? $key `at` $target attr-dict `:`"
"functional-type(operands, results)";

let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::transform::TransformRewriter &rewriter,
::mlir::Operation *target,
::mlir::transform::ApplyToEachResultList &results,
TransformState &state);
}];
}

#endif // DLTI_TRANSFORM_OPS
2 changes: 2 additions & 0 deletions mlir/include/mlir/InitAllExtensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h"
#include "mlir/Dialect/Func/Extensions/AllExtensions.h"
#include "mlir/Dialect/Func/TransformOps/FuncTransformOps.h"
#include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h"
Expand Down Expand Up @@ -69,6 +70,7 @@ inline void registerAllExtensions(DialectRegistry &registry) {
// Register all transform dialect extensions.
affine::registerTransformDialectExtension(registry);
bufferization::registerTransformDialectExtension(registry);
dlti::registerTransformDialectExtension(registry);
func::registerTransformDialectExtension(registry);
gpu::registerTransformDialectExtension(registry);
linalg::registerTransformDialectExtension(registry);
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/DLTI/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
add_subdirectory(TransformOps)
add_mlir_dialect_library(MLIRDLTIDialect
DLTI.cpp
Traits.cpp
Expand Down
35 changes: 35 additions & 0 deletions mlir/lib/Dialect/DLTI/DLTI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,41 @@ TargetSystemSpecAttr::verify(function_ref<InFlightDiagnostic()> emitError,
// DLTIDialect
//===----------------------------------------------------------------------===//

DataLayoutSpecInterface dlti::getDataLayoutSpec(Operation *op) {
DataLayoutSpecInterface dlSpec = nullptr;

for (Operation *cur = op; cur && !dlSpec; cur = cur->getParentOp()) {
if (auto dataLayoutOp = dyn_cast<DataLayoutOpInterface>(cur))
dlSpec = dataLayoutOp.getDataLayoutSpec();
else if (auto moduleOp = dyn_cast<ModuleOp>(cur))
dlSpec = moduleOp.getDataLayoutSpec();
else
for (NamedAttribute attr : cur->getAttrs())
if ((dlSpec = llvm::dyn_cast<DataLayoutSpecInterface>(attr.getValue())))
break;
}

return dlSpec;
}

TargetSystemSpecInterface dlti::getTargetSystemSpec(Operation *op) {
TargetSystemSpecInterface sysSpec = nullptr;

for (Operation *cur = op; cur && !sysSpec; cur = cur->getParentOp()) {
if (auto dataLayoutOp = dyn_cast<DataLayoutOpInterface>(cur))
sysSpec = dataLayoutOp.getTargetSystemSpec();
else if (auto moduleOp = dyn_cast<ModuleOp>(cur))
sysSpec = moduleOp.getTargetSystemSpec();
else
for (NamedAttribute attr : cur->getAttrs())
if ((sysSpec =
llvm::dyn_cast<TargetSystemSpecInterface>(attr.getValue())))
break;
}

return sysSpec;
}

constexpr const StringLiteral mlir::DLTIDialect::kDataLayoutAttrName;
constexpr const StringLiteral mlir::DLTIDialect::kDataLayoutEndiannessKey;
constexpr const StringLiteral mlir::DLTIDialect::kDataLayoutEndiannessBig;
Expand Down
15 changes: 15 additions & 0 deletions mlir/lib/Dialect/DLTI/TransformOps/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
add_mlir_dialect_library(MLIRDLTITransformOps
DLTITransformOps.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/DLTI/TransformOps

DEPENDS
MLIRDLTITransformOpsIncGen
MLIRDLTIDialect

LINK_LIBS PUBLIC
MLIRDLTIDialect
MLIRSideEffectInterfaces
MLIRTransformDialect
)
94 changes: 94 additions & 0 deletions mlir/lib/Dialect/DLTI/TransformOps/DLTITransformOps.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@

//===- DLTITransformOps.cpp - Implementation of DLTI transform ops --------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h"

#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/Dialect/Transform/Utils/Utils.h"
#include "mlir/Interfaces/DataLayoutInterfaces.h"

using namespace mlir;
using namespace mlir::transform;

#define DEBUG_TYPE "dlti-transforms"

//===----------------------------------------------------------------------===//
// QueryOp
//===----------------------------------------------------------------------===//

void transform::QueryOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
onlyReadsHandle(getTargetMutable(), effects);
producesHandle(getOperation()->getOpResults(), effects);
onlyReadsPayload(effects);
}

DiagnosedSilenceableFailure transform::QueryOp::applyToOne(
transform::TransformRewriter &rewriter, Operation *target,
transform::ApplyToEachResultList &results, TransformState &state) {
StringAttr deviceId = getDeviceAttr();
StringAttr key = getKeyAttr();

DataLayoutEntryInterface entry;
if (deviceId) {
TargetSystemSpecInterface sysSpec = dlti::getTargetSystemSpec(target);
if (!sysSpec)
return mlir::emitDefiniteFailure(target->getLoc())
<< "no target system spec associated to: " << target;

if (auto targetSpec = sysSpec.getDeviceSpecForDeviceID(deviceId))
entry = targetSpec->getSpecForIdentifier(key);
else
return mlir::emitDefiniteFailure(target->getLoc())
<< "no " << deviceId << " target device spec found";
} else {
DataLayoutSpecInterface dlSpec = dlti::getDataLayoutSpec(target);
if (!dlSpec)
return mlir::emitDefiniteFailure(target->getLoc())
<< "no data layout spec associated to: " << target;

entry = dlSpec.getSpecForIdentifier(key);
}

if (!entry)
return mlir::emitDefiniteFailure(target->getLoc())
<< "no DLTI entry for key: " << key;

results.push_back(entry.getValue());

return DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//

namespace {
class DLTITransformDialectExtension
: public transform::TransformDialectExtension<
DLTITransformDialectExtension> {
public:
using Base::Base;

void init() {
registerTransformOps<
#define GET_OP_LIST
#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.cpp.inc"
>();
}
};
} // namespace

#define GET_OP_CLASSES
#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.cpp.inc"

void mlir::dlti::registerTransformDialectExtension(DialectRegistry &registry) {
registry.addExtensions<DLTITransformDialectExtension>();
}
Loading
Loading