Skip to content

Commit 2c3ba73

Browse files
mgoinProExpertProg
authored andcommitted
[Perf] Use FlashInfer RoPE for RotaryEmbedding.forward_cuda when available (#21126)
Signed-off-by: mgoin <[email protected]> Signed-off-by: Michael Goin <[email protected]> Co-authored-by: Luka Govedič <[email protected]> Signed-off-by: yewentao256 <[email protected]>
1 parent bfd3267 commit 2c3ba73

File tree

5 files changed

+78
-14
lines changed

5 files changed

+78
-14
lines changed

vllm/model_executor/layers/rotary_embedding/base.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import torch
77

88
from vllm.model_executor.custom_op import CustomOp
9+
from vllm.platforms import current_platform
10+
from vllm.utils.flashinfer import has_flashinfer
911

1012
from .common import apply_rotary_emb_torch
1113

@@ -30,9 +32,17 @@ def __init__(
3032
self.base = base
3133
self.is_neox_style = is_neox_style
3234
self.dtype = dtype
35+
# Flashinfer only supports head_size=64, 128, 256, 512.
36+
# https://github.com/flashinfer-ai/flashinfer/blob/ebfd655efe830048dba5d582aaa61d61d1cf9a87/include/flashinfer/utils.cuh#L174-L202
37+
self.use_flashinfer = (self.enabled()
38+
and dtype in (torch.float16, torch.bfloat16)
39+
and current_platform.is_cuda()
40+
and has_flashinfer()
41+
and self.head_size in [64, 128, 256, 512])
3342

3443
cache = self._compute_cos_sin_cache()
35-
cache = cache.to(dtype)
44+
if not self.use_flashinfer:
45+
cache = cache.to(dtype)
3646
self.cos_sin_cache: torch.Tensor
3747
self.register_buffer("cos_sin_cache", cache, persistent=False)
3848

@@ -57,6 +67,14 @@ def _compute_cos_sin_cache(self) -> torch.Tensor:
5767
cache = torch.cat((cos, sin), dim=-1)
5868
return cache
5969

70+
def _match_cos_sin_cache_dtype(self, query: torch.Tensor) -> None:
71+
# __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`)
72+
# is expensive, so avoid calling it if possible
73+
if self.cos_sin_cache.device != query.device or \
74+
self.cos_sin_cache.dtype != query.dtype:
75+
self.cos_sin_cache = self.cos_sin_cache.to(query.device,
76+
dtype=query.dtype)
77+
6078
def forward_native(
6179
self,
6280
positions: torch.Tensor,
@@ -94,15 +112,16 @@ def forward_cuda(
94112
query: torch.Tensor,
95113
key: Optional[torch.Tensor] = None,
96114
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
97-
from vllm import _custom_ops as ops
115+
if self.use_flashinfer:
116+
torch.ops.vllm.flashinfer_rotary_embedding(positions, query, key,
117+
self.head_size,
118+
self.cos_sin_cache,
119+
self.is_neox_style)
120+
return query, key
98121

99-
# __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`)
100-
# is expensive, so avoid calling it if possible
101-
if self.cos_sin_cache.device != query.device or \
102-
self.cos_sin_cache.dtype != query.dtype:
103-
self.cos_sin_cache = self.cos_sin_cache.to(query.device,
104-
dtype=query.dtype)
122+
from vllm import _custom_ops as ops
105123

124+
self._match_cos_sin_cache_dtype(query)
106125
# ops.rotary_embedding() is an in-place operation
107126
# that updates the query and key tensors.
108127
ops.rotary_embedding(positions, query, key, self.head_size,
@@ -117,8 +136,7 @@ def forward_xpu(
117136
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
118137
from vllm._ipex_ops import ipex_ops as ops
119138

120-
self.cos_sin_cache = self.cos_sin_cache.to(positions.device,
121-
dtype=query.dtype)
139+
self._match_cos_sin_cache_dtype(query)
122140
# ops.rotary_embedding() is an in-place operation
123141
# that updates the query and key tensors.
124142
if key is None:

vllm/model_executor/layers/rotary_embedding/common.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch
77

88
from vllm.platforms import current_platform
9+
from vllm.utils import direct_register_custom_op
910

1011
if current_platform.is_cuda():
1112
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
@@ -103,3 +104,48 @@ def yarn_get_mscale(scale: float = 1) -> float:
103104
if scale <= 1:
104105
return 1.0
105106
return 0.1 * math.log(scale) + 1.0
107+
108+
109+
def _flashinfer_rotary_embedding(
110+
positions: torch.Tensor,
111+
query: torch.Tensor,
112+
key: torch.Tensor,
113+
head_size: int,
114+
cos_sin_cache: torch.Tensor,
115+
is_neox: bool,
116+
) -> None:
117+
"""Custom op wrapper for flashinfer's rotary embedding.
118+
119+
This is an in-place operation that modifies query and key tensors directly.
120+
"""
121+
from flashinfer.rope import apply_rope_with_cos_sin_cache_inplace
122+
123+
apply_rope_with_cos_sin_cache_inplace(
124+
positions=positions,
125+
query=query,
126+
key=key,
127+
head_size=head_size,
128+
cos_sin_cache=cos_sin_cache,
129+
is_neox=is_neox,
130+
)
131+
132+
133+
def _flashinfer_rotary_embedding_fake(
134+
positions: torch.Tensor,
135+
query: torch.Tensor,
136+
key: torch.Tensor,
137+
head_size: int,
138+
cos_sin_cache: torch.Tensor,
139+
is_neox: bool,
140+
) -> None:
141+
return
142+
143+
144+
# Register flashinfer rotary embedding custom op
145+
direct_register_custom_op(
146+
op_name="flashinfer_rotary_embedding",
147+
op_func=_flashinfer_rotary_embedding,
148+
mutates_args=["query", "key"], # These tensors are modified in-place
149+
fake_impl=_flashinfer_rotary_embedding_fake,
150+
dispatch_key=current_platform.dispatch_key,
151+
)

vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,15 +97,13 @@ def forward_native(
9797
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
9898
"""PyTorch-native implementation equivalent to forward()."""
9999
assert key is not None
100+
self._match_cos_sin_cache_dtype(query)
100101
query_rot = query[..., :self.rotary_dim]
101102
key_rot = key[..., :self.rotary_dim]
102103
if self.rotary_dim < self.head_size:
103104
query_pass = query[..., self.rotary_dim:]
104105
key_pass = key[..., self.rotary_dim:]
105106

106-
if self.cos_sin_cache.device != positions.device:
107-
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(
108-
positions.device)
109107
cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
110108
if offsets is not None else positions]
111109
cos, sin = cos_sin.chunk(2, dim=-1)

vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def forward_native( # type: ignore[override]
5959
key: Optional[torch.Tensor] = None,
6060
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
6161
assert key is not None
62-
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device)
62+
self._match_cos_sin_cache_dtype(query)
6363
query_ = torch.view_as_complex(query.float().reshape(
6464
*query.shape[:-1], -1, 2))
6565
key_ = torch.view_as_complex(key.float().reshape(

vllm/model_executor/layers/rotary_embedding/mrope.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ def forward_native(
245245
assert positions.ndim == 1 or positions.ndim == 2
246246
assert key is not None
247247

248+
self._match_cos_sin_cache_dtype(query)
248249
num_tokens = positions.shape[-1]
249250
cos_sin = self.cos_sin_cache[positions]
250251
cos, sin = cos_sin.chunk(2, dim=-1)
@@ -293,6 +294,7 @@ def forward_cuda(
293294
assert positions.ndim == 1 or positions.ndim == 2
294295
assert key is not None
295296

297+
self._match_cos_sin_cache_dtype(query)
296298
num_tokens = positions.shape[-1]
297299
cos_sin = self.cos_sin_cache[positions]
298300
cos, sin = cos_sin.chunk(2, dim=-1)

0 commit comments

Comments
 (0)