@@ -510,21 +510,21 @@ func (e *ec2Wrapper) getClientUsingAssumedRole(instanceRegion, roleARN, clusterN
510510 }
511511 e .log .Info ("created rate limited http client" , "qps" , qps , "burst" , burst )
512512
513- // Get the regional sts end point
514- regionalSTSEndpoint , err := endpoints .DefaultResolver ().
515- EndpointFor ("sts" , aws .StringValue (userStsSession .Config .Region ), endpoints .STSRegionalEndpointOption )
516- if err != nil {
517- return nil , fmt .Errorf ("failed to get the regional sts endoint for region %s: %v" ,
518- * userStsSession .Config .Region , err )
519- }
520-
513+ // GetPartition ID, SourceAccount and SourceARN
521514 roleARN = strings .Trim (roleARN , "\" " )
522515
523- sourceAcct , sourceArn , err := utils .GetSourceAcctAndArn (roleARN , region , clusterName )
516+ sourceAcct , partitionID , sourceArn , err := utils .GetSourceAcctAndArn (roleARN , region , clusterName )
524517 if err != nil {
525518 return nil , err
526519 }
527520
521+ // Get the regional sts end point
522+ regionalSTSEndpoint , err := e .getRegionalStsEndpoint (partitionID , region )
523+ if err != nil {
524+ return nil , fmt .Errorf ("failed to get the regional sts endpoint for region %s: %v %v" ,
525+ * userStsSession .Config .Region , err , partitionID )
526+ }
527+
528528 regionalProvider := & stscreds.AssumeRoleProvider {
529529 Client : e .createSTSClient (userStsSession , client , regionalSTSEndpoint , sourceAcct , sourceArn ),
530530 RoleARN : roleARN ,
@@ -547,7 +547,7 @@ func (e *ec2Wrapper) getClientUsingAssumedRole(instanceRegion, roleARN, clusterN
547547 // If the regional STS endpoint is different than the global STS endpoint then add the global sts endpoint
548548 if regionalSTSEndpoint .URL != globalSTSEndpoint .URL {
549549 globalProvider := & stscreds.AssumeRoleProvider {
550- Client : e .createSTSClient (userStsSession , client , regionalSTSEndpoint , sourceAcct , sourceArn ),
550+ Client : e .createSTSClient (userStsSession , client , globalSTSEndpoint , sourceAcct , sourceArn ),
551551 RoleARN : roleARN ,
552552 Duration : time .Minute * 60 ,
553553 }
@@ -892,3 +892,35 @@ func (e *ec2Wrapper) DisassociateTrunkInterface(input *ec2.DisassociateTrunkInte
892892 }
893893 return err
894894}
895+
896+ func (e * ec2Wrapper ) getRegionalStsEndpoint (partitionID , region string ) (endpoints.ResolvedEndpoint , error ) {
897+ var partition * endpoints.Partition
898+ var stsServiceID = "sts"
899+ for _ , p := range endpoints .DefaultPartitions () {
900+ if partitionID == p .ID () {
901+ partition = & p
902+ break
903+ }
904+ }
905+ if partition == nil {
906+ return endpoints.ResolvedEndpoint {}, fmt .Errorf ("partition %s not valid" , partitionID )
907+ }
908+
909+ stsSvc , ok := partition .Services ()[stsServiceID ]
910+ if ! ok {
911+ e .log .Info ("STS service not found in partition, generating default endpoint." , "Partition:" , partitionID )
912+ // Add the host of the current instances region if the service doesn't already exists in the partition
913+ // so we don't fail if the service is not present in the go sdk but matches the instances region.
914+ res , err := partition .EndpointFor (stsServiceID , region , endpoints .STSRegionalEndpointOption , endpoints .ResolveUnknownServiceOption )
915+ if err != nil {
916+ return endpoints.ResolvedEndpoint {}, fmt .Errorf ("error resolving endpoint for %s in partition %s. err: %v" , region , partition .ID (), err )
917+ }
918+ return res , nil
919+ }
920+
921+ res , err := stsSvc .ResolveEndpoint (region , endpoints .STSRegionalEndpointOption )
922+ if err != nil {
923+ return endpoints.ResolvedEndpoint {}, fmt .Errorf ("error resolving endpoint for %s in partition %s. err: %v" , region , partition .ID (), err )
924+ }
925+ return res , nil
926+ }
0 commit comments