diff --git a/src/main/java/graphql/servlet/OsgiGraphQLHttpServlet.java b/src/main/java/graphql/servlet/OsgiGraphQLHttpServlet.java index e1cd9aa6..80d37490 100644 --- a/src/main/java/graphql/servlet/OsgiGraphQLHttpServlet.java +++ b/src/main/java/graphql/servlet/OsgiGraphQLHttpServlet.java @@ -18,6 +18,7 @@ import graphql.servlet.core.DefaultGraphQLRootObjectBuilder; import graphql.servlet.config.DefaultGraphQLSchemaProvider; import graphql.servlet.config.ExecutionStrategyProvider; +import graphql.servlet.config.GraphQLCodeRegistryProvider; import graphql.servlet.context.GraphQLContextBuilder; import graphql.servlet.core.GraphQLErrorHandler; import graphql.servlet.config.GraphQLMutationProvider; @@ -45,6 +46,7 @@ import graphql.execution.preparsed.PreparsedDocumentProvider; import graphql.schema.GraphQLObjectType; import graphql.schema.GraphQLType; +import graphql.schema.GraphQLCodeRegistry; @Component( service={javax.servlet.http.HttpServlet.class,javax.servlet.Servlet.class}, @@ -67,6 +69,7 @@ public class OsgiGraphQLHttpServlet extends AbstractGraphQLHttpServlet { private InstrumentationProvider instrumentationProvider = new NoOpInstrumentationProvider(); private GraphQLErrorHandler errorHandler = new DefaultGraphQLErrorHandler(); private PreparsedDocumentProvider preparsedDocumentProvider = NoOpPreparsedDocumentProvider.INSTANCE; + private GraphQLCodeRegistryProvider codeRegistryProvider = () -> GraphQLCodeRegistry.newCodeRegistry().build(); private GraphQLSchemaProvider schemaProvider; @@ -191,6 +194,7 @@ private void doUpdateSchema() { .mutation(mutationType) .subscription(subscriptionType) .additionalTypes(types) + .codeRegistry(codeRegistryProvider.getCodeRegistry()) .build()); } @@ -208,6 +212,9 @@ public void bindProvider(GraphQLProvider provider) { if (provider instanceof GraphQLTypesProvider) { typesProviders.add((GraphQLTypesProvider) provider); } + if (provider instanceof GraphQLCodeRegistryProvider) { + codeRegistryProvider = (GraphQLCodeRegistryProvider) provider; + } updateSchema(); } public void unbindProvider(GraphQLProvider provider) { @@ -223,6 +230,9 @@ public void unbindProvider(GraphQLProvider provider) { if (provider instanceof GraphQLTypesProvider) { typesProviders.remove(provider); } + if (provider instanceof GraphQLCodeRegistryProvider) { + codeRegistryProvider = () -> GraphQLCodeRegistry.newCodeRegistry().build(); + } updateSchema(); } @@ -322,6 +332,16 @@ public void unsetPreparsedDocumentProvider(PreparsedDocumentProvider preparsedDo this.preparsedDocumentProvider = NoOpPreparsedDocumentProvider.INSTANCE; } + @Reference(cardinality = ReferenceCardinality.OPTIONAL, policy= ReferencePolicy.DYNAMIC, policyOption = ReferencePolicyOption.GREEDY) + public void bindCodeRegistryProvider(GraphQLCodeRegistryProvider graphQLCodeRegistryProvider) { + this.codeRegistryProvider = graphQLCodeRegistryProvider; + updateSchema(); + } + public void unbindCodeRegistryProvider(GraphQLCodeRegistryProvider graphQLCodeRegistryProvider) { + this.codeRegistryProvider = () -> GraphQLCodeRegistry.newCodeRegistry().build(); + updateSchema(); + } + public GraphQLContextBuilder getContextBuilder() { return contextBuilder; } diff --git a/src/main/java/graphql/servlet/config/GraphQLCodeRegistryProvider.java b/src/main/java/graphql/servlet/config/GraphQLCodeRegistryProvider.java new file mode 100644 index 00000000..51e46528 --- /dev/null +++ b/src/main/java/graphql/servlet/config/GraphQLCodeRegistryProvider.java @@ -0,0 +1,7 @@ +package graphql.servlet.config; + +import graphql.schema.GraphQLCodeRegistry; + +public interface GraphQLCodeRegistryProvider extends GraphQLProvider { + GraphQLCodeRegistry getCodeRegistry(); +} diff --git a/src/test/groovy/graphql/servlet/OsgiGraphQLHttpServletSpec.groovy b/src/test/groovy/graphql/servlet/OsgiGraphQLHttpServletSpec.groovy index 7cc25824..1b111ab2 100644 --- a/src/test/groovy/graphql/servlet/OsgiGraphQLHttpServletSpec.groovy +++ b/src/test/groovy/graphql/servlet/OsgiGraphQLHttpServletSpec.groovy @@ -1,9 +1,13 @@ package graphql.servlet +import graphql.AssertException import graphql.annotations.annotationTypes.GraphQLField import graphql.annotations.annotationTypes.GraphQLName import graphql.annotations.processor.GraphQLAnnotations +import graphql.schema.GraphQLCodeRegistry import graphql.schema.GraphQLFieldDefinition +import graphql.schema.GraphQLInterfaceType +import graphql.servlet.config.GraphQLCodeRegistryProvider import graphql.servlet.config.GraphQLMutationProvider import graphql.servlet.config.GraphQLQueryProvider import graphql.servlet.config.GraphQLSubscriptionProvider @@ -122,4 +126,30 @@ class OsgiGraphQLHttpServletSpec extends Specification { then: servlet.getSchemaProvider().getSchema().getSubscriptionType() == null } + + static class TestCodeRegistryProvider implements GraphQLCodeRegistryProvider { + @Override + GraphQLCodeRegistry getCodeRegistry() { + return GraphQLCodeRegistry.newCodeRegistry().typeResolver("Type", { env -> null }).build(); + } + } + + def "code registry provider adds type resolver"() { + setup: + OsgiGraphQLHttpServlet servlet = new OsgiGraphQLHttpServlet() + TestCodeRegistryProvider codeRegistryProvider = new TestCodeRegistryProvider() + + when: + servlet.bindCodeRegistryProvider(codeRegistryProvider) + servlet.getSchemaProvider().getSchema().getCodeRegistry().getTypeResolver(GraphQLInterfaceType.newInterface().name("Type").build()) + then: + notThrown AssertException + + when: + servlet.unbindCodeRegistryProvider(codeRegistryProvider) + servlet.getSchemaProvider().getSchema().getCodeRegistry().getTypeResolver(GraphQLInterfaceType.newInterface().name("Type").build()) + then: + thrown AssertException + + } }