Skip to content

Commit 371d04d

Browse files
authored
[V1] Use FlashInfer Sampling Kernel for Top-P & Top-K Sampling (#11394)
Signed-off-by: Woosuk Kwon <[email protected]>
1 parent 0c0c201 commit 371d04d

File tree

6 files changed

+355
-190
lines changed

6 files changed

+355
-190
lines changed

tests/v1/sample/test_sampler.py

Lines changed: 22 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def _create_default_sampling_metadata(
6868
no_top_p=True,
6969
no_top_k=True,
7070
generators={},
71-
max_num_logprobs=VOCAB_SIZE,
71+
max_num_logprobs=0,
7272
prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids,
7373
vocab_size, device),
7474
output_token_ids=output_token_ids,
@@ -169,20 +169,14 @@ def test_sampler_min_tokens_penalty(device: str, batch_size: int):
169169
sampling_metadata.min_tokens = min_tokens
170170
sampling_metadata.stop_token_ids = stop_token_ids
171171
sampler = Sampler()
172-
sampler_output = sampler(fake_logits, sampling_metadata)
172+
logits = sampler.apply_penalties(fake_logits, sampling_metadata)
173+
logits = logits.cpu()
173174
for batch_idx in range(batch_size):
174-
for vocab in range(VOCAB_SIZE):
175-
# Verify that the logprobs for stop token ids is set
176-
# to -inf.
177-
logprob_index = torch.where(
178-
sampler_output.logprob_token_ids[batch_idx] ==
179-
vocab)[0].item()
180-
if vocab in stop_token_ids[batch_idx]:
181-
assert sampler_output.logprobs[batch_idx][
182-
logprob_index] == -float("inf")
175+
for token_id in range(VOCAB_SIZE):
176+
if token_id in stop_token_ids[batch_idx]:
177+
assert logits[batch_idx][token_id] == -float("inf")
183178
else:
184-
assert sampler_output.logprobs[batch_idx][
185-
logprob_index] != -float("inf")
179+
assert logits[batch_idx][token_id] != -float("inf")
186180

187181

188182
@pytest.mark.parametrize("device", CUDA_DEVICES)
@@ -205,18 +199,14 @@ def test_sampler_presence_penalty(device: str, batch_size: int,
205199
batch_size, presence_penalty, torch.device(device))
206200
sampling_metadata.no_penalties = False
207201
sampler = Sampler()
208-
sampler_output = sampler(fake_logits, sampling_metadata)
202+
logits = sampler.apply_penalties(fake_logits, sampling_metadata)
203+
logits = logits.cpu()
209204
for batch_idx in range(batch_size):
210-
# The logprobs in the SamplerOutput are arranged in descending order.
211-
# Since all tokens initially have the same logprobs, the non-penalized
212-
# tokens will appear at the beginning, while the penalized tokens
213-
# will appear at the end of the list.
214-
penalized_token_id = sampler_output.logprob_token_ids[batch_idx][
215-
VOCAB_SIZE - 1]
216-
penalized_log_prod = sampler_output.logprobs[batch_idx][VOCAB_SIZE - 1]
217-
non_penalized_token_id = sampler_output.logprob_token_ids[batch_idx][0]
218-
non_penalized_log_prod = sampler_output.logprobs[batch_idx][0]
219-
assert non_penalized_log_prod > penalized_log_prod
205+
# Since all tokens initially have the same logits, the non-penalized
206+
# token ID will be the one with the highest logit value, while the
207+
# penalized token ID will be the one with the lowest logit value.
208+
non_penalized_token_id = logits[batch_idx].argmax().item()
209+
penalized_token_id = logits[batch_idx].argmin().item()
220210
if presence_penalty > 0:
221211
# If `presence_penalty` is set to a value greater than 0, it
222212
# indicates a preference for new tokens over those already
@@ -256,11 +246,11 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
256246
sampling_metadata.output_token_ids = output_token_ids
257247
sampling_metadata.no_penalties = False
258248
sampler = Sampler()
259-
sampler_output = sampler(fake_logits, sampling_metadata)
249+
logits = sampler.apply_penalties(fake_logits, sampling_metadata)
250+
logits = logits.cpu()
260251
for batch_idx in range(batch_size):
261-
logprobs_token_ids = sampler_output.logprob_token_ids[batch_idx]
262-
non_penalized_token_id = logprobs_token_ids[0]
263-
penalized_token_id = logprobs_token_ids[VOCAB_SIZE - 1]
252+
non_penalized_token_id = logits[batch_idx].argmax().item()
253+
penalized_token_id = logits[batch_idx].argmin().item()
264254
distinct_sorted_token_ids_in_output = \
265255
sorted_token_ids_in_output[batch_idx]
266256
most_frequent_token_id = distinct_sorted_token_ids_in_output[
@@ -305,11 +295,11 @@ def test_sampler_repetition_penalty(device: str, batch_size: int,
305295
batch_size, repetition_penalty, torch.device(device))
306296
sampling_metadata.no_penalties = False
307297
sampler = Sampler()
308-
sampler_output = sampler(fake_logits, sampling_metadata)
298+
logits = sampler.apply_penalties(fake_logits, sampling_metadata)
299+
logits = logits.cpu()
309300
for batch_idx in range(batch_size):
310-
logprobs_token_ids = sampler_output.logprob_token_ids[batch_idx]
311-
non_penalized_token_id = logprobs_token_ids[0]
312-
penalized_token_id = logprobs_token_ids[VOCAB_SIZE - 1]
301+
non_penalized_token_id = logits[batch_idx].argmax().item()
302+
penalized_token_id = logits[batch_idx].argmin().item()
313303
prompt_tokens = sampling_metadata.prompt_token_ids[
314304
batch_idx][:].tolist()
315305
output_tokens = sampling_metadata.output_token_ids[batch_idx]

vllm/envs.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
VLLM_LOGGING_CONFIG_PATH: Optional[str] = None
3131
VLLM_TRACE_FUNCTION: int = 0
3232
VLLM_ATTENTION_BACKEND: Optional[str] = None
33-
VLLM_USE_FLASHINFER_SAMPLER: bool = False
33+
VLLM_USE_FLASHINFER_SAMPLER: Optional[bool] = None
3434
VLLM_USE_FLASHINFER_REJECTION_SAMPLER: bool = False
3535
VLLM_FLASHINFER_FORCE_TENSOR_CORES: bool = False
3636
VLLM_PP_LAYER_PARTITION: Optional[str] = None
@@ -277,7 +277,8 @@ def get_default_config_root():
277277

278278
# If set, vllm will use flashinfer sampler
279279
"VLLM_USE_FLASHINFER_SAMPLER":
280-
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_SAMPLER", "0"))),
280+
lambda: bool(int(os.environ["VLLM_USE_FLASHINFER_SAMPLER"]))
281+
if "VLLM_USE_FLASHINFER_SAMPLER" in os.environ else None,
281282

282283
# If set, vllm will force flashinfer to use tensor cores;
283284
# otherwise will use heuristic based on model architecture.

vllm/v1/sample/ops/__init__.py

Whitespace-only changes.

vllm/v1/sample/ops/penalties.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from typing import List, Set, Tuple
2+
3+
import torch
4+
5+
from vllm.model_executor.layers.utils import (
6+
apply_penalties as _apply_penalties)
7+
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
8+
9+
10+
def apply_min_token_penalties(logits: torch.Tensor,
11+
output_token_ids: List[List[int]],
12+
stop_token_ids: List[Set[int]],
13+
min_tokens: List[int]) -> None:
14+
"""
15+
Applies minimum token penalty by setting the logits of the stop tokens
16+
to -inf.
17+
"""
18+
min_tokens_logits_to_penalize: List[Tuple[int, int]] = []
19+
for index, min_token in enumerate(min_tokens):
20+
if (len(output_token_ids[index]) < min_token):
21+
for stop_token_id in stop_token_ids[index]:
22+
min_tokens_logits_to_penalize.append((index, stop_token_id))
23+
if min_tokens_logits_to_penalize:
24+
logits[tuple(zip(*min_tokens_logits_to_penalize))] = -float("inf")
25+
26+
27+
def apply_penalties(logits: torch.Tensor, prompt_token_ids: torch.Tensor,
28+
presence_penalties: torch.Tensor,
29+
frequency_penalties: torch.Tensor,
30+
repetition_penalties: torch.Tensor,
31+
output_token_ids: List[List[int]]) -> torch.Tensor:
32+
"""
33+
Applies presence, frequency and repetition penalties to the logits.
34+
"""
35+
_, vocab_size = logits.shape
36+
output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size,
37+
logits.device)
38+
return _apply_penalties(logits, prompt_token_ids, output_tokens_t,
39+
presence_penalties, frequency_penalties,
40+
repetition_penalties)
41+
42+
43+
def _convert_to_tensors(output_token_ids: List[List[int]], vocab_size: int,
44+
device: torch.device) -> torch.Tensor:
45+
"""
46+
Convert the different list data structures to tensors.
47+
"""
48+
output_tokens_tensor = make_tensor_with_pad(
49+
output_token_ids,
50+
# Use the value of vocab_size as a pad since we don't have a
51+
# token_id of this value.
52+
pad=vocab_size,
53+
device="cpu",
54+
dtype=torch.int64,
55+
pin_memory=is_pin_memory_available(),
56+
)
57+
return output_tokens_tensor.to(device, non_blocking=True)
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
from typing import Dict
2+
3+
import torch
4+
import torch.nn as nn
5+
6+
from vllm import envs
7+
from vllm.logger import init_logger
8+
from vllm.platforms import current_platform
9+
10+
logger = init_logger(__name__)
11+
12+
try:
13+
import flashinfer.sampling
14+
is_flashinfer_available = True
15+
except ImportError:
16+
is_flashinfer_available = False
17+
18+
19+
class TopKTopPSampler(nn.Module):
20+
21+
def __init__(self):
22+
super().__init__()
23+
if current_platform.is_cuda:
24+
if is_flashinfer_available:
25+
if envs.VLLM_USE_FLASHINFER_SAMPLER is not False:
26+
# NOTE(woosuk): The V0 sampler doesn't use FlashInfer for
27+
# sampling unless VLLM_USE_FLASHINFER_SAMPLER=1 (i.e., by
28+
# default it is unused). For backward compatibility, we set
29+
# `VLLM_USE_FLASHINFER_SAMPLER` as None by default and
30+
# interpret it differently in V0 and V1 samplers: In V0,
31+
# None means False, while in V1, None means True. This is
32+
# why we use the condition
33+
# `envs.VLLM_USE_FLASHINFER_SAMPLER is not False` here.
34+
logger.info("Using FlashInfer for top-p & top-k sampling.")
35+
self.forward = self.forward_cuda
36+
else:
37+
logger.warning(
38+
"FlashInfer is available, but it is not enabled. "
39+
"Falling back to the PyTorch-native implementation of "
40+
"top-p & top-k sampling. For the best performance, "
41+
"please set VLLM_USE_FLASHINFER_SAMPLER=1.")
42+
self.forward = self.forward_native
43+
else:
44+
logger.warning(
45+
"FlashInfer is not available. Falling back to the PyTorch-"
46+
"native implementation of top-p & top-k sampling. For the "
47+
"best performance, please install FalshInfer.")
48+
self.forward = self.forward_native
49+
else:
50+
self.forward = self.forward_native
51+
52+
def forward_native(
53+
self,
54+
logits: torch.Tensor,
55+
generators: Dict[int, torch.Generator],
56+
no_top_k: bool,
57+
k: torch.Tensor,
58+
no_top_p: bool,
59+
p: torch.Tensor,
60+
) -> torch.Tensor:
61+
"""PyTorch-native implementation of top-k and top-p sampling."""
62+
logits = apply_top_k_top_p(logits, no_top_k, k, no_top_p, p)
63+
probs = logits.softmax(dim=-1, dtype=torch.float32)
64+
return random_sample(probs, generators)
65+
66+
def forward_cuda(
67+
self,
68+
logits: torch.Tensor,
69+
generators: Dict[int, torch.Generator],
70+
no_top_k: bool,
71+
k: torch.Tensor,
72+
no_top_p: bool,
73+
p: torch.Tensor,
74+
) -> torch.Tensor:
75+
"""More optimized implementation for top-k and top-p sampling."""
76+
probs = logits.softmax(dim=-1, dtype=torch.float32)
77+
if no_top_k and no_top_p:
78+
# We prefer `random_sample` over `flashinfer_sample` when sorting is
79+
# not needed. This is because `random_sample` does not require
80+
# CPU-GPU synchronization while `flashinfer_sample` does.
81+
return random_sample(probs, generators)
82+
return flashinfer_sample(probs, no_top_k, k, no_top_p, p, generators)
83+
84+
85+
def apply_top_k_top_p(
86+
logits: torch.Tensor,
87+
no_top_k: bool,
88+
k: torch.Tensor,
89+
no_top_p: bool,
90+
p: torch.Tensor,
91+
) -> torch.Tensor:
92+
"""Apply top-k and top-p masks to the logits.
93+
94+
This function sorts the logits tensor, which can be slow for large batches.
95+
"""
96+
if no_top_k and no_top_p:
97+
return logits
98+
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
99+
100+
if not no_top_k:
101+
# Apply top-k.
102+
top_k_mask = logits_sort.size(1) - k.to(torch.long)
103+
# Get all the top_k values.
104+
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
105+
top_k_mask = logits_sort < top_k_mask
106+
logits_sort.masked_fill_(top_k_mask, -float("inf"))
107+
108+
if not no_top_p:
109+
# Apply top-p.
110+
probs_sort = logits_sort.softmax(dim=-1)
111+
probs_sum = probs_sort.cumsum(dim=-1)
112+
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
113+
# at least one
114+
top_p_mask[:, -1] = False
115+
logits_sort.masked_fill_(top_p_mask, -float("inf"))
116+
117+
# Re-sort the probabilities.
118+
logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
119+
return logits
120+
121+
122+
def random_sample(
123+
probs: torch.Tensor,
124+
generators: Dict[int, torch.Generator],
125+
) -> torch.Tensor:
126+
"""Randomly sample from the probabilities.
127+
128+
We use this function instead of torch.multinomial because torch.multinomial
129+
causes CPU-GPU synchronization.
130+
"""
131+
q = torch.empty_like(probs)
132+
# NOTE(woosuk): To batch-process the requests without their own seeds,
133+
# which is the common case, we first assume that every request does
134+
# not have its own seed. Then, we overwrite the values for the requests
135+
# that have their own seeds.
136+
if len(generators) != probs.shape[0]:
137+
q.exponential_()
138+
if generators:
139+
# TODO(woosuk): This can be slow because we handle each request
140+
# one by one. Optimize this.
141+
for i, generator in generators.items():
142+
q[i].exponential_(generator=generator)
143+
return probs.div_(q).argmax(dim=-1).view(-1)
144+
145+
146+
def flashinfer_sample(
147+
probs: torch.Tensor,
148+
no_top_k: bool,
149+
k: torch.Tensor,
150+
no_top_p: bool,
151+
p: torch.Tensor,
152+
generators: Dict[int, torch.Generator],
153+
) -> torch.Tensor:
154+
"""Sample from the probabilities using FlashInfer.
155+
156+
Statistically, this function is equivalent to the `random_sample` function.
157+
However, this function is faster because it avoids sorting the logits tensor
158+
via rejection sampling.
159+
160+
NOTE: The outputs of this function do not necessarily match the outputs of
161+
the `random_sample` function. It only guarantees that the outputs are
162+
statistically equivalent.
163+
164+
NOTE: This function includes CPU-GPU synchronization, while `random_sample`
165+
does not. Call this function at the end of the forward pass to minimize
166+
the synchronization overhead.
167+
"""
168+
assert not (no_top_k and no_top_p)
169+
max_top_k_round = 32
170+
batch_size = probs.shape[0]
171+
uniform_samples = torch.empty((max_top_k_round, batch_size),
172+
device=probs.device)
173+
if len(generators) != batch_size:
174+
uniform_samples.uniform_()
175+
if generators:
176+
for i, generator in generators.items():
177+
uniform_samples[:, i].uniform_(generator=generator)
178+
179+
if no_top_k:
180+
# Top-p only.
181+
next_token_ids, success = flashinfer.sampling.top_p_sampling_from_probs(
182+
probs, uniform_samples, p, deterministic=True)
183+
elif no_top_p:
184+
# Top-k only.
185+
next_token_ids, success = flashinfer.sampling.top_k_sampling_from_probs(
186+
probs, uniform_samples, k, deterministic=True)
187+
else:
188+
# Both top-k and top-p.
189+
next_token_ids, success = (
190+
flashinfer.sampling.top_k_top_p_sampling_from_probs(
191+
probs, uniform_samples, k, p, deterministic=True))
192+
193+
# NOTE: CPU-GPU synchronization happens here.
194+
if not success.all():
195+
if not no_top_k:
196+
probs = flashinfer.sampling.top_k_renorm_prob(probs, k)
197+
if not no_top_p:
198+
probs = flashinfer.sampling.top_p_renorm_prob(probs, p)
199+
next_token_ids = flashinfer.sampling.sampling_from_probs(
200+
probs, uniform_samples[0], deterministic=True)
201+
return next_token_ids.view(-1)

0 commit comments

Comments
 (0)