-
Notifications
You must be signed in to change notification settings - Fork 29.9k
Improve Gemma3n model and tests #39764
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…transformers into max-cache-len-fix
@@ -227,7 +227,7 @@ def __init__( | |||
altup_num_inputs: int = 4, | |||
num_kv_shared_layers: int = 15, | |||
laurel_rank: int = 64, | |||
activation_sparsity_pattern: Optional[Union[float, Sequence[float]]] = (0.95,) * 10 + (0.0,) * 25, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
having the number of layers hardcoded is no good
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, agreed! Also imo having it a single float 0.95
isn;t very intuitive, we can default to None
in signature and later if None: pattern = (0.95,) * 10 [....]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, I thought the same. The problem with doing that is that it changes the behavior of None, which people might rely on in the wild:
transformers/src/transformers/models/gemma3n/configuration_gemma3n.py
Lines 291 to 292 in 83f2599
if activation_sparsity_pattern is None: | |
activation_sparsity_pattern = [0.0] * num_hidden_layers |
Maybe default to -1 or empty tuple?? Or you think it is safe to change None behaviour?
@@ -659,7 +658,6 @@ def test_automodelforcausallm(self): | |||
self.assertIsInstance(for_causal_lm, Gemma3nForCausalLM) | |||
|
|||
|
|||
@unittest.skip("Skipped for now!") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these tests were copied from gemma3 and were skipped. I updated and enabled them.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
This comment contains run-slow, running the specified jobs: models: ['models/gemma3', 'models/gemma3n'] |
run-slow: gemma3n, gemma3 |
This comment contains run-slow, running the specified jobs: models: ['models/gemma3', 'models/gemma3n'] |
all gemma 3n tests passign!! Thanks a lot @ydshieh for the help!! this is ready to merge :) (there is only a gemma3 custom test failing due to multiple gpus) Unsure who to tag for review, lmk if I didnt hit the gemma3n experts :) |
@@ -875,12 +859,13 @@ def test_model_1b_text_only(self): | |||
@require_flash_attn | |||
@require_torch_gpu | |||
@pytest.mark.flash_attn_test | |||
@unittest.skip("Timm models do not support Flash Attention 2 yet") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
then let's delete the test or use the CausalLM
@@ -227,7 +227,7 @@ def __init__( | |||
altup_num_inputs: int = 4, | |||
num_kv_shared_layers: int = 15, | |||
laurel_rank: int = 64, | |||
activation_sparsity_pattern: Optional[Union[float, Sequence[float]]] = (0.95,) * 10 + (0.0,) * 25, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, agreed! Also imo having it a single float 0.95
isn;t very intuitive, we can default to None
in signature and later if None: pattern = (0.95,) * 10 [....]
[For maintainers] Suggested jobs to run (before merge) run-slow: gemma3n |
Improves the Gemma3n model and tests by: