Skip to content

Commit db4db0c

Browse files
committed
improve the quant weight loaded code
1 parent fbaf889 commit db4db0c

File tree

1 file changed

+36
-62
lines changed

1 file changed

+36
-62
lines changed

vllm/model_executor/models/llama.py

Lines changed: 36 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -426,86 +426,60 @@ def load_weights(self,
426426
extra_rows = extra_rows.to(loaded_weight)
427427
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
428428

429-
# merge linear layers
430-
if self.quant_config is not None:
431-
is_attention_weight = False
432-
for weight_name, shard_size, offset in attention_weight_specs:
433-
if weight_name not in name:
434-
continue
435-
param = state_dict[name.replace(weight_name, "qkv_proj")]
429+
is_quantized = self.quant_config is not None
436430

437-
loaded_weight = loaded_weight[
438-
shard_size * tensor_model_parallel_rank:shard_size *
439-
(tensor_model_parallel_rank + 1)]
440-
param_slice = param.data[offset:offset + shard_size]
441-
assert param_slice.shape == loaded_weight.shape
442-
443-
param_slice.copy_(loaded_weight)
444-
is_attention_weight = True
445-
break
446-
if is_attention_weight:
431+
is_attention_weight = False
432+
for weight_name, shard_size, offset in attention_weight_specs:
433+
if weight_name not in name:
447434
continue
435+
param = state_dict[name.replace(weight_name, "qkv_proj")]
448436

449-
is_gate_up_weight = False
450-
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
451-
if weight_name not in name:
452-
continue
453-
param = state_dict[name.replace(weight_name, "gate_up_proj")]
454-
shard_size = param.shape[0] // 2
437+
if not is_quantized:
455438
loaded_weight = loaded_weight[
456439
shard_size * tensor_model_parallel_rank:shard_size *
457440
(tensor_model_parallel_rank + 1)]
458-
param_slice = param.data[shard_size * stride_id:shard_size *
459-
(stride_id + 1)]
460-
assert param_slice.shape == loaded_weight.shape
461-
param_slice.copy_(loaded_weight)
462-
is_gate_up_weight = True
463-
break
464-
if is_gate_up_weight:
465-
continue
466-
else:
467-
# TODO: improve this block of code
468-
is_attention_weight = False
469-
for stride_id, weight_spec in enumerate(attention_weight_specs):
470-
weight_name, shard_size, offset = weight_spec
471-
472-
if weight_name not in name:
473-
continue
474-
475-
param = state_dict[name.replace(weight_name, "qkv_proj")]
476-
477-
# TODO: this is specific to AWQ (should be more general)
441+
param_slice = param.data[offset:offset + shard_size]
442+
else:
443+
# TODO: this is specific to AWQ
478444
if "qweight" in name or "qzeros" in name:
479445
adjustment = 32 / self.quant_config.bits
480446
shard_size = int(shard_size // adjustment)
481447
offset = int(offset // adjustment)
482-
483448
param_slice = param.data[:, offset:offset + shard_size]
484-
assert param_slice.shape == loaded_weight.shape
485449

486-
param_slice.copy_(loaded_weight)
487-
is_attention_weight = True
488-
break
489-
if is_attention_weight:
450+
assert param_slice.shape == loaded_weight.shape
451+
452+
param_slice.copy_(loaded_weight)
453+
is_attention_weight = True
454+
break
455+
if is_attention_weight:
456+
continue
457+
458+
is_gate_up_weight = False
459+
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
460+
if weight_name not in name:
490461
continue
462+
param = state_dict[name.replace(weight_name, "gate_up_proj")]
491463

492-
is_gate_up_weight = False
493-
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
494-
if weight_name not in name:
495-
continue
496-
param = state_dict[name.replace(weight_name, "gate_up_proj")]
464+
if not is_quantized:
465+
shard_size = param.shape[0] // 2
466+
loaded_weight = loaded_weight[
467+
shard_size * tensor_model_parallel_rank:shard_size *
468+
(tensor_model_parallel_rank + 1)]
469+
param_slice = param.data[shard_size * stride_id:shard_size *
470+
(stride_id + 1)]
471+
else:
497472
shard_size = param.shape[1] // 2
498-
499473
start = shard_size * stride_id
500474
end = shard_size * (stride_id + 1)
501-
502475
param_slice = param.data[:, start:end]
503-
assert param_slice.shape == loaded_weight.shape
504-
param_slice.copy_(loaded_weight)
505-
is_gate_up_weight = True
506-
break
507-
if is_gate_up_weight:
508-
continue
476+
477+
assert param_slice.shape == loaded_weight.shape
478+
param_slice.copy_(loaded_weight)
479+
is_gate_up_weight = True
480+
break
481+
if is_gate_up_weight:
482+
continue
509483

510484
param = state_dict[name]
511485
load_tensor_parallel_weights(param, loaded_weight, name,

0 commit comments

Comments
 (0)