Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,9 @@ public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse res

if (principal != null) {
attributes.put("user", principal);
return true;
}

return false;
return true;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,21 @@ public class TokenAuthenticationFilter extends OncePerRequestFilter {
public void doFilterInternal(@NonNull HttpServletRequest request,
@NonNull HttpServletResponse response,
@NonNull FilterChain filterChain) throws ServletException, IOException {
Optional<String> token = authorizationHeaderParser.parseToken(request);
if (token.isEmpty()) {
filterChain.doFilter(request, response);
return;
}
Optional<String> resolvedToken = resolveToken(request);

TokenAuthentication authToken = new TokenAuthentication(token.get());
Authentication auth = authenticationManager.authenticate(authToken);
SecurityContextHolder.getContext().setAuthentication(auth);
resolvedToken.filter(token -> !token.isBlank()).ifPresent(token -> {
TokenAuthentication authToken = new TokenAuthentication(token);
Authentication auth = authenticationManager.authenticate(authToken);
SecurityContextHolder.getContext().setAuthentication(auth);
});

filterChain.doFilter(request, response);
}

private Optional<String> resolveToken(HttpServletRequest request) {
if (request.getRequestURI().startsWith("/connect")) {
return Optional.ofNullable(request.getParameter("token"));
}
return authorizationHeaderParser.parseToken(request);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,16 @@

import static java.util.concurrent.TimeUnit.SECONDS;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.ThrowableAssert.catchThrowable;
import static org.junit.jupiter.api.Assertions.assertAll;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

import com.example.solidconnection.auth.service.AccessToken;
import com.example.solidconnection.auth.service.AuthTokenProvider;
import com.example.solidconnection.siteuser.domain.SiteUser;
import com.example.solidconnection.siteuser.fixture.SiteUserFixture;
import com.example.solidconnection.support.TestContainerSpringBootTest;
import java.util.List;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.DisplayName;
Expand All @@ -22,11 +20,9 @@
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.web.server.LocalServerPort;
import org.springframework.messaging.converter.MappingJackson2MessageConverter;
import org.springframework.messaging.simp.stomp.StompHeaders;
import org.springframework.messaging.simp.stomp.StompSession;
import org.springframework.messaging.simp.stomp.StompSessionHandlerAdapter;
import org.springframework.web.client.HttpClientErrorException;
import org.springframework.web.socket.WebSocketHttpHeaders;
import org.springframework.web.socket.client.standard.StandardWebSocketClient;
import org.springframework.web.socket.messaging.WebSocketStompClient;
import org.springframework.web.socket.sockjs.client.SockJsClient;
Expand Down Expand Up @@ -67,48 +63,29 @@ void tearDown() {
@Nested
class WebSocket_핸드셰이크_및_STOMP_세션_수립_테스트 {

private final BlockingQueue<Throwable> transportErrorQueue = new ArrayBlockingQueue<>(1);

private final StompSessionHandlerAdapter sessionHandler = new StompSessionHandlerAdapter() {
@Override
public void handleTransportError(StompSession session, Throwable exception) {
transportErrorQueue.add(exception);
}
};

@Test
void 인증된_사용자는_핸드셰이크를_성공한다() throws Exception {
// given
SiteUser user = siteUserFixture.사용자();
AccessToken accessToken = authTokenProvider.generateAccessToken(user);

WebSocketHttpHeaders handshakeHeaders = new WebSocketHttpHeaders();
handshakeHeaders.add("Authorization", "Bearer " + accessToken.token());
String tokenUrl = url + "?token=" + accessToken.token();

// when
stompSession = stompClient.connectAsync(url, handshakeHeaders, new StompHeaders(), sessionHandler).get(5, SECONDS);
stompSession = stompClient.connectAsync(tokenUrl, new StompSessionHandlerAdapter() {
}).get(5, SECONDS);

// then
assertAll(
() -> assertThat(stompSession).isNotNull(),
() -> assertThat(transportErrorQueue).isEmpty()
);
assertThat(stompSession.isConnected()).isTrue();
}

@Test
void 인증되지_않은_사용자는_핸드셰이크를_실패한다() {
// when
Throwable thrown = catchThrowable(() -> {
stompSession = stompClient.connectAsync(url, new WebSocketHttpHeaders(), new StompHeaders(), sessionHandler).get(5, SECONDS);
});

// then
assertAll(
() -> assertThat(thrown)
.isInstanceOf(ExecutionException.class)
.hasCauseInstanceOf(HttpClientErrorException.Unauthorized.class),
() -> assertThat(transportErrorQueue).hasSize(1)
);
// when & then
assertThatThrownBy(() -> {
stompClient.connectAsync(url, new StompSessionHandlerAdapter() {
}).get(5, TimeUnit.SECONDS);
}).isInstanceOf(ExecutionException.class)
.hasCauseInstanceOf(HttpClientErrorException.Unauthorized.class);
}
}
}
Loading