Skip to content

Commit f195490

Browse files
try AMD fix
1 parent 7a0f63a commit f195490

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

ggml-cuda.cu

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5310,8 +5310,9 @@ template <bool need_check> static __global__ void
53105310
#endif // __CUDA_ARCH__ >= CC_VOLTA
53115311
}
53125312

5313-
#define MMVQ_NWARPS_NVIDIA 4
5314-
#define MMVQ_NWARPS_AMD 1
5313+
#define MMVQ_NWARPS_NVIDIA 4
5314+
#define MMVQ_NWARPS_AMD_RDNA2 1
5315+
#define MMVQ_NWARPS_AMD_OLD 4
53155316

53165317
template <int nwarps, int ncols_y_template, int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
53175318
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
@@ -6855,7 +6856,12 @@ static void mul_mat_vec_q_cuda(
68556856
int id;
68566857
CUDA_CHECK(cudaGetDevice(&id));
68576858

6858-
const int nwarps = g_device_caps[id].cc >= CC_OFFSET_AMD ? MMVQ_NWARPS_AMD : MMVQ_NWARPS_NVIDIA;
6859+
int nwarps;
6860+
if (g_device_caps[id].cc >= CC_OFFSET_AMD) {
6861+
nwarps = g_device_caps[id].cc >= CC_RDNA2 ? MMVQ_NWARPS_AMD_RDNA2 : MMVQ_NWARPS_AMD_OLD;
6862+
} else {
6863+
nwarps = MMVQ_NWARPS_NVIDIA;
6864+
}
68596865

68606866
const dim3 block_nums(nrows_x, 1, 1);
68616867
const dim3 block_dims(WARP_SIZE, nwarps, 1);

0 commit comments

Comments
 (0)