@@ -33,6 +33,9 @@ import (
33
33
"github.com/aws/aws-node-termination-handler/pkg/node"
34
34
"github.com/aws/aws-node-termination-handler/pkg/observability"
35
35
"github.com/aws/aws-node-termination-handler/pkg/webhook"
36
+ "github.com/aws/aws-sdk-go/aws"
37
+ "github.com/aws/aws-sdk-go/aws/endpoints"
38
+ "github.com/aws/aws-sdk-go/aws/session"
36
39
"github.com/aws/aws-sdk-go/service/autoscaling"
37
40
"github.com/aws/aws-sdk-go/service/ec2"
38
41
"github.com/aws/aws-sdk-go/service/sqs"
@@ -106,10 +109,11 @@ func main() {
106
109
// Populate the aws region if available from node metadata and not already explicitly configured
107
110
if nthConfig .AWSRegion == "" && nodeMetadata .Region != "" {
108
111
nthConfig .AWSRegion = nodeMetadata .Region
109
- if nthConfig .AWSSession != nil {
110
- nthConfig .AWSSession .Config .Region = & nodeMetadata .Region
111
- }
112
- } else if nthConfig .AWSRegion == "" && nodeMetadata .Region == "" && nthConfig .EnableSQSTerminationDraining {
112
+ } else if nthConfig .AWSRegion == "" && nthConfig .QueueURL != "" {
113
+ nthConfig .AWSRegion = getRegionFromQueueURL (nthConfig .QueueURL )
114
+ log .Debug ().Str ("Retrieved AWS region from queue-url: \" %s\" " , nthConfig .AWSRegion )
115
+ }
116
+ if nthConfig .AWSRegion == "" && nthConfig .EnableSQSTerminationDraining {
113
117
nthConfig .Print ()
114
118
log .Fatal ().Msgf ("Unable to find the AWS region to process queue events." )
115
119
}
@@ -157,9 +161,14 @@ func main() {
157
161
monitoringFns [rebalanceRecommendation ] = imdsRebalanceMonitor
158
162
}
159
163
if nthConfig .EnableSQSTerminationDraining {
160
- creds , err := nthConfig .AWSSession .Config .Credentials .Get ()
164
+ cfg := aws .NewConfig ().WithRegion (nthConfig .AWSRegion ).WithEndpoint (nthConfig .AWSEndpoint ).WithSTSRegionalEndpoint (endpoints .RegionalSTSEndpoint )
165
+ sess := session .Must (session .NewSessionWithOptions (session.Options {
166
+ Config : * cfg ,
167
+ SharedConfigState : session .SharedConfigEnable ,
168
+ }))
169
+ creds , err := sess .Config .Credentials .Get ()
161
170
if err != nil {
162
- log .Err (err ).Msg ("Unable to get AWS credentials" )
171
+ log .Fatal (). Err (err ).Msg ("Unable to get AWS credentials" )
163
172
}
164
173
log .Debug ().Msgf ("AWS Credentials retrieved from provider: %s" , creds .ProviderName )
165
174
@@ -169,9 +178,9 @@ func main() {
169
178
QueueURL : nthConfig .QueueURL ,
170
179
InterruptionChan : interruptionChan ,
171
180
CancelChan : cancelChan ,
172
- SQS : sqs .New (nthConfig . AWSSession ),
173
- ASG : autoscaling .New (nthConfig . AWSSession ),
174
- EC2 : ec2 .New (nthConfig . AWSSession ),
181
+ SQS : sqs .New (sess ),
182
+ ASG : autoscaling .New (sess ),
183
+ EC2 : ec2 .New (sess ),
175
184
}
176
185
monitoringFns [sqsEvents ] = sqsMonitor
177
186
}
@@ -380,3 +389,14 @@ func runPostDrainTask(node node.Node, nodeName string, drainEvent *monitor.Inter
380
389
}
381
390
metrics .NodeActionsInc ("post-drain" , nodeName , err )
382
391
}
392
+
393
+ func getRegionFromQueueURL (queueURL string ) string {
394
+ for _ , partition := range endpoints .DefaultPartitions () {
395
+ for regionID := range partition .Regions () {
396
+ if strings .Contains (queueURL , regionID ) {
397
+ return regionID
398
+ }
399
+ }
400
+ }
401
+ return ""
402
+ }
0 commit comments