1414from tests .quantization .utils import is_quant_method_supported
1515
1616from ...utils import compare_two_settings , multi_gpu_test
17- from ..utils import check_embeddings_close
17+ from ..utils import check_embeddings_close , check_logprobs_close
1818
1919models_4bit_to_test = [
2020 ("facebook/opt-125m" , "quantize opt model inflight" ),
2626 ("intfloat/e5-mistral-7b-instruct" , "quantize embedding model inflight" ),
2727]
2828
29+ models_4bit_to_moe_test = [
30+ ("allenai/OLMoE-1B-7B-0125-Instruct" , "quantize moe model inflight" ),
31+ ]
32+
2933models_pre_qaunt_4bit_to_test = [
3034 ('PrunaAI/Einstein-v6.1-Llama3-8B-bnb-4bit-smashed' ,
3135 'read pre-quantized 4-bit FP4 model' ),
@@ -115,6 +119,35 @@ def test_load_pp_4bit_bnb_model(model_name, description) -> None:
115119 compare_two_settings (model_name , common_args , pp_args )
116120
117121
122+ @pytest .mark .skipif (not is_quant_method_supported ("bitsandbytes" ),
123+ reason = 'bitsandbytes is not supported on this GPU type.' )
124+ @pytest .mark .parametrize ("model_name, description" , models_4bit_to_moe_test )
125+ def test_4bit_bnb_moe_model (hf_runner , vllm_runner , example_prompts ,
126+ model_name , description ) -> None :
127+
128+ hf_model_kwargs = dict (quantization_config = BitsAndBytesConfig (
129+ load_in_4bit = True ,
130+ bnb_4bit_quant_type = "nf4" ,
131+ bnb_4bit_use_double_quant = True ,
132+ ))
133+ with vllm_runner (model_name ,
134+ quantization = 'bitsandbytes' ,
135+ enforce_eager = False ) as llm :
136+ vllm_outputs = llm .generate_greedy_logprobs (example_prompts ,
137+ max_tokens = 32 ,
138+ num_logprobs = 5 )
139+
140+ with hf_runner (model_name , model_kwargs = hf_model_kwargs ) as llm :
141+ transformers_outputs = llm .generate_greedy_logprobs_limit (
142+ example_prompts , max_tokens = 32 , num_logprobs = 5 )
143+ check_logprobs_close (
144+ outputs_0_lst = transformers_outputs ,
145+ outputs_1_lst = vllm_outputs ,
146+ name_0 = "transformers" ,
147+ name_1 = "vllm" ,
148+ )
149+
150+
118151@pytest .mark .skipif (not is_quant_method_supported ("bitsandbytes" ),
119152 reason = 'bitsandbytes is not supported on this GPU type.' )
120153@pytest .mark .parametrize ("model_name, description" ,
@@ -182,15 +215,17 @@ def validate_generated_texts(hf_runner,
182215 model_name ,
183216 pre_quant = False ,
184217 hf_model_kwargs = None ,
185- vllm_tp_size = 1 ):
218+ vllm_tp_size = 1 ,
219+ max_tokens = 8 ):
186220
187221 # NOTE: run vLLM first, as it requires a clean process
188222 # when using distributed inference
189223 with vllm_runner (model_name ,
190224 quantization = None if pre_quant else 'bitsandbytes' ,
191225 tensor_parallel_size = vllm_tp_size ,
192226 enforce_eager = False ) as llm :
193- vllm_outputs = llm .generate_greedy (prompts , 8 )
227+
228+ vllm_outputs = llm .generate_greedy (prompts , max_tokens )
194229 vllm_logs = log_generated_texts (prompts , vllm_outputs , "VllmRunner" )
195230
196231 # Clean up the GPU memory for the next test
@@ -202,19 +237,17 @@ def validate_generated_texts(hf_runner,
202237
203238 # Run with HF runner
204239 with hf_runner (model_name , model_kwargs = hf_model_kwargs ) as llm :
205- hf_outputs = llm .generate_greedy (prompts , 8 )
240+ hf_outputs = llm .generate_greedy (prompts , max_tokens )
206241 hf_logs = log_generated_texts (prompts , hf_outputs , "HfRunner" )
207242
208243 # Clean up the GPU memory for the next test
209244 gc .collect ()
210245 torch .cuda .empty_cache ()
211-
212246 # Compare the generated strings
213247 for hf_log , vllm_log in zip (hf_logs , vllm_logs ):
214248 hf_str = hf_log ["generated_text" ]
215249 vllm_str = vllm_log ["generated_text" ]
216250 prompt = hf_log ["prompt" ]
217-
218251 assert hf_str == vllm_str , (f"Model: { model_name } "
219252 f"Mismatch between HF and vLLM outputs:\n "
220253 f"Prompt: { prompt } \n "
0 commit comments