|
8 | 8 |
|
9 | 9 | import torch |
10 | 10 | import torch.nn as nn |
11 | | -from transformers.generation import GenerationMixin |
12 | 11 | from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast |
13 | 12 | from transformers.modeling_utils import PreTrainedModel |
14 | 13 | from transformers.utils import logging |
15 | 14 | from transformers.utils.deprecation import deprecate_kwarg |
16 | 15 |
|
17 | 16 | from fla.layers.forgetting_attn import ForgettingAttention |
18 | 17 | from fla.models.forgetting_transformer.configuration_forgetting_transformer import ForgettingTransformerConfig |
19 | | -from fla.models.utils import Cache |
| 18 | +from fla.models.utils import Cache, FLAGenerationMixin |
20 | 19 | from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss |
21 | 20 | from fla.modules import GatedMLP as ForgettingTransformerMLP |
22 | 21 | from fla.modules import RMSNorm |
@@ -260,7 +259,7 @@ def forward( |
260 | 259 | ) |
261 | 260 |
|
262 | 261 |
|
263 | | -class ForgettingTransformerForCausalLM(ForgettingTransformerPreTrainedModel, GenerationMixin): |
| 262 | +class ForgettingTransformerForCausalLM(ForgettingTransformerPreTrainedModel, FLAGenerationMixin): |
264 | 263 |
|
265 | 264 | _tied_weights_keys = ["lm_head.weight"] |
266 | 265 |
|
@@ -292,40 +291,6 @@ def set_decoder(self, decoder): |
292 | 291 | def get_decoder(self): |
293 | 292 | return self.model |
294 | 293 |
|
295 | | - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") |
296 | | - def prepare_inputs_for_generation( |
297 | | - self, |
298 | | - input_ids: torch.LongTensor = None, |
299 | | - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, |
300 | | - attention_mask: Optional[torch.Tensor] = None, |
301 | | - inputs_embeds: Optional[torch.Tensor] = None, |
302 | | - use_cache: bool = True, |
303 | | - logits_to_keep: Optional[int] = None, |
304 | | - **kwargs |
305 | | - ): |
306 | | - # only last token for `inputs_ids` if the `past_key_values` is not empty. |
307 | | - if past_key_values is not None and len(past_key_values) > 0: |
308 | | - input_ids = input_ids[:, -1:] |
309 | | - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step |
310 | | - if inputs_embeds is not None and len(past_key_values) == 0: |
311 | | - model_inputs = {'inputs_embeds': inputs_embeds} |
312 | | - else: |
313 | | - # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise |
314 | | - # recompiles graphs as the stride of the inputs is a guard. |
315 | | - # Ref: https://github.com/huggingface/transformers/pull/29114 |
316 | | - # TODO: use `next_tokens` directly instead. |
317 | | - model_inputs = {'input_ids': input_ids.contiguous()} |
318 | | - |
319 | | - if logits_to_keep is not None: |
320 | | - model_inputs['logits_to_keep'] = logits_to_keep |
321 | | - |
322 | | - model_inputs.update({ |
323 | | - 'past_key_values': past_key_values, |
324 | | - 'use_cache': use_cache, |
325 | | - 'attention_mask': attention_mask, |
326 | | - }) |
327 | | - return model_inputs |
328 | | - |
329 | 294 | @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") |
330 | 295 | def forward( |
331 | 296 | self, |
|
0 commit comments