From ac8f036f7c829472a1f59c0145df4974e6c707ab Mon Sep 17 00:00:00 2001 From: Shand Date: Thu, 26 Jun 2025 16:44:08 -0700 Subject: [PATCH] Adding CancelDrainTask to ASG Termination Events to close generated heartbeat when failing to process event --- pkg/interruptionevent/draincordon/handler.go | 3 ++ .../internal/common/handler.go | 10 ++++++ pkg/monitor/sqsevent/asg-lifecycle-event.go | 17 ++++++++-- pkg/monitor/sqsevent/sqs-monitor_test.go | 31 ++++++++++++++++--- pkg/monitor/types.go | 1 + pkg/observability/k8s-events.go | 4 +++ 6 files changed, 59 insertions(+), 7 deletions(-) diff --git a/pkg/interruptionevent/draincordon/handler.go b/pkg/interruptionevent/draincordon/handler.go index 9fac6b07..a36388bf 100644 --- a/pkg/interruptionevent/draincordon/handler.go +++ b/pkg/interruptionevent/draincordon/handler.go @@ -111,6 +111,9 @@ func (h *Handler) HandleEvent(drainEvent *monitor.InterruptionEvent) error { } if err != nil { + if drainEvent.CancelDrainTask != nil { + h.commonHandler.RunCancelDrainTask(nodeName, drainEvent) + } h.commonHandler.InterruptionEventStore.CancelInterruptionEvent(drainEvent.EventID) } else { h.commonHandler.InterruptionEventStore.MarkAllAsProcessed(nodeName) diff --git a/pkg/interruptionevent/internal/common/handler.go b/pkg/interruptionevent/internal/common/handler.go index 0c58366a..023289da 100644 --- a/pkg/interruptionevent/internal/common/handler.go +++ b/pkg/interruptionevent/internal/common/handler.go @@ -55,6 +55,16 @@ func (h *Handler) RunPreDrainTask(nodeName string, drainEvent *monitor.Interrupt h.Metrics.NodeActionsInc("pre-drain", nodeName, drainEvent.EventID, err) } +func (h *Handler) RunCancelDrainTask(nodeName string, drainEvent *monitor.InterruptionEvent) { + err := drainEvent.CancelDrainTask(*drainEvent, h.Node) + if err != nil { + log.Err(err).Msg("There was a problem executing the early exit task") + h.Recorder.Emit(nodeName, observability.Warning, observability.CancelDrainErrReason, observability.CancelDrainErrMsgFmt, err.Error()) + } else { + h.Recorder.Emit(nodeName, observability.Normal, observability.CancelDrainReason, observability.CancelDrainMsg) + } +} + func (h *Handler) RunPostDrainTask(nodeName string, drainEvent *monitor.InterruptionEvent) { err := drainEvent.PostDrainTask(*drainEvent, h.Node) if err != nil { diff --git a/pkg/monitor/sqsevent/asg-lifecycle-event.go b/pkg/monitor/sqsevent/asg-lifecycle-event.go index a442b824..6fd4d852 100644 --- a/pkg/monitor/sqsevent/asg-lifecycle-event.go +++ b/pkg/monitor/sqsevent/asg-lifecycle-event.go @@ -99,6 +99,7 @@ func (m SQSMonitor) asgTerminationToInterruptionEvent(event *EventBridgeEvent, m } stopHeartbeatCh := make(chan struct{}) + cancelHeartbeatCh := make(chan struct{}) interruptionEvent.PostDrainTask = func(interruptionEvent monitor.InterruptionEvent, _ node.Node) error { @@ -111,13 +112,18 @@ func (m SQSMonitor) asgTerminationToInterruptionEvent(event *EventBridgeEvent, m close(stopHeartbeatCh) return m.deleteMessage(message) } + + interruptionEvent.CancelDrainTask = func(_ monitor.InterruptionEvent, _ node.Node) error { + close(cancelHeartbeatCh) + return nil + } interruptionEvent.PreDrainTask = func(interruptionEvent monitor.InterruptionEvent, n node.Node) error { nthConfig := n.GetNthConfig() // If only HeartbeatInterval is set, HeartbeatUntil will default to 172800. if nthConfig.HeartbeatInterval != -1 && nthConfig.HeartbeatUntil != -1 { go m.checkHeartbeatTimeout(nthConfig.HeartbeatInterval, lifecycleDetail) - go m.SendHeartbeats(nthConfig.HeartbeatInterval, nthConfig.HeartbeatUntil, lifecycleDetail, stopHeartbeatCh) + go m.SendHeartbeats(nthConfig.HeartbeatInterval, nthConfig.HeartbeatUntil, lifecycleDetail, stopHeartbeatCh, cancelHeartbeatCh) } err := n.TaintASGLifecycleTermination(interruptionEvent.NodeName, interruptionEvent.EventID) @@ -167,13 +173,20 @@ func (m SQSMonitor) checkHeartbeatTimeout(heartbeatInterval int, lifecycleDetail } // Issue lifecycle heartbeats to reset the heartbeat timeout timer in ASG -func (m SQSMonitor) SendHeartbeats(heartbeatInterval int, heartbeatUntil int, lifecycleDetail *LifecycleDetail, stopCh <-chan struct{}) { +func (m SQSMonitor) SendHeartbeats(heartbeatInterval int, heartbeatUntil int, lifecycleDetail *LifecycleDetail, stopCh <-chan struct{}, cancelCh <-chan struct{}) { ticker := time.NewTicker(time.Duration(heartbeatInterval) * time.Second) defer ticker.Stop() timeout := time.After(time.Duration(heartbeatUntil) * time.Second) for { select { + case <-cancelCh: + log.Info().Str("asgName", lifecycleDetail.AutoScalingGroupName). + Str("lifecycleHookName", lifecycleDetail.LifecycleHookName). + Str("lifecycleActionToken", lifecycleDetail.LifecycleActionToken). + Str("instanceID", lifecycleDetail.EC2InstanceID). + Msg("Failed to cordon and drain the node, stopping heartbeat") + return case <-stopCh: log.Info().Str("asgName", lifecycleDetail.AutoScalingGroupName). Str("lifecycleHookName", lifecycleDetail.LifecycleHookName). diff --git a/pkg/monitor/sqsevent/sqs-monitor_test.go b/pkg/monitor/sqsevent/sqs-monitor_test.go index 1884dddc..959a1516 100644 --- a/pkg/monitor/sqsevent/sqs-monitor_test.go +++ b/pkg/monitor/sqsevent/sqs-monitor_test.go @@ -188,6 +188,7 @@ func TestMonitor_EventBridgeSuccess(t *testing.T) { h.Equals(t, result.NodeName, dnsNodeName) h.Assert(t, result.PostDrainTask != nil, "PostDrainTask should have been set") h.Assert(t, result.PreDrainTask != nil, "PreDrainTask should have been set") + if event.ID == asgLifecycleEvent.ID { h.Assert(t, result.CancelDrainTask != nil, "CancelDrainTask should have been set") } err = result.PostDrainTask(result, node.Node{}) h.Ok(t, err) default: @@ -273,6 +274,7 @@ func TestMonitor_AsgDirectToSqsSuccess(t *testing.T) { h.Equals(t, result.NodeName, dnsNodeName) h.Assert(t, result.PostDrainTask != nil, "PostDrainTask should have been set") h.Assert(t, result.PreDrainTask != nil, "PreDrainTask should have been set") + h.Assert(t, result.CancelDrainTask != nil, "CancelDrainTask should have been set") err = result.PostDrainTask(result, node.Node{}) h.Ok(t, err) default: @@ -365,6 +367,7 @@ func TestMonitor_DrainTasks(t *testing.T) { h.Equals(st, result.NodeName, dnsNodeName) h.Assert(st, result.PostDrainTask != nil, "PostDrainTask should have been set") h.Assert(st, result.PreDrainTask != nil, "PreDrainTask should have been set") + if event.ID == asgLifecycleEvent.ID { h.Assert(t, result.CancelDrainTask != nil, "CancelDrainTask should have been set") } err := result.PostDrainTask(result, node.Node{}) h.Ok(st, err) }) @@ -466,6 +469,7 @@ func TestMonitor_DrainTasks_Errors(t *testing.T) { h.Equals(t, result.NodeName, dnsNodeName) h.Assert(t, result.PostDrainTask != nil, "PostDrainTask should have been set") h.Assert(t, result.PreDrainTask != nil, "PreDrainTask should have been set") + if i == 1 { h.Assert(t, result.CancelDrainTask != nil, "CancelDrainTask should have been set") } err := result.PostDrainTask(result, node.Node{}) h.Ok(t, err) default: @@ -909,32 +913,39 @@ func TestMonitor_InstanceNotManaged(t *testing.T) { } func TestSendHeartbeats_EarlyClosure(t *testing.T) { - err := heartbeatTestHelper(nil, 3500, 1, 5) + err := heartbeatTestHelper(nil, 3500, 1, 5, false) h.Ok(t, err) h.Assert(t, h.HeartbeatCallCount == 3, "3 Heartbeat Expected, got %d", h.HeartbeatCallCount) } func TestSendHeartbeats_HeartbeatUntilExpire(t *testing.T) { - err := heartbeatTestHelper(nil, 8000, 1, 5) + err := heartbeatTestHelper(nil, 8000, 1, 5, false) h.Ok(t, err) h.Assert(t, h.HeartbeatCallCount == 5, "5 Heartbeat Expected, got %d", h.HeartbeatCallCount) } func TestSendHeartbeats_ErrThrottlingASG(t *testing.T) { RecordLifecycleActionHeartbeatErr := awserr.New("Throttling", "Rate exceeded", nil) - err := heartbeatTestHelper(RecordLifecycleActionHeartbeatErr, 8000, 1, 6) + err := heartbeatTestHelper(RecordLifecycleActionHeartbeatErr, 8000, 1, 6, false) h.Ok(t, err) h.Assert(t, h.HeartbeatCallCount == 6, "6 Heartbeat Expected, got %d", h.HeartbeatCallCount) } func TestSendHeartbeats_ErrInvalidTarget(t *testing.T) { RecordLifecycleActionHeartbeatErr := awserr.New("ValidationError", "No active Lifecycle Action found", nil) - err := heartbeatTestHelper(RecordLifecycleActionHeartbeatErr, 6000, 1, 4) + err := heartbeatTestHelper(RecordLifecycleActionHeartbeatErr, 6000, 1, 4, false) h.Ok(t, err) h.Assert(t, h.HeartbeatCallCount == 1, "1 Heartbeat Expected, got %d", h.HeartbeatCallCount) } -func heartbeatTestHelper(RecordLifecycleActionHeartbeatErr error, sleepMilliSeconds int, heartbeatInterval int, heartbeatUntil int) error { + +func TestSendHeartbeats_CancelHeartbeat(t *testing.T) { + err := heartbeatTestHelper(nil, 6000, 1, 4, true) + h.Ok(t, err) + h.Assert(t, h.HeartbeatCallCount == 2, "2 Heartbeat Expected, got %d", h.HeartbeatCallCount) +} + +func heartbeatTestHelper(RecordLifecycleActionHeartbeatErr error, sleepMilliSeconds int, heartbeatInterval int, heartbeatUntil int, cancelDrain bool) error { h.HeartbeatCallCount = 0 msg, err := getSQSMessageFromEvent(asgLifecycleEvent) @@ -986,6 +997,16 @@ func heartbeatTestHelper(RecordLifecycleActionHeartbeatErr error, sleepMilliSeco return err } + if cancelDrain == true { + if result.CancelDrainTask == nil { + return fmt.Errorf("CancelDrainTask should have been set") + } + time.Sleep(2100 * time.Millisecond) + if err := result.CancelDrainTask(result, *testNode); err != nil { + return err + } + } + if result.PostDrainTask == nil { return fmt.Errorf("PostDrainTask should have been set") } diff --git a/pkg/monitor/types.go b/pkg/monitor/types.go index 6367868c..522ba5eb 100644 --- a/pkg/monitor/types.go +++ b/pkg/monitor/types.go @@ -61,6 +61,7 @@ type InterruptionEvent struct { InProgress bool PreDrainTask DrainTask `json:"-"` PostDrainTask DrainTask `json:"-"` + CancelDrainTask DrainTask `json:"-"` } // TimeUntilEvent returns the duration until the event start time diff --git a/pkg/observability/k8s-events.go b/pkg/observability/k8s-events.go index 5ee893ad..5e6f224c 100644 --- a/pkg/observability/k8s-events.go +++ b/pkg/observability/k8s-events.go @@ -56,6 +56,10 @@ const ( PostDrainErrMsgFmt = "There was a problem executing the post-drain task: %s" PostDrainReason = "PostDrain" PostDrainMsg = "Post-drain task successfully executed" + CancelDrainErrReason = "CancelDrainError" + CancelDrainErrMsgFmt = "There was a problem executing the early exit task: %s" + CancelDrainReason = "CancelDrain" + CancelDrainMsg = "Early exit task successfully executed" ) // Interruption event reasons