-
Couldn't load subscription status.
- Fork 279
[Deps] Upgrade to transformers 4.56.x #587
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Warning Rate limit exceeded@zhiyuan1i has exceeded the limit for the number of commits or files that can be reviewed per hour. Please wait 27 minutes and 37 seconds before requesting another review. ⌛ How to resolve this issue?After the wait time has elapsed, a review can be triggered using the We recommend that you space out your commits to avoid hitting the rate limit. 🚦 How do rate limits work?CodeRabbit enforces hourly rate limits for each developer per organization. Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout. Please see our FAQ for further information. 📒 Files selected for processing (1)
WalkthroughReplaces per-model GenerationMixin with a new FLAGenerationMixin, centralizes version-gated prepare_inputs_for_generation and Cache logic in Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor User
participant Model as CausalLM Model
participant Mixin as FLAGenerationMixin
participant HF as Transformers Runtime
User->>Model: generate(input_ids, past_key_values?, cache_position?, ...)
Model->>Mixin: prepare_inputs_for_generation(...)
alt Transformers >= 4.56
Mixin->>Mixin: if _cache_dependant_input_preparation exists
opt cache_position provided
Mixin->>Model: _cache_dependant_input_preparation(...)
Model-->>Mixin: model_inputs
end
Mixin->>Mixin: slice by cache_position or last-token, include cache_position/logits_to_keep
Mixin-->>Model: model_inputs
else Older Transformers
Mixin->>Mixin: legacy last-token / inputs_embeds handling
Mixin-->>Model: model_inputs
end
Model->>HF: forward(**model_inputs)
HF-->>User: logits / generated_ids
sequenceDiagram
autonumber
participant Caller as Caller
participant RWKV as RWKV7Model
participant PT as PreTrainedModel
Caller->>RWKV: load_state_dict(state_dict, strict=..., ...)
RWKV->>RWKV: detect v1 key patterns under model.layers.*
alt v1 keys present
RWKV->>RWKV: map v1 attn/* keys → v2 names (x_r, x_w, x_k, x_v, x_a, x_g)
end
RWKV->>PT: super().load_state_dict(migrated_state_dict, strict=...)
PT-->>Caller: LoadResult
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Pre-merge checks (1 passed, 1 warning, 1 inconclusive)❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
Poem
✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @zhiyuan1i, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request primarily focuses on updating the project's dependency on the Hugging Face Transformers library to support newer versions, specifically 4.56.x and beyond. To maintain compatibility and abstract away API differences, a custom generation mixin has been implemented. This ensures that all models within the project can leverage the latest transformers features while gracefully handling older versions, without requiring individual model adjustments for generation input preparation.
Highlights
- Dependency Upgrade: The
transformerslibrary dependency has been updated to allow versions 4.56.x and higher, removing the previous upper bound. - Custom Generation Mixin: A new
FLAGenerationMixinhas been introduced infla/models/utils.pyto encapsulate generation logic and ensure compatibility across differenttransformerslibrary versions. - Model Class Refactoring: All causal language model classes (e.g.,
ABCForCausalLM,MambaForCausalLM) now inherit from the newly introducedFLAGenerationMixininstead of the nativetransformers.generation.GenerationMixin. - Generation Input Preparation: The
FLAGenerationMixinprovides a version-awareprepare_inputs_for_generationmethod, handling API changes introduced intransformersversions 4.56.0 and above, particularly concerningcache_position.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request upgrades the transformers dependency to support version 4.56.x and introduces a FLAGenerationMixin to handle API changes for text generation, which is a good approach for maintaining compatibility. However, I've identified a critical bug in the new mixin that could cause crashes with older transformers versions, and an incomplete refactoring where many models won't benefit from this new compatibility layer. Please see my detailed comments for suggestions on how to address these issues.
a982c2e to
fec46c7
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (5)
setup.py (1)
41-47: Add missing runtime dependency: packagingThis repo imports
from packaging import version(e.g., fla/models/utils.py) butpackagingisn’t in install_requires; end users may hit ImportError. Add it here.install_requires=[ 'torch>=2.5', - 'transformers>=4.53.0', + 'transformers>=4.53.0', + 'packaging>=23.0', 'datasets>=3.3.0', 'einops', 'pytest' ],.github/workflows/reusable-ci-tests.yml (1)
80-87: Unify runner-to-env mapping: handle nvidia-h100-3 in both jobs
test-modelshandlesnvidia-h100-3(non-nightly) with_2suffix;test-opsdoes not. Mirror the logic to avoid selecting the wrong env on H100-3.- if [[ "${{ runner.name }}" == "nvidia-h100-1" ]]; then + if [[ "${{ runner.name }}" == "nvidia-h100-1" ]]; then TARGET_CONDA_ENV="${{ inputs.conda_env_name }}" elif [[ "${{ runner.name }}" == "nvidia-h100-2" ]]; then TARGET_CONDA_ENV="${{ inputs.conda_env_name }}_1" + elif [[ "${{ runner.name }}" == "nvidia-h100-3" && ! "${{ inputs.conda_env_name }}" == *"nightly"* ]]; then + TARGET_CONDA_ENV="${{ inputs.conda_env_name }}_2" else TARGET_CONDA_ENV="${{ inputs.conda_env_name }}" echo "Runner is not a special case, using input env: '${TARGET_CONDA_ENV}'" fiAlso applies to: 244-257
fla/models/comba/modeling_comba.py (1)
80-99: Propagatecache_positionthrough Attention and Comba
Addcache_position: Optional[torch.LongTensor] = Noneto theforwardsignatures infla/layers/attn.py(around line 80) andfla/layers/comba.py(around line 208), and pass it into your calls toget_seq_length/get_mask_sizesandpast_key_values.update()so that thecache_positionforwarded byCombaBlock’s**kwargsis actually consumed.fla/models/transformer/modeling_transformer.py (1)
326-327: Handle logits_to_keep=None to avoid TypeError and match other models.Other CausalLMs in this PR accept None to skip slicing. Here it would error if None is passed.
Apply:
- logits = None if self.config.fuse_linear_cross_entropy else self.lm_head(hidden_states[:, -logits_to_keep:]) + logits = None if self.config.fuse_linear_cross_entropy else self.lm_head( + hidden_states if logits_to_keep is None else hidden_states[:, -int(logits_to_keep):] + )fla/models/path_attn/modeling_path_attention.py (1)
328-349: Guard logits slicing when logits_to_keep is None to avoid TypeError.
Other models handle None explicitly; here a None value would break the negative slice.Apply:
- logits = None if self.config.fuse_linear_cross_entropy else self.lm_head(hidden_states[:, -logits_to_keep:]) + logits = None if self.config.fuse_linear_cross_entropy else self.lm_head( + hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:] + )
🧹 Nitpick comments (11)
setup.py (1)
41-47: Consider moving pytest to extraspytest is a test-only dependency; including it in install_requires bloats downstream envs. Suggest moving to an extra like
extras_require['dev']..github/workflows/reusable-ci-tests.yml (1)
166-188: Log transformers version during env verificationQuick sanity check to ensure the intended version is active.
echo "Python executable path: $CONDA_BIN_PATH/python" echo "PyTorch version: $($CONDA_BIN_PATH/python -c 'import torch; print(torch.__version__)')" + echo "Transformers version: $($CONDA_BIN_PATH/python -c 'import transformers; print(transformers.__version__)')"fla/models/utils.py (3)
391-393: Remove no-op init from mixinMixins generally avoid init; this one does nothing and can confuse MRO. Safe to drop.
- def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) -
395-405: Silence Ruff ARG002 without altering behavior
kwargsis intentionally accepted for signature parity. Explicitly consume it to appease linters.def prepare_inputs_for_generation( self, @@ - **kwargs + **kwargs ): + # Intentionally unused: kept for signature compatibility with HF APIs + _ = kwargs
416-427: Guard against missing input_ids when inputs_embeds provided mid-generationIf
inputs_embedsis provided with nonzerocache_position, current fallback may try to useinput_ids=None. Mirror HF behavior: in that case, slice last tokens instead of indexing bycache_position.- elif cache_position is not None: - # Fallback: manually slice using cache_position - if input_ids is not None and input_ids.shape[1] != cache_position.shape[0]: - input_ids = input_ids[:, cache_position] + elif cache_position is not None: + # Fallback: manually slice using cache_position (if input_ids present), else keep last tokens + if input_ids is not None: + if input_ids.shape[1] != cache_position.shape[-1]: + input_ids = input_ids[:, cache_position] + else: + # Keep last token(s) matching cache length to avoid None propagation + step = int(cache_position.shape[-1]) if cache_position is not None else 1 + input_ids = None if step == 0 else input_idsfla/models/bitnet/modeling_bitnet.py (1)
365-386: Double-check logits slicing when logits_to_keep=0
hidden_states[:, -0:]returns the full sequence; intended? If you expect “no slicing” when 0, considerNoneas sentinel instead.- logits = None if self.config.fuse_linear_cross_entropy else self.lm_head(hidden_states[:, -logits_to_keep:]) + if self.config.fuse_linear_cross_entropy: + logits = None + else: + to_slice = None if logits_to_keep in (None, 0) else logits_to_keep + logits = self.lm_head(hidden_states if to_slice is None else hidden_states[:, -to_slice:])fla/models/rwkv6/modeling_rwkv6.py (1)
418-437: Consistent logits slicing handlingSame optional refactor as BitNet to treat
logits_to_keep=0as “no slice”.fla/models/retnet/modeling_retnet.py (1)
350-370: Optional: unify logits_to_keep behaviorMirror the suggested pattern to treat 0 as “no slice”.
fla/models/hgrn2/modeling_hgrn2.py (1)
346-366: Optional: same logits_to_keep treatmentConsider the same small refactor as other heads for clarity with 0 vs None.
fla/models/transformer/modeling_transformer.py (1)
184-185: Type hint nit: past_key_values annotation could include Cache.Forward converts legacy caches to Cache; consider annotating as Union[Cache, List[Tensor]] for accuracy (purely typing).
fla/models/path_attn/modeling_path_attention.py (1)
191-193: Return type annotation mismatch.
Function returns BaseModelOutputWithPast but type hints say CausalLMOutputWithPast. Align for clarity.- ) -> Union[Tuple, CausalLMOutputWithPast]: + ) -> Union[Tuple, BaseModelOutputWithPast]:
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (31)
.github/workflows/reusable-ci-tests.yml(3 hunks)fla/models/abc/modeling_abc.py(2 hunks)fla/models/bitnet/modeling_bitnet.py(2 hunks)fla/models/comba/modeling_comba.py(2 hunks)fla/models/delta_net/modeling_delta_net.py(2 hunks)fla/models/forgetting_transformer/modeling_forgetting_transformer.py(2 hunks)fla/models/gated_deltanet/modeling_gated_deltanet.py(2 hunks)fla/models/gated_deltaproduct/modeling_gated_deltaproduct.py(2 hunks)fla/models/gla/modeling_gla.py(2 hunks)fla/models/gsa/modeling_gsa.py(2 hunks)fla/models/hgrn/modeling_hgrn.py(2 hunks)fla/models/hgrn2/modeling_hgrn2.py(2 hunks)fla/models/lightnet/modeling_lightnet.py(2 hunks)fla/models/linear_attn/modeling_linear_attn.py(2 hunks)fla/models/log_linear_mamba2/modeling_log_linear_mamba2.py(2 hunks)fla/models/mamba/modeling_mamba.py(2 hunks)fla/models/mamba2/modeling_mamba2.py(2 hunks)fla/models/mesa_net/modeling_mesa_net.py(2 hunks)fla/models/mla/modeling_mla.py(2 hunks)fla/models/mom/modeling_mom.py(2 hunks)fla/models/nsa/modeling_nsa.py(2 hunks)fla/models/path_attn/modeling_path_attention.py(2 hunks)fla/models/retnet/modeling_retnet.py(2 hunks)fla/models/rodimus/modeling_rodimus.py(2 hunks)fla/models/rwkv6/modeling_rwkv6.py(2 hunks)fla/models/rwkv7/modeling_rwkv7.py(2 hunks)fla/models/samba/modeling_samba.py(2 hunks)fla/models/transformer/modeling_transformer.py(2 hunks)fla/models/utils.py(2 hunks)pyproject.toml(1 hunks)setup.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (9)
- fla/models/mom/modeling_mom.py
- fla/models/gsa/modeling_gsa.py
- fla/models/forgetting_transformer/modeling_forgetting_transformer.py
- pyproject.toml
- fla/models/lightnet/modeling_lightnet.py
- fla/models/mamba/modeling_mamba.py
- fla/models/rodimus/modeling_rodimus.py
- fla/models/mamba2/modeling_mamba2.py
- fla/models/mla/modeling_mla.py
🧰 Additional context used
🧬 Code graph analysis (20)
fla/models/samba/modeling_samba.py (1)
fla/models/utils.py (1)
FLAGenerationMixin(385-462)
fla/models/nsa/modeling_nsa.py (1)
fla/models/utils.py (3)
Cache(466-468)Cache(470-472)FLAGenerationMixin(385-462)
fla/models/hgrn2/modeling_hgrn2.py (1)
fla/models/utils.py (3)
Cache(466-468)Cache(470-472)FLAGenerationMixin(385-462)
fla/models/path_attn/modeling_path_attention.py (1)
fla/models/utils.py (3)
Cache(466-468)Cache(470-472)FLAGenerationMixin(385-462)
fla/models/mesa_net/modeling_mesa_net.py (1)
fla/models/utils.py (3)
Cache(466-468)Cache(470-472)FLAGenerationMixin(385-462)
fla/models/gla/modeling_gla.py (1)
fla/models/utils.py (3)
Cache(466-468)Cache(470-472)FLAGenerationMixin(385-462)
fla/models/rwkv7/modeling_rwkv7.py (1)
fla/models/utils.py (3)
Cache(466-468)Cache(470-472)FLAGenerationMixin(385-462)
fla/models/rwkv6/modeling_rwkv6.py (1)
fla/models/utils.py (3)
Cache(466-468)Cache(470-472)FLAGenerationMixin(385-462)
fla/models/transformer/modeling_transformer.py (1)
fla/models/utils.py (3)
Cache(466-468)Cache(470-472)FLAGenerationMixin(385-462)
fla/models/bitnet/modeling_bitnet.py (1)
fla/models/utils.py (3)
Cache(466-468)Cache(470-472)FLAGenerationMixin(385-462)
fla/models/comba/modeling_comba.py (1)
fla/models/utils.py (3)
Cache(466-468)Cache(470-472)FLAGenerationMixin(385-462)
fla/models/abc/modeling_abc.py (1)
fla/models/utils.py (3)
Cache(466-468)Cache(470-472)FLAGenerationMixin(385-462)
fla/models/linear_attn/modeling_linear_attn.py (1)
fla/models/utils.py (3)
Cache(466-468)Cache(470-472)FLAGenerationMixin(385-462)
fla/models/utils.py (2)
fla/utils.py (1)
deprecate_kwarg(190-334)fla/models/mom/modeling_mom.py (1)
prepare_inputs_for_generation(394-426)
fla/models/gated_deltanet/modeling_gated_deltanet.py (1)
fla/models/utils.py (3)
Cache(466-468)Cache(470-472)FLAGenerationMixin(385-462)
fla/models/gated_deltaproduct/modeling_gated_deltaproduct.py (1)
fla/models/utils.py (3)
Cache(466-468)Cache(470-472)FLAGenerationMixin(385-462)
fla/models/log_linear_mamba2/modeling_log_linear_mamba2.py (1)
fla/models/utils.py (1)
FLAGenerationMixin(385-462)
fla/models/hgrn/modeling_hgrn.py (1)
fla/models/utils.py (3)
Cache(466-468)Cache(470-472)FLAGenerationMixin(385-462)
fla/models/retnet/modeling_retnet.py (1)
fla/models/utils.py (3)
Cache(466-468)Cache(470-472)FLAGenerationMixin(385-462)
fla/models/delta_net/modeling_delta_net.py (1)
fla/models/utils.py (3)
Cache(466-468)Cache(470-472)FLAGenerationMixin(385-462)
🪛 Ruff (0.12.2)
fla/models/utils.py
404-404: Unused method argument: kwargs
(ARG002)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Test H100 (PyTorch 2.7) / test-models
🔇 Additional comments (41)
fla/models/utils.py (3)
454-462: Approve: common fields assembly looks goodThe merged model_inputs block aligns with HF expectations (past_key_values/use_cache/attention_mask/logits_to_keep).
15-18: Minor: version gate constant is finePre-computing
_IS_TRANSFORMERS_4_56_PLUSavoids repeated parsing.
395-463: Sanity tests for 4.56+ generation pathThe sandbox lacked the
transformerspackage (dry-run errored) and there’s still a per-model override infla/models/mom/modeling_mom.pyyou’ll need to remove or update to use the unified mixin. Please re-run these checks in an environment withtransformers>=4.56.0installed:
- invoke
prepare_inputs_for_generationwith and withoutcache_position- supply
inputs_embedsmid-generation and confirm it’s passed throughfla/models/comba/modeling_comba.py (1)
19-20: LGTM: migrate to FLAGenerationMixinImport and base-class swap look correct; generation routes should now be version-aware via the mixin.
Also applies to: 268-269
fla/models/bitnet/modeling_bitnet.py (1)
18-19: LGTM: mixin/import swapBitNetForCausalLM now inherits the shared, version-aware generation path.
Also applies to: 298-299
fla/models/rwkv6/modeling_rwkv6.py (1)
19-20: LGTM: mixin/import swapRWKV6 adopts the shared mixin; generate override still guards against
past_key_valuesusage.Also applies to: 336-336
fla/models/retnet/modeling_retnet.py (2)
19-20: LGTM: mixin/import swapRetNet model now relies on FLAGenerationMixin.
Also applies to: 266-266
298-313: Good error message on unsupported past_key_values strategiesHelpful UX when users select incompatible strategies.
fla/models/hgrn2/modeling_hgrn2.py (1)
19-20: LGTM: mixin/import swapHGRN2 integrates with FLAGenerationMixin; generate override remains intact.
Also applies to: 264-265
fla/models/transformer/modeling_transformer.py (2)
18-18: Migration to FLAGenerationMixin looks correct.Import switch aligns with the PR intent; no issues spotted here.
259-259: Class now inherits FLAGenerationMixin — good alignment with utils.FLAGenerationMixin.This enables centralized, version-aware generation handling.
fla/models/nsa/modeling_nsa.py (3)
18-18: Import migration to FLAGenerationMixin — OK.
242-242: NSAForCausalLM inheriting FLAGenerationMixin — OK.Consistent with repo-wide change; no functional concerns here.
274-288: Helpful generate() error wrapping.Clearer message for unsupported past_key_values strategies.
fla/models/delta_net/modeling_delta_net.py (3)
19-19: Import migration to FLAGenerationMixin — OK.
258-258: DeltaNetForCausalLM now using FLAGenerationMixin — OK.
290-304: generate() wrapper improves UX for unsupported strategies.fla/models/mesa_net/modeling_mesa_net.py (3)
19-19: Import migration to FLAGenerationMixin — OK.
257-257: MesaNetForCausalLM now using FLAGenerationMixin — OK.
289-303: generate() wrapper matches pattern elsewhere — good.fla/models/hgrn/modeling_hgrn.py (3)
19-19: Import migration to FLAGenerationMixin — OK.
263-263: HGRNForCausalLM now using FLAGenerationMixin — OK.
295-309: generate() wrapper with targeted error is good.fla/models/abc/modeling_abc.py (3)
19-19: Import migration to FLAGenerationMixin — OK.
261-261: ABCForCausalLM now using FLAGenerationMixin — OK.
293-307: generate() wrapper matches others — good.fla/models/samba/modeling_samba.py (2)
19-19: Switched to FLAGenerationMixin (import) — OK.
356-359: Safe slicing with logits_to_keep=None is already handled — good.No action needed here.
fla/models/gated_deltanet/modeling_gated_deltanet.py (1)
19-19: Switch to FLAGenerationMixin/Cache import looks good.
No issues spotted with the new dependency points.fla/models/gated_deltaproduct/modeling_gated_deltaproduct.py (2)
19-19: Import swap to utils.FLAGenerationMixin/Cache — LGTM.
261-261: Class now mixes in FLAGenerationMixin — LGTM.
Matches the repo-wide migration.fla/models/log_linear_mamba2/modeling_log_linear_mamba2.py (2)
14-14: Switch to FLAGenerationMixin import — LGTM.
152-163: Stateful generation contract correctly implemented; no overrides needed. Mamba2’sprepare_inputs_for_generationand_update_model_kwargs_for_generationalready handlecache_params/cache_positionas expected by Transformers 4.56’sGenerationMixin, so the mixin interop is sound. Optionally, you can add the sameAttributeErrorguard ingeneratefor UX consistency.fla/models/gla/modeling_gla.py (2)
19-19: Import consolidation to utils.FLAGenerationMixin/Cache — LGTM.
261-261: GLAForCausalLM now uses FLAGenerationMixin — LGTM.
Behavior aligned with other models.fla/models/linear_attn/modeling_linear_attn.py (2)
19-19: Import change to FLAGenerationMixin/Cache — LGTM.
256-256: LinearAttentionForCausalLM adopts FLAGenerationMixin — LGTM.fla/models/path_attn/modeling_path_attention.py (2)
18-18: Import swap to utils.FLAGenerationMixin/Cache — LGTM.
261-261: PaTHAttentionForCausalLM now uses FLAGenerationMixin — LGTM.fla/models/rwkv7/modeling_rwkv7.py (2)
19-19: Import change to utils.FLAGenerationMixin/Cache — LGTM.
428-428: RWKV7ForCausalLM inherits FLAGenerationMixin — LGTM.
Consistent with the rest of the migration.
| $CONDA_BIN_PATH/pip install -U numpy packaging psutil ninja einops datasets transformers | ||
| $CONDA_BIN_PATH/pip install --no-deps . | ||
| else | ||
| STABLE_URL="https://download.pytorch.org/whl/${{ inputs.pytorch_cuda_version }}" | ||
| echo "Using stable index URL: $STABLE_URL" | ||
| $CONDA_BIN_PATH/pip install -U torch~=${{ inputs.pytorch_version }} triton --index-url $STABLE_URL | ||
| $CONDA_BIN_PATH/pip install -U numpy packaging psutil ninja einops datasets | ||
| $CONDA_BIN_PATH/pip install -U numpy packaging psutil ninja einops datasets transformers | ||
| $CONDA_BIN_PATH/pip install . |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Install transformers consistently across jobs/branches
Nightly paths install transformers, but stable paths in test-models (both NVIDIA and Intel) do not. This can leave an older preinstalled version and mask issues. Install it in those paths too.
- $CONDA_BIN_PATH/pip install -U torch~=${{ inputs.pytorch_version }} triton --index-url $STABLE_URL
- $CONDA_BIN_PATH/pip install -U numpy
+ $CONDA_BIN_PATH/pip install -U torch~=${{ inputs.pytorch_version }} triton --index-url $STABLE_URL
+ $CONDA_BIN_PATH/pip install -U numpy packaging psutil ninja einops datasets transformers
$CONDA_BIN_PATH/pip install .- $CONDA_BIN_PATH/pip install -U torch~=${{ inputs.pytorch_version }} pytorch-triton-xpu --index-url $XPU_URL
- $CONDA_BIN_PATH/pip install -U numpy
+ $CONDA_BIN_PATH/pip install -U torch~=${{ inputs.pytorch_version }} pytorch-triton-xpu --index-url $XPU_URL
+ $CONDA_BIN_PATH/pip install -U numpy packaging psutil ninja einops datasets transformers
$CONDA_BIN_PATH/pip install .Also applies to: 158-160, 321-329, 340-342
🤖 Prompt for AI Agents
.github/workflows/reusable-ci-tests.yml around lines 140-147 (also apply same
change at 158-160, 321-329, 340-342): the stable-path branches do not explicitly
install the transformers package, which can leave an older preinstalled version
and hide regressions; modify those stable-path install blocks to include
transformers in the pip install commands (matching the nightly paths) so that
transformers is installed/updated explicitly in the stable flows as well,
ensuring each affected block lists transformers alongside numpy packaging psutil
ninja einops datasets (and any other packages already present).
|
|
||
|
|
||
| class GatedDeltaNetForCausalLM(GatedDeltaNetPreTrainedModel, GenerationMixin): | ||
| class GatedDeltaNetForCausalLM(GatedDeltaNetPreTrainedModel, FLAGenerationMixin): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Adopt FLAGenerationMixin for generation — OK. Verify no stale per-model overrides remain.
Since the mixin centralizes prepare_inputs_for_generation, ensure no lingering model-level overrides cause shadowing.
Run to confirm no stray definitions remain:
🏁 Script executed:
#!/bin/bash
rg -nP --type=py '\bdef\s+prepare_inputs_for_generation\b' fla/models | sed -n '1,200p'Length of output: 136
Remove per-model override: Delete or delegate the prepare_inputs_for_generation method in fla/models/mom/modeling_mom.py (line 394) so it no longer shadows FLAGenerationMixin.
🤖 Prompt for AI Agents
In fla/models/gated_deltanet/modeling_gated_deltanet.py around line 269: the
class GatedDeltaNetForCausalLM is being shadowed by a per-model override of
prepare_inputs_for_generation in fla/models/mom/modeling_mom.py at line 394;
remove that method from modeling_mom.py or change it to delegate to the mixin by
replacing its body with a simple return
super().prepare_inputs_for_generation(*args, **kwargs) (or calling
FLAGenerationMixin.prepare_inputs_for_generation(self, *args, **kwargs)) so the
mixin implementation is used and no longer shadowed.
|
|
||
|
|
||
| class SambaForCausalLM(SambaPreTrainedModel, GenerationMixin): | ||
| class SambaForCausalLM(SambaPreTrainedModel, FLAGenerationMixin): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Breaking: removed Samba’s prepare_inputs_for_generation but forward expects cache_params, not past_key_values.
FLAGenerationMixin’s default prepare_inputs_for_generation returns past_key_values; Samba forward does not accept it and instead needs cache_params. Generation will raise “got an unexpected keyword argument 'past_key_values'”.
Add Samba-specific prepare_inputs_for_generation to map to cache_params:
class SambaForCausalLM(SambaPreTrainedModel, FLAGenerationMixin):
@@
def _update_model_kwargs_for_generation(
self, outputs: ModelOutput, model_kwargs: Dict[str, Any], **kwargs
) -> Dict[str, Any]:
model_kwargs["cache_params"] = outputs.get("cache_params", None)
return model_kwargs
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids: torch.LongTensor = None,
+ cache_params: Optional[MambaCache] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ use_cache: bool = True,
+ logits_to_keep: Optional[int] = None,
+ **kwargs: Any,
+ ) -> Dict[str, Any]:
+ model_inputs: Dict[str, Any] = {}
+ # If continuing with cache, only feed the next step token unless using inputs_embeds
+ if inputs_embeds is not None:
+ model_inputs["inputs_embeds"] = inputs_embeds
+ else:
+ # Keep static stride during decoding
+ model_inputs["input_ids"] = input_ids.contiguous() if input_ids is not None else None
+ if logits_to_keep is not None:
+ model_inputs["logits_to_keep"] = logits_to_keep
+ model_inputs.update(
+ {
+ "cache_params": cache_params,
+ "use_cache": use_cache,
+ "attention_mask": attention_mask,
+ }
+ )
+ return model_inputs📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| class SambaForCausalLM(SambaPreTrainedModel, FLAGenerationMixin): | |
| class SambaForCausalLM(SambaPreTrainedModel, FLAGenerationMixin): | |
| def _update_model_kwargs_for_generation( | |
| self, outputs: ModelOutput, model_kwargs: Dict[str, Any], **kwargs | |
| ) -> Dict[str, Any]: | |
| model_kwargs["cache_params"] = outputs.get("cache_params", None) | |
| return model_kwargs | |
| def prepare_inputs_for_generation( | |
| self, | |
| input_ids: torch.LongTensor = None, | |
| cache_params: Optional[MambaCache] = None, | |
| inputs_embeds: Optional[torch.Tensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| use_cache: bool = True, | |
| logits_to_keep: Optional[int] = None, | |
| **kwargs: Any, | |
| ) -> Dict[str, Any]: | |
| model_inputs: Dict[str, Any] = {} | |
| # If continuing with cache, only feed the next step token unless using inputs_embeds | |
| if inputs_embeds is not None: | |
| model_inputs["inputs_embeds"] = inputs_embeds | |
| else: | |
| # Keep static stride during decoding | |
| model_inputs["input_ids"] = input_ids.contiguous() if input_ids is not None else None | |
| if logits_to_keep is not None: | |
| model_inputs["logits_to_keep"] = logits_to_keep | |
| model_inputs.update( | |
| { | |
| "cache_params": cache_params, | |
| "use_cache": use_cache, | |
| "attention_mask": attention_mask, | |
| } | |
| ) | |
| return model_inputs |
| from transformers.generation import GenerationMixin | ||
| from transformers.utils.deprecation import deprecate_kwarg | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Harden GenerationMixin import for cross-version compatibility
Some Transformers versions relocate GenerationMixin; a single import path can break (seen intermittently around 4.52–4.56). Add a fallback chain.
-from transformers.generation import GenerationMixin
-from transformers.utils.deprecation import deprecate_kwarg
+try:
+ from transformers.generation import GenerationMixin
+except Exception:
+ try:
+ from transformers.generation.utils import GenerationMixin
+ except Exception:
+ # Final fallback for edge versions
+ from transformers import GenerationMixin
+from transformers.utils.deprecation import deprecate_kwarg🤖 Prompt for AI Agents
In fla/models/utils.py around lines 12 to 14, the direct import of
GenerationMixin can break across Transformers versions; replace the single
import with a fallback chain that first tries from transformers.generation
import GenerationMixin, then from transformers.generation_utils import
GenerationMixin, and finally from transformers import GenerationMixin (or other
plausible locations) inside try/except blocks to catch ImportError; ensure the
final fallback raises a clear ImportError if none succeed and keep the rest of
the file using the resolved GenerationMixin name.
| if inputs_embeds is not None and (cache_position is None or len(cache_position) == inputs_embeds.shape[1]): | ||
| model_inputs['inputs_embeds'] = inputs_embeds | ||
| model_inputs['input_ids'] = None | ||
| else: | ||
| model_inputs['input_ids'] = input_ids.contiguous() if input_ids is not None else None | ||
| model_inputs['inputs_embeds'] = None | ||
|
|
||
| model_inputs['cache_position'] = cache_position | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Handle cache_position shape robustly when using inputs_embeds
Comparing len(cache_position) to inputs_embeds.shape[1] breaks if cache_position is 2D. Use last-dim length.
- if inputs_embeds is not None and (cache_position is None or len(cache_position) == inputs_embeds.shape[1]):
+ cache_len = None
+ if cache_position is not None:
+ cache_len = cache_position.shape[-1]
+ if inputs_embeds is not None and (cache_len is None or cache_len == inputs_embeds.shape[1]):
model_inputs['inputs_embeds'] = inputs_embeds
model_inputs['input_ids'] = None
else:
model_inputs['input_ids'] = input_ids.contiguous() if input_ids is not None else None
model_inputs['inputs_embeds'] = None📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if inputs_embeds is not None and (cache_position is None or len(cache_position) == inputs_embeds.shape[1]): | |
| model_inputs['inputs_embeds'] = inputs_embeds | |
| model_inputs['input_ids'] = None | |
| else: | |
| model_inputs['input_ids'] = input_ids.contiguous() if input_ids is not None else None | |
| model_inputs['inputs_embeds'] = None | |
| model_inputs['cache_position'] = cache_position | |
| cache_len = None | |
| if cache_position is not None: | |
| cache_len = cache_position.shape[-1] | |
| if inputs_embeds is not None and (cache_len is None or cache_len == inputs_embeds.shape[1]): | |
| model_inputs['inputs_embeds'] = inputs_embeds | |
| model_inputs['input_ids'] = None | |
| else: | |
| model_inputs['input_ids'] = input_ids.contiguous() if input_ids is not None else None | |
| model_inputs['inputs_embeds'] = None | |
| model_inputs['cache_position'] = cache_position |
🤖 Prompt for AI Agents
In fla/models/utils.py around lines 429 to 437, the code compares
len(cache_position) to inputs_embeds.shape[1], which fails if cache_position is
a 2D tensor; change the comparison to use the last dimension length (e.g.
cache_position.shape[-1]) and guard for None so the condition becomes: if
inputs_embeds is not None and (cache_position is None or
cache_position.shape[-1] == inputs_embeds.shape[1]); leave the rest of the
branch logic unchanged.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (3)
tests/models/test_modeling_mamba2.py (3)
15-25: Guard the suite when CUDA isn’t available to avoid spurious CI failuresAll parametrizations now require
conv_backend='cuda'. Add a collection-time skip so CPU-only environments don’t fail this test.+@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA for conv_backend='cuda'") @pytest.mark.parametrize( ['L', 'B', 'T', 'H', 'D', 'use_l2warp', 'dtype', 'conv_backend'],
36-41: Use pytest’s monkeypatch to avoid leaking environment variables across testsSetting
os.environdirectly is process-global and can bleed into other tests. Usemonkeypatch.setenv, which auto-restores after the test.-def test_modeling( +def test_modeling( L: int, B: int, T: int, H: int, D: int, use_l2warp: bool, dtype: torch.dtype, - conv_backend: str, + conv_backend: str, + monkeypatch, ): @@ - os.environ['FLA_CONV_BACKEND'] = conv_backend + monkeypatch.setenv('FLA_CONV_BACKEND', conv_backend)
36-73: Skip BF16 runs on pre-Ampere GPUs to prevent backend errorsCUDA BF16 requires SM80+ (Ampere). Add a runtime skip for that combo to reduce flakiness on mixed CI fleets.
def test_modeling( @@ - monkeypatch.setenv('FLA_CONV_BACKEND', conv_backend) + monkeypatch.setenv('FLA_CONV_BACKEND', conv_backend) + + # CUDA BF16 requires Ampere+ (SM80) + if ( + conv_backend == 'cuda' + and dtype == torch.bfloat16 + and torch.cuda.is_available() + ): + major, _ = torch.cuda.get_device_capability(0) + if major < 8: + pytest.skip("CUDA bfloat16 requires SM80+ (Ampere or newer).")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
fla/layers/mamba2.py(1 hunks)tests/models/test_modeling_mamba2.py(1 hunks)
✅ Files skipped from review due to trivial changes (1)
- fla/layers/mamba2.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Test H100 (PyTorch 2.7) / test-models
- GitHub Check: check-pt-python-compatibility
🔇 Additional comments (1)
tests/models/test_modeling_mamba2.py (1)
21-21: Switching to CUDA backend for this case looks goodAligns the test with the project’s guidance to prefer CUDA over Triton for Mamba2 conv1d. No functional concerns.
* [Deps] Upgrade to transformers 4.56.x * remove triton conv1d backend for mamba2 * skip mom test
Summary by CodeRabbit
New Features
Dependency
Refactor
Chores