diff --git a/src/createSlice.test.ts b/src/createSlice.test.ts index 77df847f00..4b4b76d7ae 100644 --- a/src/createSlice.test.ts +++ b/src/createSlice.test.ts @@ -35,7 +35,7 @@ describe('createSlice', () => { }) describe('when passing slice', () => { - const { actions, reducer } = createSlice({ + const { actions, reducer, caseReducers } = createSlice({ reducers: { increment: state => state + 1 }, @@ -57,6 +57,12 @@ describe('createSlice', () => { it('should return the correct value from reducer', () => { expect(reducer(undefined, actions.increment())).toEqual(1) }) + + it('should include the generated case reducers', () => { + expect(caseReducers).toBeTruthy() + expect(caseReducers.increment).toBeTruthy() + expect(typeof caseReducers.increment).toBe('function') + }) }) describe('when mutating state object', () => { diff --git a/src/createSlice.ts b/src/createSlice.ts index dc1210d6d3..42d437023e 100644 --- a/src/createSlice.ts +++ b/src/createSlice.ts @@ -18,7 +18,9 @@ export type SliceActionCreator

= PayloadActionCreator

export interface Slice< State = any, - ActionCreators extends { [key: string]: any } = { [key: string]: any } + CaseReducers extends SliceCaseReducerDefinitions = { + [key: string]: any + } > { /** * The slice name. @@ -34,7 +36,9 @@ export interface Slice< * Action creators for the types of actions that are handled by the slice * reducer. */ - actions: ActionCreators + actions: CaseReducerActions + + caseReducers: SliceDefinedCaseReducers } /** @@ -42,7 +46,10 @@ export interface Slice< */ export interface CreateSliceOptions< State = any, - CR extends SliceCaseReducers = SliceCaseReducers + CR extends SliceCaseReducerDefinitions< + State, + any + > = SliceCaseReducerDefinitions > { /** * The slice's name. Used to namespace the generated action types. @@ -74,15 +81,15 @@ type PayloadActions = Record< PayloadAction > -type EnhancedCaseReducer = { +type CaseReducerWithPrepare = { reducer: CaseReducer prepare: PrepareAction } -type SliceCaseReducers = { +type SliceCaseReducerDefinitions = { [ActionType in keyof PA]: | CaseReducer - | EnhancedCaseReducer + | CaseReducerWithPrepare } type IfIsReducerFunctionWithoutAction = R extends ( @@ -90,7 +97,7 @@ type IfIsReducerFunctionWithoutAction = R extends ( ) => any ? True : False -type IfIsEnhancedReducer = R extends { +type IfIsCaseReducerWithPrepare = R extends { prepare: Function } ? True @@ -106,8 +113,21 @@ type PrepareActionForReducer = R extends { prepare: infer Prepare } ? Prepare : never -type CaseReducerActions> = { - [Type in keyof CaseReducers]: IfIsEnhancedReducer< +type ActionForReducer = R extends ( + state: S, + action: PayloadAction +) => S + ? PayloadAction

+ : R extends { + reducer(state: any, action: PayloadAction): any + } + ? PayloadAction

+ : unknown + +type CaseReducerActions< + CaseReducers extends SliceCaseReducerDefinitions +> = { + [Type in keyof CaseReducers]: IfIsCaseReducerWithPrepare< CaseReducers[Type], ActionCreatorWithPreparedPayload< PrepareActionForReducer @@ -122,6 +142,16 @@ type CaseReducerActions> = { > } +type SliceDefinedCaseReducers< + CaseReducers extends SliceCaseReducerDefinitions, + State = any +> = { + [Type in keyof CaseReducers]: CaseReducer< + State, + ActionForReducer + > +} + type NoInfer = [T][T extends any ? 0 : never] type SliceCaseReducersCheck = { @@ -134,9 +164,9 @@ type SliceCaseReducersCheck = { : {} } -type RestrictEnhancedReducersToMatchReducerAndPrepare< +type RestrictCaseReducerDefinitionsToMatchReducerAndPrepare< S, - CR extends SliceCaseReducers + CR extends SliceCaseReducerDefinitions > = { reducers: SliceCaseReducersCheck> } function getType(slice: string, actionKey: string): string { @@ -153,54 +183,59 @@ function getType(slice: string, actionKey: string): string { */ export function createSlice< State, - CaseReducers extends SliceCaseReducers + CaseReducers extends SliceCaseReducerDefinitions >( options: CreateSliceOptions & - RestrictEnhancedReducersToMatchReducerAndPrepare -): Slice> + RestrictCaseReducerDefinitionsToMatchReducerAndPrepare +): Slice // internal definition is a little less restrictive export function createSlice< State, - CaseReducers extends SliceCaseReducers + CaseReducers extends SliceCaseReducerDefinitions >( options: CreateSliceOptions -): Slice> { +): Slice { const { name, initialState } = options if (!name) { throw new Error('`name` is a required option for createSlice') } const reducers = options.reducers || {} const extraReducers = options.extraReducers || {} - const actionKeys = Object.keys(reducers) - - const reducerMap = actionKeys.reduce((map, actionKey) => { - let maybeEnhancedReducer = reducers[actionKey] - map[getType(name, actionKey)] = - typeof maybeEnhancedReducer === 'function' - ? maybeEnhancedReducer - : maybeEnhancedReducer.reducer - return map - }, extraReducers) - - const reducer = createReducer(initialState, reducerMap) - - const actionMap = actionKeys.reduce( - (map, action) => { - let maybeEnhancedReducer = reducers[action] - const type = getType(name, action) - map[action] = - typeof maybeEnhancedReducer === 'function' - ? createAction(type) - : createAction(type, maybeEnhancedReducer.prepare) - return map - }, - {} as any - ) + const reducerNames = Object.keys(reducers) + + const sliceCaseReducersByName: Record = {} + const sliceCaseReducersByType: Record = {} + const actionCreators: Record = {} + + reducerNames.forEach(reducerName => { + const maybeReducerWithPrepare = reducers[reducerName] + const type = getType(name, reducerName) + + let caseReducer: CaseReducer + let prepareCallback: PrepareAction | undefined + + if (typeof maybeReducerWithPrepare === 'function') { + caseReducer = maybeReducerWithPrepare + } else { + caseReducer = maybeReducerWithPrepare.reducer + prepareCallback = maybeReducerWithPrepare.prepare + } + + sliceCaseReducersByName[reducerName] = caseReducer + sliceCaseReducersByType[type] = caseReducer + actionCreators[reducerName] = prepareCallback + ? createAction(type, prepareCallback) + : createAction(type) + }) + + const finalCaseReducers = { ...extraReducers, ...sliceCaseReducersByType } + const reducer = createReducer(initialState, finalCaseReducers) return { name, reducer, - actions: actionMap + actions: actionCreators as any, + caseReducers: sliceCaseReducersByName as any } } diff --git a/type-tests/files/createSlice.typetest.ts b/type-tests/files/createSlice.typetest.ts index 95fd9b5982..ff6299d8d0 100644 --- a/type-tests/files/createSlice.typetest.ts +++ b/type-tests/files/createSlice.typetest.ts @@ -147,6 +147,57 @@ function expectType(t: T) { expectType(counter.actions.concatMetaStrLen('test').meta) } +/* + * Test: returned case reducer has the correct type + */ +{ + const counter = createSlice({ + name: 'counter', + initialState: 0, + reducers: { + increment(state, action: PayloadAction) { + return state + action.payload + }, + decrement: { + reducer(state, action: PayloadAction) { + return state - action.payload + }, + prepare(amount: number) { + return { payload: amount } + } + } + } + }) + + // Should match positively + expectType<(state: number, action: PayloadAction) => number | void>( + counter.caseReducers.increment + ) + + // Should match positively for reducers with prepare callback + expectType<(state: number, action: PayloadAction) => number | void>( + counter.caseReducers.decrement + ) + + // Should not mismatch the payload if it's a simple reducer + // typings:expect-error + expectType<(state: number, action: PayloadAction) => number | void>( + counter.caseReducers.increment + ) + + // Should not mismatch the payload if it's a reducer with a prepare callback + // typings:expect-error + expectType<(state: number, action: PayloadAction) => number | void>( + counter.caseReducers.decrement + ) + + // Should not include entries that don't exist + // typings:expect-error + expectType<(state: number, action: PayloadAction) => number | void>( + counter.caseReducers.someThingNonExistant + ) +} + /* * Test: prepared payload does not match action payload - should cause an error. */ @@ -180,6 +231,7 @@ function expectType(t: T) { } const mySlice = createSlice({ + name: 'name', initialState, reducers: { setName: (state, action) => {