Skip to content

Commit 7c79683

Browse files
Restricted block sizes to prevent memory from blowing up
Signed-off-by: Akshat Tripathi <[email protected]>
1 parent e66067c commit 7c79683

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

vllm/lora/ops/xla_ops/pallas.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,8 @@ def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor):
111111
DIM_BLOCK = 1024
112112

113113
TOKEN_BLOCK = min(max(TOKEN_BLOCK, pl.next_power_of_2(T)), 128)
114-
LORA_BLOCK = max(LORA_BLOCK, pl.next_power_of_2(L))
115-
DIM_BLOCK = max(DIM_BLOCK, pl.next_power_of_2(D))
114+
LORA_BLOCK = min(max(LORA_BLOCK, pl.next_power_of_2(L)), 4096)
115+
DIM_BLOCK = min(max(DIM_BLOCK, pl.next_power_of_2(D)), 4096)
116116

117117
kernel = make_kernel_from_pallas(
118118
functools.partial(

0 commit comments

Comments
 (0)