diff --git a/mlir/examples/transform/Ch2/lib/MyExtension.cpp b/mlir/examples/transform/Ch2/lib/MyExtension.cpp index 68d538a098018..b4b27e97d266e 100644 --- a/mlir/examples/transform/Ch2/lib/MyExtension.cpp +++ b/mlir/examples/transform/Ch2/lib/MyExtension.cpp @@ -29,6 +29,9 @@ class MyExtension : public ::mlir::transform::TransformDialectExtension { public: + // The TypeID of this extension. + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MyExtension) + // The extension must derive the base constructor. using Base::Base; diff --git a/mlir/examples/transform/Ch3/lib/MyExtension.cpp b/mlir/examples/transform/Ch3/lib/MyExtension.cpp index f7a99423a52ee..4b2123fa71d31 100644 --- a/mlir/examples/transform/Ch3/lib/MyExtension.cpp +++ b/mlir/examples/transform/Ch3/lib/MyExtension.cpp @@ -35,6 +35,9 @@ class MyExtension : public ::mlir::transform::TransformDialectExtension { public: + // The TypeID of this extension. + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MyExtension) + // The extension must derive the base constructor. using Base::Base; diff --git a/mlir/examples/transform/Ch4/lib/MyExtension.cpp b/mlir/examples/transform/Ch4/lib/MyExtension.cpp index 38c8ca1125a24..fa0ffc9dc2e8a 100644 --- a/mlir/examples/transform/Ch4/lib/MyExtension.cpp +++ b/mlir/examples/transform/Ch4/lib/MyExtension.cpp @@ -31,6 +31,9 @@ class MyExtension : public ::mlir::transform::TransformDialectExtension { public: + // The TypeID of this extension. + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MyExtension) + // The extension must derive the base constructor. using Base::Base; diff --git a/mlir/include/mlir/IR/DialectRegistry.h b/mlir/include/mlir/IR/DialectRegistry.h index 8e394988119da..2c1f6964998e8 100644 --- a/mlir/include/mlir/IR/DialectRegistry.h +++ b/mlir/include/mlir/IR/DialectRegistry.h @@ -14,9 +14,9 @@ #define MLIR_IR_DIALECTREGISTRY_H #include "mlir/IR/MLIRContext.h" +#include "mlir/Support/TypeID.h" #include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/MapVector.h" #include #include @@ -187,7 +187,8 @@ class DialectRegistry { nameAndRegistrationIt.second.second); // Merge the extensions. for (const auto &extension : extensions) - destination.extensions.push_back(extension->clone()); + destination.extensions.try_emplace(extension.first, + extension.second->clone()); } /// Return the names of dialects known to this registry. @@ -206,39 +207,37 @@ class DialectRegistry { void applyExtensions(MLIRContext *ctx) const; /// Add the given extension to the registry. - void addExtension(std::unique_ptr extension) { - extensions.push_back(std::move(extension)); + bool addExtension(TypeID extensionID, + std::unique_ptr extension) { + return extensions.try_emplace(extensionID, std::move(extension)).second; } /// Add the given extensions to the registry. template void addExtensions() { - (addExtension(std::make_unique()), ...); + (addExtension(TypeID::get(), std::make_unique()), + ...); } /// Add an extension function that requires the given dialects. /// Note: This bare functor overload is provided in addition to the /// std::function variant to enable dialect type deduction, e.g.: - /// registry.addExtension(+[](MLIRContext *ctx, MyDialect *dialect) { ... }) + /// registry.addExtension(+[](MLIRContext *ctx, MyDialect *dialect) { + /// ... }) /// /// is equivalent to: /// registry.addExtension( /// [](MLIRContext *ctx, MyDialect *dialect){ ... } /// ) template - void addExtension(void (*extensionFn)(MLIRContext *, DialectsT *...)) { - addExtension( - std::function(extensionFn)); - } - template - void - addExtension(std::function extensionFn) { - using ExtensionFnT = std::function; + bool addExtension(void (*extensionFn)(MLIRContext *, DialectsT *...)) { + using ExtensionFnT = void (*)(MLIRContext *, DialectsT *...); struct Extension : public DialectExtension { Extension(const Extension &) = default; Extension(ExtensionFnT extensionFn) - : extensionFn(std::move(extensionFn)) {} + : DialectExtension(), + extensionFn(extensionFn) {} ~Extension() override = default; void apply(MLIRContext *context, DialectsT *...dialects) const final { @@ -246,7 +245,9 @@ class DialectRegistry { } ExtensionFnT extensionFn; }; - addExtension(std::make_unique(std::move(extensionFn))); + return addExtension(TypeID::getFromOpaquePointer( + reinterpret_cast(extensionFn)), + std::make_unique(extensionFn)); } /// Returns true if the current registry is a subset of 'rhs', i.e. if 'rhs' @@ -255,7 +256,7 @@ class DialectRegistry { private: MapTy registry; - std::vector> extensions; + llvm::MapVector> extensions; }; } // namespace mlir diff --git a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp index 6135117348a5b..b2407a258c271 100644 --- a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp +++ b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp @@ -35,6 +35,8 @@ namespace { /// starting a pass pipeline that involves dialect conversion to LLVM. class LoadDependentDialectExtension : public DialectExtensionBase { public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LoadDependentDialectExtension) + LoadDependentDialectExtension() : DialectExtensionBase(/*dialectNames=*/{}) {} void apply(MLIRContext *context, diff --git a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp index 6457655cfe416..eb52297940722 100644 --- a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp +++ b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp @@ -157,6 +157,8 @@ class AffineTransformDialectExtension : public transform::TransformDialectExtension< AffineTransformDialectExtension> { public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AffineTransformDialectExtension) + using Base::Base; void init() { diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp index e10c7bd914e35..a1d7bb995fc73 100644 --- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp +++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp @@ -150,6 +150,9 @@ class BufferizationTransformDialectExtension : public transform::TransformDialectExtension< BufferizationTransformDialectExtension> { public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + BufferizationTransformDialectExtension) + using Base::Base; void init() { diff --git a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp index b632b25d0cc67..2728936bf33fd 100644 --- a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp +++ b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp @@ -236,6 +236,8 @@ class FuncTransformDialectExtension : public transform::TransformDialectExtension< FuncTransformDialectExtension> { public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FuncTransformDialectExtension) + using Base::Base; void init() { diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp index 3661c5dea4525..1528da914d546 100644 --- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp +++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp @@ -924,6 +924,8 @@ class GPUTransformDialectExtension : public transform::TransformDialectExtension< GPUTransformDialectExtension> { public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GPUTransformDialectExtension) + GPUTransformDialectExtension() { declareGeneratedDialect(); declareGeneratedDialect(); diff --git a/mlir/lib/Dialect/Linalg/TransformOps/DialectExtension.cpp b/mlir/lib/Dialect/Linalg/TransformOps/DialectExtension.cpp index f4244ca962232..4591802ce74ac 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/DialectExtension.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/DialectExtension.cpp @@ -30,6 +30,8 @@ class LinalgTransformDialectExtension : public transform::TransformDialectExtension< LinalgTransformDialectExtension> { public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LinalgTransformDialectExtension) + using Base::Base; void init() { diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp index 8469e84c668cb..89640ac323b68 100644 --- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp +++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp @@ -309,6 +309,8 @@ class MemRefTransformDialectExtension : public transform::TransformDialectExtension< MemRefTransformDialectExtension> { public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MemRefTransformDialectExtension) + using Base::Base; void init() { diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp index 733fde78e4259..0c2275bbc4b22 100644 --- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp +++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp @@ -1135,6 +1135,8 @@ class NVGPUTransformDialectExtension : public transform::TransformDialectExtension< NVGPUTransformDialectExtension> { public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(NVGPUTransformDialectExtension) + NVGPUTransformDialectExtension() { declareGeneratedDialect(); declareGeneratedDialect(); diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp index c4a55c302d0a3..551411bb14765 100644 --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -613,6 +613,8 @@ class SCFTransformDialectExtension : public transform::TransformDialectExtension< SCFTransformDialectExtension> { public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SCFTransformDialectExtension) + using Base::Base; void init() { diff --git a/mlir/lib/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.cpp b/mlir/lib/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.cpp index ca19259ebffa6..bdec43825ddc2 100644 --- a/mlir/lib/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.cpp +++ b/mlir/lib/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.cpp @@ -38,6 +38,9 @@ class SparseTensorTransformDialectExtension : public transform::TransformDialectExtension< SparseTensorTransformDialectExtension> { public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + SparseTensorTransformDialectExtension) + SparseTensorTransformDialectExtension() { declareGeneratedDialect(); registerTransformOps< diff --git a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp index 33016f84056e9..f911619d71227 100644 --- a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp +++ b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp @@ -236,6 +236,8 @@ class TensorTransformDialectExtension : public transform::TransformDialectExtension< TensorTransformDialectExtension> { public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TensorTransformDialectExtension) + using Base::Base; void init() { diff --git a/mlir/lib/Dialect/Transform/DebugExtension/DebugExtension.cpp b/mlir/lib/Dialect/Transform/DebugExtension/DebugExtension.cpp index e369daddb00cb..d69535169f956 100644 --- a/mlir/lib/Dialect/Transform/DebugExtension/DebugExtension.cpp +++ b/mlir/lib/Dialect/Transform/DebugExtension/DebugExtension.cpp @@ -20,6 +20,8 @@ namespace { class DebugExtension : public transform::TransformDialectExtension { public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DebugExtension) + void init() { registerTransformOps< #define GET_OP_LIST diff --git a/mlir/lib/Dialect/Transform/IRDLExtension/IRDLExtension.cpp b/mlir/lib/Dialect/Transform/IRDLExtension/IRDLExtension.cpp index 94004365b8a1a..9dc95490b14bb 100644 --- a/mlir/lib/Dialect/Transform/IRDLExtension/IRDLExtension.cpp +++ b/mlir/lib/Dialect/Transform/IRDLExtension/IRDLExtension.cpp @@ -18,6 +18,8 @@ namespace { class IRDLExtension : public transform::TransformDialectExtension { public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(IRDLExtension) + void init() { registerTransformOps< #define GET_OP_LIST diff --git a/mlir/lib/Dialect/Transform/LoopExtension/LoopExtension.cpp b/mlir/lib/Dialect/Transform/LoopExtension/LoopExtension.cpp index b33288fd7b991..0a099b5bc75ab 100644 --- a/mlir/lib/Dialect/Transform/LoopExtension/LoopExtension.cpp +++ b/mlir/lib/Dialect/Transform/LoopExtension/LoopExtension.cpp @@ -20,6 +20,8 @@ namespace { class LoopExtension : public transform::TransformDialectExtension { public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LoopExtension) + void init() { registerTransformOps< #define GET_OP_LIST diff --git a/mlir/lib/Dialect/Transform/PDLExtension/PDLExtension.cpp b/mlir/lib/Dialect/Transform/PDLExtension/PDLExtension.cpp index 2c770abd56d52..27c5dc332a428 100644 --- a/mlir/lib/Dialect/Transform/PDLExtension/PDLExtension.cpp +++ b/mlir/lib/Dialect/Transform/PDLExtension/PDLExtension.cpp @@ -38,6 +38,8 @@ namespace { /// with Transform dialect operations. class PDLExtension : public transform::TransformDialectExtension { public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PDLExtension) + void init() { registerTransformOps< #define GET_OP_LIST diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp index 2e9aa88011825..bc423a3781bf0 100644 --- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp +++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp @@ -212,6 +212,8 @@ class VectorTransformDialectExtension : public transform::TransformDialectExtension< VectorTransformDialectExtension> { public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VectorTransformDialectExtension) + VectorTransformDialectExtension() { declareGeneratedDialect(); declareGeneratedDialect(); diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp index 965386681f270..cc80677a4078f 100644 --- a/mlir/lib/IR/Dialect.cpp +++ b/mlir/lib/IR/Dialect.cpp @@ -11,14 +11,20 @@ #include "mlir/IR/Diagnostics.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/DialectInterface.h" +#include "mlir/IR/DialectRegistry.h" #include "mlir/IR/ExtensibleDialect.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Operation.h" +#include "mlir/Support/TypeID.h" #include "llvm/ADT/MapVector.h" +#include "llvm/ADT/SetOperations.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/SmallVectorExtras.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ManagedStatic.h" #include "llvm/Support/Regex.h" +#include #define DEBUG_TYPE "dialect" @@ -173,6 +179,40 @@ bool dialect_extension_detail::hasPromisedInterface(Dialect &dialect, // DialectRegistry //===----------------------------------------------------------------------===// +namespace { +template +void applyExtensionsFn( + Fn &&applyExtension, + const llvm::MapVector> + &extensions) { + // Note: Additional extensions may be added while applying an extension. + // The iterators will be invalidated if extensions are added so we'll keep + // a copy of the extensions for ourselves. + + const auto extractExtension = + [](const auto &entry) -> DialectExtensionBase * { + return entry.second.get(); + }; + + auto startIt = extensions.begin(), endIt = extensions.end(); + size_t count = 0; + while (startIt != endIt) { + count += endIt - startIt; + + // Grab the subset of extensions we'll apply in this iteration. + const auto subset = + llvm::map_to_vector(llvm::make_range(startIt, endIt), extractExtension); + + for (const auto *ext : subset) + applyExtension(*ext); + + // Book-keep for the next iteration. + startIt = extensions.begin() + count; + endIt = extensions.end(); + } +} +} // namespace + DialectRegistry::DialectRegistry() { insert(); } DialectAllocatorFunctionRef @@ -258,9 +298,7 @@ void DialectRegistry::applyExtensions(Dialect *dialect) const { extension.apply(ctx, requiredDialects); }; - // Note: Additional extensions may be added while applying an extension. - for (int i = 0; i < static_cast(extensions.size()); ++i) - applyExtension(*extensions[i]); + applyExtensionsFn(applyExtension, extensions); } void DialectRegistry::applyExtensions(MLIRContext *ctx) const { @@ -285,15 +323,17 @@ void DialectRegistry::applyExtensions(MLIRContext *ctx) const { extension.apply(ctx, requiredDialects); }; - // Note: Additional extensions may be added while applying an extension. - for (int i = 0; i < static_cast(extensions.size()); ++i) - applyExtension(*extensions[i]); + applyExtensionsFn(applyExtension, extensions); } bool DialectRegistry::isSubsetOf(const DialectRegistry &rhs) const { - // Treat any extensions conservatively. - if (!extensions.empty()) + // Check that all extension keys are present in 'rhs'. + const auto hasExtension = [&](const auto &key) { + return rhs.extensions.contains(key); + }; + if (!llvm::all_of(make_first_range(extensions), hasExtension)) return false; + // Check that the current dialects fully overlap with the dialects in 'rhs'. return llvm::all_of( registry, [&](const auto &it) { return rhs.registry.count(it.first); }); diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp index b8a4b9470d736..c023aad4a3ee7 100644 --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -874,6 +874,8 @@ class TestTransformDialectExtension : public transform::TransformDialectExtension< TestTransformDialectExtension> { public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformDialectExtension) + using Base::Base; void init() { diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp index a99441cd7147b..7aa7b58433f36 100644 --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp @@ -382,6 +382,9 @@ class TestTilingInterfaceDialectExtension : public transform::TransformDialectExtension< TestTilingInterfaceDialectExtension> { public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + TestTilingInterfaceDialectExtension) + using Base::Base; void init() { diff --git a/mlir/unittests/Dialect/Transform/BuildOnlyExtensionTest.cpp b/mlir/unittests/Dialect/Transform/BuildOnlyExtensionTest.cpp index 40fb752ffd6eb..d2a4999594a9e 100644 --- a/mlir/unittests/Dialect/Transform/BuildOnlyExtensionTest.cpp +++ b/mlir/unittests/Dialect/Transform/BuildOnlyExtensionTest.cpp @@ -18,6 +18,8 @@ using namespace mlir::transform; namespace { class Extension : public TransformDialectExtension { public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(Extension) + using Base::Base; void init() { declareGeneratedDialect(); } }; diff --git a/mlir/unittests/IR/DialectTest.cpp b/mlir/unittests/IR/DialectTest.cpp index e99d46e6d2643..7dd6a01c3389b 100644 --- a/mlir/unittests/IR/DialectTest.cpp +++ b/mlir/unittests/IR/DialectTest.cpp @@ -8,6 +8,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/DialectInterface.h" +#include "mlir/Support/TypeID.h" #include "gtest/gtest.h" using namespace mlir; @@ -140,15 +141,22 @@ namespace { /// A dummy extension that increases a counter when being applied and /// recursively adds additional extensions. struct DummyExtension : DialectExtension { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DummyExtension) + DummyExtension(int *counter, int numRecursive) : DialectExtension(), counter(counter), numRecursive(numRecursive) {} void apply(MLIRContext *ctx, TestDialect *dialect) const final { ++(*counter); DialectRegistry nestedRegistry; - for (int i = 0; i < numRecursive; ++i) - nestedRegistry.addExtension( - std::make_unique(counter, /*numRecursive=*/0)); + for (int i = 0; i < numRecursive; ++i) { + // Create unique TypeIDs for these recursive extensions so they don't get + // de-duplicated. + auto extension = + std::make_unique(counter, /*numRecursive=*/0); + auto typeID = TypeID::getFromOpaquePointer(extension.get()); + nestedRegistry.addExtension(typeID, std::move(extension)); + } // Adding additional extensions may trigger a reallocation of the // `extensions` vector in the dialect registry. ctx->appendDialectRegistry(nestedRegistry); @@ -166,20 +174,56 @@ TEST(Dialect, NestedDialectExtension) { // Add an extension that adds 100 more extensions. int counter1 = 0; - registry.addExtension(std::make_unique(&counter1, 100)); + registry.addExtension(TypeID::get(), + std::make_unique(&counter1, 100)); // Add one more extension. This should not crash. int counter2 = 0; - registry.addExtension(std::make_unique(&counter2, 0)); + registry.addExtension(TypeID::getFromOpaquePointer(&counter2), + std::make_unique(&counter2, 0)); // Load dialect and apply extensions. MLIRContext context(registry); Dialect *testDialect = context.getOrLoadDialect(); ASSERT_TRUE(testDialect != nullptr); - // Extensions may be applied multiple times. Make sure that each expected + // Extensions are de-duplicated by typeID. Make sure that each expected // extension was applied at least once. EXPECT_GE(counter1, 101); EXPECT_GE(counter2, 1); } +TEST(Dialect, SubsetWithExtensions) { + DialectRegistry registry1, registry2; + registry1.insert(); + registry2.insert(); + + // Validate that the registries are equivalent. + ASSERT_TRUE(registry1.isSubsetOf(registry2)); + ASSERT_TRUE(registry2.isSubsetOf(registry1)); + + // Add extensions to registry2. + int counter = 0; + registry2.addExtension(TypeID::get(), + std::make_unique(&counter, 0)); + + // Expect that (1) is a subset of (2) but not the other way around. + ASSERT_TRUE(registry1.isSubsetOf(registry2)); + ASSERT_FALSE(registry2.isSubsetOf(registry1)); + + // Add extensions to registry1. + registry1.addExtension(TypeID::get(), + std::make_unique(&counter, 0)); + + // Expect that (1) and (2) are equivalent. + ASSERT_TRUE(registry1.isSubsetOf(registry2)); + ASSERT_TRUE(registry2.isSubsetOf(registry1)); + + // Load dialect and apply extensions. + MLIRContext context(registry1); + context.getOrLoadDialect(); + context.appendDialectRegistry(registry2); + // Expect that the extension as only invoked once. + ASSERT_EQ(counter, 1); +} + } // namespace