Skip to content

Commit 3b27fb6

Browse files
committed
align return types from execution and subscription
with respect to possible promises
1 parent 75d3061 commit 3b27fb6

File tree

2 files changed

+144
-52
lines changed

2 files changed

+144
-52
lines changed

src/execution/__tests__/subscribe-test.ts

Lines changed: 39 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ import { expectJSON } from '../../__testUtils__/expectJSON';
55
import { resolveOnNextTick } from '../../__testUtils__/resolveOnNextTick';
66

77
import { isAsyncIterable } from '../../jsutils/isAsyncIterable';
8+
import { isPromise } from '../../jsutils/isPromise';
9+
import type { PromiseOrValue } from '../../jsutils/PromiseOrValue';
810

911
import { parse } from '../../language/parser';
1012

@@ -135,9 +137,6 @@ async function expectPromise(promise: Promise<unknown>) {
135137
}
136138

137139
return {
138-
toReject() {
139-
expect(caughtError).to.be.an.instanceOf(Error);
140-
},
141140
toRejectWith(message: string) {
142141
expect(caughtError).to.be.an.instanceOf(Error);
143142
expect(caughtError).to.have.property('message', message);
@@ -152,9 +151,9 @@ const DummyQueryType = new GraphQLObjectType({
152151
},
153152
});
154153

155-
async function subscribeWithBadFn(
154+
function subscribeWithBadFn(
156155
subscribeFn: () => unknown,
157-
): Promise<ExecutionResult> {
156+
): PromiseOrValue<ExecutionResult> {
158157
const schema = new GraphQLSchema({
159158
query: DummyQueryType,
160159
subscription: new GraphQLObjectType({
@@ -165,13 +164,28 @@ async function subscribeWithBadFn(
165164
}),
166165
});
167166
const document = parse('subscription { foo }');
168-
const result = await subscribe({ schema, document });
169167

170-
assert(!isAsyncIterable(result));
171-
expectJSON(await createSourceEventStream(schema, document)).toDeepEqual(
172-
result,
173-
);
174-
return result;
168+
const subscribeResult = subscribe({ schema, document });
169+
const streamResult = createSourceEventStream(schema, document);
170+
171+
if (isPromise(subscribeResult)) {
172+
assert(isPromise(streamResult));
173+
return Promise.all([subscribeResult, streamResult]).then((resolved) =>
174+
expectEquivalentStreamErrors(resolved[0], resolved[1]),
175+
);
176+
}
177+
178+
assert(!isPromise(streamResult));
179+
return expectEquivalentStreamErrors(subscribeResult, streamResult);
180+
}
181+
182+
function expectEquivalentStreamErrors(
183+
subscribeResult: ExecutionResult | AsyncGenerator<ExecutionResult>,
184+
createSourceEventStreamResult: ExecutionResult | AsyncIterable<unknown>,
185+
): ExecutionResult {
186+
assert(!isAsyncIterable(subscribeResult));
187+
expectJSON(createSourceEventStreamResult).toDeepEqual(subscribeResult);
188+
return subscribeResult;
175189
}
176190

177191
/* eslint-disable @typescript-eslint/require-await */
@@ -379,24 +393,22 @@ describe('Subscription Initialization Phase', () => {
379393
});
380394

381395
// @ts-expect-error (schema must not be null)
382-
(await expectPromise(subscribe({ schema: null, document }))).toRejectWith(
396+
expect(() => subscribe({ schema: null, document })).to.throw(
383397
'Expected null to be a GraphQL schema.',
384398
);
385399

386400
// @ts-expect-error
387-
(await expectPromise(subscribe({ document }))).toRejectWith(
401+
expect(() => subscribe({ document })).to.throw(
388402
'Expected undefined to be a GraphQL schema.',
389403
);
390404

391405
// @ts-expect-error (document must not be null)
392-
(await expectPromise(subscribe({ schema, document: null }))).toRejectWith(
406+
expect(() => subscribe({ schema, document: null })).to.throw(
393407
'Must provide document.',
394408
);
395409

396410
// @ts-expect-error
397-
(await expectPromise(subscribe({ schema }))).toRejectWith(
398-
'Must provide document.',
399-
);
411+
expect(() => subscribe({ schema })).to.throw('Must provide document.');
400412
});
401413

402414
it('resolves to an error if schema does not support subscriptions', async () => {
@@ -450,11 +462,17 @@ describe('Subscription Initialization Phase', () => {
450462
});
451463

452464
// @ts-expect-error
453-
(await expectPromise(subscribe({ schema, document: {} }))).toReject();
465+
expect(() => subscribe({ schema, document: {} })).to.throw();
454466
});
455467

456468
it('throws an error if subscribe does not return an iterator', async () => {
457-
(await expectPromise(subscribeWithBadFn(() => 'test'))).toRejectWith(
469+
expect(() => subscribeWithBadFn(() => 'test')).to.throw(
470+
'Subscription field must return Async Iterable. Received: "test".',
471+
);
472+
473+
const result = subscribeWithBadFn(() => Promise.resolve('test'));
474+
assert(isPromise(result));
475+
(await expectPromise(result)).toRejectWith(
458476
'Subscription field must return Async Iterable. Received: "test".',
459477
);
460478
});
@@ -472,12 +490,12 @@ describe('Subscription Initialization Phase', () => {
472490

473491
expectJSON(
474492
// Returning an error
475-
await subscribeWithBadFn(() => new Error('test error')),
493+
subscribeWithBadFn(() => new Error('test error')),
476494
).toDeepEqual(expectedResult);
477495

478496
expectJSON(
479497
// Throwing an error
480-
await subscribeWithBadFn(() => {
498+
subscribeWithBadFn(() => {
481499
throw new Error('test error');
482500
}),
483501
).toDeepEqual(expectedResult);

src/execution/subscribe.ts

Lines changed: 105 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import { inspect } from '../jsutils/inspect';
22
import { isAsyncIterable } from '../jsutils/isAsyncIterable';
3+
import { isPromise } from '../jsutils/isPromise';
34
import type { Maybe } from '../jsutils/Maybe';
45
import { addPath, pathToArray } from '../jsutils/Path';
6+
import type { PromiseOrValue } from '../jsutils/PromiseOrValue';
57

68
import { GraphQLError } from '../error/GraphQLError';
79
import { locatedError } from '../error/locatedError';
810

9-
import type { DocumentNode } from '../language/ast';
11+
import type { DocumentNode, FieldNode } from '../language/ast';
1012

1113
import type { GraphQLFieldResolver } from '../type/definition';
1214
import type { GraphQLSchema } from '../type/schema';
@@ -47,9 +49,11 @@ import { getArgumentValues } from './values';
4749
*
4850
* Accepts either an object with named arguments, or individual arguments.
4951
*/
50-
export async function subscribe(
52+
export function subscribe(
5153
args: ExecutionArgs,
52-
): Promise<AsyncGenerator<ExecutionResult, void, void> | ExecutionResult> {
54+
): PromiseOrValue<
55+
AsyncGenerator<ExecutionResult, void, void> | ExecutionResult
56+
> {
5357
const {
5458
schema,
5559
document,
@@ -61,7 +65,7 @@ export async function subscribe(
6165
subscribeFieldResolver,
6266
} = args;
6367

64-
const resultOrStream = await createSourceEventStream(
68+
const resultOrStream = createSourceEventStream(
6569
schema,
6670
document,
6771
rootValue,
@@ -71,6 +75,42 @@ export async function subscribe(
7175
subscribeFieldResolver,
7276
);
7377

78+
if (isPromise(resultOrStream)) {
79+
return resultOrStream.then((resolvedResultOrStream) =>
80+
mapSourceToResponse(
81+
schema,
82+
document,
83+
resolvedResultOrStream,
84+
contextValue,
85+
variableValues,
86+
operationName,
87+
fieldResolver,
88+
),
89+
);
90+
}
91+
92+
return mapSourceToResponse(
93+
schema,
94+
document,
95+
resultOrStream,
96+
contextValue,
97+
variableValues,
98+
operationName,
99+
fieldResolver,
100+
);
101+
}
102+
103+
function mapSourceToResponse(
104+
schema: GraphQLSchema,
105+
document: DocumentNode,
106+
resultOrStream: ExecutionResult | AsyncIterable<unknown>,
107+
contextValue?: unknown,
108+
variableValues?: Maybe<{ readonly [variable: string]: unknown }>,
109+
operationName?: Maybe<string>,
110+
fieldResolver?: Maybe<GraphQLFieldResolver<any, any>>,
111+
): PromiseOrValue<
112+
AsyncGenerator<ExecutionResult, void, void> | ExecutionResult
113+
> {
74114
if (!isAsyncIterable(resultOrStream)) {
75115
return resultOrStream;
76116
}
@@ -81,7 +121,7 @@ export async function subscribe(
81121
// the GraphQL specification. The `execute` function provides the
82122
// "ExecuteSubscriptionEvent" algorithm, as it is nearly identical to the
83123
// "ExecuteQuery" algorithm, for which `execute` is also used.
84-
const mapSourceToResponse = (payload: unknown) =>
124+
return mapAsyncIterator(resultOrStream, (payload: unknown) =>
85125
execute({
86126
schema,
87127
document,
@@ -90,10 +130,8 @@ export async function subscribe(
90130
variableValues,
91131
operationName,
92132
fieldResolver,
93-
});
94-
95-
// Map every source value to a ExecutionResult value as described above.
96-
return mapAsyncIterator(resultOrStream, mapSourceToResponse);
133+
}),
134+
);
97135
}
98136

99137
/**
@@ -124,15 +162,15 @@ export async function subscribe(
124162
* or otherwise separating these two steps. For more on this, see the
125163
* "Supporting Subscriptions at Scale" information in the GraphQL specification.
126164
*/
127-
export async function createSourceEventStream(
165+
export function createSourceEventStream(
128166
schema: GraphQLSchema,
129167
document: DocumentNode,
130168
rootValue?: unknown,
131169
contextValue?: unknown,
132170
variableValues?: Maybe<{ readonly [variable: string]: unknown }>,
133171
operationName?: Maybe<string>,
134172
subscribeFieldResolver?: Maybe<GraphQLFieldResolver<any, any>>,
135-
): Promise<AsyncIterable<unknown> | ExecutionResult> {
173+
): PromiseOrValue<AsyncIterable<unknown> | ExecutionResult> {
136174
// If arguments are missing or incorrectly typed, this is an internal
137175
// developer mistake which should throw an early error.
138176
assertValidExecutionArguments(schema, document, variableValues);
@@ -155,30 +193,43 @@ export async function createSourceEventStream(
155193
}
156194

157195
try {
158-
const eventStream = await executeSubscription(exeContext);
196+
const eventStream = executeSubscription(exeContext);
159197

160-
// Assert field returned an event stream, otherwise yield an error.
161-
if (!isAsyncIterable(eventStream)) {
162-
throw new Error(
163-
'Subscription field must return Async Iterable. ' +
164-
`Received: ${inspect(eventStream)}.`,
198+
if (isPromise(eventStream)) {
199+
return eventStream.then(
200+
(resolvedEventStream) => ensureAsyncIterable(resolvedEventStream),
201+
(error) => handleRawError(error),
165202
);
166203
}
167204

168-
return eventStream;
205+
return ensureAsyncIterable(eventStream);
169206
} catch (error) {
170-
// If it GraphQLError, report it as an ExecutionResult, containing only errors and no data.
171-
// Otherwise treat the error as a system-class error and re-throw it.
172-
if (error instanceof GraphQLError) {
173-
return { errors: [error] };
174-
}
175-
throw error;
207+
return handleRawError(error);
176208
}
177209
}
178210

179-
async function executeSubscription(
180-
exeContext: ExecutionContext,
181-
): Promise<unknown> {
211+
function ensureAsyncIterable(eventStream: unknown): AsyncIterable<unknown> {
212+
// Assert field returned an event stream, otherwise yield an error.
213+
if (!isAsyncIterable(eventStream)) {
214+
throw new Error(
215+
'Subscription field must return Async Iterable. ' +
216+
`Received: ${inspect(eventStream)}.`,
217+
);
218+
}
219+
220+
return eventStream;
221+
}
222+
223+
function handleRawError(error: unknown): ExecutionResult {
224+
// If it GraphQLError, report it as an ExecutionResult, containing only errors and no data.
225+
// Otherwise treat the error as a system-class error and re-throw it.
226+
if (error instanceof GraphQLError) {
227+
return { errors: [error] };
228+
}
229+
throw error;
230+
}
231+
232+
function executeSubscription(exeContext: ExecutionContext): unknown {
182233
const { schema, fragments, operation, variableValues, rootValue } =
183234
exeContext;
184235

@@ -233,13 +284,36 @@ async function executeSubscription(
233284
// Call the `subscribe()` resolver or the default resolver to produce an
234285
// AsyncIterable yielding raw payloads.
235286
const resolveFn = fieldDef.subscribe ?? exeContext.subscribeFieldResolver;
236-
const eventStream = await resolveFn(rootValue, args, contextValue, info);
237287

238-
if (eventStream instanceof Error) {
239-
throw eventStream;
288+
const eventStream = resolveFn(rootValue, args, contextValue, info);
289+
290+
if (isPromise(eventStream)) {
291+
return eventStream.then(
292+
(resolvedEventStream) =>
293+
throwReturnedError(
294+
resolvedEventStream,
295+
fieldNodes,
296+
pathToArray(path),
297+
),
298+
(error) => {
299+
throw locatedError(error, fieldNodes, pathToArray(path));
300+
},
301+
);
240302
}
241-
return eventStream;
303+
304+
return throwReturnedError(eventStream, fieldNodes, pathToArray(path));
242305
} catch (error) {
243306
throw locatedError(error, fieldNodes, pathToArray(path));
244307
}
245308
}
309+
310+
function throwReturnedError(
311+
eventStream: unknown,
312+
fieldNodes: ReadonlyArray<FieldNode>,
313+
path: ReadonlyArray<string | number>,
314+
): unknown {
315+
if (eventStream instanceof Error) {
316+
throw locatedError(eventStream, fieldNodes, path);
317+
}
318+
return eventStream;
319+
}

0 commit comments

Comments
 (0)