Skip to content

Commit 9c2cd38

Browse files
credentials/xds: fix goroutine leak in testServer (#8699)
Fixes #8694 This PR fixes a goroutine leak in `credentials/xds/xds_client_test.go`. Previously, the `testServer` used standard `Send()` calls . If a test timed out or failed before reading the expected value, the `testServer` goroutine would block indefinitely on the channel, causing a leak. Replaced blocking `Send` calls with `SendContext` in `handleConn`. This ensures that if the test ends (canceling the context), the `testServer` stops trying to send and exits its goroutine gracefully. RELEASE NOTES: None
1 parent db4cc9f commit 9c2cd38

File tree

2 files changed

+34
-32
lines changed

2 files changed

+34
-32
lines changed

credentials/xds/xds_client_test.go

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -95,40 +95,40 @@ type testHandshakeFunc func(net.Conn) handshakeResult
9595
// newTestServerWithHandshakeFunc starts a new testServer which listens for
9696
// connections on a local TCP port, and uses the provided custom handshake
9797
// function to perform TLS handshake.
98-
func newTestServerWithHandshakeFunc(f testHandshakeFunc) *testServer {
98+
func newTestServerWithHandshakeFunc(ctx context.Context, f testHandshakeFunc) *testServer {
9999
ts := &testServer{
100100
handshakeFunc: f,
101101
hsResult: testutils.NewChannel(),
102102
}
103-
ts.start()
103+
ts.start(ctx)
104104
return ts
105105
}
106106

107107
// starts actually starts listening on a local TCP port, and spawns a goroutine
108108
// to handle new connections.
109-
func (ts *testServer) start() error {
109+
func (ts *testServer) start(ctx context.Context) error {
110110
lis, err := net.Listen("tcp", "localhost:0")
111111
if err != nil {
112112
return err
113113
}
114114
ts.lis = lis
115115
ts.address = lis.Addr().String()
116-
go ts.handleConn()
116+
go ts.handleConn(ctx)
117117
return nil
118118
}
119119

120120
// handleConn accepts a new raw connection, and invokes the test provided
121121
// handshake function to perform TLS handshake, and returns the result on the
122122
// `hsResult` channel.
123-
func (ts *testServer) handleConn() {
123+
func (ts *testServer) handleConn(ctx context.Context) {
124124
for {
125125
rawConn, err := ts.lis.Accept()
126126
if err != nil {
127127
// Once the listeners closed, Accept() will return with an error.
128128
return
129129
}
130130
hsr := ts.handshakeFunc(rawConn)
131-
ts.hsResult.Send(hsr)
131+
ts.hsResult.SendContext(ctx, hsr)
132132
}
133133
}
134134

@@ -388,7 +388,9 @@ func (s) TestClientCredsSuccess(t *testing.T) {
388388

389389
for _, test := range tests {
390390
t.Run(test.desc, func(t *testing.T) {
391-
ts := newTestServerWithHandshakeFunc(test.handshakeFunc)
391+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
392+
defer cancel()
393+
ts := newTestServerWithHandshakeFunc(ctx, test.handshakeFunc)
392394
defer ts.stop()
393395

394396
opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
@@ -403,8 +405,6 @@ func (s) TestClientCredsSuccess(t *testing.T) {
403405
}
404406
defer conn.Close()
405407

406-
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
407-
defer cancel()
408408
_, ai, err := creds.ClientHandshake(test.handshakeInfoCtx(ctx), authority, conn)
409409
if err != nil {
410410
t.Fatalf("ClientHandshake() returned failed: %q", err)
@@ -418,11 +418,13 @@ func (s) TestClientCredsSuccess(t *testing.T) {
418418

419419
func (s) TestClientCredsHandshakeTimeout(t *testing.T) {
420420
clientDone := make(chan struct{})
421+
ctx, sCancel := context.WithTimeout(context.Background(), defaultTestTimeout)
422+
defer sCancel()
421423
// A handshake function which simulates a handshake timeout from the
422424
// server-side by simply blocking on the client-side handshake to timeout
423425
// and not writing any handshake data.
424426
hErr := errors.New("server handshake error")
425-
ts := newTestServerWithHandshakeFunc(func(net.Conn) handshakeResult {
427+
ts := newTestServerWithHandshakeFunc(ctx, func(net.Conn) handshakeResult {
426428
<-clientDone
427429
return handshakeResult{err: hErr}
428430
})
@@ -442,7 +444,7 @@ func (s) TestClientCredsHandshakeTimeout(t *testing.T) {
442444

443445
sCtx, sCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout)
444446
defer sCancel()
445-
ctx := newTestContextWithHandshakeInfo(sCtx, makeRootProvider(t, "x509/server_ca_cert.pem"), nil, defaultTestCertSAN)
447+
ctx = newTestContextWithHandshakeInfo(sCtx, makeRootProvider(t, "x509/server_ca_cert.pem"), nil, defaultTestCertSAN)
446448
if _, _, err := creds.ClientHandshake(ctx, authority, conn); err == nil {
447449
t.Fatal("ClientHandshake() succeeded when expected to timeout")
448450
}
@@ -489,7 +491,9 @@ func (s) TestClientCredsHandshakeFailure(t *testing.T) {
489491

490492
for _, test := range tests {
491493
t.Run(test.desc, func(t *testing.T) {
492-
ts := newTestServerWithHandshakeFunc(test.handshakeFunc)
494+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
495+
defer cancel()
496+
ts := newTestServerWithHandshakeFunc(ctx, test.handshakeFunc)
493497
defer ts.stop()
494498

495499
opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
@@ -504,8 +508,6 @@ func (s) TestClientCredsHandshakeFailure(t *testing.T) {
504508
}
505509
defer conn.Close()
506510

507-
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
508-
defer cancel()
509511
ctx = newTestContextWithHandshakeInfo(ctx, test.rootProvider, nil, test.san)
510512
if _, _, err := creds.ClientHandshake(ctx, authority, conn); err == nil || !strings.Contains(err.Error(), test.wantErr) {
511513
t.Fatalf("ClientHandshake() returned %q, wantErr %q", err, test.wantErr)
@@ -520,7 +522,9 @@ func (s) TestClientCredsHandshakeFailure(t *testing.T) {
520522
// approximation of the flow of events when the control plane specifies new
521523
// security config which results in new certificate providers being used.
522524
func (s) TestClientCredsProviderSwitch(t *testing.T) {
523-
ts := newTestServerWithHandshakeFunc(testServerTLSHandshake)
525+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
526+
defer cancel()
527+
ts := newTestServerWithHandshakeFunc(ctx, testServerTLSHandshake)
524528
defer ts.stop()
525529

526530
opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
@@ -535,8 +539,6 @@ func (s) TestClientCredsProviderSwitch(t *testing.T) {
535539
}
536540
defer conn.Close()
537541

538-
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
539-
defer cancel()
540542
// Create a root provider which will fail the handshake because it does not
541543
// use the correct trust roots.
542544
root1 := makeRootProvider(t, "x509/client_ca_cert.pem")

credentials/xds/xds_server_test.go

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -178,10 +178,12 @@ func (s) TestServerCredsHandshake_XDSHandshakeInfoError(t *testing.T) {
178178
if err != nil {
179179
t.Fatalf("NewServerCredentials(%v) failed: %v", opts, err)
180180
}
181+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
182+
defer cancel()
181183

182184
// Create a test server which uses the xDS server credentials created above
183185
// to perform TLS handshake on incoming connections.
184-
ts := newTestServerWithHandshakeFunc(func(rawConn net.Conn) handshakeResult {
186+
ts := newTestServerWithHandshakeFunc(ctx, func(rawConn net.Conn) handshakeResult {
185187
// Create a wrapped conn which returns a nil HandshakeInfo and a non-nil error.
186188
conn := newWrappedConn(rawConn, nil, time.Now().Add(defaultTestTimeout))
187189
hiErr := errors.New("xdsHandshakeInfo error")
@@ -208,8 +210,6 @@ func (s) TestServerCredsHandshake_XDSHandshakeInfoError(t *testing.T) {
208210

209211
// Read handshake result from the testServer which will return an error if
210212
// the handshake succeeded.
211-
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
212-
defer cancel()
213213
val, err := ts.hsResult.Receive(ctx)
214214
if err != nil {
215215
t.Fatalf("testServer failed to return handshake result: %v", err)
@@ -229,10 +229,12 @@ func (s) TestServerCredsHandshakeTimeout(t *testing.T) {
229229
if err != nil {
230230
t.Fatalf("NewServerCredentials(%v) failed: %v", opts, err)
231231
}
232+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
233+
defer cancel()
232234

233235
// Create a test server which uses the xDS server credentials created above
234236
// to perform TLS handshake on incoming connections.
235-
ts := newTestServerWithHandshakeFunc(func(rawConn net.Conn) handshakeResult {
237+
ts := newTestServerWithHandshakeFunc(ctx, func(rawConn net.Conn) handshakeResult {
236238
hi := xdsinternal.NewHandshakeInfo(makeRootProvider(t, "x509/client_ca_cert.pem"), makeIdentityProvider(t, "x509/server2_cert.pem", "x509/server2_key.pem"), nil, true)
237239

238240
// Create a wrapped conn which can return the HandshakeInfo created
@@ -258,8 +260,6 @@ func (s) TestServerCredsHandshakeTimeout(t *testing.T) {
258260
defer rawConn.Close()
259261

260262
// Read handshake result from the testServer and expect a failure result.
261-
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
262-
defer cancel()
263263
val, err := ts.hsResult.Receive(ctx)
264264
if err != nil {
265265
t.Fatalf("testServer failed to return handshake result: %v", err)
@@ -279,10 +279,12 @@ func (s) TestServerCredsHandshakeFailure(t *testing.T) {
279279
if err != nil {
280280
t.Fatalf("NewServerCredentials(%v) failed: %v", opts, err)
281281
}
282+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
283+
defer cancel()
282284

283285
// Create a test server which uses the xDS server credentials created above
284286
// to perform TLS handshake on incoming connections.
285-
ts := newTestServerWithHandshakeFunc(func(rawConn net.Conn) handshakeResult {
287+
ts := newTestServerWithHandshakeFunc(ctx, func(rawConn net.Conn) handshakeResult {
286288
// Create a HandshakeInfo which has a root provider which does not match
287289
// the certificate sent by the client.
288290
hi := xdsinternal.NewHandshakeInfo(makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/client2_cert.pem", "x509/client2_key.pem"), nil, true)
@@ -314,8 +316,6 @@ func (s) TestServerCredsHandshakeFailure(t *testing.T) {
314316

315317
// Read handshake result from the testServer which will return an error if
316318
// the handshake succeeded.
317-
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
318-
defer cancel()
319319
val, err := ts.hsResult.Receive(ctx)
320320
if err != nil {
321321
t.Fatalf("testServer failed to return handshake result: %v", err)
@@ -361,10 +361,12 @@ func (s) TestServerCredsHandshakeSuccess(t *testing.T) {
361361
if err != nil {
362362
t.Fatalf("NewServerCredentials(%v) failed: %v", opts, err)
363363
}
364+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
365+
defer cancel()
364366

365367
// Create a test server which uses the xDS server credentials
366368
// created above to perform TLS handshake on incoming connections.
367-
ts := newTestServerWithHandshakeFunc(func(rawConn net.Conn) handshakeResult {
369+
ts := newTestServerWithHandshakeFunc(ctx, func(rawConn net.Conn) handshakeResult {
368370
// Create a HandshakeInfo with information from the test table.
369371
hi := xdsinternal.NewHandshakeInfo(test.rootProvider, test.identityProvider, nil, test.requireClientCert)
370372

@@ -406,8 +408,6 @@ func (s) TestServerCredsHandshakeSuccess(t *testing.T) {
406408
// Read the handshake result from the testServer which contains the
407409
// TLS connection state on the server-side and compare it with the
408410
// one received on the client-side.
409-
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
410-
defer cancel()
411411
val, err := ts.hsResult.Receive(ctx)
412412
if err != nil {
413413
t.Fatalf("testServer failed to return handshake result: %v", err)
@@ -433,14 +433,16 @@ func (s) TestServerCredsProviderSwitch(t *testing.T) {
433433
if err != nil {
434434
t.Fatalf("NewServerCredentials(%v) failed: %v", opts, err)
435435
}
436+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
437+
defer cancel()
436438

437439
// The first time the handshake function is invoked, it returns a
438440
// HandshakeInfo which is expected to fail. Further invocations return a
439441
// HandshakeInfo which is expected to succeed.
440442
cnt := 0
441443
// Create a test server which uses the xDS server credentials created above
442444
// to perform TLS handshake on incoming connections.
443-
ts := newTestServerWithHandshakeFunc(func(rawConn net.Conn) handshakeResult {
445+
ts := newTestServerWithHandshakeFunc(ctx, func(rawConn net.Conn) handshakeResult {
444446
cnt++
445447
var hi *xdsinternal.HandshakeInfo
446448
if cnt == 1 {
@@ -501,8 +503,6 @@ func (s) TestServerCredsProviderSwitch(t *testing.T) {
501503
// Read the handshake result from the testServer which contains the
502504
// TLS connection state on the server-side and compare it with the
503505
// one received on the client-side.
504-
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
505-
defer cancel()
506506
val, err := ts.hsResult.Receive(ctx)
507507
if err != nil {
508508
t.Fatalf("testServer failed to return handshake result: %v", err)

0 commit comments

Comments
 (0)