@@ -45,7 +45,7 @@ class GraphCaptureContext:
4545
4646
4747def  _split_tensor_dict (
48-         tensor_dict : Dict [Any , Union [torch .Tensor , Any ]],
48+         tensor_dict : Dict [str , Union [torch .Tensor , Any ]],
4949        prefix : str  =  "" ) ->  Tuple [List [Tuple [str , Any ]], List [torch .Tensor ]]:
5050    """Split the tensor dictionary into two parts: 
5151    1. A list of (key, value) pairs. If the value is a tensor, it is replaced 
@@ -473,11 +473,11 @@ def recv_object(self, src: int) -> Any:
473473
474474    def  broadcast_tensor_dict (
475475        self ,
476-         tensor_dict : Optional [Dict [Any , Union [torch .Tensor , Any ]]] =  None ,
476+         tensor_dict : Optional [Dict [str , Union [torch .Tensor , Any ]]] =  None ,
477477        src : int  =  0 ,
478478        group : Optional [ProcessGroup ] =  None ,
479479        metadata_group : Optional [ProcessGroup ] =  None 
480-     ) ->  Optional [Dict [Any , Union [torch .Tensor , Any ]]]:
480+     ) ->  Optional [Dict [str , Union [torch .Tensor , Any ]]]:
481481        """Broadcast the input tensor dictionary. 
482482        NOTE: `src` is the local rank of the source rank. 
483483        """ 
@@ -558,9 +558,9 @@ def broadcast_tensor_dict(
558558
559559    def  send_tensor_dict (
560560        self ,
561-         tensor_dict : Dict [Any , Union [torch .Tensor , Any ]],
561+         tensor_dict : Dict [str , Union [torch .Tensor , Any ]],
562562        dst : Optional [int ] =  None 
563-     ) ->  Optional [Dict [Any , Union [torch .Tensor , Any ]]]:
563+     ) ->  Optional [Dict [str , Union [torch .Tensor , Any ]]]:
564564        """Send the input tensor dictionary. 
565565        NOTE: `dst` is the local rank of the source rank. 
566566        """ 
@@ -599,7 +599,7 @@ def send_tensor_dict(
599599    def  recv_tensor_dict (
600600        self ,
601601        src : Optional [int ] =  None 
602-     ) ->  Optional [Dict [Any , Union [torch .Tensor , Any ]]]:
602+     ) ->  Optional [Dict [str , Union [torch .Tensor , Any ]]]:
603603        """Recv the input tensor dictionary. 
604604        NOTE: `src` is the local rank of the source rank. 
605605        """ 
@@ -615,15 +615,15 @@ def recv_tensor_dict(
615615        assert  src  <  self .world_size , f"Invalid src rank ({ src }  )" 
616616
617617        recv_metadata_list  =  self .recv_object (src = src )
618-         tensor_dict  =  {}
618+         tensor_dict :  Dict [ str ,  Any ]  =  {}
619619        for  key , value  in  recv_metadata_list :
620620            if  isinstance (value , TensorMetadata ):
621621                tensor  =  torch .empty (value .size ,
622622                                     dtype = value .dtype ,
623623                                     device = value .device )
624624                if  tensor .numel () ==  0 :
625625                    # Skip broadcasting empty tensors. 
626-                     tensor_dict [ key ]  =   tensor 
626+                     _update_nested_dict ( tensor_dict ,  key ,  tensor ) 
627627                    continue 
628628                if  tensor .is_cpu :
629629                    # use metadata_group for CPU tensors 
@@ -633,9 +633,9 @@ def recv_tensor_dict(
633633                else :
634634                    # use group for GPU tensors 
635635                    torch .distributed .recv (tensor , src = src , group = group )
636-                 tensor_dict [ key ]  =   tensor 
636+                 _update_nested_dict ( tensor_dict ,  key ,  tensor ) 
637637            else :
638-                 tensor_dict [ key ]  =   value 
638+                 _update_nested_dict ( tensor_dict ,  key ,  value ) 
639639        return  tensor_dict 
640640
641641    def  barrier (self ):
0 commit comments