29
29
from ..utils import pretty_print
30
30
from ..utils .generate_yaml import (
31
31
generate_appwrapper ,
32
+ head_worker_gpu_count_from_cluster ,
32
33
)
33
34
from ..utils .kube_api_helpers import _kube_api_error_handling
34
35
from ..utils .generate_yaml import is_openshift_cluster
@@ -118,48 +119,7 @@ def create_app_wrapper(self):
118
119
f"Namespace { self .config .namespace } is of type { type (self .config .namespace )} . Check your Kubernetes Authentication."
119
120
)
120
121
121
- # Before attempting to create the cluster AW, let's evaluate the ClusterConfig
122
-
123
- name = self .config .name
124
- namespace = self .config .namespace
125
- head_cpus = self .config .head_cpus
126
- head_memory = self .config .head_memory
127
- num_head_gpus = self .config .num_head_gpus
128
- worker_cpu_requests = self .config .worker_cpu_requests
129
- worker_cpu_limits = self .config .worker_cpu_limits
130
- worker_memory_requests = self .config .worker_memory_requests
131
- worker_memory_limits = self .config .worker_memory_limits
132
- num_worker_gpus = self .config .num_worker_gpus
133
- workers = self .config .num_workers
134
- template = self .config .template
135
- image = self .config .image
136
- appwrapper = self .config .appwrapper
137
- env = self .config .envs
138
- image_pull_secrets = self .config .image_pull_secrets
139
- write_to_file = self .config .write_to_file
140
- local_queue = self .config .local_queue
141
- labels = self .config .labels
142
- return generate_appwrapper (
143
- name = name ,
144
- namespace = namespace ,
145
- head_cpus = head_cpus ,
146
- head_memory = head_memory ,
147
- num_head_gpus = num_head_gpus ,
148
- worker_cpu_requests = worker_cpu_requests ,
149
- worker_cpu_limits = worker_cpu_limits ,
150
- worker_memory_requests = worker_memory_requests ,
151
- worker_memory_limits = worker_memory_limits ,
152
- num_worker_gpus = num_worker_gpus ,
153
- workers = workers ,
154
- template = template ,
155
- image = image ,
156
- appwrapper = appwrapper ,
157
- env = env ,
158
- image_pull_secrets = image_pull_secrets ,
159
- write_to_file = write_to_file ,
160
- local_queue = local_queue ,
161
- labels = labels ,
162
- )
122
+ return generate_appwrapper (self )
163
123
164
124
# creates a new cluster with the provided or default spec
165
125
def up (self ):
@@ -305,7 +265,7 @@ def status(
305
265
306
266
if print_to_console :
307
267
# overriding the number of gpus with requested
308
- cluster .worker_gpu = self . config . num_worker_gpus
268
+ _ , cluster .worker_gpu = head_worker_gpu_count_from_cluster ( self )
309
269
pretty_print .print_cluster_status (cluster )
310
270
elif print_to_console :
311
271
if status == CodeFlareClusterStatus .UNKNOWN :
@@ -443,6 +403,29 @@ def job_logs(self, job_id: str) -> str:
443
403
"""
444
404
return self .job_client .get_job_logs (job_id )
445
405
406
+ @staticmethod
407
+ def _head_worker_extended_resources_from_rc_dict (rc : Dict ) -> Tuple [dict , dict ]:
408
+ head_extended_resources , worker_extended_resources = {}, {}
409
+ for resource in rc ["spec" ]["workerGroupSpecs" ][0 ]["template" ]["spec" ][
410
+ "containers"
411
+ ][0 ]["resources" ]["limits" ].keys ():
412
+ if resource in ["memory" , "cpu" ]:
413
+ continue
414
+ worker_extended_resources [resource ] = rc ["spec" ]["workerGroupSpecs" ][0 ][
415
+ "template"
416
+ ]["spec" ]["containers" ][0 ]["resources" ]["limits" ][resource ]
417
+
418
+ for resource in rc ["spec" ]["headGroupSpec" ]["template" ]["spec" ]["containers" ][
419
+ 0
420
+ ]["resources" ]["limits" ].keys ():
421
+ if resource in ["memory" , "cpu" ]:
422
+ continue
423
+ head_extended_resources [resource ] = rc ["spec" ]["headGroupSpec" ]["template" ][
424
+ "spec"
425
+ ]["containers" ][0 ]["resources" ]["limits" ][resource ]
426
+
427
+ return head_extended_resources , worker_extended_resources
428
+
446
429
def from_k8_cluster_object (
447
430
rc ,
448
431
appwrapper = True ,
@@ -456,6 +439,11 @@ def from_k8_cluster_object(
456
439
else []
457
440
)
458
441
442
+ (
443
+ head_extended_resources ,
444
+ worker_extended_resources ,
445
+ ) = Cluster ._head_worker_extended_resources_from_rc_dict (rc )
446
+
459
447
cluster_config = ClusterConfiguration (
460
448
name = rc ["metadata" ]["name" ],
461
449
namespace = rc ["metadata" ]["namespace" ],
@@ -473,11 +461,8 @@ def from_k8_cluster_object(
473
461
worker_memory_limits = rc ["spec" ]["workerGroupSpecs" ][0 ]["template" ]["spec" ][
474
462
"containers"
475
463
][0 ]["resources" ]["limits" ]["memory" ],
476
- num_worker_gpus = int (
477
- rc ["spec" ]["workerGroupSpecs" ][0 ]["template" ]["spec" ]["containers" ][0 ][
478
- "resources"
479
- ]["limits" ]["nvidia.com/gpu" ]
480
- ),
464
+ worker_extended_resource_requests = worker_extended_resources ,
465
+ head_extended_resource_requests = head_extended_resources ,
481
466
image = rc ["spec" ]["workerGroupSpecs" ][0 ]["template" ]["spec" ]["containers" ][
482
467
0
483
468
]["image" ],
@@ -858,6 +843,11 @@ def _map_to_ray_cluster(rc) -> Optional[RayCluster]:
858
843
protocol = "https"
859
844
dashboard_url = f"{ protocol } ://{ ingress .spec .rules [0 ].host } "
860
845
846
+ (
847
+ head_extended_resources ,
848
+ worker_extended_resources ,
849
+ ) = Cluster ._head_worker_extended_resources_from_rc_dict (rc )
850
+
861
851
return RayCluster (
862
852
name = rc ["metadata" ]["name" ],
863
853
status = status ,
@@ -872,17 +862,15 @@ def _map_to_ray_cluster(rc) -> Optional[RayCluster]:
872
862
worker_cpu = rc ["spec" ]["workerGroupSpecs" ][0 ]["template" ]["spec" ]["containers" ][
873
863
0
874
864
]["resources" ]["limits" ]["cpu" ],
875
- worker_gpu = 0 , # hard to detect currently how many gpus, can override it with what the user asked for
865
+ worker_extended_resources = worker_extended_resources ,
876
866
namespace = rc ["metadata" ]["namespace" ],
877
867
head_cpus = rc ["spec" ]["headGroupSpec" ]["template" ]["spec" ]["containers" ][0 ][
878
868
"resources"
879
869
]["limits" ]["cpu" ],
880
870
head_mem = rc ["spec" ]["headGroupSpec" ]["template" ]["spec" ]["containers" ][0 ][
881
871
"resources"
882
872
]["limits" ]["memory" ],
883
- head_gpu = rc ["spec" ]["headGroupSpec" ]["template" ]["spec" ]["containers" ][0 ][
884
- "resources"
885
- ]["limits" ]["nvidia.com/gpu" ],
873
+ head_extended_resources = head_extended_resources ,
886
874
dashboard = dashboard_url ,
887
875
)
888
876
@@ -907,12 +895,12 @@ def _copy_to_ray(cluster: Cluster) -> RayCluster:
907
895
worker_mem_min = cluster .config .worker_memory_requests ,
908
896
worker_mem_max = cluster .config .worker_memory_limits ,
909
897
worker_cpu = cluster .config .worker_cpu_requests ,
910
- worker_gpu = cluster .config .num_worker_gpus ,
898
+ worker_extended_resources = cluster .config .worker_extended_resource_requests ,
911
899
namespace = cluster .config .namespace ,
912
900
dashboard = cluster .cluster_dashboard_uri (),
913
901
head_cpus = cluster .config .head_cpus ,
914
902
head_mem = cluster .config .head_memory ,
915
- head_gpu = cluster .config .num_head_gpus ,
903
+ head_extended_resources = cluster .config .head_extended_resource_requests ,
916
904
)
917
905
if ray .status == CodeFlareClusterStatus .READY :
918
906
ray .status = RayClusterStatus .READY
0 commit comments