Skip to content

Commit 6aa4232

Browse files
committed
CUDA => MUSA
Signed-off-by: Xiaodong Ye <[email protected]>
1 parent 2cb8bda commit 6aa4232

File tree

4 files changed

+234
-8
lines changed

4 files changed

+234
-8
lines changed

ggml/include/ggml-cuda.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
#ifdef GGML_USE_HIPBLAS
77
#define GGML_CUDA_NAME "ROCm"
88
#define GGML_CUBLAS_NAME "hipBLAS"
9+
#elif defined(GGML_USE_MUSA)
10+
#define GGML_CUDA_NAME "MUSA"
11+
#define GGML_CUBLAS_NAME "muBLAS"
912
#else
1013
#define GGML_CUDA_NAME "CUDA"
1114
#define GGML_CUBLAS_NAME "cuBLAS"

ggml/src/ggml-common.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@ typedef half2 ggml_half2;
1919

2020
#define GGML_COMMON_DECL
2121
#elif defined(GGML_COMMON_DECL_CUDA)
22+
#if defined(GGML_COMMON_DECL_MUSA)
23+
#include <musa_fp16.h>
24+
#else
2225
#include <cuda_fp16.h>
26+
#endif
2327
#include <cstdint>
2428

2529
typedef half ggml_half;
@@ -415,7 +419,7 @@ static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_
415419
#define GGML_TABLE_END() };
416420

417421
#define GGML_COMMON_IMPL
418-
#elif defined(GGML_COMMON_IMPL_CUDA) || defined(GGML_COMMON_IMPL_HIP)
422+
#elif defined(GGML_COMMON_IMPL_CUDA) || defined(GGML_COMMON_IMPL_HIP) || defined(GGML_COMMON_IMPL_MUSA)
419423
#include <cstdint>
420424

421425
#define GGML_TABLE_BEGIN(type, name, size) static const __device__ type name[size] = {

ggml/src/ggml-cuda.cu

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1341,7 +1341,7 @@ static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) {
13411341
static cudaError_t ggml_cuda_Memcpy2DPeerAsync(
13421342
void * dst, int dstDevice, size_t dpitch, void * src, int srcDevice, size_t spitch, size_t width, size_t height, cudaStream_t stream) {
13431343

1344-
#if !defined(GGML_USE_HIPBLAS)
1344+
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
13451345
// cudaMemcpy2DAsync may fail with copies between vmm pools of different devices
13461346
cudaMemcpy3DPeerParms p = {};
13471347
p.dstDevice = dstDevice;
@@ -1355,7 +1355,7 @@ static cudaError_t ggml_cuda_Memcpy2DPeerAsync(
13551355
GGML_UNUSED(dstDevice);
13561356
GGML_UNUSED(srcDevice);
13571357
return cudaMemcpy2DAsync(dst, dpitch, src, spitch, width, height, cudaMemcpyDeviceToDevice, stream);
1358-
#endif // !defined(GGML_USE_HIPBLAS)
1358+
#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
13591359
}
13601360

13611361
static void ggml_cuda_op_mul_mat(
@@ -1821,6 +1821,9 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
18211821
}
18221822
}
18231823
#else
1824+
#ifdef GGML_USE_MUSA
1825+
GGML_ASSERT(false);
1826+
#else // !GGML_USE_MUSA
18241827
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
18251828
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
18261829
// use cublasGemmStridedBatchedEx
@@ -1863,6 +1866,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
18631866
cu_compute_type,
18641867
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
18651868
}
1869+
#endif // GGML_USE_MUSA
18661870
#endif
18671871

18681872
if (dst->op_params[0] == GGML_PREC_DEFAULT) {
@@ -1902,11 +1906,17 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
19021906
const int cc = ggml_cuda_info().devices[id].cc;
19031907
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
19041908
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc);
1909+
#ifdef GGML_USE_MUSA
1910+
use_mul_mat_vec_q = false;
1911+
#endif // GGML_USE_MUSA
19051912
}
19061913
} else {
19071914
const int cc = ggml_cuda_info().devices[ctx.device].cc;
19081915
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
19091916
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc);
1917+
#ifdef GGML_USE_MUSA
1918+
use_mul_mat_vec_q = false;
1919+
#endif // GGML_USE_MUSA
19101920
}
19111921

19121922
// debug helpers
@@ -3019,7 +3029,7 @@ GGML_CALL bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size
30193029
return false;
30203030
}
30213031

3022-
#if CUDART_VERSION >= 11100
3032+
#if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA)
30233033
cudaError_t err = cudaHostRegister(buffer, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly);
30243034
if (err != cudaSuccess) {
30253035
// clear the error

ggml/src/ggml-cuda/common.cuh

Lines changed: 213 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
#else
1313
#define GGML_COMMON_DECL_CUDA
1414
#define GGML_COMMON_IMPL_CUDA
15+
#if defined(GGML_USE_MUSA)
16+
#define GGML_COMMON_DECL_MUSA
17+
#define GGML_COMMON_IMPL_MUSA
18+
#endif
1519
#endif
1620
#include "ggml-common.h"
1721

@@ -114,6 +118,151 @@
114118
#define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED
115119
#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
116120
#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
121+
#elif defined(GGML_USE_MUSA)
122+
#include <musa_runtime.h>
123+
#include <musa.h>
124+
#include <mublas.h>
125+
#include <musa_fp16.h>
126+
// XXX: Keep the following order the same as hipBLAS
127+
// #define CUBLAS_COMPUTE_16F MUBLAS_COMPUTE_16F
128+
// #define CUBLAS_COMPUTE_32F MUBLAS_COMPUTE_32F
129+
#define CUBLAS_COMPUTE_32F_FAST_16F MUBLAS_COMPUTE_32F_FAST_16F
130+
#define CUBLAS_GEMM_DEFAULT MUBLAS_GEMM_DEFAULT
131+
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT
132+
#define CUBLAS_OP_N MUBLAS_OP_N
133+
#define CUBLAS_OP_T MUBLAS_OP_T
134+
#define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS
135+
// #define CUBLAS_TF32_TENSOR_OP_MATH 0
136+
#define CUDA_R_16F MUSA_R_16F
137+
#define CUDA_R_32F MUSA_R_32F
138+
// #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
139+
// #define cublasComputeType_t mublasComputeType_t
140+
#define cublasCreate mublasCreate
141+
#define cublasDestroy mublasDestroy
142+
#define cublasGemmEx mublasGemmEx
143+
#define cublasGemmBatchedEx mublasGemmBatchedEx
144+
#define cublasGemmStridedBatchedEx mublasGemmStridedBatchedEx
145+
#define cublasHandle_t mublasHandle_t
146+
// #define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
147+
#define cublasSetMathMode mublasSetMathMode
148+
#define cublasSetStream mublasSetStream
149+
#define cublasSgemm mublasSgemm
150+
#define cublasStatus_t mublasStatus_t
151+
#define cudaDataType_t musaDataType_t //deprecated, new hipblasDatatype not in 5.6
152+
#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
153+
#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess
154+
#define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess
155+
#define cudaDeviceProp musaDeviceProp
156+
#define cudaDeviceSynchronize musaDeviceSynchronize
157+
#define cudaError_t musaError_t
158+
#define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled
159+
#define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled
160+
#define cudaEventCreateWithFlags musaEventCreateWithFlags
161+
#define cudaEventDisableTiming musaEventDisableTiming
162+
#define cudaEventRecord musaEventRecord
163+
#define cudaEventSynchronize musaEventSynchronize
164+
#define cudaEvent_t musaEvent_t
165+
#define cudaEventDestroy musaEventDestroy
166+
#define cudaFree musaFree
167+
#define cudaFreeHost musaFreeHost
168+
#define cudaGetDevice musaGetDevice
169+
#define cudaGetDeviceCount musaGetDeviceCount
170+
#define cudaGetDeviceProperties musaGetDeviceProperties
171+
#define cudaGetErrorString musaGetErrorString
172+
#define cudaGetLastError musaGetLastError
173+
#define cudaHostRegister musaHostRegister
174+
#define cudaHostRegisterPortable musaHostRegisterPortable
175+
#define cudaHostRegisterReadOnly musaHostRegisterReadOnly
176+
#define cudaHostUnregister musaHostUnregister
177+
#define cudaLaunchHostFunc musaLaunchHostFunc
178+
#define cudaMalloc musaMalloc
179+
#define cudaMallocHost musaMallocHost
180+
#define cudaMemcpy musaMemcpy
181+
#define cudaMemcpyAsync musaMemcpyAsync
182+
#define cudaMemcpyPeerAsync musaMemcpyPeerAsync
183+
#define cudaMemcpy2DAsync musaMemcpy2DAsync
184+
#define cudaMemcpyDeviceToDevice musaMemcpyDeviceToDevice
185+
#define cudaMemcpyDeviceToHost musaMemcpyDeviceToHost
186+
#define cudaMemcpyHostToDevice musaMemcpyHostToDevice
187+
#define cudaMemcpyKind musaMemcpyKind
188+
#define cudaMemset musaMemset
189+
#define cudaMemsetAsync musaMemsetAsync
190+
#define cudaMemGetInfo musaMemGetInfo
191+
#define cudaOccupancyMaxPotentialBlockSize musaOccupancyMaxPotentialBlockSize
192+
#define cudaSetDevice musaSetDevice
193+
#define cudaStreamCreateWithFlags musaStreamCreateWithFlags
194+
#define cudaStreamDestroy musaStreamDestroy
195+
#define cudaStreamFireAndForget musaStreamFireAndForget
196+
#define cudaStreamNonBlocking musaStreamNonBlocking
197+
#define cudaStreamPerThread musaStreamPerThread
198+
#define cudaStreamSynchronize musaStreamSynchronize
199+
#define cudaStreamWaitEvent musaStreamWaitEvent
200+
#define cudaStream_t musaStream_t
201+
#define cudaSuccess musaSuccess
202+
203+
// XXX: Other CUDA => MUSA mapping
204+
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE MU_MEM_ACCESS_FLAGS_PROT_READWRITE
205+
#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED MU_MEM_ALLOC_GRANULARITY_RECOMMENDED
206+
#define CU_MEM_ALLOCATION_TYPE_PINNED MU_MEM_ALLOCATION_TYPE_PINNED
207+
#define CU_MEM_LOCATION_TYPE_DEVICE MU_MEM_LOCATION_TYPE_DEVICE
208+
#define CUdevice MUdevice
209+
#define CUdeviceptr MUdeviceptr
210+
#define CUmemAccessDesc MUmemAccessDesc
211+
#define CUmemAllocationProp MUmemAllocationProp
212+
#define CUmemGenericAllocationHandle MUmemGenericAllocationHandle
213+
#define cuDeviceGet muDeviceGet
214+
#define cuDeviceGetAttribute muDeviceGetAttribute
215+
#define cuMemAddressFree muMemAddressFree
216+
#define cuMemAddressReserve muMemAddressReserve
217+
#define cuMemCreate muMemCreate
218+
#define cuMemGetAllocationGranularity muMemGetAllocationGranularity
219+
#define cuMemMap muMemMap
220+
#define cuMemRelease muMemRelease
221+
#define cuMemSetAccess muMemSetAccess
222+
#define cuMemUnmap muMemUnmap
223+
#define cudaFuncAttributeMaxDynamicSharedMemorySize musaFuncAttributeMaxDynamicSharedMemorySize
224+
#define cudaFuncSetAttribute musaFuncSetAttribute
225+
#define cudaMemcpy3DPeerParms musaMemcpy3DPeerParms
226+
#define make_cudaExtent make_musaExtent
227+
#define make_cudaPitchedPtr make_musaPitchedPtr
228+
229+
// XXX: USE_CUDA_GRAPH
230+
#define CUDA_SUCCESS MUSA_SUCCESS
231+
#define CUresult MUresult
232+
#define cuGetErrorString muGetErrorString
233+
#define cudaErrorGraphExecUpdateFailure musaErrorGraphExecUpdateFailure
234+
#define cudaErrorInvalidDeviceFunction musaErrorInvalidDeviceFunction
235+
#define cudaGraphDestroy musaGraphDestroy
236+
#define cudaGraphExecDestroy musaGraphExecDestroy
237+
#define cudaGraphExec_t musaGraphExec_t
238+
#define cudaGraphExecUpdate musaGraphExecUpdate
239+
#define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult
240+
#define cudaGraphGetNodes musaGraphGetNodes
241+
#define cudaGraphInstantiate musaGraphInstantiate
242+
#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
243+
#define cudaGraphKernelNodeSetParams musaGraphKernelNodeSetParams
244+
#define cudaGraphLaunch musaGraphLaunch
245+
#define cudaGraphNodeGetType musaGraphNodeGetType
246+
#define cudaGraphNode_t musaGraphNode_t
247+
#define cudaGraphNodeType musaGraphNodeType
248+
#define cudaGraphNodeTypeKernel musaGraphNodeTypeKernel
249+
#define cudaGraph_t musaGraph_t
250+
#define cudaKernelNodeParams musaKernelNodeParams
251+
#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
252+
#define cudaStreamEndCapture musaStreamEndCapture
253+
254+
// XXX: cuBLAS => muBLAS mapping
255+
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED MU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
256+
#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_MATH_MODE_DEFAULT
257+
#define CUBLAS_COMPUTE_16F CUDA_R_16F
258+
#define CUBLAS_COMPUTE_32F CUDA_R_32F
259+
#define cublasComputeType_t cudaDataType_t
260+
261+
// XXX: Clang builtins mapping
262+
#define __vsubss4 __vsubss4_musa
263+
#define __vsub4 __vsub4_musa
264+
#define __vcmpeq4 __vcmpeq4_musa
265+
#define __vcmpne4 __vcmpne4_musa
117266
#else
118267
#include <cuda_runtime.h>
119268
#include <cuda.h>
@@ -168,9 +317,13 @@ void ggml_cuda_error(const char * stmt, const char * func, const char * file, in
168317

169318
#define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString)
170319

171-
#if CUDART_VERSION >= 12000
172-
static const char * cublas_get_error_str(const cublasStatus_t err) {
320+
#if CUDART_VERSION >= 12000 || defined(GGML_USE_MUSA)
321+
static const char * cublas_get_error_str(const mublasStatus_t err) {
322+
#ifndef GGML_USE_MUSA
173323
return cublasGetStatusString(err);
324+
#else
325+
return mublasStatus_to_string(err);
326+
#endif // GGML_USE_MUSA
174327
}
175328
#else
176329
static const char * cublas_get_error_str(const cublasStatus_t err) {
@@ -200,7 +353,7 @@ static const char * cu_get_error_str(CUresult err) {
200353
#define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
201354
#endif
202355

203-
#if CUDART_VERSION >= 11100
356+
#if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA)
204357
#define GGML_CUDA_ASSUME(x) __builtin_assume(x)
205358
#else
206359
#define GGML_CUDA_ASSUME(x)
@@ -214,6 +367,62 @@ typedef float dfloat; // dequantize float
214367
typedef float2 dfloat2;
215368
#endif //GGML_CUDA_F16
216369

370+
#if defined(GGML_USE_MUSA)
371+
#ifndef __has_builtin
372+
#define __has_builtin(x) 0
373+
#endif
374+
375+
typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
376+
typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
377+
static __device__ __forceinline__ int __vsubss4_musa(const int a, const int b) {
378+
const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
379+
const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
380+
#if __has_builtin(__builtin_elementwise_sub_sat)
381+
const int8x4_t c = __builtin_elementwise_sub_sat(va, vb);
382+
return reinterpret_cast<const int &>(c);
383+
#else
384+
int8x4_t c;
385+
int16_t tmp;
386+
#pragma unroll
387+
for (int i = 0; i < 4; i++) {
388+
tmp = va[i] - vb[i];
389+
if(tmp > std::numeric_limits<int8_t>::max()) tmp = std::numeric_limits<int8_t>::max();
390+
if(tmp < std::numeric_limits<int8_t>::min()) tmp = std::numeric_limits<int8_t>::min();
391+
c[i] = tmp;
392+
}
393+
return reinterpret_cast<int &>(c);
394+
#endif // __has_builtin(__builtin_elementwise_sub_sat)
395+
}
396+
397+
static __device__ __forceinline__ int __vsub4_musa(const int a, const int b) {
398+
return __vsubss4_musa(a, b);
399+
}
400+
401+
static __device__ __forceinline__ unsigned int __vcmpeq4_musa(unsigned int a, unsigned int b) {
402+
const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
403+
const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
404+
unsigned int c;
405+
uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
406+
#pragma unroll
407+
for (int i = 0; i < 4; ++i) {
408+
vc[i] = va[i] == vb[i] ? 0xff : 0x00;
409+
}
410+
return c;
411+
}
412+
413+
static __device__ __forceinline__ unsigned int __vcmpne4_musa(unsigned int a, unsigned int b) {
414+
const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
415+
const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
416+
unsigned int c;
417+
uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
418+
#pragma unroll
419+
for (int i = 0; i < 4; ++i) {
420+
vc[i] = va[i] == vb[i] ? 0x00 : 0xff;
421+
}
422+
return c;
423+
}
424+
#endif // defined(GGML_USE_MUSA)
425+
217426
#if defined(GGML_USE_HIPBLAS)
218427
#define __CUDA_ARCH__ 1300
219428

@@ -455,7 +664,7 @@ static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half
455664
const uint32_t mask_high = 0xFFFF0000 * (float(__high2half(a)) > float(__high2half(b)));
456665
return mask_low | mask_high;
457666
}
458-
#endif // CUDART_VERSION < 12000
667+
#endif // CUDART_VERSION < CUDART_HMASK
459668

460669
static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, int c) {
461670
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)

0 commit comments

Comments
 (0)