diff --git a/src/main/java/graphql/servlet/ApolloSubscriptionConnectionListener.java b/src/main/java/graphql/servlet/ApolloSubscriptionConnectionListener.java new file mode 100644 index 00000000..16cef03e --- /dev/null +++ b/src/main/java/graphql/servlet/ApolloSubscriptionConnectionListener.java @@ -0,0 +1,17 @@ +package graphql.servlet; + +import java.util.Optional; + +public interface ApolloSubscriptionConnectionListener extends SubscriptionConnectionListener { + + String CONNECT_RESULT_KEY = "CONNECT_RESULT"; + + default boolean isKeepAliveEnabled() { + return false; + } + + default Optional onConnect(Object payload) { + return Optional.empty(); + } + +} diff --git a/src/main/java/graphql/servlet/DefaultGraphQLContextBuilder.java b/src/main/java/graphql/servlet/DefaultGraphQLContextBuilder.java index 949cab7d..8208d202 100644 --- a/src/main/java/graphql/servlet/DefaultGraphQLContextBuilder.java +++ b/src/main/java/graphql/servlet/DefaultGraphQLContextBuilder.java @@ -2,6 +2,7 @@ import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; +import javax.websocket.Session; import javax.websocket.server.HandshakeRequest; public class DefaultGraphQLContextBuilder implements GraphQLContextBuilder { @@ -12,8 +13,8 @@ public GraphQLContext build(HttpServletRequest httpServletRequest, HttpServletRe } @Override - public GraphQLContext build(HandshakeRequest handshakeRequest) { - return new GraphQLContext(handshakeRequest); + public GraphQLContext build(Session session, HandshakeRequest handshakeRequest) { + return new GraphQLContext(session, handshakeRequest); } @Override diff --git a/src/main/java/graphql/servlet/GraphQLContext.java b/src/main/java/graphql/servlet/GraphQLContext.java index ab4129b8..c8c454dd 100644 --- a/src/main/java/graphql/servlet/GraphQLContext.java +++ b/src/main/java/graphql/servlet/GraphQLContext.java @@ -6,14 +6,17 @@ import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import javax.servlet.http.Part; +import javax.websocket.Session; import javax.websocket.server.HandshakeRequest; import java.util.List; import java.util.Map; import java.util.Optional; public class GraphQLContext { + private HttpServletRequest httpServletRequest; private HttpServletResponse httpServletResponse; + private Session session; private HandshakeRequest handshakeRequest; private Subject subject; @@ -21,9 +24,10 @@ public class GraphQLContext { private DataLoaderRegistry dataLoaderRegistry; - public GraphQLContext(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse, HandshakeRequest handshakeRequest, Subject subject) { + public GraphQLContext(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse, Session session, HandshakeRequest handshakeRequest, Subject subject) { this.httpServletRequest = httpServletRequest; this.httpServletResponse = httpServletResponse; + this.session = session; this.handshakeRequest = handshakeRequest; this.subject = subject; } @@ -33,27 +37,40 @@ public GraphQLContext(HttpServletRequest httpServletRequest) { } public GraphQLContext(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) { - this(httpServletRequest, httpServletResponse, null, null); + this(httpServletRequest, httpServletResponse, null, null, null); } - public GraphQLContext(HandshakeRequest handshakeRequest) { - this(null, null, handshakeRequest, null); + public GraphQLContext(Session session, HandshakeRequest handshakeRequest) { + this(null, null, session, handshakeRequest, null); } public GraphQLContext() { - this(null, null, null, null); + this(null, null, null, null, null); } public Optional getHttpServletRequest() { return Optional.ofNullable(httpServletRequest); } - public Optional getHttpServletResponse() { return Optional.ofNullable(httpServletResponse); } + public Optional getHttpServletResponse() { + return Optional.ofNullable(httpServletResponse); + } public Optional getSubject() { return Optional.ofNullable(subject); } + public Optional getSession() { + return Optional.ofNullable(session); + } + + public Optional getConnectResult() { + if (session != null) { + return Optional.ofNullable(session.getUserProperties().get(ApolloSubscriptionConnectionListener.CONNECT_RESULT_KEY)); + } + return Optional.empty(); + } + public Optional getHandshakeRequest() { return Optional.ofNullable(handshakeRequest); } diff --git a/src/main/java/graphql/servlet/GraphQLContextBuilder.java b/src/main/java/graphql/servlet/GraphQLContextBuilder.java index b948ee0e..1195ef91 100644 --- a/src/main/java/graphql/servlet/GraphQLContextBuilder.java +++ b/src/main/java/graphql/servlet/GraphQLContextBuilder.java @@ -2,11 +2,14 @@ import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; +import javax.websocket.Session; import javax.websocket.server.HandshakeRequest; public interface GraphQLContextBuilder { + GraphQLContext build(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse); - GraphQLContext build(HandshakeRequest handshakeRequest); + + GraphQLContext build(Session session, HandshakeRequest handshakeRequest); /** * Only used for MBean calls. diff --git a/src/main/java/graphql/servlet/GraphQLInvocationInputFactory.java b/src/main/java/graphql/servlet/GraphQLInvocationInputFactory.java index afd47c2c..d8d4dd1e 100644 --- a/src/main/java/graphql/servlet/GraphQLInvocationInputFactory.java +++ b/src/main/java/graphql/servlet/GraphQLInvocationInputFactory.java @@ -5,6 +5,7 @@ import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; +import javax.websocket.Session; import javax.websocket.server.HandshakeRequest; import java.util.List; import java.util.function.Supplier; @@ -70,20 +71,20 @@ private GraphQLBatchedInvocationInput create(List graphQLRequest ); } - public GraphQLSingleInvocationInput create(GraphQLRequest graphQLRequest, HandshakeRequest request) { + public GraphQLSingleInvocationInput create(GraphQLRequest graphQLRequest, Session session, HandshakeRequest request) { return new GraphQLSingleInvocationInput( graphQLRequest, schemaProviderSupplier.get().getSchema(request), - contextBuilderSupplier.get().build(request), + contextBuilderSupplier.get().build(session, request), rootObjectBuilderSupplier.get().build(request) ); } - public GraphQLBatchedInvocationInput create(List graphQLRequest, HandshakeRequest request) { + public GraphQLBatchedInvocationInput create(List graphQLRequest, Session session, HandshakeRequest request) { return new GraphQLBatchedInvocationInput( graphQLRequest, schemaProviderSupplier.get().getSchema(request), - contextBuilderSupplier.get().build(request), + contextBuilderSupplier.get().build(session, request), rootObjectBuilderSupplier.get().build(request) ); } diff --git a/src/main/java/graphql/servlet/GraphQLWebsocketServlet.java b/src/main/java/graphql/servlet/GraphQLWebsocketServlet.java index 6d22db90..9c341525 100644 --- a/src/main/java/graphql/servlet/GraphQLWebsocketServlet.java +++ b/src/main/java/graphql/servlet/GraphQLWebsocketServlet.java @@ -10,7 +10,6 @@ import java.io.IOException; import java.util.Collections; import java.util.HashMap; -import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.concurrent.atomic.AtomicBoolean; @@ -48,7 +47,11 @@ public class GraphQLWebsocketServlet extends Endpoint { private final Object cacheLock = new Object(); public GraphQLWebsocketServlet(GraphQLQueryInvoker queryInvoker, GraphQLInvocationInputFactory invocationInputFactory, GraphQLObjectMapper graphQLObjectMapper) { - this.subscriptionHandlerInput = new SubscriptionHandlerInput(invocationInputFactory, queryInvoker, graphQLObjectMapper); + this(queryInvoker, invocationInputFactory, graphQLObjectMapper, null); + } + + public GraphQLWebsocketServlet(GraphQLQueryInvoker queryInvoker, GraphQLInvocationInputFactory invocationInputFactory, GraphQLObjectMapper graphQLObjectMapper, SubscriptionConnectionListener subscriptionConnectionListener) { + this.subscriptionHandlerInput = new SubscriptionHandlerInput(invocationInputFactory, queryInvoker, graphQLObjectMapper, subscriptionConnectionListener); } @Override diff --git a/src/main/java/graphql/servlet/SubscriptionConnectionListener.java b/src/main/java/graphql/servlet/SubscriptionConnectionListener.java new file mode 100644 index 00000000..e381523d --- /dev/null +++ b/src/main/java/graphql/servlet/SubscriptionConnectionListener.java @@ -0,0 +1,7 @@ +package graphql.servlet; + +/** + * Marker interface + */ +public interface SubscriptionConnectionListener { +} diff --git a/src/main/java/graphql/servlet/internal/ApolloSubscriptionProtocolHandler.java b/src/main/java/graphql/servlet/internal/ApolloSubscriptionProtocolHandler.java index dd0e29b8..09f8a6c6 100644 --- a/src/main/java/graphql/servlet/internal/ApolloSubscriptionProtocolHandler.java +++ b/src/main/java/graphql/servlet/internal/ApolloSubscriptionProtocolHandler.java @@ -4,6 +4,8 @@ import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonValue; import graphql.ExecutionResult; +import graphql.servlet.ApolloSubscriptionConnectionListener; +import graphql.servlet.GraphQLSingleInvocationInput; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -13,6 +15,7 @@ import java.io.IOException; import java.util.HashMap; import java.util.Map; +import java.util.Optional; import static graphql.servlet.internal.ApolloSubscriptionProtocolHandler.OperationMessage.Type.GQL_COMPLETE; import static graphql.servlet.internal.ApolloSubscriptionProtocolHandler.OperationMessage.Type.GQL_CONNECTION_TERMINATE; @@ -30,9 +33,14 @@ public class ApolloSubscriptionProtocolHandler extends SubscriptionProtocolHandl private static final CloseReason TERMINATE_CLOSE_REASON = new CloseReason(CloseReason.CloseCodes.NORMAL_CLOSURE, "client requested " + GQL_CONNECTION_TERMINATE.getType()); private final SubscriptionHandlerInput input; + private final ApolloSubscriptionConnectionListener connectionListener; public ApolloSubscriptionProtocolHandler(SubscriptionHandlerInput subscriptionHandlerInput) { this.input = subscriptionHandlerInput; + this.connectionListener = subscriptionHandlerInput.getSubscriptionConnectionListener() + .filter(ApolloSubscriptionConnectionListener.class::isInstance) + .map(ApolloSubscriptionConnectionListener.class::cast) + .orElse(new ApolloSubscriptionConnectionListener() {}); } @Override @@ -48,19 +56,28 @@ public void onMessage(HandshakeRequest request, Session session, WsSessionSubscr switch(message.getType()) { case GQL_CONNECTION_INIT: + try { + Optional connectionResponse = connectionListener.onConnect(message.getPayload()); + connectionResponse.ifPresent(it -> session.getUserProperties().put(ApolloSubscriptionConnectionListener.CONNECT_RESULT_KEY, it)); + } catch (Throwable t) { + sendMessage(session, OperationMessage.Type.GQL_CONNECTION_ERROR, t.getMessage()); + return; + } + sendMessage(session, OperationMessage.Type.GQL_CONNECTION_ACK, message.getId()); - sendMessage(session, OperationMessage.Type.GQL_CONNECTION_KEEP_ALIVE, message.getId()); + + if (connectionListener.isKeepAliveEnabled()) { + sendMessage(session, OperationMessage.Type.GQL_CONNECTION_KEEP_ALIVE, message.getId()); + } break; case GQL_START: + GraphQLSingleInvocationInput graphQLSingleInvocationInput = createInvocationInput(session, message); handleSubscriptionStart( session, subscriptions, message.id, - input.getQueryInvoker().query(input.getInvocationInputFactory().create( - input.getGraphQLObjectMapper().getJacksonMapper().convertValue(message.payload, GraphQLRequest.class), - (HandshakeRequest) session.getUserProperties().get(HandshakeRequest.class.getName()) - )) + input.getQueryInvoker().query(graphQLSingleInvocationInput) ); break; @@ -81,6 +98,16 @@ public void onMessage(HandshakeRequest request, Session session, WsSessionSubscr } } + private GraphQLSingleInvocationInput createInvocationInput(Session session, OperationMessage message) { + GraphQLRequest graphQLRequest = input.getGraphQLObjectMapper() + .getJacksonMapper() + .convertValue(message.getPayload(), GraphQLRequest.class); + HandshakeRequest handshakeRequest = (HandshakeRequest) session.getUserProperties() + .get(HandshakeRequest.class.getName()); + + return input.getInvocationInputFactory().create(graphQLRequest, session, handshakeRequest); + } + @SuppressWarnings("unchecked") private void handleSubscriptionStart(Session session, WsSessionSubscriptions subscriptions, String id, ExecutionResult executionResult) { executionResult = input.getGraphQLObjectMapper().sanitizeErrors(executionResult); diff --git a/src/main/java/graphql/servlet/internal/SubscriptionHandlerInput.java b/src/main/java/graphql/servlet/internal/SubscriptionHandlerInput.java index 5bc1a3f8..a6fae095 100644 --- a/src/main/java/graphql/servlet/internal/SubscriptionHandlerInput.java +++ b/src/main/java/graphql/servlet/internal/SubscriptionHandlerInput.java @@ -3,17 +3,22 @@ import graphql.servlet.GraphQLInvocationInputFactory; import graphql.servlet.GraphQLObjectMapper; import graphql.servlet.GraphQLQueryInvoker; +import graphql.servlet.SubscriptionConnectionListener; + +import java.util.Optional; public class SubscriptionHandlerInput { private final GraphQLInvocationInputFactory invocationInputFactory; private final GraphQLQueryInvoker queryInvoker; private final GraphQLObjectMapper graphQLObjectMapper; + private final SubscriptionConnectionListener subscriptionConnectionListener; - public SubscriptionHandlerInput(GraphQLInvocationInputFactory invocationInputFactory, GraphQLQueryInvoker queryInvoker, GraphQLObjectMapper graphQLObjectMapper) { + public SubscriptionHandlerInput(GraphQLInvocationInputFactory invocationInputFactory, GraphQLQueryInvoker queryInvoker, GraphQLObjectMapper graphQLObjectMapper, SubscriptionConnectionListener subscriptionConnectionListener) { this.invocationInputFactory = invocationInputFactory; this.queryInvoker = queryInvoker; this.graphQLObjectMapper = graphQLObjectMapper; + this.subscriptionConnectionListener = subscriptionConnectionListener; } public GraphQLInvocationInputFactory getInvocationInputFactory() { @@ -27,4 +32,8 @@ public GraphQLQueryInvoker getQueryInvoker() { public GraphQLObjectMapper getGraphQLObjectMapper() { return graphQLObjectMapper; } + + public Optional getSubscriptionConnectionListener() { + return Optional.ofNullable(subscriptionConnectionListener); + } } diff --git a/src/test/groovy/graphql/servlet/AbstractGraphQLHttpServletSpec.groovy b/src/test/groovy/graphql/servlet/AbstractGraphQLHttpServletSpec.groovy index 6ec0d05c..0350f1d7 100644 --- a/src/test/groovy/graphql/servlet/AbstractGraphQLHttpServletSpec.groovy +++ b/src/test/groovy/graphql/servlet/AbstractGraphQLHttpServletSpec.groovy @@ -1021,7 +1021,7 @@ class AbstractGraphQLHttpServletSpec extends Specification { setup: Instrumentation expectedInstrumentation = Mock() - GraphQLContext context = new GraphQLContext(request, response, null, null) + GraphQLContext context = new GraphQLContext(request, response, null, null, null) SimpleGraphQLHttpServlet simpleGraphQLServlet = SimpleGraphQLHttpServlet .newBuilder(TestUtils.createGraphQlSchema()) .withQueryInvoker(GraphQLQueryInvoker.newBuilder().withInstrumentation(expectedInstrumentation).build()) @@ -1037,7 +1037,7 @@ class AbstractGraphQLHttpServletSpec extends Specification { def "getInstrumentation returns the ChainedInstrumentation if DataLoader provided in context"() { setup: Instrumentation servletInstrumentation = Mock() - GraphQLContext context = new GraphQLContext(request, response, null, null) + GraphQLContext context = new GraphQLContext(request, response, null, null, null) DataLoaderRegistry dlr = Mock() context.setDataLoaderRegistry(dlr) SimpleGraphQLHttpServlet simpleGraphQLServlet = SimpleGraphQLHttpServlet