Skip to content

[mlir] Optimize ThreadLocalCache by removing atomic bottleneck #93270

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 24, 2024

Conversation

Mogball
Copy link
Contributor

@Mogball Mogball commented May 24, 2024

The ThreadLocalCache implementation is used by the MLIRContext (among other things) to try to manage thread contention in the StorageUniquers. There is a bunch of fancy shared pointer/weak pointer setups that basically keeps everything alive across threads at the right time, but a huge bottleneck is the weak_ptr::lock call inside the ::get method.

This is because the lock method has to hit the atomic refcount several times, and this is bottlenecking performance across many threads. However, all this is doing is checking whether the storage is initialized. We know that it cannot be an expired weak pointer because the thread local cache object we're calling into owns the memory and is still alive for the method call to be valid. Thus, we can store and extra Value * inside the thread local cache for speedy retrieval if the cache is already initialized for the thread, which is the common case.

This also tightens the size of the critical section in the same method by scoping the mutex more to just the mutation on perInstanceState.

Before:

image

After:

image

The ThreadLocalCache implementation is used by the MLIRContext (among
other things) to try to manage thread contention in the StorageUniquers.
There is a bunch of fancy shared pointer/weak pointer setups that
basically keeps everything alive across threads at the right time, but a
huge bottleneck is the `weak_ptr::lock` call inside the `::get` method.

This is because the `lock` method has to hit the atomic refcount several
times, and this is bottlenecking performance across many threads.
However, all this is doing is checking whether the storage is
initialized. We know that it cannot be an expired weak pointer because
the thread local cache object we're calling into owns the memory and is
still alive for the method call to be valid. Thus, we can store and
extra `Value *` inside the thread local cache for speedy retrieval if
the cache is already initialized for the thread, which is the common
case.
@Mogball Mogball requested review from jpienaar and joker-eph May 24, 2024 03:36
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels May 24, 2024
@llvmbot
Copy link
Member

llvmbot commented May 24, 2024

@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir

Author: Jeff Niu (Mogball)

Changes

The ThreadLocalCache implementation is used by the MLIRContext (among other things) to try to manage thread contention in the StorageUniquers. There is a bunch of fancy shared pointer/weak pointer setups that basically keeps everything alive across threads at the right time, but a huge bottleneck is the weak_ptr::lock call inside the ::get method.

This is because the lock method has to hit the atomic refcount several times, and this is bottlenecking performance across many threads. However, all this is doing is checking whether the storage is initialized. We know that it cannot be an expired weak pointer because the thread local cache object we're calling into owns the memory and is still alive for the method call to be valid. Thus, we can store and extra Value * inside the thread local cache for speedy retrieval if the cache is already initialized for the thread, which is the common case.

Before:

<img width="560" alt="image" src="https://github.com/llvm/llvm-project/assets/15016832/f4ea3f32-6649-4c10-88c4-b7522031e8c9">

After:

<img width="344" alt="image" src="https://github.com/llvm/llvm-project/assets/15016832/1216db25-3dc1-4b0f-be89-caeff622dd35">


Full diff: https://github.com/llvm/llvm-project/pull/93270.diff

1 Files Affected:

  • (modified) mlir/include/mlir/Support/ThreadLocalCache.h (+17-11)
diff --git a/mlir/include/mlir/Support/ThreadLocalCache.h b/mlir/include/mlir/Support/ThreadLocalCache.h
index 1be94ca14bcfa..d19257bf6e25e 100644
--- a/mlir/include/mlir/Support/ThreadLocalCache.h
+++ b/mlir/include/mlir/Support/ThreadLocalCache.h
@@ -58,11 +58,12 @@ class ThreadLocalCache {
   /// ValueT. We use a weak reference here so that the object can be destroyed
   /// without needing to lock access to the cache itself.
   struct CacheType
-      : public llvm::SmallDenseMap<PerInstanceState *, std::weak_ptr<ValueT>> {
+      : public llvm::SmallDenseMap<PerInstanceState *,
+                                   std::pair<std::weak_ptr<ValueT>, ValueT *>> {
     ~CacheType() {
       // Remove the values of this cache that haven't already expired.
       for (auto &it : *this)
-        if (std::shared_ptr<ValueT> value = it.second.lock())
+        if (std::shared_ptr<ValueT> value = it.second.first.lock())
           it.first->remove(value.get());
     }
 
@@ -71,7 +72,7 @@ class ThreadLocalCache {
     void clearExpiredEntries() {
       for (auto it = this->begin(), e = this->end(); it != e;) {
         auto curIt = it++;
-        if (curIt->second.expired())
+        if (curIt->second.first.expired())
           this->erase(curIt);
       }
     }
@@ -88,22 +89,27 @@ class ThreadLocalCache {
   ValueT &get() {
     // Check for an already existing instance for this thread.
     CacheType &staticCache = getStaticCache();
-    std::weak_ptr<ValueT> &threadInstance = staticCache[perInstanceState.get()];
-    if (std::shared_ptr<ValueT> value = threadInstance.lock())
+    std::pair<std::weak_ptr<ValueT>, ValueT *> &threadInstance =
+        staticCache[perInstanceState.get()];
+    if (ValueT *value = threadInstance.second)
       return *value;
 
     // Otherwise, create a new instance for this thread.
-    llvm::sys::SmartScopedLock<true> threadInstanceLock(
-        perInstanceState->instanceMutex);
-    perInstanceState->instances.push_back(std::make_unique<ValueT>());
-    ValueT *instance = perInstanceState->instances.back().get();
-    threadInstance = std::shared_ptr<ValueT>(perInstanceState, instance);
+    {
+      llvm::sys::SmartScopedLock<true> threadInstanceLock(
+          perInstanceState->instanceMutex);
+      threadInstance.second =
+          perInstanceState->instances.emplace_back(std::make_unique<ValueT>())
+              .get();
+    }
+    threadInstance.first =
+        std::shared_ptr<ValueT>(perInstanceState, threadInstance.second);
 
     // Before returning the new instance, take the chance to clear out any used
     // entries in the static map. The cache is only cleared within the same
     // thread to remove the need to lock the cache itself.
     staticCache.clearExpiredEntries();
-    return *instance;
+    return *threadInstance.second;
   }
   ValueT &operator*() { return get(); }
   ValueT *operator->() { return &get(); }

Copy link
Member

@jpienaar jpienaar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice :)

@Mogball Mogball merged commit 1b803fe into llvm:main May 24, 2024
8 of 10 checks passed
jpienaar added a commit that referenced this pull request May 24, 2024
joker-eph pushed a commit that referenced this pull request May 24, 2024
…k" (#93306)

Reverts #93270

This was found to have a race and the forward fix was reverted,
reverting this until can forward fix.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants