Skip to content

musa: add support for muBLAS and MMA #13149

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -76,17 +76,17 @@
#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1)

// Moore Threads
#define GGML_CUDA_MUSA_ARCH_IS_QY1 (__MUSA_ARCH__ <= 210)

#define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000
#define GGML_CUDA_CC_QY2 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000
#define GGML_CUDA_CC_NG (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // TBD
#define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000
#define GGML_CUDA_CC_QY2 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000
#define GGML_CUDA_CC_NG (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // TBD

#define GGML_CUDA_CC_IS_MTHREADS(cc) (cc >= GGML_CUDA_CC_OFFSET_MTHREADS && cc < GGML_CUDA_CC_OFFSET_AMD)
#define GGML_CUDA_CC_IS_QY1(cc) (cc >= GGML_CUDA_CC_QY1 && cc < GGML_CUDA_CC_QY2)
#define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_NG)
#define GGML_CUDA_CC_IS_NG(cc) (cc >= GGML_CUDA_CC_NG)

#define GGML_CUDA_CC_IS_QY1_OR_EARLIER (__MUSA_ARCH__ < 220)

#ifdef __CUDA_ARCH_LIST__
constexpr bool ggml_cuda_has_arch_impl(int) {
return false;
Expand Down Expand Up @@ -211,6 +211,10 @@ typedef float2 dfloat2;
#define FP16_MMA_AVAILABLE
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || defined(RDNA4))

#if defined(GGML_USE_MUSA)
#define FP16_MMA_AVAILABLE
#endif // defined(GGML_USE_MUSA)

#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
#define NEW_MMA_AVAILABLE
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
Expand All @@ -219,9 +223,9 @@ typedef float2 dfloat2;
#define CP_ASYNC_AVAILABLE
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE

#if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && GGML_CUDA_MUSA_ARCH_IS_QY1)
#if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && GGML_CUDA_CC_IS_QY1_OR_EARLIER)
#define FLASH_ATTN_AVAILABLE
#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && GGML_CUDA_MUSA_ARCH_IS_QY1)
#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && GGML_CUDA_CC_IS_QY1_OR_EARLIER)

static bool fp16_available(const int cc) {
return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL;
Expand All @@ -233,7 +237,8 @@ static bool fast_fp16_available(const int cc) {

// To be used for feature selection of external libraries, e.g. cuBLAS.
static bool fast_fp16_hardware_available(const int cc) {
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_PASCAL && cc != 610) || GGML_CUDA_CC_IS_AMD(cc);
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_PASCAL && cc != 610) ||
GGML_CUDA_CC_IS_AMD(cc) || GGML_CUDA_CC_IS_MTHREADS(cc);
}

// Any FP16 tensor core instructions are available for ggml code.
Expand All @@ -242,14 +247,16 @@ static bool fp16_mma_available(const int cc) {
return false;
#else
return (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ||
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc);
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc) ||
GGML_CUDA_CC_IS_MTHREADS(cc);
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
}

// To be used for feature selection of external libraries, e.g. cuBLAS.
static bool fp16_mma_hardware_available(const int cc) {
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) ||
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc);
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc) ||
GGML_CUDA_CC_IS_MTHREADS(cc);
}

// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
Expand Down
4 changes: 4 additions & 0 deletions ggml/src/ggml-cuda/fattn-wmma-f16.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
#ifdef FP16_MMA_AVAILABLE
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
#include <mma.h>
#ifdef GGML_USE_MUSA
namespace wmma = mtmusa::wmma;
#else // GGML_USE_MUSA
namespace wmma = nvcuda::wmma;
#endif // GGML_USE_MUSA
#elif defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)
#undef HIP_ENABLE_WARP_SYNC_BUILTINS // conflicts with rocWMMA headers
#include <rocwmma/rocwmma.hpp>
Expand Down
84 changes: 61 additions & 23 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1200,7 +1200,9 @@ static void ggml_cuda_op_mul_mat_cublas(

const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT;

if (src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
if ((GGML_CUDA_CC_IS_NVIDIA(cc) || GGML_CUDA_CC_IS_AMD(cc) ||
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2)) &&
src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
ggml_cuda_pool_alloc<nv_bfloat16> src1_as_bf16(ctx.pool(id));
if (src1->type != GGML_TYPE_BF16) {
const to_bf16_cuda_t to_bf16_cuda = ggml_get_to_bf16_cuda(src1->type);
Expand Down Expand Up @@ -1228,7 +1230,9 @@ static void ggml_cuda_op_mul_mat_cublas(

const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_BF16);
to_fp32_cuda(dst_bf16.get(), dst_dd_i, row_diff*src1_ncols, stream);
} else if (((GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) || GGML_CUDA_CC_IS_AMD(cc)) && use_fp16) {
} else if (((GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) ||
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2) ||
GGML_CUDA_CC_IS_AMD(cc)) && use_fp16) {
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool(id));
if (src0->type != GGML_TYPE_F16) {
Expand Down Expand Up @@ -1872,13 +1876,24 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
// use cublasGemmBatchedEx
const int64_t ne23 = ne12*ne13;

#ifdef GGML_USE_MUSA
const void ** ptrs_src;
void ** ptrs_dst;
CUDA_CHECK(cudaMalloc((void **)&ptrs_src, sizeof(void *)*2*ne23));
CUDA_CHECK(cudaMalloc((void **)&ptrs_dst, sizeof(void *)*1*ne23));
#else // GGML_USE_MUSA
ggml_cuda_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
ggml_cuda_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23);
#endif // GGML_USE_MUSA

dim3 block_dims(ne13, ne12);
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
src0_f16, src1_f16, dst_t,
#ifdef GGML_USE_MUSA
ptrs_src, ptrs_dst,
#else // GGML_USE_MUSA
ptrs_src.get(), ptrs_dst.get(),
#endif // GGML_USE_MUSA
ne12, ne13,
ne23,
nb02, nb03,
Expand All @@ -1888,15 +1903,31 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
r2, r3);
CUDA_CHECK(cudaGetLastError());

CUBLAS_CHECK(
#ifdef GGML_USE_MUSA
CUDA_CHECK(cudaDeviceSynchronize());
const void **Aarray = (const void **) (ptrs_src + 0*ne23);
const void **Barray = (const void **) (ptrs_src + 1*ne23);
void **Carray = ( void **) (ptrs_dst + 0*ne23);
#else // GGML_USE_MUSA
const void **Aarray = (const void **) (ptrs_src.get() + 0*ne23);
const void **Barray = (const void **) (ptrs_src.get() + 1*ne23);
void **Carray = ( void **) (ptrs_dst.get() + 0*ne23);
#endif // GGML_USE_MUSA

CUBLAS_CHECK(
cublasGemmBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
alpha, (const void **) (ptrs_src.get() + 0*ne23), CUDA_R_16F, nb01/nb00,
(const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F, s11,
beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne0,
alpha, Aarray, CUDA_R_16F, nb01/nb00,
Barray, CUDA_R_16F, s11,
beta, Carray, cu_data_type, ne0,
ne23,
cu_compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));

#ifdef GGML_USE_MUSA
CUDA_CHECK(cudaFree(ptrs_src));
CUDA_CHECK(cudaFree(ptrs_dst));
#endif // GGML_USE_MUSA
}
#endif

Expand Down Expand Up @@ -1926,6 +1957,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor

bool any_gpus_with_slow_fp16 = false;
bool any_gpus_without_fp16_mma = false;
bool any_gpus_without_cublas_gemm = false;

if (split) {
ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context;
Expand All @@ -1936,16 +1968,18 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
continue;
}

const int cc = ggml_cuda_info().devices[id].cc;
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available(cc);
const int cc = ggml_cuda_info().devices[id].cc;
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available(cc);
any_gpus_without_cublas_gemm = any_gpus_without_cublas_gemm || !(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);
}
} else {
const int cc = ggml_cuda_info().devices[ctx.device].cc;
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available(cc);
const int cc = ggml_cuda_info().devices[ctx.device].cc;
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available(cc);
any_gpus_without_cublas_gemm = any_gpus_without_cublas_gemm || !(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);
}

// debug helpers
Expand All @@ -1964,8 +1998,9 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst);
} else if (!split && use_mul_mat_q) {
ggml_cuda_mul_mat_q(ctx, src0, src1, nullptr, dst);
} else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) &&
!ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
} else if (!split && !any_gpus_without_cublas_gemm && src0->type == GGML_TYPE_F16 &&
(src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) &&
!ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
// general KQ + KQV multi-batch without FlashAttention
ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
} else if (use_mul_mat_vec) {
Expand Down Expand Up @@ -3005,9 +3040,17 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
return false;
}
#ifdef GGML_USE_MUSA
if (b->type == GGML_TYPE_F16 && b->ne[2]*b->ne[3] > 1 &&
const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
if (GGML_CUDA_CC_IS_MTHREADS(cc) && b->ne[2]*b->ne[3] > 1 &&
!ggml_is_transposed(a) && !ggml_is_transposed(b)) {
return false;
if (GGML_CUDA_CC_IS_QY1(cc) && op->op == GGML_OP_MUL_MAT &&
a->type == GGML_TYPE_F16 && b->type == GGML_TYPE_F16) {
return false;
}
if (GGML_CUDA_CC_IS_QY2(cc) && op->op == GGML_OP_MUL_MAT_ID &&
a->type == GGML_TYPE_Q2_K && b->type == GGML_TYPE_F32) {
return false;
}
}
#endif // GGML_USE_MUSA
switch (a->type) {
Expand All @@ -3034,11 +3077,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_BF16:
#ifdef GGML_USE_MUSA
if (a->type == GGML_TYPE_Q3_K) {
return false;
}
#endif // GGML_USE_MUSA
return true;
default:
return false;
Expand Down
Loading