Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,22 @@
package org.apache.hc.client5.http.impl;

import java.util.Iterator;
import java.util.concurrent.atomic.AtomicReference;

import org.apache.hc.core5.annotation.Internal;
import org.apache.hc.core5.http.FormattedHeader;
import org.apache.hc.core5.http.Header;
import org.apache.hc.core5.http.HttpHeaders;
import org.apache.hc.core5.http.HttpMessage;
import org.apache.hc.core5.http.HttpVersion;
import org.apache.hc.core5.http.ParseException;
import org.apache.hc.core5.http.ProtocolException;
import org.apache.hc.core5.http.ProtocolVersion;
import org.apache.hc.core5.http.message.MessageSupport;
import org.apache.hc.core5.http.ProtocolVersionParser;
import org.apache.hc.core5.http.ssl.TLS;
import org.apache.hc.core5.util.Args;
import org.apache.hc.core5.util.CharArrayBuffer;
import org.apache.hc.core5.util.Tokenizer;

/**
* Protocol switch handler.
Expand All @@ -45,31 +52,106 @@
@Internal
public final class ProtocolSwitchStrategy {

enum ProtocolSwitch { FAILURE, TLS }
private static final ProtocolVersionParser PROTOCOL_VERSION_PARSER = ProtocolVersionParser.INSTANCE;

private static final Tokenizer TOKENIZER = Tokenizer.INSTANCE;

private static final Tokenizer.Delimiter UPGRADE_TOKEN_DELIMITER = Tokenizer.delimiters(',');

@FunctionalInterface
private interface HeaderConsumer {
void accept(CharSequence buffer, Tokenizer.Cursor cursor) throws ProtocolException;
}

public ProtocolVersion switchProtocol(final HttpMessage response) throws ProtocolException {
final Iterator<String> it = MessageSupport.iterateTokens(response, HttpHeaders.UPGRADE);
final AtomicReference<ProtocolVersion> tlsUpgrade = new AtomicReference<>();

ProtocolVersion tlsUpgrade = null;
while (it.hasNext()) {
final String token = it.next();
if (token.startsWith("TLS")) {
// TODO: Improve handling of HTTP protocol token once HttpVersion has a #parse method
try {
tlsUpgrade = token.length() == 3 ? TLS.V_1_2.getVersion() : TLS.parse(token.replace("TLS/", "TLSv"));
} catch (final ParseException ex) {
throw new ProtocolException("Invalid protocol: " + token);
parseHeaders(response, HttpHeaders.UPGRADE, (buffer, cursor) -> {
while (!cursor.atEnd()) {
TOKENIZER.skipWhiteSpace(buffer, cursor);
if (cursor.atEnd()) {
break;
}
final int tokenStart = cursor.getPos();
TOKENIZER.parseToken(buffer, cursor, UPGRADE_TOKEN_DELIMITER);
final int tokenEnd = cursor.getPos();
if (tokenStart < tokenEnd) {
final ProtocolVersion version = parseProtocolToken(buffer, tokenStart, tokenEnd);
if (version != null && "TLS".equalsIgnoreCase(version.getProtocol())) {
tlsUpgrade.set(version);
}
}
} else if (token.equals("HTTP/1.1")) {
// TODO: Improve handling of HTTP protocol token once HttpVersion has a #parse method
if (!cursor.atEnd()) {
cursor.updatePos(cursor.getPos() + 1);
}
}
});

final ProtocolVersion result = tlsUpgrade.get();
if (result != null) {
return result;
} else {
throw new ProtocolException("Invalid protocol switch response: no TLS version found");
}
}

private ProtocolVersion parseProtocolToken(final CharSequence buffer, final int start, final int end)
throws ProtocolException {
if (start >= end) {
return null;
}

if (end - start == 3) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@arturobernalg I think there is a way to make it nicer, but it is good enough for now

final char c0 = buffer.charAt(start);
final char c1 = buffer.charAt(start + 1);
final char c2 = buffer.charAt(start + 2);
if ((c0 == 'T' || c0 == 't') &&
(c1 == 'L' || c1 == 'l') &&
(c2 == 'S' || c2 == 's')) {
return TLS.V_1_2.getVersion();
}
}

try {
final Tokenizer.Cursor cursor = new Tokenizer.Cursor(start, end);
final ProtocolVersion version = PROTOCOL_VERSION_PARSER.parse(buffer, cursor, null);

if ("TLS".equalsIgnoreCase(version.getProtocol())) {
return version;
} else if (version.equals(HttpVersion.HTTP_1_1)) {
return null;
} else {
throw new ProtocolException("Unsupported protocol: " + token);
throw new ProtocolException("Unsupported protocol or HTTP version: " + buffer.subSequence(start, end));
}
} catch (final ParseException ex) {
throw new ProtocolException("Invalid protocol: " + buffer.subSequence(start, end), ex);
}
if (tlsUpgrade == null) {
throw new ProtocolException("Invalid protocol switch response");
}

private void parseHeaders(final HttpMessage message, final String name, final HeaderConsumer consumer)
throws ProtocolException {
Args.notNull(message, "Message headers");
Args.notBlank(name, "Header name");
final Iterator<Header> it = message.headerIterator(name);
while (it.hasNext()) {
parseHeader(it.next(), consumer);
}
return tlsUpgrade;
}

}
private void parseHeader(final Header header, final HeaderConsumer consumer) throws ProtocolException {
Args.notNull(header, "Header");
if (header instanceof FormattedHeader) {
final CharArrayBuffer buf = ((FormattedHeader) header).getBuffer();
final Tokenizer.Cursor cursor = new Tokenizer.Cursor(0, buf.length());
cursor.updatePos(((FormattedHeader) header).getValuePos());
consumer.accept(buf, cursor);
} else {
final String value = header.getValue();
if (value == null) {
return;
}
final Tokenizer.Cursor cursor = new Tokenizer.Cursor(0, value.length());
consumer.accept(value, cursor);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,15 @@
import org.apache.hc.core5.http.HttpResponse;
import org.apache.hc.core5.http.HttpStatus;
import org.apache.hc.core5.http.ProtocolException;
import org.apache.hc.core5.http.ProtocolVersion;
import org.apache.hc.core5.http.message.BasicHttpResponse;
import org.apache.hc.core5.http.ssl.TLS;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

/**
* Simple tests for {@link DefaultAuthenticationStrategy}.
* Simple tests for {@link ProtocolSwitchStrategy}.
*/
class TestProtocolSwitchStrategy {

Expand Down Expand Up @@ -95,4 +96,120 @@ void testSwitchInvalid() {
Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response3));
}

@Test
void testNullToken() throws ProtocolException {
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
response.addHeader(HttpHeaders.UPGRADE, "TLS,");
response.addHeader(HttpHeaders.UPGRADE, null);
Assertions.assertEquals(TLS.V_1_2.getVersion(), switchStrategy.switchProtocol(response));
}

@Test
void testWhitespaceOnlyToken() throws ProtocolException {
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
response.addHeader(HttpHeaders.UPGRADE, " , TLS");
Assertions.assertEquals(TLS.V_1_2.getVersion(), switchStrategy.switchProtocol(response));
}

@Test
void testUnsupportedTlsVersion() throws Exception {
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
response.addHeader(HttpHeaders.UPGRADE, "TLS/1.4");
Assertions.assertEquals(new ProtocolVersion("TLS", 1, 4), switchStrategy.switchProtocol(response));
}

@Test
void testUnsupportedTlsMajorVersion() throws Exception {
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
response.addHeader(HttpHeaders.UPGRADE, "TLS/2.0");
Assertions.assertEquals(new ProtocolVersion("TLS", 2, 0), switchStrategy.switchProtocol(response));
}

@Test
void testUnsupportedHttpVersion() {
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
response.addHeader(HttpHeaders.UPGRADE, "HTTP/2.0");
Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response),
"Unsupported HTTP version: HTTP/2.0");
}

@Test
void testInvalidTlsFormat() {
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
response.addHeader(HttpHeaders.UPGRADE, "TLS/abc");
Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response),
"Invalid protocol: TLS/abc");
}

@Test
void testHttp11Only() {
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
response.addHeader(HttpHeaders.UPGRADE, "HTTP/1.1");
Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response),
"Invalid protocol switch response: no TLS version found");
}

@Test
void testSwitchToTlsValid_TLS_1_2() throws Exception {
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
response.addHeader(HttpHeaders.UPGRADE, "TLS/1.2");
final ProtocolVersion result = switchStrategy.switchProtocol(response);
Assertions.assertEquals(TLS.V_1_2.getVersion(), result);
}

@Test
void testSwitchToTlsValid_TLS_1_0() throws Exception {
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
response.addHeader(HttpHeaders.UPGRADE, "TLS/1.0");
final ProtocolVersion result = switchStrategy.switchProtocol(response);
Assertions.assertEquals(TLS.V_1_0.getVersion(), result);
}

@Test
void testSwitchToTlsValid_TLS_1_1() throws Exception {
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
response.addHeader(HttpHeaders.UPGRADE, "TLS/1.1");
final ProtocolVersion result = switchStrategy.switchProtocol(response);
Assertions.assertEquals(TLS.V_1_1.getVersion(), result);
}

@Test
void testInvalidTlsFormat_NoSlash() {
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
response.addHeader(HttpHeaders.UPGRADE, "TLSv1");
Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response),
"Invalid protocol: TLSv1");
}

@Test
void testSwitchToTlsValid_TLS_1() throws Exception {
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
response.addHeader(HttpHeaders.UPGRADE, "TLS/1");
final ProtocolVersion result = switchStrategy.switchProtocol(response);
Assertions.assertEquals(TLS.V_1_0.getVersion(), result);
}

@Test
void testInvalidTlsFormat_MissingMajor() {
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
response.addHeader(HttpHeaders.UPGRADE, "TLS/.1");
Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response),
"Invalid protocol: TLS/.1");
}

@Test
void testMultipleHttp11Tokens() {
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
response.addHeader(HttpHeaders.UPGRADE, "HTTP/1.1, HTTP/1.1");
Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response),
"Invalid protocol switch response: no TLS version found");
}

@Test
void testMixedInvalidAndValidTokens() {
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
response.addHeader(HttpHeaders.UPGRADE, "Crap, TLS/1.2, Invalid");
Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response),
"Invalid protocol: Crap");
}
}