Skip to content

Commit 4dbad7f

Browse files
ggerganoviThalay
authored andcommitted
whisper : add full CUDA and Metal offloading (ggml-org#1472)
* whisper : migrate to ggml-backend * whisper : fix logit reading * whisper : fix tensor allocation during load * whisper : fix beam-search with CUDA * whisper : free backends + fix compile warning * whisper : print when CUDA is enabled * whisper : fix CoreML * make : clean-up * talk : fix compile warning * whisper : support ggml_conv with CUDA and Metal (ggml-org#1473) * ggml : add CUDA support for ggml_conv * whisper : remove ggml_repeat for conv bias + single backend * cuda : fix im2col kernel * metal : add im2col support + mul mat-vec f16 x f16 * bench-all : add q4 models * whisper : clean-up * quantize-all : fix * ggml : im2col opts * whisper : avoid whisper_model_data wrapper * whisper : add note that ggml_mul_mat_pad does not work with CUDA * whisper : factor out graph compute in common function * whisper : fixes * whisper : fix UB with measure buffers * whisper : try to fix the parallel whisper_state functionality (ggml-org#1479) * whisper : try to fix the parallel whisper_state functionality * whisper : fix multi-state Metal * whisper : free backend instances in whisper_state
1 parent 29647e5 commit 4dbad7f

File tree

14 files changed

+1029
-1707
lines changed

14 files changed

+1029
-1707
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
.DS_Store
99

1010
build/
11+
build-coreml/
1112
build-em/
1213
build-debug/
1314
build-release/

Makefile

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ ggml-backend.o: ggml-backend.c ggml.h ggml-backend.h
307307
ggml-quants.o: ggml-quants.c ggml.h ggml-quants.h
308308
$(CC) $(CFLAGS) -c $< -o $@
309309

310-
WHISPER_OBJ += ggml-alloc.o ggml-backend.o ggml-quants.o
310+
WHISPER_OBJ += ggml.o ggml-alloc.o ggml-backend.o ggml-quants.o
311311

312312
whisper.o: whisper.cpp whisper.h ggml.h ggml-cuda.h
313313
$(CXX) $(CXXFLAGS) -c $< -o $@
@@ -331,11 +331,11 @@ ggml-metal.o: ggml-metal.m ggml-metal.h
331331
WHISPER_OBJ += ggml-metal.o
332332
endif
333333

334-
libwhisper.a: ggml.o $(WHISPER_OBJ)
335-
$(AR) rcs libwhisper.a ggml.o $(WHISPER_OBJ)
334+
libwhisper.a: $(WHISPER_OBJ)
335+
$(AR) rcs libwhisper.a $(WHISPER_OBJ)
336336

337-
libwhisper.so: ggml.o $(WHISPER_OBJ)
338-
$(CXX) $(CXXFLAGS) -shared -o libwhisper.so ggml.o $(WHISPER_OBJ) $(LDFLAGS)
337+
libwhisper.so: $(WHISPER_OBJ)
338+
$(CXX) $(CXXFLAGS) -shared -o libwhisper.so $(WHISPER_OBJ) $(LDFLAGS)
339339

340340
clean:
341341
rm -f *.o main stream command talk talk-llama bench quantize lsp libwhisper.a libwhisper.so
@@ -349,30 +349,30 @@ CC_SDL=`sdl2-config --cflags --libs`
349349
SRC_COMMON = examples/common.cpp examples/common-ggml.cpp
350350
SRC_COMMON_SDL = examples/common-sdl.cpp
351351

352-
main: examples/main/main.cpp $(SRC_COMMON) ggml.o $(WHISPER_OBJ)
353-
$(CXX) $(CXXFLAGS) examples/main/main.cpp $(SRC_COMMON) ggml.o $(WHISPER_OBJ) -o main $(LDFLAGS)
352+
main: examples/main/main.cpp $(SRC_COMMON) $(WHISPER_OBJ)
353+
$(CXX) $(CXXFLAGS) examples/main/main.cpp $(SRC_COMMON) $(WHISPER_OBJ) -o main $(LDFLAGS)
354354
./main -h
355355

356-
bench: examples/bench/bench.cpp ggml.o $(WHISPER_OBJ)
357-
$(CXX) $(CXXFLAGS) examples/bench/bench.cpp ggml.o $(WHISPER_OBJ) -o bench $(LDFLAGS)
356+
bench: examples/bench/bench.cpp $(WHISPER_OBJ)
357+
$(CXX) $(CXXFLAGS) examples/bench/bench.cpp $(WHISPER_OBJ) -o bench $(LDFLAGS)
358358

359-
quantize: examples/quantize/quantize.cpp ggml.o $(WHISPER_OBJ) $(SRC_COMMON)
360-
$(CXX) $(CXXFLAGS) examples/quantize/quantize.cpp $(SRC_COMMON) ggml.o $(WHISPER_OBJ) -o quantize $(LDFLAGS)
359+
quantize: examples/quantize/quantize.cpp $(WHISPER_OBJ) $(SRC_COMMON)
360+
$(CXX) $(CXXFLAGS) examples/quantize/quantize.cpp $(SRC_COMMON) $(WHISPER_OBJ) -o quantize $(LDFLAGS)
361361

362-
stream: examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
363-
$(CXX) $(CXXFLAGS) examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o stream $(CC_SDL) $(LDFLAGS)
362+
stream: examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
363+
$(CXX) $(CXXFLAGS) examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o stream $(CC_SDL) $(LDFLAGS)
364364

365-
command: examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
366-
$(CXX) $(CXXFLAGS) examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o command $(CC_SDL) $(LDFLAGS)
365+
command: examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
366+
$(CXX) $(CXXFLAGS) examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o command $(CC_SDL) $(LDFLAGS)
367367

368-
lsp: examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
369-
$(CXX) $(CXXFLAGS) examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o lsp $(CC_SDL) $(LDFLAGS)
368+
lsp: examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
369+
$(CXX) $(CXXFLAGS) examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o lsp $(CC_SDL) $(LDFLAGS)
370370

371-
talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
372-
$(CXX) $(CXXFLAGS) examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o talk $(CC_SDL) $(LDFLAGS)
371+
talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
372+
$(CXX) $(CXXFLAGS) examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o talk $(CC_SDL) $(LDFLAGS)
373373

374-
talk-llama: examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
375-
$(CXX) $(CXXFLAGS) examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o talk-llama $(CC_SDL) $(LDFLAGS)
374+
talk-llama: examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
375+
$(CXX) $(CXXFLAGS) examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o talk-llama $(CC_SDL) $(LDFLAGS)
376376

377377
#
378378
# Audio samples

examples/common.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ class wav_writer {
181181
// It is assumed that PCM data is normalized to a range from -1 to 1
182182
bool write_audio(const float * data, size_t length) {
183183
for (size_t i = 0; i < length; ++i) {
184-
const auto intSample = static_cast<const int16_t>(data[i] * 32767);
184+
const int16_t intSample = data[i] * 32767;
185185
file.write(reinterpret_cast<const char *>(&intSample), sizeof(int16_t));
186186
dataSize += sizeof(int16_t);
187187
}

examples/talk/gpt-2.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,13 +121,13 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
121121
return false;
122122
}
123123

124-
std::string word;
124+
char word[129];
125+
125126
for (int i = 0; i < n_vocab; i++) {
126127
uint32_t len;
127128
fin.read((char *) &len, sizeof(len));
128-
129-
word.resize(len);
130-
fin.read((char *) word.data(), len);
129+
word[len] = '\0';
130+
fin.read((char *) word, len);
131131

132132
vocab.token_to_id[word] = i;
133133
vocab.id_to_token[i] = word;

extra/bench-all.sh

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@ else
1818
fi
1919

2020
models=( \
21-
"tiny" "tiny-q5_0" "tiny-q5_1" "tiny-q8_0" \
22-
"base" "base-q5_0" "base-q5_1" "base-q8_0" \
23-
"small" "small-q5_0" "small-q5_1" "small-q8_0" \
24-
"medium" "medium-q5_0" "medium-q5_1" "medium-q8_0" \
25-
"large" "large-q5_0" "large-q5_1" "large-q8_0" \
21+
"tiny" "tiny-q4_0" "tiny-q4_1" "tiny-q5_0" "tiny-q5_1" "tiny-q8_0" \
22+
"base" "base-q4_0" "base-q4_1" "base-q5_0" "base-q5_1" "base-q8_0" \
23+
"small" "small-q4_0" "small-q4_1" "small-q5_0" "small-q5_1" "small-q8_0" \
24+
"medium" "medium-q4_0" "medium-q4_1" "medium-q5_0" "medium-q5_1" "medium-q8_0" \
25+
"large" "large-q4_0" "large-q4_1" "large-q5_0" "large-q5_1" "large-q8_0" \
2626
)
2727

2828
if [ "$encoder_only" -eq 0 ]; then
@@ -83,6 +83,10 @@ for model in "${models[@]}"; do
8383
config="$config COREML"
8484
fi
8585

86+
if [[ $system_info == *"CUDA = 1"* ]]; then
87+
config="$config CUDA"
88+
fi
89+
8690
if [[ $system_info == *"METAL = 1"* ]]; then
8791
config="$config METAL"
8892
fi

extra/quantize-all.sh

Lines changed: 7 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -15,33 +15,13 @@ declare -a filedex
1515
cd `dirname $0`
1616
cd ../
1717

18-
# Let's loop across all the objects in the 'models' dir:
19-
for i in ./models/*; do
20-
# Check to see if it's a file or directory
21-
if [ -d "$i" ]; then
22-
# It's a directory! We should make sure it's not empty first:
23-
if [ "$(ls -A $i)" ]; then
24-
# Passed! Let's go searching for bin files (shouldn't need to go more than a layer deep here)
25-
for f in "$i"/*.bin; do
26-
# [Neuron Activation]
27-
newfile=`echo "${f##*/}" | cut -d _ -f 1`;
28-
if [ "$newfile" != "q5" ]; then
29-
./quantize "${f}" "${i:-4}/${i:9:${#i}-4}-${qtype1}.bin" ${qtype1};
30-
./quantize "${f}" "${i:-4}/${i:9:${#i}-4}-${qtype0}.bin" ${qtype0};
31-
filedex+=( "${i:-4}/${i:9:${#i}-4}-${qtype1}.bin" "${i:-4}/${i:9:${#i}-4}-${qtype0}.bin" )
32-
fi
33-
done
34-
fi
35-
else
36-
# It's a file! Let's make sure it's the right type:
37-
if [ "${i##*.}" == "bin" ]; then
38-
# And we probably want to skip the testing files
39-
if [ "${i:9:8}" != "for-test" ]; then
40-
# [Neuron Activation]
41-
./quantize "${i}" "${i:-4}-${qtype1}.bin" ${qtype1};
42-
./quantize "${i}" "${i:-4}-${qtype0}.bin" ${qtype0};
43-
filedex+=( "${i:-4}-${qtype1}.bin" "${i:-4}-${qtype0}.bin" )
44-
fi
18+
for i in `ls ./models | grep ^ggml-.*.bin | grep -v "\-q"`; do
19+
m="models/$i"
20+
if [ -f "$m" ]; then
21+
if [ "${m##*.}" == "bin" ]; then
22+
./quantize "${m}" "${m::${#m}-4}-${qtype1}.bin" ${qtype1};
23+
./quantize "${m}" "${m::${#m}-4}-${qtype0}.bin" ${qtype0};
24+
filedex+=( "${m::${#m}-4}-${qtype1}.bin" "${m::${#m}-4}-${qtype0}.bin" )
4525
fi
4626
fi
4727
done

ggml-cuda.cu

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4476,6 +4476,13 @@ static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
44764476
*dsti = __float2half(*xi);
44774477
}
44784478

4479+
static __device__ void cpy_1_f16_f16(const char * cxi, char * cdsti) {
4480+
const half * xi = (const half *) cxi;
4481+
half * dsti = (half *) cdsti;
4482+
4483+
*dsti = *xi;
4484+
}
4485+
44794486
template <cpy_kernel_t cpy_1>
44804487
static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
44814488
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
@@ -4729,6 +4736,25 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min,
47294736
dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
47304737
}
47314738

4739+
static __global__ void im2col_f32_f16(
4740+
const float * x, half * dst,
4741+
int ofs0, int ofs1, int IW, int IH, int CHW,
4742+
int s0, int s1, int p0, int p1, int d0, int d1) {
4743+
const int iiw = blockIdx.z * s0 + threadIdx.z * d0 - p0;
4744+
const int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1;
4745+
4746+
const int offset_dst =
4747+
(threadIdx.x * gridDim.y * gridDim.z + blockIdx.y * gridDim.z + blockIdx.z) * CHW +
4748+
(blockIdx.x * (blockDim.y * blockDim.z) + threadIdx.y * blockDim.z + threadIdx.z);
4749+
4750+
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
4751+
dst[offset_dst] = __float2half(0.0f);
4752+
} else {
4753+
const int offset_src = threadIdx.x * ofs0 + blockIdx.x * ofs1;
4754+
dst[offset_dst] = __float2half(x[offset_src + iih * IW + iiw]);
4755+
}
4756+
}
4757+
47324758
template<int qk, int qr, dequantize_kernel_t dq>
47334759
static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const int nrows, const int ncols, cudaStream_t stream) {
47344760
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
@@ -5618,6 +5644,16 @@ static void ggml_cpy_f32_f16_cuda(
56185644
(cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
56195645
}
56205646

5647+
static void ggml_cpy_f16_f16_cuda(
5648+
const char * cx, char * cdst, const int ne,
5649+
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
5650+
const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
5651+
5652+
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
5653+
cpy_f32_f16<cpy_1_f16_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
5654+
(cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
5655+
}
5656+
56215657
static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) {
56225658
const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
56235659
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
@@ -5701,6 +5737,15 @@ static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, c
57015737
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
57025738
}
57035739

5740+
static void im2col_f32_f16_cuda(const float * x, half * dst,
5741+
int OH, int IW, int IH, int OW, int IC,
5742+
int KH, int KW, int N, int ofs0, int ofs1,
5743+
int s0, int s1, int p0, int p1, int d0, int d1, cudaStream_t stream) {
5744+
dim3 block_nums(IC, OH, OW);
5745+
dim3 block_dims(N, KH, KW);
5746+
im2col_f32_f16<<<block_nums, block_dims, 0, stream>>>(x, dst, ofs0, ofs1, IW, IH, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
5747+
}
5748+
57045749
// buffer pool for cuda
57055750
#define MAX_CUDA_BUFFERS 256
57065751

@@ -6483,7 +6528,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
64836528
src1_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &src1_as, id, stream);
64846529
to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream);
64856530
}
6486-
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddq_i : src1_as_f16;
6531+
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16;
64876532
size_t dst_f16_as = 0;
64886533
half * dst_f16 = (half *) ggml_cuda_pool_malloc_async(row_diff*src1_ncols * sizeof(half), &dst_f16_as, id, stream);
64896534

@@ -6659,6 +6704,45 @@ inline void ggml_cuda_op_alibi(
66596704
(void) src1_dd;
66606705
}
66616706

6707+
inline void ggml_cuda_op_im2col(
6708+
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
6709+
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
6710+
6711+
GGML_ASSERT(src0->type == GGML_TYPE_F16);
6712+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
6713+
GGML_ASSERT( dst->type == GGML_TYPE_F16);
6714+
6715+
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
6716+
const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
6717+
const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
6718+
const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
6719+
const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
6720+
const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
6721+
6722+
const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
6723+
6724+
const int64_t N = src1->ne[is_2D ? 3 : 2];
6725+
const int64_t IC = src1->ne[is_2D ? 2 : 1];
6726+
const int64_t IH = is_2D ? src1->ne[1] : 1;
6727+
const int64_t IW = src1->ne[0];
6728+
6729+
const int64_t KH = is_2D ? src0->ne[1] : 1;
6730+
const int64_t KW = src0->ne[0];
6731+
6732+
const int64_t OH = is_2D ? dst->ne[2] : 1;
6733+
const int64_t OW = dst->ne[1];
6734+
6735+
const size_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
6736+
const size_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
6737+
6738+
im2col_f32_f16_cuda(src1_dd, (half*) dst_dd,
6739+
OH, IW, IH, OW, IC, KH, KW, N,
6740+
ofs0, ofs1, s0, s1, p0, p1, d0, d1, main_stream);
6741+
6742+
(void) src0;
6743+
(void) src0_dd;
6744+
}
6745+
66626746
inline void ggml_cuda_op_diag_mask_inf(
66636747
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
66646748
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
@@ -7549,6 +7633,9 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg
75497633
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
75507634
ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
75517635
ne10, ne11, nb10, nb11, nb12, main_stream);
7636+
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
7637+
ggml_cpy_f16_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
7638+
ne10, ne11, nb10, nb11, nb12, main_stream);
75527639
} else {
75537640
fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
75547641
ggml_type_name(src0->type), ggml_type_name(src1->type));
@@ -7580,6 +7667,10 @@ static void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1,
75807667
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi);
75817668
}
75827669

7670+
void ggml_cuda_im2col(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
7671+
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_im2col);
7672+
}
7673+
75837674
static void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
75847675
(void) src0;
75857676
(void) src1;
@@ -7943,6 +8034,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
79438034
case GGML_OP_ALIBI:
79448035
func = ggml_cuda_alibi;
79458036
break;
8037+
case GGML_OP_IM2COL:
8038+
func = ggml_cuda_im2col;
8039+
break;
79468040
default:
79478041
return false;
79488042
}

ggml-metal.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
#include <stdbool.h>
2727

2828
// max memory buffers that can be mapped to the device
29-
#define GGML_METAL_MAX_BUFFERS 16
29+
#define GGML_METAL_MAX_BUFFERS 64
3030
#define GGML_METAL_MAX_COMMAND_BUFFERS 32
3131

3232
struct ggml_tensor;

0 commit comments

Comments
 (0)