diff --git a/docs/source/reference/yaml-spec.rst b/docs/source/reference/yaml-spec.rst index 402dfef4cfe..782fd9fd208 100644 --- a/docs/source/reference/yaml-spec.rst +++ b/docs/source/reference/yaml-spec.rst @@ -194,6 +194,10 @@ Available fields: # Or machine image: https://cloud.google.com/compute/docs/machine-images # image_id: projects/my-project/global/machineImages/my-machine-image # + # Azure + # To find Azure images: https://docs.microsoft.com/en-us/azure/virtual-machines/linux/cli-ps-findimage + # image_id: microsoft-dsvm:ubuntu-2004:2004:21.11.04 + # # IBM # Create a private VPC image and paste its ID in the following format: # image_id: diff --git a/sky/adaptors/azure.py b/sky/adaptors/azure.py index b1efc349fbc..b1c8f3cd58d 100644 --- a/sky/adaptors/azure.py +++ b/sky/adaptors/azure.py @@ -40,10 +40,10 @@ def get_current_account_user() -> str: @import_package -def http_error_exception(): - """HttpError exception.""" - from azure.core import exceptions - return exceptions.HttpResponseError +def exceptions(): + """Azure exceptions.""" + from azure.core import exceptions as azure_exceptions + return azure_exceptions @functools.lru_cache() diff --git a/sky/clouds/azure.py b/sky/clouds/azure.py index f942a4b6d42..a0e16c83471 100644 --- a/sky/clouds/azure.py +++ b/sky/clouds/azure.py @@ -37,6 +37,9 @@ _MAX_IDENTITY_FETCH_RETRY = 10 +_DEFAULT_AZURE_UBUNTU_HPC_IMAGE_GB = 30 +_DEFAULT_AZURE_UBUNTU_2004_IMAGE_GB = 150 + def _run_output(cmd): proc = subprocess.run(cmd, @@ -75,9 +78,6 @@ def _unsupported_features_for_resources( features = { clouds.CloudImplementationFeatures.CLONE_DISK_FROM_CLUSTER: (f'Migrating disk is currently not supported on {cls._REPR}.'), - clouds.CloudImplementationFeatures.IMAGE_ID: - ('Specifying image ID is currently not supported on ' - f'{cls._REPR}.'), } if resources.use_spot: features[clouds.CloudImplementationFeatures.STOP] = ( @@ -137,6 +137,50 @@ def get_egress_cost(self, num_gigabytes: float): def is_same_cloud(self, other): return isinstance(other, Azure) + @classmethod + def get_image_size(cls, image_id: str, region: Optional[str]) -> float: + if region is None: + # The region used here is only for where to send the query, + # not the image location. Azure's image is globally available. + region = 'eastus' + is_skypilot_image_tag = False + if image_id.startswith('skypilot:'): + is_skypilot_image_tag = True + image_id = service_catalog.get_image_id_from_tag(image_id, + clouds='azure') + image_id_splitted = image_id.split(':') + if len(image_id_splitted) != 4: + with ux_utils.print_exception_no_traceback(): + raise ValueError(f'Invalid image id: {image_id}. Expected ' + 'format: :::') + publisher, offer, sku, version = image_id_splitted + if is_skypilot_image_tag: + if offer == 'ubuntu-hpc': + return _DEFAULT_AZURE_UBUNTU_HPC_IMAGE_GB + else: + return _DEFAULT_AZURE_UBUNTU_2004_IMAGE_GB + compute_client = azure.get_client('compute', cls.get_project_id()) + try: + image = compute_client.virtual_machine_images.get( + region, publisher, offer, sku, version) + except azure.exceptions().ResourceNotFoundError() as e: + with ux_utils.print_exception_no_traceback(): + raise ValueError(f'Image not found: {image_id}') from e + if image.os_disk_image is None: + with ux_utils.print_exception_no_traceback(): + raise ValueError(f'Retrieve image size for {image_id} failed.') + ap = image.os_disk_image.additional_properties + size_in_gb = ap.get('sizeInGb') + if size_in_gb is not None: + return float(size_in_gb) + size_in_bytes = ap.get('sizeInBytes') + if size_in_bytes is None: + with ux_utils.print_exception_no_traceback(): + raise ValueError(f'Retrieve image size for {image_id} failed. ' + f'Got additional_properties: {ap}') + size_in_gb = size_in_bytes / (1024**3) + return size_in_gb + @classmethod def get_default_instance_type( cls, @@ -149,33 +193,13 @@ def get_default_instance_type( disk_tier=disk_tier, clouds='azure') - def _get_image_config(self, gen_version, instance_type): - # TODO(tian): images for Azure is not well organized. We should refactor - # it to images.csv like AWS. - # az vm image list \ - # --publisher microsoft-dsvm --all --output table - # nvidia-driver: 535.54.03, cuda: 12.2 - # see: https://github.com/Azure/azhpc-images/releases/tag/ubuntu-hpc-20230803 - # All A100 instances is of gen2, so it will always use - # the latest ubuntu-hpc:2204 image. - image_config = { - 'image_publisher': 'microsoft-dsvm', - 'image_offer': 'ubuntu-hpc', - 'image_sku': '2204', - 'image_version': '22.04.2023080201' - } - + def _get_default_image_tag(self, gen_version, instance_type) -> str: # ubuntu-2004 v21.08.30, K80 requires image with old NVIDIA driver version acc = self.get_accelerators_from_instance_type(instance_type) if acc is not None: acc_name = list(acc.keys())[0] if acc_name == 'K80': - image_config = { - 'image_publisher': 'microsoft-dsvm', - 'image_offer': 'ubuntu-2004', - 'image_sku': '2004-gen2', - 'image_version': '21.08.30' - } + return 'skypilot:k80-ubuntu-2004' # ubuntu-2004 v21.11.04, the previous image we used in the past for # V1 HyperV instance before we change default image to ubuntu-hpc. @@ -184,14 +208,13 @@ def _get_image_config(self, gen_version, instance_type): # (Basic_A, Standard_D, ...) are V1 instance. For these instances, # we use the previous image. if gen_version == 'V1': - image_config = { - 'image_publisher': 'microsoft-dsvm', - 'image_offer': 'ubuntu-2004', - 'image_sku': '2004', - 'image_version': '21.11.04' - } + return 'skypilot:v1-ubuntu-2004' - return image_config + # nvidia-driver: 535.54.03, cuda: 12.2 + # see: https://github.com/Azure/azhpc-images/releases/tag/ubuntu-hpc-20230803 + # All A100 instances is of gen2, so it will always use + # the latest ubuntu-hpc:2204 image. + return 'skypilot:gpu-ubuntu-2204' @classmethod def regions_with_offering(cls, instance_type: str, @@ -270,11 +293,31 @@ def make_deploy_resources_variables( acc_count = str(sum(acc_dict.values())) else: custom_resources = None - # pylint: disable=import-outside-toplevel - from sky.clouds.service_catalog import azure_catalog - gen_version = azure_catalog.get_gen_version_from_instance_type( - r.instance_type) - image_config = self._get_image_config(gen_version, r.instance_type) + + if resources.image_id is None: + # pylint: disable=import-outside-toplevel + from sky.clouds.service_catalog import azure_catalog + gen_version = azure_catalog.get_gen_version_from_instance_type( + r.instance_type) + image_id = self._get_default_image_tag(gen_version, r.instance_type) + else: + if None in resources.image_id: + image_id = resources.image_id[None] + else: + assert region_name in resources.image_id, resources.image_id + image_id = resources.image_id[region_name] + if image_id.startswith('skypilot:'): + image_id = service_catalog.get_image_id_from_tag(image_id, + clouds='azure') + # Already checked in resources.py + publisher, offer, sku, version = image_id.split(':') + image_config = { + 'image_publisher': publisher, + 'image_offer': offer, + 'image_sku': sku, + 'image_version': version, + } + # Setup commands to eliminate the banner and restart sshd. # This script will modify /etc/ssh/sshd_config and add a bash script # into .bashrc. The bash script will restart sshd if it has not been diff --git a/sky/clouds/service_catalog/azure_catalog.py b/sky/clouds/service_catalog/azure_catalog.py index 6d7ebadac70..141b356712e 100644 --- a/sky/clouds/service_catalog/azure_catalog.py +++ b/sky/clouds/service_catalog/azure_catalog.py @@ -21,6 +21,9 @@ _df = common.read_catalog('azure/vms.csv', pull_frequency_hours=_PULL_FREQUENCY_HOURS) +_image_df = common.read_catalog('azure/images.csv', + pull_frequency_hours=_PULL_FREQUENCY_HOURS) + # We will select from the following three instance families: _DEFAULT_INSTANCE_FAMILY = [ # The latest general-purpose instance family as of Mar. 2023. @@ -168,3 +171,17 @@ def list_accelerators( return common.list_accelerators_impl('Azure', _df, gpus_only, name_filter, region_filter, quantity_filter, case_sensitive, all_regions) + + +def get_image_id_from_tag(tag: str, region: Optional[str]) -> Optional[str]: + """Returns the image id from the tag.""" + # Azure images are not region-specific. + del region # Unused. + return common.get_image_id_from_tag_impl(_image_df, tag, None) + + +def is_image_tag_valid(tag: str, region: Optional[str]) -> bool: + """Returns whether the image tag is valid.""" + # Azure images are not region-specific. + del region # Unused. + return common.is_image_tag_valid_impl(_image_df, tag, None) diff --git a/sky/provision/azure/instance.py b/sky/provision/azure/instance.py index a3129d8794c..3264f3e9350 100644 --- a/sky/provision/azure/instance.py +++ b/sky/provision/azure/instance.py @@ -79,7 +79,7 @@ def open_ports( with ux_utils.print_exception_no_traceback(): raise ValueError(f'Failed to open ports {ports} in NSG ' f'{nsg.name}: {poller.status()}') - except azure.http_error_exception() as e: + except azure.exceptions().HttpResponseError() as e: with ux_utils.print_exception_no_traceback(): raise ValueError( f'Failed to open ports {ports} in NSG {nsg.name}.') from e diff --git a/sky/resources.py b/sky/resources.py index 257b05442fc..aeb37b79649 100644 --- a/sky/resources.py +++ b/sky/resources.py @@ -816,8 +816,8 @@ def _try_validate_image_id(self) -> None: except exceptions.NotSupportedError as e: with ux_utils.print_exception_no_traceback(): raise ValueError( - 'image_id is only supported for AWS/GCP/IBM/OCI/Kubernetes,' - ' please explicitly specify the cloud.') from e + 'image_id is only supported for AWS/GCP/Azure/IBM/OCI/' + 'Kubernetes, please explicitly specify the cloud.') from e if self._region is not None: if self._region not in self._image_id: diff --git a/tests/test_optimizer_dryruns.py b/tests/test_optimizer_dryruns.py index e246e859435..3d557ea475f 100644 --- a/tests/test_optimizer_dryruns.py +++ b/tests/test_optimizer_dryruns.py @@ -468,8 +468,8 @@ def test_invalid_image(monkeypatch): assert 'Cloud must be specified' in str(e.value) with pytest.raises(ValueError) as e: - _test_resources(monkeypatch, cloud=sky.Azure(), image_id='some-image') - assert 'only supported for AWS/GCP/IBM/OCI' in str(e.value) + _test_resources(monkeypatch, cloud=sky.Lambda(), image_id='some-image') + assert 'only supported for AWS/GCP/Azure/IBM/OCI/Kubernetes' in str(e.value) def test_valid_image(monkeypatch): diff --git a/tests/test_smoke.py b/tests/test_smoke.py index 340de5cf842..77dd367180e 100644 --- a/tests/test_smoke.py +++ b/tests/test_smoke.py @@ -444,6 +444,24 @@ def test_gcp_images(): run_one_test(test) +@pytest.mark.azure +def test_azure_images(): + name = _get_cluster_name() + test = Test( + 'azure_images', + [ + f'sky launch -y -c {name} --image-id skypilot:gpu-ubuntu-2204 --cloud azure tests/test_yamls/minimal.yaml', + f'sky logs {name} 1 --status', # Ensure the job succeeded. + f'sky launch -c {name} --image-id skypilot:v1-ubuntu-2004 --cloud azure tests/test_yamls/minimal.yaml && exit 1 || true', + f'sky launch -y -c {name} tests/test_yamls/minimal.yaml', + f'sky logs {name} 2 --status', + f'sky logs {name} --status | grep "Job 2: SUCCEEDED"', # Equivalent. + ], + f'sky down -y {name}', + ) + run_one_test(test) + + @pytest.mark.aws def test_aws_image_id_dict(): name = _get_cluster_name()