diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 64c2dac524f2..48cdebee9161 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -161,8 +161,13 @@ def forward_cuda( ) -> Tuple[torch.Tensor, torch.Tensor]: from vllm import _custom_ops as ops - self.cos_sin_cache = self.cos_sin_cache.to(query.device, - dtype=query.dtype) + # __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`) + # is expensive, so avoid calling it if possible + if self.cos_sin_cache.device != query.device or \ + self.cos_sin_cache.dtype != query.dtype: + self.cos_sin_cache = self.cos_sin_cache.to(query.device, + dtype=query.dtype) + # ops.rotary_embedding()/batched_rotary_embedding() # are in-place operations that update the query and key tensors. if offsets is not None: diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 0b55854de94a..5b9a4b5ca4e5 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -222,8 +222,8 @@ apply_fp8_linear_generic, current_platform_fp8_dtype, is_fp8) from vllm.model_executor.layers.quantization.utils.quant_utils import ( scaled_quantize) -from vllm.model_executor.layers.rotary_embedding import ( - DeepseekScalingRotaryEmbedding, RotaryEmbedding) +from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding +from vllm.platforms import current_platform from vllm.utils import cdiv, round_down try: @@ -627,8 +627,15 @@ def __init__( self.v_head_dim = v_head_dim self.rotary_emb = rotary_emb - self.use_yarn_rope = isinstance(rotary_emb, - DeepseekScalingRotaryEmbedding) + + if current_platform.is_cuda(): + # Hack for V1 for now to avoid torch library overhead (since we are + # already inside an attention custom op), pull out the forward + # method from the rotary embedding and call it directly (and avoid + # calling forward_native, when we can call forward_cuda) + # TODO(lucas): we should probably find a cleaner way to do this + self.rotary_emb = rotary_emb.forward_cuda + self.q_proj = q_proj self.kv_b_proj = kv_b_proj self.o_proj = o_proj