Skip to content

Commit 0e0d51c

Browse files
SageMooreyewentao256
authored andcommitted
Suppress benign cuBLAS warning when capturing cudagraphs with DBO (#25596)
Signed-off-by: Sage Moore <[email protected]> Signed-off-by: yewentao256 <[email protected]>
1 parent 72a5101 commit 0e0d51c

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

vllm/v1/worker/gpu_ubatch_wrapper.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def __init__(self, runnable: Callable, vllm_config: VllmConfig,
104104
self.graph_pool = current_platform.get_global_graph_pool()
105105

106106
self.sm_control = self._create_sm_control_context(vllm_config)
107+
self.device = device
107108

108109
@staticmethod
109110
def _create_sm_control_context(vllm_config: VllmConfig):
@@ -168,6 +169,7 @@ def _capture_ubatches(self, ubatch_metadata, model) -> torch.Tensor:
168169

169170
@torch.inference_mode()
170171
def _capture_ubatch_thread(results, ubatch_metadata):
172+
torch.cuda.set_device(self.device)
171173
ubatch_context = ubatch_metadata.context
172174
with torch.cuda.stream(ubatch_context.compute_stream):
173175
_ = torch.cuda.current_blas_handle()

0 commit comments

Comments
 (0)