Skip to content

Commit c35e586

Browse files
authored
musa: enable building fat binaries, enable unified memory, and disable Flash Attention on QY1 (MTT S80) (#9526)
* mtgpu: add mp_21 support Signed-off-by: Xiaodong Ye <[email protected]> * mtgpu: disable flash attention on qy1 (MTT S80); disable q3_k and mul_mat_batched_cublas Signed-off-by: Xiaodong Ye <[email protected]> * mtgpu: enable unified memory Signed-off-by: Xiaodong Ye <[email protected]> * mtgpu: map cublasOperation_t to mublasOperation_t (sync code to latest) Signed-off-by: Xiaodong Ye <[email protected]> --------- Signed-off-by: Xiaodong Ye <[email protected]>
1 parent 912c331 commit c35e586

File tree

6 files changed

+31
-5
lines changed

6 files changed

+31
-5
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -611,7 +611,7 @@ ifdef GGML_CUDA
611611

612612
MK_CPPFLAGS += -DGGML_USE_CUDA -I$(CUDA_PATH)/include
613613
MK_LDFLAGS += -lmusa -lmublas -lmusart -lpthread -ldl -lrt -L$(CUDA_PATH)/lib -L/usr/lib64
614-
MK_NVCCFLAGS += -x musa -mtgpu --cuda-gpu-arch=mp_22
614+
MK_NVCCFLAGS += -x musa -mtgpu --cuda-gpu-arch=mp_21 --cuda-gpu-arch=mp_22
615615
else
616616
ifneq ('', '$(wildcard /opt/cuda)')
617617
CUDA_PATH ?= /opt/cuda

ggml/src/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ if (GGML_CUDA)
364364
if (GGML_MUSA)
365365
set_source_files_properties(${GGML_SOURCES_CUDA} PROPERTIES LANGUAGE CXX)
366366
foreach(SOURCE ${GGML_SOURCES_CUDA})
367-
set_property(SOURCE ${SOURCE} PROPERTY COMPILE_FLAGS "-x musa -mtgpu --cuda-gpu-arch=mp_22")
367+
set_property(SOURCE ${SOURCE} PROPERTY COMPILE_FLAGS "-x musa -mtgpu --cuda-gpu-arch=mp_21 --cuda-gpu-arch=mp_22")
368368
endforeach()
369369
endif()
370370

ggml/src/ggml-cuda.cu

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
136136
return res;
137137
#else
138138

139-
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
139+
#if !defined(GGML_USE_HIPBLAS)
140140
cudaError_t err;
141141
if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr)
142142
{
@@ -149,7 +149,7 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
149149
return err;
150150
#else
151151
return cudaMalloc(ptr, size);
152-
#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
152+
#endif // !defined(GGML_USE_HIPBLAS)
153153

154154
#endif
155155
}
@@ -2830,6 +2830,12 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
28302830
if (op->op == GGML_OP_MUL_MAT && a->ne[3] != b->ne[3]) {
28312831
return false;
28322832
}
2833+
#ifdef GGML_USE_MUSA
2834+
if (b->type == GGML_TYPE_F16 && b->ne[2]*b->ne[3] > 1 &&
2835+
!ggml_is_transposed(a) && !ggml_is_transposed(b)) {
2836+
return false;
2837+
}
2838+
#endif // GGML_USE_MUSA
28332839
switch (a->type) {
28342840
case GGML_TYPE_F32:
28352841
case GGML_TYPE_F16:
@@ -2853,6 +2859,11 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
28532859
case GGML_TYPE_IQ3_XXS:
28542860
case GGML_TYPE_IQ4_NL:
28552861
case GGML_TYPE_IQ4_XS:
2862+
#ifdef GGML_USE_MUSA
2863+
if (a->type == GGML_TYPE_Q3_K) {
2864+
return false;
2865+
}
2866+
#endif // GGML_USE_MUSA
28562867
return true;
28572868
default:
28582869
return false;
@@ -2978,6 +2989,9 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
29782989
case GGML_OP_RWKV_WKV:
29792990
return true;
29802991
case GGML_OP_FLASH_ATTN_EXT: {
2992+
#ifndef FLASH_ATTN_AVAILABLE
2993+
return false;
2994+
#endif
29812995
if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) {
29822996
return true;
29832997
}

ggml/src/ggml-cuda/common.cuh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@
5050
#define CC_RDNA1 (CC_OFFSET_AMD + 1010)
5151
#define CC_RDNA2 (CC_OFFSET_AMD + 1030)
5252
#define CC_RDNA3 (CC_OFFSET_AMD + 1100)
53+
#define CC_QY1 210
54+
#define CC_QY2 220
5355

5456
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
5557

@@ -134,6 +136,10 @@ typedef float2 dfloat2;
134136
#define INT8_MMA_AVAILABLE
135137
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
136138

139+
#if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= CC_QY1)
140+
#define FLASH_ATTN_AVAILABLE
141+
#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= CC_QY1)
142+
137143
static constexpr bool fast_fp16_available(const int cc) {
138144
return cc >= CC_PASCAL && cc != 610;
139145
}

ggml/src/ggml-cuda/fattn-tile-f32.cu

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,17 @@ static __global__ void flash_attn_tile_ext_f32(
4444
const int ne1,
4545
const int ne2,
4646
const int ne3) {
47+
#ifndef FLASH_ATTN_AVAILABLE
48+
NO_DEVICE_CODE;
49+
return;
50+
#endif // FLASH_ATTN_AVAILABLE
4751
// Skip unused kernel variants for faster compilation:
4852
if (use_logit_softcap && !(D == 128 || D == 256)) {
4953
NO_DEVICE_CODE;
5054
return;
5155
}
5256

53-
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
57+
// In this kernel Q, K, V are matrices while i, j, k are matrix indices.
5458

5559
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
5660
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.

ggml/src/ggml-cuda/vendors/musa.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#define cublasSetStream mublasSetStream
2727
#define cublasSgemm mublasSgemm
2828
#define cublasStatus_t mublasStatus_t
29+
#define cublasOperation_t mublasOperation_t
2930
#define cublasGetStatusString mublasStatus_to_string
3031
#define cudaDataType_t musaDataType_t
3132
#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
@@ -56,6 +57,7 @@
5657
#define cudaLaunchHostFunc musaLaunchHostFunc
5758
#define cudaMalloc musaMalloc
5859
#define cudaMallocHost musaMallocHost
60+
#define cudaMallocManaged musaMallocManaged
5961
#define cudaMemcpy musaMemcpy
6062
#define cudaMemcpyAsync musaMemcpyAsync
6163
#define cudaMemcpyPeerAsync musaMemcpyPeerAsync

0 commit comments

Comments
 (0)