@@ -11,7 +11,7 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
1111
1212bool ggml_cuda_should_use_mmf (enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const int src1_ncols);
1313
14- template <typename T, int rows_per_block, int cols_per_block, int nwarps, bool has_ids, size_t n_expert_used = 0 >
14+ template <typename T, int rows_per_block, int cols_per_block, int nwarps, bool has_ids>
1515__launch_bounds__ (ggml_cuda_get_physical_warp_size()*nwarps, 1)
1616static __global__ void mul_mat_f(
1717 const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
@@ -57,10 +57,8 @@ static __global__ void mul_mat_f(
5757 T * tile_xy = (T *) compute_base + threadIdx .y *(tile_A::I * tile_k_padded);
5858
5959 if constexpr (has_ids) {
60-
6160 int found = 0 ;
6261
63- #pragma unroll
6462 for (int j0 = 0 ; j0 < cols_per_block; j0 += nwarps) {
6563 const int j = j0 + threadIdx .y ;
6664 const int32_t * __restrict__ id_row = ids + j*stride_row_id;
@@ -69,28 +67,14 @@ static __global__ void mul_mat_f(
6967 slot_map[j] = -1 ;
7068 }
7169
72- if constexpr (n_expert_used == 0 ) {
73- for (int k_base = 0 ; k_base < nchannels_dst; k_base += warp_size) {
74- int k = k_base + threadIdx .x ;
75- int match = (k < nchannels_dst) && (id_row[k*stride_col_id] == expert_idx);
70+ for (int k_base = 0 ; k_base < nchannels_dst; k_base += warp_size) {
71+ int k = k_base + threadIdx .x ;
72+ int match = (k < nchannels_dst) && (id_row[k*stride_col_id] == expert_idx);
7673
77- if (match) {
78- slot_map[j] = k;
79- found = 1 ;
80- break ;
81- }
82- }
83- } else {
84- #pragma unroll
85- for (int k_base = 0 ; k_base < n_expert_used; k_base += warp_size) {
86- int k = k_base + threadIdx .x ;
87- int match = (k < n_expert_used) && (id_row[k*stride_col_id] == expert_idx);
88-
89- if (match) {
90- slot_map[j] = k;
91- found = 1 ;
92- break ;
93- }
74+ if (match) {
75+ slot_map[j] = k;
76+ found = 1 ;
77+ break ;
9478 }
9579 }
9680 }
@@ -215,71 +199,6 @@ static __global__ void mul_mat_f(
215199#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
216200}
217201
218- template <typename T, int rows_per_block, int cols_per_block, int nwarps>
219- static inline void launch_mul_mat_ids (
220- const T * x, const float * y, const int32_t * ids, float * dst,
221- const int64_t ncols_x, const int64_t nchannels_dst,
222- const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
223- const int64_t stride_col_id, const int64_t stride_row_id,
224- const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
225- const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
226- const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream) {
227-
228- const int n_expert_used = nchannels_dst;
229-
230- switch (n_expert_used) {
231- case 1 : {
232- mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true , 1 ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
233- (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
234- stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
235- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
236- } break ;
237- case 2 : {
238- mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true , 2 ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
239- (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
240- stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
241- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
242- } break ;
243- case 4 : {
244- mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true , 4 ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
245- (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
246- stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
247- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
248- } break ;
249- case 6 : {
250- mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true , 6 ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
251- (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
252- stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
253- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
254- } break ;
255- case 8 : {
256- mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true , 8 ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
257- (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
258- stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
259- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
260- } break ;
261- case 16 : {
262- mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true , 16 ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
263- (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
264- stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
265- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
266- } break ;
267- case 32 : {
268- mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true , 32 ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
269- (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
270- stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
271- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
272- } break ;
273- default : {
274- mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true , 0 ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
275- (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
276- stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
277- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
278- } break ;
279- }
280- }
281-
282-
283202template <typename T, int cols_per_block, int nwarps>
284203static inline void mul_mat_f_switch_ids (
285204 const T * x, const float * y, const int32_t * ids, float * dst,
@@ -290,11 +209,10 @@ static inline void mul_mat_f_switch_ids(
290209 const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
291210 const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream) {
292211 if (ids) {
293- launch_mul_mat_ids <T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps>(
294- x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
212+ mul_mat_f <T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true > <<<block_nums, block_dims, nbytes_shared_total, stream>>>
213+ ( x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
295214 stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
296- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst,
297- block_nums, block_dims, nbytes_shared_total, stream);
215+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
298216 } else {
299217 mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, false ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
300218 (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
0 commit comments