Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pkg/interruptionevent/draincordon/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ func (h *Handler) HandleEvent(drainEvent *monitor.InterruptionEvent) error {
}

if err != nil {
if drainEvent.CancelDrainTask != nil {
h.commonHandler.RunCancelDrainTask(nodeName, drainEvent)
}
h.commonHandler.InterruptionEventStore.CancelInterruptionEvent(drainEvent.EventID)
} else {
h.commonHandler.InterruptionEventStore.MarkAllAsProcessed(nodeName)
Expand Down
10 changes: 10 additions & 0 deletions pkg/interruptionevent/internal/common/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,16 @@ func (h *Handler) RunPreDrainTask(nodeName string, drainEvent *monitor.Interrupt
h.Metrics.NodeActionsInc("pre-drain", nodeName, drainEvent.EventID, err)
}

func (h *Handler) RunCancelDrainTask(nodeName string, drainEvent *monitor.InterruptionEvent) {
err := drainEvent.CancelDrainTask(*drainEvent, h.Node)
if err != nil {
log.Err(err).Msg("There was a problem executing the early exit task")
h.Recorder.Emit(nodeName, observability.Warning, observability.CancelDrainErrReason, observability.CancelDrainErrMsgFmt, err.Error())
} else {
h.Recorder.Emit(nodeName, observability.Normal, observability.CancelDrainReason, observability.CancelDrainMsg)
}
}

func (h *Handler) RunPostDrainTask(nodeName string, drainEvent *monitor.InterruptionEvent) {
err := drainEvent.PostDrainTask(*drainEvent, h.Node)
if err != nil {
Expand Down
17 changes: 15 additions & 2 deletions pkg/monitor/sqsevent/asg-lifecycle-event.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ func (m SQSMonitor) asgTerminationToInterruptionEvent(event *EventBridgeEvent, m
}

stopHeartbeatCh := make(chan struct{})
cancelHeartbeatCh := make(chan struct{})

interruptionEvent.PostDrainTask = func(interruptionEvent monitor.InterruptionEvent, _ node.Node) error {

Expand All @@ -111,13 +112,18 @@ func (m SQSMonitor) asgTerminationToInterruptionEvent(event *EventBridgeEvent, m
close(stopHeartbeatCh)
return m.deleteMessage(message)
}

interruptionEvent.CancelDrainTask = func(_ monitor.InterruptionEvent, _ node.Node) error {
close(cancelHeartbeatCh)
return nil
}

interruptionEvent.PreDrainTask = func(interruptionEvent monitor.InterruptionEvent, n node.Node) error {
nthConfig := n.GetNthConfig()
// If only HeartbeatInterval is set, HeartbeatUntil will default to 172800.
if nthConfig.HeartbeatInterval != -1 && nthConfig.HeartbeatUntil != -1 {
go m.checkHeartbeatTimeout(nthConfig.HeartbeatInterval, lifecycleDetail)
go m.SendHeartbeats(nthConfig.HeartbeatInterval, nthConfig.HeartbeatUntil, lifecycleDetail, stopHeartbeatCh)
go m.SendHeartbeats(nthConfig.HeartbeatInterval, nthConfig.HeartbeatUntil, lifecycleDetail, stopHeartbeatCh, cancelHeartbeatCh)
}

err := n.TaintASGLifecycleTermination(interruptionEvent.NodeName, interruptionEvent.EventID)
Expand Down Expand Up @@ -167,13 +173,20 @@ func (m SQSMonitor) checkHeartbeatTimeout(heartbeatInterval int, lifecycleDetail
}

// Issue lifecycle heartbeats to reset the heartbeat timeout timer in ASG
func (m SQSMonitor) SendHeartbeats(heartbeatInterval int, heartbeatUntil int, lifecycleDetail *LifecycleDetail, stopCh <-chan struct{}) {
func (m SQSMonitor) SendHeartbeats(heartbeatInterval int, heartbeatUntil int, lifecycleDetail *LifecycleDetail, stopCh <-chan struct{}, cancelCh <-chan struct{}) {
ticker := time.NewTicker(time.Duration(heartbeatInterval) * time.Second)
defer ticker.Stop()
timeout := time.After(time.Duration(heartbeatUntil) * time.Second)

for {
select {
case <-cancelCh:
log.Info().Str("asgName", lifecycleDetail.AutoScalingGroupName).
Str("lifecycleHookName", lifecycleDetail.LifecycleHookName).
Str("lifecycleActionToken", lifecycleDetail.LifecycleActionToken).
Str("instanceID", lifecycleDetail.EC2InstanceID).
Msg("Failed to cordon and drain the node, stopping heartbeat")
return
case <-stopCh:
log.Info().Str("asgName", lifecycleDetail.AutoScalingGroupName).
Str("lifecycleHookName", lifecycleDetail.LifecycleHookName).
Expand Down
31 changes: 26 additions & 5 deletions pkg/monitor/sqsevent/sqs-monitor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ func TestMonitor_EventBridgeSuccess(t *testing.T) {
h.Equals(t, result.NodeName, dnsNodeName)
h.Assert(t, result.PostDrainTask != nil, "PostDrainTask should have been set")
h.Assert(t, result.PreDrainTask != nil, "PreDrainTask should have been set")
if event.ID == asgLifecycleEvent.ID { h.Assert(t, result.CancelDrainTask != nil, "CancelDrainTask should have been set") }
err = result.PostDrainTask(result, node.Node{})
h.Ok(t, err)
default:
Expand Down Expand Up @@ -273,6 +274,7 @@ func TestMonitor_AsgDirectToSqsSuccess(t *testing.T) {
h.Equals(t, result.NodeName, dnsNodeName)
h.Assert(t, result.PostDrainTask != nil, "PostDrainTask should have been set")
h.Assert(t, result.PreDrainTask != nil, "PreDrainTask should have been set")
h.Assert(t, result.CancelDrainTask != nil, "CancelDrainTask should have been set")
err = result.PostDrainTask(result, node.Node{})
h.Ok(t, err)
default:
Expand Down Expand Up @@ -365,6 +367,7 @@ func TestMonitor_DrainTasks(t *testing.T) {
h.Equals(st, result.NodeName, dnsNodeName)
h.Assert(st, result.PostDrainTask != nil, "PostDrainTask should have been set")
h.Assert(st, result.PreDrainTask != nil, "PreDrainTask should have been set")
if event.ID == asgLifecycleEvent.ID { h.Assert(t, result.CancelDrainTask != nil, "CancelDrainTask should have been set") }
err := result.PostDrainTask(result, node.Node{})
h.Ok(st, err)
})
Expand Down Expand Up @@ -466,6 +469,7 @@ func TestMonitor_DrainTasks_Errors(t *testing.T) {
h.Equals(t, result.NodeName, dnsNodeName)
h.Assert(t, result.PostDrainTask != nil, "PostDrainTask should have been set")
h.Assert(t, result.PreDrainTask != nil, "PreDrainTask should have been set")
if i == 1 { h.Assert(t, result.CancelDrainTask != nil, "CancelDrainTask should have been set") }
err := result.PostDrainTask(result, node.Node{})
h.Ok(t, err)
default:
Expand Down Expand Up @@ -909,32 +913,39 @@ func TestMonitor_InstanceNotManaged(t *testing.T) {
}

func TestSendHeartbeats_EarlyClosure(t *testing.T) {
err := heartbeatTestHelper(nil, 3500, 1, 5)
err := heartbeatTestHelper(nil, 3500, 1, 5, false)
h.Ok(t, err)
h.Assert(t, h.HeartbeatCallCount == 3, "3 Heartbeat Expected, got %d", h.HeartbeatCallCount)
}

func TestSendHeartbeats_HeartbeatUntilExpire(t *testing.T) {
err := heartbeatTestHelper(nil, 8000, 1, 5)
err := heartbeatTestHelper(nil, 8000, 1, 5, false)
h.Ok(t, err)
h.Assert(t, h.HeartbeatCallCount == 5, "5 Heartbeat Expected, got %d", h.HeartbeatCallCount)
}

func TestSendHeartbeats_ErrThrottlingASG(t *testing.T) {
RecordLifecycleActionHeartbeatErr := awserr.New("Throttling", "Rate exceeded", nil)
err := heartbeatTestHelper(RecordLifecycleActionHeartbeatErr, 8000, 1, 6)
err := heartbeatTestHelper(RecordLifecycleActionHeartbeatErr, 8000, 1, 6, false)
h.Ok(t, err)
h.Assert(t, h.HeartbeatCallCount == 6, "6 Heartbeat Expected, got %d", h.HeartbeatCallCount)
}

func TestSendHeartbeats_ErrInvalidTarget(t *testing.T) {
RecordLifecycleActionHeartbeatErr := awserr.New("ValidationError", "No active Lifecycle Action found", nil)
err := heartbeatTestHelper(RecordLifecycleActionHeartbeatErr, 6000, 1, 4)
err := heartbeatTestHelper(RecordLifecycleActionHeartbeatErr, 6000, 1, 4, false)
h.Ok(t, err)
h.Assert(t, h.HeartbeatCallCount == 1, "1 Heartbeat Expected, got %d", h.HeartbeatCallCount)
}

func heartbeatTestHelper(RecordLifecycleActionHeartbeatErr error, sleepMilliSeconds int, heartbeatInterval int, heartbeatUntil int) error {

func TestSendHeartbeats_CancelHeartbeat(t *testing.T) {
err := heartbeatTestHelper(nil, 6000, 1, 4, true)
h.Ok(t, err)
h.Assert(t, h.HeartbeatCallCount == 2, "2 Heartbeat Expected, got %d", h.HeartbeatCallCount)
}

func heartbeatTestHelper(RecordLifecycleActionHeartbeatErr error, sleepMilliSeconds int, heartbeatInterval int, heartbeatUntil int, cancelDrain bool) error {
h.HeartbeatCallCount = 0

msg, err := getSQSMessageFromEvent(asgLifecycleEvent)
Expand Down Expand Up @@ -986,6 +997,16 @@ func heartbeatTestHelper(RecordLifecycleActionHeartbeatErr error, sleepMilliSeco
return err
}

if cancelDrain == true {
if result.CancelDrainTask == nil {
return fmt.Errorf("CancelDrainTask should have been set")
}
time.Sleep(2100 * time.Millisecond)
if err := result.CancelDrainTask(result, *testNode); err != nil {
return err
}
}

if result.PostDrainTask == nil {
return fmt.Errorf("PostDrainTask should have been set")
}
Expand Down
1 change: 1 addition & 0 deletions pkg/monitor/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ type InterruptionEvent struct {
InProgress bool
PreDrainTask DrainTask `json:"-"`
PostDrainTask DrainTask `json:"-"`
CancelDrainTask DrainTask `json:"-"`
}

// TimeUntilEvent returns the duration until the event start time
Expand Down
4 changes: 4 additions & 0 deletions pkg/observability/k8s-events.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ const (
PostDrainErrMsgFmt = "There was a problem executing the post-drain task: %s"
PostDrainReason = "PostDrain"
PostDrainMsg = "Post-drain task successfully executed"
CancelDrainErrReason = "CancelDrainError"
CancelDrainErrMsgFmt = "There was a problem executing the early exit task: %s"
CancelDrainReason = "CancelDrain"
CancelDrainMsg = "Early exit task successfully executed"
)

// Interruption event reasons
Expand Down
Loading