Skip to content

Commit d89e2de

Browse files
committed
align return types from execution and subscription
with respect to possible promises
1 parent 75d3061 commit d89e2de

File tree

2 files changed

+107
-39
lines changed

2 files changed

+107
-39
lines changed

src/execution/__tests__/subscribe-test.ts

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import { expectJSON } from '../../__testUtils__/expectJSON';
55
import { resolveOnNextTick } from '../../__testUtils__/resolveOnNextTick';
66

77
import { isAsyncIterable } from '../../jsutils/isAsyncIterable';
8+
import { isPromise } from '../../jsutils/isPromise';
89

910
import { parse } from '../../language/parser';
1011

@@ -135,9 +136,6 @@ async function expectPromise(promise: Promise<unknown>) {
135136
}
136137

137138
return {
138-
toReject() {
139-
expect(caughtError).to.be.an.instanceOf(Error);
140-
},
141139
toRejectWith(message: string) {
142140
expect(caughtError).to.be.an.instanceOf(Error);
143141
expect(caughtError).to.have.property('message', message);
@@ -379,24 +377,22 @@ describe('Subscription Initialization Phase', () => {
379377
});
380378

381379
// @ts-expect-error (schema must not be null)
382-
(await expectPromise(subscribe({ schema: null, document }))).toRejectWith(
380+
expect(() => subscribe({ schema: null, document })).to.throw(
383381
'Expected null to be a GraphQL schema.',
384382
);
385383

386384
// @ts-expect-error
387-
(await expectPromise(subscribe({ document }))).toRejectWith(
385+
expect(() => subscribe({ document })).to.throw(
388386
'Expected undefined to be a GraphQL schema.',
389387
);
390388

391389
// @ts-expect-error (document must not be null)
392-
(await expectPromise(subscribe({ schema, document: null }))).toRejectWith(
390+
expect(() => subscribe({ schema, document: null })).to.throw(
393391
'Must provide document.',
394392
);
395393

396394
// @ts-expect-error
397-
(await expectPromise(subscribe({ schema }))).toRejectWith(
398-
'Must provide document.',
399-
);
395+
expect(() => subscribe({ schema })).to.throw('Must provide document.');
400396
});
401397

402398
it('resolves to an error if schema does not support subscriptions', async () => {
@@ -450,11 +446,17 @@ describe('Subscription Initialization Phase', () => {
450446
});
451447

452448
// @ts-expect-error
453-
(await expectPromise(subscribe({ schema, document: {} }))).toReject();
449+
expect(() => subscribe({ schema, document: {} })).to.throw();
454450
});
455451

456452
it('throws an error if subscribe does not return an iterator', async () => {
457-
(await expectPromise(subscribeWithBadFn(() => 'test'))).toRejectWith(
453+
expect(() => subscribeWithBadFn(() => 'test')).to.throw(
454+
'Subscription field must return Async Iterable. Received: "test".',
455+
);
456+
457+
const result = subscribeWithBadFn(() => Promise.resolve('test'));
458+
assert(isPromise(result));
459+
(await expectPromise(result)).toRejectWith(
458460
'Subscription field must return Async Iterable. Received: "test".',
459461
);
460462
});
@@ -472,12 +474,12 @@ describe('Subscription Initialization Phase', () => {
472474

473475
expectJSON(
474476
// Returning an error
475-
await subscribeWithBadFn(() => new Error('test error')),
477+
subscribeWithBadFn(() => new Error('test error')),
476478
).toDeepEqual(expectedResult);
477479

478480
expectJSON(
479481
// Throwing an error
480-
await subscribeWithBadFn(() => {
482+
subscribeWithBadFn(() => {
481483
throw new Error('test error');
482484
}),
483485
).toDeepEqual(expectedResult);

src/execution/subscribe.ts

Lines changed: 92 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import { inspect } from '../jsutils/inspect';
22
import { isAsyncIterable } from '../jsutils/isAsyncIterable';
3+
import { isPromise } from '../jsutils/isPromise';
34
import type { Maybe } from '../jsutils/Maybe';
45
import { addPath, pathToArray } from '../jsutils/Path';
6+
import type { PromiseOrValue } from '../jsutils/PromiseOrValue';
57

68
import { GraphQLError } from '../error/GraphQLError';
79
import { locatedError } from '../error/locatedError';
@@ -47,9 +49,11 @@ import { getArgumentValues } from './values';
4749
*
4850
* Accepts either an object with named arguments, or individual arguments.
4951
*/
50-
export async function subscribe(
52+
export function subscribe(
5153
args: ExecutionArgs,
52-
): Promise<AsyncGenerator<ExecutionResult, void, void> | ExecutionResult> {
54+
): PromiseOrValue<
55+
AsyncGenerator<ExecutionResult, void, void> | ExecutionResult
56+
> {
5357
const {
5458
schema,
5559
document,
@@ -61,7 +65,7 @@ export async function subscribe(
6165
subscribeFieldResolver,
6266
} = args;
6367

64-
const resultOrStream = await createSourceEventStream(
68+
const resultOrStream = createSourceEventStream(
6569
schema,
6670
document,
6771
rootValue,
@@ -71,6 +75,42 @@ export async function subscribe(
7175
subscribeFieldResolver,
7276
);
7377

78+
if (isPromise(resultOrStream)) {
79+
return resultOrStream.then((resolvedResultOrStream) =>
80+
mapSourceToResponse(
81+
schema,
82+
document,
83+
resolvedResultOrStream,
84+
contextValue,
85+
variableValues,
86+
operationName,
87+
fieldResolver,
88+
),
89+
);
90+
}
91+
92+
return mapSourceToResponse(
93+
schema,
94+
document,
95+
resultOrStream,
96+
contextValue,
97+
variableValues,
98+
operationName,
99+
fieldResolver,
100+
);
101+
}
102+
103+
function mapSourceToResponse(
104+
schema: GraphQLSchema,
105+
document: DocumentNode,
106+
resultOrStream: ExecutionResult | AsyncIterable<unknown>,
107+
contextValue?: unknown,
108+
variableValues?: Maybe<{ readonly [variable: string]: unknown }>,
109+
operationName?: Maybe<string>,
110+
fieldResolver?: Maybe<GraphQLFieldResolver<any, any>>,
111+
): PromiseOrValue<
112+
AsyncGenerator<ExecutionResult, void, void> | ExecutionResult
113+
> {
74114
if (!isAsyncIterable(resultOrStream)) {
75115
return resultOrStream;
76116
}
@@ -81,7 +121,7 @@ export async function subscribe(
81121
// the GraphQL specification. The `execute` function provides the
82122
// "ExecuteSubscriptionEvent" algorithm, as it is nearly identical to the
83123
// "ExecuteQuery" algorithm, for which `execute` is also used.
84-
const mapSourceToResponse = (payload: unknown) =>
124+
return mapAsyncIterator(resultOrStream, (payload: unknown) =>
85125
execute({
86126
schema,
87127
document,
@@ -90,10 +130,8 @@ export async function subscribe(
90130
variableValues,
91131
operationName,
92132
fieldResolver,
93-
});
94-
95-
// Map every source value to a ExecutionResult value as described above.
96-
return mapAsyncIterator(resultOrStream, mapSourceToResponse);
133+
}),
134+
);
97135
}
98136

99137
/**
@@ -124,15 +162,15 @@ export async function subscribe(
124162
* or otherwise separating these two steps. For more on this, see the
125163
* "Supporting Subscriptions at Scale" information in the GraphQL specification.
126164
*/
127-
export async function createSourceEventStream(
165+
export function createSourceEventStream(
128166
schema: GraphQLSchema,
129167
document: DocumentNode,
130168
rootValue?: unknown,
131169
contextValue?: unknown,
132170
variableValues?: Maybe<{ readonly [variable: string]: unknown }>,
133171
operationName?: Maybe<string>,
134172
subscribeFieldResolver?: Maybe<GraphQLFieldResolver<any, any>>,
135-
): Promise<AsyncIterable<unknown> | ExecutionResult> {
173+
): PromiseOrValue<AsyncIterable<unknown> | ExecutionResult> {
136174
// If arguments are missing or incorrectly typed, this is an internal
137175
// developer mistake which should throw an early error.
138176
assertValidExecutionArguments(schema, document, variableValues);
@@ -155,17 +193,22 @@ export async function createSourceEventStream(
155193
}
156194

157195
try {
158-
const eventStream = await executeSubscription(exeContext);
159-
160-
// Assert field returned an event stream, otherwise yield an error.
161-
if (!isAsyncIterable(eventStream)) {
162-
throw new Error(
163-
'Subscription field must return Async Iterable. ' +
164-
`Received: ${inspect(eventStream)}.`,
165-
);
196+
const eventStream = executeSubscription(exeContext);
197+
198+
if (isPromise(eventStream)) {
199+
return eventStream
200+
.then((resolvedEventStream) => ensureAsyncIterable(resolvedEventStream))
201+
.then(undefined, (error) => {
202+
// If it GraphQLError, report it as an ExecutionResult, containing only errors and no data.
203+
// Otherwise treat the error as a system-class error and re-throw it.
204+
if (error instanceof GraphQLError) {
205+
return { errors: [error] };
206+
}
207+
throw error;
208+
});
166209
}
167210

168-
return eventStream;
211+
return ensureAsyncIterable(eventStream);
169212
} catch (error) {
170213
// If it GraphQLError, report it as an ExecutionResult, containing only errors and no data.
171214
// Otherwise treat the error as a system-class error and re-throw it.
@@ -176,9 +219,19 @@ export async function createSourceEventStream(
176219
}
177220
}
178221

179-
async function executeSubscription(
180-
exeContext: ExecutionContext,
181-
): Promise<unknown> {
222+
function ensureAsyncIterable(eventStream: unknown): AsyncIterable<unknown> {
223+
// Assert field returned an event stream, otherwise yield an error.
224+
if (!isAsyncIterable(eventStream)) {
225+
throw new Error(
226+
'Subscription field must return Async Iterable. ' +
227+
`Received: ${inspect(eventStream)}.`,
228+
);
229+
}
230+
231+
return eventStream;
232+
}
233+
234+
function executeSubscription(exeContext: ExecutionContext): unknown {
182235
const { schema, fragments, operation, variableValues, rootValue } =
183236
exeContext;
184237

@@ -233,13 +286,26 @@ async function executeSubscription(
233286
// Call the `subscribe()` resolver or the default resolver to produce an
234287
// AsyncIterable yielding raw payloads.
235288
const resolveFn = fieldDef.subscribe ?? exeContext.subscribeFieldResolver;
236-
const eventStream = await resolveFn(rootValue, args, contextValue, info);
237289

238-
if (eventStream instanceof Error) {
239-
throw eventStream;
290+
const eventStream = resolveFn(rootValue, args, contextValue, info);
291+
292+
if (isPromise(eventStream)) {
293+
return eventStream
294+
.then((resolvedEventStream) => throwReturnedError(resolvedEventStream))
295+
.then(undefined, (error) => {
296+
throw locatedError(error, fieldNodes, pathToArray(path));
297+
});
240298
}
241-
return eventStream;
299+
300+
return throwReturnedError(eventStream);
242301
} catch (error) {
243302
throw locatedError(error, fieldNodes, pathToArray(path));
244303
}
245304
}
305+
306+
function throwReturnedError(eventStream: unknown): unknown {
307+
if (eventStream instanceof Error) {
308+
throw eventStream;
309+
}
310+
return eventStream;
311+
}

0 commit comments

Comments
 (0)