diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 886295ee895c..0b0f521672b0 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -222,8 +222,7 @@ Fp8LinearGenericOp, current_platform_fp8_dtype, is_fp8) from vllm.model_executor.layers.quantization.utils.quant_utils import ( scaled_quantize) -from vllm.model_executor.layers.rotary_embedding import ( - DeepseekScalingRotaryEmbedding, RotaryEmbedding) +from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.utils import cdiv, round_down try: @@ -626,9 +625,12 @@ def __init__( self.qk_head_dim = qk_head_dim self.v_head_dim = v_head_dim - self.rotary_emb = rotary_emb - self.use_yarn_rope = isinstance(rotary_emb, - DeepseekScalingRotaryEmbedding) + # Hack for V1 for now to avoid torch library overhead (since we are + # already inside an attention custom op), pull out the forward + # method from the rotary embedding and call it directly + # TODO(lucas): we should probably find a cleaner way to do this + self.rotary_emb = rotary_emb._forward_method + self.q_proj = q_proj self.kv_b_proj = kv_b_proj self.o_proj = o_proj