@@ -84,7 +84,7 @@ def __init__(
8484        backend : BackendInterface [RequestT , ResponseT ],
8585        strategy : SchedulingStrategy ,
8686        startup_duration : float ,
87-         ** constraints : dict [ str ,  Constraint ] ,
87+         ** constraints : Constraint ,
8888    ):
8989        """ 
9090        Initialize a worker process group for distributed request processing. 
@@ -232,7 +232,7 @@ async def create_processes(self):
232232                    worker_index = rank ,
233233                    max_buffer_send_size = None ,
234234                    max_buffer_receive_size = per_proc_max_buffer_size ,
235-                 ),
235+                 ),   # The non-group worker lacks the SchedulerState type. Type err. 
236236                backend = self .backend ,
237237                strategy = self .strategy ,
238238                async_limit = async_limit ,
@@ -478,9 +478,9 @@ def __init__(
478478            num_processes = len (processes ),
479479            start_time = start_time ,
480480        )
481-         self ._queued_requests : set [RequestT   |   MultiTurnRequestT [ RequestT ] ] =  set ()
482-         self ._pending_requests : set [RequestT   |   MultiTurnRequestT [ RequestT ] ] =  set ()
483-         self ._processing_requests : set [RequestT   |   MultiTurnRequestT [ RequestT ] ] =  set ()
481+         self ._queued_request_ids : set [str ] =  set ()
482+         self ._pending_request_ids : set [str ] =  set ()
483+         self ._processing_request_ids : set [str ] =  set ()
484484
485485    def  requests_generator (
486486        self , requests : Iterable [RequestT  |  MultiTurnRequestT [RequestT ]]
@@ -517,11 +517,13 @@ def requests_generator(
517517                )
518518                state_update  =  self ._locked_update (request_info )
519519                request_info .timings .queued  =  time .time ()
520+                 if  self .messaging .buffer_receive_queue  is  None :
521+                     raise  RuntimeError ("buffer receive queue is None" )
520522                self .messaging .buffer_receive_queue .sync_put (
521523                    (None , request , request_info , state_update .state )
522524                )
523525
524-                 yield  ( request , request_info ) 
526+                 yield  request , request_info 
525527
526528                if  state_update .stop_queueing :
527529                    self .stop_send_requests_event .set ()
@@ -530,8 +532,8 @@ def requests_generator(
530532            # Reached the end, inject a RequestsExhaustedConstraint to record 
531533            self ._locked_update (
532534                info = None ,
533-                 requests_exhausted = {
534-                     "requests_exhausted" : RequestsExhaustedConstraint (
535+                 add_constraints = {
536+                     "requests_exhausted" : RequestsExhaustedConstraint (   # type: ignore[dict-item] 
535537                        num_requests = count 
536538                    )
537539                },
@@ -610,10 +612,10 @@ def received_callback(
610612    def  _locked_update (
611613        self ,
612614        info : RequestInfo  |  None  =  None ,
613-         ** add_constraints : dict [str , Constraint ],
615+         add_constraints : dict [str , Constraint ]  |   None   =   None ,
614616    ) ->  _StateUpdate :
615617        with  self ._update_lock :
616-             if  add_constraints :
618+             if  add_constraints   is   not   None :
617619                self .constraints .update (add_constraints )
618620
619621            if  info  is  not None :
@@ -631,34 +633,34 @@ def _locked_update(
631633
632634    def  _update_state_request_counts (self , info : RequestInfo ):
633635        if  info .status  ==  "queued" :
634-             self ._queued_requests .add (info .request_id )
635-             self ._state .queued_requests  =  len (self ._queued_requests )
636+             self ._queued_request_ids .add (info .request_id )
637+             self ._state .queued_requests  =  len (self ._queued_request_ids )
636638            self ._state .created_requests  +=  1 
637639        elif  info .status  ==  "pending" :
638-             self ._queued_requests .remove (info .request_id )
639-             self ._state .queued_requests  =  len (self ._queued_requests )
640-             self ._pending_requests .add (info .request_id )
641-             self ._state .pending_requests  =  len (self ._pending_requests )
640+             self ._queued_request_ids .remove (info .request_id )
641+             self ._state .queued_requests  =  len (self ._queued_request_ids )
642+             self ._pending_request_ids .add (info .request_id )
643+             self ._state .pending_requests  =  len (self ._pending_request_ids )
642644        elif  info .status  ==  "in_progress" :
643-             self ._pending_requests .remove (info .request_id )
644-             self ._state .pending_requests  =  len (self ._pending_requests )
645-             self ._processing_requests .add (info .request_id )
646-             self ._state .processing_requests  =  len (self ._processing_requests )
645+             self ._pending_request_ids .remove (info .request_id )
646+             self ._state .pending_requests  =  len (self ._pending_request_ids )
647+             self ._processing_request_ids .add (info .request_id )
648+             self ._state .processing_requests  =  len (self ._processing_request_ids )
647649        elif  info .status  ==  "completed" :
648-             self ._processing_requests .remove (info .request_id )
649-             self ._state .processing_requests  =  len (self ._processing_requests )
650+             self ._processing_request_ids .remove (info .request_id )
651+             self ._state .processing_requests  =  len (self ._processing_request_ids )
650652            self ._state .processed_requests  +=  1 
651653            self ._state .successful_requests  +=  1 
652654        elif  info .status  in  ("errored" , "cancelled" ):
653-             if  info .request_id  in  self ._queued_requests :
654-                 self ._queued_requests .remove (info .request_id )
655-                 self ._state .queued_requests  =  len (self ._queued_requests )
656-             elif  info .request_id  in  self ._pending_requests :
657-                 self ._pending_requests .remove (info .request_id )
658-                 self ._state .pending_requests  =  len (self ._pending_requests )
659-             elif  info .request_id  in  self ._processing_requests :
660-                 self ._processing_requests .remove (info .request_id )
661-                 self ._state .processing_requests  =  len (self ._processing_requests )
655+             if  info .request_id  in  self ._queued_request_ids :
656+                 self ._queued_request_ids .remove (info .request_id )
657+                 self ._state .queued_requests  =  len (self ._queued_request_ids )
658+             elif  info .request_id  in  self ._pending_request_ids :
659+                 self ._pending_request_ids .remove (info .request_id )
660+                 self ._state .pending_requests  =  len (self ._pending_request_ids )
661+             elif  info .request_id  in  self ._processing_request_ids :
662+                 self ._processing_request_ids .remove (info .request_id )
663+                 self ._state .processing_requests  =  len (self ._processing_request_ids )
662664
663665            self ._state .processed_requests  +=  1 
664666            self ._state .errored_requests  +=  1  if  info .status  ==  "errored"  else  0 
0 commit comments