@@ -164,7 +164,7 @@ def __init__(
164164 super ().__init__ ()
165165 self .hidden_size = hidden_size
166166 tp_size = get_tensor_model_parallel_world_size ()
167- assert tp_size == 1 , ' quantization does not support TP'
167+ assert tp_size == 1 , " quantization does not support TP"
168168 self .total_num_heads = num_heads
169169 assert self .total_num_heads % tp_size == 0
170170 self .num_heads = self .total_num_heads // tp_size
@@ -178,7 +178,7 @@ def __init__(
178178
179179 self .qkv_proj = get_quantized_layer (
180180 hidden_size ,
181- ( self .total_num_heads + 2 * self .total_num_kv_heads ) * self . head_dim ,
181+ self .q_size + 2 * self .kv_size ,
182182 quant_config
183183 )
184184
@@ -220,8 +220,17 @@ def __init__(
220220 quant_config : QuantizationConfig
221221 ):
222222 super ().__init__ ()
223- self .gate_up_proj = get_quantized_layer (hidden_size , 2 * intermediate_size , quant_config )
224- self .down_proj = get_quantized_layer (intermediate_size , hidden_size , quant_config )
223+
224+ self .gate_up_proj = get_quantized_layer (
225+ hidden_size ,
226+ 2 * intermediate_size , quant_config
227+ )
228+
229+ self .down_proj = get_quantized_layer (
230+ intermediate_size ,
231+ hidden_size ,
232+ quant_config
233+ )
225234
226235 if hidden_act != "silu" :
227236 raise ValueError (f"Unsupported activation: { hidden_act } . "
@@ -313,9 +322,12 @@ def __init__(self, config: LlamaConfig, quant_config: QuantizationConfig):
313322 vocab_size = ((config .vocab_size + 63 ) // 64 ) * 64
314323 self .embed_tokens = VocabParallelEmbedding (
315324 vocab_size , config .hidden_size , perform_initialization = False )
325+
316326 self .layers = nn .ModuleList ([
317- LlamaDecoderLayer (config , quant_config ) for _ in range (config .num_hidden_layers )
327+ LlamaDecoderLayer (config , quant_config )
328+ for _ in range (config .num_hidden_layers )
318329 ])
330+
319331 self .norm = RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
320332
321333 def forward (
@@ -414,82 +426,60 @@ def load_weights(self,
414426 extra_rows = extra_rows .to (loaded_weight )
415427 loaded_weight = torch .cat ([loaded_weight , extra_rows ], dim = 0 )
416428
417- is_quantized = self .quant_config is not None and self . quant_config . method is not None
429+ is_quantized = self .quant_config is not None
418430
419- # merge linear layers
420- if not is_quantized :
421- is_attention_weight = False
422- for weight_name , shard_size , offset in attention_weight_specs :
423- if weight_name not in name :
424- continue
425- param = state_dict [name .replace (weight_name , "qkv_proj" )]
431+ is_attention_weight = False
432+ for weight_name , shard_size , offset in attention_weight_specs :
433+ if weight_name not in name :
434+ continue
435+ param = state_dict [name .replace (weight_name , "qkv_proj" )]
426436
437+ if not is_quantized :
427438 loaded_weight = loaded_weight [
428439 shard_size * tensor_model_parallel_rank :shard_size *
429440 (tensor_model_parallel_rank + 1 )]
430441 param_slice = param .data [offset :offset + shard_size ]
431- assert param_slice .shape == loaded_weight .shape
442+ else :
443+ # TODO: this is specific to AWQ
444+ if "qweight" in name or "qzeros" in name :
445+ adjustment = 32 / self .quant_config .bits
446+ shard_size = int (shard_size // adjustment )
447+ offset = int (offset // adjustment )
448+ param_slice = param .data [:, offset :offset + shard_size ]
449+
450+ assert param_slice .shape == loaded_weight .shape
432451
433- param_slice .copy_ (loaded_weight )
434- is_attention_weight = True
435- break
436- if is_attention_weight :
452+ param_slice .copy_ (loaded_weight )
453+ is_attention_weight = True
454+ break
455+ if is_attention_weight :
456+ continue
457+
458+ is_gate_up_weight = False
459+ for stride_id , weight_name in enumerate (["gate_proj" , "up_proj" ]):
460+ if weight_name not in name :
437461 continue
462+ param = state_dict [name .replace (weight_name , "gate_up_proj" )]
438463
439- is_gate_up_weight = False
440- for stride_id , weight_name in enumerate (["gate_proj" , "up_proj" ]):
441- if weight_name not in name :
442- continue
443- param = state_dict [name .replace (weight_name , "gate_up_proj" )]
464+ if not is_quantized :
444465 shard_size = param .shape [0 ] // 2
445466 loaded_weight = loaded_weight [
446467 shard_size * tensor_model_parallel_rank :shard_size *
447468 (tensor_model_parallel_rank + 1 )]
448469 param_slice = param .data [shard_size * stride_id :shard_size *
449470 (stride_id + 1 )]
450- assert param_slice .shape == loaded_weight .shape
451- param_slice .copy_ (loaded_weight )
452- is_gate_up_weight = True
453- break
454- if is_gate_up_weight :
455- continue
456- else :
457- # TODO: improve this block of code (not DRY, hacky, specific to AWQ)
458- is_attention_weight = False
459- for stride_id , (weight_name , shard_size , offset ) in enumerate (attention_weight_specs ):
460- if weight_name not in name :
461- continue
462- param = state_dict [name .replace (weight_name , "qkv_proj" )]
463-
464- # TODO: this is specific to AWQ (should be more general)
465- if 'qweight' in name or 'qzeros' in name :
466- shard_size = int (shard_size // (32 / self .quant_config .bits ))
467- offset = int (offset // (32 / self .quant_config .bits ))
468-
469- param_slice = param .data [:, offset :offset + shard_size ]
470- assert param_slice .shape == loaded_weight .shape
471-
472- param_slice .copy_ (loaded_weight )
473- is_attention_weight = True
474- break
475- if is_attention_weight :
476- continue
477-
478- is_gate_up_weight = False
479- for stride_id , weight_name in enumerate (["gate_proj" , "up_proj" ]):
480- if weight_name not in name :
481- continue
482- param = state_dict [name .replace (weight_name , "gate_up_proj" )]
471+ else :
483472 shard_size = param .shape [1 ] // 2
484-
485- start , end = shard_size * stride_id , shard_size * (stride_id + 1 )
473+ start = shard_size * stride_id
474+ end = shard_size * (stride_id + 1 )
486475 param_slice = param .data [:, start :end ]
487- assert param_slice .shape == loaded_weight .shape
488- param_slice .copy_ (loaded_weight )
489- is_gate_up_weight = True
490- break
491- if is_gate_up_weight :
492- continue
476+
477+ assert param_slice .shape == loaded_weight .shape
478+ param_slice .copy_ (loaded_weight )
479+ is_gate_up_weight = True
480+ break
481+ if is_gate_up_weight :
482+ continue
493483
494484 param = state_dict [name ]
495485 load_tensor_parallel_weights (param , loaded_weight , name ,
0 commit comments