Skip to content

Commit 240ebfe

Browse files
molbapLysandreJik
authored andcommitted
Fix getter regression (#40824)
* test things * style * move tests to a sane place
1 parent a55e503 commit 240ebfe

File tree

2 files changed

+184
-3
lines changed

2 files changed

+184
-3
lines changed

src/transformers/modeling_utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3000,11 +3000,14 @@ def get_decoder(self):
30003000

30013001
if hasattr(self, "model"):
30023002
inner = self.model
3003-
if hasattr(inner, "get_decoder"):
3003+
# See: https://github.com/huggingface/transformers/issues/40815
3004+
if hasattr(inner, "get_decoder") and type(inner) is not type(self):
30043005
return inner.get_decoder()
30053006
return inner
30063007

3007-
return None # raise AttributeError(f"{self.__class__.__name__} has no decoder; override `get_decoder()` if needed.")
3008+
# If this is a base transformer model (no decoder/model attributes), return self
3009+
# This handles cases like MistralModel which is itself the decoder
3010+
return self
30083011

30093012
def set_decoder(self, decoder):
30103013
"""
@@ -3023,7 +3026,7 @@ def set_decoder(self, decoder):
30233026
self.model = decoder
30243027
return
30253028

3026-
return # raise AttributeError(f"{self.__class__.__name__} cannot accept a decoder; override `set_decoder()`.")
3029+
return
30273030

30283031
def _init_weights(self, module):
30293032
"""

tests/utils/test_modeling_utils.py

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,27 @@
3939
AutoModel,
4040
AutoModelForImageClassification,
4141
AutoModelForSequenceClassification,
42+
BartConfig,
43+
BartForConditionalGeneration,
4244
CLIPTextModelWithProjection,
4345
DynamicCache,
46+
GPT2Config,
47+
GPT2LMHeadModel,
48+
LlavaConfig,
4449
LlavaForConditionalGeneration,
50+
MistralConfig,
4551
MistralForCausalLM,
52+
OPTConfig,
53+
OPTForCausalLM,
4654
OwlViTForObjectDetection,
4755
PretrainedConfig,
56+
T5Config,
57+
T5ForConditionalGeneration,
4858
is_torch_available,
4959
logging,
5060
)
5161
from transformers.modeling_flash_attention_utils import is_flash_attn_available
62+
from transformers.models.mistral.modeling_mistral import MistralModel
5263
from transformers.testing_utils import (
5364
TOKEN,
5465
CaptureLogger,
@@ -2871,3 +2882,170 @@ def forward(self, hidden_states, attention_mask):
28712882
model.save_pretrained(tmpdirname)
28722883
model = MyModel.from_pretrained(tmpdirname)
28732884
self.assertEqual(model.my_layer.some_counter, 42)
2885+
2886+
2887+
class TestGetDecoder(unittest.TestCase):
2888+
def test_causal_lm_get_decoder_returns_underlying_model(self):
2889+
cfg = MistralConfig(
2890+
vocab_size=128,
2891+
hidden_size=32,
2892+
intermediate_size=64,
2893+
num_hidden_layers=2,
2894+
num_attention_heads=4,
2895+
)
2896+
model = MistralForCausalLM(cfg)
2897+
dec = model.get_decoder()
2898+
2899+
assert dec is model.model, f"Expected get_decoder() to return model.model, got {type(dec)}"
2900+
2901+
def test_seq2seq_get_decoder_still_returns_decoder_module(self):
2902+
cfg = BartConfig(
2903+
vocab_size=128,
2904+
d_model=32,
2905+
encoder_layers=2,
2906+
decoder_layers=2,
2907+
encoder_attention_heads=4,
2908+
decoder_attention_heads=4,
2909+
encoder_ffn_dim=64,
2910+
decoder_ffn_dim=64,
2911+
)
2912+
model = BartForConditionalGeneration(cfg)
2913+
dec = model.get_decoder()
2914+
2915+
assert dec is model.model.decoder, "Seq2seq get_decoder() should return the decoder submodule"
2916+
2917+
def test_base_model_returns_self(self):
2918+
"""Test that base transformer models (no decoder/model attributes) return self."""
2919+
cfg = MistralConfig(
2920+
vocab_size=128,
2921+
hidden_size=32,
2922+
intermediate_size=64,
2923+
num_hidden_layers=2,
2924+
num_attention_heads=4,
2925+
)
2926+
base_model = MistralModel(cfg)
2927+
dec = base_model.get_decoder()
2928+
2929+
assert dec is base_model, f"Base model get_decoder() should return self, got {type(dec)}"
2930+
2931+
def test_explicit_decoder_attribute_opt(self):
2932+
"""Test models with explicit decoder attribute (OPT style)."""
2933+
cfg = OPTConfig(
2934+
vocab_size=128,
2935+
hidden_size=32,
2936+
ffn_dim=64,
2937+
num_hidden_layers=2,
2938+
num_attention_heads=4,
2939+
max_position_embeddings=512,
2940+
)
2941+
model = OPTForCausalLM(cfg)
2942+
dec = model.get_decoder()
2943+
2944+
assert dec is model.model.decoder, f"OPT get_decoder() should return model.decoder, got {type(dec)}"
2945+
2946+
def test_explicit_decoder_attribute_t5(self):
2947+
"""Test encoder-decoder models with explicit decoder attribute."""
2948+
cfg = T5Config(
2949+
vocab_size=128,
2950+
d_model=32,
2951+
d_ff=64,
2952+
num_layers=2,
2953+
num_heads=4,
2954+
)
2955+
model = T5ForConditionalGeneration(cfg)
2956+
dec = model.get_decoder()
2957+
2958+
assert dec is model.decoder, f"T5 get_decoder() should return decoder attribute, got {type(dec)}"
2959+
2960+
def test_same_type_recursion_prevention(self):
2961+
"""Test that same-type recursion is prevented (see issue #40815)."""
2962+
cfg = MistralConfig(
2963+
vocab_size=128,
2964+
hidden_size=32,
2965+
intermediate_size=64,
2966+
num_hidden_layers=2,
2967+
num_attention_heads=4,
2968+
)
2969+
model = MistralForCausalLM(cfg)
2970+
2971+
assert type(model) is not type(model.model), "Types should be different to prevent recursion"
2972+
2973+
dec = model.get_decoder()
2974+
assert dec is model.model, f"Should return model.model without infinite recursion, got {type(dec)}"
2975+
2976+
inner_dec = model.model.get_decoder()
2977+
assert inner_dec is model.model, f"Inner model should return itself, got {type(inner_dec)}"
2978+
2979+
def test_nested_wrapper_recursion(self):
2980+
"""Test models that don't have model/decoder attributes return self."""
2981+
cfg = GPT2Config(
2982+
vocab_size=128,
2983+
n_embd=32,
2984+
n_layer=2,
2985+
n_head=4,
2986+
n_positions=512,
2987+
)
2988+
model = GPT2LMHeadModel(cfg)
2989+
dec = model.get_decoder()
2990+
2991+
assert dec is model, f"GPT2 get_decoder() should return self (fallback), got {type(dec)}"
2992+
2993+
def test_model_without_get_decoder(self):
2994+
"""Test edge case where model has model attribute but no get_decoder method."""
2995+
2996+
class MockInnerModel:
2997+
"""Mock model without get_decoder method."""
2998+
2999+
pass
3000+
3001+
class MockWrapperModel:
3002+
"""Mock wrapper with model attribute but inner has no get_decoder."""
3003+
3004+
def __init__(self):
3005+
self.model = MockInnerModel()
3006+
3007+
def get_decoder(self):
3008+
if hasattr(self, "decoder"):
3009+
return self.decoder
3010+
if hasattr(self, "model"):
3011+
inner = self.model
3012+
if hasattr(inner, "get_decoder") and type(inner) is not type(self):
3013+
return inner.get_decoder()
3014+
return inner
3015+
return self
3016+
3017+
wrapper = MockWrapperModel()
3018+
dec = wrapper.get_decoder()
3019+
3020+
assert dec is wrapper.model, f"Should return inner model when no get_decoder, got {type(dec)}"
3021+
3022+
def test_vision_language_model(self):
3023+
"""Test vision-language models like LLaVA that delegate to language_model."""
3024+
text_config = MistralConfig(
3025+
vocab_size=128,
3026+
hidden_size=32,
3027+
intermediate_size=64,
3028+
num_hidden_layers=2,
3029+
num_attention_heads=4,
3030+
)
3031+
3032+
vision_config = {
3033+
"hidden_size": 32,
3034+
"intermediate_size": 64,
3035+
"num_hidden_layers": 2,
3036+
"num_attention_heads": 4,
3037+
"num_channels": 3,
3038+
"image_size": 224,
3039+
"patch_size": 16,
3040+
}
3041+
3042+
cfg = LlavaConfig(
3043+
text_config=text_config.to_dict(),
3044+
vision_config=vision_config,
3045+
vocab_size=128,
3046+
)
3047+
3048+
model = LlavaForConditionalGeneration(cfg)
3049+
dec = model.get_decoder()
3050+
3051+
assert dec is model.language_model, f"LLaVA get_decoder() should return language_model, got {type(dec)}"

0 commit comments

Comments
 (0)