Skip to content
This repository was archived by the owner on Sep 4, 2025. It is now read-only.

Commit 9afc7e5

Browse files
robertgshaw2-redhatRobert Shaw
authored andcommitted
[ Misc ] Remove fp8_shard_indexer from Col/Row Parallel Linear (Simplify Weight Loading) (vllm-project#5928)
Co-authored-by: Robert Shaw <rshaw@neuralmagic>
1 parent 28e9598 commit 9afc7e5

File tree

1 file changed

+8
-20
lines changed

1 file changed

+8
-20
lines changed

vllm/model_executor/layers/linear.py

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -269,10 +269,6 @@ def __init__(self,
269269
self.register_parameter("bias", None)
270270

271271
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
272-
# Special case for Fp8 scales.
273-
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
274-
None)
275-
276272
tp_rank = get_tensor_model_parallel_rank()
277273
output_dim = getattr(param, "output_dim", None)
278274
param_data = param.data
@@ -281,11 +277,11 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
281277
start_idx = tp_rank * shard_size
282278
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
283279
shard_size)
284-
# Special case for Fp8 scales.
285-
elif fp8_scales_shard_indexer is not None:
286-
param_data, loaded_weight = fp8_scales_shard_indexer(param_data,
287-
loaded_weight,
288-
shard_id=0)
280+
281+
# Special case for loading scales off disk, which often do not
282+
# have a shape (such as in the case of AutoFP8).
283+
if len(loaded_weight.shape) == 0:
284+
loaded_weight = loaded_weight.reshape(1)
289285

290286
assert param_data.shape == loaded_weight.shape
291287
param_data.copy_(loaded_weight)
@@ -751,10 +747,6 @@ def __init__(self,
751747
self.register_parameter("bias", None)
752748

753749
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
754-
# Special case for Fp8 scales.
755-
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
756-
None)
757-
758750
tp_rank = get_tensor_model_parallel_rank()
759751
input_dim = getattr(param, "input_dim", None)
760752
param_data = param.data
@@ -764,13 +756,9 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
764756
loaded_weight = loaded_weight.narrow(input_dim, start_idx,
765757
shard_size)
766758

767-
# Special case for Fp8 scales.
768-
elif fp8_scales_shard_indexer is not None:
769-
param_data, loaded_weight = fp8_scales_shard_indexer(param_data,
770-
loaded_weight,
771-
shard_id=0)
772-
773-
if fp8_scales_shard_indexer is None and len(loaded_weight.shape) == 0:
759+
# Special case for loading scales off disk, which often do not
760+
# have a shape (such as in the case of AutoFP8).
761+
if len(loaded_weight.shape) == 0:
774762
loaded_weight = loaded_weight.reshape(1)
775763

776764
assert param_data.shape == loaded_weight.shape

0 commit comments

Comments
 (0)