16
16
This sub-module exists primarily to be used internally by the Cluster object
17
17
(in the cluster sub-module) for RayCluster/AppWrapper generation.
18
18
"""
19
- from typing import Union , Tuple , Dict
19
+ from typing import List , Union , Tuple , Dict
20
20
from ...common import _kube_api_error_handling
21
21
from ...common .kubernetes_cluster import get_api_client , config_check
22
22
from kubernetes .client .exceptions import ApiException
40
40
V1PodTemplateSpec ,
41
41
V1PodSpec ,
42
42
V1LocalObjectReference ,
43
+ V1Toleration ,
43
44
)
44
45
45
46
import yaml
@@ -139,7 +140,11 @@ def build_ray_cluster(cluster: "codeflare_sdk.ray.cluster.Cluster"):
139
140
"resources" : head_resources ,
140
141
},
141
142
"template" : {
142
- "spec" : get_pod_spec (cluster , [get_head_container_spec (cluster )])
143
+ "spec" : get_pod_spec (
144
+ cluster ,
145
+ [get_head_container_spec (cluster )],
146
+ cluster .config .head_tolerations ,
147
+ )
143
148
},
144
149
},
145
150
"workerGroupSpecs" : [
@@ -154,7 +159,11 @@ def build_ray_cluster(cluster: "codeflare_sdk.ray.cluster.Cluster"):
154
159
"resources" : worker_resources ,
155
160
},
156
161
"template" : V1PodTemplateSpec (
157
- spec = get_pod_spec (cluster , [get_worker_container_spec (cluster )])
162
+ spec = get_pod_spec (
163
+ cluster ,
164
+ [get_worker_container_spec (cluster )],
165
+ cluster .config .tolerations ,
166
+ )
158
167
),
159
168
}
160
169
],
@@ -243,14 +252,17 @@ def update_image(image) -> str:
243
252
return image
244
253
245
254
246
- def get_pod_spec (cluster : "codeflare_sdk.ray.cluster.Cluster" , containers ):
255
+ def get_pod_spec (cluster : "codeflare_sdk.ray.cluster.Cluster" , containers , tolerations ):
247
256
"""
248
257
The get_pod_spec() function generates a V1PodSpec for the head/worker containers
249
258
"""
250
- pod_spec = V1PodSpec (
251
- containers = containers ,
252
- volumes = VOLUMES ,
253
- )
259
+ if tolerations is None :
260
+ pod_spec = V1PodSpec (containers = containers , volumes = VOLUMES )
261
+ else :
262
+ pod_spec = V1PodSpec (
263
+ containers = containers , volumes = VOLUMES , tolerations = tolerations
264
+ )
265
+
254
266
if cluster .config .image_pull_secrets != []:
255
267
pod_spec .image_pull_secrets = generate_image_pull_secrets (cluster )
256
268
@@ -402,28 +414,18 @@ def head_worker_extended_resources_from_cluster(
402
414
resource_type = cluster .config .extended_resource_mapping [k ]
403
415
if resource_type in FORBIDDEN_CUSTOM_RESOURCE_TYPES :
404
416
continue
405
- head_worker_extended_resources [0 ][
406
- resource_type
407
- ] = cluster .config .head_extended_resource_requests [
408
- k
409
- ] + head_worker_extended_resources [
410
- 0
411
- ].get (
412
- resource_type , 0
417
+ head_worker_extended_resources [0 ][resource_type ] = (
418
+ cluster .config .head_extended_resource_requests [k ]
419
+ + head_worker_extended_resources [0 ].get (resource_type , 0 )
413
420
)
414
421
415
422
for k in cluster .config .worker_extended_resource_requests .keys ():
416
423
resource_type = cluster .config .extended_resource_mapping [k ]
417
424
if resource_type in FORBIDDEN_CUSTOM_RESOURCE_TYPES :
418
425
continue
419
- head_worker_extended_resources [1 ][
420
- resource_type
421
- ] = cluster .config .worker_extended_resource_requests [
422
- k
423
- ] + head_worker_extended_resources [
424
- 1
425
- ].get (
426
- resource_type , 0
426
+ head_worker_extended_resources [1 ][resource_type ] = (
427
+ cluster .config .worker_extended_resource_requests [k ]
428
+ + head_worker_extended_resources [1 ].get (resource_type , 0 )
427
429
)
428
430
return head_worker_extended_resources
429
431
0 commit comments