39
39
#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
40
40
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
41
41
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
42
- #define cudaDeviceGetMemPool hipDeviceGetMemPool
43
- #define cudaMemPoolAttrReleaseThreshold hipMemPoolAttrReleaseThreshold
44
- #define cudaMemPoolSetAttribute hipMemPoolSetAttribute
45
- #define cudaMemPool_t hipMemPool_t
46
42
#define cudaDeviceProp hipDeviceProp_t
47
43
#define cudaDeviceSynchronize hipDeviceSynchronize
48
44
#define cudaError_t hipError_t
52
48
#define cudaEvent_t hipEvent_t
53
49
#define cudaEventDestroy hipEventDestroy
54
50
#define cudaFree hipFree
55
- #define cudaFreeAsync hipFreeAsync
56
51
#define cudaFreeHost hipHostFree
57
52
#define cudaGetDevice hipGetDevice
58
53
#define cudaGetDeviceCount hipGetDeviceCount
59
54
#define cudaGetDeviceProperties hipGetDeviceProperties
60
55
#define cudaGetErrorString hipGetErrorString
61
56
#define cudaGetLastError hipGetLastError
62
57
#define cudaMalloc hipMalloc
63
- #define cudaMallocFromPoolAsync hipMallocFromPoolAsync
64
58
#define cudaMallocHost (ptr, size ) hipHostMalloc(ptr, size, hipHostMallocDefault)
65
59
#define cudaMemcpy hipMemcpy
66
60
#define cudaMemcpy2DAsync hipMemcpy2DAsync
@@ -187,11 +181,11 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
187
181
do { \
188
182
cudaError_t err_ = (err); \
189
183
if (err_ != cudaSuccess) { \
190
- int dev_id ; \
191
- cudaGetDevice (&dev_id ); \
184
+ int id ; \
185
+ cudaGetDevice (&id ); \
192
186
fprintf (stderr, " \n CUDA error %d at %s:%d: %s\n " , err_, __FILE__, __LINE__, \
193
187
cudaGetErrorString (err_)); \
194
- fprintf (stderr, " current device: %d\n " , dev_id ); \
188
+ fprintf (stderr, " current device: %d\n " , id ); \
195
189
exit (1 ); \
196
190
} \
197
191
} while (0 )
@@ -201,11 +195,11 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
201
195
do { \
202
196
cublasStatus_t err_ = (err); \
203
197
if (err_ != CUBLAS_STATUS_SUCCESS) { \
204
- int dev_id ; \
205
- cudaGetDevice (&dev_id ); \
198
+ int id ; \
199
+ cudaGetDevice (&id ); \
206
200
fprintf (stderr, " \n cuBLAS error %d at %s:%d: %s\n " , \
207
201
err_, __FILE__, __LINE__, cublasGetStatusString (err_)); \
208
- fprintf (stderr, " current device: %d\n " , dev_id ); \
202
+ fprintf (stderr, " current device: %d\n " , id ); \
209
203
exit (1 ); \
210
204
} \
211
205
} while (0 )
@@ -471,7 +465,6 @@ static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUA
471
465
472
466
#define MAX_STREAMS 8
473
467
static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = { nullptr };
474
- static cudaMemPool_t g_cudaMemPools[GGML_CUDA_MAX_DEVICES] = { nullptr };
475
468
476
469
struct ggml_tensor_extra_gpu {
477
470
void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
@@ -5780,16 +5773,6 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
5780
5773
return ptr;
5781
5774
}
5782
5775
5783
- static void * ggml_cuda_pool_malloc_async (size_t size, size_t * actual_size, int id, cudaStream_t stream) {
5784
- if (g_cudaMemPools[id] == nullptr ) {
5785
- return ggml_cuda_pool_malloc (size, actual_size);
5786
- }
5787
- void *ptr;
5788
- CUDA_CHECK (cudaMallocFromPoolAsync (&ptr, size, g_cudaMemPools[id], stream));
5789
- *actual_size = size;
5790
- return ptr;
5791
- }
5792
-
5793
5776
static void ggml_cuda_pool_free (void * ptr, size_t size) {
5794
5777
scoped_spin_lock lock (g_cuda_pool_lock);
5795
5778
int id;
@@ -5808,13 +5791,6 @@ static void ggml_cuda_pool_free(void * ptr, size_t size) {
5808
5791
}
5809
5792
5810
5793
5811
- static void ggml_cuda_pool_free_async (void * ptr, size_t actual_size, int id, cudaStream_t stream) {
5812
- if (g_cudaMemPools[id] == nullptr ) {
5813
- return ggml_cuda_pool_free (ptr, actual_size);
5814
- }
5815
- CUDA_CHECK (cudaFreeAsync (ptr, stream));
5816
- }
5817
-
5818
5794
void ggml_init_cublas () {
5819
5795
static bool initialized = false ;
5820
5796
@@ -5869,13 +5845,6 @@ void ggml_init_cublas() {
5869
5845
// create cublas handle
5870
5846
CUBLAS_CHECK (cublasCreate (&g_cublas_handles[id]));
5871
5847
CUBLAS_CHECK (cublasSetMathMode (g_cublas_handles[id], CUBLAS_TF32_TENSOR_OP_MATH));
5872
-
5873
- // configure memory pool
5874
- cudaError_t err = cudaDeviceGetMemPool (&g_cudaMemPools[id], id);
5875
- if (err == cudaSuccess) {
5876
- size_t treshold = UINT64_MAX;
5877
- CUDA_CHECK (cudaMemPoolSetAttribute (g_cudaMemPools[id], cudaMemPoolAttrReleaseThreshold, &treshold));
5878
- }
5879
5848
}
5880
5849
5881
5850
// configure logging to stdout
@@ -6469,7 +6438,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
6469
6438
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda (src0->type );
6470
6439
GGML_ASSERT (to_fp16_cuda != nullptr );
6471
6440
size_t ne = row_diff*ne00;
6472
- src0_as_f16 = (half *) ggml_cuda_pool_malloc_async (ne * sizeof (half), &src0_as, id, stream );
6441
+ src0_as_f16 = (half *) ggml_cuda_pool_malloc (ne * sizeof (half), &src0_as);
6473
6442
to_fp16_cuda (src0_dd_i, src0_as_f16, ne, stream);
6474
6443
}
6475
6444
const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16;
@@ -6480,12 +6449,13 @@ inline void ggml_cuda_op_mul_mat_cublas(
6480
6449
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda (src1->type );
6481
6450
GGML_ASSERT (to_fp16_cuda != nullptr );
6482
6451
size_t ne = src1_ncols*ne10;
6483
- src1_as_f16 = (half *) ggml_cuda_pool_malloc_async (ne * sizeof (half), &src1_as, id, stream );
6452
+ src1_as_f16 = (half *) ggml_cuda_pool_malloc (ne * sizeof (half), &src1_as);
6484
6453
to_fp16_cuda (src1_ddf_i, src1_as_f16, ne, stream);
6485
6454
}
6486
6455
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddq_i : src1_as_f16;
6487
- size_t dst_f16_as = 0 ;
6488
- half * dst_f16 = (half *) ggml_cuda_pool_malloc_async (row_diff*src1_ncols * sizeof (half), &dst_f16_as, id, stream);
6456
+
6457
+ size_t dst_as = 0 ;
6458
+ half * dst_f16 = (half *) ggml_cuda_pool_malloc (row_diff*src1_ncols * sizeof (half), &dst_as);
6489
6459
6490
6460
const half alpha_f16 = 1 .0f ;
6491
6461
const half beta_f16 = 0 .0f ;
@@ -6503,15 +6473,14 @@ inline void ggml_cuda_op_mul_mat_cublas(
6503
6473
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda (GGML_TYPE_F16);
6504
6474
to_fp32_cuda (dst_f16, dst_dd_i, row_diff*src1_ncols, stream);
6505
6475
6506
- if (dst_f16_as != 0 ) {
6507
- ggml_cuda_pool_free_async (dst_f16, dst_f16_as, id, stream);
6508
- }
6476
+ ggml_cuda_pool_free (dst_f16, dst_as);
6509
6477
6510
6478
if (src0_as != 0 ) {
6511
- ggml_cuda_pool_free_async (src0_as_f16, src0_as, id, stream );
6479
+ ggml_cuda_pool_free (src0_as_f16, src0_as);
6512
6480
}
6481
+
6513
6482
if (src1_as != 0 ) {
6514
- ggml_cuda_pool_free_async (src1_as_f16, src1_as, id, stream );
6483
+ ggml_cuda_pool_free (src1_as_f16, src1_as);
6515
6484
}
6516
6485
}
6517
6486
else {
@@ -6521,7 +6490,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
6521
6490
if (src0->type != GGML_TYPE_F32) {
6522
6491
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda (src0->type );
6523
6492
GGML_ASSERT (to_fp32_cuda != nullptr );
6524
- src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc_async (row_diff*ne00 * sizeof (float ), &src0_as, id, stream ); // NOLINT
6493
+ src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc (row_diff*ne00 * sizeof (float ), &src0_as); // NOLINT
6525
6494
to_fp32_cuda (src0_dd_i, src0_ddq_as_f32, row_diff*ne00, stream);
6526
6495
}
6527
6496
const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32;
@@ -6538,7 +6507,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
6538
6507
&beta, dst_dd_i, ldc));
6539
6508
6540
6509
if (src0_as != 0 ) {
6541
- ggml_cuda_pool_free_async (src0_ddq_as_f32, src0_as, id, stream );
6510
+ ggml_cuda_pool_free (src0_ddq_as_f32, src0_as);
6542
6511
}
6543
6512
}
6544
6513
@@ -6961,30 +6930,29 @@ static void ggml_cuda_op_mul_mat(
6961
6930
src0_dd[id] = (char *) src0_extra->data_device [id];
6962
6931
} else {
6963
6932
const size_t size_src0_ddq = split ? (row_high[id]-row_low[id])*ne00 * src0_ts/src0_bs : ggml_nbytes (src0);
6964
- src0_dd[id] = (char *) ggml_cuda_pool_malloc_async (ggml_nbytes (src0), &src0_as[id], id, stream );
6933
+ src0_dd[id] = (char *) ggml_cuda_pool_malloc (ggml_nbytes (src0), &src0_as[id]);
6965
6934
}
6966
6935
6967
6936
if (src1_on_device && src1_is_contiguous) {
6968
6937
src1_ddf[id] = (float *) src1_extra->data_device [id];
6969
6938
} else {
6970
- src1_ddf[id] = (float *) ggml_cuda_pool_malloc_async (ggml_nbytes (src1), &src1_asf[id], id, stream );
6939
+ src1_ddf[id] = (float *) ggml_cuda_pool_malloc (ggml_nbytes (src1), &src1_asf[id]);
6971
6940
}
6972
6941
6973
6942
if (convert_src1_to_q8_1) {
6974
- const size_t size_dst_ddq = nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs;
6975
- src1_ddq[id] = (char *) ggml_cuda_pool_malloc_async (size_dst_ddq, &src1_asq[id], id, stream);
6943
+ src1_ddq[id] = (char *) ggml_cuda_pool_malloc (nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs, &src1_asq[id]);
6976
6944
6977
6945
if (src1_on_device && src1_is_contiguous) {
6978
6946
quantize_row_q8_1_cuda (src1_ddf[id], src1_ddq[id], ne10, nrows1, src1_padded_col_size, stream);
6979
- // CUDA_CHECK(cudaGetLastError());
6947
+ CUDA_CHECK (cudaGetLastError ());
6980
6948
}
6981
6949
}
6982
6950
6983
6951
if (dst_on_device) {
6984
6952
dst_dd[id] = (float *) dst_extra->data_device [id];
6985
6953
} else {
6986
6954
const size_t size_dst_ddf = split ? (row_high[id]-row_low[id])*ne1*sizeof (float ) : ggml_nbytes (dst);
6987
- dst_dd[id] = (float *) ggml_cuda_pool_malloc_async (size_dst_ddf, &dst_as[id], id, stream );
6955
+ dst_dd[id] = (float *) ggml_cuda_pool_malloc (size_dst_ddf, &dst_as[id]);
6988
6956
}
6989
6957
}
6990
6958
@@ -7110,6 +7078,24 @@ static void ggml_cuda_op_mul_mat(
7110
7078
}
7111
7079
}
7112
7080
7081
+ for (int64_t id = 0 ; id < g_device_count; ++id) {
7082
+ CUDA_CHECK (ggml_cuda_set_device (id));
7083
+
7084
+ // free buffers again when done
7085
+ if (src0_as[id] > 0 ) {
7086
+ ggml_cuda_pool_free (src0_dd[id], src0_as[id]);
7087
+ }
7088
+ if (src1_asf[id] > 0 ) {
7089
+ ggml_cuda_pool_free (src1_ddf[id], src1_asf[id]);
7090
+ }
7091
+ if (src1_asq[id] > 0 ) {
7092
+ ggml_cuda_pool_free (src1_ddq[id], src1_asq[id]);
7093
+ }
7094
+ if (dst_as[id] > 0 ) {
7095
+ ggml_cuda_pool_free (dst_dd[id], dst_as[id]);
7096
+ }
7097
+ }
7098
+
7113
7099
// main device waits for all other devices to be finished
7114
7100
if (split && g_device_count > 1 ) {
7115
7101
int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1 ) / MUL_MAT_SRC1_COL_STRIDE;
@@ -7127,21 +7113,6 @@ static void ggml_cuda_op_mul_mat(
7127
7113
CUDA_CHECK (ggml_cuda_set_device (g_main_device));
7128
7114
CUDA_CHECK (cudaDeviceSynchronize ());
7129
7115
}
7130
-
7131
- for (int64_t id = 0 ; id < g_device_count; ++id) {
7132
- if (src0_as[id] > 0 ) {
7133
- ggml_cuda_pool_free_async (src0_dd[id], src0_as[id], id, g_cudaStreams[id][0 ]);
7134
- }
7135
- if (src1_asf[id] > 0 ) {
7136
- ggml_cuda_pool_free_async (src1_ddf[id], src1_asf[id], id, g_cudaStreams[id][0 ]);
7137
- }
7138
- if (src1_asq[id] > 0 ) {
7139
- ggml_cuda_pool_free_async (src1_ddq[id], src1_asq[id], id, g_cudaStreams[id][0 ]);
7140
- }
7141
- if (dst_as[id] > 0 ) {
7142
- ggml_cuda_pool_free_async (dst_dd[id], dst_as[id], id, g_cudaStreams[id][0 ]);
7143
- }
7144
- }
7145
7116
}
7146
7117
7147
7118
static void ggml_cuda_repeat (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -7328,11 +7299,11 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
7328
7299
GGML_ASSERT (to_fp16_cuda != nullptr );
7329
7300
7330
7301
size_t src1_as = 0 ;
7331
- half * src1_as_f16 = (half *) ggml_cuda_pool_malloc_async (ne1 * sizeof (half), &src1_as, id, main_stream );
7302
+ half * src1_as_f16 = (half *) ggml_cuda_pool_malloc (ne1 * sizeof (half), &src1_as);
7332
7303
to_fp16_cuda (src1_ddf, src1_as_f16, ne1, main_stream);
7333
7304
7334
7305
size_t dst_as = 0 ;
7335
- half * dst_f16 = (half *) ggml_cuda_pool_malloc_async (ne * sizeof (half), &dst_as, id, main_stream );
7306
+ half * dst_f16 = (half *) ggml_cuda_pool_malloc (ne * sizeof (half), &dst_as);
7336
7307
7337
7308
GGML_ASSERT (ne12 % ne02 == 0 );
7338
7309
GGML_ASSERT (ne13 % ne03 == 0 );
@@ -7386,8 +7357,8 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
7386
7357
size_t ptrs_src_s = 0 ;
7387
7358
size_t ptrs_dst_s = 0 ;
7388
7359
7389
- ptrs_src = (const void **) ggml_cuda_pool_malloc_async (2 *ne23*sizeof (void *), &ptrs_src_s, id, main_stream );
7390
- ptrs_dst = ( void **) ggml_cuda_pool_malloc_async (1 *ne23*sizeof (void *), &ptrs_dst_s, id, main_stream );
7360
+ ptrs_src = (const void **) ggml_cuda_pool_malloc (2 *ne23*sizeof (void *), &ptrs_src_s);
7361
+ ptrs_dst = ( void **) ggml_cuda_pool_malloc (1 *ne23*sizeof (void *), &ptrs_dst_s);
7391
7362
7392
7363
dim3 block_dims (ne13, ne12);
7393
7364
k_compute_batched_ptrs<<<1 , block_dims, 0 , main_stream>>> (
@@ -7400,6 +7371,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
7400
7371
dst->nb [2 ], dst->nb [3 ],
7401
7372
r2, r3);
7402
7373
CUDA_CHECK (cudaGetLastError ());
7374
+
7403
7375
CUBLAS_CHECK (
7404
7376
cublasGemmBatchedEx (g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
7405
7377
ne01, ne11, ne10,
@@ -7411,22 +7383,19 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
7411
7383
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
7412
7384
7413
7385
if (ptrs_src_s != 0 ) {
7414
- ggml_cuda_pool_free_async (ptrs_src, ptrs_src_s, id, main_stream );
7386
+ ggml_cuda_pool_free (ptrs_src, ptrs_src_s);
7415
7387
}
7416
7388
if (ptrs_dst_s != 0 ) {
7417
- ggml_cuda_pool_free_async (ptrs_dst, ptrs_dst_s, id, main_stream );
7389
+ ggml_cuda_pool_free (ptrs_dst, ptrs_dst_s);
7418
7390
}
7419
7391
}
7420
7392
#endif
7421
7393
7422
7394
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda (GGML_TYPE_F16);
7423
7395
to_fp32_cuda (dst_f16, dst_ddf, ne, main_stream);
7424
- if (src1_as != 0 ) {
7425
- ggml_cuda_pool_free_async (src1_as_f16, src1_as, id, main_stream);
7426
- }
7427
- if (dst_as != 0 ) {
7428
- ggml_cuda_pool_free_async (dst_f16, dst_as, id, main_stream);
7429
- }
7396
+
7397
+ ggml_cuda_pool_free (src1_as_f16, src1_as);
7398
+ ggml_cuda_pool_free (dst_f16, dst_as);
7430
7399
}
7431
7400
7432
7401
static void ggml_cuda_mul_mat (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
0 commit comments