@@ -141,10 +141,9 @@ template <ggml_type type, int ncols_dst>
141141__launch_bounds__ (calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
142142static __global__ void mul_mat_vec_q(
143143 const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, float * __restrict__ dst,
144- const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
145- const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x,
146- const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio,
147- const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst) {
144+ const int ncols_x, const int nchannels_y, const int stride_row_x, const int stride_col_y, const int stride_col_dst,
145+ const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
146+ const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
148147
149148 constexpr int qk = ggml_cuda_type_traits<type>::qk;
150149 constexpr int qi = ggml_cuda_type_traits<type>::qi;
@@ -162,12 +161,12 @@ static __global__ void mul_mat_vec_q(
162161 constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
163162
164163 // The MUL_MAT_ID code path with ids != nullptr is only implemented for ncols_dst == 1.
165- const uint32_t channel_dst = blockIdx .y ;
166- const uint32_t channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv ( channel_dst, channel_ratio) ;
167- const uint32_t channel_y = ncols_dst == 1 && ids ? fastmodulo ( channel_dst, nchannels_y) : channel_dst;
168- const uint32_t sample_dst = blockIdx .z ;
169- const uint32_t sample_x = fastdiv ( sample_dst, sample_ratio) ;
170- const uint32_t sample_y = sample_dst;
164+ const int channel_dst = blockIdx .y ;
165+ const int channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : channel_dst / channel_ratio;
166+ const int channel_y = ncols_dst == 1 && ids ? channel_dst % nchannels_y : channel_dst;
167+ const int sample_dst = blockIdx .z ;
168+ const int sample_x = sample_dst / sample_ratio;
169+ const int sample_y = sample_dst;
171170
172171 // partial sum for each thread
173172 float tmp[ncols_dst][rows_per_cuda_block] = {{0 .0f }};
@@ -248,80 +247,95 @@ static void mul_mat_vec_q_switch_ncols_dst(
248247 GGML_ASSERT (ncols_x % ggml_blck_size (type) == 0 );
249248 GGML_ASSERT (ncols_dst <= MMVQ_MAX_BATCH_SIZE);
250249
251- const uint3 nchannels_y_fd = ids ? init_fastdiv_values (nchannels_y) : make_uint3 (0 , 0 , 0 );
252- const uint3 channel_ratio_fd = ids ? make_uint3 (0 , 0 , 0 ) : init_fastdiv_values (nchannels_dst / nchannels_x);
253- const uint3 sample_ratio_fd = init_fastdiv_values (nsamples_dst / nsamples_x);
250+ const int channel_ratio = nchannels_dst / nchannels_x;
251+ const int sample_ratio = nsamples_dst / nsamples_x;
254252
255253 const int device = ggml_cuda_get_device ();
256254 const int warp_size = ggml_cuda_info ().devices [device].warp_size ;
257255 const mmvq_parameter_table_id table_id = get_device_table_id (ggml_cuda_info ().devices [device].cc );
258256
259257 GGML_ASSERT (!ids || ncols_dst == 1 );
260258 switch (ncols_dst) {
261- case 1 : {
259+ case 1 :
260+ {
262261 constexpr int c_ncols_dst = 1 ;
263262 std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
264263 mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0 , stream>>>
265- (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
266- channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
267- sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
268- } break ;
269- case 2 : {
264+ (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
265+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
266+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
267+ break ;
268+ }
269+ case 2 :
270+ {
270271 constexpr int c_ncols_dst = 2 ;
271272 std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
272273 mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0 , stream>>>
273- (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
274- channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
275- sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
276- } break ;
277- case 3 : {
274+ (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
275+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
276+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
277+ break ;
278+ }
279+ case 3 :
280+ {
278281 constexpr int c_ncols_dst = 3 ;
279282 std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
280283 mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0 , stream>>>
281- (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
282- channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
283- sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
284- } break ;
285- case 4 : {
284+ (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
285+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
286+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
287+ break ;
288+ }
289+ case 4 :
290+ {
286291 constexpr int c_ncols_dst = 4 ;
287292 std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
288293 mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0 , stream>>>
289- (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
290- channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
291- sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
292- } break ;
293- case 5 : {
294+ (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
295+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
296+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
297+ break ;
298+ }
299+ case 5 :
300+ {
294301 constexpr int c_ncols_dst = 5 ;
295302 std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
296303 mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0 , stream>>>
297- (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
298- channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
299- sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
300- } break ;
301- case 6 : {
304+ (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
305+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
306+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
307+ break ;
308+ }
309+ case 6 :
310+ {
302311 constexpr int c_ncols_dst = 6 ;
303312 std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
304313 mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0 , stream>>>
305- (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
306- channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
307- sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
308- } break ;
309- case 7 : {
314+ (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
315+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
316+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
317+ break ;
318+ }
319+ case 7 :
320+ {
310321 constexpr int c_ncols_dst = 7 ;
311322 std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
312323 mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0 , stream>>>
313- (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
314- channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
315- sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
316- } break ;
317- case 8 : {
324+ (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
325+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
326+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
327+ break ;
328+ }
329+ case 8 :
330+ {
318331 constexpr int c_ncols_dst = 8 ;
319332 std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
320333 mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0 , stream>>>
321- (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
322- channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
323- sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
324- } break ;
334+ (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
335+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
336+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
337+ break ;
338+ }
325339 default :
326340 GGML_ABORT (" fatal error" );
327341 break ;
0 commit comments