Skip to content

Commit bf08ea5

Browse files
committed
Remove tempalte from n_expert_used
1 parent bb831b2 commit bf08ea5

File tree

1 file changed

+11
-93
lines changed

1 file changed

+11
-93
lines changed

ggml/src/ggml-cuda/mmf.cuh

Lines changed: 11 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
1111

1212
bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const int src1_ncols);
1313

14-
template <typename T, int rows_per_block, int cols_per_block, int nwarps, bool has_ids, size_t n_expert_used = 0>
14+
template <typename T, int rows_per_block, int cols_per_block, int nwarps, bool has_ids>
1515
__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
1616
static __global__ void mul_mat_f(
1717
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
@@ -57,10 +57,8 @@ static __global__ void mul_mat_f(
5757
T * tile_xy = (T *) compute_base + threadIdx.y*(tile_A::I * tile_k_padded);
5858

5959
if constexpr (has_ids) {
60-
6160
int found = 0;
6261

63-
#pragma unroll
6462
for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
6563
const int j = j0 + threadIdx.y;
6664
const int32_t * __restrict__ id_row = ids + j*stride_row_id;
@@ -69,28 +67,14 @@ static __global__ void mul_mat_f(
6967
slot_map[j] = -1;
7068
}
7169

72-
if constexpr (n_expert_used == 0) {
73-
for (int k_base = 0; k_base < nchannels_dst; k_base += warp_size) {
74-
int k = k_base + threadIdx.x;
75-
int match = (k < nchannels_dst) && (id_row[k*stride_col_id] == expert_idx);
70+
for (int k_base = 0; k_base < nchannels_dst; k_base += warp_size) {
71+
int k = k_base + threadIdx.x;
72+
int match = (k < nchannels_dst) && (id_row[k*stride_col_id] == expert_idx);
7673

77-
if (match) {
78-
slot_map[j] = k;
79-
found = 1;
80-
break;
81-
}
82-
}
83-
} else {
84-
#pragma unroll
85-
for (int k_base = 0; k_base < n_expert_used; k_base += warp_size) {
86-
int k = k_base + threadIdx.x;
87-
int match = (k < n_expert_used) && (id_row[k*stride_col_id] == expert_idx);
88-
89-
if (match) {
90-
slot_map[j] = k;
91-
found = 1;
92-
break;
93-
}
74+
if (match) {
75+
slot_map[j] = k;
76+
found = 1;
77+
break;
9478
}
9579
}
9680
}
@@ -215,71 +199,6 @@ static __global__ void mul_mat_f(
215199
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
216200
}
217201

218-
template<typename T, int rows_per_block, int cols_per_block, int nwarps>
219-
static inline void launch_mul_mat_ids(
220-
const T * x, const float * y, const int32_t * ids, float * dst,
221-
const int64_t ncols_x, const int64_t nchannels_dst,
222-
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
223-
const int64_t stride_col_id, const int64_t stride_row_id,
224-
const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
225-
const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
226-
const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream) {
227-
228-
const int n_expert_used = nchannels_dst;
229-
230-
switch (n_expert_used) {
231-
case 1: {
232-
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true, 1><<<block_nums, block_dims, nbytes_shared_total, stream>>>
233-
(x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
234-
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
235-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
236-
} break;
237-
case 2: {
238-
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true, 2><<<block_nums, block_dims, nbytes_shared_total, stream>>>
239-
(x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
240-
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
241-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
242-
} break;
243-
case 4: {
244-
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true, 4><<<block_nums, block_dims, nbytes_shared_total, stream>>>
245-
(x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
246-
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
247-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
248-
} break;
249-
case 6: {
250-
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true, 6><<<block_nums, block_dims, nbytes_shared_total, stream>>>
251-
(x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
252-
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
253-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
254-
} break;
255-
case 8: {
256-
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true, 8><<<block_nums, block_dims, nbytes_shared_total, stream>>>
257-
(x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
258-
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
259-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
260-
} break;
261-
case 16: {
262-
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true, 16><<<block_nums, block_dims, nbytes_shared_total, stream>>>
263-
(x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
264-
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
265-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
266-
} break;
267-
case 32: {
268-
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true, 32><<<block_nums, block_dims, nbytes_shared_total, stream>>>
269-
(x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
270-
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
271-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
272-
} break;
273-
default: {
274-
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true, 0><<<block_nums, block_dims, nbytes_shared_total, stream>>>
275-
(x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
276-
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
277-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
278-
} break;
279-
}
280-
}
281-
282-
283202
template<typename T, int cols_per_block, int nwarps>
284203
static inline void mul_mat_f_switch_ids(
285204
const T * x, const float * y, const int32_t * ids, float * dst,
@@ -290,11 +209,10 @@ static inline void mul_mat_f_switch_ids(
290209
const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
291210
const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream) {
292211
if (ids) {
293-
launch_mul_mat_ids<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps>(
294-
x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
212+
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true><<<block_nums, block_dims, nbytes_shared_total, stream>>>
213+
(x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
295214
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
296-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst,
297-
block_nums, block_dims, nbytes_shared_total, stream);
215+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
298216
} else {
299217
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, false><<<block_nums, block_dims, nbytes_shared_total, stream>>>
300218
(x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,

0 commit comments

Comments
 (0)