@@ -1820,6 +1820,31 @@ def test_custom_4d_attention_mask(self):
18201820 normalized_1 = torch .nn .functional .softmax (out_shared_prefix_last_tokens )
18211821 torch .testing .assert_close (normalized_0 , normalized_1 , rtol = 1e-3 , atol = 1e-4 )
18221822
1823+ def test_generate_output_type (self ):
1824+ for model_class in self .all_generative_model_classes :
1825+ config , inputs = self .model_tester .prepare_config_and_inputs ()
1826+ model = model_class (config ).to (torch_device ).eval ()
1827+
1828+ # short-form generation without fallback
1829+ pred_ids = model .generate (** inputs )
1830+ assert isinstance (pred_ids , torch .Tensor )
1831+
1832+ # short-form generation with fallback
1833+ pred_ids = model .generate (** inputs , logprob_threshold = - 1.0 , temperature = [0.0 , 0.1 ])
1834+ assert isinstance (pred_ids , torch .Tensor )
1835+
1836+ # create artificial long-form inputs
1837+ inputs ["input_features" ] = torch .cat ([inputs ["input_features" ], inputs ["input_features" ]], dim = - 1 )
1838+ inputs ["attention_mask" ] = torch .ones (inputs ["input_features" ].shape [:2 ], dtype = torch .int , device = inputs ["input_features" ].device )
1839+ model .generation_config .no_timestamps_token_id = model .generation_config .decoder_start_token_id
1840+
1841+ # long-form generation without fallback
1842+ pred_ids = model .generate (** inputs )
1843+ assert isinstance (pred_ids , torch .Tensor )
1844+
1845+ # short-form generation with fallback
1846+ pred_ids = model .generate (** inputs , logprob_threshold = - 1.0 , temperature = [0.0 , 0.1 ])
1847+ assert isinstance (pred_ids , torch .Tensor )
18231848
18241849@require_torch
18251850@require_torchaudio
0 commit comments