Skip to content

Commit 536f0c8

Browse files
committed
simplify function calls and add option for custom resources
Signed-off-by: Kevin <[email protected]>
1 parent 1ab5421 commit 536f0c8

12 files changed

+246
-183
lines changed

src/codeflare_sdk/cluster/cluster.py

+40-53
Original file line numberDiff line numberDiff line change
@@ -131,48 +131,7 @@ def create_app_wrapper(self):
131131
# Validate image configuration
132132
self.validate_image_config()
133133

134-
# Before attempting to create the cluster AW, let's evaluate the ClusterConfig
135-
136-
name = self.config.name
137-
namespace = self.config.namespace
138-
head_cpus = self.config.head_cpus
139-
head_memory = self.config.head_memory
140-
num_head_gpus = self.config.num_head_gpus
141-
worker_cpu_requests = self.config.worker_cpu_requests
142-
worker_cpu_limits = self.config.worker_cpu_limits
143-
worker_memory_requests = self.config.worker_memory_requests
144-
worker_memory_limits = self.config.worker_memory_limits
145-
num_worker_gpus = self.config.num_worker_gpus
146-
workers = self.config.num_workers
147-
template = self.config.template
148-
image = self.config.image
149-
appwrapper = self.config.appwrapper
150-
env = self.config.envs
151-
image_pull_secrets = self.config.image_pull_secrets
152-
write_to_file = self.config.write_to_file
153-
local_queue = self.config.local_queue
154-
labels = self.config.labels
155-
return generate_appwrapper(
156-
name=name,
157-
namespace=namespace,
158-
head_cpus=head_cpus,
159-
head_memory=head_memory,
160-
num_head_gpus=num_head_gpus,
161-
worker_cpu_requests=worker_cpu_requests,
162-
worker_cpu_limits=worker_cpu_limits,
163-
worker_memory_requests=worker_memory_requests,
164-
worker_memory_limits=worker_memory_limits,
165-
num_worker_gpus=num_worker_gpus,
166-
workers=workers,
167-
template=template,
168-
image=image,
169-
appwrapper=appwrapper,
170-
env=env,
171-
image_pull_secrets=image_pull_secrets,
172-
write_to_file=write_to_file,
173-
local_queue=local_queue,
174-
labels=labels,
175-
)
134+
return generate_appwrapper(self)
176135

177136
# creates a new cluster with the provided or default spec
178137
def up(self):
@@ -456,6 +415,29 @@ def job_logs(self, job_id: str) -> str:
456415
"""
457416
return self.job_client.get_job_logs(job_id)
458417

418+
@staticmethod
419+
def _head_worker_extended_resources_from_rc_dict(rc: Dict) -> Tuple[dict, dict]:
420+
head_extended_resources, worker_extended_resources = {}, {}
421+
for resource in rc["spec"]["workerGroupSpecs"][0]["template"]["spec"][
422+
"containers"
423+
][0]["resources"]["limits"].keys():
424+
if resource in ["memory", "cpu"]:
425+
continue
426+
worker_extended_resources[resource] = rc["spec"]["workerGroupSpecs"][0][
427+
"template"
428+
]["spec"]["containers"][0]["resources"]["limits"][resource]
429+
430+
for resource in rc["spec"]["headGroupSpec"]["template"]["spec"]["containers"][
431+
0
432+
]["resources"]["limits"].keys():
433+
if resource in ["memory", "cpu"]:
434+
continue
435+
head_extended_resources[resource] = rc["spec"]["headGroupSpec"]["template"][
436+
"spec"
437+
]["containers"][0]["resources"]["limits"][resource]
438+
439+
return head_extended_resources, worker_extended_resources
440+
459441
def from_k8_cluster_object(
460442
rc,
461443
appwrapper=True,
@@ -469,6 +451,11 @@ def from_k8_cluster_object(
469451
else []
470452
)
471453

454+
(
455+
head_extended_resources,
456+
worker_extended_resources,
457+
) = Cluster._head_worker_extended_resources_from_rc_dict(rc)
458+
472459
cluster_config = ClusterConfiguration(
473460
name=rc["metadata"]["name"],
474461
namespace=rc["metadata"]["namespace"],
@@ -486,11 +473,8 @@ def from_k8_cluster_object(
486473
worker_memory_limits=rc["spec"]["workerGroupSpecs"][0]["template"]["spec"][
487474
"containers"
488475
][0]["resources"]["limits"]["memory"],
489-
num_worker_gpus=int(
490-
rc["spec"]["workerGroupSpecs"][0]["template"]["spec"]["containers"][0][
491-
"resources"
492-
]["limits"]["nvidia.com/gpu"]
493-
),
476+
worker_extended_resource_requests=worker_extended_resources,
477+
head_extended_resource_requests=head_extended_resources,
494478
image=rc["spec"]["workerGroupSpecs"][0]["template"]["spec"]["containers"][
495479
0
496480
]["image"],
@@ -871,6 +855,11 @@ def _map_to_ray_cluster(rc) -> Optional[RayCluster]:
871855
protocol = "https"
872856
dashboard_url = f"{protocol}://{ingress.spec.rules[0].host}"
873857

858+
(
859+
head_extended_resources,
860+
worker_extended_resources,
861+
) = Cluster._head_worker_extended_resources_from_rc_dict(rc)
862+
874863
return RayCluster(
875864
name=rc["metadata"]["name"],
876865
status=status,
@@ -885,17 +874,15 @@ def _map_to_ray_cluster(rc) -> Optional[RayCluster]:
885874
worker_cpu=rc["spec"]["workerGroupSpecs"][0]["template"]["spec"]["containers"][
886875
0
887876
]["resources"]["limits"]["cpu"],
888-
worker_gpu=0, # hard to detect currently how many gpus, can override it with what the user asked for
877+
worker_extended_resources=worker_extended_resources,
889878
namespace=rc["metadata"]["namespace"],
890879
head_cpus=rc["spec"]["headGroupSpec"]["template"]["spec"]["containers"][0][
891880
"resources"
892881
]["limits"]["cpu"],
893882
head_mem=rc["spec"]["headGroupSpec"]["template"]["spec"]["containers"][0][
894883
"resources"
895884
]["limits"]["memory"],
896-
head_gpu=rc["spec"]["headGroupSpec"]["template"]["spec"]["containers"][0][
897-
"resources"
898-
]["limits"]["nvidia.com/gpu"],
885+
head_extended_resources=head_extended_resources,
899886
dashboard=dashboard_url,
900887
)
901888

@@ -920,12 +907,12 @@ def _copy_to_ray(cluster: Cluster) -> RayCluster:
920907
worker_mem_min=cluster.config.worker_memory_requests,
921908
worker_mem_max=cluster.config.worker_memory_limits,
922909
worker_cpu=cluster.config.worker_cpu_requests,
923-
worker_gpu=cluster.config.num_worker_gpus,
910+
worker_extended_resources=cluster.config.worker_extended_resource_requests,
924911
namespace=cluster.config.namespace,
925912
dashboard=cluster.cluster_dashboard_uri(),
926913
head_cpus=cluster.config.head_cpus,
927914
head_mem=cluster.config.head_memory,
928-
head_gpu=cluster.config.num_head_gpus,
915+
head_extended_resources=cluster.config.head_extended_resource_requests,
929916
)
930917
if ray.status == CodeFlareClusterStatus.READY:
931918
ray.status = RayClusterStatus.READY

src/codeflare_sdk/cluster/config.py

+81-2
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,51 @@
2525

2626
dir = pathlib.Path(__file__).parent.parent.resolve()
2727

28+
# https://docs.ray.io/en/latest/ray-core/scheduling/accelerators.html
29+
DEFAULT_RESOURCE_MAPPING = {
30+
"nvidia.com/gpu": "GPU",
31+
"intel.com/gpu": "GPU",
32+
"amd.com/gpu": "GPU",
33+
"aws.amazon.com/neuroncore": "neuron_cores",
34+
"google.com/tpu": "TPU",
35+
"habana.ai/gaudi": "HPU",
36+
"huawei.com/Ascend910": "NPU",
37+
"huawei.com/Ascend310": "NPU",
38+
}
39+
2840

2941
@dataclass
3042
class ClusterConfiguration:
3143
"""
3244
This dataclass is used to specify resource requirements and other details, and
3345
is passed in as an argument when creating a Cluster object.
46+
47+
Attributes:
48+
- name: The name of the cluster.
49+
- namespace: The namespace in which the cluster should be created.
50+
- head_info: A list of strings containing information about the head node.
51+
- head_cpus: The number of CPUs to allocate to the head node.
52+
- head_memory: The amount of memory to allocate to the head node.
53+
- head_gpus: The number of GPUs to allocate to the head node. (Deprecated, use head_extended_resource_requests)
54+
- head_extended_resource_requests: A dictionary of extended resource requests for the head node. ex: {"nvidia.com/gpu": 1}
55+
- machine_types: A list of machine types to use for the cluster.
56+
- min_cpus: The minimum number of CPUs to allocate to each worker.
57+
- max_cpus: The maximum number of CPUs to allocate to each worker.
58+
- num_workers: The number of workers to create.
59+
- min_memory: The minimum amount of memory to allocate to each worker.
60+
- max_memory: The maximum amount of memory to allocate to each worker.
61+
- num_gpus: The number of GPUs to allocate to each worker. (Deprecated, use worker_extended_resource_requests)
62+
- template: The path to the template file to use for the cluster.
63+
- appwrapper: A boolean indicating whether to use an AppWrapper.
64+
- envs: A dictionary of environment variables to set for the cluster.
65+
- image: The image to use for the cluster.
66+
- image_pull_secrets: A list of image pull secrets to use for the cluster.
67+
- write_to_file: A boolean indicating whether to write the cluster configuration to a file.
68+
- verify_tls: A boolean indicating whether to verify TLS when connecting to the cluster.
69+
- labels: A dictionary of labels to apply to the cluster.
70+
- worker_extended_resource_requests: A dictionary of extended resource requests for each worker. ex: {"nvidia.com/gpu": 1}
71+
- custom_resource_mapping: A dictionary of custom resource mappings to map extended resource requests to RayCluster resource names
72+
- overwrite_default_resource_mapping: A boolean indicating whether to overwrite the default resource mapping.
3473
"""
3574

3675
name: str
@@ -40,6 +79,7 @@ class ClusterConfiguration:
4079
head_memory: typing.Union[int, str] = 8
4180
head_gpus: int = None # Deprecating
4281
num_head_gpus: int = 0
82+
head_extended_resource_requests: typing.Dict[str, int] = field(default_factory=dict)
4383
machine_types: list = field(default_factory=list) # ["m4.xlarge", "g4dn.xlarge"]
4484
worker_cpu_requests: typing.Union[int, str] = 1
4585
worker_cpu_limits: typing.Union[int, str] = 1
@@ -60,6 +100,11 @@ class ClusterConfiguration:
60100
write_to_file: bool = False
61101
verify_tls: bool = True
62102
labels: dict = field(default_factory=dict)
103+
worker_extended_resource_requests: typing.Dict[str, int] = field(
104+
default_factory=dict
105+
)
106+
custom_resource_mapping: typing.Dict[str, str] = field(default_factory=dict)
107+
overwrite_default_resource_mapping: bool = False
63108

64109
def __post_init__(self):
65110
if not self.verify_tls:
@@ -69,9 +114,43 @@ def __post_init__(self):
69114

70115
self._memory_to_string()
71116
self._str_mem_no_unit_add_GB()
117+
self._old_gpu_arg_conversion()
72118
self._memory_to_resource()
73-
self._gpu_to_resource()
74119
self._cpu_to_resource()
120+
self._gpu_to_resource()
121+
self._combine_custom_resource_mapping()
122+
123+
def _combine_custom_resource_mapping(self):
124+
if overwritten := set(self.custom_resource_mapping.keys()).intersection(
125+
DEFAULT_RESOURCE_MAPPING.keys()
126+
):
127+
if self.overwrite_default_resource_mapping:
128+
warnings.warn(
129+
f"Overwriting default resource mapping for {overwritten}",
130+
UserWarning,
131+
)
132+
else:
133+
raise ValueError(
134+
f"Resource mapping already exists for {overwritten}, set overwrite_default_resource_mapping to True to overwrite"
135+
)
136+
self.custom_resource_mapping = {
137+
**DEFAULT_RESOURCE_MAPPING,
138+
**self.custom_resource_mapping,
139+
}
140+
141+
def _gpu_to_resource(self):
142+
if self.num_head_gpus:
143+
if "nvidia.com/gpu" in self.head_extended_resource_requests:
144+
raise ValueError(
145+
"nvidia.com/gpu already exists in head_custom_resource_requests"
146+
)
147+
self.head_extended_resource_requests["nvidia.com/gpu"] = self.num_head_gpus
148+
if self.num_worker_gpus:
149+
if "nvidia.com/gpu" in self.worker_extended_resource_requests:
150+
raise ValueError(
151+
"nvidia.com/gpu already exists in worker_custom_resource_requests"
152+
)
153+
self.worker_extended_resource_requests["nvidia.com/gpu"] = self.num_worker_gpus
75154

76155
def _str_mem_no_unit_add_GB(self):
77156
if isinstance(self.head_memory, str) and self.head_memory.isdecimal():
@@ -95,7 +174,7 @@ def _memory_to_string(self):
95174
if isinstance(self.worker_memory_limits, int):
96175
self.worker_memory_limits = f"{self.worker_memory_limits}G"
97176

98-
def _gpu_to_resource(self):
177+
def _old_gpu_arg_conversion(self):
99178
if self.head_gpus:
100179
warnings.warn("head_gpus is being deprecated, use num_head_gpus")
101180
self.num_head_gpus = self.head_gpus

src/codeflare_sdk/cluster/model.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@
1818
dataclasses to store information for Ray clusters and AppWrappers.
1919
"""
2020

21-
from dataclasses import dataclass
21+
from dataclasses import dataclass, field
2222
from enum import Enum
23+
import typing
2324

2425

2526
class RayClusterStatus(Enum):
@@ -74,14 +75,14 @@ class RayCluster:
7475
status: RayClusterStatus
7576
head_cpus: int
7677
head_mem: str
77-
head_gpu: int
7878
workers: int
7979
worker_mem_min: str
8080
worker_mem_max: str
8181
worker_cpu: int
82-
worker_gpu: int
8382
namespace: str
8483
dashboard: str
84+
worker_extended_resources: typing.Dict[str, int] = field(default_factory=dict)
85+
head_extended_resources: typing.Dict[str, int] = field(default_factory=dict)
8586

8687

8788
@dataclass

src/codeflare_sdk/templates/base-template.yaml

-4
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,9 @@ spec:
8686
limits:
8787
cpu: 2
8888
memory: "8G"
89-
nvidia.com/gpu: 0
9089
requests:
9190
cpu: 2
9291
memory: "8G"
93-
nvidia.com/gpu: 0
9492
volumeMounts:
9593
- mountPath: /etc/pki/tls/certs/odh-trusted-ca-bundle.crt
9694
name: odh-trusted-ca-cert
@@ -163,11 +161,9 @@ spec:
163161
limits:
164162
cpu: "2"
165163
memory: "12G"
166-
nvidia.com/gpu: "1"
167164
requests:
168165
cpu: "2"
169166
memory: "12G"
170-
nvidia.com/gpu: "1"
171167
volumeMounts:
172168
- mountPath: /etc/pki/tls/certs/odh-trusted-ca-bundle.crt
173169
name: odh-trusted-ca-cert

0 commit comments

Comments
 (0)