Skip to content

Commit 80b1aad

Browse files
committed
simplify function calls and add option for custom resources
Signed-off-by: Kevin <[email protected]>
1 parent 1ab5421 commit 80b1aad

12 files changed

+267
-185
lines changed

src/codeflare_sdk/cluster/cluster.py

+40-53
Original file line numberDiff line numberDiff line change
@@ -131,48 +131,7 @@ def create_app_wrapper(self):
131131
# Validate image configuration
132132
self.validate_image_config()
133133

134-
# Before attempting to create the cluster AW, let's evaluate the ClusterConfig
135-
136-
name = self.config.name
137-
namespace = self.config.namespace
138-
head_cpus = self.config.head_cpus
139-
head_memory = self.config.head_memory
140-
num_head_gpus = self.config.num_head_gpus
141-
worker_cpu_requests = self.config.worker_cpu_requests
142-
worker_cpu_limits = self.config.worker_cpu_limits
143-
worker_memory_requests = self.config.worker_memory_requests
144-
worker_memory_limits = self.config.worker_memory_limits
145-
num_worker_gpus = self.config.num_worker_gpus
146-
workers = self.config.num_workers
147-
template = self.config.template
148-
image = self.config.image
149-
appwrapper = self.config.appwrapper
150-
env = self.config.envs
151-
image_pull_secrets = self.config.image_pull_secrets
152-
write_to_file = self.config.write_to_file
153-
local_queue = self.config.local_queue
154-
labels = self.config.labels
155-
return generate_appwrapper(
156-
name=name,
157-
namespace=namespace,
158-
head_cpus=head_cpus,
159-
head_memory=head_memory,
160-
num_head_gpus=num_head_gpus,
161-
worker_cpu_requests=worker_cpu_requests,
162-
worker_cpu_limits=worker_cpu_limits,
163-
worker_memory_requests=worker_memory_requests,
164-
worker_memory_limits=worker_memory_limits,
165-
num_worker_gpus=num_worker_gpus,
166-
workers=workers,
167-
template=template,
168-
image=image,
169-
appwrapper=appwrapper,
170-
env=env,
171-
image_pull_secrets=image_pull_secrets,
172-
write_to_file=write_to_file,
173-
local_queue=local_queue,
174-
labels=labels,
175-
)
134+
return generate_appwrapper(self)
176135

177136
# creates a new cluster with the provided or default spec
178137
def up(self):
@@ -456,6 +415,29 @@ def job_logs(self, job_id: str) -> str:
456415
"""
457416
return self.job_client.get_job_logs(job_id)
458417

418+
@staticmethod
419+
def _head_worker_extended_resources_from_rc_dict(rc: Dict) -> Tuple[dict, dict]:
420+
head_extended_resources, worker_extended_resources = {}, {}
421+
for resource in rc["spec"]["workerGroupSpecs"][0]["template"]["spec"][
422+
"containers"
423+
][0]["resources"]["limits"].keys():
424+
if resource in ["memory", "cpu"]:
425+
continue
426+
worker_extended_resources[resource] = rc["spec"]["workerGroupSpecs"][0][
427+
"template"
428+
]["spec"]["containers"][0]["resources"]["limits"][resource]
429+
430+
for resource in rc["spec"]["headGroupSpec"]["template"]["spec"]["containers"][
431+
0
432+
]["resources"]["limits"].keys():
433+
if resource in ["memory", "cpu"]:
434+
continue
435+
head_extended_resources[resource] = rc["spec"]["headGroupSpec"]["template"][
436+
"spec"
437+
]["containers"][0]["resources"]["limits"][resource]
438+
439+
return head_extended_resources, worker_extended_resources
440+
459441
def from_k8_cluster_object(
460442
rc,
461443
appwrapper=True,
@@ -469,6 +451,11 @@ def from_k8_cluster_object(
469451
else []
470452
)
471453

454+
(
455+
head_extended_resources,
456+
worker_extended_resources,
457+
) = Cluster._head_worker_extended_resources_from_rc_dict(rc)
458+
472459
cluster_config = ClusterConfiguration(
473460
name=rc["metadata"]["name"],
474461
namespace=rc["metadata"]["namespace"],
@@ -486,11 +473,8 @@ def from_k8_cluster_object(
486473
worker_memory_limits=rc["spec"]["workerGroupSpecs"][0]["template"]["spec"][
487474
"containers"
488475
][0]["resources"]["limits"]["memory"],
489-
num_worker_gpus=int(
490-
rc["spec"]["workerGroupSpecs"][0]["template"]["spec"]["containers"][0][
491-
"resources"
492-
]["limits"]["nvidia.com/gpu"]
493-
),
476+
worker_extended_resource_requests=worker_extended_resources,
477+
head_extended_resource_requests=head_extended_resources,
494478
image=rc["spec"]["workerGroupSpecs"][0]["template"]["spec"]["containers"][
495479
0
496480
]["image"],
@@ -871,6 +855,11 @@ def _map_to_ray_cluster(rc) -> Optional[RayCluster]:
871855
protocol = "https"
872856
dashboard_url = f"{protocol}://{ingress.spec.rules[0].host}"
873857

858+
(
859+
head_extended_resources,
860+
worker_extended_resources,
861+
) = Cluster._head_worker_extended_resources_from_rc_dict(rc)
862+
874863
return RayCluster(
875864
name=rc["metadata"]["name"],
876865
status=status,
@@ -885,17 +874,15 @@ def _map_to_ray_cluster(rc) -> Optional[RayCluster]:
885874
worker_cpu=rc["spec"]["workerGroupSpecs"][0]["template"]["spec"]["containers"][
886875
0
887876
]["resources"]["limits"]["cpu"],
888-
worker_gpu=0, # hard to detect currently how many gpus, can override it with what the user asked for
877+
worker_extended_resources=worker_extended_resources,
889878
namespace=rc["metadata"]["namespace"],
890879
head_cpus=rc["spec"]["headGroupSpec"]["template"]["spec"]["containers"][0][
891880
"resources"
892881
]["limits"]["cpu"],
893882
head_mem=rc["spec"]["headGroupSpec"]["template"]["spec"]["containers"][0][
894883
"resources"
895884
]["limits"]["memory"],
896-
head_gpu=rc["spec"]["headGroupSpec"]["template"]["spec"]["containers"][0][
897-
"resources"
898-
]["limits"]["nvidia.com/gpu"],
885+
head_extended_resources=head_extended_resources,
899886
dashboard=dashboard_url,
900887
)
901888

@@ -920,12 +907,12 @@ def _copy_to_ray(cluster: Cluster) -> RayCluster:
920907
worker_mem_min=cluster.config.worker_memory_requests,
921908
worker_mem_max=cluster.config.worker_memory_limits,
922909
worker_cpu=cluster.config.worker_cpu_requests,
923-
worker_gpu=cluster.config.num_worker_gpus,
910+
worker_extended_resources=cluster.config.worker_extended_resource_requests,
924911
namespace=cluster.config.namespace,
925912
dashboard=cluster.cluster_dashboard_uri(),
926913
head_cpus=cluster.config.head_cpus,
927914
head_mem=cluster.config.head_memory,
928-
head_gpu=cluster.config.num_head_gpus,
915+
head_extended_resources=cluster.config.head_extended_resource_requests,
929916
)
930917
if ray.status == CodeFlareClusterStatus.READY:
931918
ray.status = RayClusterStatus.READY

src/codeflare_sdk/cluster/config.py

+102-4
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,51 @@
2525

2626
dir = pathlib.Path(__file__).parent.parent.resolve()
2727

28+
# https://docs.ray.io/en/latest/ray-core/scheduling/accelerators.html
29+
DEFAULT_RESOURCE_MAPPING = {
30+
"nvidia.com/gpu": "GPU",
31+
"intel.com/gpu": "GPU",
32+
"amd.com/gpu": "GPU",
33+
"aws.amazon.com/neuroncore": "neuron_cores",
34+
"google.com/tpu": "TPU",
35+
"habana.ai/gaudi": "HPU",
36+
"huawei.com/Ascend910": "NPU",
37+
"huawei.com/Ascend310": "NPU",
38+
}
39+
2840

2941
@dataclass
3042
class ClusterConfiguration:
3143
"""
3244
This dataclass is used to specify resource requirements and other details, and
3345
is passed in as an argument when creating a Cluster object.
46+
47+
Attributes:
48+
- name: The name of the cluster.
49+
- namespace: The namespace in which the cluster should be created.
50+
- head_info: A list of strings containing information about the head node.
51+
- head_cpus: The number of CPUs to allocate to the head node.
52+
- head_memory: The amount of memory to allocate to the head node.
53+
- head_gpus: The number of GPUs to allocate to the head node. (Deprecated, use head_extended_resource_requests)
54+
- head_extended_resource_requests: A dictionary of extended resource requests for the head node. ex: {"nvidia.com/gpu": 1}
55+
- machine_types: A list of machine types to use for the cluster.
56+
- min_cpus: The minimum number of CPUs to allocate to each worker.
57+
- max_cpus: The maximum number of CPUs to allocate to each worker.
58+
- num_workers: The number of workers to create.
59+
- min_memory: The minimum amount of memory to allocate to each worker.
60+
- max_memory: The maximum amount of memory to allocate to each worker.
61+
- num_gpus: The number of GPUs to allocate to each worker. (Deprecated, use worker_extended_resource_requests)
62+
- template: The path to the template file to use for the cluster.
63+
- appwrapper: A boolean indicating whether to use an AppWrapper.
64+
- envs: A dictionary of environment variables to set for the cluster.
65+
- image: The image to use for the cluster.
66+
- image_pull_secrets: A list of image pull secrets to use for the cluster.
67+
- write_to_file: A boolean indicating whether to write the cluster configuration to a file.
68+
- verify_tls: A boolean indicating whether to verify TLS when connecting to the cluster.
69+
- labels: A dictionary of labels to apply to the cluster.
70+
- worker_extended_resource_requests: A dictionary of extended resource requests for each worker. ex: {"nvidia.com/gpu": 1}
71+
- extended_resource_mapping: A dictionary of custom resource mappings to map extended resource requests to RayCluster resource names
72+
- overwrite_default_resource_mapping: A boolean indicating whether to overwrite the default resource mapping.
3473
"""
3574

3675
name: str
@@ -39,7 +78,7 @@ class ClusterConfiguration:
3978
head_cpus: typing.Union[int, str] = 2
4079
head_memory: typing.Union[int, str] = 8
4180
head_gpus: int = None # Deprecating
42-
num_head_gpus: int = 0
81+
head_extended_resource_requests: typing.Dict[str, int] = field(default_factory=dict)
4382
machine_types: list = field(default_factory=list) # ["m4.xlarge", "g4dn.xlarge"]
4483
worker_cpu_requests: typing.Union[int, str] = 1
4584
worker_cpu_limits: typing.Union[int, str] = 1
@@ -50,7 +89,6 @@ class ClusterConfiguration:
5089
worker_memory_limits: typing.Union[int, str] = 2
5190
min_memory: typing.Union[int, str] = None # Deprecating
5291
max_memory: typing.Union[int, str] = None # Deprecating
53-
num_worker_gpus: int = 0
5492
num_gpus: int = None # Deprecating
5593
template: str = f"{dir}/templates/base-template.yaml"
5694
appwrapper: bool = False
@@ -60,6 +98,11 @@ class ClusterConfiguration:
6098
write_to_file: bool = False
6199
verify_tls: bool = True
62100
labels: dict = field(default_factory=dict)
101+
worker_extended_resource_requests: typing.Dict[str, int] = field(
102+
default_factory=dict
103+
)
104+
extended_resource_mapping: typing.Dict[str, str] = field(default_factory=dict)
105+
overwrite_default_resource_mapping: bool = False
63106

64107
def __post_init__(self):
65108
if not self.verify_tls:
@@ -69,9 +112,64 @@ def __post_init__(self):
69112

70113
self._memory_to_string()
71114
self._str_mem_no_unit_add_GB()
115+
self._old_gpu_arg_conversion()
72116
self._memory_to_resource()
73-
self._gpu_to_resource()
74117
self._cpu_to_resource()
118+
self._gpu_to_resource()
119+
self._combine_extended_resource_mapping()
120+
self._validate_extended_resource_requests(self.head_extended_resource_requests)
121+
self._validate_extended_resource_requests(
122+
self.worker_extended_resource_requests
123+
)
124+
125+
def _combine_extended_resource_mapping(self):
126+
if overwritten := set(self.extended_resource_mapping.keys()).intersection(
127+
DEFAULT_RESOURCE_MAPPING.keys()
128+
):
129+
if self.overwrite_default_resource_mapping:
130+
warnings.warn(
131+
f"Overwriting default resource mapping for {overwritten}",
132+
UserWarning,
133+
)
134+
else:
135+
raise ValueError(
136+
f"Resource mapping already exists for {overwritten}, set overwrite_default_resource_mapping to True to overwrite"
137+
)
138+
self.extended_resource_mapping = {
139+
**DEFAULT_RESOURCE_MAPPING,
140+
**self.extended_resource_mapping,
141+
}
142+
143+
def _validate_extended_resource_requests(
144+
self, extended_resources: typing.Dict[str, int]
145+
):
146+
for k in extended_resources.keys():
147+
if k not in self.extended_resource_mapping.keys():
148+
raise ValueError(
149+
f"extended resource '{k}' not found in extended_resource_mapping, available resources are {list(self.extended_resource_mapping.keys())}, to add more supported resources use extended_resource_mapping. i.e. extended_resource_mapping = {{'{k}': 'FOO_BAR'}}"
150+
)
151+
152+
def _gpu_to_resource(self):
153+
if self.head_gpus:
154+
warnings.warn(
155+
"head_gpus is being deprecated, use head_extended_resource_requests i.e. head_extended_resource_requests = {'nvidia.com/gpu': 1}"
156+
)
157+
if "nvidia.com/gpu" in self.head_extended_resource_requests:
158+
raise ValueError(
159+
"nvidia.com/gpu already exists in head_extended_resource_requests"
160+
)
161+
self.head_extended_resource_requests["nvidia.com/gpu"] = self.num_head_gpus
162+
if self.num_gpus:
163+
warnings.warn(
164+
"num_gpus is being deprecated, use worker_extended_resource_requests instead i.e. worker_extended_resource_requests = {'nvidia.com/gpu': 1}"
165+
)
166+
if "nvidia.com/gpu" in self.worker_extended_resource_requests:
167+
raise ValueError(
168+
"nvidia.com/gpu already exists in worker_extended_resource_requests"
169+
)
170+
self.worker_extended_resource_requests[
171+
"nvidia.com/gpu"
172+
] = self.num_worker_gpus
75173

76174
def _str_mem_no_unit_add_GB(self):
77175
if isinstance(self.head_memory, str) and self.head_memory.isdecimal():
@@ -95,7 +193,7 @@ def _memory_to_string(self):
95193
if isinstance(self.worker_memory_limits, int):
96194
self.worker_memory_limits = f"{self.worker_memory_limits}G"
97195

98-
def _gpu_to_resource(self):
196+
def _old_gpu_arg_conversion(self):
99197
if self.head_gpus:
100198
warnings.warn("head_gpus is being deprecated, use num_head_gpus")
101199
self.num_head_gpus = self.head_gpus

src/codeflare_sdk/cluster/model.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@
1818
dataclasses to store information for Ray clusters and AppWrappers.
1919
"""
2020

21-
from dataclasses import dataclass
21+
from dataclasses import dataclass, field
2222
from enum import Enum
23+
import typing
2324

2425

2526
class RayClusterStatus(Enum):
@@ -74,14 +75,14 @@ class RayCluster:
7475
status: RayClusterStatus
7576
head_cpus: int
7677
head_mem: str
77-
head_gpu: int
7878
workers: int
7979
worker_mem_min: str
8080
worker_mem_max: str
8181
worker_cpu: int
82-
worker_gpu: int
8382
namespace: str
8483
dashboard: str
84+
worker_extended_resources: typing.Dict[str, int] = field(default_factory=dict)
85+
head_extended_resources: typing.Dict[str, int] = field(default_factory=dict)
8586

8687

8788
@dataclass

src/codeflare_sdk/templates/base-template.yaml

-4
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,9 @@ spec:
8686
limits:
8787
cpu: 2
8888
memory: "8G"
89-
nvidia.com/gpu: 0
9089
requests:
9190
cpu: 2
9291
memory: "8G"
93-
nvidia.com/gpu: 0
9492
volumeMounts:
9593
- mountPath: /etc/pki/tls/certs/odh-trusted-ca-bundle.crt
9694
name: odh-trusted-ca-cert
@@ -163,11 +161,9 @@ spec:
163161
limits:
164162
cpu: "2"
165163
memory: "12G"
166-
nvidia.com/gpu: "1"
167164
requests:
168165
cpu: "2"
169166
memory: "12G"
170-
nvidia.com/gpu: "1"
171167
volumeMounts:
172168
- mountPath: /etc/pki/tls/certs/odh-trusted-ca-bundle.crt
173169
name: odh-trusted-ca-cert

0 commit comments

Comments
 (0)