@@ -729,7 +729,22 @@ def sample(
729729    def  load_weights (self , weights : Iterable [Tuple [str ,
730730                                                   torch .Tensor ]]) ->  Set [str ]:
731731        loader  =  AutoWeightsLoader (self , skip_prefixes = ["proj_out." ])
732-         loaded_weights  =  [(name , loaded_weight )
733-                           for  name , loaded_weight  in  weights ]
734732        mapper  =  WeightsMapper ({".fc1." : ".mlp.fc1." , ".fc2." : ".mlp.fc2." })
735-         return  loader .load_weights (loaded_weights , mapper = mapper )
733+         # add fake zeros bias for k_proj to state_dict 
734+         weights  =  _create_fake_bias_for_k_proj (weights )
735+         return  loader .load_weights (weights , mapper = mapper )
736+ 
737+ 
738+ def  _create_fake_bias_for_k_proj (
739+     weights : Iterable [Tuple [str , torch .Tensor ]]
740+ ) ->  Iterable [Tuple [str , torch .Tensor ]]:
741+     """ 
742+     Create full zeros bias for k_proj weight in self-attention layers. 
743+     So that the bias for k_proj in qkv_proj can be initialized with zeros. 
744+     """ 
745+     for  name , weight  in  weights :
746+         if  ".self_attn.k_proj.weight"  in  name :
747+             bias  =  torch .zeros (weight .size (0 ))
748+             bias_name  =  name .replace ("weight" , "bias" )
749+             yield  from  [(name , weight ), (bias_name , bias )]
750+         yield  name , weight 
0 commit comments