Skip to content

okhttp: Use a real socket during server transport testing #10281

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 15, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 47 additions & 61 deletions okhttp/src/test/java/io/grpc/okhttp/OkHttpServerTransportTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,9 @@
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.PipedInputStream;
import java.io.PipedOutputStream;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.SocketAddress;
import java.util.ArrayDeque;
import java.util.Arrays;
import java.util.Deque;
Expand All @@ -69,6 +67,7 @@
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import okio.Buffer;
import okio.BufferedSource;
Expand Down Expand Up @@ -96,14 +95,14 @@ public class OkHttpServerTransportTest {
private ServerTransportListener transportListener
= mock(ServerTransportListener.class, delegatesTo(mockTransportListener));
private OkHttpServerTransport serverTransport;
private final PipeSocket socket = new PipeSocket();
private final ExecutorService threadPool = Executors.newCachedThreadPool();
private final SocketPair socketPair = SocketPair.create(threadPool);
private final FrameWriter clientFrameWriter
= new Http2().newWriter(Okio.buffer(Okio.sink(socket.inputStreamSource)), true);
= new Http2().newWriter(Okio.buffer(Okio.sink(socketPair.getClientOutputStream())), true);
private final FrameReader clientFrameReader
= new Http2().newReader(Okio.buffer(Okio.source(socket.outputStreamSink)), true);
= new Http2().newReader(Okio.buffer(Okio.source(socketPair.getClientInputStream())), true);
private final FrameReader.Handler clientFramesRead = mock(FrameReader.Handler.class);
private final DataFrameHandler clientDataFrames = mock(DataFrameHandler.class);
private ExecutorService threadPool = Executors.newCachedThreadPool();
private HandshakerSocketFactory handshakerSocketFactory
= mock(HandshakerSocketFactory.class, delegatesTo(new PlaintextHandshakerSocketFactory()));
private final FakeClock fakeClock = new FakeClock();
Expand Down Expand Up @@ -142,7 +141,11 @@ public void setUp() throws Exception {
@After
public void tearDown() throws Exception {
threadPool.shutdownNow();
socket.closeSourceAndSink();
try {
socketPair.client.close();
} finally {
socketPair.server.close();
}
}

@Test
Expand Down Expand Up @@ -172,7 +175,7 @@ public void maxConnectionAge() throws Exception {
verifyGracefulShutdown(1);
pingPong();
fakeClock.forwardNanos(TimeUnit.SECONDS.toNanos(3));
assertThat(socket.isClosed()).isTrue();
assertThat(socketPair.server.isClosed()).isTrue();
}

@Test
Expand Down Expand Up @@ -254,7 +257,7 @@ public void startThenShutdownTwice() throws Exception {
@Test
public void shutdownDuringHandshake() throws Exception {
doAnswer(invocation -> {
socket.getInputStream().read();
((Socket) invocation.getArguments()[0]).getInputStream().read();
throw new IOException("handshake purposefully failed");
}).when(handshakerSocketFactory).handshake(any(Socket.class), any(Attributes.class));
serverBuilder.transportExecutor(threadPool);
Expand All @@ -268,7 +271,7 @@ public void shutdownDuringHandshake() throws Exception {
@Test
public void shutdownNowDuringHandshake() throws Exception {
doAnswer(invocation -> {
socket.getInputStream().read();
((Socket) invocation.getArguments()[0]).getInputStream().read();
throw new IOException("handshake purposefully failed");
}).when(handshakerSocketFactory).handshake(any(Socket.class), any(Attributes.class));
serverBuilder.transportExecutor(threadPool);
Expand All @@ -282,12 +285,12 @@ public void shutdownNowDuringHandshake() throws Exception {
@Test
public void clientCloseDuringHandshake() throws Exception {
doAnswer(invocation -> {
socket.getInputStream().read();
((Socket) invocation.getArguments()[0]).getInputStream().read();
throw new IOException("handshake purposefully failed");
}).when(handshakerSocketFactory).handshake(any(Socket.class), any(Attributes.class));
serverBuilder.transportExecutor(threadPool);
initTransport();
socket.close();
socketPair.client.close();

verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated();
verify(transportListener, never()).transportReady(any(Attributes.class));
Expand All @@ -296,7 +299,7 @@ public void clientCloseDuringHandshake() throws Exception {
@Test
public void closeDuringHttp2Preface() throws Exception {
initTransport();
socket.close();
socketPair.client.close();

verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated();
verify(transportListener, never()).transportReady(any(Attributes.class));
Expand All @@ -307,7 +310,7 @@ public void noSettingsDuringHttp2HandshakeSettings() throws Exception {
initTransport();
clientFrameWriter.connectionPreface();
clientFrameWriter.flush();
socket.close();
socketPair.client.close();

verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated();
verify(transportListener, never()).transportReady(any(Attributes.class));
Expand All @@ -329,7 +332,7 @@ public void startThenClientDisconnect() throws Exception {
initTransport();
handshake();

socket.closeSourceAndSink();
socketPair.client.close();
verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated();
}

Expand Down Expand Up @@ -1086,8 +1089,8 @@ public void channelzStats() throws Exception {
assertThat(stats.data.messagesReceived).isEqualTo(0);
assertThat(stats.data.remoteFlowControlWindow).isEqualTo(30000); // Lower bound
assertThat(stats.data.localFlowControlWindow).isEqualTo(66535);
assertThat(stats.local).isEqualTo(new InetSocketAddress("127.0.0.1", 4000));
assertThat(stats.remote).isEqualTo(new InetSocketAddress("127.0.0.2", 5000));
assertThat(stats.local).isEqualTo(socketPair.server.getLocalSocketAddress());
assertThat(stats.remote).isEqualTo(socketPair.server.getRemoteSocketAddress());
}

@Test
Expand Down Expand Up @@ -1188,7 +1191,7 @@ public void keepAliveEnforcer_noticesActive() throws Exception {
private void initTransport() throws Exception {
serverTransport = new OkHttpServerTransport(
new OkHttpServerTransport.Config(serverBuilder, Arrays.asList()),
socket);
socketPair.server);
serverTransport.start(transportListener);
}

Expand Down Expand Up @@ -1357,61 +1360,44 @@ static String getContent(InputStream message) throws IOException {
}
}

private static class PipeSocket extends Socket {
private final PipedOutputStream outputStream = new PipedOutputStream();
private final PipedInputStream outputStreamSink = new PipedInputStream();
private final PipedOutputStream inputStreamSource = new PipedOutputStream();
private final PipedInputStream inputStream = new PipedInputStream();
private static class SocketPair {
public final Socket client;
public final Socket server;

public SocketPair(Socket client, Socket server) {
this.client = client;
this.server = server;
}

public PipeSocket() {
public InputStream getClientInputStream() {
try {
outputStreamSink.connect(outputStream);
inputStream.connect(inputStreamSource);
return client.getInputStream();
} catch (IOException ex) {
throw new AssertionError(ex);
throw new RuntimeException(ex);
}
}

@Override
public synchronized void close() throws IOException {
public OutputStream getClientOutputStream() {
try {
outputStream.close();
} finally {
inputStream.close();
// PipedInputStream can only be woken by PipedOutputStream, so PipedOutputStream.close() is
// a better imitation of Socket.close().
inputStreamSource.close();
super.close();
return client.getOutputStream();
} catch (IOException ex) {
throw new RuntimeException(ex);
}
}

public void closeSourceAndSink() throws IOException {
public static SocketPair create(ExecutorService threadPool) {
try {
outputStreamSink.close();
} finally {
inputStreamSource.close();
try (ServerSocket serverSocket = new ServerSocket(0)) {
Future<Socket> serverFuture = threadPool.submit(() -> serverSocket.accept());
Socket client = new Socket();
client.connect(serverSocket.getLocalSocketAddress());
Socket server = serverFuture.get();
return new SocketPair(client, server);
}
} catch (Exception ex) {
throw new RuntimeException(ex);
}
}

@Override
public SocketAddress getLocalSocketAddress() {
return new InetSocketAddress("127.0.0.1", 4000);
}

@Override
public SocketAddress getRemoteSocketAddress() {
return new InetSocketAddress("127.0.0.2", 5000);
}

@Override
public OutputStream getOutputStream() {
return outputStream;
}

@Override
public InputStream getInputStream() {
return inputStream;
}
}

private interface DataFrameHandler {
Expand Down