|
| 1 | +from typing import Dict |
| 2 | + |
| 3 | +import torch |
| 4 | +import torch.nn as nn |
| 5 | + |
| 6 | +from vllm import envs |
| 7 | +from vllm.logger import init_logger |
| 8 | +from vllm.platforms import current_platform |
| 9 | + |
| 10 | +logger = init_logger(__name__) |
| 11 | + |
| 12 | +try: |
| 13 | + import flashinfer.sampling |
| 14 | + is_flashinfer_available = True |
| 15 | +except ImportError: |
| 16 | + is_flashinfer_available = False |
| 17 | + |
| 18 | + |
| 19 | +class TopKTopPSampler(nn.Module): |
| 20 | + |
| 21 | + def __init__(self): |
| 22 | + super().__init__() |
| 23 | + if current_platform.is_cuda: |
| 24 | + if is_flashinfer_available: |
| 25 | + if envs.VLLM_USE_FLASHINFER_SAMPLER is not False: |
| 26 | + # NOTE(woosuk): The V0 sampler doesn't use FlashInfer for |
| 27 | + # sampling unless VLLM_USE_FLASHINFER_SAMPLER=1 (i.e., by |
| 28 | + # default it is unused). For backward compatibility, we set |
| 29 | + # `VLLM_USE_FLASHINFER_SAMPLER` as None by default and |
| 30 | + # interpret it differently in V0 and V1 samplers: In V0, |
| 31 | + # None means False, while in V1, None means True. This is |
| 32 | + # why we use the condition |
| 33 | + # `envs.VLLM_USE_FLASHINFER_SAMPLER is not False` here. |
| 34 | + logger.info("Using FlashInfer for top-p & top-k sampling.") |
| 35 | + self.forward = self.forward_cuda |
| 36 | + else: |
| 37 | + logger.warning( |
| 38 | + "FlashInfer is available, but it is not enabled. " |
| 39 | + "Falling back to the PyTorch-native implementation of " |
| 40 | + "top-p & top-k sampling. For the best performance, " |
| 41 | + "please set VLLM_USE_FLASHINFER_SAMPLER=1.") |
| 42 | + self.forward = self.forward_native |
| 43 | + else: |
| 44 | + logger.warning( |
| 45 | + "FlashInfer is not available. Falling back to the PyTorch-" |
| 46 | + "native implementation of top-p & top-k sampling. For the " |
| 47 | + "best performance, please install FalshInfer.") |
| 48 | + self.forward = self.forward_native |
| 49 | + else: |
| 50 | + self.forward = self.forward_native |
| 51 | + |
| 52 | + def forward_native( |
| 53 | + self, |
| 54 | + logits: torch.Tensor, |
| 55 | + generators: Dict[int, torch.Generator], |
| 56 | + no_top_k: bool, |
| 57 | + k: torch.Tensor, |
| 58 | + no_top_p: bool, |
| 59 | + p: torch.Tensor, |
| 60 | + ) -> torch.Tensor: |
| 61 | + """PyTorch-native implementation of top-k and top-p sampling.""" |
| 62 | + logits = apply_top_k_top_p(logits, no_top_k, k, no_top_p, p) |
| 63 | + probs = logits.softmax(dim=-1, dtype=torch.float32) |
| 64 | + return random_sample(probs, generators) |
| 65 | + |
| 66 | + def forward_cuda( |
| 67 | + self, |
| 68 | + logits: torch.Tensor, |
| 69 | + generators: Dict[int, torch.Generator], |
| 70 | + no_top_k: bool, |
| 71 | + k: torch.Tensor, |
| 72 | + no_top_p: bool, |
| 73 | + p: torch.Tensor, |
| 74 | + ) -> torch.Tensor: |
| 75 | + """More optimized implementation for top-k and top-p sampling.""" |
| 76 | + probs = logits.softmax(dim=-1, dtype=torch.float32) |
| 77 | + if no_top_k and no_top_p: |
| 78 | + # We prefer `random_sample` over `flashinfer_sample` when sorting is |
| 79 | + # not needed. This is because `random_sample` does not require |
| 80 | + # CPU-GPU synchronization while `flashinfer_sample` does. |
| 81 | + return random_sample(probs, generators) |
| 82 | + return flashinfer_sample(probs, no_top_k, k, no_top_p, p, generators) |
| 83 | + |
| 84 | + |
| 85 | +def apply_top_k_top_p( |
| 86 | + logits: torch.Tensor, |
| 87 | + no_top_k: bool, |
| 88 | + k: torch.Tensor, |
| 89 | + no_top_p: bool, |
| 90 | + p: torch.Tensor, |
| 91 | +) -> torch.Tensor: |
| 92 | + """Apply top-k and top-p masks to the logits. |
| 93 | +
|
| 94 | + This function sorts the logits tensor, which can be slow for large batches. |
| 95 | + """ |
| 96 | + if no_top_k and no_top_p: |
| 97 | + return logits |
| 98 | + logits_sort, logits_idx = logits.sort(dim=-1, descending=False) |
| 99 | + |
| 100 | + if not no_top_k: |
| 101 | + # Apply top-k. |
| 102 | + top_k_mask = logits_sort.size(1) - k.to(torch.long) |
| 103 | + # Get all the top_k values. |
| 104 | + top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1)) |
| 105 | + top_k_mask = logits_sort < top_k_mask |
| 106 | + logits_sort.masked_fill_(top_k_mask, -float("inf")) |
| 107 | + |
| 108 | + if not no_top_p: |
| 109 | + # Apply top-p. |
| 110 | + probs_sort = logits_sort.softmax(dim=-1) |
| 111 | + probs_sum = probs_sort.cumsum(dim=-1) |
| 112 | + top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) |
| 113 | + # at least one |
| 114 | + top_p_mask[:, -1] = False |
| 115 | + logits_sort.masked_fill_(top_p_mask, -float("inf")) |
| 116 | + |
| 117 | + # Re-sort the probabilities. |
| 118 | + logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort) |
| 119 | + return logits |
| 120 | + |
| 121 | + |
| 122 | +def random_sample( |
| 123 | + probs: torch.Tensor, |
| 124 | + generators: Dict[int, torch.Generator], |
| 125 | +) -> torch.Tensor: |
| 126 | + """Randomly sample from the probabilities. |
| 127 | +
|
| 128 | + We use this function instead of torch.multinomial because torch.multinomial |
| 129 | + causes CPU-GPU synchronization. |
| 130 | + """ |
| 131 | + q = torch.empty_like(probs) |
| 132 | + # NOTE(woosuk): To batch-process the requests without their own seeds, |
| 133 | + # which is the common case, we first assume that every request does |
| 134 | + # not have its own seed. Then, we overwrite the values for the requests |
| 135 | + # that have their own seeds. |
| 136 | + if len(generators) != probs.shape[0]: |
| 137 | + q.exponential_() |
| 138 | + if generators: |
| 139 | + # TODO(woosuk): This can be slow because we handle each request |
| 140 | + # one by one. Optimize this. |
| 141 | + for i, generator in generators.items(): |
| 142 | + q[i].exponential_(generator=generator) |
| 143 | + return probs.div_(q).argmax(dim=-1).view(-1) |
| 144 | + |
| 145 | + |
| 146 | +def flashinfer_sample( |
| 147 | + probs: torch.Tensor, |
| 148 | + no_top_k: bool, |
| 149 | + k: torch.Tensor, |
| 150 | + no_top_p: bool, |
| 151 | + p: torch.Tensor, |
| 152 | + generators: Dict[int, torch.Generator], |
| 153 | +) -> torch.Tensor: |
| 154 | + """Sample from the probabilities using FlashInfer. |
| 155 | +
|
| 156 | + Statistically, this function is equivalent to the `random_sample` function. |
| 157 | + However, this function is faster because it avoids sorting the logits tensor |
| 158 | + via rejection sampling. |
| 159 | + |
| 160 | + NOTE: The outputs of this function do not necessarily match the outputs of |
| 161 | + the `random_sample` function. It only guarantees that the outputs are |
| 162 | + statistically equivalent. |
| 163 | +
|
| 164 | + NOTE: This function includes CPU-GPU synchronization, while `random_sample` |
| 165 | + does not. Call this function at the end of the forward pass to minimize |
| 166 | + the synchronization overhead. |
| 167 | + """ |
| 168 | + assert not (no_top_k and no_top_p) |
| 169 | + max_top_k_round = 32 |
| 170 | + batch_size = probs.shape[0] |
| 171 | + uniform_samples = torch.empty((max_top_k_round, batch_size), |
| 172 | + device=probs.device) |
| 173 | + if len(generators) != batch_size: |
| 174 | + uniform_samples.uniform_() |
| 175 | + if generators: |
| 176 | + for i, generator in generators.items(): |
| 177 | + uniform_samples[:, i].uniform_(generator=generator) |
| 178 | + |
| 179 | + if no_top_k: |
| 180 | + # Top-p only. |
| 181 | + next_token_ids, success = flashinfer.sampling.top_p_sampling_from_probs( |
| 182 | + probs, uniform_samples, p, deterministic=True) |
| 183 | + elif no_top_p: |
| 184 | + # Top-k only. |
| 185 | + next_token_ids, success = flashinfer.sampling.top_k_sampling_from_probs( |
| 186 | + probs, uniform_samples, k, deterministic=True) |
| 187 | + else: |
| 188 | + # Both top-k and top-p. |
| 189 | + next_token_ids, success = ( |
| 190 | + flashinfer.sampling.top_k_top_p_sampling_from_probs( |
| 191 | + probs, uniform_samples, k, p, deterministic=True)) |
| 192 | + |
| 193 | + # NOTE: CPU-GPU synchronization happens here. |
| 194 | + if not success.all(): |
| 195 | + if not no_top_k: |
| 196 | + probs = flashinfer.sampling.top_k_renorm_prob(probs, k) |
| 197 | + if not no_top_p: |
| 198 | + probs = flashinfer.sampling.top_p_renorm_prob(probs, p) |
| 199 | + next_token_ids = flashinfer.sampling.sampling_from_probs( |
| 200 | + probs, uniform_samples[0], deterministic=True) |
| 201 | + return next_token_ids.view(-1) |
0 commit comments