diff --git a/httpserver/src/main/java/io/esastack/httpserver/HttpServer.java b/httpserver/src/main/java/io/esastack/httpserver/HttpServer.java index 90c9dbc..571ce13 100644 --- a/httpserver/src/main/java/io/esastack/httpserver/HttpServer.java +++ b/httpserver/src/main/java/io/esastack/httpserver/HttpServer.java @@ -19,7 +19,6 @@ import io.esastack.httpserver.impl.HttpServerImpl; import io.esastack.httpserver.metrics.Metrics; import io.netty.channel.Channel; -import io.netty.channel.ChannelHandlerContext; import io.netty.channel.EventLoopGroup; import io.netty.util.concurrent.Future; @@ -78,6 +77,15 @@ static HttpServer create(String name, ServerOptions options) { */ HttpServer handle(Consumer h); + /** + * Sets the handler for listening connection init. + * + * @param h handler + * + * @return this + */ + HttpServer onConnectionInit(Consumer h); + /** * Sets the handler for listening connection connected. * @@ -85,7 +93,7 @@ static HttpServer create(String name, ServerOptions options) { * * @return this */ - HttpServer onConnected(Consumer h); + HttpServer onConnected(Consumer h); /** * Sets the handler for listening connection disconnected. diff --git a/httpserver/src/main/java/io/esastack/httpserver/impl/HttpServerChannelInitializr.java b/httpserver/src/main/java/io/esastack/httpserver/impl/HttpServerChannelInitializr.java index a84e2ff..b76dcf7 100644 --- a/httpserver/src/main/java/io/esastack/httpserver/impl/HttpServerChannelInitializr.java +++ b/httpserver/src/main/java/io/esastack/httpserver/impl/HttpServerChannelInitializr.java @@ -53,17 +53,20 @@ final class HttpServerChannelInitializr extends ChannelInitializer { private final ServerRuntime runtime; private final SslHelper sslHelper; private final Consumer handler; - private final Consumer onConnected; + private final Consumer onConnectionInit; + private final Consumer onConnected; private final Consumer onDisconnected; HttpServerChannelInitializr(ServerRuntime runtime, SslHelper sslHelper, Consumer handler, - Consumer onConnected, + Consumer onConnectionInit, + Consumer onConnected, Consumer onDisconnected) { this.runtime = runtime; this.sslHelper = sslHelper; this.handler = handler; + this.onConnectionInit = onConnectionInit; this.onConnected = onConnected; this.onDisconnected = onDisconnected; } @@ -76,6 +79,14 @@ protected void initChannel(Channel ch) { return; } + if (onConnectionInit != null) { + try { + onConnectionInit.accept(ch); + } catch (Throwable t) { + Loggers.logger().warn("Error while processing onConnectionInit()", t); + } + } + final ChannelPipeline pipeline = ch.pipeline(); // options for each accepted channel applyChannelOptions(ch); @@ -91,7 +102,7 @@ protected void initChannel(Channel ch) { } if (onConnected != null) { try { - onConnected.accept(ctx); + onConnected.accept(ch); } catch (Throwable t) { Loggers.logger().warn("Error while processing onConnected()", t); } diff --git a/httpserver/src/main/java/io/esastack/httpserver/impl/HttpServerImpl.java b/httpserver/src/main/java/io/esastack/httpserver/impl/HttpServerImpl.java index 28e629c..0823f52 100644 --- a/httpserver/src/main/java/io/esastack/httpserver/impl/HttpServerImpl.java +++ b/httpserver/src/main/java/io/esastack/httpserver/impl/HttpServerImpl.java @@ -51,7 +51,8 @@ public class HttpServerImpl implements HttpServer { private final ServerRuntime runtime; private Consumer handler; - private Consumer onConnected; + private Consumer onConnectionInit; + private Consumer onConnected; private Consumer onDisconnected; private final CloseFuture closeFuture = new CloseFuture(); private final CopyOnWriteArrayList closures = new CopyOnWriteArrayList<>(); @@ -78,7 +79,14 @@ public synchronized HttpServerImpl handle(Consumer h) { } @Override - public synchronized HttpServerImpl onConnected(Consumer h) { + public HttpServer onConnectionInit(Consumer h) { + checkStarted(); + this.onConnectionInit = h; + return this; + } + + @Override + public synchronized HttpServerImpl onConnected(Consumer h) { checkStarted(); this.onConnected = h; return this; @@ -140,7 +148,7 @@ public Future closeFuture() { private synchronized HttpServerImpl listen0(SocketAddress address) { checkStarted(); Checks.checkNotNull(address, "address"); - Checks.checkNotNull(handler, "Request handler required"); + Checks.checkNotNull(handler, "Request handler required. Set it by HttpServer.handle(xxx)"); final ServerBootstrap bootstrap = new ServerBootstrap(); final Transport transport = Transports.transport(options().isPreferNativeTransport()); @@ -153,6 +161,7 @@ private synchronized HttpServerImpl listen0(SocketAddress address) { bootstrap.childHandler(new HttpServerChannelInitializr(runtime, sslHelper, handler, + onConnectionInit, onConnected, onDisconnected)); final EventLoopGroup bossGroup = transport.loop(options().getBossThreads(), diff --git a/httpserver/src/test/java/io/esastack/httpserver/impl/HttpServerChannelInitializrTest.java b/httpserver/src/test/java/io/esastack/httpserver/impl/HttpServerChannelInitializrTest.java index ad8bd65..f6345d4 100644 --- a/httpserver/src/test/java/io/esastack/httpserver/impl/HttpServerChannelInitializrTest.java +++ b/httpserver/src/test/java/io/esastack/httpserver/impl/HttpServerChannelInitializrTest.java @@ -67,6 +67,7 @@ void testCloseChannelIfServerIsShutdown() { new SslHelper(null, false), r -> r.response().end(), null, + null, null); final EmbeddedChannel channel = new EmbeddedChannel(initializr); assertFalse(channel.isActive()); @@ -81,6 +82,7 @@ void testApplyWriteBufferHighWaterMark() { new SslHelper(null, false), r -> r.response().end(), null, + null, null); final EmbeddedChannel channel = new EmbeddedChannel(initializr); assertEquals(WriteBufferWaterMark.DEFAULT.high() + 1, @@ -96,6 +98,7 @@ void testApplyWriteBufferLowWaterMark() { new SslHelper(null, false), r -> r.response().end(), null, + null, null); final EmbeddedChannel channel = new EmbeddedChannel(initializr); assertEquals(WriteBufferWaterMark.DEFAULT.low() - 1, @@ -112,6 +115,7 @@ void testApplyWriteBufferWaterMark() { new SslHelper(null, false), r -> r.response().end(), null, + null, null); final EmbeddedChannel channel = new EmbeddedChannel(initializr); assertEquals(WriteBufferWaterMark.DEFAULT.low() - 1, @@ -129,6 +133,7 @@ void testHAProxyDetectorInitialization() { new SslHelper(null, false), r -> r.response().end(), null, + null, null); final EmbeddedChannel channel = new EmbeddedChannel(initializr); assertNotNull(channel.pipeline().get(HAProxyDetector.class)); @@ -144,6 +149,7 @@ void testHAProxyDecoderInitialization() { new SslHelper(null, false), r -> r.response().end(), null, + null, null); final EmbeddedChannel channel = new EmbeddedChannel(initializr); assertNotNull(channel.pipeline().get(HAProxyMessageDecoder.class)); @@ -162,6 +168,7 @@ void testHAProxyOffInitialization() { new SslHelper(null, false), r -> r.response().end(), null, + null, null); final EmbeddedChannel channel = new EmbeddedChannel(initializr); assertNull(channel.pipeline().get(HAProxyMessageDecoder.class)); @@ -169,7 +176,8 @@ void testHAProxyOffInitialization() { } @Test - void testChannelActiveAndInactiveListener() { + void testChannelInitAndActiveAndInactiveListener() { + final AtomicBoolean init = new AtomicBoolean(); final AtomicBoolean active = new AtomicBoolean(); final AtomicBoolean inActive = new AtomicBoolean(); final ServerRuntime runtime = Helper.serverRuntime(ServerOptionsConfigure.newOpts() @@ -179,11 +187,13 @@ void testChannelActiveAndInactiveListener() { new HttpServerChannelInitializr(runtime, new SslHelper(null, false), r -> r.response().end(), - ctx -> active.set(true), + c -> init.set(true), + c -> active.set(true), c -> inActive.set(true)); final EmbeddedChannel channel = new EmbeddedChannel(initializr); + assertTrue(init.get()); assertTrue(active.get()); assertFalse(inActive.get()); @@ -203,6 +213,7 @@ void testChannelActiveAndInactiveListener() { @Test void testErrorOccurredInConnectionHandlerShouldBeIgnored() { + final AtomicBoolean init = new AtomicBoolean(); final AtomicBoolean active = new AtomicBoolean(); final AtomicBoolean inActive = new AtomicBoolean(); final ServerRuntime runtime = Helper.serverRuntime(ServerOptionsConfigure.newOpts() @@ -212,7 +223,11 @@ void testErrorOccurredInConnectionHandlerShouldBeIgnored() { new HttpServerChannelInitializr(runtime, new SslHelper(null, false), r -> r.response().end(), - ctx -> { + c -> { + init.set(true); + ExceptionUtils.throwException(new IllegalStateException()); + }, + c -> { active.set(true); ExceptionUtils.throwException(new IllegalStateException()); }, @@ -223,6 +238,7 @@ void testErrorOccurredInConnectionHandlerShouldBeIgnored() { final EmbeddedChannel channel = new EmbeddedChannel(initializr); + assertTrue(init.get()); assertTrue(active.get()); assertFalse(inActive.get()); channel.close(); @@ -426,6 +442,7 @@ private static void testUnsupportedUpgradeProtocol(boolean logging, new SslHelper(null, false), r -> r.response().end(), null, + null, null); final EmbeddedChannel channel = new EmbeddedChannel(initializr); @@ -502,6 +519,7 @@ private static void testUpgradeError(boolean logging, }); }, null, + null, null); final EmbeddedChannel channel = new EmbeddedChannel(initializr); @@ -573,6 +591,7 @@ private static void testUpgrade(boolean logging, new SslHelper(null, false), r -> r.response().end(), null, + null, null); final EmbeddedChannel channel = new EmbeddedChannel(initializr); @@ -637,6 +656,7 @@ private static void testHttp1HandlerInitialization(boolean logging, new SslHelper(null, false), r -> r.response().end(), null, + null, null); final EmbeddedChannel channel = new EmbeddedChannel(initializr); @@ -681,6 +701,7 @@ private static void testSslDetectFailedWithHttp2Disabled(boolean logging, new SslHelper(runtime.options().getSsl(), false), r -> r.response().end(), null, + null, null); final EmbeddedChannel channel = new EmbeddedChannel(initializr); @@ -737,6 +758,7 @@ private static void testSslDetectFailedWithHttp2Enabled(boolean logging, new SslHelper(runtime.options().getSsl(), false), r -> r.response().end(), null, + null, null); final EmbeddedChannel channel = new EmbeddedChannel(initializr); @@ -805,6 +827,7 @@ private static void testH1HandlerInitializationAfterH2cDetector(boolean logging, new SslHelper(null, false), r -> r.response().end(), null, + null, null); final EmbeddedChannel channel = new EmbeddedChannel(initializr); @@ -869,6 +892,7 @@ private static void testH2cHandlerInitializationAfterH2cDetector(boolean logging new SslHelper(null, false), r -> r.response().end(), null, + null, null); final EmbeddedChannel channel = new EmbeddedChannel(initializr); diff --git a/httpserver/src/test/java/io/esastack/httpserver/impl/HttpServerImplTest.java b/httpserver/src/test/java/io/esastack/httpserver/impl/HttpServerImplTest.java index eef83e3..a936e91 100644 --- a/httpserver/src/test/java/io/esastack/httpserver/impl/HttpServerImplTest.java +++ b/httpserver/src/test/java/io/esastack/httpserver/impl/HttpServerImplTest.java @@ -51,11 +51,13 @@ void testStatus() throws InterruptedException { final AtomicBoolean onClose = new AtomicBoolean(); + final AtomicBoolean onConnectionInit = new AtomicBoolean(); final AtomicBoolean onConnected = new AtomicBoolean(); final AtomicBoolean onDisconnected = new AtomicBoolean(); final AtomicBoolean onHandle = new AtomicBoolean(); final CompletableFuture closed = new CompletableFuture<>(); assertDoesNotThrow(() -> server.onClose(() -> onClose.set(true))); + assertDoesNotThrow(() -> server.onConnectionInit(ctx -> onConnectionInit.set(true))); assertDoesNotThrow(() -> server.onConnected(ctx -> onConnected.set(true))); assertDoesNotThrow(() -> server.onDisconnected(ctx -> onDisconnected.set(true))); assertDoesNotThrow(() -> server.handle(ctx -> onHandle.set(true))); @@ -100,10 +102,12 @@ void testStatus() throws InterruptedException { assertNotNull(server.bossGroup()); assertNotNull(server.ioGroup()); + assertThrows(IllegalStateException.class, () -> server.onConnectionInit(ctx -> onConnected.set(true))); assertThrows(IllegalStateException.class, () -> server.onConnected(ctx -> onConnected.set(true))); assertThrows(IllegalStateException.class, () -> server.onDisconnected(ctx -> onDisconnected.set(true))); assertThrows(IllegalStateException.class, () -> server.handle(ctx -> onHandle.set(true))); + assertFalse(onConnectionInit.get()); assertFalse(onConnected.get()); assertFalse(onDisconnected.get()); assertFalse(onHandle.get());