diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpServerTransportTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpServerTransportTest.java index 816272fbc98..455908816a8 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpServerTransportTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpServerTransportTest.java @@ -56,11 +56,9 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; -import java.io.PipedInputStream; -import java.io.PipedOutputStream; import java.net.InetSocketAddress; +import java.net.ServerSocket; import java.net.Socket; -import java.net.SocketAddress; import java.util.ArrayDeque; import java.util.Arrays; import java.util.Deque; @@ -69,6 +67,7 @@ import java.util.concurrent.Executor; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import okio.Buffer; import okio.BufferedSource; @@ -96,14 +95,14 @@ public class OkHttpServerTransportTest { private ServerTransportListener transportListener = mock(ServerTransportListener.class, delegatesTo(mockTransportListener)); private OkHttpServerTransport serverTransport; - private final PipeSocket socket = new PipeSocket(); + private final ExecutorService threadPool = Executors.newCachedThreadPool(); + private final SocketPair socketPair = SocketPair.create(threadPool); private final FrameWriter clientFrameWriter - = new Http2().newWriter(Okio.buffer(Okio.sink(socket.inputStreamSource)), true); + = new Http2().newWriter(Okio.buffer(Okio.sink(socketPair.getClientOutputStream())), true); private final FrameReader clientFrameReader - = new Http2().newReader(Okio.buffer(Okio.source(socket.outputStreamSink)), true); + = new Http2().newReader(Okio.buffer(Okio.source(socketPair.getClientInputStream())), true); private final FrameReader.Handler clientFramesRead = mock(FrameReader.Handler.class); private final DataFrameHandler clientDataFrames = mock(DataFrameHandler.class); - private ExecutorService threadPool = Executors.newCachedThreadPool(); private HandshakerSocketFactory handshakerSocketFactory = mock(HandshakerSocketFactory.class, delegatesTo(new PlaintextHandshakerSocketFactory())); private final FakeClock fakeClock = new FakeClock(); @@ -142,7 +141,11 @@ public void setUp() throws Exception { @After public void tearDown() throws Exception { threadPool.shutdownNow(); - socket.closeSourceAndSink(); + try { + socketPair.client.close(); + } finally { + socketPair.server.close(); + } } @Test @@ -172,7 +175,7 @@ public void maxConnectionAge() throws Exception { verifyGracefulShutdown(1); pingPong(); fakeClock.forwardNanos(TimeUnit.SECONDS.toNanos(3)); - assertThat(socket.isClosed()).isTrue(); + assertThat(socketPair.server.isClosed()).isTrue(); } @Test @@ -254,7 +257,7 @@ public void startThenShutdownTwice() throws Exception { @Test public void shutdownDuringHandshake() throws Exception { doAnswer(invocation -> { - socket.getInputStream().read(); + ((Socket) invocation.getArguments()[0]).getInputStream().read(); throw new IOException("handshake purposefully failed"); }).when(handshakerSocketFactory).handshake(any(Socket.class), any(Attributes.class)); serverBuilder.transportExecutor(threadPool); @@ -268,7 +271,7 @@ public void shutdownDuringHandshake() throws Exception { @Test public void shutdownNowDuringHandshake() throws Exception { doAnswer(invocation -> { - socket.getInputStream().read(); + ((Socket) invocation.getArguments()[0]).getInputStream().read(); throw new IOException("handshake purposefully failed"); }).when(handshakerSocketFactory).handshake(any(Socket.class), any(Attributes.class)); serverBuilder.transportExecutor(threadPool); @@ -282,12 +285,12 @@ public void shutdownNowDuringHandshake() throws Exception { @Test public void clientCloseDuringHandshake() throws Exception { doAnswer(invocation -> { - socket.getInputStream().read(); + ((Socket) invocation.getArguments()[0]).getInputStream().read(); throw new IOException("handshake purposefully failed"); }).when(handshakerSocketFactory).handshake(any(Socket.class), any(Attributes.class)); serverBuilder.transportExecutor(threadPool); initTransport(); - socket.close(); + socketPair.client.close(); verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); verify(transportListener, never()).transportReady(any(Attributes.class)); @@ -296,7 +299,7 @@ public void clientCloseDuringHandshake() throws Exception { @Test public void closeDuringHttp2Preface() throws Exception { initTransport(); - socket.close(); + socketPair.client.close(); verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); verify(transportListener, never()).transportReady(any(Attributes.class)); @@ -307,7 +310,7 @@ public void noSettingsDuringHttp2HandshakeSettings() throws Exception { initTransport(); clientFrameWriter.connectionPreface(); clientFrameWriter.flush(); - socket.close(); + socketPair.client.close(); verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); verify(transportListener, never()).transportReady(any(Attributes.class)); @@ -329,7 +332,7 @@ public void startThenClientDisconnect() throws Exception { initTransport(); handshake(); - socket.closeSourceAndSink(); + socketPair.client.close(); verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); } @@ -1086,8 +1089,8 @@ public void channelzStats() throws Exception { assertThat(stats.data.messagesReceived).isEqualTo(0); assertThat(stats.data.remoteFlowControlWindow).isEqualTo(30000); // Lower bound assertThat(stats.data.localFlowControlWindow).isEqualTo(66535); - assertThat(stats.local).isEqualTo(new InetSocketAddress("127.0.0.1", 4000)); - assertThat(stats.remote).isEqualTo(new InetSocketAddress("127.0.0.2", 5000)); + assertThat(stats.local).isEqualTo(socketPair.server.getLocalSocketAddress()); + assertThat(stats.remote).isEqualTo(socketPair.server.getRemoteSocketAddress()); } @Test @@ -1188,7 +1191,7 @@ public void keepAliveEnforcer_noticesActive() throws Exception { private void initTransport() throws Exception { serverTransport = new OkHttpServerTransport( new OkHttpServerTransport.Config(serverBuilder, Arrays.asList()), - socket); + socketPair.server); serverTransport.start(transportListener); } @@ -1357,61 +1360,44 @@ static String getContent(InputStream message) throws IOException { } } - private static class PipeSocket extends Socket { - private final PipedOutputStream outputStream = new PipedOutputStream(); - private final PipedInputStream outputStreamSink = new PipedInputStream(); - private final PipedOutputStream inputStreamSource = new PipedOutputStream(); - private final PipedInputStream inputStream = new PipedInputStream(); + private static class SocketPair { + public final Socket client; + public final Socket server; + + public SocketPair(Socket client, Socket server) { + this.client = client; + this.server = server; + } - public PipeSocket() { + public InputStream getClientInputStream() { try { - outputStreamSink.connect(outputStream); - inputStream.connect(inputStreamSource); + return client.getInputStream(); } catch (IOException ex) { - throw new AssertionError(ex); + throw new RuntimeException(ex); } } - @Override - public synchronized void close() throws IOException { + public OutputStream getClientOutputStream() { try { - outputStream.close(); - } finally { - inputStream.close(); - // PipedInputStream can only be woken by PipedOutputStream, so PipedOutputStream.close() is - // a better imitation of Socket.close(). - inputStreamSource.close(); - super.close(); + return client.getOutputStream(); + } catch (IOException ex) { + throw new RuntimeException(ex); } } - public void closeSourceAndSink() throws IOException { + public static SocketPair create(ExecutorService threadPool) { try { - outputStreamSink.close(); - } finally { - inputStreamSource.close(); + try (ServerSocket serverSocket = new ServerSocket(0)) { + Future serverFuture = threadPool.submit(() -> serverSocket.accept()); + Socket client = new Socket(); + client.connect(serverSocket.getLocalSocketAddress()); + Socket server = serverFuture.get(); + return new SocketPair(client, server); + } + } catch (Exception ex) { + throw new RuntimeException(ex); } } - - @Override - public SocketAddress getLocalSocketAddress() { - return new InetSocketAddress("127.0.0.1", 4000); - } - - @Override - public SocketAddress getRemoteSocketAddress() { - return new InetSocketAddress("127.0.0.2", 5000); - } - - @Override - public OutputStream getOutputStream() { - return outputStream; - } - - @Override - public InputStream getInputStream() { - return inputStream; - } } private interface DataFrameHandler {