Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 32 additions & 34 deletions packages/action-listener-middleware/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -59,25 +59,6 @@ function assertFunction(
}
}

export const hasMatchFunction = <T>(
v: Matcher<T>
): v is HasMatchFunction<T> => {
return v && typeof (v as HasMatchFunction<T>).match === 'function'
}

export const isActionCreator = (
item: Function
): item is TypedActionCreator<any> => {
return (
typeof item === 'function' &&
typeof (item as any).type === 'string' &&
hasMatchFunction(item as any)
)
}

/** @public */
export type Matcher<T> = HasMatchFunction<T> | MatchFunction<T>

type Unsubscribe = () => void

type GuardedType<T> = T extends (x: any, ...args: unknown[]) => x is infer T
Expand Down Expand Up @@ -110,6 +91,7 @@ export interface ActionListenerMiddlewareAPI<S, D extends Dispatch<AnyAction>>
extends MiddlewareAPI<D, S> {
getOriginalState: () => S
unsubscribe(): void
subscribe(): void
condition: ConditionFunction<S>
currentPhase: MiddlewarePhase
// TODO Figure out how to pass this through the other types correctly
Expand Down Expand Up @@ -146,11 +128,15 @@ export interface CreateListenerMiddlewareOptions<ExtraArgument = unknown> {
onError?: ListenerErrorHandler
}

/**
* The possible overloads and options for defining a listener. The return type of each function is specified as a generic arg, so the overloads can be reused for multiple different functions
*/
interface AddListenerOverloads<
Return,
S = unknown,
D extends Dispatch = ThunkDispatch<S, unknown, AnyAction>
> {
/** Accepts a "listener predicate" that is also a TS type predicate for the action*/
<MA extends AnyAction, LP extends ListenerPredicate<MA, S>>(
options: {
actionCreator?: never
Expand All @@ -160,6 +146,8 @@ interface AddListenerOverloads<
listener: ActionListener<ListenerPredicateGuardedActionType<LP>, S, D>
} & ActionListenerOptions
): Return

/** Accepts an RTK action creator, like `incrementByAmount` */
<C extends TypedActionCreator<any>>(
options: {
actionCreator: C
Expand All @@ -169,6 +157,8 @@ interface AddListenerOverloads<
listener: ActionListener<ReturnType<C>, S, D>
} & ActionListenerOptions
): Return

/** Accepts a specific action type string */
<T extends string>(
options: {
actionCreator?: never
Expand All @@ -178,6 +168,8 @@ interface AddListenerOverloads<
listener: ActionListener<Action<T>, S, D>
} & ActionListenerOptions
): Return

/** Accepts an RTK matcher function, such as `incrementByAmount.match` */
<MA extends AnyAction, M extends MatchFunction<MA>>(
options: {
actionCreator?: never
Expand All @@ -188,6 +180,7 @@ interface AddListenerOverloads<
} & ActionListenerOptions
): Return

/** Accepts a "listener predicate" that just returns a boolean, no type assertion */
<LP extends AnyActionListenerPredicate<S>>(
options: {
actionCreator?: never
Expand All @@ -210,19 +203,22 @@ interface RemoveListenerOverloads<
(type: string, listener: ActionListener<AnyAction, S, D>): boolean
}

/** A "pre-typed" version of `addListenerAction`, so the listener args are well-typed */
export type TypedAddListenerAction<
S,
D extends Dispatch<AnyAction> = ThunkDispatch<S, unknown, AnyAction>,
Payload = ListenerEntry<S, D>,
T extends string = 'actionListenerMiddleware/add'
> = BaseActionCreator<Payload, T> &
AddListenerOverloads<PayloadAction<Payload>, S, D>
AddListenerOverloads<PayloadAction<Payload, T>, S, D>

/** A "pre-typed" version of `middleware.addListener`, so the listener args are well-typed */
export type TypedAddListener<
S,
D extends Dispatch<AnyAction> = ThunkDispatch<S, unknown, AnyAction>
> = AddListenerOverloads<Unsubscribe, S, D>

/** @internal An single listener entry */
type ListenerEntry<
S = unknown,
D extends Dispatch<AnyAction> = Dispatch<AnyAction>
Expand All @@ -235,16 +231,13 @@ type ListenerEntry<
predicate: ListenerPredicate<AnyAction, S>
}

/** A "pre-typed" version of `createListenerEntry`, so the listener args are well-typed */
export type TypedCreateListenerEntry<
S,
D extends Dispatch<AnyAction> = ThunkDispatch<S, unknown, AnyAction>
> = AddListenerOverloads<ListenerEntry<S, D>, S, D>

export type TypedAddListenerPrepareFunction<
S,
D extends Dispatch<AnyAction> = ThunkDispatch<S, unknown, AnyAction>
> = AddListenerOverloads<{ payload: ListenerEntry<S, D> }, S, D>

// A shorthand form of the accepted args, solely so that `createListenerEntry` has validly-typed conditional logic when checking the options contents
type FallbackAddListenerOptions = (
| { actionCreator: TypedActionCreator<string> }
| { type: string }
Expand All @@ -253,6 +246,7 @@ type FallbackAddListenerOptions = (
) &
ActionListenerOptions & { listener: ActionListener<any, any, any> }

/** Accepts the possible options for creating a listener, and returns a formatted listener entry */
export const createListenerEntry: TypedCreateListenerEntry<unknown> = (
options: FallbackAddListenerOptions
) => {
Expand Down Expand Up @@ -336,6 +330,7 @@ export const addListenerAction = createAction(
'actionListenerMiddleware/add',
function prepare(options: unknown) {
const entry = createListenerEntry(
// Fake out TS here
options as Parameters<AddListenerOverloads<unknown>>[0]
)

Expand Down Expand Up @@ -406,14 +401,6 @@ export function createActionListenerMiddleware<
D extends Dispatch<AnyAction> = ThunkDispatch<S, unknown, AnyAction>,
ExtraArgument = unknown
>(middlewareOptions: CreateListenerMiddlewareOptions<ExtraArgument> = {}) {
type ListenerEntry = ActionListenerOptions & {
id: string
listener: ActionListener<any, S, D>
unsubscribe: () => void
type?: string
predicate: ListenerPredicate<any, any>
}

const listenerMap = new Map<string, ListenerEntry>()
const { extra, onError = defaultErrorHandler } = middlewareOptions

Expand All @@ -434,7 +421,15 @@ export function createActionListenerMiddleware<
D
> = (api) => (next) => (action) => {
if (addListenerAction.match(action)) {
return insertEntry(action.payload)
let entry = findListenerEntry(
(existingEntry) => existingEntry.listener === action.payload.listener
)

if (!entry) {
entry = action.payload
}

return insertEntry(entry)
}
if (removeListenerAction.match(action)) {
removeListener(action.payload.type, action.payload.listener)
Expand Down Expand Up @@ -479,6 +474,9 @@ export function createActionListenerMiddleware<
currentPhase,
extra,
unsubscribe: entry.unsubscribe,
subscribe: () => {
listenerMap.set(entry.id, entry)
},
})
} catch (listenerError) {
safelyNotifyError(onError, listenerError)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,21 @@ import {
configureStore,
createAction,
createSlice,
AnyAction,
isAnyOf,
PayloadAction,
} from '@reduxjs/toolkit'

import type { AnyAction, PayloadAction, Action } from '@reduxjs/toolkit'

import {
createActionListenerMiddleware,
createListenerEntry,
addListenerAction,
removeListenerAction,
} from '../index'

import type {
When,
ActionListenerMiddlewareAPI,
ActionListenerMiddleware,
TypedCreateListenerEntry,
TypedAddListenerAction,
TypedAddListener,
} from '../index'
Expand All @@ -27,6 +29,7 @@ const middlewareApi = {
dispatch: expect.any(Function),
currentPhase: expect.stringMatching(/beforeReducer|afterReducer/),
unsubscribe: expect.any(Function),
subscribe: expect.any(Function),
}

const noop = () => {}
Expand Down Expand Up @@ -344,6 +347,7 @@ describe('createActionListenerMiddleware', () => {
listener,
})
)
expectType<Action<'actionListenerMiddleware/add'>>(unsubscribe)

store.dispatch(testAction1('a'))
// TODO This return type isn't correct
Expand Down Expand Up @@ -419,6 +423,36 @@ describe('createActionListenerMiddleware', () => {
])
})

test('Can re-subscribe via middleware api', async () => {
let numListenerRuns = 0
middleware.addListener({
actionCreator: testAction1,
listener: async (action, listenerApi) => {
numListenerRuns++

listenerApi.unsubscribe()

await listenerApi.condition(testAction2.match)

listenerApi.subscribe()
},
})

store.dispatch(testAction1('a'))
expect(numListenerRuns).toBe(1)

store.dispatch(testAction1('a'))
expect(numListenerRuns).toBe(1)

store.dispatch(testAction2('b'))
expect(numListenerRuns).toBe(1)

await delay(5)

store.dispatch(testAction1('b'))
expect(numListenerRuns).toBe(2)
})

const whenMap: [When, string, string, number][] = [
[undefined, 'reducer', 'listener', 1],
['beforeReducer', 'listener', 'reducer', 1],
Expand Down