@@ -1238,42 +1238,6 @@ def _dummy_run(
1238
1238
)
1239
1239
return hidden_states
1240
1240
1241
- @torch .inference_mode ()
1242
- def _dummy_sampler_run (
1243
- self ,
1244
- hidden_states : torch .Tensor ,
1245
- ) -> torch .Tensor :
1246
-
1247
- logits = self .model .compute_logits (hidden_states , None )
1248
- num_reqs = logits .size (0 )
1249
-
1250
- dummy_tensors = lambda v : torch .full (
1251
- (num_reqs , ), v , device = self .device )
1252
-
1253
- dummy_metadata = SamplingMetadata (
1254
- temperature = dummy_tensors (0.5 ),
1255
- all_greedy = False ,
1256
- all_random = False ,
1257
- top_p = dummy_tensors (0.9 ),
1258
- top_k = dummy_tensors (logits .size (1 ) - 1 ),
1259
- min_p = None ,
1260
- generators = {},
1261
- max_num_logprobs = None ,
1262
- no_penalties = True ,
1263
- prompt_token_ids = None ,
1264
- frequency_penalties = dummy_tensors (0.1 ),
1265
- presence_penalties = dummy_tensors (0.1 ),
1266
- repetition_penalties = dummy_tensors (0.1 ),
1267
- output_token_ids = [[] for _ in range (num_reqs )],
1268
- min_tokens = {},
1269
- logit_bias = [None for _ in range (num_reqs )],
1270
- allowed_token_ids_mask = None ,
1271
- )
1272
- sampler_output = self .model .sample (logits = logits ,
1273
- sampling_metadata = dummy_metadata )
1274
-
1275
- return sampler_output
1276
-
1277
1241
def profile_run (self ) -> None :
1278
1242
# Profile with multimodal encoder & encoder cache.
1279
1243
# TODO: handle encoder-decoder models once we support them.
@@ -1389,11 +1353,37 @@ def profile_run(self) -> None:
1389
1353
hidden_states = self ._dummy_run (self .max_num_tokens )
1390
1354
if get_pp_group ().is_last_rank :
1391
1355
hidden_states = hidden_states [logit_indices ]
1392
- sampler_output = self ._dummy_sampler_run (hidden_states )
1356
+ logits = self .model .compute_logits (hidden_states , None )
1357
+ dummy_tensors = lambda v : torch .full (
1358
+ (num_reqs , ), v , device = self .device )
1359
+ dummy_metadata = SamplingMetadata (
1360
+ temperature = dummy_tensors (0.5 ),
1361
+ all_greedy = False ,
1362
+ all_random = False ,
1363
+ top_p = dummy_tensors (0.9 ),
1364
+ top_k = dummy_tensors (logits .size (1 ) - 1 ),
1365
+ min_p = None ,
1366
+ generators = {},
1367
+ max_num_logprobs = None ,
1368
+ no_penalties = True ,
1369
+ prompt_token_ids = torch .ones_like (logits ,
1370
+ dtype = torch .int64 ),
1371
+ frequency_penalties = dummy_tensors (0.1 ),
1372
+ presence_penalties = dummy_tensors (0.1 ),
1373
+ repetition_penalties = dummy_tensors (0.1 ),
1374
+ output_token_ids = [[] for _ in range (num_reqs )],
1375
+ min_tokens = {},
1376
+ logit_bias = [None for _ in range (num_reqs )],
1377
+ allowed_token_ids_mask = None ,
1378
+ )
1379
+ sampler_output = self .model .sample (
1380
+ logits = logits , sampling_metadata = dummy_metadata )
1393
1381
else :
1382
+ logits = None
1394
1383
sampler_output = None
1384
+ dummy_metadata = None
1395
1385
torch .cuda .synchronize ()
1396
- del hidden_states , sampler_output
1386
+ del hidden_states , logits , sampler_output , dummy_metadata
1397
1387
self .encoder_cache .clear ()
1398
1388
gc .collect ()
1399
1389
0 commit comments