12
12
#else
13
13
#define GGML_COMMON_DECL_CUDA
14
14
#define GGML_COMMON_IMPL_CUDA
15
+ #if defined(GGML_USE_MUSA)
16
+ #define GGML_COMMON_DECL_MUSA
17
+ #define GGML_COMMON_IMPL_MUSA
18
+ #endif
15
19
#endif
16
20
#include " ggml-common.h"
17
21
114
118
#define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED
115
119
#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
116
120
#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
121
+ #elif defined(GGML_USE_MUSA)
122
+ #include < musa_runtime.h>
123
+ #include < musa.h>
124
+ #include < mublas.h>
125
+ #include < musa_fp16.h>
126
+ // XXX: Keep the following order the same as hipBLAS
127
+ // #define CUBLAS_COMPUTE_16F MUBLAS_COMPUTE_16F
128
+ // #define CUBLAS_COMPUTE_32F MUBLAS_COMPUTE_32F
129
+ #define CUBLAS_COMPUTE_32F_FAST_16F MUBLAS_COMPUTE_32F_FAST_16F
130
+ #define CUBLAS_GEMM_DEFAULT MUBLAS_GEMM_DEFAULT
131
+ #define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT
132
+ #define CUBLAS_OP_N MUBLAS_OP_N
133
+ #define CUBLAS_OP_T MUBLAS_OP_T
134
+ #define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS
135
+ // #define CUBLAS_TF32_TENSOR_OP_MATH 0
136
+ #define CUDA_R_16F MUSA_R_16F
137
+ #define CUDA_R_32F MUSA_R_32F
138
+ // #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
139
+ // #define cublasComputeType_t mublasComputeType_t
140
+ #define cublasCreate mublasCreate
141
+ #define cublasDestroy mublasDestroy
142
+ #define cublasGemmEx mublasGemmEx
143
+ #define cublasGemmBatchedEx mublasGemmBatchedEx
144
+ #define cublasGemmStridedBatchedEx mublasGemmStridedBatchedEx
145
+ #define cublasHandle_t mublasHandle_t
146
+ // #define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
147
+ #define cublasSetMathMode mublasSetMathMode
148
+ #define cublasSetStream mublasSetStream
149
+ #define cublasSgemm mublasSgemm
150
+ #define cublasStatus_t mublasStatus_t
151
+ #define cudaDataType_t musaDataType_t // deprecated, new hipblasDatatype not in 5.6
152
+ #define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
153
+ #define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess
154
+ #define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess
155
+ #define cudaDeviceProp musaDeviceProp
156
+ #define cudaDeviceSynchronize musaDeviceSynchronize
157
+ #define cudaError_t musaError_t
158
+ #define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled
159
+ #define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled
160
+ #define cudaEventCreateWithFlags musaEventCreateWithFlags
161
+ #define cudaEventDisableTiming musaEventDisableTiming
162
+ #define cudaEventRecord musaEventRecord
163
+ #define cudaEventSynchronize musaEventSynchronize
164
+ #define cudaEvent_t musaEvent_t
165
+ #define cudaEventDestroy musaEventDestroy
166
+ #define cudaFree musaFree
167
+ #define cudaFreeHost musaFreeHost
168
+ #define cudaGetDevice musaGetDevice
169
+ #define cudaGetDeviceCount musaGetDeviceCount
170
+ #define cudaGetDeviceProperties musaGetDeviceProperties
171
+ #define cudaGetErrorString musaGetErrorString
172
+ #define cudaGetLastError musaGetLastError
173
+ #define cudaHostRegister musaHostRegister
174
+ #define cudaHostRegisterPortable musaHostRegisterPortable
175
+ #define cudaHostRegisterReadOnly musaHostRegisterReadOnly
176
+ #define cudaHostUnregister musaHostUnregister
177
+ #define cudaLaunchHostFunc musaLaunchHostFunc
178
+ #define cudaMalloc musaMalloc
179
+ #define cudaMallocHost musaMallocHost
180
+ #define cudaMemcpy musaMemcpy
181
+ #define cudaMemcpyAsync musaMemcpyAsync
182
+ #define cudaMemcpyPeerAsync musaMemcpyPeerAsync
183
+ #define cudaMemcpy2DAsync musaMemcpy2DAsync
184
+ #define cudaMemcpyDeviceToDevice musaMemcpyDeviceToDevice
185
+ #define cudaMemcpyDeviceToHost musaMemcpyDeviceToHost
186
+ #define cudaMemcpyHostToDevice musaMemcpyHostToDevice
187
+ #define cudaMemcpyKind musaMemcpyKind
188
+ #define cudaMemset musaMemset
189
+ #define cudaMemsetAsync musaMemsetAsync
190
+ #define cudaMemGetInfo musaMemGetInfo
191
+ #define cudaOccupancyMaxPotentialBlockSize musaOccupancyMaxPotentialBlockSize
192
+ #define cudaSetDevice musaSetDevice
193
+ #define cudaStreamCreateWithFlags musaStreamCreateWithFlags
194
+ #define cudaStreamDestroy musaStreamDestroy
195
+ #define cudaStreamFireAndForget musaStreamFireAndForget
196
+ #define cudaStreamNonBlocking musaStreamNonBlocking
197
+ #define cudaStreamPerThread musaStreamPerThread
198
+ #define cudaStreamSynchronize musaStreamSynchronize
199
+ #define cudaStreamWaitEvent musaStreamWaitEvent
200
+ #define cudaStream_t musaStream_t
201
+ #define cudaSuccess musaSuccess
202
+
203
+ // XXX: Other CUDA => MUSA mapping
204
+ #define CU_MEM_ACCESS_FLAGS_PROT_READWRITE MU_MEM_ACCESS_FLAGS_PROT_READWRITE
205
+ #define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED MU_MEM_ALLOC_GRANULARITY_RECOMMENDED
206
+ #define CU_MEM_ALLOCATION_TYPE_PINNED MU_MEM_ALLOCATION_TYPE_PINNED
207
+ #define CU_MEM_LOCATION_TYPE_DEVICE MU_MEM_LOCATION_TYPE_DEVICE
208
+ #define CUdevice MUdevice
209
+ #define CUdeviceptr MUdeviceptr
210
+ #define CUmemAccessDesc MUmemAccessDesc
211
+ #define CUmemAllocationProp MUmemAllocationProp
212
+ #define CUmemGenericAllocationHandle MUmemGenericAllocationHandle
213
+ #define cuDeviceGet muDeviceGet
214
+ #define cuDeviceGetAttribute muDeviceGetAttribute
215
+ #define cuMemAddressFree muMemAddressFree
216
+ #define cuMemAddressReserve muMemAddressReserve
217
+ #define cuMemCreate muMemCreate
218
+ #define cuMemGetAllocationGranularity muMemGetAllocationGranularity
219
+ #define cuMemMap muMemMap
220
+ #define cuMemRelease muMemRelease
221
+ #define cuMemSetAccess muMemSetAccess
222
+ #define cuMemUnmap muMemUnmap
223
+ #define cudaFuncAttributeMaxDynamicSharedMemorySize musaFuncAttributeMaxDynamicSharedMemorySize
224
+ #define cudaFuncSetAttribute musaFuncSetAttribute
225
+ #define cudaMemcpy3DPeerParms musaMemcpy3DPeerParms
226
+ #define make_cudaExtent make_musaExtent
227
+ #define make_cudaPitchedPtr make_musaPitchedPtr
228
+
229
+ // XXX: USE_CUDA_GRAPH
230
+ #define CUDA_SUCCESS MUSA_SUCCESS
231
+ #define CUresult MUresult
232
+ #define cuGetErrorString muGetErrorString
233
+ #define cudaErrorGraphExecUpdateFailure musaErrorGraphExecUpdateFailure
234
+ #define cudaErrorInvalidDeviceFunction musaErrorInvalidDeviceFunction
235
+ #define cudaGraphDestroy musaGraphDestroy
236
+ #define cudaGraphExecDestroy musaGraphExecDestroy
237
+ #define cudaGraphExec_t musaGraphExec_t
238
+ #define cudaGraphExecUpdate musaGraphExecUpdate
239
+ #define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult
240
+ #define cudaGraphGetNodes musaGraphGetNodes
241
+ #define cudaGraphInstantiate musaGraphInstantiate
242
+ #define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
243
+ #define cudaGraphKernelNodeSetParams musaGraphKernelNodeSetParams
244
+ #define cudaGraphLaunch musaGraphLaunch
245
+ #define cudaGraphNodeGetType musaGraphNodeGetType
246
+ #define cudaGraphNode_t musaGraphNode_t
247
+ #define cudaGraphNodeType musaGraphNodeType
248
+ #define cudaGraphNodeTypeKernel musaGraphNodeTypeKernel
249
+ #define cudaGraph_t musaGraph_t
250
+ #define cudaKernelNodeParams musaKernelNodeParams
251
+ #define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
252
+ #define cudaStreamEndCapture musaStreamEndCapture
253
+
254
+ // XXX: cuBLAS => muBLAS mapping
255
+ #define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED MU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
256
+ #define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_MATH_MODE_DEFAULT
257
+ #define CUBLAS_COMPUTE_16F CUDA_R_16F
258
+ #define CUBLAS_COMPUTE_32F CUDA_R_32F
259
+ #define cublasComputeType_t cudaDataType_t
260
+
261
+ // XXX: Clang builtins mapping
262
+ #define __vsubss4 __vsubss4_musa
263
+ #define __vsub4 __vsub4_musa
264
+ #define __vcmpeq4 __vcmpeq4_musa
265
+ #define __vcmpne4 __vcmpne4_musa
117
266
#else
118
267
#include < cuda_runtime.h>
119
268
#include < cuda.h>
@@ -168,9 +317,13 @@ void ggml_cuda_error(const char * stmt, const char * func, const char * file, in
168
317
169
318
#define CUDA_CHECK (err ) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString)
170
319
171
- #if CUDART_VERSION >= 12000
172
- static const char * cublas_get_error_str (const cublasStatus_t err) {
320
+ #if CUDART_VERSION >= 12000 || defined(GGML_USE_MUSA)
321
+ static const char * cublas_get_error_str (const mublasStatus_t err) {
322
+ #ifndef GGML_USE_MUSA
173
323
return cublasGetStatusString (err);
324
+ #else
325
+ return mublasStatus_to_string (err);
326
+ #endif // GGML_USE_MUSA
174
327
}
175
328
#else
176
329
static const char * cublas_get_error_str (const cublasStatus_t err) {
@@ -200,7 +353,7 @@ static const char * cu_get_error_str(CUresult err) {
200
353
#define CU_CHECK (err ) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
201
354
#endif
202
355
203
- #if CUDART_VERSION >= 11100
356
+ #if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA)
204
357
#define GGML_CUDA_ASSUME (x ) __builtin_assume(x)
205
358
#else
206
359
#define GGML_CUDA_ASSUME (x )
@@ -214,6 +367,62 @@ typedef float dfloat; // dequantize float
214
367
typedef float2 dfloat2;
215
368
#endif // GGML_CUDA_F16
216
369
370
+ #if defined(GGML_USE_MUSA)
371
+ #ifndef __has_builtin
372
+ #define __has_builtin (x ) 0
373
+ #endif
374
+
375
+ typedef int8_t int8x4_t __attribute__ ((ext_vector_type(4 )));
376
+ typedef uint8_t uint8x4_t __attribute__ ((ext_vector_type(4 )));
377
+ static __device__ __forceinline__ int __vsubss4_musa (const int a, const int b) {
378
+ const int8x4_t va = reinterpret_cast <const int8x4_t &>(a);
379
+ const int8x4_t vb = reinterpret_cast <const int8x4_t &>(b);
380
+ #if __has_builtin(__builtin_elementwise_sub_sat)
381
+ const int8x4_t c = __builtin_elementwise_sub_sat (va, vb);
382
+ return reinterpret_cast <const int &>(c);
383
+ #else
384
+ int8x4_t c;
385
+ int16_t tmp;
386
+ #pragma unroll
387
+ for (int i = 0 ; i < 4 ; i++) {
388
+ tmp = va[i] - vb[i];
389
+ if (tmp > std::numeric_limits<int8_t >::max ()) tmp = std::numeric_limits<int8_t >::max ();
390
+ if (tmp < std::numeric_limits<int8_t >::min ()) tmp = std::numeric_limits<int8_t >::min ();
391
+ c[i] = tmp;
392
+ }
393
+ return reinterpret_cast <int &>(c);
394
+ #endif // __has_builtin(__builtin_elementwise_sub_sat)
395
+ }
396
+
397
+ static __device__ __forceinline__ int __vsub4_musa (const int a, const int b) {
398
+ return __vsubss4_musa (a, b);
399
+ }
400
+
401
+ static __device__ __forceinline__ unsigned int __vcmpeq4_musa (unsigned int a, unsigned int b) {
402
+ const uint8x4_t & va = reinterpret_cast <const uint8x4_t &>(a);
403
+ const uint8x4_t & vb = reinterpret_cast <const uint8x4_t &>(b);
404
+ unsigned int c;
405
+ uint8x4_t & vc = reinterpret_cast <uint8x4_t &>(c);
406
+ #pragma unroll
407
+ for (int i = 0 ; i < 4 ; ++i) {
408
+ vc[i] = va[i] == vb[i] ? 0xff : 0x00 ;
409
+ }
410
+ return c;
411
+ }
412
+
413
+ static __device__ __forceinline__ unsigned int __vcmpne4_musa (unsigned int a, unsigned int b) {
414
+ const uint8x4_t & va = reinterpret_cast <const uint8x4_t &>(a);
415
+ const uint8x4_t & vb = reinterpret_cast <const uint8x4_t &>(b);
416
+ unsigned int c;
417
+ uint8x4_t & vc = reinterpret_cast <uint8x4_t &>(c);
418
+ #pragma unroll
419
+ for (int i = 0 ; i < 4 ; ++i) {
420
+ vc[i] = va[i] == vb[i] ? 0x00 : 0xff ;
421
+ }
422
+ return c;
423
+ }
424
+ #endif // defined(GGML_USE_MUSA)
425
+
217
426
#if defined(GGML_USE_HIPBLAS)
218
427
#define __CUDA_ARCH__ 1300
219
428
@@ -455,7 +664,7 @@ static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half
455
664
const uint32_t mask_high = 0xFFFF0000 * (float (__high2half (a)) > float (__high2half (b)));
456
665
return mask_low | mask_high;
457
666
}
458
- #endif // CUDART_VERSION < 12000
667
+ #endif // CUDART_VERSION < CUDART_HMASK
459
668
460
669
static __device__ __forceinline__ int ggml_cuda_dp4a (const int a, const int b, int c) {
461
670
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
0 commit comments