Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion tests/basic_correctness/test_cumem.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,16 @@ def test_end_to_end(model: str, use_v1: bool):
used_bytes = total - free_gpu_bytes_after_sleep - used_bytes_baseline
# now the memory usage is mostly cudagraph memory pool,
# and it should be less than the model weights (1B model, 2GiB weights)
assert used_bytes < 2 * GiB_bytes

# NOTE: In V1, the memory buffer for logits (max_num_reqs x vocab_size)
# is captured but cannot be releasesd from PyTorch due to a known bug,
Comment on lines +146 to +147
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please elaborate on this?

Copy link
Member Author

@ywang96 ywang96 Mar 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See the discussion here https://vllm-dev.slack.com/archives/C087WBWC5AQ/p1741398800083509?thread_ts=1741386694.452939&cid=C087WBWC5AQ - TLDR is that empty_cache cannot be called when we turn on sleep mode.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm... Why do we need empty_cache?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The difference here is that we never (in both V0 and V1) warmed up sampler, therefore the memory fragmentation issue was always there but not as pronounced in V0 (since the default batch size is 256).

Now we're adding the sampler warmup in V1, but when we call sleep(), the memory buffer for logits can't be cleared from the pytorch caching allocator (the bug mentioned in this comment), therefore the memory usage will be a lot higher.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ywang96 Thanks for the explanation. Just want to double check: We don't want to call empty_cache anyways, because we intentionally reserve the (max_num_reqs x vocab_size)-sized tensor in the pytorch allocator, right?

Copy link
Member Author

@ywang96 ywang96 Mar 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is correct though I do think there should be a better & clean fix for this to work with sleep mode in the long term. We should probably free the memory when sleep is called, then warm up sampler again within wakeup, but this is currently blocked since we can't free the memory anyways.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm,,, How is the logits tensor different from other intermediate activation tensors?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why this specific tensor becomes a problem.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because dummy_run doesn't include/activate sampler tensors, this is why we made dummy_sampler_run in the first place.

# therefore high memory usage after `llm.sleep` is called is expected.
# FIXME(youkaichao & ywang96): Fix memory buffer issue with sleep mode
# in V1.
if use_v1:
assert used_bytes < 7 * GiB_bytes
else:
assert used_bytes < 2 * GiB_bytes

llm.wake_up()
output2 = llm.generate(prompt, sampling_params)
Expand Down
5 changes: 5 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3525,6 +3525,11 @@ def _set_cudagraph_sizes(self):
not self.model_config.enforce_eager:
batch_size_capture_list = [1, 2, 4
] + [i for i in range(8, 513, 8)]
max_num_tokens = self.scheduler_config.max_num_batched_tokens
batch_size_capture_list = [
size for size in batch_size_capture_list
if size <= max_num_tokens
]

self.compilation_config.init_with_cudagraph_sizes(
batch_size_capture_list)
Expand Down
185 changes: 98 additions & 87 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1202,41 +1202,98 @@ def _dummy_run(
self,
num_tokens: int,
) -> torch.Tensor:
model = self.model
if self.is_multimodal_model:
input_ids = None
inputs_embeds = self.inputs_embeds[:num_tokens]
else:
input_ids = self.input_ids[:num_tokens]
inputs_embeds = None
if self.uses_mrope:
positions = self.mrope_positions[:, :num_tokens]
else:
positions = self.positions[:num_tokens]

if get_pp_group().is_first_rank:
intermediate_tensors = None
else:
if self.intermediate_tensors is None:
self.intermediate_tensors = (
self.model.make_empty_intermediate_tensors(
batch_size=self.max_num_tokens,
dtype=self.model_config.dtype,
device=self.device))
intermediate_tensors = IntermediateTensors({
k: v[:num_tokens]
for k, v in self.intermediate_tensors.items()
})
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
# for dummy run with LoRA so that the num_reqs collectively
# has num_tokens in total.
assert num_tokens <= self.scheduler_config.max_num_batched_tokens
max_num_reqs = self.scheduler_config.max_num_seqs
num_reqs = max_num_reqs if num_tokens >= max_num_reqs else num_tokens
min_tokens_per_req = num_tokens // num_reqs
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
assert sum(num_scheduled_tokens_list) == num_tokens
assert len(num_scheduled_tokens_list) == num_reqs
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
dtype=np.int32)

with set_forward_context(None, self.vllm_config,
num_tokens=num_tokens):
hidden_states = model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
return hidden_states
with self.maybe_dummy_run_with_lora(self.lora_config,
num_scheduled_tokens):
model = self.model
if self.is_multimodal_model:
input_ids = None
inputs_embeds = self.inputs_embeds[:num_tokens]
else:
input_ids = self.input_ids[:num_tokens]
inputs_embeds = None
if self.uses_mrope:
positions = self.mrope_positions[:, :num_tokens]
else:
positions = self.positions[:num_tokens]

if get_pp_group().is_first_rank:
intermediate_tensors = None
else:
if self.intermediate_tensors is None:
self.intermediate_tensors = (
self.model.make_empty_intermediate_tensors(
batch_size=self.max_num_tokens,
dtype=self.model_config.dtype,
device=self.device))
intermediate_tensors = IntermediateTensors({
k: v[:num_tokens]
for k, v in self.intermediate_tensors.items()
})

with set_forward_context(None,
self.vllm_config,
num_tokens=num_tokens):
hidden_states = model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)

logit_indices = np.cumsum(num_scheduled_tokens) - 1
return hidden_states[logit_indices]

@torch.inference_mode()
def _dummy_sampler_run(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:

logits = self.model.compute_logits(hidden_states, None)
num_reqs = logits.size(0)

dummy_tensors = lambda v: torch.full(
(num_reqs, ), v, device=self.device)

dummy_metadata = SamplingMetadata(
temperature=dummy_tensors(0.5),
all_greedy=False,
all_random=False,
top_p=dummy_tensors(0.9),
top_k=dummy_tensors(logits.size(1) - 1),
min_p=None,
generators={},
max_num_logprobs=None,
no_penalties=True,
prompt_token_ids=None,
frequency_penalties=dummy_tensors(0.1),
presence_penalties=dummy_tensors(0.1),
repetition_penalties=dummy_tensors(0.1),
output_token_ids=[[] for _ in range(num_reqs)],
min_tokens={},
logit_bias=[None for _ in range(num_reqs)],
allowed_token_ids_mask=None,
bad_words_token_ids={},
)
sampler_output = self.model.sample(logits=logits,
sampling_metadata=dummy_metadata)

return sampler_output

def profile_run(self) -> None:
# Profile with multimodal encoder & encoder cache.
Expand Down Expand Up @@ -1332,60 +1389,14 @@ def profile_run(self) -> None:
# Cache the dummy encoder outputs.
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))

# For profile, have maximum num_reqs and that collectively have
# maximum num_tokens.
num_reqs = self.scheduler_config.max_num_seqs
num_tokens = self.max_num_tokens
min_tokens_per_req = num_tokens // num_reqs

num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
assert sum(num_scheduled_tokens_list) == num_tokens
assert len(num_scheduled_tokens_list) == num_reqs

num_scheduled_tokens = np.array(num_scheduled_tokens_list,
dtype=np.int32)
logit_indices = np.cumsum(num_scheduled_tokens) - 1

with self.maybe_profile_with_lora(self.lora_config,
num_scheduled_tokens):
# Trigger compilation for general shape.
hidden_states = self._dummy_run(self.max_num_tokens)
if get_pp_group().is_last_rank:
hidden_states = hidden_states[logit_indices]
logits = self.model.compute_logits(hidden_states, None)
dummy_tensors = lambda v: torch.full(
(num_reqs, ), v, device=self.device)
dummy_metadata = SamplingMetadata(
temperature=dummy_tensors(0.5),
all_greedy=False,
all_random=False,
top_p=dummy_tensors(0.9),
top_k=dummy_tensors(logits.size(1) - 1),
min_p=None,
generators={},
max_num_logprobs=None,
no_penalties=True,
prompt_token_ids=torch.ones_like(logits,
dtype=torch.int64),
frequency_penalties=dummy_tensors(0.1),
presence_penalties=dummy_tensors(0.1),
repetition_penalties=dummy_tensors(0.1),
output_token_ids=[[] for _ in range(num_reqs)],
min_tokens={},
logit_bias=[None for _ in range(num_reqs)],
allowed_token_ids_mask=None,
bad_words_token_ids={},
)
sampler_output = self.model.sample(
logits=logits, sampling_metadata=dummy_metadata)
else:
logits = None
sampler_output = None
dummy_metadata = None
torch.cuda.synchronize()
del hidden_states, logits, sampler_output, dummy_metadata
self.encoder_cache.clear()
hidden_states = self._dummy_run(self.max_num_tokens)
if get_pp_group().is_last_rank:
sampler_output = self._dummy_sampler_run(hidden_states)
else:
sampler_output = None
torch.cuda.synchronize()
del hidden_states, sampler_output
self.encoder_cache.clear()
gc.collect()

def capture_model(self) -> None:
Expand Down
23 changes: 23 additions & 0 deletions vllm/v1/worker/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ def init_device(self):
self.model_runner: GPUModelRunner = GPUModelRunner(
self.vllm_config, self.device)

# FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
# to hijack tensor allocation.
def load_model(self) -> None:
if self.vllm_config.model_config.enable_sleep_mode:
allocator = CuMemAllocator.get_instance()
Expand Down Expand Up @@ -211,6 +213,27 @@ def compile_or_warm_up_model(self) -> None:
self.model_runner._dummy_run(size)
if not self.model_config.enforce_eager:
self.model_runner.capture_model()

# Warm up sampler and preallocate memory buffer for logits and other
# sampling related tensors of max possible shape to avoid memory
# fragmentation issue.
# NOTE: This is called after `capture_model` on purpose to prevent
# memory buffers from being cleared by `torch.cuda.empty_cache`.
try:
max_num_reqs = min(self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens)
self.model_runner._dummy_sampler_run(
hidden_states=self.model_runner._dummy_run(
num_tokens=max_num_reqs))
except RuntimeError as e:
if 'out of memory' in str(e):
raise RuntimeError(
"CUDA out of memory occurred when warming up sampler. "
"Please try lowering `gpu_memory_utilization` when "
"initializing the engine.") from None
else:
raise e

# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
Expand Down
6 changes: 3 additions & 3 deletions vllm/v1/worker/lora_model_runner_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ def set_active_loras(self, input_batch: InputBatch,
lora_requests)

@contextmanager
def maybe_profile_with_lora(self, lora_config: LoRAConfig,
num_scheduled_tokens: np.ndarray):
def maybe_dummy_run_with_lora(self, lora_config: LoRAConfig,
num_scheduled_tokens: np.ndarray):
if lora_config is None:
yield
else:
Expand Down Expand Up @@ -145,4 +145,4 @@ def pin_lora(self, lora_id: int) -> bool:
def list_loras(self) -> set[int]:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.list_adapters()
return self.lora_manager.list_adapters()