diff --git a/rpc/level0_test.go b/rpc/level0_test.go index a1f0affc..7fbeac6d 100644 --- a/rpc/level0_test.go +++ b/rpc/level0_test.go @@ -918,11 +918,7 @@ func TestRecvBootstrapCall(t *testing.T) { }) defer func() { finishTest(t, conn, p2) - select { - case <-srvShutdown: - default: - t.Error("Bootstrap client still alive after Close returned") - } + <-srvShutdown // Hangs if bootstrap client is never shut down. }() ctx := context.Background() @@ -1233,11 +1229,7 @@ func TestRecvBootstrapPipelineCall(t *testing.T) { }) defer func() { finishTest(t, conn, p2) - select { - case <-srvShutdown: - default: - t.Error("Bootstrap client still alive after Close returned") - } + <-srvShutdown // Will hang if closing does not shut down the client. }() ctx := context.Background() @@ -1699,13 +1691,9 @@ func TestRecvCancel(t *testing.T) { } // 8. Verify that returned capability was shut down. - // There's no guarantee when the release/shutdown will happen, other - // than it will be released before Close returns. - select { - case <-retcapShutdown: - default: - t.Error("returned capability was not shut down") - } + // There's no guarantee exactly when the release/shutdown will happen, + // but Close should trigger it. Otherwise, this will hang: + <-retcapShutdown } // TestSendCancel makes a call, cancels the Context, then checks to diff --git a/server/server.go b/server/server.go index 50548356..04f023c9 100644 --- a/server/server.go +++ b/server/server.go @@ -69,7 +69,7 @@ func (c *Call) Go() { return } c.acked = true - go c.srv.handleCalls(c.srv.handleCallsCtx) + go c.srv.handleCalls() } // Shutdowner is the interface that wraps the Shutdown method. @@ -84,15 +84,6 @@ type Server struct { brand any shutdown Shutdowner - // Cancels handleCallsCtx - cancelHandleCalls context.CancelFunc - - // Context used by the goroutine running handleCalls(). Note - // the calls themselves will have different contexts, which - // are not children of this context, but are supplied by - // start(). - handleCallsCtx context.Context - // wg is incremented each time a method is queued, and // decremented after it is handled. wg sync.WaitGroup @@ -114,19 +105,15 @@ func (s *Server) String() string { // guarantees message delivery order by blocking each call on the // return of the previous call or a call to Call.Go. func New(methods []Method, brand any, shutdown Shutdowner) *Server { - ctx, cancel := context.WithCancel(context.Background()) - srv := &Server{ - methods: make(sortedMethods, len(methods)), - brand: brand, - shutdown: shutdown, - callQueue: mpsc.New[*Call](), - cancelHandleCalls: cancel, - handleCallsCtx: ctx, + methods: make(sortedMethods, len(methods)), + brand: brand, + shutdown: shutdown, + callQueue: mpsc.New[*Call](), } copy(srv.methods, methods) sort.Sort(srv.methods) - go srv.handleCalls(ctx) + go srv.handleCalls() return srv } @@ -171,38 +158,20 @@ func (srv *Server) Recv(ctx context.Context, r capnp.Recv) capnp.PipelineCaller return srv.start(ctx, mm, r) } -func (srv *Server) handleCalls(ctx context.Context) { +func (srv *Server) handleCalls() { + ctx := context.Background() for { call, err := srv.callQueue.Recv(ctx) if err != nil { - // Context has been canceled; drain the rest of the queue, - // invoking handleCall() with the cancelled context to - // trigger cleanup. - var ok bool - call, ok = srv.callQueue.TryRecv() - if !ok { - return + // Queue closed; wait for outstanding calls and shut down. + if srv.shutdown != nil { + srv.wg.Wait() + srv.shutdown.Shutdown() } + return } - // The context for the individual call is not necessarily - // related to the context managing the server's lifetime - // (ctx); we need to monitor both and pass the call a - // context that will be canceled if *either* context is - // cancelled. - callCtx, cancelCall := context.WithCancel(call.ctx) - go func() { - defer cancelCall() - select { - case <-callCtx.Done(): - case <-ctx.Done(): - } - }() - func() { - defer cancelCall() - srv.handleCall(callCtx, call) - }() - + srv.handleCall(call) if call.acked { // Another goroutine has taken over; time // to retire. @@ -211,10 +180,10 @@ func (srv *Server) handleCalls(ctx context.Context) { } } -func (srv *Server) handleCall(ctx context.Context, c *Call) { +func (srv *Server) handleCall(c *Call) { defer srv.wg.Done() - err := c.method.Impl(ctx, c) + err := c.method.Impl(c.ctx, c) c.recv.ReleaseArgs() c.recv.Returner.PrepareReturn(err) @@ -246,15 +215,11 @@ func (srv *Server) Brand() capnp.Brand { return capnp.Brand{Value: serverBrand{srv.brand}} } -// Shutdown waits for ongoing calls to finish and calls Shutdown on the -// Shutdowner passed into NewServer. Shutdown must not be called more -// than once. +// Shutdown arranges for Shutdown to be called on the Shutdowner passed +// into NewServer after outstanding all calls have been serviced. +// Shutdown must not be called more than once. func (srv *Server) Shutdown() { - srv.cancelHandleCalls() - srv.wg.Wait() - if srv.shutdown != nil { - srv.shutdown.Shutdown() - } + srv.callQueue.Close() } // IsServer reports whether a brand returned by capnp.Client.Brand diff --git a/server/server_test.go b/server/server_test.go index 84ed89e3..fa3fc69d 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -7,6 +7,7 @@ import ( "sync" "sync/atomic" "testing" + "time" "github.com/stretchr/testify/require" @@ -198,17 +199,23 @@ func TestServerShutdown(t *testing.T) { echo := air.Echo_ServerToClient(blockingEchoImpl{wait}) defer echo.Release() ctx, cancel := context.WithCancel(context.Background()) - defer cancel() call, finish := echo.Echo(ctx, nil) defer finish() echo.Release() + + // Even though we've dropped the client, existing calls should + // still go through: select { case <-call.Done(): - if _, err := call.Struct(); err == nil { - t.Error("call finished without error") - } - default: - t.Error("call not done after Shutdown") + t.Error("call finished before cancel()") + case <-time.After(10 * time.Millisecond): + } + + cancel() + <-call.Done() // Will hang if cancel doesn't stop the call. + + if _, err := call.Struct(); err == nil { + t.Error("call finished without error") } }