Skip to content

Commit 66b8635

Browse files
Azure: update fetch_azure to support two H100 families. (#2844)
* Azure: update fetch_azure to support two H100 families. * format
1 parent 84313ed commit 66b8635

File tree

2 files changed

+38
-22
lines changed

2 files changed

+38
-22
lines changed

sky/clouds/service_catalog/data_fetchers/fetch_azure.py

+30-20
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,34 @@
3636

3737
SINGLE_THREADED = False
3838

39+
# Family name to SkyPilot GPU name mapping.
40+
#
41+
# When adding a new accelerator:
42+
# - The instance type is typically already fetched, but we need to find the
43+
# family name and add it to this mapping.
44+
# - To inspect family names returned by Azure API, check the dataframes in
45+
# get_all_regions_instance_types_df().
46+
FAMILY_NAME_TO_SKYPILOT_GPU_NAME = {
47+
'standardNCFamily': 'K80',
48+
'standardNCSv2Family': 'P100',
49+
'standardNCSv3Family': 'V100',
50+
'standardNCPromoFamily': 'K80',
51+
'StandardNCASv3_T4Family': 'T4',
52+
'standardNDSv2Family': 'V100-32GB',
53+
'StandardNCADSA100v4Family': 'A100-80GB',
54+
'standardNDAMSv4_A100Family': 'A100-80GB',
55+
'StandardNDASv4_A100Family': 'A100',
56+
'standardNVFamily': 'M60',
57+
'standardNVSv2Family': 'M60',
58+
'standardNVSv3Family': 'M60',
59+
'standardNVPromoFamily': 'M60',
60+
'standardNVSv4Family': 'Radeon MI25',
61+
'standardNDSFamily': 'P40',
62+
'StandardNVADSA10v5Family': 'A10',
63+
'StandardNCadsH100v5Family': 'H100',
64+
'standardNDSH100v5Family': 'H100',
65+
}
66+
3967

4068
def get_regions() -> List[str]:
4169
"""Get all available regions."""
@@ -78,7 +106,7 @@ def get_pricing_url(region: Optional[str] = None) -> str:
78106
def get_pricing_df(region: Optional[str] = None) -> pd.DataFrame:
79107
all_items = []
80108
url = get_pricing_url(region)
81-
print(f'Getting pricing for {region}')
109+
print(f'Getting pricing for {region}, url: {url}')
82110
page = 0
83111
while url is not None:
84112
page += 1
@@ -125,29 +153,11 @@ def get_sku_df(region_set: Set[str]) -> pd.DataFrame:
125153

126154

127155
def get_gpu_name(family: str) -> Optional[str]:
128-
gpu_data = {
129-
'standardNCFamily': 'K80',
130-
'standardNCSv2Family': 'P100',
131-
'standardNCSv3Family': 'V100',
132-
'standardNCPromoFamily': 'K80',
133-
'StandardNCASv3_T4Family': 'T4',
134-
'standardNDSv2Family': 'V100-32GB',
135-
'StandardNCADSA100v4Family': 'A100-80GB',
136-
'standardNDAMSv4_A100Family': 'A100-80GB',
137-
'StandardNDASv4_A100Family': 'A100',
138-
'standardNVFamily': 'M60',
139-
'standardNVSv2Family': 'M60',
140-
'standardNVSv3Family': 'M60',
141-
'standardNVPromoFamily': 'M60',
142-
'standardNVSv4Family': 'Radeon MI25',
143-
'standardNDSFamily': 'P40',
144-
'StandardNVADSA10v5Family': 'A10',
145-
}
146156
# NP-series offer Xilinx U250 FPGAs which are not GPUs,
147157
# so we do not include them here.
148158
# https://docs.microsoft.com/en-us/azure/virtual-machines/np-series
149159
family = family.replace(' ', '')
150-
return gpu_data.get(family)
160+
return FAMILY_NAME_TO_SKYPILOT_GPU_NAME.get(family)
151161

152162

153163
def get_all_regions_instance_types_df(region_set: Set[str]):

sky/utils/accelerator_registry.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,22 @@
66
# NOTE: Must include accelerators supported for local clusters.
77
#
88
# 1. What if a name is in this list, but not in any catalog?
9+
#
910
# The name will be canonicalized, but the accelerator will not be supported.
1011
# Optimizer will print an error message.
12+
#
1113
# 2. What if a name is not in this list, but in a catalog?
14+
#
1215
# The list is simply an optimization to short-circuit the search in the catalog.
1316
# If the name is not found in the list, it will be searched in the catalog
1417
# with its case being ignored. If a match is found, the name will be
1518
# canonicalized to that in the catalog. Note that this lookup can be an
1619
# expensive operation, as it requires reading the catalog or making external
1720
# API calls (such as for Kubernetes). Thus it is desirable to keep this list
1821
# up-to-date with commonly used accelerators.
22+
1923
# 3. (For SkyPilot dev) What to do if I want to add a new accelerator?
24+
#
2025
# Append its case-sensitive canonical name to this list. The name must match
2126
# `AcceleratorName` in the service catalog, or what we define in
2227
# `onprem_utils.get_local_cluster_accelerators`.
@@ -42,6 +47,7 @@
4247
'Radeon MI25',
4348
'P4',
4449
'L4',
50+
'H100',
4551
]
4652

4753

@@ -72,11 +78,11 @@ def canonicalize_accelerator_name(accelerator: str) -> str:
7278
if len(names) == 1:
7379
return names[0]
7480

75-
# Do not print an error meessage here. Optimizer will handle it.
81+
# Do not print an error message here. Optimizer will handle it.
7682
if len(names) == 0:
7783
return accelerator
7884

79-
# Currenlty unreachable.
85+
# Currently unreachable.
8086
# This can happen if catalogs have the same accelerator with
8187
# different names (e.g., A10g and A10G).
8288
assert len(names) > 1

0 commit comments

Comments
 (0)