Skip to content

Commit 49bb0f1

Browse files
committed
simplify function calls and add option for custom resources
Signed-off-by: Kevin <[email protected]>
1 parent 20476aa commit 49bb0f1

12 files changed

+239
-193
lines changed

src/codeflare_sdk/cluster/cluster.py

+40-57
Original file line numberDiff line numberDiff line change
@@ -131,52 +131,7 @@ def create_app_wrapper(self):
131131
# Validate image configuration
132132
self.validate_image_config()
133133

134-
# Before attempting to create the cluster AW, let's evaluate the ClusterConfig
135-
136-
name = self.config.name
137-
namespace = self.config.namespace
138-
head_cpus = self.config.head_cpus
139-
head_memory = self.config.head_memory
140-
head_gpus = self.config.head_gpus
141-
min_cpu = self.config.min_cpus
142-
max_cpu = self.config.max_cpus
143-
min_memory = self.config.min_memory
144-
max_memory = self.config.max_memory
145-
gpu = self.config.num_gpus
146-
workers = self.config.num_workers
147-
template = self.config.template
148-
image = self.config.image
149-
appwrapper = self.config.appwrapper
150-
env = self.config.envs
151-
image_pull_secrets = self.config.image_pull_secrets
152-
write_to_file = self.config.write_to_file
153-
local_queue = self.config.local_queue
154-
labels = self.config.labels
155-
volumes = self.config.volumes
156-
volume_mounts = self.config.volume_mounts
157-
return generate_appwrapper(
158-
name=name,
159-
namespace=namespace,
160-
head_cpus=head_cpus,
161-
head_memory=head_memory,
162-
head_gpus=head_gpus,
163-
min_cpu=min_cpu,
164-
max_cpu=max_cpu,
165-
min_memory=min_memory,
166-
max_memory=max_memory,
167-
gpu=gpu,
168-
workers=workers,
169-
template=template,
170-
image=image,
171-
appwrapper=appwrapper,
172-
env=env,
173-
image_pull_secrets=image_pull_secrets,
174-
write_to_file=write_to_file,
175-
local_queue=local_queue,
176-
labels=labels,
177-
volumes=volumes,
178-
volume_mounts=volume_mounts,
179-
)
134+
return generate_appwrapper(self)
180135

181136
# creates a new cluster with the provided or default spec
182137
def up(self):
@@ -460,6 +415,29 @@ def job_logs(self, job_id: str) -> str:
460415
"""
461416
return self.job_client.get_job_logs(job_id)
462417

418+
@staticmethod
419+
def _head_worker_extended_resources_from_rc_dict(rc: Dict) -> Tuple[dict, dict]:
420+
head_extended_resources, worker_extended_resources = {}, {}
421+
for resource in rc["spec"]["workerGroupSpecs"][0]["template"]["spec"][
422+
"containers"
423+
][0]["resources"]["limits"].keys():
424+
if resource in ["memory", "cpu"]:
425+
continue
426+
worker_extended_resources[resource] = rc["spec"]["workerGroupSpecs"][0][
427+
"template"
428+
]["spec"]["containers"][0]["resources"]["limits"][resource]
429+
430+
for resource in rc["spec"]["headGroupSpec"]["template"]["spec"]["containers"][
431+
0
432+
]["resources"]["limits"].keys():
433+
if resource in ["memory", "cpu"]:
434+
continue
435+
head_extended_resources[resource] = rc["spec"]["headGroupSpec"]["template"][
436+
"spec"
437+
]["containers"][0]["resources"]["limits"][resource]
438+
439+
return head_extended_resources, worker_extended_resources
440+
463441
def from_k8_cluster_object(
464442
rc,
465443
appwrapper=True,
@@ -473,6 +451,11 @@ def from_k8_cluster_object(
473451
else []
474452
)
475453

454+
(
455+
head_extended_resources,
456+
worker_extended_resources,
457+
) = Cluster._head_worker_extended_resources_from_rc_dict(rc)
458+
476459
cluster_config = ClusterConfiguration(
477460
name=rc["metadata"]["name"],
478461
namespace=rc["metadata"]["namespace"],
@@ -490,11 +473,8 @@ def from_k8_cluster_object(
490473
max_memory=rc["spec"]["workerGroupSpecs"][0]["template"]["spec"][
491474
"containers"
492475
][0]["resources"]["limits"]["memory"],
493-
num_gpus=int(
494-
rc["spec"]["workerGroupSpecs"][0]["template"]["spec"]["containers"][0][
495-
"resources"
496-
]["limits"]["nvidia.com/gpu"]
497-
),
476+
worker_extended_resource_requests=worker_extended_resources,
477+
head_extended_resource_requests=head_extended_resources,
498478
image=rc["spec"]["workerGroupSpecs"][0]["template"]["spec"]["containers"][
499479
0
500480
]["image"],
@@ -875,6 +855,11 @@ def _map_to_ray_cluster(rc) -> Optional[RayCluster]:
875855
protocol = "https"
876856
dashboard_url = f"{protocol}://{ingress.spec.rules[0].host}"
877857

858+
(
859+
head_extended_resources,
860+
worker_extended_resources,
861+
) = Cluster._head_worker_extended_resources_from_rc_dict(rc)
862+
878863
return RayCluster(
879864
name=rc["metadata"]["name"],
880865
status=status,
@@ -889,17 +874,15 @@ def _map_to_ray_cluster(rc) -> Optional[RayCluster]:
889874
worker_cpu=rc["spec"]["workerGroupSpecs"][0]["template"]["spec"]["containers"][
890875
0
891876
]["resources"]["limits"]["cpu"],
892-
worker_gpu=0, # hard to detect currently how many gpus, can override it with what the user asked for
877+
worker_extended_resources=worker_extended_resources,
893878
namespace=rc["metadata"]["namespace"],
894879
head_cpus=rc["spec"]["headGroupSpec"]["template"]["spec"]["containers"][0][
895880
"resources"
896881
]["limits"]["cpu"],
897882
head_mem=rc["spec"]["headGroupSpec"]["template"]["spec"]["containers"][0][
898883
"resources"
899884
]["limits"]["memory"],
900-
head_gpu=rc["spec"]["headGroupSpec"]["template"]["spec"]["containers"][0][
901-
"resources"
902-
]["limits"]["nvidia.com/gpu"],
885+
head_extended_resources=head_extended_resources,
903886
dashboard=dashboard_url,
904887
)
905888

@@ -924,12 +907,12 @@ def _copy_to_ray(cluster: Cluster) -> RayCluster:
924907
worker_mem_min=cluster.config.min_memory,
925908
worker_mem_max=cluster.config.max_memory,
926909
worker_cpu=cluster.config.min_cpus,
927-
worker_gpu=cluster.config.num_gpus,
910+
worker_extended_resources=cluster.config.worker_extended_resource_requests,
928911
namespace=cluster.config.namespace,
929912
dashboard=cluster.cluster_dashboard_uri(),
930913
head_cpus=cluster.config.head_cpus,
931914
head_mem=cluster.config.head_memory,
932-
head_gpu=cluster.config.head_gpus,
915+
head_extended_resources=cluster.config.head_extended_resource_requests,
933916
)
934917
if ray.status == CodeFlareClusterStatus.READY:
935918
ray.status = RayClusterStatus.READY

src/codeflare_sdk/cluster/config.py

+59
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,22 @@
2121
from dataclasses import dataclass, field
2222
import pathlib
2323
import typing
24+
import warnings
2425

2526
dir = pathlib.Path(__file__).parent.parent.resolve()
2627

28+
# https://docs.ray.io/en/latest/ray-core/scheduling/accelerators.html
29+
DEFAULT_RESOURCE_MAPPING = {
30+
"nvidia.com/gpu": "GPU",
31+
"intel.com/gpu": "GPU",
32+
"amd.com/gpu": "GPU",
33+
"aws.amazon.com/neuroncore": "neuron_cores",
34+
"google.com/tpu": "TPU",
35+
"habana.ai/gaudi": "HPU",
36+
"huawei.com/Ascend910": "NPU",
37+
"huawei.com/Ascend310": "NPU",
38+
}
39+
2740

2841
@dataclass
2942
class ClusterConfiguration:
@@ -38,6 +51,7 @@ class ClusterConfiguration:
3851
head_cpus: typing.Union[int, str] = 2
3952
head_memory: typing.Union[int, str] = 8
4053
head_gpus: int = 0
54+
head_extended_resource_requests: typing.Dict[str, int] = field(default_factory=dict)
4155
machine_types: list = field(default_factory=list) # ["m4.xlarge", "g4dn.xlarge"]
4256
min_cpus: typing.Union[int, str] = 1
4357
max_cpus: typing.Union[int, str] = 1
@@ -55,6 +69,11 @@ class ClusterConfiguration:
5569
labels: dict = field(default_factory=dict)
5670
volumes: list = field(default_factory=list)
5771
volume_mounts: list = field(default_factory=list)
72+
worker_extended_resource_requests: typing.Dict[str, int] = field(
73+
default_factory=dict
74+
)
75+
custom_resource_mapping: typing.Dict[str, str] = field(default_factory=dict)
76+
overwrite_default_resource_mapping: bool = False
5877

5978
def __post_init__(self):
6079
if not self.verify_tls:
@@ -63,6 +82,46 @@ def __post_init__(self):
6382
)
6483
self._memory_to_string()
6584
self._str_mem_no_unit_add_GB()
85+
self._gpu_to_resource()
86+
self._combine_custom_resource_mapping()
87+
88+
def _combine_custom_resource_mapping(self):
89+
if overwritten := set(self.custom_resource_mapping.keys()).intersection(
90+
DEFAULT_RESOURCE_MAPPING.keys()
91+
):
92+
if self.overwrite_default_resource_mapping:
93+
warnings.warn(
94+
f"Overwriting default resource mapping for {overwritten}",
95+
UserWarning,
96+
)
97+
else:
98+
raise ValueError(
99+
f"Resource mapping already exists for {overwritten}, set overwrite_default_resource_mapping to True to overwrite"
100+
)
101+
self.custom_resource_mapping = {
102+
**DEFAULT_RESOURCE_MAPPING,
103+
**self.custom_resource_mapping,
104+
}
105+
106+
def _gpu_to_resource(self):
107+
if self.head_gpus:
108+
warnings.warn(
109+
"head_gpus is being deprecated, use head_custom_resource_requests"
110+
)
111+
if "nvidia.com/gpu" in self.head_extended_resource_requests:
112+
raise ValueError(
113+
"nvidia.com/gpu already exists in head_custom_resource_requests"
114+
)
115+
self.head_extended_resource_requests["nvidia.com/gpu"] = self.head_gpus
116+
if self.num_gpus:
117+
warnings.warn(
118+
"num_gpus is being deprecated, use worker_custom_resource_requests"
119+
)
120+
if "nvidia.com/gpu" in self.worker_extended_resource_requests:
121+
raise ValueError(
122+
"nvidia.com/gpu already exists in worker_custom_resource_requests"
123+
)
124+
self.worker_extended_resource_requests["nvidia.com/gpu"] = self.num_gpus
66125

67126
def _str_mem_no_unit_add_GB(self):
68127
if isinstance(self.head_memory, str) and self.head_memory.isdecimal():

src/codeflare_sdk/cluster/model.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@
1818
dataclasses to store information for Ray clusters and AppWrappers.
1919
"""
2020

21-
from dataclasses import dataclass
21+
from dataclasses import dataclass, field
2222
from enum import Enum
23+
import typing
2324

2425

2526
class RayClusterStatus(Enum):
@@ -74,14 +75,14 @@ class RayCluster:
7475
status: RayClusterStatus
7576
head_cpus: int
7677
head_mem: str
77-
head_gpu: int
7878
workers: int
7979
worker_mem_min: str
8080
worker_mem_max: str
8181
worker_cpu: int
82-
worker_gpu: int
8382
namespace: str
8483
dashboard: str
84+
worker_extended_resources: typing.Dict[str, int] = field(default_factory=dict)
85+
head_extended_resources: typing.Dict[str, int] = field(default_factory=dict)
8586

8687

8788
@dataclass

src/codeflare_sdk/templates/base-template.yaml

-4
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,9 @@ spec:
8686
limits:
8787
cpu: 2
8888
memory: "8G"
89-
nvidia.com/gpu: 0
9089
requests:
9190
cpu: 2
9291
memory: "8G"
93-
nvidia.com/gpu: 0
9492
volumeMounts:
9593
- mountPath: /etc/pki/tls/certs/odh-trusted-ca-bundle.crt
9694
name: odh-trusted-ca-cert
@@ -163,11 +161,9 @@ spec:
163161
limits:
164162
cpu: "2"
165163
memory: "12G"
166-
nvidia.com/gpu: "1"
167164
requests:
168165
cpu: "2"
169166
memory: "12G"
170-
nvidia.com/gpu: "1"
171167
volumeMounts:
172168
- mountPath: /etc/pki/tls/certs/odh-trusted-ca-bundle.crt
173169
name: odh-trusted-ca-cert

0 commit comments

Comments
 (0)