diff --git a/src/execution/__tests__/defer-test.ts b/src/execution/__tests__/defer-test.ts index 5cad95bbc3..f9c0639306 100644 --- a/src/execution/__tests__/defer-test.ts +++ b/src/execution/__tests__/defer-test.ts @@ -605,11 +605,6 @@ describe('Execute: defer directive', () => { data: { slowField: 'slow', friends: [{}, {}, {}] }, path: ['hero'], }, - ], - hasNext: true, - }, - { - incremental: [ { data: { name: 'Han' }, path: ['hero', 'friends', 0] }, { data: { name: 'Leia' }, path: ['hero', 'friends', 1] }, { data: { name: 'C-3PO' }, path: ['hero', 'friends', 2] }, @@ -653,11 +648,6 @@ describe('Execute: defer directive', () => { }, path: ['hero'], }, - ], - hasNext: true, - }, - { - incremental: [ { data: { name: 'Han' }, path: ['hero', 'friends', 0] }, { data: { name: 'Leia' }, path: ['hero', 'friends', 1] }, { data: { name: 'C-3PO' }, path: ['hero', 'friends', 2] }, diff --git a/src/execution/__tests__/stream-test.ts b/src/execution/__tests__/stream-test.ts index aed5211ae1..65d7f67381 100644 --- a/src/execution/__tests__/stream-test.ts +++ b/src/execution/__tests__/stream-test.ts @@ -151,11 +151,10 @@ describe('Execute: stream directive', () => { hasNext: true, }, { - incremental: [{ items: ['banana'], path: ['scalarList', 1] }], - hasNext: true, - }, - { - incremental: [{ items: ['coconut'], path: ['scalarList', 2] }], + incremental: [ + { items: ['banana'], path: ['scalarList', 1] }, + { items: ['coconut'], path: ['scalarList', 2] }, + ], hasNext: false, }, ]); @@ -173,15 +172,11 @@ describe('Execute: stream directive', () => { hasNext: true, }, { - incremental: [{ items: ['apple'], path: ['scalarList', 0] }], - hasNext: true, - }, - { - incremental: [{ items: ['banana'], path: ['scalarList', 1] }], - hasNext: true, - }, - { - incremental: [{ items: ['coconut'], path: ['scalarList', 2] }], + incremental: [ + { items: ['apple'], path: ['scalarList', 0] }, + { items: ['banana'], path: ['scalarList', 1] }, + { items: ['coconut'], path: ['scalarList', 2] }, + ], hasNext: false, }, ]); @@ -230,11 +225,6 @@ describe('Execute: stream directive', () => { path: ['scalarList', 1], label: 'scalar-stream', }, - ], - hasNext: true, - }, - { - incremental: [ { items: ['coconut'], path: ['scalarList', 2], @@ -296,11 +286,6 @@ describe('Execute: stream directive', () => { items: [['banana', 'banana', 'banana']], path: ['scalarListList', 1], }, - ], - hasNext: true, - }, - { - incremental: [ { items: [['coconut', 'coconut', 'coconut']], path: ['scalarListList', 2], @@ -379,20 +364,10 @@ describe('Execute: stream directive', () => { items: [{ name: 'Luke', id: '1' }], path: ['friendList', 0], }, - ], - hasNext: true, - }, - { - incremental: [ { items: [{ name: 'Han', id: '2' }], path: ['friendList', 1], }, - ], - hasNext: true, - }, - { - incremental: [ { items: [{ name: 'Leia', id: '3' }], path: ['friendList', 2], @@ -531,11 +506,6 @@ describe('Execute: stream directive', () => { }, ], }, - ], - hasNext: true, - }, - { - incremental: [ { items: [{ name: 'Leia', id: '3' }], path: ['friendList', 2], @@ -707,7 +677,12 @@ describe('Execute: stream directive', () => { hasNext: true, }, }, - { done: false, value: { hasNext: false } }, + { + done: false, + value: { + hasNext: false, + }, + }, { done: true, value: undefined }, ]); }); @@ -935,11 +910,6 @@ describe('Execute: stream directive', () => { }, ], }, - ], - hasNext: true, - }, - { - incremental: [ { items: [{ nonNullName: 'Han' }], path: ['friendList', 2], @@ -984,11 +954,6 @@ describe('Execute: stream directive', () => { }, ], }, - ], - hasNext: true, - }, - { - incremental: [ { items: [{ nonNullName: 'Han' }], path: ['friendList', 2], @@ -1117,11 +1082,6 @@ describe('Execute: stream directive', () => { }, ], }, - ], - hasNext: true, - }, - { - incremental: [ { items: [{ nonNullName: 'Han' }], path: ['friendList', 2], @@ -1407,9 +1367,6 @@ describe('Execute: stream directive', () => { ], }, ], - hasNext: true, - }, - { hasNext: false, }, ]); @@ -1556,9 +1513,6 @@ describe('Execute: stream directive', () => { path: ['friendList', 2], }, ], - hasNext: true, - }, - { hasNext: false, }, ]); @@ -1612,15 +1566,6 @@ describe('Execute: stream directive', () => { data: { scalarField: 'slow', nestedFriendList: [] }, path: ['nestedObject'], }, - ], - hasNext: true, - }, - done: false, - }); - const result3 = await iterator.next(); - expectJSON(result3).toDeepEqual({ - value: { - incremental: [ { items: [{ name: 'Luke' }], path: ['nestedObject', 'nestedFriendList', 0], @@ -1630,8 +1575,8 @@ describe('Execute: stream directive', () => { }, done: false, }); - const result4 = await iterator.next(); - expectJSON(result4).toDeepEqual({ + const result3 = await iterator.next(); + expectJSON(result3).toDeepEqual({ value: { incremental: [ { @@ -1643,13 +1588,13 @@ describe('Execute: stream directive', () => { }, done: false, }); - const result5 = await iterator.next(); - expectJSON(result5).toDeepEqual({ + const result4 = await iterator.next(); + expectJSON(result4).toDeepEqual({ value: { hasNext: false }, done: false, }); - const result6 = await iterator.next(); - expectJSON(result6).toDeepEqual({ + const result5 = await iterator.next(); + expectJSON(result5).toDeepEqual({ value: undefined, done: true, }); diff --git a/src/execution/execute.ts b/src/execution/execute.ts index 1bc6c4267b..6ce7964736 100644 --- a/src/execution/execute.ts +++ b/src/execution/execute.ts @@ -53,6 +53,7 @@ import { collectSubfields as _collectSubfields, } from './collectFields.js'; import { mapAsyncIterable } from './mapAsyncIterable.js'; +import { Publisher } from './publisher.js'; import { getArgumentValues, getDirectiveValues, @@ -121,7 +122,11 @@ export interface ExecutionContext { typeResolver: GraphQLTypeResolver; subscribeFieldResolver: GraphQLFieldResolver; errors: Array; - subsequentPayloads: Set; + publisher: Publisher< + AsyncPayloadRecord, + IncrementalResult, + SubsequentIncrementalExecutionResult + >; } /** @@ -357,13 +362,14 @@ function executeImpl( return result.then( (data) => { const initialResult = buildResponse(data, exeContext.errors); - if (exeContext.subsequentPayloads.size > 0) { + const publisher = exeContext.publisher; + if (publisher.hasNext()) { return { initialResult: { ...initialResult, hasNext: true, }, - subsequentResults: yieldSubsequentPayloads(exeContext), + subsequentResults: publisher.subscribe(), }; } return initialResult; @@ -375,13 +381,14 @@ function executeImpl( ); } const initialResult = buildResponse(result, exeContext.errors); - if (exeContext.subsequentPayloads.size > 0) { + const publisher = exeContext.publisher; + if (publisher.hasNext()) { return { initialResult: { ...initialResult, hasNext: true, }, - subsequentResults: yieldSubsequentPayloads(exeContext), + subsequentResults: publisher.subscribe(), }; } return initialResult; @@ -503,7 +510,7 @@ export function buildExecutionContext( fieldResolver: fieldResolver ?? defaultFieldResolver, typeResolver: typeResolver ?? defaultTypeResolver, subscribeFieldResolver: subscribeFieldResolver ?? defaultFieldResolver, - subsequentPayloads: new Set(), + publisher: new Publisher(resultFromAsyncPayloadRecord, payloadFromResults), errors: [], }; } @@ -515,7 +522,7 @@ function buildPerEventExecutionContext( return { ...exeContext, rootValue: payload, - subsequentPayloads: new Set(), + publisher: new Publisher(resultFromAsyncPayloadRecord, payloadFromResults), errors: [], }; } @@ -1791,16 +1798,21 @@ function executeDeferredFragment( fields, asyncPayloadRecord, ); - - if (isPromise(promiseOrData)) { - promiseOrData = promiseOrData.then(null, (e) => { - asyncPayloadRecord.errors.push(e); - return null; - }); - } } catch (e) { asyncPayloadRecord.errors.push(e); - promiseOrData = null; + asyncPayloadRecord.addData(null); + return; + } + + if (isPromise(promiseOrData)) { + promiseOrData.then( + (value) => asyncPayloadRecord.addData(value), + (error) => { + asyncPayloadRecord.errors.push(error); + asyncPayloadRecord.addData(null); + }, + ); + return; } asyncPayloadRecord.addData(promiseOrData); } @@ -1823,7 +1835,7 @@ function executeStreamField( exeContext, }); if (isPromise(item)) { - const completedItems = completePromisedValue( + completePromisedValue( exeContext, itemType, fieldNodes, @@ -1832,15 +1844,14 @@ function executeStreamField( item, asyncPayloadRecord, ).then( - (value) => [value], + (value) => asyncPayloadRecord.addItems([value]), (error) => { asyncPayloadRecord.errors.push(error); filterSubsequentPayloads(exeContext, path, asyncPayloadRecord); - return null; + asyncPayloadRecord.addItems(null); }, ); - asyncPayloadRecord.addItems(completedItems); return asyncPayloadRecord; } @@ -1873,7 +1884,7 @@ function executeStreamField( } if (isPromise(completedItem)) { - const completedItems = completedItem + completedItem .then(undefined, (rawError) => { const error = locatedError(rawError, fieldNodes, pathToArray(itemPath)); const handledError = handleFieldError( @@ -1885,15 +1896,14 @@ function executeStreamField( return handledError; }) .then( - (value) => [value], + (value) => asyncPayloadRecord.addItems([value]), (error) => { asyncPayloadRecord.errors.push(error); filterSubsequentPayloads(exeContext, path, asyncPayloadRecord); - return null; + asyncPayloadRecord.addItems(null); }, ); - asyncPayloadRecord.addItems(completedItems); return asyncPayloadRecord; } @@ -2008,22 +2018,19 @@ async function executeStreamIterator( const { done, value: completedItem } = iteration; - let completedItems: PromiseOrValue | null>; if (isPromise(completedItem)) { - completedItems = completedItem.then( - (value) => [value], + completedItem.then( + (resolvedItem) => asyncPayloadRecord.addItems([resolvedItem]), (error) => { asyncPayloadRecord.errors.push(error); filterSubsequentPayloads(exeContext, path, asyncPayloadRecord); - return null; + asyncPayloadRecord.addItems(null); }, ); } else { - completedItems = [completedItem]; + asyncPayloadRecord.addItems([completedItem]); } - asyncPayloadRecord.addItems(completedItems); - if (done) { break; } @@ -2038,132 +2045,49 @@ function filterSubsequentPayloads( currentAsyncRecord: AsyncPayloadRecord | undefined, ): void { const nullPathArray = pathToArray(nullPath); - exeContext.subsequentPayloads.forEach((asyncRecord) => { + exeContext.publisher.filter((asyncRecord) => { if (asyncRecord === currentAsyncRecord) { // don't remove payload from where error originates - return; + return true; } for (let i = 0; i < nullPathArray.length; i++) { if (asyncRecord.path[i] !== nullPathArray[i]) { // asyncRecord points to a path unaffected by this payload - return; + return true; } } - // asyncRecord path points to nulled error field - if (isStreamPayload(asyncRecord) && asyncRecord.iterator?.return) { - asyncRecord.iterator.return().catch(() => { - // ignore error - }); - } - exeContext.subsequentPayloads.delete(asyncRecord); + + return false; }); } -function getCompletedIncrementalResults( - exeContext: ExecutionContext, -): Array { - const incrementalResults: Array = []; - for (const asyncPayloadRecord of exeContext.subsequentPayloads) { - const incrementalResult: IncrementalResult = {}; - if (!asyncPayloadRecord.isCompleted) { - continue; - } - exeContext.subsequentPayloads.delete(asyncPayloadRecord); - if (isStreamPayload(asyncPayloadRecord)) { - const items = asyncPayloadRecord.items; - if (asyncPayloadRecord.isCompletedIterator) { - // async iterable resolver just finished but there may be pending payloads - continue; - } - (incrementalResult as IncrementalStreamResult).items = items; - } else { - const data = asyncPayloadRecord.data; - (incrementalResult as IncrementalDeferResult).data = data ?? null; - } - - incrementalResult.path = asyncPayloadRecord.path; - if (asyncPayloadRecord.label) { - incrementalResult.label = asyncPayloadRecord.label; - } - if (asyncPayloadRecord.errors.length > 0) { - incrementalResult.errors = asyncPayloadRecord.errors; - } - incrementalResults.push(incrementalResult); +function resultFromAsyncPayloadRecord( + asyncPayloadRecord: AsyncPayloadRecord, +): IncrementalResult { + const incrementalResult: IncrementalResult = {}; + if (isStreamPayload(asyncPayloadRecord)) { + const items = asyncPayloadRecord.items; + (incrementalResult as IncrementalStreamResult).items = items; + } else { + const data = asyncPayloadRecord.data; + (incrementalResult as IncrementalDeferResult).data = data ?? null; } - return incrementalResults; -} - -function yieldSubsequentPayloads( - exeContext: ExecutionContext, -): AsyncGenerator { - let isDone = false; - - async function next(): Promise< - IteratorResult - > { - if (isDone) { - return { value: undefined, done: true }; - } - await Promise.race( - Array.from(exeContext.subsequentPayloads).map((p) => p.promise), - ); - - if (isDone) { - // a different call to next has exhausted all payloads - return { value: undefined, done: true }; - } - - const incremental = getCompletedIncrementalResults(exeContext); - const hasNext = exeContext.subsequentPayloads.size > 0; - - if (!incremental.length && hasNext) { - return next(); - } - - if (!hasNext) { - isDone = true; - } - - return { - value: incremental.length ? { incremental, hasNext } : { hasNext }, - done: false, - }; + incrementalResult.path = asyncPayloadRecord.path; + if (asyncPayloadRecord.label) { + incrementalResult.label = asyncPayloadRecord.label; } - - function returnStreamIterators() { - const promises: Array>> = []; - exeContext.subsequentPayloads.forEach((asyncPayloadRecord) => { - if ( - isStreamPayload(asyncPayloadRecord) && - asyncPayloadRecord.iterator?.return - ) { - promises.push(asyncPayloadRecord.iterator.return()); - } - }); - return Promise.all(promises); + if (asyncPayloadRecord.errors.length > 0) { + incrementalResult.errors = asyncPayloadRecord.errors; } + return incrementalResult; +} - return { - [Symbol.asyncIterator]() { - return this; - }, - next, - async return(): Promise< - IteratorResult - > { - await returnStreamIterators(); - isDone = true; - return { value: undefined, done: true }; - }, - async throw( - error?: unknown, - ): Promise> { - await returnStreamIterators(); - isDone = true; - return Promise.reject(error); - }, - }; +function payloadFromResults( + incremental: ReadonlyArray, + hasNext: boolean, +): SubsequentIncrementalExecutionResult { + return incremental.length ? { incremental, hasNext } : { hasNext }; } class DeferredFragmentRecord { @@ -2171,12 +2095,14 @@ class DeferredFragmentRecord { errors: Array; label: string | undefined; path: Array; - promise: Promise; data: ObjMap | null; parentContext: AsyncPayloadRecord | undefined; - isCompleted: boolean; - _exeContext: ExecutionContext; - _resolve?: (arg: PromiseOrValue | null>) => void; + _publisher: Publisher< + AsyncPayloadRecord, + IncrementalResult, + SubsequentIncrementalExecutionResult + >; + constructor(opts: { label: string | undefined; path: Path | undefined; @@ -2188,27 +2114,14 @@ class DeferredFragmentRecord { this.path = pathToArray(opts.path); this.parentContext = opts.parentContext; this.errors = []; - this._exeContext = opts.exeContext; - this._exeContext.subsequentPayloads.add(this); - this.isCompleted = false; + this._publisher = opts.exeContext.publisher; + this._publisher.add(this); this.data = null; - this.promise = new Promise | null>((resolve) => { - this._resolve = (promiseOrValue) => { - resolve(promiseOrValue); - }; - }).then((data) => { - this.data = data; - this.isCompleted = true; - }); } - addData(data: PromiseOrValue | null>) { - const parentData = this.parentContext?.promise; - if (parentData) { - this._resolve?.(parentData.then(() => data)); - return; - } - this._resolve?.(data); + addData(data: ObjMap | null) { + this.data = data; + this._publisher.complete(this); } } @@ -2218,13 +2131,15 @@ class StreamRecord { label: string | undefined; path: Array; items: Array | null; - promise: Promise; parentContext: AsyncPayloadRecord | undefined; iterator: AsyncIterator | undefined; isCompletedIterator?: boolean; - isCompleted: boolean; - _exeContext: ExecutionContext; - _resolve?: (arg: PromiseOrValue | null>) => void; + _publisher: Publisher< + AsyncPayloadRecord, + IncrementalResult, + SubsequentIncrementalExecutionResult + >; + constructor(opts: { label: string | undefined; path: Path | undefined; @@ -2239,27 +2154,14 @@ class StreamRecord { this.parentContext = opts.parentContext; this.iterator = opts.iterator; this.errors = []; - this._exeContext = opts.exeContext; - this._exeContext.subsequentPayloads.add(this); - this.isCompleted = false; + this._publisher = opts.exeContext.publisher; + this._publisher.add(this); this.items = null; - this.promise = new Promise | null>((resolve) => { - this._resolve = (promiseOrValue) => { - resolve(promiseOrValue); - }; - }).then((items) => { - this.items = items; - this.isCompleted = true; - }); } - addItems(items: PromiseOrValue | null>) { - const parentData = this.parentContext?.promise; - if (parentData) { - this._resolve?.(parentData.then(() => items)); - return; - } - this._resolve?.(items); + addItems(items: Array | null) { + this.items = items; + this._publisher.complete(this); } setIsCompletedIterator() { diff --git a/src/execution/publisher.ts b/src/execution/publisher.ts new file mode 100644 index 0000000000..7fe4f37074 --- /dev/null +++ b/src/execution/publisher.ts @@ -0,0 +1,252 @@ +interface Source { + parentContext: this | undefined; + isCompletedIterator?: boolean | undefined; + iterator?: AsyncIterator | undefined; +} + +interface HasParent { + parentContext: T; +} + +function hasParent(value: T): value is T & HasParent { + return (value as HasParent).parentContext !== undefined; +} + +type ToIncrementalResult = ( + source: TSource, +) => TIncremental; + +type ToPayload = ( + incremental: ReadonlyArray, + hasNext: boolean, +) => TPayload; + +/** + * @internal + */ +export class Publisher { + // This is safe because a promise executor within the constructor will assign this. + trigger!: () => void; + signal: Promise; + pending: Set; + waiting: Set>; + waitingByParent: Map>>; + pushed: WeakSet; + current: Set; + toIncrementalResult: ToIncrementalResult; + toPayload: ToPayload; + + constructor( + toIncrementalResult: ToIncrementalResult, + toPayload: ToPayload, + ) { + this.signal = new Promise((resolve) => { + this.trigger = resolve; + }); + this.pending = new Set(); + this.waiting = new Set(); + this.waitingByParent = new Map(); + this.pushed = new WeakSet(); + this.current = new Set(); + this.toIncrementalResult = toIncrementalResult; + this.toPayload = toPayload; + } + + add(source: TSource): void { + this.pending.add(source); + } + + complete(source: TSource): void { + // if source has been filtered, ignore completion + if (!this.pending.has(source)) { + return; + } + + this.pending.delete(source); + + if (!hasParent(source)) { + this._push(source); + this.trigger(); + return; + } + + const parentContext = source.parentContext; + if (this.pushed.has(source.parentContext)) { + this._push(source); + this.trigger(); + return; + } + + this.waiting.add(source); + + const waitingByParent = this.waitingByParent.get(parentContext); + if (waitingByParent) { + waitingByParent.add(source); + return; + } + + this.waitingByParent.set(parentContext, new Set([source])); + } + + _push(source: TSource): void { + this.pushed.add(source); + this.current.add(source); + + const waitingByParent = this.waitingByParent.get(source); + if (waitingByParent === undefined) { + return; + } + + for (const child of waitingByParent) { + this.waitingByParent.delete(child); + this.waiting.delete(child); + this._push(child); + } + } + + hasNext(): boolean { + return ( + this.pending.size > 0 || this.waiting.size > 0 || this.current.size > 0 + ); + } + + filter(predicate: (source: TSource) => boolean): void { + const iterators = new Set>(); + for (const set of [this.pending, this.current]) { + set.forEach((source) => { + if (predicate(source)) { + return; + } + if (source.iterator?.return) { + iterators.add(source.iterator); + } + set.delete(source); + }); + } + + this.waiting.forEach((source) => { + if (predicate(source)) { + return; + } + + if (source.iterator?.return) { + iterators.add(source.iterator); + } + + this.waiting.delete(source); + + const parentContext = source.parentContext; + const children = this.waitingByParent.get(parentContext); + // TODO: children can never be undefined, but TS doesn't know that + children?.delete(source); + }); + + for (const iterator of iterators) { + iterator.return?.().catch(() => { + // ignore error + }); + } + } + + _getCompletedIncrementalResults(): Array { + const incrementalResults: Array = []; + for (const source of this.current) { + this.current.delete(source); + if (source.isCompletedIterator) { + continue; + } + incrementalResults.push(this.toIncrementalResult(source)); + } + return incrementalResults; + } + + subscribe(): AsyncGenerator { + let isDone = false; + + const next = async (): Promise> => { + if (isDone) { + return { value: undefined, done: true }; + } + + const incremental = this._getCompletedIncrementalResults(); + if (!incremental.length) { + return onSignal(); + } + + const hasNext = this.hasNext(); + + if (!hasNext) { + isDone = true; + } + + return { + value: this.toPayload(incremental, hasNext), + done: false, + }; + }; + + const onSignal = async (): Promise> => { + await this.signal; + + if (isDone) { + return { value: undefined, done: true }; + } + + const incremental = this._getCompletedIncrementalResults(); + + this.signal = new Promise((resolve) => { + this.trigger = resolve; + }); + + const hasNext = this.hasNext(); + if (!incremental.length && hasNext) { + return onSignal(); + } + + if (!hasNext) { + isDone = true; + } + + return { + value: this.toPayload(incremental, hasNext), + done: false, + }; + }; + + const returnIterators = () => { + const iterators = new Set>(); + for (const set of [this.pending, this.waiting, this.current]) { + for (const source of set) { + if (source.iterator?.return) { + iterators.add(source.iterator); + } + } + } + + const promises: Array>> = []; + for (const iterator of iterators) { + if (iterator?.return) { + promises.push(iterator.return()); + } + } + return Promise.all(promises); + }; + + return { + [Symbol.asyncIterator]() { + return this; + }, + next, + async return(): Promise> { + isDone = true; + await returnIterators(); + return { value: undefined, done: true }; + }, + async throw(error?: unknown): Promise> { + isDone = true; + await returnIterators(); + return Promise.reject(error); + }, + }; + } +}