25
25
26
26
dir = pathlib .Path (__file__ ).parent .parent .resolve ()
27
27
28
+ # https://docs.ray.io/en/latest/ray-core/scheduling/accelerators.html
29
+ DEFAULT_RESOURCE_MAPPING = {
30
+ "nvidia.com/gpu" : "GPU" ,
31
+ "intel.com/gpu" : "GPU" ,
32
+ "amd.com/gpu" : "GPU" ,
33
+ "aws.amazon.com/neuroncore" : "neuron_cores" ,
34
+ "google.com/tpu" : "TPU" ,
35
+ "habana.ai/gaudi" : "HPU" ,
36
+ "huawei.com/Ascend910" : "NPU" ,
37
+ "huawei.com/Ascend310" : "NPU" ,
38
+ }
39
+
28
40
29
41
@dataclass
30
42
class ClusterConfiguration :
31
43
"""
32
44
This dataclass is used to specify resource requirements and other details, and
33
45
is passed in as an argument when creating a Cluster object.
46
+
47
+ Attributes:
48
+ - name: The name of the cluster.
49
+ - namespace: The namespace in which the cluster should be created.
50
+ - head_info: A list of strings containing information about the head node.
51
+ - head_cpus: The number of CPUs to allocate to the head node.
52
+ - head_memory: The amount of memory to allocate to the head node.
53
+ - head_gpus: The number of GPUs to allocate to the head node. (Deprecated, use head_extended_resource_requests)
54
+ - head_extended_resource_requests: A dictionary of extended resource requests for the head node. ex: {"nvidia.com/gpu": 1}
55
+ - machine_types: A list of machine types to use for the cluster.
56
+ - min_cpus: The minimum number of CPUs to allocate to each worker.
57
+ - max_cpus: The maximum number of CPUs to allocate to each worker.
58
+ - num_workers: The number of workers to create.
59
+ - min_memory: The minimum amount of memory to allocate to each worker.
60
+ - max_memory: The maximum amount of memory to allocate to each worker.
61
+ - num_gpus: The number of GPUs to allocate to each worker. (Deprecated, use worker_extended_resource_requests)
62
+ - template: The path to the template file to use for the cluster.
63
+ - appwrapper: A boolean indicating whether to use an AppWrapper.
64
+ - envs: A dictionary of environment variables to set for the cluster.
65
+ - image: The image to use for the cluster.
66
+ - image_pull_secrets: A list of image pull secrets to use for the cluster.
67
+ - write_to_file: A boolean indicating whether to write the cluster configuration to a file.
68
+ - verify_tls: A boolean indicating whether to verify TLS when connecting to the cluster.
69
+ - labels: A dictionary of labels to apply to the cluster.
70
+ - worker_extended_resource_requests: A dictionary of extended resource requests for each worker. ex: {"nvidia.com/gpu": 1}
71
+ - extended_resource_mapping: A dictionary of custom resource mappings to map extended resource requests to RayCluster resource names
72
+ - overwrite_default_resource_mapping: A boolean indicating whether to overwrite the default resource mapping.
34
73
"""
35
74
36
75
name : str
@@ -39,7 +78,7 @@ class ClusterConfiguration:
39
78
head_cpus : typing .Union [int , str ] = 2
40
79
head_memory : typing .Union [int , str ] = 8
41
80
head_gpus : int = None # Deprecating
42
- num_head_gpus : int = 0
81
+ head_extended_resource_requests : typing . Dict [ str , int ] = field ( default_factory = dict )
43
82
machine_types : list = field (default_factory = list ) # ["m4.xlarge", "g4dn.xlarge"]
44
83
worker_cpu_requests : typing .Union [int , str ] = 1
45
84
worker_cpu_limits : typing .Union [int , str ] = 1
@@ -50,7 +89,6 @@ class ClusterConfiguration:
50
89
worker_memory_limits : typing .Union [int , str ] = 2
51
90
min_memory : typing .Union [int , str ] = None # Deprecating
52
91
max_memory : typing .Union [int , str ] = None # Deprecating
53
- num_worker_gpus : int = 0
54
92
num_gpus : int = None # Deprecating
55
93
template : str = f"{ dir } /templates/base-template.yaml"
56
94
appwrapper : bool = False
@@ -60,6 +98,11 @@ class ClusterConfiguration:
60
98
write_to_file : bool = False
61
99
verify_tls : bool = True
62
100
labels : dict = field (default_factory = dict )
101
+ worker_extended_resource_requests : typing .Dict [str , int ] = field (
102
+ default_factory = dict
103
+ )
104
+ extended_resource_mapping : typing .Dict [str , str ] = field (default_factory = dict )
105
+ overwrite_default_resource_mapping : bool = False
63
106
64
107
def __post_init__ (self ):
65
108
if not self .verify_tls :
@@ -69,9 +112,64 @@ def __post_init__(self):
69
112
70
113
self ._memory_to_string ()
71
114
self ._str_mem_no_unit_add_GB ()
115
+ self ._old_gpu_arg_conversion ()
72
116
self ._memory_to_resource ()
73
- self ._gpu_to_resource ()
74
117
self ._cpu_to_resource ()
118
+ self ._gpu_to_resource ()
119
+ self ._combine_extended_resource_mapping ()
120
+ self ._validate_extended_resource_requests (self .head_extended_resource_requests )
121
+ self ._validate_extended_resource_requests (
122
+ self .worker_extended_resource_requests
123
+ )
124
+
125
+ def _combine_extended_resource_mapping (self ):
126
+ if overwritten := set (self .extended_resource_mapping .keys ()).intersection (
127
+ DEFAULT_RESOURCE_MAPPING .keys ()
128
+ ):
129
+ if self .overwrite_default_resource_mapping :
130
+ warnings .warn (
131
+ f"Overwriting default resource mapping for { overwritten } " ,
132
+ UserWarning ,
133
+ )
134
+ else :
135
+ raise ValueError (
136
+ f"Resource mapping already exists for { overwritten } , set overwrite_default_resource_mapping to True to overwrite"
137
+ )
138
+ self .extended_resource_mapping = {
139
+ ** DEFAULT_RESOURCE_MAPPING ,
140
+ ** self .extended_resource_mapping ,
141
+ }
142
+
143
+ def _validate_extended_resource_requests (
144
+ self , extended_resources : typing .Dict [str , int ]
145
+ ):
146
+ for k in extended_resources .keys ():
147
+ if k not in self .extended_resource_mapping .keys ():
148
+ raise ValueError (
149
+ f"extended resource '{ k } ' not found in extended_resource_mapping, available resources are { list (self .extended_resource_mapping .keys ())} , to add more supported resources use extended_resource_mapping. i.e. extended_resource_mapping = {{'{ k } ': 'FOO_BAR'}}"
150
+ )
151
+
152
+ def _gpu_to_resource (self ):
153
+ if self .head_gpus :
154
+ warnings .warn (
155
+ "head_gpus is being deprecated, use head_extended_resource_requests i.e. head_extended_resource_requests = {'nvidia.com/gpu': 1}"
156
+ )
157
+ if "nvidia.com/gpu" in self .head_extended_resource_requests :
158
+ raise ValueError (
159
+ "nvidia.com/gpu already exists in head_extended_resource_requests"
160
+ )
161
+ self .head_extended_resource_requests ["nvidia.com/gpu" ] = self .num_head_gpus
162
+ if self .num_gpus :
163
+ warnings .warn (
164
+ "num_gpus is being deprecated, use worker_extended_resource_requests instead i.e. worker_extended_resource_requests = {'nvidia.com/gpu': 1}"
165
+ )
166
+ if "nvidia.com/gpu" in self .worker_extended_resource_requests :
167
+ raise ValueError (
168
+ "nvidia.com/gpu already exists in worker_extended_resource_requests"
169
+ )
170
+ self .worker_extended_resource_requests [
171
+ "nvidia.com/gpu"
172
+ ] = self .num_worker_gpus
75
173
76
174
def _str_mem_no_unit_add_GB (self ):
77
175
if isinstance (self .head_memory , str ) and self .head_memory .isdecimal ():
@@ -95,7 +193,7 @@ def _memory_to_string(self):
95
193
if isinstance (self .worker_memory_limits , int ):
96
194
self .worker_memory_limits = f"{ self .worker_memory_limits } G"
97
195
98
- def _gpu_to_resource (self ):
196
+ def _old_gpu_arg_conversion (self ):
99
197
if self .head_gpus :
100
198
warnings .warn ("head_gpus is being deprecated, use num_head_gpus" )
101
199
self .num_head_gpus = self .head_gpus
0 commit comments