Skip to content

Commit cdc09f0

Browse files
author
Julien Pivotto
authored
Merge pull request #387 from roidelapluie/useragent
Useragent for OAuth2
2 parents d75e027 + db0284d commit cdc09f0

File tree

2 files changed

+89
-6
lines changed

2 files changed

+89
-6
lines changed

config/http_config.go

+42-2
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,7 @@ type httpClientOptions struct {
372372
keepAlivesEnabled bool
373373
http2Enabled bool
374374
idleConnTimeout time.Duration
375+
userAgent string
375376
}
376377

377378
// HTTPClientOption defines an option that can be applied to the HTTP client.
@@ -405,6 +406,13 @@ func WithIdleConnTimeout(timeout time.Duration) HTTPClientOption {
405406
}
406407
}
407408

409+
// WithUserAgent allows setting the user agent.
410+
func WithUserAgent(ua string) HTTPClientOption {
411+
return func(opts *httpClientOptions) {
412+
opts.userAgent = ua
413+
}
414+
}
415+
408416
// NewClient returns a http.Client using the specified http.RoundTripper.
409417
func newClient(rt http.RoundTripper) *http.Client {
410418
return &http.Client{Transport: rt}
@@ -497,8 +505,12 @@ func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string, optFuncs ...HT
497505
rt = NewBasicAuthRoundTripper(cfg.BasicAuth.Username, cfg.BasicAuth.Password, cfg.BasicAuth.PasswordFile, rt)
498506
}
499507

508+
if opts.userAgent != "" {
509+
rt = NewUserAgentRoundTripper(opts.userAgent, rt)
510+
}
511+
500512
if cfg.OAuth2 != nil {
501-
rt = NewOAuth2RoundTripper(cfg.OAuth2, rt)
513+
rt = NewOAuth2RoundTripper(cfg.OAuth2, rt, &opts)
502514
}
503515
// Return a new configured RoundTripper.
504516
return rt, nil
@@ -619,12 +631,14 @@ type oauth2RoundTripper struct {
619631
next http.RoundTripper
620632
secret string
621633
mtx sync.RWMutex
634+
opts *httpClientOptions
622635
}
623636

624-
func NewOAuth2RoundTripper(config *OAuth2, next http.RoundTripper) http.RoundTripper {
637+
func NewOAuth2RoundTripper(config *OAuth2, next http.RoundTripper, opts *httpClientOptions) http.RoundTripper {
625638
return &oauth2RoundTripper{
626639
config: config,
627640
next: next,
641+
opts: opts,
628642
}
629643
}
630644

@@ -681,6 +695,10 @@ func (rt *oauth2RoundTripper) RoundTrip(req *http.Request) (*http.Response, erro
681695
}
682696
}
683697

698+
if rt.opts.userAgent != "" {
699+
t = NewUserAgentRoundTripper(rt.opts.userAgent, t)
700+
}
701+
684702
ctx := context.WithValue(context.Background(), oauth2.HTTPClient, &http.Client{Transport: t})
685703
tokenSource := config.TokenSource(ctx)
686704

@@ -911,6 +929,28 @@ func (t *tlsRoundTripper) CloseIdleConnections() {
911929
}
912930
}
913931

932+
type userAgentRoundTripper struct {
933+
userAgent string
934+
rt http.RoundTripper
935+
}
936+
937+
// NewUserAgentRoundTripper adds the user agent every request header.
938+
func NewUserAgentRoundTripper(userAgent string, rt http.RoundTripper) http.RoundTripper {
939+
return &userAgentRoundTripper{userAgent, rt}
940+
}
941+
942+
func (rt *userAgentRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
943+
req = cloneRequest(req)
944+
req.Header.Set("User-Agent", rt.userAgent)
945+
return rt.rt.RoundTrip(req)
946+
}
947+
948+
func (rt *userAgentRoundTripper) CloseIdleConnections() {
949+
if ci, ok := rt.rt.(closeIdler); ok {
950+
ci.CloseIdleConnections()
951+
}
952+
}
953+
914954
func (c HTTPClientConfig) String() string {
915955
b, err := yaml.Marshal(c)
916956
if err != nil {

config/http_config_test.go

+47-4
Original file line numberDiff line numberDiff line change
@@ -1198,7 +1198,7 @@ client_secret: 2
11981198
scopes:
11991199
- A
12001200
- B
1201-
token_url: %s
1201+
token_url: %s/token
12021202
endpoint_params:
12031203
hi: hello
12041204
`, ts.URL)
@@ -1207,7 +1207,7 @@ endpoint_params:
12071207
ClientSecret: "2",
12081208
Scopes: []string{"A", "B"},
12091209
EndpointParams: map[string]string{"hi": "hello"},
1210-
TokenURL: ts.URL,
1210+
TokenURL: fmt.Sprintf("%s/token", ts.URL),
12111211
}
12121212

12131213
var unmarshalledConfig OAuth2
@@ -1219,7 +1219,7 @@ endpoint_params:
12191219
t.Fatalf("Got unmarshalled config %v, expected %v", unmarshalledConfig, expectedConfig)
12201220
}
12211221

1222-
rt := NewOAuth2RoundTripper(&expectedConfig, http.DefaultTransport)
1222+
rt := NewOAuth2RoundTripper(&expectedConfig, http.DefaultTransport, &defaultHTTPClientOptions)
12231223

12241224
client := http.Client{
12251225
Transport: rt,
@@ -1232,6 +1232,49 @@ endpoint_params:
12321232
}
12331233
}
12341234

1235+
func TestOAuth2UserAgent(t *testing.T) {
1236+
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1237+
if r.Header.Get("User-Agent") != "myuseragent" {
1238+
t.Fatalf("Expected User-Agent header in oauth request to be 'myuseragent', got '%s'", r.Header.Get("User-Agent"))
1239+
}
1240+
1241+
res, _ := json.Marshal(oauth2TestServerResponse{
1242+
AccessToken: "12345",
1243+
TokenType: "Bearer",
1244+
})
1245+
w.Header().Add("Content-Type", "application/json")
1246+
_, _ = w.Write(res)
1247+
}))
1248+
defer ts.Close()
1249+
1250+
config := DefaultHTTPClientConfig
1251+
config.OAuth2 = &OAuth2{
1252+
ClientID: "1",
1253+
ClientSecret: "2",
1254+
Scopes: []string{"A", "B"},
1255+
EndpointParams: map[string]string{"hi": "hello"},
1256+
TokenURL: fmt.Sprintf("%s/token", ts.URL),
1257+
}
1258+
1259+
rt, err := NewRoundTripperFromConfig(config, "test_oauth2", WithUserAgent("myuseragent"))
1260+
if err != nil {
1261+
t.Fatal(err)
1262+
}
1263+
1264+
client := http.Client{
1265+
Transport: rt,
1266+
}
1267+
resp, err := client.Get(ts.URL)
1268+
if err != nil {
1269+
t.Fatal(err)
1270+
}
1271+
1272+
authorization := resp.Request.Header.Get("Authorization")
1273+
if authorization != "Bearer 12345" {
1274+
t.Fatalf("Expected authorization header to be 'Bearer 12345', got '%s'", authorization)
1275+
}
1276+
}
1277+
12351278
func TestOAuth2WithFile(t *testing.T) {
12361279
var expectedAuth *string
12371280
var previousAuth string
@@ -1294,7 +1337,7 @@ endpoint_params:
12941337
t.Fatalf("Got unmarshalled config %v, expected %v", unmarshalledConfig, expectedConfig)
12951338
}
12961339

1297-
rt := NewOAuth2RoundTripper(&expectedConfig, http.DefaultTransport)
1340+
rt := NewOAuth2RoundTripper(&expectedConfig, http.DefaultTransport, &defaultHTTPClientOptions)
12981341

12991342
client := http.Client{
13001343
Transport: rt,

0 commit comments

Comments
 (0)