@@ -220,6 +220,7 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
220
220
case CUBLAS_STATUS_LICENSE_ERROR: return "CUBLAS_STATUS_LICENSE_ERROR";
221
221
default: return "unknown error";
222
222
}
223
+ }
223
224
#endif // CUDART_VERSION >= 12000
224
225
225
226
static const char * cu_get_error_str(CUresult err) {
@@ -6739,6 +6740,39 @@ static void ggml_cuda_pool_free(void * ptr, size_t size) {
6739
6740
#define ggml_cuda_pool_free ggml_cuda_pool_free_leg
6740
6741
#endif
6741
6742
6743
+ template<typename T>
6744
+ struct cuda_pool_alloc {
6745
+ T * ptr = nullptr;
6746
+ size_t act_size = 0;
6747
+
6748
+ // size is in number of elements
6749
+ T * alloc(size_t size) {
6750
+ GGML_ASSERT(ptr == nullptr);
6751
+ ptr = (T *) ggml_cuda_pool_malloc(size * sizeof(T), &this->act_size);
6752
+ return ptr;
6753
+ }
6754
+
6755
+ cuda_pool_alloc(size_t size) {
6756
+ alloc(size);
6757
+ }
6758
+
6759
+ ~cuda_pool_alloc() {
6760
+ if (ptr != nullptr) {
6761
+ ggml_cuda_pool_free(ptr, act_size);
6762
+ }
6763
+ }
6764
+
6765
+ T * get() {
6766
+ return ptr;
6767
+ }
6768
+
6769
+ cuda_pool_alloc() = default;
6770
+ cuda_pool_alloc(const cuda_pool_alloc &) = delete;
6771
+ cuda_pool_alloc(cuda_pool_alloc &&) = delete;
6772
+ cuda_pool_alloc& operator=(const cuda_pool_alloc &) = delete;
6773
+ cuda_pool_alloc& operator=(cuda_pool_alloc &&) = delete;
6774
+ };
6775
+
6742
6776
static bool g_cublas_loaded = false;
6743
6777
6744
6778
bool ggml_cublas_loaded(void) {
@@ -7432,16 +7466,16 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
7432
7466
7433
7467
// on some GPUs it is faster to convert src1 to half and to use half precision intrinsics
7434
7468
#ifdef GGML_CUDA_F16
7435
- size_t ash ;
7436
- dfloat * src1_dfloat = nullptr ; // dfloat == half
7469
+ cuda_pool_alloc<half> src1_dfloat_a ;
7470
+ half * src1_dfloat = nullptr; // dfloat == half
7437
7471
7438
7472
bool src1_convert_f16 =
7439
7473
src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
7440
7474
src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 ||
7441
7475
src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16;
7442
7476
7443
7477
if (src1_convert_f16) {
7444
- src1_dfloat = (half *) ggml_cuda_pool_malloc ( ne00* sizeof (half), &ash );
7478
+ src1_dfloat = src1_dfloat_a.alloc( ne00);
7445
7479
ggml_cpy_f32_f16_cuda((const char *) src1_ddf_i, (char *) src1_dfloat, ne00,
7446
7480
ne00, 1, sizeof(float), 0, 0,
7447
7481
ne00, 1, sizeof(half), 0, 0, stream);
@@ -7489,12 +7523,6 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
7489
7523
break;
7490
7524
}
7491
7525
7492
- #ifdef GGML_CUDA_F16
7493
- if (src1_convert_f16) {
7494
- ggml_cuda_pool_free (src1_dfloat, ash);
7495
- }
7496
- #endif // GGML_CUDA_F16
7497
-
7498
7526
(void) src1;
7499
7527
(void) dst;
7500
7528
(void) src1_ddq_i;
@@ -7529,29 +7557,26 @@ inline void ggml_cuda_op_mul_mat_cublas(
7529
7557
7530
7558
if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
7531
7559
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
7532
- half * src0_as_f16 = nullptr ;
7533
- size_t src0_as = 0 ;
7560
+ cuda_pool_alloc<half> src0_as_f16;
7534
7561
if (src0->type != GGML_TYPE_F16) {
7535
7562
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type);
7536
7563
GGML_ASSERT(to_fp16_cuda != nullptr);
7537
7564
size_t ne = row_diff*ne00;
7538
- src0_as_f16 = (half *) ggml_cuda_pool_malloc (ne * sizeof (half), &src0_as );
7539
- to_fp16_cuda (src0_dd_i, src0_as_f16, ne, stream);
7565
+ src0_as_f16.alloc(ne );
7566
+ to_fp16_cuda(src0_dd_i, src0_as_f16.get() , ne, stream);
7540
7567
}
7541
- const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16;
7568
+ const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16.get() ;
7542
7569
7543
- half * src1_as_f16 = nullptr ;
7544
- size_t src1_as = 0 ;
7570
+ cuda_pool_alloc<half> src1_as_f16;
7545
7571
if (src1->type != GGML_TYPE_F16) {
7546
7572
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
7547
7573
GGML_ASSERT(to_fp16_cuda != nullptr);
7548
7574
size_t ne = src1_ncols*ne10;
7549
- src1_as_f16 = (half *) ggml_cuda_pool_malloc (ne * sizeof (half), &src1_as );
7550
- to_fp16_cuda (src1_ddf_i, src1_as_f16, ne, stream);
7575
+ src1_as_f16.alloc(ne );
7576
+ to_fp16_cuda(src1_ddf_i, src1_as_f16.get() , ne, stream);
7551
7577
}
7552
- const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16;
7553
- size_t dst_as = 0 ;
7554
- half * dst_f16 = (half *) ggml_cuda_pool_malloc (row_diff*src1_ncols * sizeof (half), &dst_as);
7578
+ const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16.get();
7579
+ cuda_pool_alloc<half> dst_f16(row_diff*src1_ncols);
7555
7580
7556
7581
const half alpha_f16 = 1.0f;
7557
7582
const half beta_f16 = 0.0f;
@@ -7560,36 +7585,25 @@ inline void ggml_cuda_op_mul_mat_cublas(
7560
7585
CUBLAS_CHECK(
7561
7586
cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
7562
7587
row_diff, src1_ncols, ne10,
7563
- &alpha_f16, src0_ptr, CUDA_R_16F, ne00,
7564
- src1_ptr, CUDA_R_16F, ne10,
7565
- &beta_f16, dst_f16, CUDA_R_16F, ldc,
7588
+ &alpha_f16, src0_ptr, CUDA_R_16F, ne00,
7589
+ src1_ptr, CUDA_R_16F, ne10,
7590
+ &beta_f16, dst_f16.get() , CUDA_R_16F, ldc,
7566
7591
CUBLAS_COMPUTE_16F,
7567
7592
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
7568
7593
7569
7594
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
7570
- to_fp32_cuda (dst_f16, dst_dd_i, row_diff*src1_ncols, stream);
7571
-
7572
- ggml_cuda_pool_free (dst_f16, dst_as);
7573
-
7574
- if (src1_as != 0 ) {
7575
- ggml_cuda_pool_free (src1_as_f16, src1_as);
7576
- }
7577
-
7578
- if (src0_as != 0 ) {
7579
- ggml_cuda_pool_free (src0_as_f16, src0_as);
7580
- }
7595
+ to_fp32_cuda(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
7581
7596
}
7582
7597
else {
7583
- float * src0_ddq_as_f32 = nullptr ;
7584
- size_t src0_as = 0 ;
7598
+ cuda_pool_alloc<float> src0_ddq_as_f32;
7585
7599
7586
7600
if (src0->type != GGML_TYPE_F32) {
7587
7601
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
7588
7602
GGML_ASSERT(to_fp32_cuda != nullptr);
7589
- src0_ddq_as_f32 = ( float *) ggml_cuda_pool_malloc ( row_diff*ne00 * sizeof ( float ), &src0_as); // NOLINT
7590
- to_fp32_cuda (src0_dd_i, src0_ddq_as_f32, row_diff*ne00, stream);
7603
+ src0_ddq_as_f32.alloc( row_diff*ne00);
7604
+ to_fp32_cuda(src0_dd_i, src0_ddq_as_f32.get() , row_diff*ne00, stream);
7591
7605
}
7592
- const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32;
7606
+ const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get() ;
7593
7607
7594
7608
const float alpha = 1.0f;
7595
7609
const float beta = 0.0f;
@@ -7601,10 +7615,6 @@ inline void ggml_cuda_op_mul_mat_cublas(
7601
7615
&alpha, src0_ddf_i, ne00,
7602
7616
src1_ddf_i, ne10,
7603
7617
&beta, dst_dd_i, ldc));
7604
-
7605
- if (src0_as != 0 ) {
7606
- ggml_cuda_pool_free (src0_ddq_as_f32, src0_as);
7607
- }
7608
7618
}
7609
7619
7610
7620
(void) dst;
@@ -7896,33 +7906,32 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s
7896
7906
float * src1_ddf = nullptr;
7897
7907
float * dst_ddf = nullptr;
7898
7908
7899
- // as = actual size
7900
- size_t src0_asf = 0 ;
7901
- size_t src1_asf = 0 ;
7902
- size_t dst_asf = 0 ;
7909
+ cuda_pool_alloc<float> src0_f;
7910
+ cuda_pool_alloc<float> src1_f;
7911
+ cuda_pool_alloc<float> dst_f;
7903
7912
7904
7913
ggml_cuda_set_device(g_main_device);
7905
- const cudaStream_t main_stream = g_cudaStreams[g_main_device][0 ];
7914
+ cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
7906
7915
7907
7916
if (src0_on_device) {
7908
7917
src0_ddf = (float *) src0_extra->data_device[g_main_device];
7909
7918
} else {
7910
- src0_ddf = ( float *) ggml_cuda_pool_malloc ( ggml_nbytes ( src0), &src0_asf );
7919
+ src0_ddf = src0_f.alloc(ggml_nelements( src0));
7911
7920
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf, src0, 0, 0, 0, nrows0, main_stream));
7912
7921
}
7913
7922
7914
7923
if (use_src1) {
7915
7924
if (src1_on_device) {
7916
7925
src1_ddf = (float *) src1_extra->data_device[g_main_device];
7917
7926
} else {
7918
- src1_ddf = ( float *) ggml_cuda_pool_malloc ( ggml_nbytes ( src1), &src1_asf );
7927
+ src1_ddf = src1_f.alloc(ggml_nelements( src1));
7919
7928
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf, src1, 0, 0, 0, nrows1, main_stream));
7920
7929
}
7921
7930
}
7922
7931
if (dst_on_device) {
7923
7932
dst_ddf = (float *) dst_extra->data_device[g_main_device];
7924
7933
} else {
7925
- dst_ddf = ( float *) ggml_cuda_pool_malloc ( ggml_nbytes ( dst), &dst_asf );
7934
+ dst_ddf = dst_f.alloc(ggml_nelements( dst));
7926
7935
}
7927
7936
7928
7937
// do the computation
@@ -7934,16 +7943,6 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s
7934
7943
CUDA_CHECK(cudaMemcpyAsync(dst->data, dst_ddf, ggml_nbytes(dst), cudaMemcpyDeviceToHost, main_stream));
7935
7944
}
7936
7945
7937
- if (dst_asf > 0 ) {
7938
- ggml_cuda_pool_free (dst_ddf, dst_asf);
7939
- }
7940
- if (src1_asf > 0 ) {
7941
- ggml_cuda_pool_free (src1_ddf, src1_asf);
7942
- }
7943
- if (src0_asf > 0 ) {
7944
- ggml_cuda_pool_free (src0_ddf, src0_asf);
7945
- }
7946
-
7947
7946
if (dst->backend == GGML_BACKEND_CPU) {
7948
7947
CUDA_CHECK(cudaDeviceSynchronize());
7949
7948
}
@@ -8516,14 +8515,11 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
8516
8515
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
8517
8516
GGML_ASSERT(to_fp16_cuda != nullptr);
8518
8517
8519
- size_t src1_as = 0 ;
8520
- half * src1_as_f16 = (half *) ggml_cuda_pool_malloc (ne1 * sizeof (half), &src1_as);
8521
- to_fp16_cuda (src1_ddf, src1_as_f16, ne1, main_stream);
8518
+ cuda_pool_alloc<half> src1_as_f16(ne1);
8519
+ to_fp16_cuda(src1_ddf, src1_as_f16.get(), ne1, main_stream);
8522
8520
8523
- size_t dst_as = 0 ;
8524
-
8525
- half * dst_f16 = nullptr ;
8526
- char * dst_t = nullptr ;
8521
+ cuda_pool_alloc<half> dst_f16;
8522
+ char * dst_t;
8527
8523
8528
8524
cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
8529
8525
cudaDataType_t cu_data_type = CUDA_R_16F;
@@ -8542,8 +8538,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
8542
8538
const void * beta = &beta_f16;
8543
8539
8544
8540
if (dst->op_params[0] == GGML_PREC_DEFAULT) {
8545
- dst_f16 = (half *) ggml_cuda_pool_malloc (ne * sizeof (half), &dst_as);
8546
- dst_t = (char *) dst_f16;
8541
+ dst_t = (char *) dst_f16.alloc(ne);
8547
8542
8548
8543
nbd2 /= sizeof(float) / sizeof(half);
8549
8544
nbd3 /= sizeof(float) / sizeof(half);
@@ -8590,29 +8585,23 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
8590
8585
CUBLAS_CHECK(
8591
8586
cublasGemmStridedBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
8592
8587
ne01, ne11, ne10,
8593
- alpha, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof (half), src0->nb [2 ]/sizeof (half), // strideA
8594
- (const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof (float ), src1->nb [2 ]/sizeof (float ), // strideB
8595
- beta, ( char *) dst_t , cu_data_type, ne01, dst->nb [2 ]/sizeof (float ), // strideC
8588
+ alpha, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), src0->nb[2]/sizeof(half), // strideA
8589
+ (const char *) src1_as_f16.get() , CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB
8590
+ beta, ( char *) dst_t, cu_data_type, ne01, dst->nb[2]/sizeof(float), // strideC
8596
8591
ne12*ne13,
8597
8592
cu_compute_type,
8598
8593
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
8599
8594
} else {
8600
8595
// use cublasGemmBatchedEx
8601
8596
const int ne23 = ne12*ne13;
8602
8597
8603
- const void ** ptrs_src = nullptr ;
8604
- void ** ptrs_dst = nullptr ;
8605
-
8606
- size_t ptrs_src_s = 0 ;
8607
- size_t ptrs_dst_s = 0 ;
8608
-
8609
- ptrs_src = (const void **) ggml_cuda_pool_malloc (2 *ne23*sizeof (void *), &ptrs_src_s);
8610
- ptrs_dst = ( void **) ggml_cuda_pool_malloc (1 *ne23*sizeof (void *), &ptrs_dst_s);
8598
+ cuda_pool_alloc<const void *> ptrs_src(2*ne23);
8599
+ cuda_pool_alloc< void *> ptrs_dst(1*ne23);
8611
8600
8612
8601
dim3 block_dims(ne13, ne12);
8613
8602
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
8614
- src0_as_f16, src1_as_f16, dst_t ,
8615
- ptrs_src, ptrs_dst,
8603
+ src0_as_f16, src1_as_f16.get() , dst_t,
8604
+ ptrs_src.get() , ptrs_dst.get() ,
8616
8605
ne12, ne13,
8617
8606
ne23,
8618
8607
nb02, nb03,
@@ -8624,30 +8613,19 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
8624
8613
CUBLAS_CHECK(
8625
8614
cublasGemmBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
8626
8615
ne01, ne11, ne10,
8627
- alpha, (const void **) (ptrs_src + 0 *ne23), CUDA_R_16F, nb01/sizeof (half),
8628
- (const void **) (ptrs_src + 1 *ne23), CUDA_R_16F, nb11/sizeof (float ),
8629
- beta, ( void **) (ptrs_dst + 0 *ne23), cu_data_type, ne01,
8616
+ alpha, (const void **) (ptrs_src.get() + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
8617
+ (const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
8618
+ beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne01,
8630
8619
ne23,
8631
8620
cu_compute_type,
8632
8621
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
8633
-
8634
- if (ptrs_dst_s != 0 ) {
8635
- ggml_cuda_pool_free (ptrs_dst, ptrs_dst_s);
8636
- }
8637
- if (ptrs_src_s != 0 ) {
8638
- ggml_cuda_pool_free (ptrs_src, ptrs_src_s);
8639
- }
8640
8622
}
8641
8623
#endif
8642
8624
8643
8625
if (dst->op_params[0] == GGML_PREC_DEFAULT) {
8644
8626
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
8645
- to_fp32_cuda (dst_f16, dst_ddf, ne, main_stream);
8646
-
8647
- ggml_cuda_pool_free (dst_f16, dst_as);
8627
+ to_fp32_cuda(dst_f16.get(), dst_ddf, ne, main_stream);
8648
8628
}
8649
-
8650
- ggml_cuda_pool_free (src1_as_f16, src1_as);
8651
8629
}
8652
8630
8653
8631
static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -8974,12 +8952,11 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
8974
8952
ggml_cuda_mul_mat(src0_row, &src1_row, &dst_row);
8975
8953
}
8976
8954
} else {
8977
- size_t as_src1, as_dst;
8978
- char * src1_contiguous = (char *) ggml_cuda_pool_malloc (sizeof (float )*ggml_nelements (src1), &as_src1);
8979
- char * dst_contiguous = (char *) ggml_cuda_pool_malloc (sizeof (float )*ggml_nelements (dst), &as_dst);
8955
+ cuda_pool_alloc<char> src1_contiguous(sizeof(float)*ggml_nelements(src1));
8956
+ cuda_pool_alloc<char> dst_contiguous(sizeof(float)*ggml_nelements(dst));
8980
8957
8981
- src1_row_extra.data_device [g_main_device] = src1_contiguous;
8982
- dst_row_extra.data_device [g_main_device] = dst_contiguous;
8958
+ src1_row_extra.data_device[g_main_device] = src1_contiguous.get() ;
8959
+ dst_row_extra.data_device[g_main_device] = dst_contiguous.get() ;
8983
8960
8984
8961
const cudaMemcpyKind src1_kind = src1->backend == GGML_BACKEND_CPU ?
8985
8962
cudaMemcpyHostToDevice : cudaMemcpyDeviceToDevice;
@@ -8999,7 +8976,7 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
8999
8976
9000
8977
GGML_ASSERT(row_id >= 0 && row_id < n_as);
9001
8978
9002
- CUDA_CHECK (cudaMemcpyAsync (src1_contiguous + num_src1_rows*nb11, src1_original + i01*nb11,
8979
+ CUDA_CHECK(cudaMemcpyAsync(src1_contiguous.get() + num_src1_rows*nb11, src1_original + i01*nb11,
9003
8980
nb11, src1_kind, stream));
9004
8981
num_src1_rows++;
9005
8982
}
@@ -9031,14 +9008,11 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
9031
9008
9032
9009
GGML_ASSERT(row_id >= 0 && row_id < n_as);
9033
9010
9034
- CUDA_CHECK (cudaMemcpyAsync (dst_original + i01*nb1, dst_contiguous + num_src1_rows*nb1,
9011
+ CUDA_CHECK(cudaMemcpyAsync(dst_original + i01*nb1, dst_contiguous.get() + num_src1_rows*nb1,
9035
9012
nb1, dst_kind, stream));
9036
9013
num_src1_rows++;
9037
9014
}
9038
9015
}
9039
-
9040
- ggml_cuda_pool_free (dst_contiguous, as_dst);
9041
- ggml_cuda_pool_free (src1_contiguous, as_src1);
9042
9016
}
9043
9017
9044
9018
if (dst->backend == GGML_BACKEND_CPU) {
0 commit comments