Skip to content

Commit 4934616

Browse files
authored
Merge pull request #200 from JacobMountain/defer_support
Defer support
2 parents 67d6cd5 + afc4c49 commit 4934616

File tree

4 files changed

+151
-30
lines changed

4 files changed

+151
-30
lines changed

src/main/java/graphql/servlet/AbstractGraphQLHttpServlet.java

+29-24
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import com.google.common.io.ByteStreams;
44
import com.google.common.io.CharStreams;
55
import graphql.ExecutionResult;
6+
import graphql.GraphQL;
7+
import graphql.execution.reactive.SingleSubscriberPublisher;
68
import graphql.introspection.IntrospectionQuery;
79
import graphql.schema.GraphQLFieldDefinition;
810
import graphql.servlet.config.GraphQLConfiguration;
@@ -13,11 +15,7 @@
1315
import graphql.servlet.core.GraphQLServletListener;
1416
import graphql.servlet.core.internal.GraphQLRequest;
1517
import graphql.servlet.core.internal.VariableMapper;
16-
import graphql.servlet.input.BatchInputPreProcessResult;
17-
import graphql.servlet.input.BatchInputPreProcessor;
18-
import graphql.servlet.input.GraphQLBatchedInvocationInput;
19-
import graphql.servlet.input.GraphQLInvocationInputFactory;
20-
import graphql.servlet.input.GraphQLSingleInvocationInput;
18+
import graphql.servlet.input.*;
2119
import org.reactivestreams.Publisher;
2220
import org.reactivestreams.Subscriber;
2321
import org.reactivestreams.Subscription;
@@ -28,24 +26,12 @@
2826
import javax.servlet.AsyncEvent;
2927
import javax.servlet.AsyncListener;
3028
import javax.servlet.Servlet;
31-
import javax.servlet.ServletException;
3229
import javax.servlet.http.HttpServlet;
3330
import javax.servlet.http.HttpServletRequest;
3431
import javax.servlet.http.HttpServletResponse;
3532
import javax.servlet.http.Part;
36-
import java.io.BufferedInputStream;
37-
import java.io.ByteArrayOutputStream;
38-
import java.io.IOException;
39-
import java.io.InputStream;
40-
import java.io.Writer;
41-
import java.util.ArrayList;
42-
import java.util.Arrays;
43-
import java.util.HashMap;
44-
import java.util.Iterator;
45-
import java.util.List;
46-
import java.util.Map;
47-
import java.util.Objects;
48-
import java.util.Optional;
33+
import java.io.*;
34+
import java.util.*;
4935
import java.util.concurrent.CountDownLatch;
5036
import java.util.concurrent.atomic.AtomicReference;
5137
import java.util.function.BiConsumer;
@@ -354,13 +340,13 @@ private void doRequest(HttpServletRequest request, HttpServletResponse response,
354340
}
355341

356342
@Override
357-
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
343+
protected void doGet(HttpServletRequest req, HttpServletResponse resp) {
358344
init();
359345
doRequestAsync(req, resp, getHandler);
360346
}
361347

362348
@Override
363-
protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
349+
protected void doPost(HttpServletRequest req, HttpServletResponse resp) {
364350
init();
365351
doRequestAsync(req, resp, postHandler);
366352
}
@@ -373,7 +359,9 @@ private void query(GraphQLQueryInvoker queryInvoker, GraphQLObjectMapper graphQL
373359
HttpServletRequest req, HttpServletResponse resp) throws IOException {
374360
ExecutionResult result = queryInvoker.query(invocationInput);
375361

376-
if (!(result.getData() instanceof Publisher)) {
362+
boolean isDeferred = Objects.nonNull(result.getExtensions()) && result.getExtensions().containsKey(GraphQL.DEFERRED_RESULTS);
363+
364+
if (!(result.getData() instanceof Publisher || isDeferred)) {
377365
resp.setContentType(APPLICATION_JSON_UTF8);
378366
resp.setStatus(STATUS_OK);
379367
resp.getWriter().write(graphQLObjectMapper.serializeResultAsJson(result));
@@ -390,7 +378,16 @@ private void query(GraphQLQueryInvoker queryInvoker, GraphQLObjectMapper graphQL
390378
AtomicReference<Subscription> subscriptionRef = new AtomicReference<>();
391379
asyncContext.addListener(new SubscriptionAsyncListener(subscriptionRef));
392380
ExecutionResultSubscriber subscriber = new ExecutionResultSubscriber(subscriptionRef, asyncContext, graphQLObjectMapper);
393-
((Publisher<ExecutionResult>) result.getData()).subscribe(subscriber);
381+
List<Publisher<ExecutionResult>> publishers = new ArrayList<>();
382+
if (result.getData() instanceof Publisher) {
383+
publishers.add(result.getData());
384+
} else {
385+
publishers.add(new StaticDataPublisher<>(result));
386+
final Publisher<ExecutionResult> deferredResultsPublisher = (Publisher<ExecutionResult>) result.getExtensions().get(GraphQL.DEFERRED_RESULTS);
387+
publishers.add(deferredResultsPublisher);
388+
}
389+
publishers.forEach(it -> it.subscribe(subscriber));
390+
394391
if (isInAsyncThread) {
395392
// We need to delay the completion of async context until after the subscription has terminated, otherwise the AsyncContext is prematurely closed.
396393
try {
@@ -537,7 +534,6 @@ public void onStartAsync(AsyncEvent event) {
537534
}
538535
}
539536

540-
541537
private static class ExecutionResultSubscriber implements Subscriber<ExecutionResult> {
542538

543539
private final AtomicReference<Subscription> subscriptionRef;
@@ -584,4 +580,13 @@ public void await() throws InterruptedException {
584580
completedLatch.await();
585581
}
586582
}
583+
584+
private static class StaticDataPublisher<T> extends SingleSubscriberPublisher<T> implements Publisher<T> {
585+
StaticDataPublisher(T data) {
586+
super();
587+
super.offer(data);
588+
super.noMoreData();
589+
}
590+
}
591+
587592
}

src/main/java/graphql/servlet/core/GraphQLObjectMapper.java

+15-5
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,8 @@
55
import com.fasterxml.jackson.databind.MappingIterator;
66
import com.fasterxml.jackson.databind.ObjectMapper;
77
import com.fasterxml.jackson.databind.ObjectReader;
8-
import graphql.ExecutionResult;
9-
import graphql.ExecutionResultImpl;
10-
import graphql.GraphQLError;
8+
import graphql.*;
9+
import graphql.execution.ExecutionPath;
1110
import graphql.servlet.config.ConfiguringObjectMapperProvider;
1211
import graphql.servlet.config.ObjectMapperConfigurer;
1312
import graphql.servlet.config.ObjectMapperProvider;
@@ -117,12 +116,19 @@ public ExecutionResult sanitizeErrors(ExecutionResult executionResult) {
117116
} else {
118117
errors = null;
119118
}
120-
121119
return new ExecutionResultImpl(data, errors, extensions);
122120
}
123121

124122
public Map<String, Object> createResultFromExecutionResult(ExecutionResult executionResult) {
125-
return convertSanitizedExecutionResult(sanitizeErrors(executionResult));
123+
ExecutionResult sanitizedExecutionResult = sanitizeErrors(executionResult);
124+
if (executionResult instanceof DeferredExecutionResult) {
125+
sanitizedExecutionResult = DeferredExecutionResultImpl
126+
.newDeferredExecutionResult()
127+
.from(executionResult)
128+
.path(ExecutionPath.fromList(((DeferredExecutionResult) executionResult).getPath()))
129+
.build();
130+
}
131+
return convertSanitizedExecutionResult(sanitizedExecutionResult);
126132
}
127133

128134
public Map<String, Object> convertSanitizedExecutionResult(ExecutionResult executionResult) {
@@ -144,6 +150,10 @@ public Map<String, Object> convertSanitizedExecutionResult(ExecutionResult execu
144150
result.put("data", executionResult.getData());
145151
}
146152

153+
if (executionResult instanceof DeferredExecutionResult) {
154+
result.put("path", ((DeferredExecutionResult) executionResult).getPath());
155+
}
156+
147157
return result;
148158
}
149159

src/test/groovy/graphql/servlet/AbstractGraphQLHttpServletSpec.groovy

+77
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,28 @@ class AbstractGraphQLHttpServletSpec extends Specification {
283283
getBatchedResponseContent()[1].data.echo == "test"
284284
}
285285

286+
287+
def "deferred query over HTTP GET"() {
288+
setup:
289+
request.addParameter('query', 'query { echo(arg:"test") @defer }')
290+
291+
when:
292+
servlet.doGet(request, response)
293+
294+
then:
295+
response.getStatus() == STATUS_OK
296+
response.getContentType() == CONTENT_TYPE_SERVER_SENT_EVENTS
297+
getSubscriptionResponseContent()[0].data.echo == null
298+
299+
when:
300+
subscriptionLatch.await(1, TimeUnit.SECONDS)
301+
302+
then:
303+
def content = getSubscriptionResponseContent()
304+
content[1].data == "test"
305+
content[1].path == ["echo"]
306+
}
307+
286308
def "Batch Execution Handler allows limiting batches and sending error messages."() {
287309
setup:
288310
servlet = TestUtils.createBatchCustomizedServlet({ env -> env.arguments.arg }, { env -> env.arguments.arg }, { env ->
@@ -1030,6 +1052,61 @@ class AbstractGraphQLHttpServletSpec extends Specification {
10301052
getSubscriptionResponseContent()[1].data.echo == "Second\n\ntest"
10311053
}
10321054

1055+
def "defer query over HTTP POST"() {
1056+
setup:
1057+
request.setContent('{"query": "subscription Subscription($arg: String!) { echo(arg: $arg) }", "operationName": "Subscription", "variables": {"arg": "test"}}'.bytes)
1058+
request.setAsyncSupported(true)
1059+
1060+
when:
1061+
servlet.doPost(request, response)
1062+
then:
1063+
response.getStatus() == STATUS_OK
1064+
response.getContentType() == CONTENT_TYPE_SERVER_SENT_EVENTS
1065+
1066+
when:
1067+
subscriptionLatch.await(1, TimeUnit.SECONDS)
1068+
then:
1069+
getSubscriptionResponseContent()[0].data.echo == "First\n\ntest"
1070+
getSubscriptionResponseContent()[1].data.echo == "Second\n\ntest"
1071+
}
1072+
1073+
def "deferred query that takes longer than initial results, should still be sent second"() {
1074+
setup:
1075+
servlet = TestUtils.createDefaultServlet({ env ->
1076+
if (env.getField().name == "a") {
1077+
Thread.sleep(1000)
1078+
}
1079+
env.arguments.arg
1080+
})
1081+
request.setContent(mapper.writeValueAsBytes([
1082+
query: '''
1083+
{
1084+
object {
1085+
a(arg: "Hello")
1086+
b(arg: "World") @defer
1087+
}
1088+
}
1089+
'''
1090+
]))
1091+
request.setAsyncSupported(true)
1092+
1093+
when:
1094+
servlet.doPost(request, response)
1095+
1096+
then:
1097+
response.getStatus() == STATUS_OK
1098+
response.getContentType() == CONTENT_TYPE_SERVER_SENT_EVENTS
1099+
getSubscriptionResponseContent()[0].data.object.a == "Hello" // a has a Thread.sleep
1100+
1101+
when:
1102+
subscriptionLatch.await(1, TimeUnit.SECONDS)
1103+
1104+
then:
1105+
def content = getSubscriptionResponseContent()
1106+
content[1].data == "World"
1107+
content[1].path == ["object", "b"]
1108+
}
1109+
10331110
def "errors before graphql schema execution return internal server error"() {
10341111
setup:
10351112
servlet = SimpleGraphQLHttpServlet.newBuilder(GraphQLInvocationInputFactory.newBuilder {

src/test/groovy/graphql/servlet/TestUtils.groovy

+30-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package graphql.servlet
22

33
import com.google.common.io.ByteStreams
4+
import graphql.Directives
45
import graphql.Scalars
56
import graphql.execution.reactive.SingleSubscriberPublisher
67
import graphql.schema.*
@@ -15,6 +16,7 @@ import graphql.servlet.core.ApolloScalars
1516
import graphql.servlet.input.BatchInputPreProcessor
1617
import graphql.servlet.context.ContextSetting
1718

19+
import java.util.concurrent.CompletableFuture
1820
import java.util.concurrent.atomic.AtomicReference
1921

2022
class TestUtils {
@@ -95,7 +97,7 @@ class TestUtils {
9597
static def createGraphQlSchema(DataFetcher queryDataFetcher = { env -> env.arguments.arg },
9698
DataFetcher mutationDataFetcher = { env -> env.arguments.arg },
9799
DataFetcher subscriptionDataFetcher = { env ->
98-
AtomicReference<SingleSubscriberPublisher<String>> publisherRef = new AtomicReference<>();
100+
AtomicReference<SingleSubscriberPublisher<String>> publisherRef = new AtomicReference<>()
99101
publisherRef.set(new SingleSubscriberPublisher<>({ subscription ->
100102
publisherRef.get().offer(env.arguments.arg)
101103
publisherRef.get().noMoreData()
@@ -113,6 +115,32 @@ class TestUtils {
113115
}
114116
field.dataFetcher(queryDataFetcher)
115117
}
118+
.field { GraphQLFieldDefinition.Builder field ->
119+
field.name("object")
120+
field.type(
121+
GraphQLObjectType.newObject()
122+
.name("NestedObject")
123+
.field { nested ->
124+
nested.name("a")
125+
nested.type(Scalars.GraphQLString)
126+
nested.argument { argument ->
127+
argument.name("arg")
128+
argument.type(Scalars.GraphQLString)
129+
}
130+
nested.dataFetcher(queryDataFetcher)
131+
}
132+
.field { nested ->
133+
nested.name("b")
134+
nested.type(Scalars.GraphQLString)
135+
nested.argument { argument ->
136+
argument.name("arg")
137+
argument.type(Scalars.GraphQLString)
138+
}
139+
nested.dataFetcher(queryDataFetcher)
140+
}
141+
)
142+
field.dataFetcher(new StaticDataFetcher([:]))
143+
}
116144
.field { GraphQLFieldDefinition.Builder field ->
117145
field.name("returnsNullIncorrectly")
118146
field.type(new GraphQLNonNull(Scalars.GraphQLString))
@@ -174,6 +202,7 @@ class TestUtils {
174202
.mutation(mutation)
175203
.subscription(subscription)
176204
.additionalType(ApolloScalars.Upload)
205+
.additionalDirective(Directives.DeferDirective)
177206
.build()
178207
}
179208

0 commit comments

Comments
 (0)