@@ -54,7 +54,7 @@ def test_get_field():
5454 ("jason9693/Qwen2.5-1.5B-apeach" , "pooling" , "classify" ),
5555 ("cross-encoder/ms-marco-MiniLM-L-6-v2" , "pooling" , "classify" ),
5656 ("Qwen/Qwen2.5-Math-RM-72B" , "pooling" , "reward" ),
57- ("openai/whisper-small" , "transcription " , "transcription" ),
57+ ("openai/whisper-small" , "generate " , "transcription" ),
5858 ],
5959)
6060def test_auto_task (model_id , expected_runner_type , expected_task ):
@@ -69,7 +69,11 @@ def test_auto_task(model_id, expected_runner_type, expected_task):
6969 )
7070
7171 assert config .runner_type == expected_runner_type
72- assert config .task == expected_task
72+
73+ if config .runner_type == "pooling" :
74+ assert config .task == expected_task
75+ else :
76+ assert expected_task in config .supported_tasks
7377
7478
7579@pytest .mark .parametrize (
@@ -98,11 +102,50 @@ def test_score_task(model_id, expected_runner_type, expected_task):
98102 assert config .task == expected_task
99103
100104
105+ @pytest .mark .parametrize (("model_id" , "expected_runner_type" , "expected_task" ),
106+ [
107+ ("Qwen/Qwen2.5-1.5B-Instruct" , "draft" , "auto" ),
108+ ])
109+ def test_draft_task (model_id , expected_runner_type , expected_task ):
110+ config = ModelConfig (
111+ model_id ,
112+ runner = "draft" ,
113+ tokenizer = model_id ,
114+ seed = 0 ,
115+ dtype = "float16" ,
116+ )
117+
118+ assert config .runner_type == expected_runner_type
119+ assert config .task == expected_task
120+
121+
122+ @pytest .mark .parametrize (
123+ ("model_id" , "expected_runner_type" , "expected_task" ),
124+ [
125+ ("openai/whisper-small" , "generate" , "transcription" ),
126+ ],
127+ )
128+ def test_transcription_task (model_id , expected_runner_type , expected_task ):
129+ config = ModelConfig (
130+ model_id ,
131+ task = "transcription" ,
132+ tokenizer = model_id ,
133+ tokenizer_mode = "auto" ,
134+ trust_remote_code = False ,
135+ seed = 0 ,
136+ dtype = "float16" ,
137+ )
138+
139+ assert config .runner_type == expected_runner_type
140+ assert config .task == expected_task
141+
142+
101143@pytest .mark .parametrize (("model_id" , "bad_task" ), [
102144 ("Qwen/Qwen2.5-Math-RM-72B" , "generate" ),
145+ ("Qwen/Qwen3-0.6B" , "transcription" ),
103146])
104147def test_incorrect_task (model_id , bad_task ):
105- with pytest .raises (ValueError , match = r"does not support the .* task " ):
148+ with pytest .raises (ValueError , match = r"does not support task=.* " ):
106149 ModelConfig (
107150 model_id ,
108151 task = bad_task ,
0 commit comments