diff --git a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/model/AwsProxyRequest.java b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/model/AwsProxyRequest.java index e9b141ec1..d8c594826 100644 --- a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/model/AwsProxyRequest.java +++ b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/model/AwsProxyRequest.java @@ -15,6 +15,7 @@ import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import java.util.HashMap; import java.util.Map; /** @@ -31,7 +32,7 @@ public class AwsProxyRequest { private String resource; private ApiGatewayRequestContext requestContext; private Map queryStringParameters; - private Map headers; + private Map headers = new HashMap<>(); // avoid NPE private Map pathParameters; private String httpMethod; private Map stageVariables; @@ -105,7 +106,11 @@ public Map getHeaders() { public void setHeaders(Map headers) { - this.headers = headers; + if (null != headers) { + this.headers = headers; + } else { + this.headers.clear(); + } } diff --git a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsHttpServletRequest.java b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsHttpServletRequest.java index 15ad74dac..c4d85dd2d 100644 --- a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsHttpServletRequest.java +++ b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsHttpServletRequest.java @@ -12,6 +12,8 @@ */ package com.amazonaws.serverless.proxy.internal.servlet; +import com.amazonaws.serverless.proxy.internal.RequestReader; +import com.amazonaws.serverless.proxy.internal.model.ApiGatewayRequestContext; import com.amazonaws.serverless.proxy.internal.model.ContainerConfig; import com.amazonaws.services.lambda.runtime.Context; @@ -24,6 +26,7 @@ import javax.servlet.http.Cookie; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpSession; +import javax.servlet.http.HttpSessionContext; import java.io.UnsupportedEncodingException; import java.net.URLDecoder; import java.net.URLEncoder; @@ -68,6 +71,7 @@ public abstract class AwsHttpServletRequest implements HttpServletRequest { private Context lambdaContext; private Map attributes; private ServletContext servletContext; + private AwsHttpSession session; protected DispatcherType dispatcherType; @@ -101,13 +105,17 @@ public String getRequestedSessionId() { @Override public HttpSession getSession(boolean b) { - return null; + if (b && null == this.session) { + ApiGatewayRequestContext requestContext = (ApiGatewayRequestContext) getAttribute(RequestReader.API_GATEWAY_CONTEXT_PROPERTY); + this.session = new AwsHttpSession(requestContext.getRequestId()); + } + return this.session; } @Override public HttpSession getSession() { - return null; + return this.session; } diff --git a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsHttpServletResponse.java b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsHttpServletResponse.java index 4ebfd8938..52b6b80a3 100644 --- a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsHttpServletResponse.java +++ b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsHttpServletResponse.java @@ -53,6 +53,7 @@ public class AwsHttpServletResponse private int statusCode; private String statusMessage; private String responseBody; + private PrintWriter writer; private ByteArrayOutputStream bodyOutputStream = new ByteArrayOutputStream(); private CountDownLatch writersCountDownLatch; private AwsHttpServletRequest request; @@ -316,7 +317,10 @@ public void close() @Override public PrintWriter getWriter() throws IOException { - return new PrintWriter(bodyOutputStream); + if (null == writer) { + writer = new PrintWriter(bodyOutputStream); + } + return writer; } @@ -358,7 +362,11 @@ public int getBufferSize() { @Override public void flushBuffer() throws IOException { + if (null != writer) { + writer.flush(); + } responseBody = new String(bodyOutputStream.toByteArray()); + log.debug("Response buffer flushed with {} bytes, latch={}", responseBody.length(), writersCountDownLatch.getCount()); isCommitted = true; writersCountDownLatch.countDown(); } diff --git a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsHttpSession.java b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsHttpSession.java new file mode 100644 index 000000000..6aca5947c --- /dev/null +++ b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsHttpSession.java @@ -0,0 +1,111 @@ +package com.amazonaws.serverless.proxy.internal.servlet; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.servlet.ServletContext; +import javax.servlet.http.HttpSession; +import javax.servlet.http.HttpSessionContext; +import java.util.Enumeration; + +public class AwsHttpSession implements HttpSession { + + private static final Logger log = LoggerFactory.getLogger(AwsHttpSession.class); + private String id; + + /** + * @param id API gateway request ID. + */ + public AwsHttpSession(String id) { + if (null == id) { + throw new RuntimeException("HTTP session id (from request ID) cannot be null"); + } + log.debug("Creating session " + id); + this.id = id; + } + + @Override + public long getCreationTime() { + return 0; + } + + @Override + public String getId() { + return id; + } + + @Override + public long getLastAccessedTime() { + return 0; + } + + @Override + public ServletContext getServletContext() { + return null; + } + + @Override + public void setMaxInactiveInterval(int interval) { + + } + + @Override + public int getMaxInactiveInterval() { + return 0; + } + + @Override + public HttpSessionContext getSessionContext() { + return null; + } + + @Override + public Object getAttribute(String name) { + return null; + } + + @Override + public Object getValue(String name) { + return null; + } + + @Override + public Enumeration getAttributeNames() { + return null; + } + + @Override + public String[] getValueNames() { + return new String[0]; + } + + @Override + public void setAttribute(String name, Object value) { + + } + + @Override + public void putValue(String name, Object value) { + + } + + @Override + public void removeAttribute(String name) { + + } + + @Override + public void removeValue(String name) { + + } + + @Override + public void invalidate() { + + } + + @Override + public boolean isNew() { + return false; + } +} diff --git a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsLambdaServletContainerHandler.java b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsLambdaServletContainerHandler.java index 39d7806ed..cc6697cf0 100644 --- a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsLambdaServletContainerHandler.java +++ b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsLambdaServletContainerHandler.java @@ -22,6 +22,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import javax.servlet.Servlet; import javax.servlet.ServletContext; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; @@ -176,11 +177,13 @@ protected void setServletContext(final ServletContext context) { * Applies the filter chain in the request lifecycle * @param request The Request object. This must be an implementation of HttpServletRequest * @param response The response object. This must be an implementation of HttpServletResponse + * @param servlet Servlet at the end of the chain (optional). * @throws IOException * @throws ServletException */ - protected void doFilter(ContainerRequestType request, ContainerResponseType response) throws IOException, ServletException { - FilterChainHolder chain = filterChainManager.getFilterChain(request); + protected void doFilter(ContainerRequestType request, ContainerResponseType response, Servlet servlet) throws IOException, ServletException { + FilterChainHolder chain = filterChainManager.getFilterChain(request, servlet); + log.debug("FilterChainHolder.doFilter {}", chain); chain.doFilter(request, response); } diff --git a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsProxyHttpServletRequest.java b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsProxyHttpServletRequest.java index 5f836ef11..7245522ee 100644 --- a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsProxyHttpServletRequest.java +++ b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsProxyHttpServletRequest.java @@ -196,9 +196,13 @@ public String getPathTranslated() { } + /** + * In AWS API Gateway, stage is never given as part of the path. + * @return + */ @Override public String getContextPath() { - return request.getRequestContext().getStage(); + return ""; } @@ -228,7 +232,7 @@ public Principal getUserPrincipal() { @Override public String getRequestURI() { - return request.getPath(); + return (getContextPath().isEmpty() ? "" : "/" + getContextPath()) + request.getPath(); } diff --git a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsServletContext.java b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsServletContext.java index fcdbcb5ea..957af0f01 100644 --- a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsServletContext.java +++ b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsServletContext.java @@ -34,6 +34,7 @@ import java.io.IOException; import java.io.InputStream; import java.net.MalformedURLException; +import java.net.URI; import java.net.URISyntaxException; import java.net.URL; import java.nio.file.Files; @@ -141,9 +142,16 @@ public int getEffectiveMinorVersion() { @Override public String getMimeType(String s) { try { - return Files.probeContentType(Paths.get(s)); + + if (s.startsWith("file:")) { // Support paths such as file:/D:/something/hello.txt + return Files.probeContentType(Paths.get(URI.create(s))); + } else if (s.startsWith("/")) { // Support paths such as file:/D:/something/hello.txt + return Files.probeContentType(Paths.get(URI.create("file://" + s))); + } else { + return Files.probeContentType(Paths.get(s)); + } } catch (IOException e) { - log.warn("Could not find content type for filter", e); + log.warn("Could not find content type for file {}", s, e); return null; } } @@ -364,6 +372,8 @@ public FilterRegistration.Dynamic addFilter(String name, Filter filter) { // filter already exists, we do nothing if (filters.containsKey(name)) { return null; + } else { + log.debug("Adding filter '{}' from {}", name, filter); } FilterHolder newFilter = new FilterHolder(name, filter, this); @@ -376,6 +386,7 @@ public FilterRegistration.Dynamic addFilter(String name, Filter filter) { @Override public FilterRegistration.Dynamic addFilter(String name, Class filterClass) { try { + log.debug("Adding filter '{}' from {}", name, filterClass.getName()); Filter newFilter = createFilter(filterClass); return addFilter(name, newFilter); } catch (ServletException e) { diff --git a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/FilterChainHolder.java b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/FilterChainHolder.java index 47c5ebe6e..eb42cff77 100644 --- a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/FilterChainHolder.java +++ b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/FilterChainHolder.java @@ -15,10 +15,8 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import javax.servlet.FilterChain; -import javax.servlet.ServletException; -import javax.servlet.ServletRequest; -import javax.servlet.ServletResponse; +import javax.servlet.*; +import javax.servlet.http.HttpServletRequest; import java.io.IOException; import java.util.ArrayList; @@ -30,6 +28,7 @@ * during a request lifecycle */ public class FilterChainHolder implements FilterChain { + private final Servlet servlet; //------------------------------------------------------------- // Variables - Private @@ -47,19 +46,22 @@ public class FilterChainHolder implements FilterChain { /** * Creates a new empty FilterChainHolder + * @param servlet */ - FilterChainHolder() { - this(new ArrayList<>()); + FilterChainHolder(Servlet servlet) { + this(new ArrayList<>(), servlet); } /** * Creates a new instance of a filter chain holder * @param allFilters A populated list of FilterHolder objects + * @param servlet */ - FilterChainHolder(List allFilters) { + FilterChainHolder(List allFilters, Servlet servlet) { filters = allFilters; resetHolder(); + this.servlet = servlet; } @@ -70,9 +72,19 @@ public class FilterChainHolder implements FilterChain { @Override public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse) throws IOException, ServletException { currentFilter++; - if (filters == null || filters.size() == 0 || currentFilter > filters.size() - 1) { + if (filters == null || filters.size() == 0 ) { log.debug("Could not find filters to execute, returning"); return; + } else if (currentFilter > filters.size() - 1) { + if (null != servlet) { + log.debug("Starting servlet {}", servlet); + servlet.service(servletRequest, servletResponse); + log.debug("Executed servlet {}", servlet); + return; + } else { + log.debug("No more filters"); + return; + } } // TODO: We do not check for async filters here @@ -82,9 +94,12 @@ public void doFilter(ServletRequest servletRequest, ServletResponse servletRespo if (!holder.isFilterInitialized()) { holder.init(); } - log.debug("Starting filter " + holder.getFilterName()); + log.debug("Starting {} {} : filter {}-{} {}", servletRequest.getDispatcherType(), ((HttpServletRequest) servletRequest).getRequestURI(), + currentFilter, holder.getFilterName(), holder.getFilter()); holder.getFilter().doFilter(servletRequest, servletResponse, this); - log.debug("Executed filter " + holder.getFilterName()); + log.debug("Executed {} {} : filter {}-{} {}", servletRequest.getDispatcherType(), ((HttpServletRequest) servletRequest).getRequestURI(), + currentFilter, holder.getFilterName(), holder.getFilter()); + currentFilter--; } @@ -144,4 +159,9 @@ public List getFilters() { private void resetHolder() { currentFilter = -1; } + + @Override + public String toString() { + return "filters=" + filters + ", servlet=" + servlet; + } } diff --git a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/FilterChainManager.java b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/FilterChainManager.java index 79bbbc75d..6ddfeadfb 100644 --- a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/FilterChainManager.java +++ b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/FilterChainManager.java @@ -13,6 +13,7 @@ package com.amazonaws.serverless.proxy.internal.servlet; import javax.servlet.DispatcherType; +import javax.servlet.Servlet; import javax.servlet.ServletContext; import javax.servlet.http.HttpServletRequest; @@ -81,18 +82,19 @@ public abstract class FilterChainManagerFilterChainHolder object that can be used to apply the filters to the request */ - FilterChainHolder getFilterChain(final HttpServletRequest request) { + FilterChainHolder getFilterChain(final HttpServletRequest request, Servlet servlet) { String targetPath = request.getServletPath(); DispatcherType type = request.getDispatcherType(); // only return the cached result if the filter list hasn't changed in the meanwhile - if (getFilterHolders().size() == filtersSize && getFilterChainCache(type, targetPath) != null) { - return getFilterChainCache(type, targetPath); + if (getFilterHolders().size() == filtersSize && getFilterChainCache(type, targetPath, servlet) != null) { + return getFilterChainCache(type, targetPath, servlet); } - FilterChainHolder chainHolder = new FilterChainHolder(); + FilterChainHolder chainHolder = new FilterChainHolder(servlet); Map registrations = getFilterHolders(); if (registrations == null || registrations.size() == 0) { @@ -134,9 +136,10 @@ FilterChainHolder getFilterChain(final HttpServletRequest request) { * initialized with the cached list of {@link FilterHolder} objects * @param type The dispatcher type for the incoming request * @param targetPath The request path - this is extracted with the getPath method of the request object + * @param servlet Servlet to put at the end of the chain (optional). * @return A populated FilterChainHolder */ - private FilterChainHolder getFilterChainCache(final DispatcherType type, final String targetPath) { + private FilterChainHolder getFilterChainCache(final DispatcherType type, final String targetPath, Servlet servlet) { TargetCacheKey key = new TargetCacheKey(); key.setDispatcherType(type); key.setTargetPath(targetPath); @@ -145,7 +148,7 @@ private FilterChainHolder getFilterChainCache(final DispatcherType type, final S return null; } - return new FilterChainHolder(filterCache.get(key)); + return new FilterChainHolder(filterCache.get(key), servlet); } diff --git a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/FilterHolder.java b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/FilterHolder.java index daa318285..a8446b141 100644 --- a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/FilterHolder.java +++ b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/FilterHolder.java @@ -29,7 +29,7 @@ public class FilterHolder { //------------------------------------------------------------- private Filter filter; - private FilterConfig filterConfig; + private FilterConfig filterConfig = new Config(); private Registration registration; private String filterName; private Map initParameters; diff --git a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/testutils/AwsProxyRequestBuilder.java b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/testutils/AwsProxyRequestBuilder.java index bdfd90fbd..1c9b4d36c 100644 --- a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/testutils/AwsProxyRequestBuilder.java +++ b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/testutils/AwsProxyRequestBuilder.java @@ -62,10 +62,12 @@ public AwsProxyRequestBuilder(String path, String httpMethod) { this.mapper = new ObjectMapper(); this.request = new AwsProxyRequest(); + this.request.setHeaders(new HashMap<>()); // avoid NPE this.request.setHttpMethod(httpMethod); this.request.setPath(path); this.request.setQueryStringParameters(new HashMap<>()); this.request.setRequestContext(new ApiGatewayRequestContext()); + this.request.getRequestContext().setRequestId("test-invoke-request"); this.request.getRequestContext().setStage("test"); ApiGatewayRequestIdentity identity = new ApiGatewayRequestIdentity(); identity.setSourceIp("127.0.0.1"); diff --git a/aws-serverless-java-container-core/src/test/java/com/amazonaws/serverless/proxy/internal/servlet/AwsFilterChainManagerTest.java b/aws-serverless-java-container-core/src/test/java/com/amazonaws/serverless/proxy/internal/servlet/AwsFilterChainManagerTest.java index 84f320b96..c5ed8f248 100644 --- a/aws-serverless-java-container-core/src/test/java/com/amazonaws/serverless/proxy/internal/servlet/AwsFilterChainManagerTest.java +++ b/aws-serverless-java-container-core/src/test/java/com/amazonaws/serverless/proxy/internal/servlet/AwsFilterChainManagerTest.java @@ -115,21 +115,21 @@ public void filterChain_getFilterChain_subsetOfFilters() { new AwsProxyRequestBuilder("/first/second", "GET").build(), lambdaContext, null ); req.setServletContext(servletContext); - FilterChainHolder fcHolder = chainManager.getFilterChain(req); + FilterChainHolder fcHolder = chainManager.getFilterChain(req, null); assertEquals(1, fcHolder.filterCount()); assertEquals("Filter1", fcHolder.getFilter(0).getFilterName()); req = new AwsProxyHttpServletRequest( new AwsProxyRequestBuilder("/second/mime", "GET").build(), lambdaContext, null ); - fcHolder = chainManager.getFilterChain(req); + fcHolder = chainManager.getFilterChain(req, null); assertEquals(1, fcHolder.filterCount()); assertEquals("Filter2", fcHolder.getFilter(0).getFilterName()); req = new AwsProxyHttpServletRequest( new AwsProxyRequestBuilder("/second/mime/third", "GET").build(), lambdaContext, null ); - fcHolder = chainManager.getFilterChain(req); + fcHolder = chainManager.getFilterChain(req, null); assertEquals(1, fcHolder.filterCount()); assertEquals("Filter2", fcHolder.getFilter(0).getFilterName()); } @@ -140,7 +140,7 @@ public void filterChain_matchMultipleTimes_expectSameMatch() { new AwsProxyRequestBuilder("/first/second", "GET").build(), lambdaContext, null ); req.setServletContext(servletContext); - FilterChainHolder fcHolder = chainManager.getFilterChain(req); + FilterChainHolder fcHolder = chainManager.getFilterChain(req, null); assertEquals(1, fcHolder.filterCount()); assertEquals("Filter1", fcHolder.getFilter(0).getFilterName()); @@ -148,7 +148,7 @@ public void filterChain_matchMultipleTimes_expectSameMatch() { new AwsProxyRequestBuilder("/first/second", "GET").build(), lambdaContext, null ); req.setServletContext(servletContext); - FilterChainHolder fcHolder2 = chainManager.getFilterChain(req2); + FilterChainHolder fcHolder2 = chainManager.getFilterChain(req2, null); assertEquals(1, fcHolder2.filterCount()); assertEquals("Filter1", fcHolder2.getFilter(0).getFilterName()); } @@ -159,7 +159,7 @@ public void filerChain_executeMultipleFilters_expectRunEachTime() { new AwsProxyRequestBuilder("/first/second", "GET").build(), lambdaContext, null ); req.setServletContext(servletContext); - FilterChainHolder fcHolder = chainManager.getFilterChain(req); + FilterChainHolder fcHolder = chainManager.getFilterChain(req, null); assertEquals(1, fcHolder.filterCount()); assertEquals("Filter1", fcHolder.getFilter(0).getFilterName()); AwsHttpServletResponse resp = new AwsHttpServletResponse(req, new CountDownLatch(1)); @@ -183,7 +183,7 @@ public void filerChain_executeMultipleFilters_expectRunEachTime() { new AwsProxyRequestBuilder("/first/second", "GET").build(), lambdaContext, null ); req2.setServletContext(servletContext); - FilterChainHolder fcHolder2 = chainManager.getFilterChain(req2); + FilterChainHolder fcHolder2 = chainManager.getFilterChain(req2, null); assertEquals(1, fcHolder2.filterCount()); assertEquals("Filter1", fcHolder2.getFilter(0).getFilterName()); assertEquals(-1, fcHolder2.currentFilter); @@ -212,14 +212,14 @@ public void filterChain_getFilterChain_multipleFilters() { req.setServletContext(servletContext); FilterRegistration.Dynamic reg = req.getServletContext().addFilter("Filter4", new MockFilter()); reg.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), true, "/second/*"); - FilterChainHolder fcHolder = chainManager.getFilterChain(req); + FilterChainHolder fcHolder = chainManager.getFilterChain(req, null); assertEquals(2, fcHolder.filterCount()); assertEquals("Filter2", fcHolder.getFilter(0).getFilterName()); assertEquals("Filter4", fcHolder.getFilter(1).getFilterName()); reg = req.getServletContext().addFilter("Filter5", new MockFilter()); reg.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/second/*"); - fcHolder = chainManager.getFilterChain(req); + fcHolder = chainManager.getFilterChain(req, null); assertEquals(3, fcHolder.filterCount()); assertEquals("Filter2", fcHolder.getFilter(0).getFilterName()); assertEquals("Filter4", fcHolder.getFilter(1).getFilterName()); diff --git a/aws-serverless-java-container-spark/src/main/java/com/amazonaws/serverless/proxy/spark/SparkLambdaContainerHandler.java b/aws-serverless-java-container-spark/src/main/java/com/amazonaws/serverless/proxy/spark/SparkLambdaContainerHandler.java index b4a059d30..6c7aa1a49 100644 --- a/aws-serverless-java-container-spark/src/main/java/com/amazonaws/serverless/proxy/spark/SparkLambdaContainerHandler.java +++ b/aws-serverless-java-container-spark/src/main/java/com/amazonaws/serverless/proxy/spark/SparkLambdaContainerHandler.java @@ -164,7 +164,7 @@ protected void handleRequest(AwsProxyHttpServletRequest httpServletRequest, AwsH } } - doFilter(httpServletRequest, httpServletResponse); + doFilter(httpServletRequest, httpServletResponse, null); embeddedServer.handle(httpServletRequest, httpServletResponse); } diff --git a/aws-serverless-java-container-spring/src/main/java/com/amazonaws/serverless/proxy/spring/SpringBootLambdaContainerHandler.java b/aws-serverless-java-container-spring/src/main/java/com/amazonaws/serverless/proxy/spring/SpringBootLambdaContainerHandler.java new file mode 100644 index 000000000..de7ce509b --- /dev/null +++ b/aws-serverless-java-container-spring/src/main/java/com/amazonaws/serverless/proxy/spring/SpringBootLambdaContainerHandler.java @@ -0,0 +1,254 @@ +/* + * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package com.amazonaws.serverless.proxy.spring; + +import com.amazonaws.serverless.exceptions.ContainerInitializationException; +import com.amazonaws.serverless.proxy.internal.*; +import com.amazonaws.serverless.proxy.internal.model.AwsProxyRequest; +import com.amazonaws.serverless.proxy.internal.model.AwsProxyResponse; +import com.amazonaws.serverless.proxy.internal.servlet.*; +import com.amazonaws.services.lambda.runtime.Context; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.web.SpringServletContainerInitializer; +import org.springframework.web.WebApplicationInitializer; +import org.springframework.web.context.WebApplicationContext; +import org.springframework.web.context.support.WebApplicationContextUtils; +import org.springframework.web.servlet.DispatcherServlet; + +import javax.servlet.*; +import javax.servlet.http.HttpServletResponse; +import java.util.*; +import java.util.concurrent.CountDownLatch; + +/** + * Spring implementation of the `LambdaContainerHandler` abstract class. This class uses the `LambdaSpringApplicationInitializer` + * object behind the scenes to proxy requests. The default implementation leverages the `AwsProxyHttpServletRequest` and + * `AwsHttpServletResponse` implemented in the `aws-serverless-java-container-core` package. + * + * Important: Make sure to add {@link LambdaFlushResponseListener} in your SpringBootServletInitializer subclass configure(). + * + * @param The incoming event type + * @param The expected return type + */ +public class SpringBootLambdaContainerHandler extends AwsLambdaServletContainerHandler { + static ThreadLocal currentResponse = new ThreadLocal<>(); + private final Class springBootInitializer; + private static final Logger log = LoggerFactory.getLogger(SpringBootLambdaContainerHandler.class); + + // State vars + private boolean initialized; + + /** + * Creates a default SpringLambdaContainerHandler initialized with the `AwsProxyRequest` and `AwsProxyResponse` objects + * @param springBootInitializer {@code SpringBootServletInitializer} class + * @return An initialized instance of the `SpringLambdaContainerHandler` + * @throws ContainerInitializationException + */ + public static SpringBootLambdaContainerHandler getAwsProxyHandler(Class springBootInitializer) + throws ContainerInitializationException { + return new SpringBootLambdaContainerHandler<>( + new AwsProxyHttpServletRequestReader(), + new AwsProxyHttpServletResponseWriter(), + new AwsProxySecurityContextWriter(), + new AwsProxyExceptionHandler(), + springBootInitializer + ); + } + + /** + * Creates a new container handler with the given reader and writer objects + * + * @param requestReader An implementation of `RequestReader` + * @param responseWriter An implementation of `ResponseWriter` + * @param securityContextWriter An implementation of `SecurityContextWriter` + * @param exceptionHandler An implementation of `ExceptionHandler` + * @throws ContainerInitializationException + */ + public SpringBootLambdaContainerHandler(RequestReader requestReader, + ResponseWriter responseWriter, + SecurityContextWriter securityContextWriter, + ExceptionHandler exceptionHandler, + Class springBootInitializer) + throws ContainerInitializationException { + super(requestReader, responseWriter, securityContextWriter, exceptionHandler); + this.springBootInitializer = springBootInitializer; + } + + @Override + protected AwsHttpServletResponse getContainerResponse(AwsProxyHttpServletRequest request, CountDownLatch latch) { + return new AwsHttpServletResponse(request, latch); + } + + @Override + protected void handleRequest(AwsProxyHttpServletRequest containerRequest, AwsHttpServletResponse containerResponse, Context lambdaContext) throws Exception { + // this method of the AwsLambdaServletContainerHandler sets the servlet context + if (getServletContext() == null) { + setServletContext(new SpringBootAwsServletContext()); + } + + // wire up the application context on the first invocation + if (!initialized) { + SpringServletContainerInitializer springServletContainerInitializer = new SpringServletContainerInitializer(); + LinkedHashSet> webAppInitializers = new LinkedHashSet<>(); + webAppInitializers.add(springBootInitializer); + springServletContainerInitializer.onStartup(webAppInitializers, getServletContext()); + initialized = true; + } + + containerRequest.setServletContext(getServletContext()); + + currentResponse.set(containerResponse); + try { + WebApplicationContext applicationContext = WebApplicationContextUtils.getRequiredWebApplicationContext(getServletContext()); + DispatcherServlet dispatcherServlet = applicationContext.getBean("dispatcherServlet", DispatcherServlet.class); + // process filters & invoke servlet + log.debug("Process filters & invoke servlet: {}", dispatcherServlet); + doFilter(containerRequest, containerResponse, dispatcherServlet); + } finally { + // call the flush method to release the latch + SpringBootLambdaContainerHandler.currentResponse.remove(); + currentResponse.remove(); + } + } + + private class SpringBootAwsServletContext extends AwsServletContext { + public SpringBootAwsServletContext() { + super(SpringBootLambdaContainerHandler.this); + } + + @Override + public ServletRegistration.Dynamic addServlet(String s, String s1) { + throw new UnsupportedOperationException("Only dispatcherServlet is supported"); + } + + @Override + public ServletRegistration.Dynamic addServlet(String s, Class aClass) { + throw new UnsupportedOperationException("Only dispatcherServlet is supported"); + } + + @Override + public ServletRegistration.Dynamic addServlet(String s, Servlet servlet) { + if ("dispatcherServlet".equals(s)) { + try { + servlet.init(new ServletConfig() { + @Override + public String getServletName() { + return s; + } + + @Override + public ServletContext getServletContext() { + return SpringBootAwsServletContext.this; + } + + @Override + public String getInitParameter(String name) { + return null; + } + + @Override + public Enumeration getInitParameterNames() { + return new Enumeration() { + @Override + public boolean hasMoreElements() { + return false; + } + + @Override + public String nextElement() { + return null; + } + }; + } + }); + } catch (ServletException e) { + throw new RuntimeException("Cannot add servlet " + servlet, e); + } + return new ServletRegistration.Dynamic() { + @Override + public String getName() { + return s; + } + + @Override + public String getClassName() { + return null; + } + + @Override + public boolean setInitParameter(String name, String value) { + return false; + } + + @Override + public String getInitParameter(String name) { + return null; + } + + @Override + public Set setInitParameters(Map initParameters) { + return null; + } + + @Override + public Map getInitParameters() { + return null; + } + + @Override + public Set addMapping(String... urlPatterns) { + return null; + } + + @Override + public Collection getMappings() { + return null; + } + + @Override + public String getRunAsRole() { + return null; + } + + @Override + public void setAsyncSupported(boolean isAsyncSupported) { + + } + + @Override + public void setLoadOnStartup(int loadOnStartup) { + + } + + @Override + public Set setServletSecurity(ServletSecurityElement constraint) { + return null; + } + + @Override + public void setMultipartConfig(MultipartConfigElement multipartConfig) { + + } + + @Override + public void setRunAsRole(String roleName) { + + } + }; + } else { + throw new UnsupportedOperationException("Only dispatcherServlet is supported"); + } + } + } +} diff --git a/aws-serverless-java-container-spring/src/main/java/com/amazonaws/serverless/proxy/spring/SpringLambdaContainerHandler.java b/aws-serverless-java-container-spring/src/main/java/com/amazonaws/serverless/proxy/spring/SpringLambdaContainerHandler.java index f3a932b0b..9c4950ba4 100644 --- a/aws-serverless-java-container-spring/src/main/java/com/amazonaws/serverless/proxy/spring/SpringLambdaContainerHandler.java +++ b/aws-serverless-java-container-spring/src/main/java/com/amazonaws/serverless/proxy/spring/SpringLambdaContainerHandler.java @@ -21,7 +21,6 @@ import org.springframework.web.context.ConfigurableWebApplicationContext; import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; -import javax.servlet.ServletContext; import java.util.Arrays; import java.util.concurrent.CountDownLatch; @@ -136,7 +135,7 @@ protected void handleRequest(AwsProxyHttpServletRequest containerRequest, AwsHtt containerRequest.setServletContext(getServletContext()); // process filters - doFilter(containerRequest, containerResponse); + doFilter(containerRequest, containerResponse, null); // invoke servlet initializer.dispatch(containerRequest, containerResponse); }