@@ -1179,6 +1179,43 @@ def _dummy_run(
1179
1179
)
1180
1180
return hidden_states
1181
1181
1182
+ @torch .inference_mode ()
1183
+ def _dummy_sampler_run (
1184
+ self ,
1185
+ hidden_states : torch .Tensor ,
1186
+ ) -> torch .Tensor :
1187
+
1188
+ logits = self .model .compute_logits (hidden_states , None )
1189
+ num_reqs = logits .size (0 )
1190
+
1191
+ dummy_tensors = lambda v : torch .full (
1192
+ (num_reqs , ), v , device = self .device )
1193
+
1194
+ dummy_metadata = SamplingMetadata (
1195
+ temperature = dummy_tensors (0.5 ),
1196
+ all_greedy = False ,
1197
+ all_random = False ,
1198
+ spec_token_ids = None ,
1199
+ top_p = dummy_tensors (0.9 ),
1200
+ top_k = dummy_tensors (logits .size (1 ) - 1 ),
1201
+ min_p = None ,
1202
+ generators = {},
1203
+ max_num_logprobs = None ,
1204
+ no_penalties = True ,
1205
+ prompt_token_ids = None ,
1206
+ frequency_penalties = dummy_tensors (0.1 ),
1207
+ presence_penalties = dummy_tensors (0.1 ),
1208
+ repetition_penalties = dummy_tensors (0.1 ),
1209
+ output_token_ids = [[] for _ in range (num_reqs )],
1210
+ min_tokens = {},
1211
+ logit_bias = [None for _ in range (num_reqs )],
1212
+ allowed_token_ids_mask = None ,
1213
+ )
1214
+ sampler_output = self .model .sample (logits = logits ,
1215
+ sampling_metadata = dummy_metadata )
1216
+
1217
+ return sampler_output
1218
+
1182
1219
def profile_run (self ) -> None :
1183
1220
# use an empty tensor instead of `None`` to force Dynamo to pass
1184
1221
# it by reference, rather by specializing on the value `None`.
@@ -1306,38 +1343,11 @@ def profile_run(self) -> None:
1306
1343
dummy_kv_caches )
1307
1344
if get_pp_group ().is_last_rank :
1308
1345
hidden_states = hidden_states [logit_indices ]
1309
- logits = self .model .compute_logits (hidden_states , None )
1310
- dummy_tensors = lambda v : torch .full (
1311
- (num_reqs , ), v , device = self .device )
1312
- dummy_metadata = SamplingMetadata (
1313
- temperature = dummy_tensors (0.5 ),
1314
- all_greedy = False ,
1315
- all_random = False ,
1316
- spec_token_ids = None ,
1317
- top_p = dummy_tensors (0.9 ),
1318
- top_k = dummy_tensors (logits .size (1 ) - 1 ),
1319
- min_p = None ,
1320
- generators = {},
1321
- max_num_logprobs = None ,
1322
- no_penalties = True ,
1323
- prompt_token_ids = torch .ones_like (logits ,
1324
- dtype = torch .int64 ),
1325
- frequency_penalties = dummy_tensors (0.1 ),
1326
- presence_penalties = dummy_tensors (0.1 ),
1327
- repetition_penalties = dummy_tensors (0.1 ),
1328
- output_token_ids = [[] for _ in range (num_reqs )],
1329
- min_tokens = {},
1330
- logit_bias = [None for _ in range (num_reqs )],
1331
- allowed_token_ids_mask = None ,
1332
- )
1333
- sampler_output = self .model .sample (
1334
- logits = logits , sampling_metadata = dummy_metadata )
1346
+ sampler_output = self ._dummy_sampler_run (hidden_states )
1335
1347
else :
1336
- logits = None
1337
1348
sampler_output = None
1338
- dummy_metadata = None
1339
1349
torch .cuda .synchronize ()
1340
- del hidden_states , logits , sampler_output , dummy_metadata
1350
+ del hidden_states , sampler_output
1341
1351
self .encoder_cache .clear ()
1342
1352
gc .collect ()
1343
1353
0 commit comments