@@ -102,18 +102,14 @@ def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor):
102102
103103 jax_import_guard ()
104104
105- TOKEN_BLOCK = 16
105+ TOKEN_BLOCK = get_bounded_value ( 16 , next_multiple_of ( T , 16 ), 128 )
106106 if is_expand : # Expand
107107 LORA_BLOCK = 1024
108108 DIM_BLOCK = 256
109109 else : # Shrink
110110 LORA_BLOCK = 256
111111 DIM_BLOCK = 1024
112112
113- TOKEN_BLOCK = min (max (TOKEN_BLOCK , pl .next_power_of_2 (T )), 128 )
114- LORA_BLOCK = min (max (LORA_BLOCK , pl .next_power_of_2 (L )), 4096 )
115- DIM_BLOCK = min (max (DIM_BLOCK , pl .next_power_of_2 (D )), 4096 )
116-
117113 kernel = make_kernel_from_pallas (
118114 functools .partial (
119115 _bgmv ,
@@ -128,15 +124,15 @@ def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor):
128124 # register. This has to happen in pytorch, doing it in Jax will lead to NaNs
129125 pad_L = 0
130126 if LORA_BLOCK > L or L % LORA_BLOCK != 0 :
131- pad_L = ( L // LORA_BLOCK + 1 ) * LORA_BLOCK - L
127+ pad_L = next_multiple_of ( L , LORA_BLOCK ) - L
132128
133129 pad_D = 0
134130 if DIM_BLOCK > D or D % DIM_BLOCK != 0 :
135- pad_D = ( D // DIM_BLOCK + 1 ) * DIM_BLOCK - D
131+ pad_D = next_multiple_of ( D , DIM_BLOCK ) - D
136132
137133 pad_T = 0
138134 if TOKEN_BLOCK > T or T % TOKEN_BLOCK != 0 :
139- pad_T = ( T // TOKEN_BLOCK + 1 ) * TOKEN_BLOCK - T
135+ pad_T = next_multiple_of ( T , TOKEN_BLOCK ) - T
140136
141137 if pad_D != 0 or pad_L != 0 :
142138 loras = torch .nn .functional .pad (loras , (0 , pad_D , 0 , pad_L , 0 , 0 ))
@@ -159,3 +155,12 @@ def bgmv_non_xla(inputs: torch.Tensor, loras: torch.Tensor,
159155 _ , L , _ = loras .shape
160156
161157 return torch .empty ((T , L ), device = inputs .device )
158+
159+
160+ def next_multiple_of (n : int , mult : int ) -> int :
161+ if n % mult == 0 :
162+ return n
163+ return (n // mult + 1 ) * mult
164+
165+ def get_bounded_value (_min : int , val : int , _max : int ) -> int :
166+ return min (max (_min , val ), _max )
0 commit comments