Skip to content

Commit c8bb3ac

Browse files
authored
Merge pull request #83 from ashu-walmart/DataLoader-in-context
Add DataLoaderRegistry to GraphQLContext
2 parents 7d9ea36 + a4ede44 commit c8bb3ac

File tree

4 files changed

+109
-30
lines changed

4 files changed

+109
-30
lines changed

src/main/java/graphql/servlet/GraphQLContext.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
package graphql.servlet;
22

3+
import org.dataloader.DataLoaderRegistry;
4+
35
import javax.security.auth.Subject;
46
import javax.servlet.http.HttpServletRequest;
57
import javax.servlet.http.Part;
@@ -15,6 +17,8 @@ public class GraphQLContext {
1517
private Subject subject;
1618
private Map<String, List<Part>> files;
1719

20+
private DataLoaderRegistry dataLoaderRegistry;
21+
1822
public GraphQLContext(HttpServletRequest httpServletRequest, HandshakeRequest handshakeRequest, Subject subject) {
1923
this.httpServletRequest = httpServletRequest;
2024
this.handshakeRequest = handshakeRequest;
@@ -52,4 +56,12 @@ public Optional<Map<String, List<Part>>> getFiles() {
5256
public void setFiles(Map<String, List<Part>> files) {
5357
this.files = files;
5458
}
59+
60+
public Optional<DataLoaderRegistry> getDataLoaderRegistry() {
61+
return Optional.ofNullable(dataLoaderRegistry);
62+
}
63+
64+
public void setDataLoaderRegistry(DataLoaderRegistry dataLoaderRegistry) {
65+
this.dataLoaderRegistry = dataLoaderRegistry;
66+
}
5567
}

src/main/java/graphql/servlet/GraphQLQueryInvoker.java

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
import graphql.ExecutionInput;
44
import graphql.ExecutionResult;
55
import graphql.GraphQL;
6+
import graphql.execution.instrumentation.ChainedInstrumentation;
67
import graphql.execution.instrumentation.Instrumentation;
78
import graphql.execution.instrumentation.SimpleInstrumentation;
9+
import graphql.execution.instrumentation.dataloader.DataLoaderDispatcherInstrumentation;
810
import graphql.execution.preparsed.NoOpPreparsedDocumentProvider;
911
import graphql.execution.preparsed.PreparsedDocumentProvider;
1012
import graphql.schema.GraphQLSchema;
@@ -13,7 +15,9 @@
1315
import javax.security.auth.Subject;
1416
import java.security.AccessController;
1517
import java.security.PrivilegedAction;
18+
import java.util.ArrayList;
1619
import java.util.Iterator;
20+
import java.util.List;
1721
import java.util.function.Supplier;
1822

1923
/**
@@ -44,17 +48,32 @@ public void query(GraphQLBatchedInvocationInput batchedInvocationInput, Executio
4448
}
4549
}
4650

47-
private GraphQL newGraphQL(GraphQLSchema schema) {
51+
private GraphQL newGraphQL(GraphQLSchema schema, Object context) {
4852
ExecutionStrategyProvider executionStrategyProvider = getExecutionStrategyProvider.get();
4953
return GraphQL.newGraphQL(schema)
5054
.queryExecutionStrategy(executionStrategyProvider.getQueryExecutionStrategy())
5155
.mutationExecutionStrategy(executionStrategyProvider.getMutationExecutionStrategy())
5256
.subscriptionExecutionStrategy(executionStrategyProvider.getSubscriptionExecutionStrategy())
53-
.instrumentation(getInstrumentation.get())
57+
.instrumentation(getInstrumentation(context))
5458
.preparsedDocumentProvider(getPreparsedDocumentProvider.get())
5559
.build();
5660
}
5761

62+
protected Instrumentation getInstrumentation(Object context) {
63+
if (context instanceof GraphQLContext) {
64+
return ((GraphQLContext) context).getDataLoaderRegistry()
65+
.map(registry -> {
66+
List<Instrumentation> instrumentations = new ArrayList<>();
67+
instrumentations.add(getInstrumentation.get());
68+
instrumentations.add(new DataLoaderDispatcherInstrumentation(registry));
69+
return new ChainedInstrumentation(instrumentations);
70+
})
71+
.map(Instrumentation.class::cast)
72+
.orElse(getInstrumentation.get());
73+
}
74+
return getInstrumentation.get();
75+
}
76+
5877
private ExecutionResult query(GraphQLInvocationInput invocationInput, ExecutionInput executionInput) {
5978
if (Subject.getSubject(AccessController.getContext()) == null && invocationInput.getSubject().isPresent()) {
6079
return Subject.doAs(invocationInput.getSubject().get(), (PrivilegedAction<ExecutionResult>) () -> {
@@ -70,7 +89,7 @@ private ExecutionResult query(GraphQLInvocationInput invocationInput, ExecutionI
7089
}
7190

7291
private ExecutionResult query(GraphQLSchema schema, ExecutionInput executionInput) {
73-
return newGraphQL(schema).execute(executionInput);
92+
return newGraphQL(schema, executionInput.getContext()).execute(executionInput);
7493
}
7594

7695
public static Builder newBuilder() {

src/main/java/graphql/servlet/OsgiGraphQLHttpServlet.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
package graphql.servlet;
22

3+
import graphql.execution.instrumentation.ChainedInstrumentation;
4+
import graphql.execution.instrumentation.Instrumentation;
5+
import graphql.execution.instrumentation.dataloader.DataLoaderDispatcherInstrumentation;
36
import graphql.execution.preparsed.NoOpPreparsedDocumentProvider;
47
import graphql.execution.preparsed.PreparsedDocumentProvider;
58
import graphql.schema.GraphQLObjectType;

src/test/groovy/graphql/servlet/AbstractGraphQLHttpServletSpec.groovy

Lines changed: 72 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,16 @@ package graphql.servlet
22

33
import com.fasterxml.jackson.databind.ObjectMapper
44
import graphql.Scalars
5+
import graphql.analysis.QueryVisitorInlineFragmentEnvironment
56
import graphql.execution.ExecutionTypeInfo
7+
import graphql.execution.instrumentation.ChainedInstrumentation
8+
import graphql.execution.instrumentation.Instrumentation
69
import graphql.schema.DataFetcher
710
import graphql.schema.GraphQLFieldDefinition
811
import graphql.schema.GraphQLNonNull
912
import graphql.schema.GraphQLObjectType
1013
import graphql.schema.GraphQLSchema
14+
import org.dataloader.DataLoaderRegistry
1115
import org.springframework.mock.web.MockHttpServletRequest
1216
import org.springframework.mock.web.MockHttpServletResponse
1317
import spock.lang.Ignore
@@ -39,39 +43,45 @@ class AbstractGraphQLHttpServletSpec extends Specification {
3943
response = new MockHttpServletResponse()
4044
}
4145

42-
def createServlet(DataFetcher queryDataFetcher = { env -> env.arguments.arg }, DataFetcher mutationDataFetcher = { env -> env.arguments.arg }) {
46+
def createServlet(DataFetcher queryDataFetcher = { env -> env.arguments.arg },
47+
DataFetcher mutationDataFetcher = { env -> env.arguments.arg }) {
48+
return SimpleGraphQLHttpServlet.newBuilder(createGraphQlSchema(queryDataFetcher, mutationDataFetcher)).build()
49+
}
50+
51+
def createGraphQlSchema(DataFetcher queryDataFetcher = { env -> env.arguments.arg },
52+
DataFetcher mutationDataFetcher = { env -> env.arguments.arg }) {
4353
GraphQLObjectType query = GraphQLObjectType.newObject()
44-
.name("Query")
45-
.field { GraphQLFieldDefinition.Builder field ->
46-
field.name("echo")
47-
field.type(Scalars.GraphQLString)
48-
field.argument { argument ->
49-
argument.name("arg")
50-
argument.type(Scalars.GraphQLString)
51-
}
52-
field.dataFetcher(queryDataFetcher)
54+
.name("Query")
55+
.field { GraphQLFieldDefinition.Builder field ->
56+
field.name("echo")
57+
field.type(Scalars.GraphQLString)
58+
field.argument { argument ->
59+
argument.name("arg")
60+
argument.type(Scalars.GraphQLString)
5361
}
54-
.field { GraphQLFieldDefinition.Builder field ->
55-
field.name("returnsNullIncorrectly")
56-
field.type(new GraphQLNonNull(Scalars.GraphQLString))
57-
field.dataFetcher({env -> null})
58-
}
59-
.build()
62+
field.dataFetcher(queryDataFetcher)
63+
}
64+
.field { GraphQLFieldDefinition.Builder field ->
65+
field.name("returnsNullIncorrectly")
66+
field.type(new GraphQLNonNull(Scalars.GraphQLString))
67+
field.dataFetcher({env -> null})
68+
}
69+
.build()
6070

6171
GraphQLObjectType mutation = GraphQLObjectType.newObject()
62-
.name("Mutation")
63-
.field { field ->
64-
field.name("echo")
65-
field.type(Scalars.GraphQLString)
66-
field.argument { argument ->
67-
argument.name("arg")
68-
argument.type(Scalars.GraphQLString)
69-
}
70-
field.dataFetcher(mutationDataFetcher)
72+
.name("Mutation")
73+
.field { field ->
74+
field.name("echo")
75+
field.type(Scalars.GraphQLString)
76+
field.argument { argument ->
77+
argument.name("arg")
78+
argument.type(Scalars.GraphQLString)
7179
}
72-
.build()
80+
field.dataFetcher(mutationDataFetcher)
81+
}
82+
.build()
7383

74-
return SimpleGraphQLHttpServlet.newBuilder(new GraphQLSchema(query, mutation, [query, mutation].toSet())).build()
84+
return new GraphQLSchema(query, mutation, [query, mutation].toSet())
7585
}
7686

7787
Map<String, Object> getResponseContent() {
@@ -852,4 +862,39 @@ class AbstractGraphQLHttpServletSpec extends Specification {
852862
then:
853863
1 * mockInputStream.reset()
854864
}
865+
866+
def "getInstrumentation returns the set Instrumentation if none is provided in the context"() {
867+
868+
setup:
869+
Instrumentation expectedInstrumentation = Mock()
870+
GraphQLContext context = new GraphQLContext(request, null, null)
871+
SimpleGraphQLHttpServlet simpleGraphQLServlet = SimpleGraphQLHttpServlet
872+
.newBuilder(createGraphQlSchema())
873+
.withQueryInvoker(GraphQLQueryInvoker.newBuilder().withInstrumentation(expectedInstrumentation).build())
874+
.build()
875+
when:
876+
Instrumentation actualInstrumentation = simpleGraphQLServlet.getQueryInvoker().getInstrumentation(context)
877+
then:
878+
actualInstrumentation == expectedInstrumentation;
879+
! (actualInstrumentation instanceof ChainedInstrumentation)
880+
881+
}
882+
883+
def "getInstrumentation returns the ChainedInstrumentation if DataLoader provided in context"() {
884+
885+
setup:
886+
Instrumentation servletInstrumentation = Mock()
887+
GraphQLContext context = new GraphQLContext(request, null, null)
888+
DataLoaderRegistry dlr = Mock()
889+
context.setDataLoaderRegistry(dlr)
890+
SimpleGraphQLHttpServlet simpleGraphQLServlet = SimpleGraphQLHttpServlet
891+
.newBuilder(createGraphQlSchema())
892+
.withQueryInvoker(GraphQLQueryInvoker.newBuilder().withInstrumentation(servletInstrumentation).build())
893+
.build();
894+
when:
895+
Instrumentation actualInstrumentation = simpleGraphQLServlet.getQueryInvoker().getInstrumentation(context)
896+
then:
897+
actualInstrumentation instanceof ChainedInstrumentation
898+
actualInstrumentation != servletInstrumentation
899+
}
855900
}

0 commit comments

Comments
 (0)