diff --git a/etc/redux-toolkit.api.md b/etc/redux-toolkit.api.md index 401db6562e..ff4d4e41ff 100644 --- a/etc/redux-toolkit.api.md +++ b/etc/redux-toolkit.api.md @@ -13,6 +13,7 @@ import { DeepPartial } from 'redux'; import { Dispatch } from 'redux'; import { Draft } from 'immer'; import { Middleware } from 'redux'; +import { MiddlewareAPI } from 'redux'; import { OutputParametricSelector } from 'reselect'; import { OutputSelector } from 'reselect'; import { ParametricSelector } from 'reselect'; @@ -50,6 +51,17 @@ export interface ActionCreatorWithPreparedPayload; } +// @alpha (undocumented) +export type ActionListener, O extends ActionListenerOptions> = (action: A, api: ActionListenerMiddlewareAPI) => void; + +// @alpha (undocumented) +export interface ActionListenerMiddlewareAPI, O extends ActionListenerOptions> extends MiddlewareAPI { + // (undocumented) + stopPropagation: WhenFromOptions extends 'before' ? () => void : undefined; + // (undocumented) + unsubscribe(): void; +} + // @public export interface ActionReducerMapBuilder { addCase>(actionCreator: ActionCreator, reducer: CaseReducer>): ActionReducerMapBuilder; @@ -59,6 +71,16 @@ export interface ActionReducerMapBuilder { // @public @deprecated export type Actions = Record; +// @alpha (undocumented) +export const addListenerAction: BaseActionCreator<{ + type: string; + listener: ActionListener; + options: ActionListenerOptions; +}, "actionListenerMiddleware/add", never, never> & { + , S, D extends Dispatch, O extends ActionListenerOptions>(actionCreator: C, listener: ActionListener, S, D, O>, options?: O | undefined): AddListenerAction, S, D, O>; + , O_1 extends ActionListenerOptions>(type: string, listener: ActionListener, options?: O_1 | undefined): AddListenerAction; +}; + // @public export type AsyncThunkAction = (dispatch: GetDispatch, getState: () => GetState, extra: GetExtra) => Promise(type: T): Payl // @public export function createAction, T extends string = string>(type: T, prepareAction: PA): PayloadActionCreator['payload'], T, PA>; +// @alpha (undocumented) +export function createActionListenerMiddleware = Dispatch>(): Middleware<(action: Action<"actionListenerMiddleware/add">) => () => void, S, D> & { + addListener: { + , O extends ActionListenerOptions>(actionCreator: C, listener: ActionListener, S, D, O>, options?: O | undefined): () => void; + (type: T, listener: ActionListener, S, D, O_1>, options?: O_1 | undefined): () => void; + }; + removeListener: { + >(actionCreator: C_1, listener: ActionListener, S, D, any>): boolean; + (type: string, listener: ActionListener): boolean; + }; +} & WithMiddlewareType) => () => void, S, D>>; + // @public (undocumented) export function createAsyncThunk(typePrefix: string, payloadCreator: (arg: ThunkArg, thunkAPI: GetThunkAPI) => Promise>> | Returned | RejectWithValue>, options?: AsyncThunkOptions): IsAny AsyncThunkAction, unknown extends ThunkArg ? (arg: ThunkArg) => AsyncThunkAction : [ThunkArg] extends [void] | [undefined] ? () => AsyncThunkAction : [void] extends [ThunkArg] ? (arg?: ThunkArg | undefined) => AsyncThunkAction : [undefined] extends [ThunkArg] ? (arg?: ThunkArg | undefined) => AsyncThunkAction : (arg: ThunkArg) => AsyncThunkAction> & { pending: ActionCreatorWithPreparedPayload<[string, ThunkArg], undefined, string, never, { @@ -342,6 +376,15 @@ export type PrepareAction

= ((...args: any[]) => { error: any; }); +// @alpha (undocumented) +export const removeListenerAction: BaseActionCreator<{ + type: string; + listener: ActionListener; +}, "actionListenerMiddleware/remove", never, never> & { + , S, D extends Dispatch>(actionCreator: C, listener: ActionListener, S, D, any>): RemoveListenerAction, S, D>; + >(type: string, listener: ActionListener): RemoveListenerAction; +}; + export { Selector } // @public diff --git a/src/createAction.ts b/src/createAction.ts index 93558d7be6..ca3ffbac7a 100644 --- a/src/createAction.ts +++ b/src/createAction.ts @@ -81,7 +81,7 @@ export type _ActionCreatorWithPreparedPayload< * * @inheritdoc {redux#ActionCreator} */ -interface BaseActionCreator { +export interface BaseActionCreator { type: T match(action: Action): action is PayloadAction } diff --git a/src/createActionListenerMiddleware.test.ts b/src/createActionListenerMiddleware.test.ts new file mode 100644 index 0000000000..b3bf831db0 --- /dev/null +++ b/src/createActionListenerMiddleware.test.ts @@ -0,0 +1,354 @@ +import { configureStore } from './configureStore' +import { + createActionListenerMiddleware, + addListenerAction, + removeListenerAction, + When, + ActionListenerMiddlewareAPI +} from './createActionListenerMiddleware' +import { createAction } from './createAction' +import { AnyAction } from 'redux' + +const middlewareApi = { + getState: expect.any(Function), + dispatch: expect.any(Function), + stopPropagation: expect.any(Function), + unsubscribe: expect.any(Function) +} + +const noop = () => {} + +describe('createActionListenerMiddleware', () => { + let store = configureStore({ + reducer: () => ({}), + middleware: [createActionListenerMiddleware()] as const + }) + let reducer: jest.Mock + let middleware: ReturnType + + const testAction1 = createAction('testAction1') + type TestAction1 = ReturnType + const testAction2 = createAction('testAction2') + + beforeEach(() => { + middleware = createActionListenerMiddleware() + reducer = jest.fn(() => ({})) + store = configureStore({ + reducer, + middleware: [middleware] as const + }) + }) + + test('directly subscribing', () => { + const listener = jest.fn((_: TestAction1) => {}) + + middleware.addListener(testAction1, listener) + + store.dispatch(testAction1('a')) + store.dispatch(testAction2('b')) + store.dispatch(testAction1('c')) + + expect(listener.mock.calls).toEqual([ + [testAction1('a'), middlewareApi], + [testAction1('c'), middlewareApi] + ]) + }) + + test('subscribing with the same listener will not make it trigger twice (like EventTarget.addEventListener())', () => { + const listener = jest.fn((_: TestAction1) => {}) + + middleware.addListener(testAction1, listener) + middleware.addListener(testAction1, listener) + + store.dispatch(testAction1('a')) + store.dispatch(testAction2('b')) + store.dispatch(testAction1('c')) + + expect(listener.mock.calls).toEqual([ + [testAction1('a'), middlewareApi], + [testAction1('c'), middlewareApi] + ]) + }) + + test('unsubscribing via callback', () => { + const listener = jest.fn((_: TestAction1) => {}) + + const unsubscribe = middleware.addListener(testAction1, listener) + + store.dispatch(testAction1('a')) + unsubscribe() + store.dispatch(testAction2('b')) + store.dispatch(testAction1('c')) + + expect(listener.mock.calls).toEqual([[testAction1('a'), middlewareApi]]) + }) + + test('directly unsubscribing', () => { + const listener = jest.fn((_: TestAction1) => {}) + + middleware.addListener(testAction1, listener) + + store.dispatch(testAction1('a')) + + middleware.removeListener(testAction1, listener) + store.dispatch(testAction2('b')) + store.dispatch(testAction1('c')) + + expect(listener.mock.calls).toEqual([[testAction1('a'), middlewareApi]]) + }) + + test('unsubscribing without any subscriptions does not trigger an error', () => { + middleware.removeListener(testAction1, noop) + }) + + test('subscribing via action', () => { + const listener = jest.fn((_: TestAction1) => {}) + + store.dispatch(addListenerAction(testAction1, listener)) + + store.dispatch(testAction1('a')) + store.dispatch(testAction2('b')) + store.dispatch(testAction1('c')) + + expect(listener.mock.calls).toEqual([ + [testAction1('a'), middlewareApi], + [testAction1('c'), middlewareApi] + ]) + }) + + test('unsubscribing via callback from dispatch', () => { + const listener = jest.fn((_: TestAction1) => {}) + + const unsubscribe = store.dispatch(addListenerAction(testAction1, listener)) + + store.dispatch(testAction1('a')) + unsubscribe() + store.dispatch(testAction2('b')) + store.dispatch(testAction1('c')) + + expect(listener.mock.calls).toEqual([[testAction1('a'), middlewareApi]]) + }) + + test('unsubscribing via action', () => { + const listener = jest.fn((_: TestAction1) => {}) + + middleware.addListener(testAction1, listener) + + store.dispatch(testAction1('a')) + + store.dispatch(removeListenerAction(testAction1, listener)) + store.dispatch(testAction2('b')) + store.dispatch(testAction1('c')) + + expect(listener.mock.calls).toEqual([[testAction1('a'), middlewareApi]]) + }) + + const unforwaredActions: [string, AnyAction][] = [ + ['addListenerAction', addListenerAction(testAction1, noop)], + ['removeListenerAction', removeListenerAction(testAction1, noop)] + ] + test.each(unforwaredActions)( + '"%s" is not forwarded to the reducer', + (_, action) => { + reducer.mockClear() + + store.dispatch(testAction1('a')) + store.dispatch(action) + store.dispatch(testAction2('b')) + + expect(reducer.mock.calls).toEqual([ + [{}, testAction1('a')], + [{}, testAction2('b')] + ]) + } + ) + + test('"can unsubscribe via middleware api', () => { + const listener = jest.fn( + ( + action: TestAction1, + api: ActionListenerMiddlewareAPI + ) => { + if (action.payload === 'b') { + api.unsubscribe() + } + } + ) + + middleware.addListener(testAction1, listener) + + store.dispatch(testAction1('a')) + store.dispatch(testAction1('b')) + store.dispatch(testAction1('c')) + + expect(listener.mock.calls).toEqual([ + [testAction1('a'), middlewareApi], + [testAction1('b'), middlewareApi] + ]) + }) + + const whenMap: [When, string, string][] = [ + [undefined, 'reducer', 'listener'], + ['before', 'listener', 'reducer'], + ['after', 'reducer', 'listener'] + ] + test.each(whenMap)( + 'with "when" set to %s, %s runs before %s', + (when, _, shouldRunLast) => { + let whoRanLast = '' + + reducer.mockClear() + reducer.mockImplementationOnce(() => { + whoRanLast = 'reducer' + }) + const listener = jest.fn(() => { + whoRanLast = 'listener' + }) + + middleware.addListener(testAction1, listener, when ? { when } : {}) + + store.dispatch(testAction1('a')) + expect(reducer).toHaveBeenCalledTimes(1) + expect(listener).toHaveBeenCalledTimes(1) + expect(whoRanLast).toBe(shouldRunLast) + } + ) + + test('mixing "before" and "after"', () => { + const calls: Function[] = [] + function before1() { + calls.push(before1) + } + function before2() { + calls.push(before2) + } + function after1() { + calls.push(after1) + } + function after2() { + calls.push(after2) + } + + middleware.addListener(testAction1, before1, { when: 'before' }) + middleware.addListener(testAction1, before2, { when: 'before' }) + middleware.addListener(testAction1, after1, { when: 'after' }) + middleware.addListener(testAction1, after2, { when: 'after' }) + + store.dispatch(testAction1('a')) + store.dispatch(testAction2('a')) + + expect(calls).toEqual([before1, before2, after1, after2]) + }) + + test('mixing "before" and "after" with stopPropagation', () => { + const calls: Function[] = [] + function before1() { + calls.push(before1) + } + function before2(_: any, api: any) { + calls.push(before2) + api.stopPropagation() + } + function before3() { + calls.push(before3) + } + function after1() { + calls.push(after1) + } + function after2() { + calls.push(after2) + } + + middleware.addListener(testAction1, before1, { when: 'before' }) + middleware.addListener(testAction1, before2, { when: 'before' }) + middleware.addListener(testAction1, before3, { when: 'before' }) + middleware.addListener(testAction1, after1, { when: 'after' }) + middleware.addListener(testAction1, after2, { when: 'after' }) + + store.dispatch(testAction1('a')) + store.dispatch(testAction2('a')) + + expect(calls).toEqual([before1, before2]) + }) + + test('by default, actions are forwarded to the store', () => { + reducer.mockClear() + + const listener = jest.fn((_: TestAction1) => {}) + + middleware.addListener(testAction1, listener) + + store.dispatch(testAction1('a')) + + expect(reducer.mock.calls).toEqual([[{}, testAction1('a')]]) + }) + + test('calling `api.stopPropagation` in the listeners prevents actions from being forwarded to the store', () => { + reducer.mockClear() + + middleware.addListener( + testAction1, + (action: TestAction1, api) => { + if (action.payload === 'b') { + api.stopPropagation() + } + }, + { when: 'before' } + ) + + store.dispatch(testAction1('a')) + store.dispatch(testAction1('b')) + store.dispatch(testAction1('c')) + + expect(reducer.mock.calls).toEqual([ + [{}, testAction1('a')], + [{}, testAction1('c')] + ]) + }) + + test('calling `api.stopPropagation` with `when` set to "after" causes an error to be thrown', () => { + reducer.mockClear() + + middleware.addListener( + testAction1, + (action: TestAction1, api) => { + if (action.payload === 'b') { + // @ts-ignore TypeScript would already prevent this from being called with "after" + api.stopPropagation() + } + }, + { when: 'after' } + ) + + store.dispatch(testAction1('a')) + expect(() => { + store.dispatch(testAction1('b')) + }).toThrowErrorMatchingInlineSnapshot( + `"stopPropagation can only be called by action listeners with the \`when\` option set to \\"before\\""` + ) + }) + + test('calling `api.stopPropagation` asynchronously causes an error to be thrown', finish => { + reducer.mockClear() + + middleware.addListener( + testAction1, + (action: TestAction1, api) => { + if (action.payload === 'b') { + setTimeout(() => { + expect(() => { + api.stopPropagation() + }).toThrowErrorMatchingInlineSnapshot( + `"stopPropagation can only be called synchronously"` + ) + finish() + }) + } + }, + { when: 'before' } + ) + + store.dispatch(testAction1('a')) + store.dispatch(testAction1('b')) + }) +}) diff --git a/src/createActionListenerMiddleware.ts b/src/createActionListenerMiddleware.ts new file mode 100644 index 0000000000..f554086337 --- /dev/null +++ b/src/createActionListenerMiddleware.ts @@ -0,0 +1,332 @@ +import { Middleware, Dispatch, AnyAction, MiddlewareAPI, Action } from 'redux' +import { TypedActionCreator } from './mapBuilders' +import { createAction, BaseActionCreator } from './createAction' +import { WithMiddlewareType } from './tsHelpers' + +export type When = 'before' | 'after' | undefined +type WhenFromOptions< + O extends ActionListenerOptions +> = O extends ActionListenerOptions ? O['when'] : never + +/** + * @alpha + */ +export interface ActionListenerMiddlewareAPI< + S, + D extends Dispatch, + O extends ActionListenerOptions +> extends MiddlewareAPI { + stopPropagation: WhenFromOptions extends 'before' ? () => void : undefined + unsubscribe(): void +} + +/** + * @alpha + */ +export type ActionListener< + A extends AnyAction, + S, + D extends Dispatch, + O extends ActionListenerOptions +> = (action: A, api: ActionListenerMiddlewareAPI) => void + +export interface ActionListenerOptions { + /** + * Determines if the listener runs 'before' or 'after' the reducers have been called. + * If set to 'before', calling `api.stopPropagation()` from the listener becomes possible. + * Defaults to 'before'. + */ + when?: When +} + +export interface AddListenerAction< + A extends AnyAction, + S, + D extends Dispatch, + O extends ActionListenerOptions +> { + type: 'actionListenerMiddleware/add' + payload: { + type: string + listener: ActionListener + options?: O + } +} + +/** + * @alpha + */ +export const addListenerAction = createAction( + 'actionListenerMiddleware/add', + function prepare( + typeOrActionCreator: string | TypedActionCreator, + listener: ActionListener, + options?: ActionListenerOptions + ) { + const type = + typeof typeOrActionCreator === 'string' + ? typeOrActionCreator + : (typeOrActionCreator as TypedActionCreator).type + + return { + payload: { + type, + listener, + options + } + } + } +) as BaseActionCreator< + { + type: string + listener: ActionListener + options: ActionListenerOptions + }, + 'actionListenerMiddleware/add' +> & { + < + C extends TypedActionCreator, + S, + D extends Dispatch, + O extends ActionListenerOptions + >( + actionCreator: C, + listener: ActionListener, S, D, O>, + options?: O + ): AddListenerAction, S, D, O> + + ( + type: string, + listener: ActionListener, + options?: O + ): AddListenerAction +} + +interface RemoveListenerAction< + A extends AnyAction, + S, + D extends Dispatch +> { + type: 'actionListenerMiddleware/remove' + payload: { + type: string + listener: ActionListener + } +} + +/** + * @alpha + */ +export const removeListenerAction = createAction( + 'actionListenerMiddleware/remove', + function prepare( + typeOrActionCreator: string | TypedActionCreator, + listener: ActionListener + ) { + const type = + typeof typeOrActionCreator === 'string' + ? typeOrActionCreator + : (typeOrActionCreator as TypedActionCreator).type + + return { + payload: { + type, + listener + } + } + } +) as BaseActionCreator< + { type: string; listener: ActionListener }, + 'actionListenerMiddleware/remove' +> & { + , S, D extends Dispatch>( + actionCreator: C, + listener: ActionListener, S, D, any> + ): RemoveListenerAction, S, D> + + ( + type: string, + listener: ActionListener + ): RemoveListenerAction +} + +/** + * @alpha + */ +export function createActionListenerMiddleware< + S, + D extends Dispatch = Dispatch +>() { + type ListenerEntry = ActionListenerOptions & { + listener: ActionListener + } + + const listenerMap: Record | undefined> = {} + const middleware: Middleware< + { + (action: Action<'actionListenerMiddleware/add'>): Unsubscribe + }, + S, + D + > = api => next => action => { + if (addListenerAction.match(action)) { + const unsubscribe = addListener( + action.payload.type, + action.payload.listener, + action.payload.options + ) + + return unsubscribe + } + if (removeListenerAction.match(action)) { + removeListener(action.payload.type, action.payload.listener) + + return + } + + const listeners = listenerMap[action.type] + if (listeners) { + const defaultWhen = 'after' + let result: unknown + for (const phase of ['before', 'after'] as const) { + for (const entry of listeners) { + if (phase !== (entry.when || defaultWhen)) { + continue + } + let stoppedPropagation = false + let currentPhase = phase + let synchronousListenerFinished = false + entry.listener(action, { + ...api, + stopPropagation() { + if (currentPhase === 'before') { + if (!synchronousListenerFinished) { + stoppedPropagation = true + } else { + throw new Error( + 'stopPropagation can only be called synchronously' + ) + } + } else { + throw new Error( + 'stopPropagation can only be called by action listeners with the `when` option set to "before"' + ) + } + }, + unsubscribe() { + listeners.delete(entry) + } + }) + synchronousListenerFinished = true + if (stoppedPropagation) { + return action + } + } + if (phase === 'before') { + result = next(action) + } else { + return result + } + } + } + return next(action) + } + + type Unsubscribe = () => void + + function addListener< + C extends TypedActionCreator, + O extends ActionListenerOptions + >( + actionCreator: C, + listener: ActionListener, S, D, O>, + options?: O + ): Unsubscribe + function addListener( + type: T, + listener: ActionListener, S, D, O>, + options?: O + ): Unsubscribe + function addListener( + typeOrActionCreator: string | TypedActionCreator, + listener: ActionListener, + options?: ActionListenerOptions + ): Unsubscribe { + const type = + typeof typeOrActionCreator === 'string' + ? typeOrActionCreator + : typeOrActionCreator.type + + const listeners = getListenerMap(type) + + let entry = findListenerEntry(listeners, listener) + + if (!entry) { + entry = { + ...options, + listener + } + + listeners.add(entry) + } + + return () => listeners.delete(entry!) + } + + function getListenerMap(type: string) { + if (!listenerMap[type]) { + listenerMap[type] = new Set() + } + return listenerMap[type]! + } + + function removeListener>( + actionCreator: C, + listener: ActionListener, S, D, any> + ): boolean + function removeListener( + type: string, + listener: ActionListener + ): boolean + function removeListener( + typeOrActionCreator: string | TypedActionCreator, + listener: ActionListener + ): boolean { + const type = + typeof typeOrActionCreator === 'string' + ? typeOrActionCreator + : typeOrActionCreator.type + + const listeners = listenerMap[type] + + if (!listeners) { + return false + } + + let entry = findListenerEntry(listeners, listener) + + if (!entry) { + return false + } + + listeners.delete(entry) + return true + } + + function findListenerEntry( + entries: Set, + listener: Function + ): ListenerEntry | undefined { + for (const entry of entries) { + if (entry.listener === listener) { + return entry + } + } + } + + return Object.assign( + middleware, + { addListener, removeListener }, + {} as WithMiddlewareType + ) +} diff --git a/src/index.ts b/src/index.ts index f4a1843d92..0cb8290ab9 100644 --- a/src/index.ts +++ b/src/index.ts @@ -104,4 +104,12 @@ export { SerializedError } from './createAsyncThunk' +export { + createActionListenerMiddleware, + addListenerAction, + removeListenerAction, + ActionListener, + ActionListenerMiddlewareAPI +} from './createActionListenerMiddleware' + export { nanoid } from './nanoid' diff --git a/src/tsHelpers.ts b/src/tsHelpers.ts index d624cd359b..5899565d98 100644 --- a/src/tsHelpers.ts +++ b/src/tsHelpers.ts @@ -65,6 +65,11 @@ export type IsUnknownOrNonInferrable = AtLeastTS35< IsEmptyObj> > +const declaredMiddlewareType: unique symbol = undefined as any +export type WithMiddlewareType> = { + [declaredMiddlewareType]: T +} + /** * Combines all dispatch signatures of all middlewares in the array `M` into * one intersected dispatch signature. @@ -72,7 +77,19 @@ export type IsUnknownOrNonInferrable = AtLeastTS35< export type DispatchForMiddlewares = M extends ReadonlyArray ? UnionToIntersection< M[number] extends infer MiddlewareValues - ? MiddlewareValues extends Middleware + ? MiddlewareValues extends WithMiddlewareType< + infer DeclaredMiddlewareType + > + ? DeclaredMiddlewareType extends Middleware< + infer DispatchExt, + any, + any + > + ? DispatchExt extends Function + ? DispatchExt + : never + : never + : MiddlewareValues extends Middleware ? DispatchExt extends Function ? DispatchExt : never