Skip to content

Introduce Customizable AuthorizationFailureHandler in OAuth2AuthorizationRequestRedirectFilter #14168

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -25,13 +25,15 @@

import org.springframework.core.log.LogMessage;
import org.springframework.http.HttpStatus;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.web.DefaultRedirectStrategy;
import org.springframework.security.web.RedirectStrategy;
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
import org.springframework.security.web.savedrequest.HttpSessionRequestCache;
import org.springframework.security.web.savedrequest.RequestCache;
import org.springframework.security.web.util.ThrowableAnalyzer;
Expand Down Expand Up @@ -97,6 +99,8 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt

private RequestCache requestCache = new HttpSessionRequestCache();

private AuthenticationFailureHandler authenticationFailureHandler = this::unsuccessfulRedirectForAuthorization;

/**
* Constructs an {@code OAuth2AuthorizationRequestRedirectFilter} using the provided
* parameters.
Expand Down Expand Up @@ -163,6 +167,18 @@ public final void setRequestCache(RequestCache requestCache) {
this.requestCache = requestCache;
}

/**
* Sets the {@link AuthenticationFailureHandler} used to handle errors redirecting to
* the Authorization Server's Authorization Endpoint.
* @param authenticationFailureHandler the {@link AuthenticationFailureHandler} used
* to handle errors redirecting to the Authorization Server's Authorization Endpoint
* @since 6.3
*/
public void setAuthenticationFailureHandler(AuthenticationFailureHandler authenticationFailureHandler) {
Assert.notNull(authenticationFailureHandler, "authenticationFailureHandler cannot be null");
this.authenticationFailureHandler = authenticationFailureHandler;
}

@Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
throws ServletException, IOException {
Expand All @@ -174,7 +190,8 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse
}
}
catch (Exception ex) {
this.unsuccessfulRedirectForAuthorization(request, response, ex);
AuthenticationException wrappedException = new OAuth2AuthorizationRequestException(ex);
this.authenticationFailureHandler.onAuthenticationFailure(request, response, wrappedException);
return;
}
try {
Expand All @@ -199,7 +216,8 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse
this.sendRedirectForAuthorization(request, response, authorizationRequest);
}
catch (Exception failed) {
this.unsuccessfulRedirectForAuthorization(request, response, failed);
AuthenticationException wrappedException = new OAuth2AuthorizationRequestException(ex);
this.authenticationFailureHandler.onAuthenticationFailure(request, response, wrappedException);
}
return;
}
Expand All @@ -223,9 +241,10 @@ private void sendRedirectForAuthorization(HttpServletRequest request, HttpServle
}

private void unsuccessfulRedirectForAuthorization(HttpServletRequest request, HttpServletResponse response,
Exception ex) throws IOException {
LogMessage message = LogMessage.format("Authorization Request failed: %s", ex);
if (InvalidClientRegistrationIdException.class.isAssignableFrom(ex.getClass())) {
AuthenticationException ex) throws IOException {
Throwable cause = ex.getCause();
LogMessage message = LogMessage.format("Authorization Request failed: %s", cause);
if (InvalidClientRegistrationIdException.class.isAssignableFrom(cause.getClass())) {
// Log an invalid registrationId at WARN level to allow these errors to be
// tuned separately from other errors
this.logger.warn(message, ex);
Expand All @@ -250,4 +269,12 @@ protected void initExtractorMap() {

}

private static final class OAuth2AuthorizationRequestException extends AuthenticationException {

OAuth2AuthorizationRequestException(Throwable cause) {
super(cause.getMessage(), cause);
}

}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -119,6 +119,11 @@ public void setRequestCacheWhenRequestCacheIsNullThenThrowIllegalArgumentExcepti
assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setRequestCache(null));
}

@Test
public void setAuthenticationFailureHandlerIsNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setAuthenticationFailureHandler(null));
}

@Test
public void doFilterWhenNotAuthorizationRequestThenNextFilter() throws Exception {
String requestUri = "/path";
Expand All @@ -144,6 +149,31 @@ public void doFilterWhenAuthorizationRequestWithInvalidClientThenStatusInternalS
assertThat(response.getErrorMessage()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR.getReasonPhrase());
}

@Test
public void doFilterWhenAuthorizationRequestWithInvalidClientAndCustomFailureHandlerThenCustomError()
throws Exception {
String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI + "/"
+ this.registration1.getRegistrationId() + "-invalid";
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.filter.setAuthenticationFailureHandler((request1, response1, ex) -> {
Throwable cause = ex.getCause();
if (InvalidClientRegistrationIdException.class.isAssignableFrom(cause.getClass())) {
response1.sendError(HttpStatus.BAD_REQUEST.value(), HttpStatus.BAD_REQUEST.getReasonPhrase());
}
else {
response1.sendError(HttpStatus.INTERNAL_SERVER_ERROR.value(),
HttpStatus.INTERNAL_SERVER_ERROR.getReasonPhrase());
}
});
this.filter.doFilter(request, response, filterChain);
verifyNoMoreInteractions(filterChain);
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
assertThat(response.getErrorMessage()).isEqualTo(HttpStatus.BAD_REQUEST.getReasonPhrase());
}

@Test
public void doFilterWhenAuthorizationRequestOAuth2LoginThenRedirectForAuthorization() throws Exception {
String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI + "/"
Expand Down