1616
1717#include " mlir/Support/LLVM.h"
1818#include " llvm/ADT/DenseMap.h"
19+ #include " llvm/Support/ManagedStatic.h"
1920#include " llvm/Support/Mutex.h"
2021
2122namespace mlir {
@@ -24,80 +25,28 @@ namespace mlir {
2425// / cache has very large lock contention.
2526template <typename ValueT>
2627class ThreadLocalCache {
27- struct PerInstanceState ;
28-
29- // / The "observer" is owned by a thread-local cache instance. It is
30- // / constructed the first time a `ThreadLocalCache` instance is accessed by a
31- // / thread, unless `perInstanceState` happens to get re-allocated to the same
32- // / address as a previous one. This class is destructed the thread in which
33- // / the `thread_local` cache lives is destroyed.
34- // /
35- // / This class is called the "observer" because while values cached in
36- // / thread-local caches are owned by `PerInstanceState`, a reference is stored
37- // / via this class in the TLC. With a double pointer, it knows when the
38- // / referenced value has been destroyed.
39- struct Observer {
40- // / This is the double pointer, explicitly allocated because we need to keep
41- // / the address stable if the TLC map re-allocates. It is owned by the
42- // / observer and shared with the value owner.
43- std::shared_ptr<ValueT *> ptr = std::make_shared<ValueT *>(nullptr );
44- // / Because `Owner` living inside `PerInstanceState` contains a reference to
45- // / the double pointer, and livkewise this class contains a reference to the
46- // / value, we need to synchronize destruction of the TLC and the
47- // / `PerInstanceState` to avoid racing. This weak pointer is acquired during
48- // / TLC destruction if the `PerInstanceState` hasn't entered its destructor
49- // / yet, and prevents it from happening.
50- std::weak_ptr<PerInstanceState> keepalive;
51- };
52-
53- // / This struct owns the cache entries. It contains a reference back to the
54- // / reference inside the cache so that it can be written to null to indicate
55- // / that the cache entry is invalidated. It needs to do this because
56- // / `perInstanceState` could get re-allocated to the same pointer and we don't
57- // / remove entries from the TLC when it is deallocated. Thus, we have to reset
58- // / the TLC entries to a starting state in case the `ThreadLocalCache` lives
59- // / shorter than the threads.
60- struct Owner {
61- // / Save a pointer to the reference and write it to the newly created entry.
62- Owner (Observer &observer)
63- : value(std::make_unique<ValueT>()), ptrRef(observer.ptr) {
64- *observer.ptr = value.get ();
65- }
66- ~Owner () {
67- if (std::shared_ptr<ValueT *> ptr = ptrRef.lock ())
68- *ptr = nullptr ;
69- }
70-
71- Owner (Owner &&) = default ;
72- Owner &operator =(Owner &&) = default ;
73-
74- std::unique_ptr<ValueT> value;
75- std::weak_ptr<ValueT *> ptrRef;
76- };
77-
7828 // Keep a separate shared_ptr protected state that can be acquired atomically
7929 // instead of using shared_ptr's for each value. This avoids a problem
8030 // where the instance shared_ptr is locked() successfully, and then the
8131 // ThreadLocalCache gets destroyed before remove() can be called successfully.
8232 struct PerInstanceState {
83- // / Remove the given value entry. This is called when a thread local cache
84- // / is destructing but still contains references to values owned by the
85- // / `PerInstanceState`. Removal is required because it prevents writeback to
86- // / a pointer that was deallocated.
33+ // / Remove the given value entry. This is generally called when a thread
34+ // / local cache is destructing.
8735 void remove (ValueT *value) {
8836 // Erase the found value directly, because it is guaranteed to be in the
8937 // list.
9038 llvm::sys::SmartScopedLock<true > threadInstanceLock (instanceMutex);
91- auto it = llvm::find_if (instances, [&](Owner &instance) {
92- return instance.value .get () == value;
93- });
39+ auto it =
40+ llvm::find_if (instances, [&](std::unique_ptr<ValueT> &instance) {
41+ return instance.get () == value;
42+ });
9443 assert (it != instances.end () && " expected value to exist in cache" );
9544 instances.erase (it);
9645 }
9746
9847 // / Owning pointers to all of the values that have been constructed for this
9948 // / object in the static cache.
100- SmallVector<Owner , 1 > instances;
49+ SmallVector<std::unique_ptr<ValueT> , 1 > instances;
10150
10251 // / A mutex used when a new thread instance has been added to the cache for
10352 // / this object.
@@ -108,22 +57,22 @@ class ThreadLocalCache {
10857 // / instance of the non-static cache and a weak reference to an instance of
10958 // / ValueT. We use a weak reference here so that the object can be destroyed
11059 // / without needing to lock access to the cache itself.
111- struct CacheType : public llvm ::SmallDenseMap<PerInstanceState *, Observer> {
60+ struct CacheType
61+ : public llvm::SmallDenseMap<PerInstanceState *,
62+ std::pair<std::weak_ptr<ValueT>, ValueT *>> {
11263 ~CacheType () {
113- // Remove the values of this cache that haven't already expired. This is
114- // required because if we don't remove them, they will contain a reference
115- // back to the data here that is being destroyed.
116- for (auto &[instance, observer] : *this )
117- if (std::shared_ptr<PerInstanceState> state = observer.keepalive .lock ())
118- state->remove (*observer.ptr );
64+ // Remove the values of this cache that haven't already expired.
65+ for (auto &it : *this )
66+ if (std::shared_ptr<ValueT> value = it.second .first .lock ())
67+ it.first ->remove (value.get ());
11968 }
12069
12170 // / Clear out any unused entries within the map. This method is not
12271 // / thread-safe, and should only be called by the same thread as the cache.
12372 void clearExpiredEntries () {
12473 for (auto it = this ->begin (), e = this ->end (); it != e;) {
12574 auto curIt = it++;
126- if (!* curIt->second .ptr )
75+ if (curIt->second .first . expired () )
12776 this ->erase (curIt);
12877 }
12978 }
@@ -140,23 +89,27 @@ class ThreadLocalCache {
14089 ValueT &get () {
14190 // Check for an already existing instance for this thread.
14291 CacheType &staticCache = getStaticCache ();
143- Observer &threadInstance = staticCache[perInstanceState.get ()];
144- if (ValueT *value = *threadInstance.ptr )
92+ std::pair<std::weak_ptr<ValueT>, ValueT *> &threadInstance =
93+ staticCache[perInstanceState.get ()];
94+ if (ValueT *value = threadInstance.second )
14595 return *value;
14696
14797 // Otherwise, create a new instance for this thread.
14898 {
14999 llvm::sys::SmartScopedLock<true > threadInstanceLock (
150100 perInstanceState->instanceMutex );
151- perInstanceState->instances .emplace_back (threadInstance);
101+ threadInstance.second =
102+ perInstanceState->instances .emplace_back (std::make_unique<ValueT>())
103+ .get ();
152104 }
153- threadInstance.keepalive = perInstanceState;
105+ threadInstance.first =
106+ std::shared_ptr<ValueT>(perInstanceState, threadInstance.second );
154107
155108 // Before returning the new instance, take the chance to clear out any used
156109 // entries in the static map. The cache is only cleared within the same
157110 // thread to remove the need to lock the cache itself.
158111 staticCache.clearExpiredEntries ();
159- return ** threadInstance.ptr ;
112+ return *threadInstance.second ;
160113 }
161114 ValueT &operator *() { return get (); }
162115 ValueT *operator ->() { return &get (); }
0 commit comments