Skip to content

Commit 713480f

Browse files
committed
align return types from execution and subscription
with respect to possible promises
1 parent ea1894a commit 713480f

File tree

2 files changed

+129
-47
lines changed

2 files changed

+129
-47
lines changed

src/execution/__tests__/subscribe-test.ts

Lines changed: 40 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ import { expectJSON } from '../../__testUtils__/expectJSON';
55
import { resolveOnNextTick } from '../../__testUtils__/resolveOnNextTick';
66

77
import { isAsyncIterable } from '../../jsutils/isAsyncIterable';
8+
import { isPromise } from '../../jsutils/isPromise';
9+
import type { PromiseOrValue } from '../../jsutils/PromiseOrValue';
810

911
import { parse } from '../../language/parser';
1012

@@ -135,9 +137,6 @@ async function expectPromise(promise: Promise<unknown>) {
135137
}
136138

137139
return {
138-
toReject() {
139-
expect(caughtError).to.be.an.instanceOf(Error);
140-
},
141140
toRejectWith(message: string) {
142141
expect(caughtError).to.be.an.instanceOf(Error);
143142
expect(caughtError).to.have.property('message', message);
@@ -152,9 +151,9 @@ const DummyQueryType = new GraphQLObjectType({
152151
},
153152
});
154153

155-
async function subscribeWithBadFn(
154+
function subscribeWithBadFn(
156155
subscribeFn: () => unknown,
157-
): Promise<ExecutionResult> {
156+
): PromiseOrValue<ExecutionResult> {
158157
const schema = new GraphQLSchema({
159158
query: DummyQueryType,
160159
subscription: new GraphQLObjectType({
@@ -165,13 +164,28 @@ async function subscribeWithBadFn(
165164
}),
166165
});
167166
const document = parse('subscription { foo }');
168-
const result = await subscribe({ schema, document });
169167

170-
assert(!isAsyncIterable(result));
171-
expectJSON(await createSourceEventStream(schema, document)).toDeepEqual(
172-
result,
173-
);
174-
return result;
168+
const subscribeResult = subscribe({ schema, document });
169+
const streamResult = createSourceEventStream(schema, document);
170+
171+
if (isPromise(subscribeResult)) {
172+
assert(isPromise(streamResult));
173+
return Promise.all([subscribeResult, streamResult]).then((resolved) =>
174+
expectEquivalentStreamErrors(resolved[0], resolved[1]),
175+
);
176+
}
177+
178+
assert(!isPromise(streamResult));
179+
return expectEquivalentStreamErrors(subscribeResult, streamResult);
180+
}
181+
182+
function expectEquivalentStreamErrors(
183+
subscribeResult: ExecutionResult | AsyncGenerator<ExecutionResult>,
184+
createSourceEventStreamResult: ExecutionResult | AsyncIterable<unknown>,
185+
): ExecutionResult {
186+
assert(!isAsyncIterable(subscribeResult));
187+
expectJSON(createSourceEventStreamResult).toDeepEqual(subscribeResult);
188+
return subscribeResult;
175189
}
176190

177191
/* eslint-disable @typescript-eslint/require-await */
@@ -379,24 +393,22 @@ describe('Subscription Initialization Phase', () => {
379393
});
380394

381395
// @ts-expect-error (schema must not be null)
382-
(await expectPromise(subscribe({ schema: null, document }))).toRejectWith(
396+
expect(() => subscribe({ schema: null, document })).to.throw(
383397
'Expected null to be a GraphQL schema.',
384398
);
385399

386400
// @ts-expect-error
387-
(await expectPromise(subscribe({ document }))).toRejectWith(
401+
expect(() => subscribe({ document })).to.throw(
388402
'Expected undefined to be a GraphQL schema.',
389403
);
390404

391405
// @ts-expect-error (document must not be null)
392-
(await expectPromise(subscribe({ schema, document: null }))).toRejectWith(
406+
expect(() => subscribe({ schema, document: null })).to.throw(
393407
'Must provide document.',
394408
);
395409

396410
// @ts-expect-error
397-
(await expectPromise(subscribe({ schema }))).toRejectWith(
398-
'Must provide document.',
399-
);
411+
expect(() => subscribe({ schema })).to.throw('Must provide document.');
400412
});
401413

402414
it('resolves to an error if schema does not support subscriptions', async () => {
@@ -450,11 +462,11 @@ describe('Subscription Initialization Phase', () => {
450462
});
451463

452464
// @ts-expect-error
453-
(await expectPromise(subscribe({ schema, document: {} }))).toReject();
465+
expect(() => subscribe({ schema, document: {} })).to.throw();
454466
});
455467

456468
it('throws an error if subscribe does not return an iterator', async () => {
457-
expectJSON(await subscribeWithBadFn(() => 'test')).toDeepEqual({
469+
const expectedResult = {
458470
errors: [
459471
{
460472
message:
@@ -463,7 +475,13 @@ describe('Subscription Initialization Phase', () => {
463475
path: ['foo'],
464476
},
465477
],
466-
});
478+
};
479+
480+
expectJSON(subscribeWithBadFn(() => 'test')).toDeepEqual(expectedResult);
481+
482+
const result = subscribeWithBadFn(() => Promise.resolve('test'));
483+
assert(isPromise(result));
484+
expectJSON(await result).toDeepEqual(expectedResult);
467485
});
468486

469487
it('resolves to an error for subscription resolver errors', async () => {
@@ -479,12 +497,12 @@ describe('Subscription Initialization Phase', () => {
479497

480498
expectJSON(
481499
// Returning an error
482-
await subscribeWithBadFn(() => new Error('test error')),
500+
subscribeWithBadFn(() => new Error('test error')),
483501
).toDeepEqual(expectedResult);
484502

485503
expectJSON(
486504
// Throwing an error
487-
await subscribeWithBadFn(() => {
505+
subscribeWithBadFn(() => {
488506
throw new Error('test error');
489507
}),
490508
).toDeepEqual(expectedResult);

src/execution/subscribe.ts

Lines changed: 89 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
import { inspect } from '../jsutils/inspect';
22
import { isAsyncIterable } from '../jsutils/isAsyncIterable';
3+
import { isPromise } from '../jsutils/isPromise';
34
import type { Maybe } from '../jsutils/Maybe';
5+
import type { Path } from '../jsutils/Path';
46
import { addPath, pathToArray } from '../jsutils/Path';
7+
import type { PromiseOrValue } from '../jsutils/PromiseOrValue';
58

69
import { GraphQLError } from '../error/GraphQLError';
710
import { locatedError } from '../error/locatedError';
811

9-
import type { DocumentNode } from '../language/ast';
12+
import type { DocumentNode, FieldNode } from '../language/ast';
1013

1114
import type { GraphQLFieldResolver } from '../type/definition';
1215
import type { GraphQLSchema } from '../type/schema';
@@ -47,9 +50,11 @@ import { getArgumentValues } from './values';
4750
*
4851
* Accepts either an object with named arguments, or individual arguments.
4952
*/
50-
export async function subscribe(
53+
export function subscribe(
5154
args: ExecutionArgs,
52-
): Promise<AsyncGenerator<ExecutionResult, void, void> | ExecutionResult> {
55+
): PromiseOrValue<
56+
AsyncGenerator<ExecutionResult, void, void> | ExecutionResult
57+
> {
5358
const {
5459
schema,
5560
document,
@@ -61,7 +66,7 @@ export async function subscribe(
6166
subscribeFieldResolver,
6267
} = args;
6368

64-
const resultOrStream = await createSourceEventStream(
69+
const resultOrStream = createSourceEventStream(
6570
schema,
6671
document,
6772
rootValue,
@@ -71,6 +76,42 @@ export async function subscribe(
7176
subscribeFieldResolver,
7277
);
7378

79+
if (isPromise(resultOrStream)) {
80+
return resultOrStream.then((resolvedResultOrStream) =>
81+
mapSourceToResponse(
82+
schema,
83+
document,
84+
resolvedResultOrStream,
85+
contextValue,
86+
variableValues,
87+
operationName,
88+
fieldResolver,
89+
),
90+
);
91+
}
92+
93+
return mapSourceToResponse(
94+
schema,
95+
document,
96+
resultOrStream,
97+
contextValue,
98+
variableValues,
99+
operationName,
100+
fieldResolver,
101+
);
102+
}
103+
104+
function mapSourceToResponse(
105+
schema: GraphQLSchema,
106+
document: DocumentNode,
107+
resultOrStream: ExecutionResult | AsyncIterable<unknown>,
108+
contextValue?: unknown,
109+
variableValues?: Maybe<{ readonly [variable: string]: unknown }>,
110+
operationName?: Maybe<string>,
111+
fieldResolver?: Maybe<GraphQLFieldResolver<any, any>>,
112+
): PromiseOrValue<
113+
AsyncGenerator<ExecutionResult, void, void> | ExecutionResult
114+
> {
74115
if (!isAsyncIterable(resultOrStream)) {
75116
return resultOrStream;
76117
}
@@ -81,7 +122,7 @@ export async function subscribe(
81122
// the GraphQL specification. The `execute` function provides the
82123
// "ExecuteSubscriptionEvent" algorithm, as it is nearly identical to the
83124
// "ExecuteQuery" algorithm, for which `execute` is also used.
84-
const mapSourceToResponse = (payload: unknown) =>
125+
return mapAsyncIterator(resultOrStream, (payload: unknown) =>
85126
execute({
86127
schema,
87128
document,
@@ -90,10 +131,8 @@ export async function subscribe(
90131
variableValues,
91132
operationName,
92133
fieldResolver,
93-
});
94-
95-
// Map every source value to a ExecutionResult value as described above.
96-
return mapAsyncIterator(resultOrStream, mapSourceToResponse);
134+
}),
135+
);
97136
}
98137

99138
/**
@@ -124,15 +163,15 @@ export async function subscribe(
124163
* or otherwise separating these two steps. For more on this, see the
125164
* "Supporting Subscriptions at Scale" information in the GraphQL specification.
126165
*/
127-
export async function createSourceEventStream(
166+
export function createSourceEventStream(
128167
schema: GraphQLSchema,
129168
document: DocumentNode,
130169
rootValue?: unknown,
131170
contextValue?: unknown,
132171
variableValues?: Maybe<{ readonly [variable: string]: unknown }>,
133172
operationName?: Maybe<string>,
134173
subscribeFieldResolver?: Maybe<GraphQLFieldResolver<any, any>>,
135-
): Promise<AsyncIterable<unknown> | ExecutionResult> {
174+
): PromiseOrValue<AsyncIterable<unknown> | ExecutionResult> {
136175
// If arguments are missing or incorrectly typed, this is an internal
137176
// developer mistake which should throw an early error.
138177
assertValidExecutionArguments(schema, document, variableValues);
@@ -155,17 +194,20 @@ export async function createSourceEventStream(
155194
}
156195

157196
try {
158-
const eventStream = await executeSubscription(exeContext);
197+
const eventStream = executeSubscription(exeContext);
198+
if (isPromise(eventStream)) {
199+
return eventStream.then(undefined, (error) => ({ errors: [error] }));
200+
}
159201

160202
return eventStream;
161203
} catch (error) {
162204
return { errors: [error] };
163205
}
164206
}
165207

166-
async function executeSubscription(
208+
function executeSubscription(
167209
exeContext: ExecutionContext,
168-
): Promise<AsyncIterable<unknown>> {
210+
): PromiseOrValue<AsyncIterable<unknown> | ExecutionResult> {
169211
const { schema, fragments, operation, variableValues, rootValue } =
170212
exeContext;
171213

@@ -220,22 +262,44 @@ async function executeSubscription(
220262
// Call the `subscribe()` resolver or the default resolver to produce an
221263
// AsyncIterable yielding raw payloads.
222264
const resolveFn = fieldDef.subscribe ?? exeContext.subscribeFieldResolver;
223-
const eventStream = await resolveFn(rootValue, args, contextValue, info);
224-
225-
if (eventStream instanceof Error) {
226-
throw eventStream;
227-
}
265+
const eventStream = resolveFn(rootValue, args, contextValue, info);
228266

229-
// Assert field returned an event stream, otherwise yield an error.
230-
if (!isAsyncIterable(eventStream)) {
231-
throw new GraphQLError(
232-
'Subscription field must return Async Iterable. ' +
233-
`Received: ${inspect(eventStream)}.`,
267+
if (isPromise(eventStream)) {
268+
return eventStream.then(
269+
(resolvedEventStream) =>
270+
ensureAsyncIterable(resolvedEventStream, fieldNodes, path),
271+
(error) => {
272+
throw locatedError(error, fieldNodes, pathToArray(path));
273+
},
234274
);
235275
}
236276

237-
return eventStream;
277+
return ensureAsyncIterable(eventStream, fieldNodes, path);
238278
} catch (error) {
239279
throw locatedError(error, fieldNodes, pathToArray(path));
240280
}
241281
}
282+
283+
function ensureAsyncIterable(
284+
eventStream: unknown,
285+
fieldNodes: ReadonlyArray<FieldNode>,
286+
path: Path,
287+
): AsyncIterable<unknown> {
288+
if (eventStream instanceof Error) {
289+
throw locatedError(eventStream, fieldNodes, pathToArray(path));
290+
}
291+
292+
// Assert field returned an event stream, otherwise yield an error.
293+
if (!isAsyncIterable(eventStream)) {
294+
throw locatedError(
295+
new GraphQLError(
296+
'Subscription field must return Async Iterable. ' +
297+
`Received: ${inspect(eventStream)}.`,
298+
),
299+
fieldNodes,
300+
pathToArray(path),
301+
);
302+
}
303+
304+
return eventStream;
305+
}

0 commit comments

Comments
 (0)