Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ different replacement can also be specified if you wish.
`removeHeaders` on `Preprocessors` removes any occurrences of the named headers
from the request or response.

`removeMatchingHeaders` on `Preprocessors` applies the given patterns on every header and
removes them when matching.


[[customizing-requests-and-responses-preprocessors-replace-patterns]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;
import java.util.regex.Pattern;

import org.springframework.http.HttpHeaders;
import org.springframework.restdocs.operation.OperationRequest;
Expand All @@ -27,20 +28,28 @@
import org.springframework.restdocs.operation.OperationResponseFactory;

/**
* An {@link OperationPreprocessor} that removes headers.
* An {@link OperationPreprocessor} that removes headers. The headers to remove are
* provided as constructor arguments and can be either plain string or patterns to match
* against the headers found
*
* @author Andy Wilkinson
*/
class HeaderRemovingOperationPreprocessor implements OperationPreprocessor {

private final OperationRequestFactory requestFactory = new OperationRequestFactory();

private final OperationResponseFactory responseFactory = new OperationResponseFactory();

private final Set<String> headersToRemove;
private final Set<String> plainHeadersToRemove;
private final Set<Pattern> patternHeadersToRemove;

HeaderRemovingOperationPreprocessor(String ... headersToRemove) {
this.plainHeadersToRemove = new HashSet<>(Arrays.asList(headersToRemove));
this.patternHeadersToRemove = null;
}

HeaderRemovingOperationPreprocessor(String... headersToRemove) {
this.headersToRemove = new HashSet<>(Arrays.asList(headersToRemove));
HeaderRemovingOperationPreprocessor(Pattern ... patternHeadersToRemove) {
this.plainHeadersToRemove = null;
this.patternHeadersToRemove = new HashSet<>(Arrays.asList(patternHeadersToRemove));
}

@Override
Expand All @@ -58,8 +67,25 @@ public OperationRequest preprocess(OperationRequest request) {
private HttpHeaders removeHeaders(HttpHeaders originalHeaders) {
HttpHeaders processedHeaders = new HttpHeaders();
processedHeaders.putAll(originalHeaders);
for (String headerToRemove : this.headersToRemove) {
processedHeaders.remove(headerToRemove);
if (this.plainHeadersToRemove != null) {
for (String headerToRemove : this.plainHeadersToRemove) {
processedHeaders.remove(headerToRemove);
}
}
else {
Set<String> toRemove = new HashSet<>();
for (String headerToCheck : originalHeaders.keySet()) {
for (Pattern pattern : this.patternHeadersToRemove) {
if (pattern.matcher(headerToCheck).matches()) {
toRemove.add(headerToCheck);
}
}
}
// Remove afterwards to avoid side effects when removing while iterating over
// the set keys :
for (String headerToRemove : toRemove) {
processedHeaders.remove(headerToRemove);
}
}
return processedHeaders;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,28 @@ public static OperationPreprocessor prettyPrint() {
* Returns an {@code OperationPreprocessor} that will remove headers from the request
* or response.
*
* @param headersToRemove the names of the headers to remove
* @param headersToRemove the names of the headers to remove.
* @return the preprocessor
*/
public static OperationPreprocessor removeHeaders(String... headersToRemove) {
return new HeaderRemovingOperationPreprocessor(headersToRemove);
}

/**
* Returns an {@code OperationPreprocessor} that will remove headers from the request
* or response based on a pattern match.
*
* @param headerPatternsToRemove pattern for the header names to remove. Every matchig header will be removed.
* @return the preprocessor
*/
public static OperationPreprocessor removeMatchingHeaders(String... headerPatternsToRemove) {
Pattern[] patterns = new Pattern[headerPatternsToRemove.length];
for (int i = 0; i < headerPatternsToRemove.length; i++) {
patterns[i] = Pattern.compile(headerPatternsToRemove[i]);
}
return new HeaderRemovingOperationPreprocessor(patterns);
}

/**
* Returns an {@code OperationPreprocessor} that will mask the href of hypermedia
* links in the request or response.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.net.URI;
import java.util.Arrays;
import java.util.Collections;
import java.util.regex.Pattern;

import org.junit.Test;
import org.springframework.http.HttpHeaders;
Expand Down Expand Up @@ -66,18 +67,43 @@ public void modifyRequestHeaders() {

@Test
public void modifyResponseHeaders() {
OperationResponse response = this.responseFactory.create(HttpStatus.OK,
getHttpHeaders(), new byte[0]);
OperationResponse response = createResponse();
OperationResponse preprocessed = this.preprocessor.preprocess(response);
assertThat(preprocessed.getHeaders().size(), is(equalTo(1)));
assertThat(preprocessed.getHeaders(), hasEntry("a", Arrays.asList("alpha")));
}

private HttpHeaders getHttpHeaders() {
@Test
public void modifyWithPattern() {
OperationResponse response = createResponse("content-length", "1234");
HeaderRemovingOperationPreprocessor processor =
new HeaderRemovingOperationPreprocessor(Pattern.compile("co.*le(.)gth]"));
OperationResponse preprocessed = processor.preprocess(response);
assertThat(preprocessed.getHeaders().size(), is(equalTo(2)));
assertThat(preprocessed.getHeaders(), hasEntry("a", Arrays.asList("alpha")));
assertThat(preprocessed.getHeaders(), hasEntry("b", Arrays.asList("bravo", "banana")));
}

@Test
public void removeAllHeaders() {
HeaderRemovingOperationPreprocessor processor =
new HeaderRemovingOperationPreprocessor(Pattern.compile(".*"));
OperationResponse preprocessed = processor.preprocess(createResponse());
assertThat(preprocessed.getHeaders().size(), is(equalTo(0)));
}

private OperationResponse createResponse(String ... extraHeaders) {
return this.responseFactory.create(HttpStatus.OK, getHttpHeaders(extraHeaders), new byte[0]);
}

private HttpHeaders getHttpHeaders(String ... extraHeaders) {
HttpHeaders httpHeaders = new HttpHeaders();
httpHeaders.add("a", "alpha");
httpHeaders.add("b", "bravo");
httpHeaders.add("b", "banana");
for (int i = 0; i < extraHeaders.length; i += 2) {
httpHeaders.add(extraHeaders[i], extraHeaders[i + 1]);
}
return httpHeaders;
}

Expand Down