Skip to content

Commit e2df725

Browse files
authored
[Core] Support custom image on Azure (#3362)
* init. TODO: test * query location * fix unittest * add skypilot-tag * refactor & add defaule image size * add smoke test * retry tests after merging catalog * use exceptions in adaptor * remove assertion
1 parent f8ed1f1 commit e2df725

File tree

8 files changed

+128
-46
lines changed

8 files changed

+128
-46
lines changed

docs/source/reference/yaml-spec.rst

+4
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,10 @@ Available fields:
194194
# Or machine image: https://cloud.google.com/compute/docs/machine-images
195195
# image_id: projects/my-project/global/machineImages/my-machine-image
196196
#
197+
# Azure
198+
# To find Azure images: https://docs.microsoft.com/en-us/azure/virtual-machines/linux/cli-ps-findimage
199+
# image_id: microsoft-dsvm:ubuntu-2004:2004:21.11.04
200+
#
197201
# IBM
198202
# Create a private VPC image and paste its ID in the following format:
199203
# image_id: <unique_image_id>

sky/adaptors/azure.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,10 @@ def get_current_account_user() -> str:
4040

4141

4242
@import_package
43-
def http_error_exception():
44-
"""HttpError exception."""
45-
from azure.core import exceptions
46-
return exceptions.HttpResponseError
43+
def exceptions():
44+
"""Azure exceptions."""
45+
from azure.core import exceptions as azure_exceptions
46+
return azure_exceptions
4747

4848

4949
@functools.lru_cache()

sky/clouds/azure.py

+80-37
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@
3737

3838
_MAX_IDENTITY_FETCH_RETRY = 10
3939

40+
_DEFAULT_AZURE_UBUNTU_HPC_IMAGE_GB = 30
41+
_DEFAULT_AZURE_UBUNTU_2004_IMAGE_GB = 150
42+
4043

4144
def _run_output(cmd):
4245
proc = subprocess.run(cmd,
@@ -75,9 +78,6 @@ def _unsupported_features_for_resources(
7578
features = {
7679
clouds.CloudImplementationFeatures.CLONE_DISK_FROM_CLUSTER:
7780
(f'Migrating disk is currently not supported on {cls._REPR}.'),
78-
clouds.CloudImplementationFeatures.IMAGE_ID:
79-
('Specifying image ID is currently not supported on '
80-
f'{cls._REPR}.'),
8181
}
8282
if resources.use_spot:
8383
features[clouds.CloudImplementationFeatures.STOP] = (
@@ -137,6 +137,50 @@ def get_egress_cost(self, num_gigabytes: float):
137137
def is_same_cloud(self, other):
138138
return isinstance(other, Azure)
139139

140+
@classmethod
141+
def get_image_size(cls, image_id: str, region: Optional[str]) -> float:
142+
if region is None:
143+
# The region used here is only for where to send the query,
144+
# not the image location. Azure's image is globally available.
145+
region = 'eastus'
146+
is_skypilot_image_tag = False
147+
if image_id.startswith('skypilot:'):
148+
is_skypilot_image_tag = True
149+
image_id = service_catalog.get_image_id_from_tag(image_id,
150+
clouds='azure')
151+
image_id_splitted = image_id.split(':')
152+
if len(image_id_splitted) != 4:
153+
with ux_utils.print_exception_no_traceback():
154+
raise ValueError(f'Invalid image id: {image_id}. Expected '
155+
'format: <publisher>:<offer>:<sku>:<version>')
156+
publisher, offer, sku, version = image_id_splitted
157+
if is_skypilot_image_tag:
158+
if offer == 'ubuntu-hpc':
159+
return _DEFAULT_AZURE_UBUNTU_HPC_IMAGE_GB
160+
else:
161+
return _DEFAULT_AZURE_UBUNTU_2004_IMAGE_GB
162+
compute_client = azure.get_client('compute', cls.get_project_id())
163+
try:
164+
image = compute_client.virtual_machine_images.get(
165+
region, publisher, offer, sku, version)
166+
except azure.exceptions().ResourceNotFoundError() as e:
167+
with ux_utils.print_exception_no_traceback():
168+
raise ValueError(f'Image not found: {image_id}') from e
169+
if image.os_disk_image is None:
170+
with ux_utils.print_exception_no_traceback():
171+
raise ValueError(f'Retrieve image size for {image_id} failed.')
172+
ap = image.os_disk_image.additional_properties
173+
size_in_gb = ap.get('sizeInGb')
174+
if size_in_gb is not None:
175+
return float(size_in_gb)
176+
size_in_bytes = ap.get('sizeInBytes')
177+
if size_in_bytes is None:
178+
with ux_utils.print_exception_no_traceback():
179+
raise ValueError(f'Retrieve image size for {image_id} failed. '
180+
f'Got additional_properties: {ap}')
181+
size_in_gb = size_in_bytes / (1024**3)
182+
return size_in_gb
183+
140184
@classmethod
141185
def get_default_instance_type(
142186
cls,
@@ -149,33 +193,13 @@ def get_default_instance_type(
149193
disk_tier=disk_tier,
150194
clouds='azure')
151195

152-
def _get_image_config(self, gen_version, instance_type):
153-
# TODO(tian): images for Azure is not well organized. We should refactor
154-
# it to images.csv like AWS.
155-
# az vm image list \
156-
# --publisher microsoft-dsvm --all --output table
157-
# nvidia-driver: 535.54.03, cuda: 12.2
158-
# see: https://github.com/Azure/azhpc-images/releases/tag/ubuntu-hpc-20230803
159-
# All A100 instances is of gen2, so it will always use
160-
# the latest ubuntu-hpc:2204 image.
161-
image_config = {
162-
'image_publisher': 'microsoft-dsvm',
163-
'image_offer': 'ubuntu-hpc',
164-
'image_sku': '2204',
165-
'image_version': '22.04.2023080201'
166-
}
167-
196+
def _get_default_image_tag(self, gen_version, instance_type) -> str:
168197
# ubuntu-2004 v21.08.30, K80 requires image with old NVIDIA driver version
169198
acc = self.get_accelerators_from_instance_type(instance_type)
170199
if acc is not None:
171200
acc_name = list(acc.keys())[0]
172201
if acc_name == 'K80':
173-
image_config = {
174-
'image_publisher': 'microsoft-dsvm',
175-
'image_offer': 'ubuntu-2004',
176-
'image_sku': '2004-gen2',
177-
'image_version': '21.08.30'
178-
}
202+
return 'skypilot:k80-ubuntu-2004'
179203

180204
# ubuntu-2004 v21.11.04, the previous image we used in the past for
181205
# V1 HyperV instance before we change default image to ubuntu-hpc.
@@ -184,14 +208,13 @@ def _get_image_config(self, gen_version, instance_type):
184208
# (Basic_A, Standard_D, ...) are V1 instance. For these instances,
185209
# we use the previous image.
186210
if gen_version == 'V1':
187-
image_config = {
188-
'image_publisher': 'microsoft-dsvm',
189-
'image_offer': 'ubuntu-2004',
190-
'image_sku': '2004',
191-
'image_version': '21.11.04'
192-
}
211+
return 'skypilot:v1-ubuntu-2004'
193212

194-
return image_config
213+
# nvidia-driver: 535.54.03, cuda: 12.2
214+
# see: https://github.com/Azure/azhpc-images/releases/tag/ubuntu-hpc-20230803
215+
# All A100 instances is of gen2, so it will always use
216+
# the latest ubuntu-hpc:2204 image.
217+
return 'skypilot:gpu-ubuntu-2204'
195218

196219
@classmethod
197220
def regions_with_offering(cls, instance_type: str,
@@ -270,11 +293,31 @@ def make_deploy_resources_variables(
270293
acc_count = str(sum(acc_dict.values()))
271294
else:
272295
custom_resources = None
273-
# pylint: disable=import-outside-toplevel
274-
from sky.clouds.service_catalog import azure_catalog
275-
gen_version = azure_catalog.get_gen_version_from_instance_type(
276-
r.instance_type)
277-
image_config = self._get_image_config(gen_version, r.instance_type)
296+
297+
if resources.image_id is None:
298+
# pylint: disable=import-outside-toplevel
299+
from sky.clouds.service_catalog import azure_catalog
300+
gen_version = azure_catalog.get_gen_version_from_instance_type(
301+
r.instance_type)
302+
image_id = self._get_default_image_tag(gen_version, r.instance_type)
303+
else:
304+
if None in resources.image_id:
305+
image_id = resources.image_id[None]
306+
else:
307+
assert region_name in resources.image_id, resources.image_id
308+
image_id = resources.image_id[region_name]
309+
if image_id.startswith('skypilot:'):
310+
image_id = service_catalog.get_image_id_from_tag(image_id,
311+
clouds='azure')
312+
# Already checked in resources.py
313+
publisher, offer, sku, version = image_id.split(':')
314+
image_config = {
315+
'image_publisher': publisher,
316+
'image_offer': offer,
317+
'image_sku': sku,
318+
'image_version': version,
319+
}
320+
278321
# Setup commands to eliminate the banner and restart sshd.
279322
# This script will modify /etc/ssh/sshd_config and add a bash script
280323
# into .bashrc. The bash script will restart sshd if it has not been

sky/clouds/service_catalog/azure_catalog.py

+17
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
_df = common.read_catalog('azure/vms.csv',
2222
pull_frequency_hours=_PULL_FREQUENCY_HOURS)
2323

24+
_image_df = common.read_catalog('azure/images.csv',
25+
pull_frequency_hours=_PULL_FREQUENCY_HOURS)
26+
2427
# We will select from the following three instance families:
2528
_DEFAULT_INSTANCE_FAMILY = [
2629
# The latest general-purpose instance family as of Mar. 2023.
@@ -168,3 +171,17 @@ def list_accelerators(
168171
return common.list_accelerators_impl('Azure', _df, gpus_only, name_filter,
169172
region_filter, quantity_filter,
170173
case_sensitive, all_regions)
174+
175+
176+
def get_image_id_from_tag(tag: str, region: Optional[str]) -> Optional[str]:
177+
"""Returns the image id from the tag."""
178+
# Azure images are not region-specific.
179+
del region # Unused.
180+
return common.get_image_id_from_tag_impl(_image_df, tag, None)
181+
182+
183+
def is_image_tag_valid(tag: str, region: Optional[str]) -> bool:
184+
"""Returns whether the image tag is valid."""
185+
# Azure images are not region-specific.
186+
del region # Unused.
187+
return common.is_image_tag_valid_impl(_image_df, tag, None)

sky/provision/azure/instance.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def open_ports(
7979
with ux_utils.print_exception_no_traceback():
8080
raise ValueError(f'Failed to open ports {ports} in NSG '
8181
f'{nsg.name}: {poller.status()}')
82-
except azure.http_error_exception() as e:
82+
except azure.exceptions().HttpResponseError() as e:
8383
with ux_utils.print_exception_no_traceback():
8484
raise ValueError(
8585
f'Failed to open ports {ports} in NSG {nsg.name}.') from e

sky/resources.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -816,8 +816,8 @@ def _try_validate_image_id(self) -> None:
816816
except exceptions.NotSupportedError as e:
817817
with ux_utils.print_exception_no_traceback():
818818
raise ValueError(
819-
'image_id is only supported for AWS/GCP/IBM/OCI/Kubernetes,'
820-
' please explicitly specify the cloud.') from e
819+
'image_id is only supported for AWS/GCP/Azure/IBM/OCI/'
820+
'Kubernetes, please explicitly specify the cloud.') from e
821821

822822
if self._region is not None:
823823
if self._region not in self._image_id:

tests/test_optimizer_dryruns.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -468,8 +468,8 @@ def test_invalid_image(monkeypatch):
468468
assert 'Cloud must be specified' in str(e.value)
469469

470470
with pytest.raises(ValueError) as e:
471-
_test_resources(monkeypatch, cloud=sky.Azure(), image_id='some-image')
472-
assert 'only supported for AWS/GCP/IBM/OCI' in str(e.value)
471+
_test_resources(monkeypatch, cloud=sky.Lambda(), image_id='some-image')
472+
assert 'only supported for AWS/GCP/Azure/IBM/OCI/Kubernetes' in str(e.value)
473473

474474

475475
def test_valid_image(monkeypatch):

tests/test_smoke.py

+18
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,24 @@ def test_gcp_images():
444444
run_one_test(test)
445445

446446

447+
@pytest.mark.azure
448+
def test_azure_images():
449+
name = _get_cluster_name()
450+
test = Test(
451+
'azure_images',
452+
[
453+
f'sky launch -y -c {name} --image-id skypilot:gpu-ubuntu-2204 --cloud azure tests/test_yamls/minimal.yaml',
454+
f'sky logs {name} 1 --status', # Ensure the job succeeded.
455+
f'sky launch -c {name} --image-id skypilot:v1-ubuntu-2004 --cloud azure tests/test_yamls/minimal.yaml && exit 1 || true',
456+
f'sky launch -y -c {name} tests/test_yamls/minimal.yaml',
457+
f'sky logs {name} 2 --status',
458+
f'sky logs {name} --status | grep "Job 2: SUCCEEDED"', # Equivalent.
459+
],
460+
f'sky down -y {name}',
461+
)
462+
run_one_test(test)
463+
464+
447465
@pytest.mark.aws
448466
def test_aws_image_id_dict():
449467
name = _get_cluster_name()

0 commit comments

Comments
 (0)