Skip to content

Commit 7b7faea

Browse files
authored
Fix potential subscription leakage in SSR environments (#5111)
1 parent fde0be7 commit 7b7faea

File tree

7 files changed

+123
-73
lines changed

7 files changed

+123
-73
lines changed

packages/toolkit/src/query/core/buildInitiate.ts

Lines changed: 22 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import {
2222
type QueryDefinition,
2323
type ResultTypeFrom,
2424
} from '../endpointDefinitions'
25-
import { countObjectKeys, getOrInsert, isNotNullish } from '../utils'
25+
import { filterNullishValues } from '../utils'
2626
import type {
2727
InfiniteData,
2828
InfiniteQueryConfigOptions,
@@ -271,17 +271,20 @@ export function buildInitiate({
271271
mutationThunk,
272272
api,
273273
context,
274-
internalState,
274+
getInternalState,
275275
}: {
276276
serializeQueryArgs: InternalSerializeQueryArgs
277277
queryThunk: QueryThunk
278278
infiniteQueryThunk: InfiniteQueryThunk<any>
279279
mutationThunk: MutationThunk
280280
api: Api<any, EndpointDefinitions, any, any>
281281
context: ApiContext<EndpointDefinitions>
282-
internalState: InternalMiddlewareState
282+
getInternalState: (dispatch: Dispatch) => InternalMiddlewareState
283283
}) {
284-
const { runningQueries, runningMutations } = internalState
284+
const getRunningQueries = (dispatch: Dispatch) =>
285+
getInternalState(dispatch)?.runningQueries
286+
const getRunningMutations = (dispatch: Dispatch) =>
287+
getInternalState(dispatch)?.runningMutations
285288

286289
const {
287290
unsubscribeQueryResult,
@@ -306,7 +309,7 @@ export function buildInitiate({
306309
endpointDefinition,
307310
endpointName,
308311
})
309-
return runningQueries.get(dispatch)?.[queryCacheKey] as
312+
return getRunningQueries(dispatch)?.get(queryCacheKey) as
310313
| QueryActionCreatorResult<never>
311314
| InfiniteQueryActionCreatorResult<never>
312315
| undefined
@@ -322,20 +325,20 @@ export function buildInitiate({
322325
fixedCacheKeyOrRequestId: string,
323326
) {
324327
return (dispatch: Dispatch) => {
325-
return runningMutations.get(dispatch)?.[fixedCacheKeyOrRequestId] as
328+
return getRunningMutations(dispatch)?.get(fixedCacheKeyOrRequestId) as
326329
| MutationActionCreatorResult<never>
327330
| undefined
328331
}
329332
}
330333

331334
function getRunningQueriesThunk() {
332335
return (dispatch: Dispatch) =>
333-
Object.values(runningQueries.get(dispatch) || {}).filter(isNotNullish)
336+
filterNullishValues(getRunningQueries(dispatch))
334337
}
335338

336339
function getRunningMutationsThunk() {
337340
return (dispatch: Dispatch) =>
338-
Object.values(runningMutations.get(dispatch) || {}).filter(isNotNullish)
341+
filterNullishValues(getRunningMutations(dispatch))
339342
}
340343

341344
function middlewareWarning(dispatch: Dispatch) {
@@ -429,7 +432,7 @@ You must add the middleware for RTK-Query to function correctly!`,
429432

430433
const skippedSynchronously = stateAfter.requestId !== requestId
431434

432-
const runningQuery = runningQueries.get(dispatch)?.[queryCacheKey]
435+
const runningQuery = getRunningQueries(dispatch)?.get(queryCacheKey)
433436
const selectFromState = () => selector(getState())
434437

435438
const statePromise: AnyActionCreatorResult = Object.assign(
@@ -489,14 +492,11 @@ You must add the middleware for RTK-Query to function correctly!`,
489492
)
490493

491494
if (!runningQuery && !skippedSynchronously && !forceQueryFn) {
492-
const running = getOrInsert(runningQueries, dispatch, {})
493-
running[queryCacheKey] = statePromise
495+
const runningQueries = getRunningQueries(dispatch)!
496+
runningQueries.set(queryCacheKey, statePromise)
494497

495498
statePromise.then(() => {
496-
delete running[queryCacheKey]
497-
if (!countObjectKeys(running)) {
498-
runningQueries.delete(dispatch)
499-
}
499+
runningQueries.delete(queryCacheKey)
500500
})
501501
}
502502

@@ -559,23 +559,17 @@ You must add the middleware for RTK-Query to function correctly!`,
559559
reset,
560560
})
561561

562-
const running = runningMutations.get(dispatch) || {}
563-
runningMutations.set(dispatch, running)
564-
running[requestId] = ret
562+
const runningMutations = getRunningMutations(dispatch)!
563+
564+
runningMutations.set(requestId, ret)
565565
ret.then(() => {
566-
delete running[requestId]
567-
if (!countObjectKeys(running)) {
568-
runningMutations.delete(dispatch)
569-
}
566+
runningMutations.delete(requestId)
570567
})
571568
if (fixedCacheKey) {
572-
running[fixedCacheKey] = ret
569+
runningMutations.set(fixedCacheKey, ret)
573570
ret.then(() => {
574-
if (running[fixedCacheKey] === ret) {
575-
delete running[fixedCacheKey]
576-
if (!countObjectKeys(running)) {
577-
runningMutations.delete(dispatch)
578-
}
571+
if (runningMutations.get(fixedCacheKey) === ret) {
572+
runningMutations.delete(fixedCacheKey)
579573
}
580574
})
581575
}

packages/toolkit/src/query/core/buildMiddleware/cacheCollection.ts

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,27 +13,27 @@ export type ReferenceCacheCollection = never
1313

1414
/**
1515
* @example
16-
* ```ts
17-
* // codeblock-meta title="keepUnusedDataFor example"
18-
* import { createApi, fetchBaseQuery } from '@reduxjs/toolkit/query/react'
19-
* interface Post {
20-
* id: number
21-
* name: string
22-
* }
23-
* type PostsResponse = Post[]
24-
*
25-
* const api = createApi({
26-
* baseQuery: fetchBaseQuery({ baseUrl: '/' }),
27-
* endpoints: (build) => ({
28-
* getPosts: build.query<PostsResponse, void>({
29-
* query: () => 'posts',
30-
* // highlight-start
31-
* keepUnusedDataFor: 5
32-
* // highlight-end
33-
* })
34-
* })
35-
* })
36-
* ```
16+
* ```ts
17+
* // codeblock-meta title="keepUnusedDataFor example"
18+
* import { createApi, fetchBaseQuery } from '@reduxjs/toolkit/query/react'
19+
* interface Post {
20+
* id: number
21+
* name: string
22+
* }
23+
* type PostsResponse = Post[]
24+
*
25+
* const api = createApi({
26+
* baseQuery: fetchBaseQuery({ baseUrl: '/' }),
27+
* endpoints: (build) => ({
28+
* getPosts: build.query<PostsResponse, void>({
29+
* query: () => 'posts',
30+
* // highlight-start
31+
* keepUnusedDataFor: 5
32+
* // highlight-end
33+
* })
34+
* })
35+
* })
36+
* ```
3737
*/
3838
export type CacheCollectionQueryExtraOptions = {
3939
/**
@@ -64,8 +64,6 @@ export const buildCacheCollectionHandler: InternalHandlerBuilder = ({
6464
const { removeQueryResult, unsubscribeQueryResult, cacheEntriesUpserted } =
6565
api.internalActions
6666

67-
const runningQueries = internalState.runningQueries.get(mwApi.dispatch)!
68-
6967
const canTriggerUnsubscribe = isAnyOf(
7068
unsubscribeQueryResult.match,
7169
queryThunk.fulfilled,
@@ -80,8 +78,7 @@ export const buildCacheCollectionHandler: InternalHandlerBuilder = ({
8078
}
8179

8280
const hasSubscriptions = subscriptions.size > 0
83-
const isRunning = runningQueries?.[queryCacheKey] !== undefined
84-
return hasSubscriptions || isRunning
81+
return hasSubscriptions
8582
}
8683

8784
const currentRemovalTimeouts: QueryStateMeta<TimeoutId> = {}

packages/toolkit/src/query/core/buildMiddleware/index.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ export function buildMiddleware<
4545
ReducerPath extends string,
4646
TagTypes extends string,
4747
>(input: BuildMiddlewareInput<Definitions, ReducerPath, TagTypes>) {
48-
const { reducerPath, queryThunk, api, context, internalState } = input
48+
const { reducerPath, queryThunk, api, context, getInternalState } = input
4949
const { apiUid } = context
5050

5151
const actions = {
@@ -73,6 +73,8 @@ export function buildMiddleware<
7373
> = (mwApi) => {
7474
let initialized = false
7575

76+
const internalState = getInternalState(mwApi.dispatch)
77+
7678
const builderArgs = {
7779
...(input as any as BuildMiddlewareInput<
7880
EndpointDefinitions,

packages/toolkit/src/query/core/buildMiddleware/types.ts

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -47,18 +47,12 @@ export interface InternalMiddlewareState {
4747
currentSubscriptions: SubscriptionInternalState
4848
currentPolls: Map<string, QueryPollState>
4949
runningQueries: Map<
50-
Dispatch,
51-
Record<
52-
string,
53-
| QueryActionCreatorResult<any>
54-
| InfiniteQueryActionCreatorResult<any>
55-
| undefined
56-
>
57-
>
58-
runningMutations: Map<
59-
Dispatch,
60-
Record<string, MutationActionCreatorResult<any> | undefined>
50+
string,
51+
| QueryActionCreatorResult<any>
52+
| InfiniteQueryActionCreatorResult<any>
53+
| undefined
6154
>
55+
runningMutations: Map<string, MutationActionCreatorResult<any> | undefined>
6256
}
6357

6458
export interface SubscriptionSelectors {
@@ -84,7 +78,7 @@ export interface BuildMiddlewareInput<
8478
endpointName: string,
8579
queryArgs: any,
8680
) => (dispatch: Dispatch) => QueryActionCreatorResult<any> | undefined
87-
internalState: InternalMiddlewareState
81+
getInternalState: (dispatch: Dispatch) => InternalMiddlewareState
8882
}
8983

9084
export type SubMiddlewareApi = MiddlewareAPI<

packages/toolkit/src/query/core/module.ts

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
*/
44
import type {
55
ActionCreatorWithPayload,
6+
Dispatch,
67
Middleware,
78
Reducer,
89
ThunkAction,
@@ -72,6 +73,7 @@ import { buildThunks } from './buildThunks'
7273
import { createSelector as _createSelector } from './rtkImports'
7374
import { onFocus, onFocusLost, onOffline, onOnline } from './setupListeners'
7475
import type { InternalMiddlewareState } from './buildMiddleware/types'
76+
import { getOrInsertComputed } from '../utils'
7577

7678
/**
7779
* `ifOlderThan` - (default: `false` | `number`) - _number is value in seconds_
@@ -619,11 +621,17 @@ export const coreModule = ({
619621
})
620622
safeAssign(api.internalActions, sliceActions)
621623

622-
const internalState: InternalMiddlewareState = {
623-
currentSubscriptions: new Map(),
624-
currentPolls: new Map(),
625-
runningQueries: new Map(),
626-
runningMutations: new Map(),
624+
const internalStateMap = new WeakMap<Dispatch, InternalMiddlewareState>()
625+
626+
const getInternalState = (dispatch: Dispatch) => {
627+
const state = getOrInsertComputed(internalStateMap, dispatch, () => ({
628+
currentSubscriptions: new Map(),
629+
currentPolls: new Map(),
630+
runningQueries: new Map(),
631+
runningMutations: new Map(),
632+
}))
633+
634+
return state
627635
}
628636

629637
const {
@@ -641,7 +649,7 @@ export const coreModule = ({
641649
api,
642650
serializeQueryArgs: serializeQueryArgs as any,
643651
context,
644-
internalState,
652+
getInternalState,
645653
})
646654

647655
safeAssign(api.util, {
@@ -661,7 +669,7 @@ export const coreModule = ({
661669
assertTagType,
662670
selectors,
663671
getRunningQueryThunk,
664-
internalState,
672+
getInternalState,
665673
})
666674
safeAssign(api.util, middlewareActions)
667675

packages/toolkit/src/query/tests/buildMiddleware.test.tsx

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import { createApi } from '@reduxjs/toolkit/query'
22
import { delay } from 'msw'
33
import { actionsReducer, setupApiStore } from '../../tests/utils/helpers'
4+
import { vi } from 'vitest'
45

56
const baseQuery = (args?: any) => ({ data: args })
67
const api = createApi({
@@ -213,3 +214,53 @@ it('correctly stringifies subscription state and dispatches subscriptionsUpdated
213214
subscriptionState['getBananas(undefined)']?.[subscription3.requestId],
214215
).toEqual({})
215216
})
217+
218+
it('does not leak subscription state between multiple stores using the same API instance (SSR scenario)', async () => {
219+
vi.useFakeTimers()
220+
// Simulate SSR: create API once at module level
221+
const sharedApi = createApi({
222+
baseQuery: (args?: any) => ({ data: args }),
223+
tagTypes: ['Test'],
224+
endpoints: (build) => ({
225+
getTest: build.query<unknown, number>({
226+
query(id) {
227+
return { url: `test/${id}` }
228+
},
229+
}),
230+
}),
231+
})
232+
233+
// Create first store (simulating first SSR request)
234+
const store1Ref = setupApiStore(sharedApi, {}, { withoutListeners: true })
235+
236+
// Add subscription in store1
237+
const sub1 = store1Ref.store.dispatch(
238+
sharedApi.endpoints.getTest.initiate(1, {
239+
subscriptionOptions: { pollingInterval: 1000 },
240+
}),
241+
)
242+
vi.advanceTimersByTime(10)
243+
await sub1
244+
245+
// Wait for subscription sync (500ms + buffer)
246+
vi.advanceTimersByTime(600)
247+
248+
// Verify store1 has the subscription
249+
const store1SubscriptionSelectors = store1Ref.store.dispatch(
250+
sharedApi.internalActions.internal_getRTKQSubscriptions(),
251+
) as any
252+
const store1InternalSubs = store1SubscriptionSelectors.getSubscriptions()
253+
expect(store1InternalSubs.size).toBe(1)
254+
255+
// Create second store (simulating second SSR request)
256+
const store2Ref = setupApiStore(sharedApi, {}, { withoutListeners: true })
257+
258+
// Check subscriptions via internal action
259+
const store2SubscriptionSelectors = store2Ref.store.dispatch(
260+
sharedApi.internalActions.internal_getRTKQSubscriptions(),
261+
) as any
262+
263+
const store2InternalSubs = store2SubscriptionSelectors.getSubscriptions()
264+
265+
expect(store2InternalSubs.size).toBe(0)
266+
})
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
11
export function isNotNullish<T>(v: T | null | undefined): v is T {
22
return v != null
33
}
4+
5+
export function filterNullishValues<T>(map?: Map<any, T>) {
6+
return [...(map?.values() ?? [])].filter(isNotNullish) as NonNullable<T>[]
7+
}

0 commit comments

Comments
 (0)