Skip to content

Commit fb85bf8

Browse files
committed
simplify function calls and add option for custom resources
Signed-off-by: Kevin <[email protected]>
1 parent 6798b74 commit fb85bf8

12 files changed

+466
-222
lines changed

src/codeflare_sdk/cluster/cluster.py

+40-57
Original file line numberDiff line numberDiff line change
@@ -131,52 +131,7 @@ def create_app_wrapper(self):
131131
# Validate image configuration
132132
self.validate_image_config()
133133

134-
# Before attempting to create the cluster AW, let's evaluate the ClusterConfig
135-
136-
name = self.config.name
137-
namespace = self.config.namespace
138-
head_cpus = self.config.head_cpus
139-
head_memory = self.config.head_memory
140-
head_gpus = self.config.head_gpus
141-
min_cpu = self.config.min_cpus
142-
max_cpu = self.config.max_cpus
143-
min_memory = self.config.min_memory
144-
max_memory = self.config.max_memory
145-
gpu = self.config.num_gpus
146-
workers = self.config.num_workers
147-
template = self.config.template
148-
image = self.config.image
149-
appwrapper = self.config.appwrapper
150-
instance_types = self.config.machine_types
151-
env = self.config.envs
152-
image_pull_secrets = self.config.image_pull_secrets
153-
write_to_file = self.config.write_to_file
154-
verify_tls = self.config.verify_tls
155-
local_queue = self.config.local_queue
156-
labels = self.config.labels
157-
return generate_appwrapper(
158-
name=name,
159-
namespace=namespace,
160-
head_cpus=head_cpus,
161-
head_memory=head_memory,
162-
head_gpus=head_gpus,
163-
min_cpu=min_cpu,
164-
max_cpu=max_cpu,
165-
min_memory=min_memory,
166-
max_memory=max_memory,
167-
gpu=gpu,
168-
workers=workers,
169-
template=template,
170-
image=image,
171-
appwrapper=appwrapper,
172-
instance_types=instance_types,
173-
env=env,
174-
image_pull_secrets=image_pull_secrets,
175-
write_to_file=write_to_file,
176-
verify_tls=verify_tls,
177-
local_queue=local_queue,
178-
labels=labels,
179-
)
134+
return generate_appwrapper(self)
180135

181136
# creates a new cluster with the provided or default spec
182137
def up(self):
@@ -460,6 +415,29 @@ def job_logs(self, job_id: str) -> str:
460415
"""
461416
return self.job_client.get_job_logs(job_id)
462417

418+
@staticmethod
419+
def _head_worker_resources_from_rc_dict(rc: Dict) -> Tuple[dict, dict]:
420+
head_custom_resources, worker_custom_resources = {}, {}
421+
for resource in rc["spec"]["workerGroupSpecs"][0]["template"]["spec"][
422+
"containers"
423+
][0]["resources"]["limits"].keys():
424+
if resource in ["memory", "cpu"]:
425+
continue
426+
worker_custom_resources[resource] = rc["spec"]["workerGroupSpecs"][0][
427+
"template"
428+
]["spec"]["containers"][0]["resources"]["limits"][resource]
429+
430+
for resource in rc["spec"]["headGroupSpec"]["template"]["spec"]["containers"][
431+
0
432+
]["resources"]["limits"].keys():
433+
if resource in ["memory", "cpu"]:
434+
continue
435+
head_custom_resources[resource] = rc["spec"]["headGroupSpec"][0][
436+
"template"
437+
]["spec"]["containers"][0]["resources"]["limits"][resource]
438+
439+
return head_custom_resources, worker_custom_resources
440+
463441
def from_k8_cluster_object(
464442
rc,
465443
appwrapper=True,
@@ -473,6 +451,11 @@ def from_k8_cluster_object(
473451
else []
474452
)
475453

454+
(
455+
head_custom_resources,
456+
worker_custom_resources,
457+
) = Cluster._head_worker_resources_from_rc_dict(rc)
458+
476459
cluster_config = ClusterConfiguration(
477460
name=rc["metadata"]["name"],
478461
namespace=rc["metadata"]["namespace"],
@@ -490,11 +473,8 @@ def from_k8_cluster_object(
490473
max_memory=rc["spec"]["workerGroupSpecs"][0]["template"]["spec"][
491474
"containers"
492475
][0]["resources"]["limits"]["memory"],
493-
num_gpus=int(
494-
rc["spec"]["workerGroupSpecs"][0]["template"]["spec"]["containers"][0][
495-
"resources"
496-
]["limits"]["nvidia.com/gpu"]
497-
),
476+
worker_custom_resource_requests=worker_custom_resources,
477+
head_custom_resource_requests=head_custom_resources,
498478
image=rc["spec"]["workerGroupSpecs"][0]["template"]["spec"]["containers"][
499479
0
500480
]["image"],
@@ -875,6 +855,11 @@ def _map_to_ray_cluster(rc) -> Optional[RayCluster]:
875855
protocol = "https"
876856
dashboard_url = f"{protocol}://{ingress.spec.rules[0].host}"
877857

858+
(
859+
head_custom_resources,
860+
worker_custom_resources,
861+
) = Cluster._head_worker_resources_from_rc_dict(rc)
862+
878863
return RayCluster(
879864
name=rc["metadata"]["name"],
880865
status=status,
@@ -889,17 +874,15 @@ def _map_to_ray_cluster(rc) -> Optional[RayCluster]:
889874
worker_cpu=rc["spec"]["workerGroupSpecs"][0]["template"]["spec"]["containers"][
890875
0
891876
]["resources"]["limits"]["cpu"],
892-
worker_gpu=0, # hard to detect currently how many gpus, can override it with what the user asked for
877+
worker_custom_resources=worker_custom_resources,
893878
namespace=rc["metadata"]["namespace"],
894879
head_cpus=rc["spec"]["headGroupSpec"]["template"]["spec"]["containers"][0][
895880
"resources"
896881
]["limits"]["cpu"],
897882
head_mem=rc["spec"]["headGroupSpec"]["template"]["spec"]["containers"][0][
898883
"resources"
899884
]["limits"]["memory"],
900-
head_gpu=rc["spec"]["headGroupSpec"]["template"]["spec"]["containers"][0][
901-
"resources"
902-
]["limits"]["nvidia.com/gpu"],
885+
head_custom_resources=head_custom_resources,
903886
dashboard=dashboard_url,
904887
)
905888

@@ -924,12 +907,12 @@ def _copy_to_ray(cluster: Cluster) -> RayCluster:
924907
worker_mem_min=cluster.config.min_memory,
925908
worker_mem_max=cluster.config.max_memory,
926909
worker_cpu=cluster.config.min_cpus,
927-
worker_gpu=cluster.config.num_gpus,
910+
worker_custom_resources=cluster.config.worker_custom_resource_requests,
928911
namespace=cluster.config.namespace,
929912
dashboard=cluster.cluster_dashboard_uri(),
930913
head_cpus=cluster.config.head_cpus,
931914
head_mem=cluster.config.head_memory,
932-
head_gpu=cluster.config.head_gpus,
915+
head_custom_resources=cluster.config.head_custom_resource_requests,
933916
)
934917
if ray.status == CodeFlareClusterStatus.READY:
935918
ray.status = RayClusterStatus.READY

src/codeflare_sdk/cluster/config.py

+57
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,22 @@
2121
from dataclasses import dataclass, field
2222
import pathlib
2323
import typing
24+
import warnings
2425

2526
dir = pathlib.Path(__file__).parent.parent.resolve()
2627

28+
# https://docs.ray.io/en/latest/ray-core/scheduling/accelerators.html
29+
DEFAULT_RESOURCE_MAPPING = {
30+
"nvidia.com/gpu": "GPU",
31+
"intel.com/gpu": "GPU",
32+
"amd.com/gpu": "GPU",
33+
"aws.amazon.com/neuroncore": "neuron_cores",
34+
"google.com/tpu": "TPU",
35+
"habana.ai/gaudi": "HPU",
36+
"huawei.com/Ascend910": "NPU",
37+
"huawei.com/Ascend310": "NPU",
38+
}
39+
2740

2841
@dataclass
2942
class ClusterConfiguration:
@@ -38,6 +51,7 @@ class ClusterConfiguration:
3851
head_cpus: typing.Union[int, str] = 2
3952
head_memory: typing.Union[int, str] = 8
4053
head_gpus: int = 0
54+
head_custom_resource_requests: typing.Dict[str, int] = field(default_factory=dict)
4155
machine_types: list = field(default_factory=list) # ["m4.xlarge", "g4dn.xlarge"]
4256
min_cpus: typing.Union[int, str] = 1
4357
max_cpus: typing.Union[int, str] = 1
@@ -53,6 +67,9 @@ class ClusterConfiguration:
5367
write_to_file: bool = False
5468
verify_tls: bool = True
5569
labels: dict = field(default_factory=dict)
70+
worker_custom_resource_requests: typing.Dict[str, int] = field(default_factory=dict)
71+
custom_resource_mapping: typing.Dict[str, str] = field(default_factory=dict)
72+
overwrite_default_resource_mapping: bool = False
5673

5774
def __post_init__(self):
5875
if not self.verify_tls:
@@ -61,6 +78,46 @@ def __post_init__(self):
6178
)
6279
self._memory_to_string()
6380
self._str_mem_no_unit_add_GB()
81+
self._gpu_to_resource()
82+
self._combine_custom_resource_mapping()
83+
84+
def _combine_custom_resource_mapping(self):
85+
if overwritten := set(self.custom_resource_mapping.keys()).intersection(
86+
DEFAULT_RESOURCE_MAPPING.keys()
87+
):
88+
if self.overwrite_default_resource_mapping:
89+
warnings.warn(
90+
f"Overwriting default resource mapping for {overwritten}",
91+
UserWarning,
92+
)
93+
else:
94+
raise ValueError(
95+
f"Resource mapping already exists for {overwritten}, set overwrite_default_resource_mapping to True to overwrite"
96+
)
97+
self.custom_resource_mapping = {
98+
**DEFAULT_RESOURCE_MAPPING,
99+
**self.custom_resource_mapping,
100+
}
101+
102+
def _gpu_to_resource(self):
103+
if self.head_gpus:
104+
warnings.warn(
105+
"head_gpus is being deprecated, use head_custom_resource_requests"
106+
)
107+
if "nvidia.com/gpu" in self.head_custom_resource_requests:
108+
raise ValueError(
109+
"nvidia.com/gpu already exists in head_custom_resource_requests"
110+
)
111+
self.head_custom_resource_requests["nvidia.com/gpu"] = self.head_gpus
112+
if self.num_gpus:
113+
warnings.warn(
114+
"num_gpus is being deprecated, use worker_custom_resource_requests"
115+
)
116+
if "nvidia.com/gpu" in self.worker_custom_resource_requests:
117+
raise ValueError(
118+
"nvidia.com/gpu already exists in worker_custom_resource_requests"
119+
)
120+
self.worker_custom_resource_requests["nvidia.com/gpu"] = self.num_gpus
64121

65122
def _str_mem_no_unit_add_GB(self):
66123
if isinstance(self.head_memory, str) and self.head_memory.isdecimal():

src/codeflare_sdk/cluster/model.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@
1818
dataclasses to store information for Ray clusters and AppWrappers.
1919
"""
2020

21-
from dataclasses import dataclass
21+
from dataclasses import dataclass, field
2222
from enum import Enum
23+
import typing
2324

2425

2526
class RayClusterStatus(Enum):
@@ -74,14 +75,14 @@ class RayCluster:
7475
status: RayClusterStatus
7576
head_cpus: int
7677
head_mem: str
77-
head_gpu: int
7878
workers: int
7979
worker_mem_min: str
8080
worker_mem_max: str
8181
worker_cpu: int
82-
worker_gpu: int
8382
namespace: str
8483
dashboard: str
84+
worker_custom_resources: typing.Dict[str, int] = field(default_factory=dict)
85+
head_custom_resources: typing.Dict[str, int] = field(default_factory=dict)
8586

8687

8788
@dataclass

src/codeflare_sdk/templates/base-template.yaml

-6
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ spec:
7171
# The value of `resources` is a string-integer mapping.
7272
# Currently, `resources` must be provided in the specific format demonstrated below:
7373
# resources: '"{\"Custom1\": 1, \"Custom2\": 5}"'
74-
num-gpus: '0'
7574
#pod template
7675
template:
7776
spec:
@@ -95,11 +94,9 @@ spec:
9594
limits:
9695
cpu: 2
9796
memory: "8G"
98-
nvidia.com/gpu: 0
9997
requests:
10098
cpu: 2
10199
memory: "8G"
102-
nvidia.com/gpu: 0
103100
volumeMounts:
104101
- mountPath: /etc/pki/tls/certs/odh-trusted-ca-bundle.crt
105102
name: odh-trusted-ca-cert
@@ -147,7 +144,6 @@ spec:
147144
# the following params are used to complete the ray start: ray start --block ...
148145
rayStartParams:
149146
block: 'true'
150-
num-gpus: 1
151147
#pod template
152148
template:
153149
metadata:
@@ -172,11 +168,9 @@ spec:
172168
limits:
173169
cpu: "2"
174170
memory: "12G"
175-
nvidia.com/gpu: "1"
176171
requests:
177172
cpu: "2"
178173
memory: "12G"
179-
nvidia.com/gpu: "1"
180174
volumeMounts:
181175
- mountPath: /etc/pki/tls/certs/odh-trusted-ca-bundle.crt
182176
name: odh-trusted-ca-cert

0 commit comments

Comments
 (0)