Skip to content

Commit 7b5bacc

Browse files
committed
simplify function calls and add option for custom resources
Signed-off-by: Kevin <[email protected]>
1 parent 6a9b185 commit 7b5bacc

File tree

3 files changed

+138
-190
lines changed

3 files changed

+138
-190
lines changed

src/codeflare_sdk/cluster/cluster.py

+1-47
Original file line numberDiff line numberDiff line change
@@ -165,53 +165,7 @@ def create_app_wrapper(self):
165165
else:
166166
priority_val = None
167167

168-
name = self.config.name
169-
namespace = self.config.namespace
170-
head_cpus = self.config.head_cpus
171-
head_memory = self.config.head_memory
172-
head_gpus = self.config.head_gpus
173-
min_cpu = self.config.min_cpus
174-
max_cpu = self.config.max_cpus
175-
min_memory = self.config.min_memory
176-
max_memory = self.config.max_memory
177-
gpu = self.config.num_gpus
178-
workers = self.config.num_workers
179-
template = self.config.template
180-
image = self.config.image
181-
instascale = self.config.instascale
182-
mcad = self.config.mcad
183-
instance_types = self.config.machine_types
184-
env = self.config.envs
185-
image_pull_secrets = self.config.image_pull_secrets
186-
dispatch_priority = self.config.dispatch_priority
187-
write_to_file = self.config.write_to_file
188-
verify_tls = self.config.verify_tls
189-
local_queue = self.config.local_queue
190-
return generate_appwrapper(
191-
name=name,
192-
namespace=namespace,
193-
head_cpus=head_cpus,
194-
head_memory=head_memory,
195-
head_gpus=head_gpus,
196-
min_cpu=min_cpu,
197-
max_cpu=max_cpu,
198-
min_memory=min_memory,
199-
max_memory=max_memory,
200-
gpu=gpu,
201-
workers=workers,
202-
template=template,
203-
image=image,
204-
instascale=instascale,
205-
mcad=mcad,
206-
instance_types=instance_types,
207-
env=env,
208-
image_pull_secrets=image_pull_secrets,
209-
dispatch_priority=dispatch_priority,
210-
priority_val=priority_val,
211-
write_to_file=write_to_file,
212-
verify_tls=verify_tls,
213-
local_queue=local_queue,
214-
)
168+
return generate_appwrapper(self)
215169

216170
# creates a new cluster with the provided or default spec
217171
def up(self):

src/codeflare_sdk/cluster/config.py

+47
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,22 @@
2121
from dataclasses import dataclass, field
2222
import pathlib
2323
import typing
24+
import warnings
2425

2526
dir = pathlib.Path(__file__).parent.parent.resolve()
2627

28+
# https://docs.ray.io/en/latest/ray-core/scheduling/accelerators.html
29+
DEFAULT_RESOURCE_MAPPING = {
30+
"nvidia.com/gpu": "GPU",
31+
"intel.com/gpu": "GPU",
32+
"amd.com/gpu": "GPU",
33+
"aws.amazon.com/neuroncore": "neuron_cores",
34+
"google.com/tpu": "TPU",
35+
"habana.ai/gaudi": "HPU",
36+
"huawei.com/Ascend910": "NPU",
37+
"huawei.com/Ascend310": "NPU",
38+
}
39+
2740

2841
@dataclass
2942
class ClusterConfiguration:
@@ -38,6 +51,7 @@ class ClusterConfiguration:
3851
head_cpus: typing.Union[int, str] = 2
3952
head_memory: typing.Union[int, str] = 8
4053
head_gpus: int = 0
54+
head_custom_resource_requests: typing.Dict[str, int] = field(default_factory=dict)
4155
machine_types: list = field(default_factory=list) # ["m4.xlarge", "g4dn.xlarge"]
4256
min_cpus: typing.Union[int, str] = 1
4357
max_cpus: typing.Union[int, str] = 1
@@ -54,6 +68,9 @@ class ClusterConfiguration:
5468
dispatch_priority: str = None
5569
write_to_file: bool = False
5670
verify_tls: bool = True
71+
worker_custom_resource_requests: typing.Dict[str, int] = field(default_factory=dict)
72+
custom_resource_mapping: typing.Dict[str, str] = field(default_factory=dict)
73+
overwrite_default_resource_mapping: bool = False
5774

5875
def __post_init__(self):
5976
if not self.verify_tls:
@@ -63,6 +80,36 @@ def __post_init__(self):
6380
self._memory_to_string()
6481
self._str_mem_no_unit_add_GB()
6582

83+
def _combine_custom_resource_mapping(self):
84+
if self.overwrite_default_resource_mapping:
85+
self.custom_resource_mapping = self.worker_custom_resource_requests
86+
else:
87+
if overwritten := self.worker_custom_resource_requests.keys().intersection(
88+
DEFAULT_RESOURCE_MAPPING.keys()
89+
):
90+
warnings.warn(
91+
f"Overwriting default resource mapping for {overwritten}",
92+
UserWarning,
93+
)
94+
self.custom_resource_mapping = {
95+
**DEFAULT_RESOURCE_MAPPING,
96+
**self.worker_custom_resource_requests,
97+
}
98+
99+
def _gpu_to_resource(self):
100+
if self.head_gpus:
101+
if "nvidia.com/gpu" in self.head_custom_resource_requests:
102+
raise ValueError(
103+
"nvidia.com/gpu already exists in head_custom_resource_requests"
104+
)
105+
self.head_custom_resource_requests["nvidia.com/gpu"] = self.head_gpus
106+
if self.num_gpus:
107+
if "nvidia.com/gpu" in self.worker_custom_resource_requests:
108+
raise ValueError(
109+
"nvidia.com/gpu already exists in worker_custom_resource_requests"
110+
)
111+
self.worker_custom_resource_requests["nvidia.com/gpu"] = self.num_gpus
112+
66113
def _str_mem_no_unit_add_GB(self):
67114
if isinstance(self.head_memory, str) and self.head_memory.isdecimal():
68115
self.head_memory = f"{self.head_memory}G"

0 commit comments

Comments
 (0)