Skip to content

Commit 1d1f7e2

Browse files
authored
whisper : initial hipBLAS support (ggml-org#1209)
1 parent 0164e3f commit 1d1f7e2

File tree

3 files changed

+98
-0
lines changed

3 files changed

+98
-0
lines changed

CMakeLists.txt

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ else()
6565
option(WHISPER_BLAS_VENDOR "whisper: BLAS library vendor" Generic)
6666
option(WHISPER_OPENBLAS "whisper: prefer OpenBLAS" OFF)
6767
option(WHISPER_CUBLAS "whisper: support for cuBLAS" OFF)
68+
option(WHISPER_HIPBLAS "whisper: support for hipBLAS" OFF)
6869
option(WHISPER_CLBLAST "whisper: use CLBlast" OFF)
6970
endif()
7071

@@ -191,6 +192,37 @@ if (WHISPER_CUBLAS)
191192
endif()
192193
endif()
193194

195+
196+
if (WHISPER_HIPBLAS)
197+
list(APPEND CMAKE_PREFIX_PATH /opt/rocm)
198+
if (NOT ${CMAKE_C_COMPILER_ID} MATCHES "Clang")
199+
message(WARNING "Only LLVM is supported for HIP, hint: CC=/opt/rocm/llvm/bin/clang")
200+
endif()
201+
if (NOT ${CMAKE_CXX_COMPILER_ID} MATCHES "Clang")
202+
message(WARNING "Only LLVM is supported for HIP, hint: CXX=/opt/rocm/llvm/bin/clang++")
203+
endif()
204+
205+
find_package(hip)
206+
find_package(hipblas)
207+
find_package(rocblas)
208+
209+
if (${hipblas_FOUND} AND ${hip_FOUND})
210+
message(STATUS "HIP and hipBLAS found")
211+
add_compile_definitions(GGML_USE_HIPBLAS GGML_USE_CUBLAS)
212+
add_library(ggml-rocm OBJECT ggml-cuda.cu ggml-cuda.h)
213+
set_property(TARGET ggml-rocm PROPERTY POSITION_INDEPENDENT_CODE ON)
214+
set_source_files_properties(ggml-cuda.cu PROPERTIES LANGUAGE CXX)
215+
target_link_libraries(ggml-rocm PRIVATE hip::device PUBLIC hip::host roc::rocblas roc::hipblas)
216+
217+
if (WHISPER_STATIC)
218+
message(FATAL_ERROR "Static linking not supported for HIP/ROCm")
219+
endif()
220+
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ggml-rocm)
221+
else()
222+
message(WARNING "hipBLAS or HIP not found. Try setting CMAKE_PREFIX_PATH=/opt/rocm")
223+
endif()
224+
endif()
225+
194226
if (WHISPER_CLBLAST)
195227
find_package(CLBlast)
196228
if (CLBlast_FOUND)

Makefile

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,21 @@ ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
161161
$(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@
162162
endif
163163

164+
ifdef WHISPER_HIPBLAS
165+
ROCM_PATH ?= /opt/rocm
166+
HIPCC ?= $(ROCM_PATH)/bin/hipcc
167+
GPU_TARGETS ?= $(shell $(ROCM_PATH)/llvm/bin/amdgpu-arch)
168+
CFLAGS += -DGGML_USE_HIPBLAS -DGGML_USE_CUBLAS
169+
CXXFLAGS += -DGGML_USE_HIPBLAS -DGGML_USE_CUBLAS
170+
LDFLAGS += -L$(ROCM_PATH)/lib -Wl,-rpath=$(ROCM_PATH)/lib
171+
LDFLAGS += -lhipblas -lamdhip64 -lrocblas
172+
HIPFLAGS += $(addprefix --offload-arch=,$(GPU_TARGETS))
173+
WHISPER_OBJ += ggml-cuda.o
174+
175+
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
176+
$(HIPCC) $(CXXFLAGS) $(HIPFLAGS) -x hip -c -o $@ $<
177+
endif
178+
164179
ifdef WHISPER_CLBLAST
165180
CFLAGS += -DGGML_USE_CLBLAST
166181
CXXFLAGS += -DGGML_USE_CLBLAST

ggml-cuda.cu

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,60 @@
66
#include <atomic>
77
#include <assert.h>
88

9+
#if defined(GGML_USE_HIPBLAS)
10+
#include <hip/hip_runtime.h>
11+
#include <hipblas/hipblas.h>
12+
#include <hip/hip_fp16.h>
13+
#include <rocblas/rocblas.h>
14+
#define CUBLAS_OP_N HIPBLAS_OP_N
15+
#define CUBLAS_OP_T HIPBLAS_OP_T
16+
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
17+
#define CUBLAS_TF32_TENSOR_OP_MATH 0
18+
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
19+
#define cublasCreate hipblasCreate
20+
#define cublasGetStatusString rocblas_status_to_string
21+
#define cublasHandle_t hipblasHandle_t
22+
#define cublasLoggerConfigure(logIsOn, logToStdOut, logToStdErr, logFileName) CUBLAS_STATUS_SUCCESS
23+
#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
24+
#define cublasSetStream hipblasSetStream
25+
#define cublasSgemm hipblasSgemm
26+
#define cublasStatus_t hipblasStatus_t
27+
#define cudaDeviceProp hipDeviceProp_t
28+
#define cudaDeviceSynchronize hipDeviceSynchronize
29+
#define cudaError_t hipError_t
30+
#define cudaEventCreateWithFlags hipEventCreateWithFlags
31+
#define cudaEventDestroy hipEventDestroy
32+
#define cudaEventDisableTiming hipEventDisableTiming
33+
#define cudaEventRecord hipEventRecord
34+
#define cudaEvent_t hipEvent_t
35+
#define cudaFree hipFree
36+
#define cudaFreeHost hipHostFree
37+
#define cudaGetDevice hipGetDevice
38+
#define cudaGetDeviceCount hipGetDeviceCount
39+
#define cudaGetDeviceProperties hipGetDeviceProperties
40+
#define cudaGetErrorString hipGetErrorString
41+
#define cudaGetLastError hipGetLastError
42+
#define cudaMalloc hipMalloc
43+
#define cudaMallocHost hipHostMalloc
44+
#define cudaMemcpy hipMemcpy
45+
#define cudaMemcpy2DAsync hipMemcpy2DAsync
46+
#define cudaMemcpyAsync hipMemcpyAsync
47+
#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice
48+
#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
49+
#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
50+
#define cudaMemcpyKind hipMemcpyKind
51+
#define cudaMemset hipMemset
52+
#define cudaSetDevice hipSetDevice
53+
#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
54+
#define cudaStreamNonBlocking hipStreamNonBlocking
55+
#define cudaStreamWaitEvent(stream, event) hipStreamWaitEvent(stream, event, 0)
56+
#define cudaStream_t hipStream_t
57+
#define cudaSuccess hipSuccess
58+
#else
959
#include <cuda_runtime.h>
1060
#include <cublas_v2.h>
1161
#include <cuda_fp16.h>
62+
#endif
1263

1364
#include "ggml-cuda.h"
1465
#include "ggml.h"

0 commit comments

Comments
 (0)