Skip to content

musa: enable MMA #13149

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,10 @@ typedef float2 dfloat2;
#define FP16_MMA_AVAILABLE
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || defined(RDNA4))

#if defined(GGML_USE_MUSA) && !GGML_CUDA_MUSA_ARCH_IS_QY1
#define FP16_MMA_AVAILABLE
#endif // defined(GGML_USE_MUSA) && !GGML_CUDA_MUSA_ARCH_IS_QY1

#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
#define NEW_MMA_AVAILABLE
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
Expand All @@ -237,7 +241,7 @@ static bool fast_fp16_available(const int cc) {

// To be used for feature selection of external libraries, e.g. cuBLAS.
static bool fast_fp16_hardware_available(const int cc) {
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_PASCAL && cc != 610) || GGML_CUDA_CC_IS_AMD(cc);
return cc >= GGML_CUDA_CC_PASCAL && cc != 610 && cc != GGML_CUDA_CC_QY1;
}

// Any FP16 tensor core instructions are available for ggml code.
Expand All @@ -246,13 +250,15 @@ static bool fp16_mma_available(const int cc) {
return false;
#else
return (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ||
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2) ||
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc);
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
}

// To be used for feature selection of external libraries, e.g. cuBLAS.
static bool fp16_mma_hardware_available(const int cc) {
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) ||
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2) ||
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc);
}

Expand Down
4 changes: 4 additions & 0 deletions ggml/src/ggml-cuda/fattn-wmma-f16.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
#ifdef FP16_MMA_AVAILABLE
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
#include <mma.h>
#ifdef GGML_USE_MUSA
namespace wmma = mtmusa::wmma;
#else // GGML_USE_MUSA
namespace wmma = nvcuda::wmma;
#endif // GGML_USE_MUSA
#elif defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)
#undef HIP_ENABLE_WARP_SYNC_BUILTINS // conflicts with rocWMMA headers
#include <rocwmma/rocwmma.hpp>
Expand Down
57 changes: 42 additions & 15 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are you manually allocating and deallocating memory instead of using ggml_cuda_pool_alloc? Batched FP16 GEMM is used for attention without FlashAttention so most likely this is where the bug is. I don't remember what the synchronization behavior of cudaFree is but if it's done asynchronously from the kernel executions that would explain why you get incorrect results.

Original file line number Diff line number Diff line change
Expand Up @@ -1851,13 +1851,24 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
// use cublasGemmBatchedEx
const int ne23 = ne12*ne13;

#ifdef GGML_USE_MUSA
const void ** ptrs_src;
void ** ptrs_dst;
CUDA_CHECK(cudaMalloc((void **)&ptrs_src, sizeof(void *)*2*ne23));
CUDA_CHECK(cudaMalloc((void **)&ptrs_dst, sizeof(void *)*1*ne23));
#else // GGML_USE_MUSA
ggml_cuda_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
ggml_cuda_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23);
#endif // GGML_USE_MUSA

dim3 block_dims(ne13, ne12);
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
src0_f16, src1_f16, dst_t,
#ifdef GGML_USE_MUSA
ptrs_src, ptrs_dst,
#else // GGML_USE_MUSA
ptrs_src.get(), ptrs_dst.get(),
#endif // GGML_USE_MUSA
ne12, ne13,
ne23,
nb02, nb03,
Expand All @@ -1867,15 +1878,31 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
r2, r3);
CUDA_CHECK(cudaGetLastError());

CUBLAS_CHECK(
#ifdef GGML_USE_MUSA
cudaDeviceSynchronize();
const void **Aarray = (const void **) (ptrs_src + 0 * ne23);
const void **Barray = (const void **) (ptrs_src + 1 * ne23);
void **Carray = (void **) (ptrs_dst + 0 * ne23);
#else // GGML_USE_MUSA
const void **Aarray = (const void **) (ptrs_src.get() + 0 * ne23);
const void **Barray = (const void **) (ptrs_src.get() + 1 * ne23);
void **Carray = (void **) (ptrs_dst.get() + 0 * ne23);
#endif // GGML_USE_MUSA

CUBLAS_CHECK(
cublasGemmBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
alpha, (const void **) (ptrs_src.get() + 0*ne23), CUDA_R_16F, nb01/nb00,
(const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F, nb11/nb10,
beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne01,
alpha, Aarray, CUDA_R_16F, nb01/nb00,
Barray, CUDA_R_16F, nb11/nb10,
beta, Carray, cu_data_type, ne01,
ne23,
cu_compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));

#ifdef GGML_USE_MUSA
CUDA_CHECK(cudaFree(ptrs_src));
CUDA_CHECK(cudaFree(ptrs_dst));
#endif // GGML_USE_MUSA
}
#endif

Expand Down Expand Up @@ -3011,12 +3038,12 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) {
return false;
}
#ifdef GGML_USE_MUSA
if (b->type == GGML_TYPE_F16 && b->ne[2]*b->ne[3] > 1 &&
!ggml_is_transposed(a) && !ggml_is_transposed(b)) {
return false;
}
#endif // GGML_USE_MUSA
// #ifdef GGML_USE_MUSA
// if (b->type == GGML_TYPE_F16 && b->ne[2]*b->ne[3] > 1 &&
// !ggml_is_transposed(a) && !ggml_is_transposed(b)) {
// return false;
// }
// #endif // GGML_USE_MUSA
switch (a->type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
Expand All @@ -3041,11 +3068,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_BF16:
#ifdef GGML_USE_MUSA
if (a->type == GGML_TYPE_Q3_K) {
return false;
}
#endif // GGML_USE_MUSA
// #ifdef GGML_USE_MUSA
// if (a->type == GGML_TYPE_Q3_K) {
// return false;
// }
// #endif // GGML_USE_MUSA
return true;
default:
return false;
Expand Down
Loading