Skip to content

Commit 7f60594

Browse files
author
sanchit-gandhi
committed
add test
1 parent eaa5afe commit 7f60594

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

tests/models/whisper/test_modeling_whisper.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1820,6 +1820,31 @@ def test_custom_4d_attention_mask(self):
18201820
normalized_1 = torch.nn.functional.softmax(out_shared_prefix_last_tokens)
18211821
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)
18221822

1823+
def test_generate_output_type(self):
1824+
for model_class in self.all_generative_model_classes:
1825+
config, inputs = self.model_tester.prepare_config_and_inputs()
1826+
model = model_class(config).to(torch_device).eval()
1827+
1828+
# short-form generation without fallback
1829+
pred_ids = model.generate(**inputs)
1830+
assert isinstance(pred_ids, torch.Tensor)
1831+
1832+
# short-form generation with fallback
1833+
pred_ids = model.generate(**inputs, logprob_threshold=-1.0, temperature=[0.0, 0.1])
1834+
assert isinstance(pred_ids, torch.Tensor)
1835+
1836+
# create artificial long-form inputs
1837+
inputs["input_features"] = torch.cat([inputs["input_features"], inputs["input_features"]], dim=-1)
1838+
inputs["attention_mask"] = torch.ones(inputs["input_features"].shape[:2], dtype=torch.int, device=inputs["input_features"].device)
1839+
model.generation_config.no_timestamps_token_id = model.generation_config.decoder_start_token_id
1840+
1841+
# long-form generation without fallback
1842+
pred_ids = model.generate(**inputs)
1843+
assert isinstance(pred_ids, torch.Tensor)
1844+
1845+
# short-form generation with fallback
1846+
pred_ids = model.generate(**inputs, logprob_threshold=-1.0, temperature=[0.0, 0.1])
1847+
assert isinstance(pred_ids, torch.Tensor)
18231848

18241849
@require_torch
18251850
@require_torchaudio

0 commit comments

Comments
 (0)