Skip to content

Commit 0100489

Browse files
authored
Correct instantiation of AWS session object (#421)
The region needs to be set on the config object before creating the session so that things like setting `AWS_STS_REGIONAL_ENDPOINTS` can take effect. Also, if credentials cannot be acquired, die at that point instead of carrying on.
1 parent d698755 commit 0100489

File tree

2 files changed

+29
-42
lines changed

2 files changed

+29
-42
lines changed

cmd/node-termination-handler.go

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ import (
3333
"github.com/aws/aws-node-termination-handler/pkg/node"
3434
"github.com/aws/aws-node-termination-handler/pkg/observability"
3535
"github.com/aws/aws-node-termination-handler/pkg/webhook"
36+
"github.com/aws/aws-sdk-go/aws"
37+
"github.com/aws/aws-sdk-go/aws/endpoints"
38+
"github.com/aws/aws-sdk-go/aws/session"
3639
"github.com/aws/aws-sdk-go/service/autoscaling"
3740
"github.com/aws/aws-sdk-go/service/ec2"
3841
"github.com/aws/aws-sdk-go/service/sqs"
@@ -106,10 +109,11 @@ func main() {
106109
// Populate the aws region if available from node metadata and not already explicitly configured
107110
if nthConfig.AWSRegion == "" && nodeMetadata.Region != "" {
108111
nthConfig.AWSRegion = nodeMetadata.Region
109-
if nthConfig.AWSSession != nil {
110-
nthConfig.AWSSession.Config.Region = &nodeMetadata.Region
111-
}
112-
} else if nthConfig.AWSRegion == "" && nodeMetadata.Region == "" && nthConfig.EnableSQSTerminationDraining {
112+
} else if nthConfig.AWSRegion == "" && nthConfig.QueueURL != "" {
113+
nthConfig.AWSRegion = getRegionFromQueueURL(nthConfig.QueueURL)
114+
log.Debug().Str("Retrieved AWS region from queue-url: \"%s\"", nthConfig.AWSRegion)
115+
}
116+
if nthConfig.AWSRegion == "" && nthConfig.EnableSQSTerminationDraining {
113117
nthConfig.Print()
114118
log.Fatal().Msgf("Unable to find the AWS region to process queue events.")
115119
}
@@ -157,9 +161,14 @@ func main() {
157161
monitoringFns[rebalanceRecommendation] = imdsRebalanceMonitor
158162
}
159163
if nthConfig.EnableSQSTerminationDraining {
160-
creds, err := nthConfig.AWSSession.Config.Credentials.Get()
164+
cfg := aws.NewConfig().WithRegion(nthConfig.AWSRegion).WithEndpoint(nthConfig.AWSEndpoint).WithSTSRegionalEndpoint(endpoints.RegionalSTSEndpoint)
165+
sess := session.Must(session.NewSessionWithOptions(session.Options{
166+
Config: *cfg,
167+
SharedConfigState: session.SharedConfigEnable,
168+
}))
169+
creds, err := sess.Config.Credentials.Get()
161170
if err != nil {
162-
log.Err(err).Msg("Unable to get AWS credentials")
171+
log.Fatal().Err(err).Msg("Unable to get AWS credentials")
163172
}
164173
log.Debug().Msgf("AWS Credentials retrieved from provider: %s", creds.ProviderName)
165174

@@ -169,9 +178,9 @@ func main() {
169178
QueueURL: nthConfig.QueueURL,
170179
InterruptionChan: interruptionChan,
171180
CancelChan: cancelChan,
172-
SQS: sqs.New(nthConfig.AWSSession),
173-
ASG: autoscaling.New(nthConfig.AWSSession),
174-
EC2: ec2.New(nthConfig.AWSSession),
181+
SQS: sqs.New(sess),
182+
ASG: autoscaling.New(sess),
183+
EC2: ec2.New(sess),
175184
}
176185
monitoringFns[sqsEvents] = sqsMonitor
177186
}
@@ -380,3 +389,14 @@ func runPostDrainTask(node node.Node, nodeName string, drainEvent *monitor.Inter
380389
}
381390
metrics.NodeActionsInc("post-drain", nodeName, err)
382391
}
392+
393+
func getRegionFromQueueURL(queueURL string) string {
394+
for _, partition := range endpoints.DefaultPartitions() {
395+
for regionID := range partition.Regions() {
396+
if strings.Contains(queueURL, regionID) {
397+
return regionID
398+
}
399+
}
400+
}
401+
return ""
402+
}

pkg/config/config.go

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@ import (
2020
"strconv"
2121
"strings"
2222

23-
"github.com/aws/aws-sdk-go/aws/endpoints"
24-
"github.com/aws/aws-sdk-go/aws/session"
2523
"github.com/rs/zerolog/log"
2624
)
2725

@@ -139,7 +137,6 @@ type Config struct {
139137
AWSEndpoint string
140138
QueueURL string
141139
Workers int
142-
AWSSession *session.Session
143140
}
144141

145142
//ParseCliArgs parses cli arguments and uses environment variables as fallback values
@@ -195,25 +192,6 @@ func ParseCliArgs() (config Config, err error) {
195192

196193
flag.Parse()
197194

198-
if config.EnableSQSTerminationDraining {
199-
sess := session.Must(session.NewSessionWithOptions(session.Options{
200-
SharedConfigState: session.SharedConfigEnable,
201-
}))
202-
if config.AWSRegion != "" {
203-
sess.Config.Region = &config.AWSRegion
204-
} else if *sess.Config.Region == "" && config.QueueURL != "" {
205-
config.AWSRegion = getRegionFromQueueURL(config.QueueURL)
206-
log.Debug().Str("Retrieved AWS region from queue-url: \"%s\"", config.AWSRegion)
207-
sess.Config.Region = &config.AWSRegion
208-
} else {
209-
config.AWSRegion = *sess.Config.Region
210-
}
211-
config.AWSSession = sess
212-
if config.AWSEndpoint != "" {
213-
config.AWSSession.Config.Endpoint = &config.AWSEndpoint
214-
}
215-
}
216-
217195
if isConfigProvided("pod-termination-grace-period", podTerminationGracePeriodConfigKey) && isConfigProvided("grace-period", gracePeriodConfigKey) {
218196
log.Warn().Msg("Deprecated argument \"grace-period\" and the replacement argument \"pod-termination-grace-period\" was provided. Using the newer argument \"pod-termination-grace-period\"")
219197
} else if isConfigProvided("grace-period", gracePeriodConfigKey) {
@@ -413,14 +391,3 @@ func isConfigProvided(cliArgName string, envVarName string) bool {
413391
})
414392
return cliArgProvided
415393
}
416-
417-
func getRegionFromQueueURL(queueURL string) string {
418-
for _, partition := range endpoints.DefaultPartitions() {
419-
for regionID := range partition.Regions() {
420-
if strings.Contains(queueURL, regionID) {
421-
return regionID
422-
}
423-
}
424-
}
425-
return ""
426-
}

0 commit comments

Comments
 (0)