@@ -222,6 +222,7 @@ def __init__(self,
222
222
self ._verify_embedding_mode ()
223
223
self ._verify_quantization ()
224
224
self ._verify_cuda_graph ()
225
+ self ._verify_bnb_config ()
225
226
226
227
def _init_multimodal_config (
227
228
self , limit_mm_per_prompt : Optional [Mapping [str , int ]]
@@ -337,6 +338,28 @@ def _verify_cuda_graph(self) -> None:
337
338
self .max_seq_len_to_capture = min (self .max_seq_len_to_capture ,
338
339
self .max_model_len )
339
340
341
+ def _verify_bnb_config (self ) -> None :
342
+ """
343
+ The current version of bitsandbytes (0.44.0) with 8-bit models does not
344
+ yet support CUDA graph.
345
+ """
346
+ is_bitsandbytes = self .quantization == "bitsandbytes"
347
+ has_quantization_config = (getattr (self .hf_config ,
348
+ "quantization_config" , None )
349
+ is not None )
350
+ is_8bit = (self .hf_config .quantization_config .get (
351
+ "load_in_8bit" , False ) if has_quantization_config else False )
352
+ if all ([
353
+ is_bitsandbytes ,
354
+ has_quantization_config ,
355
+ is_8bit ,
356
+ not self .enforce_eager ,
357
+ ]):
358
+ logger .warning (
359
+ "CUDA graph is not supported on BitAndBytes 8bit yet, "
360
+ "fallback to the eager mode." )
361
+ self .enforce_eager = True
362
+
340
363
def verify_async_output_proc (self , parallel_config , speculative_config ,
341
364
device_config ) -> None :
342
365
if not self .use_async_output_proc :
@@ -401,13 +424,6 @@ def verify_with_parallel_config(
401
424
"Pipeline parallelism is only supported for the following "
402
425
f" architectures: { _PP_SUPPORTED_MODELS } ." )
403
426
404
- # Remove the constraint after the bitsandbytes issue is fixed:
405
- # https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1308
406
- if self .quantization == "bitsandbytes" and self .enforce_eager is False :
407
- logger .warning ("CUDA graph is not supported on BitAndBytes yet, "
408
- "fallback to the eager mode." )
409
- self .enforce_eager = True
410
-
411
427
if pipeline_parallel_size > 1 and self .use_async_output_proc :
412
428
logger .warning ("Async output processor is not supported with "
413
429
"pipeline parallelism currently. Disabling it." )
0 commit comments