Skip to content

Commit 84cc186

Browse files
nikalrajoker-eph
andauthored
[mlir] Support DialectRegistry extension comparison (#101119)
`PassManager::run` loads the dependent dialects for each pass into the current context prior to invoking the individual passes. If the dependent dialect is already loaded into the context, this should be a no-op. However, if there are extensions registered in the `DialectRegistry`, the dependent dialects are unconditionally registered into the context. This poses a problem for dynamic pass pipelines, however, because they will likely be executing while the context is in an immutable state (because of the parent pass pipeline being run). To solve this, we'll update the extension registration API on `DialectRegistry` to require a type ID for each extension that is registered. Then, instead of unconditionally registered dialects into a context if extensions are present, we'll check against the extension type IDs already present in the context's internal `DialectRegistry`. The context will only be marked as dirty if there are net-new extension types present in the `DialectRegistry` populated by `PassManager::getDependentDialects`. Note: this PR removes the `addExtension` overload that utilizes `std::function` as the parameter. This is because `std::function` is copyable and potentially allocates memory for the contained function so we can't use the function pointer as the unique type ID for the extension. Downstream changes required: - Existing `DialectExtension` subclasses will need a type ID to be registered for each subclass. More details on how to register a type ID can be found here: https://github.com/llvm/llvm-project/blob/8b68e06731e0033ed3f8d6fe6292ae671611cfa1/mlir/include/mlir/Support/TypeID.h#L30 - Existing uses of the `std::function` overload of `addExtension` will need to be refactored into dedicated `DialectExtension` classes with associated type IDs. The attached `std::function` can either be inlined into or called directly from `DialectExtension::apply`. --------- Co-authored-by: Mehdi Amini <[email protected]>
1 parent 2fd2fd2 commit 84cc186

File tree

25 files changed

+167
-32
lines changed

25 files changed

+167
-32
lines changed

mlir/examples/transform/Ch2/lib/MyExtension.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@
2929
class MyExtension
3030
: public ::mlir::transform::TransformDialectExtension<MyExtension> {
3131
public:
32+
// The TypeID of this extension.
33+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MyExtension)
34+
3235
// The extension must derive the base constructor.
3336
using Base::Base;
3437

mlir/examples/transform/Ch3/lib/MyExtension.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@
3535
class MyExtension
3636
: public ::mlir::transform::TransformDialectExtension<MyExtension> {
3737
public:
38+
// The TypeID of this extension.
39+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MyExtension)
40+
3841
// The extension must derive the base constructor.
3942
using Base::Base;
4043

mlir/examples/transform/Ch4/lib/MyExtension.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131
class MyExtension
3232
: public ::mlir::transform::TransformDialectExtension<MyExtension> {
3333
public:
34+
// The TypeID of this extension.
35+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MyExtension)
36+
3437
// The extension must derive the base constructor.
3538
using Base::Base;
3639

mlir/include/mlir/IR/DialectRegistry.h

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414
#define MLIR_IR_DIALECTREGISTRY_H
1515

1616
#include "mlir/IR/MLIRContext.h"
17+
#include "mlir/Support/TypeID.h"
1718
#include "llvm/ADT/ArrayRef.h"
18-
#include "llvm/ADT/SmallVector.h"
19-
#include "llvm/ADT/StringRef.h"
19+
#include "llvm/ADT/MapVector.h"
2020

2121
#include <map>
2222
#include <tuple>
@@ -187,7 +187,8 @@ class DialectRegistry {
187187
nameAndRegistrationIt.second.second);
188188
// Merge the extensions.
189189
for (const auto &extension : extensions)
190-
destination.extensions.push_back(extension->clone());
190+
destination.extensions.try_emplace(extension.first,
191+
extension.second->clone());
191192
}
192193

193194
/// Return the names of dialects known to this registry.
@@ -206,47 +207,47 @@ class DialectRegistry {
206207
void applyExtensions(MLIRContext *ctx) const;
207208

208209
/// Add the given extension to the registry.
209-
void addExtension(std::unique_ptr<DialectExtensionBase> extension) {
210-
extensions.push_back(std::move(extension));
210+
bool addExtension(TypeID extensionID,
211+
std::unique_ptr<DialectExtensionBase> extension) {
212+
return extensions.try_emplace(extensionID, std::move(extension)).second;
211213
}
212214

213215
/// Add the given extensions to the registry.
214216
template <typename... ExtensionsT>
215217
void addExtensions() {
216-
(addExtension(std::make_unique<ExtensionsT>()), ...);
218+
(addExtension(TypeID::get<ExtensionsT>(), std::make_unique<ExtensionsT>()),
219+
...);
217220
}
218221

219222
/// Add an extension function that requires the given dialects.
220223
/// Note: This bare functor overload is provided in addition to the
221224
/// std::function variant to enable dialect type deduction, e.g.:
222-
/// registry.addExtension(+[](MLIRContext *ctx, MyDialect *dialect) { ... })
225+
/// registry.addExtension(+[](MLIRContext *ctx, MyDialect *dialect) {
226+
/// ... })
223227
///
224228
/// is equivalent to:
225229
/// registry.addExtension<MyDialect>(
226230
/// [](MLIRContext *ctx, MyDialect *dialect){ ... }
227231
/// )
228232
template <typename... DialectsT>
229-
void addExtension(void (*extensionFn)(MLIRContext *, DialectsT *...)) {
230-
addExtension<DialectsT...>(
231-
std::function<void(MLIRContext *, DialectsT * ...)>(extensionFn));
232-
}
233-
template <typename... DialectsT>
234-
void
235-
addExtension(std::function<void(MLIRContext *, DialectsT *...)> extensionFn) {
236-
using ExtensionFnT = std::function<void(MLIRContext *, DialectsT * ...)>;
233+
bool addExtension(void (*extensionFn)(MLIRContext *, DialectsT *...)) {
234+
using ExtensionFnT = void (*)(MLIRContext *, DialectsT *...);
237235

238236
struct Extension : public DialectExtension<Extension, DialectsT...> {
239237
Extension(const Extension &) = default;
240238
Extension(ExtensionFnT extensionFn)
241-
: extensionFn(std::move(extensionFn)) {}
239+
: DialectExtension<Extension, DialectsT...>(),
240+
extensionFn(extensionFn) {}
242241
~Extension() override = default;
243242

244243
void apply(MLIRContext *context, DialectsT *...dialects) const final {
245244
extensionFn(context, dialects...);
246245
}
247246
ExtensionFnT extensionFn;
248247
};
249-
addExtension(std::make_unique<Extension>(std::move(extensionFn)));
248+
return addExtension(TypeID::getFromOpaquePointer(
249+
reinterpret_cast<const void *>(extensionFn)),
250+
std::make_unique<Extension>(extensionFn));
250251
}
251252

252253
/// Returns true if the current registry is a subset of 'rhs', i.e. if 'rhs'
@@ -255,7 +256,7 @@ class DialectRegistry {
255256

256257
private:
257258
MapTy registry;
258-
std::vector<std::unique_ptr<DialectExtensionBase>> extensions;
259+
llvm::MapVector<TypeID, std::unique_ptr<DialectExtensionBase>> extensions;
259260
};
260261

261262
} // namespace mlir

mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ namespace {
3535
/// starting a pass pipeline that involves dialect conversion to LLVM.
3636
class LoadDependentDialectExtension : public DialectExtensionBase {
3737
public:
38+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LoadDependentDialectExtension)
39+
3840
LoadDependentDialectExtension() : DialectExtensionBase(/*dialectNames=*/{}) {}
3941

4042
void apply(MLIRContext *context,

mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,8 @@ class AffineTransformDialectExtension
157157
: public transform::TransformDialectExtension<
158158
AffineTransformDialectExtension> {
159159
public:
160+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AffineTransformDialectExtension)
161+
160162
using Base::Base;
161163

162164
void init() {

mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,9 @@ class BufferizationTransformDialectExtension
150150
: public transform::TransformDialectExtension<
151151
BufferizationTransformDialectExtension> {
152152
public:
153+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
154+
BufferizationTransformDialectExtension)
155+
153156
using Base::Base;
154157

155158
void init() {

mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,8 @@ class FuncTransformDialectExtension
236236
: public transform::TransformDialectExtension<
237237
FuncTransformDialectExtension> {
238238
public:
239+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FuncTransformDialectExtension)
240+
239241
using Base::Base;
240242

241243
void init() {

mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -924,6 +924,8 @@ class GPUTransformDialectExtension
924924
: public transform::TransformDialectExtension<
925925
GPUTransformDialectExtension> {
926926
public:
927+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GPUTransformDialectExtension)
928+
927929
GPUTransformDialectExtension() {
928930
declareGeneratedDialect<scf::SCFDialect>();
929931
declareGeneratedDialect<arith::ArithDialect>();

mlir/lib/Dialect/Linalg/TransformOps/DialectExtension.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ class LinalgTransformDialectExtension
3030
: public transform::TransformDialectExtension<
3131
LinalgTransformDialectExtension> {
3232
public:
33+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LinalgTransformDialectExtension)
34+
3335
using Base::Base;
3436

3537
void init() {

0 commit comments

Comments
 (0)