Skip to content

Commit 4f6d8a5

Browse files
committed
net/rpc: wait for responses to be written before closing Codec
If there are no more requests being made, wait to shut down the response-writing codec until the pending requests are all answered. Fixes #17239. Change-Id: Ie62c63ada536171df4e70b73c95f98f778069972 Reviewed-on: https://go-review.googlesource.com/79515 Run-TryBot: Russ Cox <[email protected]> TryBot-Result: Gobot Gobot <[email protected]> Reviewed-by: Rob Pike <[email protected]>
1 parent 70ee9b4 commit 4f6d8a5

File tree

2 files changed

+63
-3
lines changed

2 files changed

+63
-3
lines changed

src/net/rpc/server.go

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,10 @@ func (m *methodType) NumCalls() (n uint) {
372372
return n
373373
}
374374

375-
func (s *service) call(server *Server, sending *sync.Mutex, mtype *methodType, req *Request, argv, replyv reflect.Value, codec ServerCodec) {
375+
func (s *service) call(server *Server, sending *sync.Mutex, wg *sync.WaitGroup, mtype *methodType, req *Request, argv, replyv reflect.Value, codec ServerCodec) {
376+
if wg != nil {
377+
defer wg.Done()
378+
}
376379
mtype.Lock()
377380
mtype.numCalls++
378381
mtype.Unlock()
@@ -456,6 +459,7 @@ func (server *Server) ServeConn(conn io.ReadWriteCloser) {
456459
// decode requests and encode responses.
457460
func (server *Server) ServeCodec(codec ServerCodec) {
458461
sending := new(sync.Mutex)
462+
wg := new(sync.WaitGroup)
459463
for {
460464
service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec)
461465
if err != nil {
@@ -472,8 +476,12 @@ func (server *Server) ServeCodec(codec ServerCodec) {
472476
}
473477
continue
474478
}
475-
go service.call(server, sending, mtype, req, argv, replyv, codec)
479+
wg.Add(1)
480+
go service.call(server, sending, wg, mtype, req, argv, replyv, codec)
476481
}
482+
// We've seen that there are no more requests.
483+
// Wait for responses to be sent before closing codec.
484+
wg.Wait()
477485
codec.Close()
478486
}
479487

@@ -493,7 +501,7 @@ func (server *Server) ServeRequest(codec ServerCodec) error {
493501
}
494502
return err
495503
}
496-
service.call(server, sending, mtype, req, argv, replyv, codec)
504+
service.call(server, sending, nil, mtype, req, argv, replyv, codec)
497505
return nil
498506
}
499507

src/net/rpc/server_test.go

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,11 @@ func (t *Arith) Error(args *Args, reply *Reply) error {
7575
panic("ERROR")
7676
}
7777

78+
func (t *Arith) SleepMilli(args *Args, reply *Reply) error {
79+
time.Sleep(time.Duration(args.A) * time.Millisecond)
80+
return nil
81+
}
82+
7883
type hidden int
7984

8085
func (t *hidden) Exported(args Args, reply *Reply) error {
@@ -693,6 +698,53 @@ func TestAcceptExitAfterListenerClose(t *testing.T) {
693698
newServer.Accept(l)
694699
}
695700

701+
func TestShutdown(t *testing.T) {
702+
var l net.Listener
703+
l, _ = listenTCP()
704+
ch := make(chan net.Conn, 1)
705+
go func() {
706+
defer l.Close()
707+
c, err := l.Accept()
708+
if err != nil {
709+
t.Error(err)
710+
}
711+
ch <- c
712+
}()
713+
c, err := net.Dial("tcp", l.Addr().String())
714+
if err != nil {
715+
t.Fatal(err)
716+
}
717+
c1 := <-ch
718+
if c1 == nil {
719+
t.Fatal(err)
720+
}
721+
722+
newServer := NewServer()
723+
newServer.Register(new(Arith))
724+
go newServer.ServeConn(c1)
725+
726+
args := &Args{7, 8}
727+
reply := new(Reply)
728+
client := NewClient(c)
729+
err = client.Call("Arith.Add", args, reply)
730+
if err != nil {
731+
t.Fatal(err)
732+
}
733+
734+
// On an unloaded system 10ms is usually enough to fail 100% of the time
735+
// with a broken server. On a loaded system, a broken server might incorrectly
736+
// be reported as passing, but we're OK with that kind of flakiness.
737+
// If the code is correct, this test will never fail, regardless of timeout.
738+
args.A = 10 // 10 ms
739+
done := make(chan *Call, 1)
740+
call := client.Go("Arith.SleepMilli", args, reply, done)
741+
c.(*net.TCPConn).CloseWrite()
742+
<-done
743+
if call.Error != nil {
744+
t.Fatal(err)
745+
}
746+
}
747+
696748
func benchmarkEndToEnd(dial func() (*Client, error), b *testing.B) {
697749
once.Do(startServer)
698750
client, err := dial()

0 commit comments

Comments
 (0)