|
39 | 39 | AutoModel, |
40 | 40 | AutoModelForImageClassification, |
41 | 41 | AutoModelForSequenceClassification, |
| 42 | + BartConfig, |
| 43 | + BartForConditionalGeneration, |
42 | 44 | CLIPTextModelWithProjection, |
43 | 45 | DynamicCache, |
| 46 | + GPT2Config, |
| 47 | + GPT2LMHeadModel, |
| 48 | + LlavaConfig, |
44 | 49 | LlavaForConditionalGeneration, |
| 50 | + MistralConfig, |
45 | 51 | MistralForCausalLM, |
| 52 | + OPTConfig, |
| 53 | + OPTForCausalLM, |
46 | 54 | OwlViTForObjectDetection, |
47 | 55 | PretrainedConfig, |
| 56 | + T5Config, |
| 57 | + T5ForConditionalGeneration, |
48 | 58 | is_torch_available, |
49 | 59 | logging, |
50 | 60 | ) |
51 | 61 | from transformers.modeling_flash_attention_utils import is_flash_attn_available |
| 62 | +from transformers.models.mistral.modeling_mistral import MistralModel |
52 | 63 | from transformers.testing_utils import ( |
53 | 64 | TOKEN, |
54 | 65 | CaptureLogger, |
@@ -2871,3 +2882,170 @@ def forward(self, hidden_states, attention_mask): |
2871 | 2882 | model.save_pretrained(tmpdirname) |
2872 | 2883 | model = MyModel.from_pretrained(tmpdirname) |
2873 | 2884 | self.assertEqual(model.my_layer.some_counter, 42) |
| 2885 | + |
| 2886 | + |
| 2887 | +class TestGetDecoder(unittest.TestCase): |
| 2888 | + def test_causal_lm_get_decoder_returns_underlying_model(self): |
| 2889 | + cfg = MistralConfig( |
| 2890 | + vocab_size=128, |
| 2891 | + hidden_size=32, |
| 2892 | + intermediate_size=64, |
| 2893 | + num_hidden_layers=2, |
| 2894 | + num_attention_heads=4, |
| 2895 | + ) |
| 2896 | + model = MistralForCausalLM(cfg) |
| 2897 | + dec = model.get_decoder() |
| 2898 | + |
| 2899 | + assert dec is model.model, f"Expected get_decoder() to return model.model, got {type(dec)}" |
| 2900 | + |
| 2901 | + def test_seq2seq_get_decoder_still_returns_decoder_module(self): |
| 2902 | + cfg = BartConfig( |
| 2903 | + vocab_size=128, |
| 2904 | + d_model=32, |
| 2905 | + encoder_layers=2, |
| 2906 | + decoder_layers=2, |
| 2907 | + encoder_attention_heads=4, |
| 2908 | + decoder_attention_heads=4, |
| 2909 | + encoder_ffn_dim=64, |
| 2910 | + decoder_ffn_dim=64, |
| 2911 | + ) |
| 2912 | + model = BartForConditionalGeneration(cfg) |
| 2913 | + dec = model.get_decoder() |
| 2914 | + |
| 2915 | + assert dec is model.model.decoder, "Seq2seq get_decoder() should return the decoder submodule" |
| 2916 | + |
| 2917 | + def test_base_model_returns_self(self): |
| 2918 | + """Test that base transformer models (no decoder/model attributes) return self.""" |
| 2919 | + cfg = MistralConfig( |
| 2920 | + vocab_size=128, |
| 2921 | + hidden_size=32, |
| 2922 | + intermediate_size=64, |
| 2923 | + num_hidden_layers=2, |
| 2924 | + num_attention_heads=4, |
| 2925 | + ) |
| 2926 | + base_model = MistralModel(cfg) |
| 2927 | + dec = base_model.get_decoder() |
| 2928 | + |
| 2929 | + assert dec is base_model, f"Base model get_decoder() should return self, got {type(dec)}" |
| 2930 | + |
| 2931 | + def test_explicit_decoder_attribute_opt(self): |
| 2932 | + """Test models with explicit decoder attribute (OPT style).""" |
| 2933 | + cfg = OPTConfig( |
| 2934 | + vocab_size=128, |
| 2935 | + hidden_size=32, |
| 2936 | + ffn_dim=64, |
| 2937 | + num_hidden_layers=2, |
| 2938 | + num_attention_heads=4, |
| 2939 | + max_position_embeddings=512, |
| 2940 | + ) |
| 2941 | + model = OPTForCausalLM(cfg) |
| 2942 | + dec = model.get_decoder() |
| 2943 | + |
| 2944 | + assert dec is model.model.decoder, f"OPT get_decoder() should return model.decoder, got {type(dec)}" |
| 2945 | + |
| 2946 | + def test_explicit_decoder_attribute_t5(self): |
| 2947 | + """Test encoder-decoder models with explicit decoder attribute.""" |
| 2948 | + cfg = T5Config( |
| 2949 | + vocab_size=128, |
| 2950 | + d_model=32, |
| 2951 | + d_ff=64, |
| 2952 | + num_layers=2, |
| 2953 | + num_heads=4, |
| 2954 | + ) |
| 2955 | + model = T5ForConditionalGeneration(cfg) |
| 2956 | + dec = model.get_decoder() |
| 2957 | + |
| 2958 | + assert dec is model.decoder, f"T5 get_decoder() should return decoder attribute, got {type(dec)}" |
| 2959 | + |
| 2960 | + def test_same_type_recursion_prevention(self): |
| 2961 | + """Test that same-type recursion is prevented (see issue #40815).""" |
| 2962 | + cfg = MistralConfig( |
| 2963 | + vocab_size=128, |
| 2964 | + hidden_size=32, |
| 2965 | + intermediate_size=64, |
| 2966 | + num_hidden_layers=2, |
| 2967 | + num_attention_heads=4, |
| 2968 | + ) |
| 2969 | + model = MistralForCausalLM(cfg) |
| 2970 | + |
| 2971 | + assert type(model) is not type(model.model), "Types should be different to prevent recursion" |
| 2972 | + |
| 2973 | + dec = model.get_decoder() |
| 2974 | + assert dec is model.model, f"Should return model.model without infinite recursion, got {type(dec)}" |
| 2975 | + |
| 2976 | + inner_dec = model.model.get_decoder() |
| 2977 | + assert inner_dec is model.model, f"Inner model should return itself, got {type(inner_dec)}" |
| 2978 | + |
| 2979 | + def test_nested_wrapper_recursion(self): |
| 2980 | + """Test models that don't have model/decoder attributes return self.""" |
| 2981 | + cfg = GPT2Config( |
| 2982 | + vocab_size=128, |
| 2983 | + n_embd=32, |
| 2984 | + n_layer=2, |
| 2985 | + n_head=4, |
| 2986 | + n_positions=512, |
| 2987 | + ) |
| 2988 | + model = GPT2LMHeadModel(cfg) |
| 2989 | + dec = model.get_decoder() |
| 2990 | + |
| 2991 | + assert dec is model, f"GPT2 get_decoder() should return self (fallback), got {type(dec)}" |
| 2992 | + |
| 2993 | + def test_model_without_get_decoder(self): |
| 2994 | + """Test edge case where model has model attribute but no get_decoder method.""" |
| 2995 | + |
| 2996 | + class MockInnerModel: |
| 2997 | + """Mock model without get_decoder method.""" |
| 2998 | + |
| 2999 | + pass |
| 3000 | + |
| 3001 | + class MockWrapperModel: |
| 3002 | + """Mock wrapper with model attribute but inner has no get_decoder.""" |
| 3003 | + |
| 3004 | + def __init__(self): |
| 3005 | + self.model = MockInnerModel() |
| 3006 | + |
| 3007 | + def get_decoder(self): |
| 3008 | + if hasattr(self, "decoder"): |
| 3009 | + return self.decoder |
| 3010 | + if hasattr(self, "model"): |
| 3011 | + inner = self.model |
| 3012 | + if hasattr(inner, "get_decoder") and type(inner) is not type(self): |
| 3013 | + return inner.get_decoder() |
| 3014 | + return inner |
| 3015 | + return self |
| 3016 | + |
| 3017 | + wrapper = MockWrapperModel() |
| 3018 | + dec = wrapper.get_decoder() |
| 3019 | + |
| 3020 | + assert dec is wrapper.model, f"Should return inner model when no get_decoder, got {type(dec)}" |
| 3021 | + |
| 3022 | + def test_vision_language_model(self): |
| 3023 | + """Test vision-language models like LLaVA that delegate to language_model.""" |
| 3024 | + text_config = MistralConfig( |
| 3025 | + vocab_size=128, |
| 3026 | + hidden_size=32, |
| 3027 | + intermediate_size=64, |
| 3028 | + num_hidden_layers=2, |
| 3029 | + num_attention_heads=4, |
| 3030 | + ) |
| 3031 | + |
| 3032 | + vision_config = { |
| 3033 | + "hidden_size": 32, |
| 3034 | + "intermediate_size": 64, |
| 3035 | + "num_hidden_layers": 2, |
| 3036 | + "num_attention_heads": 4, |
| 3037 | + "num_channels": 3, |
| 3038 | + "image_size": 224, |
| 3039 | + "patch_size": 16, |
| 3040 | + } |
| 3041 | + |
| 3042 | + cfg = LlavaConfig( |
| 3043 | + text_config=text_config.to_dict(), |
| 3044 | + vision_config=vision_config, |
| 3045 | + vocab_size=128, |
| 3046 | + ) |
| 3047 | + |
| 3048 | + model = LlavaForConditionalGeneration(cfg) |
| 3049 | + dec = model.get_decoder() |
| 3050 | + |
| 3051 | + assert dec is model.language_model, f"LLaVA get_decoder() should return language_model, got {type(dec)}" |
0 commit comments