Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ build-em/
build-debug/
build-release/
build-static/
build-cublas/
build-no-accel/
build-sanitize-addr/
build-sanitize-thread/
Expand Down
39 changes: 36 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ option(WHISPER_SANITIZE_UNDEFINED "whisper: enable undefined sanitizer" OFF)
option(WHISPER_BUILD_TESTS "whisper: build tests" ${WHISPER_STANDALONE})
option(WHISPER_BUILD_EXAMPLES "whisper: build examples" ${WHISPER_STANDALONE})

option(WHISPER_SUPPORT_SDL2 "whisper: support for libSDL2" OFF)
option(WHISPER_SDL2 "whisper: support for libSDL2" OFF)

if (APPLE)
option(WHISPER_NO_ACCELERATE "whisper: disable Accelerate framework" OFF)
Expand All @@ -62,7 +62,8 @@ if (APPLE)
option(WHISPER_COREML "whisper: enable Core ML framework" OFF)
option(WHISPER_COREML_ALLOW_FALLBACK "whisper: allow non-CoreML fallback" OFF)
else()
option(WHISPER_SUPPORT_OPENBLAS "whisper: support for OpenBLAS" OFF)
option(WHISPER_OPENBLAS "whisper: support for OpenBLAS" OFF)
option(WHISPER_CUBLAS "whisper: support for cuBLAS" OFF)
endif()

option(WHISPER_PERF "whisper: enable perf timings" OFF)
Expand Down Expand Up @@ -127,7 +128,7 @@ if (APPLE)
endif()
endif()

if (WHISPER_SUPPORT_OPENBLAS)
if (WHISPER_OPENBLAS)
find_library(OPENBLAS_LIB
NAMES openblas libopenblas
)
Expand All @@ -141,6 +142,31 @@ if (WHISPER_SUPPORT_OPENBLAS)
endif()
endif()

if (WHISPER_CUBLAS)
cmake_minimum_required(VERSION 3.17)

find_package(CUDAToolkit)

if (CUDAToolkit_FOUND)
message(STATUS "cuBLAS found")

enable_language(CUDA)

set(GGML_CUDA_SOURCES ggml-cuda.cu ggml-cuda.h)

add_compile_definitions(GGML_USE_CUBLAS)

if (WHISPER_STATIC)
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
else()
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
endif()

else()
message(WARNING "cuBLAS not found")
endif()
endif()

# compiler flags

if (NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES)
Expand Down Expand Up @@ -247,6 +273,7 @@ set(TARGET whisper)
add_library(${TARGET}
ggml.h
ggml.c
${GGML_CUDA_SOURCES}
whisper.h
whisper.cpp
)
Expand Down Expand Up @@ -279,6 +306,12 @@ if (BUILD_SHARED_LIBS)
)
endif()

if (GGML_CUDA_SOURCES)
message(STATUS "GGML CUDA sources found, configuring CUDA architecture")
set_property(TARGET whisper PROPERTY CUDA_ARCHITECTURES OFF)
set_property(TARGET whisper PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto")
endif()

if (EMSCRIPTEN)
set_target_properties(${TARGET} PROPERTIES COMPILE_FLAGS "-msimd128")
endif()
Expand Down
28 changes: 20 additions & 8 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
default: main bench

ifndef UNAME_S
UNAME_S := $(shell uname -s)
endif
Expand Down Expand Up @@ -157,6 +159,18 @@ ifdef WHISPER_OPENBLAS
LDFLAGS += -lopenblas
endif

ifdef WHISPER_CUBLAS
CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include
CXXFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include
LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib
WHISPER_OBJ += ggml-cuda.o
NVCC = nvcc
NVCCFLAGS = --forward-unknown-to-host-compiler -arch=native

ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
$(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@
endif

ifdef WHISPER_GPROF
CFLAGS += -pg
CXXFLAGS += -pg
Expand Down Expand Up @@ -200,28 +214,26 @@ $(info I CC: $(CCV))
$(info I CXX: $(CXXV))
$(info )

default: main bench

#
# Build library
#

ggml.o: ggml.c ggml.h
$(CC) $(CFLAGS) -c ggml.c -o ggml.o
ggml.o: ggml.c ggml.h ggml-cuda.h
$(CC) $(CFLAGS) -c $< -o $@

whisper.o: whisper.cpp whisper.h ggml.h
$(CXX) $(CXXFLAGS) -c whisper.cpp -o whisper.o
whisper.o: whisper.cpp whisper.h ggml.h ggml-cuda.h
$(CXX) $(CXXFLAGS) -c $< -o $@

ifndef WHISPER_COREML
WHISPER_OBJ = whisper.o
WHISPER_OBJ += whisper.o
else
whisper-encoder.o: coreml/whisper-encoder.mm coreml/whisper-encoder.h
$(CXX) -O3 -I . -c coreml/whisper-encoder.mm -o whisper-encoder.o

whisper-encoder-impl.o: coreml/whisper-encoder-impl.m coreml/whisper-encoder-impl.h
$(CXX) -O3 -I . -fobjc-arc -c coreml/whisper-encoder-impl.m -o whisper-encoder-impl.o

WHISPER_OBJ = whisper.o whisper-encoder.o whisper-encoder-impl.o
WHISPER_OBJ += whisper.o whisper-encoder.o whisper-encoder-impl.o
endif

libwhisper.a: ggml.o $(WHISPER_OBJ)
Expand Down
26 changes: 20 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisp
- Low memory usage (Flash Attention)
- Zero memory allocations at runtime
- Runs on the CPU
- [Partial GPU support for NVIDIA via cuBLAS](https://github.com/ggerganov/whisper.cpp#nvidia-gpu-support-via-cublas)
- [C-style API](https://github.com/ggerganov/whisper.cpp/blob/master/whisper.h)

Supported platforms:
Expand Down Expand Up @@ -254,7 +255,7 @@ speed-up - more than x3 faster compared with CPU-only execution. Here are the in
# using Makefile
make clean
WHISPER_COREML=1 make -j

# using CMake
cd build
cmake -DWHISPER_COREML=1 ..
Expand All @@ -271,20 +272,33 @@ speed-up - more than x3 faster compared with CPU-only execution. Here are the in
whisper_init_state: first run on a device may take a while ...
whisper_init_state: Core ML model loaded

system_info: n_threads = 4 / 10 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 | COREML = 1 |
system_info: n_threads = 4 / 10 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 | COREML = 1 |

...
```

The first run on a device is slow, since the ANE service compiles the Core ML model to some device-specific format.
Next runs are faster.

For more information about the Core ML implementation please refer to PR [#566](https://github.com/ggerganov/whisper.cpp/pull/566).


## NVIDIA GPU support via cuBLAS

With NVIDIA cards, the Encoder processing can be offloaded to the GPU to a large extend through cuBLAS.
First, make sure you have installed `cuda`: https://developer.nvidia.com/cuda-downloads

Now build `whisper.cpp` with cuBLAS support:

```
make clean
WHISPER_CUBLAS=1 make -j
```

Run all the examples as usual.

## Limitations

- Inference only
- No GPU support (yet)

## Another example

Expand Down Expand Up @@ -429,7 +443,7 @@ system_info: n_threads = 4 / 10 | AVX2 = 0 | AVX512 = 0 | NEON = 1 | FP16_VA = 1

main: processing './samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, 1 processors, lang = en, task = transcribe, timestamps = 1 ...

[00:00:00.000 --> 00:00:00.320]
[00:00:00.000 --> 00:00:00.320]
[00:00:00.320 --> 00:00:00.370] And
[00:00:00.370 --> 00:00:00.690] so
[00:00:00.690 --> 00:00:00.850] my
Expand Down
4 changes: 2 additions & 2 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ find_package(Threads REQUIRED)

# third-party

if (WHISPER_SUPPORT_SDL2)
if (WHISPER_SDL2)
# SDL2
find_package(SDL2 REQUIRED)

Expand All @@ -27,7 +27,7 @@ include(DefaultTargetOptions)

set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)

if (WHISPER_SUPPORT_SDL2)
if (WHISPER_SDL2)
# common-sdl

set(TARGET common-sdl)
Expand Down
2 changes: 1 addition & 1 deletion examples/command/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
if (WHISPER_SUPPORT_SDL2)
if (WHISPER_SDL2)
# command
set(TARGET command)
add_executable(${TARGET} command.cpp)
Expand Down
2 changes: 1 addition & 1 deletion examples/stream/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
if (WHISPER_SUPPORT_SDL2)
if (WHISPER_SDL2)
# stream
set(TARGET stream)
add_executable(${TARGET} stream.cpp)
Expand Down
2 changes: 1 addition & 1 deletion examples/talk-llama/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
if (WHISPER_SUPPORT_SDL2)
if (WHISPER_SDL2)
# talk-llama
set(TARGET talk-llama)
#add_executable(${TARGET} talk-llama.cpp llama.cpp)
Expand Down
2 changes: 1 addition & 1 deletion examples/talk/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
if (WHISPER_SUPPORT_SDL2)
if (WHISPER_SDL2)
# talk
set(TARGET talk)
#add_executable(${TARGET} talk.cpp gpt-2.cpp)
Expand Down
37 changes: 14 additions & 23 deletions whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ static void byteswap_tensor(ggml_tensor * tensor) {
#define WHISPER_PRINT_DEBUG(...)
#endif

#define WHISPER_USE_FLASH_ATTN
//#define WHISPER_USE_FLASH_ATTN
//#define WHISPER_USE_FLASH_FF
#define WHISPER_MAX_DECODERS 16

Expand Down Expand Up @@ -224,11 +224,11 @@ static const std::map<std::string, std::pair<int, std::string>> g_lang = {
static const size_t MB = 1ull*1024*1024;

static const std::map<e_model, size_t> MEM_REQ_SCRATCH0 = {
{ MODEL_TINY, 14ull*MB },
{ MODEL_BASE, 18ull*MB },
{ MODEL_SMALL, 28ull*MB },
{ MODEL_MEDIUM, 36ull*MB },
{ MODEL_LARGE, 44ull*MB },
{ MODEL_TINY, 62ull*MB },
{ MODEL_BASE, 80ull*MB },
{ MODEL_SMALL, 120ull*MB },
{ MODEL_MEDIUM, 158ull*MB },
{ MODEL_LARGE, 198ull*MB },
};

static const std::map<e_model, size_t> MEM_REQ_SCRATCH1 = {
Expand Down Expand Up @@ -280,11 +280,11 @@ static const std::map<e_model, size_t> MEM_REQ_KV_CROSS = {
};

static const std::map<e_model, size_t> MEM_REQ_ENCODE = {
{ MODEL_TINY, 6ull*MB },
{ MODEL_BASE, 8ull*MB },
{ MODEL_SMALL, 13ull*MB },
{ MODEL_MEDIUM, 22ull*MB },
{ MODEL_LARGE, 33ull*MB },
{ MODEL_TINY, 30ull*MB },
{ MODEL_BASE, 38ull*MB },
{ MODEL_SMALL, 56ull*MB },
{ MODEL_MEDIUM, 74ull*MB },
{ MODEL_LARGE, 94ull*MB },
};

static const std::map<e_model, size_t> MEM_REQ_DECODE = {
Expand Down Expand Up @@ -1554,26 +1554,17 @@ static bool whisper_encode_internal(

struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_scaled);

//struct ggml_tensor * V_trans =
// ggml_permute(ctx0,
// ggml_cpy(ctx0,
// Vcur,
// ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)),
// 1, 2, 0, 3);

//struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);

struct ggml_tensor * V =
ggml_cpy(ctx0,
ggml_permute(ctx0,
ggml_reshape_3d(ctx0,
Vcur,
n_state/n_head, n_head, n_ctx),
0, 2, 1, 3),
ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_ctx, n_head)
1, 2, 0, 3),
ggml_new_tensor_3d(ctx0, wctx.wtype, n_ctx, n_state/n_head, n_head)
);

struct ggml_tensor * KQV = ggml_mul_mat(ctx0, ggml_transpose(ctx0, V), KQ_soft_max);
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
#endif
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);

Expand Down