@@ -211,9 +211,12 @@ def group_infos(self) -> dict[str, ProcessGroupInfo]:
211211 def _initialize_hosts (self ) -> None :
212212 with self ._queue_lock :
213213 self ._hosts : t .List [str ] = sorted (
214- dragon_machine .Node (node ).hostname
215- for node in dragon_machine .System ().nodes
214+ node for node in dragon_machine .System ().nodes
216215 )
216+ self ._nodes = [dragon_machine .Node (node ) for node in self ._hosts ]
217+ self ._cpus = [node .num_cpus for node in self ._nodes ]
218+ self ._gpus = [node .num_gpus for node in self ._nodes ]
219+
217220 """List of hosts available in allocation"""
218221 self ._free_hosts : t .Deque [str ] = collections .deque (self ._hosts )
219222 """List of hosts on which steps can be launched"""
@@ -285,6 +288,34 @@ def current_time(self) -> float:
285288 """Current time for DragonBackend object, in seconds since the Epoch"""
286289 return time .time ()
287290
291+ def _can_honor_policy (
292+ self , request : DragonRunRequest
293+ ) -> t .Tuple [bool , t .Optional [str ]]:
294+ # ensure the policy can be honored
295+ if request .policy :
296+ if request .policy .device == "gpu" :
297+ # make sure nodes w/GPUs exist
298+ if not any (self ._gpus ):
299+ return False , "Cannot satisfy request, no GPUs available"
300+
301+ if request .policy .cpu_affinity :
302+ # make sure some node has enough CPUs
303+ available = max (self ._cpus )
304+ requested = max (request .policy .cpu_affinity )
305+
306+ if requested >= available :
307+ return False , "Cannot satisfy request, not enough CPUs available"
308+
309+ if request .policy .gpu_affinity :
310+ # make sure some node has enough GPUs
311+ available = max (self ._gpus )
312+ requested = max (request .policy .gpu_affinity )
313+
314+ if requested >= available :
315+ return False , "Cannot satisfy request, not enough GPUs available"
316+
317+ return True , None
318+
288319 def _can_honor (self , request : DragonRunRequest ) -> t .Tuple [bool , t .Optional [str ]]:
289320 """Check if request can be honored with resources available in the allocation.
290321
@@ -299,6 +330,11 @@ def _can_honor(self, request: DragonRunRequest) -> t.Tuple[bool, t.Optional[str]
299330 if self ._shutdown_requested :
300331 message = "Cannot satisfy request, server is shutting down."
301332 return False , message
333+
334+ honorable , err = self ._can_honor_policy (request )
335+ if not honorable :
336+ return False , err
337+
302338 return True , None
303339
304340 def _allocate_step (
@@ -391,6 +427,44 @@ def _stop_steps(self) -> None:
391427 self ._group_infos [step_id ].status = SmartSimStatus .STATUS_CANCELLED
392428 self ._group_infos [step_id ].return_codes = [- 9 ]
393429
430+ @staticmethod
431+ def create_run_policy (
432+ request : DragonRunRequest , node_name : str
433+ ) -> "dragon_policy.Policy" :
434+ if isinstance (request , DragonRunRequest ):
435+ run_request : DragonRunRequest = request
436+
437+ device = dragon_policy .Policy .Device .DEFAULT
438+ affinity = dragon_policy .Policy .Affinity .DEFAULT
439+ cpu_affinity : t .List [int ] = []
440+ gpu_affinity : t .List [int ] = []
441+
442+ if run_request .policy is not None :
443+ if run_request .policy .cpu_affinity :
444+ affinity = dragon_policy .Policy .Affinity .SPECIFIC
445+ cpu_affinity = run_request .policy .cpu_affinity
446+ device = dragon_policy .Policy .Device .CPU
447+
448+ if run_request .policy .gpu_affinity :
449+ affinity = dragon_policy .Policy .Affinity .SPECIFIC
450+ gpu_affinity = run_request .policy .gpu_affinity
451+ device = dragon_policy .Policy .Device .GPU
452+
453+ if affinity != dragon_policy .Policy .Affinity .DEFAULT :
454+ return dragon_policy .Policy (
455+ placement = dragon_policy .Policy .Placement .HOST_NAME ,
456+ host_name = node_name ,
457+ affinity = affinity ,
458+ device = device ,
459+ cpu_affinity = cpu_affinity ,
460+ gpu_affinity = gpu_affinity ,
461+ )
462+
463+ return dragon_policy .Policy (
464+ placement = dragon_policy .Policy .Placement .HOST_NAME ,
465+ host_name = node_name ,
466+ )
467+
394468 def _start_steps (self ) -> None :
395469 self ._heartbeat ()
396470 with self ._queue_lock :
@@ -412,10 +486,7 @@ def _start_steps(self) -> None:
412486
413487 policies = []
414488 for node_name in hosts :
415- local_policy = dragon_policy .Policy (
416- placement = dragon_policy .Policy .Placement .HOST_NAME ,
417- host_name = node_name ,
418- )
489+ local_policy = self .create_run_policy (request , node_name )
419490 policies .extend ([local_policy ] * request .tasks_per_node )
420491 tmp_proc = dragon_process .ProcessTemplate (
421492 target = request .exe ,
0 commit comments