Skip to content

Commit df79623

Browse files
committed
mtgpu: disable flash attention on qy1 (MTT S80); disable q3_k and mul_mat_batched_cublas
Signed-off-by: Xiaodong Ye <[email protected]>
1 parent 30724dc commit df79623

File tree

3 files changed

+25
-1
lines changed

3 files changed

+25
-1
lines changed

ggml/src/ggml-cuda.cu

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2785,6 +2785,12 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
27852785
if (op->op == GGML_OP_MUL_MAT && a->ne[3] != b->ne[3]) {
27862786
return false;
27872787
}
2788+
#ifdef GGML_USE_MUSA
2789+
if (b->type == GGML_TYPE_F16 && b->ne[2]*b->ne[3] > 1 &&
2790+
!ggml_is_transposed(a) && !ggml_is_transposed(b)) {
2791+
return false;
2792+
}
2793+
#endif // GGML_USE_MUSA
27882794
switch (a->type) {
27892795
case GGML_TYPE_F32:
27902796
case GGML_TYPE_F16:
@@ -2808,6 +2814,11 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
28082814
case GGML_TYPE_IQ3_XXS:
28092815
case GGML_TYPE_IQ4_NL:
28102816
case GGML_TYPE_IQ4_XS:
2817+
#ifdef GGML_USE_MUSA
2818+
if (a->type == GGML_TYPE_Q3_K) {
2819+
return false;
2820+
}
2821+
#endif // GGML_USE_MUSA
28112822
return true;
28122823
default:
28132824
return false;
@@ -2924,6 +2935,9 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
29242935
case GGML_OP_LEAKY_RELU:
29252936
return true;
29262937
case GGML_OP_FLASH_ATTN_EXT:
2938+
#ifndef FLASH_ATTN_AVAILABLE
2939+
return false;
2940+
#endif
29272941
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
29282942
return (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) || op->src[0]->ne[0] == 128;
29292943
#else

ggml/src/ggml-cuda/common.cuh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@
5050
#define CC_RDNA1 (CC_OFFSET_AMD + 1010)
5151
#define CC_RDNA2 (CC_OFFSET_AMD + 1030)
5252
#define CC_RDNA3 (CC_OFFSET_AMD + 1100)
53+
#define CC_QY1 210
54+
#define CC_QY2 220
5355

5456
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
5557

@@ -134,6 +136,10 @@ typedef float2 dfloat2;
134136
#define INT8_MMA_AVAILABLE
135137
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
136138

139+
#if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= CC_QY1)
140+
#define FLASH_ATTN_AVAILABLE
141+
#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= CC_QY1)
142+
137143
static constexpr bool fast_fp16_available(const int cc) {
138144
return cc >= CC_PASCAL && cc != 610;
139145
}

ggml/src/ggml-cuda/fattn-tile-f32.cu

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,17 @@ static __global__ void flash_attn_tile_ext_f32(
4444
const int ne1,
4545
const int ne2,
4646
const int ne3) {
47+
#ifndef FLASH_ATTN_AVAILABLE
48+
NO_DEVICE_CODE;
49+
return;
50+
#endif // FLASH_ATTN_AVAILABLE
4751
// Skip unused kernel variants for faster compilation:
4852
if (use_logit_softcap && !(D == 128 || D == 256)) {
4953
NO_DEVICE_CODE;
5054
return;
5155
}
5256

53-
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
57+
// In this kernel Q, K, V are matrices while i, j, k are matrix indices.
5458

5559
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
5660
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.

0 commit comments

Comments
 (0)