Skip to content

Commit 1997878

Browse files
committed
Batch invariant torch.compile
Signed-off-by: PaulZhang12 <[email protected]>
1 parent 9466661 commit 1997878

File tree

4 files changed

+80
-9
lines changed

4 files changed

+80
-9
lines changed

vllm/config/model.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,6 @@
2020
from vllm.config.scheduler import RunnerType
2121
from vllm.config.utils import assert_hashable, config, getattr_iter
2222
from vllm.logger import init_logger
23-
from vllm.model_executor.layers.batch_invariant import (
24-
vllm_is_batch_invariant,
25-
)
2623
from vllm.platforms import current_platform
2724
from vllm.transformers_utils.config import (
2825
ConfigFormat,
@@ -437,10 +434,6 @@ def __post_init__(
437434
skip_mm_profiling: bool | None,
438435
video_pruning_rate: float | None,
439436
) -> None:
440-
# Enable batch invariance settings if requested
441-
if vllm_is_batch_invariant():
442-
self.enforce_eager = True
443-
444437
# Set the default seed to 0 in V1.
445438
# NOTE(woosuk): In V0, we set the default seed to None because the
446439
# driver worker shares the same process as the user process, and thus

vllm/envs.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,10 +247,16 @@ def maybe_convert_bool(value: str | None) -> bool | None:
247247

248248

249249
def use_aot_compile() -> bool:
250+
from vllm.model_executor.layers.batch_invariant import (
251+
vllm_is_batch_invariant,
252+
)
250253
from vllm.utils.torch_utils import is_torch_equal_or_newer
251254

252255
default_value = "1" if is_torch_equal_or_newer("2.10.0.dev") else "0"
253-
return os.environ.get("VLLM_USE_AOT_COMPILE", default_value) == "1"
256+
return (
257+
not vllm_is_batch_invariant()
258+
and os.environ.get("VLLM_USE_AOT_COMPILE", default_value) == "1"
259+
)
254260

255261

256262
def env_with_choices(

vllm/model_executor/layers/batch_invariant.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import vllm.envs as envs
1212
from vllm.logger import init_logger
1313
from vllm.triton_utils import tl, triton
14+
from vllm.utils.torch_utils import is_torch_equal_or_newer
1415

1516
logger = init_logger(__name__)
1617

@@ -676,6 +677,10 @@ def linear_batch_invariant(input, weight, bias=None):
676677
_batch_invariant_MODE = False
677678
_batch_invariant_LIB = None
678679
_original_torch_bmm = None
680+
_original_fp16_reduction_precision = None
681+
_original_bf16_reduction_precision = None
682+
_original_cublas_workspace_cfg = None
683+
_original_cublaslt_workspace_size = None
679684

680685

681686
def is_batch_invariant_mode_enabled():
@@ -684,6 +689,8 @@ def is_batch_invariant_mode_enabled():
684689

685690
def enable_batch_invariant_mode():
686691
global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm
692+
global _original_fp16_reduction_precision, _original_bf16_reduction_precision
693+
global _original_cublas_workspace_cfg, _original_cublaslt_workspace_size
687694
if _batch_invariant_MODE:
688695
return
689696

@@ -705,14 +712,75 @@ def enable_batch_invariant_mode():
705712
_original_torch_bmm = torch.bmm
706713
torch.bmm = bmm_batch_invariant
707714

715+
_original_bf16_reduction_precision = (
716+
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction
717+
)
718+
_original_fp16_reduction_precision = (
719+
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction
720+
)
721+
722+
reduced_precision_val = (
723+
(False, False) if is_torch_equal_or_newer("2.10.0.dev") else False
724+
)
725+
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = (
726+
reduced_precision_val
727+
)
728+
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = (
729+
reduced_precision_val
730+
)
731+
torch.backends.cuda.preferred_blas_library(backend="cublaslt")
732+
733+
if not is_torch_equal_or_newer("2.10.0.dev"):
734+
_original_cublas_workspace_cfg = os.environ.get("CUBLAS_WORKSPACE_CONFIG", None)
735+
_original_cublaslt_workspace_size = os.environ.get(
736+
"CUBLASLT_WORKSPACE_SIZE", None
737+
)
738+
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
739+
os.environ["CUBLASLT_WORKSPACE_SIZE"] = "1"
740+
708741

709742
def disable_batch_invariant_mode():
710743
global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm
744+
global _original_fp16_reduction_precision, _original_bf16_reduction_precision
745+
global _original_cublas_workspace_cfg, _original_cublaslt_workspace_size
746+
if not _batch_invariant_MODE:
747+
return
748+
711749
if _batch_invariant_LIB is not None:
712750
_batch_invariant_LIB._destroy()
713751
if _original_torch_bmm is not None:
714752
torch.bmm = _original_torch_bmm
715753
_original_torch_bmm = None
754+
755+
if _original_bf16_reduction_precision is not None:
756+
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = (
757+
_original_bf16_reduction_precision
758+
)
759+
_original_bf16_reduction_precision = None
760+
if _original_fp16_reduction_precision is not None:
761+
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = (
762+
_original_fp16_reduction_precision
763+
)
764+
_original_fp16_reduction_precision = None
765+
766+
torch.backends.cuda.preferred_blas_library(backend="default")
767+
768+
if not is_torch_equal_or_newer("2.10.0.dev"):
769+
# Set cublas envvars to previous results. If previous results are None,
770+
# that means the envvars were not set, so we should remove them.
771+
if _original_cublas_workspace_cfg:
772+
os.environ["CUBLAS_WORKSPACE_CONFIG"] = _original_cublas_workspace_cfg
773+
elif "CUBLAS_WORKSPACE_CONFIG" in os.environ:
774+
del os.environ["CUBLAS_WORKSPACE_CONFIG"]
775+
776+
if _original_cublaslt_workspace_size:
777+
os.environ["CUBLASLT_WORKSPACE_SIZE"] = _original_cublaslt_workspace_size
778+
elif "CUBLASLT_WORKSPACE_SIZE" in os.environ:
779+
del os.environ["CUBLASLT_WORKSPACE_SIZE"]
780+
781+
_original_cublas_workspace_cfg = None
782+
_original_cublaslt_workspace_size = None
783+
716784
_batch_invariant_MODE = False
717785
_batch_invariant_LIB = None
718786

@@ -791,6 +859,9 @@ def override_envs_for_invariance():
791859
os.environ["NCCL_NTHREADS"] = "1"
792860
os.environ["NCCL_SOCKET_NTHREADS"] = "1"
793861

862+
# torch.compile settings
863+
os.environ["VLLM_USE_AOT_COMPILE"] = "0"
864+
794865

795866
def init_batch_invariance():
796867
# this will hit all the csrc overrides as well

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,7 @@ def __init__(self, quant_config: Fp8Config):
363363
self.use_marlin = False
364364

365365
self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()
366+
self.use_deep_gemm = is_deep_gemm_supported()
366367

367368
self.weight_block_size = self.quant_config.weight_block_size
368369
self.block_quant = self.weight_block_size is not None
@@ -546,7 +547,7 @@ def apply(
546547
# we will use BF16 dequant when DeepGEMM is not supported.
547548
if vllm_is_batch_invariant():
548549
if self.block_quant and should_use_deepgemm_for_fp8_linear(
549-
torch.bfloat16, layer.weight, None
550+
torch.bfloat16, layer.weight, self.use_deep_gemm
550551
):
551552
# use group quant consistent with block size across K
552553
assert self.act_q_group_shape is not None

0 commit comments

Comments
 (0)