Skip to content

Commit 8a6cfa4

Browse files
committed
CUDA: MUL_MAT_ID optimizations for mmf
1 parent 10d8b2b commit 8a6cfa4

File tree

1 file changed

+25
-33
lines changed

1 file changed

+25
-33
lines changed

ggml/src/ggml-cuda/mmf.cuh

Lines changed: 25 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -57,31 +57,37 @@ static __global__ void mul_mat_f(
5757
T * tile_xy = (T *) compute_base + threadIdx.y*(tile_A::I * tile_k_padded);
5858

5959
if constexpr (has_ids) {
60-
__shared__ int has_any;
61-
if (threadIdx.y == 0) {
62-
int local_has_any = 0;
63-
for (int j = threadIdx.x; j < cols_per_block; j += warp_size) {
64-
int slot = -1;
65-
for (int k = 0; k < nchannels_dst; ++k) {
66-
const int idv = ids[j*stride_row_id + k*stride_col_id];
67-
if (idv == expert_idx) {
68-
slot = k;
69-
break;
60+
int found = 0;
61+
62+
for (int j = threadIdx.y; j < cols_per_block; j += nwarps) {
63+
const int32_t * __restrict__ id_row = ids + j*stride_row_id;
64+
65+
if (threadIdx.x == 0) {
66+
slot_map[j] = -1;
67+
}
68+
69+
for (int k_base = 0; k_base < nchannels_dst; k_base += warp_size) {
70+
int k = k_base + threadIdx.x;
71+
int match = (k < nchannels_dst) && (id_row[k*stride_col_id] == expert_idx);
72+
73+
unsigned mask = __ballot_sync(0xffffffff, match);
74+
if (mask) {
75+
int leader = __ffs(mask) - 1;
76+
if (threadIdx.x == leader) {
77+
slot_map[j] = k_base + leader;
7078
}
71-
}
72-
if (j < cols_per_block) {
73-
local_has_any |= (slot >= 0);
74-
slot_map[j] = slot;
79+
found = 1;
80+
break;
7581
}
7682
}
77-
has_any = warp_reduce_any(local_has_any);
7883
}
79-
__syncthreads();
80-
if (has_any == 0) {
84+
85+
if (!__syncthreads_or(found)) {
8186
return;
8287
}
8388
}
8489

90+
8591
for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) {
8692
tile_A A[ntA][warp_size / tile_A::J];
8793
#pragma unroll
@@ -106,14 +112,7 @@ static __global__ void mul_mat_f(
106112
if constexpr (!has_ids) {
107113
tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f;
108114
} else {
109-
float val = 0.0f;
110-
if (j < cols_per_block) {
111-
const int slot = slot_map[j];
112-
if (slot >= 0) {
113-
val = y[slot*stride_channel_y + j*stride_col_y + col];
114-
}
115-
}
116-
tile_xy[j0*tile_k_padded + threadIdx.x] = val;
115+
tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[slot_map[j]*stride_channel_y + j*stride_col_y + col] : 0.0f;
117116
}
118117
}
119118
} else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
@@ -125,14 +124,7 @@ static __global__ void mul_mat_f(
125124
const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f);
126125
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
127126
} else {
128-
float2 tmp = make_float2(0.0f, 0.0f);
129-
if (j < cols_per_block) {
130-
const int slot = slot_map[j];
131-
if (slot >= 0) {
132-
const float2 * y2_slot = (const float2 *)(y + slot*stride_channel_y);
133-
tmp = y2_slot[j*stride_col_y + col];
134-
}
135-
}
127+
float2 tmp = j < cols_per_block && slot_map[j] >= 0 ? *(const float2*) &y[slot_map[j]*stride_channel_y + 2*(j*stride_col_y + col)] : make_float2(0.0f, 0.0f);
136128
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
137129
}
138130
}

0 commit comments

Comments
 (0)