@@ -52,6 +52,13 @@ const (
52
52
// The AWS authorization header name for the security session token if available.
53
53
awsSecurityTokenHeader = "x-amz-security-token"
54
54
55
+ // The name of the header containing the session token for metadata endpoint calls
56
+ awsIMDSv2SessionTokenHeader = "X-aws-ec2-metadata-token"
57
+
58
+ awsIMDSv2SessionTtlHeader = "X-aws-ec2-metadata-token-ttl-seconds"
59
+
60
+ awsIMDSv2SessionTtl = "300"
61
+
55
62
// The AWS authorization header name for the auto-generated date.
56
63
awsDateHeader = "x-amz-date"
57
64
@@ -241,6 +248,7 @@ type awsCredentialSource struct {
241
248
RegionURL string
242
249
RegionalCredVerificationURL string
243
250
CredVerificationURL string
251
+ IMDSv2SessionTokenURL string
244
252
TargetResource string
245
253
requestSigner * awsRequestSigner
246
254
region string
@@ -268,12 +276,22 @@ func (cs awsCredentialSource) doRequest(req *http.Request) (*http.Response, erro
268
276
269
277
func (cs awsCredentialSource ) subjectToken () (string , error ) {
270
278
if cs .requestSigner == nil {
271
- awsSecurityCredentials , err := cs .getSecurityCredentials ()
279
+ awsSessionToken , err := cs .getAWSSessionToken ()
280
+ if err != nil {
281
+ return "" , err
282
+ }
283
+
284
+ headers := make (map [string ]string )
285
+ if awsSessionToken != "" {
286
+ headers [awsIMDSv2SessionTokenHeader ] = awsSessionToken
287
+ }
288
+
289
+ awsSecurityCredentials , err := cs .getSecurityCredentials (headers )
272
290
if err != nil {
273
291
return "" , err
274
292
}
275
293
276
- if cs .region , err = cs .getRegion (); err != nil {
294
+ if cs .region , err = cs .getRegion (headers ); err != nil {
277
295
return "" , err
278
296
}
279
297
@@ -340,7 +358,37 @@ func (cs awsCredentialSource) subjectToken() (string, error) {
340
358
return url .QueryEscape (string (result )), nil
341
359
}
342
360
343
- func (cs * awsCredentialSource ) getRegion () (string , error ) {
361
+ func (cs * awsCredentialSource ) getAWSSessionToken () (string , error ) {
362
+ if cs .IMDSv2SessionTokenURL == "" {
363
+ return "" , nil
364
+ }
365
+
366
+ req , err := http .NewRequest ("PUT" , cs .IMDSv2SessionTokenURL , nil )
367
+ if err != nil {
368
+ return "" , err
369
+ }
370
+
371
+ req .Header .Add (awsIMDSv2SessionTtlHeader , awsIMDSv2SessionTtl )
372
+
373
+ resp , err := cs .doRequest (req )
374
+ if err != nil {
375
+ return "" , err
376
+ }
377
+ defer resp .Body .Close ()
378
+
379
+ respBody , err := ioutil .ReadAll (io .LimitReader (resp .Body , 1 << 20 ))
380
+ if err != nil {
381
+ return "" , err
382
+ }
383
+
384
+ if resp .StatusCode != 200 {
385
+ return "" , fmt .Errorf ("oauth2/google: unable to retrieve AWS session token - %s" , string (respBody ))
386
+ }
387
+
388
+ return string (respBody ), nil
389
+ }
390
+
391
+ func (cs * awsCredentialSource ) getRegion (headers map [string ]string ) (string , error ) {
344
392
if envAwsRegion := getenv ("AWS_REGION" ); envAwsRegion != "" {
345
393
return envAwsRegion , nil
346
394
}
@@ -357,6 +405,10 @@ func (cs *awsCredentialSource) getRegion() (string, error) {
357
405
return "" , err
358
406
}
359
407
408
+ for name , value := range headers {
409
+ req .Header .Add (name , value )
410
+ }
411
+
360
412
resp , err := cs .doRequest (req )
361
413
if err != nil {
362
414
return "" , err
@@ -381,7 +433,7 @@ func (cs *awsCredentialSource) getRegion() (string, error) {
381
433
return string (respBody [:respBodyEnd ]), nil
382
434
}
383
435
384
- func (cs * awsCredentialSource ) getSecurityCredentials () (result awsSecurityCredentials , err error ) {
436
+ func (cs * awsCredentialSource ) getSecurityCredentials (headers map [ string ] string ) (result awsSecurityCredentials , err error ) {
385
437
if accessKeyID := getenv ("AWS_ACCESS_KEY_ID" ); accessKeyID != "" {
386
438
if secretAccessKey := getenv ("AWS_SECRET_ACCESS_KEY" ); secretAccessKey != "" {
387
439
return awsSecurityCredentials {
@@ -392,12 +444,12 @@ func (cs *awsCredentialSource) getSecurityCredentials() (result awsSecurityCrede
392
444
}
393
445
}
394
446
395
- roleName , err := cs .getMetadataRoleName ()
447
+ roleName , err := cs .getMetadataRoleName (headers )
396
448
if err != nil {
397
449
return
398
450
}
399
451
400
- credentials , err := cs .getMetadataSecurityCredentials (roleName )
452
+ credentials , err := cs .getMetadataSecurityCredentials (roleName , headers )
401
453
if err != nil {
402
454
return
403
455
}
@@ -413,7 +465,7 @@ func (cs *awsCredentialSource) getSecurityCredentials() (result awsSecurityCrede
413
465
return credentials , nil
414
466
}
415
467
416
- func (cs * awsCredentialSource ) getMetadataSecurityCredentials (roleName string ) (awsSecurityCredentials , error ) {
468
+ func (cs * awsCredentialSource ) getMetadataSecurityCredentials (roleName string , headers map [ string ] string ) (awsSecurityCredentials , error ) {
417
469
var result awsSecurityCredentials
418
470
419
471
req , err := http .NewRequest ("GET" , fmt .Sprintf ("%s/%s" , cs .CredVerificationURL , roleName ), nil )
@@ -422,6 +474,10 @@ func (cs *awsCredentialSource) getMetadataSecurityCredentials(roleName string) (
422
474
}
423
475
req .Header .Add ("Content-Type" , "application/json" )
424
476
477
+ for name , value := range headers {
478
+ req .Header .Add (name , value )
479
+ }
480
+
425
481
resp , err := cs .doRequest (req )
426
482
if err != nil {
427
483
return result , err
@@ -441,7 +497,7 @@ func (cs *awsCredentialSource) getMetadataSecurityCredentials(roleName string) (
441
497
return result , err
442
498
}
443
499
444
- func (cs * awsCredentialSource ) getMetadataRoleName () (string , error ) {
500
+ func (cs * awsCredentialSource ) getMetadataRoleName (headers map [ string ] string ) (string , error ) {
445
501
if cs .CredVerificationURL == "" {
446
502
return "" , errors .New ("oauth2/google: unable to determine the AWS metadata server security credentials endpoint" )
447
503
}
@@ -451,6 +507,10 @@ func (cs *awsCredentialSource) getMetadataRoleName() (string, error) {
451
507
return "" , err
452
508
}
453
509
510
+ for name , value := range headers {
511
+ req .Header .Add (name , value )
512
+ }
513
+
454
514
resp , err := cs .doRequest (req )
455
515
if err != nil {
456
516
return "" , err
0 commit comments