Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm/v1/sample/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
@dataclass
class SamplingMetadata:

temperature: torch.Tensor
temperature: Optional[torch.Tensor]
all_greedy: bool
all_random: bool

Expand Down
8 changes: 4 additions & 4 deletions vllm/v1/sample/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,8 @@ def apply_temperature(
logits: torch.Tensor,
temp: torch.Tensor,
) -> torch.Tensor:
# Avoid division by zero.
temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
# Use in-place division to avoid creating a new tensor.
logits.div_(temp.unsqueeze(dim=1))
return logits
return logits.div_(temp.unsqueeze(dim=1))

def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
return logits.argmax(dim=-1).view(-1)
Expand All @@ -100,6 +97,8 @@ def sample(
if sampling_metadata.all_greedy:
return greedy_sampled

assert sampling_metadata.temperature is not None

# Apply temperature.
logits = self.apply_temperature(logits, sampling_metadata.temperature)

Expand All @@ -122,6 +121,7 @@ def sample(
sampling_metadata.temperature < _SAMPLING_EPS,
greedy_sampled,
random_sampled,
out=greedy_sampled, # Reuse tensor
)
return sampled

Expand Down
6 changes: 4 additions & 2 deletions vllm/v1/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,11 +191,13 @@ def bind_kv_cache(


def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor,
length: int) -> None:
length: int) -> torch.Tensor:
"""
Copy the first length elements of a tensor into another tensor in a
non-blocking manner.

Used to copy pinned CPU tensor data to pre-allocated GPU tensors.

Returns the sliced target tensor.
"""
to_tensor[:length].copy_(from_tensor[:length], non_blocking=True)
return to_tensor[:length].copy_(from_tensor[:length], non_blocking=True)
12 changes: 9 additions & 3 deletions vllm/v1/worker/gpu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,10 +242,12 @@ def add_request(
self.block_table.add_row(req_index, request.block_ids)

sampling_params = request.sampling_params
self.temperature_cpu[req_index] = sampling_params.temperature
if sampling_params.sampling_type == SamplingType.GREEDY:
# Avoid later division by zero.
self.temperature_cpu[req_index] = -1.0
self.greedy_reqs.add(req_id)
else:
self.temperature_cpu[req_index] = sampling_params.temperature
self.random_reqs.add(req_id)

self.top_p_cpu[req_index] = sampling_params.top_p
Expand Down Expand Up @@ -410,7 +412,11 @@ def refresh_sampling_metadata(self):

def _make_sampling_metadata(self) -> SamplingMetadata:
num_reqs = self.num_reqs
copy_slice(self.temperature_cpu_tensor, self.temperature, num_reqs)
if not self.all_greedy:
temperature = copy_slice(self.temperature_cpu_tensor,
self.temperature, num_reqs)
else:
temperature = None
if not self.no_top_p:
copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs)
if not self.no_top_k:
Expand All @@ -437,7 +443,7 @@ def _make_sampling_metadata(self) -> SamplingMetadata:
prompt_token_ids = None

return SamplingMetadata(
temperature=self.temperature[:num_reqs],
temperature=temperature,
all_greedy=self.all_greedy,
all_random=self.all_random,
top_p=None if self.no_top_p else self.top_p[:num_reqs],
Expand Down