Skip to content

Commit 84cc186

Browse files
nikalrajoker-eph
andauthored
[mlir] Support DialectRegistry extension comparison (#101119)
`PassManager::run` loads the dependent dialects for each pass into the current context prior to invoking the individual passes. If the dependent dialect is already loaded into the context, this should be a no-op. However, if there are extensions registered in the `DialectRegistry`, the dependent dialects are unconditionally registered into the context. This poses a problem for dynamic pass pipelines, however, because they will likely be executing while the context is in an immutable state (because of the parent pass pipeline being run). To solve this, we'll update the extension registration API on `DialectRegistry` to require a type ID for each extension that is registered. Then, instead of unconditionally registered dialects into a context if extensions are present, we'll check against the extension type IDs already present in the context's internal `DialectRegistry`. The context will only be marked as dirty if there are net-new extension types present in the `DialectRegistry` populated by `PassManager::getDependentDialects`. Note: this PR removes the `addExtension` overload that utilizes `std::function` as the parameter. This is because `std::function` is copyable and potentially allocates memory for the contained function so we can't use the function pointer as the unique type ID for the extension. Downstream changes required: - Existing `DialectExtension` subclasses will need a type ID to be registered for each subclass. More details on how to register a type ID can be found here: https://github.com/llvm/llvm-project/blob/8b68e06731e0033ed3f8d6fe6292ae671611cfa1/mlir/include/mlir/Support/TypeID.h#L30 - Existing uses of the `std::function` overload of `addExtension` will need to be refactored into dedicated `DialectExtension` classes with associated type IDs. The attached `std::function` can either be inlined into or called directly from `DialectExtension::apply`. --------- Co-authored-by: Mehdi Amini <[email protected]>
1 parent 2fd2fd2 commit 84cc186

File tree

25 files changed

+167
-32
lines changed

25 files changed

+167
-32
lines changed

mlir/examples/transform/Ch2/lib/MyExtension.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@
2929
class MyExtension
3030
: public ::mlir::transform::TransformDialectExtension<MyExtension> {
3131
public:
32+
// The TypeID of this extension.
33+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MyExtension)
34+
3235
// The extension must derive the base constructor.
3336
using Base::Base;
3437

mlir/examples/transform/Ch3/lib/MyExtension.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@
3535
class MyExtension
3636
: public ::mlir::transform::TransformDialectExtension<MyExtension> {
3737
public:
38+
// The TypeID of this extension.
39+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MyExtension)
40+
3841
// The extension must derive the base constructor.
3942
using Base::Base;
4043

mlir/examples/transform/Ch4/lib/MyExtension.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131
class MyExtension
3232
: public ::mlir::transform::TransformDialectExtension<MyExtension> {
3333
public:
34+
// The TypeID of this extension.
35+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MyExtension)
36+
3437
// The extension must derive the base constructor.
3538
using Base::Base;
3639

mlir/include/mlir/IR/DialectRegistry.h

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414
#define MLIR_IR_DIALECTREGISTRY_H
1515

1616
#include "mlir/IR/MLIRContext.h"
17+
#include "mlir/Support/TypeID.h"
1718
#include "llvm/ADT/ArrayRef.h"
18-
#include "llvm/ADT/SmallVector.h"
19-
#include "llvm/ADT/StringRef.h"
19+
#include "llvm/ADT/MapVector.h"
2020

2121
#include <map>
2222
#include <tuple>
@@ -187,7 +187,8 @@ class DialectRegistry {
187187
nameAndRegistrationIt.second.second);
188188
// Merge the extensions.
189189
for (const auto &extension : extensions)
190-
destination.extensions.push_back(extension->clone());
190+
destination.extensions.try_emplace(extension.first,
191+
extension.second->clone());
191192
}
192193

193194
/// Return the names of dialects known to this registry.
@@ -206,47 +207,47 @@ class DialectRegistry {
206207
void applyExtensions(MLIRContext *ctx) const;
207208

208209
/// Add the given extension to the registry.
209-
void addExtension(std::unique_ptr<DialectExtensionBase> extension) {
210-
extensions.push_back(std::move(extension));
210+
bool addExtension(TypeID extensionID,
211+
std::unique_ptr<DialectExtensionBase> extension) {
212+
return extensions.try_emplace(extensionID, std::move(extension)).second;
211213
}
212214

213215
/// Add the given extensions to the registry.
214216
template <typename... ExtensionsT>
215217
void addExtensions() {
216-
(addExtension(std::make_unique<ExtensionsT>()), ...);
218+
(addExtension(TypeID::get<ExtensionsT>(), std::make_unique<ExtensionsT>()),
219+
...);
217220
}
218221

219222
/// Add an extension function that requires the given dialects.
220223
/// Note: This bare functor overload is provided in addition to the
221224
/// std::function variant to enable dialect type deduction, e.g.:
222-
/// registry.addExtension(+[](MLIRContext *ctx, MyDialect *dialect) { ... })
225+
/// registry.addExtension(+[](MLIRContext *ctx, MyDialect *dialect) {
226+
/// ... })
223227
///
224228
/// is equivalent to:
225229
/// registry.addExtension<MyDialect>(
226230
/// [](MLIRContext *ctx, MyDialect *dialect){ ... }
227231
/// )
228232
template <typename... DialectsT>
229-
void addExtension(void (*extensionFn)(MLIRContext *, DialectsT *...)) {
230-
addExtension<DialectsT...>(
231-
std::function<void(MLIRContext *, DialectsT * ...)>(extensionFn));
232-
}
233-
template <typename... DialectsT>
234-
void
235-
addExtension(std::function<void(MLIRContext *, DialectsT *...)> extensionFn) {
236-
using ExtensionFnT = std::function<void(MLIRContext *, DialectsT * ...)>;
233+
bool addExtension(void (*extensionFn)(MLIRContext *, DialectsT *...)) {
234+
using ExtensionFnT = void (*)(MLIRContext *, DialectsT *...);
237235

238236
struct Extension : public DialectExtension<Extension, DialectsT...> {
239237
Extension(const Extension &) = default;
240238
Extension(ExtensionFnT extensionFn)
241-
: extensionFn(std::move(extensionFn)) {}
239+
: DialectExtension<Extension, DialectsT...>(),
240+
extensionFn(extensionFn) {}
242241
~Extension() override = default;
243242

244243
void apply(MLIRContext *context, DialectsT *...dialects) const final {
245244
extensionFn(context, dialects...);
246245
}
247246
ExtensionFnT extensionFn;
248247
};
249-
addExtension(std::make_unique<Extension>(std::move(extensionFn)));
248+
return addExtension(TypeID::getFromOpaquePointer(
249+
reinterpret_cast<const void *>(extensionFn)),
250+
std::make_unique<Extension>(extensionFn));
250251
}
251252

252253
/// Returns true if the current registry is a subset of 'rhs', i.e. if 'rhs'
@@ -255,7 +256,7 @@ class DialectRegistry {
255256

256257
private:
257258
MapTy registry;
258-
std::vector<std::unique_ptr<DialectExtensionBase>> extensions;
259+
llvm::MapVector<TypeID, std::unique_ptr<DialectExtensionBase>> extensions;
259260
};
260261

261262
} // namespace mlir

mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ namespace {
3535
/// starting a pass pipeline that involves dialect conversion to LLVM.
3636
class LoadDependentDialectExtension : public DialectExtensionBase {
3737
public:
38+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LoadDependentDialectExtension)
39+
3840
LoadDependentDialectExtension() : DialectExtensionBase(/*dialectNames=*/{}) {}
3941

4042
void apply(MLIRContext *context,

mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,8 @@ class AffineTransformDialectExtension
157157
: public transform::TransformDialectExtension<
158158
AffineTransformDialectExtension> {
159159
public:
160+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AffineTransformDialectExtension)
161+
160162
using Base::Base;
161163

162164
void init() {

mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,9 @@ class BufferizationTransformDialectExtension
150150
: public transform::TransformDialectExtension<
151151
BufferizationTransformDialectExtension> {
152152
public:
153+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
154+
BufferizationTransformDialectExtension)
155+
153156
using Base::Base;
154157

155158
void init() {

mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,8 @@ class FuncTransformDialectExtension
236236
: public transform::TransformDialectExtension<
237237
FuncTransformDialectExtension> {
238238
public:
239+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FuncTransformDialectExtension)
240+
239241
using Base::Base;
240242

241243
void init() {

mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -924,6 +924,8 @@ class GPUTransformDialectExtension
924924
: public transform::TransformDialectExtension<
925925
GPUTransformDialectExtension> {
926926
public:
927+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GPUTransformDialectExtension)
928+
927929
GPUTransformDialectExtension() {
928930
declareGeneratedDialect<scf::SCFDialect>();
929931
declareGeneratedDialect<arith::ArithDialect>();

mlir/lib/Dialect/Linalg/TransformOps/DialectExtension.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ class LinalgTransformDialectExtension
3030
: public transform::TransformDialectExtension<
3131
LinalgTransformDialectExtension> {
3232
public:
33+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LinalgTransformDialectExtension)
34+
3335
using Base::Base;
3436

3537
void init() {

mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,8 @@ class MemRefTransformDialectExtension
309309
: public transform::TransformDialectExtension<
310310
MemRefTransformDialectExtension> {
311311
public:
312+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MemRefTransformDialectExtension)
313+
312314
using Base::Base;
313315

314316
void init() {

mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1135,6 +1135,8 @@ class NVGPUTransformDialectExtension
11351135
: public transform::TransformDialectExtension<
11361136
NVGPUTransformDialectExtension> {
11371137
public:
1138+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(NVGPUTransformDialectExtension)
1139+
11381140
NVGPUTransformDialectExtension() {
11391141
declareGeneratedDialect<arith::ArithDialect>();
11401142
declareGeneratedDialect<affine::AffineDialect>();

mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,8 @@ class SCFTransformDialectExtension
613613
: public transform::TransformDialectExtension<
614614
SCFTransformDialectExtension> {
615615
public:
616+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SCFTransformDialectExtension)
617+
616618
using Base::Base;
617619

618620
void init() {

mlir/lib/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ class SparseTensorTransformDialectExtension
3838
: public transform::TransformDialectExtension<
3939
SparseTensorTransformDialectExtension> {
4040
public:
41+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
42+
SparseTensorTransformDialectExtension)
43+
4144
SparseTensorTransformDialectExtension() {
4245
declareGeneratedDialect<sparse_tensor::SparseTensorDialect>();
4346
registerTransformOps<

mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,8 @@ class TensorTransformDialectExtension
236236
: public transform::TransformDialectExtension<
237237
TensorTransformDialectExtension> {
238238
public:
239+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TensorTransformDialectExtension)
240+
239241
using Base::Base;
240242

241243
void init() {

mlir/lib/Dialect/Transform/DebugExtension/DebugExtension.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ namespace {
2020
class DebugExtension
2121
: public transform::TransformDialectExtension<DebugExtension> {
2222
public:
23+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DebugExtension)
24+
2325
void init() {
2426
registerTransformOps<
2527
#define GET_OP_LIST

mlir/lib/Dialect/Transform/IRDLExtension/IRDLExtension.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ namespace {
1818
class IRDLExtension
1919
: public transform::TransformDialectExtension<IRDLExtension> {
2020
public:
21+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(IRDLExtension)
22+
2123
void init() {
2224
registerTransformOps<
2325
#define GET_OP_LIST

mlir/lib/Dialect/Transform/LoopExtension/LoopExtension.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ namespace {
2020
class LoopExtension
2121
: public transform::TransformDialectExtension<LoopExtension> {
2222
public:
23+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LoopExtension)
24+
2325
void init() {
2426
registerTransformOps<
2527
#define GET_OP_LIST

mlir/lib/Dialect/Transform/PDLExtension/PDLExtension.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ namespace {
3838
/// with Transform dialect operations.
3939
class PDLExtension : public transform::TransformDialectExtension<PDLExtension> {
4040
public:
41+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PDLExtension)
42+
4143
void init() {
4244
registerTransformOps<
4345
#define GET_OP_LIST

mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,8 @@ class VectorTransformDialectExtension
212212
: public transform::TransformDialectExtension<
213213
VectorTransformDialectExtension> {
214214
public:
215+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VectorTransformDialectExtension)
216+
215217
VectorTransformDialectExtension() {
216218
declareGeneratedDialect<vector::VectorDialect>();
217219
declareGeneratedDialect<LLVM::LLVMDialect>();

mlir/lib/IR/Dialect.cpp

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,20 @@
1111
#include "mlir/IR/Diagnostics.h"
1212
#include "mlir/IR/DialectImplementation.h"
1313
#include "mlir/IR/DialectInterface.h"
14+
#include "mlir/IR/DialectRegistry.h"
1415
#include "mlir/IR/ExtensibleDialect.h"
1516
#include "mlir/IR/MLIRContext.h"
1617
#include "mlir/IR/Operation.h"
18+
#include "mlir/Support/TypeID.h"
1719
#include "llvm/ADT/MapVector.h"
20+
#include "llvm/ADT/SetOperations.h"
21+
#include "llvm/ADT/SmallVector.h"
22+
#include "llvm/ADT/SmallVectorExtras.h"
1823
#include "llvm/ADT/Twine.h"
1924
#include "llvm/Support/Debug.h"
2025
#include "llvm/Support/ManagedStatic.h"
2126
#include "llvm/Support/Regex.h"
27+
#include <memory>
2228

2329
#define DEBUG_TYPE "dialect"
2430

@@ -173,6 +179,40 @@ bool dialect_extension_detail::hasPromisedInterface(Dialect &dialect,
173179
// DialectRegistry
174180
//===----------------------------------------------------------------------===//
175181

182+
namespace {
183+
template <typename Fn>
184+
void applyExtensionsFn(
185+
Fn &&applyExtension,
186+
const llvm::MapVector<TypeID, std::unique_ptr<DialectExtensionBase>>
187+
&extensions) {
188+
// Note: Additional extensions may be added while applying an extension.
189+
// The iterators will be invalidated if extensions are added so we'll keep
190+
// a copy of the extensions for ourselves.
191+
192+
const auto extractExtension =
193+
[](const auto &entry) -> DialectExtensionBase * {
194+
return entry.second.get();
195+
};
196+
197+
auto startIt = extensions.begin(), endIt = extensions.end();
198+
size_t count = 0;
199+
while (startIt != endIt) {
200+
count += endIt - startIt;
201+
202+
// Grab the subset of extensions we'll apply in this iteration.
203+
const auto subset =
204+
llvm::map_to_vector(llvm::make_range(startIt, endIt), extractExtension);
205+
206+
for (const auto *ext : subset)
207+
applyExtension(*ext);
208+
209+
// Book-keep for the next iteration.
210+
startIt = extensions.begin() + count;
211+
endIt = extensions.end();
212+
}
213+
}
214+
} // namespace
215+
176216
DialectRegistry::DialectRegistry() { insert<BuiltinDialect>(); }
177217

178218
DialectAllocatorFunctionRef
@@ -258,9 +298,7 @@ void DialectRegistry::applyExtensions(Dialect *dialect) const {
258298
extension.apply(ctx, requiredDialects);
259299
};
260300

261-
// Note: Additional extensions may be added while applying an extension.
262-
for (int i = 0; i < static_cast<int>(extensions.size()); ++i)
263-
applyExtension(*extensions[i]);
301+
applyExtensionsFn(applyExtension, extensions);
264302
}
265303

266304
void DialectRegistry::applyExtensions(MLIRContext *ctx) const {
@@ -285,15 +323,17 @@ void DialectRegistry::applyExtensions(MLIRContext *ctx) const {
285323
extension.apply(ctx, requiredDialects);
286324
};
287325

288-
// Note: Additional extensions may be added while applying an extension.
289-
for (int i = 0; i < static_cast<int>(extensions.size()); ++i)
290-
applyExtension(*extensions[i]);
326+
applyExtensionsFn(applyExtension, extensions);
291327
}
292328

293329
bool DialectRegistry::isSubsetOf(const DialectRegistry &rhs) const {
294-
// Treat any extensions conservatively.
295-
if (!extensions.empty())
330+
// Check that all extension keys are present in 'rhs'.
331+
const auto hasExtension = [&](const auto &key) {
332+
return rhs.extensions.contains(key);
333+
};
334+
if (!llvm::all_of(make_first_range(extensions), hasExtension))
296335
return false;
336+
297337
// Check that the current dialects fully overlap with the dialects in 'rhs'.
298338
return llvm::all_of(
299339
registry, [&](const auto &it) { return rhs.registry.count(it.first); });

mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -874,6 +874,8 @@ class TestTransformDialectExtension
874874
: public transform::TransformDialectExtension<
875875
TestTransformDialectExtension> {
876876
public:
877+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformDialectExtension)
878+
877879
using Base::Base;
878880

879881
void init() {

mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,9 @@ class TestTilingInterfaceDialectExtension
382382
: public transform::TransformDialectExtension<
383383
TestTilingInterfaceDialectExtension> {
384384
public:
385+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
386+
TestTilingInterfaceDialectExtension)
387+
385388
using Base::Base;
386389

387390
void init() {

mlir/unittests/Dialect/Transform/BuildOnlyExtensionTest.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ using namespace mlir::transform;
1818
namespace {
1919
class Extension : public TransformDialectExtension<Extension> {
2020
public:
21+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(Extension)
22+
2123
using Base::Base;
2224
void init() { declareGeneratedDialect<func::FuncDialect>(); }
2325
};

0 commit comments

Comments
 (0)