1111import vllm .envs as envs
1212from vllm .logger import init_logger
1313from vllm .triton_utils import tl , triton
14+ from vllm .utils .torch_utils import is_torch_equal_or_newer
1415
1516logger = init_logger (__name__ )
1617
@@ -676,6 +677,10 @@ def linear_batch_invariant(input, weight, bias=None):
676677_batch_invariant_MODE = False
677678_batch_invariant_LIB = None
678679_original_torch_bmm = None
680+ _original_fp16_reduction_precision = None
681+ _original_bf16_reduction_precision = None
682+ _original_cublas_workspace_cfg = None
683+ _original_cublaslt_workspace_size = None
679684
680685
681686def is_batch_invariant_mode_enabled ():
@@ -684,6 +689,8 @@ def is_batch_invariant_mode_enabled():
684689
685690def enable_batch_invariant_mode ():
686691 global _batch_invariant_MODE , _batch_invariant_LIB , _original_torch_bmm
692+ global _original_fp16_reduction_precision , _original_bf16_reduction_precision
693+ global _original_cublas_workspace_cfg , _original_cublaslt_workspace_size
687694 if _batch_invariant_MODE :
688695 return
689696
@@ -705,14 +712,75 @@ def enable_batch_invariant_mode():
705712 _original_torch_bmm = torch .bmm
706713 torch .bmm = bmm_batch_invariant
707714
715+ _original_bf16_reduction_precision = (
716+ torch .backends .cuda .matmul .allow_bf16_reduced_precision_reduction
717+ )
718+ _original_fp16_reduction_precision = (
719+ torch .backends .cuda .matmul .allow_fp16_reduced_precision_reduction
720+ )
721+
722+ reduced_precision_val = (
723+ (False , False ) if is_torch_equal_or_newer ("2.10.0.dev" ) else False
724+ )
725+ torch .backends .cuda .matmul .allow_fp16_reduced_precision_reduction = (
726+ reduced_precision_val
727+ )
728+ torch .backends .cuda .matmul .allow_bf16_reduced_precision_reduction = (
729+ reduced_precision_val
730+ )
731+ torch .backends .cuda .preferred_blas_library (backend = "cublaslt" )
732+
733+ if not is_torch_equal_or_newer ("2.10.0.dev" ):
734+ _original_cublas_workspace_cfg = os .environ .get ("CUBLAS_WORKSPACE_CONFIG" , None )
735+ _original_cublaslt_workspace_size = os .environ .get (
736+ "CUBLASLT_WORKSPACE_SIZE" , None
737+ )
738+ os .environ ["CUBLAS_WORKSPACE_CONFIG" ] = ":16:8"
739+ os .environ ["CUBLASLT_WORKSPACE_SIZE" ] = "1"
740+
708741
709742def disable_batch_invariant_mode ():
710743 global _batch_invariant_MODE , _batch_invariant_LIB , _original_torch_bmm
744+ global _original_fp16_reduction_precision , _original_bf16_reduction_precision
745+ global _original_cublas_workspace_cfg , _original_cublaslt_workspace_size
746+ if not _batch_invariant_MODE :
747+ return
748+
711749 if _batch_invariant_LIB is not None :
712750 _batch_invariant_LIB ._destroy ()
713751 if _original_torch_bmm is not None :
714752 torch .bmm = _original_torch_bmm
715753 _original_torch_bmm = None
754+
755+ if _original_bf16_reduction_precision is not None :
756+ torch .backends .cuda .matmul .allow_bf16_reduced_precision_reduction = (
757+ _original_bf16_reduction_precision
758+ )
759+ _original_bf16_reduction_precision = None
760+ if _original_fp16_reduction_precision is not None :
761+ torch .backends .cuda .matmul .allow_fp16_reduced_precision_reduction = (
762+ _original_fp16_reduction_precision
763+ )
764+ _original_fp16_reduction_precision = None
765+
766+ torch .backends .cuda .preferred_blas_library (backend = "default" )
767+
768+ if not is_torch_equal_or_newer ("2.10.0.dev" ):
769+ # Set cublas envvars to previous results. If previous results are None,
770+ # that means the envvars were not set, so we should remove them.
771+ if _original_cublas_workspace_cfg :
772+ os .environ ["CUBLAS_WORKSPACE_CONFIG" ] = _original_cublas_workspace_cfg
773+ elif "CUBLAS_WORKSPACE_CONFIG" in os .environ :
774+ del os .environ ["CUBLAS_WORKSPACE_CONFIG" ]
775+
776+ if _original_cublaslt_workspace_size :
777+ os .environ ["CUBLASLT_WORKSPACE_SIZE" ] = _original_cublaslt_workspace_size
778+ elif "CUBLASLT_WORKSPACE_SIZE" in os .environ :
779+ del os .environ ["CUBLASLT_WORKSPACE_SIZE" ]
780+
781+ _original_cublas_workspace_cfg = None
782+ _original_cublaslt_workspace_size = None
783+
716784 _batch_invariant_MODE = False
717785 _batch_invariant_LIB = None
718786
@@ -791,6 +859,9 @@ def override_envs_for_invariance():
791859 os .environ ["NCCL_NTHREADS" ] = "1"
792860 os .environ ["NCCL_SOCKET_NTHREADS" ] = "1"
793861
862+ # torch.compile settings
863+ os .environ ["VLLM_USE_AOT_COMPILE" ] = "0"
864+
794865
795866def init_batch_invariance ():
796867 # this will hit all the csrc overrides as well
0 commit comments