Skip to content

Conversation

@zhiyuan1i
Copy link
Collaborator

@zhiyuan1i zhiyuan1i commented Sep 9, 2025

Summary by CodeRabbit

  • New Features

    • More consistent generation behavior across models with improved handling for cache-aware decoding and clearer errors for unsupported decoding strategies.
  • Dependency

    • Transformers requirement relaxed to >=4.53.0 (no upper bound) for broader compatibility.
  • Refactor

    • Many causal‑LM models moved to the new unified generation backend while preserving the public generate API.
  • Chores

    • CI: smarter Conda detection, env propagation, dependency installs, and extra diagnostic logging.
    • Runtime: added warning when using an untested Triton conv1d backend.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 9, 2025

Warning

Rate limit exceeded

@zhiyuan1i has exceeded the limit for the number of commits or files that can be reviewed per hour. Please wait 27 minutes and 37 seconds before requesting another review.

⌛ How to resolve this issue?

After the wait time has elapsed, a review can be triggered using the @coderabbitai review command as a PR comment. Alternatively, push new commits to this PR.

We recommend that you space out your commits to avoid hitting the rate limit.

🚦 How do rate limits work?

CodeRabbit enforces hourly rate limits for each developer per organization.

Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout.

Please see our FAQ for further information.

📥 Commits

Reviewing files that changed from the base of the PR and between bfc34d3 and 4839b6a.

📒 Files selected for processing (1)
  • tests/models/test_modeling_mom.py (1 hunks)

Walkthrough

Replaces per-model GenerationMixin with a new FLAGenerationMixin, centralizes version-gated prepare_inputs_for_generation and Cache logic in fla/models/utils.py, relaxes the transformers version bound, adjusts CI to install transformers and discover/export Conda envs, adds a Triton warning in fla/layers/mamba2.py, and updates a test to use cuda backend.

Changes

Cohort / File(s) Summary of changes
CI workflow
.github/workflows/reusable-ci-tests.yml
Add transformers to dependency installs across multiple PyTorch paths; add dynamic Conda discovery/export (CONDA, CONDA_ENV_NAME, CONDA_BIN_PATH), runner-specific TARGET_CONDA_ENV selection, environment verification prints, and extra logging/env propagation for subsequent steps.
Dependency constraints
pyproject.toml, setup.py
Loosen transformers constraint from transformers>=4.53.0,<4.56.0transformers>=4.53.0 (remove upper bound).
Central generation utilities
fla/models/utils.py
Add FLAGenerationMixin(GenerationMixin) with a version-gated prepare_inputs_for_generation (supports Transformers >=4.56 behavior and legacy behavior), introduce _IS_TRANSFORMERS_4_56_PLUS flag, and make Cache subclassing conditional on transformers version. (Note: file now contains paired/duplicated public declarations introduced by the change.)
Model mixin swap (bulk)
fla/models/*/modeling_*.py
Replace GenerationMixin imports/usage with FLAGenerationMixin and add FLAGenerationMixin/Cache imports from fla.models.utils; remove per-model prepare_inputs_for_generation helpers across many CausalLM model files (e.g., abc, bitnet, comba, delta_net, forgetting_transformer, gated_deltanet, gated_deltaproduct, gla, gsa, hgrn, hgrn2, lightnet, linear_attn, log_linear_mamba2, mamba, mamba2, mesa_net, mla, mom, nsa, path_attn, retnet, rodimus, rwkv6, rwkv7, samba, transformer).
Generate wrappers / error handling
fla/models/*/modeling_*.py
For several models (examples: gated_deltaproduct, gla, hgrn, rwkv7, ...) add/adjust generate wrappers to catch AttributeError related to past_key_values and raise clearer errors or re-raise.
RWKV7 state migration
fla/models/rwkv7/modeling_rwkv7.py
Switch to FLAGenerationMixin, remove per-model prepare_inputs_for_generation, and add state-dict migration logic in load_state_dict to remap v1→v2 parameter keys before delegating to superclass load.
Mamba2 runtime warning & test
fla/layers/mamba2.py, tests/models/test_modeling_mamba2.py
Add runtime warning when Triton conv1d backend is selected in Mamba2.__init__; change test parameterization to use cuda instead of triton for one case.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  actor User
  participant Model as CausalLM Model
  participant Mixin as FLAGenerationMixin
  participant HF as Transformers Runtime

  User->>Model: generate(input_ids, past_key_values?, cache_position?, ...)
  Model->>Mixin: prepare_inputs_for_generation(...)
  alt Transformers >= 4.56
    Mixin->>Mixin: if _cache_dependant_input_preparation exists
    opt cache_position provided
      Mixin->>Model: _cache_dependant_input_preparation(...)
      Model-->>Mixin: model_inputs
    end
    Mixin->>Mixin: slice by cache_position or last-token, include cache_position/logits_to_keep
    Mixin-->>Model: model_inputs
  else Older Transformers
    Mixin->>Mixin: legacy last-token / inputs_embeds handling
    Mixin-->>Model: model_inputs
  end
  Model->>HF: forward(**model_inputs)
  HF-->>User: logits / generated_ids
Loading
sequenceDiagram
  autonumber
  participant Caller as Caller
  participant RWKV as RWKV7Model
  participant PT as PreTrainedModel

  Caller->>RWKV: load_state_dict(state_dict, strict=..., ...)
  RWKV->>RWKV: detect v1 key patterns under model.layers.*
  alt v1 keys present
    RWKV->>RWKV: map v1 attn/* keys → v2 names (x_r, x_w, x_k, x_v, x_a, x_g)
  end
  RWKV->>PT: super().load_state_dict(migrated_state_dict, strict=...)
  PT-->>Caller: LoadResult
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • yzhangcs

Pre-merge checks (1 passed, 1 warning, 1 inconclusive)

❌ Failed checks (1 warning, 1 inconclusive)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 4.76% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Description Check ❓ Inconclusive The pull request description is not present in the provided context so its relevance to the changeset cannot be assessed. Without any visible description, there is insufficient information to determine whether it meaningfully relates to the modifications. This makes the description check inconclusive. Please update the pull request with a clear description summarizing the major changes, including the dependency upgrade to transformers 4.56.x, the introduction and application of FLAGenerationMixin, and any CI pipeline adjustments. A concise overview of key impacted areas will help reviewers quickly understand the scope and rationale of this PR.
✅ Passed checks (1 passed)
Check name Status Explanation
Title Check ✅ Passed The title “[Deps] Upgrade to transformers 4.56.x” succinctly captures the primary intent of loosening the transformers version constraint and enabling support for the new 4.56.x release, which is the core focus of the changeset. It is concise, free of extraneous detail, and directly reflects the dependency update that reviewers will care about. As a standalone line, it clearly conveys the main change without delving into implementation specifics.

Poem

A rabbit hops through code at speed,
Swapping mixins where new rules lead.
Versions loosen, caches align,
Conda found, CI prints a sign.
Models generate with tidy grace — hooray! 🐇✨

✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch upstream-transformers

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @zhiyuan1i, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request primarily focuses on updating the project's dependency on the Hugging Face Transformers library to support newer versions, specifically 4.56.x and beyond. To maintain compatibility and abstract away API differences, a custom generation mixin has been implemented. This ensures that all models within the project can leverage the latest transformers features while gracefully handling older versions, without requiring individual model adjustments for generation input preparation.

Highlights

  • Dependency Upgrade: The transformers library dependency has been updated to allow versions 4.56.x and higher, removing the previous upper bound.
  • Custom Generation Mixin: A new FLAGenerationMixin has been introduced in fla/models/utils.py to encapsulate generation logic and ensure compatibility across different transformers library versions.
  • Model Class Refactoring: All causal language model classes (e.g., ABCForCausalLM, MambaForCausalLM) now inherit from the newly introduced FLAGenerationMixin instead of the native transformers.generation.GenerationMixin.
  • Generation Input Preparation: The FLAGenerationMixin provides a version-aware prepare_inputs_for_generation method, handling API changes introduced in transformers versions 4.56.0 and above, particularly concerning cache_position.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request upgrades the transformers dependency to support version 4.56.x and introduces a FLAGenerationMixin to handle API changes for text generation, which is a good approach for maintaining compatibility. However, I've identified a critical bug in the new mixin that could cause crashes with older transformers versions, and an incomplete refactoring where many models won't benefit from this new compatibility layer. Please see my detailed comments for suggestions on how to address these issues.

@zhiyuan1i zhiyuan1i force-pushed the upstream-transformers branch from a982c2e to fec46c7 Compare September 9, 2025 16:59
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (5)
setup.py (1)

41-47: Add missing runtime dependency: packaging

This repo imports from packaging import version (e.g., fla/models/utils.py) but packaging isn’t in install_requires; end users may hit ImportError. Add it here.

   install_requires=[
     'torch>=2.5',
-    'transformers>=4.53.0',
+    'transformers>=4.53.0',
+    'packaging>=23.0',
     'datasets>=3.3.0',
     'einops',
     'pytest'
   ],
.github/workflows/reusable-ci-tests.yml (1)

80-87: Unify runner-to-env mapping: handle nvidia-h100-3 in both jobs

test-models handles nvidia-h100-3 (non-nightly) with _2 suffix; test-ops does not. Mirror the logic to avoid selecting the wrong env on H100-3.

-          if [[ "${{ runner.name }}" == "nvidia-h100-1" ]]; then
+          if [[ "${{ runner.name }}" == "nvidia-h100-1" ]]; then
             TARGET_CONDA_ENV="${{ inputs.conda_env_name }}"
           elif [[ "${{ runner.name }}" == "nvidia-h100-2" ]]; then
             TARGET_CONDA_ENV="${{ inputs.conda_env_name }}_1"
+          elif [[ "${{ runner.name }}" == "nvidia-h100-3" && ! "${{ inputs.conda_env_name }}" == *"nightly"* ]]; then
+            TARGET_CONDA_ENV="${{ inputs.conda_env_name }}_2"
           else
             TARGET_CONDA_ENV="${{ inputs.conda_env_name }}"
             echo "Runner is not a special case, using input env: '${TARGET_CONDA_ENV}'"
           fi

Also applies to: 244-257

fla/models/comba/modeling_comba.py (1)

80-99: Propagate cache_position through Attention and Comba
Add cache_position: Optional[torch.LongTensor] = None to the forward signatures in fla/layers/attn.py (around line 80) and fla/layers/comba.py (around line 208), and pass it into your calls to get_seq_length/get_mask_sizes and past_key_values.update() so that the cache_position forwarded by CombaBlock’s **kwargs is actually consumed.

fla/models/transformer/modeling_transformer.py (1)

326-327: Handle logits_to_keep=None to avoid TypeError and match other models.

Other CausalLMs in this PR accept None to skip slicing. Here it would error if None is passed.

Apply:

-        logits = None if self.config.fuse_linear_cross_entropy else self.lm_head(hidden_states[:, -logits_to_keep:])
+        logits = None if self.config.fuse_linear_cross_entropy else self.lm_head(
+            hidden_states if logits_to_keep is None else hidden_states[:, -int(logits_to_keep):]
+        )
fla/models/path_attn/modeling_path_attention.py (1)

328-349: Guard logits slicing when logits_to_keep is None to avoid TypeError.
Other models handle None explicitly; here a None value would break the negative slice.

Apply:

-        logits = None if self.config.fuse_linear_cross_entropy else self.lm_head(hidden_states[:, -logits_to_keep:])
+        logits = None if self.config.fuse_linear_cross_entropy else self.lm_head(
+            hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:]
+        )
🧹 Nitpick comments (11)
setup.py (1)

41-47: Consider moving pytest to extras

pytest is a test-only dependency; including it in install_requires bloats downstream envs. Suggest moving to an extra like extras_require['dev'].

.github/workflows/reusable-ci-tests.yml (1)

166-188: Log transformers version during env verification

Quick sanity check to ensure the intended version is active.

           echo "Python executable path: $CONDA_BIN_PATH/python"
           echo "PyTorch version: $($CONDA_BIN_PATH/python -c 'import torch; print(torch.__version__)')"
+          echo "Transformers version: $($CONDA_BIN_PATH/python -c 'import transformers; print(transformers.__version__)')"
fla/models/utils.py (3)

391-393: Remove no-op init from mixin

Mixins generally avoid init; this one does nothing and can confuse MRO. Safe to drop.

-    def __init__(self, *args, **kwargs):
-        super().__init__(*args, **kwargs)
-

395-405: Silence Ruff ARG002 without altering behavior

kwargs is intentionally accepted for signature parity. Explicitly consume it to appease linters.

     def prepare_inputs_for_generation(
         self,
@@
-        **kwargs
+        **kwargs
     ):
+        # Intentionally unused: kept for signature compatibility with HF APIs
+        _ = kwargs

416-427: Guard against missing input_ids when inputs_embeds provided mid-generation

If inputs_embeds is provided with nonzero cache_position, current fallback may try to use input_ids=None. Mirror HF behavior: in that case, slice last tokens instead of indexing by cache_position.

-                elif cache_position is not None:
-                    # Fallback: manually slice using cache_position
-                    if input_ids is not None and input_ids.shape[1] != cache_position.shape[0]:
-                        input_ids = input_ids[:, cache_position]
+                elif cache_position is not None:
+                    # Fallback: manually slice using cache_position (if input_ids present), else keep last tokens
+                    if input_ids is not None:
+                        if input_ids.shape[1] != cache_position.shape[-1]:
+                            input_ids = input_ids[:, cache_position]
+                    else:
+                        # Keep last token(s) matching cache length to avoid None propagation
+                        step = int(cache_position.shape[-1]) if cache_position is not None else 1
+                        input_ids = None if step == 0 else input_ids
fla/models/bitnet/modeling_bitnet.py (1)

365-386: Double-check logits slicing when logits_to_keep=0

hidden_states[:, -0:] returns the full sequence; intended? If you expect “no slicing” when 0, consider None as sentinel instead.

-        logits = None if self.config.fuse_linear_cross_entropy else self.lm_head(hidden_states[:, -logits_to_keep:])
+        if self.config.fuse_linear_cross_entropy:
+            logits = None
+        else:
+            to_slice = None if logits_to_keep in (None, 0) else logits_to_keep
+            logits = self.lm_head(hidden_states if to_slice is None else hidden_states[:, -to_slice:])
fla/models/rwkv6/modeling_rwkv6.py (1)

418-437: Consistent logits slicing handling

Same optional refactor as BitNet to treat logits_to_keep=0 as “no slice”.

fla/models/retnet/modeling_retnet.py (1)

350-370: Optional: unify logits_to_keep behavior

Mirror the suggested pattern to treat 0 as “no slice”.

fla/models/hgrn2/modeling_hgrn2.py (1)

346-366: Optional: same logits_to_keep treatment

Consider the same small refactor as other heads for clarity with 0 vs None.

fla/models/transformer/modeling_transformer.py (1)

184-185: Type hint nit: past_key_values annotation could include Cache.

Forward converts legacy caches to Cache; consider annotating as Union[Cache, List[Tensor]] for accuracy (purely typing).

fla/models/path_attn/modeling_path_attention.py (1)

191-193: Return type annotation mismatch.
Function returns BaseModelOutputWithPast but type hints say CausalLMOutputWithPast. Align for clarity.

-    ) -> Union[Tuple, CausalLMOutputWithPast]:
+    ) -> Union[Tuple, BaseModelOutputWithPast]:
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between a982c2e and fec46c7.

📒 Files selected for processing (31)
  • .github/workflows/reusable-ci-tests.yml (3 hunks)
  • fla/models/abc/modeling_abc.py (2 hunks)
  • fla/models/bitnet/modeling_bitnet.py (2 hunks)
  • fla/models/comba/modeling_comba.py (2 hunks)
  • fla/models/delta_net/modeling_delta_net.py (2 hunks)
  • fla/models/forgetting_transformer/modeling_forgetting_transformer.py (2 hunks)
  • fla/models/gated_deltanet/modeling_gated_deltanet.py (2 hunks)
  • fla/models/gated_deltaproduct/modeling_gated_deltaproduct.py (2 hunks)
  • fla/models/gla/modeling_gla.py (2 hunks)
  • fla/models/gsa/modeling_gsa.py (2 hunks)
  • fla/models/hgrn/modeling_hgrn.py (2 hunks)
  • fla/models/hgrn2/modeling_hgrn2.py (2 hunks)
  • fla/models/lightnet/modeling_lightnet.py (2 hunks)
  • fla/models/linear_attn/modeling_linear_attn.py (2 hunks)
  • fla/models/log_linear_mamba2/modeling_log_linear_mamba2.py (2 hunks)
  • fla/models/mamba/modeling_mamba.py (2 hunks)
  • fla/models/mamba2/modeling_mamba2.py (2 hunks)
  • fla/models/mesa_net/modeling_mesa_net.py (2 hunks)
  • fla/models/mla/modeling_mla.py (2 hunks)
  • fla/models/mom/modeling_mom.py (2 hunks)
  • fla/models/nsa/modeling_nsa.py (2 hunks)
  • fla/models/path_attn/modeling_path_attention.py (2 hunks)
  • fla/models/retnet/modeling_retnet.py (2 hunks)
  • fla/models/rodimus/modeling_rodimus.py (2 hunks)
  • fla/models/rwkv6/modeling_rwkv6.py (2 hunks)
  • fla/models/rwkv7/modeling_rwkv7.py (2 hunks)
  • fla/models/samba/modeling_samba.py (2 hunks)
  • fla/models/transformer/modeling_transformer.py (2 hunks)
  • fla/models/utils.py (2 hunks)
  • pyproject.toml (1 hunks)
  • setup.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (9)
  • fla/models/mom/modeling_mom.py
  • fla/models/gsa/modeling_gsa.py
  • fla/models/forgetting_transformer/modeling_forgetting_transformer.py
  • pyproject.toml
  • fla/models/lightnet/modeling_lightnet.py
  • fla/models/mamba/modeling_mamba.py
  • fla/models/rodimus/modeling_rodimus.py
  • fla/models/mamba2/modeling_mamba2.py
  • fla/models/mla/modeling_mla.py
🧰 Additional context used
🧬 Code graph analysis (20)
fla/models/samba/modeling_samba.py (1)
fla/models/utils.py (1)
  • FLAGenerationMixin (385-462)
fla/models/nsa/modeling_nsa.py (1)
fla/models/utils.py (3)
  • Cache (466-468)
  • Cache (470-472)
  • FLAGenerationMixin (385-462)
fla/models/hgrn2/modeling_hgrn2.py (1)
fla/models/utils.py (3)
  • Cache (466-468)
  • Cache (470-472)
  • FLAGenerationMixin (385-462)
fla/models/path_attn/modeling_path_attention.py (1)
fla/models/utils.py (3)
  • Cache (466-468)
  • Cache (470-472)
  • FLAGenerationMixin (385-462)
fla/models/mesa_net/modeling_mesa_net.py (1)
fla/models/utils.py (3)
  • Cache (466-468)
  • Cache (470-472)
  • FLAGenerationMixin (385-462)
fla/models/gla/modeling_gla.py (1)
fla/models/utils.py (3)
  • Cache (466-468)
  • Cache (470-472)
  • FLAGenerationMixin (385-462)
fla/models/rwkv7/modeling_rwkv7.py (1)
fla/models/utils.py (3)
  • Cache (466-468)
  • Cache (470-472)
  • FLAGenerationMixin (385-462)
fla/models/rwkv6/modeling_rwkv6.py (1)
fla/models/utils.py (3)
  • Cache (466-468)
  • Cache (470-472)
  • FLAGenerationMixin (385-462)
fla/models/transformer/modeling_transformer.py (1)
fla/models/utils.py (3)
  • Cache (466-468)
  • Cache (470-472)
  • FLAGenerationMixin (385-462)
fla/models/bitnet/modeling_bitnet.py (1)
fla/models/utils.py (3)
  • Cache (466-468)
  • Cache (470-472)
  • FLAGenerationMixin (385-462)
fla/models/comba/modeling_comba.py (1)
fla/models/utils.py (3)
  • Cache (466-468)
  • Cache (470-472)
  • FLAGenerationMixin (385-462)
fla/models/abc/modeling_abc.py (1)
fla/models/utils.py (3)
  • Cache (466-468)
  • Cache (470-472)
  • FLAGenerationMixin (385-462)
fla/models/linear_attn/modeling_linear_attn.py (1)
fla/models/utils.py (3)
  • Cache (466-468)
  • Cache (470-472)
  • FLAGenerationMixin (385-462)
fla/models/utils.py (2)
fla/utils.py (1)
  • deprecate_kwarg (190-334)
fla/models/mom/modeling_mom.py (1)
  • prepare_inputs_for_generation (394-426)
fla/models/gated_deltanet/modeling_gated_deltanet.py (1)
fla/models/utils.py (3)
  • Cache (466-468)
  • Cache (470-472)
  • FLAGenerationMixin (385-462)
fla/models/gated_deltaproduct/modeling_gated_deltaproduct.py (1)
fla/models/utils.py (3)
  • Cache (466-468)
  • Cache (470-472)
  • FLAGenerationMixin (385-462)
fla/models/log_linear_mamba2/modeling_log_linear_mamba2.py (1)
fla/models/utils.py (1)
  • FLAGenerationMixin (385-462)
fla/models/hgrn/modeling_hgrn.py (1)
fla/models/utils.py (3)
  • Cache (466-468)
  • Cache (470-472)
  • FLAGenerationMixin (385-462)
fla/models/retnet/modeling_retnet.py (1)
fla/models/utils.py (3)
  • Cache (466-468)
  • Cache (470-472)
  • FLAGenerationMixin (385-462)
fla/models/delta_net/modeling_delta_net.py (1)
fla/models/utils.py (3)
  • Cache (466-468)
  • Cache (470-472)
  • FLAGenerationMixin (385-462)
🪛 Ruff (0.12.2)
fla/models/utils.py

404-404: Unused method argument: kwargs

(ARG002)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Test H100 (PyTorch 2.7) / test-models
🔇 Additional comments (41)
fla/models/utils.py (3)

454-462: Approve: common fields assembly looks good

The merged model_inputs block aligns with HF expectations (past_key_values/use_cache/attention_mask/logits_to_keep).


15-18: Minor: version gate constant is fine

Pre-computing _IS_TRANSFORMERS_4_56_PLUS avoids repeated parsing.


395-463: Sanity tests for 4.56+ generation path

The sandbox lacked the transformers package (dry-run errored) and there’s still a per-model override in fla/models/mom/modeling_mom.py you’ll need to remove or update to use the unified mixin. Please re-run these checks in an environment with transformers>=4.56.0 installed:

  • invoke prepare_inputs_for_generation with and without cache_position
  • supply inputs_embeds mid-generation and confirm it’s passed through
fla/models/comba/modeling_comba.py (1)

19-20: LGTM: migrate to FLAGenerationMixin

Import and base-class swap look correct; generation routes should now be version-aware via the mixin.

Also applies to: 268-269

fla/models/bitnet/modeling_bitnet.py (1)

18-19: LGTM: mixin/import swap

BitNetForCausalLM now inherits the shared, version-aware generation path.

Also applies to: 298-299

fla/models/rwkv6/modeling_rwkv6.py (1)

19-20: LGTM: mixin/import swap

RWKV6 adopts the shared mixin; generate override still guards against past_key_values usage.

Also applies to: 336-336

fla/models/retnet/modeling_retnet.py (2)

19-20: LGTM: mixin/import swap

RetNet model now relies on FLAGenerationMixin.

Also applies to: 266-266


298-313: Good error message on unsupported past_key_values strategies

Helpful UX when users select incompatible strategies.

fla/models/hgrn2/modeling_hgrn2.py (1)

19-20: LGTM: mixin/import swap

HGRN2 integrates with FLAGenerationMixin; generate override remains intact.

Also applies to: 264-265

fla/models/transformer/modeling_transformer.py (2)

18-18: Migration to FLAGenerationMixin looks correct.

Import switch aligns with the PR intent; no issues spotted here.


259-259: Class now inherits FLAGenerationMixin — good alignment with utils.FLAGenerationMixin.

This enables centralized, version-aware generation handling.

fla/models/nsa/modeling_nsa.py (3)

18-18: Import migration to FLAGenerationMixin — OK.


242-242: NSAForCausalLM inheriting FLAGenerationMixin — OK.

Consistent with repo-wide change; no functional concerns here.


274-288: Helpful generate() error wrapping.

Clearer message for unsupported past_key_values strategies.

fla/models/delta_net/modeling_delta_net.py (3)

19-19: Import migration to FLAGenerationMixin — OK.


258-258: DeltaNetForCausalLM now using FLAGenerationMixin — OK.


290-304: generate() wrapper improves UX for unsupported strategies.

fla/models/mesa_net/modeling_mesa_net.py (3)

19-19: Import migration to FLAGenerationMixin — OK.


257-257: MesaNetForCausalLM now using FLAGenerationMixin — OK.


289-303: generate() wrapper matches pattern elsewhere — good.

fla/models/hgrn/modeling_hgrn.py (3)

19-19: Import migration to FLAGenerationMixin — OK.


263-263: HGRNForCausalLM now using FLAGenerationMixin — OK.


295-309: generate() wrapper with targeted error is good.

fla/models/abc/modeling_abc.py (3)

19-19: Import migration to FLAGenerationMixin — OK.


261-261: ABCForCausalLM now using FLAGenerationMixin — OK.


293-307: generate() wrapper matches others — good.

fla/models/samba/modeling_samba.py (2)

19-19: Switched to FLAGenerationMixin (import) — OK.


356-359: Safe slicing with logits_to_keep=None is already handled — good.

No action needed here.

fla/models/gated_deltanet/modeling_gated_deltanet.py (1)

19-19: Switch to FLAGenerationMixin/Cache import looks good.
No issues spotted with the new dependency points.

fla/models/gated_deltaproduct/modeling_gated_deltaproduct.py (2)

19-19: Import swap to utils.FLAGenerationMixin/Cache — LGTM.


261-261: Class now mixes in FLAGenerationMixin — LGTM.
Matches the repo-wide migration.

fla/models/log_linear_mamba2/modeling_log_linear_mamba2.py (2)

14-14: Switch to FLAGenerationMixin import — LGTM.


152-163: Stateful generation contract correctly implemented; no overrides needed. Mamba2’s prepare_inputs_for_generation and _update_model_kwargs_for_generation already handle cache_params/cache_position as expected by Transformers 4.56’s GenerationMixin, so the mixin interop is sound. Optionally, you can add the same AttributeError guard in generate for UX consistency.

fla/models/gla/modeling_gla.py (2)

19-19: Import consolidation to utils.FLAGenerationMixin/Cache — LGTM.


261-261: GLAForCausalLM now uses FLAGenerationMixin — LGTM.
Behavior aligned with other models.

fla/models/linear_attn/modeling_linear_attn.py (2)

19-19: Import change to FLAGenerationMixin/Cache — LGTM.


256-256: LinearAttentionForCausalLM adopts FLAGenerationMixin — LGTM.

fla/models/path_attn/modeling_path_attention.py (2)

18-18: Import swap to utils.FLAGenerationMixin/Cache — LGTM.


261-261: PaTHAttentionForCausalLM now uses FLAGenerationMixin — LGTM.

fla/models/rwkv7/modeling_rwkv7.py (2)

19-19: Import change to utils.FLAGenerationMixin/Cache — LGTM.


428-428: RWKV7ForCausalLM inherits FLAGenerationMixin — LGTM.
Consistent with the rest of the migration.

Comment on lines +140 to 147
$CONDA_BIN_PATH/pip install -U numpy packaging psutil ninja einops datasets transformers
$CONDA_BIN_PATH/pip install --no-deps .
else
STABLE_URL="https://download.pytorch.org/whl/${{ inputs.pytorch_cuda_version }}"
echo "Using stable index URL: $STABLE_URL"
$CONDA_BIN_PATH/pip install -U torch~=${{ inputs.pytorch_version }} triton --index-url $STABLE_URL
$CONDA_BIN_PATH/pip install -U numpy packaging psutil ninja einops datasets
$CONDA_BIN_PATH/pip install -U numpy packaging psutil ninja einops datasets transformers
$CONDA_BIN_PATH/pip install .
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Install transformers consistently across jobs/branches

Nightly paths install transformers, but stable paths in test-models (both NVIDIA and Intel) do not. This can leave an older preinstalled version and mask issues. Install it in those paths too.

-              $CONDA_BIN_PATH/pip install -U torch~=${{ inputs.pytorch_version }} triton --index-url $STABLE_URL
-              $CONDA_BIN_PATH/pip install -U numpy
+              $CONDA_BIN_PATH/pip install -U torch~=${{ inputs.pytorch_version }} triton --index-url $STABLE_URL
+              $CONDA_BIN_PATH/pip install -U numpy packaging psutil ninja einops datasets transformers
               $CONDA_BIN_PATH/pip install .
-            $CONDA_BIN_PATH/pip install -U torch~=${{ inputs.pytorch_version }} pytorch-triton-xpu --index-url $XPU_URL
-            $CONDA_BIN_PATH/pip install -U numpy
+            $CONDA_BIN_PATH/pip install -U torch~=${{ inputs.pytorch_version }} pytorch-triton-xpu --index-url $XPU_URL
+            $CONDA_BIN_PATH/pip install -U numpy packaging psutil ninja einops datasets transformers
             $CONDA_BIN_PATH/pip install .

Also applies to: 158-160, 321-329, 340-342

🤖 Prompt for AI Agents
.github/workflows/reusable-ci-tests.yml around lines 140-147 (also apply same
change at 158-160, 321-329, 340-342): the stable-path branches do not explicitly
install the transformers package, which can leave an older preinstalled version
and hide regressions; modify those stable-path install blocks to include
transformers in the pip install commands (matching the nightly paths) so that
transformers is installed/updated explicitly in the stable flows as well,
ensuring each affected block lists transformers alongside numpy packaging psutil
ninja einops datasets (and any other packages already present).



class GatedDeltaNetForCausalLM(GatedDeltaNetPreTrainedModel, GenerationMixin):
class GatedDeltaNetForCausalLM(GatedDeltaNetPreTrainedModel, FLAGenerationMixin):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Adopt FLAGenerationMixin for generation — OK. Verify no stale per-model overrides remain.
Since the mixin centralizes prepare_inputs_for_generation, ensure no lingering model-level overrides cause shadowing.

Run to confirm no stray definitions remain:


🏁 Script executed:

#!/bin/bash
rg -nP --type=py '\bdef\s+prepare_inputs_for_generation\b' fla/models | sed -n '1,200p'

Length of output: 136


Remove per-model override: Delete or delegate the prepare_inputs_for_generation method in fla/models/mom/modeling_mom.py (line 394) so it no longer shadows FLAGenerationMixin.

🤖 Prompt for AI Agents
In fla/models/gated_deltanet/modeling_gated_deltanet.py around line 269: the
class GatedDeltaNetForCausalLM is being shadowed by a per-model override of
prepare_inputs_for_generation in fla/models/mom/modeling_mom.py at line 394;
remove that method from modeling_mom.py or change it to delegate to the mixin by
replacing its body with a simple return
super().prepare_inputs_for_generation(*args, **kwargs) (or calling
FLAGenerationMixin.prepare_inputs_for_generation(self, *args, **kwargs)) so the
mixin implementation is used and no longer shadowed.



class SambaForCausalLM(SambaPreTrainedModel, GenerationMixin):
class SambaForCausalLM(SambaPreTrainedModel, FLAGenerationMixin):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Breaking: removed Samba’s prepare_inputs_for_generation but forward expects cache_params, not past_key_values.

FLAGenerationMixin’s default prepare_inputs_for_generation returns past_key_values; Samba forward does not accept it and instead needs cache_params. Generation will raise “got an unexpected keyword argument 'past_key_values'”.

Add Samba-specific prepare_inputs_for_generation to map to cache_params:

 class SambaForCausalLM(SambaPreTrainedModel, FLAGenerationMixin):
@@
     def _update_model_kwargs_for_generation(
         self, outputs: ModelOutput, model_kwargs: Dict[str, Any], **kwargs
     ) -> Dict[str, Any]:
         model_kwargs["cache_params"] = outputs.get("cache_params", None)
         return model_kwargs
+
+    def prepare_inputs_for_generation(
+        self,
+        input_ids: torch.LongTensor = None,
+        cache_params: Optional[MambaCache] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        use_cache: bool = True,
+        logits_to_keep: Optional[int] = None,
+        **kwargs: Any,
+    ) -> Dict[str, Any]:
+        model_inputs: Dict[str, Any] = {}
+        # If continuing with cache, only feed the next step token unless using inputs_embeds
+        if inputs_embeds is not None:
+            model_inputs["inputs_embeds"] = inputs_embeds
+        else:
+            # Keep static stride during decoding
+            model_inputs["input_ids"] = input_ids.contiguous() if input_ids is not None else None
+        if logits_to_keep is not None:
+            model_inputs["logits_to_keep"] = logits_to_keep
+        model_inputs.update(
+            {
+                "cache_params": cache_params,
+                "use_cache": use_cache,
+                "attention_mask": attention_mask,
+            }
+        )
+        return model_inputs
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
class SambaForCausalLM(SambaPreTrainedModel, FLAGenerationMixin):
class SambaForCausalLM(SambaPreTrainedModel, FLAGenerationMixin):
def _update_model_kwargs_for_generation(
self, outputs: ModelOutput, model_kwargs: Dict[str, Any], **kwargs
) -> Dict[str, Any]:
model_kwargs["cache_params"] = outputs.get("cache_params", None)
return model_kwargs
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor = None,
cache_params: Optional[MambaCache] = None,
inputs_embeds: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
use_cache: bool = True,
logits_to_keep: Optional[int] = None,
**kwargs: Any,
) -> Dict[str, Any]:
model_inputs: Dict[str, Any] = {}
# If continuing with cache, only feed the next step token unless using inputs_embeds
if inputs_embeds is not None:
model_inputs["inputs_embeds"] = inputs_embeds
else:
# Keep static stride during decoding
model_inputs["input_ids"] = input_ids.contiguous() if input_ids is not None else None
if logits_to_keep is not None:
model_inputs["logits_to_keep"] = logits_to_keep
model_inputs.update(
{
"cache_params": cache_params,
"use_cache": use_cache,
"attention_mask": attention_mask,
}
)
return model_inputs

Comment on lines +12 to 14
from transformers.generation import GenerationMixin
from transformers.utils.deprecation import deprecate_kwarg

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Harden GenerationMixin import for cross-version compatibility

Some Transformers versions relocate GenerationMixin; a single import path can break (seen intermittently around 4.52–4.56). Add a fallback chain.

-from transformers.generation import GenerationMixin
-from transformers.utils.deprecation import deprecate_kwarg
+try:
+    from transformers.generation import GenerationMixin
+except Exception:
+    try:
+        from transformers.generation.utils import GenerationMixin
+    except Exception:
+        # Final fallback for edge versions
+        from transformers import GenerationMixin
+from transformers.utils.deprecation import deprecate_kwarg
🤖 Prompt for AI Agents
In fla/models/utils.py around lines 12 to 14, the direct import of
GenerationMixin can break across Transformers versions; replace the single
import with a fallback chain that first tries from transformers.generation
import GenerationMixin, then from transformers.generation_utils import
GenerationMixin, and finally from transformers import GenerationMixin (or other
plausible locations) inside try/except blocks to catch ImportError; ensure the
final fallback raises a clear ImportError if none succeed and keep the rest of
the file using the resolved GenerationMixin name.

Comment on lines +429 to +437
if inputs_embeds is not None and (cache_position is None or len(cache_position) == inputs_embeds.shape[1]):
model_inputs['inputs_embeds'] = inputs_embeds
model_inputs['input_ids'] = None
else:
model_inputs['input_ids'] = input_ids.contiguous() if input_ids is not None else None
model_inputs['inputs_embeds'] = None

model_inputs['cache_position'] = cache_position

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Handle cache_position shape robustly when using inputs_embeds

Comparing len(cache_position) to inputs_embeds.shape[1] breaks if cache_position is 2D. Use last-dim length.

-            if inputs_embeds is not None and (cache_position is None or len(cache_position) == inputs_embeds.shape[1]):
+            cache_len = None
+            if cache_position is not None:
+                cache_len = cache_position.shape[-1]
+            if inputs_embeds is not None and (cache_len is None or cache_len == inputs_embeds.shape[1]):
                 model_inputs['inputs_embeds'] = inputs_embeds
                 model_inputs['input_ids'] = None
             else:
                 model_inputs['input_ids'] = input_ids.contiguous() if input_ids is not None else None
                 model_inputs['inputs_embeds'] = None
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if inputs_embeds is not None and (cache_position is None or len(cache_position) == inputs_embeds.shape[1]):
model_inputs['inputs_embeds'] = inputs_embeds
model_inputs['input_ids'] = None
else:
model_inputs['input_ids'] = input_ids.contiguous() if input_ids is not None else None
model_inputs['inputs_embeds'] = None
model_inputs['cache_position'] = cache_position
cache_len = None
if cache_position is not None:
cache_len = cache_position.shape[-1]
if inputs_embeds is not None and (cache_len is None or cache_len == inputs_embeds.shape[1]):
model_inputs['inputs_embeds'] = inputs_embeds
model_inputs['input_ids'] = None
else:
model_inputs['input_ids'] = input_ids.contiguous() if input_ids is not None else None
model_inputs['inputs_embeds'] = None
model_inputs['cache_position'] = cache_position
🤖 Prompt for AI Agents
In fla/models/utils.py around lines 429 to 437, the code compares
len(cache_position) to inputs_embeds.shape[1], which fails if cache_position is
a 2D tensor; change the comparison to use the last dimension length (e.g.
cache_position.shape[-1]) and guard for None so the condition becomes: if
inputs_embeds is not None and (cache_position is None or
cache_position.shape[-1] == inputs_embeds.shape[1]); leave the rest of the
branch logic unchanged.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (3)
tests/models/test_modeling_mamba2.py (3)

15-25: Guard the suite when CUDA isn’t available to avoid spurious CI failures

All parametrizations now require conv_backend='cuda'. Add a collection-time skip so CPU-only environments don’t fail this test.

+@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA for conv_backend='cuda'")
 @pytest.mark.parametrize(
     ['L', 'B', 'T', 'H', 'D', 'use_l2warp', 'dtype', 'conv_backend'],

36-41: Use pytest’s monkeypatch to avoid leaking environment variables across tests

Setting os.environ directly is process-global and can bleed into other tests. Use monkeypatch.setenv, which auto-restores after the test.

-def test_modeling(
+def test_modeling(
     L: int,
     B: int,
     T: int,
     H: int,
     D: int,
     use_l2warp: bool,
     dtype: torch.dtype,
-    conv_backend: str,
+    conv_backend: str,
+    monkeypatch,
 ):
@@
-    os.environ['FLA_CONV_BACKEND'] = conv_backend
+    monkeypatch.setenv('FLA_CONV_BACKEND', conv_backend)

36-73: Skip BF16 runs on pre-Ampere GPUs to prevent backend errors

CUDA BF16 requires SM80+ (Ampere). Add a runtime skip for that combo to reduce flakiness on mixed CI fleets.

 def test_modeling(
@@
-    monkeypatch.setenv('FLA_CONV_BACKEND', conv_backend)
+    monkeypatch.setenv('FLA_CONV_BACKEND', conv_backend)
+
+    # CUDA BF16 requires Ampere+ (SM80)
+    if (
+        conv_backend == 'cuda'
+        and dtype == torch.bfloat16
+        and torch.cuda.is_available()
+    ):
+        major, _ = torch.cuda.get_device_capability(0)
+        if major < 8:
+            pytest.skip("CUDA bfloat16 requires SM80+ (Ampere or newer).")
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between fec46c7 and bfc34d3.

📒 Files selected for processing (2)
  • fla/layers/mamba2.py (1 hunks)
  • tests/models/test_modeling_mamba2.py (1 hunks)
✅ Files skipped from review due to trivial changes (1)
  • fla/layers/mamba2.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: Test H100 (PyTorch 2.7) / test-models
  • GitHub Check: check-pt-python-compatibility
🔇 Additional comments (1)
tests/models/test_modeling_mamba2.py (1)

21-21: Switching to CUDA backend for this case looks good

Aligns the test with the project’s guidance to prefer CUDA over Triton for Mamba2 conv1d. No functional concerns.

@zhiyuan1i zhiyuan1i merged commit 0ca1288 into main Sep 9, 2025
4 of 5 checks passed
@zhiyuan1i zhiyuan1i deleted the upstream-transformers branch September 9, 2025 18:44
zhiyuan1i added a commit that referenced this pull request Sep 11, 2025
* [Deps] Upgrade to transformers 4.56.x

* remove triton conv1d backend for mamba2

* skip mom test
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants