Skip to content

Commit aa4662f

Browse files
authored
fix(transport/grpc): retain UserAgent option with new auth stack (#2690)
The only code that converted the `option.WithUserAgent` to a gRPC DialOption `WithUserAgent` was in the `dial` function, which was bypassed when leveraging the new auth library that handles most of the important dialing set up re:credentials. To fix this, I added a new helper specifically for use in `dialPoolNewAuth` that prepares dial options based on the given `DialSettings` - `prepareDialOpionsNewAuth`. I could've added the same logic to anywhere in `dialPoolNewAuth` to fix it, but I felt I needed to signal that this is the method to be modifying gRPC DialOptions in so that in the future, we can centralize further gRPC DialOption manipulation rather than just "wherever" like in the "old" `dial`. Also stored the `grpctransport.Dial` function into a global variable so that unit tests could spoof it to check on the gRPC DialOptions that were provided - just like the "old" `dial` did for testing. Updates https://togithub.com/googleapis/google-cloud-go/issues/10550 but needs to be released here, integrated there, and released there.
1 parent 786363b commit aa4662f

File tree

2 files changed

+82
-2
lines changed

2 files changed

+82
-2
lines changed

transport/grpc/dial.go

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ var logRateLimiter = rate.Sometimes{Interval: 1 * time.Second}
5353
// Assign to var for unit test replacement
5454
var dialContext = grpc.DialContext
5555

56+
// Assign to var for unit test replacement
57+
var dialContextNewAuth = grpctransport.Dial
58+
5659
// otelStatsHandler is a singleton otelgrpc.clientHandler to be used across
5760
// all dial connections to avoid the memory leak documented in
5861
// https://github.com/open-telemetry/opentelemetry-go-contrib/issues/4226
@@ -218,12 +221,12 @@ func dialPoolNewAuth(ctx context.Context, secure bool, poolSize int, ds *interna
218221
defaultEndpointTemplate = ds.DefaultEndpoint
219222
}
220223

221-
pool, err := grpctransport.Dial(ctx, secure, &grpctransport.Options{
224+
pool, err := dialContextNewAuth(ctx, secure, &grpctransport.Options{
222225
DisableTelemetry: ds.TelemetryDisabled,
223226
DisableAuthentication: ds.NoAuth,
224227
Endpoint: ds.Endpoint,
225228
Metadata: metadata,
226-
GRPCDialOpts: ds.GRPCDialOpts,
229+
GRPCDialOpts: prepareDialOptsNewAuth(ds),
227230
PoolSize: poolSize,
228231
Credentials: creds,
229232
APIKey: ds.APIKey,
@@ -248,6 +251,15 @@ func dialPoolNewAuth(ctx context.Context, secure bool, poolSize int, ds *interna
248251
return pool, err
249252
}
250253

254+
func prepareDialOptsNewAuth(ds *internal.DialSettings) []grpc.DialOption {
255+
var opts []grpc.DialOption
256+
if ds.UserAgent != "" {
257+
opts = append(opts, grpc.WithUserAgent(ds.UserAgent))
258+
}
259+
260+
return append(opts, ds.GRPCDialOpts...)
261+
}
262+
251263
func dial(ctx context.Context, insecure bool, o *internal.DialSettings) (*grpc.ClientConn, error) {
252264
if o.HTTPClient != nil {
253265
return nil, errors.New("unsupported HTTP client specified")

transport/grpc/dial_test.go

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"strings"
1212
"testing"
1313

14+
"cloud.google.com/go/auth/grpctransport"
1415
"cloud.google.com/go/compute/metadata"
1516
"github.com/google/go-cmp/cmp"
1617
"golang.org/x/oauth2/google"
@@ -35,6 +36,73 @@ func TestDial(t *testing.T) {
3536
dial(context.Background(), false, &o)
3637
}
3738

39+
func TestDialPoolNewAuthDialOptions(t *testing.T) {
40+
oldDialContextNewAuth := dialContextNewAuth
41+
var wantNumOpts int
42+
// Replace package var in order to assert DialContext args.
43+
dialContextNewAuth = func(ctx context.Context, secure bool, opts *grpctransport.Options) (grpctransport.GRPCClientConnPool, error) {
44+
if len(opts.GRPCDialOpts) != wantNumOpts {
45+
t.Fatalf("got: %d, want: %d", len(opts.GRPCDialOpts), wantNumOpts)
46+
}
47+
return nil, nil
48+
}
49+
defer func() {
50+
dialContextNewAuth = oldDialContextNewAuth
51+
}()
52+
53+
for _, testcase := range []struct {
54+
name string
55+
ds *internal.DialSettings
56+
wantNumOpts int
57+
}{
58+
{
59+
name: "no dial options",
60+
ds: &internal.DialSettings{},
61+
wantNumOpts: 0,
62+
},
63+
{
64+
name: "with user agent",
65+
ds: &internal.DialSettings{
66+
UserAgent: "test",
67+
},
68+
wantNumOpts: 1,
69+
},
70+
} {
71+
t.Run(testcase.name, func(t *testing.T) {
72+
wantNumOpts = testcase.wantNumOpts
73+
dialPoolNewAuth(context.Background(), false, 1, testcase.ds)
74+
})
75+
}
76+
}
77+
78+
func TestPrepareDialOptsNewAuth(t *testing.T) {
79+
for _, testcase := range []struct {
80+
name string
81+
ds *internal.DialSettings
82+
wantNumOpts int
83+
}{
84+
{
85+
name: "empty",
86+
ds: &internal.DialSettings{},
87+
wantNumOpts: 0,
88+
},
89+
{
90+
name: "user agent",
91+
ds: &internal.DialSettings{
92+
UserAgent: "test",
93+
},
94+
wantNumOpts: 1,
95+
},
96+
} {
97+
t.Run(testcase.name, func(t *testing.T) {
98+
got := prepareDialOptsNewAuth(testcase.ds)
99+
if len(got) != testcase.wantNumOpts {
100+
t.Fatalf("got %d options, want %d options", len(got), testcase.wantNumOpts)
101+
}
102+
})
103+
}
104+
}
105+
38106
func TestCheckDirectPathEndPoint(t *testing.T) {
39107
for _, testcase := range []struct {
40108
name string

0 commit comments

Comments
 (0)