Skip to content

Handle Apollo subscription onConnect #111

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Oct 6, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package graphql.servlet;

import java.util.Optional;

public interface ApolloSubscriptionConnectionListener extends SubscriptionConnectionListener {

String CONNECT_RESULT_KEY = "CONNECT_RESULT";

default boolean isKeepAliveEnabled() {
return false;
}

default Optional<Object> onConnect(Object payload) {
return Optional.empty();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.websocket.Session;
import javax.websocket.server.HandshakeRequest;

public class DefaultGraphQLContextBuilder implements GraphQLContextBuilder {
Expand All @@ -12,8 +13,8 @@ public GraphQLContext build(HttpServletRequest httpServletRequest, HttpServletRe
}

@Override
public GraphQLContext build(HandshakeRequest handshakeRequest) {
return new GraphQLContext(handshakeRequest);
public GraphQLContext build(Session session, HandshakeRequest handshakeRequest) {
return new GraphQLContext(session, handshakeRequest);
}

@Override
Expand Down
29 changes: 23 additions & 6 deletions src/main/java/graphql/servlet/GraphQLContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,28 @@
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.Part;
import javax.websocket.Session;
import javax.websocket.server.HandshakeRequest;
import java.util.List;
import java.util.Map;
import java.util.Optional;

public class GraphQLContext {

private HttpServletRequest httpServletRequest;
private HttpServletResponse httpServletResponse;
private Session session;
private HandshakeRequest handshakeRequest;

private Subject subject;
private Map<String, List<Part>> files;

private DataLoaderRegistry dataLoaderRegistry;

public GraphQLContext(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse, HandshakeRequest handshakeRequest, Subject subject) {
public GraphQLContext(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse, Session session, HandshakeRequest handshakeRequest, Subject subject) {
this.httpServletRequest = httpServletRequest;
this.httpServletResponse = httpServletResponse;
this.session = session;
this.handshakeRequest = handshakeRequest;
this.subject = subject;
}
Expand All @@ -33,27 +37,40 @@ public GraphQLContext(HttpServletRequest httpServletRequest) {
}

public GraphQLContext(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) {
this(httpServletRequest, httpServletResponse, null, null);
this(httpServletRequest, httpServletResponse, null, null, null);
}

public GraphQLContext(HandshakeRequest handshakeRequest) {
this(null, null, handshakeRequest, null);
public GraphQLContext(Session session, HandshakeRequest handshakeRequest) {
this(null, null, session, handshakeRequest, null);
}

public GraphQLContext() {
this(null, null, null, null);
this(null, null, null, null, null);
}

public Optional<HttpServletRequest> getHttpServletRequest() {
return Optional.ofNullable(httpServletRequest);
}

public Optional<HttpServletResponse> getHttpServletResponse() { return Optional.ofNullable(httpServletResponse); }
public Optional<HttpServletResponse> getHttpServletResponse() {
return Optional.ofNullable(httpServletResponse);
}

public Optional<Subject> getSubject() {
return Optional.ofNullable(subject);
}

public Optional<Session> getSession() {
return Optional.ofNullable(session);
}

public Optional<Object> getConnectResult() {
if (session != null) {
return Optional.ofNullable(session.getUserProperties().get(ApolloSubscriptionConnectionListener.CONNECT_RESULT_KEY));
}
return Optional.empty();
}

public Optional<HandshakeRequest> getHandshakeRequest() {
return Optional.ofNullable(handshakeRequest);
}
Expand Down
5 changes: 4 additions & 1 deletion src/main/java/graphql/servlet/GraphQLContextBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.websocket.Session;
import javax.websocket.server.HandshakeRequest;

public interface GraphQLContextBuilder {

GraphQLContext build(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse);
GraphQLContext build(HandshakeRequest handshakeRequest);

GraphQLContext build(Session session, HandshakeRequest handshakeRequest);

/**
* Only used for MBean calls.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.websocket.Session;
import javax.websocket.server.HandshakeRequest;
import java.util.List;
import java.util.function.Supplier;
Expand Down Expand Up @@ -70,20 +71,20 @@ private GraphQLBatchedInvocationInput create(List<GraphQLRequest> graphQLRequest
);
}

public GraphQLSingleInvocationInput create(GraphQLRequest graphQLRequest, HandshakeRequest request) {
public GraphQLSingleInvocationInput create(GraphQLRequest graphQLRequest, Session session, HandshakeRequest request) {
return new GraphQLSingleInvocationInput(
graphQLRequest,
schemaProviderSupplier.get().getSchema(request),
contextBuilderSupplier.get().build(request),
contextBuilderSupplier.get().build(session, request),
rootObjectBuilderSupplier.get().build(request)
);
}

public GraphQLBatchedInvocationInput create(List<GraphQLRequest> graphQLRequest, HandshakeRequest request) {
public GraphQLBatchedInvocationInput create(List<GraphQLRequest> graphQLRequest, Session session, HandshakeRequest request) {
return new GraphQLBatchedInvocationInput(
graphQLRequest,
schemaProviderSupplier.get().getSchema(request),
contextBuilderSupplier.get().build(request),
contextBuilderSupplier.get().build(session, request),
rootObjectBuilderSupplier.get().build(request)
);
}
Expand Down
7 changes: 5 additions & 2 deletions src/main/java/graphql/servlet/GraphQLWebsocketServlet.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
Expand Down Expand Up @@ -48,7 +47,11 @@ public class GraphQLWebsocketServlet extends Endpoint {
private final Object cacheLock = new Object();

public GraphQLWebsocketServlet(GraphQLQueryInvoker queryInvoker, GraphQLInvocationInputFactory invocationInputFactory, GraphQLObjectMapper graphQLObjectMapper) {
this.subscriptionHandlerInput = new SubscriptionHandlerInput(invocationInputFactory, queryInvoker, graphQLObjectMapper);
this(queryInvoker, invocationInputFactory, graphQLObjectMapper, null);
}

public GraphQLWebsocketServlet(GraphQLQueryInvoker queryInvoker, GraphQLInvocationInputFactory invocationInputFactory, GraphQLObjectMapper graphQLObjectMapper, SubscriptionConnectionListener subscriptionConnectionListener) {
this.subscriptionHandlerInput = new SubscriptionHandlerInput(invocationInputFactory, queryInvoker, graphQLObjectMapper, subscriptionConnectionListener);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package graphql.servlet;

/**
* Marker interface
*/
public interface SubscriptionConnectionListener {
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonValue;
import graphql.ExecutionResult;
import graphql.servlet.ApolloSubscriptionConnectionListener;
import graphql.servlet.GraphQLSingleInvocationInput;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -13,6 +15,7 @@
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;

import static graphql.servlet.internal.ApolloSubscriptionProtocolHandler.OperationMessage.Type.GQL_COMPLETE;
import static graphql.servlet.internal.ApolloSubscriptionProtocolHandler.OperationMessage.Type.GQL_CONNECTION_TERMINATE;
Expand All @@ -30,9 +33,14 @@ public class ApolloSubscriptionProtocolHandler extends SubscriptionProtocolHandl
private static final CloseReason TERMINATE_CLOSE_REASON = new CloseReason(CloseReason.CloseCodes.NORMAL_CLOSURE, "client requested " + GQL_CONNECTION_TERMINATE.getType());

private final SubscriptionHandlerInput input;
private final ApolloSubscriptionConnectionListener connectionListener;

public ApolloSubscriptionProtocolHandler(SubscriptionHandlerInput subscriptionHandlerInput) {
this.input = subscriptionHandlerInput;
this.connectionListener = subscriptionHandlerInput.getSubscriptionConnectionListener()
.filter(ApolloSubscriptionConnectionListener.class::isInstance)
.map(ApolloSubscriptionConnectionListener.class::cast)
.orElse(new ApolloSubscriptionConnectionListener() {});
}

@Override
Expand All @@ -48,19 +56,28 @@ public void onMessage(HandshakeRequest request, Session session, WsSessionSubscr

switch(message.getType()) {
case GQL_CONNECTION_INIT:
try {
Optional<Object> connectionResponse = connectionListener.onConnect(message.getPayload());
connectionResponse.ifPresent(it -> session.getUserProperties().put(ApolloSubscriptionConnectionListener.CONNECT_RESULT_KEY, it));
} catch (Throwable t) {
sendMessage(session, OperationMessage.Type.GQL_CONNECTION_ERROR, t.getMessage());
return;
}

sendMessage(session, OperationMessage.Type.GQL_CONNECTION_ACK, message.getId());
sendMessage(session, OperationMessage.Type.GQL_CONNECTION_KEEP_ALIVE, message.getId());

if (connectionListener.isKeepAliveEnabled()) {
sendMessage(session, OperationMessage.Type.GQL_CONNECTION_KEEP_ALIVE, message.getId());
}
break;

case GQL_START:
GraphQLSingleInvocationInput graphQLSingleInvocationInput = createInvocationInput(session, message);
handleSubscriptionStart(
session,
subscriptions,
message.id,
input.getQueryInvoker().query(input.getInvocationInputFactory().create(
input.getGraphQLObjectMapper().getJacksonMapper().convertValue(message.payload, GraphQLRequest.class),
(HandshakeRequest) session.getUserProperties().get(HandshakeRequest.class.getName())
))
input.getQueryInvoker().query(graphQLSingleInvocationInput)
);
break;

Expand All @@ -81,6 +98,16 @@ public void onMessage(HandshakeRequest request, Session session, WsSessionSubscr
}
}

private GraphQLSingleInvocationInput createInvocationInput(Session session, OperationMessage message) {
GraphQLRequest graphQLRequest = input.getGraphQLObjectMapper()
.getJacksonMapper()
.convertValue(message.getPayload(), GraphQLRequest.class);
HandshakeRequest handshakeRequest = (HandshakeRequest) session.getUserProperties()
.get(HandshakeRequest.class.getName());

return input.getInvocationInputFactory().create(graphQLRequest, session, handshakeRequest);
}

@SuppressWarnings("unchecked")
private void handleSubscriptionStart(Session session, WsSessionSubscriptions subscriptions, String id, ExecutionResult executionResult) {
executionResult = input.getGraphQLObjectMapper().sanitizeErrors(executionResult);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,22 @@
import graphql.servlet.GraphQLInvocationInputFactory;
import graphql.servlet.GraphQLObjectMapper;
import graphql.servlet.GraphQLQueryInvoker;
import graphql.servlet.SubscriptionConnectionListener;

import java.util.Optional;

public class SubscriptionHandlerInput {

private final GraphQLInvocationInputFactory invocationInputFactory;
private final GraphQLQueryInvoker queryInvoker;
private final GraphQLObjectMapper graphQLObjectMapper;
private final SubscriptionConnectionListener subscriptionConnectionListener;

public SubscriptionHandlerInput(GraphQLInvocationInputFactory invocationInputFactory, GraphQLQueryInvoker queryInvoker, GraphQLObjectMapper graphQLObjectMapper) {
public SubscriptionHandlerInput(GraphQLInvocationInputFactory invocationInputFactory, GraphQLQueryInvoker queryInvoker, GraphQLObjectMapper graphQLObjectMapper, SubscriptionConnectionListener subscriptionConnectionListener) {
this.invocationInputFactory = invocationInputFactory;
this.queryInvoker = queryInvoker;
this.graphQLObjectMapper = graphQLObjectMapper;
this.subscriptionConnectionListener = subscriptionConnectionListener;
}

public GraphQLInvocationInputFactory getInvocationInputFactory() {
Expand All @@ -27,4 +32,8 @@ public GraphQLQueryInvoker getQueryInvoker() {
public GraphQLObjectMapper getGraphQLObjectMapper() {
return graphQLObjectMapper;
}

public Optional<SubscriptionConnectionListener> getSubscriptionConnectionListener() {
return Optional.ofNullable(subscriptionConnectionListener);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1021,7 +1021,7 @@ class AbstractGraphQLHttpServletSpec extends Specification {

setup:
Instrumentation expectedInstrumentation = Mock()
GraphQLContext context = new GraphQLContext(request, response, null, null)
GraphQLContext context = new GraphQLContext(request, response, null, null, null)
SimpleGraphQLHttpServlet simpleGraphQLServlet = SimpleGraphQLHttpServlet
.newBuilder(TestUtils.createGraphQlSchema())
.withQueryInvoker(GraphQLQueryInvoker.newBuilder().withInstrumentation(expectedInstrumentation).build())
Expand All @@ -1037,7 +1037,7 @@ class AbstractGraphQLHttpServletSpec extends Specification {
def "getInstrumentation returns the ChainedInstrumentation if DataLoader provided in context"() {
setup:
Instrumentation servletInstrumentation = Mock()
GraphQLContext context = new GraphQLContext(request, response, null, null)
GraphQLContext context = new GraphQLContext(request, response, null, null, null)
DataLoaderRegistry dlr = Mock()
context.setDataLoaderRegistry(dlr)
SimpleGraphQLHttpServlet simpleGraphQLServlet = SimpleGraphQLHttpServlet
Expand Down