@@ -1426,10 +1426,15 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
1426
1426
1427
1427
# Prepare dummy inputs. These will be reused for all batch sizes.
1428
1428
max_batch_size = self .max_batchsize_to_capture
1429
- input_tokens = torch .zeros (max_batch_size , dtype = torch .long ).cuda ()
1430
- input_positions = torch .zeros (max_batch_size , dtype = torch .long ).cuda ()
1429
+ input_tokens = torch .zeros (max_batch_size ,
1430
+ dtype = torch .long ,
1431
+ device = self .device )
1432
+ input_positions = torch .zeros (max_batch_size ,
1433
+ dtype = torch .long ,
1434
+ device = self .device )
1431
1435
if self .model_config .uses_mrope :
1432
- input_positions = torch .tile (input_positions , (3 , 1 ))
1436
+ input_positions = torch .tile (input_positions ,
1437
+ (3 , 1 )).cuda (device = self .device )
1433
1438
# Prepare dummy previous_hidden_states only if needed by the model.
1434
1439
# This is used by draft models such as EAGLE.
1435
1440
previous_hidden_states = None
@@ -1448,8 +1453,8 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
1448
1453
dtype = self .model_config .dtype ,
1449
1454
device = self .device )
1450
1455
1451
- with self .attn_state .graph_capture (
1452
- max_batch_size ), graph_capture ( ) as graph_capture_context :
1456
+ with self .attn_state .graph_capture (max_batch_size ), graph_capture (
1457
+ self . device ) as graph_capture_context :
1453
1458
# NOTE: Capturing the largest batch size first may help reduce the
1454
1459
# memory usage of CUDA graph.
1455
1460
for virtual_engine in range (
@@ -1549,10 +1554,12 @@ def _update_inputs_to_capture_for_enc_dec_model(self,
1549
1554
"""
1550
1555
# During the decode phase encoder_input_ids and encoder_positions are
1551
1556
# unset. Do the same thing for graph capture.
1552
- capture_inputs ["encoder_input_ids" ] = torch .tensor (
1553
- [], dtype = torch .long ).cuda ()
1554
- capture_inputs ["encoder_positions" ] = torch .tensor (
1555
- [], dtype = torch .long ).cuda ()
1557
+ capture_inputs ["encoder_input_ids" ] = torch .tensor ([],
1558
+ dtype = torch .long ,
1559
+ device = self .device )
1560
+ capture_inputs ["encoder_positions" ] = torch .tensor ([],
1561
+ dtype = torch .long ,
1562
+ device = self .device )
1556
1563
1557
1564
@property
1558
1565
def vocab_size (self ) -> int :
0 commit comments