Skip to content

Commit e36ee02

Browse files
ywang96shreyankg
authored andcommitted
[V1][Core] Fix memory issue with logits & sampling (vllm-project#13721)
1 parent be5a887 commit e36ee02

File tree

2 files changed

+49
-29
lines changed

2 files changed

+49
-29
lines changed

vllm/v1/worker/gpu_model_runner.py

Lines changed: 39 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1179,6 +1179,43 @@ def _dummy_run(
11791179
)
11801180
return hidden_states
11811181

1182+
@torch.inference_mode()
1183+
def _dummy_sampler_run(
1184+
self,
1185+
hidden_states: torch.Tensor,
1186+
) -> torch.Tensor:
1187+
1188+
logits = self.model.compute_logits(hidden_states, None)
1189+
num_reqs = logits.size(0)
1190+
1191+
dummy_tensors = lambda v: torch.full(
1192+
(num_reqs, ), v, device=self.device)
1193+
1194+
dummy_metadata = SamplingMetadata(
1195+
temperature=dummy_tensors(0.5),
1196+
all_greedy=False,
1197+
all_random=False,
1198+
spec_token_ids=None,
1199+
top_p=dummy_tensors(0.9),
1200+
top_k=dummy_tensors(logits.size(1) - 1),
1201+
min_p=None,
1202+
generators={},
1203+
max_num_logprobs=None,
1204+
no_penalties=True,
1205+
prompt_token_ids=None,
1206+
frequency_penalties=dummy_tensors(0.1),
1207+
presence_penalties=dummy_tensors(0.1),
1208+
repetition_penalties=dummy_tensors(0.1),
1209+
output_token_ids=[[] for _ in range(num_reqs)],
1210+
min_tokens={},
1211+
logit_bias=[None for _ in range(num_reqs)],
1212+
allowed_token_ids_mask=None,
1213+
)
1214+
sampler_output = self.model.sample(logits=logits,
1215+
sampling_metadata=dummy_metadata)
1216+
1217+
return sampler_output
1218+
11821219
def profile_run(self) -> None:
11831220
# use an empty tensor instead of `None`` to force Dynamo to pass
11841221
# it by reference, rather by specializing on the value `None`.
@@ -1306,38 +1343,11 @@ def profile_run(self) -> None:
13061343
dummy_kv_caches)
13071344
if get_pp_group().is_last_rank:
13081345
hidden_states = hidden_states[logit_indices]
1309-
logits = self.model.compute_logits(hidden_states, None)
1310-
dummy_tensors = lambda v: torch.full(
1311-
(num_reqs, ), v, device=self.device)
1312-
dummy_metadata = SamplingMetadata(
1313-
temperature=dummy_tensors(0.5),
1314-
all_greedy=False,
1315-
all_random=False,
1316-
spec_token_ids=None,
1317-
top_p=dummy_tensors(0.9),
1318-
top_k=dummy_tensors(logits.size(1) - 1),
1319-
min_p=None,
1320-
generators={},
1321-
max_num_logprobs=None,
1322-
no_penalties=True,
1323-
prompt_token_ids=torch.ones_like(logits,
1324-
dtype=torch.int64),
1325-
frequency_penalties=dummy_tensors(0.1),
1326-
presence_penalties=dummy_tensors(0.1),
1327-
repetition_penalties=dummy_tensors(0.1),
1328-
output_token_ids=[[] for _ in range(num_reqs)],
1329-
min_tokens={},
1330-
logit_bias=[None for _ in range(num_reqs)],
1331-
allowed_token_ids_mask=None,
1332-
)
1333-
sampler_output = self.model.sample(
1334-
logits=logits, sampling_metadata=dummy_metadata)
1346+
sampler_output = self._dummy_sampler_run(hidden_states)
13351347
else:
1336-
logits = None
13371348
sampler_output = None
1338-
dummy_metadata = None
13391349
torch.cuda.synchronize()
1340-
del hidden_states, logits, sampler_output, dummy_metadata
1350+
del hidden_states, sampler_output
13411351
self.encoder_cache.clear()
13421352
gc.collect()
13431353

vllm/v1/worker/gpu_worker.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,16 @@ def compile_or_warm_up_model(self) -> None:
211211
self.model_runner._dummy_run(size)
212212
if not self.model_config.enforce_eager:
213213
self.model_runner.capture_model()
214+
215+
# Warm up sampler and preallocate memory buffer for logits and other
216+
# sampling related tensors of max possible shape to avoid memory
217+
# fragmentation issue.
218+
# NOTE: This is called after `capture_model` on purpose to prevent
219+
# memory buffers from being cleared by `torch.cuda.empty_cache`.
220+
self.model_runner._dummy_sampler_run(
221+
hidden_states=self.model_runner._dummy_run(
222+
num_tokens=self.scheduler_config.max_num_seqs))
223+
214224
# Reset the seed to ensure that the random state is not affected by
215225
# the model initialization and profiling.
216226
set_random_seed(self.model_config.seed)

0 commit comments

Comments
 (0)