Skip to content

Commit 7a3d5f5

Browse files
committed
simplify function calls and add option for custom resources
Signed-off-by: Kevin <[email protected]>
1 parent 2a85469 commit 7a3d5f5

19 files changed

+289
-205
lines changed

src/codeflare_sdk/cluster/cluster.py

+42-54
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from ..utils import pretty_print
3030
from ..utils.generate_yaml import (
3131
generate_appwrapper,
32+
head_worker_gpu_count_from_cluster,
3233
)
3334
from ..utils.kube_api_helpers import _kube_api_error_handling
3435
from ..utils.generate_yaml import is_openshift_cluster
@@ -118,48 +119,7 @@ def create_app_wrapper(self):
118119
f"Namespace {self.config.namespace} is of type {type(self.config.namespace)}. Check your Kubernetes Authentication."
119120
)
120121

121-
# Before attempting to create the cluster AW, let's evaluate the ClusterConfig
122-
123-
name = self.config.name
124-
namespace = self.config.namespace
125-
head_cpus = self.config.head_cpus
126-
head_memory = self.config.head_memory
127-
num_head_gpus = self.config.num_head_gpus
128-
worker_cpu_requests = self.config.worker_cpu_requests
129-
worker_cpu_limits = self.config.worker_cpu_limits
130-
worker_memory_requests = self.config.worker_memory_requests
131-
worker_memory_limits = self.config.worker_memory_limits
132-
num_worker_gpus = self.config.num_worker_gpus
133-
workers = self.config.num_workers
134-
template = self.config.template
135-
image = self.config.image
136-
appwrapper = self.config.appwrapper
137-
env = self.config.envs
138-
image_pull_secrets = self.config.image_pull_secrets
139-
write_to_file = self.config.write_to_file
140-
local_queue = self.config.local_queue
141-
labels = self.config.labels
142-
return generate_appwrapper(
143-
name=name,
144-
namespace=namespace,
145-
head_cpus=head_cpus,
146-
head_memory=head_memory,
147-
num_head_gpus=num_head_gpus,
148-
worker_cpu_requests=worker_cpu_requests,
149-
worker_cpu_limits=worker_cpu_limits,
150-
worker_memory_requests=worker_memory_requests,
151-
worker_memory_limits=worker_memory_limits,
152-
num_worker_gpus=num_worker_gpus,
153-
workers=workers,
154-
template=template,
155-
image=image,
156-
appwrapper=appwrapper,
157-
env=env,
158-
image_pull_secrets=image_pull_secrets,
159-
write_to_file=write_to_file,
160-
local_queue=local_queue,
161-
labels=labels,
162-
)
122+
return generate_appwrapper(self)
163123

164124
# creates a new cluster with the provided or default spec
165125
def up(self):
@@ -305,7 +265,7 @@ def status(
305265

306266
if print_to_console:
307267
# overriding the number of gpus with requested
308-
cluster.worker_gpu = self.config.num_worker_gpus
268+
_, cluster.worker_gpu = head_worker_gpu_count_from_cluster(self)
309269
pretty_print.print_cluster_status(cluster)
310270
elif print_to_console:
311271
if status == CodeFlareClusterStatus.UNKNOWN:
@@ -443,6 +403,29 @@ def job_logs(self, job_id: str) -> str:
443403
"""
444404
return self.job_client.get_job_logs(job_id)
445405

406+
@staticmethod
407+
def _head_worker_extended_resources_from_rc_dict(rc: Dict) -> Tuple[dict, dict]:
408+
head_extended_resources, worker_extended_resources = {}, {}
409+
for resource in rc["spec"]["workerGroupSpecs"][0]["template"]["spec"][
410+
"containers"
411+
][0]["resources"]["limits"].keys():
412+
if resource in ["memory", "cpu"]:
413+
continue
414+
worker_extended_resources[resource] = rc["spec"]["workerGroupSpecs"][0][
415+
"template"
416+
]["spec"]["containers"][0]["resources"]["limits"][resource]
417+
418+
for resource in rc["spec"]["headGroupSpec"]["template"]["spec"]["containers"][
419+
0
420+
]["resources"]["limits"].keys():
421+
if resource in ["memory", "cpu"]:
422+
continue
423+
head_extended_resources[resource] = rc["spec"]["headGroupSpec"]["template"][
424+
"spec"
425+
]["containers"][0]["resources"]["limits"][resource]
426+
427+
return head_extended_resources, worker_extended_resources
428+
446429
def from_k8_cluster_object(
447430
rc,
448431
appwrapper=True,
@@ -456,6 +439,11 @@ def from_k8_cluster_object(
456439
else []
457440
)
458441

442+
(
443+
head_extended_resources,
444+
worker_extended_resources,
445+
) = Cluster._head_worker_extended_resources_from_rc_dict(rc)
446+
459447
cluster_config = ClusterConfiguration(
460448
name=rc["metadata"]["name"],
461449
namespace=rc["metadata"]["namespace"],
@@ -473,11 +461,8 @@ def from_k8_cluster_object(
473461
worker_memory_limits=rc["spec"]["workerGroupSpecs"][0]["template"]["spec"][
474462
"containers"
475463
][0]["resources"]["limits"]["memory"],
476-
num_worker_gpus=int(
477-
rc["spec"]["workerGroupSpecs"][0]["template"]["spec"]["containers"][0][
478-
"resources"
479-
]["limits"]["nvidia.com/gpu"]
480-
),
464+
worker_extended_resource_requests=worker_extended_resources,
465+
head_extended_resource_requests=head_extended_resources,
481466
image=rc["spec"]["workerGroupSpecs"][0]["template"]["spec"]["containers"][
482467
0
483468
]["image"],
@@ -858,6 +843,11 @@ def _map_to_ray_cluster(rc) -> Optional[RayCluster]:
858843
protocol = "https"
859844
dashboard_url = f"{protocol}://{ingress.spec.rules[0].host}"
860845

846+
(
847+
head_extended_resources,
848+
worker_extended_resources,
849+
) = Cluster._head_worker_extended_resources_from_rc_dict(rc)
850+
861851
return RayCluster(
862852
name=rc["metadata"]["name"],
863853
status=status,
@@ -872,17 +862,15 @@ def _map_to_ray_cluster(rc) -> Optional[RayCluster]:
872862
worker_cpu=rc["spec"]["workerGroupSpecs"][0]["template"]["spec"]["containers"][
873863
0
874864
]["resources"]["limits"]["cpu"],
875-
worker_gpu=0, # hard to detect currently how many gpus, can override it with what the user asked for
865+
worker_extended_resources=worker_extended_resources,
876866
namespace=rc["metadata"]["namespace"],
877867
head_cpus=rc["spec"]["headGroupSpec"]["template"]["spec"]["containers"][0][
878868
"resources"
879869
]["limits"]["cpu"],
880870
head_mem=rc["spec"]["headGroupSpec"]["template"]["spec"]["containers"][0][
881871
"resources"
882872
]["limits"]["memory"],
883-
head_gpu=rc["spec"]["headGroupSpec"]["template"]["spec"]["containers"][0][
884-
"resources"
885-
]["limits"]["nvidia.com/gpu"],
873+
head_extended_resources=head_extended_resources,
886874
dashboard=dashboard_url,
887875
)
888876

@@ -907,12 +895,12 @@ def _copy_to_ray(cluster: Cluster) -> RayCluster:
907895
worker_mem_min=cluster.config.worker_memory_requests,
908896
worker_mem_max=cluster.config.worker_memory_limits,
909897
worker_cpu=cluster.config.worker_cpu_requests,
910-
worker_gpu=cluster.config.num_worker_gpus,
898+
worker_extended_resources=cluster.config.worker_extended_resource_requests,
911899
namespace=cluster.config.namespace,
912900
dashboard=cluster.cluster_dashboard_uri(),
913901
head_cpus=cluster.config.head_cpus,
914902
head_mem=cluster.config.head_memory,
915-
head_gpu=cluster.config.num_head_gpus,
903+
head_extended_resources=cluster.config.head_extended_resource_requests,
916904
)
917905
if ray.status == CodeFlareClusterStatus.READY:
918906
ray.status = RayClusterStatus.READY

src/codeflare_sdk/cluster/config.py

+98-11
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,51 @@
2525

2626
dir = pathlib.Path(__file__).parent.parent.resolve()
2727

28+
# https://docs.ray.io/en/latest/ray-core/scheduling/accelerators.html
29+
DEFAULT_RESOURCE_MAPPING = {
30+
"nvidia.com/gpu": "GPU",
31+
"intel.com/gpu": "GPU",
32+
"amd.com/gpu": "GPU",
33+
"aws.amazon.com/neuroncore": "neuron_cores",
34+
"google.com/tpu": "TPU",
35+
"habana.ai/gaudi": "HPU",
36+
"huawei.com/Ascend910": "NPU",
37+
"huawei.com/Ascend310": "NPU",
38+
}
39+
2840

2941
@dataclass
3042
class ClusterConfiguration:
3143
"""
3244
This dataclass is used to specify resource requirements and other details, and
3345
is passed in as an argument when creating a Cluster object.
46+
47+
Attributes:
48+
- name: The name of the cluster.
49+
- namespace: The namespace in which the cluster should be created.
50+
- head_info: A list of strings containing information about the head node.
51+
- head_cpus: The number of CPUs to allocate to the head node.
52+
- head_memory: The amount of memory to allocate to the head node.
53+
- head_gpus: The number of GPUs to allocate to the head node. (Deprecated, use head_extended_resource_requests)
54+
- head_extended_resource_requests: A dictionary of extended resource requests for the head node. ex: {"nvidia.com/gpu": 1}
55+
- machine_types: A list of machine types to use for the cluster.
56+
- min_cpus: The minimum number of CPUs to allocate to each worker.
57+
- max_cpus: The maximum number of CPUs to allocate to each worker.
58+
- num_workers: The number of workers to create.
59+
- min_memory: The minimum amount of memory to allocate to each worker.
60+
- max_memory: The maximum amount of memory to allocate to each worker.
61+
- num_gpus: The number of GPUs to allocate to each worker. (Deprecated, use worker_extended_resource_requests)
62+
- template: The path to the template file to use for the cluster.
63+
- appwrapper: A boolean indicating whether to use an AppWrapper.
64+
- envs: A dictionary of environment variables to set for the cluster.
65+
- image: The image to use for the cluster.
66+
- image_pull_secrets: A list of image pull secrets to use for the cluster.
67+
- write_to_file: A boolean indicating whether to write the cluster configuration to a file.
68+
- verify_tls: A boolean indicating whether to verify TLS when connecting to the cluster.
69+
- labels: A dictionary of labels to apply to the cluster.
70+
- worker_extended_resource_requests: A dictionary of extended resource requests for each worker. ex: {"nvidia.com/gpu": 1}
71+
- extended_resource_mapping: A dictionary of custom resource mappings to map extended resource requests to RayCluster resource names
72+
- overwrite_default_resource_mapping: A boolean indicating whether to overwrite the default resource mapping.
3473
"""
3574

3675
name: str
@@ -39,7 +78,7 @@ class ClusterConfiguration:
3978
head_cpus: typing.Union[int, str] = 2
4079
head_memory: typing.Union[int, str] = 8
4180
head_gpus: int = None # Deprecating
42-
num_head_gpus: int = 0
81+
head_extended_resource_requests: typing.Dict[str, int] = field(default_factory=dict)
4382
machine_types: list = field(default_factory=list) # ["m4.xlarge", "g4dn.xlarge"]
4483
worker_cpu_requests: typing.Union[int, str] = 1
4584
worker_cpu_limits: typing.Union[int, str] = 1
@@ -50,7 +89,6 @@ class ClusterConfiguration:
5089
worker_memory_limits: typing.Union[int, str] = 2
5190
min_memory: typing.Union[int, str] = None # Deprecating
5291
max_memory: typing.Union[int, str] = None # Deprecating
53-
num_worker_gpus: int = 0
5492
num_gpus: int = None # Deprecating
5593
template: str = f"{dir}/templates/base-template.yaml"
5694
appwrapper: bool = False
@@ -60,6 +98,11 @@ class ClusterConfiguration:
6098
write_to_file: bool = False
6199
verify_tls: bool = True
62100
labels: dict = field(default_factory=dict)
101+
worker_extended_resource_requests: typing.Dict[str, int] = field(
102+
default_factory=dict
103+
)
104+
extended_resource_mapping: typing.Dict[str, str] = field(default_factory=dict)
105+
overwrite_default_resource_mapping: bool = False
63106

64107
def __post_init__(self):
65108
if not self.verify_tls:
@@ -70,8 +113,60 @@ def __post_init__(self):
70113
self._memory_to_string()
71114
self._str_mem_no_unit_add_GB()
72115
self._memory_to_resource()
73-
self._gpu_to_resource()
74116
self._cpu_to_resource()
117+
self._gpu_to_resource()
118+
self._combine_extended_resource_mapping()
119+
self._validate_extended_resource_requests(self.head_extended_resource_requests)
120+
self._validate_extended_resource_requests(
121+
self.worker_extended_resource_requests
122+
)
123+
124+
def _combine_extended_resource_mapping(self):
125+
if overwritten := set(self.extended_resource_mapping.keys()).intersection(
126+
DEFAULT_RESOURCE_MAPPING.keys()
127+
):
128+
if self.overwrite_default_resource_mapping:
129+
warnings.warn(
130+
f"Overwriting default resource mapping for {overwritten}",
131+
UserWarning,
132+
)
133+
else:
134+
raise ValueError(
135+
f"Resource mapping already exists for {overwritten}, set overwrite_default_resource_mapping to True to overwrite"
136+
)
137+
self.extended_resource_mapping = {
138+
**DEFAULT_RESOURCE_MAPPING,
139+
**self.extended_resource_mapping,
140+
}
141+
142+
def _validate_extended_resource_requests(
143+
self, extended_resources: typing.Dict[str, int]
144+
):
145+
for k in extended_resources.keys():
146+
if k not in self.extended_resource_mapping.keys():
147+
raise ValueError(
148+
f"extended resource '{k}' not found in extended_resource_mapping, available resources are {list(self.extended_resource_mapping.keys())}, to add more supported resources use extended_resource_mapping. i.e. extended_resource_mapping = {{'{k}': 'FOO_BAR'}}"
149+
)
150+
151+
def _gpu_to_resource(self):
152+
if self.head_gpus:
153+
warnings.warn(
154+
f"head_gpus is being deprecated, replacing with head_extended_resource_requests['nvidia.com/gpu'] = {self.head_gpus}"
155+
)
156+
if "nvidia.com/gpu" in self.head_extended_resource_requests:
157+
raise ValueError(
158+
"nvidia.com/gpu already exists in head_extended_resource_requests"
159+
)
160+
self.head_extended_resource_requests["nvidia.com/gpu"] = self.head_gpus
161+
if self.num_gpus:
162+
warnings.warn(
163+
f"num_gpus is being deprecated, replacing with worker_extended_resource_requests['nvidia.com/gpu'] = {self.num_gpus}"
164+
)
165+
if "nvidia.com/gpu" in self.worker_extended_resource_requests:
166+
raise ValueError(
167+
"nvidia.com/gpu already exists in worker_extended_resource_requests"
168+
)
169+
self.worker_extended_resource_requests["nvidia.com/gpu"] = self.num_gpus
75170

76171
def _str_mem_no_unit_add_GB(self):
77172
if isinstance(self.head_memory, str) and self.head_memory.isdecimal():
@@ -95,14 +190,6 @@ def _memory_to_string(self):
95190
if isinstance(self.worker_memory_limits, int):
96191
self.worker_memory_limits = f"{self.worker_memory_limits}G"
97192

98-
def _gpu_to_resource(self):
99-
if self.head_gpus:
100-
warnings.warn("head_gpus is being deprecated, use num_head_gpus")
101-
self.num_head_gpus = self.head_gpus
102-
if self.num_gpus:
103-
warnings.warn("num_gpus is being deprecated, use num_worker_gpus")
104-
self.num_worker_gpus = self.num_gpus
105-
106193
def _cpu_to_resource(self):
107194
if self.min_cpus:
108195
warnings.warn("min_cpus is being deprecated, use worker_cpu_requests")

src/codeflare_sdk/cluster/model.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@
1818
dataclasses to store information for Ray clusters and AppWrappers.
1919
"""
2020

21-
from dataclasses import dataclass
21+
from dataclasses import dataclass, field
2222
from enum import Enum
23+
import typing
2324

2425

2526
class RayClusterStatus(Enum):
@@ -74,14 +75,14 @@ class RayCluster:
7475
status: RayClusterStatus
7576
head_cpus: int
7677
head_mem: str
77-
head_gpu: int
7878
workers: int
7979
worker_mem_min: str
8080
worker_mem_max: str
8181
worker_cpu: int
82-
worker_gpu: int
8382
namespace: str
8483
dashboard: str
84+
worker_extended_resources: typing.Dict[str, int] = field(default_factory=dict)
85+
head_extended_resources: typing.Dict[str, int] = field(default_factory=dict)
8586

8687

8788
@dataclass

src/codeflare_sdk/templates/base-template.yaml

-4
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,9 @@ spec:
8686
limits:
8787
cpu: 2
8888
memory: "8G"
89-
nvidia.com/gpu: 0
9089
requests:
9190
cpu: 2
9291
memory: "8G"
93-
nvidia.com/gpu: 0
9492
volumeMounts:
9593
- mountPath: /etc/pki/tls/certs/odh-trusted-ca-bundle.crt
9694
name: odh-trusted-ca-cert
@@ -163,11 +161,9 @@ spec:
163161
limits:
164162
cpu: "2"
165163
memory: "12G"
166-
nvidia.com/gpu: "1"
167164
requests:
168165
cpu: "2"
169166
memory: "12G"
170-
nvidia.com/gpu: "1"
171167
volumeMounts:
172168
- mountPath: /etc/pki/tls/certs/odh-trusted-ca-bundle.crt
173169
name: odh-trusted-ca-cert

0 commit comments

Comments
 (0)