Skip to content

Commit 9bcc88a

Browse files
JohannesGaesslerhodlen
authored andcommitted
CUDA: more warps for mmvq on NVIDIA (ggml-org#5394)
1 parent 1ea16d0 commit 9bcc88a

File tree

1 file changed

+86
-47
lines changed

1 file changed

+86
-47
lines changed

ggml-cuda.cu

Lines changed: 86 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -5310,45 +5310,65 @@ template <bool need_check> static __global__ void
53105310
#endif // __CUDA_ARCH__ >= CC_VOLTA
53115311
}
53125312

5313-
template <int ncols_y_template, int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
5313+
#define MMVQ_NWARPS_NVIDIA 4
5314+
#define MMVQ_NWARPS_AMD_RDNA2 1
5315+
#define MMVQ_NWARPS_AMD_OLD 4
5316+
5317+
template <int nwarps, int ncols_y_template, int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
5318+
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
5319+
__launch_bounds__(nwarps*WARP_SIZE, 1) // tells the compiler to use as many registers as it wants
5320+
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
53145321
static __global__ void mul_mat_vec_q(
53155322
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
53165323
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y_par, const int nrows_dst) {
53175324

53185325
const int ncols_y = ncols_y_template != 0 ? ncols_y_template : ncols_y_par;
53195326

5320-
const int row = blockIdx.x*blockDim.y + threadIdx.y;
5321-
5322-
if (row >= nrows_x) {
5323-
return;
5324-
}
5327+
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
5328+
const int row = blockIdx.x;
53255329

53265330
const int blocks_per_row_x = ncols_x / qk;
53275331
const int blocks_per_col_y = nrows_y / QK8_1;
5328-
const int blocks_per_warp = vdr * WARP_SIZE / qi;
5332+
const int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi;
53295333

53305334
// partial sum for each thread
53315335
float tmp[ncols_y_template != 0 ? ncols_y_template : 8] = {0.0f};
53325336

53335337
const block_q_t * x = (const block_q_t *) vx;
53345338
const block_q8_1 * y = (const block_q8_1 *) vy;
53355339

5336-
for (int i = threadIdx.x / (qi/vdr); i < blocks_per_row_x; i += blocks_per_warp) {
5340+
for (int i = tid / (qi/vdr); i < blocks_per_row_x; i += blocks_per_iter) {
53375341
const int ibx = row*blocks_per_row_x + i; // x block index
53385342

53395343
const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
53405344

5341-
const int iqs = vdr * (threadIdx.x % (qi/vdr)); // x block quant index when casting the quants to int
5345+
const int iqs = vdr * (tid % (qi/vdr)); // x block quant index when casting the quants to int
53425346

53435347
#pragma unroll
53445348
for (int j = 0; j < ncols_y; ++j) {
53455349
tmp[j] += vec_dot_q_cuda(&x[ibx], &y[j*blocks_per_col_y + iby], iqs);
53465350
}
53475351
}
53485352

5353+
__shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y_template != 0 ? ncols_y_template : 8][WARP_SIZE];
5354+
if (threadIdx.y > 0) {
5355+
#pragma unroll
5356+
for (int j = 0; j < ncols_y; ++j) {
5357+
tmp_shared[threadIdx.y-1][j][threadIdx.x] = tmp[j];
5358+
}
5359+
}
5360+
__syncthreads();
5361+
if (threadIdx.y > 0) {
5362+
return;
5363+
}
5364+
53495365
// sum up partial sums and write back result
53505366
#pragma unroll
53515367
for (int j = 0; j < ncols_y; ++j) {
5368+
#pragma unroll
5369+
for (int i = 0; i < nwarps-1; ++i) {
5370+
tmp[j] += tmp_shared[i][j][threadIdx.x];
5371+
}
53525372
tmp[j] = warp_reduce_sum(tmp[j]);
53535373

53545374
if (threadIdx.x == 0) {
@@ -6833,46 +6853,65 @@ static void mul_mat_vec_q_cuda(
68336853
GGML_ASSERT(ncols_x % qk == 0);
68346854
GGML_ASSERT(ncols_y <= 4);
68356855

6836-
const int block_num_y = (nrows_x + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
6837-
const dim3 block_nums(block_num_y, 1, 1);
6838-
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
6839-
switch (ncols_y) {
6840-
case 1:
6841-
mul_mat_vec_q<1, qk, qi, block_q_t, vdr, vec_dot>
6842-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6843-
break;
6844-
case 2:
6845-
mul_mat_vec_q<2, qk, qi, block_q_t, vdr, vec_dot>
6846-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6847-
break;
6848-
case 3:
6849-
mul_mat_vec_q<3, qk, qi, block_q_t, vdr, vec_dot>
6850-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6851-
break;
6852-
case 4:
6853-
mul_mat_vec_q<4, qk, qi, block_q_t, vdr, vec_dot>
6854-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6855-
break;
6856-
// case 5:
6857-
// mul_mat_vec_q<5, qk, qi, block_q_t, vdr, vec_dot>
6858-
// <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6859-
// break;
6860-
// case 6:
6861-
// mul_mat_vec_q<6, qk, qi, block_q_t, vdr, vec_dot>
6862-
// <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6863-
// break;
6864-
// case 7:
6865-
// mul_mat_vec_q<7, qk, qi, block_q_t, vdr, vec_dot>
6866-
// <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6867-
// break;
6868-
// case 8:
6869-
// mul_mat_vec_q<8, qk, qi, block_q_t, vdr, vec_dot>
6870-
// <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6871-
// break;
6856+
int id;
6857+
CUDA_CHECK(cudaGetDevice(&id));
6858+
6859+
int nwarps;
6860+
if (g_device_caps[id].cc >= CC_OFFSET_AMD) {
6861+
nwarps = g_device_caps[id].cc >= CC_RDNA2 ? MMVQ_NWARPS_AMD_RDNA2 : MMVQ_NWARPS_AMD_OLD;
6862+
} else {
6863+
nwarps = MMVQ_NWARPS_NVIDIA;
6864+
}
6865+
6866+
const dim3 block_nums(nrows_x, 1, 1);
6867+
const dim3 block_dims(WARP_SIZE, nwarps, 1);
6868+
6869+
switch (nwarps) {
6870+
case 1: switch(ncols_y) {
6871+
case 1:
6872+
mul_mat_vec_q<1, 1, qk, qi, block_q_t, vdr, vec_dot>
6873+
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6874+
break;
6875+
case 2:
6876+
mul_mat_vec_q<1, 2, qk, qi, block_q_t, vdr, vec_dot>
6877+
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6878+
break;
6879+
case 3:
6880+
mul_mat_vec_q<1, 3, qk, qi, block_q_t, vdr, vec_dot>
6881+
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6882+
break;
6883+
case 4:
6884+
mul_mat_vec_q<1, 4, qk, qi, block_q_t, vdr, vec_dot>
6885+
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6886+
break;
6887+
default:
6888+
GGML_ASSERT(false);
6889+
break;
6890+
} break;
6891+
case 4: switch(ncols_y) {
6892+
case 1:
6893+
mul_mat_vec_q<4, 1, qk, qi, block_q_t, vdr, vec_dot>
6894+
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6895+
break;
6896+
case 2:
6897+
mul_mat_vec_q<4, 2, qk, qi, block_q_t, vdr, vec_dot>
6898+
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6899+
break;
6900+
case 3:
6901+
mul_mat_vec_q<4, 3, qk, qi, block_q_t, vdr, vec_dot>
6902+
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6903+
break;
6904+
case 4:
6905+
mul_mat_vec_q<4, 4, qk, qi, block_q_t, vdr, vec_dot>
6906+
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6907+
break;
6908+
default:
6909+
GGML_ASSERT(false);
6910+
break;
6911+
} break;
6912+
68726913
default:
68736914
GGML_ASSERT(false);
6874-
// mul_mat_vec_q<0, qk, qi, block_q_t, vdr, vec_dot>
6875-
// <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
68766915
break;
68776916
}
68786917
}

0 commit comments

Comments
 (0)