Skip to content

Commit e69eec2

Browse files
committed
tmp
1 parent 93c9e2e commit e69eec2

File tree

1 file changed

+2
-13
lines changed

1 file changed

+2
-13
lines changed

src/transformers/models/llama/modeling_llama.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1262,8 +1262,7 @@ def prepare_inputs_for_generation(
12621262
# create position_ids on the fly for batch generation
12631263
position_ids = attention_mask.long().cumsum(-1) - 1
12641264
position_ids.masked_fill_(attention_mask == 0, 1)
1265-
if past_key_values:
1266-
position_ids = position_ids[:, -input_ids.shape[1] :]
1265+
position_ids = position_ids[:, -input_ids.shape[1] :]
12671266

12681267
# TODO @gante we should only keep a `cache_position` in generate, and do +=1.
12691268
# same goes for position ids. Could also help with continued generation.
@@ -1274,7 +1273,7 @@ def prepare_inputs_for_generation(
12741273
)
12751274

12761275
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1277-
if inputs_embeds is not None and past_key_values is None:
1276+
if inputs_embeds is not None and past_length == 0:
12781277
model_inputs = {"inputs_embeds": inputs_embeds}
12791278
else:
12801279
model_inputs = {"input_ids": input_ids}
@@ -1283,22 +1282,12 @@ def prepare_inputs_for_generation(
12831282
{
12841283
"position_ids": position_ids,
12851284
"cache_position": cache_position,
1286-
"past_key_values": past_key_values,
12871285
"use_cache": kwargs.get("use_cache"),
12881286
"attention_mask": attention_mask,
12891287
}
12901288
)
12911289
return model_inputs
12921290

1293-
@staticmethod
1294-
def _reorder_cache(past_key_values, beam_idx):
1295-
reordered_past = ()
1296-
for layer_past in past_key_values:
1297-
reordered_past += (
1298-
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1299-
)
1300-
return reordered_past
1301-
13021291

13031292
@add_start_docstrings(
13041293
"""

0 commit comments

Comments
 (0)