Skip to content

Commit d7338f8

Browse files
Removed larger lora/dim block sizes since they reduce perf outside of microbenchmarks
Signed-off-by: Akshat Tripathi <[email protected]>
1 parent 7c79683 commit d7338f8

File tree

1 file changed

+13
-8
lines changed

1 file changed

+13
-8
lines changed

vllm/lora/ops/xla_ops/pallas.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -102,18 +102,14 @@ def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor):
102102

103103
jax_import_guard()
104104

105-
TOKEN_BLOCK = 16
105+
TOKEN_BLOCK = get_bounded_value(16, next_multiple_of(T, 16), 128)
106106
if is_expand: # Expand
107107
LORA_BLOCK = 1024
108108
DIM_BLOCK = 256
109109
else: # Shrink
110110
LORA_BLOCK = 256
111111
DIM_BLOCK = 1024
112112

113-
TOKEN_BLOCK = min(max(TOKEN_BLOCK, pl.next_power_of_2(T)), 128)
114-
LORA_BLOCK = min(max(LORA_BLOCK, pl.next_power_of_2(L)), 4096)
115-
DIM_BLOCK = min(max(DIM_BLOCK, pl.next_power_of_2(D)), 4096)
116-
117113
kernel = make_kernel_from_pallas(
118114
functools.partial(
119115
_bgmv,
@@ -128,15 +124,15 @@ def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor):
128124
# register. This has to happen in pytorch, doing it in Jax will lead to NaNs
129125
pad_L = 0
130126
if LORA_BLOCK > L or L % LORA_BLOCK != 0:
131-
pad_L = (L // LORA_BLOCK + 1) * LORA_BLOCK - L
127+
pad_L = next_multiple_of(L, LORA_BLOCK) - L
132128

133129
pad_D = 0
134130
if DIM_BLOCK > D or D % DIM_BLOCK != 0:
135-
pad_D = (D // DIM_BLOCK + 1) * DIM_BLOCK - D
131+
pad_D = next_multiple_of(D, DIM_BLOCK) - D
136132

137133
pad_T = 0
138134
if TOKEN_BLOCK > T or T % TOKEN_BLOCK != 0:
139-
pad_T = (T // TOKEN_BLOCK + 1) * TOKEN_BLOCK - T
135+
pad_T = next_multiple_of(T, TOKEN_BLOCK) - T
140136

141137
if pad_D != 0 or pad_L != 0:
142138
loras = torch.nn.functional.pad(loras, (0, pad_D, 0, pad_L, 0, 0))
@@ -159,3 +155,12 @@ def bgmv_non_xla(inputs: torch.Tensor, loras: torch.Tensor,
159155
_, L, _ = loras.shape
160156

161157
return torch.empty((T, L), device=inputs.device)
158+
159+
160+
def next_multiple_of(n: int, mult: int) -> int:
161+
if n % mult == 0:
162+
return n
163+
return (n // mult + 1) * mult
164+
165+
def get_bounded_value(_min: int, val: int, _max: int) -> int:
166+
return min(max(_min, val), _max)

0 commit comments

Comments
 (0)