Skip to content

Commit 38ba39f

Browse files
authored
[jobs] backoff cluster teardown (#4562)
* move terminate_cluster into utils * [jobs] backoff cluster teardown * use terminate_cluster in update_managed_job_status * fix unit test * add details on backoff
1 parent f8494b5 commit 38ba39f

File tree

4 files changed

+52
-46
lines changed

4 files changed

+52
-46
lines changed

sky/jobs/controller.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def _run_one_task(self, task_id: int, task: 'sky.Task') -> bool:
243243
self._download_log_and_stream(task_id, handle)
244244
# Only clean up the cluster, not the storages, because tasks may
245245
# share storages.
246-
recovery_strategy.terminate_cluster(cluster_name=cluster_name)
246+
managed_job_utils.terminate_cluster(cluster_name=cluster_name)
247247
return True
248248

249249
# For single-node jobs, non-terminated job_status indicates a
@@ -342,7 +342,7 @@ def _run_one_task(self, task_id: int, task: 'sky.Task') -> bool:
342342
# those clusters again may fail.
343343
logger.info('Cleaning up the preempted or failed cluster'
344344
'...')
345-
recovery_strategy.terminate_cluster(cluster_name)
345+
managed_job_utils.terminate_cluster(cluster_name)
346346

347347
# Try to recover the managed jobs, when the cluster is preempted or
348348
# failed or the job status is failed to be fetched.
@@ -478,7 +478,7 @@ def _cleanup(job_id: int, dag_yaml: str):
478478
assert task.name is not None, task
479479
cluster_name = managed_job_utils.generate_managed_job_cluster_name(
480480
task.name, job_id)
481-
recovery_strategy.terminate_cluster(cluster_name)
481+
managed_job_utils.terminate_cluster(cluster_name)
482482
# Clean up Storages with persistent=False.
483483
# TODO(zhwu): this assumes the specific backend.
484484
backend = cloud_vm_ray_backend.CloudVmRayBackend()

sky/jobs/recovery_strategy.py

+4-28
Original file line numberDiff line numberDiff line change
@@ -43,30 +43,6 @@
4343
_AUTODOWN_MINUTES = 5
4444

4545

46-
def terminate_cluster(cluster_name: str, max_retry: int = 3) -> None:
47-
"""Terminate the cluster."""
48-
retry_cnt = 0
49-
while True:
50-
try:
51-
usage_lib.messages.usage.set_internal()
52-
sky.down(cluster_name)
53-
return
54-
except exceptions.ClusterDoesNotExist:
55-
# The cluster is already down.
56-
logger.debug(f'The cluster {cluster_name} is already down.')
57-
return
58-
except Exception as e: # pylint: disable=broad-except
59-
retry_cnt += 1
60-
if retry_cnt >= max_retry:
61-
raise RuntimeError(
62-
f'Failed to terminate the cluster {cluster_name}.') from e
63-
logger.error(
64-
f'Failed to terminate the cluster {cluster_name}. Retrying.'
65-
f'Details: {common_utils.format_exception(e)}')
66-
with ux_utils.enable_traceback():
67-
logger.error(f' Traceback: {traceback.format_exc()}')
68-
69-
7046
class StrategyExecutor:
7147
"""Handle the launching, recovery and termination of managed job clusters"""
7248

@@ -193,7 +169,7 @@ def _try_cancel_all_jobs(self):
193169
f'{common_utils.format_exception(e)}\n'
194170
'Terminating the cluster explicitly to ensure no '
195171
'remaining job process interferes with recovery.')
196-
terminate_cluster(self.cluster_name)
172+
managed_job_utils.terminate_cluster(self.cluster_name)
197173

198174
def _wait_until_job_starts_on_cluster(self) -> Optional[float]:
199175
"""Wait for MAX_JOB_CHECKING_RETRY times until job starts on the cluster
@@ -380,7 +356,7 @@ def _launch(self,
380356

381357
# If we get here, the launch did not succeed. Tear down the
382358
# cluster and retry.
383-
terminate_cluster(self.cluster_name)
359+
managed_job_utils.terminate_cluster(self.cluster_name)
384360
if max_retry is not None and retry_cnt >= max_retry:
385361
# Retry forever if max_retry is None.
386362
if raise_on_failure:
@@ -473,7 +449,7 @@ def recover(self) -> float:
473449
# Step 2
474450
logger.debug('Terminating unhealthy cluster and reset cloud '
475451
'region.')
476-
terminate_cluster(self.cluster_name)
452+
managed_job_utils.terminate_cluster(self.cluster_name)
477453

478454
# Step 3
479455
logger.debug('Relaunch the cluster without constraining to prior '
@@ -531,7 +507,7 @@ def recover(self) -> float:
531507

532508
# Step 1
533509
logger.debug('Terminating unhealthy cluster and reset cloud region.')
534-
terminate_cluster(self.cluster_name)
510+
managed_job_utils.terminate_cluster(self.cluster_name)
535511

536512
# Step 2
537513
logger.debug('Relaunch the cluster skipping the previously launched '

sky/jobs/utils.py

+42-12
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import shutil
1414
import textwrap
1515
import time
16+
import traceback
1617
import typing
1718
from typing import Any, Dict, List, Optional, Set, Tuple, Union
1819

@@ -21,6 +22,7 @@
2122
import psutil
2223
from typing_extensions import Literal
2324

25+
import sky
2426
from sky import backends
2527
from sky import exceptions
2628
from sky import global_user_state
@@ -32,14 +34,14 @@
3234
from sky.skylet import constants
3335
from sky.skylet import job_lib
3436
from sky.skylet import log_lib
37+
from sky.usage import usage_lib
3538
from sky.utils import common_utils
3639
from sky.utils import log_utils
3740
from sky.utils import rich_utils
3841
from sky.utils import subprocess_utils
3942
from sky.utils import ux_utils
4043

4144
if typing.TYPE_CHECKING:
42-
import sky
4345
from sky import dag as dag_lib
4446

4547
logger = sky_logging.init_logger(__name__)
@@ -85,6 +87,43 @@ class UserSignal(enum.Enum):
8587

8688

8789
# ====== internal functions ======
90+
def terminate_cluster(cluster_name: str, max_retry: int = 6) -> None:
91+
"""Terminate the cluster."""
92+
retry_cnt = 0
93+
# In some cases, e.g. botocore.exceptions.NoCredentialsError due to AWS
94+
# metadata service throttling, the failed sky.down attempt can take 10-11
95+
# seconds. In this case, we need the backoff to significantly reduce the
96+
# rate of requests - that is, significantly increase the time between
97+
# requests. We set the initial backoff to 15 seconds, so that once it grows
98+
# exponentially it will quickly dominate the 10-11 seconds that we already
99+
# see between requests. We set the max backoff very high, since it's
100+
# generally much more important to eventually succeed than to fail fast.
101+
backoff = common_utils.Backoff(
102+
initial_backoff=15,
103+
# 1.6 ** 5 = 10.48576 < 20, so we won't hit this with default max_retry
104+
max_backoff_factor=20)
105+
while True:
106+
try:
107+
usage_lib.messages.usage.set_internal()
108+
sky.down(cluster_name)
109+
return
110+
except exceptions.ClusterDoesNotExist:
111+
# The cluster is already down.
112+
logger.debug(f'The cluster {cluster_name} is already down.')
113+
return
114+
except Exception as e: # pylint: disable=broad-except
115+
retry_cnt += 1
116+
if retry_cnt >= max_retry:
117+
raise RuntimeError(
118+
f'Failed to terminate the cluster {cluster_name}.') from e
119+
logger.error(
120+
f'Failed to terminate the cluster {cluster_name}. Retrying.'
121+
f'Details: {common_utils.format_exception(e)}')
122+
with ux_utils.enable_traceback():
123+
logger.error(f' Traceback: {traceback.format_exc()}')
124+
time.sleep(backoff.current_backoff())
125+
126+
88127
def get_job_status(backend: 'backends.CloudVmRayBackend',
89128
cluster_name: str) -> Optional['job_lib.JobStatus']:
90129
"""Check the status of the job running on a managed job cluster.
@@ -202,18 +241,9 @@ def update_managed_job_status(job_id: Optional[int] = None):
202241
cluster_name = generate_managed_job_cluster_name(task_name, job_id_)
203242
handle = global_user_state.get_handle_from_cluster_name(
204243
cluster_name)
244+
# If the cluster exists, terminate it.
205245
if handle is not None:
206-
backend = backend_utils.get_backend_from_handle(handle)
207-
# TODO(cooperc): Add backoff
208-
max_retry = 3
209-
for retry_cnt in range(max_retry):
210-
try:
211-
backend.teardown(handle, terminate=True)
212-
break
213-
except RuntimeError:
214-
logger.error('Failed to tear down the cluster '
215-
f'{cluster_name!r}. Retrying '
216-
f'[{retry_cnt}/{max_retry}].')
246+
terminate_cluster(cluster_name)
217247

218248
# The controller process for this managed job is not running: it must
219249
# have exited abnormally, and we should set the job status to

tests/unit_tests/test_recovery_strategy.py renamed to tests/unit_tests/test_jobs_utils.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from unittest import mock
22

33
from sky.exceptions import ClusterDoesNotExist
4-
from sky.jobs import recovery_strategy
4+
from sky.jobs import utils
55

66

77
@mock.patch('sky.down')
@@ -16,7 +16,7 @@ def test_terminate_cluster_retry_on_value_error(mock_set_internal,
1616
]
1717

1818
# Call should succeed after retries
19-
recovery_strategy.terminate_cluster('test-cluster')
19+
utils.terminate_cluster('test-cluster')
2020

2121
# Verify sky.down was called 3 times
2222
assert mock_sky_down.call_count == 3
@@ -38,7 +38,7 @@ def test_terminate_cluster_handles_nonexistent_cluster(mock_set_internal,
3838
mock_sky_down.side_effect = ClusterDoesNotExist('test-cluster')
3939

4040
# Call should succeed silently
41-
recovery_strategy.terminate_cluster('test-cluster')
41+
utils.terminate_cluster('test-cluster')
4242

4343
# Verify sky.down was called once
4444
assert mock_sky_down.call_count == 1

0 commit comments

Comments
 (0)