diff --git a/src/createSlice.test.ts b/src/createSlice.test.ts
index 77df847f00..4b4b76d7ae 100644
--- a/src/createSlice.test.ts
+++ b/src/createSlice.test.ts
@@ -35,7 +35,7 @@ describe('createSlice', () => {
})
describe('when passing slice', () => {
- const { actions, reducer } = createSlice({
+ const { actions, reducer, caseReducers } = createSlice({
reducers: {
increment: state => state + 1
},
@@ -57,6 +57,12 @@ describe('createSlice', () => {
it('should return the correct value from reducer', () => {
expect(reducer(undefined, actions.increment())).toEqual(1)
})
+
+ it('should include the generated case reducers', () => {
+ expect(caseReducers).toBeTruthy()
+ expect(caseReducers.increment).toBeTruthy()
+ expect(typeof caseReducers.increment).toBe('function')
+ })
})
describe('when mutating state object', () => {
diff --git a/src/createSlice.ts b/src/createSlice.ts
index dc1210d6d3..42d437023e 100644
--- a/src/createSlice.ts
+++ b/src/createSlice.ts
@@ -18,7 +18,9 @@ export type SliceActionCreator
= PayloadActionCreator
export interface Slice<
State = any,
- ActionCreators extends { [key: string]: any } = { [key: string]: any }
+ CaseReducers extends SliceCaseReducerDefinitions = {
+ [key: string]: any
+ }
> {
/**
* The slice name.
@@ -34,7 +36,9 @@ export interface Slice<
* Action creators for the types of actions that are handled by the slice
* reducer.
*/
- actions: ActionCreators
+ actions: CaseReducerActions
+
+ caseReducers: SliceDefinedCaseReducers
}
/**
@@ -42,7 +46,10 @@ export interface Slice<
*/
export interface CreateSliceOptions<
State = any,
- CR extends SliceCaseReducers = SliceCaseReducers
+ CR extends SliceCaseReducerDefinitions<
+ State,
+ any
+ > = SliceCaseReducerDefinitions
> {
/**
* The slice's name. Used to namespace the generated action types.
@@ -74,15 +81,15 @@ type PayloadActions = Record<
PayloadAction
>
-type EnhancedCaseReducer = {
+type CaseReducerWithPrepare = {
reducer: CaseReducer
prepare: PrepareAction
}
-type SliceCaseReducers = {
+type SliceCaseReducerDefinitions = {
[ActionType in keyof PA]:
| CaseReducer
- | EnhancedCaseReducer
+ | CaseReducerWithPrepare
}
type IfIsReducerFunctionWithoutAction = R extends (
@@ -90,7 +97,7 @@ type IfIsReducerFunctionWithoutAction = R extends (
) => any
? True
: False
-type IfIsEnhancedReducer = R extends {
+type IfIsCaseReducerWithPrepare = R extends {
prepare: Function
}
? True
@@ -106,8 +113,21 @@ type PrepareActionForReducer = R extends { prepare: infer Prepare }
? Prepare
: never
-type CaseReducerActions> = {
- [Type in keyof CaseReducers]: IfIsEnhancedReducer<
+type ActionForReducer = R extends (
+ state: S,
+ action: PayloadAction
+) => S
+ ? PayloadAction
+ : R extends {
+ reducer(state: any, action: PayloadAction): any
+ }
+ ? PayloadAction
+ : unknown
+
+type CaseReducerActions<
+ CaseReducers extends SliceCaseReducerDefinitions
+> = {
+ [Type in keyof CaseReducers]: IfIsCaseReducerWithPrepare<
CaseReducers[Type],
ActionCreatorWithPreparedPayload<
PrepareActionForReducer
@@ -122,6 +142,16 @@ type CaseReducerActions> = {
>
}
+type SliceDefinedCaseReducers<
+ CaseReducers extends SliceCaseReducerDefinitions,
+ State = any
+> = {
+ [Type in keyof CaseReducers]: CaseReducer<
+ State,
+ ActionForReducer
+ >
+}
+
type NoInfer = [T][T extends any ? 0 : never]
type SliceCaseReducersCheck = {
@@ -134,9 +164,9 @@ type SliceCaseReducersCheck = {
: {}
}
-type RestrictEnhancedReducersToMatchReducerAndPrepare<
+type RestrictCaseReducerDefinitionsToMatchReducerAndPrepare<
S,
- CR extends SliceCaseReducers
+ CR extends SliceCaseReducerDefinitions
> = { reducers: SliceCaseReducersCheck> }
function getType(slice: string, actionKey: string): string {
@@ -153,54 +183,59 @@ function getType(slice: string, actionKey: string): string {
*/
export function createSlice<
State,
- CaseReducers extends SliceCaseReducers
+ CaseReducers extends SliceCaseReducerDefinitions
>(
options: CreateSliceOptions &
- RestrictEnhancedReducersToMatchReducerAndPrepare
-): Slice>
+ RestrictCaseReducerDefinitionsToMatchReducerAndPrepare
+): Slice
// internal definition is a little less restrictive
export function createSlice<
State,
- CaseReducers extends SliceCaseReducers
+ CaseReducers extends SliceCaseReducerDefinitions
>(
options: CreateSliceOptions
-): Slice> {
+): Slice {
const { name, initialState } = options
if (!name) {
throw new Error('`name` is a required option for createSlice')
}
const reducers = options.reducers || {}
const extraReducers = options.extraReducers || {}
- const actionKeys = Object.keys(reducers)
-
- const reducerMap = actionKeys.reduce((map, actionKey) => {
- let maybeEnhancedReducer = reducers[actionKey]
- map[getType(name, actionKey)] =
- typeof maybeEnhancedReducer === 'function'
- ? maybeEnhancedReducer
- : maybeEnhancedReducer.reducer
- return map
- }, extraReducers)
-
- const reducer = createReducer(initialState, reducerMap)
-
- const actionMap = actionKeys.reduce(
- (map, action) => {
- let maybeEnhancedReducer = reducers[action]
- const type = getType(name, action)
- map[action] =
- typeof maybeEnhancedReducer === 'function'
- ? createAction(type)
- : createAction(type, maybeEnhancedReducer.prepare)
- return map
- },
- {} as any
- )
+ const reducerNames = Object.keys(reducers)
+
+ const sliceCaseReducersByName: Record = {}
+ const sliceCaseReducersByType: Record = {}
+ const actionCreators: Record = {}
+
+ reducerNames.forEach(reducerName => {
+ const maybeReducerWithPrepare = reducers[reducerName]
+ const type = getType(name, reducerName)
+
+ let caseReducer: CaseReducer
+ let prepareCallback: PrepareAction | undefined
+
+ if (typeof maybeReducerWithPrepare === 'function') {
+ caseReducer = maybeReducerWithPrepare
+ } else {
+ caseReducer = maybeReducerWithPrepare.reducer
+ prepareCallback = maybeReducerWithPrepare.prepare
+ }
+
+ sliceCaseReducersByName[reducerName] = caseReducer
+ sliceCaseReducersByType[type] = caseReducer
+ actionCreators[reducerName] = prepareCallback
+ ? createAction(type, prepareCallback)
+ : createAction(type)
+ })
+
+ const finalCaseReducers = { ...extraReducers, ...sliceCaseReducersByType }
+ const reducer = createReducer(initialState, finalCaseReducers)
return {
name,
reducer,
- actions: actionMap
+ actions: actionCreators as any,
+ caseReducers: sliceCaseReducersByName as any
}
}
diff --git a/type-tests/files/createSlice.typetest.ts b/type-tests/files/createSlice.typetest.ts
index 95fd9b5982..ff6299d8d0 100644
--- a/type-tests/files/createSlice.typetest.ts
+++ b/type-tests/files/createSlice.typetest.ts
@@ -147,6 +147,57 @@ function expectType(t: T) {
expectType(counter.actions.concatMetaStrLen('test').meta)
}
+/*
+ * Test: returned case reducer has the correct type
+ */
+{
+ const counter = createSlice({
+ name: 'counter',
+ initialState: 0,
+ reducers: {
+ increment(state, action: PayloadAction) {
+ return state + action.payload
+ },
+ decrement: {
+ reducer(state, action: PayloadAction) {
+ return state - action.payload
+ },
+ prepare(amount: number) {
+ return { payload: amount }
+ }
+ }
+ }
+ })
+
+ // Should match positively
+ expectType<(state: number, action: PayloadAction) => number | void>(
+ counter.caseReducers.increment
+ )
+
+ // Should match positively for reducers with prepare callback
+ expectType<(state: number, action: PayloadAction) => number | void>(
+ counter.caseReducers.decrement
+ )
+
+ // Should not mismatch the payload if it's a simple reducer
+ // typings:expect-error
+ expectType<(state: number, action: PayloadAction) => number | void>(
+ counter.caseReducers.increment
+ )
+
+ // Should not mismatch the payload if it's a reducer with a prepare callback
+ // typings:expect-error
+ expectType<(state: number, action: PayloadAction) => number | void>(
+ counter.caseReducers.decrement
+ )
+
+ // Should not include entries that don't exist
+ // typings:expect-error
+ expectType<(state: number, action: PayloadAction) => number | void>(
+ counter.caseReducers.someThingNonExistant
+ )
+}
+
/*
* Test: prepared payload does not match action payload - should cause an error.
*/
@@ -180,6 +231,7 @@ function expectType(t: T) {
}
const mySlice = createSlice({
+ name: 'name',
initialState,
reducers: {
setName: (state, action) => {