@@ -135,11 +135,11 @@ def _apply_op(
135
135
return tensor
136
136
137
137
def _collective_op (
138
- self , tensor : Union [torch .Tensor , Number , str ], fn : Callable , * args : Any , ** kwargs : Any
139
- ) -> Union [torch .Tensor , Number , List [Number ], List [str ]]:
138
+ self , tensor : Union [torch .Tensor , float , str ], fn : Callable , * args : Any , ** kwargs : Any
139
+ ) -> Union [torch .Tensor , float , List [float ], List [str ]]:
140
140
tensor_to_number = tensor_to_str = False
141
141
device = self .device ()
142
- if isinstance (tensor , Number ):
142
+ if isinstance (tensor , ( Number , float ) ):
143
143
tensor_to_number = True
144
144
tensor = torch .tensor (tensor , device = device , dtype = self ._collective_op_dtype )
145
145
elif isinstance (tensor , str ):
@@ -150,28 +150,26 @@ def _collective_op(
150
150
151
151
if tensor_to_number :
152
152
if tensor .numel () == 1 :
153
- return cast ( Number , tensor .item () )
153
+ return tensor .item ()
154
154
else :
155
155
return tensor .tolist ()
156
156
elif tensor_to_str :
157
157
return self ._decode_str (tensor )
158
158
return tensor
159
159
160
- def all_reduce (self , tensor : Union [torch .Tensor , Number ], op : str = "sum" ) -> Union [torch .Tensor , Number ]:
160
+ def all_reduce (self , tensor : Union [torch .Tensor , float ], op : str = "sum" ) -> Union [torch .Tensor , float ]:
161
161
if not isinstance (tensor , (torch .Tensor , Number )):
162
162
raise TypeError ("Unhandled input type {}" .format (type (tensor )))
163
163
164
- return cast (Union [torch .Tensor , Number ], self ._collective_op (tensor , self ._do_all_reduce , op ))
164
+ return cast (Union [torch .Tensor , float ], self ._collective_op (tensor , self ._do_all_reduce , op ))
165
165
166
- def all_gather (
167
- self , tensor : Union [torch .Tensor , Number , str ]
168
- ) -> Union [torch .Tensor , Number , List [Number ], List [str ]]:
166
+ def all_gather (self , tensor : Union [torch .Tensor , float , str ]) -> Union [torch .Tensor , float , List [float ], List [str ]]:
169
167
if not isinstance (tensor , (torch .Tensor , Number , str )):
170
168
raise TypeError ("Unhandled input type {}" .format (type (tensor )))
171
169
172
170
return self ._collective_op (tensor , self ._do_all_gather )
173
171
174
- def broadcast (self , tensor : Union [torch .Tensor , Number , str ], src : int = 0 ) -> Union [torch .Tensor , Number , str ]:
172
+ def broadcast (self , tensor : Union [torch .Tensor , float , str ], src : int = 0 ) -> Union [torch .Tensor , float , str ]:
175
173
if not isinstance (tensor , (torch .Tensor , Number , str )):
176
174
raise TypeError ("Unhandled input type {}" .format (type (tensor )))
177
175
@@ -196,7 +194,7 @@ def broadcast(self, tensor: Union[torch.Tensor, Number, str], src: int = 0) -> U
196
194
tensor = self ._apply_op (tensor , device , self ._do_broadcast , src )
197
195
198
196
if tensor_to_number :
199
- return cast ( Number , tensor .item () )
197
+ return tensor .item ()
200
198
if tensor_to_str :
201
199
list_str = self ._decode_str (tensor )
202
200
return list_str [0 ]
@@ -273,17 +271,15 @@ def create_from_backend(backend: Optional[str] = None, **kwargs: Any) -> "_Seria
273
271
def spawn (* args : Any , ** kwargs : Any ) -> None :
274
272
raise NotImplementedError ("Serial computation model does not implement spawn method" )
275
273
276
- def all_reduce (self , tensor : Union [torch .Tensor , Number ], op : str = "sum" ) -> Union [torch .Tensor , Number ]:
274
+ def all_reduce (self , tensor : Union [torch .Tensor , float ], op : str = "sum" ) -> Union [torch .Tensor , float ]:
277
275
return tensor
278
276
279
- def all_gather (
280
- self , tensor : Union [torch .Tensor , Number , str ]
281
- ) -> Union [torch .Tensor , Number , List [Number ], List [str ]]:
277
+ def all_gather (self , tensor : Union [torch .Tensor , float , str ]) -> Union [torch .Tensor , float , List [float ], List [str ]]:
282
278
if isinstance (tensor , torch .Tensor ):
283
279
return tensor
284
- return cast (Union [List [Number ], List [str ]], [tensor ])
280
+ return cast (Union [List [float ], List [str ]], [tensor ])
285
281
286
- def broadcast (self , tensor : Union [torch .Tensor , Number , str ], src : int = 0 ) -> Union [torch .Tensor , Number , str ]:
282
+ def broadcast (self , tensor : Union [torch .Tensor , float , str ], src : int = 0 ) -> Union [torch .Tensor , float , str ]:
287
283
return tensor
288
284
289
285
def _do_all_reduce (self , tensor : torch .Tensor , op : str = "sum" ) -> torch .Tensor :
0 commit comments