@@ -426,86 +426,60 @@ def load_weights(self,
426426 extra_rows = extra_rows .to (loaded_weight )
427427 loaded_weight = torch .cat ([loaded_weight , extra_rows ], dim = 0 )
428428
429- # merge linear layers
430- if self .quant_config is not None :
431- is_attention_weight = False
432- for weight_name , shard_size , offset in attention_weight_specs :
433- if weight_name not in name :
434- continue
435- param = state_dict [name .replace (weight_name , "qkv_proj" )]
429+ is_quantized = self .quant_config is not None
436430
437- loaded_weight = loaded_weight [
438- shard_size * tensor_model_parallel_rank :shard_size *
439- (tensor_model_parallel_rank + 1 )]
440- param_slice = param .data [offset :offset + shard_size ]
441- assert param_slice .shape == loaded_weight .shape
442-
443- param_slice .copy_ (loaded_weight )
444- is_attention_weight = True
445- break
446- if is_attention_weight :
431+ is_attention_weight = False
432+ for weight_name , shard_size , offset in attention_weight_specs :
433+ if weight_name not in name :
447434 continue
435+ param = state_dict [name .replace (weight_name , "qkv_proj" )]
448436
449- is_gate_up_weight = False
450- for stride_id , weight_name in enumerate (["gate_proj" , "up_proj" ]):
451- if weight_name not in name :
452- continue
453- param = state_dict [name .replace (weight_name , "gate_up_proj" )]
454- shard_size = param .shape [0 ] // 2
437+ if not is_quantized :
455438 loaded_weight = loaded_weight [
456439 shard_size * tensor_model_parallel_rank :shard_size *
457440 (tensor_model_parallel_rank + 1 )]
458- param_slice = param .data [shard_size * stride_id :shard_size *
459- (stride_id + 1 )]
460- assert param_slice .shape == loaded_weight .shape
461- param_slice .copy_ (loaded_weight )
462- is_gate_up_weight = True
463- break
464- if is_gate_up_weight :
465- continue
466- else :
467- # TODO: improve this block of code
468- is_attention_weight = False
469- for stride_id , weight_spec in enumerate (attention_weight_specs ):
470- weight_name , shard_size , offset = weight_spec
471-
472- if weight_name not in name :
473- continue
474-
475- param = state_dict [name .replace (weight_name , "qkv_proj" )]
476-
477- # TODO: this is specific to AWQ (should be more general)
441+ param_slice = param .data [offset :offset + shard_size ]
442+ else :
443+ # TODO: this is specific to AWQ
478444 if "qweight" in name or "qzeros" in name :
479445 adjustment = 32 / self .quant_config .bits
480446 shard_size = int (shard_size // adjustment )
481447 offset = int (offset // adjustment )
482-
483448 param_slice = param .data [:, offset :offset + shard_size ]
484- assert param_slice .shape == loaded_weight .shape
485449
486- param_slice .copy_ (loaded_weight )
487- is_attention_weight = True
488- break
489- if is_attention_weight :
450+ assert param_slice .shape == loaded_weight .shape
451+
452+ param_slice .copy_ (loaded_weight )
453+ is_attention_weight = True
454+ break
455+ if is_attention_weight :
456+ continue
457+
458+ is_gate_up_weight = False
459+ for stride_id , weight_name in enumerate (["gate_proj" , "up_proj" ]):
460+ if weight_name not in name :
490461 continue
462+ param = state_dict [name .replace (weight_name , "gate_up_proj" )]
491463
492- is_gate_up_weight = False
493- for stride_id , weight_name in enumerate (["gate_proj" , "up_proj" ]):
494- if weight_name not in name :
495- continue
496- param = state_dict [name .replace (weight_name , "gate_up_proj" )]
464+ if not is_quantized :
465+ shard_size = param .shape [0 ] // 2
466+ loaded_weight = loaded_weight [
467+ shard_size * tensor_model_parallel_rank :shard_size *
468+ (tensor_model_parallel_rank + 1 )]
469+ param_slice = param .data [shard_size * stride_id :shard_size *
470+ (stride_id + 1 )]
471+ else :
497472 shard_size = param .shape [1 ] // 2
498-
499473 start = shard_size * stride_id
500474 end = shard_size * (stride_id + 1 )
501-
502475 param_slice = param .data [:, start :end ]
503- assert param_slice .shape == loaded_weight .shape
504- param_slice .copy_ (loaded_weight )
505- is_gate_up_weight = True
506- break
507- if is_gate_up_weight :
508- continue
476+
477+ assert param_slice .shape == loaded_weight .shape
478+ param_slice .copy_ (loaded_weight )
479+ is_gate_up_weight = True
480+ break
481+ if is_gate_up_weight :
482+ continue
509483
510484 param = state_dict [name ]
511485 load_tensor_parallel_weights (param , loaded_weight , name ,
0 commit comments