Skip to content

Commit 67f1b67

Browse files
committed
Correct instantiation of AWS session object
The region needs to be set on the config object before creating the session so that things like setting `AWS_STS_REGIONAL_ENDPOINTS` can take effect. Also, if credentials cannot be acquired, die at that point instead of carrying on.
1 parent d698755 commit 67f1b67

File tree

2 files changed

+29
-42
lines changed

2 files changed

+29
-42
lines changed

cmd/node-termination-handler.go

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ import (
3333
"github.com/aws/aws-node-termination-handler/pkg/node"
3434
"github.com/aws/aws-node-termination-handler/pkg/observability"
3535
"github.com/aws/aws-node-termination-handler/pkg/webhook"
36+
"github.com/aws/aws-sdk-go/aws"
37+
"github.com/aws/aws-sdk-go/aws/endpoints"
38+
"github.com/aws/aws-sdk-go/aws/session"
3639
"github.com/aws/aws-sdk-go/service/autoscaling"
3740
"github.com/aws/aws-sdk-go/service/ec2"
3841
"github.com/aws/aws-sdk-go/service/sqs"
@@ -106,10 +109,11 @@ func main() {
106109
// Populate the aws region if available from node metadata and not already explicitly configured
107110
if nthConfig.AWSRegion == "" && nodeMetadata.Region != "" {
108111
nthConfig.AWSRegion = nodeMetadata.Region
109-
if nthConfig.AWSSession != nil {
110-
nthConfig.AWSSession.Config.Region = &nodeMetadata.Region
111-
}
112-
} else if nthConfig.AWSRegion == "" && nodeMetadata.Region == "" && nthConfig.EnableSQSTerminationDraining {
112+
} else if nthConfig.AWSRegion == "" && nthConfig.QueueURL != "" {
113+
nthConfig.AWSRegion = getRegionFromQueueURL(nthConfig.QueueURL)
114+
log.Debug().Str("Retrieved AWS region from queue-url: \"%s\"", nthConfig.AWSRegion)
115+
}
116+
if nthConfig.AWSRegion == "" && nthConfig.EnableSQSTerminationDraining {
113117
nthConfig.Print()
114118
log.Fatal().Msgf("Unable to find the AWS region to process queue events.")
115119
}
@@ -157,9 +161,14 @@ func main() {
157161
monitoringFns[rebalanceRecommendation] = imdsRebalanceMonitor
158162
}
159163
if nthConfig.EnableSQSTerminationDraining {
160-
creds, err := nthConfig.AWSSession.Config.Credentials.Get()
164+
cfg := aws.NewConfig().WithRegion(nthConfig.AWSRegion).WithEndpoint(nthConfig.AWSEndpoint).WithSTSRegionalEndpoint(endpoints.RegionalSTSEndpoint)
165+
sess := session.Must(session.NewSessionWithOptions(session.Options{
166+
Config: *cfg,
167+
SharedConfigState: session.SharedConfigEnable,
168+
}))
169+
creds, err := sess.Config.Credentials.Get()
161170
if err != nil {
162-
log.Err(err).Msg("Unable to get AWS credentials")
171+
log.Fatal().Err(err).Msg("Unable to get AWS credentials")
163172
}
164173
log.Debug().Msgf("AWS Credentials retrieved from provider: %s", creds.ProviderName)
165174

@@ -169,9 +178,9 @@ func main() {
169178
QueueURL: nthConfig.QueueURL,
170179
InterruptionChan: interruptionChan,
171180
CancelChan: cancelChan,
172-
SQS: sqs.New(nthConfig.AWSSession),
173-
ASG: autoscaling.New(nthConfig.AWSSession),
174-
EC2: ec2.New(nthConfig.AWSSession),
181+
SQS: sqs.New(sess),
182+
ASG: autoscaling.New(sess),
183+
EC2: ec2.New(sess),
175184
}
176185
monitoringFns[sqsEvents] = sqsMonitor
177186
}
@@ -380,3 +389,14 @@ func runPostDrainTask(node node.Node, nodeName string, drainEvent *monitor.Inter
380389
}
381390
metrics.NodeActionsInc("post-drain", nodeName, err)
382391
}
392+
393+
func getRegionFromQueueURL(queueURL string) string {
394+
for _, partition := range endpoints.DefaultPartitions() {
395+
for regionID := range partition.Regions() {
396+
if strings.Contains(queueURL, regionID) {
397+
return regionID
398+
}
399+
}
400+
}
401+
return ""
402+
}

pkg/config/config.go

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@ import (
2020
"strconv"
2121
"strings"
2222

23-
"github.com/aws/aws-sdk-go/aws/endpoints"
24-
"github.com/aws/aws-sdk-go/aws/session"
2523
"github.com/rs/zerolog/log"
2624
)
2725

@@ -139,7 +137,6 @@ type Config struct {
139137
AWSEndpoint string
140138
QueueURL string
141139
Workers int
142-
AWSSession *session.Session
143140
}
144141

145142
//ParseCliArgs parses cli arguments and uses environment variables as fallback values
@@ -195,25 +192,6 @@ func ParseCliArgs() (config Config, err error) {
195192

196193
flag.Parse()
197194

198-
if config.EnableSQSTerminationDraining {
199-
sess := session.Must(session.NewSessionWithOptions(session.Options{
200-
SharedConfigState: session.SharedConfigEnable,
201-
}))
202-
if config.AWSRegion != "" {
203-
sess.Config.Region = &config.AWSRegion
204-
} else if *sess.Config.Region == "" && config.QueueURL != "" {
205-
config.AWSRegion = getRegionFromQueueURL(config.QueueURL)
206-
log.Debug().Str("Retrieved AWS region from queue-url: \"%s\"", config.AWSRegion)
207-
sess.Config.Region = &config.AWSRegion
208-
} else {
209-
config.AWSRegion = *sess.Config.Region
210-
}
211-
config.AWSSession = sess
212-
if config.AWSEndpoint != "" {
213-
config.AWSSession.Config.Endpoint = &config.AWSEndpoint
214-
}
215-
}
216-
217195
if isConfigProvided("pod-termination-grace-period", podTerminationGracePeriodConfigKey) && isConfigProvided("grace-period", gracePeriodConfigKey) {
218196
log.Warn().Msg("Deprecated argument \"grace-period\" and the replacement argument \"pod-termination-grace-period\" was provided. Using the newer argument \"pod-termination-grace-period\"")
219197
} else if isConfigProvided("grace-period", gracePeriodConfigKey) {
@@ -413,14 +391,3 @@ func isConfigProvided(cliArgName string, envVarName string) bool {
413391
})
414392
return cliArgProvided
415393
}
416-
417-
func getRegionFromQueueURL(queueURL string) string {
418-
for _, partition := range endpoints.DefaultPartitions() {
419-
for regionID := range partition.Regions() {
420-
if strings.Contains(queueURL, regionID) {
421-
return regionID
422-
}
423-
}
424-
}
425-
return ""
426-
}

0 commit comments

Comments
 (0)