Skip to content

Commit 8719ce7

Browse files
committed
Remove contexts without timeouts
1 parent 9afb49d commit 8719ce7

File tree

13 files changed

+167
-106
lines changed

13 files changed

+167
-106
lines changed

credentials/google/google_test.go

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
"context"
2323
"net"
2424
"testing"
25+
"time"
2526

2627
"github.com/google/go-cmp/cmp"
2728
"google.golang.org/grpc/credentials"
@@ -31,6 +32,8 @@ import (
3132
"google.golang.org/grpc/resolver"
3233
)
3334

35+
var defaultTestTimeout = 10 * time.Second
36+
3437
type s struct {
3538
grpctest.Tester
3639
}
@@ -103,6 +106,8 @@ func overrideNewCredsFuncs() func() {
103106
// modes), ClientHandshake does either tls or alts base on the cluster name in
104107
// attributes.
105108
func (s) TestClientHandshakeBasedOnClusterName(t *testing.T) {
109+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
110+
defer cancel()
106111
defer overrideNewCredsFuncs()()
107112
for bundleTyp, tc := range map[string]credentials.Bundle{
108113
"defaultCredsWithOptions": NewDefaultCredentialsWithOptions(DefaultCredentialsOptions{}),
@@ -116,36 +121,36 @@ func (s) TestClientHandshakeBasedOnClusterName(t *testing.T) {
116121
}{
117122
{
118123
name: "no cluster name",
119-
ctx: context.Background(),
124+
ctx: ctx,
120125
wantTyp: "tls",
121126
},
122127
{
123128
name: "with non-CFE cluster name",
124-
ctx: icredentials.NewClientHandshakeInfoContext(context.Background(), credentials.ClientHandshakeInfo{
129+
ctx: icredentials.NewClientHandshakeInfoContext(ctx, credentials.ClientHandshakeInfo{
125130
Attributes: xds.SetXDSHandshakeClusterName(resolver.Address{}, "lalala").Attributes,
126131
}),
127132
// non-CFE backends should use alts.
128133
wantTyp: "alts",
129134
},
130135
{
131136
name: "with CFE cluster name",
132-
ctx: icredentials.NewClientHandshakeInfoContext(context.Background(), credentials.ClientHandshakeInfo{
137+
ctx: icredentials.NewClientHandshakeInfoContext(ctx, credentials.ClientHandshakeInfo{
133138
Attributes: xds.SetXDSHandshakeClusterName(resolver.Address{}, "google_cfe_bigtable.googleapis.com").Attributes,
134139
}),
135140
// CFE should use tls.
136141
wantTyp: "tls",
137142
},
138143
{
139144
name: "with xdstp CFE cluster name",
140-
ctx: icredentials.NewClientHandshakeInfoContext(context.Background(), credentials.ClientHandshakeInfo{
145+
ctx: icredentials.NewClientHandshakeInfoContext(ctx, credentials.ClientHandshakeInfo{
141146
Attributes: xds.SetXDSHandshakeClusterName(resolver.Address{}, "xdstp://traffic-director-c2p.xds.googleapis.com/envoy.config.cluster.v3.Cluster/google_cfe_bigtable.googleapis.com").Attributes,
142147
}),
143148
// CFE should use tls.
144149
wantTyp: "tls",
145150
},
146151
{
147152
name: "with xdstp non-CFE cluster name",
148-
ctx: icredentials.NewClientHandshakeInfoContext(context.Background(), credentials.ClientHandshakeInfo{
153+
ctx: icredentials.NewClientHandshakeInfoContext(ctx, credentials.ClientHandshakeInfo{
149154
Attributes: xds.SetXDSHandshakeClusterName(resolver.Address{}, "xdstp://other.com/envoy.config.cluster.v3.Cluster/google_cfe_bigtable.googleapis.com").Attributes,
150155
}),
151156
// non-CFE should use atls.
@@ -176,6 +181,8 @@ func (s) TestClientHandshakeBasedOnClusterName(t *testing.T) {
176181
}
177182

178183
func TestDefaultCredentialsWithOptions(t *testing.T) {
184+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
185+
defer cancel()
179186
md1 := map[string]string{"foo": "tls"}
180187
md2 := map[string]string{"foo": "alts"}
181188
tests := []struct {
@@ -248,7 +255,7 @@ func TestDefaultCredentialsWithOptions(t *testing.T) {
248255
t.Run(tc.desc, func(t *testing.T) {
249256
bundle := NewDefaultCredentialsWithOptions(tc.defaultCredsOpts)
250257
ri := credentials.RequestInfo{AuthInfo: tc.authInfo}
251-
ctx := icredentials.NewRequestInfoContext(context.Background(), ri)
258+
ctx := icredentials.NewRequestInfoContext(ctx, ri)
252259
got, err := bundle.PerRPCCredentials().GetRequestMetadata(ctx, "uri")
253260
if err != nil {
254261
t.Fatalf("Bundle's PerRPCCredentials().GetRequestMetadata() unexpected error = %v", err)

credentials/google/xds_test.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,10 @@ import (
2929
)
3030

3131
func (s) TestIsDirectPathCluster(t *testing.T) {
32+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
33+
defer cancel()
3234
c := func(cluster string) context.Context {
33-
return icredentials.NewClientHandshakeInfoContext(context.Background(), credentials.ClientHandshakeInfo{
35+
return icredentials.NewClientHandshakeInfoContext(ctx, credentials.ClientHandshakeInfo{
3436
Attributes: xds.SetXDSHandshakeClusterName(resolver.Address{}, cluster).Attributes,
3537
})
3638
}
@@ -40,7 +42,7 @@ func (s) TestIsDirectPathCluster(t *testing.T) {
4042
ctx context.Context
4143
want bool
4244
}{
43-
{"not an xDS cluster", context.Background(), false},
45+
{"not an xDS cluster", ctx, false},
4446
{"cfe", c("google_cfe_bigtable.googleapis.com"), false},
4547
{"non-cfe", c("google_bigtable.googleapis.com"), true},
4648
{"starts with xdstp but not cfe format", c("xdstp:google_cfe_bigtable.googleapis.com"), true},

internal/transport/handler_server_test.go

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -312,8 +312,10 @@ func (s) TestHandlerTransport_HandleStreams(t *testing.T) {
312312
st.bodyw.Close() // no body
313313
s.WriteStatus(status.New(codes.OK, ""))
314314
}
315+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
316+
defer cancel()
315317
st.ht.HandleStreams(
316-
context.Background(), func(s *ServerStream) { go handleStream(s) },
318+
ctx, func(s *ServerStream) { go handleStream(s) },
317319
)
318320
wantHeader := http.Header{
319321
"Date": nil,
@@ -345,8 +347,10 @@ func handleStreamCloseBodyTest(t *testing.T, statusCode codes.Code, msg string)
345347
handleStream := func(s *ServerStream) {
346348
s.WriteStatus(status.New(statusCode, msg))
347349
}
350+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
351+
defer cancel()
348352
st.ht.HandleStreams(
349-
context.Background(), func(s *ServerStream) { go handleStream(s) },
353+
ctx, func(s *ServerStream) { go handleStream(s) },
350354
)
351355
wantHeader := http.Header{
352356
"Date": nil,
@@ -394,8 +398,10 @@ func (s) TestHandlerTransport_HandleStreams_Timeout(t *testing.T) {
394398
}
395399
s.WriteStatus(status.New(codes.DeadlineExceeded, "too slow"))
396400
}
401+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
402+
defer cancel()
397403
ht.HandleStreams(
398-
context.Background(), func(s *ServerStream) { go runStream(s) },
404+
ctx, func(s *ServerStream) { go runStream(s) },
399405
)
400406
wantHeader := http.Header{
401407
"Date": nil,
@@ -446,8 +452,10 @@ func (s) TestHandlerTransport_HandleStreams_WriteStatusWrite(t *testing.T) {
446452

447453
func testHandlerTransportHandleStreams(t *testing.T, handleStream func(st *handleStreamTest, s *ServerStream)) {
448454
st := newHandleStreamTest(t)
455+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
456+
defer cancel()
449457
st.ht.HandleStreams(
450-
context.Background(), func(s *ServerStream) { go handleStream(st, s) },
458+
ctx, func(s *ServerStream) { go handleStream(st, s) },
451459
)
452460
}
453461

@@ -479,8 +487,10 @@ func (s) TestHandlerTransport_HandleStreams_ErrDetails(t *testing.T) {
479487
handleStream := func(s *ServerStream) {
480488
s.WriteStatus(st)
481489
}
490+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
491+
defer cancel()
482492
hst.ht.HandleStreams(
483-
context.Background(), func(s *ServerStream) { go handleStream(s) },
493+
ctx, func(s *ServerStream) { go handleStream(s) },
484494
)
485495
wantHeader := http.Header{
486496
"Date": nil,

internal/transport/transport_test.go

Lines changed: 42 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -381,21 +381,23 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT
381381
h := &testStreamHandler{t: transport.(*http2Server)}
382382
s.h = h
383383
s.mu.Unlock()
384+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
385+
defer cancel()
384386
switch ht {
385387
case notifyCall:
386-
go transport.HandleStreams(context.Background(), h.handleStreamAndNotify)
388+
go transport.HandleStreams(ctx, h.handleStreamAndNotify)
387389
case suspended:
388-
go transport.HandleStreams(context.Background(), func(*ServerStream) {})
390+
go transport.HandleStreams(ctx, func(*ServerStream) {})
389391
case misbehaved:
390-
go transport.HandleStreams(context.Background(), func(s *ServerStream) {
392+
go transport.HandleStreams(ctx, func(s *ServerStream) {
391393
go h.handleStreamMisbehave(t, s)
392394
})
393395
case encodingRequiredStatus:
394-
go transport.HandleStreams(context.Background(), func(s *ServerStream) {
396+
go transport.HandleStreams(ctx, func(s *ServerStream) {
395397
go h.handleStreamEncodingRequiredStatus(s)
396398
})
397399
case invalidHeaderField:
398-
go transport.HandleStreams(context.Background(), func(s *ServerStream) {
400+
go transport.HandleStreams(ctx, func(s *ServerStream) {
399401
go h.handleStreamInvalidHeaderField(s)
400402
})
401403
case delayRead:
@@ -404,15 +406,15 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT
404406
s.mu.Lock()
405407
close(s.ready)
406408
s.mu.Unlock()
407-
go transport.HandleStreams(context.Background(), func(s *ServerStream) {
409+
go transport.HandleStreams(ctx, func(s *ServerStream) {
408410
go h.handleStreamDelayRead(t, s)
409411
})
410412
case pingpong:
411-
go transport.HandleStreams(context.Background(), func(s *ServerStream) {
413+
go transport.HandleStreams(ctx, func(s *ServerStream) {
412414
go h.handleStreamPingPong(t, s)
413415
})
414416
default:
415-
go transport.HandleStreams(context.Background(), func(s *ServerStream) {
417+
go transport.HandleStreams(ctx, func(s *ServerStream) {
416418
go h.handleStream(t, s)
417419
})
418420
}
@@ -464,13 +466,15 @@ func setUpWithOptions(t *testing.T, port int, sc *ServerConfig, ht hType, copts
464466
addr := resolver.Address{Addr: "localhost:" + server.port}
465467
copts.ChannelzParent = channelzSubChannel(t)
466468

467-
connectCtx, cancel := context.WithDeadline(context.Background(), time.Now().Add(2*time.Second))
468-
ct, connErr := NewHTTP2Client(connectCtx, context.Background(), addr, copts, func(GoAwayReason) {})
469+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
470+
t.Cleanup(cancel)
471+
connectCtx, cCancel := context.WithTimeout(context.Background(), 2*time.Second)
472+
ct, connErr := NewHTTP2Client(connectCtx, ctx, addr, copts, func(GoAwayReason) {})
469473
if connErr != nil {
470-
cancel() // Do not cancel in success path.
474+
cCancel() // Do not cancel in success path.
471475
t.Fatalf("failed to create transport: %v", connErr)
472476
}
473-
return server, ct.(*http2Client), cancel
477+
return server, ct.(*http2Client), cCancel
474478
}
475479

476480
func setUpWithNoPingServer(t *testing.T, copts ConnectOptions, connCh chan net.Conn) (*http2Client, func()) {
@@ -495,18 +499,20 @@ func setUpWithNoPingServer(t *testing.T, copts ConnectOptions, connCh chan net.C
495499
}
496500
connCh <- conn
497501
}()
498-
connectCtx, cancel := context.WithDeadline(context.Background(), time.Now().Add(2*time.Second))
499-
tr, err := NewHTTP2Client(connectCtx, context.Background(), resolver.Address{Addr: lis.Addr().String()}, copts, func(GoAwayReason) {})
502+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
503+
t.Cleanup(cancel)
504+
connectCtx, cCancel := context.WithTimeout(context.Background(), 2*time.Second)
505+
tr, err := NewHTTP2Client(connectCtx, ctx, resolver.Address{Addr: lis.Addr().String()}, copts, func(GoAwayReason) {})
500506
if err != nil {
501-
cancel() // Do not cancel in success path.
507+
cCancel() // Do not cancel in success path.
502508
// Server clean-up.
503509
lis.Close()
504510
if conn, ok := <-connCh; ok {
505511
conn.Close()
506512
}
507513
t.Fatalf("Failed to dial: %v", err)
508514
}
509-
return tr.(*http2Client), cancel
515+
return tr.(*http2Client), cCancel
510516
}
511517

512518
// TestInflightStreamClosing ensures that closing in-flight stream
@@ -739,7 +745,7 @@ func (s) TestLargeMessageWithDelayRead(t *testing.T) {
739745
Host: "localhost",
740746
Method: "foo.Large",
741747
}
742-
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second*10))
748+
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
743749
defer cancel()
744750
s, err := ct.NewStream(ctx, callHdr)
745751
if err != nil {
@@ -833,7 +839,7 @@ func (s) TestGracefulClose(t *testing.T) {
833839
// Correctly clean up the server
834840
server.stop()
835841
}()
836-
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second*10))
842+
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
837843
defer cancel()
838844

839845
// Create a stream that will exist for this whole test and confirm basic
@@ -969,7 +975,7 @@ func (s) TestMaxStreams(t *testing.T) {
969975
// Try and create a new stream.
970976
go func() {
971977
defer close(done)
972-
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second*10))
978+
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
973979
defer cancel()
974980
if _, err := ct.NewStream(ctx, callHdr); err != nil {
975981
t.Errorf("Failed to open stream: %v", err)
@@ -1353,7 +1359,9 @@ func (s) TestClientHonorsConnectContext(t *testing.T) {
13531359

13541360
parent := channelzSubChannel(t)
13551361
copts := ConnectOptions{ChannelzParent: parent}
1356-
_, err = NewHTTP2Client(connectCtx, context.Background(), resolver.Address{Addr: lis.Addr().String()}, copts, func(GoAwayReason) {})
1362+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
1363+
defer cancel()
1364+
_, err = NewHTTP2Client(connectCtx, ctx, resolver.Address{Addr: lis.Addr().String()}, copts, func(GoAwayReason) {})
13571365
if err == nil {
13581366
t.Fatalf("NewHTTP2Client() returned successfully; wanted error")
13591367
}
@@ -1365,7 +1373,7 @@ func (s) TestClientHonorsConnectContext(t *testing.T) {
13651373
// Test context deadline.
13661374
connectCtx, cancel = context.WithTimeout(context.Background(), 100*time.Millisecond)
13671375
defer cancel()
1368-
_, err = NewHTTP2Client(connectCtx, context.Background(), resolver.Address{Addr: lis.Addr().String()}, copts, func(GoAwayReason) {})
1376+
_, err = NewHTTP2Client(connectCtx, ctx, resolver.Address{Addr: lis.Addr().String()}, copts, func(GoAwayReason) {})
13691377
if err == nil {
13701378
t.Fatalf("NewHTTP2Client() returned successfully; wanted error")
13711379
}
@@ -1440,12 +1448,14 @@ func (s) TestClientWithMisbehavedServer(t *testing.T) {
14401448
}
14411449
}
14421450
}()
1443-
connectCtx, cancel := context.WithDeadline(context.Background(), time.Now().Add(2*time.Second))
1451+
connectCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
14441452
defer cancel()
14451453

14461454
parent := channelzSubChannel(t)
14471455
copts := ConnectOptions{ChannelzParent: parent}
1448-
ct, err := NewHTTP2Client(connectCtx, context.Background(), resolver.Address{Addr: lis.Addr().String()}, copts, func(GoAwayReason) {})
1456+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
1457+
defer cancel()
1458+
ct, err := NewHTTP2Client(connectCtx, ctx, resolver.Address{Addr: lis.Addr().String()}, copts, func(GoAwayReason) {})
14491459
if err != nil {
14501460
t.Fatalf("Error while creating client transport: %v", err)
14511461
}
@@ -1779,9 +1789,11 @@ func waitWhileTrue(t *testing.T, condition func() (bool, error)) {
17791789
// If any error occurs on a call to Stream.Read, future calls
17801790
// should continue to return that same error.
17811791
func (s) TestReadGivesSameErrorAfterAnyErrorOccurs(t *testing.T) {
1792+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
1793+
defer cancel()
17821794
testRecvBuffer := newRecvBuffer()
17831795
s := &Stream{
1784-
ctx: context.Background(),
1796+
ctx: ctx,
17851797
buf: testRecvBuffer,
17861798
requestRead: func(int) {},
17871799
}
@@ -2450,7 +2462,7 @@ func (s) TestClientHandshakeInfo(t *testing.T) {
24502462
Addr: "localhost:" + server.port,
24512463
Attributes: attributes.New(testAttrKey, testAttrVal),
24522464
}
2453-
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(2*time.Second))
2465+
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
24542466
defer cancel()
24552467
creds := &attrTransportCreds{}
24562468

@@ -2485,7 +2497,7 @@ func (s) TestClientHandshakeInfoDialer(t *testing.T) {
24852497
Addr: "localhost:" + server.port,
24862498
Attributes: attributes.New(testAttrKey, testAttrVal),
24872499
}
2488-
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(2*time.Second))
2500+
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
24892501
defer cancel()
24902502

24912503
var attr *attributes.Attributes
@@ -2829,7 +2841,7 @@ func (s) TestClientCloseReturnsAfterReaderCompletes(t *testing.T) {
28292841

28302842
// Create a client transport with a custom dialer that hangs the Read()
28312843
// after Close().
2832-
ct, err := NewHTTP2Client(ctx, context.Background(), addr, copts, func(GoAwayReason) {})
2844+
ct, err := NewHTTP2Client(ctx, ctx, addr, copts, func(GoAwayReason) {})
28332845
if err != nil {
28342846
t.Fatalf("Failed to create transport: %v", err)
28352847
}
@@ -2915,14 +2927,14 @@ func (s) TestClientCloseReturnsEarlyWhenGoAwayWriteHangs(t *testing.T) {
29152927
}
29162928
copts := ConnectOptions{Dialer: dialer}
29172929
copts.ChannelzParent = channelzSubChannel(t)
2930+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
2931+
defer cancel()
29182932
// Create client transport with custom dialer
2919-
ct, connErr := NewHTTP2Client(connectCtx, context.Background(), addr, copts, func(GoAwayReason) {})
2933+
ct, connErr := NewHTTP2Client(connectCtx, ctx, addr, copts, func(GoAwayReason) {})
29202934
if connErr != nil {
29212935
t.Fatalf("failed to create transport: %v", connErr)
29222936
}
29232937

2924-
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
2925-
defer cancel()
29262938
if _, err := ct.NewStream(ctx, &CallHdr{}); err != nil {
29272939
t.Fatalf("Failed to open stream: %v", err)
29282940
}

scripts/vet.sh

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,7 @@ git grep '"github.com/envoyproxy/go-control-plane/envoy' -- '*.go' ':(exclude)*.
7777

7878
# - Ensure all context usages are done with timeout.
7979
# Context tests under benchmark are excluded as they are testing the performance of context.Background() and context.TODO().
80-
# TODO: Remove the exclusions once the tests are updated to use context.WithTimeout().
81-
# See https://github.com/grpc/grpc-go/issues/7304
82-
git grep -e 'context.Background()' --or -e 'context.TODO()' -- "*_test.go" | grep -v "benchmark/primitives/context_test.go" | grep -v "credential
83-
s/google" | grep -v "internal/transport/" | grep -v "xds/internal/" | grep -v "security/advancedtls" | grep -v 'context.WithTimeout(' | not grep -v 'context.WithCancel('
80+
git grep -e 'context.Background()' --or -e 'context.TODO()' -- "*_test.go" | grep -v "benchmark/primitives/context_test.go" | grep -v 'context.WithTimeout(' | not grep -v 'context.WithCancel('
8481

8582
# Disallow usage of net.ParseIP in favour of netip.ParseAddr as the former
8683
# can't parse link local IPv6 addresses.

0 commit comments

Comments
 (0)