Skip to content

Commit fab234a

Browse files
authored
Revert "[mlir] Optimize ThreadLocalCache by removing atomic bottleneck" (#93306)
Reverts #93270 This was found to have a race and the forward fix was reverted, reverting this until can forward fix.
1 parent b008a2d commit fab234a

File tree

1 file changed

+11
-17
lines changed

1 file changed

+11
-17
lines changed

mlir/include/mlir/Support/ThreadLocalCache.h

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,11 @@ class ThreadLocalCache {
5858
/// ValueT. We use a weak reference here so that the object can be destroyed
5959
/// without needing to lock access to the cache itself.
6060
struct CacheType
61-
: public llvm::SmallDenseMap<PerInstanceState *,
62-
std::pair<std::weak_ptr<ValueT>, ValueT *>> {
61+
: public llvm::SmallDenseMap<PerInstanceState *, std::weak_ptr<ValueT>> {
6362
~CacheType() {
6463
// Remove the values of this cache that haven't already expired.
6564
for (auto &it : *this)
66-
if (std::shared_ptr<ValueT> value = it.second.first.lock())
65+
if (std::shared_ptr<ValueT> value = it.second.lock())
6766
it.first->remove(value.get());
6867
}
6968

@@ -72,7 +71,7 @@ class ThreadLocalCache {
7271
void clearExpiredEntries() {
7372
for (auto it = this->begin(), e = this->end(); it != e;) {
7473
auto curIt = it++;
75-
if (curIt->second.first.expired())
74+
if (curIt->second.expired())
7675
this->erase(curIt);
7776
}
7877
}
@@ -89,27 +88,22 @@ class ThreadLocalCache {
8988
ValueT &get() {
9089
// Check for an already existing instance for this thread.
9190
CacheType &staticCache = getStaticCache();
92-
std::pair<std::weak_ptr<ValueT>, ValueT *> &threadInstance =
93-
staticCache[perInstanceState.get()];
94-
if (ValueT *value = threadInstance.second)
91+
std::weak_ptr<ValueT> &threadInstance = staticCache[perInstanceState.get()];
92+
if (std::shared_ptr<ValueT> value = threadInstance.lock())
9593
return *value;
9694

9795
// Otherwise, create a new instance for this thread.
98-
{
99-
llvm::sys::SmartScopedLock<true> threadInstanceLock(
100-
perInstanceState->instanceMutex);
101-
threadInstance.second =
102-
perInstanceState->instances.emplace_back(std::make_unique<ValueT>())
103-
.get();
104-
}
105-
threadInstance.first =
106-
std::shared_ptr<ValueT>(perInstanceState, threadInstance.second);
96+
llvm::sys::SmartScopedLock<true> threadInstanceLock(
97+
perInstanceState->instanceMutex);
98+
perInstanceState->instances.push_back(std::make_unique<ValueT>());
99+
ValueT *instance = perInstanceState->instances.back().get();
100+
threadInstance = std::shared_ptr<ValueT>(perInstanceState, instance);
107101

108102
// Before returning the new instance, take the chance to clear out any used
109103
// entries in the static map. The cache is only cleared within the same
110104
// thread to remove the need to lock the cache itself.
111105
staticCache.clearExpiredEntries();
112-
return *threadInstance.second;
106+
return *instance;
113107
}
114108
ValueT &operator*() { return get(); }
115109
ValueT *operator->() { return &get(); }

0 commit comments

Comments
 (0)