Skip to content

Commit 6fe7207

Browse files
committed
Merge branch 'TYsewyn-fix/lb-reconstruct-uri'
2 parents 858e706 + 84c00a4 commit 6fe7207

File tree

2 files changed

+116
-39
lines changed

2 files changed

+116
-39
lines changed

spring-cloud-gateway-core/src/main/java/org/springframework/cloud/gateway/filter/LoadBalancerClientFilter.java

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,15 @@
2626
import org.springframework.cloud.gateway.support.NotFoundException;
2727
import org.springframework.core.Ordered;
2828
import org.springframework.web.server.ServerWebExchange;
29-
import org.springframework.web.util.UriComponentsBuilder;
3029

3130
import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.GATEWAY_REQUEST_URL_ATTR;
3231
import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.addOriginalRequestUrl;
33-
import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.containsEncodedQuery;
3432

3533
import reactor.core.publisher.Mono;
3634

3735
/**
3836
* @author Spencer Gibb
37+
* @author Tim Ysewyn
3938
*/
4039
public class LoadBalancerClientFilter implements GlobalFilter, Ordered {
4140

@@ -70,15 +69,9 @@ public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
7069
throw new NotFoundException("Unable to find instance for " + url.getHost());
7170
}
7271

73-
/*URI uri = exchange.getRequest().getURI();
74-
URI requestUrl = loadBalancer.reconstructURI(instance, uri);*/
75-
boolean encoded = containsEncodedQuery(url);
76-
URI requestUrl = UriComponentsBuilder.fromUri(url)
77-
.scheme(instance.isSecure()? "https" : "http") //TODO: support websockets
78-
.host(instance.getHost())
79-
.port(instance.getPort())
80-
.build(encoded)
81-
.toUri();
72+
URI uri = exchange.getRequest().getURI();
73+
URI requestUrl = loadBalancer.reconstructURI(instance, uri);
74+
8275
log.trace("LoadBalancerClientFilter url chosen: " + requestUrl);
8376
exchange.getAttributes().put(GATEWAY_REQUEST_URL_ATTR, requestUrl);
8477
return chain.filter(exchange);

spring-cloud-gateway-core/src/test/java/org/springframework/cloud/gateway/filter/LoadBalancerClientFilterTests.java

Lines changed: 112 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,129 @@
1-
/*
2-
* Copyright 2013-2017 the original author or authors.
3-
*
4-
* Licensed under the Apache License, Version 2.0 (the "License");
5-
* you may not use this file except in compliance with the License.
6-
* You may obtain a copy of the License at
7-
*
8-
* http://www.apache.org/licenses/LICENSE-2.0
9-
*
10-
* Unless required by applicable law or agreed to in writing, software
11-
* distributed under the License is distributed on an "AS IS" BASIS,
12-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13-
* See the License for the specific language governing permissions and
14-
* limitations under the License.
15-
*
16-
*/
17-
181
package org.springframework.cloud.gateway.filter;
192

203
import java.net.URI;
214
import java.util.Collections;
5+
import java.util.LinkedHashSet;
226

7+
import com.netflix.loadbalancer.ILoadBalancer;
8+
import com.netflix.loadbalancer.Server;
9+
import org.junit.Before;
2310
import org.junit.Test;
11+
import org.junit.runner.RunWith;
2412
import org.mockito.ArgumentCaptor;
13+
import org.mockito.InjectMocks;
14+
import org.mockito.Mock;
15+
import org.mockito.junit.MockitoJUnitRunner;
2516
import org.springframework.cloud.client.DefaultServiceInstance;
17+
import org.springframework.cloud.client.ServiceInstance;
2618
import org.springframework.cloud.client.loadbalancer.LoadBalancerClient;
19+
import org.springframework.cloud.gateway.support.NotFoundException;
20+
import org.springframework.cloud.netflix.ribbon.RibbonLoadBalancerClient;
21+
import org.springframework.cloud.netflix.ribbon.RibbonLoadBalancerContext;
22+
import org.springframework.cloud.netflix.ribbon.SpringClientFactory;
2723
import org.springframework.http.HttpMethod;
2824
import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
2925
import org.springframework.mock.web.server.MockServerWebExchange;
3026
import org.springframework.web.server.ServerWebExchange;
3127
import org.springframework.web.util.UriComponentsBuilder;
3228

3329
import static org.assertj.core.api.Assertions.assertThat;
30+
import static org.mockito.ArgumentMatchers.any;
31+
import static org.mockito.ArgumentMatchers.eq;
3432
import static org.mockito.Mockito.mock;
33+
import static org.mockito.Mockito.verify;
34+
import static org.mockito.Mockito.verifyNoMoreInteractions;
35+
import static org.mockito.Mockito.verifyZeroInteractions;
3536
import static org.mockito.Mockito.when;
37+
import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.GATEWAY_ORIGINAL_REQUEST_URL_ATTR;
3638
import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.GATEWAY_REQUEST_URL_ATTR;
3739

3840
import reactor.core.publisher.Mono;
3941

4042
/**
4143
* @author Spencer Gibb
44+
* @author Tim Ysewyn
4245
*/
46+
@RunWith(MockitoJUnitRunner.class)
4347
public class LoadBalancerClientFilterTests {
4448

49+
private ServerWebExchange exchange;
50+
51+
@Mock
52+
private GatewayFilterChain chain;
53+
54+
@Mock
55+
private LoadBalancerClient loadBalancerClient;
56+
57+
@InjectMocks
58+
private LoadBalancerClientFilter loadBalancerClientFilter;
59+
60+
@Before
61+
public void setup() {
62+
exchange = MockServerWebExchange.from(MockServerHttpRequest.get("loadbalancerclient.org").build());
63+
}
64+
65+
@Test
66+
public void shouldNotFilterWhenGatewayRequestUrlIsMissing() {
67+
loadBalancerClientFilter.filter(exchange, chain);
68+
69+
verify(chain).filter(exchange);
70+
verifyNoMoreInteractions(chain);
71+
verifyZeroInteractions(loadBalancerClient);
72+
}
73+
74+
@Test
75+
public void shouldNotFilterWhenGatewayRequestUrlSchemeIsNotLb() {
76+
URI uri = UriComponentsBuilder.fromUriString("http://myservice").build().toUri();
77+
exchange.getAttributes().put(GATEWAY_REQUEST_URL_ATTR, uri);
78+
79+
loadBalancerClientFilter.filter(exchange, chain);
80+
81+
verify(chain).filter(exchange);
82+
verifyNoMoreInteractions(chain);
83+
verifyZeroInteractions(loadBalancerClient);
84+
}
85+
86+
@Test(expected = NotFoundException.class)
87+
public void shouldThrowExceptionWhenNoServiceInstanceIsFound() {
88+
URI uri = UriComponentsBuilder.fromUriString("lb://myservice").build().toUri();
89+
exchange.getAttributes().put(GATEWAY_REQUEST_URL_ATTR, uri);
90+
91+
loadBalancerClientFilter.filter(exchange, chain);
92+
}
93+
94+
@Test
95+
public void shouldFilter() {
96+
URI url = UriComponentsBuilder.fromUriString("lb://myservice").build().toUri();
97+
exchange.getAttributes().put(GATEWAY_REQUEST_URL_ATTR, url);
98+
99+
ServiceInstance serviceInstance = new DefaultServiceInstance("myservice", "localhost", 8080, true);
100+
when(loadBalancerClient.choose("myservice")).thenReturn(serviceInstance);
101+
102+
URI requestUrl = UriComponentsBuilder.fromUriString("https://localhost:8080").build().toUri();
103+
when(loadBalancerClient.reconstructURI(any(ServiceInstance.class), any(URI.class))).thenReturn(requestUrl);
104+
105+
loadBalancerClientFilter.filter(exchange, chain);
106+
107+
assertThat((LinkedHashSet<URI>)exchange.getAttribute(GATEWAY_ORIGINAL_REQUEST_URL_ATTR)).contains(url);
108+
109+
verify(loadBalancerClient).choose("myservice");
110+
111+
ArgumentCaptor<URI> urlArgumentCaptor = ArgumentCaptor.forClass(URI.class);
112+
verify(loadBalancerClient).reconstructURI(eq(serviceInstance), urlArgumentCaptor.capture());
113+
114+
URI uri = urlArgumentCaptor.getValue();
115+
assertThat(uri).isNotNull();
116+
assertThat(uri.toString()).isEqualTo("loadbalancerclient.org");
117+
118+
verifyNoMoreInteractions(loadBalancerClient);
119+
120+
assertThat((URI)exchange.getAttribute(GATEWAY_REQUEST_URL_ATTR)).isEqualTo(requestUrl);
121+
122+
verify(chain).filter(exchange);
123+
verifyNoMoreInteractions(chain);
124+
}
125+
126+
45127
@Test
46128
public void happyPath() {
47129
MockServerHttpRequest request = MockServerHttpRequest
@@ -119,18 +201,20 @@ private ServerWebExchange testFilter(MockServerHttpRequest request, URI uri) {
119201
ServerWebExchange exchange = MockServerWebExchange.from(request);
120202
exchange.getAttributes().put(GATEWAY_REQUEST_URL_ATTR, uri);
121203

122-
GatewayFilterChain filterChain = mock(GatewayFilterChain.class);
123-
124204
ArgumentCaptor<ServerWebExchange> captor = ArgumentCaptor.forClass(ServerWebExchange.class);
125-
when(filterChain.filter(captor.capture())).thenReturn(Mono.empty());
126-
127-
LoadBalancerClient loadBalancerClient = mock(LoadBalancerClient.class);
128-
when(loadBalancerClient.choose("service1")).
129-
thenReturn(new DefaultServiceInstance("service1", "service1-host1", 8081,
130-
false, Collections.emptyMap()));
131-
132-
LoadBalancerClientFilter filter = new LoadBalancerClientFilter(loadBalancerClient);
133-
filter.filter(exchange, filterChain);
205+
when(chain.filter(captor.capture())).thenReturn(Mono.empty());
206+
207+
SpringClientFactory clientFactory = mock(SpringClientFactory.class);
208+
ILoadBalancer loadBalancer = mock(ILoadBalancer.class);
209+
210+
when(clientFactory.getLoadBalancerContext("service1")).thenReturn(new RibbonLoadBalancerContext(loadBalancer));
211+
when(clientFactory.getLoadBalancer("service1")).thenReturn(loadBalancer);
212+
when(loadBalancer.chooseServer(any())).thenReturn(new Server("service1-host1", 8081));
213+
214+
RibbonLoadBalancerClient client = new RibbonLoadBalancerClient(clientFactory);
215+
216+
LoadBalancerClientFilter filter = new LoadBalancerClientFilter(client);
217+
filter.filter(exchange, chain);
134218

135219
return captor.getValue();
136220
}

0 commit comments

Comments
 (0)