Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion vllm/lora/ops/triton_ops/bgmv_expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from vllm.utils import direct_register_custom_op

from .utils import get_lora_op_configs
from .utils import _set_cuda_device, get_lora_op_configs


@triton.jit
Expand Down Expand Up @@ -142,6 +142,7 @@ def _bgmv_expand(
META["SPLIT_N"],
batches,
)
_set_cuda_device(inputs.device)
_bgmv_expand_kernel[grid](
inputs,
lora_b_weights,
Expand Down
3 changes: 2 additions & 1 deletion vllm/lora/ops/triton_ops/bgmv_expand_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from vllm.utils import direct_register_custom_op

from .utils import get_lora_op_configs
from .utils import _set_cuda_device, get_lora_op_configs


@triton.jit
Expand Down Expand Up @@ -158,6 +158,7 @@ def _bgmv_expand_slice(
META["SPLIT_N"],
batches,
)
_set_cuda_device(inputs.device)
_bgmv_expand_slice_kernel[grid](
inputs,
lora_b_weights,
Expand Down
3 changes: 2 additions & 1 deletion vllm/lora/ops/triton_ops/bgmv_shrink.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from vllm.utils import direct_register_custom_op

from .utils import get_lora_op_configs
from .utils import _set_cuda_device, get_lora_op_configs


@triton.jit
Expand Down Expand Up @@ -124,6 +124,7 @@ def _bgmv_shrink(
META["SPLIT_K"],
batches,
)
_set_cuda_device(inputs.device)
_bgmv_shrink_kernel[grid](
inputs,
lora_a_weights,
Expand Down
3 changes: 2 additions & 1 deletion vllm/lora/ops/triton_ops/sgmv_expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from vllm.utils import direct_register_custom_op

from .utils import _get_lora_b_ptr
from .utils import _get_lora_b_ptr, _set_cuda_device


@triton.jit
Expand Down Expand Up @@ -218,6 +218,7 @@ def _sgmv_expand(
batches,
len(lora_b_weights),
)
_set_cuda_device(inputs.device)
_sgmv_expand_kernel[grid](
inputs,
lora_ptr_tensor,
Expand Down
3 changes: 2 additions & 1 deletion vllm/lora/ops/triton_ops/sgmv_shrink.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from vllm.utils import direct_register_custom_op

from .utils import _get_lora_a_ptr
from .utils import _get_lora_a_ptr, _set_cuda_device


@triton.jit
Expand Down Expand Up @@ -184,6 +184,7 @@ def _sgmv_shrink(
SPLIT_K * len(lora_a_weights),
batches,
)
_set_cuda_device(inputs.device)
_sgmv_shrink_kernel[grid](
inputs,
lora_ptr_tensor,
Expand Down
9 changes: 9 additions & 0 deletions vllm/lora/ops/triton_ops/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0

import functools
from functools import lru_cache
from typing import Dict, List, Tuple

import torch
Expand Down Expand Up @@ -50,6 +51,14 @@ def get_lora_op_configs(op_type: str, batch: int,
return config


@lru_cache
def _set_cuda_device(device: torch.device):
"""
Sets the current CUDA device.
"""
torch.cuda.set_device(device)


_LORA_A_PTR_DICT: Dict[Tuple[int, ...], Tuple[torch.tensor, ...]] = {}
_LORA_B_PTR_DICT: Dict[Tuple[int, ...], Tuple[torch.tensor, ...]] = {}

Expand Down
Loading