Skip to content

Commit abbc1f2

Browse files
committed
mtgpu: disable flash attention on qy1 (MTT S80)
Signed-off-by: Xiaodong Ye <[email protected]>
1 parent 30724dc commit abbc1f2

File tree

3 files changed

+16
-1
lines changed

3 files changed

+16
-1
lines changed

common/arg.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -703,7 +703,12 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
703703
{"-fa", "--flash-attn"},
704704
format("enable Flash Attention (default: %s)", params.flash_attn ? "enabled" : "disabled"),
705705
[](gpt_params & params) {
706+
#ifdef FLASH_ATTN_AVAILABLE
706707
params.flash_attn = true;
708+
#else
709+
GGML_UNUSED(params);
710+
fprintf(stderr, "warning: flash attention is not supported\n");
711+
#endif // FLASH_ATTN_AVAILABLE
707712
}
708713
).set_env("LLAMA_ARG_FLASH_ATTN"));
709714
add_opt(llama_arg(

ggml/src/ggml-cuda/common.cuh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@
5050
#define CC_RDNA1 (CC_OFFSET_AMD + 1010)
5151
#define CC_RDNA2 (CC_OFFSET_AMD + 1030)
5252
#define CC_RDNA3 (CC_OFFSET_AMD + 1100)
53+
#define CC_QY1 210
54+
#define CC_QY2 220
5355

5456
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
5557

@@ -134,6 +136,10 @@ typedef float2 dfloat2;
134136
#define INT8_MMA_AVAILABLE
135137
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
136138

139+
#if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= CC_QY1)
140+
#define FLASH_ATTN_AVAILABLE
141+
#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= CC_QY1)
142+
137143
static constexpr bool fast_fp16_available(const int cc) {
138144
return cc >= CC_PASCAL && cc != 610;
139145
}

ggml/src/ggml-cuda/fattn-tile-f32.cu

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,17 @@ static __global__ void flash_attn_tile_ext_f32(
4444
const int ne1,
4545
const int ne2,
4646
const int ne3) {
47+
#ifndef FLASH_ATTN_AVAILABLE
48+
NO_DEVICE_CODE;
49+
return;
50+
#endif // FLASH_ATTN_AVAILABLE
4751
// Skip unused kernel variants for faster compilation:
4852
if (use_logit_softcap && !(D == 128 || D == 256)) {
4953
NO_DEVICE_CODE;
5054
return;
5155
}
5256

53-
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
57+
// In this kernel Q, K, V are matrices while i, j, k are matrix indices.
5458

5559
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
5660
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.

0 commit comments

Comments
 (0)