@@ -269,10 +269,6 @@ def __init__(self,
269269 self .register_parameter ("bias" , None )
270270
271271 def weight_loader (self , param : Parameter , loaded_weight : torch .Tensor ):
272- # Special case for Fp8 scales.
273- fp8_scales_shard_indexer = getattr (param , "fp8_scales_shard_indexer" ,
274- None )
275-
276272 tp_rank = get_tensor_model_parallel_rank ()
277273 output_dim = getattr (param , "output_dim" , None )
278274 param_data = param .data
@@ -281,11 +277,11 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
281277 start_idx = tp_rank * shard_size
282278 loaded_weight = loaded_weight .narrow (output_dim , start_idx ,
283279 shard_size )
284- # Special case for Fp8 scales.
285- elif fp8_scales_shard_indexer is not None :
286- param_data , loaded_weight = fp8_scales_shard_indexer ( param_data ,
287- loaded_weight ,
288- shard_id = 0 )
280+
281+ # Special case for loading scales off disk, which often do not
282+ # have a shape (such as in the case of AutoFP8).
283+ if len ( loaded_weight . shape ) == 0 :
284+ loaded_weight = loaded_weight . reshape ( 1 )
289285
290286 assert param_data .shape == loaded_weight .shape
291287 param_data .copy_ (loaded_weight )
@@ -751,10 +747,6 @@ def __init__(self,
751747 self .register_parameter ("bias" , None )
752748
753749 def weight_loader (self , param : Parameter , loaded_weight : torch .Tensor ):
754- # Special case for Fp8 scales.
755- fp8_scales_shard_indexer = getattr (param , "fp8_scales_shard_indexer" ,
756- None )
757-
758750 tp_rank = get_tensor_model_parallel_rank ()
759751 input_dim = getattr (param , "input_dim" , None )
760752 param_data = param .data
@@ -764,13 +756,9 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
764756 loaded_weight = loaded_weight .narrow (input_dim , start_idx ,
765757 shard_size )
766758
767- # Special case for Fp8 scales.
768- elif fp8_scales_shard_indexer is not None :
769- param_data , loaded_weight = fp8_scales_shard_indexer (param_data ,
770- loaded_weight ,
771- shard_id = 0 )
772-
773- if fp8_scales_shard_indexer is None and len (loaded_weight .shape ) == 0 :
759+ # Special case for loading scales off disk, which often do not
760+ # have a shape (such as in the case of AutoFP8).
761+ if len (loaded_weight .shape ) == 0 :
774762 loaded_weight = loaded_weight .reshape (1 )
775763
776764 assert param_data .shape == loaded_weight .shape
0 commit comments