diff --git a/vllm/model_executor/layers/rejection_sampler.py b/vllm/model_executor/layers/rejection_sampler.py index f173cbde03f4..9d6c3797c62f 100644 --- a/vllm/model_executor/layers/rejection_sampler.py +++ b/vllm/model_executor/layers/rejection_sampler.py @@ -9,6 +9,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.spec_decode_base_sampler import ( SpecDecodeStochasticBaseSampler) +from vllm.platforms import current_platform logger = init_logger(__name__) @@ -368,7 +369,7 @@ def _smallest_positive_value(self) -> float: # Note that we always sample with replacement. # probs will be modified in place, but this is fine, as we pass # in a copy already. -@torch.compile(dynamic=True) +@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) def _multinomial( probs: torch.Tensor, num_samples: int, diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index 30548e656c55..65920aa61ba1 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -133,7 +133,7 @@ def __post_init__(self): assert self.num_added_elements <= self.num_added_elements_padded -@torch.compile(dynamic=True) +@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) def get_masked_input_and_mask( input_: torch.Tensor, org_vocab_start_index: int, org_vocab_end_index: int, num_org_vocab_padding: int, diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index d22d1f317146..8d61ece28941 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -45,6 +45,7 @@ row_parallel_weight_loader) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -53,7 +54,7 @@ maybe_prefix) -@torch.compile +@torch.compile(backend=current_platform.simple_compile_backend) def layer_norm_func(hidden_states, weight, variance_epsilon): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) diff --git a/vllm/model_executor/models/phi3_small.py b/vllm/model_executor/models/phi3_small.py index da7e4cdbc694..f47676b934e4 100644 --- a/vllm/model_executor/models/phi3_small.py +++ b/vllm/model_executor/models/phi3_small.py @@ -20,6 +20,7 @@ DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP @@ -54,12 +55,12 @@ def weight_loader(self, param: torch.nn.Parameter, return load_column_parallel_weight(param, loaded_weight) -@torch.compile(dynamic=True) +@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) def quick_gelu(x): return x * torch.sigmoid(1.702 * x) -@torch.compile(dynamic=True) +@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) def gegelu(input, limit: Optional[float] = None): a_gelu, a_linear = input[..., ::2], input[..., 1::2] if limit is not None: diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 01d753408e6d..fe398801c5dd 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -82,6 +82,12 @@ class Platform: # check https://github.com/pytorch/pytorch/blob/313dac6c1ca0fa0cde32477509cce32089f8532a/torchgen/model.py#L134 # noqa # use "CPU" as a fallback for platforms not registered in PyTorch dispatch_key: str = "CPU" + # The torch.compile backend for compiling simple and + # standalone functions. The default value is "inductor" to keep + # the same behavior as PyTorch. + # NOTE: for the forward part of the model, vLLM has another separate + # compilation strategy. + simple_compile_backend: str = "inductor" supported_quantization: list[str] = [] def is_cuda(self) -> bool: