diff --git a/pkg/monitor/sqsevent/asg-lifecycle-event.go b/pkg/monitor/sqsevent/asg-lifecycle-event.go index 2a703e4b..87746161 100644 --- a/pkg/monitor/sqsevent/asg-lifecycle-event.go +++ b/pkg/monitor/sqsevent/asg-lifecycle-event.go @@ -57,7 +57,7 @@ type LifecycleDetail struct { LifecycleTransition string `json:"LifecycleTransition"` } -func (m SQSMonitor) asgTerminationToInterruptionEvent(event EventBridgeEvent, messages []*sqs.Message) (monitor.InterruptionEvent, error) { +func (m SQSMonitor) asgTerminationToInterruptionEvent(event EventBridgeEvent, message *sqs.Message) (monitor.InterruptionEvent, error) { lifecycleDetail := &LifecycleDetail{} err := json.Unmarshal(event.Detail, lifecycleDetail) if err != nil { @@ -94,7 +94,7 @@ func (m SQSMonitor) asgTerminationToInterruptionEvent(event EventBridgeEvent, me log.Info().Msgf("Completed ASG Lifecycle Hook (%s) for instance %s", lifecycleDetail.LifecycleHookName, lifecycleDetail.EC2InstanceID) - errs := m.deleteMessages(messages) + errs := m.deleteMessages([]*sqs.Message{message}) if errs != nil { return errs[0] } @@ -111,7 +111,7 @@ func (m SQSMonitor) asgTerminationToInterruptionEvent(event EventBridgeEvent, me if nodeName == "" { log.Info().Msg("Node name is empty, assuming instance was already terminated, deleting queue message") - errs := m.deleteMessages(messages) + errs := m.deleteMessages([]*sqs.Message{message}) if errs != nil { log.Warn().Errs("errors", errs).Msg("There was an error deleting the messages") } diff --git a/pkg/monitor/sqsevent/ec2-state-change-event.go b/pkg/monitor/sqsevent/ec2-state-change-event.go index 8875b5a8..7db7fff8 100644 --- a/pkg/monitor/sqsevent/ec2-state-change-event.go +++ b/pkg/monitor/sqsevent/ec2-state-change-event.go @@ -50,7 +50,7 @@ type EC2StateChangeDetail struct { const instanceStatesToDrain = "stopping,stopped,shutting-down,terminated" -func (m SQSMonitor) ec2StateChangeToInterruptionEvent(event EventBridgeEvent, messages []*sqs.Message) (monitor.InterruptionEvent, error) { +func (m SQSMonitor) ec2StateChangeToInterruptionEvent(event EventBridgeEvent, message *sqs.Message) (monitor.InterruptionEvent, error) { ec2StateChangeDetail := &EC2StateChangeDetail{} err := json.Unmarshal(event.Detail, ec2StateChangeDetail) if err != nil { @@ -75,7 +75,7 @@ func (m SQSMonitor) ec2StateChangeToInterruptionEvent(event EventBridgeEvent, me Description: fmt.Sprintf("EC2 State Change event received. Instance went into %s at %s \n", ec2StateChangeDetail.State, event.getTime()), } interruptionEvent.PostDrainTask = func(interruptionEvent monitor.InterruptionEvent, n node.Node) error { - errs := m.deleteMessages([]*sqs.Message{messages[0]}) + errs := m.deleteMessages([]*sqs.Message{message}) if errs != nil { return errs[0] } diff --git a/pkg/monitor/sqsevent/rebalance-recommendation-event.go b/pkg/monitor/sqsevent/rebalance-recommendation-event.go index 15b7dbcd..b12d5a40 100644 --- a/pkg/monitor/sqsevent/rebalance-recommendation-event.go +++ b/pkg/monitor/sqsevent/rebalance-recommendation-event.go @@ -46,7 +46,7 @@ type RebalanceRecommendationDetail struct { InstanceID string `json:"instance-id"` } -func (m SQSMonitor) rebalanceRecommendationToInterruptionEvent(event EventBridgeEvent, messages []*sqs.Message) (monitor.InterruptionEvent, error) { +func (m SQSMonitor) rebalanceRecommendationToInterruptionEvent(event EventBridgeEvent, message *sqs.Message) (monitor.InterruptionEvent, error) { rebalanceRecDetail := &RebalanceRecommendationDetail{} err := json.Unmarshal(event.Detail, rebalanceRecDetail) if err != nil { @@ -67,7 +67,7 @@ func (m SQSMonitor) rebalanceRecommendationToInterruptionEvent(event EventBridge Description: fmt.Sprintf("Rebalance recommendation event received. Instance will be cordoned at %s \n", event.getTime()), } interruptionEvent.PostDrainTask = func(interruptionEvent monitor.InterruptionEvent, n node.Node) error { - errs := m.deleteMessages(messages) + errs := m.deleteMessages([]*sqs.Message{message}) if errs != nil { return errs[0] } diff --git a/pkg/monitor/sqsevent/spot-itn-event.go b/pkg/monitor/sqsevent/spot-itn-event.go index b0cd647e..6b578619 100644 --- a/pkg/monitor/sqsevent/spot-itn-event.go +++ b/pkg/monitor/sqsevent/spot-itn-event.go @@ -48,7 +48,7 @@ type SpotInterruptionDetail struct { InstanceAction string `json:"instance-action"` } -func (m SQSMonitor) spotITNTerminationToInterruptionEvent(event EventBridgeEvent, messages []*sqs.Message) (monitor.InterruptionEvent, error) { +func (m SQSMonitor) spotITNTerminationToInterruptionEvent(event EventBridgeEvent, message *sqs.Message) (monitor.InterruptionEvent, error) { spotInterruptionDetail := &SpotInterruptionDetail{} err := json.Unmarshal(event.Detail, spotInterruptionDetail) if err != nil { @@ -69,7 +69,7 @@ func (m SQSMonitor) spotITNTerminationToInterruptionEvent(event EventBridgeEvent Description: fmt.Sprintf("Spot Interruption event received. Instance will be interrupted at %s \n", event.getTime()), } interruptionEvent.PostDrainTask = func(interruptionEvent monitor.InterruptionEvent, n node.Node) error { - errs := m.deleteMessages([]*sqs.Message{messages[0]}) + errs := m.deleteMessages([]*sqs.Message{message}) if errs != nil { return errs[0] } diff --git a/pkg/monitor/sqsevent/sqs-monitor.go b/pkg/monitor/sqsevent/sqs-monitor.go index 5786238f..cbb4a472 100644 --- a/pkg/monitor/sqsevent/sqs-monitor.go +++ b/pkg/monitor/sqsevent/sqs-monitor.go @@ -56,35 +56,48 @@ func (m SQSMonitor) Kind() string { // Monitor continuously monitors SQS for events and sends interruption events to the passed in channel func (m SQSMonitor) Monitor() error { - interruptionEvent, err := m.checkForSQSMessage() + log.Debug().Msg("Checking for queue messages") + messages, err := m.receiveQueueMessages(m.QueueURL) if err != nil { - if errors.Is(err, ErrNodeStateNotRunning) { - log.Warn().Err(err).Msg("dropping event for an already terminated node") - return nil - } return err } - if interruptionEvent != nil && interruptionEvent.Kind == SQSTerminateKind { - log.Debug().Msgf("Sending %s interruption event to the interruption channel", SQSTerminateKind) - m.InterruptionChan <- *interruptionEvent - } - return nil -} -// checkForSpotInterruptionNotice checks sqs for new messages and returns interruption events -func (m SQSMonitor) checkForSQSMessage() (*monitor.InterruptionEvent, error) { + failedEvents := 0 + for _, message := range messages { + interruptionEvent, err := m.processSQSMessage(message) + switch { + case errors.Is(err, ErrNodeStateNotRunning): + // If the node is no longer running, just log and delete the message. If message deletion fails, count it as an error. + log.Warn().Err(err).Msg("dropping event for an already terminated node") + errs := m.deleteMessages([]*sqs.Message{message}) + if len(errs) > 0 { + log.Warn().Err(errs[0]).Msg("error deleting event for already terminated node") + failedEvents++ + } - log.Debug().Msg("Checking for queue messages") - messages, err := m.receiveQueueMessages(m.QueueURL) - if err != nil { - return nil, err + case err != nil: + // Log errors and record as failed events + log.Warn().Err(err).Msg("ignoring event due to error") + failedEvents++ + + case err == nil && interruptionEvent != nil && interruptionEvent.Kind == SQSTerminateKind: + // Successfully processed SQS message into a SQSTerminateKind interruption event + log.Debug().Msgf("Sending %s interruption event to the interruption channel", SQSTerminateKind) + m.InterruptionChan <- *interruptionEvent + } } - if len(messages) == 0 { - return nil, nil + + if len(messages) > 0 && failedEvents == len(messages) { + return fmt.Errorf("All of the waiting queue events could not be processed") } + return nil +} + +// processSQSMessage checks sqs for new messages and returns interruption events +func (m SQSMonitor) processSQSMessage(message *sqs.Message) (*monitor.InterruptionEvent, error) { event := EventBridgeEvent{} - err = json.Unmarshal([]byte(*messages[0].Body), &event) + err := json.Unmarshal([]byte(*message.Body), &event) if err != nil { return nil, err } @@ -93,17 +106,17 @@ func (m SQSMonitor) checkForSQSMessage() (*monitor.InterruptionEvent, error) { switch event.Source { case "aws.autoscaling": - interruptionEvent, err = m.asgTerminationToInterruptionEvent(event, messages) + interruptionEvent, err = m.asgTerminationToInterruptionEvent(event, message) if err != nil { return nil, err } case "aws.ec2": if event.DetailType == "EC2 Instance State-change Notification" { - interruptionEvent, err = m.ec2StateChangeToInterruptionEvent(event, messages) + interruptionEvent, err = m.ec2StateChangeToInterruptionEvent(event, message) } else if event.DetailType == "EC2 Spot Instance Interruption Warning" { - interruptionEvent, err = m.spotITNTerminationToInterruptionEvent(event, messages) + interruptionEvent, err = m.spotITNTerminationToInterruptionEvent(event, message) } else if event.DetailType == "EC2 Instance Rebalance Recommendation" { - interruptionEvent, err = m.rebalanceRecommendationToInterruptionEvent(event, messages) + interruptionEvent, err = m.rebalanceRecommendationToInterruptionEvent(event, message) } if err != nil { return nil, err @@ -140,7 +153,7 @@ func (m SQSMonitor) receiveQueueMessages(qURL string) ([]*sqs.Message, error) { aws.String(sqs.QueueAttributeNameAll), }, QueueUrl: &qURL, - MaxNumberOfMessages: aws.Int64(2), + MaxNumberOfMessages: aws.Int64(5), VisibilityTimeout: aws.Int64(20), // 20 seconds WaitTimeSeconds: aws.Int64(0), }) diff --git a/pkg/monitor/sqsevent/sqs-monitor_test.go b/pkg/monitor/sqsevent/sqs-monitor_test.go index b220406b..c5f231dc 100644 --- a/pkg/monitor/sqsevent/sqs-monitor_test.go +++ b/pkg/monitor/sqsevent/sqs-monitor_test.go @@ -104,71 +104,138 @@ func TestMonitor_Success(t *testing.T) { ec2Mock := h.MockedEC2{ DescribeInstancesResp: getDescribeInstancesResp(dnsNodeName), } - drainChan := make(chan monitor.InterruptionEvent) + drainChan := make(chan monitor.InterruptionEvent, 1) sqsMonitor := sqsevent.SQSMonitor{ SQS: sqsMock, EC2: ec2Mock, + ManagedAsgTag: "aws-node-termination-handler/managed", ASG: mockIsManagedTrue(nil), CheckIfManaged: true, QueueURL: "https://test-queue", InterruptionChan: drainChan, } - go func() { - result := <-drainChan + + err = sqsMonitor.Monitor() + h.Ok(t, err) + + select { + case result := <-drainChan: h.Equals(t, sqsevent.SQSTerminateKind, result.Kind) h.Equals(t, result.NodeName, dnsNodeName) h.Assert(t, result.PostDrainTask != nil, "PostDrainTask should have been set") h.Assert(t, result.PreDrainTask != nil, "PreDrainTask should have been set") - }() + err = result.PostDrainTask(result, node.Node{}) + h.Ok(t, err) + default: + h.Ok(t, fmt.Errorf("Expected an event to be generated")) + } - err = sqsMonitor.Monitor() - h.Ok(t, err) } } func TestMonitor_DrainTasks(t *testing.T) { - for _, event := range []sqsevent.EventBridgeEvent{spotItnEvent, asgLifecycleEvent, rebalanceRecommendationEvent} { + testEvents := []sqsevent.EventBridgeEvent{spotItnEvent, asgLifecycleEvent, rebalanceRecommendationEvent} + messages := make([]*sqs.Message, 0, len(testEvents)) + for _, event := range testEvents { msg, err := getSQSMessageFromEvent(event) h.Ok(t, err) - messages := []*sqs.Message{ - &msg, - } - sqsMock := h.MockedSQS{ - ReceiveMessageResp: sqs.ReceiveMessageOutput{Messages: messages}, - ReceiveMessageErr: nil, - DeleteMessageResp: sqs.DeleteMessageOutput{}, - } - dnsNodeName := "ip-10-0-0-157.us-east-2.compute.internal" - ec2Mock := h.MockedEC2{ - DescribeInstancesResp: getDescribeInstancesResp(dnsNodeName), - } - asgMock := h.MockedASG{ - CompleteLifecycleActionResp: autoscaling.CompleteLifecycleActionOutput{}, - } - drainChan := make(chan monitor.InterruptionEvent) + messages = append(messages, &msg) + } - sqsMonitor := sqsevent.SQSMonitor{ - SQS: sqsMock, - EC2: ec2Mock, - ASG: mockIsManagedTrue(&asgMock), - CheckIfManaged: true, - QueueURL: "https://test-queue", - InterruptionChan: drainChan, - } - go func() { + sqsMock := h.MockedSQS{ + ReceiveMessageResp: sqs.ReceiveMessageOutput{Messages: messages}, + ReceiveMessageErr: nil, + DeleteMessageResp: sqs.DeleteMessageOutput{}, + } + dnsNodeName := "ip-10-0-0-157.us-east-2.compute.internal" + ec2Mock := h.MockedEC2{ + DescribeInstancesResp: getDescribeInstancesResp(dnsNodeName), + } + asgMock := h.MockedASG{ + CompleteLifecycleActionResp: autoscaling.CompleteLifecycleActionOutput{}, + } + drainChan := make(chan monitor.InterruptionEvent, len(testEvents)) + + sqsMonitor := sqsevent.SQSMonitor{ + SQS: sqsMock, + EC2: ec2Mock, + ManagedAsgTag: "aws-node-termination-handler/managed", + ASG: mockIsManagedTrue(&asgMock), + CheckIfManaged: true, + QueueURL: "https://test-queue", + InterruptionChan: drainChan, + } + + err := sqsMonitor.Monitor() + h.Ok(t, err) + + for _, event := range testEvents { + t.Run(event.DetailType, func(st *testing.T) { result := <-drainChan + h.Equals(st, sqsevent.SQSTerminateKind, result.Kind) + h.Equals(st, result.NodeName, dnsNodeName) + h.Assert(st, result.PostDrainTask != nil, "PostDrainTask should have been set") + h.Assert(st, result.PreDrainTask != nil, "PreDrainTask should have been set") + err := result.PostDrainTask(result, node.Node{}) + h.Ok(st, err) + }) + } +} + +func TestMonitor_DrainTasks_Errors(t *testing.T) { + testEvents := []sqsevent.EventBridgeEvent{spotItnEvent, asgLifecycleEvent, sqsevent.EventBridgeEvent{}, rebalanceRecommendationEvent} + messages := make([]*sqs.Message, 0, len(testEvents)) + for _, event := range testEvents { + msg, err := getSQSMessageFromEvent(event) + h.Ok(t, err) + messages = append(messages, &msg) + } + + sqsMock := h.MockedSQS{ + ReceiveMessageResp: sqs.ReceiveMessageOutput{Messages: messages}, + ReceiveMessageErr: nil, + DeleteMessageResp: sqs.DeleteMessageOutput{}, + } + dnsNodeName := "ip-10-0-0-157.us-east-2.compute.internal" + ec2Mock := h.MockedEC2{ + DescribeInstancesResp: getDescribeInstancesResp(dnsNodeName), + } + asgMock := h.MockedASG{ + CompleteLifecycleActionResp: autoscaling.CompleteLifecycleActionOutput{}, + } + drainChan := make(chan monitor.InterruptionEvent, len(testEvents)) + + sqsMonitor := sqsevent.SQSMonitor{ + SQS: sqsMock, + EC2: ec2Mock, + ManagedAsgTag: "aws-node-termination-handler/managed", + ASG: mockIsManagedTrue(&asgMock), + CheckIfManaged: true, + QueueURL: "https://test-queue", + InterruptionChan: drainChan, + } + + err := sqsMonitor.Monitor() + h.Ok(t, err) + + count := 0 + done := false + for !done { + select { + case result := <-drainChan: + count++ h.Equals(t, sqsevent.SQSTerminateKind, result.Kind) h.Equals(t, result.NodeName, dnsNodeName) h.Assert(t, result.PostDrainTask != nil, "PostDrainTask should have been set") h.Assert(t, result.PreDrainTask != nil, "PreDrainTask should have been set") err := result.PostDrainTask(result, node.Node{}) h.Ok(t, err) - }() - - err = sqsMonitor.Monitor() - h.Ok(t, err) + default: + done = true + } } + h.Equals(t, count, 3) } func TestMonitor_DrainTasksASGFailure(t *testing.T) { @@ -190,28 +257,33 @@ func TestMonitor_DrainTasksASGFailure(t *testing.T) { CompleteLifecycleActionResp: autoscaling.CompleteLifecycleActionOutput{}, CompleteLifecycleActionErr: awserr.NewRequestFailure(aws.ErrMissingEndpoint, 500, "bad-request"), } - drainChan := make(chan monitor.InterruptionEvent) + drainChan := make(chan monitor.InterruptionEvent, 1) sqsMonitor := sqsevent.SQSMonitor{ SQS: sqsMock, EC2: ec2Mock, + ManagedAsgTag: "aws-node-termination-handler/managed", ASG: mockIsManagedTrue(&asgMock), CheckIfManaged: true, QueueURL: "https://test-queue", InterruptionChan: drainChan, } - go func() { - result := <-drainChan + + err = sqsMonitor.Monitor() + h.Ok(t, err) + + select { + case result := <-drainChan: h.Equals(t, sqsevent.SQSTerminateKind, result.Kind) h.Equals(t, result.NodeName, dnsNodeName) h.Assert(t, result.PostDrainTask != nil, "PostDrainTask should have been set") h.Assert(t, result.PreDrainTask != nil, "PreDrainTask should have been set") - err := result.PostDrainTask(result, node.Node{}) + err = result.PostDrainTask(result, node.Node{}) h.Nok(t, err) - }() + default: + h.Ok(t, fmt.Errorf("Expected to get an event with a failing post drain task")) + } - err = sqsMonitor.Monitor() - h.Ok(t, err) } func TestMonitor_Failure(t *testing.T) { @@ -226,20 +298,23 @@ func TestMonitor_Failure(t *testing.T) { ReceiveMessageResp: sqs.ReceiveMessageOutput{Messages: messages}, ReceiveMessageErr: nil, } - drainChan := make(chan monitor.InterruptionEvent) + drainChan := make(chan monitor.InterruptionEvent, 1) sqsMonitor := sqsevent.SQSMonitor{ SQS: sqsMock, QueueURL: "https://test-queue", InterruptionChan: drainChan, } - go func() { - result := <-drainChan - h.Equals(t, result, monitor.InterruptionEvent{}) - }() err = sqsMonitor.Monitor() h.Nok(t, err) + + select { + case <-drainChan: + h.Ok(t, fmt.Errorf("Expected no events")) + default: + h.Ok(t, nil) + } } } @@ -254,20 +329,24 @@ func TestMonitor_SQSFailure(t *testing.T) { ReceiveMessageResp: sqs.ReceiveMessageOutput{Messages: messages}, ReceiveMessageErr: fmt.Errorf("error"), } - drainChan := make(chan monitor.InterruptionEvent) + drainChan := make(chan monitor.InterruptionEvent, 1) sqsMonitor := sqsevent.SQSMonitor{ SQS: sqsMock, QueueURL: "https://test-queue", InterruptionChan: drainChan, } - go func() { - result := <-drainChan - h.Equals(t, result, monitor.InterruptionEvent{}) - }() err = sqsMonitor.Monitor() h.Nok(t, err) + + select { + case <-drainChan: + h.Ok(t, fmt.Errorf("Expected no events")) + default: + h.Ok(t, nil) + } + } } @@ -277,14 +356,27 @@ func TestMonitor_SQSNoMessages(t *testing.T) { ReceiveMessageResp: sqs.ReceiveMessageOutput{Messages: messages}, ReceiveMessageErr: nil, } + + drainChan := make(chan monitor.InterruptionEvent, 1) + sqsMonitor := sqsevent.SQSMonitor{ - SQS: sqsMock, - QueueURL: "https://test-queue", + SQS: sqsMock, + QueueURL: "https://test-queue", + InterruptionChan: drainChan, } err := sqsMonitor.Monitor() h.Ok(t, err) + + select { + case <-drainChan: + h.Ok(t, fmt.Errorf("Expected no events")) + default: + h.Ok(t, nil) + } + } +// Test processing invalid sqs message func TestMonitor_SQSJsonErr(t *testing.T) { replaceStr := `{"test":"test-string-to-replace"}` badJson := []*sqs.Message{{Body: aws.String(`?`)}} @@ -303,12 +395,22 @@ func TestMonitor_SQSJsonErr(t *testing.T) { ReceiveMessageResp: sqs.ReceiveMessageOutput{Messages: badMessages}, ReceiveMessageErr: nil, } + + drainChan := make(chan monitor.InterruptionEvent, 1) sqsMonitor := sqsevent.SQSMonitor{ - SQS: sqsMock, - QueueURL: "https://test-queue", + SQS: sqsMock, + QueueURL: "https://test-queue", + InterruptionChan: drainChan, } err := sqsMonitor.Monitor() h.Nok(t, err) + + select { + case <-drainChan: + h.Ok(t, fmt.Errorf("Expected no events")) + default: + h.Ok(t, nil) + } } } @@ -327,7 +429,7 @@ func TestMonitor_EC2Failure(t *testing.T) { DescribeInstancesResp: getDescribeInstancesResp(""), DescribeInstancesErr: fmt.Errorf("error"), } - drainChan := make(chan monitor.InterruptionEvent) + drainChan := make(chan monitor.InterruptionEvent, 1) sqsMonitor := sqsevent.SQSMonitor{ SQS: sqsMock, @@ -335,13 +437,16 @@ func TestMonitor_EC2Failure(t *testing.T) { QueueURL: "https://test-queue", InterruptionChan: drainChan, } - go func() { - result := <-drainChan - h.Equals(t, result, monitor.InterruptionEvent{}) - }() err = sqsMonitor.Monitor() h.Nok(t, err) + + select { + case <-drainChan: + h.Ok(t, fmt.Errorf("Expected no events")) + default: + h.Ok(t, nil) + } } } @@ -359,7 +464,7 @@ func TestMonitor_EC2NoInstances(t *testing.T) { ec2Mock := h.MockedEC2{ DescribeInstancesResp: ec2.DescribeInstancesOutput{}, } - drainChan := make(chan monitor.InterruptionEvent) + drainChan := make(chan monitor.InterruptionEvent, 1) sqsMonitor := sqsevent.SQSMonitor{ SQS: sqsMock, @@ -367,13 +472,16 @@ func TestMonitor_EC2NoInstances(t *testing.T) { QueueURL: "https://test-queue", InterruptionChan: drainChan, } - go func() { - result := <-drainChan - h.Equals(t, result, monitor.InterruptionEvent{}) - }() err = sqsMonitor.Monitor() h.Nok(t, err) + + select { + case <-drainChan: + h.Ok(t, fmt.Errorf("Expected no events")) + default: + h.Ok(t, nil) + } } } @@ -391,23 +499,27 @@ func TestMonitor_EC2NoDNSName(t *testing.T) { ec2Mock := h.MockedEC2{ DescribeInstancesResp: getDescribeInstancesResp(""), } - drainChan := make(chan monitor.InterruptionEvent) + drainChan := make(chan monitor.InterruptionEvent, 1) sqsMonitor := sqsevent.SQSMonitor{ SQS: sqsMock, EC2: ec2Mock, + ManagedAsgTag: "aws-node-termination-handler/managed", ASG: mockIsManagedTrue(nil), CheckIfManaged: true, QueueURL: "https://test-queue", InterruptionChan: drainChan, } - go func() { - result := <-drainChan - h.Equals(t, result.Kind, sqsevent.SQSTerminateKind) - }() err = sqsMonitor.Monitor() h.Ok(t, err) + + select { + case <-drainChan: + h.Ok(t, fmt.Errorf("Expected no events")) + default: + h.Ok(t, nil) + } } func TestMonitor_EC2NoDNSNameOnTerminatedInstance(t *testing.T) { @@ -427,23 +539,27 @@ func TestMonitor_EC2NoDNSNameOnTerminatedInstance(t *testing.T) { ec2Mock.DescribeInstancesResp.Reservations[0].Instances[0].State = &ec2.InstanceState{ Name: aws.String("running"), } - drainChan := make(chan monitor.InterruptionEvent) + drainChan := make(chan monitor.InterruptionEvent, 1) sqsMonitor := sqsevent.SQSMonitor{ SQS: sqsMock, EC2: ec2Mock, + ManagedAsgTag: "aws-node-termination-handler/managed", ASG: mockIsManagedTrue(nil), CheckIfManaged: true, QueueURL: "https://test-queue", InterruptionChan: drainChan, } - go func() { - result := <-drainChan - h.Equals(t, result.Kind, sqsevent.SQSTerminateKind) - }() err = sqsMonitor.Monitor() h.Nok(t, err) + + select { + case <-drainChan: + h.Ok(t, fmt.Errorf("Expected no events")) + default: + h.Ok(t, nil) + } } func TestMonitor_SQSDeleteFailure(t *testing.T) { @@ -461,23 +577,27 @@ func TestMonitor_SQSDeleteFailure(t *testing.T) { ec2Mock := h.MockedEC2{ DescribeInstancesResp: getDescribeInstancesResp(""), } - drainChan := make(chan monitor.InterruptionEvent) + drainChan := make(chan monitor.InterruptionEvent, 1) sqsMonitor := sqsevent.SQSMonitor{ SQS: sqsMock, EC2: ec2Mock, + ManagedAsgTag: "aws-node-termination-handler/managed", ASG: mockIsManagedTrue(nil), CheckIfManaged: true, QueueURL: "https://test-queue", InterruptionChan: drainChan, } - go func() { - result := <-drainChan - h.Equals(t, result.Kind, sqsevent.SQSTerminateKind) - }() err = sqsMonitor.Monitor() - h.Ok(t, err) + h.Nok(t, err) + + select { + case <-drainChan: + h.Ok(t, fmt.Errorf("Expected no events")) + default: + h.Ok(t, nil) + } } func TestMonitor_InstanceNotManaged(t *testing.T) { @@ -496,16 +616,26 @@ func TestMonitor_InstanceNotManaged(t *testing.T) { DescribeInstancesResp: getDescribeInstancesResp(dnsNodeName), } + drainChan := make(chan monitor.InterruptionEvent, 1) + sqsMonitor := sqsevent.SQSMonitor{ - SQS: sqsMock, - EC2: ec2Mock, - ASG: mockIsManagedFalse(nil), - CheckIfManaged: true, - QueueURL: "https://test-queue", + SQS: sqsMock, + EC2: ec2Mock, + ASG: mockIsManagedFalse(nil), + CheckIfManaged: true, + QueueURL: "https://test-queue", + InterruptionChan: drainChan, } err = sqsMonitor.Monitor() h.Ok(t, err) + + select { + case <-drainChan: + h.Ok(t, fmt.Errorf("Expected no events")) + default: + h.Ok(t, nil) + } } } @@ -525,16 +655,26 @@ func TestMonitor_InstanceManagedErr(t *testing.T) { DescribeInstancesResp: getDescribeInstancesResp(dnsNodeName), } + drainChan := make(chan monitor.InterruptionEvent, 1) + sqsMonitor := sqsevent.SQSMonitor{ - SQS: sqsMock, - EC2: ec2Mock, - ASG: mockIsManagedErr(nil), - CheckIfManaged: true, - QueueURL: "https://test-queue", + SQS: sqsMock, + EC2: ec2Mock, + ASG: mockIsManagedErr(nil), + CheckIfManaged: true, + QueueURL: "https://test-queue", + InterruptionChan: drainChan, } err = sqsMonitor.Monitor() h.Nok(t, err) + + select { + case <-drainChan: + h.Ok(t, fmt.Errorf("Expected no events")) + default: + h.Ok(t, nil) + } } }