21
21
from dataclasses import dataclass , field
22
22
import pathlib
23
23
import typing
24
+ import warnings
24
25
25
26
dir = pathlib .Path (__file__ ).parent .parent .resolve ()
26
27
28
+ # https://docs.ray.io/en/latest/ray-core/scheduling/accelerators.html
29
+ DEFAULT_RESOURCE_MAPPING = {
30
+ "nvidia.com/gpu" : "GPU" ,
31
+ "intel.com/gpu" : "GPU" ,
32
+ "amd.com/gpu" : "GPU" ,
33
+ "aws.amazon.com/neuroncore" : "neuron_cores" ,
34
+ "google.com/tpu" : "TPU" ,
35
+ "habana.ai/gaudi" : "HPU" ,
36
+ "huawei.com/Ascend910" : "NPU" ,
37
+ "huawei.com/Ascend310" : "NPU" ,
38
+ }
39
+
27
40
28
41
@dataclass
29
42
class ClusterConfiguration :
@@ -38,6 +51,7 @@ class ClusterConfiguration:
38
51
head_cpus : typing .Union [int , str ] = 2
39
52
head_memory : typing .Union [int , str ] = 8
40
53
head_gpus : int = 0
54
+ head_custom_resource_requests : typing .Dict [str , int ] = field (default_factory = dict )
41
55
machine_types : list = field (default_factory = list ) # ["m4.xlarge", "g4dn.xlarge"]
42
56
min_cpus : typing .Union [int , str ] = 1
43
57
max_cpus : typing .Union [int , str ] = 1
@@ -54,6 +68,9 @@ class ClusterConfiguration:
54
68
dispatch_priority : str = None
55
69
write_to_file : bool = False
56
70
verify_tls : bool = True
71
+ worker_custom_resource_requests : typing .Dict [str , int ] = field (default_factory = dict )
72
+ custom_resource_mapping : typing .Dict [str , str ] = field (default_factory = dict )
73
+ overwrite_default_resource_mapping : bool = False
57
74
58
75
def __post_init__ (self ):
59
76
if not self .verify_tls :
@@ -63,6 +80,36 @@ def __post_init__(self):
63
80
self ._memory_to_string ()
64
81
self ._str_mem_no_unit_add_GB ()
65
82
83
+ def _combine_custom_resource_mapping (self ):
84
+ if self .overwrite_default_resource_mapping :
85
+ self .custom_resource_mapping = self .worker_custom_resource_requests
86
+ else :
87
+ if overwritten := self .worker_custom_resource_requests .keys ().intersection (
88
+ DEFAULT_RESOURCE_MAPPING .keys ()
89
+ ):
90
+ warnings .warn (
91
+ f"Overwriting default resource mapping for { overwritten } " ,
92
+ UserWarning ,
93
+ )
94
+ self .custom_resource_mapping = {
95
+ ** DEFAULT_RESOURCE_MAPPING ,
96
+ ** self .worker_custom_resource_requests ,
97
+ }
98
+
99
+ def _gpu_to_resource (self ):
100
+ if self .head_gpus :
101
+ if "nvidia.com/gpu" in self .head_custom_resource_requests :
102
+ raise ValueError (
103
+ "nvidia.com/gpu already exists in head_custom_resource_requests"
104
+ )
105
+ self .head_custom_resource_requests ["nvidia.com/gpu" ] = self .head_gpus
106
+ if self .num_gpus :
107
+ if "nvidia.com/gpu" in self .worker_custom_resource_requests :
108
+ raise ValueError (
109
+ "nvidia.com/gpu already exists in worker_custom_resource_requests"
110
+ )
111
+ self .worker_custom_resource_requests ["nvidia.com/gpu" ] = self .num_gpus
112
+
66
113
def _str_mem_no_unit_add_GB (self ):
67
114
if isinstance (self .head_memory , str ) and self .head_memory .isdecimal ():
68
115
self .head_memory = f"{ self .head_memory } G"
0 commit comments