@@ -1262,8 +1262,7 @@ def prepare_inputs_for_generation(
12621262 # create position_ids on the fly for batch generation
12631263 position_ids = attention_mask .long ().cumsum (- 1 ) - 1
12641264 position_ids .masked_fill_ (attention_mask == 0 , 1 )
1265- if past_key_values :
1266- position_ids = position_ids [:, - input_ids .shape [1 ] :]
1265+ position_ids = position_ids [:, - input_ids .shape [1 ] :]
12671266
12681267 # TODO @gante we should only keep a `cache_position` in generate, and do +=1.
12691268 # same goes for position ids. Could also help with continued generation.
@@ -1274,7 +1273,7 @@ def prepare_inputs_for_generation(
12741273 )
12751274
12761275 # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1277- if inputs_embeds is not None and past_key_values is None :
1276+ if inputs_embeds is not None and past_length == 0 :
12781277 model_inputs = {"inputs_embeds" : inputs_embeds }
12791278 else :
12801279 model_inputs = {"input_ids" : input_ids }
@@ -1283,22 +1282,12 @@ def prepare_inputs_for_generation(
12831282 {
12841283 "position_ids" : position_ids ,
12851284 "cache_position" : cache_position ,
1286- "past_key_values" : past_key_values ,
12871285 "use_cache" : kwargs .get ("use_cache" ),
12881286 "attention_mask" : attention_mask ,
12891287 }
12901288 )
12911289 return model_inputs
12921290
1293- @staticmethod
1294- def _reorder_cache (past_key_values , beam_idx ):
1295- reordered_past = ()
1296- for layer_past in past_key_values :
1297- reordered_past += (
1298- tuple (past_state .index_select (0 , beam_idx .to (past_state .device )) for past_state in layer_past ),
1299- )
1300- return reordered_past
1301-
13021291
13031292@add_start_docstrings (
13041293 """
0 commit comments