Skip to content

Commit 0cf68e8

Browse files
committed
add dialect registry extension comparison by key
1 parent 26669bd commit 0cf68e8

File tree

2 files changed

+43
-8
lines changed

2 files changed

+43
-8
lines changed

mlir/lib/IR/Dialect.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -291,9 +291,19 @@ void DialectRegistry::applyExtensions(MLIRContext *ctx) const {
291291
}
292292

293293
bool DialectRegistry::isSubsetOf(const DialectRegistry &rhs) const {
294-
// Treat any extensions conservatively.
295-
if (!extensions.empty())
294+
// Check that all extension keys are present in 'rhs'.
295+
llvm::DenseSet<llvm::StringRef> rhsExtensionKeys;
296+
{
297+
auto rhsKeys = llvm::map_range(rhs.extensions,
298+
[](const auto &item) { return item.first; });
299+
rhsExtensionKeys.insert(rhsKeys.begin(), rhsKeys.end());
300+
}
301+
302+
if (!llvm::all_of(extensions, [&rhsExtensionKeys](const auto &extension) {
303+
return rhsExtensionKeys.contains(extension.first);
304+
}))
296305
return false;
306+
297307
// Check that the current dialects fully overlap with the dialects in 'rhs'.
298308
return llvm::all_of(
299309
registry, [&](const auto &it) { return rhs.registry.count(it.first); });

mlir/unittests/IR/DialectTest.cpp

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@ TEST(Dialect, DelayedInterfaceRegistration) {
7777
// Delayed registration of an interface for TestDialect.
7878
registry.addExtension(
7979
"TEST_DIALECT_DELAYED", +[](MLIRContext *ctx, TestDialect *dialect) {
80-
dialect->addInterfaces<TestDialectInterface>();
81-
});
80+
dialect->addInterfaces<TestDialectInterface>();
81+
});
8282

8383
MLIRContext context(registry);
8484

@@ -116,8 +116,8 @@ TEST(Dialect, RepeatedDelayedRegistration) {
116116
registry.insert<TestDialect>();
117117
registry.addExtension(
118118
"TEST_DIALECT", +[](MLIRContext *ctx, TestDialect *dialect) {
119-
dialect->addInterfaces<TestDialectInterface>();
120-
});
119+
dialect->addInterfaces<TestDialectInterface>();
120+
});
121121
MLIRContext context(registry);
122122

123123
// Load the TestDialect and check that the interface got registered for it.
@@ -132,8 +132,8 @@ TEST(Dialect, RepeatedDelayedRegistration) {
132132
secondRegistry.insert<TestDialect>();
133133
secondRegistry.addExtension(
134134
"TEST_DIALECT", +[](MLIRContext *ctx, TestDialect *dialect) {
135-
dialect->addInterfaces<TestDialectInterface>();
136-
});
135+
dialect->addInterfaces<TestDialectInterface>();
136+
});
137137
context.appendDialectRegistry(secondRegistry);
138138
testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect);
139139
EXPECT_TRUE(testDialectInterface != nullptr);
@@ -193,4 +193,29 @@ TEST(Dialect, NestedDialectExtension) {
193193
EXPECT_GE(counter2, 1);
194194
}
195195

196+
TEST(Dialect, SubsetWithExtensions) {
197+
DialectRegistry registry1, registry2;
198+
registry1.insert<TestDialect>();
199+
registry2.insert<TestDialect>();
200+
201+
// Validate that the registries are equivalent.
202+
ASSERT_TRUE(registry1.isSubsetOf(registry2));
203+
ASSERT_TRUE(registry2.isSubsetOf(registry1));
204+
205+
// Add extensions to registry2.
206+
int counter;
207+
registry2.addExtension("EXT", std::make_unique<DummyExtension>(&counter, 0));
208+
209+
// Expect that (1) is a subset of (2) but not the other way around.
210+
ASSERT_TRUE(registry1.isSubsetOf(registry2));
211+
ASSERT_FALSE(registry2.isSubsetOf(registry1));
212+
213+
// Add extensions to registry1.
214+
registry1.addExtension("EXT", std::make_unique<DummyExtension>(&counter, 0));
215+
216+
// Expect that (1) and (2) are equivalent.
217+
ASSERT_TRUE(registry1.isSubsetOf(registry2));
218+
ASSERT_TRUE(registry2.isSubsetOf(registry1));
219+
}
220+
196221
} // namespace

0 commit comments

Comments
 (0)