Skip to content

Commit 0ca1288

Browse files
authored
[Deps] Upgrade to transformers 4.56.x (#587)
* [Deps] Upgrade to transformers 4.56.x * remove triton conv1d backend for mamba2 * skip mom test
1 parent 33e6b35 commit 0ca1288

34 files changed

+149
-1017
lines changed

.github/workflows/reusable-ci-tests.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -137,13 +137,13 @@ jobs:
137137
NIGHTLY_URL="https://download.pytorch.org/whl/nightly/${{ inputs.pytorch_cuda_version }}"
138138
echo "Using nightly index URL: $NIGHTLY_URL"
139139
$CONDA_BIN_PATH/pip install -U torch pytorch-triton --index-url $NIGHTLY_URL
140-
$CONDA_BIN_PATH/pip install -U numpy packaging psutil ninja einops datasets
140+
$CONDA_BIN_PATH/pip install -U numpy packaging psutil ninja einops datasets transformers
141141
$CONDA_BIN_PATH/pip install --no-deps .
142142
else
143143
STABLE_URL="https://download.pytorch.org/whl/${{ inputs.pytorch_cuda_version }}"
144144
echo "Using stable index URL: $STABLE_URL"
145145
$CONDA_BIN_PATH/pip install -U torch~=${{ inputs.pytorch_version }} triton --index-url $STABLE_URL
146-
$CONDA_BIN_PATH/pip install -U numpy packaging psutil ninja einops datasets
146+
$CONDA_BIN_PATH/pip install -U numpy packaging psutil ninja einops datasets transformers
147147
$CONDA_BIN_PATH/pip install .
148148
if [[ "${{ inputs.runner }}" == nvidia-h100* ]]; then
149149
echo "Installing causal-conv1d for H100"
@@ -156,7 +156,7 @@ jobs:
156156
XPU_URL="https://download.pytorch.org/whl/xpu"
157157
echo "Using XPU index URL: $XPU_URL"
158158
$CONDA_BIN_PATH/pip install -U torch~=${{ inputs.pytorch_version }} pytorch-triton-xpu --index-url $XPU_URL
159-
$CONDA_BIN_PATH/pip install -U numpy packaging psutil ninja einops datasets
159+
$CONDA_BIN_PATH/pip install -U numpy packaging psutil ninja einops datasets transformers
160160
$CONDA_BIN_PATH/pip install .
161161
else
162162
echo "::error::Unsupported GPU type: ${{ inputs.gpu_type }}"
@@ -319,7 +319,7 @@ jobs:
319319
NIGHTLY_URL="https://download.pytorch.org/whl/nightly/${{ inputs.pytorch_cuda_version }}"
320320
echo "Using nightly index URL: $NIGHTLY_URL"
321321
$CONDA_BIN_PATH/pip install -U torch pytorch-triton --index-url $NIGHTLY_URL
322-
$CONDA_BIN_PATH/pip install -U numpy packaging psutil ninja einops datasets
322+
$CONDA_BIN_PATH/pip install -U numpy packaging psutil ninja einops datasets transformers
323323
$CONDA_BIN_PATH/pip install --no-deps .
324324
else
325325
STABLE_URL="https://download.pytorch.org/whl/${{ inputs.pytorch_cuda_version }}"

fla/layers/mamba2.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,10 @@ def __init__(
211211
from fla.modules.convolution import causal_conv1d_update as causal_conv1d_update_triton
212212
self.causal_conv1d_fn = causal_conv1d_triton
213213
self.causal_conv1d_update = causal_conv1d_update_triton
214+
logger.warning(
215+
"Mamba2 does not recommend using Triton's conv1d backend, "
216+
"as it is untested and may contain bugs."
217+
)
214218
else:
215219
self.causal_conv1d_fn = causal_conv1d_fn
216220
self.causal_conv1d_update = causal_conv1d_update

fla/models/abc/modeling_abc.py

Lines changed: 2 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
import torch
1010
import torch.nn as nn
11-
from transformers.generation import GenerationMixin
1211
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
1312
from transformers.modeling_utils import PreTrainedModel
1413
from transformers.utils import logging
@@ -17,7 +16,7 @@
1716
from fla.layers.abc import ABCAttention
1817
from fla.layers.attn import Attention
1918
from fla.models.abc.configuration_abc import ABCConfig
20-
from fla.models.utils import Cache
19+
from fla.models.utils import Cache, FLAGenerationMixin
2120
from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
2221
from fla.modules import GatedMLP as ABCMLP
2322
from fla.modules import RMSNorm
@@ -259,7 +258,7 @@ def forward(
259258
)
260259

261260

262-
class ABCForCausalLM(ABCPreTrainedModel, GenerationMixin):
261+
class ABCForCausalLM(ABCPreTrainedModel, FLAGenerationMixin):
263262

264263
_tied_weights_keys = ["lm_head.weight"]
265264

@@ -306,40 +305,6 @@ def generate(self, *args, **kwargs):
306305
else:
307306
raise exception
308307

309-
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
310-
def prepare_inputs_for_generation(
311-
self,
312-
input_ids: torch.LongTensor = None,
313-
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
314-
attention_mask: Optional[torch.Tensor] = None,
315-
inputs_embeds: Optional[torch.Tensor] = None,
316-
use_cache: bool = True,
317-
logits_to_keep: Optional[int] = None,
318-
**kwargs
319-
):
320-
# only last token for `inputs_ids` if the `past_key_values` is not empty.
321-
if past_key_values is not None and len(past_key_values) > 0:
322-
input_ids = input_ids[:, -1:]
323-
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
324-
if inputs_embeds is not None and len(past_key_values) == 0:
325-
model_inputs = {'inputs_embeds': inputs_embeds}
326-
else:
327-
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
328-
# recompiles graphs as the stride of the inputs is a guard.
329-
# Ref: https://github.com/huggingface/transformers/pull/29114
330-
# TODO: use `next_tokens` directly instead.
331-
model_inputs = {'input_ids': input_ids.contiguous()}
332-
333-
if logits_to_keep is not None:
334-
model_inputs['logits_to_keep'] = logits_to_keep
335-
336-
model_inputs.update({
337-
'past_key_values': past_key_values,
338-
'use_cache': use_cache,
339-
'attention_mask': attention_mask,
340-
})
341-
return model_inputs
342-
343308
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
344309
def forward(
345310
self,

fla/models/bitnet/modeling_bitnet.py

Lines changed: 2 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,14 @@
88

99
import torch
1010
import torch.nn as nn
11-
from transformers.generation import GenerationMixin
1211
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
1312
from transformers.modeling_utils import PreTrainedModel
1413
from transformers.utils import logging
1514
from transformers.utils.deprecation import deprecate_kwarg
1615

1716
from fla.layers.bitattn import BitAttention
1817
from fla.models.bitnet.configuration_bitnet import BitNetConfig
19-
from fla.models.utils import Cache
18+
from fla.models.utils import Cache, FLAGenerationMixin
2019
from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss, RMSNorm
2120
from fla.modules.activations import swiglu
2221
from fla.modules.fused_bitlinear import FusedBitLinear
@@ -296,7 +295,7 @@ def forward(
296295
)
297296

298297

299-
class BitNetForCausalLM(BitNetPreTrainedModel, GenerationMixin):
298+
class BitNetForCausalLM(BitNetPreTrainedModel, FLAGenerationMixin):
300299

301300
_tied_weights_keys = ["lm_head.weight"]
302301

@@ -328,40 +327,6 @@ def set_decoder(self, decoder):
328327
def get_decoder(self):
329328
return self.model
330329

331-
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
332-
def prepare_inputs_for_generation(
333-
self,
334-
input_ids: torch.LongTensor = None,
335-
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
336-
attention_mask: Optional[torch.Tensor] = None,
337-
inputs_embeds: Optional[torch.Tensor] = None,
338-
use_cache: bool = True,
339-
logits_to_keep: Optional[int] = None,
340-
**kwargs
341-
):
342-
# only last token for `inputs_ids` if the `past_key_values` is not empty.
343-
if past_key_values is not None and len(past_key_values) > 0:
344-
input_ids = input_ids[:, -1:]
345-
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
346-
if inputs_embeds is not None and len(past_key_values) == 0:
347-
model_inputs = {'inputs_embeds': inputs_embeds}
348-
else:
349-
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
350-
# recompiles graphs as the stride of the inputs is a guard.
351-
# Ref: https://github.com/huggingface/transformers/pull/29114
352-
# TODO: use `next_tokens` directly instead.
353-
model_inputs = {'input_ids': input_ids.contiguous()}
354-
355-
if logits_to_keep is not None:
356-
model_inputs['logits_to_keep'] = logits_to_keep
357-
358-
model_inputs.update({
359-
'past_key_values': past_key_values,
360-
'use_cache': use_cache,
361-
'attention_mask': attention_mask,
362-
})
363-
return model_inputs
364-
365330
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
366331
def forward(
367332
self,

fla/models/comba/modeling_comba.py

Lines changed: 2 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
import torch
1010
import torch.nn as nn
11-
from transformers.generation import GenerationMixin
1211
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
1312
from transformers.modeling_utils import PreTrainedModel
1413
from transformers.utils import logging
@@ -17,7 +16,7 @@
1716
from fla.layers.attn import Attention
1817
from fla.layers.comba import Comba
1918
from fla.models.comba.configuration_comba import CombaConfig
20-
from fla.models.utils import Cache
19+
from fla.models.utils import Cache, FLAGenerationMixin
2120
from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
2221
from fla.modules import GatedMLP as CombaMLP
2322
from fla.modules import RMSNorm
@@ -266,7 +265,7 @@ def forward(
266265
)
267266

268267

269-
class CombaForCausalLM(CombaPreTrainedModel, GenerationMixin):
268+
class CombaForCausalLM(CombaPreTrainedModel, FLAGenerationMixin):
270269

271270
_tied_weights_keys = ["lm_head.weight"]
272271

@@ -313,40 +312,6 @@ def generate(self, *args, **kwargs):
313312
else:
314313
raise exception
315314

316-
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
317-
def prepare_inputs_for_generation(
318-
self,
319-
input_ids: torch.LongTensor = None,
320-
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
321-
attention_mask: Optional[torch.Tensor] = None,
322-
inputs_embeds: Optional[torch.Tensor] = None,
323-
use_cache: bool = True,
324-
logits_to_keep: Optional[int] = None,
325-
**kwargs
326-
):
327-
# only last token for `inputs_ids` if the `past_key_values` is not empty.
328-
if past_key_values is not None and len(past_key_values) > 0:
329-
input_ids = input_ids[:, -1:]
330-
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
331-
if inputs_embeds is not None and len(past_key_values) == 0:
332-
model_inputs = {'inputs_embeds': inputs_embeds}
333-
else:
334-
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
335-
# recompiles graphs as the stride of the inputs is a guard.
336-
# Ref: https://github.com/huggingface/transformers/pull/29114
337-
# TODO: use `next_tokens` directly instead.
338-
model_inputs = {'input_ids': input_ids.contiguous()}
339-
340-
if logits_to_keep is not None:
341-
model_inputs['logits_to_keep'] = logits_to_keep
342-
343-
model_inputs.update({
344-
'past_key_values': past_key_values,
345-
'use_cache': use_cache,
346-
'attention_mask': attention_mask,
347-
})
348-
return model_inputs
349-
350315
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
351316
def forward(
352317
self,

fla/models/delta_net/modeling_delta_net.py

Lines changed: 2 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
import torch
1010
import torch.nn as nn
11-
from transformers.generation import GenerationMixin
1211
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
1312
from transformers.modeling_utils import PreTrainedModel
1413
from transformers.utils import logging
@@ -17,7 +16,7 @@
1716
from fla.layers.attn import Attention
1817
from fla.layers.delta_net import DeltaNet
1918
from fla.models.delta_net.configuration_delta_net import DeltaNetConfig
20-
from fla.models.utils import Cache
19+
from fla.models.utils import Cache, FLAGenerationMixin
2120
from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
2221
from fla.modules import GatedMLP as DeltaNetMLP
2322
from fla.modules import RMSNorm
@@ -256,7 +255,7 @@ def forward(
256255
)
257256

258257

259-
class DeltaNetForCausalLM(DeltaNetPreTrainedModel, GenerationMixin):
258+
class DeltaNetForCausalLM(DeltaNetPreTrainedModel, FLAGenerationMixin):
260259

261260
_tied_weights_keys = ["lm_head.weight"]
262261

@@ -303,40 +302,6 @@ def generate(self, *args, **kwargs):
303302
else:
304303
raise exception
305304

306-
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
307-
def prepare_inputs_for_generation(
308-
self,
309-
input_ids: torch.LongTensor = None,
310-
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
311-
attention_mask: Optional[torch.Tensor] = None,
312-
inputs_embeds: Optional[torch.Tensor] = None,
313-
use_cache: bool = True,
314-
logits_to_keep: Optional[int] = None,
315-
**kwargs
316-
):
317-
# only last token for `inputs_ids` if the `past_key_values` is not empty.
318-
if past_key_values is not None and len(past_key_values) > 0:
319-
input_ids = input_ids[:, -1:]
320-
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
321-
if inputs_embeds is not None and len(past_key_values) == 0:
322-
model_inputs = {'inputs_embeds': inputs_embeds}
323-
else:
324-
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
325-
# recompiles graphs as the stride of the inputs is a guard.
326-
# Ref: https://github.com/huggingface/transformers/pull/29114
327-
# TODO: use `next_tokens` directly instead.
328-
model_inputs = {'input_ids': input_ids.contiguous()}
329-
330-
if logits_to_keep is not None:
331-
model_inputs['logits_to_keep'] = logits_to_keep
332-
333-
model_inputs.update({
334-
'past_key_values': past_key_values,
335-
'use_cache': use_cache,
336-
'attention_mask': attention_mask,
337-
})
338-
return model_inputs
339-
340305
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
341306
def forward(
342307
self,

fla/models/forgetting_transformer/modeling_forgetting_transformer.py

Lines changed: 2 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,14 @@
88

99
import torch
1010
import torch.nn as nn
11-
from transformers.generation import GenerationMixin
1211
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
1312
from transformers.modeling_utils import PreTrainedModel
1413
from transformers.utils import logging
1514
from transformers.utils.deprecation import deprecate_kwarg
1615

1716
from fla.layers.forgetting_attn import ForgettingAttention
1817
from fla.models.forgetting_transformer.configuration_forgetting_transformer import ForgettingTransformerConfig
19-
from fla.models.utils import Cache
18+
from fla.models.utils import Cache, FLAGenerationMixin
2019
from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
2120
from fla.modules import GatedMLP as ForgettingTransformerMLP
2221
from fla.modules import RMSNorm
@@ -260,7 +259,7 @@ def forward(
260259
)
261260

262261

263-
class ForgettingTransformerForCausalLM(ForgettingTransformerPreTrainedModel, GenerationMixin):
262+
class ForgettingTransformerForCausalLM(ForgettingTransformerPreTrainedModel, FLAGenerationMixin):
264263

265264
_tied_weights_keys = ["lm_head.weight"]
266265

@@ -292,40 +291,6 @@ def set_decoder(self, decoder):
292291
def get_decoder(self):
293292
return self.model
294293

295-
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
296-
def prepare_inputs_for_generation(
297-
self,
298-
input_ids: torch.LongTensor = None,
299-
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
300-
attention_mask: Optional[torch.Tensor] = None,
301-
inputs_embeds: Optional[torch.Tensor] = None,
302-
use_cache: bool = True,
303-
logits_to_keep: Optional[int] = None,
304-
**kwargs
305-
):
306-
# only last token for `inputs_ids` if the `past_key_values` is not empty.
307-
if past_key_values is not None and len(past_key_values) > 0:
308-
input_ids = input_ids[:, -1:]
309-
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
310-
if inputs_embeds is not None and len(past_key_values) == 0:
311-
model_inputs = {'inputs_embeds': inputs_embeds}
312-
else:
313-
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
314-
# recompiles graphs as the stride of the inputs is a guard.
315-
# Ref: https://github.com/huggingface/transformers/pull/29114
316-
# TODO: use `next_tokens` directly instead.
317-
model_inputs = {'input_ids': input_ids.contiguous()}
318-
319-
if logits_to_keep is not None:
320-
model_inputs['logits_to_keep'] = logits_to_keep
321-
322-
model_inputs.update({
323-
'past_key_values': past_key_values,
324-
'use_cache': use_cache,
325-
'attention_mask': attention_mask,
326-
})
327-
return model_inputs
328-
329294
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
330295
def forward(
331296
self,

0 commit comments

Comments
 (0)