diff --git a/packages/toolkit/src/combineSlices.ts b/packages/toolkit/src/combineSlices.ts index 58d07ece28..1a10538cd8 100644 --- a/packages/toolkit/src/combineSlices.ts +++ b/packages/toolkit/src/combineSlices.ts @@ -8,6 +8,7 @@ import type { UnionToIntersection, WithOptionalProp, } from './tsHelpers' +import { emplace } from './utils' type SliceLike = { reducerPath: ReducerPath @@ -330,37 +331,34 @@ const stateProxyMap = new WeakMap() const createStateProxy = ( state: State, reducerMap: Partial> -) => { - let proxy = stateProxyMap.get(state) - if (!proxy) { - proxy = new Proxy(state, { - get: (target, prop, receiver) => { - if (prop === ORIGINAL_STATE) return target - const result = Reflect.get(target, prop, receiver) - if (typeof result === 'undefined') { - const reducer = reducerMap[prop.toString()] - if (reducer) { - // ensure action type is random, to prevent reducer treating it differently - const reducerResult = reducer(undefined, { type: nanoid() }) - if (typeof reducerResult === 'undefined') { - throw new Error( - `The slice reducer for key "${prop.toString()}" returned undefined when called for selector(). ` + - `If the state passed to the reducer is undefined, you must ` + - `explicitly return the initial state. The initial state may ` + - `not be undefined. If you don't want to set a value for this reducer, ` + - `you can use null instead of undefined.` - ) +) => + emplace(stateProxyMap, state, { + insert: () => + new Proxy(state, { + get: (target, prop, receiver) => { + if (prop === ORIGINAL_STATE) return target + const result = Reflect.get(target, prop, receiver) + if (typeof result === 'undefined') { + const reducer = reducerMap[prop.toString()] + if (reducer) { + // ensure action type is random, to prevent reducer treating it differently + const reducerResult = reducer(undefined, { type: nanoid() }) + if (typeof reducerResult === 'undefined') { + throw new Error( + `The slice reducer for key "${prop.toString()}" returned undefined when called for selector(). ` + + `If the state passed to the reducer is undefined, you must ` + + `explicitly return the initial state. The initial state may ` + + `not be undefined. If you don't want to set a value for this reducer, ` + + `you can use null instead of undefined.` + ) + } + return reducerResult } - return reducerResult } - } - return result - }, - }) - stateProxyMap.set(state, proxy) - } - return proxy as State -} + return result + }, + }), + }) as State const original = (state: any) => { if (!isStateProxy(state)) { diff --git a/packages/toolkit/src/createSlice.ts b/packages/toolkit/src/createSlice.ts index bc43c15ad4..4f8c1049ef 100644 --- a/packages/toolkit/src/createSlice.ts +++ b/packages/toolkit/src/createSlice.ts @@ -1,4 +1,5 @@ import type { Action, UnknownAction, Reducer } from 'redux' +import type { Selector } from 'reselect' import type { ActionCreatorWithoutPayload, PayloadAction, @@ -15,7 +16,7 @@ import type { import { createReducer } from './createReducer' import type { ActionReducerMapBuilder } from './mapBuilders' import { executeReducerBuilderCallback } from './mapBuilders' -import type { Id, Tail } from './tsHelpers' +import type { Id } from './tsHelpers' import type { InjectConfig } from './combineSlices' import type { AsyncThunk, @@ -25,6 +26,7 @@ import type { OverrideThunkApiConfigs, } from './createAsyncThunk' import { createAsyncThunk } from './createAsyncThunk' +import { capitalize, emplace } from './utils' interface InjectIntoConfig extends InjectConfig { reducerPath?: NewReducerPath @@ -40,7 +42,8 @@ export interface Slice< CaseReducers extends SliceCaseReducers = SliceCaseReducers, Name extends string = string, ReducerPath extends string = Name, - Selectors extends SliceSelectors = SliceSelectors + Selectors extends SliceSelectors = SliceSelectors, + SelectorFactories extends SliceSelectorFactories = SliceSelectorFactories > { /** * The slice name. @@ -97,6 +100,34 @@ export interface Slice< SliceDefinedSelectors > + /** + * Get localised slice selectors (expects to be called with *just* the slice's state as the first parameter) + */ + getSelectorFactories( + this: this + ): Id> + + /** + * Get globalised slice selector factories (`selectState` callback is expected to receive first parameter and return slice state) + */ + getSelectorFactories( + this: this, + selectState: (rootState: RootState) => State + ): Id> + + /** + * Selector factories that assume the slice's state is `rootState[slice.reducerPath]` (which is usually the case) + * + * Equivalent to `slice.getSelectors((state: RootState) => state[slice.reducerPath])`. + */ + selectorFactories: Id< + SliceDefinedSelectorFactories< + State, + SelectorFactories, + { [K in ReducerPath]: State } + > + > + /** * Inject slice into provided reducer (return value from `combineSlices`), and return injected slice. */ @@ -129,10 +160,11 @@ interface InjectedSlice< CaseReducers extends SliceCaseReducers = SliceCaseReducers, Name extends string = string, ReducerPath extends string = Name, - Selectors extends SliceSelectors = SliceSelectors + Selectors extends SliceSelectors = SliceSelectors, + SelectorFactories extends SliceSelectorFactories = SliceSelectorFactories > extends Omit< Slice, - 'getSelectors' | 'selectors' + 'getSelectors' | 'selectors' | 'getSelectorFactories' | 'selectorFactories' > { /** * Get localised slice selectors (expects to be called with *just* the slice's state as the first parameter) @@ -159,6 +191,33 @@ interface InjectedSlice< > > + /** + * Get localised slice selectors (expects to be called with *just* the slice's state as the first parameter) + */ + getSelectorFactories(): Id< + SliceDefinedSelectorFactories + > + + /** + * Get globalised slice selector factories (`selectState` callback is expected to receive first parameter and return slice state) + */ + getSelectorFactories( + selectState: (rootState: RootState) => State + ): Id> + + /** + * Selector factories that assume the slice's state is `rootState[slice.reducerPath]` (which is usually the case) + * + * Equivalent to `slice.getSelectors((state: RootState) => state[slice.reducerPath])`. + */ + selectorFactories: Id< + SliceDefinedSelectorFactories< + State, + SelectorFactories, + { [K in ReducerPath]: State } + > + > + /** * Select the slice state, using the slice's current reducerPath. * @@ -177,7 +236,8 @@ export interface CreateSliceOptions< CR extends SliceCaseReducers = SliceCaseReducers, Name extends string = string, ReducerPath extends string = Name, - Selectors extends SliceSelectors = SliceSelectors + Selectors extends SliceSelectors = SliceSelectors, + SelectorFactories extends SliceSelectorFactories = SliceSelectorFactories > { /** * The slice's name. Used to namespace the generated action types. @@ -251,6 +311,8 @@ createSlice({ * A map of selectors that receive the slice's state and any additional arguments, and return a result. */ selectors?: Selectors + + selectorFactories?: SelectorFactories } export enum ReducerType { @@ -434,7 +496,27 @@ export type SliceCaseReducers = * The type describing a slice's `selectors` option. */ export type SliceSelectors = { - [K: string]: (sliceState: State, ...args: any[]) => any + [K: string]: Selector +} + +export type SelectorFactory< + FactoryParams extends readonly any[] = any[], + State = any, + Result = unknown, + Params extends readonly any[] = any[] +> = (...args: FactoryParams) => Selector + +export type RemappedSelectorFactory< + SF extends SelectorFactory, + NewState +> = SF extends SelectorFactory + ? ((...args: FP) => RemappedSelector, NewState>) & { + unwrapped: SF + } + : never + +export type SliceSelectorFactories = { + [K: string]: SelectorFactory } type SliceActionType< @@ -533,10 +615,36 @@ type SliceDefinedSelectors< Selectors extends SliceSelectors, RootState > = { - [K in keyof Selectors as string extends K ? never : K]: ( - rootState: RootState, - ...args: Tail> - ) => ReturnType + [K in keyof Selectors as string extends K ? never : K]: RemappedSelector< + Selectors[K], + RootState + > +} + +type RemappedSelector = S extends Selector< + any, + infer R, + infer P +> + ? Selector & { unwrapped: S } + : never + +/** + * Extracts the final selector factories type from the `selectorFactories` object. + * + * Removes the `string` index signature from the default value. + */ +type SliceDefinedSelectorFactories< + State, + SelectorFactories extends SliceSelectorFactories, + RootState +> = { + [K in keyof SelectorFactories as string extends K + ? never + : `make${Capitalize}`]: RemappedSelectorFactory< + SelectorFactories[K], + RootState + > } /** @@ -582,10 +690,18 @@ export function createSlice< CaseReducers extends SliceCaseReducers, Name extends string, Selectors extends SliceSelectors, + SelectorFactories extends SliceSelectorFactories, ReducerPath extends string = Name >( - options: CreateSliceOptions -): Slice { + options: CreateSliceOptions< + State, + CaseReducers, + Name, + ReducerPath, + Selectors, + SelectorFactories + > +): Slice { const { name, reducerPath = name as unknown as ReducerPath } = options if (!name) { throw new Error('`name` is a required option for createSlice') @@ -681,14 +797,29 @@ export function createSlice< const injectedSelectorCache = new WeakMap< Slice, WeakMap< - (rootState: any) => State | undefined, - Record any> + Selector, + Record> + > + >() + + const injectedSelectorFactoryCache = new WeakMap< + Slice, + WeakMap< + Selector, + Record> > >() let _reducer: ReducerWithInitialState - const slice: Slice = { + const slice: Slice< + State, + CaseReducers, + Name, + ReducerPath, + Selectors, + SelectorFactories + > = { name, reducerPath, reducer(state, action) { @@ -703,35 +834,26 @@ export function createSlice< return _reducer.getInitialState() }, - getSelectors(selectState: (rootState: any) => State = selectSelf) { - let selectorCache = injectedSelectorCache.get(this) - if (!selectorCache) { - selectorCache = new WeakMap() - injectedSelectorCache.set(this, selectorCache) - } - let cached = selectorCache.get(selectState) - if (!cached) { - cached = {} - for (const [name, selector] of Object.entries( - options.selectors ?? {} - )) { - cached[name] = (rootState: any, ...args: any[]) => { - let sliceState = selectState.call(this, rootState) - if (typeof sliceState === 'undefined') { - // check if injectInto has been called - if (this !== slice) { - sliceState = this.getInitialState() - } else if (process.env.NODE_ENV !== 'production') { - throw new Error( - 'selectState returned undefined for an uninjected slice reducer' - ) - } - } - return selector(sliceState, ...args) + getSelectors(selectState: Selector = selectSelf) { + const selectorCache = emplace(injectedSelectorCache, this, { + insert: () => new WeakMap(), + }) + const cached = emplace(selectorCache, selectState, { + insert: () => { + const map: Record> = {} + for (const [name, selector] of Object.entries( + options.selectors ?? {} + )) { + map[name] = wrapSelector( + this, + selector, + selectState, + this !== slice + ) } - } - selectorCache.set(selectState, cached) - } + return map + }, + }) return cached as any }, selectSlice(state) { @@ -759,10 +881,59 @@ export function createSlice< reducerPath, } as any }, + getSelectorFactories(selectState: Selector = selectSelf) { + const selectorCache = emplace(injectedSelectorFactoryCache, this, { + insert: () => new WeakMap(), + }) + const cached = emplace(selectorCache, selectState, { + insert: () => { + const map: Record> = {} + for (const [name, selectorFactory] of Object.entries( + options.selectorFactories ?? {} + )) { + map[`make${capitalize(name)}`] = Object.assign( + (...args: any[]) => { + const selector = selectorFactory(...args) + return wrapSelector(this, selector, selectState, this !== slice) + }, + { unwrapped: selectorFactory } + ) + } + return map + }, + }) + return cached as any + }, + get selectorFactories() { + return this.getSelectorFactories(this.selectSlice) + }, } return slice } +function wrapSelector>( + slice: Slice, + selector: S, + selectState: Selector, + injected?: boolean +) { + function wrapper(rootState: NewState, ...args: any[]) { + let sliceState = selectState.call(slice, rootState) + if (typeof sliceState === 'undefined') { + if (injected) { + sliceState = slice.getInitialState() + } else if (process.env.NODE_ENV !== 'production') { + throw new Error( + 'selectState returned undefined for an uninjected slice reducer' + ) + } + } + return selector(sliceState, ...args) + } + wrapper.unwrapped = selector + return wrapper as RemappedSelector +} + interface ReducerHandlingContext { sliceCaseReducersByName: Record< string, diff --git a/packages/toolkit/src/dynamicMiddleware/index.ts b/packages/toolkit/src/dynamicMiddleware/index.ts index 12f5b6522a..e05c16c274 100644 --- a/packages/toolkit/src/dynamicMiddleware/index.ts +++ b/packages/toolkit/src/dynamicMiddleware/index.ts @@ -7,7 +7,7 @@ import { compose } from 'redux' import { createAction, isAction } from '../createAction' import { isAllOf } from '../matchers' import { nanoid } from '../nanoid' -import { find } from '../utils' +import { find, emplace } from '../utils' import type { WithMiddleware, AddMiddleware, @@ -68,16 +68,9 @@ export const createDynamicMiddleware = < { withTypes: () => addMiddleware } ) as AddMiddleware - const getFinalMiddleware: Middleware<{}, State, Dispatch> = (api) => { - const appliedMiddleware = Array.from(middlewareMap.values()).map( - (entry) => { - let applied = entry.applied.get(api) - if (!applied) { - applied = entry.middleware(api) - entry.applied.set(api, applied) - } - return applied - } + const currentMiddleware: Middleware<{}, State, Dispatch> = (api) => { + const appliedMiddleware = Array.from(middlewareMap.values()).map((entry) => + emplace(entry.applied, api, { insert: () => entry.middleware(api) }) ) return compose(...appliedMiddleware) } @@ -94,7 +87,7 @@ export const createDynamicMiddleware = < addMiddleware(...action.payload) return api.dispatch } - return getFinalMiddleware(api)(next)(action) + return currentMiddleware(api)(next)(action) } return { diff --git a/packages/toolkit/src/tests/createSlice.test.ts b/packages/toolkit/src/tests/createSlice.test.ts index 4755a4e5dc..a25b8464df 100644 --- a/packages/toolkit/src/tests/createSlice.test.ts +++ b/packages/toolkit/src/tests/createSlice.test.ts @@ -1,5 +1,6 @@ import { vi } from 'vitest' import type { PayloadAction, WithSlice } from '@reduxjs/toolkit' +import { createSelector } from '@reduxjs/toolkit' import { configureStore, combineSlices, @@ -464,12 +465,15 @@ describe('createSlice', () => { reducers: {}, selectors: { selectSlice: (state) => state, - selectMultiple: (state, multiplier: number) => state * multiplier, + selectMultiple: Object.assign( + (state: number, multiplier: number) => state * multiplier, + { test: 0 } + ), }, }) - it('expects reducer under slice.name if no selectState callback passed', () => { + it('expects reducer under slice.reducerPath if no selectState callback passed', () => { const testState = { - [slice.name]: slice.getInitialState(), + [slice.reducerPath]: slice.getInitialState(), } const { selectSlice, selectMultiple } = slice.selectors expect(selectSlice(testState)).toBe(slice.getInitialState()) @@ -485,6 +489,64 @@ describe('createSlice', () => { expect(selectSlice(customState)).toBe(slice.getInitialState()) expect(selectMultiple(customState, 2)).toBe(slice.getInitialState() * 2) }) + it('allows accessing properties on the selector', () => { + expect(slice.selectors.selectMultiple.unwrapped.test).toBe(0) + }) + }) + describe('slice selector factories', () => { + const slice = createSlice({ + name: 'counter', + initialState: 42, + reducers: {}, + selectorFactories: { + selectMultiple: (multiple: number) => (value) => multiple * value, + selectMemoized: () => + createSelector( + (value: number) => value, + (value) => ({ value }) + ), + }, + }) + it('expects reducer under slice.reducerPath if no selectState callback passed', () => { + const testState = { + [slice.reducerPath]: slice.getInitialState(), + } + const { makeSelectMultiple, makeSelectMemoized } = slice.selectorFactories + const selectMemoized = makeSelectMemoized() + expect(selectMemoized(testState)).toEqual({ + value: slice.getInitialState(), + }) + + const selectDouble = makeSelectMultiple(2) + expect(selectDouble(testState)).toBe(slice.getInitialState() * 2) + }) + it('allows passing a selector for a custom location', () => { + const customState = { + number: slice.getInitialState(), + } + const { makeSelectMemoized, makeSelectMultiple } = + slice.getSelectorFactories((state: typeof customState) => state.number) + const selectMemoized = makeSelectMemoized() + expect(selectMemoized(customState)).toEqual({ + value: slice.getInitialState(), + }) + + const selectDouble = makeSelectMultiple(2) + expect(selectDouble(customState)).toBe(slice.getInitialState() * 2) + }) + it('creates a new instance per call, and allows accessing properties on the selector', () => { + const testState = { + [slice.reducerPath]: slice.getInitialState(), + } + const { makeSelectMemoized } = slice.selectorFactories + const selectMemoized1 = makeSelectMemoized() + const selectMemoized2 = makeSelectMemoized() + expect(selectMemoized1).not.toBe(selectMemoized2) + + const result = selectMemoized1(testState) + + expect(selectMemoized1.unwrapped.lastResult()).toBe(result) + }) }) describe('slice injections', () => { it('uses injectInto to inject slice into combined reducer', () => { diff --git a/packages/toolkit/src/tests/createSlice.typetest.ts b/packages/toolkit/src/tests/createSlice.typetest.ts index 94ad5b773c..4ee6bdd13a 100644 --- a/packages/toolkit/src/tests/createSlice.typetest.ts +++ b/packages/toolkit/src/tests/createSlice.typetest.ts @@ -16,7 +16,7 @@ import type { ThunkDispatch, ValidateSliceCaseReducers, } from '@reduxjs/toolkit' -import { configureStore, isRejected } from '@reduxjs/toolkit' +import { createSelector, configureStore, isRejected } from '@reduxjs/toolkit' import { createAction, createSlice } from '@reduxjs/toolkit' import { expectExactType, expectType, expectUnknown } from './helpers' import { castDraft } from 'immer' @@ -561,6 +561,49 @@ const value = actionCreators.anyKey expectType(nestedSelectors.selectToFixed(nestedState)) } +/** + * Test: selector factories + */ +{ + const sliceWithSelectors = createSlice({ + name: 'counter', + initialState: { value: 0 }, + reducers: { + increment: (state) => { + state.value += 1 + }, + }, + selectors: { + selectValue: (state) => state.value, + }, + selectorFactories: { + selectMultiply: () => + createSelector( + (state: { value: number }) => state.value, + (_: unknown, multiplier: number) => multiplier, + (value, multiplier) => value * multiplier + ), + selectToFixed: Object.assign( + (dp: number) => + createSelector( + (state: { value: number }) => state.value, + (value) => value.toFixed(dp) + ), + { test: 0 } + ), + }, + }) + + const selectMultiply = + sliceWithSelectors.selectorFactories.makeSelectMultiply() + expectType( + sliceWithSelectors.selectorFactories.makeSelectToFixed.unwrapped.test + ) + const selectToFixed = sliceWithSelectors + .getSelectorFactories() + .makeSelectToFixed(2) +} + /** * Test: reducer callback */ diff --git a/packages/toolkit/src/utils.ts b/packages/toolkit/src/utils.ts index 2e29b2bc9e..db075a97eb 100644 --- a/packages/toolkit/src/utils.ts +++ b/packages/toolkit/src/utils.ts @@ -87,3 +87,83 @@ export class Tuple = []> extends Array< export function freezeDraftable(val: T) { return isDraftable(val) ? createNextState(val, () => {}) : val } + +export function capitalize(str: string) { + return str.replace(str[0], str[0].toUpperCase()) +} + +interface WeakMapEmplaceHandler { + /** + * Will be called to get value, if no value is currently in map. + */ + insert?(key: K, map: WeakMap): V + /** + * Will be called to update a value, if one exists already. + */ + update?(previous: V, key: K, map: WeakMap): V +} + +interface MapEmplaceHandler { + /** + * Will be called to get value, if no value is currently in map. + */ + insert?(key: K, map: Map): V + /** + * Will be called to update a value, if one exists already. + */ + update?(previous: V, key: K, map: Map): V +} + +export function emplace( + map: Map, + key: K, + handler: MapEmplaceHandler +): V +export function emplace( + map: WeakMap, + key: K, + handler: WeakMapEmplaceHandler +): V +/** + * Allow inserting a new value, or updating an existing one + * @throws if called for a key with no current value and no `insert` handler is provided + * @returns current value in map (after insertion/updating) + * ```ts + * // return current value if already in map, otherwise initialise to 0 and return that + * const num = emplace(map, key, { + * insert: () => 0 + * }) + * + * // increase current value by one if already in map, otherwise initialise to 0 + * const num = emplace(map, key, { + * update: (n) => n + 1, + * insert: () => 0, + * }) + * + * // only update if value's already in the map - and increase it by one + * if (map.has(key)) { + * const num = emplace(map, key, { + * update: (n) => n + 1, + * }) + * } + * ``` + */ +export function emplace( + map: WeakMap, + key: K, + handler: WeakMapEmplaceHandler +): V { + if (map.has(key)) { + let value = map.get(key) as V + if (handler.update) { + value = handler.update(value, key, map) + map.set(key, value) + } + return value + } + if (!handler.insert) + throw new Error('No insert provided for key not already in map') + const inserted = handler.insert(key, map) + map.set(key, inserted) + return inserted +}