Skip to content

Commit f839396

Browse files
committed
WIP
1 parent 6675aa7 commit f839396

File tree

2 files changed

+78
-19
lines changed

2 files changed

+78
-19
lines changed

ggml/src/ggml-cuda/mmf.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
139139
if (type == GGML_TYPE_F32 && src1_ncols > 512) {
140140
return false;
141141
}
142-
if ((type == GGML_TYPE_F16 || type == GGML_TYPE_BF16) && src1_ncols > 1024) {
142+
if ((type == GGML_TYPE_F16 || type == GGML_TYPE_BF16) && src1_ncols > 512) {
143143
return false;
144144
}
145145
} else {

ggml/src/ggml-cuda/mmf.cuh

Lines changed: 77 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ static __global__ void mul_mat_f_ids(
291291
extern __shared__ char data_mmv[];
292292
char * compute_base = data_mmv;
293293

294-
const float2 * y2 = (const float2 *) y;
294+
//const float2 * y2 = (const float2 *) y;
295295

296296
tile_C C[ntA][ntB];
297297

@@ -311,13 +311,12 @@ static __global__ void mul_mat_f_ids(
311311
}
312312
}
313313

314-
#pragma unroll
315-
for (int itB = 0; itB < ntB; ++itB) {
316-
if constexpr (std::is_same_v<T, float>) {
314+
if constexpr (std::is_same_v<T, float>) {
315+
float vals_buf[2][tile_B::I];
316+
auto gather_tile = [&](int tile_idx_local, float *vals) {
317317
#pragma unroll
318318
for (int j0 = 0; j0 < tile_B::I; ++j0) {
319-
const int j = j0 + itB*tile_B::I;
320-
319+
const int j = j0 + tile_idx_local*tile_B::I;
321320
const int global_j = col_base + j;
322321
float val = 0.0f;
323322
if (j < cols_per_block && global_j < ncols_expert) {
@@ -329,13 +328,48 @@ static __global__ void mul_mat_f_ids(
329328
val = y[channel*stride_channel_y + token*stride_col_y + col];
330329
}
331330
}
332-
tile_xy[j0*tile_k_padded + threadIdx.x] = val;
331+
vals[j0] = val;
333332
}
334-
} else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
333+
};
334+
335+
if (ntB > 0) {
336+
gather_tile(0, vals_buf[0]);
337+
}
338+
339+
int curr_buf = 0;
340+
int next_buf = 1;
341+
#pragma unroll
342+
for (int itB = 0; itB < ntB; ++itB) {
335343
#pragma unroll
336344
for (int j0 = 0; j0 < tile_B::I; ++j0) {
337-
const int j = j0 + itB*tile_B::I;
345+
tile_xy[j0*tile_k_padded + threadIdx.x] = vals_buf[curr_buf][j0];
346+
}
347+
348+
if (itB + 1 < ntB) {
349+
gather_tile(itB + 1, vals_buf[next_buf]);
350+
}
338351

352+
#pragma unroll
353+
for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
354+
tile_B B;
355+
load_ldmatrix(B, tile_xy + k0, tile_k_padded);
356+
#pragma unroll
357+
for (int itA = 0; itA < ntA; ++itA) {
358+
mma(C[itA][itB], A[itA][k0/tile_B::J], B);
359+
}
360+
}
361+
362+
if (itB + 1 < ntB) {
363+
curr_buf ^= 1;
364+
next_buf ^= 1;
365+
}
366+
}
367+
} else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
368+
float2 vals_buf[2][tile_B::I];
369+
auto gather_tile = [&](int tile_idx_local, float2 *vals) {
370+
#pragma unroll
371+
for (int j0 = 0; j0 < tile_B::I; ++j0) {
372+
const int j = j0 + tile_idx_local*tile_B::I;
339373
const int global_j = col_base + j;
340374
float2 tmp = make_float2(0.0f, 0.0f);
341375
if (j < cols_per_block && global_j < ncols_expert) {
@@ -344,23 +378,48 @@ static __global__ void mul_mat_f_ids(
344378
const int token = (int) qrm.x;
345379
const int channel = (int) qrm.y;
346380
if (token < ncols_dst_total) {
347-
tmp = y2[channel*stride_channel_y/2 + token*stride_col_y + col];
381+
tmp = *(const float2*) &y[channel*stride_channel_y + 2*(token*stride_col_y + col)];
348382
}
349383
}
350-
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
384+
vals[j0] = tmp;
351385
}
352-
} else {
353-
static_assert(std::is_same_v<T, void>, "unsupported type");
386+
};
387+
388+
if (ntB > 0) {
389+
gather_tile(0, vals_buf[0]);
354390
}
391+
392+
int curr_buf = 0;
393+
int next_buf = 1;
355394
#pragma unroll
356-
for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
357-
tile_B B;
358-
load_ldmatrix(B, tile_xy + k0, tile_k_padded);
395+
for (int itB = 0; itB < ntB; ++itB) {
359396
#pragma unroll
360-
for (int itA = 0; itA < ntA; ++itA) {
361-
mma(C[itA][itB], A[itA][k0/tile_B::J], B);
397+
for (int j0 = 0; j0 < tile_B::I; ++j0) {
398+
const float2 tmp = vals_buf[curr_buf][j0];
399+
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
400+
}
401+
402+
if (itB + 1 < ntB) {
403+
gather_tile(itB + 1, vals_buf[next_buf]);
404+
}
405+
406+
#pragma unroll
407+
for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
408+
tile_B B;
409+
load_ldmatrix(B, tile_xy + k0, tile_k_padded);
410+
#pragma unroll
411+
for (int itA = 0; itA < ntA; ++itA) {
412+
mma(C[itA][itB], A[itA][k0/tile_B::J], B);
413+
}
414+
}
415+
416+
if (itB + 1 < ntB) {
417+
curr_buf ^= 1;
418+
next_buf ^= 1;
362419
}
363420
}
421+
} else {
422+
static_assert(std::is_same_v<T, void>, "unsupported type");
364423
}
365424
}
366425

0 commit comments

Comments
 (0)