|
6 | 6 | #include <atomic>
|
7 | 7 | #include <assert.h>
|
8 | 8 |
|
| 9 | +#if defined(GGML_USE_HIPBLAS) |
| 10 | +#include <hip/hip_runtime.h> |
| 11 | +#include <hipblas/hipblas.h> |
| 12 | +#include <hip/hip_fp16.h> |
| 13 | +#include <rocblas/rocblas.h> |
| 14 | +#define CUBLAS_OP_N HIPBLAS_OP_N |
| 15 | +#define CUBLAS_OP_T HIPBLAS_OP_T |
| 16 | +#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS |
| 17 | +#define CUBLAS_TF32_TENSOR_OP_MATH 0 |
| 18 | +#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) |
| 19 | +#define cublasCreate hipblasCreate |
| 20 | +#define cublasGetStatusString rocblas_status_to_string |
| 21 | +#define cublasHandle_t hipblasHandle_t |
| 22 | +#define cublasLoggerConfigure(logIsOn, logToStdOut, logToStdErr, logFileName) CUBLAS_STATUS_SUCCESS |
| 23 | +#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS |
| 24 | +#define cublasSetStream hipblasSetStream |
| 25 | +#define cublasSgemm hipblasSgemm |
| 26 | +#define cublasStatus_t hipblasStatus_t |
| 27 | +#define cudaDeviceProp hipDeviceProp_t |
| 28 | +#define cudaDeviceSynchronize hipDeviceSynchronize |
| 29 | +#define cudaError_t hipError_t |
| 30 | +#define cudaEventCreateWithFlags hipEventCreateWithFlags |
| 31 | +#define cudaEventDestroy hipEventDestroy |
| 32 | +#define cudaEventDisableTiming hipEventDisableTiming |
| 33 | +#define cudaEventRecord hipEventRecord |
| 34 | +#define cudaEvent_t hipEvent_t |
| 35 | +#define cudaFree hipFree |
| 36 | +#define cudaFreeHost hipHostFree |
| 37 | +#define cudaGetDevice hipGetDevice |
| 38 | +#define cudaGetDeviceCount hipGetDeviceCount |
| 39 | +#define cudaGetDeviceProperties hipGetDeviceProperties |
| 40 | +#define cudaGetErrorString hipGetErrorString |
| 41 | +#define cudaGetLastError hipGetLastError |
| 42 | +#define cudaMalloc hipMalloc |
| 43 | +#define cudaMallocHost hipHostMalloc |
| 44 | +#define cudaMemcpy hipMemcpy |
| 45 | +#define cudaMemcpy2DAsync hipMemcpy2DAsync |
| 46 | +#define cudaMemcpyAsync hipMemcpyAsync |
| 47 | +#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice |
| 48 | +#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost |
| 49 | +#define cudaMemcpyHostToDevice hipMemcpyHostToDevice |
| 50 | +#define cudaMemcpyKind hipMemcpyKind |
| 51 | +#define cudaMemset hipMemset |
| 52 | +#define cudaSetDevice hipSetDevice |
| 53 | +#define cudaStreamCreateWithFlags hipStreamCreateWithFlags |
| 54 | +#define cudaStreamNonBlocking hipStreamNonBlocking |
| 55 | +#define cudaStreamWaitEvent(stream, event) hipStreamWaitEvent(stream, event, 0) |
| 56 | +#define cudaStream_t hipStream_t |
| 57 | +#define cudaSuccess hipSuccess |
| 58 | +#else |
9 | 59 | #include <cuda_runtime.h>
|
10 | 60 | #include <cublas_v2.h>
|
11 | 61 | #include <cuda_fp16.h>
|
| 62 | +#endif |
12 | 63 |
|
13 | 64 | #include "ggml-cuda.h"
|
14 | 65 | #include "ggml.h"
|
|
0 commit comments